diff options
Diffstat (limited to 'packages/base/src/Data')
-rw-r--r-- | packages/base/src/Data/Packed/Internal/Matrix.hs | 45 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/Internal/Signatures.hs | 7 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/Internal/Vector.hs | 6 |
3 files changed, 55 insertions, 3 deletions
diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs index 150b978..be5fb03 100644 --- a/packages/base/src/Data/Packed/Internal/Matrix.hs +++ b/packages/base/src/Data/Packed/Internal/Matrix.hs | |||
@@ -265,23 +265,33 @@ class (Storable a) => Element a where | |||
265 | transdata = transdataP -- transdata' | 265 | transdata = transdataP -- transdata' |
266 | constantD :: a -> Int -> Vector a | 266 | constantD :: a -> Int -> Vector a |
267 | constantD = constantP -- constant' | 267 | constantD = constantP -- constant' |
268 | 268 | extractR :: Matrix a -> Idxs -> Matrix a | |
269 | 269 | ||
270 | instance Element Float where | 270 | instance Element Float where |
271 | transdata = transdataAux ctransF | 271 | transdata = transdataAux ctransF |
272 | constantD = constantAux cconstantF | 272 | constantD = constantAux cconstantF |
273 | extractR = extractAux c_extractRF | ||
273 | 274 | ||
274 | instance Element Double where | 275 | instance Element Double where |
275 | transdata = transdataAux ctransR | 276 | transdata = transdataAux ctransR |
276 | constantD = constantAux cconstantR | 277 | constantD = constantAux cconstantR |
278 | extractR = extractAux c_extractRD | ||
277 | 279 | ||
278 | instance Element (Complex Float) where | 280 | instance Element (Complex Float) where |
279 | transdata = transdataAux ctransQ | 281 | transdata = transdataAux ctransQ |
280 | constantD = constantAux cconstantQ | 282 | constantD = constantAux cconstantQ |
283 | extractR = extractAux c_extractRQ | ||
281 | 284 | ||
282 | instance Element (Complex Double) where | 285 | instance Element (Complex Double) where |
283 | transdata = transdataAux ctransC | 286 | transdata = transdataAux ctransC |
284 | constantD = constantAux cconstantC | 287 | constantD = constantAux cconstantC |
288 | extractR = extractAux c_extractRC | ||
289 | |||
290 | instance Element (CInt) where | ||
291 | transdata = transdataAux ctransI | ||
292 | constantD = constantAux cconstantI | ||
293 | extractR = extractAux c_extractRI | ||
294 | |||
285 | 295 | ||
286 | ------------------------------------------------------------------- | 296 | ------------------------------------------------------------------- |
287 | 297 | ||
@@ -289,6 +299,7 @@ transdataAux fun c1 d c2 = | |||
289 | if noneed | 299 | if noneed |
290 | then d | 300 | then d |
291 | else unsafePerformIO $ do | 301 | else unsafePerformIO $ do |
302 | -- putStrLn "T" | ||
292 | v <- createVector (dim d) | 303 | v <- createVector (dim d) |
293 | unsafeWith d $ \pd -> | 304 | unsafeWith d $ \pd -> |
294 | unsafeWith v $ \pv -> | 305 | unsafeWith v $ \pv -> |
@@ -317,6 +328,7 @@ foreign import ccall unsafe "transF" ctransF :: TFMFM | |||
317 | foreign import ccall unsafe "transR" ctransR :: TMM | 328 | foreign import ccall unsafe "transR" ctransR :: TMM |
318 | foreign import ccall unsafe "transQ" ctransQ :: TQMQM | 329 | foreign import ccall unsafe "transQ" ctransQ :: TQMQM |
319 | foreign import ccall unsafe "transC" ctransC :: TCMCM | 330 | foreign import ccall unsafe "transC" ctransC :: TCMCM |
331 | foreign import ccall unsafe "transI" ctransI :: CM CInt (CM CInt (IO CInt)) | ||
320 | foreign import ccall unsafe "transP" ctransP :: CInt -> CInt -> Ptr () -> CInt -> CInt -> CInt -> Ptr () -> CInt -> IO CInt | 332 | foreign import ccall unsafe "transP" ctransP :: CInt -> CInt -> Ptr () -> CInt -> CInt -> CInt -> Ptr () -> CInt -> IO CInt |
321 | 333 | ||
322 | ---------------------------------------------------------------------- | 334 | ---------------------------------------------------------------------- |
@@ -336,6 +348,8 @@ foreign import ccall unsafe "constantQ" cconstantQ :: Ptr (Complex Float) -> TQV | |||
336 | 348 | ||
337 | foreign import ccall unsafe "constantC" cconstantC :: Ptr (Complex Double) -> TCV | 349 | foreign import ccall unsafe "constantC" cconstantC :: Ptr (Complex Double) -> TCV |
338 | 350 | ||
351 | foreign import ccall unsafe "constantI" cconstantI :: Ptr CInt -> CV CInt (IO CInt) | ||
352 | |||
339 | constantP :: Storable a => a -> Int -> Vector a | 353 | constantP :: Storable a => a -> Int -> Vector a |
340 | constantP a n = unsafePerformIO $ do | 354 | constantP a n = unsafePerformIO $ do |
341 | let sz = sizeOf a | 355 | let sz = sizeOf a |
@@ -421,3 +435,32 @@ instance (Storable t, NFData t) => NFData (Matrix t) | |||
421 | d = dim v | 435 | d = dim v |
422 | v = xdat m | 436 | v = xdat m |
423 | 437 | ||
438 | --------------------------------------------------------------- | ||
439 | |||
440 | isT Matrix{order = ColumnMajor} = 1 | ||
441 | isT Matrix{order = RowMajor} = 0 | ||
442 | |||
443 | tt x@Matrix{order = ColumnMajor} = trans x | ||
444 | tt x@Matrix{order = RowMajor} = x | ||
445 | |||
446 | --extractAux :: Matrix Double -> Idxs -> Matrix Double | ||
447 | extractAux f m v = unsafePerformIO $ do | ||
448 | r <- createMatrix RowMajor (dim v) (cols m) | ||
449 | app3 (f (isT m)) vec v mat (tt m) mat r "extractAux" | ||
450 | return r | ||
451 | |||
452 | foreign import ccall unsafe "extractRD" c_extractRD | ||
453 | :: CInt -> CIdxs (CM Double (CM Double (IO CInt))) | ||
454 | |||
455 | foreign import ccall unsafe "extractRF" c_extractRF | ||
456 | :: CInt -> CIdxs (CM Float (CM Float (IO CInt))) | ||
457 | |||
458 | foreign import ccall unsafe "extractRC" c_extractRC | ||
459 | :: CInt -> CIdxs (CM (Complex Double) (CM (Complex Double) (IO CInt))) | ||
460 | |||
461 | foreign import ccall unsafe "extractRQ" c_extractRQ | ||
462 | :: CInt -> CIdxs (CM (Complex Float) (CM (Complex Float) (IO CInt))) | ||
463 | |||
464 | foreign import ccall unsafe "extractRI" c_extractRI | ||
465 | :: CInt -> CIdxs (CM CInt (CM CInt (IO CInt))) | ||
466 | |||
diff --git a/packages/base/src/Data/Packed/Internal/Signatures.hs b/packages/base/src/Data/Packed/Internal/Signatures.hs index acc3070..37dac16 100644 --- a/packages/base/src/Data/Packed/Internal/Signatures.hs +++ b/packages/base/src/Data/Packed/Internal/Signatures.hs | |||
@@ -1,6 +1,6 @@ | |||
1 | -- | | 1 | -- | |
2 | -- Module : Data.Packed.Internal.Signatures | 2 | -- Module : Data.Packed.Internal.Signatures |
3 | -- Copyright : (c) Alberto Ruiz 2009 | 3 | -- Copyright : (c) Alberto Ruiz 2009-15 |
4 | -- License : BSD3 | 4 | -- License : BSD3 |
5 | -- Maintainer : Alberto Ruiz | 5 | -- Maintainer : Alberto Ruiz |
6 | -- Stability : provisional | 6 | -- Stability : provisional |
@@ -68,3 +68,8 @@ type TCVM = CInt -> PC -> TM -- | |||
68 | type TMCVM = CInt -> CInt -> PD -> TCVM -- | 68 | type TMCVM = CInt -> CInt -> PD -> TCVM -- |
69 | type TMMCVM = CInt -> CInt -> PD -> TMCVM -- | 69 | type TMMCVM = CInt -> CInt -> PD -> TMCVM -- |
70 | 70 | ||
71 | type CM b r = CInt -> CInt -> Ptr b -> r | ||
72 | type CV b r = CInt -> Ptr b -> r | ||
73 | |||
74 | type CIdxs r = CV CInt r | ||
75 | |||
diff --git a/packages/base/src/Data/Packed/Internal/Vector.hs b/packages/base/src/Data/Packed/Internal/Vector.hs index b49f379..2a6ed2c 100644 --- a/packages/base/src/Data/Packed/Internal/Vector.hs +++ b/packages/base/src/Data/Packed/Internal/Vector.hs | |||
@@ -24,7 +24,8 @@ module Data.Packed.Internal.Vector ( | |||
24 | cloneVector, | 24 | cloneVector, |
25 | unsafeToForeignPtr, | 25 | unsafeToForeignPtr, |
26 | unsafeFromForeignPtr, | 26 | unsafeFromForeignPtr, |
27 | unsafeWith | 27 | unsafeWith, |
28 | Idxs | ||
28 | ) where | 29 | ) where |
29 | 30 | ||
30 | import Data.Packed.Internal.Common | 31 | import Data.Packed.Internal.Common |
@@ -56,6 +57,8 @@ import Data.Vector.Storable(Vector, | |||
56 | unsafeWith) | 57 | unsafeWith) |
57 | 58 | ||
58 | 59 | ||
60 | type Idxs = Vector CInt | ||
61 | |||
59 | -- | Number of elements | 62 | -- | Number of elements |
60 | dim :: (Storable t) => Vector t -> Int | 63 | dim :: (Storable t) => Vector t -> Int |
61 | dim = Vector.length | 64 | dim = Vector.length |
@@ -243,6 +246,7 @@ double2FloatV v = unsafePerformIO $ do | |||
243 | foreign import ccall unsafe "float2double" c_float2double:: TFV | 246 | foreign import ccall unsafe "float2double" c_float2double:: TFV |
244 | foreign import ccall unsafe "double2float" c_double2float:: TVF | 247 | foreign import ccall unsafe "double2float" c_double2float:: TVF |
245 | 248 | ||
249 | |||
246 | --------------------------------------------------------------- | 250 | --------------------------------------------------------------- |
247 | 251 | ||
248 | stepF :: Vector Float -> Vector Float | 252 | stepF :: Vector Float -> Vector Float |