summaryrefslogtreecommitdiff
path: root/packages/base/src/Data
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Data')
-rw-r--r--packages/base/src/Data/Packed/Internal/Matrix.hs45
-rw-r--r--packages/base/src/Data/Packed/Internal/Signatures.hs7
-rw-r--r--packages/base/src/Data/Packed/Internal/Vector.hs6
3 files changed, 55 insertions, 3 deletions
diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs
index 150b978..be5fb03 100644
--- a/packages/base/src/Data/Packed/Internal/Matrix.hs
+++ b/packages/base/src/Data/Packed/Internal/Matrix.hs
@@ -265,23 +265,33 @@ class (Storable a) => Element a where
265 transdata = transdataP -- transdata' 265 transdata = transdataP -- transdata'
266 constantD :: a -> Int -> Vector a 266 constantD :: a -> Int -> Vector a
267 constantD = constantP -- constant' 267 constantD = constantP -- constant'
268 268 extractR :: Matrix a -> Idxs -> Matrix a
269 269
270instance Element Float where 270instance Element Float where
271 transdata = transdataAux ctransF 271 transdata = transdataAux ctransF
272 constantD = constantAux cconstantF 272 constantD = constantAux cconstantF
273 extractR = extractAux c_extractRF
273 274
274instance Element Double where 275instance Element Double where
275 transdata = transdataAux ctransR 276 transdata = transdataAux ctransR
276 constantD = constantAux cconstantR 277 constantD = constantAux cconstantR
278 extractR = extractAux c_extractRD
277 279
278instance Element (Complex Float) where 280instance Element (Complex Float) where
279 transdata = transdataAux ctransQ 281 transdata = transdataAux ctransQ
280 constantD = constantAux cconstantQ 282 constantD = constantAux cconstantQ
283 extractR = extractAux c_extractRQ
281 284
282instance Element (Complex Double) where 285instance Element (Complex Double) where
283 transdata = transdataAux ctransC 286 transdata = transdataAux ctransC
284 constantD = constantAux cconstantC 287 constantD = constantAux cconstantC
288 extractR = extractAux c_extractRC
289
290instance Element (CInt) where
291 transdata = transdataAux ctransI
292 constantD = constantAux cconstantI
293 extractR = extractAux c_extractRI
294
285 295
286------------------------------------------------------------------- 296-------------------------------------------------------------------
287 297
@@ -289,6 +299,7 @@ transdataAux fun c1 d c2 =
289 if noneed 299 if noneed
290 then d 300 then d
291 else unsafePerformIO $ do 301 else unsafePerformIO $ do
302 -- putStrLn "T"
292 v <- createVector (dim d) 303 v <- createVector (dim d)
293 unsafeWith d $ \pd -> 304 unsafeWith d $ \pd ->
294 unsafeWith v $ \pv -> 305 unsafeWith v $ \pv ->
@@ -317,6 +328,7 @@ foreign import ccall unsafe "transF" ctransF :: TFMFM
317foreign import ccall unsafe "transR" ctransR :: TMM 328foreign import ccall unsafe "transR" ctransR :: TMM
318foreign import ccall unsafe "transQ" ctransQ :: TQMQM 329foreign import ccall unsafe "transQ" ctransQ :: TQMQM
319foreign import ccall unsafe "transC" ctransC :: TCMCM 330foreign import ccall unsafe "transC" ctransC :: TCMCM
331foreign import ccall unsafe "transI" ctransI :: CM CInt (CM CInt (IO CInt))
320foreign import ccall unsafe "transP" ctransP :: CInt -> CInt -> Ptr () -> CInt -> CInt -> CInt -> Ptr () -> CInt -> IO CInt 332foreign import ccall unsafe "transP" ctransP :: CInt -> CInt -> Ptr () -> CInt -> CInt -> CInt -> Ptr () -> CInt -> IO CInt
321 333
322---------------------------------------------------------------------- 334----------------------------------------------------------------------
@@ -336,6 +348,8 @@ foreign import ccall unsafe "constantQ" cconstantQ :: Ptr (Complex Float) -> TQV
336 348
337foreign import ccall unsafe "constantC" cconstantC :: Ptr (Complex Double) -> TCV 349foreign import ccall unsafe "constantC" cconstantC :: Ptr (Complex Double) -> TCV
338 350
351foreign import ccall unsafe "constantI" cconstantI :: Ptr CInt -> CV CInt (IO CInt)
352
339constantP :: Storable a => a -> Int -> Vector a 353constantP :: Storable a => a -> Int -> Vector a
340constantP a n = unsafePerformIO $ do 354constantP a n = unsafePerformIO $ do
341 let sz = sizeOf a 355 let sz = sizeOf a
@@ -421,3 +435,32 @@ instance (Storable t, NFData t) => NFData (Matrix t)
421 d = dim v 435 d = dim v
422 v = xdat m 436 v = xdat m
423 437
438---------------------------------------------------------------
439
440isT Matrix{order = ColumnMajor} = 1
441isT Matrix{order = RowMajor} = 0
442
443tt x@Matrix{order = ColumnMajor} = trans x
444tt x@Matrix{order = RowMajor} = x
445
446--extractAux :: Matrix Double -> Idxs -> Matrix Double
447extractAux f m v = unsafePerformIO $ do
448 r <- createMatrix RowMajor (dim v) (cols m)
449 app3 (f (isT m)) vec v mat (tt m) mat r "extractAux"
450 return r
451
452foreign import ccall unsafe "extractRD" c_extractRD
453 :: CInt -> CIdxs (CM Double (CM Double (IO CInt)))
454
455foreign import ccall unsafe "extractRF" c_extractRF
456 :: CInt -> CIdxs (CM Float (CM Float (IO CInt)))
457
458foreign import ccall unsafe "extractRC" c_extractRC
459 :: CInt -> CIdxs (CM (Complex Double) (CM (Complex Double) (IO CInt)))
460
461foreign import ccall unsafe "extractRQ" c_extractRQ
462 :: CInt -> CIdxs (CM (Complex Float) (CM (Complex Float) (IO CInt)))
463
464foreign import ccall unsafe "extractRI" c_extractRI
465 :: CInt -> CIdxs (CM CInt (CM CInt (IO CInt)))
466
diff --git a/packages/base/src/Data/Packed/Internal/Signatures.hs b/packages/base/src/Data/Packed/Internal/Signatures.hs
index acc3070..37dac16 100644
--- a/packages/base/src/Data/Packed/Internal/Signatures.hs
+++ b/packages/base/src/Data/Packed/Internal/Signatures.hs
@@ -1,6 +1,6 @@
1-- | 1-- |
2-- Module : Data.Packed.Internal.Signatures 2-- Module : Data.Packed.Internal.Signatures
3-- Copyright : (c) Alberto Ruiz 2009 3-- Copyright : (c) Alberto Ruiz 2009-15
4-- License : BSD3 4-- License : BSD3
5-- Maintainer : Alberto Ruiz 5-- Maintainer : Alberto Ruiz
6-- Stability : provisional 6-- Stability : provisional
@@ -68,3 +68,8 @@ type TCVM = CInt -> PC -> TM --
68type TMCVM = CInt -> CInt -> PD -> TCVM -- 68type TMCVM = CInt -> CInt -> PD -> TCVM --
69type TMMCVM = CInt -> CInt -> PD -> TMCVM -- 69type TMMCVM = CInt -> CInt -> PD -> TMCVM --
70 70
71type CM b r = CInt -> CInt -> Ptr b -> r
72type CV b r = CInt -> Ptr b -> r
73
74type CIdxs r = CV CInt r
75
diff --git a/packages/base/src/Data/Packed/Internal/Vector.hs b/packages/base/src/Data/Packed/Internal/Vector.hs
index b49f379..2a6ed2c 100644
--- a/packages/base/src/Data/Packed/Internal/Vector.hs
+++ b/packages/base/src/Data/Packed/Internal/Vector.hs
@@ -24,7 +24,8 @@ module Data.Packed.Internal.Vector (
24 cloneVector, 24 cloneVector,
25 unsafeToForeignPtr, 25 unsafeToForeignPtr,
26 unsafeFromForeignPtr, 26 unsafeFromForeignPtr,
27 unsafeWith 27 unsafeWith,
28 Idxs
28) where 29) where
29 30
30import Data.Packed.Internal.Common 31import Data.Packed.Internal.Common
@@ -56,6 +57,8 @@ import Data.Vector.Storable(Vector,
56 unsafeWith) 57 unsafeWith)
57 58
58 59
60type Idxs = Vector CInt
61
59-- | Number of elements 62-- | Number of elements
60dim :: (Storable t) => Vector t -> Int 63dim :: (Storable t) => Vector t -> Int
61dim = Vector.length 64dim = Vector.length
@@ -243,6 +246,7 @@ double2FloatV v = unsafePerformIO $ do
243foreign import ccall unsafe "float2double" c_float2double:: TFV 246foreign import ccall unsafe "float2double" c_float2double:: TFV
244foreign import ccall unsafe "double2float" c_double2float:: TVF 247foreign import ccall unsafe "double2float" c_double2float:: TVF
245 248
249
246--------------------------------------------------------------- 250---------------------------------------------------------------
247 251
248stepF :: Vector Float -> Vector Float 252stepF :: Vector Float -> Vector Float