From 9c087c435a5c8762fa66458899a36ac505e45128 Mon Sep 17 00:00:00 2001 From: Vivian McPhail Date: Sun, 12 Sep 2010 01:16:21 +0000 Subject: improve chain -> optimiseMult function / documentation --- lib/Numeric/Chain.hs | 38 ++++++++++++++++++++++++++++++++++---- lib/Numeric/LinearAlgebra/Tests.hs | 2 +- lib/Numeric/Matrix.hs | 2 +- 3 files changed, 36 insertions(+), 6 deletions(-) (limited to 'lib/Numeric') diff --git a/lib/Numeric/Chain.hs b/lib/Numeric/Chain.hs index 0c33f76..299d8fa 100644 --- a/lib/Numeric/Chain.hs +++ b/lib/Numeric/Chain.hs @@ -1,10 +1,10 @@ ----------------------------------------------------------------------------- -- | -- Module : Numeric.Chain --- Copyright : (c) Alberto Ruiz 2010 +-- Copyright : (c) Vivian McPhail 2010 -- License : GPL-style -- --- Maintainer : Alberto Ruiz +-- Maintainer : Vivian McPhail gmail.com> -- Stability : provisional -- Portability : portable -- @@ -13,7 +13,7 @@ ----------------------------------------------------------------------------- module Numeric.Chain ( - chain + optimiseMult, ) where import Data.Maybe @@ -23,6 +23,34 @@ import Numeric.Container import qualified Data.Array.IArray as A +----------------------------------------------------------------------------- +{- | + Provide optimal association order for a chain of matrix multiplications + and apply the multiplications. + + The algorithm is the well-known O(n\^3) dynamic programming algorithm + that builds a pyramid of optimal associations. + +> m1, m2, m3, m4 :: Matrix Double +> m1 = (10><15) [1..] +> m2 = (15><20) [1..] +> m3 = (20><5) [1..] +> m4 = (5><10) [1..] + +> >>> optimiseMult [m1,m2,m3,m4] + +will perform @((m1 `multiply` (m2 `multiply` m3)) `multiply` m4)@ + +The naive left-to-right multiplication would take @4500@ scalar multiplications +whereas the optimised version performs @2750@ scalar multiplications. The complexity +in this case is 32 (= 4^3/2) * (2 comparisons, 3 scalar multiplications, 3 scalar additions, +5 lookups, 2 updates) + a constant (= three table allocations) +-} +optimiseMult :: Product t => [Matrix t] -> Matrix t +optimiseMult = chain + +----------------------------------------------------------------------------- + type Matrices a = A.Array Int (Matrix a) type Sizes = A.Array Int (Int,Int) type Cost = A.Array Int (A.Array Int (Maybe Int)) @@ -42,8 +70,10 @@ newWorkSpaceIndexes n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n] matricesToSizes :: [Matrix a] -> Sizes matricesToSizes ms = A.listArray (1,length ms) $ map (\m -> (rows m,cols m)) ms --- | provide optimal association order for a chain of matrix multiplications and apply the multiplications chain :: Product a => [Matrix a] -> Matrix a +chain [] = error "chain: zero matrices to multiply" +chain [m] = m +chain [ml,mr] = ml `multiply` mr chain ms = let ln = length ms ma = A.listArray (1,ln) ms mz = matricesToSizes ms diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs index 2d5b0da..aa7b01c 100644 --- a/lib/Numeric/LinearAlgebra/Tests.hs +++ b/lib/Numeric/LinearAlgebra/Tests.hs @@ -287,7 +287,7 @@ sumprodTest = TestList [ --------------------------------------------------------------------- -chainTest = utest "chain" $ foldl1' (<>) ms |~| chain ms where +chainTest = utest "chain" $ foldl1' (<>) ms |~| optimiseMult ms where ms = [ diag (fromList [1,2,3 :: Double]) , konst 3 (3,5) , (5><10) [1 .. ] diff --git a/lib/Numeric/Matrix.hs b/lib/Numeric/Matrix.hs index fe0ec13..c766381 100644 --- a/lib/Numeric/Matrix.hs +++ b/lib/Numeric/Matrix.hs @@ -26,7 +26,7 @@ module Numeric.Matrix ( module Data.Packed.Matrix, module Numeric.Vector, --module Numeric.Container, - chain, + optimiseMult, -- * Operators (<>), (<\>), -- * Deprecated -- cgit v1.2.3