summaryrefslogtreecommitdiff
path: root/packages/base/src/Data/Packed/Internal/Matrix.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Data/Packed/Internal/Matrix.hs')
-rw-r--r--packages/base/src/Data/Packed/Internal/Matrix.hs45
1 files changed, 44 insertions, 1 deletions
diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs
index 150b978..be5fb03 100644
--- a/packages/base/src/Data/Packed/Internal/Matrix.hs
+++ b/packages/base/src/Data/Packed/Internal/Matrix.hs
@@ -265,23 +265,33 @@ class (Storable a) => Element a where
265 transdata = transdataP -- transdata' 265 transdata = transdataP -- transdata'
266 constantD :: a -> Int -> Vector a 266 constantD :: a -> Int -> Vector a
267 constantD = constantP -- constant' 267 constantD = constantP -- constant'
268 268 extractR :: Matrix a -> Idxs -> Matrix a
269 269
270instance Element Float where 270instance Element Float where
271 transdata = transdataAux ctransF 271 transdata = transdataAux ctransF
272 constantD = constantAux cconstantF 272 constantD = constantAux cconstantF
273 extractR = extractAux c_extractRF
273 274
274instance Element Double where 275instance Element Double where
275 transdata = transdataAux ctransR 276 transdata = transdataAux ctransR
276 constantD = constantAux cconstantR 277 constantD = constantAux cconstantR
278 extractR = extractAux c_extractRD
277 279
278instance Element (Complex Float) where 280instance Element (Complex Float) where
279 transdata = transdataAux ctransQ 281 transdata = transdataAux ctransQ
280 constantD = constantAux cconstantQ 282 constantD = constantAux cconstantQ
283 extractR = extractAux c_extractRQ
281 284
282instance Element (Complex Double) where 285instance Element (Complex Double) where
283 transdata = transdataAux ctransC 286 transdata = transdataAux ctransC
284 constantD = constantAux cconstantC 287 constantD = constantAux cconstantC
288 extractR = extractAux c_extractRC
289
290instance Element (CInt) where
291 transdata = transdataAux ctransI
292 constantD = constantAux cconstantI
293 extractR = extractAux c_extractRI
294
285 295
286------------------------------------------------------------------- 296-------------------------------------------------------------------
287 297
@@ -289,6 +299,7 @@ transdataAux fun c1 d c2 =
289 if noneed 299 if noneed
290 then d 300 then d
291 else unsafePerformIO $ do 301 else unsafePerformIO $ do
302 -- putStrLn "T"
292 v <- createVector (dim d) 303 v <- createVector (dim d)
293 unsafeWith d $ \pd -> 304 unsafeWith d $ \pd ->
294 unsafeWith v $ \pv -> 305 unsafeWith v $ \pv ->
@@ -317,6 +328,7 @@ foreign import ccall unsafe "transF" ctransF :: TFMFM
317foreign import ccall unsafe "transR" ctransR :: TMM 328foreign import ccall unsafe "transR" ctransR :: TMM
318foreign import ccall unsafe "transQ" ctransQ :: TQMQM 329foreign import ccall unsafe "transQ" ctransQ :: TQMQM
319foreign import ccall unsafe "transC" ctransC :: TCMCM 330foreign import ccall unsafe "transC" ctransC :: TCMCM
331foreign import ccall unsafe "transI" ctransI :: CM CInt (CM CInt (IO CInt))
320foreign import ccall unsafe "transP" ctransP :: CInt -> CInt -> Ptr () -> CInt -> CInt -> CInt -> Ptr () -> CInt -> IO CInt 332foreign import ccall unsafe "transP" ctransP :: CInt -> CInt -> Ptr () -> CInt -> CInt -> CInt -> Ptr () -> CInt -> IO CInt
321 333
322---------------------------------------------------------------------- 334----------------------------------------------------------------------
@@ -336,6 +348,8 @@ foreign import ccall unsafe "constantQ" cconstantQ :: Ptr (Complex Float) -> TQV
336 348
337foreign import ccall unsafe "constantC" cconstantC :: Ptr (Complex Double) -> TCV 349foreign import ccall unsafe "constantC" cconstantC :: Ptr (Complex Double) -> TCV
338 350
351foreign import ccall unsafe "constantI" cconstantI :: Ptr CInt -> CV CInt (IO CInt)
352
339constantP :: Storable a => a -> Int -> Vector a 353constantP :: Storable a => a -> Int -> Vector a
340constantP a n = unsafePerformIO $ do 354constantP a n = unsafePerformIO $ do
341 let sz = sizeOf a 355 let sz = sizeOf a
@@ -421,3 +435,32 @@ instance (Storable t, NFData t) => NFData (Matrix t)
421 d = dim v 435 d = dim v
422 v = xdat m 436 v = xdat m
423 437
438---------------------------------------------------------------
439
440isT Matrix{order = ColumnMajor} = 1
441isT Matrix{order = RowMajor} = 0
442
443tt x@Matrix{order = ColumnMajor} = trans x
444tt x@Matrix{order = RowMajor} = x
445
446--extractAux :: Matrix Double -> Idxs -> Matrix Double
447extractAux f m v = unsafePerformIO $ do
448 r <- createMatrix RowMajor (dim v) (cols m)
449 app3 (f (isT m)) vec v mat (tt m) mat r "extractAux"
450 return r
451
452foreign import ccall unsafe "extractRD" c_extractRD
453 :: CInt -> CIdxs (CM Double (CM Double (IO CInt)))
454
455foreign import ccall unsafe "extractRF" c_extractRF
456 :: CInt -> CIdxs (CM Float (CM Float (IO CInt)))
457
458foreign import ccall unsafe "extractRC" c_extractRC
459 :: CInt -> CIdxs (CM (Complex Double) (CM (Complex Double) (IO CInt)))
460
461foreign import ccall unsafe "extractRQ" c_extractRQ
462 :: CInt -> CIdxs (CM (Complex Float) (CM (Complex Float) (IO CInt)))
463
464foreign import ccall unsafe "extractRI" c_extractRI
465 :: CInt -> CIdxs (CM CInt (CM CInt (IO CInt)))
466