diff options
Diffstat (limited to 'packages/base/src/Numeric')
-rw-r--r-- | packages/base/src/Numeric/Chain.hs | 148 |
1 files changed, 0 insertions, 148 deletions
diff --git a/packages/base/src/Numeric/Chain.hs b/packages/base/src/Numeric/Chain.hs deleted file mode 100644 index 64c09c0..0000000 --- a/packages/base/src/Numeric/Chain.hs +++ /dev/null | |||
@@ -1,148 +0,0 @@ | |||
1 | {-# LANGUAGE FlexibleContexts #-} | ||
2 | |||
3 | ----------------------------------------------------------------------------- | ||
4 | -- | | ||
5 | -- Module : Numeric.Chain | ||
6 | -- Copyright : (c) Vivian McPhail 2010 | ||
7 | -- License : BSD3 | ||
8 | -- | ||
9 | -- Maintainer : Vivian McPhail <haskell.vivian.mcphail <at> gmail.com> | ||
10 | -- Stability : provisional | ||
11 | -- Portability : portable | ||
12 | -- | ||
13 | -- optimisation of association order for chains of matrix multiplication | ||
14 | -- | ||
15 | ----------------------------------------------------------------------------- | ||
16 | |||
17 | {-# LANGUAGE FlexibleContexts #-} | ||
18 | |||
19 | module Numeric.Chain ( | ||
20 | optimiseMult, | ||
21 | ) where | ||
22 | |||
23 | import Data.Maybe | ||
24 | |||
25 | import Data.Packed.Matrix | ||
26 | import Data.Packed.Internal.Numeric | ||
27 | |||
28 | import qualified Data.Array.IArray as A | ||
29 | |||
30 | ----------------------------------------------------------------------------- | ||
31 | {- | | ||
32 | Provide optimal association order for a chain of matrix multiplications | ||
33 | and apply the multiplications. | ||
34 | |||
35 | The algorithm is the well-known O(n\^3) dynamic programming algorithm | ||
36 | that builds a pyramid of optimal associations. | ||
37 | |||
38 | > m1, m2, m3, m4 :: Matrix Double | ||
39 | > m1 = (10><15) [1..] | ||
40 | > m2 = (15><20) [1..] | ||
41 | > m3 = (20><5) [1..] | ||
42 | > m4 = (5><10) [1..] | ||
43 | |||
44 | > >>> optimiseMult [m1,m2,m3,m4] | ||
45 | |||
46 | will perform @((m1 `multiply` (m2 `multiply` m3)) `multiply` m4)@ | ||
47 | |||
48 | The naive left-to-right multiplication would take @4500@ scalar multiplications | ||
49 | whereas the optimised version performs @2750@ scalar multiplications. The complexity | ||
50 | in this case is 32 (= 4^3/2) * (2 comparisons, 3 scalar multiplications, 3 scalar additions, | ||
51 | 5 lookups, 2 updates) + a constant (= three table allocations) | ||
52 | -} | ||
53 | optimiseMult :: Product t => [Matrix t] -> Matrix t | ||
54 | optimiseMult = chain | ||
55 | |||
56 | ----------------------------------------------------------------------------- | ||
57 | |||
58 | type Matrices a = A.Array Int (Matrix a) | ||
59 | type Sizes = A.Array Int (Int,Int) | ||
60 | type Cost = A.Array Int (A.Array Int (Maybe Int)) | ||
61 | type Indexes = A.Array Int (A.Array Int (Maybe ((Int,Int),(Int,Int)))) | ||
62 | |||
63 | update :: A.Array Int (A.Array Int a) -> (Int,Int) -> a -> A.Array Int (A.Array Int a) | ||
64 | update a (r,c) e = a A.// [(r,(a A.! r) A.// [(c,e)])] | ||
65 | |||
66 | newWorkSpaceCost :: Int -> A.Array Int (A.Array Int (Maybe Int)) | ||
67 | newWorkSpaceCost n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n] | ||
68 | where subArray i = A.listArray (1,i) (repeat Nothing) | ||
69 | |||
70 | newWorkSpaceIndexes :: Int -> A.Array Int (A.Array Int (Maybe ((Int,Int),(Int,Int)))) | ||
71 | newWorkSpaceIndexes n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n] | ||
72 | where subArray i = A.listArray (1,i) (repeat Nothing) | ||
73 | |||
74 | matricesToSizes :: [Matrix a] -> Sizes | ||
75 | matricesToSizes ms = A.listArray (1,length ms) $ map (\m -> (rows m,cols m)) ms | ||
76 | |||
77 | chain :: Product a => [Matrix a] -> Matrix a | ||
78 | chain [] = error "chain: zero matrices to multiply" | ||
79 | chain [m] = m | ||
80 | chain [ml,mr] = ml `multiply` mr | ||
81 | chain ms = let ln = length ms | ||
82 | ma = A.listArray (1,ln) ms | ||
83 | mz = matricesToSizes ms | ||
84 | i = chain_cost mz | ||
85 | in chain_paren (ln,ln) i ma | ||
86 | |||
87 | chain_cost :: Sizes -> Indexes | ||
88 | chain_cost mz = let (_,u) = A.bounds mz | ||
89 | cost = newWorkSpaceCost u | ||
90 | ixes = newWorkSpaceIndexes u | ||
91 | (_,_,i) = foldl chain_cost' (mz,cost,ixes) (order u) | ||
92 | in i | ||
93 | |||
94 | chain_cost' :: (Sizes,Cost,Indexes) -> (Int,Int) -> (Sizes,Cost,Indexes) | ||
95 | chain_cost' sci@(mz,cost,ixes) (r,c) | ||
96 | | c == 1 = let cost' = update cost (r,c) (Just 0) | ||
97 | ixes' = update ixes (r,c) (Just ((r,c),(r,c))) | ||
98 | in (mz,cost',ixes') | ||
99 | | otherwise = minimum_cost sci (r,c) | ||
100 | |||
101 | minimum_cost :: (Sizes,Cost,Indexes) -> (Int,Int) -> (Sizes,Cost,Indexes) | ||
102 | minimum_cost sci fu = foldl (smaller_cost fu) sci (fulcrum_order fu) | ||
103 | |||
104 | smaller_cost :: (Int,Int) -> (Sizes,Cost,Indexes) -> ((Int,Int),(Int,Int)) -> (Sizes,Cost,Indexes) | ||
105 | smaller_cost (r,c) (mz,cost,ixes) ix@((lr,lc),(rr,rc)) = | ||
106 | let op_cost = fromJust ((cost A.! lr) A.! lc) | ||
107 | + fromJust ((cost A.! rr) A.! rc) | ||
108 | + fst (mz A.! (lr-lc+1)) | ||
109 | * snd (mz A.! lc) | ||
110 | * snd (mz A.! rr) | ||
111 | cost' = (cost A.! r) A.! c | ||
112 | in case cost' of | ||
113 | Nothing -> let cost'' = update cost (r,c) (Just op_cost) | ||
114 | ixes'' = update ixes (r,c) (Just ix) | ||
115 | in (mz,cost'',ixes'') | ||
116 | Just ct -> if op_cost < ct then | ||
117 | let cost'' = update cost (r,c) (Just op_cost) | ||
118 | ixes'' = update ixes (r,c) (Just ix) | ||
119 | in (mz,cost'',ixes'') | ||
120 | else (mz,cost,ixes) | ||
121 | |||
122 | |||
123 | fulcrum_order (r,c) = let fs' = zip (repeat r) [1..(c-1)] | ||
124 | in map (partner (r,c)) fs' | ||
125 | |||
126 | partner (r,c) (a,b) = ((r-b, c-b), (a,b)) | ||
127 | |||
128 | order 0 = [] | ||
129 | order n = order (n-1) ++ zip (repeat n) [1..n] | ||
130 | |||
131 | chain_paren :: Product a => (Int,Int) -> Indexes -> Matrices a -> Matrix a | ||
132 | chain_paren (r,c) ixes ma = let ((lr,lc),(rr,rc)) = fromJust $ (ixes A.! r) A.! c | ||
133 | in if lr == rr && lc == rc then (ma A.! lr) | ||
134 | else (chain_paren (lr,lc) ixes ma) `multiply` (chain_paren (rr,rc) ixes ma) | ||
135 | |||
136 | -------------------------------------------------------------------------- | ||
137 | |||
138 | {- TESTS | ||
139 | |||
140 | -- optimal association is ((m1*(m2*m3))*m4) | ||
141 | m1, m2, m3, m4 :: Matrix Double | ||
142 | m1 = (10><15) [1..] | ||
143 | m2 = (15><20) [1..] | ||
144 | m3 = (20><5) [1..] | ||
145 | m4 = (5><10) [1..] | ||
146 | |||
147 | -} | ||
148 | |||