summaryrefslogtreecommitdiff
path: root/packages/hmatrix/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/hmatrix/src')
-rw-r--r--packages/hmatrix/src/Numeric/Chain.hs140
-rw-r--r--packages/hmatrix/src/Numeric/Container.hs6
-rw-r--r--packages/hmatrix/src/Numeric/Matrix.hs98
-rw-r--r--packages/hmatrix/src/Numeric/Vector.hs158
4 files changed, 5 insertions, 397 deletions
diff --git a/packages/hmatrix/src/Numeric/Chain.hs b/packages/hmatrix/src/Numeric/Chain.hs
deleted file mode 100644
index de6a86f..0000000
--- a/packages/hmatrix/src/Numeric/Chain.hs
+++ /dev/null
@@ -1,140 +0,0 @@
1-----------------------------------------------------------------------------
2-- |
3-- Module : Numeric.Chain
4-- Copyright : (c) Vivian McPhail 2010
5-- License : GPL-style
6--
7-- Maintainer : Vivian McPhail <haskell.vivian.mcphail <at> gmail.com>
8-- Stability : provisional
9-- Portability : portable
10--
11-- optimisation of association order for chains of matrix multiplication
12--
13-----------------------------------------------------------------------------
14
15module Numeric.Chain (
16 optimiseMult,
17 ) where
18
19import Data.Maybe
20
21import Data.Packed.Matrix
22import Data.Packed.Numeric
23
24import qualified Data.Array.IArray as A
25
26-----------------------------------------------------------------------------
27{- |
28 Provide optimal association order for a chain of matrix multiplications
29 and apply the multiplications.
30
31 The algorithm is the well-known O(n\^3) dynamic programming algorithm
32 that builds a pyramid of optimal associations.
33
34> m1, m2, m3, m4 :: Matrix Double
35> m1 = (10><15) [1..]
36> m2 = (15><20) [1..]
37> m3 = (20><5) [1..]
38> m4 = (5><10) [1..]
39
40> >>> optimiseMult [m1,m2,m3,m4]
41
42will perform @((m1 `multiply` (m2 `multiply` m3)) `multiply` m4)@
43
44The naive left-to-right multiplication would take @4500@ scalar multiplications
45whereas the optimised version performs @2750@ scalar multiplications. The complexity
46in this case is 32 (= 4^3/2) * (2 comparisons, 3 scalar multiplications, 3 scalar additions,
475 lookups, 2 updates) + a constant (= three table allocations)
48-}
49optimiseMult :: Product t => [Matrix t] -> Matrix t
50optimiseMult = chain
51
52-----------------------------------------------------------------------------
53
54type Matrices a = A.Array Int (Matrix a)
55type Sizes = A.Array Int (Int,Int)
56type Cost = A.Array Int (A.Array Int (Maybe Int))
57type Indexes = A.Array Int (A.Array Int (Maybe ((Int,Int),(Int,Int))))
58
59update :: A.Array Int (A.Array Int a) -> (Int,Int) -> a -> A.Array Int (A.Array Int a)
60update a (r,c) e = a A.// [(r,(a A.! r) A.// [(c,e)])]
61
62newWorkSpaceCost :: Int -> A.Array Int (A.Array Int (Maybe Int))
63newWorkSpaceCost n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n]
64 where subArray i = A.listArray (1,i) (repeat Nothing)
65
66newWorkSpaceIndexes :: Int -> A.Array Int (A.Array Int (Maybe ((Int,Int),(Int,Int))))
67newWorkSpaceIndexes n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n]
68 where subArray i = A.listArray (1,i) (repeat Nothing)
69
70matricesToSizes :: [Matrix a] -> Sizes
71matricesToSizes ms = A.listArray (1,length ms) $ map (\m -> (rows m,cols m)) ms
72
73chain :: Product a => [Matrix a] -> Matrix a
74chain [] = error "chain: zero matrices to multiply"
75chain [m] = m
76chain [ml,mr] = ml `multiply` mr
77chain ms = let ln = length ms
78 ma = A.listArray (1,ln) ms
79 mz = matricesToSizes ms
80 i = chain_cost mz
81 in chain_paren (ln,ln) i ma
82
83chain_cost :: Sizes -> Indexes
84chain_cost mz = let (_,u) = A.bounds mz
85 cost = newWorkSpaceCost u
86 ixes = newWorkSpaceIndexes u
87 (_,_,i) = foldl chain_cost' (mz,cost,ixes) (order u)
88 in i
89
90chain_cost' :: (Sizes,Cost,Indexes) -> (Int,Int) -> (Sizes,Cost,Indexes)
91chain_cost' sci@(mz,cost,ixes) (r,c)
92 | c == 1 = let cost' = update cost (r,c) (Just 0)
93 ixes' = update ixes (r,c) (Just ((r,c),(r,c)))
94 in (mz,cost',ixes')
95 | otherwise = minimum_cost sci (r,c)
96
97minimum_cost :: (Sizes,Cost,Indexes) -> (Int,Int) -> (Sizes,Cost,Indexes)
98minimum_cost sci fu = foldl (smaller_cost fu) sci (fulcrum_order fu)
99
100smaller_cost :: (Int,Int) -> (Sizes,Cost,Indexes) -> ((Int,Int),(Int,Int)) -> (Sizes,Cost,Indexes)
101smaller_cost (r,c) (mz,cost,ixes) ix@((lr,lc),(rr,rc)) = let op_cost = fromJust ((cost A.! lr) A.! lc)
102 + fromJust ((cost A.! rr) A.! rc)
103 + fst (mz A.! (lr-lc+1))
104 * snd (mz A.! lc)
105 * snd (mz A.! rr)
106 cost' = (cost A.! r) A.! c
107 in case cost' of
108 Nothing -> let cost'' = update cost (r,c) (Just op_cost)
109 ixes'' = update ixes (r,c) (Just ix)
110 in (mz,cost'',ixes'')
111 Just ct -> if op_cost < ct then
112 let cost'' = update cost (r,c) (Just op_cost)
113 ixes'' = update ixes (r,c) (Just ix)
114 in (mz,cost'',ixes'')
115 else (mz,cost,ixes)
116
117
118fulcrum_order (r,c) = let fs' = zip (repeat r) [1..(c-1)]
119 in map (partner (r,c)) fs'
120
121partner (r,c) (a,b) = ((r-b, c-b), (a,b))
122
123order 0 = []
124order n = order (n-1) ++ zip (repeat n) [1..n]
125
126chain_paren :: Product a => (Int,Int) -> Indexes -> Matrices a -> Matrix a
127chain_paren (r,c) ixes ma = let ((lr,lc),(rr,rc)) = fromJust $ (ixes A.! r) A.! c
128 in if lr == rr && lc == rc then (ma A.! lr)
129 else (chain_paren (lr,lc) ixes ma) `multiply` (chain_paren (rr,rc) ixes ma)
130
131--------------------------------------------------------------------------
132
133{- TESTS -}
134
135-- optimal association is ((m1*(m2*m3))*m4)
136m1, m2, m3, m4 :: Matrix Double
137m1 = (10><15) [1..]
138m2 = (15><20) [1..]
139m3 = (20><5) [1..]
140m4 = (5><10) [1..]
diff --git a/packages/hmatrix/src/Numeric/Container.hs b/packages/hmatrix/src/Numeric/Container.hs
index 645a83f..e7f23d4 100644
--- a/packages/hmatrix/src/Numeric/Container.hs
+++ b/packages/hmatrix/src/Numeric/Container.hs
@@ -66,11 +66,11 @@ module Numeric.Container (
66 66
67import Data.Packed hiding (stepD, stepF, condD, condF, conjugateC, conjugateQ) 67import Data.Packed hiding (stepD, stepF, condD, condF, conjugateC, conjugateQ)
68import Data.Packed.Numeric 68import Data.Packed.Numeric
69import Numeric.Chain
70import Numeric.IO 69import Numeric.IO
71import Data.Complex 70import Data.Complex
72import Numeric.LinearAlgebra.Algorithms(Field,linearSolveSVD) 71import Numeric.LinearAlgebra.Algorithms(Field,linearSolveSVD)
73import Numeric.Random 72import Numeric.Random
73import Data.Monoid(Monoid(mconcat))
74 74
75------------------------------------------------------------------ 75------------------------------------------------------------------
76 76
@@ -268,4 +268,8 @@ infixl 7 ◇
268dot :: (Container Vector t, Product t) => Vector t -> Vector t -> t 268dot :: (Container Vector t, Product t) => Vector t -> Vector t -> t
269dot u v = udot (conj u) v 269dot u v = udot (conj u) v
270 270
271--------------------------------------------------------------------------------
272
273optimiseMult :: Monoid (Matrix t) => [Matrix t] -> Matrix t
274optimiseMult = mconcat
271 275
diff --git a/packages/hmatrix/src/Numeric/Matrix.hs b/packages/hmatrix/src/Numeric/Matrix.hs
deleted file mode 100644
index e285ff2..0000000
--- a/packages/hmatrix/src/Numeric/Matrix.hs
+++ /dev/null
@@ -1,98 +0,0 @@
1{-# LANGUAGE TypeFamilies #-}
2{-# LANGUAGE FlexibleContexts #-}
3{-# LANGUAGE FlexibleInstances #-}
4{-# LANGUAGE UndecidableInstances #-}
5{-# LANGUAGE MultiParamTypeClasses #-}
6
7-----------------------------------------------------------------------------
8-- |
9-- Module : Numeric.Matrix
10-- Copyright : (c) Alberto Ruiz 2010
11-- License : GPL-style
12--
13-- Maintainer : Alberto Ruiz <aruiz@um.es>
14-- Stability : provisional
15-- Portability : portable
16--
17-- Provides instances of standard classes 'Show', 'Read', 'Eq',
18-- 'Num', 'Fractional', and 'Floating' for 'Matrix'.
19--
20-- In arithmetic operations one-component
21-- vectors and matrices automatically expand to match the dimensions of the other operand.
22
23-----------------------------------------------------------------------------
24
25module Numeric.Matrix (
26 ) where
27
28-------------------------------------------------------------------
29
30import Numeric.Container
31import qualified Data.Monoid as M
32import Data.List(partition)
33
34-------------------------------------------------------------------
35
36instance Container Matrix a => Eq (Matrix a) where
37 (==) = equal
38
39instance (Container Matrix a, Num (Vector a)) => Num (Matrix a) where
40 (+) = liftMatrix2Auto (+)
41 (-) = liftMatrix2Auto (-)
42 negate = liftMatrix negate
43 (*) = liftMatrix2Auto (*)
44 signum = liftMatrix signum
45 abs = liftMatrix abs
46 fromInteger = (1><1) . return . fromInteger
47
48---------------------------------------------------
49
50instance (Container Vector a, Fractional (Vector a), Num (Matrix a)) => Fractional (Matrix a) where
51 fromRational n = (1><1) [fromRational n]
52 (/) = liftMatrix2Auto (/)
53
54---------------------------------------------------------
55
56instance (Floating a, Container Vector a, Floating (Vector a), Fractional (Matrix a)) => Floating (Matrix a) where
57 sin = liftMatrix sin
58 cos = liftMatrix cos
59 tan = liftMatrix tan
60 asin = liftMatrix asin
61 acos = liftMatrix acos
62 atan = liftMatrix atan
63 sinh = liftMatrix sinh
64 cosh = liftMatrix cosh
65 tanh = liftMatrix tanh
66 asinh = liftMatrix asinh
67 acosh = liftMatrix acosh
68 atanh = liftMatrix atanh
69 exp = liftMatrix exp
70 log = liftMatrix log
71 (**) = liftMatrix2Auto (**)
72 sqrt = liftMatrix sqrt
73 pi = (1><1) [pi]
74
75--------------------------------------------------------------------------------
76
77isScalar m = rows m == 1 && cols m == 1
78
79adaptScalarM f1 f2 f3 x y
80 | isScalar x = f1 (x @@>(0,0) ) y
81 | isScalar y = f3 x (y @@>(0,0) )
82 | otherwise = f2 x y
83
84instance (Container Vector t, Eq t, Num (Vector t), Product t) => M.Monoid (Matrix t)
85 where
86 mempty = 1
87 mappend = adaptScalarM scale mXm (flip scale)
88
89 mconcat xs = work (partition isScalar xs)
90 where
91 work (ss,[]) = product ss
92 work (ss,ms) = scale' (product ss) (optimiseMult ms)
93 scale' x m
94 | isScalar x && x00 == 1 = m
95 | otherwise = scale x00 m
96 where
97 x00 = x @@> (0,0)
98
diff --git a/packages/hmatrix/src/Numeric/Vector.hs b/packages/hmatrix/src/Numeric/Vector.hs
deleted file mode 100644
index 4c59d32..0000000
--- a/packages/hmatrix/src/Numeric/Vector.hs
+++ /dev/null
@@ -1,158 +0,0 @@
1{-# LANGUAGE TypeFamilies #-}
2{-# LANGUAGE FlexibleContexts #-}
3{-# LANGUAGE FlexibleInstances #-}
4{-# LANGUAGE UndecidableInstances #-}
5{-# LANGUAGE MultiParamTypeClasses #-}
6-----------------------------------------------------------------------------
7-- |
8-- Module : Numeric.Vector
9-- Copyright : (c) Alberto Ruiz 2011
10-- License : GPL-style
11--
12-- Maintainer : Alberto Ruiz <aruiz@um.es>
13-- Stability : provisional
14-- Portability : portable
15--
16-- Provides instances of standard classes 'Show', 'Read', 'Eq',
17-- 'Num', 'Fractional', and 'Floating' for 'Vector'.
18--
19-----------------------------------------------------------------------------
20
21module Numeric.Vector () where
22
23import Numeric.Vectorized
24import Numeric.Container
25
26-------------------------------------------------------------------
27
28adaptScalar f1 f2 f3 x y
29 | dim x == 1 = f1 (x@>0) y
30 | dim y == 1 = f3 x (y@>0)
31 | otherwise = f2 x y
32
33------------------------------------------------------------------
34
35instance Num (Vector Float) where
36 (+) = adaptScalar addConstant add (flip addConstant)
37 negate = scale (-1)
38 (*) = adaptScalar scale mul (flip scale)
39 signum = vectorMapF Sign
40 abs = vectorMapF Abs
41 fromInteger = fromList . return . fromInteger
42
43instance Num (Vector Double) where
44 (+) = adaptScalar addConstant add (flip addConstant)
45 negate = scale (-1)
46 (*) = adaptScalar scale mul (flip scale)
47 signum = vectorMapR Sign
48 abs = vectorMapR Abs
49 fromInteger = fromList . return . fromInteger
50
51instance Num (Vector (Complex Double)) where
52 (+) = adaptScalar addConstant add (flip addConstant)
53 negate = scale (-1)
54 (*) = adaptScalar scale mul (flip scale)
55 signum = vectorMapC Sign
56 abs = vectorMapC Abs
57 fromInteger = fromList . return . fromInteger
58
59instance Num (Vector (Complex Float)) where
60 (+) = adaptScalar addConstant add (flip addConstant)
61 negate = scale (-1)
62 (*) = adaptScalar scale mul (flip scale)
63 signum = vectorMapQ Sign
64 abs = vectorMapQ Abs
65 fromInteger = fromList . return . fromInteger
66
67---------------------------------------------------
68
69instance (Container Vector a, Num (Vector a)) => Fractional (Vector a) where
70 fromRational n = fromList [fromRational n]
71 (/) = adaptScalar f divide g where
72 r `f` v = scaleRecip r v
73 v `g` r = scale (recip r) v
74
75-------------------------------------------------------
76
77instance Floating (Vector Float) where
78 sin = vectorMapF Sin
79 cos = vectorMapF Cos
80 tan = vectorMapF Tan
81 asin = vectorMapF ASin
82 acos = vectorMapF ACos
83 atan = vectorMapF ATan
84 sinh = vectorMapF Sinh
85 cosh = vectorMapF Cosh
86 tanh = vectorMapF Tanh
87 asinh = vectorMapF ASinh
88 acosh = vectorMapF ACosh
89 atanh = vectorMapF ATanh
90 exp = vectorMapF Exp
91 log = vectorMapF Log
92 sqrt = vectorMapF Sqrt
93 (**) = adaptScalar (vectorMapValF PowSV) (vectorZipF Pow) (flip (vectorMapValF PowVS))
94 pi = fromList [pi]
95
96-------------------------------------------------------------
97
98instance Floating (Vector Double) where
99 sin = vectorMapR Sin
100 cos = vectorMapR Cos
101 tan = vectorMapR Tan
102 asin = vectorMapR ASin
103 acos = vectorMapR ACos
104 atan = vectorMapR ATan
105 sinh = vectorMapR Sinh
106 cosh = vectorMapR Cosh
107 tanh = vectorMapR Tanh
108 asinh = vectorMapR ASinh
109 acosh = vectorMapR ACosh
110 atanh = vectorMapR ATanh
111 exp = vectorMapR Exp
112 log = vectorMapR Log
113 sqrt = vectorMapR Sqrt
114 (**) = adaptScalar (vectorMapValR PowSV) (vectorZipR Pow) (flip (vectorMapValR PowVS))
115 pi = fromList [pi]
116
117-------------------------------------------------------------
118
119instance Floating (Vector (Complex Double)) where
120 sin = vectorMapC Sin
121 cos = vectorMapC Cos
122 tan = vectorMapC Tan
123 asin = vectorMapC ASin
124 acos = vectorMapC ACos
125 atan = vectorMapC ATan
126 sinh = vectorMapC Sinh
127 cosh = vectorMapC Cosh
128 tanh = vectorMapC Tanh
129 asinh = vectorMapC ASinh
130 acosh = vectorMapC ACosh
131 atanh = vectorMapC ATanh
132 exp = vectorMapC Exp
133 log = vectorMapC Log
134 sqrt = vectorMapC Sqrt
135 (**) = adaptScalar (vectorMapValC PowSV) (vectorZipC Pow) (flip (vectorMapValC PowVS))
136 pi = fromList [pi]
137
138-----------------------------------------------------------
139
140instance Floating (Vector (Complex Float)) where
141 sin = vectorMapQ Sin
142 cos = vectorMapQ Cos
143 tan = vectorMapQ Tan
144 asin = vectorMapQ ASin
145 acos = vectorMapQ ACos
146 atan = vectorMapQ ATan
147 sinh = vectorMapQ Sinh
148 cosh = vectorMapQ Cosh
149 tanh = vectorMapQ Tanh
150 asinh = vectorMapQ ASinh
151 acosh = vectorMapQ ACosh
152 atanh = vectorMapQ ATanh
153 exp = vectorMapQ Exp
154 log = vectorMapQ Log
155 sqrt = vectorMapQ Sqrt
156 (**) = adaptScalar (vectorMapValQ PowSV) (vectorZipQ Pow) (flip (vectorMapValQ PowVS))
157 pi = fromList [pi]
158