summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal')
-rw-r--r--packages/base/src/Internal/Chain.hs148
1 files changed, 148 insertions, 0 deletions
diff --git a/packages/base/src/Internal/Chain.hs b/packages/base/src/Internal/Chain.hs
new file mode 100644
index 0000000..fa518d1
--- /dev/null
+++ b/packages/base/src/Internal/Chain.hs
@@ -0,0 +1,148 @@
1{-# LANGUAGE FlexibleContexts #-}
2
3-----------------------------------------------------------------------------
4-- |
5-- Module : Internal.Chain
6-- Copyright : (c) Vivian McPhail 2010
7-- License : BSD3
8--
9-- Maintainer : Vivian McPhail <haskell.vivian.mcphail <at> gmail.com>
10-- Stability : provisional
11-- Portability : portable
12--
13-- optimisation of association order for chains of matrix multiplication
14--
15-----------------------------------------------------------------------------
16
17{-# LANGUAGE FlexibleContexts #-}
18
19module Internal.Chain (
20 optimiseMult,
21 ) where
22
23import Data.Maybe
24
25import Internal.Matrix hiding (order)
26import Internal.Numeric
27
28import qualified Data.Array.IArray as A
29
30-----------------------------------------------------------------------------
31{- |
32 Provide optimal association order for a chain of matrix multiplications
33 and apply the multiplications.
34
35 The algorithm is the well-known O(n\^3) dynamic programming algorithm
36 that builds a pyramid of optimal associations.
37
38> m1, m2, m3, m4 :: Matrix Double
39> m1 = (10><15) [1..]
40> m2 = (15><20) [1..]
41> m3 = (20><5) [1..]
42> m4 = (5><10) [1..]
43
44> >>> optimiseMult [m1,m2,m3,m4]
45
46will perform @((m1 `multiply` (m2 `multiply` m3)) `multiply` m4)@
47
48The naive left-to-right multiplication would take @4500@ scalar multiplications
49whereas the optimised version performs @2750@ scalar multiplications. The complexity
50in this case is 32 (= 4^3/2) * (2 comparisons, 3 scalar multiplications, 3 scalar additions,
515 lookups, 2 updates) + a constant (= three table allocations)
52-}
53optimiseMult :: Product t => [Matrix t] -> Matrix t
54optimiseMult = chain
55
56-----------------------------------------------------------------------------
57
58type Matrices a = A.Array Int (Matrix a)
59type Sizes = A.Array Int (Int,Int)
60type Cost = A.Array Int (A.Array Int (Maybe Int))
61type Indexes = A.Array Int (A.Array Int (Maybe ((Int,Int),(Int,Int))))
62
63update :: A.Array Int (A.Array Int a) -> (Int,Int) -> a -> A.Array Int (A.Array Int a)
64update a (r,c) e = a A.// [(r,(a A.! r) A.// [(c,e)])]
65
66newWorkSpaceCost :: Int -> A.Array Int (A.Array Int (Maybe Int))
67newWorkSpaceCost n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n]
68 where subArray i = A.listArray (1,i) (repeat Nothing)
69
70newWorkSpaceIndexes :: Int -> A.Array Int (A.Array Int (Maybe ((Int,Int),(Int,Int))))
71newWorkSpaceIndexes n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n]
72 where subArray i = A.listArray (1,i) (repeat Nothing)
73
74matricesToSizes :: [Matrix a] -> Sizes
75matricesToSizes ms = A.listArray (1,length ms) $ map (\m -> (rows m,cols m)) ms
76
77chain :: Product a => [Matrix a] -> Matrix a
78chain [] = error "chain: zero matrices to multiply"
79chain [m] = m
80chain [ml,mr] = ml `multiply` mr
81chain ms = let ln = length ms
82 ma = A.listArray (1,ln) ms
83 mz = matricesToSizes ms
84 i = chain_cost mz
85 in chain_paren (ln,ln) i ma
86
87chain_cost :: Sizes -> Indexes
88chain_cost mz = let (_,u) = A.bounds mz
89 cost = newWorkSpaceCost u
90 ixes = newWorkSpaceIndexes u
91 (_,_,i) = foldl chain_cost' (mz,cost,ixes) (order u)
92 in i
93
94chain_cost' :: (Sizes,Cost,Indexes) -> (Int,Int) -> (Sizes,Cost,Indexes)
95chain_cost' sci@(mz,cost,ixes) (r,c)
96 | c == 1 = let cost' = update cost (r,c) (Just 0)
97 ixes' = update ixes (r,c) (Just ((r,c),(r,c)))
98 in (mz,cost',ixes')
99 | otherwise = minimum_cost sci (r,c)
100
101minimum_cost :: (Sizes,Cost,Indexes) -> (Int,Int) -> (Sizes,Cost,Indexes)
102minimum_cost sci fu = foldl (smaller_cost fu) sci (fulcrum_order fu)
103
104smaller_cost :: (Int,Int) -> (Sizes,Cost,Indexes) -> ((Int,Int),(Int,Int)) -> (Sizes,Cost,Indexes)
105smaller_cost (r,c) (mz,cost,ixes) ix@((lr,lc),(rr,rc)) =
106 let op_cost = fromJust ((cost A.! lr) A.! lc)
107 + fromJust ((cost A.! rr) A.! rc)
108 + fst (mz A.! (lr-lc+1))
109 * snd (mz A.! lc)
110 * snd (mz A.! rr)
111 cost' = (cost A.! r) A.! c
112 in case cost' of
113 Nothing -> let cost'' = update cost (r,c) (Just op_cost)
114 ixes'' = update ixes (r,c) (Just ix)
115 in (mz,cost'',ixes'')
116 Just ct -> if op_cost < ct then
117 let cost'' = update cost (r,c) (Just op_cost)
118 ixes'' = update ixes (r,c) (Just ix)
119 in (mz,cost'',ixes'')
120 else (mz,cost,ixes)
121
122
123fulcrum_order (r,c) = let fs' = zip (repeat r) [1..(c-1)]
124 in map (partner (r,c)) fs'
125
126partner (r,c) (a,b) = ((r-b, c-b), (a,b))
127
128order 0 = []
129order n = order (n-1) ++ zip (repeat n) [1..n]
130
131chain_paren :: Product a => (Int,Int) -> Indexes -> Matrices a -> Matrix a
132chain_paren (r,c) ixes ma = let ((lr,lc),(rr,rc)) = fromJust $ (ixes A.! r) A.! c
133 in if lr == rr && lc == rc then (ma A.! lr)
134 else (chain_paren (lr,lc) ixes ma) `multiply` (chain_paren (rr,rc) ixes ma)
135
136--------------------------------------------------------------------------
137
138{- TESTS
139
140-- optimal association is ((m1*(m2*m3))*m4)
141m1, m2, m3, m4 :: Matrix Double
142m1 = (10><15) [1..]
143m2 = (15><20) [1..]
144m3 = (20><5) [1..]
145m4 = (5><10) [1..]
146
147-}
148