------------------------------------------------
-- |
-- Module    : Data.MemoCombinators
-- Copyright : (c) Luke Palmer 2008-2010
-- License   : BSD3
--
-- Maintainer : Luke Palmer <lrpalmer@gmail.com>
-- Stability  : experimental
--
-- This module provides combinators for building memo tables
-- over various data types, so that the type of table can
-- be customized depending on the application.
--
-- This module is designed to be imported /qualified/, eg.
--
-- > import qualified Data.MemoCombinators as Memo
--
-- Usage is straightforward: apply an object of type @Memo a@
-- to a function of type @a -> b@, and get a memoized function
-- of type @a -> b@.  For example:
--
-- > fib = Memo.integral fib'
-- >    where
-- >    fib' 0 = 0
-- >    fib' 1 = 1
-- >    fib' x = fib (x-1) + fib (x-2)
------------------------------------------------

module Data.MemoCombinators
    ( Memo
    , wrap
    , memo2, memo3, memoSecond, memoThird
    , bool, char, list, boundedList, either, maybe, unit, pair
    , enum, integral, bits
    , switch
    , RangeMemo
    , arrayRange, unsafeArrayRange, chunks
    )
where

import Prelude hiding (either, maybe)
import Data.Bits
import qualified Data.Array as Array
import Data.Char (ord,chr)
import qualified Data.IntTrie as IntTrie

-- | The type of a memo table for functions of a.
type Memo a = forall r. (a -> r) -> (a -> r)

-- | Given a memoizer for a and an isomorphism between a and b, build
-- a memoizer for b.
wrap :: (a -> b) -> (b -> a) -> Memo a -> Memo b
wrap :: forall a b. (a -> b) -> (b -> a) -> Memo a -> Memo b
wrap a -> b
i b -> a
j Memo a
m b -> r
f = (a -> r) -> a -> r
Memo a
m (b -> r
f (b -> r) -> (a -> b) -> a -> r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
i) (a -> r) -> (b -> a) -> b -> r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> a
j

-- | Memoize a two argument function (just apply the table directly for
-- single argument functions).
memo2 :: Memo a -> Memo b -> (a -> b -> r) -> (a -> b -> r)
memo2 :: forall a b r. Memo a -> Memo b -> (a -> b -> r) -> a -> b -> r
memo2 Memo a
a Memo b
b = (a -> b -> r) -> a -> b -> r
Memo a
a ((a -> b -> r) -> a -> b -> r)
-> ((a -> b -> r) -> a -> b -> r) -> (a -> b -> r) -> a -> b -> r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((b -> r) -> b -> r
Memo b
b ((b -> r) -> b -> r) -> (a -> b -> r) -> a -> b -> r
forall b c a. (b -> c) -> (a -> b) -> a -> c
.)

-- | Memoize a three argument function.
memo3 :: Memo a -> Memo b -> Memo c -> (a -> b -> c -> r) -> (a -> b -> c -> r)
memo3 :: forall a b c r.
Memo a
-> Memo b -> Memo c -> (a -> b -> c -> r) -> a -> b -> c -> r
memo3 Memo a
a Memo b
b Memo c
c = (a -> b -> c -> r) -> a -> b -> c -> r
Memo a
a ((a -> b -> c -> r) -> a -> b -> c -> r)
-> ((a -> b -> c -> r) -> a -> b -> c -> r)
-> (a -> b -> c -> r)
-> a
-> b
-> c
-> r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Memo b -> Memo c -> (b -> c -> r) -> b -> c -> r
forall a b r. Memo a -> Memo b -> (a -> b -> r) -> a -> b -> r
memo2 (b -> r) -> b -> r
Memo b
b (c -> r) -> c -> r
Memo c
c ((b -> c -> r) -> b -> c -> r)
-> (a -> b -> c -> r) -> a -> b -> c -> r
forall b c a. (b -> c) -> (a -> b) -> a -> c
.)

