summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Matrix.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal/Matrix.hs')
-rw-r--r--packages/base/src/Internal/Matrix.hs190
1 files changed, 77 insertions, 113 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index f76b9dc..bdf2785 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -32,49 +32,13 @@ import Foreign.C.Types ( CInt(..) )
32import Foreign.C.String ( CString, newCString ) 32import Foreign.C.String ( CString, newCString )
33import System.IO.Unsafe ( unsafePerformIO ) 33import System.IO.Unsafe ( unsafePerformIO )
34import Control.DeepSeq ( NFData(..) ) 34import Control.DeepSeq ( NFData(..) )
35import Data.List.Split(chunksOf) 35import Text.Printf
36 36
37----------------------------------------------------------------- 37-----------------------------------------------------------------
38 38
39{- Design considerations for the Matrix Type
40 -----------------------------------------
41
42- we must easily handle both row major and column major order,
43 for bindings to LAPACK and GSL/C
44
45- we'd like to simplify redundant matrix transposes:
46 - Some of them arise from the order requirements of some functions
47 - some functions (matrix product) admit transposed arguments
48
49- maybe we don't really need this kind of simplification:
50 - more complex code
51 - some computational overhead
52 - only appreciable gain in code with a lot of redundant transpositions
53 and cheap matrix computations
54
55- we could carry both the matrix and its (lazily computed) transpose.
56 This may save some transpositions, but it is necessary to keep track of the
57 data which is actually computed to be used by functions like the matrix product
58 which admit both orders.
59
60- but if we need the transposed data and it is not in the structure, we must make
61 sure that we touch the same foreignptr that is used in the computation.
62
63- a reasonable solution is using two constructors for a matrix. Transposition just
64 "flips" the constructor. Actual data transposition is not done if followed by a
65 matrix product or another transpose.
66
67-}
68
69data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) 39data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
70 40
71transOrder RowMajor = ColumnMajor 41-- | Matrix representation suitable for BLAS\/LAPACK computations.
72transOrder ColumnMajor = RowMajor
73{- | Matrix representation suitable for BLAS\/LAPACK computations.
74
75The elements are stored in a continuous memory array.
76
77-}
78 42
79data Matrix t = Matrix 43data Matrix t = Matrix
80 { irows :: {-# UNPACK #-} !Int 44 { irows :: {-# UNPACK #-} !Int
@@ -83,8 +47,6 @@ data Matrix t = Matrix
83 , xCol :: {-# UNPACK #-} !Int 47 , xCol :: {-# UNPACK #-} !Int
84 , xdat :: {-# UNPACK #-} !(Vector t) 48 , xdat :: {-# UNPACK #-} !(Vector t)
85 } 49 }
86-- RowMajor: preferred by C, fdat may require a transposition
87-- ColumnMajor: preferred by LAPACK, cdat may require a transposition
88 50
89 51
90rows :: Matrix t -> Int 52rows :: Matrix t -> Int
@@ -95,32 +57,55 @@ cols :: Matrix t -> Int
95cols = icols 57cols = icols
96{-# INLINE cols #-} 58{-# INLINE cols #-}
97 59
98rowOrder m = xRow m > 1 60size m = (irows m, icols m)
61{-# INLINE size #-}
62
63rowOrder m = xCol m == 1 || cols m == 1
99{-# INLINE rowOrder #-} 64{-# INLINE rowOrder #-}
100 65
101isSlice m = cols m < xRow m || rows m < xCol m 66colOrder m = xRow m == 1 || rows m == 1
67{-# INLINE colOrder #-}
68
69is1d (size->(r,c)) = r==1 || c==1
70{-# INLINE is1d #-}
71
72-- data is not contiguous
73isSlice m@(size->(r,c)) = (c < xRow m || r < xCol m) && min r c > 1
102{-# INLINE isSlice #-} 74{-# INLINE isSlice #-}
103 75
104orderOf :: Matrix t -> MatrixOrder 76orderOf :: Matrix t -> MatrixOrder
105orderOf m = if rowOrder m then RowMajor else ColumnMajor 77orderOf m = if rowOrder m then RowMajor else ColumnMajor
106 78
107 79
80showInternal :: Storable t => Matrix t -> IO ()
81showInternal m = printf "%dx%d %s %s %d:%d (%d)\n" r c slc ord xr xc dv
82 where
83 r = rows m
84 c = cols m
85 xr = xRow m
86 xc = xCol m
87 slc = if isSlice m then "slice" else "full"
88 ord = if is1d m then "1d" else if rowOrder m then "rows" else "cols"
89 dv = dim (xdat m)
90
91--------------------------------------------------------------------------------
92
108-- | Matrix transpose. 93-- | Matrix transpose.
109trans :: Matrix t -> Matrix t 94trans :: Matrix t -> Matrix t
110trans m@Matrix { irows = r, icols = c } | rowOrder m = 95trans m@Matrix { irows = r, icols = c, xRow = xr, xCol = xc } =
111 m { irows = c, icols = r, xRow = 1, xCol = c } 96 m { irows = c, icols = r, xRow = xc, xCol = xr }
112trans m@Matrix { irows = r, icols = c } = 97
113 m { irows = c, icols = r, xRow = r, xCol = 1 }
114 98
115cmat :: (Element t) => Matrix t -> Matrix t 99cmat :: (Element t) => Matrix t -> Matrix t
116cmat m | rowOrder m = m 100cmat m
117cmat m@Matrix { irows = r, icols = c, xdat = d } = 101 | rowOrder m = m
118 m { xdat = transdata r d c, xRow = c, xCol = 1 } 102 | otherwise = extractAll RowMajor m
103
119 104
120fmat :: (Element t) => Matrix t -> Matrix t 105fmat :: (Element t) => Matrix t -> Matrix t
121fmat m | not (rowOrder m) = m 106fmat m
122fmat m@Matrix { irows = r, icols = c, xdat = d} = 107 | colOrder m = m
123 m { xdat = transdata c d r, xRow = 1, xCol = r } 108 | otherwise = extractAll ColumnMajor m
124 109
125 110
126-- C-Haskell matrix adapters 111-- C-Haskell matrix adapters
@@ -157,6 +142,11 @@ a # b = apply a b
157 142
158-------------------------------------------------------------------------------- 143--------------------------------------------------------------------------------
159 144
145extractAll ord m = unsafePerformIO $
146 extractR ord m
147 0 (idxs[0,rows m-1])
148 0 (idxs[0,cols m-1])
149
160{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. 150{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose.
161 151
162>>> flatten (ident 3) 152>>> flatten (ident 3)
@@ -164,12 +154,14 @@ fromList [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]
164 154
165-} 155-}
166flatten :: Element t => Matrix t -> Vector t 156flatten :: Element t => Matrix t -> Vector t
167flatten = xdat . cmat 157flatten m
158 | isSlice m || not (rowOrder m) = xdat (extractAll RowMajor m)
159 | otherwise = xdat m
168 160
169 161
170-- | the inverse of 'Data.Packed.Matrix.fromLists' 162-- | the inverse of 'Data.Packed.Matrix.fromLists'
171toLists :: (Element t) => Matrix t -> [[t]] 163toLists :: (Element t) => Matrix t -> [[t]]
172toLists m = chunksOf (cols m) . toList . flatten $ m 164toLists = map toList . toRows
173 165
174 166
175 167
@@ -205,6 +197,14 @@ fromRows vs = case compatdim (map dim vs) of
205-- | extracts the rows of a matrix as a list of vectors 197-- | extracts the rows of a matrix as a list of vectors
206toRows :: Element t => Matrix t -> [Vector t] 198toRows :: Element t => Matrix t -> [Vector t]
207toRows m 199toRows m
200 | rowOrder m = map sub rowRange
201 | otherwise = map ext rowRange
202 where
203 rowRange = [0..rows m-1]
204 sub k = subVector (k*xRow m) (cols m) (xdat m)
205 ext k = xdat $ unsafePerformIO $ extractR RowMajor m 1 (idxs[k]) 0 (idxs[0,cols m-1])
206
207{-
208 | c == 0 = replicate r (fromList[]) 208 | c == 0 = replicate r (fromList[])
209 | otherwise = toRows' 0 209 | otherwise = toRows' 0
210 where 210 where
@@ -213,6 +213,7 @@ toRows m
213 c = cols m 213 c = cols m
214 toRows' k | k == r*c = [] 214 toRows' k | k == r*c = []
215 | otherwise = subVector k c v : toRows' (k+c) 215 | otherwise = subVector k c v : toRows' (k+c)
216-}
216 217
217-- | Creates a matrix from a list of vectors, as columns 218-- | Creates a matrix from a list of vectors, as columns
218fromColumns :: Element t => [Vector t] -> Matrix t 219fromColumns :: Element t => [Vector t] -> Matrix t
@@ -240,7 +241,7 @@ matrixFromVector o r c v
240 | r * c == dim v = m 241 | r * c == dim v = m
241 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m 242 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m
242 where 243 where
243 m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = c, xCol = 1 } 244 m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = c, xCol = 1 }
244 | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1, xCol = r } 245 | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1, xCol = r }
245 246
246-- allocates memory for a new matrix 247-- allocates memory for a new matrix
@@ -263,31 +264,26 @@ reshape :: Storable t => Int -> Vector t -> Matrix t
263reshape 0 v = matrixFromVector RowMajor 0 0 v 264reshape 0 v = matrixFromVector RowMajor 0 0 v
264reshape c v = matrixFromVector RowMajor (dim v `div` c) c v 265reshape c v = matrixFromVector RowMajor (dim v `div` c) c v
265 266
266--singleton x = reshape 1 (fromList [x])
267 267
268-- | application of a vector function on the flattened matrix elements 268-- | application of a vector function on the flattened matrix elements
269liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b 269liftMatrix :: (Element a, Element b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
270liftMatrix f m@Matrix { irows = r, icols = c, xdat = d} = matrixFromVector (orderOf m) r c (f d) 270liftMatrix f m@Matrix { irows = r, icols = c, xdat = d}
271 | isSlice m = matrixFromVector RowMajor r c (f (flatten m))
272 | otherwise = matrixFromVector (orderOf m) r c (f d)
271 273
272-- | application of a vector function on the flattened matrices elements 274-- | application of a vector function on the flattened matrices elements
273liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t 275liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
274liftMatrix2 f m1 m2 276liftMatrix2 f m1@(size->(r,c)) m2
275 | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" 277 | (r,c)/=size m2 = error "nonconformant matrices in liftMatrix2"
276 | otherwise = case orderOf m1 of 278 | rowOrder m1 = matrixFromVector RowMajor r c (f (flatten m1) (flatten m2))
277 RowMajor -> matrixFromVector RowMajor (rows m1) (cols m1) (f (xdat m1) (flatten m2)) 279 | otherwise = matrixFromVector ColumnMajor r c (f (flatten (trans m1)) (flatten (trans m2)))
278 ColumnMajor -> matrixFromVector ColumnMajor (rows m1) (cols m1) (f (xdat m1) ((xdat.fmat) m2))
279
280
281compat :: Matrix a -> Matrix b -> Bool
282compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2
283 280
284------------------------------------------------------------------ 281------------------------------------------------------------------
285 282
286-- | Supported matrix elements. 283-- | Supported matrix elements.
287class (Storable a) => Element a where 284class (Storable a) => Element a where
288 transdata :: Int -> Vector a -> Int -> Vector a
289 constantD :: a -> Int -> Vector a 285 constantD :: a -> Int -> Vector a
290 extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) 286 extractR :: MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a)
291 setRect :: Int -> Int -> Matrix a -> Matrix a -> IO () 287 setRect :: Int -> Int -> Matrix a -> Matrix a -> IO ()
292 sortI :: Ord a => Vector a -> Vector CInt 288 sortI :: Ord a => Vector a -> Vector CInt
293 sortV :: Ord a => Vector a -> Vector a 289 sortV :: Ord a => Vector a -> Vector a
@@ -299,7 +295,6 @@ class (Storable a) => Element a where
299 295
300 296
301instance Element Float where 297instance Element Float where
302 transdata = transdataAux ctransF
303 constantD = constantAux cconstantF 298 constantD = constantAux cconstantF
304 extractR = extractAux c_extractF 299 extractR = extractAux c_extractF
305 setRect = setRectAux c_setRectF 300 setRect = setRectAux c_setRectF
@@ -312,7 +307,6 @@ instance Element Float where
312 gemm = gemmg c_gemmF 307 gemm = gemmg c_gemmF
313 308
314instance Element Double where 309instance Element Double where
315 transdata = transdataAux ctransR
316 constantD = constantAux cconstantR 310 constantD = constantAux cconstantR
317 extractR = extractAux c_extractD 311 extractR = extractAux c_extractD
318 setRect = setRectAux c_setRectD 312 setRect = setRectAux c_setRectD
@@ -325,7 +319,6 @@ instance Element Double where
325 gemm = gemmg c_gemmD 319 gemm = gemmg c_gemmD
326 320
327instance Element (Complex Float) where 321instance Element (Complex Float) where
328 transdata = transdataAux ctransQ
329 constantD = constantAux cconstantQ 322 constantD = constantAux cconstantQ
330 extractR = extractAux c_extractQ 323 extractR = extractAux c_extractQ
331 setRect = setRectAux c_setRectQ 324 setRect = setRectAux c_setRectQ
@@ -338,7 +331,6 @@ instance Element (Complex Float) where
338 gemm = gemmg c_gemmQ 331 gemm = gemmg c_gemmQ
339 332
340instance Element (Complex Double) where 333instance Element (Complex Double) where
341 transdata = transdataAux ctransC
342 constantD = constantAux cconstantC 334 constantD = constantAux cconstantC
343 extractR = extractAux c_extractC 335 extractR = extractAux c_extractC
344 setRect = setRectAux c_setRectC 336 setRect = setRectAux c_setRectC
@@ -351,7 +343,6 @@ instance Element (Complex Double) where
351 gemm = gemmg c_gemmC 343 gemm = gemmg c_gemmC
352 344
353instance Element (CInt) where 345instance Element (CInt) where
354 transdata = transdataAux ctransI
355 constantD = constantAux cconstantI 346 constantD = constantAux cconstantI
356 extractR = extractAux c_extractI 347 extractR = extractAux c_extractI
357 setRect = setRectAux c_setRectI 348 setRect = setRectAux c_setRectI
@@ -364,7 +355,6 @@ instance Element (CInt) where
364 gemm = gemmg c_gemmI 355 gemm = gemmg c_gemmI
365 356
366instance Element Z where 357instance Element Z where
367 transdata = transdataAux ctransL
368 constantD = constantAux cconstantL 358 constantD = constantAux cconstantL
369 extractR = extractAux c_extractL 359 extractR = extractAux c_extractL
370 setRect = setRectAux c_setRectL 360 setRect = setRectAux c_setRectL
@@ -378,32 +368,6 @@ instance Element Z where
378 368
379------------------------------------------------------------------- 369-------------------------------------------------------------------
380 370
381transdataAux fun c1 d c2 =
382 if noneed
383 then d
384 else unsafePerformIO $ do
385 -- putStrLn "T"
386 v <- createVector (dim d)
387 unsafeWith d $ \pd ->
388 unsafeWith v $ \pv ->
389 fun (fi r1) (fi c1) pd (fi r2) (fi c2) pv // check "transdataAux"
390 return v
391 where r1 = dim d `div` c1
392 r2 = dim d `div` c2
393 noneed = dim d == 0 || r1 == 1 || c1 == 1
394
395
396type TMM t = t ..> t ..> Ok
397
398foreign import ccall unsafe "transF" ctransF :: TMM Float
399foreign import ccall unsafe "transR" ctransR :: TMM Double
400foreign import ccall unsafe "transQ" ctransQ :: TMM (Complex Float)
401foreign import ccall unsafe "transC" ctransC :: TMM (Complex Double)
402foreign import ccall unsafe "transI" ctransI :: TMM CInt
403foreign import ccall unsafe "transL" ctransL :: TMM Z
404
405----------------------------------------------------------------------
406
407subMatrix :: Element a 371subMatrix :: Element a
408 => (Int,Int) -- ^ (r0,c0) starting position 372 => (Int,Int) -- ^ (r0,c0) starting position
409 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix 373 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
@@ -411,9 +375,8 @@ subMatrix :: Element a
411 -> Matrix a -- ^ result 375 -> Matrix a -- ^ result
412subMatrix (r0,c0) (rt,ct) m 376subMatrix (r0,c0) (rt,ct) m
413 | 0 <= r0 && 0 <= rt && r0+rt <= rows m && 377 | 0 <= r0 && 0 <= rt && r0+rt <= rows m &&
414 0 <= c0 && 0 <= ct && c0+ct <= cols m = unsafePerformIO $ extractR m 0 (idxs[r0,r0+rt-1]) 0 (idxs[c0,c0+ct-1]) 378 0 <= c0 && 0 <= ct && c0+ct <= cols m = unsafePerformIO $ extractR RowMajor m 0 (idxs[r0,r0+rt-1]) 0 (idxs[c0,c0+ct-1])
415 | otherwise = error $ "wrong subMatrix "++ 379 | otherwise = error $ "wrong subMatrix "++show ((r0,c0),(rt,ct))++" of "++shSize m
416 show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m)
417 380
418 381
419sliceMatrix :: Element a 382sliceMatrix :: Element a
@@ -424,11 +387,12 @@ sliceMatrix :: Element a
424sliceMatrix (r0,c0) (rt,ct) m 387sliceMatrix (r0,c0) (rt,ct) m
425 | 0 <= r0 && 0 <= rt && r0+rt <= rows m && 388 | 0 <= r0 && 0 <= rt && r0+rt <= rows m &&
426 0 <= c0 && 0 <= ct && c0+ct <= cols m = res 389 0 <= c0 && 0 <= ct && c0+ct <= cols m = res
427 | otherwise = error $ "wrong sliceMatrix "++ 390 | otherwise = error $ "wrong sliceMatrix "++show ((r0,c0),(rt,ct))++" of "++shSize m
428 show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m)
429 where 391 where
430 t = r0 * xRow m + c0 * xCol m 392 p = r0 * xRow m + c0 * xCol m
431 res = m { irows = rt, icols = ct, xdat = subVector t (rt*ct) (xdat m) } 393 tot | rowOrder m = ct + (rt-1) * xRow m
394 | otherwise = rt + (ct-1) * xCol m
395 res = m { irows = rt, icols = ct, xdat = subVector p tot (xdat m) }
432 396
433-------------------------------------------------------------------------- 397--------------------------------------------------------------------------
434 398
@@ -449,7 +413,7 @@ conformMTo (r,c) m
449 | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) 413 | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c))
450 | size m == (r,1) = repCols c m 414 | size m == (r,1) = repCols c m
451 | size m == (1,c) = repRows r m 415 | size m == (1,c) = repRows r m
452 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")" 416 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c)
453 417
454conformVTo n v 418conformVTo n v
455 | dim v == n = v 419 | dim v == n = v
@@ -459,9 +423,9 @@ conformVTo n v
459repRows n x = fromRows (replicate n (flatten x)) 423repRows n x = fromRows (replicate n (flatten x))
460repCols n x = fromColumns (replicate n (flatten x)) 424repCols n x = fromColumns (replicate n (flatten x))
461 425
462size m = (rows m, cols m) 426shSize = shDim . size
463 427
464shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" 428shDim (r,c) = "(" ++ show r ++"x"++ show c ++")"
465 429
466emptyM r c = matrixFromVector RowMajor r c (fromList[]) 430emptyM r c = matrixFromVector RowMajor r c (fromList[])
467 431
@@ -477,10 +441,10 @@ instance (Storable t, NFData t) => NFData (Matrix t)
477 441
478--------------------------------------------------------------- 442---------------------------------------------------------------
479 443
480extractAux f m moder vr modec vc = do 444extractAux f ord m moder vr modec vc = do
481 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr 445 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr
482 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc 446 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc
483 r <- createMatrix RowMajor nr nc 447 r <- createMatrix ord nr nc
484 f moder modec # vr # vc # m # r #|"extract" 448 f moder modec # vr # vc # m # r #|"extract"
485 return r 449 return r
486 450