diff options
author | Vivian McPhail <haskell.vivian.mcphail@gmail.com> | 2010-09-11 12:06:13 +0000 |
---|---|---|
committer | Vivian McPhail <haskell.vivian.mcphail@gmail.com> | 2010-09-11 12:06:13 +0000 |
commit | a519a29770a6ef8d08dea3b3e7971ed1f4084126 (patch) | |
tree | f5420ebeded9aa5b524a50af4709898e2de4e588 /lib/Numeric | |
parent | 6859c5712a85950b5bc3de3fe8352f4592bc273b (diff) |
add optimised chain matrix multiplication
Diffstat (limited to 'lib/Numeric')
-rw-r--r-- | lib/Numeric/Chain.hs | 110 | ||||
-rw-r--r-- | lib/Numeric/Container.hs | 2 | ||||
-rw-r--r-- | lib/Numeric/Matrix.hs | 2 |
3 files changed, 114 insertions, 0 deletions
diff --git a/lib/Numeric/Chain.hs b/lib/Numeric/Chain.hs new file mode 100644 index 0000000..0c33f76 --- /dev/null +++ b/lib/Numeric/Chain.hs | |||
@@ -0,0 +1,110 @@ | |||
1 | ----------------------------------------------------------------------------- | ||
2 | -- | | ||
3 | -- Module : Numeric.Chain | ||
4 | -- Copyright : (c) Alberto Ruiz 2010 | ||
5 | -- License : GPL-style | ||
6 | -- | ||
7 | -- Maintainer : Alberto Ruiz <aruiz@um.es> | ||
8 | -- Stability : provisional | ||
9 | -- Portability : portable | ||
10 | -- | ||
11 | -- optimisation of association order for chains of matrix multiplication | ||
12 | -- | ||
13 | ----------------------------------------------------------------------------- | ||
14 | |||
15 | module Numeric.Chain ( | ||
16 | chain | ||
17 | ) where | ||
18 | |||
19 | import Data.Maybe | ||
20 | |||
21 | import Data.Packed.Matrix | ||
22 | import Numeric.Container | ||
23 | |||
24 | import qualified Data.Array.IArray as A | ||
25 | |||
26 | type Matrices a = A.Array Int (Matrix a) | ||
27 | type Sizes = A.Array Int (Int,Int) | ||
28 | type Cost = A.Array Int (A.Array Int (Maybe Int)) | ||
29 | type Indexes = A.Array Int (A.Array Int (Maybe ((Int,Int),(Int,Int)))) | ||
30 | |||
31 | update :: A.Array Int (A.Array Int a) -> (Int,Int) -> a -> A.Array Int (A.Array Int a) | ||
32 | update a (r,c) e = a A.// [(r,(a A.! r) A.// [(c,e)])] | ||
33 | |||
34 | newWorkSpaceCost :: Int -> A.Array Int (A.Array Int (Maybe Int)) | ||
35 | newWorkSpaceCost n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n] | ||
36 | where subArray i = A.listArray (1,i) (repeat Nothing) | ||
37 | |||
38 | newWorkSpaceIndexes :: Int -> A.Array Int (A.Array Int (Maybe ((Int,Int),(Int,Int)))) | ||
39 | newWorkSpaceIndexes n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n] | ||
40 | where subArray i = A.listArray (1,i) (repeat Nothing) | ||
41 | |||
42 | matricesToSizes :: [Matrix a] -> Sizes | ||
43 | matricesToSizes ms = A.listArray (1,length ms) $ map (\m -> (rows m,cols m)) ms | ||
44 | |||
45 | -- | provide optimal association order for a chain of matrix multiplications and apply the multiplications | ||
46 | chain :: Product a => [Matrix a] -> Matrix a | ||
47 | chain ms = let ln = length ms | ||
48 | ma = A.listArray (1,ln) ms | ||
49 | mz = matricesToSizes ms | ||
50 | i = chain_cost mz | ||
51 | in chain_paren (ln,ln) i ma | ||
52 | |||
53 | chain_cost :: Sizes -> Indexes | ||
54 | chain_cost mz = let (_,u) = A.bounds mz | ||
55 | cost = newWorkSpaceCost u | ||
56 | ixes = newWorkSpaceIndexes u | ||
57 | (_,_,i) = foldl chain_cost' (mz,cost,ixes) (order u) | ||
58 | in i | ||
59 | |||
60 | chain_cost' :: (Sizes,Cost,Indexes) -> (Int,Int) -> (Sizes,Cost,Indexes) | ||
61 | chain_cost' sci@(mz,cost,ixes) (r,c) | ||
62 | | c == 1 = let cost' = update cost (r,c) (Just 0) | ||
63 | ixes' = update ixes (r,c) (Just ((r,c),(r,c))) | ||
64 | in (mz,cost',ixes') | ||
65 | | otherwise = minimum_cost sci (r,c) | ||
66 | |||
67 | minimum_cost :: (Sizes,Cost,Indexes) -> (Int,Int) -> (Sizes,Cost,Indexes) | ||
68 | minimum_cost sci fu = foldl (smaller_cost fu) sci (fulcrum_order fu) | ||
69 | |||
70 | smaller_cost :: (Int,Int) -> (Sizes,Cost,Indexes) -> ((Int,Int),(Int,Int)) -> (Sizes,Cost,Indexes) | ||
71 | smaller_cost (r,c) (mz,cost,ixes) ix@((lr,lc),(rr,rc)) = let op_cost = (fromJust ((cost A.! lr) A.! lc)) | ||
72 | + (fromJust ((cost A.! rr) A.! rc)) | ||
73 | + ((fst $ mz A.! (lr-lc+1)) | ||
74 | *(snd $ mz A.! lc) | ||
75 | *(snd $ mz A.! rr)) | ||
76 | cost' = (cost A.! r) A.! c | ||
77 | in case cost' of | ||
78 | Nothing -> let cost'' = update cost (r,c) (Just op_cost) | ||
79 | ixes'' = update ixes (r,c) (Just ix) | ||
80 | in (mz,cost'',ixes'') | ||
81 | Just ct -> if op_cost < ct then | ||
82 | let cost'' = update cost (r,c) (Just op_cost) | ||
83 | ixes'' = update ixes (r,c) (Just ix) | ||
84 | in (mz,cost'',ixes'') | ||
85 | else (mz,cost,ixes) | ||
86 | |||
87 | |||
88 | fulcrum_order (r,c) = let fs' = zip (repeat r) [1..(c-1)] | ||
89 | in map (partner (r,c)) fs' | ||
90 | |||
91 | partner (r,c) (a,b) = (((r-b),(c-b)),(a,b)) | ||
92 | |||
93 | order 0 = [] | ||
94 | order n = (order (n-1)) ++ (zip (repeat n) [1..n]) | ||
95 | |||
96 | chain_paren :: Product a => (Int,Int) -> Indexes -> Matrices a -> Matrix a | ||
97 | chain_paren (r,c) ixes ma = let ((lr,lc),(rr,rc)) = fromJust $ (ixes A.! r) A.! c | ||
98 | in if lr == rr && lc == rc then (ma A.! lr) | ||
99 | else (chain_paren (lr,lc) ixes ma) `multiply` (chain_paren (rr,rc) ixes ma) | ||
100 | |||
101 | -------------------------------------------------------------------------- | ||
102 | |||
103 | {- TESTS -} | ||
104 | |||
105 | -- optimal association is ((m1*(m2*m3))*m4) | ||
106 | m1, m2, m3, m4 :: Matrix Double | ||
107 | m1 = (10><15) [1..] | ||
108 | m2 = (15><20) [1..] | ||
109 | m3 = (20><5) [1..] | ||
110 | m4 = (5><10) [1..] \ No newline at end of file | ||
diff --git a/lib/Numeric/Container.hs b/lib/Numeric/Container.hs index 2c9c500..aaa068f 100644 --- a/lib/Numeric/Container.hs +++ b/lib/Numeric/Container.hs | |||
@@ -308,3 +308,5 @@ kronecker a b = fromBlocks | |||
308 | . map (reshape (cols b)) | 308 | . map (reshape (cols b)) |
309 | . toRows | 309 | . toRows |
310 | $ flatten a `outer` flatten b | 310 | $ flatten a `outer` flatten b |
311 | |||
312 | ---------------------------------------------------------- | ||
diff --git a/lib/Numeric/Matrix.hs b/lib/Numeric/Matrix.hs index f462e88..ce36ef2 100644 --- a/lib/Numeric/Matrix.hs +++ b/lib/Numeric/Matrix.hs | |||
@@ -26,6 +26,7 @@ module Numeric.Matrix ( | |||
26 | module Data.Packed.Matrix, | 26 | module Data.Packed.Matrix, |
27 | module Numeric.Vector, | 27 | module Numeric.Vector, |
28 | --module Numeric.Container, | 28 | --module Numeric.Container, |
29 | chain, | ||
29 | -- * Operators | 30 | -- * Operators |
30 | (<>), (<\>), | 31 | (<>), (<\>), |
31 | -- * Deprecated | 32 | -- * Deprecated |
@@ -40,6 +41,7 @@ import Data.Packed.Matrix | |||
40 | 41 | ||
41 | import Numeric.Container | 42 | import Numeric.Container |
42 | import Numeric.Vector | 43 | import Numeric.Vector |
44 | import Numeric.Chain | ||
43 | import Numeric.LinearAlgebra.Algorithms | 45 | import Numeric.LinearAlgebra.Algorithms |
44 | 46 | ||
45 | ------------------------------------------------------------------- | 47 | ------------------------------------------------------------------- |