-- | Memoize the second argument of a function.
memoSecond :: Memo b -> (a -> b -> r) -> (a -> b -> r)
memoSecond :: forall b a r. Memo b -> (a -> b -> r) -> a -> b -> r
memoSecond Memo b
b = ((b -> r) -> b -> r
Memo b
b ((b -> r) -> b -> r) -> (a -> b -> r) -> a -> b -> r
forall b c a. (b -> c) -> (a -> b) -> a -> c
.)

-- | Memoize the third argument of a function.
memoThird :: Memo c -> (a -> b -> c -> r) -> (a -> b -> c -> r)
memoThird :: forall c a b r. Memo c -> (a -> b -> c -> r) -> a -> b -> c -> r
memoThird Memo c
c = (Memo c -> (b -> c -> r) -> b -> c -> r
forall b a r. Memo b -> (a -> b -> r) -> a -> b -> r
memoSecond (c -> r) -> c -> r
Memo c
c ((b -> c -> r) -> b -> c -> r)
-> (a -> b -> c -> r) -> a -> b -> c -> r
forall b c a. (b -> c) -> (a -> b) -> a -> c
.)

bool :: Memo Bool
bool :: Memo Bool
bool Bool -> r
f = r -> r -> Bool -> r
forall {p}. p -> p -> Bool -> p
cond (Bool -> r
f Bool
True) (Bool -> r
f Bool
False)
    where
    cond :: p -> p -> Bool -> p
cond p
t p
f Bool
True  = p
t
    cond p
t p
f Bool
False = p
f

list :: Memo a -> Memo [a]
list :: forall a. Memo a -> Memo [a]
list Memo a
m [a] -> r
f = r -> (a -> [a] -> r) -> [a] -> r
forall {p} {t}. p -> (t -> [t] -> p) -> [t] -> p
table ([a] -> r
f []) ((a -> [a] -> r) -> a -> [a] -> r
Memo a
m (\a
x -> Memo a -> Memo [a]
forall a. Memo a -> Memo [a]
list (a -> r) -> a -> r
Memo a
m ([a] -> r
f ([a] -> r) -> ([a] -> [a]) -> [a] -> r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
xa -> [a] -> [a]
forall a. a -> [a] -> [a]
:))))
    where
    table :: p -> (t -> [t] -> p) -> [t] -> p
table p
nil t -> [t] -> p
cons [] = p
nil
    table p
nil t -> [t] -> p
cons (t
x:[t]
xs) = t -> [t] -> p
cons t
x [t]
xs

char :: Memo Char
char :: Memo Char
char = (Int -> Char) -> (Char -> Int) -> Memo Int -> Memo Char
forall a b. (a -> b) -> (b -> a) -> Memo a -> Memo b
wrap Int -> Char
chr Char -> Int
ord (Int -> r) -> Int -> r
forall a. Integral a => Memo a
Memo Int
integral

-- | Build a table which memoizes all lists of less than the given length.
boundedList :: Int -> Memo a -> Memo [a]
boundedList :: forall a. Int -> Memo a -> Memo [a]
boundedList Int
0 Memo a
m [a] -> r
f = [a] -> r
f
boundedList Int
n Memo a
m [a] -> r
f = r -> (a -> [a] -> r) -> [a] -> r
forall {p} {t}. p -> (t -> [t] -> p) -> [t] -> p
table ([a] -> r
f []) ((a -> [a] -> r) -> a -> [a] -> r
Memo a
m (\a
x -> Int -> Memo a -> Memo [a]
forall a. Int -> Memo a -> Memo [a]
boundedList (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) (a -> r) -> a -> r
Memo a
m ([a] -> r
f ([a] -> r) -> ([a] -> [a]) -> [a] -> r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
xa -> [a] -> [a]
forall a. a -> [a] -> [a]
:))))
    where
    table :: p -> (t -> [t] -> p) -> [t] -> p
table p
nil t -> [t] -> p
cons [] = p
nil
    table p
nil t -> [t] -> p
cons (t
x:[t]
xs) = t -> [t] -> p
cons t
x [t]
xs

