diff options
Diffstat (limited to 'packages/hmatrix/src/Numeric')
-rw-r--r-- | packages/hmatrix/src/Numeric/Chain.hs | 140 | ||||
-rw-r--r-- | packages/hmatrix/src/Numeric/Container.hs | 6 | ||||
-rw-r--r-- | packages/hmatrix/src/Numeric/Matrix.hs | 98 | ||||
-rw-r--r-- | packages/hmatrix/src/Numeric/Vector.hs | 158 |
4 files changed, 5 insertions, 397 deletions
diff --git a/packages/hmatrix/src/Numeric/Chain.hs b/packages/hmatrix/src/Numeric/Chain.hs deleted file mode 100644 index de6a86f..0000000 --- a/packages/hmatrix/src/Numeric/Chain.hs +++ /dev/null | |||
@@ -1,140 +0,0 @@ | |||
1 | ----------------------------------------------------------------------------- | ||
2 | -- | | ||
3 | -- Module : Numeric.Chain | ||
4 | -- Copyright : (c) Vivian McPhail 2010 | ||
5 | -- License : GPL-style | ||
6 | -- | ||
7 | -- Maintainer : Vivian McPhail <haskell.vivian.mcphail <at> gmail.com> | ||
8 | -- Stability : provisional | ||
9 | -- Portability : portable | ||
10 | -- | ||
11 | -- optimisation of association order for chains of matrix multiplication | ||
12 | -- | ||
13 | ----------------------------------------------------------------------------- | ||
14 | |||
15 | module Numeric.Chain ( | ||
16 | optimiseMult, | ||
17 | ) where | ||
18 | |||
19 | import Data.Maybe | ||
20 | |||
21 | import Data.Packed.Matrix | ||
22 | import Data.Packed.Numeric | ||
23 | |||
24 | import qualified Data.Array.IArray as A | ||
25 | |||
26 | ----------------------------------------------------------------------------- | ||
27 | {- | | ||
28 | Provide optimal association order for a chain of matrix multiplications | ||
29 | and apply the multiplications. | ||
30 | |||
31 | The algorithm is the well-known O(n\^3) dynamic programming algorithm | ||
32 | that builds a pyramid of optimal associations. | ||
33 | |||
34 | > m1, m2, m3, m4 :: Matrix Double | ||
35 | > m1 = (10><15) [1..] | ||
36 | > m2 = (15><20) [1..] | ||
37 | > m3 = (20><5) [1..] | ||
38 | > m4 = (5><10) [1..] | ||
39 | |||
40 | > >>> optimiseMult [m1,m2,m3,m4] | ||
41 | |||
42 | will perform @((m1 `multiply` (m2 `multiply` m3)) `multiply` m4)@ | ||
43 | |||
44 | The naive left-to-right multiplication would take @4500@ scalar multiplications | ||
45 | whereas the optimised version performs @2750@ scalar multiplications. The complexity | ||
46 | in this case is 32 (= 4^3/2) * (2 comparisons, 3 scalar multiplications, 3 scalar additions, | ||
47 | 5 lookups, 2 updates) + a constant (= three table allocations) | ||
48 | -} | ||
49 | optimiseMult :: Product t => [Matrix t] -> Matrix t | ||
50 | optimiseMult = chain | ||
51 | |||
52 | ----------------------------------------------------------------------------- | ||
53 | |||
54 | type Matrices a = A.Array Int (Matrix a) | ||
55 | type Sizes = A.Array Int (Int,Int) | ||
56 | type Cost = A.Array Int (A.Array Int (Maybe Int)) | ||
57 | type Indexes = A.Array Int (A.Array Int (Maybe ((Int,Int),(Int,Int)))) | ||
58 | |||
59 | update :: A.Array Int (A.Array Int a) -> (Int,Int) -> a -> A.Array Int (A.Array Int a) | ||
60 | update a (r,c) e = a A.// [(r,(a A.! r) A.// [(c,e)])] | ||
61 | |||
62 | newWorkSpaceCost :: Int -> A.Array Int (A.Array Int (Maybe Int)) | ||
63 | newWorkSpaceCost n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n] | ||
64 | where subArray i = A.listArray (1,i) (repeat Nothing) | ||
65 | |||
66 | newWorkSpaceIndexes :: Int -> A.Array Int (A.Array Int (Maybe ((Int,Int),(Int,Int)))) | ||
67 | newWorkSpaceIndexes n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n] | ||
68 | where subArray i = A.listArray (1,i) (repeat Nothing) | ||
69 | |||
70 | matricesToSizes :: [Matrix a] -> Sizes | ||
71 | matricesToSizes ms = A.listArray (1,length ms) $ map (\m -> (rows m,cols m)) ms | ||
72 | |||
73 | chain :: Product a => [Matrix a] -> Matrix a | ||
74 | chain [] = error "chain: zero matrices to multiply" | ||
75 | chain [m] = m | ||
76 | chain [ml,mr] = ml `multiply` mr | ||
77 | chain ms = let ln = length ms | ||
78 | ma = A.listArray (1,ln) ms | ||
79 | mz = matricesToSizes ms | ||
80 | i = chain_cost mz | ||
81 | in chain_paren (ln,ln) i ma | ||
82 | |||
83 | chain_cost :: Sizes -> Indexes | ||
84 | chain_cost mz = let (_,u) = A.bounds mz | ||
85 | cost = newWorkSpaceCost u | ||
86 | ixes = newWorkSpaceIndexes u | ||
87 | (_,_,i) = foldl chain_cost' (mz,cost,ixes) (order u) | ||
88 | in i | ||
89 | |||
90 | chain_cost' :: (Sizes,Cost,Indexes) -> (Int,Int) -> (Sizes,Cost,Indexes) | ||
91 | chain_cost' sci@(mz,cost,ixes) (r,c) | ||
92 | | c == 1 = let cost' = update cost (r,c) (Just 0) | ||
93 | ixes' = update ixes (r,c) (Just ((r,c),(r,c))) | ||
94 | in (mz,cost',ixes') | ||
95 | | otherwise = minimum_cost sci (r,c) | ||
96 | |||
97 | minimum_cost :: (Sizes,Cost,Indexes) -> (Int,Int) -> (Sizes,Cost,Indexes) | ||
98 | minimum_cost sci fu = foldl (smaller_cost fu) sci (fulcrum_order fu) | ||
99 | |||
100 | smaller_cost :: (Int,Int) -> (Sizes,Cost,Indexes) -> ((Int,Int),(Int,Int)) -> (Sizes,Cost,Indexes) | ||
101 | smaller_cost (r,c) (mz,cost,ixes) ix@((lr,lc),(rr,rc)) = let op_cost = fromJust ((cost A.! lr) A.! lc) | ||
102 | + fromJust ((cost A.! rr) A.! rc) | ||
103 | + fst (mz A.! (lr-lc+1)) | ||
104 | * snd (mz A.! lc) | ||
105 | * snd (mz A.! rr) | ||
106 | cost' = (cost A.! r) A.! c | ||
107 | in case cost' of | ||
108 | Nothing -> let cost'' = update cost (r,c) (Just op_cost) | ||
109 | ixes'' = update ixes (r,c) (Just ix) | ||
110 | in (mz,cost'',ixes'') | ||
111 | Just ct -> if op_cost < ct then | ||
112 | let cost'' = update cost (r,c) (Just op_cost) | ||
113 | ixes'' = update ixes (r,c) (Just ix) | ||
114 | in (mz,cost'',ixes'') | ||
115 | else (mz,cost,ixes) | ||
116 | |||
117 | |||
118 | fulcrum_order (r,c) = let fs' = zip (repeat r) [1..(c-1)] | ||
119 | in map (partner (r,c)) fs' | ||
120 | |||
121 | partner (r,c) (a,b) = ((r-b, c-b), (a,b)) | ||
122 | |||
123 | order 0 = [] | ||
124 | order n = order (n-1) ++ zip (repeat n) [1..n] | ||
125 | |||
126 | chain_paren :: Product a => (Int,Int) -> Indexes -> Matrices a -> Matrix a | ||
127 | chain_paren (r,c) ixes ma = let ((lr,lc),(rr,rc)) = fromJust $ (ixes A.! r) A.! c | ||
128 | in if lr == rr && lc == rc then (ma A.! lr) | ||
129 | else (chain_paren (lr,lc) ixes ma) `multiply` (chain_paren (rr,rc) ixes ma) | ||
130 | |||
131 | -------------------------------------------------------------------------- | ||
132 | |||
133 | {- TESTS -} | ||
134 | |||
135 | -- optimal association is ((m1*(m2*m3))*m4) | ||
136 | m1, m2, m3, m4 :: Matrix Double | ||
137 | m1 = (10><15) [1..] | ||
138 | m2 = (15><20) [1..] | ||
139 | m3 = (20><5) [1..] | ||
140 | m4 = (5><10) [1..] | ||
diff --git a/packages/hmatrix/src/Numeric/Container.hs b/packages/hmatrix/src/Numeric/Container.hs index 645a83f..e7f23d4 100644 --- a/packages/hmatrix/src/Numeric/Container.hs +++ b/packages/hmatrix/src/Numeric/Container.hs | |||
@@ -66,11 +66,11 @@ module Numeric.Container ( | |||
66 | 66 | ||
67 | import Data.Packed hiding (stepD, stepF, condD, condF, conjugateC, conjugateQ) | 67 | import Data.Packed hiding (stepD, stepF, condD, condF, conjugateC, conjugateQ) |
68 | import Data.Packed.Numeric | 68 | import Data.Packed.Numeric |
69 | import Numeric.Chain | ||
70 | import Numeric.IO | 69 | import Numeric.IO |
71 | import Data.Complex | 70 | import Data.Complex |
72 | import Numeric.LinearAlgebra.Algorithms(Field,linearSolveSVD) | 71 | import Numeric.LinearAlgebra.Algorithms(Field,linearSolveSVD) |
73 | import Numeric.Random | 72 | import Numeric.Random |
73 | import Data.Monoid(Monoid(mconcat)) | ||
74 | 74 | ||
75 | ------------------------------------------------------------------ | 75 | ------------------------------------------------------------------ |
76 | 76 | ||
@@ -268,4 +268,8 @@ infixl 7 ◇ | |||
268 | dot :: (Container Vector t, Product t) => Vector t -> Vector t -> t | 268 | dot :: (Container Vector t, Product t) => Vector t -> Vector t -> t |
269 | dot u v = udot (conj u) v | 269 | dot u v = udot (conj u) v |
270 | 270 | ||
271 | -------------------------------------------------------------------------------- | ||
272 | |||
273 | optimiseMult :: Monoid (Matrix t) => [Matrix t] -> Matrix t | ||
274 | optimiseMult = mconcat | ||
271 | 275 | ||
diff --git a/packages/hmatrix/src/Numeric/Matrix.hs b/packages/hmatrix/src/Numeric/Matrix.hs deleted file mode 100644 index e285ff2..0000000 --- a/packages/hmatrix/src/Numeric/Matrix.hs +++ /dev/null | |||
@@ -1,98 +0,0 @@ | |||
1 | {-# LANGUAGE TypeFamilies #-} | ||
2 | {-# LANGUAGE FlexibleContexts #-} | ||
3 | {-# LANGUAGE FlexibleInstances #-} | ||
4 | {-# LANGUAGE UndecidableInstances #-} | ||
5 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
6 | |||
7 | ----------------------------------------------------------------------------- | ||
8 | -- | | ||
9 | -- Module : Numeric.Matrix | ||
10 | -- Copyright : (c) Alberto Ruiz 2010 | ||
11 | -- License : GPL-style | ||
12 | -- | ||
13 | -- Maintainer : Alberto Ruiz <aruiz@um.es> | ||
14 | -- Stability : provisional | ||
15 | -- Portability : portable | ||
16 | -- | ||
17 | -- Provides instances of standard classes 'Show', 'Read', 'Eq', | ||
18 | -- 'Num', 'Fractional', and 'Floating' for 'Matrix'. | ||
19 | -- | ||
20 | -- In arithmetic operations one-component | ||
21 | -- vectors and matrices automatically expand to match the dimensions of the other operand. | ||
22 | |||
23 | ----------------------------------------------------------------------------- | ||
24 | |||
25 | module Numeric.Matrix ( | ||
26 | ) where | ||
27 | |||
28 | ------------------------------------------------------------------- | ||
29 | |||
30 | import Numeric.Container | ||
31 | import qualified Data.Monoid as M | ||
32 | import Data.List(partition) | ||
33 | |||
34 | ------------------------------------------------------------------- | ||
35 | |||
36 | instance Container Matrix a => Eq (Matrix a) where | ||
37 | (==) = equal | ||
38 | |||
39 | instance (Container Matrix a, Num (Vector a)) => Num (Matrix a) where | ||
40 | (+) = liftMatrix2Auto (+) | ||
41 | (-) = liftMatrix2Auto (-) | ||
42 | negate = liftMatrix negate | ||
43 | (*) = liftMatrix2Auto (*) | ||
44 | signum = liftMatrix signum | ||
45 | abs = liftMatrix abs | ||
46 | fromInteger = (1><1) . return . fromInteger | ||
47 | |||
48 | --------------------------------------------------- | ||
49 | |||
50 | instance (Container Vector a, Fractional (Vector a), Num (Matrix a)) => Fractional (Matrix a) where | ||
51 | fromRational n = (1><1) [fromRational n] | ||
52 | (/) = liftMatrix2Auto (/) | ||
53 | |||
54 | --------------------------------------------------------- | ||
55 | |||
56 | instance (Floating a, Container Vector a, Floating (Vector a), Fractional (Matrix a)) => Floating (Matrix a) where | ||
57 | sin = liftMatrix sin | ||
58 | cos = liftMatrix cos | ||
59 | tan = liftMatrix tan | ||
60 | asin = liftMatrix asin | ||
61 | acos = liftMatrix acos | ||
62 | atan = liftMatrix atan | ||
63 | sinh = liftMatrix sinh | ||
64 | cosh = liftMatrix cosh | ||
65 | tanh = liftMatrix tanh | ||
66 | asinh = liftMatrix asinh | ||
67 | acosh = liftMatrix acosh | ||
68 | atanh = liftMatrix atanh | ||
69 | exp = liftMatrix exp | ||
70 | log = liftMatrix log | ||
71 | (**) = liftMatrix2Auto (**) | ||
72 | sqrt = liftMatrix sqrt | ||
73 | pi = (1><1) [pi] | ||
74 | |||
75 | -------------------------------------------------------------------------------- | ||
76 | |||
77 | isScalar m = rows m == 1 && cols m == 1 | ||
78 | |||
79 | adaptScalarM f1 f2 f3 x y | ||
80 | | isScalar x = f1 (x @@>(0,0) ) y | ||
81 | | isScalar y = f3 x (y @@>(0,0) ) | ||
82 | | otherwise = f2 x y | ||
83 | |||
84 | instance (Container Vector t, Eq t, Num (Vector t), Product t) => M.Monoid (Matrix t) | ||
85 | where | ||
86 | mempty = 1 | ||
87 | mappend = adaptScalarM scale mXm (flip scale) | ||
88 | |||
89 | mconcat xs = work (partition isScalar xs) | ||
90 | where | ||
91 | work (ss,[]) = product ss | ||
92 | work (ss,ms) = scale' (product ss) (optimiseMult ms) | ||
93 | scale' x m | ||
94 | | isScalar x && x00 == 1 = m | ||
95 | | otherwise = scale x00 m | ||
96 | where | ||
97 | x00 = x @@> (0,0) | ||
98 | |||
diff --git a/packages/hmatrix/src/Numeric/Vector.hs b/packages/hmatrix/src/Numeric/Vector.hs deleted file mode 100644 index 4c59d32..0000000 --- a/packages/hmatrix/src/Numeric/Vector.hs +++ /dev/null | |||
@@ -1,158 +0,0 @@ | |||
1 | {-# LANGUAGE TypeFamilies #-} | ||
2 | {-# LANGUAGE FlexibleContexts #-} | ||
3 | {-# LANGUAGE FlexibleInstances #-} | ||
4 | {-# LANGUAGE UndecidableInstances #-} | ||
5 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
6 | ----------------------------------------------------------------------------- | ||
7 | -- | | ||
8 | -- Module : Numeric.Vector | ||
9 | -- Copyright : (c) Alberto Ruiz 2011 | ||
10 | -- License : GPL-style | ||
11 | -- | ||
12 | -- Maintainer : Alberto Ruiz <aruiz@um.es> | ||
13 | -- Stability : provisional | ||
14 | -- Portability : portable | ||
15 | -- | ||
16 | -- Provides instances of standard classes 'Show', 'Read', 'Eq', | ||
17 | -- 'Num', 'Fractional', and 'Floating' for 'Vector'. | ||
18 | -- | ||
19 | ----------------------------------------------------------------------------- | ||
20 | |||
21 | module Numeric.Vector () where | ||
22 | |||
23 | import Numeric.Vectorized | ||
24 | import Numeric.Container | ||
25 | |||
26 | ------------------------------------------------------------------- | ||
27 | |||
28 | adaptScalar f1 f2 f3 x y | ||
29 | | dim x == 1 = f1 (x@>0) y | ||
30 | | dim y == 1 = f3 x (y@>0) | ||
31 | | otherwise = f2 x y | ||
32 | |||
33 | ------------------------------------------------------------------ | ||
34 | |||
35 | instance Num (Vector Float) where | ||
36 | (+) = adaptScalar addConstant add (flip addConstant) | ||
37 | negate = scale (-1) | ||
38 | (*) = adaptScalar scale mul (flip scale) | ||
39 | signum = vectorMapF Sign | ||
40 | abs = vectorMapF Abs | ||
41 | fromInteger = fromList . return . fromInteger | ||
42 | |||
43 | instance Num (Vector Double) where | ||
44 | (+) = adaptScalar addConstant add (flip addConstant) | ||
45 | negate = scale (-1) | ||
46 | (*) = adaptScalar scale mul (flip scale) | ||
47 | signum = vectorMapR Sign | ||
48 | abs = vectorMapR Abs | ||
49 | fromInteger = fromList . return . fromInteger | ||
50 | |||
51 | instance Num (Vector (Complex Double)) where | ||
52 | (+) = adaptScalar addConstant add (flip addConstant) | ||
53 | negate = scale (-1) | ||
54 | (*) = adaptScalar scale mul (flip scale) | ||
55 | signum = vectorMapC Sign | ||
56 | abs = vectorMapC Abs | ||
57 | fromInteger = fromList . return . fromInteger | ||
58 | |||
59 | instance Num (Vector (Complex Float)) where | ||
60 | (+) = adaptScalar addConstant add (flip addConstant) | ||
61 | negate = scale (-1) | ||
62 | (*) = adaptScalar scale mul (flip scale) | ||
63 | signum = vectorMapQ Sign | ||
64 | abs = vectorMapQ Abs | ||
65 | fromInteger = fromList . return . fromInteger | ||
66 | |||
67 | --------------------------------------------------- | ||
68 | |||
69 | instance (Container Vector a, Num (Vector a)) => Fractional (Vector a) where | ||
70 | fromRational n = fromList [fromRational n] | ||
71 | (/) = adaptScalar f divide g where | ||
72 | r `f` v = scaleRecip r v | ||
73 | v `g` r = scale (recip r) v | ||
74 | |||
75 | ------------------------------------------------------- | ||
76 | |||
77 | instance Floating (Vector Float) where | ||
78 | sin = vectorMapF Sin | ||
79 | cos = vectorMapF Cos | ||
80 | tan = vectorMapF Tan | ||
81 | asin = vectorMapF ASin | ||
82 | acos = vectorMapF ACos | ||
83 | atan = vectorMapF ATan | ||
84 | sinh = vectorMapF Sinh | ||
85 | cosh = vectorMapF Cosh | ||
86 | tanh = vectorMapF Tanh | ||
87 | asinh = vectorMapF ASinh | ||
88 | acosh = vectorMapF ACosh | ||
89 | atanh = vectorMapF ATanh | ||
90 | exp = vectorMapF Exp | ||
91 | log = vectorMapF Log | ||
92 | sqrt = vectorMapF Sqrt | ||
93 | (**) = adaptScalar (vectorMapValF PowSV) (vectorZipF Pow) (flip (vectorMapValF PowVS)) | ||
94 | pi = fromList [pi] | ||
95 | |||
96 | ------------------------------------------------------------- | ||
97 | |||
98 | instance Floating (Vector Double) where | ||
99 | sin = vectorMapR Sin | ||
100 | cos = vectorMapR Cos | ||
101 | tan = vectorMapR Tan | ||
102 | asin = vectorMapR ASin | ||
103 | acos = vectorMapR ACos | ||
104 | atan = vectorMapR ATan | ||
105 | sinh = vectorMapR Sinh | ||
106 | cosh = vectorMapR Cosh | ||
107 | tanh = vectorMapR Tanh | ||
108 | asinh = vectorMapR ASinh | ||
109 | acosh = vectorMapR ACosh | ||
110 | atanh = vectorMapR ATanh | ||
111 | exp = vectorMapR Exp | ||
112 | log = vectorMapR Log | ||
113 | sqrt = vectorMapR Sqrt | ||
114 | (**) = adaptScalar (vectorMapValR PowSV) (vectorZipR Pow) (flip (vectorMapValR PowVS)) | ||
115 | pi = fromList [pi] | ||
116 | |||
117 | ------------------------------------------------------------- | ||
118 | |||
119 | instance Floating (Vector (Complex Double)) where | ||
120 | sin = vectorMapC Sin | ||
121 | cos = vectorMapC Cos | ||
122 | tan = vectorMapC Tan | ||
123 | asin = vectorMapC ASin | ||
124 | acos = vectorMapC ACos | ||
125 | atan = vectorMapC ATan | ||
126 | sinh = vectorMapC Sinh | ||
127 | cosh = vectorMapC Cosh | ||
128 | tanh = vectorMapC Tanh | ||
129 | asinh = vectorMapC ASinh | ||
130 | acosh = vectorMapC ACosh | ||
131 | atanh = vectorMapC ATanh | ||
132 | exp = vectorMapC Exp | ||
133 | log = vectorMapC Log | ||
134 | sqrt = vectorMapC Sqrt | ||
135 | (**) = adaptScalar (vectorMapValC PowSV) (vectorZipC Pow) (flip (vectorMapValC PowVS)) | ||
136 | pi = fromList [pi] | ||
137 | |||
138 | ----------------------------------------------------------- | ||
139 | |||
140 | instance Floating (Vector (Complex Float)) where | ||
141 | sin = vectorMapQ Sin | ||
142 | cos = vectorMapQ Cos | ||
143 | tan = vectorMapQ Tan | ||
144 | asin = vectorMapQ ASin | ||
145 | acos = vectorMapQ ACos | ||
146 | atan = vectorMapQ ATan | ||
147 | sinh = vectorMapQ Sinh | ||
148 | cosh = vectorMapQ Cosh | ||
149 | tanh = vectorMapQ Tanh | ||
150 | asinh = vectorMapQ ASinh | ||
151 | acosh = vectorMapQ ACosh | ||
152 | atanh = vectorMapQ ATanh | ||
153 | exp = vectorMapQ Exp | ||
154 | log = vectorMapQ Log | ||
155 | sqrt = vectorMapQ Sqrt | ||
156 | (**) = adaptScalar (vectorMapValQ PowSV) (vectorZipQ Pow) (flip (vectorMapValQ PowVS)) | ||
157 | pi = fromList [pi] | ||
158 | |||