diff options
Diffstat (limited to 'lib/Data/Packed/Internal/Matrix.hs')
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 66 |
1 files changed, 42 insertions, 24 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index fbab33c..010c40b 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -62,21 +62,35 @@ import Data.List(transpose) | |||
62 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) | 62 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) |
63 | 63 | ||
64 | -- | Matrix representation suitable for GSL and LAPACK computations. | 64 | -- | Matrix representation suitable for GSL and LAPACK computations. |
65 | data Matrix t = MC { rows :: Int, cols :: Int, cdat :: Vector t, fdat :: Vector t } | 65 | data Matrix t = MC { rows :: Int, cols :: Int, cdat :: Vector t } |
66 | | MF { rows :: Int, cols :: Int, fdat :: Vector t, cdat :: Vector t } | 66 | | MF { rows :: Int, cols :: Int, fdat :: Vector t } |
67 | 67 | ||
68 | -- MC: preferred by C, fdat may require a transposition | 68 | -- MC: preferred by C, fdat may require a transposition |
69 | -- MF: preferred by LAPACK, cdat may require a transposition | 69 | -- MF: preferred by LAPACK, cdat may require a transposition |
70 | 70 | ||
71 | -- | Matrix transpose. | 71 | -- | Matrix transpose. |
72 | trans :: Matrix t -> Matrix t | 72 | trans :: Matrix t -> Matrix t |
73 | trans MC {rows = r, cols = c, cdat = d, fdat = dt } = MF {rows = c, cols = r, fdat = d, cdat = dt } | 73 | trans MC {rows = r, cols = c, cdat = d } = MF {rows = c, cols = r, fdat = d } |
74 | trans MF {rows = r, cols = c, fdat = d, cdat = dt } = MC {rows = c, cols = r, cdat = d, fdat = dt } | 74 | trans MF {rows = r, cols = c, fdat = d } = MC {rows = c, cols = r, cdat = d } |
75 | 75 | ||
76 | dat MC { cdat = d } = d | 76 | cmat m@MC{} = m |
77 | dat MF { fdat = d } = d | 77 | cmat MF {rows = r, cols = c, fdat = d } = MC {rows = r, cols = c, cdat = transdata r d c} |
78 | |||
79 | fmat m@MF{} = m | ||
80 | fmat MC {rows = r, cols = c, cdat = d } = MF {rows = r, cols = c, fdat = transdata c d r} | ||
81 | |||
82 | matc m f = f (rows m) (cols m) (ptr (cdat m)) | ||
83 | matf m f = f (rows m) (cols m) (ptr (fdat m)) | ||
84 | |||
85 | |||
86 | {- | Creates a vector by concatenation of rows | ||
87 | |||
88 | @\> flatten ('ident' 3) | ||
89 | 9 |> [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]@ | ||
90 | -} | ||
91 | flatten :: Element t => Matrix t -> Vector t | ||
92 | flatten = cdat . cmat | ||
78 | 93 | ||
79 | mat d m f = f (rows m) (cols m) (ptr (d m)) | ||
80 | 94 | ||
81 | type Mt t s = Int -> Int -> Ptr t -> s | 95 | type Mt t s = Int -> Int -> Ptr t -> s |
82 | -- not yet admitted by my haddock version | 96 | -- not yet admitted by my haddock version |
@@ -85,7 +99,7 @@ type Mt t s = Int -> Int -> Ptr t -> s | |||
85 | 99 | ||
86 | -- | the inverse of 'Data.Packed.Matrix.fromLists' | 100 | -- | the inverse of 'Data.Packed.Matrix.fromLists' |
87 | toLists :: (Element t) => Matrix t -> [[t]] | 101 | toLists :: (Element t) => Matrix t -> [[t]] |
88 | toLists m = partit (cols m) . toList . cdat $ m | 102 | toLists m = partit (cols m) . toList . flatten $ m |
89 | 103 | ||
90 | -- | creates a Matrix from a list of vectors | 104 | -- | creates a Matrix from a list of vectors |
91 | fromRows :: Element t => [Vector t] -> Matrix t | 105 | fromRows :: Element t => [Vector t] -> Matrix t |
@@ -96,7 +110,7 @@ fromRows vs = case common dim vs of | |||
96 | -- | extracts the rows of a matrix as a list of vectors | 110 | -- | extracts the rows of a matrix as a list of vectors |
97 | toRows :: Element t => Matrix t -> [Vector t] | 111 | toRows :: Element t => Matrix t -> [Vector t] |
98 | toRows m = toRows' 0 where | 112 | toRows m = toRows' 0 where |
99 | v = cdat m | 113 | v = flatten $ m |
100 | r = rows m | 114 | r = rows m |
101 | c = cols m | 115 | c = cols m |
102 | toRows' k | k == r*c = [] | 116 | toRows' k | k == r*c = [] |
@@ -128,12 +142,12 @@ MF {rows = r, cols = c, fdat = v} @@> (i,j) | |||
128 | 142 | ||
129 | ------------------------------------------------------------------ | 143 | ------------------------------------------------------------------ |
130 | 144 | ||
131 | matrixFromVector RowMajor c v = MC { rows = r, cols = c, cdat = v, fdat = transdata c v r } | 145 | matrixFromVector RowMajor c v = MC { rows = r, cols = c, cdat = v } |
132 | where (d,m) = dim v `divMod` c | 146 | where (d,m) = dim v `divMod` c |
133 | r | m==0 = d | 147 | r | m==0 = d |
134 | | otherwise = error "matrixFromVector" | 148 | | otherwise = error "matrixFromVector" |
135 | 149 | ||
136 | matrixFromVector ColumnMajor c v = MF { rows = r, cols = c, fdat = v, cdat = transdata r v c } | 150 | matrixFromVector ColumnMajor c v = MF { rows = r, cols = c, fdat = v } |
137 | where (d,m) = dim v `divMod` c | 151 | where (d,m) = dim v `divMod` c |
138 | r | m==0 = d | 152 | r | m==0 = d |
139 | | otherwise = error "matrixFromVector" | 153 | | otherwise = error "matrixFromVector" |
@@ -167,8 +181,8 @@ liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vec | |||
167 | liftMatrix2 f m1 m2 | 181 | liftMatrix2 f m1 m2 |
168 | | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" | 182 | | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" |
169 | | otherwise = case m1 of | 183 | | otherwise = case m1 of |
170 | MC {} -> matrixFromVector RowMajor (cols m1) (f (cdat m1) (cdat m2)) | 184 | MC {} -> matrixFromVector RowMajor (cols m1) (f (cdat m1) (flatten m2)) |
171 | MF {} -> matrixFromVector ColumnMajor (cols m1) (f (fdat m1) (fdat m2)) | 185 | MF {} -> matrixFromVector ColumnMajor (cols m1) (f (fdat m1) ((fdat.fmat) m2)) |
172 | 186 | ||
173 | 187 | ||
174 | compat :: Matrix a -> Matrix b -> Bool | 188 | compat :: Matrix a -> Matrix b -> Bool |
@@ -222,8 +236,8 @@ transdataAux fun c1 d c2 = | |||
222 | then d | 236 | then d |
223 | else unsafePerformIO $ do | 237 | else unsafePerformIO $ do |
224 | v <- createVector (dim d) | 238 | v <- createVector (dim d) |
225 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] | 239 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d,v] |
226 | --putStrLn "---> transdataAux" | 240 | -- putStrLn $ "---> transdataAux" ++ show (toList d) ++ show (toList v) |
227 | return v | 241 | return v |
228 | where r1 = dim d `div` c1 | 242 | where r1 = dim d `div` c1 |
229 | r2 = dim d `div` c2 | 243 | r2 = dim d `div` c2 |
@@ -239,11 +253,14 @@ foreign import ccall safe "auxi.h transC" | |||
239 | gmatC MF {rows = r, cols = c, fdat = d} f = f 1 c r (ptr d) | 253 | gmatC MF {rows = r, cols = c, fdat = d} f = f 1 c r (ptr d) |
240 | gmatC MC {rows = r, cols = c, cdat = d} f = f 0 r c (ptr d) | 254 | gmatC MC {rows = r, cols = c, cdat = d} f = f 0 r c (ptr d) |
241 | 255 | ||
256 | dtt MC { cdat = d } = d | ||
257 | dtt MF { fdat = d } = d | ||
258 | |||
242 | multiplyAux fun a b = unsafePerformIO $ do | 259 | multiplyAux fun a b = unsafePerformIO $ do |
243 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ | 260 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ |
244 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) | 261 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) |
245 | r <- createMatrix RowMajor (rows a) (cols b) | 262 | r <- createMatrix RowMajor (rows a) (cols b) |
246 | fun // gmatC a // gmatC b // mat dat r // check "multiplyAux" [dat a, dat b] | 263 | fun // gmatC a // gmatC b // matc r // check "multiplyAux" [dtt a, dtt b, cdat r] |
247 | return r | 264 | return r |
248 | 265 | ||
249 | multiplyR = multiplyAux cmultiplyR | 266 | multiplyR = multiplyAux cmultiplyR |
@@ -273,18 +290,19 @@ multiply = multiplyD | |||
273 | 290 | ||
274 | -- | extraction of a submatrix from a real matrix | 291 | -- | extraction of a submatrix from a real matrix |
275 | subMatrixR :: (Int,Int) -> (Int,Int) -> Matrix Double -> Matrix Double | 292 | subMatrixR :: (Int,Int) -> (Int,Int) -> Matrix Double -> Matrix Double |
276 | subMatrixR (r0,c0) (rt,ct) x = unsafePerformIO $ do | 293 | subMatrixR (r0,c0) (rt,ct) x' = unsafePerformIO $ do |
277 | r <- createMatrix RowMajor rt ct | 294 | r <- createMatrix RowMajor rt ct |
278 | c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // mat cdat x // mat dat r // check "subMatrixR" [dat r] | 295 | let x = cmat x' |
296 | c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // matc x // matc r // check "subMatrixR" [cdat x] | ||
279 | return r | 297 | return r |
280 | foreign import ccall "auxi.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM | 298 | foreign import ccall "auxi.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM |
281 | 299 | ||
282 | -- | extraction of a submatrix from a complex matrix | 300 | -- | extraction of a submatrix from a complex matrix |
283 | subMatrixC :: (Int,Int) -> (Int,Int) -> Matrix (Complex Double) -> Matrix (Complex Double) | 301 | subMatrixC :: (Int,Int) -> (Int,Int) -> Matrix (Complex Double) -> Matrix (Complex Double) |
284 | subMatrixC (r0,c0) (rt,ct) x = | 302 | subMatrixC (r0,c0) (rt,ct) x = |
285 | reshape ct . asComplex . cdat . | 303 | reshape ct . asComplex . flatten . |
286 | subMatrixR (r0,2*c0) (rt,2*ct) . | 304 | subMatrixR (r0,2*c0) (rt,2*ct) . |
287 | reshape (2*cols x) . asReal . cdat $ x | 305 | reshape (2*cols x) . asReal . flatten $ x |
288 | 306 | ||
289 | -- | Extracts a submatrix from a matrix. | 307 | -- | Extracts a submatrix from a matrix. |
290 | subMatrix :: Element a | 308 | subMatrix :: Element a |
@@ -299,7 +317,7 @@ subMatrix = subMatrixD | |||
299 | 317 | ||
300 | diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do | 318 | diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do |
301 | m <- createMatrix RowMajor n n | 319 | m <- createMatrix RowMajor n n |
302 | fun // vec v // mat cdat m // check msg [dat m] | 320 | fun // vec v // matc m // check msg [cdat m] |
303 | return m -- {tdat = dat m} | 321 | return m -- {tdat = dat m} |
304 | 322 | ||
305 | -- | diagonal matrix from a real vector | 323 | -- | diagonal matrix from a real vector |
@@ -347,11 +365,11 @@ constant = constantD | |||
347 | 365 | ||
348 | -- | obtains the complex conjugate of a complex vector | 366 | -- | obtains the complex conjugate of a complex vector |
349 | conj :: Vector (Complex Double) -> Vector (Complex Double) | 367 | conj :: Vector (Complex Double) -> Vector (Complex Double) |
350 | conj v = asComplex $ cdat $ reshape 2 (asReal v) `multiply` diag (fromList [1,-1]) | 368 | conj v = asComplex $ flatten $ reshape 2 (asReal v) `multiply` diag (fromList [1,-1]) |
351 | 369 | ||
352 | -- | creates a complex vector from vectors with real and imaginary parts | 370 | -- | creates a complex vector from vectors with real and imaginary parts |
353 | toComplex :: (Vector Double, Vector Double) -> Vector (Complex Double) | 371 | toComplex :: (Vector Double, Vector Double) -> Vector (Complex Double) |
354 | toComplex (r,i) = asComplex $ cdat $ fromColumns [r,i] | 372 | toComplex (r,i) = asComplex $ flatten $ fromColumns [r,i] |
355 | 373 | ||
356 | -- | the inverse of 'toComplex' | 374 | -- | the inverse of 'toComplex' |
357 | fromComplex :: Vector (Complex Double) -> (Vector Double, Vector Double) | 375 | fromComplex :: Vector (Complex Double) -> (Vector Double, Vector Double) |
@@ -367,7 +385,7 @@ fromFile :: FilePath -> (Int,Int) -> IO (Matrix Double) | |||
367 | fromFile filename (r,c) = do | 385 | fromFile filename (r,c) = do |
368 | charname <- newCString filename | 386 | charname <- newCString filename |
369 | res <- createMatrix RowMajor r c | 387 | res <- createMatrix RowMajor r c |
370 | c_gslReadMatrix charname // mat dat res // check "gslReadMatrix" [] | 388 | c_gslReadMatrix charname // matc res // check "gslReadMatrix" [] |
371 | --free charname -- TO DO: free the auxiliary CString | 389 | --free charname -- TO DO: free the auxiliary CString |
372 | return res | 390 | return res |
373 | foreign import ccall "auxi.h matrix_fscanf" c_gslReadMatrix:: Ptr CChar -> TM | 391 | foreign import ccall "auxi.h matrix_fscanf" c_gslReadMatrix:: Ptr CChar -> TM |