either :: Memo a -> Memo b -> Memo (Either a b)
either :: forall a b. Memo a -> Memo b -> Memo (Either a b)
either Memo a
m Memo b
m' Either a b -> r
f = (a -> r) -> (b -> r) -> Either a b -> r
forall {t} {t} {t}. (t -> t) -> (t -> t) -> Either t t -> t
table ((a -> r) -> a -> r
Memo a
m (Either a b -> r
f (Either a b -> r) -> (a -> Either a b) -> a -> r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Either a b
forall a b. a -> Either a b
Left)) ((b -> r) -> b -> r
Memo b
m' (Either a b -> r
f (Either a b -> r) -> (b -> Either a b) -> b -> r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> Either a b
forall a b. b -> Either a b
Right))
    where
    table :: (t -> t) -> (t -> t) -> Either t t -> t
table t -> t
l t -> t
r (Left t
x) = t -> t
l t
x
    table t -> t
l t -> t
r (Right t
x) = t -> t
r t
x

maybe :: Memo a -> Memo (Maybe a)
maybe :: forall a. Memo a -> Memo (Maybe a)
maybe Memo a
m Maybe a -> r
f = r -> (a -> r) -> Maybe a -> r
forall {p} {t}. p -> (t -> p) -> Maybe t -> p
table (Maybe a -> r
f Maybe a
forall a. Maybe a
Nothing) ((a -> r) -> a -> r
Memo a
m (Maybe a -> r
f (Maybe a -> r) -> (a -> Maybe a) -> a -> r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Maybe a
forall a. a -> Maybe a
Just))
    where
    table :: p -> (t -> p) -> Maybe t -> p
table p
n t -> p
j Maybe t
Nothing = p
n
    table p
n t -> p
j (Just t
x) = t -> p
j t
x

unit :: Memo ()
unit :: Memo ()
unit () -> r
f = let m :: r
m = () -> r
f () in \() -> r
m

pair :: Memo a -> Memo b -> Memo (a,b)
pair :: forall a b. Memo a -> Memo b -> Memo (a, b)
pair Memo a
m Memo b
m' (a, b) -> r
f = (a -> b -> r) -> (a, b) -> r
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((a -> b -> r) -> a -> b -> r
Memo a
m (\a
x -> (b -> r) -> b -> r
Memo b
m' (\b
y -> (a, b) -> r
f (a
x,b
y))))

-- | Memoize an enum type.
enum :: (Enum a) => Memo a
enum :: forall a. Enum a => Memo a
enum = (Int -> a) -> (a -> Int) -> Memo Int -> Memo a
forall a b. (a -> b) -> (b -> a) -> Memo a -> Memo b
wrap Int -> a
forall a. Enum a => Int -> a
toEnum a -> Int
forall a. Enum a => a -> Int
fromEnum (Int -> r) -> Int -> r
forall a. Integral a => Memo a
Memo Int
integral

-- | Memoize an integral type.
integral :: (Integral a) => Memo a
integral :: forall a. Integral a => Memo a
integral = (Integer -> a) -> (a -> Integer) -> Memo Integer -> Memo a
forall a b. (a -> b) -> (b -> a) -> Memo a -> Memo b
wrap Integer -> a
forall a. Num a => Integer -> a
fromInteger a -> Integer
forall a. Integral a => a -> Integer
toInteger (Integer -> r) -> Integer -> r
forall a. (Num a, Ord a, Bits a) => Memo a
Memo Integer
bits

-- | Memoize an ordered type with a bits instance.
bits :: (Num a, Ord a, Bits a) => Memo a
bits :: forall a. (Num a, Ord a, Bits a) => Memo a
bits a -> r
f = IntTrie r -> a -> r
forall b a. (Ord b, Num b, Bits b) => IntTrie a -> b -> a
IntTrie.apply ((a -> r) -> IntTrie a -> IntTrie r
forall a b. (a -> b) -> IntTrie a -> IntTrie b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> r
f IntTrie a
forall a. (Num a, Bits a) => IntTrie a
IntTrie.identity)

