diff options
Diffstat (limited to 'lib/Data')
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 33 |
1 files changed, 31 insertions, 2 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 7a17ef0..090826d 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -252,10 +252,11 @@ class (Storable a, Floating a) => Element a where | |||
252 | -> Matrix a -> Matrix a | 252 | -> Matrix a -> Matrix a |
253 | subMatrixD = subMatrix' | 253 | subMatrixD = subMatrix' |
254 | transdata :: Int -> Vector a -> Int -> Vector a | 254 | transdata :: Int -> Vector a -> Int -> Vector a |
255 | transdata = transdata' | 255 | transdata = transdataP -- transdata' |
256 | constantD :: a -> Int -> Vector a | 256 | constantD :: a -> Int -> Vector a |
257 | constantD = constant' | 257 | constantD = constantP -- constant' |
258 | conjugateD :: Vector a -> Vector a | 258 | conjugateD :: Vector a -> Vector a |
259 | conjugateD = id | ||
259 | 260 | ||
260 | instance Element Float where | 261 | instance Element Float where |
261 | transdata = transdataAux ctransF | 262 | transdata = transdataAux ctransF |
@@ -320,10 +321,27 @@ transdataAux fun c1 d c2 = | |||
320 | r2 = dim d `div` c2 | 321 | r2 = dim d `div` c2 |
321 | noneed = r1 == 1 || c1 == 1 | 322 | noneed = r1 == 1 || c1 == 1 |
322 | 323 | ||
324 | transdataP :: Storable a => Int -> Vector a -> Int -> Vector a | ||
325 | transdataP c1 d c2 = | ||
326 | if noneed | ||
327 | then d | ||
328 | else unsafePerformIO $ do | ||
329 | v <- createVector (dim d) | ||
330 | unsafeWith d $ \pd -> | ||
331 | unsafeWith v $ \pv -> | ||
332 | ctransP (fi r1) (fi c1) (castPtr pd) (fi sz) (fi r2) (fi c2) (castPtr pv) (fi sz) // check "transdataStorable" | ||
333 | return v | ||
334 | where r1 = dim d `div` c1 | ||
335 | r2 = dim d `div` c2 | ||
336 | sz = sizeOf (d @> 0) | ||
337 | noneed = r1 == 1 || c1 == 1 | ||
338 | |||
323 | foreign import ccall "transF" ctransF :: TFMFM | 339 | foreign import ccall "transF" ctransF :: TFMFM |
324 | foreign import ccall "transR" ctransR :: TMM | 340 | foreign import ccall "transR" ctransR :: TMM |
325 | foreign import ccall "transQ" ctransQ :: TQMQM | 341 | foreign import ccall "transQ" ctransQ :: TQMQM |
326 | foreign import ccall "transC" ctransC :: TCMCM | 342 | foreign import ccall "transC" ctransC :: TCMCM |
343 | foreign import ccall "transP" ctransP :: CInt -> CInt -> Ptr () -> CInt -> CInt -> CInt -> Ptr () -> CInt -> IO CInt | ||
344 | |||
327 | ---------------------------------------------------------------------- | 345 | ---------------------------------------------------------------------- |
328 | 346 | ||
329 | constant' v n = unsafePerformIO $ do | 347 | constant' v n = unsafePerformIO $ do |
@@ -359,6 +377,17 @@ constantC :: Complex Double -> Int -> Vector (Complex Double) | |||
359 | constantC = constantAux cconstantC | 377 | constantC = constantAux cconstantC |
360 | foreign import ccall "constantC" cconstantC :: Ptr (Complex Double) -> TCV | 378 | foreign import ccall "constantC" cconstantC :: Ptr (Complex Double) -> TCV |
361 | 379 | ||
380 | constantP :: Storable a => a -> Int -> Vector a | ||
381 | constantP a n = unsafePerformIO $ do | ||
382 | let sz = sizeOf a | ||
383 | v <- createVector n | ||
384 | unsafeWith v $ \p -> do | ||
385 | alloca $ \k -> do | ||
386 | poke k a | ||
387 | cconstantP (castPtr k) (fi n) (castPtr p) (fi sz) // check "constantP" | ||
388 | return v | ||
389 | foreign import ccall "constantP" cconstantP :: Ptr () -> CInt -> Ptr () -> CInt -> IO CInt | ||
390 | |||
362 | --------------------------------------- | 391 | --------------------------------------- |
363 | 392 | ||
364 | conjugateAux fun x = unsafePerformIO $ do | 393 | conjugateAux fun x = unsafePerformIO $ do |