summaryrefslogtreecommitdiff
path: root/lib/Numeric
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Numeric')
-rw-r--r--lib/Numeric/Chain.hs38
-rw-r--r--lib/Numeric/LinearAlgebra/Tests.hs2
-rw-r--r--lib/Numeric/Matrix.hs2
3 files changed, 36 insertions, 6 deletions
diff --git a/lib/Numeric/Chain.hs b/lib/Numeric/Chain.hs
index 0c33f76..299d8fa 100644
--- a/lib/Numeric/Chain.hs
+++ b/lib/Numeric/Chain.hs
@@ -1,10 +1,10 @@
1----------------------------------------------------------------------------- 1-----------------------------------------------------------------------------
2-- | 2-- |
3-- Module : Numeric.Chain 3-- Module : Numeric.Chain
4-- Copyright : (c) Alberto Ruiz 2010 4-- Copyright : (c) Vivian McPhail 2010
5-- License : GPL-style 5-- License : GPL-style
6-- 6--
7-- Maintainer : Alberto Ruiz <aruiz@um.es> 7-- Maintainer : Vivian McPhail <haskell.vivian.mcphail <at> gmail.com>
8-- Stability : provisional 8-- Stability : provisional
9-- Portability : portable 9-- Portability : portable
10-- 10--
@@ -13,7 +13,7 @@
13----------------------------------------------------------------------------- 13-----------------------------------------------------------------------------
14 14
15module Numeric.Chain ( 15module Numeric.Chain (
16 chain 16 optimiseMult,
17 ) where 17 ) where
18 18
19import Data.Maybe 19import Data.Maybe
@@ -23,6 +23,34 @@ import Numeric.Container
23 23
24import qualified Data.Array.IArray as A 24import qualified Data.Array.IArray as A
25 25
26-----------------------------------------------------------------------------
27{- |
28 Provide optimal association order for a chain of matrix multiplications
29 and apply the multiplications.
30
31 The algorithm is the well-known O(n\^3) dynamic programming algorithm
32 that builds a pyramid of optimal associations.
33
34> m1, m2, m3, m4 :: Matrix Double
35> m1 = (10><15) [1..]
36> m2 = (15><20) [1..]
37> m3 = (20><5) [1..]
38> m4 = (5><10) [1..]
39
40> >>> optimiseMult [m1,m2,m3,m4]
41
42will perform @((m1 `multiply` (m2 `multiply` m3)) `multiply` m4)@
43
44The naive left-to-right multiplication would take @4500@ scalar multiplications
45whereas the optimised version performs @2750@ scalar multiplications. The complexity
46in this case is 32 (= 4^3/2) * (2 comparisons, 3 scalar multiplications, 3 scalar additions,
475 lookups, 2 updates) + a constant (= three table allocations)
48-}
49optimiseMult :: Product t => [Matrix t] -> Matrix t
50optimiseMult = chain
51
52-----------------------------------------------------------------------------
53
26type Matrices a = A.Array Int (Matrix a) 54type Matrices a = A.Array Int (Matrix a)
27type Sizes = A.Array Int (Int,Int) 55type Sizes = A.Array Int (Int,Int)
28type Cost = A.Array Int (A.Array Int (Maybe Int)) 56type Cost = A.Array Int (A.Array Int (Maybe Int))
@@ -42,8 +70,10 @@ newWorkSpaceIndexes n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n]
42matricesToSizes :: [Matrix a] -> Sizes 70matricesToSizes :: [Matrix a] -> Sizes
43matricesToSizes ms = A.listArray (1,length ms) $ map (\m -> (rows m,cols m)) ms 71matricesToSizes ms = A.listArray (1,length ms) $ map (\m -> (rows m,cols m)) ms
44 72
45-- | provide optimal association order for a chain of matrix multiplications and apply the multiplications
46chain :: Product a => [Matrix a] -> Matrix a 73chain :: Product a => [Matrix a] -> Matrix a
74chain [] = error "chain: zero matrices to multiply"
75chain [m] = m
76chain [ml,mr] = ml `multiply` mr
47chain ms = let ln = length ms 77chain ms = let ln = length ms
48 ma = A.listArray (1,ln) ms 78 ma = A.listArray (1,ln) ms
49 mz = matricesToSizes ms 79 mz = matricesToSizes ms
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs
index 2d5b0da..aa7b01c 100644
--- a/lib/Numeric/LinearAlgebra/Tests.hs
+++ b/lib/Numeric/LinearAlgebra/Tests.hs
@@ -287,7 +287,7 @@ sumprodTest = TestList [
287 287
288--------------------------------------------------------------------- 288---------------------------------------------------------------------
289 289
290chainTest = utest "chain" $ foldl1' (<>) ms |~| chain ms where 290chainTest = utest "chain" $ foldl1' (<>) ms |~| optimiseMult ms where
291 ms = [ diag (fromList [1,2,3 :: Double]) 291 ms = [ diag (fromList [1,2,3 :: Double])
292 , konst 3 (3,5) 292 , konst 3 (3,5)
293 , (5><10) [1 .. ] 293 , (5><10) [1 .. ]
diff --git a/lib/Numeric/Matrix.hs b/lib/Numeric/Matrix.hs
index fe0ec13..c766381 100644
--- a/lib/Numeric/Matrix.hs
+++ b/lib/Numeric/Matrix.hs
@@ -26,7 +26,7 @@ module Numeric.Matrix (
26 module Data.Packed.Matrix, 26 module Data.Packed.Matrix,
27 module Numeric.Vector, 27 module Numeric.Vector,
28 --module Numeric.Container, 28 --module Numeric.Container,
29 chain, 29 optimiseMult,
30 -- * Operators 30 -- * Operators
31 (<>), (<\>), 31 (<>), (<\>),
32 -- * Deprecated 32 -- * Deprecated