-- | @switch p a b@ uses the memo table a whenever p gives
-- true and the memo table b whenever p gives false.
switch :: (a -> Bool) -> Memo a -> Memo a -> Memo a
switch :: forall a. (a -> Bool) -> Memo a -> Memo a -> Memo a
switch a -> Bool
p Memo a
m Memo a
m' a -> r
f = (a -> r) -> (a -> r) -> a -> r
forall {t}. (a -> t) -> (a -> t) -> a -> t
table ((a -> r) -> a -> r
Memo a
m a -> r
f) ((a -> r) -> a -> r
Memo a
m' a -> r
f)
    where
    table :: (a -> t) -> (a -> t) -> a -> t
table a -> t
t a -> t
f a
x | a -> Bool
p a
x       = a -> t
t a
x
                | Bool
otherwise = a -> t
f a
x

-- | The type of builders for ranged tables; takes a lower bound and an upper
-- bound, and returns a memo table for that range.
type RangeMemo a = (a,a) -> Memo a

-- | Build a memo table for a range using a flat array.  If items are
-- given outside the range, don't memoize.
arrayRange :: (Array.Ix a) => RangeMemo a
arrayRange :: forall a. Ix a => RangeMemo a
arrayRange (a, a)
rng = (a -> Bool) -> Memo a -> Memo a -> Memo a
forall a. (a -> Bool) -> Memo a -> Memo a -> Memo a
switch ((a, a) -> a -> Bool
forall a. Ix a => (a, a) -> a -> Bool
Array.inRange (a, a)
rng) (RangeMemo a
forall a. Ix a => RangeMemo a
unsafeArrayRange (a, a)
rng) (a -> r) -> a -> r
forall a. a -> a
Memo a
id

-- | Build a memo table for a range using a flat array.  If items are
-- given outside the range, behavior is undefined.
unsafeArrayRange :: (Array.Ix a) => RangeMemo a
unsafeArrayRange :: forall a. Ix a => RangeMemo a
unsafeArrayRange (a, a)
rng a -> r
f = ((a, a) -> [r] -> Array a r
forall i e. Ix i => (i, i) -> [e] -> Array i e
Array.listArray (a, a)
rng ((a -> r) -> [a] -> [r]
forall a b. (a -> b) -> [a] -> [b]
map a -> r
f ((a, a) -> [a]
forall a. Ix a => (a, a) -> [a]
Array.range (a, a)
rng)) Array a r -> a -> r
forall i e. Ix i => Array i e -> i -> e
Array.!)


-- | Given a list of ranges, (lazily) build a memo table for each one
-- and combine them using linear search.
chunks :: (Array.Ix a) => RangeMemo a -> [(a,a)] -> Memo a
chunks :: forall a. Ix a => RangeMemo a -> [(a, a)] -> Memo a
chunks RangeMemo a
rmemo [(a, a)]
cs a -> r
f = [((a, a), a -> r)] -> a -> r
forall {t} {a}. Ix t => [((t, t), t -> a)] -> t -> a
lookup ([(a, a)]
cs [(a, a)] -> [a -> r] -> [((a, a), a -> r)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` ((a, a) -> a -> r) -> [(a, a)] -> [a -> r]
forall a b. (a -> b) -> [a] -> [b]
map (\(a, a)
rng -> RangeMemo a
rmemo (a, a)
rng a -> r
f) [(a, a)]
cs)
    where
    lookup :: [((t, t), t -> a)] -> t -> a
lookup [] t
_ = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Element non in table"
    lookup (((t, t)
r,t -> a
c):[((t, t), t -> a)]
cs) t
x | (t, t) -> t -> Bool
forall a. Ix a => (a, a) -> a -> Bool
Array.inRange (t, t)
r t
x = t -> a
c t
x
                        | Bool
otherwise = [((t, t), t -> a)] -> t -> a
lookup [((t, t), t -> a)]
cs t
x