1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
|
{-# LANGUAGE FlexibleContexts #-}
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
-----------------------------------------------------------------------------
-- |
-- Module : Internal.Chain
-- Copyright : (c) Vivian McPhail 2010
-- License : BSD3
--
-- Maintainer : Vivian McPhail <haskell.vivian.mcphail <at> gmail.com>
-- Stability : provisional
-- Portability : portable
--
-- optimisation of association order for chains of matrix multiplication
--
-----------------------------------------------------------------------------
{-# LANGUAGE FlexibleContexts #-}
module Internal.Chain (
optimiseMult,
) where
import Data.Maybe
import Internal.Matrix
import Internal.Numeric
import qualified Data.Array.IArray as A
-----------------------------------------------------------------------------
{- |
Provide optimal association order for a chain of matrix multiplications
and apply the multiplications.
The algorithm is the well-known O(n\^3) dynamic programming algorithm
that builds a pyramid of optimal associations.
> m1, m2, m3, m4 :: Matrix Double
> m1 = (10><15) [1..]
> m2 = (15><20) [1..]
> m3 = (20><5) [1..]
> m4 = (5><10) [1..]
> >>> optimiseMult [m1,m2,m3,m4]
will perform @((m1 `multiply` (m2 `multiply` m3)) `multiply` m4)@
The naive left-to-right multiplication would take @4500@ scalar multiplications
whereas the optimised version performs @2750@ scalar multiplications. The complexity
in this case is 32 (= 4^3/2) * (2 comparisons, 3 scalar multiplications, 3 scalar additions,
5 lookups, 2 updates) + a constant (= three table allocations)
-}
optimiseMult :: Product t => [Matrix t] -> Matrix t
optimiseMult = chain
-----------------------------------------------------------------------------
type Matrices a = A.Array Int (Matrix a)
type Sizes = A.Array Int (Int,Int)
type Cost = A.Array Int (A.Array Int (Maybe Int))
type Indexes = A.Array Int (A.Array Int (Maybe ((Int,Int),(Int,Int))))
update :: A.Array Int (A.Array Int a) -> (Int,Int) -> a -> A.Array Int (A.Array Int a)
update a (r,c) e = a A.// [(r,(a A.! r) A.// [(c,e)])]
newWorkSpaceCost :: Int -> A.Array Int (A.Array Int (Maybe Int))
newWorkSpaceCost n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n]
where subArray i = A.listArray (1,i) (repeat Nothing)
newWorkSpaceIndexes :: Int -> A.Array Int (A.Array Int (Maybe ((Int,Int),(Int,Int))))
newWorkSpaceIndexes n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n]
where subArray i = A.listArray (1,i) (repeat Nothing)
matricesToSizes :: [Matrix a] -> Sizes
matricesToSizes ms = A.listArray (1,length ms) $ map (\m -> (rows m,cols m)) ms
chain :: Product a => [Matrix a] -> Matrix a
chain [] = error "chain: zero matrices to multiply"
chain [m] = m
chain [ml,mr] = ml `multiply` mr
chain ms = let ln = length ms
ma = A.listArray (1,ln) ms
mz = matricesToSizes ms
i = chain_cost mz
in chain_paren (ln,ln) i ma
chain_cost :: Sizes -> Indexes
chain_cost mz = let (_,u) = A.bounds mz
cost = newWorkSpaceCost u
ixes = newWorkSpaceIndexes u
(_,_,i) = foldl chain_cost' (mz,cost,ixes) (order u)
in i
chain_cost' :: (Sizes,Cost,Indexes) -> (Int,Int) -> (Sizes,Cost,Indexes)
chain_cost' sci@(mz,cost,ixes) (r,c)
| c == 1 = let cost' = update cost (r,c) (Just 0)
ixes' = update ixes (r,c) (Just ((r,c),(r,c)))
in (mz,cost',ixes')
| otherwise = minimum_cost sci (r,c)
minimum_cost :: (Sizes,Cost,Indexes) -> (Int,Int) -> (Sizes,Cost,Indexes)
minimum_cost sci fu = foldl (smaller_cost fu) sci (fulcrum_order fu)
smaller_cost :: (Int,Int) -> (Sizes,Cost,Indexes) -> ((Int,Int),(Int,Int)) -> (Sizes,Cost,Indexes)
smaller_cost (r,c) (mz,cost,ixes) ix@((lr,lc),(rr,rc)) =
let op_cost = fromJust ((cost A.! lr) A.! lc)
+ fromJust ((cost A.! rr) A.! rc)
+ fst (mz A.! (lr-lc+1))
* snd (mz A.! lc)
* snd (mz A.! rr)
cost' = (cost A.! r) A.! c
in case cost' of
Nothing -> let cost'' = update cost (r,c) (Just op_cost)
ixes'' = update ixes (r,c) (Just ix)
in (mz,cost'',ixes'')
Just ct -> if op_cost < ct then
let cost'' = update cost (r,c) (Just op_cost)
ixes'' = update ixes (r,c) (Just ix)
in (mz,cost'',ixes'')
else (mz,cost,ixes)
fulcrum_order (r,c) = let fs' = zip (repeat r) [1..(c-1)]
in map (partner (r,c)) fs'
partner (r,c) (a,b) = ((r-b, c-b), (a,b))
order 0 = []
order n = order (n-1) ++ zip (repeat n) [1..n]
chain_paren :: Product a => (Int,Int) -> Indexes -> Matrices a -> Matrix a
chain_paren (r,c) ixes ma = let ((lr,lc),(rr,rc)) = fromJust $ (ixes A.! r) A.! c
in if lr == rr && lc == rc then (ma A.! lr)
else (chain_paren (lr,lc) ixes ma) `multiply` (chain_paren (rr,rc) ixes ma)
--------------------------------------------------------------------------
{- TESTS
-- optimal association is ((m1*(m2*m3))*m4)
m1, m2, m3, m4 :: Matrix Double
m1 = (10><15) [1..]
m2 = (15><20) [1..]
m3 = (20><5) [1..]
m4 = (5><10) [1..]
-}
|