{-# Language TypeFamilies, TypeOperators, BlockArguments #-}
{-|
Module      : Advent.SmallSet
Description : An efficient set representation for small integers.
Copyright   : (c) Eric Mertens, 2021
License     : ISC
Maintainer  : emertens@gmail.com

A compact set type for when you have very few elements to track.

-}
module Advent.SmallSet where

import Data.Bits
import Data.Coerce (coerce)
import Data.List (foldl')
import Data.MemoTrie (HasTrie(..))
import Data.Word (Word64)

-- | Sets of integers from 0 to 63 efficiently implemented using a Word64
newtype SmallSet = SmallSet Word64
  deriving (SmallSet -> SmallSet -> Bool
(SmallSet -> SmallSet -> Bool)
-> (SmallSet -> SmallSet -> Bool) -> Eq SmallSet
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SmallSet -> SmallSet -> Bool
== :: SmallSet -> SmallSet -> Bool
$c/= :: SmallSet -> SmallSet -> Bool
/= :: SmallSet -> SmallSet -> Bool
Eq, Eq SmallSet
Eq SmallSet =>
(SmallSet -> SmallSet -> Ordering)
-> (SmallSet -> SmallSet -> Bool)
-> (SmallSet -> SmallSet -> Bool)
-> (SmallSet -> SmallSet -> Bool)
-> (SmallSet -> SmallSet -> Bool)
-> (SmallSet -> SmallSet -> SmallSet)
-> (SmallSet -> SmallSet -> SmallSet)
-> Ord SmallSet
SmallSet -> SmallSet -> Bool
SmallSet -> SmallSet -> Ordering
SmallSet -> SmallSet -> SmallSet
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: SmallSet -> SmallSet -> Ordering
compare :: SmallSet -> SmallSet -> Ordering
$c< :: SmallSet -> SmallSet -> Bool
< :: SmallSet -> SmallSet -> Bool
$c<= :: SmallSet -> SmallSet -> Bool
<= :: SmallSet -> SmallSet -> Bool
$c> :: SmallSet -> SmallSet -> Bool
> :: SmallSet -> SmallSet -> Bool
$c>= :: SmallSet -> SmallSet -> Bool
>= :: SmallSet -> SmallSet -> Bool
$cmax :: SmallSet -> SmallSet -> SmallSet
max :: SmallSet -> SmallSet -> SmallSet
$cmin :: SmallSet -> SmallSet -> SmallSet
min :: SmallSet -> SmallSet -> SmallSet
Ord)

-- | Construct a set given a list of set members.
--
-- >>> fromList []
-- fromList []
--
-- >>> fromList [0]
-- fromList [0]
--
-- >>> fromList [63]
-- fromList [63]
--
-- >>> fromList [0,10,20]
-- fromList [0,10,20]
fromList :: [Int] -> SmallSet
fromList :: [Int] -> SmallSet
fromList = (SmallSet -> Int -> SmallSet) -> SmallSet -> [Int] -> SmallSet
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((Int -> SmallSet -> SmallSet) -> SmallSet -> Int -> SmallSet
forall a b c. (a -> b -> c) -> b -> a -> c
flip Int -> SmallSet -> SmallSet
insert) SmallSet
empty

-- | Return an ordered list of the elements in the set
toList :: SmallSet -> [Int]
toList :: SmallSet -> [Int]
toList (SmallSet Word64
s) = Int -> Word64 -> [Int]
forall {t}. (FiniteBits t, Num t) => Int -> t -> [Int]
go Int
0 Word64
s
  where
    go :: Int -> t -> [Int]
go Int
offset t
n
      | t
0 t -> t -> Bool
forall a. Eq a => a -> a -> Bool
== t
n     = []
      | Int
next Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
63 = [Int
63] -- avoid shift overflow
      | Bool
otherwise  = Int
x Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int -> t -> [Int]
go (Int
1Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
x) (t
n t -> Int -> t
forall a. Bits a => a -> Int -> a
`unsafeShiftR` (Int
nextInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1))
      where
        next :: Int
next = t -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros t
n
        x :: Int
x    = Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
next

-- | Predicate for integer elements that fit in a 'SmallSet'
inRange :: Int -> Bool
inRange :: Int -> Bool
inRange Int
x = Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
x Bool -> Bool -> Bool
&& Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
64

-- | The set with no members.
--
-- >>> empty
-- fromList []
empty :: SmallSet
empty :: SmallSet
empty = Word64 -> SmallSet
SmallSet Word64
0

-- | Predicate for empty sets.
null :: SmallSet -> Bool
null :: SmallSet -> Bool
null (SmallSet Word64
x) = Word64
0 Word64 -> Word64 -> Bool
forall a. Eq a => a -> a -> Bool
== Word64
x

-- | Make a new set with a single element
--
-- >>> singleton 42
-- fromList [42]
singleton :: Int -> SmallSet
singleton :: Int -> SmallSet
singleton Int
x
  | Int -> Bool
inRange Int
x = Word64 -> SmallSet
SmallSet (Int -> Word64
forall a. Bits a => Int -> a
bit Int
x)
  | Bool
otherwise = [Char] -> SmallSet
forall a. HasCallStack => [Char] -> a
error ([Char]
"Advent.SmallSet.singleton: bad argument " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
x)

-- | Compute the union of two sets
--
-- >>> union (fromList [3,4,5,6]) (fromList [5,6,7,8])
-- fromList [3,4,5,6,7,8]
union :: SmallSet -> SmallSet -> SmallSet
union :: SmallSet -> SmallSet -> SmallSet
union (SmallSet Word64
x) (SmallSet Word64
y) = Word64 -> SmallSet
SmallSet (Word64
x Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.|. Word64
y)

-- | Union of the sets in a list
--
-- >>> unions []
-- fromList []
--
-- >>> unions [singleton 1, fromList [2,4], fromList [2,3]]
-- fromList [1,2,3,4]
unions :: [SmallSet] -> SmallSet
unions :: [SmallSet] -> SmallSet
unions = (SmallSet -> SmallSet -> SmallSet)
-> SmallSet -> [SmallSet] -> SmallSet
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' SmallSet -> SmallSet -> SmallSet
union SmallSet
empty

-- | Compute the intersection of two sets
--
-- >>> intersection (fromList [3,4,5,6]) (fromList [5,6,7,8])
-- fromList [5,6]
intersection :: SmallSet -> SmallSet -> SmallSet
intersection :: SmallSet -> SmallSet -> SmallSet
intersection (SmallSet Word64
x) (SmallSet Word64
y) = Word64 -> SmallSet
SmallSet (Word64
x Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
y)

-- | Subtract the elements of the second set from the first set.
--
-- >>> difference (fromList [3,4,5,6]) (fromList [5,6,7,8])
-- fromList [3,4]
difference :: SmallSet -> SmallSet -> SmallSet
difference :: SmallSet -> SmallSet -> SmallSet
difference (SmallSet Word64
x) (SmallSet Word64
y) = Word64 -> SmallSet
SmallSet (Word64
x Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64 -> Word64
forall a. Bits a => a -> a
complement Word64
y)

-- | Operator for 'difference'
(\\) :: SmallSet -> SmallSet -> SmallSet
\\ :: SmallSet -> SmallSet -> SmallSet
(\\) = SmallSet -> SmallSet -> SmallSet
difference

infix 5 \\

-- | Add an element to a set
--
-- >>> insert 10 (fromList [3,4,5])
-- fromList [3,4,5,10]
--
-- >>> insert 5 (fromList [3,4,5])
-- fromList [3,4,5]
insert :: Int -> SmallSet -> SmallSet
insert :: Int -> SmallSet -> SmallSet
insert Int
x (SmallSet Word64
y)
  | Int -> Bool
inRange Int
x = Word64 -> SmallSet
SmallSet (Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
setBit Word64
y Int
x)
  | Bool
otherwise = [Char] -> SmallSet
forall a. HasCallStack => [Char] -> a
error ([Char]
"Advent.SmallSet.insert: bad argument " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
x)

-- | Remove an element from a set
--
-- >>> delete 5 (fromList [3,4,5])
-- fromList [3,4]
--
-- >>> delete 8 (fromList [3,4,5])
-- fromList [3,4,5]
delete :: Int -> SmallSet -> SmallSet
delete :: Int -> SmallSet -> SmallSet
delete Int
x (SmallSet Word64
y)
  | Int -> Bool
inRange Int
x = Word64 -> SmallSet
SmallSet (Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
clearBit Word64
y Int
x)
  | Bool
otherwise = [Char] -> SmallSet
forall a. HasCallStack => [Char] -> a
error ([Char]
"Advent.SmallSet.insert: bad argument " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
x)

-- | Check if a set contains an element
--
-- >>> member 8 (fromList [3,4,5])
-- False
--
-- >>> member 4 (fromList [3,4,5])
-- True
member :: Int -> SmallSet -> Bool
member :: Int -> SmallSet -> Bool
member Int
x (SmallSet Word64
y)
  | Int -> Bool
inRange Int
x = Word64 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Word64
y Int
x
  | Bool
otherwise = [Char] -> Bool
forall a. HasCallStack => [Char] -> a
error ([Char]
"Advent.SmallSet.member: bad argument " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
x)

-- | Check if two sets contain no common elements
disjoint :: SmallSet -> SmallSet -> Bool
disjoint :: SmallSet -> SmallSet -> Bool
disjoint SmallSet
x SmallSet
y = SmallSet -> Bool
Advent.SmallSet.null (SmallSet -> SmallSet -> SmallSet
intersection SmallSet
x SmallSet
y)

-- | Number of elements in a set.
--
-- >>> size (fromList [1,2,3])
-- 3
--
-- >>> size empty
-- 0
size :: SmallSet -> Int
size :: SmallSet -> Int
size (SmallSet Word64
x) = Word64 -> Int
forall a. Bits a => a -> Int
popCount Word64
x

setRep :: SmallSet -> Word64
setRep :: SmallSet -> Word64
setRep (SmallSet Word64
x) = Word64
x

-- | Shows a 'SmallSet' using 'fromList' syntax
instance Show SmallSet where
  showsPrec :: Int -> SmallSet -> [Char] -> [Char]
showsPrec Int
p SmallSet
x = Bool -> ([Char] -> [Char]) -> [Char] -> [Char]
showParen (Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10) ([Char] -> [Char] -> [Char]
showString [Char]
"fromList " ([Char] -> [Char]) -> ([Char] -> [Char]) -> [Char] -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Char] -> [Char]
forall a. Show a => a -> [Char] -> [Char]
shows (SmallSet -> [Int]
toList SmallSet
x))

-- | Reads a 'SmallSet' using 'fromList' syntax
instance Read SmallSet where
  readsPrec :: Int -> ReadS SmallSet
readsPrec Int
p = Bool -> ReadS SmallSet -> ReadS SmallSet
forall a. Bool -> ReadS a -> ReadS a
readParen (Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10) \[Char]
s ->
    [([Int] -> SmallSet
fromList [Int]
xs, [Char]
s2) | ([Char]
"fromList", [Char]
s1) <- ReadS [Char]
lex [Char]
s, ([Int]
xs, [Char]
s2) <- ReadS [Int]
forall a. Read a => ReadS a
reads [Char]
s1]

-- | Instance derived from: 'HasTrie' 'Word64'
instance HasTrie SmallSet where
  newtype SmallSet :->: a = T (Word64 :->: a)
  trie :: forall b. (SmallSet -> b) -> SmallSet :->: b
trie = (((Word64 -> a) -> Word64 :->: a)
-> (SmallSet -> a) -> SmallSet :->: a
forall {a}.
((Word64 -> a) -> Word64 :->: a)
-> (SmallSet -> a) -> SmallSet :->: a
forall a b. Coercible a b => a -> b
coerce :: ((Word64 -> a) -> Word64 :->: a) -> (SmallSet -> a) -> SmallSet :->: a) (Word64 -> b) -> Word64 :->: b
forall a b. HasTrie a => (a -> b) -> a :->: b
forall b. (Word64 -> b) -> Word64 :->: b
trie
  untrie :: forall b. (SmallSet :->: b) -> SmallSet -> b
untrie = (((Word64 :->: a) -> Word64 -> a)
-> (SmallSet :->: a) -> SmallSet -> a
forall {a}.
((Word64 :->: a) -> Word64 -> a)
-> (SmallSet :->: a) -> SmallSet -> a
forall a b. Coercible a b => a -> b
coerce :: (Word64 :->: a -> Word64 -> a) -> SmallSet :->: a -> SmallSet -> a) (Word64 :->: b) -> Word64 -> b
forall a b. HasTrie a => (a :->: b) -> a -> b
forall b. (Word64 :->: b) -> Word64 -> b
untrie
  enumerate :: forall b. (SmallSet :->: b) -> [(SmallSet, b)]
enumerate = (((Word64 :->: a) -> [(Word64, a)])
-> (SmallSet :->: a) -> [(SmallSet, a)]
forall {a}.
((Word64 :->: a) -> [(Word64, a)])
-> (SmallSet :->: a) -> [(SmallSet, a)]
forall a b. Coercible a b => a -> b
coerce :: (Word64 :->: a -> [(Word64, a)]) -> SmallSet :->: a -> [(SmallSet, a)]) (Word64 :->: b) -> [(Word64, b)]
forall a b. HasTrie a => (a :->: b) -> [(a, b)]
forall b. (Word64 :->: b) -> [(Word64, b)]
enumerate