summaryrefslogtreecommitdiff
path: root/lib/Numeric
diff options
context:
space:
mode:
authorVivian McPhail <haskell.vivian.mcphail@gmail.com>2010-09-11 12:06:13 +0000
committerVivian McPhail <haskell.vivian.mcphail@gmail.com>2010-09-11 12:06:13 +0000
commita519a29770a6ef8d08dea3b3e7971ed1f4084126 (patch)
treef5420ebeded9aa5b524a50af4709898e2de4e588 /lib/Numeric
parent6859c5712a85950b5bc3de3fe8352f4592bc273b (diff)
add optimised chain matrix multiplication
Diffstat (limited to 'lib/Numeric')
-rw-r--r--lib/Numeric/Chain.hs110
-rw-r--r--lib/Numeric/Container.hs2
-rw-r--r--lib/Numeric/Matrix.hs2
3 files changed, 114 insertions, 0 deletions
diff --git a/lib/Numeric/Chain.hs b/lib/Numeric/Chain.hs
new file mode 100644
index 0000000..0c33f76
--- /dev/null
+++ b/lib/Numeric/Chain.hs
@@ -0,0 +1,110 @@
1-----------------------------------------------------------------------------
2-- |
3-- Module : Numeric.Chain
4-- Copyright : (c) Alberto Ruiz 2010
5-- License : GPL-style
6--
7-- Maintainer : Alberto Ruiz <aruiz@um.es>
8-- Stability : provisional
9-- Portability : portable
10--
11-- optimisation of association order for chains of matrix multiplication
12--
13-----------------------------------------------------------------------------
14
15module Numeric.Chain (
16 chain
17 ) where
18
19import Data.Maybe
20
21import Data.Packed.Matrix
22import Numeric.Container
23
24import qualified Data.Array.IArray as A
25
26type Matrices a = A.Array Int (Matrix a)
27type Sizes = A.Array Int (Int,Int)
28type Cost = A.Array Int (A.Array Int (Maybe Int))
29type Indexes = A.Array Int (A.Array Int (Maybe ((Int,Int),(Int,Int))))
30
31update :: A.Array Int (A.Array Int a) -> (Int,Int) -> a -> A.Array Int (A.Array Int a)
32update a (r,c) e = a A.// [(r,(a A.! r) A.// [(c,e)])]
33
34newWorkSpaceCost :: Int -> A.Array Int (A.Array Int (Maybe Int))
35newWorkSpaceCost n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n]
36 where subArray i = A.listArray (1,i) (repeat Nothing)
37
38newWorkSpaceIndexes :: Int -> A.Array Int (A.Array Int (Maybe ((Int,Int),(Int,Int))))
39newWorkSpaceIndexes n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n]
40 where subArray i = A.listArray (1,i) (repeat Nothing)
41
42matricesToSizes :: [Matrix a] -> Sizes
43matricesToSizes ms = A.listArray (1,length ms) $ map (\m -> (rows m,cols m)) ms
44
45-- | provide optimal association order for a chain of matrix multiplications and apply the multiplications
46chain :: Product a => [Matrix a] -> Matrix a
47chain ms = let ln = length ms
48 ma = A.listArray (1,ln) ms
49 mz = matricesToSizes ms
50 i = chain_cost mz
51 in chain_paren (ln,ln) i ma
52
53chain_cost :: Sizes -> Indexes
54chain_cost mz = let (_,u) = A.bounds mz
55 cost = newWorkSpaceCost u
56 ixes = newWorkSpaceIndexes u
57 (_,_,i) = foldl chain_cost' (mz,cost,ixes) (order u)
58 in i
59
60chain_cost' :: (Sizes,Cost,Indexes) -> (Int,Int) -> (Sizes,Cost,Indexes)
61chain_cost' sci@(mz,cost,ixes) (r,c)
62 | c == 1 = let cost' = update cost (r,c) (Just 0)
63 ixes' = update ixes (r,c) (Just ((r,c),(r,c)))
64 in (mz,cost',ixes')
65 | otherwise = minimum_cost sci (r,c)
66
67minimum_cost :: (Sizes,Cost,Indexes) -> (Int,Int) -> (Sizes,Cost,Indexes)
68minimum_cost sci fu = foldl (smaller_cost fu) sci (fulcrum_order fu)
69
70smaller_cost :: (Int,Int) -> (Sizes,Cost,Indexes) -> ((Int,Int),(Int,Int)) -> (Sizes,Cost,Indexes)
71smaller_cost (r,c) (mz,cost,ixes) ix@((lr,lc),(rr,rc)) = let op_cost = (fromJust ((cost A.! lr) A.! lc))
72 + (fromJust ((cost A.! rr) A.! rc))
73 + ((fst $ mz A.! (lr-lc+1))
74 *(snd $ mz A.! lc)
75 *(snd $ mz A.! rr))
76 cost' = (cost A.! r) A.! c
77 in case cost' of
78 Nothing -> let cost'' = update cost (r,c) (Just op_cost)
79 ixes'' = update ixes (r,c) (Just ix)
80 in (mz,cost'',ixes'')
81 Just ct -> if op_cost < ct then
82 let cost'' = update cost (r,c) (Just op_cost)
83 ixes'' = update ixes (r,c) (Just ix)
84 in (mz,cost'',ixes'')
85 else (mz,cost,ixes)
86
87
88fulcrum_order (r,c) = let fs' = zip (repeat r) [1..(c-1)]
89 in map (partner (r,c)) fs'
90
91partner (r,c) (a,b) = (((r-b),(c-b)),(a,b))
92
93order 0 = []
94order n = (order (n-1)) ++ (zip (repeat n) [1..n])
95
96chain_paren :: Product a => (Int,Int) -> Indexes -> Matrices a -> Matrix a
97chain_paren (r,c) ixes ma = let ((lr,lc),(rr,rc)) = fromJust $ (ixes A.! r) A.! c
98 in if lr == rr && lc == rc then (ma A.! lr)
99 else (chain_paren (lr,lc) ixes ma) `multiply` (chain_paren (rr,rc) ixes ma)
100
101--------------------------------------------------------------------------
102
103{- TESTS -}
104
105-- optimal association is ((m1*(m2*m3))*m4)
106m1, m2, m3, m4 :: Matrix Double
107m1 = (10><15) [1..]
108m2 = (15><20) [1..]
109m3 = (20><5) [1..]
110m4 = (5><10) [1..] \ No newline at end of file
diff --git a/lib/Numeric/Container.hs b/lib/Numeric/Container.hs
index 2c9c500..aaa068f 100644
--- a/lib/Numeric/Container.hs
+++ b/lib/Numeric/Container.hs
@@ -308,3 +308,5 @@ kronecker a b = fromBlocks
308 . map (reshape (cols b)) 308 . map (reshape (cols b))
309 . toRows 309 . toRows
310 $ flatten a `outer` flatten b 310 $ flatten a `outer` flatten b
311
312----------------------------------------------------------
diff --git a/lib/Numeric/Matrix.hs b/lib/Numeric/Matrix.hs
index f462e88..ce36ef2 100644
--- a/lib/Numeric/Matrix.hs
+++ b/lib/Numeric/Matrix.hs
@@ -26,6 +26,7 @@ module Numeric.Matrix (
26 module Data.Packed.Matrix, 26 module Data.Packed.Matrix,
27 module Numeric.Vector, 27 module Numeric.Vector,
28 --module Numeric.Container, 28 --module Numeric.Container,
29 chain,
29 -- * Operators 30 -- * Operators
30 (<>), (<\>), 31 (<>), (<\>),
31 -- * Deprecated 32 -- * Deprecated
@@ -40,6 +41,7 @@ import Data.Packed.Matrix
40 41
41import Numeric.Container 42import Numeric.Container
42import Numeric.Vector 43import Numeric.Vector
44import Numeric.Chain
43import Numeric.LinearAlgebra.Algorithms 45import Numeric.LinearAlgebra.Algorithms
44 46
45------------------------------------------------------------------- 47-------------------------------------------------------------------