diff options
Diffstat (limited to 'lib/Numeric')
-rw-r--r-- | lib/Numeric/Chain.hs | 38 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Tests.hs | 2 | ||||
-rw-r--r-- | lib/Numeric/Matrix.hs | 2 |
3 files changed, 36 insertions, 6 deletions
diff --git a/lib/Numeric/Chain.hs b/lib/Numeric/Chain.hs index 0c33f76..299d8fa 100644 --- a/lib/Numeric/Chain.hs +++ b/lib/Numeric/Chain.hs | |||
@@ -1,10 +1,10 @@ | |||
1 | ----------------------------------------------------------------------------- | 1 | ----------------------------------------------------------------------------- |
2 | -- | | 2 | -- | |
3 | -- Module : Numeric.Chain | 3 | -- Module : Numeric.Chain |
4 | -- Copyright : (c) Alberto Ruiz 2010 | 4 | -- Copyright : (c) Vivian McPhail 2010 |
5 | -- License : GPL-style | 5 | -- License : GPL-style |
6 | -- | 6 | -- |
7 | -- Maintainer : Alberto Ruiz <aruiz@um.es> | 7 | -- Maintainer : Vivian McPhail <haskell.vivian.mcphail <at> gmail.com> |
8 | -- Stability : provisional | 8 | -- Stability : provisional |
9 | -- Portability : portable | 9 | -- Portability : portable |
10 | -- | 10 | -- |
@@ -13,7 +13,7 @@ | |||
13 | ----------------------------------------------------------------------------- | 13 | ----------------------------------------------------------------------------- |
14 | 14 | ||
15 | module Numeric.Chain ( | 15 | module Numeric.Chain ( |
16 | chain | 16 | optimiseMult, |
17 | ) where | 17 | ) where |
18 | 18 | ||
19 | import Data.Maybe | 19 | import Data.Maybe |
@@ -23,6 +23,34 @@ import Numeric.Container | |||
23 | 23 | ||
24 | import qualified Data.Array.IArray as A | 24 | import qualified Data.Array.IArray as A |
25 | 25 | ||
26 | ----------------------------------------------------------------------------- | ||
27 | {- | | ||
28 | Provide optimal association order for a chain of matrix multiplications | ||
29 | and apply the multiplications. | ||
30 | |||
31 | The algorithm is the well-known O(n\^3) dynamic programming algorithm | ||
32 | that builds a pyramid of optimal associations. | ||
33 | |||
34 | > m1, m2, m3, m4 :: Matrix Double | ||
35 | > m1 = (10><15) [1..] | ||
36 | > m2 = (15><20) [1..] | ||
37 | > m3 = (20><5) [1..] | ||
38 | > m4 = (5><10) [1..] | ||
39 | |||
40 | > >>> optimiseMult [m1,m2,m3,m4] | ||
41 | |||
42 | will perform @((m1 `multiply` (m2 `multiply` m3)) `multiply` m4)@ | ||
43 | |||
44 | The naive left-to-right multiplication would take @4500@ scalar multiplications | ||
45 | whereas the optimised version performs @2750@ scalar multiplications. The complexity | ||
46 | in this case is 32 (= 4^3/2) * (2 comparisons, 3 scalar multiplications, 3 scalar additions, | ||
47 | 5 lookups, 2 updates) + a constant (= three table allocations) | ||
48 | -} | ||
49 | optimiseMult :: Product t => [Matrix t] -> Matrix t | ||
50 | optimiseMult = chain | ||
51 | |||
52 | ----------------------------------------------------------------------------- | ||
53 | |||
26 | type Matrices a = A.Array Int (Matrix a) | 54 | type Matrices a = A.Array Int (Matrix a) |
27 | type Sizes = A.Array Int (Int,Int) | 55 | type Sizes = A.Array Int (Int,Int) |
28 | type Cost = A.Array Int (A.Array Int (Maybe Int)) | 56 | type Cost = A.Array Int (A.Array Int (Maybe Int)) |
@@ -42,8 +70,10 @@ newWorkSpaceIndexes n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n] | |||
42 | matricesToSizes :: [Matrix a] -> Sizes | 70 | matricesToSizes :: [Matrix a] -> Sizes |
43 | matricesToSizes ms = A.listArray (1,length ms) $ map (\m -> (rows m,cols m)) ms | 71 | matricesToSizes ms = A.listArray (1,length ms) $ map (\m -> (rows m,cols m)) ms |
44 | 72 | ||
45 | -- | provide optimal association order for a chain of matrix multiplications and apply the multiplications | ||
46 | chain :: Product a => [Matrix a] -> Matrix a | 73 | chain :: Product a => [Matrix a] -> Matrix a |
74 | chain [] = error "chain: zero matrices to multiply" | ||
75 | chain [m] = m | ||
76 | chain [ml,mr] = ml `multiply` mr | ||
47 | chain ms = let ln = length ms | 77 | chain ms = let ln = length ms |
48 | ma = A.listArray (1,ln) ms | 78 | ma = A.listArray (1,ln) ms |
49 | mz = matricesToSizes ms | 79 | mz = matricesToSizes ms |
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs index 2d5b0da..aa7b01c 100644 --- a/lib/Numeric/LinearAlgebra/Tests.hs +++ b/lib/Numeric/LinearAlgebra/Tests.hs | |||
@@ -287,7 +287,7 @@ sumprodTest = TestList [ | |||
287 | 287 | ||
288 | --------------------------------------------------------------------- | 288 | --------------------------------------------------------------------- |
289 | 289 | ||
290 | chainTest = utest "chain" $ foldl1' (<>) ms |~| chain ms where | 290 | chainTest = utest "chain" $ foldl1' (<>) ms |~| optimiseMult ms where |
291 | ms = [ diag (fromList [1,2,3 :: Double]) | 291 | ms = [ diag (fromList [1,2,3 :: Double]) |
292 | , konst 3 (3,5) | 292 | , konst 3 (3,5) |
293 | , (5><10) [1 .. ] | 293 | , (5><10) [1 .. ] |
diff --git a/lib/Numeric/Matrix.hs b/lib/Numeric/Matrix.hs index fe0ec13..c766381 100644 --- a/lib/Numeric/Matrix.hs +++ b/lib/Numeric/Matrix.hs | |||
@@ -26,7 +26,7 @@ module Numeric.Matrix ( | |||
26 | module Data.Packed.Matrix, | 26 | module Data.Packed.Matrix, |
27 | module Numeric.Vector, | 27 | module Numeric.Vector, |
28 | --module Numeric.Container, | 28 | --module Numeric.Container, |
29 | chain, | 29 | optimiseMult, |
30 | -- * Operators | 30 | -- * Operators |
31 | (<>), (<\>), | 31 | (<>), (<\>), |
32 | -- * Deprecated | 32 | -- * Deprecated |