summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Chain.hs
blob: fa518d1035efab9ed0535595a716419b28050bed (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
{-# LANGUAGE FlexibleContexts #-}

-----------------------------------------------------------------------------
-- |
-- Module      :  Internal.Chain
-- Copyright   :  (c) Vivian McPhail 2010
-- License     :  BSD3
--
-- Maintainer  :  Vivian McPhail <haskell.vivian.mcphail <at> gmail.com>
-- Stability   :  provisional
-- Portability :  portable
--
-- optimisation of association order for chains of matrix multiplication
--
-----------------------------------------------------------------------------

{-# LANGUAGE FlexibleContexts #-}

module Internal.Chain (
                      optimiseMult,
                     ) where

import Data.Maybe

import Internal.Matrix hiding (order)
import Internal.Numeric

import qualified Data.Array.IArray as A

-----------------------------------------------------------------------------
{- | 
     Provide optimal association order for a chain of matrix multiplications 
     and apply the multiplications.

     The algorithm is the well-known O(n\^3) dynamic programming algorithm
     that builds a pyramid of optimal associations.

> m1, m2, m3, m4 :: Matrix Double
> m1 = (10><15) [1..]
> m2 = (15><20) [1..]
> m3 = (20><5) [1..]
> m4 = (5><10) [1..]

> >>> optimiseMult [m1,m2,m3,m4]

will perform @((m1 `multiply` (m2 `multiply` m3)) `multiply` m4)@

The naive left-to-right multiplication would take @4500@ scalar multiplications
whereas the optimised version performs @2750@ scalar multiplications.  The complexity
in this case is 32 (= 4^3/2) * (2 comparisons, 3 scalar multiplications, 3 scalar additions,
5 lookups, 2 updates) + a constant (= three table allocations)
-}
optimiseMult :: Product t => [Matrix t] -> Matrix t
optimiseMult = chain

-----------------------------------------------------------------------------

type Matrices a = A.Array Int (Matrix a)
type Sizes      = A.Array Int (Int,Int)
type Cost       = A.Array Int (A.Array Int (Maybe Int))
type Indexes    = A.Array Int (A.Array Int (Maybe ((Int,Int),(Int,Int))))

update :: A.Array Int (A.Array Int a) -> (Int,Int) -> a -> A.Array Int (A.Array Int a)
update a (r,c) e = a A.// [(r,(a A.! r) A.// [(c,e)])]

newWorkSpaceCost :: Int -> A.Array Int (A.Array Int (Maybe Int))
newWorkSpaceCost n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n]
   where subArray i = A.listArray (1,i) (repeat Nothing)

newWorkSpaceIndexes :: Int -> A.Array Int (A.Array Int (Maybe ((Int,Int),(Int,Int))))
newWorkSpaceIndexes n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n]
   where subArray i = A.listArray (1,i) (repeat Nothing)

matricesToSizes :: [Matrix a] -> Sizes
matricesToSizes ms = A.listArray (1,length ms) $ map (\m -> (rows m,cols m)) ms

chain :: Product a => [Matrix a] -> Matrix a
chain []  = error "chain: zero matrices to multiply"
chain [m] = m
chain [ml,mr] = ml `multiply` mr
chain ms = let ln = length ms
               ma = A.listArray (1,ln) ms
               mz = matricesToSizes ms
               i = chain_cost mz
           in chain_paren (ln,ln) i ma

chain_cost :: Sizes -> Indexes
chain_cost mz = let (_,u) = A.bounds mz
                    cost = newWorkSpaceCost u
                    ixes = newWorkSpaceIndexes u
                    (_,_,i) =  foldl chain_cost' (mz,cost,ixes) (order u)
                in i

chain_cost' :: (Sizes,Cost,Indexes) -> (Int,Int) -> (Sizes,Cost,Indexes)
chain_cost' sci@(mz,cost,ixes) (r,c) 
    | c == 1                     = let cost' = update cost (r,c) (Just 0)
                                       ixes' = update ixes (r,c) (Just ((r,c),(r,c)))
                                       in (mz,cost',ixes')
    | otherwise                  = minimum_cost sci (r,c)

minimum_cost :: (Sizes,Cost,Indexes) -> (Int,Int) -> (Sizes,Cost,Indexes)
minimum_cost sci fu = foldl (smaller_cost fu) sci (fulcrum_order fu)

smaller_cost :: (Int,Int) -> (Sizes,Cost,Indexes) -> ((Int,Int),(Int,Int)) -> (Sizes,Cost,Indexes)
smaller_cost (r,c) (mz,cost,ixes) ix@((lr,lc),(rr,rc)) =
    let op_cost =   fromJust ((cost A.! lr) A.! lc)
               + fromJust ((cost A.! rr) A.! rc)
               + fst (mz A.! (lr-lc+1))
                 * snd (mz A.! lc)
                 * snd (mz A.! rr)
        cost' = (cost A.! r) A.! c
    in case cost' of
               Nothing -> let cost'' = update cost (r,c) (Just op_cost)
                              ixes'' = update ixes (r,c) (Just ix)
                          in (mz,cost'',ixes'')
               Just ct -> if op_cost < ct then
                          let cost'' = update cost (r,c) (Just op_cost)
                              ixes'' = update ixes (r,c) (Just ix)
                          in (mz,cost'',ixes'')
                          else (mz,cost,ixes)
                                                                         

fulcrum_order (r,c) = let fs' = zip (repeat r) [1..(c-1)]
                      in map (partner (r,c)) fs'

partner (r,c) (a,b) = ((r-b, c-b), (a,b))

order 0 = []
order n = order (n-1) ++ zip (repeat n) [1..n]

chain_paren :: Product a => (Int,Int) -> Indexes -> Matrices a -> Matrix a
chain_paren (r,c) ixes ma = let ((lr,lc),(rr,rc)) = fromJust $ (ixes A.! r) A.! c
                            in if lr == rr && lc == rc then (ma A.! lr)
                               else (chain_paren (lr,lc) ixes ma) `multiply` (chain_paren (rr,rc) ixes ma) 

--------------------------------------------------------------------------

{- TESTS

-- optimal association is ((m1*(m2*m3))*m4)
m1, m2, m3, m4 :: Matrix Double
m1 = (10><15) [1..]
m2 = (15><20) [1..]
m3 = (20><5) [1..]
m4 = (5><10) [1..]

-}