diff options
Diffstat (limited to 'packages/base/src')
-rw-r--r-- | packages/base/src/Internal/Algorithms.hs | 5 | ||||
-rw-r--r-- | packages/base/src/Internal/Container.hs | 96 | ||||
-rw-r--r-- | packages/base/src/Internal/Conversion.hs | 7 | ||||
-rw-r--r-- | packages/base/src/Internal/Convolution.hs | 7 | ||||
-rw-r--r-- | packages/base/src/Internal/Devel.hs | 21 | ||||
-rw-r--r-- | packages/base/src/Internal/Element.hs | 84 | ||||
-rw-r--r-- | packages/base/src/Internal/Extract.hs | 145 | ||||
-rw-r--r-- | packages/base/src/Internal/IO.hs | 9 | ||||
-rw-r--r-- | packages/base/src/Internal/LAPACK.hs | 19 | ||||
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 307 | ||||
-rw-r--r-- | packages/base/src/Internal/Modular.hs | 6 | ||||
-rw-r--r-- | packages/base/src/Internal/Numeric.hs | 80 | ||||
-rw-r--r-- | packages/base/src/Internal/ST.hs | 131 | ||||
-rw-r--r-- | packages/base/src/Internal/Sparse.hs | 16 | ||||
-rw-r--r-- | packages/base/src/Internal/Util.hs | 15 | ||||
-rw-r--r-- | packages/base/src/Internal/Vector.hs | 10 | ||||
-rw-r--r-- | packages/base/src/Internal/Vectorized.hs | 133 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra.hs | 2 |
18 files changed, 720 insertions, 373 deletions
diff --git a/packages/base/src/Internal/Algorithms.hs b/packages/base/src/Internal/Algorithms.hs index f5bddc6..aa51792 100644 --- a/packages/base/src/Internal/Algorithms.hs +++ b/packages/base/src/Internal/Algorithms.hs | |||
@@ -39,6 +39,7 @@ import qualified Data.Vector.Storable as Vector | |||
39 | import Internal.ST | 39 | import Internal.ST |
40 | import Internal.Vectorized(range) | 40 | import Internal.Vectorized(range) |
41 | import Control.DeepSeq | 41 | import Control.DeepSeq |
42 | import Foreign.Storable | ||
42 | 43 | ||
43 | {- | Generic linear algebra functions for double precision real and complex matrices. | 44 | {- | Generic linear algebra functions for double precision real and complex matrices. |
44 | 45 | ||
@@ -742,7 +743,7 @@ pinvTol t m = v' `mXm` diag s' `mXm` ctrans u' where | |||
742 | 743 | ||
743 | 744 | ||
744 | -- | Numeric rank of a matrix from the SVD decomposition. | 745 | -- | Numeric rank of a matrix from the SVD decomposition. |
745 | rankSVD :: Element t | 746 | rankSVD :: Storable t |
746 | => Double -- ^ numeric zero (e.g. 1*'eps') | 747 | => Double -- ^ numeric zero (e.g. 1*'eps') |
747 | -> Matrix t -- ^ input matrix m | 748 | -> Matrix t -- ^ input matrix m |
748 | -> Vector Double -- ^ 'sv' of m | 749 | -> Vector Double -- ^ 'sv' of m |
@@ -1003,7 +1004,7 @@ fixPerm' s = res $ mutable f s0 | |||
1003 | s0 = reshape 1 (range (length s)) | 1004 | s0 = reshape 1 (range (length s)) |
1004 | res = flatten . fst | 1005 | res = flatten . fst |
1005 | swap m i j = rowOper (SWAP i j AllCols) m | 1006 | swap m i j = rowOper (SWAP i j AllCols) m |
1006 | f :: (Num t, Element t) => (Int, Int) -> STMatrix s t -> ST s () -- needed because of TypeFamilies | 1007 | f :: (Num t, Storable t) => (Int, Int) -> STMatrix s t -> ST s () -- needed because of TypeFamilies |
1007 | f _ p = sequence_ $ zipWith (swap p) [0..] s | 1008 | f _ p = sequence_ $ zipWith (swap p) [0..] s |
1008 | 1009 | ||
1009 | triang r c h v = (r><c) [el s t | s<-[0..r-1], t<-[0..c-1]] | 1010 | triang r c h v = (r><c) [el s t | s<-[0..r-1], t<-[0..c-1]] |
diff --git a/packages/base/src/Internal/Container.hs b/packages/base/src/Internal/Container.hs index 41b8214..0f2e7d5 100644 --- a/packages/base/src/Internal/Container.hs +++ b/packages/base/src/Internal/Container.hs | |||
@@ -4,6 +4,8 @@ | |||
4 | {-# LANGUAGE MultiParamTypeClasses #-} | 4 | {-# LANGUAGE MultiParamTypeClasses #-} |
5 | {-# LANGUAGE FunctionalDependencies #-} | 5 | {-# LANGUAGE FunctionalDependencies #-} |
6 | {-# LANGUAGE UndecidableInstances #-} | 6 | {-# LANGUAGE UndecidableInstances #-} |
7 | {-# LANGUAGE PatternSynonyms #-} | ||
8 | {-# LANGUAGE ScopedTypeVariables #-} | ||
7 | 9 | ||
8 | {-# OPTIONS_GHC -fno-warn-simplifiable-class-constraints #-} | 10 | {-# OPTIONS_GHC -fno-warn-simplifiable-class-constraints #-} |
9 | 11 | ||
@@ -30,8 +32,15 @@ module Internal.Container where | |||
30 | import Internal.Vector | 32 | import Internal.Vector |
31 | import Internal.Matrix | 33 | import Internal.Matrix |
32 | import Internal.Element | 34 | import Internal.Element |
35 | import Internal.Extract(requires,pattern BAD_SIZE) | ||
33 | import Internal.Numeric | 36 | import Internal.Numeric |
34 | import Internal.Algorithms(Field,linearSolveSVD,Herm,mTm) | 37 | import Internal.Algorithms(Field,linearSolveSVD,Herm,mTm) |
38 | import Control.Monad(when) | ||
39 | import Data.Function | ||
40 | import Data.Int | ||
41 | import Foreign.Ptr | ||
42 | import Foreign.Storable | ||
43 | import Foreign.Marshal.Array | ||
35 | #if MIN_VERSION_base(4,11,0) | 44 | #if MIN_VERSION_base(4,11,0) |
36 | import Prelude hiding ((<>)) | 45 | import Prelude hiding ((<>)) |
37 | #endif | 46 | #endif |
@@ -227,7 +236,7 @@ meanCov x = (med,cov) where | |||
227 | 236 | ||
228 | -------------------------------------------------------------------------------- | 237 | -------------------------------------------------------------------------------- |
229 | 238 | ||
230 | sortVector :: (Ord t, Element t) => Vector t -> Vector t | 239 | sortVector :: (Ord t, Storable t) => Vector t -> Vector t |
231 | sortVector = sortV | 240 | sortVector = sortV |
232 | 241 | ||
233 | {- | | 242 | {- | |
@@ -248,7 +257,7 @@ sortVector = sortV | |||
248 | -2.20 0.11 -1.58 -0.01 0.19 -0.29 1.04 1.06 -2.09 -0.75 | 257 | -2.20 0.11 -1.58 -0.01 0.19 -0.29 1.04 1.06 -2.09 -0.75 |
249 | 258 | ||
250 | -} | 259 | -} |
251 | sortIndex :: (Ord t, Element t) => Vector t -> Vector I | 260 | sortIndex :: (Ord t, Storable t) => Vector t -> Vector I |
252 | sortIndex = sortI | 261 | sortIndex = sortI |
253 | 262 | ||
254 | ccompare :: (Ord t, Container c t) => c t -> c t -> c I | 263 | ccompare :: (Ord t, Container c t) => c t -> c t -> c I |
@@ -296,10 +305,91 @@ The indexes are autoconformable. | |||
296 | , 10, 16, 22 ] | 305 | , 10, 16, 22 ] |
297 | 306 | ||
298 | -} | 307 | -} |
299 | remap :: Element t => Matrix I -> Matrix I -> Matrix t -> Matrix t | 308 | remap :: Storable t => Matrix I -> Matrix I -> Matrix t -> Matrix t |
300 | remap i j m | 309 | remap i j m |
301 | | minElement i >= 0 && maxElement i < fromIntegral (rows m) && | 310 | | minElement i >= 0 && maxElement i < fromIntegral (rows m) && |
302 | minElement j >= 0 && maxElement j < fromIntegral (cols m) = remapM i' j' m | 311 | minElement j >= 0 && maxElement j < fromIntegral (cols m) = remapM i' j' m |
303 | | otherwise = error $ "out of range index in remap" | 312 | | otherwise = error $ "out of range index in remap" |
304 | where | 313 | where |
305 | [i',j'] = conformMs [i,j] | 314 | [i',j'] = conformMs [i,j] |
315 | |||
316 | sortI :: (Storable a, Ord a) => Vector a -> Vector Int32 | ||
317 | sortI = sortG sort_index | ||
318 | |||
319 | type C_Compare a = Ptr a -> Ptr a -> IO Int32 | ||
320 | |||
321 | foreign import ccall "wrapper" wrapCompare :: C_Compare a -> IO (FunPtr (C_Compare a)) | ||
322 | |||
323 | foreign import ccall "qsort" | ||
324 | c_qsort :: Ptr a -- ^ base | ||
325 | -> Word -- ^ nmemb | ||
326 | -> Word -- ^ size | ||
327 | -> FunPtr (C_Compare a) -- ^ compar | ||
328 | -> IO () | ||
329 | |||
330 | sizeOfElem :: forall a. Storable a => Ptr a -> Int | ||
331 | sizeOfElem _ = sizeOf (undefined :: a) | ||
332 | |||
333 | sort_index :: (Storable a, Ord a) => | ||
334 | Int32 -> Ptr a | ||
335 | -> Int32 -> Ptr Int32 | ||
336 | -> IO Int32 | ||
337 | sort_index vn vp rn rp = do | ||
338 | requires (vn == rn) BAD_SIZE $ do | ||
339 | comp <- wrapCompare $ \ap bp -> do | ||
340 | a <- peekElemOff vp . fromIntegral =<< peek (ap :: Ptr Int32) | ||
341 | b <- peekElemOff vp . fromIntegral =<< peek bp | ||
342 | return $ case compare a b of | ||
343 | LT -> -1 | ||
344 | GT -> 1 | ||
345 | EQ -> 0 | ||
346 | sequence_ [ pokeElemOff rp (fromIntegral i) i | i <- [0 .. rn-1] ] | ||
347 | c_qsort rp (fromIntegral rn) 4 comp | ||
348 | freeHaskellFunPtr comp | ||
349 | return 0 | ||
350 | |||
351 | sortV :: (Storable a, Ord a) => Vector a -> Vector a | ||
352 | sortV = sortG sortStorable | ||
353 | |||
354 | sortStorable :: (Storable a, Ord a) => | ||
355 | Int32 -> Ptr a | ||
356 | -> Int32 -> Ptr a | ||
357 | -> IO Int32 | ||
358 | sortStorable vn vp rn rp = do | ||
359 | requires (vn == rn) BAD_SIZE $ do | ||
360 | copyArray rp vp (fromIntegral vn * sizeOfElem vp) | ||
361 | comp <- wrapCompare $ \ap bp -> do | ||
362 | a <- peek ap | ||
363 | b <- peek bp | ||
364 | return $ case compare a b of | ||
365 | LT -> -1 | ||
366 | GT -> 1 | ||
367 | EQ -> 0 | ||
368 | c_qsort rp (fromIntegral rn) (fromIntegral $ sizeOfElem rp) comp | ||
369 | freeHaskellFunPtr comp | ||
370 | return 0 | ||
371 | |||
372 | remapM :: Storable a => Matrix Int32 -> Matrix Int32 -> Matrix a -> Matrix a | ||
373 | remapM = remapG remapStorable | ||
374 | |||
375 | remapStorable :: Storable a => | ||
376 | Int32 -> Int32 -> Int32 -> Int32 -> Ptr Int32 -- i | ||
377 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Int32 -- j | ||
378 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr a -- m | ||
379 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr a -- r | ||
380 | -> IO Int32 | ||
381 | remapStorable ir ic iXr iXc ip | ||
382 | jr jc jXr jXc jp | ||
383 | mr mc mXr mXc mp | ||
384 | rr rc rXr rXc rp = do | ||
385 | requires (ir==jr && ic==jc && ir==rr && ic==rc) BAD_SIZE $ do | ||
386 | ($ 0) $ fix $ \aloop a -> when (a<rr) $ do | ||
387 | ($ 0) $ fix $ \bloop b -> when (b<rc) $ do | ||
388 | iab <- peekElemOff ip (fromIntegral $ iXr*a + iXc*b) | ||
389 | jab <- peekElemOff jp (fromIntegral $ jXr*a + jXc*b) | ||
390 | when (0 <= iab && iab < mr && 0 <= jab && jab < mc) $ | ||
391 | pokeElemOff rp (fromIntegral $ rXr*a + rXc*b) | ||
392 | =<< peekElemOff mp (fromIntegral $ mXr*iab + mXc*jab) | ||
393 | bloop (succ b) | ||
394 | aloop (succ a) | ||
395 | return 0 | ||
diff --git a/packages/base/src/Internal/Conversion.hs b/packages/base/src/Internal/Conversion.hs index 4541ec4..7eb8ec7 100644 --- a/packages/base/src/Internal/Conversion.hs +++ b/packages/base/src/Internal/Conversion.hs | |||
@@ -28,11 +28,12 @@ import Internal.Matrix | |||
28 | import Internal.Vectorized | 28 | import Internal.Vectorized |
29 | import Data.Complex | 29 | import Data.Complex |
30 | import Control.Arrow((***)) | 30 | import Control.Arrow((***)) |
31 | import Foreign.Storable | ||
31 | 32 | ||
32 | ------------------------------------------------------------------- | 33 | ------------------------------------------------------------------- |
33 | 34 | ||
34 | -- | Supported single-double precision type pairs | 35 | -- | Supported single-double precision type pairs |
35 | class (Element s, Element d) => Precision s d | s -> d, d -> s where | 36 | class (Storable s, Storable d) => Precision s d | s -> d, d -> s where |
36 | double2FloatG :: Vector d -> Vector s | 37 | double2FloatG :: Vector d -> Vector s |
37 | float2DoubleG :: Vector s -> Vector d | 38 | float2DoubleG :: Vector s -> Vector d |
38 | 39 | ||
@@ -50,7 +51,7 @@ instance Precision I Z where | |||
50 | 51 | ||
51 | 52 | ||
52 | -- | Supported real types | 53 | -- | Supported real types |
53 | class (Element t, Element (Complex t), RealFloat t) | 54 | class (Storable t, Storable (Complex t), RealFloat t) |
54 | => RealElement t | 55 | => RealElement t |
55 | 56 | ||
56 | instance RealElement Double | 57 | instance RealElement Double |
@@ -69,7 +70,7 @@ class Complexable c where | |||
69 | instance Complexable Vector where | 70 | instance Complexable Vector where |
70 | toComplex' = toComplexV | 71 | toComplex' = toComplexV |
71 | fromComplex' = fromComplexV | 72 | fromComplex' = fromComplexV |
72 | comp' v = toComplex' (v,constantD 0 (dim v)) | 73 | comp' v = toComplex' (v,constantAux 0 (dim v)) |
73 | single' = double2FloatG | 74 | single' = double2FloatG |
74 | double' = float2DoubleG | 75 | double' = float2DoubleG |
75 | 76 | ||
diff --git a/packages/base/src/Internal/Convolution.hs b/packages/base/src/Internal/Convolution.hs index 75fbef4..ae8ebc6 100644 --- a/packages/base/src/Internal/Convolution.hs +++ b/packages/base/src/Internal/Convolution.hs | |||
@@ -24,12 +24,13 @@ import Internal.Numeric | |||
24 | import Internal.Element | 24 | import Internal.Element |
25 | import Internal.Conversion | 25 | import Internal.Conversion |
26 | import Internal.Container | 26 | import Internal.Container |
27 | import Foreign.Storable | ||
27 | #if MIN_VERSION_base(4,11,0) | 28 | #if MIN_VERSION_base(4,11,0) |
28 | import Prelude hiding ((<>)) | 29 | import Prelude hiding ((<>)) |
29 | #endif | 30 | #endif |
30 | 31 | ||
31 | 32 | ||
32 | vectSS :: Element t => Int -> Vector t -> Matrix t | 33 | vectSS :: Storable t => Int -> Vector t -> Matrix t |
33 | vectSS n v = fromRows [ subVector k n v | k <- [0 .. dim v - n] ] | 34 | vectSS n v = fromRows [ subVector k n v | k <- [0 .. dim v - n] ] |
34 | 35 | ||
35 | 36 | ||
@@ -82,7 +83,7 @@ corrMin ker v | |||
82 | 83 | ||
83 | 84 | ||
84 | 85 | ||
85 | matSS :: Element t => Int -> Matrix t -> [Matrix t] | 86 | matSS :: Storable t => Int -> Matrix t -> [Matrix t] |
86 | matSS dr m = map (reshape c) [ subVector (k*c) n v | k <- [0 .. r - dr] ] | 87 | matSS dr m = map (reshape c) [ subVector (k*c) n v | k <- [0 .. r - dr] ] |
87 | where | 88 | where |
88 | v = flatten m | 89 | v = flatten m |
@@ -155,7 +156,7 @@ conv2 k m | |||
155 | empty = r == 0 || c == 0 | 156 | empty = r == 0 || c == 0 |
156 | 157 | ||
157 | 158 | ||
158 | separable :: Element t => (Vector t -> Vector t) -> Matrix t -> Matrix t | 159 | separable :: Storable t => (Vector t -> Vector t) -> Matrix t -> Matrix t |
159 | -- ^ matrix computation implemented as separated vector operations by rows and columns. | 160 | -- ^ matrix computation implemented as separated vector operations by rows and columns. |
160 | separable f = fromColumns . map f . toColumns . fromRows . map f . toRows | 161 | separable f = fromColumns . map f . toColumns . fromRows . map f . toRows |
161 | 162 | ||
diff --git a/packages/base/src/Internal/Devel.hs b/packages/base/src/Internal/Devel.hs index f72d8aa..b0594d4 100644 --- a/packages/base/src/Internal/Devel.hs +++ b/packages/base/src/Internal/Devel.hs | |||
@@ -13,6 +13,7 @@ module Internal.Devel where | |||
13 | 13 | ||
14 | 14 | ||
15 | import Control.Monad ( when ) | 15 | import Control.Monad ( when ) |
16 | import Data.Int | ||
16 | import Foreign.C.Types ( CInt ) | 17 | import Foreign.C.Types ( CInt ) |
17 | --import Foreign.Storable.Complex () | 18 | --import Foreign.Storable.Complex () |
18 | import Foreign.Ptr(Ptr) | 19 | import Foreign.Ptr(Ptr) |
@@ -28,7 +29,7 @@ infixl 0 // | |||
28 | 29 | ||
29 | -- GSL error codes are <= 1024 | 30 | -- GSL error codes are <= 1024 |
30 | -- | error codes for the auxiliary functions required by the wrappers | 31 | -- | error codes for the auxiliary functions required by the wrappers |
31 | errorCode :: CInt -> String | 32 | errorCode :: Int32 -> String |
32 | errorCode 2000 = "bad size" | 33 | errorCode 2000 = "bad size" |
33 | errorCode 2001 = "bad function code" | 34 | errorCode 2001 = "bad function code" |
34 | errorCode 2002 = "memory problem" | 35 | errorCode 2002 = "memory problem" |
@@ -44,7 +45,7 @@ errorCode n = "code "++show n | |||
44 | foreign import ccall unsafe "asm_finit" finit :: IO () | 45 | foreign import ccall unsafe "asm_finit" finit :: IO () |
45 | 46 | ||
46 | -- | check the error code | 47 | -- | check the error code |
47 | check :: String -> IO CInt -> IO () | 48 | check :: String -> IO Int32 -> IO () |
48 | check msg f = do | 49 | check msg f = do |
49 | -- finit | 50 | -- finit |
50 | err <- f | 51 | err <- f |
@@ -54,7 +55,7 @@ check msg f = do | |||
54 | 55 | ||
55 | -- | postfix error code check | 56 | -- | postfix error code check |
56 | infixl 0 #| | 57 | infixl 0 #| |
57 | (#|) :: IO CInt -> String -> IO () | 58 | (#|) :: IO Int32 -> String -> IO () |
58 | (#|) = flip check | 59 | (#|) = flip check |
59 | 60 | ||
60 | -- | Error capture and conversion to Maybe | 61 | -- | Error capture and conversion to Maybe |
@@ -65,12 +66,12 @@ mbCatch act = E.catch (Just `fmap` act) f | |||
65 | 66 | ||
66 | -------------------------------------------------------------------------------- | 67 | -------------------------------------------------------------------------------- |
67 | 68 | ||
68 | type CM b r = CInt -> CInt -> Ptr b -> r | 69 | type CM b r = Int32 -> Int32 -> Ptr b -> r |
69 | type CV b r = CInt -> Ptr b -> r | 70 | type CV b r = Int32 -> Ptr b -> r |
70 | type OM b r = CInt -> CInt -> CInt -> CInt -> Ptr b -> r | 71 | type OM b r = Int32 -> Int32 -> Int32 -> Int32 -> Ptr b -> r |
71 | 72 | ||
72 | type CIdxs r = CV CInt r | 73 | type CIdxs r = CV Int32 r |
73 | type Ok = IO CInt | 74 | type Ok = IO Int32 |
74 | 75 | ||
75 | infixr 5 :>, ::>, ..> | 76 | infixr 5 :>, ::>, ..> |
76 | type (:>) t r = CV t r | 77 | type (:>) t r = CV t r |
@@ -87,8 +88,8 @@ class TransArray c | |||
87 | 88 | ||
88 | instance Storable t => TransArray (Vector t) | 89 | instance Storable t => TransArray (Vector t) |
89 | where | 90 | where |
90 | type Trans (Vector t) b = CInt -> Ptr t -> b | 91 | type Trans (Vector t) b = Int32 -> Ptr t -> b |
91 | type TransRaw (Vector t) b = CInt -> Ptr t -> b | 92 | type TransRaw (Vector t) b = Int32 -> Ptr t -> b |
92 | apply = avec | 93 | apply = avec |
93 | {-# INLINE apply #-} | 94 | {-# INLINE apply #-} |
94 | applyRaw = avec | 95 | applyRaw = avec |
diff --git a/packages/base/src/Internal/Element.hs b/packages/base/src/Internal/Element.hs index 2e330ee..80eda8d 100644 --- a/packages/base/src/Internal/Element.hs +++ b/packages/base/src/Internal/Element.hs | |||
@@ -33,14 +33,14 @@ import Data.List.Split(chunksOf) | |||
33 | import Foreign.Storable(Storable) | 33 | import Foreign.Storable(Storable) |
34 | import System.IO.Unsafe(unsafePerformIO) | 34 | import System.IO.Unsafe(unsafePerformIO) |
35 | import Control.Monad(liftM) | 35 | import Control.Monad(liftM) |
36 | import Foreign.C.Types(CInt) | 36 | import Data.Int |
37 | 37 | ||
38 | ------------------------------------------------------------------- | 38 | ------------------------------------------------------------------- |
39 | 39 | ||
40 | 40 | ||
41 | import Data.Binary | 41 | import Data.Binary |
42 | 42 | ||
43 | instance (Binary (Vector a), Element a) => Binary (Matrix a) where | 43 | instance (Binary (Vector a), Storable a) => Binary (Matrix a) where |
44 | put m = do | 44 | put m = do |
45 | put (cols m) | 45 | put (cols m) |
46 | put (flatten m) | 46 | put (flatten m) |
@@ -52,7 +52,7 @@ instance (Binary (Vector a), Element a) => Binary (Matrix a) where | |||
52 | 52 | ||
53 | ------------------------------------------------------------------- | 53 | ------------------------------------------------------------------- |
54 | 54 | ||
55 | instance (Show a, Element a) => (Show (Matrix a)) where | 55 | instance (Show a, Storable a) => (Show (Matrix a)) where |
56 | show m | rows m == 0 || cols m == 0 = sizes m ++" []" | 56 | show m | rows m == 0 || cols m == 0 = sizes m ++" []" |
57 | show m = (sizes m++) . dsp . map (map show) . toLists $ m | 57 | show m = (sizes m++) . dsp . map (map show) . toLists $ m |
58 | 58 | ||
@@ -70,7 +70,7 @@ dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unw | |||
70 | 70 | ||
71 | ------------------------------------------------------------------ | 71 | ------------------------------------------------------------------ |
72 | 72 | ||
73 | instance (Element a, Read a) => Read (Matrix a) where | 73 | instance (Storable a, Read a) => Read (Matrix a) where |
74 | readsPrec _ s = [((rs><cs) . read $ listnums, rest)] | 74 | readsPrec _ s = [((rs><cs) . read $ listnums, rest)] |
75 | where (thing,rest) = breakAt ']' s | 75 | where (thing,rest) = breakAt ']' s |
76 | (dims,listnums) = breakAt ')' thing | 76 | (dims,listnums) = breakAt ')' thing |
@@ -133,13 +133,13 @@ ppext (DropLast n) = printf "DropLast %d" n | |||
133 | 133 | ||
134 | -} | 134 | -} |
135 | infixl 9 ?? | 135 | infixl 9 ?? |
136 | (??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t | 136 | (??) :: Storable t => Matrix t -> (Extractor,Extractor) -> Matrix t |
137 | 137 | ||
138 | minEl :: Vector CInt -> CInt | 138 | minEl :: Vector Int32 -> Int32 |
139 | minEl = toScalarI Min | 139 | minEl = toScalarI Min |
140 | maxEl :: Vector CInt -> CInt | 140 | maxEl :: Vector Int32 -> Int32 |
141 | maxEl = toScalarI Max | 141 | maxEl = toScalarI Max |
142 | cmodi :: Foreign.C.Types.CInt -> Vector Foreign.C.Types.CInt -> Vector Foreign.C.Types.CInt | 142 | cmodi :: Int32 -> Vector Int32 -> Vector Int32 |
143 | cmodi = vectorMapValI ModVS | 143 | cmodi = vectorMapValI ModVS |
144 | 144 | ||
145 | extractError :: Matrix t1 -> (Extractor, Extractor) -> t | 145 | extractError :: Matrix t1 -> (Extractor, Extractor) -> t |
@@ -181,7 +181,7 @@ m ?? (e, TakeLast n) = m ?? (e, Drop (cols m - n)) | |||
181 | m ?? (DropLast n, e) = m ?? (Take (rows m - n), e) | 181 | m ?? (DropLast n, e) = m ?? (Take (rows m - n), e) |
182 | m ?? (e, DropLast n) = m ?? (e, Take (cols m - n)) | 182 | m ?? (e, DropLast n) = m ?? (e, Take (cols m - n)) |
183 | 183 | ||
184 | m ?? (er,ec) = unsafePerformIO $ extractR (orderOf m) m moder rs modec cs | 184 | m ?? (er,ec) = unsafePerformIO $ extractAux (orderOf m) m moder rs modec cs |
185 | where | 185 | where |
186 | (moder,rs) = mkExt (rows m) er | 186 | (moder,rs) = mkExt (rows m) er |
187 | (modec,cs) = mkExt (cols m) ec | 187 | (modec,cs) = mkExt (cols m) ec |
@@ -209,14 +209,14 @@ common f = commonval . map f | |||
209 | 209 | ||
210 | 210 | ||
211 | -- | creates a matrix from a vertical list of matrices | 211 | -- | creates a matrix from a vertical list of matrices |
212 | joinVert :: Element t => [Matrix t] -> Matrix t | 212 | joinVert :: Storable t => [Matrix t] -> Matrix t |
213 | joinVert [] = emptyM 0 0 | 213 | joinVert [] = emptyM 0 0 |
214 | joinVert ms = case common cols ms of | 214 | joinVert ms = case common cols ms of |
215 | Nothing -> error "(impossible) joinVert on matrices with different number of columns" | 215 | Nothing -> error "(impossible) joinVert on matrices with different number of columns" |
216 | Just c -> matrixFromVector RowMajor (sum (map rows ms)) c $ vjoin (map flatten ms) | 216 | Just c -> matrixFromVector RowMajor (sum (map rows ms)) c $ vjoin (map flatten ms) |
217 | 217 | ||
218 | -- | creates a matrix from a horizontal list of matrices | 218 | -- | creates a matrix from a horizontal list of matrices |
219 | joinHoriz :: Element t => [Matrix t] -> Matrix t | 219 | joinHoriz :: Storable t => [Matrix t] -> Matrix t |
220 | joinHoriz ms = trans. joinVert . map trans $ ms | 220 | joinHoriz ms = trans. joinVert . map trans $ ms |
221 | 221 | ||
222 | {- | Create a matrix from blocks given as a list of lists of matrices. | 222 | {- | Create a matrix from blocks given as a list of lists of matrices. |
@@ -240,13 +240,13 @@ disp = putStr . dispf 2 | |||
240 | 3 3 3 3 3 0 0 3 0 0 | 240 | 3 3 3 3 3 0 0 3 0 0 |
241 | 241 | ||
242 | -} | 242 | -} |
243 | fromBlocks :: Element t => [[Matrix t]] -> Matrix t | 243 | fromBlocks :: Storable t => [[Matrix t]] -> Matrix t |
244 | fromBlocks = fromBlocksRaw . adaptBlocks | 244 | fromBlocks = fromBlocksRaw . adaptBlocks |
245 | 245 | ||
246 | fromBlocksRaw :: Element t => [[Matrix t]] -> Matrix t | 246 | fromBlocksRaw :: Storable t => [[Matrix t]] -> Matrix t |
247 | fromBlocksRaw mms = joinVert . map joinHoriz $ mms | 247 | fromBlocksRaw mms = joinVert . map joinHoriz $ mms |
248 | 248 | ||
249 | adaptBlocks :: Element t => [[Matrix t]] -> [[Matrix t]] | 249 | adaptBlocks :: Storable t => [[Matrix t]] -> [[Matrix t]] |
250 | adaptBlocks ms = ms' where | 250 | adaptBlocks ms = ms' where |
251 | bc = case common length ms of | 251 | bc = case common length ms of |
252 | Just c -> c | 252 | Just c -> c |
@@ -258,7 +258,7 @@ adaptBlocks ms = ms' where | |||
258 | 258 | ||
259 | g [Just nr,Just nc] m | 259 | g [Just nr,Just nc] m |
260 | | nr == r && nc == c = m | 260 | | nr == r && nc == c = m |
261 | | r == 1 && c == 1 = matrixFromVector RowMajor nr nc (constantD x (nr*nc)) | 261 | | r == 1 && c == 1 = matrixFromVector RowMajor nr nc (constantAux x (nr*nc)) |
262 | | r == 1 = fromRows (replicate nr (flatten m)) | 262 | | r == 1 = fromRows (replicate nr (flatten m)) |
263 | | otherwise = fromColumns (replicate nc (flatten m)) | 263 | | otherwise = fromColumns (replicate nc (flatten m)) |
264 | where | 264 | where |
@@ -288,7 +288,7 @@ adaptBlocks ms = ms' where | |||
288 | , 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0 ] | 288 | , 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0 ] |
289 | 289 | ||
290 | -} | 290 | -} |
291 | diagBlock :: (Element t, Num t) => [Matrix t] -> Matrix t | 291 | diagBlock :: (Storable t, Num t) => [Matrix t] -> Matrix t |
292 | diagBlock ms = fromBlocks $ zipWith f ms [0..] | 292 | diagBlock ms = fromBlocks $ zipWith f ms [0..] |
293 | where | 293 | where |
294 | f m k = take n $ replicate k z ++ m : repeat z | 294 | f m k = take n $ replicate k z ++ m : repeat z |
@@ -299,13 +299,13 @@ diagBlock ms = fromBlocks $ zipWith f ms [0..] | |||
299 | 299 | ||
300 | 300 | ||
301 | -- | Reverse rows | 301 | -- | Reverse rows |
302 | flipud :: Element t => Matrix t -> Matrix t | 302 | flipud :: Storable t => Matrix t -> Matrix t |
303 | flipud m = extractRows [r-1,r-2 .. 0] $ m | 303 | flipud m = extractRows [r-1,r-2 .. 0] $ m |
304 | where | 304 | where |
305 | r = rows m | 305 | r = rows m |
306 | 306 | ||
307 | -- | Reverse columns | 307 | -- | Reverse columns |
308 | fliprl :: Element t => Matrix t -> Matrix t | 308 | fliprl :: Storable t => Matrix t -> Matrix t |
309 | fliprl m = extractColumns [c-1,c-2 .. 0] $ m | 309 | fliprl m = extractColumns [c-1,c-2 .. 0] $ m |
310 | where | 310 | where |
311 | c = cols m | 311 | c = cols m |
@@ -330,7 +330,7 @@ diagRect z v r c = ST.runSTMatrix $ do | |||
330 | return m | 330 | return m |
331 | 331 | ||
332 | -- | extracts the diagonal from a rectangular matrix | 332 | -- | extracts the diagonal from a rectangular matrix |
333 | takeDiag :: (Element t) => Matrix t -> Vector t | 333 | takeDiag :: (Storable t) => Matrix t -> Vector t |
334 | takeDiag m = fromList [flatten m @> (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] | 334 | takeDiag m = fromList [flatten m @> (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] |
335 | 335 | ||
336 | ------------------------------------------------------------ | 336 | ------------------------------------------------------------ |
@@ -363,32 +363,32 @@ r >< c = f where | |||
363 | 363 | ||
364 | ---------------------------------------------------------------- | 364 | ---------------------------------------------------------------- |
365 | 365 | ||
366 | takeRows :: Element t => Int -> Matrix t -> Matrix t | 366 | takeRows :: Storable t => Int -> Matrix t -> Matrix t |
367 | takeRows n mt = subMatrix (0,0) (n, cols mt) mt | 367 | takeRows n mt = subMatrix (0,0) (n, cols mt) mt |
368 | 368 | ||
369 | -- | Creates a matrix with the last n rows of another matrix | 369 | -- | Creates a matrix with the last n rows of another matrix |
370 | takeLastRows :: Element t => Int -> Matrix t -> Matrix t | 370 | takeLastRows :: Storable t => Int -> Matrix t -> Matrix t |
371 | takeLastRows n mt = subMatrix (rows mt - n, 0) (n, cols mt) mt | 371 | takeLastRows n mt = subMatrix (rows mt - n, 0) (n, cols mt) mt |
372 | 372 | ||
373 | dropRows :: Element t => Int -> Matrix t -> Matrix t | 373 | dropRows :: Storable t => Int -> Matrix t -> Matrix t |
374 | dropRows n mt = subMatrix (n,0) (rows mt - n, cols mt) mt | 374 | dropRows n mt = subMatrix (n,0) (rows mt - n, cols mt) mt |
375 | 375 | ||
376 | -- | Creates a copy of a matrix without the last n rows | 376 | -- | Creates a copy of a matrix without the last n rows |
377 | dropLastRows :: Element t => Int -> Matrix t -> Matrix t | 377 | dropLastRows :: Storable t => Int -> Matrix t -> Matrix t |
378 | dropLastRows n mt = subMatrix (0,0) (rows mt - n, cols mt) mt | 378 | dropLastRows n mt = subMatrix (0,0) (rows mt - n, cols mt) mt |
379 | 379 | ||
380 | takeColumns :: Element t => Int -> Matrix t -> Matrix t | 380 | takeColumns :: Storable t => Int -> Matrix t -> Matrix t |
381 | takeColumns n mt = subMatrix (0,0) (rows mt, n) mt | 381 | takeColumns n mt = subMatrix (0,0) (rows mt, n) mt |
382 | 382 | ||
383 | -- |Creates a matrix with the last n columns of another matrix | 383 | -- |Creates a matrix with the last n columns of another matrix |
384 | takeLastColumns :: Element t => Int -> Matrix t -> Matrix t | 384 | takeLastColumns :: Storable t => Int -> Matrix t -> Matrix t |
385 | takeLastColumns n mt = subMatrix (0, cols mt - n) (rows mt, n) mt | 385 | takeLastColumns n mt = subMatrix (0, cols mt - n) (rows mt, n) mt |
386 | 386 | ||
387 | dropColumns :: Element t => Int -> Matrix t -> Matrix t | 387 | dropColumns :: Storable t => Int -> Matrix t -> Matrix t |
388 | dropColumns n mt = subMatrix (0,n) (rows mt, cols mt - n) mt | 388 | dropColumns n mt = subMatrix (0,n) (rows mt, cols mt - n) mt |
389 | 389 | ||
390 | -- | Creates a copy of a matrix without the last n columns | 390 | -- | Creates a copy of a matrix without the last n columns |
391 | dropLastColumns :: Element t => Int -> Matrix t -> Matrix t | 391 | dropLastColumns :: Storable t => Int -> Matrix t -> Matrix t |
392 | dropLastColumns n mt = subMatrix (0,0) (rows mt, cols mt - n) mt | 392 | dropLastColumns n mt = subMatrix (0,0) (rows mt, cols mt - n) mt |
393 | 393 | ||
394 | ---------------------------------------------------------------- | 394 | ---------------------------------------------------------------- |
@@ -402,7 +402,7 @@ dropLastColumns n mt = subMatrix (0,0) (rows mt, cols mt - n) mt | |||
402 | , 5.0, 6.0 ] | 402 | , 5.0, 6.0 ] |
403 | 403 | ||
404 | -} | 404 | -} |
405 | fromLists :: Element t => [[t]] -> Matrix t | 405 | fromLists :: Storable t => [[t]] -> Matrix t |
406 | fromLists = fromRows . map fromList | 406 | fromLists = fromRows . map fromList |
407 | 407 | ||
408 | -- | creates a 1-row matrix from a vector | 408 | -- | creates a 1-row matrix from a vector |
@@ -443,7 +443,7 @@ Hilbert matrix of order N: | |||
443 | @hilb n = buildMatrix n n (\\(i,j)->1/(fromIntegral i + fromIntegral j +1))@ | 443 | @hilb n = buildMatrix n n (\\(i,j)->1/(fromIntegral i + fromIntegral j +1))@ |
444 | 444 | ||
445 | -} | 445 | -} |
446 | buildMatrix :: Element a => Int -> Int -> ((Int, Int) -> a) -> Matrix a | 446 | buildMatrix :: Storable a => Int -> Int -> ((Int, Int) -> a) -> Matrix a |
447 | buildMatrix rc cc f = | 447 | buildMatrix rc cc f = |
448 | fromLists $ map (map f) | 448 | fromLists $ map (map f) |
449 | $ map (\ ri -> map (\ ci -> (ri, ci)) [0 .. (cc - 1)]) [0 .. (rc - 1)] | 449 | $ map (\ ri -> map (\ ci -> (ri, ci)) [0 .. (cc - 1)]) [0 .. (rc - 1)] |
@@ -458,11 +458,11 @@ fromArray2D m = (r><c) (elems m) | |||
458 | 458 | ||
459 | 459 | ||
460 | -- | rearranges the rows of a matrix according to the order given in a list of integers. | 460 | -- | rearranges the rows of a matrix according to the order given in a list of integers. |
461 | extractRows :: Element t => [Int] -> Matrix t -> Matrix t | 461 | extractRows :: Storable t => [Int] -> Matrix t -> Matrix t |
462 | extractRows l m = m ?? (Pos (idxs l), All) | 462 | extractRows l m = m ?? (Pos (idxs l), All) |
463 | 463 | ||
464 | -- | rearranges the rows of a matrix according to the order given in a list of integers. | 464 | -- | rearranges the rows of a matrix according to the order given in a list of integers. |
465 | extractColumns :: Element t => [Int] -> Matrix t -> Matrix t | 465 | extractColumns :: Storable t => [Int] -> Matrix t -> Matrix t |
466 | extractColumns l m = m ?? (All, Pos (idxs l)) | 466 | extractColumns l m = m ?? (All, Pos (idxs l)) |
467 | 467 | ||
468 | 468 | ||
@@ -476,13 +476,13 @@ extractColumns l m = m ?? (All, Pos (idxs l)) | |||
476 | , 0.0, 1.0, 0.0, 1.0, 0.0, 1.0 ] | 476 | , 0.0, 1.0, 0.0, 1.0, 0.0, 1.0 ] |
477 | 477 | ||
478 | -} | 478 | -} |
479 | repmat :: (Element t) => Matrix t -> Int -> Int -> Matrix t | 479 | repmat :: (Storable t) => Matrix t -> Int -> Int -> Matrix t |
480 | repmat m r c | 480 | repmat m r c |
481 | | r == 0 || c == 0 = emptyM (r*rows m) (c*cols m) | 481 | | r == 0 || c == 0 = emptyM (r*rows m) (c*cols m) |
482 | | otherwise = fromBlocks $ replicate r $ replicate c $ m | 482 | | otherwise = fromBlocks $ replicate r $ replicate c $ m |
483 | 483 | ||
484 | -- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix. | 484 | -- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix. |
485 | liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | 485 | liftMatrix2Auto :: (Storable t, Storable a, Storable b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t |
486 | liftMatrix2Auto f m1 m2 | 486 | liftMatrix2Auto f m1 m2 |
487 | | compat' m1 m2 = lM f m1 m2 | 487 | | compat' m1 m2 = lM f m1 m2 |
488 | | ok = lM f m1' m2' | 488 | | ok = lM f m1' m2' |
@@ -499,7 +499,7 @@ liftMatrix2Auto f m1 m2 | |||
499 | m2' = conformMTo (r,c) m2 | 499 | m2' = conformMTo (r,c) m2 |
500 | 500 | ||
501 | -- FIXME do not flatten if equal order | 501 | -- FIXME do not flatten if equal order |
502 | lM :: (Storable t, Element t1, Element t2) | 502 | lM :: (Storable t, Storable t1, Storable t2) |
503 | => (Vector t1 -> Vector t2 -> Vector t) | 503 | => (Vector t1 -> Vector t2 -> Vector t) |
504 | -> Matrix t1 -> Matrix t2 -> Matrix t | 504 | -> Matrix t1 -> Matrix t2 -> Matrix t |
505 | lM f m1 m2 = matrixFromVector | 505 | lM f m1 m2 = matrixFromVector |
@@ -520,7 +520,7 @@ compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 | |||
520 | 520 | ||
521 | ------------------------------------------------------------ | 521 | ------------------------------------------------------------ |
522 | 522 | ||
523 | toBlockRows :: Element t => [Int] -> Matrix t -> [Matrix t] | 523 | toBlockRows :: Storable t => [Int] -> Matrix t -> [Matrix t] |
524 | toBlockRows [r] m | 524 | toBlockRows [r] m |
525 | | r == rows m = [m] | 525 | | r == rows m = [m] |
526 | toBlockRows rs m | 526 | toBlockRows rs m |
@@ -530,13 +530,13 @@ toBlockRows rs m | |||
530 | szs = map (* cols m) rs | 530 | szs = map (* cols m) rs |
531 | g k = (k><0)[] | 531 | g k = (k><0)[] |
532 | 532 | ||
533 | toBlockCols :: Element t => [Int] -> Matrix t -> [Matrix t] | 533 | toBlockCols :: Storable t => [Int] -> Matrix t -> [Matrix t] |
534 | toBlockCols [c] m | c == cols m = [m] | 534 | toBlockCols [c] m | c == cols m = [m] |
535 | toBlockCols cs m = map trans . toBlockRows cs . trans $ m | 535 | toBlockCols cs m = map trans . toBlockRows cs . trans $ m |
536 | 536 | ||
537 | -- | Partition a matrix into blocks with the given numbers of rows and columns. | 537 | -- | Partition a matrix into blocks with the given numbers of rows and columns. |
538 | -- The remaining rows and columns are discarded. | 538 | -- The remaining rows and columns are discarded. |
539 | toBlocks :: (Element t) => [Int] -> [Int] -> Matrix t -> [[Matrix t]] | 539 | toBlocks :: (Storable t) => [Int] -> [Int] -> Matrix t -> [[Matrix t]] |
540 | toBlocks rs cs m | 540 | toBlocks rs cs m |
541 | | ok = map (toBlockCols cs) . toBlockRows rs $ m | 541 | | ok = map (toBlockCols cs) . toBlockRows rs $ m |
542 | | otherwise = error $ "toBlocks: bad partition: "++show rs++" "++show cs | 542 | | otherwise = error $ "toBlocks: bad partition: "++show rs++" "++show cs |
@@ -546,7 +546,7 @@ toBlocks rs cs m | |||
546 | 546 | ||
547 | -- | Fully partition a matrix into blocks of the same size. If the dimensions are not | 547 | -- | Fully partition a matrix into blocks of the same size. If the dimensions are not |
548 | -- a multiple of the given size the last blocks will be smaller. | 548 | -- a multiple of the given size the last blocks will be smaller. |
549 | toBlocksEvery :: (Element t) => Int -> Int -> Matrix t -> [[Matrix t]] | 549 | toBlocksEvery :: (Storable t) => Int -> Int -> Matrix t -> [[Matrix t]] |
550 | toBlocksEvery r c m | 550 | toBlocksEvery r c m |
551 | | r < 1 || c < 1 = error $ "toBlocksEvery expects block sizes > 0, given "++show r++" and "++ show c | 551 | | r < 1 || c < 1 = error $ "toBlocksEvery expects block sizes > 0, given "++show r++" and "++ show c |
552 | | otherwise = toBlocks rs cs m | 552 | | otherwise = toBlocks rs cs m |
@@ -576,7 +576,7 @@ m[1,2] = 6 | |||
576 | 576 | ||
577 | -} | 577 | -} |
578 | mapMatrixWithIndexM_ | 578 | mapMatrixWithIndexM_ |
579 | :: (Element a, Num a, Monad m) => | 579 | :: (Storable a, Num a, Monad m) => |
580 | ((Int, Int) -> a -> m ()) -> Matrix a -> m () | 580 | ((Int, Int) -> a -> m ()) -> Matrix a -> m () |
581 | mapMatrixWithIndexM_ g m = mapVectorWithIndexM_ (mk c g) . flatten $ m | 581 | mapMatrixWithIndexM_ g m = mapVectorWithIndexM_ (mk c g) . flatten $ m |
582 | where | 582 | where |
@@ -592,7 +592,7 @@ Just (3><3) | |||
592 | 592 | ||
593 | -} | 593 | -} |
594 | mapMatrixWithIndexM | 594 | mapMatrixWithIndexM |
595 | :: (Element a, Storable b, Monad m) => | 595 | :: (Storable a, Storable b, Monad m) => |
596 | ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b) | 596 | ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b) |
597 | mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m | 597 | mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m |
598 | where | 598 | where |
@@ -608,11 +608,11 @@ mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . fla | |||
608 | 608 | ||
609 | -} | 609 | -} |
610 | mapMatrixWithIndex | 610 | mapMatrixWithIndex |
611 | :: (Element a, Storable b) => | 611 | :: (Storable a, Storable b) => |
612 | ((Int, Int) -> a -> b) -> Matrix a -> Matrix b | 612 | ((Int, Int) -> a -> b) -> Matrix a -> Matrix b |
613 | mapMatrixWithIndex g m = reshape c . mapVectorWithIndex (mk c g) . flatten $ m | 613 | mapMatrixWithIndex g m = reshape c . mapVectorWithIndex (mk c g) . flatten $ m |
614 | where | 614 | where |
615 | c = cols m | 615 | c = cols m |
616 | 616 | ||
617 | mapMatrix :: (Element a, Element b) => (a -> b) -> Matrix a -> Matrix b | 617 | mapMatrix :: (Storable a, Storable b) => (a -> b) -> Matrix a -> Matrix b |
618 | mapMatrix f = liftMatrix (mapVector f) | 618 | mapMatrix f = liftMatrix (mapVector f) |
diff --git a/packages/base/src/Internal/Extract.hs b/packages/base/src/Internal/Extract.hs new file mode 100644 index 0000000..84ee20f --- /dev/null +++ b/packages/base/src/Internal/Extract.hs | |||
@@ -0,0 +1,145 @@ | |||
1 | {-# LANGUAGE BangPatterns #-} | ||
2 | {-# LANGUAGE NondecreasingIndentation #-} | ||
3 | {-# LANGUAGE PatternSynonyms #-} | ||
4 | {-# LANGUAGE UnboxedTuples #-} | ||
5 | module Internal.Extract where | ||
6 | import Control.Monad | ||
7 | import Data.Complex | ||
8 | import Data.Function | ||
9 | import Data.Int | ||
10 | import Foreign.Ptr | ||
11 | import Foreign.Storable | ||
12 | |||
13 | type ConstPtr a = Ptr a | ||
14 | pattern ConstPtr a = a | ||
15 | |||
16 | extractStorable :: Storable t => | ||
17 | Int32 -- int modei | ||
18 | -> Int32 -- int modej | ||
19 | -> Int32 -- / KIVEC(i) | ||
20 | -> ConstPtr Int32 -- \ | ||
21 | -> Int32 -- / KIVEC(j) | ||
22 | -> ConstPtr Int32 -- \ | ||
23 | -> Int32 -- / | ||
24 | -> Int32 -- / | ||
25 | -> Int32 -- { KO##T##MAT(m) | ||
26 | -> Int32 -- \ | ||
27 | -> ConstPtr t -- \ | ||
28 | -> Int32 -- / | ||
29 | -> Int32 -- / | ||
30 | -> Int32 -- { O##T##MAT(r) | ||
31 | -> Int32 -- \ | ||
32 | -> Ptr t -- \ | ||
33 | -> IO Int32 | ||
34 | extractStorable modei | ||
35 | modej | ||
36 | in_ (ConstPtr ip) | ||
37 | jn (ConstPtr jp) | ||
38 | mr mc mXr mXc (ConstPtr mp) | ||
39 | rr rc rXr rXc rp = do | ||
40 | -- int i,j,si,sj,ni,nj; | ||
41 | ni <- if modei/=0 then return in_ | ||
42 | else fmap succ $ (-) <$> peekElemOff ip 1 <*> peekElemOff ip 0 | ||
43 | nj <- if modej/=0 then return jn | ||
44 | else fmap succ $ (-) <$> peekElemOff jp 1 <*> peekElemOff jp 0 | ||
45 | ($ 0) $ fix $ \iloop i -> when (i<ni) $ do | ||
46 | si <- if modei/=0 then peekElemOff ip (fromIntegral i) | ||
47 | else (+ i) <$> peek ip | ||
48 | ($ 0) $ fix $ \jloop j -> when (j<nj) $ do | ||
49 | sj <- if modej/=0 then peekElemOff jp (fromIntegral j) | ||
50 | else (+ j) <$> peek jp | ||
51 | pokeElemOff rp (fromIntegral $ i*rXr + j*rXc) | ||
52 | =<< peekElemOff mp (fromIntegral $ si*mXr + sj*mXc) | ||
53 | jloop $! succ j | ||
54 | iloop $! succ i | ||
55 | return 0 | ||
56 | |||
57 | {-# SPECIALIZE extractStorable :: | ||
58 | Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32 | ||
59 | -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr Double | ||
60 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Double | ||
61 | -> IO Int32 #-} | ||
62 | |||
63 | {-# SPECIALIZE extractStorable :: | ||
64 | Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32 | ||
65 | -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr Float | ||
66 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Float | ||
67 | -> IO Int32 #-} | ||
68 | |||
69 | {-# SPECIALIZE extractStorable :: | ||
70 | Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32 | ||
71 | -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr (Complex Double) | ||
72 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr (Complex Double) | ||
73 | -> IO Int32 #-} | ||
74 | |||
75 | {-# SPECIALIZE extractStorable :: | ||
76 | Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32 | ||
77 | -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr (Complex Float) | ||
78 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr (Complex Float) | ||
79 | -> IO Int32 #-} | ||
80 | |||
81 | {-# SPECIALIZE extractStorable :: | ||
82 | Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32 | ||
83 | -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr Int32 | ||
84 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Int32 | ||
85 | -> IO Int32 #-} | ||
86 | |||
87 | {-# SPECIALIZE extractStorable :: | ||
88 | Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32 | ||
89 | -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr Int64 | ||
90 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Int64 | ||
91 | -> IO Int32 #-} | ||
92 | |||
93 | {- | ||
94 | type Reorder x = CV Int32 (CV Int32 (CV Int32 (CV x (CV x (IO Int32))))) | ||
95 | |||
96 | foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double | ||
97 | foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float | ||
98 | foreign import ccall unsafe "reorderI" c_reorderI :: Reorder Int32 | ||
99 | foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double) | ||
100 | foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float) | ||
101 | foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z | ||
102 | -} | ||
103 | |||
104 | -- #define ERROR(CODE) MACRO(return CODE;) | ||
105 | -- #define REQUIRES(COND, CODE) MACRO(if(!(COND)) {ERROR(CODE);}) | ||
106 | |||
107 | requires :: Monad m => Bool -> Int32 -> m Int32 -> m Int32 | ||
108 | requires cond code go = | ||
109 | if cond then go | ||
110 | else return code | ||
111 | |||
112 | pattern BAD_SIZE = 2000 | ||
113 | |||
114 | reorderStorable :: Storable a => | ||
115 | Int32 -> Ptr Int32 -- k | ||
116 | -> Int32 -> ConstPtr Int32 -- strides | ||
117 | -> Int32 -> ConstPtr Int32 -- dims | ||
118 | -> Int32 -> ConstPtr a -- v | ||
119 | -> Int32 -> Ptr a -- r | ||
120 | -> IO Int32 | ||
121 | reorderStorable kn kp stridesn stridesp dimsn dimsp vn vp rn rp = do | ||
122 | requires (kn == stridesn && stridesn == dimsn) BAD_SIZE $ do | ||
123 | let ijlloop !i !j l fin = do | ||
124 | pokeElemOff kp (fromIntegral l) 0 | ||
125 | dimspl <- peekElemOff dimsp (fromIntegral l) | ||
126 | stridespl <- peekElemOff stridesp (fromIntegral l) | ||
127 | if (l<kn) then ijlloop (i * dimspl) (j + stridespl*(dimspl - 1)) (l + 1) fin | ||
128 | else fin i j | ||
129 | ijlloop 1 0 0 $ \i j -> do | ||
130 | requires (i <= vn && j < rn) BAD_SIZE $ do | ||
131 | (\go -> go 0 0) $ fix $ \ijloop i j -> do | ||
132 | pokeElemOff rp (fromIntegral i) =<< peekElemOff vp (fromIntegral j) | ||
133 | (\go -> go (kn - 1) j) $ fix $ \lloop l !j -> do | ||
134 | kpl <- succ <$> peekElemOff kp (fromIntegral l) | ||
135 | pokeElemOff kp (fromIntegral l) kpl | ||
136 | dimspl <- peekElemOff dimsp (fromIntegral l) | ||
137 | if (kpl < dimspl) | ||
138 | then do | ||
139 | stridespl <- peekElemOff stridesp (fromIntegral l) | ||
140 | ijloop (succ i) (j + stridespl) | ||
141 | else do | ||
142 | if l == 0 then return 0 else do | ||
143 | pokeElemOff kp (fromIntegral l) 0 | ||
144 | stridespl <- peekElemOff stridesp (fromIntegral l) | ||
145 | lloop (pred l) (j - stridespl*(dimspl-1)) | ||
diff --git a/packages/base/src/Internal/IO.hs b/packages/base/src/Internal/IO.hs index b0f5606..de5eea5 100644 --- a/packages/base/src/Internal/IO.hs +++ b/packages/base/src/Internal/IO.hs | |||
@@ -23,6 +23,7 @@ import Internal.Vectorized | |||
23 | import Text.Printf(printf, PrintfArg, PrintfType) | 23 | import Text.Printf(printf, PrintfArg, PrintfType) |
24 | import Data.List(intersperse,transpose) | 24 | import Data.List(intersperse,transpose) |
25 | import Data.Complex | 25 | import Data.Complex |
26 | import Foreign.Storable | ||
26 | 27 | ||
27 | 28 | ||
28 | -- | Formatting tool | 29 | -- | Formatting tool |
@@ -45,7 +46,7 @@ this function the user can easily define any desired display function: | |||
45 | @disp = putStr . format \" \" (printf \"%.2f\")@ | 46 | @disp = putStr . format \" \" (printf \"%.2f\")@ |
46 | 47 | ||
47 | -} | 48 | -} |
48 | format :: (Element t) => String -> (t -> String) -> Matrix t -> String | 49 | format :: (Storable t) => String -> (t -> String) -> Matrix t -> String |
49 | format sep f m = table sep . map (map f) . toLists $ m | 50 | format sep f m = table sep . map (map f) . toLists $ m |
50 | 51 | ||
51 | {- | Show a matrix with \"autoscaling\" and a given number of decimal places. | 52 | {- | Show a matrix with \"autoscaling\" and a given number of decimal places. |
@@ -81,14 +82,14 @@ dispf d x = sdims x ++ "\n" ++ formatFixed (if isInt x then 0 else d) x | |||
81 | sdims :: Matrix t -> [Char] | 82 | sdims :: Matrix t -> [Char] |
82 | sdims x = show (rows x) ++ "x" ++ show (cols x) | 83 | sdims x = show (rows x) ++ "x" ++ show (cols x) |
83 | 84 | ||
84 | formatFixed :: (Show a, Text.Printf.PrintfArg t, Element t) | 85 | formatFixed :: (Show a, Text.Printf.PrintfArg t, Storable t) |
85 | => a -> Matrix t -> String | 86 | => a -> Matrix t -> String |
86 | formatFixed d x = format " " (printf ("%."++show d++"f")) $ x | 87 | formatFixed d x = format " " (printf ("%."++show d++"f")) $ x |
87 | 88 | ||
88 | isInt :: Matrix Double -> Bool | 89 | isInt :: Matrix Double -> Bool |
89 | isInt = all lookslikeInt . toList . flatten | 90 | isInt = all lookslikeInt . toList . flatten |
90 | 91 | ||
91 | formatScaled :: (Text.Printf.PrintfArg b, RealFrac b, Floating b, Num t, Element b, Show t) | 92 | formatScaled :: (Text.Printf.PrintfArg b, RealFrac b, Floating b, Num t, Storable b, Show t) |
92 | => t -> Matrix b -> [Char] | 93 | => t -> Matrix b -> [Char] |
93 | formatScaled dec t = "E"++show o++"\n" ++ ss | 94 | formatScaled dec t = "E"++show o++"\n" ++ ss |
94 | where ss = format " " (printf fmt. g) t | 95 | where ss = format " " (printf fmt. g) t |
@@ -104,7 +105,7 @@ formatScaled dec t = "E"++show o++"\n" ++ ss | |||
104 | 10 |> 0.00 0.11 0.22 0.33 0.44 0.56 0.67 0.78 0.89 1.00 | 105 | 10 |> 0.00 0.11 0.22 0.33 0.44 0.56 0.67 0.78 0.89 1.00 |
105 | 106 | ||
106 | -} | 107 | -} |
107 | vecdisp :: (Element t) => (Matrix t -> String) -> Vector t -> String | 108 | vecdisp :: (Storable t) => (Matrix t -> String) -> Vector t -> String |
108 | vecdisp f v | 109 | vecdisp f v |
109 | = ((show (dim v) ++ " |> ") ++) . (++"\n") | 110 | = ((show (dim v) ++ " |> ") ++) . (++"\n") |
110 | . unwords . lines . tail . dropWhile (not . (`elem` " \n")) | 111 | . unwords . lines . tail . dropWhile (not . (`elem` " \n")) |
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index 27d1f95..d88ff6b 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs | |||
@@ -22,9 +22,12 @@ import Data.Bifunctor (first) | |||
22 | 22 | ||
23 | import Internal.Devel | 23 | import Internal.Devel |
24 | import Internal.Vector | 24 | import Internal.Vector |
25 | import Internal.Vectorized (constantAux) | ||
25 | import Internal.Matrix hiding ((#), (#!)) | 26 | import Internal.Matrix hiding ((#), (#!)) |
26 | import Internal.Conversion | 27 | import Internal.Conversion |
27 | import Internal.Element | 28 | import Internal.Element |
29 | import Internal.ST (setRect) | ||
30 | import Data.Int | ||
28 | import Foreign.Ptr(nullPtr) | 31 | import Foreign.Ptr(nullPtr) |
29 | import Foreign.C.Types | 32 | import Foreign.C.Types |
30 | import Control.Monad(when) | 33 | import Control.Monad(when) |
@@ -46,10 +49,10 @@ type TMMM t = t ::> t ::> t ::> Ok | |||
46 | type F = Float | 49 | type F = Float |
47 | type Q = Complex Float | 50 | type Q = Complex Float |
48 | 51 | ||
49 | foreign import ccall unsafe "multiplyR" dgemmc :: CInt -> CInt -> TMMM R | 52 | foreign import ccall unsafe "multiplyR" dgemmc :: Int32 -> Int32 -> TMMM R |
50 | foreign import ccall unsafe "multiplyC" zgemmc :: CInt -> CInt -> TMMM C | 53 | foreign import ccall unsafe "multiplyC" zgemmc :: Int32 -> Int32 -> TMMM C |
51 | foreign import ccall unsafe "multiplyF" sgemmc :: CInt -> CInt -> TMMM F | 54 | foreign import ccall unsafe "multiplyF" sgemmc :: Int32 -> Int32 -> TMMM F |
52 | foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TMMM Q | 55 | foreign import ccall unsafe "multiplyQ" cgemmc :: Int32 -> Int32 -> TMMM Q |
53 | foreign import ccall unsafe "multiplyI" c_multiplyI :: I -> TMMM I | 56 | foreign import ccall unsafe "multiplyI" c_multiplyI :: I -> TMMM I |
54 | foreign import ccall unsafe "multiplyL" c_multiplyL :: Z -> TMMM Z | 57 | foreign import ccall unsafe "multiplyL" c_multiplyL :: Z -> TMMM Z |
55 | 58 | ||
@@ -82,7 +85,7 @@ multiplyF a b = multiplyAux sgemmc "sgemmc" a b | |||
82 | multiplyQ :: Matrix (Complex Float) -> Matrix (Complex Float) -> Matrix (Complex Float) | 85 | multiplyQ :: Matrix (Complex Float) -> Matrix (Complex Float) -> Matrix (Complex Float) |
83 | multiplyQ a b = multiplyAux cgemmc "cgemmc" a b | 86 | multiplyQ a b = multiplyAux cgemmc "cgemmc" a b |
84 | 87 | ||
85 | multiplyI :: I -> Matrix CInt -> Matrix CInt -> Matrix CInt | 88 | multiplyI :: I -> Matrix Int32 -> Matrix Int32 -> Matrix Int32 |
86 | multiplyI m a b = unsafePerformIO $ do | 89 | multiplyI m a b = unsafePerformIO $ do |
87 | when (cols a /= rows b) $ error $ | 90 | when (cols a /= rows b) $ error $ |
88 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b | 91 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b |
@@ -239,8 +242,8 @@ foreign import ccall unsafe "eig_l_R" dgeev :: R ::> R ::> C :> R ::> Ok | |||
239 | foreign import ccall unsafe "eig_l_G" dggev :: R ::> R ::> C :> R :> R ::> R ::> Ok | 242 | foreign import ccall unsafe "eig_l_G" dggev :: R ::> R ::> C :> R :> R ::> R ::> Ok |
240 | foreign import ccall unsafe "eig_l_C" zgeev :: C ::> C ::> C :> C ::> Ok | 243 | foreign import ccall unsafe "eig_l_C" zgeev :: C ::> C ::> C :> C ::> Ok |
241 | foreign import ccall unsafe "eig_l_GC" zggev :: C ::> C ::> C :> C :> C ::> C ::> Ok | 244 | foreign import ccall unsafe "eig_l_GC" zggev :: C ::> C ::> C :> C :> C ::> C ::> Ok |
242 | foreign import ccall unsafe "eig_l_S" dsyev :: CInt -> R :> R ::> Ok | 245 | foreign import ccall unsafe "eig_l_S" dsyev :: Int32 -> R :> R ::> Ok |
243 | foreign import ccall unsafe "eig_l_H" zheev :: CInt -> R :> C ::> Ok | 246 | foreign import ccall unsafe "eig_l_H" zheev :: Int32 -> R :> C ::> Ok |
244 | 247 | ||
245 | eigAux f st m = unsafePerformIO $ do | 248 | eigAux f st m = unsafePerformIO $ do |
246 | a <- copy ColumnMajor m | 249 | a <- copy ColumnMajor m |
@@ -636,7 +639,7 @@ qrgrAux f st n (a, tau) = unsafePerformIO $ do | |||
636 | ((subVector 0 n tau') #! res) f #| st | 639 | ((subVector 0 n tau') #! res) f #| st |
637 | return res | 640 | return res |
638 | where | 641 | where |
639 | tau' = vjoin [tau, constantD 0 n] | 642 | tau' = vjoin [tau, constantAux 0 n] |
640 | 643 | ||
641 | ----------------------------------------------------------------------------------- | 644 | ----------------------------------------------------------------------------------- |
642 | foreign import ccall unsafe "hess_l_R" dgehrd :: R :> R ::> Ok | 645 | foreign import ccall unsafe "hess_l_R" dgehrd :: R :> R ::> Ok |
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 5436e59..04092f9 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -2,6 +2,7 @@ | |||
2 | {-# LANGUAGE FlexibleContexts #-} | 2 | {-# LANGUAGE FlexibleContexts #-} |
3 | {-# LANGUAGE FlexibleInstances #-} | 3 | {-# LANGUAGE FlexibleInstances #-} |
4 | {-# LANGUAGE BangPatterns #-} | 4 | {-# LANGUAGE BangPatterns #-} |
5 | {-# LANGUAGE CPP #-} | ||
5 | {-# LANGUAGE TypeOperators #-} | 6 | {-# LANGUAGE TypeOperators #-} |
6 | {-# LANGUAGE TypeFamilies #-} | 7 | {-# LANGUAGE TypeFamilies #-} |
7 | {-# LANGUAGE ViewPatterns #-} | 8 | {-# LANGUAGE ViewPatterns #-} |
@@ -22,12 +23,14 @@ module Internal.Matrix where | |||
22 | 23 | ||
23 | import Internal.Vector | 24 | import Internal.Vector |
24 | import Internal.Devel | 25 | import Internal.Devel |
26 | import Internal.Extract | ||
25 | import Internal.Vectorized hiding ((#), (#!)) | 27 | import Internal.Vectorized hiding ((#), (#!)) |
26 | import Foreign.Marshal.Alloc ( free ) | 28 | import Foreign.Marshal.Alloc ( free ) |
27 | import Foreign.Marshal.Array(newArray) | 29 | import Foreign.Marshal.Array(newArray) |
28 | import Foreign.Ptr ( Ptr ) | 30 | import Foreign.Ptr ( Ptr ) |
29 | import Foreign.Storable ( Storable ) | 31 | import Foreign.Storable ( Storable ) |
30 | import Data.Complex ( Complex ) | 32 | import Data.Complex ( Complex ) |
33 | import Data.Int | ||
31 | import Foreign.C.Types ( CInt(..) ) | 34 | import Foreign.C.Types ( CInt(..) ) |
32 | import Foreign.C.String ( CString, newCString ) | 35 | import Foreign.C.String ( CString, newCString ) |
33 | import System.IO.Unsafe ( unsafePerformIO ) | 36 | import System.IO.Unsafe ( unsafePerformIO ) |
@@ -61,19 +64,23 @@ size :: Matrix t -> (Int, Int) | |||
61 | size m = (irows m, icols m) | 64 | size m = (irows m, icols m) |
62 | {-# INLINE size #-} | 65 | {-# INLINE size #-} |
63 | 66 | ||
67 | -- | True if the matrix is in RowMajor form. | ||
64 | rowOrder :: Matrix t -> Bool | 68 | rowOrder :: Matrix t -> Bool |
65 | rowOrder m = xCol m == 1 || cols m == 1 | 69 | rowOrder m = xCol m == 1 || cols m == 1 |
66 | {-# INLINE rowOrder #-} | 70 | {-# INLINE rowOrder #-} |
67 | 71 | ||
72 | -- | True if the matrix is in ColMajor form or if their is only one row. | ||
68 | colOrder :: Matrix t -> Bool | 73 | colOrder :: Matrix t -> Bool |
69 | colOrder m = xRow m == 1 || rows m == 1 | 74 | colOrder m = xRow m == 1 || rows m == 1 |
70 | {-# INLINE colOrder #-} | 75 | {-# INLINE colOrder #-} |
71 | 76 | ||
77 | -- | True if the matrix is a single row or column vector. | ||
72 | is1d :: Matrix t -> Bool | 78 | is1d :: Matrix t -> Bool |
73 | is1d (size->(r,c)) = r==1 || c==1 | 79 | is1d (size->(r,c)) = r==1 || c==1 |
74 | {-# INLINE is1d #-} | 80 | {-# INLINE is1d #-} |
75 | 81 | ||
76 | -- data is not contiguous | 82 | -- | True if the matrix is not contiguous. This usually |
83 | -- means it is a slice of some larger matrix. | ||
77 | isSlice :: Storable t => Matrix t -> Bool | 84 | isSlice :: Storable t => Matrix t -> Bool |
78 | isSlice m@(size->(r,c)) = r*c < dim (xdat m) | 85 | isSlice m@(size->(r,c)) = r*c < dim (xdat m) |
79 | {-# INLINE isSlice #-} | 86 | {-# INLINE isSlice #-} |
@@ -95,19 +102,23 @@ showInternal m = printf "%dx%d %s %s %d:%d (%d)\n" r c slc ord xr xc dv | |||
95 | 102 | ||
96 | -------------------------------------------------------------------------------- | 103 | -------------------------------------------------------------------------------- |
97 | 104 | ||
98 | -- | Matrix transpose. | 105 | -- | O(1) Matrix transpose. This is only a logical transposition that does not |
106 | -- re-order the element storage. If the storage order is important, use 'cmat' | ||
107 | -- or 'fmat'. | ||
99 | trans :: Matrix t -> Matrix t | 108 | trans :: Matrix t -> Matrix t |
100 | trans m@Matrix { irows = r, icols = c, xRow = xr, xCol = xc } = | 109 | trans m@Matrix { irows = r, icols = c, xRow = xr, xCol = xc } = |
101 | m { irows = c, icols = r, xRow = xc, xCol = xr } | 110 | m { irows = c, icols = r, xRow = xc, xCol = xr } |
102 | 111 | ||
103 | 112 | ||
104 | cmat :: (Element t) => Matrix t -> Matrix t | 113 | -- | Obtain the RowMajor equivalent of a given Matrix. |
114 | cmat :: (Storable t) => Matrix t -> Matrix t | ||
105 | cmat m | 115 | cmat m |
106 | | rowOrder m = m | 116 | | rowOrder m = m |
107 | | otherwise = extractAll RowMajor m | 117 | | otherwise = extractAll RowMajor m |
108 | 118 | ||
109 | 119 | ||
110 | fmat :: (Element t) => Matrix t -> Matrix t | 120 | -- | Obtain the ColumnMajor equivalent of a given Matrix. |
121 | fmat :: (Storable t) => Matrix t -> Matrix t | ||
111 | fmat m | 122 | fmat m |
112 | | colOrder m = m | 123 | | colOrder m = m |
113 | | otherwise = extractAll ColumnMajor m | 124 | | otherwise = extractAll ColumnMajor m |
@@ -115,14 +126,14 @@ fmat m | |||
115 | 126 | ||
116 | -- C-Haskell matrix adapters | 127 | -- C-Haskell matrix adapters |
117 | {-# INLINE amatr #-} | 128 | {-# INLINE amatr #-} |
118 | amatr :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> Ptr a -> f) -> IO r | 129 | amatr :: Storable a => Matrix a -> (f -> IO r) -> (Int32 -> Int32 -> Ptr a -> f) -> IO r |
119 | amatr x f g = unsafeWith (xdat x) (f . g r c) | 130 | amatr x f g = unsafeWith (xdat x) (f . g r c) |
120 | where | 131 | where |
121 | r = fi (rows x) | 132 | r = fi (rows x) |
122 | c = fi (cols x) | 133 | c = fi (cols x) |
123 | 134 | ||
124 | {-# INLINE amat #-} | 135 | {-# INLINE amat #-} |
125 | amat :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> CInt -> CInt -> Ptr a -> f) -> IO r | 136 | amat :: Storable a => Matrix a -> (f -> IO r) -> (Int32 -> Int32 -> Int32 -> Int32 -> Ptr a -> f) -> IO r |
126 | amat x f g = unsafeWith (xdat x) (f . g r c sr sc) | 137 | amat x f g = unsafeWith (xdat x) (f . g r c sr sc) |
127 | where | 138 | where |
128 | r = fi (rows x) | 139 | r = fi (rows x) |
@@ -133,8 +144,8 @@ amat x f g = unsafeWith (xdat x) (f . g r c sr sc) | |||
133 | 144 | ||
134 | instance Storable t => TransArray (Matrix t) | 145 | instance Storable t => TransArray (Matrix t) |
135 | where | 146 | where |
136 | type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b | 147 | type TransRaw (Matrix t) b = Int32 -> Int32 -> Ptr t -> b |
137 | type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b | 148 | type Trans (Matrix t) b = Int32 -> Int32 -> Int32 -> Int32 -> Ptr t -> b |
138 | apply = amat | 149 | apply = amat |
139 | {-# INLINE apply #-} | 150 | {-# INLINE apply #-} |
140 | applyRaw = amatr | 151 | applyRaw = amatr |
@@ -151,10 +162,10 @@ a #! b = a # b # id | |||
151 | 162 | ||
152 | -------------------------------------------------------------------------------- | 163 | -------------------------------------------------------------------------------- |
153 | 164 | ||
154 | copy :: Element t => MatrixOrder -> Matrix t -> IO (Matrix t) | 165 | copy :: Storable t => MatrixOrder -> Matrix t -> IO (Matrix t) |
155 | copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) | 166 | copy ord m = extractAux ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) |
156 | 167 | ||
157 | extractAll :: Element t => MatrixOrder -> Matrix t -> Matrix t | 168 | extractAll :: Storable t => MatrixOrder -> Matrix t -> Matrix t |
158 | extractAll ord m = unsafePerformIO (copy ord m) | 169 | extractAll ord m = unsafePerformIO (copy ord m) |
159 | 170 | ||
160 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. | 171 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. |
@@ -164,14 +175,14 @@ extractAll ord m = unsafePerformIO (copy ord m) | |||
164 | it :: (Num t, Element t) => Vector t | 175 | it :: (Num t, Element t) => Vector t |
165 | 176 | ||
166 | -} | 177 | -} |
167 | flatten :: Element t => Matrix t -> Vector t | 178 | flatten :: Storable t => Matrix t -> Vector t |
168 | flatten m | 179 | flatten m |
169 | | isSlice m || not (rowOrder m) = xdat (extractAll RowMajor m) | 180 | | isSlice m || not (rowOrder m) = xdat (extractAll RowMajor m) |
170 | | otherwise = xdat m | 181 | | otherwise = xdat m |
171 | 182 | ||
172 | 183 | ||
173 | -- | the inverse of 'Data.Packed.Matrix.fromLists' | 184 | -- | the inverse of 'Data.Packed.Matrix.fromLists' |
174 | toLists :: (Element t) => Matrix t -> [[t]] | 185 | toLists :: (Storable t) => Matrix t -> [[t]] |
175 | toLists = map toList . toRows | 186 | toLists = map toList . toRows |
176 | 187 | ||
177 | 188 | ||
@@ -192,7 +203,7 @@ compatdim (a:b:xs) | |||
192 | -- | Create a matrix from a list of vectors. | 203 | -- | Create a matrix from a list of vectors. |
193 | -- All vectors must have the same dimension, | 204 | -- All vectors must have the same dimension, |
194 | -- or dimension 1, which is are automatically expanded. | 205 | -- or dimension 1, which is are automatically expanded. |
195 | fromRows :: Element t => [Vector t] -> Matrix t | 206 | fromRows :: Storable t => [Vector t] -> Matrix t |
196 | fromRows [] = emptyM 0 0 | 207 | fromRows [] = emptyM 0 0 |
197 | fromRows vs = case compatdim (map dim vs) of | 208 | fromRows vs = case compatdim (map dim vs) of |
198 | Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs) | 209 | Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs) |
@@ -203,25 +214,25 @@ fromRows vs = case compatdim (map dim vs) of | |||
203 | adapt c v | 214 | adapt c v |
204 | | c == 0 = fromList[] | 215 | | c == 0 = fromList[] |
205 | | dim v == c = v | 216 | | dim v == c = v |
206 | | otherwise = constantD (v@>0) c | 217 | | otherwise = constantAux (v@>0) c |
207 | 218 | ||
208 | -- | extracts the rows of a matrix as a list of vectors | 219 | -- | extracts the rows of a matrix as a list of vectors |
209 | toRows :: Element t => Matrix t -> [Vector t] | 220 | toRows :: Storable t => Matrix t -> [Vector t] |
210 | toRows m | 221 | toRows m |
211 | | rowOrder m = map sub rowRange | 222 | | rowOrder m = map sub rowRange |
212 | | otherwise = map ext rowRange | 223 | | otherwise = map ext rowRange |
213 | where | 224 | where |
214 | rowRange = [0..rows m-1] | 225 | rowRange = [0..rows m-1] |
215 | sub k = subVector (k*xRow m) (cols m) (xdat m) | 226 | sub k = subVector (k*xRow m) (cols m) (xdat m) |
216 | ext k = xdat $ unsafePerformIO $ extractR RowMajor m 1 (idxs[k]) 0 (idxs[0,cols m-1]) | 227 | ext k = xdat $ unsafePerformIO $ extractAux RowMajor m 1 (idxs[k]) 0 (idxs[0,cols m-1]) |
217 | 228 | ||
218 | 229 | ||
219 | -- | Creates a matrix from a list of vectors, as columns | 230 | -- | Creates a matrix from a list of vectors, as columns |
220 | fromColumns :: Element t => [Vector t] -> Matrix t | 231 | fromColumns :: Storable t => [Vector t] -> Matrix t |
221 | fromColumns m = trans . fromRows $ m | 232 | fromColumns m = trans . fromRows $ m |
222 | 233 | ||
223 | -- | Creates a list of vectors from the columns of a matrix | 234 | -- | Creates a list of vectors from the columns of a matrix |
224 | toColumns :: Element t => Matrix t -> [Vector t] | 235 | toColumns :: Storable t => Matrix t -> [Vector t] |
225 | toColumns m = toRows . trans $ m | 236 | toColumns m = toRows . trans $ m |
226 | 237 | ||
227 | -- | Reads a matrix position. | 238 | -- | Reads a matrix position. |
@@ -271,13 +282,13 @@ reshape c v = matrixFromVector RowMajor (dim v `div` c) c v | |||
271 | 282 | ||
272 | 283 | ||
273 | -- | application of a vector function on the flattened matrix elements | 284 | -- | application of a vector function on the flattened matrix elements |
274 | liftMatrix :: (Element a, Element b) => (Vector a -> Vector b) -> Matrix a -> Matrix b | 285 | liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b |
275 | liftMatrix f m@Matrix { irows = r, icols = c, xdat = d} | 286 | liftMatrix f m@Matrix { irows = r, icols = c, xdat = d} |
276 | | isSlice m = matrixFromVector RowMajor r c (f (flatten m)) | 287 | | isSlice m = matrixFromVector RowMajor r c (f (flatten m)) |
277 | | otherwise = matrixFromVector (orderOf m) r c (f d) | 288 | | otherwise = matrixFromVector (orderOf m) r c (f d) |
278 | 289 | ||
279 | -- | application of a vector function on the flattened matrices elements | 290 | -- | application of a vector function on the flattened matrices elements |
280 | liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | 291 | liftMatrix2 :: (Storable t, Storable a, Storable b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t |
281 | liftMatrix2 f m1@(size->(r,c)) m2 | 292 | liftMatrix2 f m1@(size->(r,c)) m2 |
282 | | (r,c)/=size m2 = error "nonconformant matrices in liftMatrix2" | 293 | | (r,c)/=size m2 = error "nonconformant matrices in liftMatrix2" |
283 | | rowOrder m1 = matrixFromVector RowMajor r c (f (flatten m1) (flatten m2)) | 294 | | rowOrder m1 = matrixFromVector RowMajor r c (f (flatten m1) (flatten m2)) |
@@ -285,103 +296,8 @@ liftMatrix2 f m1@(size->(r,c)) m2 | |||
285 | 296 | ||
286 | ------------------------------------------------------------------ | 297 | ------------------------------------------------------------------ |
287 | 298 | ||
288 | -- | Supported matrix elements. | ||
289 | class (Storable a) => Element a where | ||
290 | constantD :: a -> Int -> Vector a | ||
291 | extractR :: MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) | ||
292 | setRect :: Int -> Int -> Matrix a -> Matrix a -> IO () | ||
293 | sortI :: Ord a => Vector a -> Vector CInt | ||
294 | sortV :: Ord a => Vector a -> Vector a | ||
295 | compareV :: Ord a => Vector a -> Vector a -> Vector CInt | ||
296 | selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a | ||
297 | remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a | ||
298 | rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () | ||
299 | gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () | ||
300 | reorderV :: Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation | ||
301 | |||
302 | |||
303 | instance Element Float where | ||
304 | constantD = constantAux cconstantF | ||
305 | extractR = extractAux c_extractF | ||
306 | setRect = setRectAux c_setRectF | ||
307 | sortI = sortIdxF | ||
308 | sortV = sortValF | ||
309 | compareV = compareF | ||
310 | selectV = selectF | ||
311 | remapM = remapF | ||
312 | rowOp = rowOpAux c_rowOpF | ||
313 | gemm = gemmg c_gemmF | ||
314 | reorderV = reorderAux c_reorderF | ||
315 | |||
316 | instance Element Double where | ||
317 | constantD = constantAux cconstantR | ||
318 | extractR = extractAux c_extractD | ||
319 | setRect = setRectAux c_setRectD | ||
320 | sortI = sortIdxD | ||
321 | sortV = sortValD | ||
322 | compareV = compareD | ||
323 | selectV = selectD | ||
324 | remapM = remapD | ||
325 | rowOp = rowOpAux c_rowOpD | ||
326 | gemm = gemmg c_gemmD | ||
327 | reorderV = reorderAux c_reorderD | ||
328 | |||
329 | instance Element (Complex Float) where | ||
330 | constantD = constantAux cconstantQ | ||
331 | extractR = extractAux c_extractQ | ||
332 | setRect = setRectAux c_setRectQ | ||
333 | sortI = undefined | ||
334 | sortV = undefined | ||
335 | compareV = undefined | ||
336 | selectV = selectQ | ||
337 | remapM = remapQ | ||
338 | rowOp = rowOpAux c_rowOpQ | ||
339 | gemm = gemmg c_gemmQ | ||
340 | reorderV = reorderAux c_reorderQ | ||
341 | |||
342 | instance Element (Complex Double) where | ||
343 | constantD = constantAux cconstantC | ||
344 | extractR = extractAux c_extractC | ||
345 | setRect = setRectAux c_setRectC | ||
346 | sortI = undefined | ||
347 | sortV = undefined | ||
348 | compareV = undefined | ||
349 | selectV = selectC | ||
350 | remapM = remapC | ||
351 | rowOp = rowOpAux c_rowOpC | ||
352 | gemm = gemmg c_gemmC | ||
353 | reorderV = reorderAux c_reorderC | ||
354 | |||
355 | instance Element (CInt) where | ||
356 | constantD = constantAux cconstantI | ||
357 | extractR = extractAux c_extractI | ||
358 | setRect = setRectAux c_setRectI | ||
359 | sortI = sortIdxI | ||
360 | sortV = sortValI | ||
361 | compareV = compareI | ||
362 | selectV = selectI | ||
363 | remapM = remapI | ||
364 | rowOp = rowOpAux c_rowOpI | ||
365 | gemm = gemmg c_gemmI | ||
366 | reorderV = reorderAux c_reorderI | ||
367 | |||
368 | instance Element Z where | ||
369 | constantD = constantAux cconstantL | ||
370 | extractR = extractAux c_extractL | ||
371 | setRect = setRectAux c_setRectL | ||
372 | sortI = sortIdxL | ||
373 | sortV = sortValL | ||
374 | compareV = compareL | ||
375 | selectV = selectL | ||
376 | remapM = remapL | ||
377 | rowOp = rowOpAux c_rowOpL | ||
378 | gemm = gemmg c_gemmL | ||
379 | reorderV = reorderAux c_reorderL | ||
380 | |||
381 | ------------------------------------------------------------------- | ||
382 | |||
383 | -- | reference to a rectangular slice of a matrix (no data copy) | 299 | -- | reference to a rectangular slice of a matrix (no data copy) |
384 | subMatrix :: Element a | 300 | subMatrix :: Storable a |
385 | => (Int,Int) -- ^ (r0,c0) starting position | 301 | => (Int,Int) -- ^ (r0,c0) starting position |
386 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | 302 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix |
387 | -> Matrix a -- ^ input matrix | 303 | -> Matrix a -- ^ input matrix |
@@ -402,34 +318,34 @@ subMatrix (r0,c0) (rt,ct) m | |||
402 | maxZ :: (Num t1, Ord t1, Foldable t) => t t1 -> t1 | 318 | maxZ :: (Num t1, Ord t1, Foldable t) => t t1 -> t1 |
403 | maxZ xs = if minimum xs == 0 then 0 else maximum xs | 319 | maxZ xs = if minimum xs == 0 then 0 else maximum xs |
404 | 320 | ||
405 | conformMs :: Element t => [Matrix t] -> [Matrix t] | 321 | conformMs :: Storable t => [Matrix t] -> [Matrix t] |
406 | conformMs ms = map (conformMTo (r,c)) ms | 322 | conformMs ms = map (conformMTo (r,c)) ms |
407 | where | 323 | where |
408 | r = maxZ (map rows ms) | 324 | r = maxZ (map rows ms) |
409 | c = maxZ (map cols ms) | 325 | c = maxZ (map cols ms) |
410 | 326 | ||
411 | conformVs :: Element t => [Vector t] -> [Vector t] | 327 | conformVs :: Storable t => [Vector t] -> [Vector t] |
412 | conformVs vs = map (conformVTo n) vs | 328 | conformVs vs = map (conformVTo n) vs |
413 | where | 329 | where |
414 | n = maxZ (map dim vs) | 330 | n = maxZ (map dim vs) |
415 | 331 | ||
416 | conformMTo :: Element t => (Int, Int) -> Matrix t -> Matrix t | 332 | conformMTo :: Storable t => (Int, Int) -> Matrix t -> Matrix t |
417 | conformMTo (r,c) m | 333 | conformMTo (r,c) m |
418 | | size m == (r,c) = m | 334 | | size m == (r,c) = m |
419 | | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) | 335 | | size m == (1,1) = matrixFromVector RowMajor r c (constantAux (m@@>(0,0)) (r*c)) |
420 | | size m == (r,1) = repCols c m | 336 | | size m == (r,1) = repCols c m |
421 | | size m == (1,c) = repRows r m | 337 | | size m == (1,c) = repRows r m |
422 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c) | 338 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c) |
423 | 339 | ||
424 | conformVTo :: Element t => Int -> Vector t -> Vector t | 340 | conformVTo :: Storable t => Int -> Vector t -> Vector t |
425 | conformVTo n v | 341 | conformVTo n v |
426 | | dim v == n = v | 342 | | dim v == n = v |
427 | | dim v == 1 = constantD (v@>0) n | 343 | | dim v == 1 = constantAux (v@>0) n |
428 | | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n | 344 | | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n |
429 | 345 | ||
430 | repRows :: Element t => Int -> Matrix t -> Matrix t | 346 | repRows :: Storable t => Int -> Matrix t -> Matrix t |
431 | repRows n x = fromRows (replicate n (flatten x)) | 347 | repRows n x = fromRows (replicate n (flatten x)) |
432 | repCols :: Element t => Int -> Matrix t -> Matrix t | 348 | repCols :: Storable t => Int -> Matrix t -> Matrix t |
433 | repCols n x = fromColumns (replicate n (flatten x)) | 349 | repCols n x = fromColumns (replicate n (flatten x)) |
434 | 350 | ||
435 | shSize :: Matrix t -> [Char] | 351 | shSize :: Matrix t -> [Char] |
@@ -453,32 +369,50 @@ instance (Storable t, NFData t) => NFData (Matrix t) | |||
453 | 369 | ||
454 | --------------------------------------------------------------- | 370 | --------------------------------------------------------------- |
455 | 371 | ||
372 | {- | ||
456 | extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1, | 373 | extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1, |
457 | Storable t, Num t3, Num t2, Integral t1, Integral t) | 374 | Storable t, Num t3, Num t2, Integral t1, Integral t) |
458 | => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t | 375 | => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) -- f |
459 | -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) | 376 | -> MatrixOrder -- ord |
460 | -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a) | 377 | -> c -- m |
461 | extractAux f ord m moder vr modec vc = do | 378 | -> t3 -- moder |
379 | -> Vector t1 -- vr | ||
380 | -> t2 -- modec | ||
381 | -> Vector t -- vc | ||
382 | -> IO (Matrix a) | ||
383 | -} | ||
384 | |||
385 | extractAux :: Storable a => | ||
386 | MatrixOrder | ||
387 | -> Matrix a | ||
388 | -> Int32 | ||
389 | -> Vector Int32 | ||
390 | -> Int32 | ||
391 | -> Vector Int32 | ||
392 | -> IO (Matrix a) | ||
393 | extractAux ord m moder vr modec vc = do | ||
462 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr | 394 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr |
463 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc | 395 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc |
464 | r <- createMatrix ord nr nc | 396 | r <- createMatrix ord nr nc |
465 | (vr # vc # m #! r) (f moder modec) #|"extract" | 397 | (vr # vc # m #! r) (extractStorable moder modec) #|"extract" |
466 | 398 | ||
467 | return r | 399 | return r |
468 | 400 | ||
469 | type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) | 401 | {- |
402 | type Extr x = Int32 -> Int32 -> CIdxs (CIdxs (OM x (OM x (IO Int32)))) | ||
470 | 403 | ||
471 | foreign import ccall unsafe "extractD" c_extractD :: Extr Double | 404 | foreign import ccall unsafe "extractD" c_extractD :: Extr Double |
472 | foreign import ccall unsafe "extractF" c_extractF :: Extr Float | 405 | foreign import ccall unsafe "extractF" c_extractF :: Extr Float |
473 | foreign import ccall unsafe "extractC" c_extractC :: Extr (Complex Double) | 406 | foreign import ccall unsafe "extractC" c_extractC :: Extr (Complex Double) |
474 | foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float) | 407 | foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float) |
475 | foreign import ccall unsafe "extractI" c_extractI :: Extr CInt | 408 | foreign import ccall unsafe "extractI" c_extractI :: Extr Int32 |
476 | foreign import ccall unsafe "extractL" c_extractL :: Extr Z | 409 | foreign import ccall unsafe "extractL" c_extractL :: Extr Z |
410 | -} | ||
477 | 411 | ||
478 | --------------------------------------------------------------- | 412 | --------------------------------------------------------------- |
479 | 413 | ||
480 | setRectAux :: (TransArray c1, TransArray c) | 414 | setRectAux :: (TransArray c1, TransArray c) |
481 | => (CInt -> CInt -> Trans c1 (Trans c (IO CInt))) | 415 | => (Int32 -> Int32 -> Trans c1 (Trans c (IO Int32))) |
482 | -> Int -> Int -> c1 -> c -> IO () | 416 | -> Int -> Int -> c1 -> c -> IO () |
483 | setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" | 417 | setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" |
484 | 418 | ||
@@ -494,17 +428,17 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z | |||
494 | -------------------------------------------------------------------------------- | 428 | -------------------------------------------------------------------------------- |
495 | 429 | ||
496 | sortG :: (Storable t, Storable a) | 430 | sortG :: (Storable t, Storable a) |
497 | => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a | 431 | => (Int32 -> Ptr t -> Int32 -> Ptr a -> IO Int32) -> Vector t -> Vector a |
498 | sortG f v = unsafePerformIO $ do | 432 | sortG f v = unsafePerformIO $ do |
499 | r <- createVector (dim v) | 433 | r <- createVector (dim v) |
500 | (v #! r) f #|"sortG" | 434 | (v #! r) f #|"sortG" |
501 | return r | 435 | return r |
502 | 436 | ||
503 | sortIdxD :: Vector Double -> Vector CInt | 437 | sortIdxD :: Vector Double -> Vector Int32 |
504 | sortIdxD = sortG c_sort_indexD | 438 | sortIdxD = sortG c_sort_indexD |
505 | sortIdxF :: Vector Float -> Vector CInt | 439 | sortIdxF :: Vector Float -> Vector Int32 |
506 | sortIdxF = sortG c_sort_indexF | 440 | sortIdxF = sortG c_sort_indexF |
507 | sortIdxI :: Vector CInt -> Vector CInt | 441 | sortIdxI :: Vector Int32 -> Vector Int32 |
508 | sortIdxI = sortG c_sort_indexI | 442 | sortIdxI = sortG c_sort_indexI |
509 | sortIdxL :: Vector Z -> Vector I | 443 | sortIdxL :: Vector Z -> Vector I |
510 | sortIdxL = sortG c_sort_indexL | 444 | sortIdxL = sortG c_sort_indexL |
@@ -513,81 +447,81 @@ sortValD :: Vector Double -> Vector Double | |||
513 | sortValD = sortG c_sort_valD | 447 | sortValD = sortG c_sort_valD |
514 | sortValF :: Vector Float -> Vector Float | 448 | sortValF :: Vector Float -> Vector Float |
515 | sortValF = sortG c_sort_valF | 449 | sortValF = sortG c_sort_valF |
516 | sortValI :: Vector CInt -> Vector CInt | 450 | sortValI :: Vector Int32 -> Vector Int32 |
517 | sortValI = sortG c_sort_valI | 451 | sortValI = sortG c_sort_valI |
518 | sortValL :: Vector Z -> Vector Z | 452 | sortValL :: Vector Z -> Vector Z |
519 | sortValL = sortG c_sort_valL | 453 | sortValL = sortG c_sort_valL |
520 | 454 | ||
521 | foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) | 455 | foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV Int32 (IO Int32)) |
522 | foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV CInt (IO CInt)) | 456 | foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV Int32 (IO Int32)) |
523 | foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV CInt (CV CInt (IO CInt)) | 457 | foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV Int32 (CV Int32 (IO Int32)) |
524 | foreign import ccall unsafe "sort_indexL" c_sort_indexL :: Z :> I :> Ok | 458 | foreign import ccall unsafe "sort_indexL" c_sort_indexL :: Z :> I :> Ok |
525 | 459 | ||
526 | foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO CInt)) | 460 | foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO Int32)) |
527 | foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO CInt)) | 461 | foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO Int32)) |
528 | foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV CInt (CV CInt (IO CInt)) | 462 | foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV Int32 (CV Int32 (IO Int32)) |
529 | foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok | 463 | foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok |
530 | 464 | ||
531 | -------------------------------------------------------------------------------- | 465 | -------------------------------------------------------------------------------- |
532 | 466 | ||
533 | compareG :: (TransArray c, Storable t, Storable a) | 467 | compareG :: (TransArray c, Storable t, Storable a) |
534 | => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) | 468 | => Trans c (Int32 -> Ptr t -> Int32 -> Ptr a -> IO Int32) |
535 | -> c -> Vector t -> Vector a | 469 | -> c -> Vector t -> Vector a |
536 | compareG f u v = unsafePerformIO $ do | 470 | compareG f u v = unsafePerformIO $ do |
537 | r <- createVector (dim v) | 471 | r <- createVector (dim v) |
538 | (u # v #! r) f #|"compareG" | 472 | (u # v #! r) f #|"compareG" |
539 | return r | 473 | return r |
540 | 474 | ||
541 | compareD :: Vector Double -> Vector Double -> Vector CInt | 475 | compareD :: Vector Double -> Vector Double -> Vector Int32 |
542 | compareD = compareG c_compareD | 476 | compareD = compareG c_compareD |
543 | compareF :: Vector Float -> Vector Float -> Vector CInt | 477 | compareF :: Vector Float -> Vector Float -> Vector Int32 |
544 | compareF = compareG c_compareF | 478 | compareF = compareG c_compareF |
545 | compareI :: Vector CInt -> Vector CInt -> Vector CInt | 479 | compareI :: Vector Int32 -> Vector Int32 -> Vector Int32 |
546 | compareI = compareG c_compareI | 480 | compareI = compareG c_compareI |
547 | compareL :: Vector Z -> Vector Z -> Vector CInt | 481 | compareL :: Vector Z -> Vector Z -> Vector Int32 |
548 | compareL = compareG c_compareL | 482 | compareL = compareG c_compareL |
549 | 483 | ||
550 | foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) | 484 | foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV Int32 (IO Int32))) |
551 | foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV CInt (IO CInt))) | 485 | foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV Int32 (IO Int32))) |
552 | foreign import ccall unsafe "compareI" c_compareI :: CV CInt (CV CInt (CV CInt (IO CInt))) | 486 | foreign import ccall unsafe "compareI" c_compareI :: CV Int32 (CV Int32 (CV Int32 (IO Int32))) |
553 | foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok | 487 | foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok |
554 | 488 | ||
555 | -------------------------------------------------------------------------------- | 489 | -------------------------------------------------------------------------------- |
556 | 490 | ||
557 | selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a) | 491 | selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a) |
558 | => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt))) | 492 | => Trans c2 (Trans c1 (Int32 -> Ptr t -> Trans c (Int32 -> Ptr a -> IO Int32))) |
559 | -> c2 -> c1 -> Vector t -> c -> Vector a | 493 | -> c2 -> c1 -> Vector t -> c -> Vector a |
560 | selectG f c u v w = unsafePerformIO $ do | 494 | selectG f c u v w = unsafePerformIO $ do |
561 | r <- createVector (dim v) | 495 | r <- createVector (dim v) |
562 | (c # u # v # w #! r) f #|"selectG" | 496 | (c # u # v # w #! r) f #|"selectG" |
563 | return r | 497 | return r |
564 | 498 | ||
565 | selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double | 499 | selectD :: Vector Int32 -> Vector Double -> Vector Double -> Vector Double -> Vector Double |
566 | selectD = selectG c_selectD | 500 | selectD = selectG c_selectD |
567 | selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float | 501 | selectF :: Vector Int32 -> Vector Float -> Vector Float -> Vector Float -> Vector Float |
568 | selectF = selectG c_selectF | 502 | selectF = selectG c_selectF |
569 | selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt | 503 | selectI :: Vector Int32 -> Vector Int32 -> Vector Int32 -> Vector Int32 -> Vector Int32 |
570 | selectI = selectG c_selectI | 504 | selectI = selectG c_selectI |
571 | selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z | 505 | selectL :: Vector Int32 -> Vector Z -> Vector Z -> Vector Z -> Vector Z |
572 | selectL = selectG c_selectL | 506 | selectL = selectG c_selectL |
573 | selectC :: Vector CInt | 507 | selectC :: Vector Int32 |
574 | -> Vector (Complex Double) | 508 | -> Vector (Complex Double) |
575 | -> Vector (Complex Double) | 509 | -> Vector (Complex Double) |
576 | -> Vector (Complex Double) | 510 | -> Vector (Complex Double) |
577 | -> Vector (Complex Double) | 511 | -> Vector (Complex Double) |
578 | selectC = selectG c_selectC | 512 | selectC = selectG c_selectC |
579 | selectQ :: Vector CInt | 513 | selectQ :: Vector Int32 |
580 | -> Vector (Complex Float) | 514 | -> Vector (Complex Float) |
581 | -> Vector (Complex Float) | 515 | -> Vector (Complex Float) |
582 | -> Vector (Complex Float) | 516 | -> Vector (Complex Float) |
583 | -> Vector (Complex Float) | 517 | -> Vector (Complex Float) |
584 | selectQ = selectG c_selectQ | 518 | selectQ = selectG c_selectQ |
585 | 519 | ||
586 | type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) | 520 | type Sel x = CV Int32 (CV x (CV x (CV x (CV x (IO Int32))))) |
587 | 521 | ||
588 | foreign import ccall unsafe "chooseD" c_selectD :: Sel Double | 522 | foreign import ccall unsafe "chooseD" c_selectD :: Sel Double |
589 | foreign import ccall unsafe "chooseF" c_selectF :: Sel Float | 523 | foreign import ccall unsafe "chooseF" c_selectF :: Sel Float |
590 | foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt | 524 | foreign import ccall unsafe "chooseI" c_selectI :: Sel Int32 |
591 | foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double) | 525 | foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double) |
592 | foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float) | 526 | foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float) |
593 | foreign import ccall unsafe "chooseL" c_selectL :: Sel Z | 527 | foreign import ccall unsafe "chooseL" c_selectL :: Sel Z |
@@ -595,35 +529,35 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z | |||
595 | --------------------------------------------------------------------------- | 529 | --------------------------------------------------------------------------- |
596 | 530 | ||
597 | remapG :: (TransArray c, TransArray c1, Storable t, Storable a) | 531 | remapG :: (TransArray c, TransArray c1, Storable t, Storable a) |
598 | => (CInt -> CInt -> CInt -> CInt -> Ptr t | 532 | => (Int32 -> Int32 -> Int32 -> Int32 -> Ptr t |
599 | -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))) | 533 | -> Trans c1 (Trans c (Int32 -> Int32 -> Int32 -> Int32 -> Ptr a -> IO Int32))) |
600 | -> Matrix t -> c1 -> c -> Matrix a | 534 | -> Matrix t -> c1 -> c -> Matrix a |
601 | remapG f i j m = unsafePerformIO $ do | 535 | remapG f i j m = unsafePerformIO $ do |
602 | r <- createMatrix RowMajor (rows i) (cols i) | 536 | r <- createMatrix RowMajor (rows i) (cols i) |
603 | (i # j # m #! r) f #|"remapG" | 537 | (i # j # m #! r) f #|"remapG" |
604 | return r | 538 | return r |
605 | 539 | ||
606 | remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double | 540 | remapD :: Matrix Int32 -> Matrix Int32 -> Matrix Double -> Matrix Double |
607 | remapD = remapG c_remapD | 541 | remapD = remapG c_remapD |
608 | remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float | 542 | remapF :: Matrix Int32 -> Matrix Int32 -> Matrix Float -> Matrix Float |
609 | remapF = remapG c_remapF | 543 | remapF = remapG c_remapF |
610 | remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt | 544 | remapI :: Matrix Int32 -> Matrix Int32 -> Matrix Int32 -> Matrix Int32 |
611 | remapI = remapG c_remapI | 545 | remapI = remapG c_remapI |
612 | remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z | 546 | remapL :: Matrix Int32 -> Matrix Int32 -> Matrix Z -> Matrix Z |
613 | remapL = remapG c_remapL | 547 | remapL = remapG c_remapL |
614 | remapC :: Matrix CInt | 548 | remapC :: Matrix Int32 |
615 | -> Matrix CInt | 549 | -> Matrix Int32 |
616 | -> Matrix (Complex Double) | 550 | -> Matrix (Complex Double) |
617 | -> Matrix (Complex Double) | 551 | -> Matrix (Complex Double) |
618 | remapC = remapG c_remapC | 552 | remapC = remapG c_remapC |
619 | remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float) | 553 | remapQ :: Matrix Int32 -> Matrix Int32 -> Matrix (Complex Float) -> Matrix (Complex Float) |
620 | remapQ = remapG c_remapQ | 554 | remapQ = remapG c_remapQ |
621 | 555 | ||
622 | type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) | 556 | type Rem x = OM Int32 (OM Int32 (OM x (OM x (IO Int32)))) |
623 | 557 | ||
624 | foreign import ccall unsafe "remapD" c_remapD :: Rem Double | 558 | foreign import ccall unsafe "remapD" c_remapD :: Rem Double |
625 | foreign import ccall unsafe "remapF" c_remapF :: Rem Float | 559 | foreign import ccall unsafe "remapF" c_remapF :: Rem Float |
626 | foreign import ccall unsafe "remapI" c_remapI :: Rem CInt | 560 | foreign import ccall unsafe "remapI" c_remapI :: Rem Int32 |
627 | foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double) | 561 | foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double) |
628 | foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float) | 562 | foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float) |
629 | foreign import ccall unsafe "remapL" c_remapL :: Rem Z | 563 | foreign import ccall unsafe "remapL" c_remapL :: Rem Z |
@@ -631,14 +565,14 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z | |||
631 | -------------------------------------------------------------------------------- | 565 | -------------------------------------------------------------------------------- |
632 | 566 | ||
633 | rowOpAux :: (TransArray c, Storable a) => | 567 | rowOpAux :: (TransArray c, Storable a) => |
634 | (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt)) | 568 | (Int32 -> Ptr a -> Int32 -> Int32 -> Int32 -> Int32 -> Trans c (IO Int32)) |
635 | -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO () | 569 | -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO () |
636 | rowOpAux f c x i1 i2 j1 j2 m = do | 570 | rowOpAux f c x i1 i2 j1 j2 m = do |
637 | px <- newArray [x] | 571 | px <- newArray [x] |
638 | (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" | 572 | (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" |
639 | free px | 573 | free px |
640 | 574 | ||
641 | type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok | 575 | type RowOp x = Int32 -> Ptr x -> Int32 -> Int32 -> Int32 -> Int32 -> x ::> Ok |
642 | 576 | ||
643 | foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R | 577 | foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R |
644 | foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float | 578 | foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float |
@@ -652,7 +586,7 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z | |||
652 | -------------------------------------------------------------------------------- | 586 | -------------------------------------------------------------------------------- |
653 | 587 | ||
654 | gemmg :: (TransArray c1, TransArray c, TransArray c2, TransArray c3) | 588 | gemmg :: (TransArray c1, TransArray c, TransArray c2, TransArray c3) |
655 | => Trans c3 (Trans c2 (Trans c1 (Trans c (IO CInt)))) | 589 | => Trans c3 (Trans c2 (Trans c1 (Trans c (IO Int32)))) |
656 | -> c3 -> c2 -> c1 -> c -> IO () | 590 | -> c3 -> c2 -> c1 -> c -> IO () |
657 | gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" | 591 | gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" |
658 | 592 | ||
@@ -669,21 +603,26 @@ foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z | |||
669 | 603 | ||
670 | -------------------------------------------------------------------------------- | 604 | -------------------------------------------------------------------------------- |
671 | 605 | ||
606 | {- | ||
672 | reorderAux :: (TransArray c, Storable t, Storable a1, Storable t1, Storable a) => | 607 | reorderAux :: (TransArray c, Storable t, Storable a1, Storable t1, Storable a) => |
673 | (CInt -> Ptr a -> CInt -> Ptr t1 | 608 | (Int32 -> Ptr a -> Int32 -> Ptr t1 |
674 | -> Trans c (CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt)) | 609 | -> Trans c (Int32 -> Ptr t -> Int32 -> Ptr a1 -> IO Int32)) |
675 | -> Vector t1 -> c -> Vector t -> Vector a1 | 610 | -> Vector t1 -> c -> Vector t -> Vector a1 |
611 | -} | ||
612 | reorderAux :: (TransArray c, Storable a, | ||
613 | Trans c (Int32 -> Ptr a -> Int32 -> Ptr a -> IO Int32) ~ (Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr a -> Int32 -> Ptr a -> IO Int32)) => | ||
614 | p -> Vector Int32 -> c -> Vector a -> Vector a | ||
676 | reorderAux f s d v = unsafePerformIO $ do | 615 | reorderAux f s d v = unsafePerformIO $ do |
677 | k <- createVector (dim s) | 616 | k <- createVector (dim s) |
678 | r <- createVector (dim v) | 617 | r <- createVector (dim v) |
679 | (k # s # d # v #! r) f #| "reorderV" | 618 | (k # s # d # v #! r) reorderStorable #| "reorderV" |
680 | return r | 619 | return r |
681 | 620 | ||
682 | type Reorder x = CV CInt (CV CInt (CV CInt (CV x (CV x (IO CInt))))) | 621 | type Reorder x = CV Int32 (CV Int32 (CV Int32 (CV x (CV x (IO Int32))))) |
683 | 622 | ||
684 | foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double | 623 | foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double |
685 | foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float | 624 | foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float |
686 | foreign import ccall unsafe "reorderI" c_reorderI :: Reorder CInt | 625 | foreign import ccall unsafe "reorderI" c_reorderI :: Reorder Int32 |
687 | foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double) | 626 | foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double) |
688 | foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float) | 627 | foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float) |
689 | foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z | 628 | foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z |
@@ -691,12 +630,12 @@ foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z | |||
691 | -- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices, | 630 | -- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices, |
692 | -- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@ | 631 | -- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@ |
693 | -- This function is intended to be used internally by tensor libraries. | 632 | -- This function is intended to be used internally by tensor libraries. |
694 | reorderVector :: Element a | 633 | reorderVector :: Storable a |
695 | => Vector CInt -- ^ @strides@: array strides | 634 | => Vector Int32 -- ^ @strides@: array strides |
696 | -> Vector CInt -- ^ @dims@: array dimensions of new array @v@ | 635 | -> Vector Int32 -- ^ @dims@: array dimensions of new array @v@ |
697 | -> Vector a -- ^ @v@: flattened input array | 636 | -> Vector a -- ^ @v@: flattened input array |
698 | -> Vector a -- ^ @v'@: flattened output array | 637 | -> Vector a -- ^ @v'@: flattened output array |
699 | reorderVector = reorderV | 638 | reorderVector = reorderAux () |
700 | 639 | ||
701 | -------------------------------------------------------------------------------- | 640 | -------------------------------------------------------------------------------- |
702 | 641 | ||
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index eb0c5a8..e67aa67 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs | |||
@@ -135,6 +135,7 @@ instance (Integral t, KnownNat n) => Num (Mod n t) | |||
135 | fromInteger = l0 (\m x -> fromInteger x `mod` (fromIntegral m)) | 135 | fromInteger = l0 (\m x -> fromInteger x `mod` (fromIntegral m)) |
136 | 136 | ||
137 | 137 | ||
138 | #if 0 | ||
138 | instance KnownNat m => Element (Mod m I) | 139 | instance KnownNat m => Element (Mod m I) |
139 | where | 140 | where |
140 | constantD x n = i2f (constantD (unMod x) n) | 141 | constantD x n = i2f (constantD (unMod x) n) |
@@ -168,6 +169,7 @@ instance KnownNat m => Element (Mod m Z) | |||
168 | gemm u a b c = gemmg (c_gemmML m') (f2i u) (f2iM a) (f2iM b) (f2iM c) | 169 | gemm u a b c = gemmg (c_gemmML m') (f2i u) (f2iM a) (f2iM b) (f2iM c) |
169 | where | 170 | where |
170 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 171 | m' = fromIntegral . natVal $ (undefined :: Proxy m) |
172 | #endif | ||
171 | 173 | ||
172 | 174 | ||
173 | instance KnownNat m => CTrans (Mod m I) | 175 | instance KnownNat m => CTrans (Mod m I) |
@@ -306,10 +308,10 @@ f2i :: Storable t => Vector (Mod n t) -> Vector t | |||
306 | f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) | 308 | f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) |
307 | where (fp,i,n) = unsafeToForeignPtr v | 309 | where (fp,i,n) = unsafeToForeignPtr v |
308 | 310 | ||
309 | f2iM :: (Element t, Element (Mod n t)) => Matrix (Mod n t) -> Matrix t | 311 | f2iM :: (Storable t, Storable (Mod n t)) => Matrix (Mod n t) -> Matrix t |
310 | f2iM m = m { xdat = f2i (xdat m) } | 312 | f2iM m = m { xdat = f2i (xdat m) } |
311 | 313 | ||
312 | i2fM :: (Element t, Element (Mod n t)) => Matrix t -> Matrix (Mod n t) | 314 | i2fM :: (Storable t, Storable (Mod n t)) => Matrix t -> Matrix (Mod n t) |
313 | i2fM m = m { xdat = i2f (xdat m) } | 315 | i2fM m = m { xdat = i2f (xdat m) } |
314 | 316 | ||
315 | vmod :: forall m t. (KnownNat m, Storable t, Integral t, Numeric t) => Vector t -> Vector (Mod m t) | 317 | vmod :: forall m t. (KnownNat m, Storable t, Integral t, Numeric t) => Vector t -> Vector (Mod m t) |
diff --git a/packages/base/src/Internal/Numeric.hs b/packages/base/src/Internal/Numeric.hs index fd0a217..4f7bb82 100644 --- a/packages/base/src/Internal/Numeric.hs +++ b/packages/base/src/Internal/Numeric.hs | |||
@@ -4,6 +4,7 @@ | |||
4 | {-# LANGUAGE MultiParamTypeClasses #-} | 4 | {-# LANGUAGE MultiParamTypeClasses #-} |
5 | {-# LANGUAGE FunctionalDependencies #-} | 5 | {-# LANGUAGE FunctionalDependencies #-} |
6 | {-# LANGUAGE UndecidableInstances #-} | 6 | {-# LANGUAGE UndecidableInstances #-} |
7 | {-# LANGUAGE PatternSynonyms #-} | ||
7 | 8 | ||
8 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | 9 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} |
9 | 10 | ||
@@ -22,12 +23,18 @@ module Internal.Numeric where | |||
22 | import Internal.Vector | 23 | import Internal.Vector |
23 | import Internal.Matrix | 24 | import Internal.Matrix |
24 | import Internal.Element | 25 | import Internal.Element |
26 | import Internal.Extract (requires,pattern BAD_SIZE) | ||
25 | import Internal.ST as ST | 27 | import Internal.ST as ST |
26 | import Internal.Conversion | 28 | import Internal.Conversion |
27 | import Internal.Vectorized | 29 | import Internal.Vectorized |
28 | import Internal.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ,multiplyI,multiplyL) | 30 | import Internal.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ,multiplyI,multiplyL) |
31 | import Control.Monad | ||
32 | import Data.Function | ||
33 | import Data.Int | ||
29 | import Data.List.Split(chunksOf) | 34 | import Data.List.Split(chunksOf) |
30 | import qualified Data.Vector.Storable as V | 35 | import qualified Data.Vector.Storable as V |
36 | import Foreign.Ptr | ||
37 | import Foreign.Storable | ||
31 | 38 | ||
32 | -------------------------------------------------------------------------------- | 39 | -------------------------------------------------------------------------------- |
33 | 40 | ||
@@ -44,7 +51,7 @@ type instance ArgOf Matrix a = a -> a -> a | |||
44 | -------------------------------------------------------------------------------- | 51 | -------------------------------------------------------------------------------- |
45 | 52 | ||
46 | -- | Basic element-by-element functions for numeric containers | 53 | -- | Basic element-by-element functions for numeric containers |
47 | class Element e => Container c e | 54 | class Storable e => Container c e |
48 | where | 55 | where |
49 | conj' :: c e -> c e | 56 | conj' :: c e -> c e |
50 | size' :: c e -> IndexOf c | 57 | size' :: c e -> IndexOf c |
@@ -56,7 +63,7 @@ class Element e => Container c e | |||
56 | -- | element by element multiplication | 63 | -- | element by element multiplication |
57 | mul :: c e -> c e -> c e | 64 | mul :: c e -> c e -> c e |
58 | equal :: c e -> c e -> Bool | 65 | equal :: c e -> c e -> Bool |
59 | cmap' :: (Element b) => (e -> b) -> c e -> c b | 66 | cmap' :: (Storable b) => (e -> b) -> c e -> c b |
60 | konst' :: e -> IndexOf c -> c e | 67 | konst' :: e -> IndexOf c -> c e |
61 | build' :: IndexOf c -> (ArgOf c e) -> c e | 68 | build' :: IndexOf c -> (ArgOf c e) -> c e |
62 | atIndex' :: c e -> IndexOf c -> e | 69 | atIndex' :: c e -> IndexOf c -> e |
@@ -107,7 +114,7 @@ instance Container Vector I | |||
107 | mul = vectorZipI Mul | 114 | mul = vectorZipI Mul |
108 | equal = (==) | 115 | equal = (==) |
109 | scalar' = V.singleton | 116 | scalar' = V.singleton |
110 | konst' = constantD | 117 | konst' = constantAux |
111 | build' = buildV | 118 | build' = buildV |
112 | cmap' = mapVector | 119 | cmap' = mapVector |
113 | atIndex' = (@>) | 120 | atIndex' = (@>) |
@@ -146,7 +153,7 @@ instance Container Vector Z | |||
146 | mul = vectorZipL Mul | 153 | mul = vectorZipL Mul |
147 | equal = (==) | 154 | equal = (==) |
148 | scalar' = V.singleton | 155 | scalar' = V.singleton |
149 | konst' = constantD | 156 | konst' = constantAux |
150 | build' = buildV | 157 | build' = buildV |
151 | cmap' = mapVector | 158 | cmap' = mapVector |
152 | atIndex' = (@>) | 159 | atIndex' = (@>) |
@@ -186,7 +193,7 @@ instance Container Vector Float | |||
186 | mul = vectorZipF Mul | 193 | mul = vectorZipF Mul |
187 | equal = (==) | 194 | equal = (==) |
188 | scalar' = V.singleton | 195 | scalar' = V.singleton |
189 | konst' = constantD | 196 | konst' = constantAux |
190 | build' = buildV | 197 | build' = buildV |
191 | cmap' = mapVector | 198 | cmap' = mapVector |
192 | atIndex' = (@>) | 199 | atIndex' = (@>) |
@@ -223,7 +230,7 @@ instance Container Vector Double | |||
223 | mul = vectorZipR Mul | 230 | mul = vectorZipR Mul |
224 | equal = (==) | 231 | equal = (==) |
225 | scalar' = V.singleton | 232 | scalar' = V.singleton |
226 | konst' = constantD | 233 | konst' = constantAux |
227 | build' = buildV | 234 | build' = buildV |
228 | cmap' = mapVector | 235 | cmap' = mapVector |
229 | atIndex' = (@>) | 236 | atIndex' = (@>) |
@@ -260,7 +267,7 @@ instance Container Vector (Complex Double) | |||
260 | mul = vectorZipC Mul | 267 | mul = vectorZipC Mul |
261 | equal = (==) | 268 | equal = (==) |
262 | scalar' = V.singleton | 269 | scalar' = V.singleton |
263 | konst' = constantD | 270 | konst' = constantAux |
264 | build' = buildV | 271 | build' = buildV |
265 | cmap' = mapVector | 272 | cmap' = mapVector |
266 | atIndex' = (@>) | 273 | atIndex' = (@>) |
@@ -296,7 +303,7 @@ instance Container Vector (Complex Float) | |||
296 | mul = vectorZipQ Mul | 303 | mul = vectorZipQ Mul |
297 | equal = (==) | 304 | equal = (==) |
298 | scalar' = V.singleton | 305 | scalar' = V.singleton |
299 | konst' = constantD | 306 | konst' = constantAux |
300 | build' = buildV | 307 | build' = buildV |
301 | cmap' = mapVector | 308 | cmap' = mapVector |
302 | atIndex' = (@>) | 309 | atIndex' = (@>) |
@@ -323,7 +330,7 @@ instance Container Vector (Complex Float) | |||
323 | 330 | ||
324 | --------------------------------------------------------------- | 331 | --------------------------------------------------------------- |
325 | 332 | ||
326 | instance (Num a, Element a, Container Vector a) => Container Matrix a | 333 | instance (Num a, Storable a, Container Vector a) => Container Matrix a |
327 | where | 334 | where |
328 | conj' = liftMatrix conj' | 335 | conj' = liftMatrix conj' |
329 | size' = size | 336 | size' = size |
@@ -418,8 +425,8 @@ fromZ = fromZ' | |||
418 | toZ :: (Container c e) => c e -> c Z | 425 | toZ :: (Container c e) => c e -> c Z |
419 | toZ = toZ' | 426 | toZ = toZ' |
420 | 427 | ||
421 | -- | like 'fmap' (cannot implement instance Functor because of Element class constraint) | 428 | -- | like 'fmap' (cannot implement instance Functor because of Storable class constraint) |
422 | cmap :: (Element b, Container c e) => (e -> b) -> c e -> c b | 429 | cmap :: (Storable b, Container c e) => (e -> b) -> c e -> c b |
423 | cmap = cmap' | 430 | cmap = cmap' |
424 | 431 | ||
425 | -- | generic indexing function | 432 | -- | generic indexing function |
@@ -470,7 +477,7 @@ step | |||
470 | step = step' | 477 | step = step' |
471 | 478 | ||
472 | 479 | ||
473 | -- | Element by element version of @case compare a b of {LT -> l; EQ -> e; GT -> g}@. | 480 | -- | Storable by element version of @case compare a b of {LT -> l; EQ -> e; GT -> g}@. |
474 | -- | 481 | -- |
475 | -- Arguments with any dimension = 1 are automatically expanded: | 482 | -- Arguments with any dimension = 1 are automatically expanded: |
476 | -- | 483 | -- |
@@ -598,7 +605,7 @@ instance Numeric Z | |||
598 | -------------------------------------------------------------------------------- | 605 | -------------------------------------------------------------------------------- |
599 | 606 | ||
600 | -- | Matrix product and related functions | 607 | -- | Matrix product and related functions |
601 | class (Num e, Element e) => Product e where | 608 | class (Num e, Storable e) => Product e where |
602 | -- | matrix product | 609 | -- | matrix product |
603 | multiply :: Matrix e -> Matrix e -> Matrix e | 610 | multiply :: Matrix e -> Matrix e -> Matrix e |
604 | -- | sum of absolute value of elements (differs in complex case from @norm1@) | 611 | -- | sum of absolute value of elements (differs in complex case from @norm1@) |
@@ -823,12 +830,12 @@ buildV n f = fromList [f k | k <- ks] | |||
823 | -------------------------------------------------------- | 830 | -------------------------------------------------------- |
824 | 831 | ||
825 | -- | Creates a square matrix with a given diagonal. | 832 | -- | Creates a square matrix with a given diagonal. |
826 | diag :: (Num a, Element a) => Vector a -> Matrix a | 833 | diag :: (Num a, Storable a) => Vector a -> Matrix a |
827 | diag v = diagRect 0 v n n where n = dim v | 834 | diag v = diagRect 0 v n n where n = dim v |
828 | 835 | ||
829 | -- | creates the identity matrix of given dimension | 836 | -- | creates the identity matrix of given dimension |
830 | ident :: (Num a, Element a) => Int -> Matrix a | 837 | ident :: (Num a, Storable a) => Int -> Matrix a |
831 | ident n = diag (constantD 1 n) | 838 | ident n = diag (constantAux 1 n) |
832 | 839 | ||
833 | -------------------------------------------------------- | 840 | -------------------------------------------------------- |
834 | 841 | ||
@@ -943,3 +950,44 @@ class Testable t | |||
943 | 950 | ||
944 | -------------------------------------------------------------------------------- | 951 | -------------------------------------------------------------------------------- |
945 | 952 | ||
953 | compareV :: (Storable a, Ord a) => Vector a -> Vector a -> Vector Int32 | ||
954 | compareV = compareG compareStorable | ||
955 | |||
956 | compareStorable :: (Storable a, Ord a) => | ||
957 | Int32 -> Ptr a | ||
958 | -> Int32 -> Ptr a | ||
959 | -> Int32 -> Ptr Int32 | ||
960 | -> IO Int32 | ||
961 | compareStorable xn xp yn yp rn rp = do | ||
962 | requires (xn==yn && xn==rn) BAD_SIZE $ do | ||
963 | ($ 0) $ fix $ \kloop k -> when (k<xn) $ do | ||
964 | xk <- peekElemOff xp (fromIntegral k) | ||
965 | yk <- peekElemOff yp (fromIntegral k) | ||
966 | pokeElemOff rp (fromIntegral k) $ case compare xk yk of | ||
967 | LT -> -1 | ||
968 | GT -> 1 | ||
969 | EQ -> 0 | ||
970 | kloop (succ k) | ||
971 | return 0 | ||
972 | |||
973 | selectV :: Storable a => Vector Int32 -> Vector a -> Vector a -> Vector a -> Vector a | ||
974 | selectV = selectG selectStorable | ||
975 | |||
976 | selectStorable :: Storable a => | ||
977 | Int32 -> Ptr Int32 | ||
978 | -> Int32 -> Ptr a | ||
979 | -> Int32 -> Ptr a | ||
980 | -> Int32 -> Ptr a | ||
981 | -> Int32 -> Ptr a | ||
982 | -> IO Int32 | ||
983 | selectStorable condn condp ltn ltp eqn eqp gtn gtp rn rp = do | ||
984 | requires (condn==ltn && ltn==eqn && ltn==gtn && ltn==rn) BAD_SIZE $ do | ||
985 | ($ 0) $ fix $ \kloop k -> when (k<condn) $ do | ||
986 | condpk <- peekElemOff condp (fromIntegral k) | ||
987 | pokeElemOff rp (fromIntegral k) =<< case compare condpk 0 of | ||
988 | LT -> peekElemOff ltp (fromIntegral k) | ||
989 | GT -> peekElemOff gtp (fromIntegral k) | ||
990 | EQ -> peekElemOff eqp (fromIntegral k) | ||
991 | kloop (succ k) | ||
992 | return 0 | ||
993 | |||
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index 7d54e6d..326b90a 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs | |||
@@ -1,6 +1,7 @@ | |||
1 | {-# LANGUAGE Rank2Types #-} | 1 | {-# LANGUAGE Rank2Types #-} |
2 | {-# LANGUAGE BangPatterns #-} | 2 | {-# LANGUAGE BangPatterns #-} |
3 | {-# LANGUAGE ViewPatterns #-} | 3 | {-# LANGUAGE ViewPatterns #-} |
4 | {-# LANGUAGE PatternSynonyms #-} | ||
4 | 5 | ||
5 | ----------------------------------------------------------------------------- | 6 | ----------------------------------------------------------------------------- |
6 | -- | | 7 | -- | |
@@ -30,14 +31,20 @@ module Internal.ST ( | |||
30 | unsafeThawVector, unsafeFreezeVector, | 31 | unsafeThawVector, unsafeFreezeVector, |
31 | newUndefinedMatrix, | 32 | newUndefinedMatrix, |
32 | unsafeReadMatrix, unsafeWriteMatrix, | 33 | unsafeReadMatrix, unsafeWriteMatrix, |
33 | unsafeThawMatrix, unsafeFreezeMatrix | 34 | unsafeThawMatrix, unsafeFreezeMatrix, |
35 | setRect | ||
34 | ) where | 36 | ) where |
35 | 37 | ||
36 | import Internal.Vector | 38 | import Internal.Vector |
37 | import Internal.Matrix | 39 | import Internal.Matrix |
38 | import Internal.Vectorized | 40 | import Internal.Vectorized |
41 | import Internal.Devel ((#|)) | ||
39 | import Control.Monad.ST(ST, runST) | 42 | import Control.Monad.ST(ST, runST) |
40 | import Foreign.Storable(Storable, peekElemOff, pokeElemOff) | 43 | import Control.Monad |
44 | import Data.Function | ||
45 | import Data.Int | ||
46 | import Foreign.Ptr | ||
47 | import Foreign.Storable | ||
41 | import Control.Monad.ST.Unsafe(unsafeIOToST) | 48 | import Control.Monad.ST.Unsafe(unsafeIOToST) |
42 | 49 | ||
43 | {-# INLINE ioReadV #-} | 50 | {-# INLINE ioReadV #-} |
@@ -121,7 +128,7 @@ ioWriteM m r c val = ioWriteV (xdat m) (r * xRow m + c * xCol m) val | |||
121 | 128 | ||
122 | newtype STMatrix s t = STMatrix (Matrix t) | 129 | newtype STMatrix s t = STMatrix (Matrix t) |
123 | 130 | ||
124 | thawMatrix :: Element t => Matrix t -> ST s (STMatrix s t) | 131 | thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) |
125 | thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix | 132 | thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix |
126 | 133 | ||
127 | unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) | 134 | unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) |
@@ -142,17 +149,17 @@ unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c | |||
142 | modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () | 149 | modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () |
143 | modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c | 150 | modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c |
144 | 151 | ||
145 | liftSTMatrix :: (Element t) => (Matrix t -> a) -> STMatrix s t -> ST s a | 152 | liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s t -> ST s a |
146 | liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x | 153 | liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x |
147 | 154 | ||
148 | unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) | 155 | unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) |
149 | unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x | 156 | unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x |
150 | 157 | ||
151 | 158 | ||
152 | freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t) | 159 | freezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) |
153 | freezeMatrix m = liftSTMatrix id m | 160 | freezeMatrix m = liftSTMatrix id m |
154 | 161 | ||
155 | cloneMatrix :: Element t => Matrix t -> IO (Matrix t) | 162 | cloneMatrix :: Storable t => Matrix t -> IO (Matrix t) |
156 | cloneMatrix m = copy (orderOf m) m | 163 | cloneMatrix m = copy (orderOf m) m |
157 | 164 | ||
158 | {-# INLINE safeIndexM #-} | 165 | {-# INLINE safeIndexM #-} |
@@ -172,7 +179,7 @@ readMatrix = safeIndexM unsafeReadMatrix | |||
172 | writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () | 179 | writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () |
173 | writeMatrix = safeIndexM unsafeWriteMatrix | 180 | writeMatrix = safeIndexM unsafeWriteMatrix |
174 | 181 | ||
175 | setMatrix :: Element t => STMatrix s t -> Int -> Int -> Matrix t -> ST s () | 182 | setMatrix :: Storable t => STMatrix s t -> Int -> Int -> Matrix t -> ST s () |
176 | setMatrix (STMatrix x) i j m = unsafeIOToST $ setRect i j m x | 183 | setMatrix (STMatrix x) i j m = unsafeIOToST $ setRect i j m x |
177 | 184 | ||
178 | newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t) | 185 | newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t) |
@@ -210,7 +217,7 @@ data RowOper t = AXPY t Int Int ColRange | |||
210 | | SCAL t RowRange ColRange | 217 | | SCAL t RowRange ColRange |
211 | | SWAP Int Int ColRange | 218 | | SWAP Int Int ColRange |
212 | 219 | ||
213 | rowOper :: (Num t, Element t) => RowOper t -> STMatrix s t -> ST s () | 220 | rowOper :: (Num t, Storable t) => RowOper t -> STMatrix s t -> ST s () |
214 | 221 | ||
215 | rowOper (AXPY x i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 0 x i1' i2' j1 j2 m | 222 | rowOper (AXPY x i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 0 x i1' i2' j1 j2 m |
216 | where | 223 | where |
@@ -230,8 +237,8 @@ rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m | |||
230 | i2' = i2 `mod` (rows m) | 237 | i2' = i2 `mod` (rows m) |
231 | 238 | ||
232 | 239 | ||
233 | extractMatrix :: Element a => STMatrix t a -> RowRange -> ColRange -> ST s (Matrix a) | 240 | extractMatrix :: Storable a => STMatrix t a -> RowRange -> ColRange -> ST s (Matrix a) |
234 | extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) | 241 | extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractAux (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) |
235 | where | 242 | where |
236 | (i1,i2) = getRowRange (rows m) rr | 243 | (i1,i2) = getRowRange (rows m) rr |
237 | (j1,j2) = getColRange (cols m) rc | 244 | (j1,j2) = getColRange (cols m) rc |
@@ -239,19 +246,117 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[ | |||
239 | -- | r0 c0 height width | 246 | -- | r0 c0 height width |
240 | data Slice s t = Slice (STMatrix s t) Int Int Int Int | 247 | data Slice s t = Slice (STMatrix s t) Int Int Int Int |
241 | 248 | ||
242 | slice :: Element a => Slice t a -> Matrix a | 249 | slice :: Storable a => Slice t a -> Matrix a |
243 | slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m | 250 | slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m |
244 | 251 | ||
245 | gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () | 252 | gemmm :: (Storable t, Num t) => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () |
246 | gemmm beta (slice->r) alpha (slice->a) (slice->b) = res | 253 | gemmm beta (slice->r) alpha (slice->a) (slice->b) = res |
247 | where | 254 | where |
248 | res = unsafeIOToST (gemm v a b r) | 255 | res = unsafeIOToST (gemm v a b r) |
249 | v = fromList [alpha,beta] | 256 | v = fromList [alpha,beta] |
250 | 257 | ||
251 | 258 | ||
252 | mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) | 259 | mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) |
253 | mutable f a = runST $ do | 260 | mutable f a = runST $ do |
254 | x <- thawMatrix a | 261 | x <- thawMatrix a |
255 | info <- f (rows a, cols a) x | 262 | info <- f (rows a, cols a) x |
256 | r <- unsafeFreezeMatrix x | 263 | r <- unsafeFreezeMatrix x |
257 | return (r,info) | 264 | return (r,info) |
265 | |||
266 | |||
267 | |||
268 | setRect :: Storable t => Int -> Int -> Matrix t -> Matrix t -> IO () | ||
269 | setRect i j m r = (m Internal.Matrix.#! r) (setRectStorable (fi i) (fi j)) #|"setRect" | ||
270 | |||
271 | setRectStorable :: Storable t => | ||
272 | Int32 -> Int32 | ||
273 | -> Int32 -> Int32 -> Int32 -> Int32 -> {- const -} Ptr t | ||
274 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t | ||
275 | -> IO Int32 | ||
276 | setRectStorable i j mr mc mXr mXc mp rr rc rXr rXc rp = do | ||
277 | ($ 0) $ fix $ \aloop a -> when (a<mr) $ do | ||
278 | ($ 0) $ fix $ \bloop b -> when (b<mc) $ do | ||
279 | let x = a+i | ||
280 | y = b+j | ||
281 | when (0<=x && x<rr && 0<=y && y<rc) $ do | ||
282 | pokeElemOff rp (fromIntegral $ rXr*x + rXc*y) | ||
283 | =<< peekElemOff mp (fromIntegral $ mXr*a + mXc*b) | ||
284 | bloop (succ b) | ||
285 | aloop (succ a) | ||
286 | return 0 | ||
287 | |||
288 | rowOp :: (Storable t, Num t) => Int -> t -> Int -> Int -> Int -> Int -> Matrix t -> IO () | ||
289 | rowOp = rowOpAux rowOpStorable | ||
290 | |||
291 | pattern BAD_CODE = 2001 | ||
292 | |||
293 | rowOpStorable :: (Storable t, Num t) => | ||
294 | Int32 -> Ptr t -> Int32 -> Int32 -> Int32 -> Int32 | ||
295 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t | ||
296 | -> IO Int32 | ||
297 | rowOpStorable 0 pa i1 i2 j1 j2 rr rc rXr rXc rp = do | ||
298 | -- AXPY_IMP | ||
299 | a <- peek pa | ||
300 | ($ j1) $ fix $ \jloop j -> when (j<=j2) $ do | ||
301 | ri1j <- peekElemOff rp $ fromIntegral $ rXr*i1 + rXc*j | ||
302 | let i2j = fromIntegral $ rXr*i2 + rXc*j | ||
303 | ri2j <- peekElemOff rp i2j | ||
304 | pokeElemOff rp i2j $ ri2j + a*ri1j | ||
305 | jloop (succ j) | ||
306 | return 0 | ||
307 | rowOpStorable 1 pa i1 i2 j1 j2 rr rc rXr rXc rp = do | ||
308 | -- SCAL_IMP | ||
309 | a <- peek pa | ||
310 | ($ i1) $ fix $ \iloop i -> when (i<=i2) $ do | ||
311 | ($ j1) $ fix $ \jloop j -> when (j<=j2) $ do | ||
312 | let rijp = rp `plusPtr` fromIntegral (rXr*i + rXc*j) | ||
313 | rij <- peek rijp | ||
314 | poke rijp $ a * rij | ||
315 | jloop (succ j) | ||
316 | iloop (succ i) | ||
317 | return 0 | ||
318 | rowOpStorable 2 pa i1 i2 j1 j2 rr rc rXr rXc rp | i1 == i2 = return 0 | ||
319 | rowOpStorable 2 pa i1 i2 j1 j2 rr rc rXr rXc rp = do | ||
320 | -- SWAP_IMP | ||
321 | ($ j1) $ fix $ \kloop k -> when (k<=j2) $ do | ||
322 | let i1k = fromIntegral $ rXr*i1 + rXc*k | ||
323 | i2k = fromIntegral $ rXr*i2 + rXc*k | ||
324 | aux <- peekElemOff rp i1k | ||
325 | pokeElemOff rp i1k =<< peekElemOff rp i2k | ||
326 | pokeElemOff rp i2k aux | ||
327 | kloop (succ k) | ||
328 | return 0 | ||
329 | rowOpStorable _ pa i1 i2 j1 j2 rr rc rXr rXc rp = do | ||
330 | return BAD_CODE | ||
331 | |||
332 | gemm :: (Storable t, Num t) => Vector t -> Matrix t -> Matrix t -> Matrix t -> IO () | ||
333 | gemm v m1 m2 m3 = (v Internal.Matrix.# m1 Internal.Matrix.# m2 Internal.Matrix.#! m3) gemmStorable #|"gemm" | ||
334 | |||
335 | -- ScalarLike t | ||
336 | gemmStorable :: (Storable t, Num t) => | ||
337 | Int32 -> Ptr t -- VECG(T,c) | ||
338 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t -- MATG(T,a) | ||
339 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t -- MATG(T,b) | ||
340 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t -- MATG(T,r) | ||
341 | -> IO Int32 | ||
342 | gemmStorable cn cp | ||
343 | ar ac aXr aXc ap | ||
344 | br bc bXr bXc bp | ||
345 | rr rc rXr rXc rp = do | ||
346 | a <- peek cp | ||
347 | b <- peekElemOff cp 1 | ||
348 | ($ 0) $ fix $ \iloop i -> when (i<rr) $ do | ||
349 | ($ 0) $ fix $ \jloop j -> when (j<rc) $ do | ||
350 | let kloop k !t fin | ||
351 | | k<ac = do | ||
352 | aik <- peekElemOff ap (fromIntegral $ i*aXr + k*aXc) | ||
353 | bkj <- peekElemOff bp (fromIntegral $ k*bXr + j*bXc) | ||
354 | kloop (succ k) (t + aik*bkj) fin | ||
355 | | otherwise = fin t | ||
356 | kloop 0 0 $ \t -> do | ||
357 | let ij = fromIntegral $ i*rXr + j*rXc | ||
358 | rij <- peekElemOff rp ij | ||
359 | pokeElemOff rp ij (b*rij + a*t) | ||
360 | jloop (succ j) | ||
361 | iloop (succ i) | ||
362 | return 0 | ||
diff --git a/packages/base/src/Internal/Sparse.hs b/packages/base/src/Internal/Sparse.hs index fbea11a..423b169 100644 --- a/packages/base/src/Internal/Sparse.hs +++ b/packages/base/src/Internal/Sparse.hs | |||
@@ -20,7 +20,7 @@ import Data.Function(on) | |||
20 | import Control.Arrow((***)) | 20 | import Control.Arrow((***)) |
21 | import Control.Monad(when) | 21 | import Control.Monad(when) |
22 | import Data.List(groupBy, sort) | 22 | import Data.List(groupBy, sort) |
23 | import Foreign.C.Types(CInt(..)) | 23 | import Data.Int |
24 | 24 | ||
25 | import Internal.Devel | 25 | import Internal.Devel |
26 | import System.IO.Unsafe(unsafePerformIO) | 26 | import System.IO.Unsafe(unsafePerformIO) |
@@ -34,16 +34,16 @@ type AssocMatrix = [((Int,Int),Double)] | |||
34 | 34 | ||
35 | data CSR = CSR | 35 | data CSR = CSR |
36 | { csrVals :: Vector Double | 36 | { csrVals :: Vector Double |
37 | , csrCols :: Vector CInt | 37 | , csrCols :: Vector Int32 |
38 | , csrRows :: Vector CInt | 38 | , csrRows :: Vector Int32 |
39 | , csrNRows :: Int | 39 | , csrNRows :: Int |
40 | , csrNCols :: Int | 40 | , csrNCols :: Int |
41 | } deriving Show | 41 | } deriving Show |
42 | 42 | ||
43 | data CSC = CSC | 43 | data CSC = CSC |
44 | { cscVals :: Vector Double | 44 | { cscVals :: Vector Double |
45 | , cscRows :: Vector CInt | 45 | , cscRows :: Vector Int32 |
46 | , cscCols :: Vector CInt | 46 | , cscCols :: Vector Int32 |
47 | , cscNRows :: Int | 47 | , cscNRows :: Int |
48 | , cscNCols :: Int | 48 | , cscNCols :: Int |
49 | } deriving Show | 49 | } deriving Show |
@@ -138,9 +138,9 @@ mkDiagR r c v | |||
138 | diagVals = v | 138 | diagVals = v |
139 | 139 | ||
140 | 140 | ||
141 | type IV t = CInt -> Ptr CInt -> t | 141 | type IV t = Int32 -> Ptr Int32 -> t |
142 | type V t = CInt -> Ptr Double -> t | 142 | type V t = Int32 -> Ptr Double -> t |
143 | type SMxV = V (IV (IV (V (V (IO CInt))))) | 143 | type SMxV = V (IV (IV (V (V (IO Int32))))) |
144 | 144 | ||
145 | gmXv :: GMatrix -> Vector Double -> Vector Double | 145 | gmXv :: GMatrix -> Vector Double -> Vector Double |
146 | gmXv SparseR { gmCSR = CSR{..}, .. } v = unsafePerformIO $ do | 146 | gmXv SparseR { gmCSR = CSR{..}, .. } v = unsafePerformIO $ do |
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs index f642e8d..6f3b4c8 100644 --- a/packages/base/src/Internal/Util.hs +++ b/packages/base/src/Internal/Util.hs | |||
@@ -83,6 +83,7 @@ import Control.Arrow((&&&),(***)) | |||
83 | import Data.Complex | 83 | import Data.Complex |
84 | import Data.Function(on) | 84 | import Data.Function(on) |
85 | import Internal.ST | 85 | import Internal.ST |
86 | import Foreign.Storable | ||
86 | #if MIN_VERSION_base(4,11,0) | 87 | #if MIN_VERSION_base(4,11,0) |
87 | import Prelude hiding ((<>)) | 88 | import Prelude hiding ((<>)) |
88 | #endif | 89 | #endif |
@@ -174,7 +175,7 @@ a & b = vjoin [a,b] | |||
174 | 175 | ||
175 | -} | 176 | -} |
176 | infixl 3 ||| | 177 | infixl 3 ||| |
177 | (|||) :: Element t => Matrix t -> Matrix t -> Matrix t | 178 | (|||) :: Storable t => Matrix t -> Matrix t -> Matrix t |
178 | a ||| b = fromBlocks [[a,b]] | 179 | a ||| b = fromBlocks [[a,b]] |
179 | 180 | ||
180 | -- | a synonym for ('|||') (unicode 0x00a6, broken bar) | 181 | -- | a synonym for ('|||') (unicode 0x00a6, broken bar) |
@@ -185,7 +186,7 @@ infixl 3 ¦ | |||
185 | 186 | ||
186 | -- | vertical concatenation | 187 | -- | vertical concatenation |
187 | -- | 188 | -- |
188 | (===) :: Element t => Matrix t -> Matrix t -> Matrix t | 189 | (===) :: Storable t => Matrix t -> Matrix t -> Matrix t |
189 | infixl 2 === | 190 | infixl 2 === |
190 | a === b = fromBlocks [[a],[b]] | 191 | a === b = fromBlocks [[a],[b]] |
191 | 192 | ||
@@ -225,7 +226,7 @@ col = asColumn . fromList | |||
225 | 226 | ||
226 | -} | 227 | -} |
227 | infixl 9 ? | 228 | infixl 9 ? |
228 | (?) :: Element t => Matrix t -> [Int] -> Matrix t | 229 | (?) :: Storable t => Matrix t -> [Int] -> Matrix t |
229 | (?) = flip extractRows | 230 | (?) = flip extractRows |
230 | 231 | ||
231 | {- | extract columns | 232 | {- | extract columns |
@@ -240,7 +241,7 @@ infixl 9 ? | |||
240 | 241 | ||
241 | -} | 242 | -} |
242 | infixl 9 ¿ | 243 | infixl 9 ¿ |
243 | (¿) :: Element t => Matrix t -> [Int] -> Matrix t | 244 | (¿) :: Storable t => Matrix t -> [Int] -> Matrix t |
244 | (¿)= flip extractColumns | 245 | (¿)= flip extractColumns |
245 | 246 | ||
246 | 247 | ||
@@ -329,7 +330,7 @@ instance Normed (Vector (Complex Float)) | |||
329 | norm_Inf = norm_Inf . double | 330 | norm_Inf = norm_Inf . double |
330 | 331 | ||
331 | -- | Frobenius norm (Schatten p-norm with p=2) | 332 | -- | Frobenius norm (Schatten p-norm with p=2) |
332 | norm_Frob :: (Normed (Vector t), Element t) => Matrix t -> R | 333 | norm_Frob :: (Normed (Vector t), Storable t) => Matrix t -> R |
333 | norm_Frob = norm_2 . flatten | 334 | norm_Frob = norm_2 . flatten |
334 | 335 | ||
335 | -- | Sum of singular values (Schatten p-norm with p=1) | 336 | -- | Sum of singular values (Schatten p-norm with p=1) |
@@ -346,7 +347,7 @@ True | |||
346 | True | 347 | True |
347 | 348 | ||
348 | -} | 349 | -} |
349 | magnit :: (Element t, Normed (Vector t)) => R -> t -> Bool | 350 | magnit :: (Storable t, Normed (Vector t)) => R -> t -> Bool |
350 | magnit e x = norm_1 (fromList [x]) > e | 351 | magnit e x = norm_1 (fromList [x]) > e |
351 | 352 | ||
352 | 353 | ||
@@ -415,7 +416,7 @@ instance Indexable (Vector (Complex Float)) (Complex Float) | |||
415 | where | 416 | where |
416 | (!) = (@>) | 417 | (!) = (@>) |
417 | 418 | ||
418 | instance Element t => Indexable (Matrix t) (Vector t) | 419 | instance Storable t => Indexable (Matrix t) (Vector t) |
419 | where | 420 | where |
420 | m!j = subVector (j*c) c (flatten m) | 421 | m!j = subVector (j*c) c (flatten m) |
421 | where | 422 | where |
diff --git a/packages/base/src/Internal/Vector.hs b/packages/base/src/Internal/Vector.hs index 6271bb6..3037019 100644 --- a/packages/base/src/Internal/Vector.hs +++ b/packages/base/src/Internal/Vector.hs | |||
@@ -32,7 +32,7 @@ import Foreign.ForeignPtr | |||
32 | import Foreign.Ptr | 32 | import Foreign.Ptr |
33 | import Foreign.Storable | 33 | import Foreign.Storable |
34 | import Foreign.C.Types(CInt) | 34 | import Foreign.C.Types(CInt) |
35 | import Data.Int(Int64) | 35 | import Data.Int |
36 | import Data.Complex | 36 | import Data.Complex |
37 | import System.IO.Unsafe(unsafePerformIO) | 37 | import System.IO.Unsafe(unsafePerformIO) |
38 | import GHC.ForeignPtr(mallocPlainForeignPtrBytes) | 38 | import GHC.ForeignPtr(mallocPlainForeignPtrBytes) |
@@ -46,18 +46,18 @@ import Control.Monad(replicateM) | |||
46 | import qualified Data.ByteString.Internal as BS | 46 | import qualified Data.ByteString.Internal as BS |
47 | import Data.Vector.Storable.Internal(updPtr) | 47 | import Data.Vector.Storable.Internal(updPtr) |
48 | 48 | ||
49 | type I = CInt | 49 | type I = Int32 |
50 | type Z = Int64 | 50 | type Z = Int64 |
51 | type R = Double | 51 | type R = Double |
52 | type C = Complex Double | 52 | type C = Complex Double |
53 | 53 | ||
54 | 54 | ||
55 | -- | specialized fromIntegral | 55 | -- | specialized fromIntegral |
56 | fi :: Int -> CInt | 56 | fi :: Int -> Int32 |
57 | fi = fromIntegral | 57 | fi = fromIntegral |
58 | 58 | ||
59 | -- | specialized fromIntegral | 59 | -- | specialized fromIntegral |
60 | ti :: CInt -> Int | 60 | ti :: Int32 -> Int |
61 | ti = fromIntegral | 61 | ti = fromIntegral |
62 | 62 | ||
63 | 63 | ||
@@ -69,7 +69,7 @@ dim = Vector.length | |||
69 | 69 | ||
70 | -- C-Haskell vector adapter | 70 | -- C-Haskell vector adapter |
71 | {-# INLINE avec #-} | 71 | {-# INLINE avec #-} |
72 | avec :: Storable a => Vector a -> (f -> IO r) -> ((CInt -> Ptr a -> f) -> IO r) | 72 | avec :: Storable a => Vector a -> (f -> IO r) -> ((Int32 -> Ptr a -> f) -> IO r) |
73 | avec v f g = unsafeWith v $ \ptr -> f (g (fromIntegral (Vector.length v)) ptr) | 73 | avec v f g = unsafeWith v $ \ptr -> f (g (fromIntegral (Vector.length v)) ptr) |
74 | 74 | ||
75 | -- allocates memory for a new vector | 75 | -- allocates memory for a new vector |
diff --git a/packages/base/src/Internal/Vectorized.hs b/packages/base/src/Internal/Vectorized.hs index 32430c6..ede3826 100644 --- a/packages/base/src/Internal/Vectorized.hs +++ b/packages/base/src/Internal/Vectorized.hs | |||
@@ -18,10 +18,12 @@ module Internal.Vectorized where | |||
18 | import Internal.Vector | 18 | import Internal.Vector |
19 | import Internal.Devel | 19 | import Internal.Devel |
20 | import Data.Complex | 20 | import Data.Complex |
21 | import Data.Function | ||
22 | import Data.Int | ||
21 | import Foreign.Marshal.Alloc(free,malloc) | 23 | import Foreign.Marshal.Alloc(free,malloc) |
22 | import Foreign.Marshal.Array(newArray,copyArray) | 24 | import Foreign.Marshal.Array(newArray,copyArray) |
23 | import Foreign.Ptr(Ptr) | 25 | import Foreign.Ptr(Ptr) |
24 | import Foreign.Storable(peek,Storable) | 26 | import Foreign.Storable(peek,pokeElemOff,Storable) |
25 | import Foreign.C.Types | 27 | import Foreign.C.Types |
26 | import Foreign.C.String | 28 | import Foreign.C.String |
27 | import System.IO.Unsafe(unsafePerformIO) | 29 | import System.IO.Unsafe(unsafePerformIO) |
@@ -36,8 +38,8 @@ a # b = applyRaw a b | |||
36 | a #! b = a # b # id | 38 | a #! b = a # b # id |
37 | {-# INLINE (#!) #-} | 39 | {-# INLINE (#!) #-} |
38 | 40 | ||
39 | fromei :: Enum a => a -> CInt | 41 | fromei :: Enum a => a -> Int32 |
40 | fromei x = fromIntegral (fromEnum x) :: CInt | 42 | fromei x = fromIntegral (fromEnum x) :: Int32 |
41 | 43 | ||
42 | data FunCodeV = Sin | 44 | data FunCodeV = Sin |
43 | | Cos | 45 | | Cos |
@@ -103,20 +105,20 @@ sumQ = sumg c_sumQ | |||
103 | sumC :: Vector (Complex Double) -> Complex Double | 105 | sumC :: Vector (Complex Double) -> Complex Double |
104 | sumC = sumg c_sumC | 106 | sumC = sumg c_sumC |
105 | 107 | ||
106 | sumI :: ( TransRaw c (CInt -> Ptr a -> IO CInt) ~ (CInt -> Ptr I -> I :> Ok) | 108 | sumI :: ( TransRaw c (Int32 -> Ptr a -> IO Int32) ~ (Int32 -> Ptr I -> I :> Ok) |
107 | , TransArray c | 109 | , TransArray c |
108 | , Storable a | 110 | , Storable a |
109 | ) | 111 | ) |
110 | => I -> c -> a | 112 | => I -> c -> a |
111 | sumI m = sumg (c_sumI m) | 113 | sumI m = sumg (c_sumI m) |
112 | 114 | ||
113 | sumL :: ( TransRaw c (CInt -> Ptr a -> IO CInt) ~ (CInt -> Ptr Z -> Z :> Ok) | 115 | sumL :: ( TransRaw c (Int32 -> Ptr a -> IO Int32) ~ (Int32 -> Ptr Z -> Z :> Ok) |
114 | , TransArray c | 116 | , TransArray c |
115 | , Storable a | 117 | , Storable a |
116 | ) => Z -> c -> a | 118 | ) => Z -> c -> a |
117 | sumL m = sumg (c_sumL m) | 119 | sumL m = sumg (c_sumL m) |
118 | 120 | ||
119 | sumg :: (TransArray c, Storable a) => TransRaw c (CInt -> Ptr a -> IO CInt) -> c -> a | 121 | sumg :: (TransArray c, Storable a) => TransRaw c (Int32 -> Ptr a -> IO Int32) -> c -> a |
120 | sumg f x = unsafePerformIO $ do | 122 | sumg f x = unsafePerformIO $ do |
121 | r <- createVector 1 | 123 | r <- createVector 1 |
122 | (x #! r) f #| "sum" | 124 | (x #! r) f #| "sum" |
@@ -154,7 +156,7 @@ prodL :: Z-> Vector Z -> Z | |||
154 | prodL = prodg . c_prodL | 156 | prodL = prodg . c_prodL |
155 | 157 | ||
156 | prodg :: (TransArray c, Storable a) | 158 | prodg :: (TransArray c, Storable a) |
157 | => TransRaw c (CInt -> Ptr a -> IO CInt) -> c -> a | 159 | => TransRaw c (Int32 -> Ptr a -> IO Int32) -> c -> a |
158 | prodg f x = unsafePerformIO $ do | 160 | prodg f x = unsafePerformIO $ do |
159 | r <- createVector 1 | 161 | r <- createVector 1 |
160 | (x #! r) f #| "prod" | 162 | (x #! r) f #| "prod" |
@@ -171,7 +173,7 @@ foreign import ccall unsafe "prodL" c_prodL :: Z -> TVV Z | |||
171 | ------------------------------------------------------------------ | 173 | ------------------------------------------------------------------ |
172 | 174 | ||
173 | toScalarAux :: (Enum a, TransArray c, Storable a1) | 175 | toScalarAux :: (Enum a, TransArray c, Storable a1) |
174 | => (CInt -> TransRaw c (CInt -> Ptr a1 -> IO CInt)) -> a -> c -> a1 | 176 | => (Int32 -> TransRaw c (Int32 -> Ptr a1 -> IO Int32)) -> a -> c -> a1 |
175 | toScalarAux fun code v = unsafePerformIO $ do | 177 | toScalarAux fun code v = unsafePerformIO $ do |
176 | r <- createVector 1 | 178 | r <- createVector 1 |
177 | (v #! r) (fun (fromei code)) #|"toScalarAux" | 179 | (v #! r) (fun (fromei code)) #|"toScalarAux" |
@@ -179,7 +181,7 @@ toScalarAux fun code v = unsafePerformIO $ do | |||
179 | 181 | ||
180 | 182 | ||
181 | vectorMapAux :: (Enum a, Storable t, Storable a1) | 183 | vectorMapAux :: (Enum a, Storable t, Storable a1) |
182 | => (CInt -> CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt) | 184 | => (Int32 -> Int32 -> Ptr t -> Int32 -> Ptr a1 -> IO Int32) |
183 | -> a -> Vector t -> Vector a1 | 185 | -> a -> Vector t -> Vector a1 |
184 | vectorMapAux fun code v = unsafePerformIO $ do | 186 | vectorMapAux fun code v = unsafePerformIO $ do |
185 | r <- createVector (dim v) | 187 | r <- createVector (dim v) |
@@ -187,7 +189,7 @@ vectorMapAux fun code v = unsafePerformIO $ do | |||
187 | return r | 189 | return r |
188 | 190 | ||
189 | vectorMapValAux :: (Enum a, Storable a2, Storable t, Storable a1) | 191 | vectorMapValAux :: (Enum a, Storable a2, Storable t, Storable a1) |
190 | => (CInt -> Ptr a2 -> CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt) | 192 | => (Int32 -> Ptr a2 -> Int32 -> Ptr t -> Int32 -> Ptr a1 -> IO Int32) |
191 | -> a -> a2 -> Vector t -> Vector a1 | 193 | -> a -> a2 -> Vector t -> Vector a1 |
192 | vectorMapValAux fun code val v = unsafePerformIO $ do | 194 | vectorMapValAux fun code val v = unsafePerformIO $ do |
193 | r <- createVector (dim v) | 195 | r <- createVector (dim v) |
@@ -197,7 +199,7 @@ vectorMapValAux fun code val v = unsafePerformIO $ do | |||
197 | return r | 199 | return r |
198 | 200 | ||
199 | vectorZipAux :: (Enum a, TransArray c, Storable t, Storable a1) | 201 | vectorZipAux :: (Enum a, TransArray c, Storable t, Storable a1) |
200 | => (CInt -> CInt -> Ptr t -> TransRaw c (CInt -> Ptr a1 -> IO CInt)) | 202 | => (Int32 -> Int32 -> Ptr t -> TransRaw c (Int32 -> Ptr a1 -> IO Int32)) |
201 | -> a -> Vector t -> c -> Vector a1 | 203 | -> a -> Vector t -> c -> Vector a1 |
202 | vectorZipAux fun code u v = unsafePerformIO $ do | 204 | vectorZipAux fun code u v = unsafePerformIO $ do |
203 | r <- createVector (dim u) | 205 | r <- createVector (dim u) |
@@ -210,37 +212,37 @@ vectorZipAux fun code u v = unsafePerformIO $ do | |||
210 | toScalarR :: FunCodeS -> Vector Double -> Double | 212 | toScalarR :: FunCodeS -> Vector Double -> Double |
211 | toScalarR oper = toScalarAux c_toScalarR (fromei oper) | 213 | toScalarR oper = toScalarAux c_toScalarR (fromei oper) |
212 | 214 | ||
213 | foreign import ccall unsafe "toScalarR" c_toScalarR :: CInt -> TVV Double | 215 | foreign import ccall unsafe "toScalarR" c_toScalarR :: Int32 -> TVV Double |
214 | 216 | ||
215 | -- | obtains different functions of a vector: norm1, norm2, max, min, posmax, posmin, etc. | 217 | -- | obtains different functions of a vector: norm1, norm2, max, min, posmax, posmin, etc. |
216 | toScalarF :: FunCodeS -> Vector Float -> Float | 218 | toScalarF :: FunCodeS -> Vector Float -> Float |
217 | toScalarF oper = toScalarAux c_toScalarF (fromei oper) | 219 | toScalarF oper = toScalarAux c_toScalarF (fromei oper) |
218 | 220 | ||
219 | foreign import ccall unsafe "toScalarF" c_toScalarF :: CInt -> TVV Float | 221 | foreign import ccall unsafe "toScalarF" c_toScalarF :: Int32 -> TVV Float |
220 | 222 | ||
221 | -- | obtains different functions of a vector: only norm1, norm2 | 223 | -- | obtains different functions of a vector: only norm1, norm2 |
222 | toScalarC :: FunCodeS -> Vector (Complex Double) -> Double | 224 | toScalarC :: FunCodeS -> Vector (Complex Double) -> Double |
223 | toScalarC oper = toScalarAux c_toScalarC (fromei oper) | 225 | toScalarC oper = toScalarAux c_toScalarC (fromei oper) |
224 | 226 | ||
225 | foreign import ccall unsafe "toScalarC" c_toScalarC :: CInt -> Complex Double :> Double :> Ok | 227 | foreign import ccall unsafe "toScalarC" c_toScalarC :: Int32 -> Complex Double :> Double :> Ok |
226 | 228 | ||
227 | -- | obtains different functions of a vector: only norm1, norm2 | 229 | -- | obtains different functions of a vector: only norm1, norm2 |
228 | toScalarQ :: FunCodeS -> Vector (Complex Float) -> Float | 230 | toScalarQ :: FunCodeS -> Vector (Complex Float) -> Float |
229 | toScalarQ oper = toScalarAux c_toScalarQ (fromei oper) | 231 | toScalarQ oper = toScalarAux c_toScalarQ (fromei oper) |
230 | 232 | ||
231 | foreign import ccall unsafe "toScalarQ" c_toScalarQ :: CInt -> Complex Float :> Float :> Ok | 233 | foreign import ccall unsafe "toScalarQ" c_toScalarQ :: Int32 -> Complex Float :> Float :> Ok |
232 | 234 | ||
233 | -- | obtains different functions of a vector: norm1, norm2, max, min, posmax, posmin, etc. | 235 | -- | obtains different functions of a vector: norm1, norm2, max, min, posmax, posmin, etc. |
234 | toScalarI :: FunCodeS -> Vector CInt -> CInt | 236 | toScalarI :: FunCodeS -> Vector Int32 -> Int32 |
235 | toScalarI oper = toScalarAux c_toScalarI (fromei oper) | 237 | toScalarI oper = toScalarAux c_toScalarI (fromei oper) |
236 | 238 | ||
237 | foreign import ccall unsafe "toScalarI" c_toScalarI :: CInt -> TVV CInt | 239 | foreign import ccall unsafe "toScalarI" c_toScalarI :: Int32 -> TVV Int32 |
238 | 240 | ||
239 | -- | obtains different functions of a vector: norm1, norm2, max, min, posmax, posmin, etc. | 241 | -- | obtains different functions of a vector: norm1, norm2, max, min, posmax, posmin, etc. |
240 | toScalarL :: FunCodeS -> Vector Z -> Z | 242 | toScalarL :: FunCodeS -> Vector Z -> Z |
241 | toScalarL oper = toScalarAux c_toScalarL (fromei oper) | 243 | toScalarL oper = toScalarAux c_toScalarL (fromei oper) |
242 | 244 | ||
243 | foreign import ccall unsafe "toScalarL" c_toScalarL :: CInt -> TVV Z | 245 | foreign import ccall unsafe "toScalarL" c_toScalarL :: Int32 -> TVV Z |
244 | 246 | ||
245 | 247 | ||
246 | ------------------------------------------------------------------ | 248 | ------------------------------------------------------------------ |
@@ -249,37 +251,37 @@ foreign import ccall unsafe "toScalarL" c_toScalarL :: CInt -> TVV Z | |||
249 | vectorMapR :: FunCodeV -> Vector Double -> Vector Double | 251 | vectorMapR :: FunCodeV -> Vector Double -> Vector Double |
250 | vectorMapR = vectorMapAux c_vectorMapR | 252 | vectorMapR = vectorMapAux c_vectorMapR |
251 | 253 | ||
252 | foreign import ccall unsafe "mapR" c_vectorMapR :: CInt -> TVV Double | 254 | foreign import ccall unsafe "mapR" c_vectorMapR :: Int32 -> TVV Double |
253 | 255 | ||
254 | -- | map of complex vectors with given function | 256 | -- | map of complex vectors with given function |
255 | vectorMapC :: FunCodeV -> Vector (Complex Double) -> Vector (Complex Double) | 257 | vectorMapC :: FunCodeV -> Vector (Complex Double) -> Vector (Complex Double) |
256 | vectorMapC oper = vectorMapAux c_vectorMapC (fromei oper) | 258 | vectorMapC oper = vectorMapAux c_vectorMapC (fromei oper) |
257 | 259 | ||
258 | foreign import ccall unsafe "mapC" c_vectorMapC :: CInt -> TVV (Complex Double) | 260 | foreign import ccall unsafe "mapC" c_vectorMapC :: Int32 -> TVV (Complex Double) |
259 | 261 | ||
260 | -- | map of real vectors with given function | 262 | -- | map of real vectors with given function |
261 | vectorMapF :: FunCodeV -> Vector Float -> Vector Float | 263 | vectorMapF :: FunCodeV -> Vector Float -> Vector Float |
262 | vectorMapF = vectorMapAux c_vectorMapF | 264 | vectorMapF = vectorMapAux c_vectorMapF |
263 | 265 | ||
264 | foreign import ccall unsafe "mapF" c_vectorMapF :: CInt -> TVV Float | 266 | foreign import ccall unsafe "mapF" c_vectorMapF :: Int32 -> TVV Float |
265 | 267 | ||
266 | -- | map of real vectors with given function | 268 | -- | map of real vectors with given function |
267 | vectorMapQ :: FunCodeV -> Vector (Complex Float) -> Vector (Complex Float) | 269 | vectorMapQ :: FunCodeV -> Vector (Complex Float) -> Vector (Complex Float) |
268 | vectorMapQ = vectorMapAux c_vectorMapQ | 270 | vectorMapQ = vectorMapAux c_vectorMapQ |
269 | 271 | ||
270 | foreign import ccall unsafe "mapQ" c_vectorMapQ :: CInt -> TVV (Complex Float) | 272 | foreign import ccall unsafe "mapQ" c_vectorMapQ :: Int32 -> TVV (Complex Float) |
271 | 273 | ||
272 | -- | map of real vectors with given function | 274 | -- | map of real vectors with given function |
273 | vectorMapI :: FunCodeV -> Vector CInt -> Vector CInt | 275 | vectorMapI :: FunCodeV -> Vector Int32 -> Vector Int32 |
274 | vectorMapI = vectorMapAux c_vectorMapI | 276 | vectorMapI = vectorMapAux c_vectorMapI |
275 | 277 | ||
276 | foreign import ccall unsafe "mapI" c_vectorMapI :: CInt -> TVV CInt | 278 | foreign import ccall unsafe "mapI" c_vectorMapI :: Int32 -> TVV Int32 |
277 | 279 | ||
278 | -- | map of real vectors with given function | 280 | -- | map of real vectors with given function |
279 | vectorMapL :: FunCodeV -> Vector Z -> Vector Z | 281 | vectorMapL :: FunCodeV -> Vector Z -> Vector Z |
280 | vectorMapL = vectorMapAux c_vectorMapL | 282 | vectorMapL = vectorMapAux c_vectorMapL |
281 | 283 | ||
282 | foreign import ccall unsafe "mapL" c_vectorMapL :: CInt -> TVV Z | 284 | foreign import ccall unsafe "mapL" c_vectorMapL :: Int32 -> TVV Z |
283 | 285 | ||
284 | ------------------------------------------------------------------- | 286 | ------------------------------------------------------------------- |
285 | 287 | ||
@@ -287,37 +289,37 @@ foreign import ccall unsafe "mapL" c_vectorMapL :: CInt -> TVV Z | |||
287 | vectorMapValR :: FunCodeSV -> Double -> Vector Double -> Vector Double | 289 | vectorMapValR :: FunCodeSV -> Double -> Vector Double -> Vector Double |
288 | vectorMapValR oper = vectorMapValAux c_vectorMapValR (fromei oper) | 290 | vectorMapValR oper = vectorMapValAux c_vectorMapValR (fromei oper) |
289 | 291 | ||
290 | foreign import ccall unsafe "mapValR" c_vectorMapValR :: CInt -> Ptr Double -> TVV Double | 292 | foreign import ccall unsafe "mapValR" c_vectorMapValR :: Int32 -> Ptr Double -> TVV Double |
291 | 293 | ||
292 | -- | map of complex vectors with given function | 294 | -- | map of complex vectors with given function |
293 | vectorMapValC :: FunCodeSV -> Complex Double -> Vector (Complex Double) -> Vector (Complex Double) | 295 | vectorMapValC :: FunCodeSV -> Complex Double -> Vector (Complex Double) -> Vector (Complex Double) |
294 | vectorMapValC = vectorMapValAux c_vectorMapValC | 296 | vectorMapValC = vectorMapValAux c_vectorMapValC |
295 | 297 | ||
296 | foreign import ccall unsafe "mapValC" c_vectorMapValC :: CInt -> Ptr (Complex Double) -> TVV (Complex Double) | 298 | foreign import ccall unsafe "mapValC" c_vectorMapValC :: Int32 -> Ptr (Complex Double) -> TVV (Complex Double) |
297 | 299 | ||
298 | -- | map of real vectors with given function | 300 | -- | map of real vectors with given function |
299 | vectorMapValF :: FunCodeSV -> Float -> Vector Float -> Vector Float | 301 | vectorMapValF :: FunCodeSV -> Float -> Vector Float -> Vector Float |
300 | vectorMapValF oper = vectorMapValAux c_vectorMapValF (fromei oper) | 302 | vectorMapValF oper = vectorMapValAux c_vectorMapValF (fromei oper) |
301 | 303 | ||
302 | foreign import ccall unsafe "mapValF" c_vectorMapValF :: CInt -> Ptr Float -> TVV Float | 304 | foreign import ccall unsafe "mapValF" c_vectorMapValF :: Int32 -> Ptr Float -> TVV Float |
303 | 305 | ||
304 | -- | map of complex vectors with given function | 306 | -- | map of complex vectors with given function |
305 | vectorMapValQ :: FunCodeSV -> Complex Float -> Vector (Complex Float) -> Vector (Complex Float) | 307 | vectorMapValQ :: FunCodeSV -> Complex Float -> Vector (Complex Float) -> Vector (Complex Float) |
306 | vectorMapValQ oper = vectorMapValAux c_vectorMapValQ (fromei oper) | 308 | vectorMapValQ oper = vectorMapValAux c_vectorMapValQ (fromei oper) |
307 | 309 | ||
308 | foreign import ccall unsafe "mapValQ" c_vectorMapValQ :: CInt -> Ptr (Complex Float) -> TVV (Complex Float) | 310 | foreign import ccall unsafe "mapValQ" c_vectorMapValQ :: Int32 -> Ptr (Complex Float) -> TVV (Complex Float) |
309 | 311 | ||
310 | -- | map of real vectors with given function | 312 | -- | map of real vectors with given function |
311 | vectorMapValI :: FunCodeSV -> CInt -> Vector CInt -> Vector CInt | 313 | vectorMapValI :: FunCodeSV -> Int32 -> Vector Int32 -> Vector Int32 |
312 | vectorMapValI oper = vectorMapValAux c_vectorMapValI (fromei oper) | 314 | vectorMapValI oper = vectorMapValAux c_vectorMapValI (fromei oper) |
313 | 315 | ||
314 | foreign import ccall unsafe "mapValI" c_vectorMapValI :: CInt -> Ptr CInt -> TVV CInt | 316 | foreign import ccall unsafe "mapValI" c_vectorMapValI :: Int32 -> Ptr Int32 -> TVV Int32 |
315 | 317 | ||
316 | -- | map of real vectors with given function | 318 | -- | map of real vectors with given function |
317 | vectorMapValL :: FunCodeSV -> Z -> Vector Z -> Vector Z | 319 | vectorMapValL :: FunCodeSV -> Z -> Vector Z -> Vector Z |
318 | vectorMapValL oper = vectorMapValAux c_vectorMapValL (fromei oper) | 320 | vectorMapValL oper = vectorMapValAux c_vectorMapValL (fromei oper) |
319 | 321 | ||
320 | foreign import ccall unsafe "mapValL" c_vectorMapValL :: CInt -> Ptr Z -> TVV Z | 322 | foreign import ccall unsafe "mapValL" c_vectorMapValL :: Int32 -> Ptr Z -> TVV Z |
321 | 323 | ||
322 | 324 | ||
323 | ------------------------------------------------------------------- | 325 | ------------------------------------------------------------------- |
@@ -328,42 +330,42 @@ type TVVV t = t :> t :> t :> Ok | |||
328 | vectorZipR :: FunCodeVV -> Vector Double -> Vector Double -> Vector Double | 330 | vectorZipR :: FunCodeVV -> Vector Double -> Vector Double -> Vector Double |
329 | vectorZipR = vectorZipAux c_vectorZipR | 331 | vectorZipR = vectorZipAux c_vectorZipR |
330 | 332 | ||
331 | foreign import ccall unsafe "zipR" c_vectorZipR :: CInt -> TVVV Double | 333 | foreign import ccall unsafe "zipR" c_vectorZipR :: Int32 -> TVVV Double |
332 | 334 | ||
333 | -- | elementwise operation on complex vectors | 335 | -- | elementwise operation on complex vectors |
334 | vectorZipC :: FunCodeVV -> Vector (Complex Double) -> Vector (Complex Double) -> Vector (Complex Double) | 336 | vectorZipC :: FunCodeVV -> Vector (Complex Double) -> Vector (Complex Double) -> Vector (Complex Double) |
335 | vectorZipC = vectorZipAux c_vectorZipC | 337 | vectorZipC = vectorZipAux c_vectorZipC |
336 | 338 | ||
337 | foreign import ccall unsafe "zipC" c_vectorZipC :: CInt -> TVVV (Complex Double) | 339 | foreign import ccall unsafe "zipC" c_vectorZipC :: Int32 -> TVVV (Complex Double) |
338 | 340 | ||
339 | -- | elementwise operation on real vectors | 341 | -- | elementwise operation on real vectors |
340 | vectorZipF :: FunCodeVV -> Vector Float -> Vector Float -> Vector Float | 342 | vectorZipF :: FunCodeVV -> Vector Float -> Vector Float -> Vector Float |
341 | vectorZipF = vectorZipAux c_vectorZipF | 343 | vectorZipF = vectorZipAux c_vectorZipF |
342 | 344 | ||
343 | foreign import ccall unsafe "zipF" c_vectorZipF :: CInt -> TVVV Float | 345 | foreign import ccall unsafe "zipF" c_vectorZipF :: Int32 -> TVVV Float |
344 | 346 | ||
345 | -- | elementwise operation on complex vectors | 347 | -- | elementwise operation on complex vectors |
346 | vectorZipQ :: FunCodeVV -> Vector (Complex Float) -> Vector (Complex Float) -> Vector (Complex Float) | 348 | vectorZipQ :: FunCodeVV -> Vector (Complex Float) -> Vector (Complex Float) -> Vector (Complex Float) |
347 | vectorZipQ = vectorZipAux c_vectorZipQ | 349 | vectorZipQ = vectorZipAux c_vectorZipQ |
348 | 350 | ||
349 | foreign import ccall unsafe "zipQ" c_vectorZipQ :: CInt -> TVVV (Complex Float) | 351 | foreign import ccall unsafe "zipQ" c_vectorZipQ :: Int32 -> TVVV (Complex Float) |
350 | 352 | ||
351 | -- | elementwise operation on CInt vectors | 353 | -- | elementwise operation on Int32 vectors |
352 | vectorZipI :: FunCodeVV -> Vector CInt -> Vector CInt -> Vector CInt | 354 | vectorZipI :: FunCodeVV -> Vector Int32 -> Vector Int32 -> Vector Int32 |
353 | vectorZipI = vectorZipAux c_vectorZipI | 355 | vectorZipI = vectorZipAux c_vectorZipI |
354 | 356 | ||
355 | foreign import ccall unsafe "zipI" c_vectorZipI :: CInt -> TVVV CInt | 357 | foreign import ccall unsafe "zipI" c_vectorZipI :: Int32 -> TVVV Int32 |
356 | 358 | ||
357 | -- | elementwise operation on CInt vectors | 359 | -- | elementwise operation on Int32 vectors |
358 | vectorZipL :: FunCodeVV -> Vector Z -> Vector Z -> Vector Z | 360 | vectorZipL :: FunCodeVV -> Vector Z -> Vector Z -> Vector Z |
359 | vectorZipL = vectorZipAux c_vectorZipL | 361 | vectorZipL = vectorZipAux c_vectorZipL |
360 | 362 | ||
361 | foreign import ccall unsafe "zipL" c_vectorZipL :: CInt -> TVVV Z | 363 | foreign import ccall unsafe "zipL" c_vectorZipL :: Int32 -> TVVV Z |
362 | 364 | ||
363 | -------------------------------------------------------------------------------- | 365 | -------------------------------------------------------------------------------- |
364 | 366 | ||
365 | foreign import ccall unsafe "vectorScan" c_vectorScan | 367 | foreign import ccall unsafe "vectorScan" c_vectorScan |
366 | :: CString -> Ptr CInt -> Ptr (Ptr Double) -> IO CInt | 368 | :: CString -> Ptr Int32 -> Ptr (Ptr Double) -> IO Int32 |
367 | 369 | ||
368 | vectorScan :: FilePath -> IO (Vector Double) | 370 | vectorScan :: FilePath -> IO (Vector Double) |
369 | vectorScan s = do | 371 | vectorScan s = do |
@@ -401,7 +403,7 @@ randomVector seed dist n = unsafePerformIO $ do | |||
401 | (r # id) (c_random_vector (fi seed) ((fi.fromEnum) dist)) #|"randomVector" | 403 | (r # id) (c_random_vector (fi seed) ((fi.fromEnum) dist)) #|"randomVector" |
402 | return r | 404 | return r |
403 | 405 | ||
404 | foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> Double :> Ok | 406 | foreign import ccall unsafe "random_vector" c_random_vector :: Int32 -> Int32 -> Double :> Ok |
405 | 407 | ||
406 | -------------------------------------------------------------------------------- | 408 | -------------------------------------------------------------------------------- |
407 | 409 | ||
@@ -426,7 +428,7 @@ range n = unsafePerformIO $ do | |||
426 | (r # id) c_range_vector #|"range" | 428 | (r # id) c_range_vector #|"range" |
427 | return r | 429 | return r |
428 | 430 | ||
429 | foreign import ccall unsafe "range_vector" c_range_vector :: CInt :> Ok | 431 | foreign import ccall unsafe "range_vector" c_range_vector :: Int32 :> Ok |
430 | 432 | ||
431 | 433 | ||
432 | float2DoubleV :: Vector Float -> Vector Double | 434 | float2DoubleV :: Vector Float -> Vector Double |
@@ -435,10 +437,10 @@ float2DoubleV = tog c_float2double | |||
435 | double2FloatV :: Vector Double -> Vector Float | 437 | double2FloatV :: Vector Double -> Vector Float |
436 | double2FloatV = tog c_double2float | 438 | double2FloatV = tog c_double2float |
437 | 439 | ||
438 | double2IntV :: Vector Double -> Vector CInt | 440 | double2IntV :: Vector Double -> Vector Int32 |
439 | double2IntV = tog c_double2int | 441 | double2IntV = tog c_double2int |
440 | 442 | ||
441 | int2DoubleV :: Vector CInt -> Vector Double | 443 | int2DoubleV :: Vector Int32 -> Vector Double |
442 | int2DoubleV = tog c_int2double | 444 | int2DoubleV = tog c_int2double |
443 | 445 | ||
444 | double2longV :: Vector Double -> Vector Z | 446 | double2longV :: Vector Double -> Vector Z |
@@ -448,10 +450,10 @@ long2DoubleV :: Vector Z -> Vector Double | |||
448 | long2DoubleV = tog c_long2double | 450 | long2DoubleV = tog c_long2double |
449 | 451 | ||
450 | 452 | ||
451 | float2IntV :: Vector Float -> Vector CInt | 453 | float2IntV :: Vector Float -> Vector Int32 |
452 | float2IntV = tog c_float2int | 454 | float2IntV = tog c_float2int |
453 | 455 | ||
454 | int2floatV :: Vector CInt -> Vector Float | 456 | int2floatV :: Vector Int32 -> Vector Float |
455 | int2floatV = tog c_int2float | 457 | int2floatV = tog c_int2float |
456 | 458 | ||
457 | int2longV :: Vector I -> Vector Z | 459 | int2longV :: Vector I -> Vector Z |
@@ -462,7 +464,7 @@ long2intV = tog c_long2int | |||
462 | 464 | ||
463 | 465 | ||
464 | tog :: (Storable t, Storable a) | 466 | tog :: (Storable t, Storable a) |
465 | => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a | 467 | => (Int32 -> Ptr t -> Int32 -> Ptr a -> IO Int32) -> Vector t -> Vector a |
466 | tog f v = unsafePerformIO $ do | 468 | tog f v = unsafePerformIO $ do |
467 | r <- createVector (dim v) | 469 | r <- createVector (dim v) |
468 | (v #! r) f #|"tog" | 470 | (v #! r) f #|"tog" |
@@ -470,12 +472,12 @@ tog f v = unsafePerformIO $ do | |||
470 | 472 | ||
471 | foreign import ccall unsafe "float2double" c_float2double :: Float :> Double :> Ok | 473 | foreign import ccall unsafe "float2double" c_float2double :: Float :> Double :> Ok |
472 | foreign import ccall unsafe "double2float" c_double2float :: Double :> Float :> Ok | 474 | foreign import ccall unsafe "double2float" c_double2float :: Double :> Float :> Ok |
473 | foreign import ccall unsafe "int2double" c_int2double :: CInt :> Double :> Ok | 475 | foreign import ccall unsafe "int2double" c_int2double :: Int32 :> Double :> Ok |
474 | foreign import ccall unsafe "double2int" c_double2int :: Double :> CInt :> Ok | 476 | foreign import ccall unsafe "double2int" c_double2int :: Double :> Int32 :> Ok |
475 | foreign import ccall unsafe "long2double" c_long2double :: Z :> Double :> Ok | 477 | foreign import ccall unsafe "long2double" c_long2double :: Z :> Double :> Ok |
476 | foreign import ccall unsafe "double2long" c_double2long :: Double :> Z :> Ok | 478 | foreign import ccall unsafe "double2long" c_double2long :: Double :> Z :> Ok |
477 | foreign import ccall unsafe "int2float" c_int2float :: CInt :> Float :> Ok | 479 | foreign import ccall unsafe "int2float" c_int2float :: Int32 :> Float :> Ok |
478 | foreign import ccall unsafe "float2int" c_float2int :: Float :> CInt :> Ok | 480 | foreign import ccall unsafe "float2int" c_float2int :: Float :> Int32 :> Ok |
479 | foreign import ccall unsafe "int2long" c_int2long :: I :> Z :> Ok | 481 | foreign import ccall unsafe "int2long" c_int2long :: I :> Z :> Ok |
480 | foreign import ccall unsafe "long2int" c_long2int :: Z :> I :> Ok | 482 | foreign import ccall unsafe "long2int" c_long2int :: Z :> I :> Ok |
481 | 483 | ||
@@ -483,7 +485,7 @@ foreign import ccall unsafe "long2int" c_long2int :: Z :> I :> Ok | |||
483 | --------------------------------------------------------------- | 485 | --------------------------------------------------------------- |
484 | 486 | ||
485 | stepg :: (Storable t, Storable a) | 487 | stepg :: (Storable t, Storable a) |
486 | => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a | 488 | => (Int32 -> Ptr t -> Int32 -> Ptr a -> IO Int32) -> Vector t -> Vector a |
487 | stepg f v = unsafePerformIO $ do | 489 | stepg f v = unsafePerformIO $ do |
488 | r <- createVector (dim v) | 490 | r <- createVector (dim v) |
489 | (v #! r) f #|"step" | 491 | (v #! r) f #|"step" |
@@ -495,7 +497,7 @@ stepD = stepg c_stepD | |||
495 | stepF :: Vector Float -> Vector Float | 497 | stepF :: Vector Float -> Vector Float |
496 | stepF = stepg c_stepF | 498 | stepF = stepg c_stepF |
497 | 499 | ||
498 | stepI :: Vector CInt -> Vector CInt | 500 | stepI :: Vector Int32 -> Vector Int32 |
499 | stepI = stepg c_stepI | 501 | stepI = stepg c_stepI |
500 | 502 | ||
501 | stepL :: Vector Z -> Vector Z | 503 | stepL :: Vector Z -> Vector Z |
@@ -504,13 +506,13 @@ stepL = stepg c_stepL | |||
504 | 506 | ||
505 | foreign import ccall unsafe "stepF" c_stepF :: TVV Float | 507 | foreign import ccall unsafe "stepF" c_stepF :: TVV Float |
506 | foreign import ccall unsafe "stepD" c_stepD :: TVV Double | 508 | foreign import ccall unsafe "stepD" c_stepD :: TVV Double |
507 | foreign import ccall unsafe "stepI" c_stepI :: TVV CInt | 509 | foreign import ccall unsafe "stepI" c_stepI :: TVV Int32 |
508 | foreign import ccall unsafe "stepL" c_stepL :: TVV Z | 510 | foreign import ccall unsafe "stepL" c_stepL :: TVV Z |
509 | 511 | ||
510 | -------------------------------------------------------------------------------- | 512 | -------------------------------------------------------------------------------- |
511 | 513 | ||
512 | conjugateAux :: (Storable t, Storable a) | 514 | conjugateAux :: (Storable t, Storable a) |
513 | => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a | 515 | => (Int32 -> Ptr t -> Int32 -> Ptr a -> IO Int32) -> Vector t -> Vector a |
514 | conjugateAux fun x = unsafePerformIO $ do | 516 | conjugateAux fun x = unsafePerformIO $ do |
515 | v <- createVector (dim x) | 517 | v <- createVector (dim x) |
516 | (x #! v) fun #|"conjugateAux" | 518 | (x #! v) fun #|"conjugateAux" |
@@ -536,22 +538,29 @@ cloneVector v = do | |||
536 | 538 | ||
537 | -------------------------------------------------------------------------------- | 539 | -------------------------------------------------------------------------------- |
538 | 540 | ||
539 | constantAux :: (Storable a1, Storable a) | 541 | constantAux :: Storable a => a -> Int -> Vector a |
540 | => (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a | 542 | constantAux x n = unsafePerformIO $ do |
541 | constantAux fun x n = unsafePerformIO $ do | ||
542 | v <- createVector n | 543 | v <- createVector n |
543 | px <- newArray [x] | 544 | px <- newArray [x] |
544 | (v # id) (fun px) #|"constantAux" | 545 | (v # id) (constantStorable px) #|"constantAux" |
545 | free px | 546 | free px |
546 | return v | 547 | return v |
547 | 548 | ||
549 | constantStorable :: Storable a => Ptr a -> Int32 -> Ptr a -> IO Int32 | ||
550 | constantStorable pval n p = do | ||
551 | val <- peek pval | ||
552 | ($ 0) $ fix $ \iloop i -> when (i<n) $ do | ||
553 | pokeElemOff p (fromIntegral i) val | ||
554 | iloop $! succ i | ||
555 | return 0 | ||
556 | |||
548 | type TConst t = Ptr t -> t :> Ok | 557 | type TConst t = Ptr t -> t :> Ok |
549 | 558 | ||
550 | foreign import ccall unsafe "constantF" cconstantF :: TConst Float | 559 | foreign import ccall unsafe "constantF" cconstantF :: TConst Float |
551 | foreign import ccall unsafe "constantR" cconstantR :: TConst Double | 560 | foreign import ccall unsafe "constantR" cconstantR :: TConst Double |
552 | foreign import ccall unsafe "constantQ" cconstantQ :: TConst (Complex Float) | 561 | foreign import ccall unsafe "constantQ" cconstantQ :: TConst (Complex Float) |
553 | foreign import ccall unsafe "constantC" cconstantC :: TConst (Complex Double) | 562 | foreign import ccall unsafe "constantC" cconstantC :: TConst (Complex Double) |
554 | foreign import ccall unsafe "constantI" cconstantI :: TConst CInt | 563 | foreign import ccall unsafe "constantI" cconstantI :: TConst Int32 |
555 | foreign import ccall unsafe "constantL" cconstantL :: TConst Z | 564 | foreign import ccall unsafe "constantL" cconstantL :: TConst Z |
556 | 565 | ||
557 | ---------------------------------------------------------------------- | 566 | ---------------------------------------------------------------------- |
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs index 9670187..a0a23bd 100644 --- a/packages/base/src/Numeric/LinearAlgebra.hs +++ b/packages/base/src/Numeric/LinearAlgebra.hs | |||
@@ -167,7 +167,7 @@ module Numeric.LinearAlgebra ( | |||
167 | haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv, | 167 | haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv, |
168 | iC, sym, mTm, trustSym, unSym, | 168 | iC, sym, mTm, trustSym, unSym, |
169 | -- * Auxiliary classes | 169 | -- * Auxiliary classes |
170 | Element, Container, Product, Numeric, LSDiv, Herm, | 170 | Container, Product, Numeric, LSDiv, Herm, |
171 | Complexable, RealElement, | 171 | Complexable, RealElement, |
172 | RealOf, ComplexOf, SingleOf, DoubleOf, | 172 | RealOf, ComplexOf, SingleOf, DoubleOf, |
173 | IndexOf, | 173 | IndexOf, |