summaryrefslogtreecommitdiff
path: root/packages/base/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src')
-rw-r--r--packages/base/src/Numeric/Chain.hs143
-rw-r--r--packages/base/src/Numeric/Matrix.hs100
-rw-r--r--packages/base/src/Numeric/Vector.hs159
3 files changed, 402 insertions, 0 deletions
diff --git a/packages/base/src/Numeric/Chain.hs b/packages/base/src/Numeric/Chain.hs
new file mode 100644
index 0000000..fbdb01b
--- /dev/null
+++ b/packages/base/src/Numeric/Chain.hs
@@ -0,0 +1,143 @@
1-----------------------------------------------------------------------------
2-- |
3-- Module : Numeric.Chain
4-- Copyright : (c) Vivian McPhail 2010
5-- License : GPL-style
6--
7-- Maintainer : Vivian McPhail <haskell.vivian.mcphail <at> gmail.com>
8-- Stability : provisional
9-- Portability : portable
10--
11-- optimisation of association order for chains of matrix multiplication
12--
13-----------------------------------------------------------------------------
14
15module Numeric.Chain (
16 optimiseMult,
17 ) where
18
19import Data.Maybe
20
21import Data.Packed.Matrix
22import Data.Packed.Numeric
23
24import qualified Data.Array.IArray as A
25
26-----------------------------------------------------------------------------
27{- |
28 Provide optimal association order for a chain of matrix multiplications
29 and apply the multiplications.
30
31 The algorithm is the well-known O(n\^3) dynamic programming algorithm
32 that builds a pyramid of optimal associations.
33
34> m1, m2, m3, m4 :: Matrix Double
35> m1 = (10><15) [1..]
36> m2 = (15><20) [1..]
37> m3 = (20><5) [1..]
38> m4 = (5><10) [1..]
39
40> >>> optimiseMult [m1,m2,m3,m4]
41
42will perform @((m1 `multiply` (m2 `multiply` m3)) `multiply` m4)@
43
44The naive left-to-right multiplication would take @4500@ scalar multiplications
45whereas the optimised version performs @2750@ scalar multiplications. The complexity
46in this case is 32 (= 4^3/2) * (2 comparisons, 3 scalar multiplications, 3 scalar additions,
475 lookups, 2 updates) + a constant (= three table allocations)
48-}
49optimiseMult :: Product t => [Matrix t] -> Matrix t
50optimiseMult = chain
51
52-----------------------------------------------------------------------------
53
54type Matrices a = A.Array Int (Matrix a)
55type Sizes = A.Array Int (Int,Int)
56type Cost = A.Array Int (A.Array Int (Maybe Int))
57type Indexes = A.Array Int (A.Array Int (Maybe ((Int,Int),(Int,Int))))
58
59update :: A.Array Int (A.Array Int a) -> (Int,Int) -> a -> A.Array Int (A.Array Int a)
60update a (r,c) e = a A.// [(r,(a A.! r) A.// [(c,e)])]
61
62newWorkSpaceCost :: Int -> A.Array Int (A.Array Int (Maybe Int))
63newWorkSpaceCost n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n]
64 where subArray i = A.listArray (1,i) (repeat Nothing)
65
66newWorkSpaceIndexes :: Int -> A.Array Int (A.Array Int (Maybe ((Int,Int),(Int,Int))))
67newWorkSpaceIndexes n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n]
68 where subArray i = A.listArray (1,i) (repeat Nothing)
69
70matricesToSizes :: [Matrix a] -> Sizes
71matricesToSizes ms = A.listArray (1,length ms) $ map (\m -> (rows m,cols m)) ms
72
73chain :: Product a => [Matrix a] -> Matrix a
74chain [] = error "chain: zero matrices to multiply"
75chain [m] = m
76chain [ml,mr] = ml `multiply` mr
77chain ms = let ln = length ms
78 ma = A.listArray (1,ln) ms
79 mz = matricesToSizes ms
80 i = chain_cost mz
81 in chain_paren (ln,ln) i ma
82
83chain_cost :: Sizes -> Indexes
84chain_cost mz = let (_,u) = A.bounds mz
85 cost = newWorkSpaceCost u
86 ixes = newWorkSpaceIndexes u
87 (_,_,i) = foldl chain_cost' (mz,cost,ixes) (order u)
88 in i
89
90chain_cost' :: (Sizes,Cost,Indexes) -> (Int,Int) -> (Sizes,Cost,Indexes)
91chain_cost' sci@(mz,cost,ixes) (r,c)
92 | c == 1 = let cost' = update cost (r,c) (Just 0)
93 ixes' = update ixes (r,c) (Just ((r,c),(r,c)))
94 in (mz,cost',ixes')
95 | otherwise = minimum_cost sci (r,c)
96
97minimum_cost :: (Sizes,Cost,Indexes) -> (Int,Int) -> (Sizes,Cost,Indexes)
98minimum_cost sci fu = foldl (smaller_cost fu) sci (fulcrum_order fu)
99
100smaller_cost :: (Int,Int) -> (Sizes,Cost,Indexes) -> ((Int,Int),(Int,Int)) -> (Sizes,Cost,Indexes)
101smaller_cost (r,c) (mz,cost,ixes) ix@((lr,lc),(rr,rc)) = let op_cost = fromJust ((cost A.! lr) A.! lc)
102 + fromJust ((cost A.! rr) A.! rc)
103 + fst (mz A.! (lr-lc+1))
104 * snd (mz A.! lc)
105 * snd (mz A.! rr)
106 cost' = (cost A.! r) A.! c
107 in case cost' of
108 Nothing -> let cost'' = update cost (r,c) (Just op_cost)
109 ixes'' = update ixes (r,c) (Just ix)
110 in (mz,cost'',ixes'')
111 Just ct -> if op_cost < ct then
112 let cost'' = update cost (r,c) (Just op_cost)
113 ixes'' = update ixes (r,c) (Just ix)
114 in (mz,cost'',ixes'')
115 else (mz,cost,ixes)
116
117
118fulcrum_order (r,c) = let fs' = zip (repeat r) [1..(c-1)]
119 in map (partner (r,c)) fs'
120
121partner (r,c) (a,b) = ((r-b, c-b), (a,b))
122
123order 0 = []
124order n = order (n-1) ++ zip (repeat n) [1..n]
125
126chain_paren :: Product a => (Int,Int) -> Indexes -> Matrices a -> Matrix a
127chain_paren (r,c) ixes ma = let ((lr,lc),(rr,rc)) = fromJust $ (ixes A.! r) A.! c
128 in if lr == rr && lc == rc then (ma A.! lr)
129 else (chain_paren (lr,lc) ixes ma) `multiply` (chain_paren (rr,rc) ixes ma)
130
131--------------------------------------------------------------------------
132
133{- TESTS
134
135-- optimal association is ((m1*(m2*m3))*m4)
136m1, m2, m3, m4 :: Matrix Double
137m1 = (10><15) [1..]
138m2 = (15><20) [1..]
139m3 = (20><5) [1..]
140m4 = (5><10) [1..]
141
142-}
143
diff --git a/packages/base/src/Numeric/Matrix.hs b/packages/base/src/Numeric/Matrix.hs
new file mode 100644
index 0000000..3478aae
--- /dev/null
+++ b/packages/base/src/Numeric/Matrix.hs
@@ -0,0 +1,100 @@
1{-# LANGUAGE TypeFamilies #-}
2{-# LANGUAGE FlexibleContexts #-}
3{-# LANGUAGE FlexibleInstances #-}
4{-# LANGUAGE UndecidableInstances #-}
5{-# LANGUAGE MultiParamTypeClasses #-}
6
7-----------------------------------------------------------------------------
8-- |
9-- Module : Numeric.Matrix
10-- Copyright : (c) Alberto Ruiz 2010
11-- License : GPL-style
12--
13-- Maintainer : Alberto Ruiz <aruiz@um.es>
14-- Stability : provisional
15-- Portability : portable
16--
17-- Provides instances of standard classes 'Show', 'Read', 'Eq',
18-- 'Num', 'Fractional', and 'Floating' for 'Matrix'.
19--
20-- In arithmetic operations one-component
21-- vectors and matrices automatically expand to match the dimensions of the other operand.
22
23-----------------------------------------------------------------------------
24
25module Numeric.Matrix (
26 ) where
27
28-------------------------------------------------------------------
29
30import Data.Packed
31import Data.Packed.Numeric
32import qualified Data.Monoid as M
33import Data.List(partition)
34import Numeric.Chain
35
36-------------------------------------------------------------------
37
38instance Container Matrix a => Eq (Matrix a) where
39 (==) = equal
40
41instance (Container Matrix a, Num (Vector a)) => Num (Matrix a) where
42 (+) = liftMatrix2Auto (+)
43 (-) = liftMatrix2Auto (-)
44 negate = liftMatrix negate
45 (*) = liftMatrix2Auto (*)
46 signum = liftMatrix signum
47 abs = liftMatrix abs
48 fromInteger = (1><1) . return . fromInteger
49
50---------------------------------------------------
51
52instance (Container Vector a, Fractional (Vector a), Num (Matrix a)) => Fractional (Matrix a) where
53 fromRational n = (1><1) [fromRational n]
54 (/) = liftMatrix2Auto (/)
55
56---------------------------------------------------------
57
58instance (Floating a, Container Vector a, Floating (Vector a), Fractional (Matrix a)) => Floating (Matrix a) where
59 sin = liftMatrix sin
60 cos = liftMatrix cos
61 tan = liftMatrix tan
62 asin = liftMatrix asin
63 acos = liftMatrix acos
64 atan = liftMatrix atan
65 sinh = liftMatrix sinh
66 cosh = liftMatrix cosh
67 tanh = liftMatrix tanh
68 asinh = liftMatrix asinh
69 acosh = liftMatrix acosh
70 atanh = liftMatrix atanh
71 exp = liftMatrix exp
72 log = liftMatrix log
73 (**) = liftMatrix2Auto (**)
74 sqrt = liftMatrix sqrt
75 pi = (1><1) [pi]
76
77--------------------------------------------------------------------------------
78
79isScalar m = rows m == 1 && cols m == 1
80
81adaptScalarM f1 f2 f3 x y
82 | isScalar x = f1 (x @@>(0,0) ) y
83 | isScalar y = f3 x (y @@>(0,0) )
84 | otherwise = f2 x y
85
86instance (Container Vector t, Eq t, Num (Vector t), Product t) => M.Monoid (Matrix t)
87 where
88 mempty = 1
89 mappend = adaptScalarM scale mXm (flip scale)
90
91 mconcat xs = work (partition isScalar xs)
92 where
93 work (ss,[]) = product ss
94 work (ss,ms) = scale' (product ss) (optimiseMult ms)
95 scale' x m
96 | isScalar x && x00 == 1 = m
97 | otherwise = scale x00 m
98 where
99 x00 = x @@> (0,0)
100
diff --git a/packages/base/src/Numeric/Vector.hs b/packages/base/src/Numeric/Vector.hs
new file mode 100644
index 0000000..2769cd9
--- /dev/null
+++ b/packages/base/src/Numeric/Vector.hs
@@ -0,0 +1,159 @@
1{-# LANGUAGE TypeFamilies #-}
2{-# LANGUAGE FlexibleContexts #-}
3{-# LANGUAGE FlexibleInstances #-}
4{-# LANGUAGE UndecidableInstances #-}
5{-# LANGUAGE MultiParamTypeClasses #-}
6-----------------------------------------------------------------------------
7-- |
8-- Module : Numeric.Vector
9-- Copyright : (c) Alberto Ruiz 2011
10-- License : GPL-style
11--
12-- Maintainer : Alberto Ruiz <aruiz@um.es>
13-- Stability : provisional
14-- Portability : portable
15--
16-- Provides instances of standard classes 'Show', 'Read', 'Eq',
17-- 'Num', 'Fractional', and 'Floating' for 'Vector'.
18--
19-----------------------------------------------------------------------------
20
21module Numeric.Vector () where
22
23import Numeric.Vectorized
24import Data.Packed.Vector
25import Data.Packed.Numeric
26
27-------------------------------------------------------------------
28
29adaptScalar f1 f2 f3 x y
30 | dim x == 1 = f1 (x@>0) y
31 | dim y == 1 = f3 x (y@>0)
32 | otherwise = f2 x y
33
34------------------------------------------------------------------
35
36instance Num (Vector Float) where
37 (+) = adaptScalar addConstant add (flip addConstant)
38 negate = scale (-1)
39 (*) = adaptScalar scale mul (flip scale)
40 signum = vectorMapF Sign
41 abs = vectorMapF Abs
42 fromInteger = fromList . return . fromInteger
43
44instance Num (Vector Double) where
45 (+) = adaptScalar addConstant add (flip addConstant)
46 negate = scale (-1)
47 (*) = adaptScalar scale mul (flip scale)
48 signum = vectorMapR Sign
49 abs = vectorMapR Abs
50 fromInteger = fromList . return . fromInteger
51
52instance Num (Vector (Complex Double)) where
53 (+) = adaptScalar addConstant add (flip addConstant)
54 negate = scale (-1)
55 (*) = adaptScalar scale mul (flip scale)
56 signum = vectorMapC Sign
57 abs = vectorMapC Abs
58 fromInteger = fromList . return . fromInteger
59
60instance Num (Vector (Complex Float)) where
61 (+) = adaptScalar addConstant add (flip addConstant)
62 negate = scale (-1)
63 (*) = adaptScalar scale mul (flip scale)
64 signum = vectorMapQ Sign
65 abs = vectorMapQ Abs
66 fromInteger = fromList . return . fromInteger
67
68---------------------------------------------------
69
70instance (Container Vector a, Num (Vector a)) => Fractional (Vector a) where
71 fromRational n = fromList [fromRational n]
72 (/) = adaptScalar f divide g where
73 r `f` v = scaleRecip r v
74 v `g` r = scale (recip r) v
75
76-------------------------------------------------------
77
78instance Floating (Vector Float) where
79 sin = vectorMapF Sin
80 cos = vectorMapF Cos
81 tan = vectorMapF Tan
82 asin = vectorMapF ASin
83 acos = vectorMapF ACos
84 atan = vectorMapF ATan
85 sinh = vectorMapF Sinh
86 cosh = vectorMapF Cosh
87 tanh = vectorMapF Tanh
88 asinh = vectorMapF ASinh
89 acosh = vectorMapF ACosh
90 atanh = vectorMapF ATanh
91 exp = vectorMapF Exp
92 log = vectorMapF Log
93 sqrt = vectorMapF Sqrt
94 (**) = adaptScalar (vectorMapValF PowSV) (vectorZipF Pow) (flip (vectorMapValF PowVS))
95 pi = fromList [pi]
96
97-------------------------------------------------------------
98
99instance Floating (Vector Double) where
100 sin = vectorMapR Sin
101 cos = vectorMapR Cos
102 tan = vectorMapR Tan
103 asin = vectorMapR ASin
104 acos = vectorMapR ACos
105 atan = vectorMapR ATan
106 sinh = vectorMapR Sinh
107 cosh = vectorMapR Cosh
108 tanh = vectorMapR Tanh
109 asinh = vectorMapR ASinh
110 acosh = vectorMapR ACosh
111 atanh = vectorMapR ATanh
112 exp = vectorMapR Exp
113 log = vectorMapR Log
114 sqrt = vectorMapR Sqrt
115 (**) = adaptScalar (vectorMapValR PowSV) (vectorZipR Pow) (flip (vectorMapValR PowVS))
116 pi = fromList [pi]
117
118-------------------------------------------------------------
119
120instance Floating (Vector (Complex Double)) where
121 sin = vectorMapC Sin
122 cos = vectorMapC Cos
123 tan = vectorMapC Tan
124 asin = vectorMapC ASin
125 acos = vectorMapC ACos
126 atan = vectorMapC ATan
127 sinh = vectorMapC Sinh
128 cosh = vectorMapC Cosh
129 tanh = vectorMapC Tanh
130 asinh = vectorMapC ASinh
131 acosh = vectorMapC ACosh
132 atanh = vectorMapC ATanh
133 exp = vectorMapC Exp
134 log = vectorMapC Log
135 sqrt = vectorMapC Sqrt
136 (**) = adaptScalar (vectorMapValC PowSV) (vectorZipC Pow) (flip (vectorMapValC PowVS))
137 pi = fromList [pi]
138
139-----------------------------------------------------------
140
141instance Floating (Vector (Complex Float)) where
142 sin = vectorMapQ Sin
143 cos = vectorMapQ Cos
144 tan = vectorMapQ Tan
145 asin = vectorMapQ ASin
146 acos = vectorMapQ ACos
147 atan = vectorMapQ ATan
148 sinh = vectorMapQ Sinh
149 cosh = vectorMapQ Cosh
150 tanh = vectorMapQ Tanh
151 asinh = vectorMapQ ASinh
152 acosh = vectorMapQ ACosh
153 atanh = vectorMapQ ATanh
154 exp = vectorMapQ Exp
155 log = vectorMapQ Log
156 sqrt = vectorMapQ Sqrt
157 (**) = adaptScalar (vectorMapValQ PowSV) (vectorZipQ Pow) (flip (vectorMapValQ PowVS))
158 pi = fromList [pi]
159