{-# Language ImportQualifiedPost, QuasiQuotes, TemplateHaskell #-}
{-|
Module      : Main
Description : Day 14 solution
Copyright   : (c) Eric Mertens, 2020
License     : ISC
Maintainer  : emertens@gmail.com

<https://adventofcode.com/2020/day/14>

@
>>> :{
let cmds = parseInput
      "mask = XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X\n\
      \mem[8] = 11\n\
      \mem[7] = 101\n\
      \mem[8] = 0\n"
in run1 [] IntMap.empty cmds
:}
165

>>> :{
let cmds = parseInput
      "mask = 000000000000000000000000000000X1001X\n\
      \mem[42] = 100\n\
      \mask = 00000000000000000000000000000000X0XX\n\
      \mem[26] = 1\n"
in run2 [] IntMap.empty cmds
:}
208

@

-}

module Main where

import Advent (format, stageTH)
import Data.Bits (setBit, clearBit)
import Data.IntMap (IntMap)
import Data.IntMap qualified as IntMap
import Data.List (foldl')

type Cmd = Either [M] (Int,Int)
data M = M1 | M0 | MX deriving (Int -> M -> ShowS
[M] -> ShowS
M -> String
(Int -> M -> ShowS) -> (M -> String) -> ([M] -> ShowS) -> Show M
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> M -> ShowS
showsPrec :: Int -> M -> ShowS
$cshow :: M -> String
show :: M -> String
$cshowList :: [M] -> ShowS
showList :: [M] -> ShowS
Show)

stageTH

[format|((mask = @M*|mem[%u] = %u)%n)*|]

-- |
-- >>> :main
-- 17934269678453
-- 3440662844064
main :: IO ()
IO ()
main =
  do inp <- Int -> Int -> IO Input
getInput Int
2020 Int
14
     print (run1 [] IntMap.empty inp)
     print (run2 [] IntMap.empty inp)

-- | Simulate the computer using the 'mask1' rule.
run1 ::
  [M]        {- ^ initial mask       -} ->
  IntMap Int {- ^ initial memory     -} ->
  [Cmd]      {- ^ program statements -} ->
  Int
run1 :: [M] -> IntMap Int -> Input -> Int
run1 [M]
_    IntMap Int
mem []                 = IntMap Int -> Int
forall a. Num a => IntMap a -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum IntMap Int
mem
run1 [M]
_    IntMap Int
mem (Left [M]
mask   : Input
xs) = [M] -> IntMap Int -> Input -> Int
run1 [M]
mask IntMap Int
mem Input
xs
run1 [M]
mask IntMap Int
mem (Right (Int
k,Int
v) : Input
xs) = [M] -> IntMap Int -> Input -> Int
run1 [M]
mask IntMap Int
mem' Input
xs
  where
    mem' :: IntMap Int
mem' = Int -> Int -> IntMap Int -> IntMap Int
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
k Int
v' IntMap Int
mem
    v' :: Int
v'   = Int -> Int -> [M] -> Int
mask1 Int
v Int
35 [M]
mask

-- | Apply a mask where @1@ and @0@ overwrite bits.
--
-- >>> mask1 11 6 [M1,MX,MX,MX,MX,M0,MX]
-- 73
--
-- >>> mask1 101 6 [M1,MX,MX,MX,MX,M0,MX]
-- 101
--
-- >>> mask1 0 6 [M1,MX,MX,MX,MX,M0,MX]
-- 64
mask1 ::
  Int {- ^ target value                   -} ->
  Int {- ^ bit index of beginning of mask -} ->
  [M] -> Int
mask1 :: Int -> Int -> [M] -> Int
mask1 Int
acc Int
i (M
M1:[M]
xs) = Int -> Int -> [M] -> Int
mask1 (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
setBit   Int
acc Int
i) (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) [M]
xs
mask1 Int
acc Int
i (M
M0:[M]
xs) = Int -> Int -> [M] -> Int
mask1 (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
clearBit Int
acc Int
i) (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) [M]
xs
mask1 Int
acc Int
i (M
MX:[M]
xs) = Int -> Int -> [M] -> Int
mask1 Int
acc              (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) [M]
xs
mask1 Int
acc Int
_ []     = Int
acc

-- | Simulate the computer using the 'mask2' rule.
run2 ::
  [M]        {- ^ initial mask       -} ->
  IntMap Int {- ^ initial memory     -} ->
  [Cmd]      {- ^ program statements -} ->
  Int        {- ^ sum of memory      -}
run2 :: [M] -> IntMap Int -> Input -> Int
run2 [M]
_    IntMap Int
mem []                 = IntMap Int -> Int
forall a. Num a => IntMap a -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum IntMap Int
mem
run2 [M]
_    IntMap Int
mem (Left [M]
mask   : Input
xs) = [M] -> IntMap Int -> Input -> Int
run2 [M]
mask IntMap Int
mem Input
xs
run2 [M]
mask IntMap Int
mem (Right (Int
k,Int
v) : Input
xs) = [M] -> IntMap Int -> Input -> Int
run2 [M]
mask IntMap Int
mem' Input
xs
  where
    mem' :: IntMap Int
mem' = (IntMap Int -> Int -> IntMap Int)
-> IntMap Int -> [Int] -> IntMap Int
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\IntMap Int
m_ Int
k_ -> Int -> Int -> IntMap Int -> IntMap Int
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
k_ Int
v IntMap Int
m_) IntMap Int
mem
         ([Int] -> IntMap Int) -> [Int] -> IntMap Int
forall a b. (a -> b) -> a -> b
$ Int -> Int -> [M] -> [Int]
mask2 Int
k Int
35 [M]
mask

-- | Apply a mask where 'I' overwrites and 'X' takes both bit values.
--
-- >>> mask2 42 5 [MX,M1,M0,M0,M1,MX]
-- [59,27,58,26]
mask2 ::
  Int {- ^ target value                   -} ->
  Int {- ^ bit index of beginning of mask -} ->
  [M] -> [Int]
mask2 :: Int -> Int -> [M] -> [Int]
mask2 Int
x Int
i (M
M1:[M]
xs) = Int -> Int -> [M] -> [Int]
mask2 (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
setBit Int
x Int
i) (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) [M]
xs
mask2 Int
x Int
i (M
M0:[M]
xs) = Int -> Int -> [M] -> [Int]
mask2 Int
x (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) [M]
xs
mask2 Int
x Int
i (M
MX:[M]
xs) = do y <- Int -> Int -> [M] -> [Int]
mask2 (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
setBit Int
x Int
i) (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) [M]
xs; [y, clearBit y i]
mask2 Int
x Int
_ []      = [Int
x]