From 2dae75e9d2b08a23945e936dcd5244b7f0c46107 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Fri, 5 Jun 2015 16:37:30 +0200 Subject: move chain --- packages/base/src/Internal/Chain.hs | 148 ++++++++++++++++++++++++++++++++++++ packages/base/src/Numeric/Chain.hs | 148 ------------------------------------ 2 files changed, 148 insertions(+), 148 deletions(-) create mode 100644 packages/base/src/Internal/Chain.hs delete mode 100644 packages/base/src/Numeric/Chain.hs (limited to 'packages/base') diff --git a/packages/base/src/Internal/Chain.hs b/packages/base/src/Internal/Chain.hs new file mode 100644 index 0000000..fa518d1 --- /dev/null +++ b/packages/base/src/Internal/Chain.hs @@ -0,0 +1,148 @@ +{-# LANGUAGE FlexibleContexts #-} + +----------------------------------------------------------------------------- +-- | +-- Module : Internal.Chain +-- Copyright : (c) Vivian McPhail 2010 +-- License : BSD3 +-- +-- Maintainer : Vivian McPhail gmail.com> +-- Stability : provisional +-- Portability : portable +-- +-- optimisation of association order for chains of matrix multiplication +-- +----------------------------------------------------------------------------- + +{-# LANGUAGE FlexibleContexts #-} + +module Internal.Chain ( + optimiseMult, + ) where + +import Data.Maybe + +import Internal.Matrix hiding (order) +import Internal.Numeric + +import qualified Data.Array.IArray as A + +----------------------------------------------------------------------------- +{- | + Provide optimal association order for a chain of matrix multiplications + and apply the multiplications. + + The algorithm is the well-known O(n\^3) dynamic programming algorithm + that builds a pyramid of optimal associations. + +> m1, m2, m3, m4 :: Matrix Double +> m1 = (10><15) [1..] +> m2 = (15><20) [1..] +> m3 = (20><5) [1..] +> m4 = (5><10) [1..] + +> >>> optimiseMult [m1,m2,m3,m4] + +will perform @((m1 `multiply` (m2 `multiply` m3)) `multiply` m4)@ + +The naive left-to-right multiplication would take @4500@ scalar multiplications +whereas the optimised version performs @2750@ scalar multiplications. The complexity +in this case is 32 (= 4^3/2) * (2 comparisons, 3 scalar multiplications, 3 scalar additions, +5 lookups, 2 updates) + a constant (= three table allocations) +-} +optimiseMult :: Product t => [Matrix t] -> Matrix t +optimiseMult = chain + +----------------------------------------------------------------------------- + +type Matrices a = A.Array Int (Matrix a) +type Sizes = A.Array Int (Int,Int) +type Cost = A.Array Int (A.Array Int (Maybe Int)) +type Indexes = A.Array Int (A.Array Int (Maybe ((Int,Int),(Int,Int)))) + +update :: A.Array Int (A.Array Int a) -> (Int,Int) -> a -> A.Array Int (A.Array Int a) +update a (r,c) e = a A.// [(r,(a A.! r) A.// [(c,e)])] + +newWorkSpaceCost :: Int -> A.Array Int (A.Array Int (Maybe Int)) +newWorkSpaceCost n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n] + where subArray i = A.listArray (1,i) (repeat Nothing) + +newWorkSpaceIndexes :: Int -> A.Array Int (A.Array Int (Maybe ((Int,Int),(Int,Int)))) +newWorkSpaceIndexes n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n] + where subArray i = A.listArray (1,i) (repeat Nothing) + +matricesToSizes :: [Matrix a] -> Sizes +matricesToSizes ms = A.listArray (1,length ms) $ map (\m -> (rows m,cols m)) ms + +chain :: Product a => [Matrix a] -> Matrix a +chain [] = error "chain: zero matrices to multiply" +chain [m] = m +chain [ml,mr] = ml `multiply` mr +chain ms = let ln = length ms + ma = A.listArray (1,ln) ms + mz = matricesToSizes ms + i = chain_cost mz + in chain_paren (ln,ln) i ma + +chain_cost :: Sizes -> Indexes +chain_cost mz = let (_,u) = A.bounds mz + cost = newWorkSpaceCost u + ixes = newWorkSpaceIndexes u + (_,_,i) = foldl chain_cost' (mz,cost,ixes) (order u) + in i + +chain_cost' :: (Sizes,Cost,Indexes) -> (Int,Int) -> (Sizes,Cost,Indexes) +chain_cost' sci@(mz,cost,ixes) (r,c) + | c == 1 = let cost' = update cost (r,c) (Just 0) + ixes' = update ixes (r,c) (Just ((r,c),(r,c))) + in (mz,cost',ixes') + | otherwise = minimum_cost sci (r,c) + +minimum_cost :: (Sizes,Cost,Indexes) -> (Int,Int) -> (Sizes,Cost,Indexes) +minimum_cost sci fu = foldl (smaller_cost fu) sci (fulcrum_order fu) + +smaller_cost :: (Int,Int) -> (Sizes,Cost,Indexes) -> ((Int,Int),(Int,Int)) -> (Sizes,Cost,Indexes) +smaller_cost (r,c) (mz,cost,ixes) ix@((lr,lc),(rr,rc)) = + let op_cost = fromJust ((cost A.! lr) A.! lc) + + fromJust ((cost A.! rr) A.! rc) + + fst (mz A.! (lr-lc+1)) + * snd (mz A.! lc) + * snd (mz A.! rr) + cost' = (cost A.! r) A.! c + in case cost' of + Nothing -> let cost'' = update cost (r,c) (Just op_cost) + ixes'' = update ixes (r,c) (Just ix) + in (mz,cost'',ixes'') + Just ct -> if op_cost < ct then + let cost'' = update cost (r,c) (Just op_cost) + ixes'' = update ixes (r,c) (Just ix) + in (mz,cost'',ixes'') + else (mz,cost,ixes) + + +fulcrum_order (r,c) = let fs' = zip (repeat r) [1..(c-1)] + in map (partner (r,c)) fs' + +partner (r,c) (a,b) = ((r-b, c-b), (a,b)) + +order 0 = [] +order n = order (n-1) ++ zip (repeat n) [1..n] + +chain_paren :: Product a => (Int,Int) -> Indexes -> Matrices a -> Matrix a +chain_paren (r,c) ixes ma = let ((lr,lc),(rr,rc)) = fromJust $ (ixes A.! r) A.! c + in if lr == rr && lc == rc then (ma A.! lr) + else (chain_paren (lr,lc) ixes ma) `multiply` (chain_paren (rr,rc) ixes ma) + +-------------------------------------------------------------------------- + +{- TESTS + +-- optimal association is ((m1*(m2*m3))*m4) +m1, m2, m3, m4 :: Matrix Double +m1 = (10><15) [1..] +m2 = (15><20) [1..] +m3 = (20><5) [1..] +m4 = (5><10) [1..] + +-} + diff --git a/packages/base/src/Numeric/Chain.hs b/packages/base/src/Numeric/Chain.hs deleted file mode 100644 index 64c09c0..0000000 --- a/packages/base/src/Numeric/Chain.hs +++ /dev/null @@ -1,148 +0,0 @@ -{-# LANGUAGE FlexibleContexts #-} - ------------------------------------------------------------------------------ --- | --- Module : Numeric.Chain --- Copyright : (c) Vivian McPhail 2010 --- License : BSD3 --- --- Maintainer : Vivian McPhail gmail.com> --- Stability : provisional --- Portability : portable --- --- optimisation of association order for chains of matrix multiplication --- ------------------------------------------------------------------------------ - -{-# LANGUAGE FlexibleContexts #-} - -module Numeric.Chain ( - optimiseMult, - ) where - -import Data.Maybe - -import Data.Packed.Matrix -import Data.Packed.Internal.Numeric - -import qualified Data.Array.IArray as A - ------------------------------------------------------------------------------ -{- | - Provide optimal association order for a chain of matrix multiplications - and apply the multiplications. - - The algorithm is the well-known O(n\^3) dynamic programming algorithm - that builds a pyramid of optimal associations. - -> m1, m2, m3, m4 :: Matrix Double -> m1 = (10><15) [1..] -> m2 = (15><20) [1..] -> m3 = (20><5) [1..] -> m4 = (5><10) [1..] - -> >>> optimiseMult [m1,m2,m3,m4] - -will perform @((m1 `multiply` (m2 `multiply` m3)) `multiply` m4)@ - -The naive left-to-right multiplication would take @4500@ scalar multiplications -whereas the optimised version performs @2750@ scalar multiplications. The complexity -in this case is 32 (= 4^3/2) * (2 comparisons, 3 scalar multiplications, 3 scalar additions, -5 lookups, 2 updates) + a constant (= three table allocations) --} -optimiseMult :: Product t => [Matrix t] -> Matrix t -optimiseMult = chain - ------------------------------------------------------------------------------ - -type Matrices a = A.Array Int (Matrix a) -type Sizes = A.Array Int (Int,Int) -type Cost = A.Array Int (A.Array Int (Maybe Int)) -type Indexes = A.Array Int (A.Array Int (Maybe ((Int,Int),(Int,Int)))) - -update :: A.Array Int (A.Array Int a) -> (Int,Int) -> a -> A.Array Int (A.Array Int a) -update a (r,c) e = a A.// [(r,(a A.! r) A.// [(c,e)])] - -newWorkSpaceCost :: Int -> A.Array Int (A.Array Int (Maybe Int)) -newWorkSpaceCost n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n] - where subArray i = A.listArray (1,i) (repeat Nothing) - -newWorkSpaceIndexes :: Int -> A.Array Int (A.Array Int (Maybe ((Int,Int),(Int,Int)))) -newWorkSpaceIndexes n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n] - where subArray i = A.listArray (1,i) (repeat Nothing) - -matricesToSizes :: [Matrix a] -> Sizes -matricesToSizes ms = A.listArray (1,length ms) $ map (\m -> (rows m,cols m)) ms - -chain :: Product a => [Matrix a] -> Matrix a -chain [] = error "chain: zero matrices to multiply" -chain [m] = m -chain [ml,mr] = ml `multiply` mr -chain ms = let ln = length ms - ma = A.listArray (1,ln) ms - mz = matricesToSizes ms - i = chain_cost mz - in chain_paren (ln,ln) i ma - -chain_cost :: Sizes -> Indexes -chain_cost mz = let (_,u) = A.bounds mz - cost = newWorkSpaceCost u - ixes = newWorkSpaceIndexes u - (_,_,i) = foldl chain_cost' (mz,cost,ixes) (order u) - in i - -chain_cost' :: (Sizes,Cost,Indexes) -> (Int,Int) -> (Sizes,Cost,Indexes) -chain_cost' sci@(mz,cost,ixes) (r,c) - | c == 1 = let cost' = update cost (r,c) (Just 0) - ixes' = update ixes (r,c) (Just ((r,c),(r,c))) - in (mz,cost',ixes') - | otherwise = minimum_cost sci (r,c) - -minimum_cost :: (Sizes,Cost,Indexes) -> (Int,Int) -> (Sizes,Cost,Indexes) -minimum_cost sci fu = foldl (smaller_cost fu) sci (fulcrum_order fu) - -smaller_cost :: (Int,Int) -> (Sizes,Cost,Indexes) -> ((Int,Int),(Int,Int)) -> (Sizes,Cost,Indexes) -smaller_cost (r,c) (mz,cost,ixes) ix@((lr,lc),(rr,rc)) = - let op_cost = fromJust ((cost A.! lr) A.! lc) - + fromJust ((cost A.! rr) A.! rc) - + fst (mz A.! (lr-lc+1)) - * snd (mz A.! lc) - * snd (mz A.! rr) - cost' = (cost A.! r) A.! c - in case cost' of - Nothing -> let cost'' = update cost (r,c) (Just op_cost) - ixes'' = update ixes (r,c) (Just ix) - in (mz,cost'',ixes'') - Just ct -> if op_cost < ct then - let cost'' = update cost (r,c) (Just op_cost) - ixes'' = update ixes (r,c) (Just ix) - in (mz,cost'',ixes'') - else (mz,cost,ixes) - - -fulcrum_order (r,c) = let fs' = zip (repeat r) [1..(c-1)] - in map (partner (r,c)) fs' - -partner (r,c) (a,b) = ((r-b, c-b), (a,b)) - -order 0 = [] -order n = order (n-1) ++ zip (repeat n) [1..n] - -chain_paren :: Product a => (Int,Int) -> Indexes -> Matrices a -> Matrix a -chain_paren (r,c) ixes ma = let ((lr,lc),(rr,rc)) = fromJust $ (ixes A.! r) A.! c - in if lr == rr && lc == rc then (ma A.! lr) - else (chain_paren (lr,lc) ixes ma) `multiply` (chain_paren (rr,rc) ixes ma) - --------------------------------------------------------------------------- - -{- TESTS - --- optimal association is ((m1*(m2*m3))*m4) -m1, m2, m3, m4 :: Matrix Double -m1 = (10><15) [1..] -m2 = (15><20) [1..] -m3 = (20><5) [1..] -m4 = (5><10) [1..] - --} - -- cgit v1.2.3