summaryrefslogtreecommitdiff
path: root/lib/Data/Packed/Internal/Matrix.hs
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Data/Packed/Internal/Matrix.hs')
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs204
1 files changed, 118 insertions, 86 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index 9895393..48652f3 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -1,4 +1,4 @@
1{-# OPTIONS_GHC -fglasgow-exts #-} 1{-# OPTIONS_GHC -fglasgow-exts -fallow-overlapping-instances #-}
2----------------------------------------------------------------------------- 2-----------------------------------------------------------------------------
3-- | 3-- |
4-- Module : Data.Packed.Internal.Matrix 4-- Module : Data.Packed.Internal.Matrix
@@ -22,9 +22,65 @@ import Foreign hiding (xor)
22import Complex 22import Complex
23import Control.Monad(when) 23import Control.Monad(when)
24import Data.List(transpose,intersperse) 24import Data.List(transpose,intersperse)
25import Data.Typeable 25--import Data.Typeable
26import Data.Maybe(fromJust) 26import Data.Maybe(fromJust)
27 27
28----------------------------------------------------------------
29
30class Storable a => Field a where
31 constant :: a -> Int -> Vector a
32 transdata :: Int -> Vector a -> Int -> Vector a
33 multiplyD :: MatrixOrder -> Matrix a -> Matrix a -> Matrix a
34 subMatrix :: (Int,Int) -- ^ (r0,c0) starting position
35 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
36 -> Matrix a -> Matrix a
37 diag :: Vector a -> Matrix a
38
39
40instance Field Double where
41 constant = constantR
42 transdata = transdataR
43 multiplyD = multiplyR
44 subMatrix = subMatrixR
45 diag = diagR
46
47instance Field (Complex Double) where
48 constant = constantC
49 transdata = transdataC
50 multiplyD = multiplyC
51 subMatrix = subMatrixC
52 diag = diagC
53
54-----------------------------------------------------------------
55
56transdataR :: Int -> Vector Double -> Int -> Vector Double
57transdataR = transdataAux ctransR
58
59transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double)
60transdataC = transdataAux ctransC
61
62transdataAux fun c1 d c2 =
63 if noneed
64 then d
65 else unsafePerformIO $ do
66 v <- createVector (dim d)
67 fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d]
68 --putStrLn "---> transdataAux"
69 return v
70 where r1 = dim d `div` c1
71 r2 = dim d `div` c2
72 noneed = r1 == 1 || c1 == 1
73
74foreign import ccall safe "aux.h transR"
75 ctransR :: TMM -- Double ::> Double ::> IO Int
76foreign import ccall safe "aux.h transC"
77 ctransC :: TCMCM -- Complex Double ::> Complex Double ::> IO Int
78
79transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d
80
81
82
83
28 84
29data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) 85data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
30 86
@@ -34,9 +90,18 @@ data Matrix t = M { rows :: Int
34 , tdat :: Vector t 90 , tdat :: Vector t
35 , isTrans :: Bool 91 , isTrans :: Bool
36 , order :: MatrixOrder 92 , order :: MatrixOrder
37 } deriving Typeable 93 } -- deriving Typeable
38 94
39 95
96data NMat t = MC { rws, cls :: Int, dtc :: Vector t}
97 | MF { rws, cls :: Int, dtf :: Vector t}
98 | Tr (NMat t)
99
100ntrans (Tr m) = m
101ntrans m = Tr m
102
103viewC m@MC{} = m
104viewF m@MF{} = m
40 105
41fortran m = order m == ColumnMajor 106fortran m = order m == ColumnMajor
42 107
@@ -78,7 +143,11 @@ matrixFromVector RowMajor c v =
78 , tdat = transdata c v r 143 , tdat = transdata c v r
79 , order = RowMajor 144 , order = RowMajor
80 , isTrans = False 145 , isTrans = False
81 } where r = dim v `div` c -- TODO check mod=0 146 } where (d,m) = dim v `divMod` c
147 r | m==0 = d
148 | otherwise = error "matrixFromVector"
149
150-- r = dim v `div` c -- TODO check mod=0
82 151
83matrixFromVector ColumnMajor c v = 152matrixFromVector ColumnMajor c v =
84 M { rows = r 153 M { rows = r
@@ -87,7 +156,9 @@ matrixFromVector ColumnMajor c v =
87 , tdat = transdata r v c 156 , tdat = transdata r v c
88 , order = ColumnMajor 157 , order = ColumnMajor
89 , isTrans = False 158 , isTrans = False
90 } where r = dim v `div` c -- TODO check mod=0 159 } where (d,m) = dim v `divMod` c
160 r | m==0 = d
161 | otherwise = error "matrixFromVector"
91 162
92createMatrix order r c = do 163createMatrix order r c = do
93 p <- createVector (r*c) 164 p <- createVector (r*c)
@@ -102,48 +173,11 @@ createMatrix order r c = do
102 , 9.0, 10.0, 11.0, 12.0 ]@ 173 , 9.0, 10.0, 11.0, 12.0 ]@
103 174
104-} 175-}
105reshape :: (Field t) => Int -> Vector t -> Matrix t 176reshape :: Field t => Int -> Vector t -> Matrix t
106reshape c v = matrixFromVector RowMajor c v 177reshape c v = matrixFromVector RowMajor c v
107 178
108singleton x = reshape 1 (fromList [x]) 179singleton x = reshape 1 (fromList [x])
109 180
110transdataG :: Storable a => Int -> Vector a -> Int -> Vector a
111transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d
112
113transdataR :: Int -> Vector Double -> Int -> Vector Double
114transdataR = transdataAux ctransR
115
116transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double)
117transdataC = transdataAux ctransC
118
119transdataAux fun c1 d c2 =
120 if noneed
121 then d
122 else unsafePerformIO $ do
123 v <- createVector (dim d)
124 fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d]
125 --putStrLn "---> transdataAux"
126 return v
127 where r1 = dim d `div` c1
128 r2 = dim d `div` c2
129 noneed = r1 == 1 || c1 == 1
130
131foreign import ccall safe "aux.h transR"
132 ctransR :: TMM -- Double ::> Double ::> IO Int
133foreign import ccall safe "aux.h transC"
134 ctransC :: TCMCM -- Complex Double ::> Complex Double ::> IO Int
135
136transdata :: Field a => Int -> Vector a -> Int -> Vector a
137transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2
138 | isComp baseOf d = scast $ transdataC c1 (scast d) c2
139 | otherwise = transdataG c1 d c2
140
141--transdata :: Storable a => Int -> Vector a -> Int -> Vector a
142--transdata = transdataG
143--{-# RULES "transdataR" transdata=transdataR #-}
144--{-# RULES "transdataC" transdata=transdataC #-}
145
146-----------------------------------------------------------------
147liftMatrix :: (Field a, Field b) => (Vector a -> Vector b) -> Matrix a -> Matrix b 181liftMatrix :: (Field a, Field b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
148liftMatrix f m = reshape (cols m) (f (cdat m)) 182liftMatrix f m = reshape (cols m) (f (cdat m))
149 183
@@ -163,7 +197,7 @@ multiplyL a b | ok = [[dotL x y | y <- transpose b] | x <- a]
163 Nothing -> False 197 Nothing -> False
164 Just c -> c == length b 198 Just c -> c == length b
165 199
166transL m = matrixFromVector RowMajor (rows m) $ transdataG (cols m) (cdat m) (rows m) 200transL m = matrixFromVector RowMajor (rows m) $ transdata (cols m) (cdat m) (rows m)
167 201
168multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b) 202multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b)
169 203
@@ -179,7 +213,7 @@ gmatC m f | fortran m =
179 else f 0 (rows m) (cols m) (ptr (dat m)) 213 else f 0 (rows m) (cols m) (ptr (dat m))
180 214
181 215
182multiplyAux order fun a b = unsafePerformIO $ do 216multiplyAux fun order a b = unsafePerformIO $ do
183 when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ 217 when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++
184 show (rows a,cols a) ++ " x " ++ show (rows b, cols b) 218 show (rows a,cols a) ++ " x " ++ show (rows b, cols b)
185 r <- createMatrix order (rows a) (cols b) 219 r <- createMatrix order (rows a) (cols b)
@@ -198,37 +232,14 @@ foreign import ccall safe "aux.h multiplyC"
198 -> Int -> Int -> Ptr (Complex Double) 232 -> Int -> Int -> Ptr (Complex Double)
199 -> IO Int 233 -> IO Int
200 234
201multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a 235multiply :: (Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a
202multiply RowMajor a b = multiplyD RowMajor a b 236multiply RowMajor a b = multiplyD RowMajor a b
203multiply ColumnMajor a b = m {rows = cols m, cols = rows m, order = ColumnMajor} 237multiply ColumnMajor a b = m {rows = cols m, cols = rows m, order = ColumnMajor}
204 where m = multiplyD RowMajor (trans b) (trans a) 238 where m = multiplyD RowMajor (trans b) (trans a)
205 239
206multiplyD order a b
207 | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b)
208 | isComp (baseOf.dat) a = scast $ multiplyAux order cmultiplyC (scast a) (scast b)
209 | otherwise = multiplyG a b
210
211----------------------------------------------------------------------
212 240
213outer' u v = dat (outer u v) 241multiplyR = multiplyAux cmultiplyR
214 242multiplyC = multiplyAux cmultiplyC
215{- | Outer product of two vectors.
216
217@\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3]
218(3><3)
219 [ 5.0, 2.0, 3.0
220 , 10.0, 4.0, 6.0
221 , 15.0, 6.0, 9.0 ]@
222-}
223outer :: (Num t, Field t) => Vector t -> Vector t -> Matrix t
224outer u v = multiply RowMajor r c
225 where r = matrixFromVector RowMajor 1 u
226 c = matrixFromVector RowMajor (dim v) v
227
228dot :: (Field t, Num t) => Vector t -> Vector t -> t
229dot u v = dat (multiply RowMajor r c) `at` 0
230 where r = matrixFromVector RowMajor (dim u) u
231 c = matrixFromVector RowMajor 1 v
232 243
233---------------------------------------------------------------------- 244----------------------------------------------------------------------
234 245
@@ -251,14 +262,14 @@ subMatrixC (r0,c0) (rt,ct) x =
251 subMatrixR (r0,2*c0) (rt,2*ct) . 262 subMatrixR (r0,2*c0) (rt,2*ct) .
252 reshape (2*cols x) . asReal . cdat $ x 263 reshape (2*cols x) . asReal . cdat $ x
253 264
254subMatrix :: (Field a) 265--subMatrix :: (Field a)
255 => (Int,Int) -- ^ (r0,c0) starting position 266-- => (Int,Int) -- ^ (r0,c0) starting position
256 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix 267-- -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
257 -> Matrix a -> Matrix a 268-- -> Matrix a -> Matrix a
258subMatrix st sz m 269--subMatrix st sz m
259 | isReal (baseOf.dat) m = scast $ subMatrixR st sz (scast m) 270-- | isReal (baseOf.dat) m = scast $ subMatrixR st sz (scast m)
260 | isComp (baseOf.dat) m = scast $ subMatrixC st sz (scast m) 271-- | isComp (baseOf.dat) m = scast $ subMatrixC st sz (scast m)
261 | otherwise = subMatrixG st sz m 272-- | otherwise = subMatrixG st sz m
262 273
263subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x)) 274subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x))
264 where subList s n = take n . drop s 275 where subList s n = take n . drop s
@@ -281,11 +292,11 @@ diagC = diagAux c_diagC "diagC"
281foreign import ccall "aux.h diagC" c_diagC :: TCVCM 292foreign import ccall "aux.h diagC" c_diagC :: TCVCM
282 293
283-- | diagonal matrix from a vector 294-- | diagonal matrix from a vector
284diag :: (Num a, Field a) => Vector a -> Matrix a 295--diag :: (Num a, Field a) => Vector a -> Matrix a
285diag v 296--diag v
286 | isReal (baseOf) v = scast $ diagR (scast v) 297-- | isReal (baseOf) v = scast $ diagR (scast v)
287 | isComp (baseOf) v = scast $ diagC (scast v) 298-- | isComp (baseOf) v = scast $ diagC (scast v)
288 | otherwise = diagG v 299-- | otherwise = diagG v
289 300
290diagG v = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]] 301diagG v = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]]
291 where c = dim v 302 where c = dim v
@@ -313,13 +324,34 @@ fromColumns :: Field t => [Vector t] -> Matrix t
313fromColumns m = trans . fromRows $ m 324fromColumns m = trans . fromRows $ m
314 325
315-- | Creates a list of vectors from the columns of a matrix 326-- | Creates a list of vectors from the columns of a matrix
316toColumns :: Field t => Matrix t -> [Vector t] 327toColumns :: Storable t => Matrix t -> [Vector t]
317toColumns m = toRows . trans $ m 328toColumns m = toRows . trans $ m
318 329
319 330
320-- | Reads a matrix position. 331-- | Reads a matrix position.
321(@@>) :: Field t => Matrix t -> (Int,Int) -> t 332(@@>) :: Storable t => Matrix t -> (Int,Int) -> t
322infixl 9 @@> 333infixl 9 @@>
323m@M {rows = r, cols = c} @@> (i,j) 334m@M {rows = r, cols = c} @@> (i,j)
324 | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" 335 | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range"
325 | otherwise = cdat m `at` (i*c+j) 336 | otherwise = cdat m `at` (i*c+j)
337
338------------------------------------------------------------------
339
340constantR :: Double -> Int -> Vector Double
341constantR = constantAux cconstantR
342
343constantC :: Complex Double -> Int -> Vector (Complex Double)
344constantC = constantAux cconstantC
345
346constantAux fun x n = unsafePerformIO $ do
347 v <- createVector n
348 px <- newArray [x]
349 fun px // vec v // check "constantAux" []
350 free px
351 return v
352
353foreign import ccall safe "aux.h constantR"
354 cconstantR :: Ptr Double -> TV -- Double :> IO Int
355
356foreign import ccall safe "aux.h constantC"
357 cconstantC :: Ptr (Complex Double) -> TCV -- Complex Double :> IO Int