summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoe Crayne <joe@jerkface.net>2019-08-08 02:22:30 -0400
committerJoe Crayne <joe@jerkface.net>2019-08-08 22:47:46 -0400
commitbadcbdfddc4be31fc79a6df4553795af18069efe (patch)
tree90c38bd8793b53a5e6f00049eb78acaa8d88d711
parentd844a145f2e8808c9f75cd99c673d5f5c8960bf2 (diff)
Removed the Element class.tower
-rw-r--r--packages/base/hmatrix.cabal1
-rw-r--r--packages/base/src/Internal/Algorithms.hs5
-rw-r--r--packages/base/src/Internal/Container.hs96
-rw-r--r--packages/base/src/Internal/Conversion.hs7
-rw-r--r--packages/base/src/Internal/Convolution.hs7
-rw-r--r--packages/base/src/Internal/Devel.hs21
-rw-r--r--packages/base/src/Internal/Element.hs84
-rw-r--r--packages/base/src/Internal/Extract.hs145
-rw-r--r--packages/base/src/Internal/IO.hs9
-rw-r--r--packages/base/src/Internal/LAPACK.hs19
-rw-r--r--packages/base/src/Internal/Matrix.hs307
-rw-r--r--packages/base/src/Internal/Modular.hs6
-rw-r--r--packages/base/src/Internal/Numeric.hs80
-rw-r--r--packages/base/src/Internal/ST.hs131
-rw-r--r--packages/base/src/Internal/Sparse.hs16
-rw-r--r--packages/base/src/Internal/Util.hs15
-rw-r--r--packages/base/src/Internal/Vector.hs10
-rw-r--r--packages/base/src/Internal/Vectorized.hs133
-rw-r--r--packages/base/src/Numeric/LinearAlgebra.hs2
-rw-r--r--packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs7
20 files changed, 725 insertions, 376 deletions
diff --git a/packages/base/hmatrix.cabal b/packages/base/hmatrix.cabal
index 4dc62e5..476a293 100644
--- a/packages/base/hmatrix.cabal
+++ b/packages/base/hmatrix.cabal
@@ -66,6 +66,7 @@ library
66 Internal.Devel 66 Internal.Devel
67 Internal.Vectorized 67 Internal.Vectorized
68 Internal.Matrix 68 Internal.Matrix
69 Internal.Extract
69 Internal.ST 70 Internal.ST
70 Internal.IO 71 Internal.IO
71 Internal.Element 72 Internal.Element
diff --git a/packages/base/src/Internal/Algorithms.hs b/packages/base/src/Internal/Algorithms.hs
index f5bddc6..aa51792 100644
--- a/packages/base/src/Internal/Algorithms.hs
+++ b/packages/base/src/Internal/Algorithms.hs
@@ -39,6 +39,7 @@ import qualified Data.Vector.Storable as Vector
39import Internal.ST 39import Internal.ST
40import Internal.Vectorized(range) 40import Internal.Vectorized(range)
41import Control.DeepSeq 41import Control.DeepSeq
42import Foreign.Storable
42 43
43{- | Generic linear algebra functions for double precision real and complex matrices. 44{- | Generic linear algebra functions for double precision real and complex matrices.
44 45
@@ -742,7 +743,7 @@ pinvTol t m = v' `mXm` diag s' `mXm` ctrans u' where
742 743
743 744
744-- | Numeric rank of a matrix from the SVD decomposition. 745-- | Numeric rank of a matrix from the SVD decomposition.
745rankSVD :: Element t 746rankSVD :: Storable t
746 => Double -- ^ numeric zero (e.g. 1*'eps') 747 => Double -- ^ numeric zero (e.g. 1*'eps')
747 -> Matrix t -- ^ input matrix m 748 -> Matrix t -- ^ input matrix m
748 -> Vector Double -- ^ 'sv' of m 749 -> Vector Double -- ^ 'sv' of m
@@ -1003,7 +1004,7 @@ fixPerm' s = res $ mutable f s0
1003 s0 = reshape 1 (range (length s)) 1004 s0 = reshape 1 (range (length s))
1004 res = flatten . fst 1005 res = flatten . fst
1005 swap m i j = rowOper (SWAP i j AllCols) m 1006 swap m i j = rowOper (SWAP i j AllCols) m
1006 f :: (Num t, Element t) => (Int, Int) -> STMatrix s t -> ST s () -- needed because of TypeFamilies 1007 f :: (Num t, Storable t) => (Int, Int) -> STMatrix s t -> ST s () -- needed because of TypeFamilies
1007 f _ p = sequence_ $ zipWith (swap p) [0..] s 1008 f _ p = sequence_ $ zipWith (swap p) [0..] s
1008 1009
1009triang r c h v = (r><c) [el s t | s<-[0..r-1], t<-[0..c-1]] 1010triang r c h v = (r><c) [el s t | s<-[0..r-1], t<-[0..c-1]]
diff --git a/packages/base/src/Internal/Container.hs b/packages/base/src/Internal/Container.hs
index 41b8214..0f2e7d5 100644
--- a/packages/base/src/Internal/Container.hs
+++ b/packages/base/src/Internal/Container.hs
@@ -4,6 +4,8 @@
4{-# LANGUAGE MultiParamTypeClasses #-} 4{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE FunctionalDependencies #-} 5{-# LANGUAGE FunctionalDependencies #-}
6{-# LANGUAGE UndecidableInstances #-} 6{-# LANGUAGE UndecidableInstances #-}
7{-# LANGUAGE PatternSynonyms #-}
8{-# LANGUAGE ScopedTypeVariables #-}
7 9
8{-# OPTIONS_GHC -fno-warn-simplifiable-class-constraints #-} 10{-# OPTIONS_GHC -fno-warn-simplifiable-class-constraints #-}
9 11
@@ -30,8 +32,15 @@ module Internal.Container where
30import Internal.Vector 32import Internal.Vector
31import Internal.Matrix 33import Internal.Matrix
32import Internal.Element 34import Internal.Element
35import Internal.Extract(requires,pattern BAD_SIZE)
33import Internal.Numeric 36import Internal.Numeric
34import Internal.Algorithms(Field,linearSolveSVD,Herm,mTm) 37import Internal.Algorithms(Field,linearSolveSVD,Herm,mTm)
38import Control.Monad(when)
39import Data.Function
40import Data.Int
41import Foreign.Ptr
42import Foreign.Storable
43import Foreign.Marshal.Array
35#if MIN_VERSION_base(4,11,0) 44#if MIN_VERSION_base(4,11,0)
36import Prelude hiding ((<>)) 45import Prelude hiding ((<>))
37#endif 46#endif
@@ -227,7 +236,7 @@ meanCov x = (med,cov) where
227 236
228-------------------------------------------------------------------------------- 237--------------------------------------------------------------------------------
229 238
230sortVector :: (Ord t, Element t) => Vector t -> Vector t 239sortVector :: (Ord t, Storable t) => Vector t -> Vector t
231sortVector = sortV 240sortVector = sortV
232 241
233{- | 242{- |
@@ -248,7 +257,7 @@ sortVector = sortV
248-2.20 0.11 -1.58 -0.01 0.19 -0.29 1.04 1.06 -2.09 -0.75 257-2.20 0.11 -1.58 -0.01 0.19 -0.29 1.04 1.06 -2.09 -0.75
249 258
250-} 259-}
251sortIndex :: (Ord t, Element t) => Vector t -> Vector I 260sortIndex :: (Ord t, Storable t) => Vector t -> Vector I
252sortIndex = sortI 261sortIndex = sortI
253 262
254ccompare :: (Ord t, Container c t) => c t -> c t -> c I 263ccompare :: (Ord t, Container c t) => c t -> c t -> c I
@@ -296,10 +305,91 @@ The indexes are autoconformable.
296 , 10, 16, 22 ] 305 , 10, 16, 22 ]
297 306
298-} 307-}
299remap :: Element t => Matrix I -> Matrix I -> Matrix t -> Matrix t 308remap :: Storable t => Matrix I -> Matrix I -> Matrix t -> Matrix t
300remap i j m 309remap i j m
301 | minElement i >= 0 && maxElement i < fromIntegral (rows m) && 310 | minElement i >= 0 && maxElement i < fromIntegral (rows m) &&
302 minElement j >= 0 && maxElement j < fromIntegral (cols m) = remapM i' j' m 311 minElement j >= 0 && maxElement j < fromIntegral (cols m) = remapM i' j' m
303 | otherwise = error $ "out of range index in remap" 312 | otherwise = error $ "out of range index in remap"
304 where 313 where
305 [i',j'] = conformMs [i,j] 314 [i',j'] = conformMs [i,j]
315
316sortI :: (Storable a, Ord a) => Vector a -> Vector Int32
317sortI = sortG sort_index
318
319type C_Compare a = Ptr a -> Ptr a -> IO Int32
320
321foreign import ccall "wrapper" wrapCompare :: C_Compare a -> IO (FunPtr (C_Compare a))
322
323foreign import ccall "qsort"
324 c_qsort :: Ptr a -- ^ base
325 -> Word -- ^ nmemb
326 -> Word -- ^ size
327 -> FunPtr (C_Compare a) -- ^ compar
328 -> IO ()
329
330sizeOfElem :: forall a. Storable a => Ptr a -> Int
331sizeOfElem _ = sizeOf (undefined :: a)
332
333sort_index :: (Storable a, Ord a) =>
334 Int32 -> Ptr a
335 -> Int32 -> Ptr Int32
336 -> IO Int32
337sort_index vn vp rn rp = do
338 requires (vn == rn) BAD_SIZE $ do
339 comp <- wrapCompare $ \ap bp -> do
340 a <- peekElemOff vp . fromIntegral =<< peek (ap :: Ptr Int32)
341 b <- peekElemOff vp . fromIntegral =<< peek bp
342 return $ case compare a b of
343 LT -> -1
344 GT -> 1
345 EQ -> 0
346 sequence_ [ pokeElemOff rp (fromIntegral i) i | i <- [0 .. rn-1] ]
347 c_qsort rp (fromIntegral rn) 4 comp
348 freeHaskellFunPtr comp
349 return 0
350
351sortV :: (Storable a, Ord a) => Vector a -> Vector a
352sortV = sortG sortStorable
353
354sortStorable :: (Storable a, Ord a) =>
355 Int32 -> Ptr a
356 -> Int32 -> Ptr a
357 -> IO Int32
358sortStorable vn vp rn rp = do
359 requires (vn == rn) BAD_SIZE $ do
360 copyArray rp vp (fromIntegral vn * sizeOfElem vp)
361 comp <- wrapCompare $ \ap bp -> do
362 a <- peek ap
363 b <- peek bp
364 return $ case compare a b of
365 LT -> -1
366 GT -> 1
367 EQ -> 0
368 c_qsort rp (fromIntegral rn) (fromIntegral $ sizeOfElem rp) comp
369 freeHaskellFunPtr comp
370 return 0
371
372remapM :: Storable a => Matrix Int32 -> Matrix Int32 -> Matrix a -> Matrix a
373remapM = remapG remapStorable
374
375remapStorable :: Storable a =>
376 Int32 -> Int32 -> Int32 -> Int32 -> Ptr Int32 -- i
377 -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Int32 -- j
378 -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr a -- m
379 -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr a -- r
380 -> IO Int32
381remapStorable ir ic iXr iXc ip
382 jr jc jXr jXc jp
383 mr mc mXr mXc mp
384 rr rc rXr rXc rp = do
385 requires (ir==jr && ic==jc && ir==rr && ic==rc) BAD_SIZE $ do
386 ($ 0) $ fix $ \aloop a -> when (a<rr) $ do
387 ($ 0) $ fix $ \bloop b -> when (b<rc) $ do
388 iab <- peekElemOff ip (fromIntegral $ iXr*a + iXc*b)
389 jab <- peekElemOff jp (fromIntegral $ jXr*a + jXc*b)
390 when (0 <= iab && iab < mr && 0 <= jab && jab < mc) $
391 pokeElemOff rp (fromIntegral $ rXr*a + rXc*b)
392 =<< peekElemOff mp (fromIntegral $ mXr*iab + mXc*jab)
393 bloop (succ b)
394 aloop (succ a)
395 return 0
diff --git a/packages/base/src/Internal/Conversion.hs b/packages/base/src/Internal/Conversion.hs
index 4541ec4..7eb8ec7 100644
--- a/packages/base/src/Internal/Conversion.hs
+++ b/packages/base/src/Internal/Conversion.hs
@@ -28,11 +28,12 @@ import Internal.Matrix
28import Internal.Vectorized 28import Internal.Vectorized
29import Data.Complex 29import Data.Complex
30import Control.Arrow((***)) 30import Control.Arrow((***))
31import Foreign.Storable
31 32
32------------------------------------------------------------------- 33-------------------------------------------------------------------
33 34
34-- | Supported single-double precision type pairs 35-- | Supported single-double precision type pairs
35class (Element s, Element d) => Precision s d | s -> d, d -> s where 36class (Storable s, Storable d) => Precision s d | s -> d, d -> s where
36 double2FloatG :: Vector d -> Vector s 37 double2FloatG :: Vector d -> Vector s
37 float2DoubleG :: Vector s -> Vector d 38 float2DoubleG :: Vector s -> Vector d
38 39
@@ -50,7 +51,7 @@ instance Precision I Z where
50 51
51 52
52-- | Supported real types 53-- | Supported real types
53class (Element t, Element (Complex t), RealFloat t) 54class (Storable t, Storable (Complex t), RealFloat t)
54 => RealElement t 55 => RealElement t
55 56
56instance RealElement Double 57instance RealElement Double
@@ -69,7 +70,7 @@ class Complexable c where
69instance Complexable Vector where 70instance Complexable Vector where
70 toComplex' = toComplexV 71 toComplex' = toComplexV
71 fromComplex' = fromComplexV 72 fromComplex' = fromComplexV
72 comp' v = toComplex' (v,constantD 0 (dim v)) 73 comp' v = toComplex' (v,constantAux 0 (dim v))
73 single' = double2FloatG 74 single' = double2FloatG
74 double' = float2DoubleG 75 double' = float2DoubleG
75 76
diff --git a/packages/base/src/Internal/Convolution.hs b/packages/base/src/Internal/Convolution.hs
index 75fbef4..ae8ebc6 100644
--- a/packages/base/src/Internal/Convolution.hs
+++ b/packages/base/src/Internal/Convolution.hs
@@ -24,12 +24,13 @@ import Internal.Numeric
24import Internal.Element 24import Internal.Element
25import Internal.Conversion 25import Internal.Conversion
26import Internal.Container 26import Internal.Container
27import Foreign.Storable
27#if MIN_VERSION_base(4,11,0) 28#if MIN_VERSION_base(4,11,0)
28import Prelude hiding ((<>)) 29import Prelude hiding ((<>))
29#endif 30#endif
30 31
31 32
32vectSS :: Element t => Int -> Vector t -> Matrix t 33vectSS :: Storable t => Int -> Vector t -> Matrix t
33vectSS n v = fromRows [ subVector k n v | k <- [0 .. dim v - n] ] 34vectSS n v = fromRows [ subVector k n v | k <- [0 .. dim v - n] ]
34 35
35 36
@@ -82,7 +83,7 @@ corrMin ker v
82 83
83 84
84 85
85matSS :: Element t => Int -> Matrix t -> [Matrix t] 86matSS :: Storable t => Int -> Matrix t -> [Matrix t]
86matSS dr m = map (reshape c) [ subVector (k*c) n v | k <- [0 .. r - dr] ] 87matSS dr m = map (reshape c) [ subVector (k*c) n v | k <- [0 .. r - dr] ]
87 where 88 where
88 v = flatten m 89 v = flatten m
@@ -155,7 +156,7 @@ conv2 k m
155 empty = r == 0 || c == 0 156 empty = r == 0 || c == 0
156 157
157 158
158separable :: Element t => (Vector t -> Vector t) -> Matrix t -> Matrix t 159separable :: Storable t => (Vector t -> Vector t) -> Matrix t -> Matrix t
159-- ^ matrix computation implemented as separated vector operations by rows and columns. 160-- ^ matrix computation implemented as separated vector operations by rows and columns.
160separable f = fromColumns . map f . toColumns . fromRows . map f . toRows 161separable f = fromColumns . map f . toColumns . fromRows . map f . toRows
161 162
diff --git a/packages/base/src/Internal/Devel.hs b/packages/base/src/Internal/Devel.hs
index f72d8aa..b0594d4 100644
--- a/packages/base/src/Internal/Devel.hs
+++ b/packages/base/src/Internal/Devel.hs
@@ -13,6 +13,7 @@ module Internal.Devel where
13 13
14 14
15import Control.Monad ( when ) 15import Control.Monad ( when )
16import Data.Int
16import Foreign.C.Types ( CInt ) 17import Foreign.C.Types ( CInt )
17--import Foreign.Storable.Complex () 18--import Foreign.Storable.Complex ()
18import Foreign.Ptr(Ptr) 19import Foreign.Ptr(Ptr)
@@ -28,7 +29,7 @@ infixl 0 //
28 29
29-- GSL error codes are <= 1024 30-- GSL error codes are <= 1024
30-- | error codes for the auxiliary functions required by the wrappers 31-- | error codes for the auxiliary functions required by the wrappers
31errorCode :: CInt -> String 32errorCode :: Int32 -> String
32errorCode 2000 = "bad size" 33errorCode 2000 = "bad size"
33errorCode 2001 = "bad function code" 34errorCode 2001 = "bad function code"
34errorCode 2002 = "memory problem" 35errorCode 2002 = "memory problem"
@@ -44,7 +45,7 @@ errorCode n = "code "++show n
44foreign import ccall unsafe "asm_finit" finit :: IO () 45foreign import ccall unsafe "asm_finit" finit :: IO ()
45 46
46-- | check the error code 47-- | check the error code
47check :: String -> IO CInt -> IO () 48check :: String -> IO Int32 -> IO ()
48check msg f = do 49check msg f = do
49-- finit 50-- finit
50 err <- f 51 err <- f
@@ -54,7 +55,7 @@ check msg f = do
54 55
55-- | postfix error code check 56-- | postfix error code check
56infixl 0 #| 57infixl 0 #|
57(#|) :: IO CInt -> String -> IO () 58(#|) :: IO Int32 -> String -> IO ()
58(#|) = flip check 59(#|) = flip check
59 60
60-- | Error capture and conversion to Maybe 61-- | Error capture and conversion to Maybe
@@ -65,12 +66,12 @@ mbCatch act = E.catch (Just `fmap` act) f
65 66
66-------------------------------------------------------------------------------- 67--------------------------------------------------------------------------------
67 68
68type CM b r = CInt -> CInt -> Ptr b -> r 69type CM b r = Int32 -> Int32 -> Ptr b -> r
69type CV b r = CInt -> Ptr b -> r 70type CV b r = Int32 -> Ptr b -> r
70type OM b r = CInt -> CInt -> CInt -> CInt -> Ptr b -> r 71type OM b r = Int32 -> Int32 -> Int32 -> Int32 -> Ptr b -> r
71 72
72type CIdxs r = CV CInt r 73type CIdxs r = CV Int32 r
73type Ok = IO CInt 74type Ok = IO Int32
74 75
75infixr 5 :>, ::>, ..> 76infixr 5 :>, ::>, ..>
76type (:>) t r = CV t r 77type (:>) t r = CV t r
@@ -87,8 +88,8 @@ class TransArray c
87 88
88instance Storable t => TransArray (Vector t) 89instance Storable t => TransArray (Vector t)
89 where 90 where
90 type Trans (Vector t) b = CInt -> Ptr t -> b 91 type Trans (Vector t) b = Int32 -> Ptr t -> b
91 type TransRaw (Vector t) b = CInt -> Ptr t -> b 92 type TransRaw (Vector t) b = Int32 -> Ptr t -> b
92 apply = avec 93 apply = avec
93 {-# INLINE apply #-} 94 {-# INLINE apply #-}
94 applyRaw = avec 95 applyRaw = avec
diff --git a/packages/base/src/Internal/Element.hs b/packages/base/src/Internal/Element.hs
index 2e330ee..80eda8d 100644
--- a/packages/base/src/Internal/Element.hs
+++ b/packages/base/src/Internal/Element.hs
@@ -33,14 +33,14 @@ import Data.List.Split(chunksOf)
33import Foreign.Storable(Storable) 33import Foreign.Storable(Storable)
34import System.IO.Unsafe(unsafePerformIO) 34import System.IO.Unsafe(unsafePerformIO)
35import Control.Monad(liftM) 35import Control.Monad(liftM)
36import Foreign.C.Types(CInt) 36import Data.Int
37 37
38------------------------------------------------------------------- 38-------------------------------------------------------------------
39 39
40 40
41import Data.Binary 41import Data.Binary
42 42
43instance (Binary (Vector a), Element a) => Binary (Matrix a) where 43instance (Binary (Vector a), Storable a) => Binary (Matrix a) where
44 put m = do 44 put m = do
45 put (cols m) 45 put (cols m)
46 put (flatten m) 46 put (flatten m)
@@ -52,7 +52,7 @@ instance (Binary (Vector a), Element a) => Binary (Matrix a) where
52 52
53------------------------------------------------------------------- 53-------------------------------------------------------------------
54 54
55instance (Show a, Element a) => (Show (Matrix a)) where 55instance (Show a, Storable a) => (Show (Matrix a)) where
56 show m | rows m == 0 || cols m == 0 = sizes m ++" []" 56 show m | rows m == 0 || cols m == 0 = sizes m ++" []"
57 show m = (sizes m++) . dsp . map (map show) . toLists $ m 57 show m = (sizes m++) . dsp . map (map show) . toLists $ m
58 58
@@ -70,7 +70,7 @@ dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unw
70 70
71------------------------------------------------------------------ 71------------------------------------------------------------------
72 72
73instance (Element a, Read a) => Read (Matrix a) where 73instance (Storable a, Read a) => Read (Matrix a) where
74 readsPrec _ s = [((rs><cs) . read $ listnums, rest)] 74 readsPrec _ s = [((rs><cs) . read $ listnums, rest)]
75 where (thing,rest) = breakAt ']' s 75 where (thing,rest) = breakAt ']' s
76 (dims,listnums) = breakAt ')' thing 76 (dims,listnums) = breakAt ')' thing
@@ -133,13 +133,13 @@ ppext (DropLast n) = printf "DropLast %d" n
133 133
134-} 134-}
135infixl 9 ?? 135infixl 9 ??
136(??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t 136(??) :: Storable t => Matrix t -> (Extractor,Extractor) -> Matrix t
137 137
138minEl :: Vector CInt -> CInt 138minEl :: Vector Int32 -> Int32
139minEl = toScalarI Min 139minEl = toScalarI Min
140maxEl :: Vector CInt -> CInt 140maxEl :: Vector Int32 -> Int32
141maxEl = toScalarI Max 141maxEl = toScalarI Max
142cmodi :: Foreign.C.Types.CInt -> Vector Foreign.C.Types.CInt -> Vector Foreign.C.Types.CInt 142cmodi :: Int32 -> Vector Int32 -> Vector Int32
143cmodi = vectorMapValI ModVS 143cmodi = vectorMapValI ModVS
144 144
145extractError :: Matrix t1 -> (Extractor, Extractor) -> t 145extractError :: Matrix t1 -> (Extractor, Extractor) -> t
@@ -181,7 +181,7 @@ m ?? (e, TakeLast n) = m ?? (e, Drop (cols m - n))
181m ?? (DropLast n, e) = m ?? (Take (rows m - n), e) 181m ?? (DropLast n, e) = m ?? (Take (rows m - n), e)
182m ?? (e, DropLast n) = m ?? (e, Take (cols m - n)) 182m ?? (e, DropLast n) = m ?? (e, Take (cols m - n))
183 183
184m ?? (er,ec) = unsafePerformIO $ extractR (orderOf m) m moder rs modec cs 184m ?? (er,ec) = unsafePerformIO $ extractAux (orderOf m) m moder rs modec cs
185 where 185 where
186 (moder,rs) = mkExt (rows m) er 186 (moder,rs) = mkExt (rows m) er
187 (modec,cs) = mkExt (cols m) ec 187 (modec,cs) = mkExt (cols m) ec
@@ -209,14 +209,14 @@ common f = commonval . map f
209 209
210 210
211-- | creates a matrix from a vertical list of matrices 211-- | creates a matrix from a vertical list of matrices
212joinVert :: Element t => [Matrix t] -> Matrix t 212joinVert :: Storable t => [Matrix t] -> Matrix t
213joinVert [] = emptyM 0 0 213joinVert [] = emptyM 0 0
214joinVert ms = case common cols ms of 214joinVert ms = case common cols ms of
215 Nothing -> error "(impossible) joinVert on matrices with different number of columns" 215 Nothing -> error "(impossible) joinVert on matrices with different number of columns"
216 Just c -> matrixFromVector RowMajor (sum (map rows ms)) c $ vjoin (map flatten ms) 216 Just c -> matrixFromVector RowMajor (sum (map rows ms)) c $ vjoin (map flatten ms)
217 217
218-- | creates a matrix from a horizontal list of matrices 218-- | creates a matrix from a horizontal list of matrices
219joinHoriz :: Element t => [Matrix t] -> Matrix t 219joinHoriz :: Storable t => [Matrix t] -> Matrix t
220joinHoriz ms = trans. joinVert . map trans $ ms 220joinHoriz ms = trans. joinVert . map trans $ ms
221 221
222{- | Create a matrix from blocks given as a list of lists of matrices. 222{- | Create a matrix from blocks given as a list of lists of matrices.
@@ -240,13 +240,13 @@ disp = putStr . dispf 2
2403 3 3 3 3 0 0 3 0 0 2403 3 3 3 3 0 0 3 0 0
241 241
242-} 242-}
243fromBlocks :: Element t => [[Matrix t]] -> Matrix t 243fromBlocks :: Storable t => [[Matrix t]] -> Matrix t
244fromBlocks = fromBlocksRaw . adaptBlocks 244fromBlocks = fromBlocksRaw . adaptBlocks
245 245
246fromBlocksRaw :: Element t => [[Matrix t]] -> Matrix t 246fromBlocksRaw :: Storable t => [[Matrix t]] -> Matrix t
247fromBlocksRaw mms = joinVert . map joinHoriz $ mms 247fromBlocksRaw mms = joinVert . map joinHoriz $ mms
248 248
249adaptBlocks :: Element t => [[Matrix t]] -> [[Matrix t]] 249adaptBlocks :: Storable t => [[Matrix t]] -> [[Matrix t]]
250adaptBlocks ms = ms' where 250adaptBlocks ms = ms' where
251 bc = case common length ms of 251 bc = case common length ms of
252 Just c -> c 252 Just c -> c
@@ -258,7 +258,7 @@ adaptBlocks ms = ms' where
258 258
259 g [Just nr,Just nc] m 259 g [Just nr,Just nc] m
260 | nr == r && nc == c = m 260 | nr == r && nc == c = m
261 | r == 1 && c == 1 = matrixFromVector RowMajor nr nc (constantD x (nr*nc)) 261 | r == 1 && c == 1 = matrixFromVector RowMajor nr nc (constantAux x (nr*nc))
262 | r == 1 = fromRows (replicate nr (flatten m)) 262 | r == 1 = fromRows (replicate nr (flatten m))
263 | otherwise = fromColumns (replicate nc (flatten m)) 263 | otherwise = fromColumns (replicate nc (flatten m))
264 where 264 where
@@ -288,7 +288,7 @@ adaptBlocks ms = ms' where
288 , 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0 ] 288 , 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0 ]
289 289
290-} 290-}
291diagBlock :: (Element t, Num t) => [Matrix t] -> Matrix t 291diagBlock :: (Storable t, Num t) => [Matrix t] -> Matrix t
292diagBlock ms = fromBlocks $ zipWith f ms [0..] 292diagBlock ms = fromBlocks $ zipWith f ms [0..]
293 where 293 where
294 f m k = take n $ replicate k z ++ m : repeat z 294 f m k = take n $ replicate k z ++ m : repeat z
@@ -299,13 +299,13 @@ diagBlock ms = fromBlocks $ zipWith f ms [0..]
299 299
300 300
301-- | Reverse rows 301-- | Reverse rows
302flipud :: Element t => Matrix t -> Matrix t 302flipud :: Storable t => Matrix t -> Matrix t
303flipud m = extractRows [r-1,r-2 .. 0] $ m 303flipud m = extractRows [r-1,r-2 .. 0] $ m
304 where 304 where
305 r = rows m 305 r = rows m
306 306
307-- | Reverse columns 307-- | Reverse columns
308fliprl :: Element t => Matrix t -> Matrix t 308fliprl :: Storable t => Matrix t -> Matrix t
309fliprl m = extractColumns [c-1,c-2 .. 0] $ m 309fliprl m = extractColumns [c-1,c-2 .. 0] $ m
310 where 310 where
311 c = cols m 311 c = cols m
@@ -330,7 +330,7 @@ diagRect z v r c = ST.runSTMatrix $ do
330 return m 330 return m
331 331
332-- | extracts the diagonal from a rectangular matrix 332-- | extracts the diagonal from a rectangular matrix
333takeDiag :: (Element t) => Matrix t -> Vector t 333takeDiag :: (Storable t) => Matrix t -> Vector t
334takeDiag m = fromList [flatten m @> (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] 334takeDiag m = fromList [flatten m @> (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]]
335 335
336------------------------------------------------------------ 336------------------------------------------------------------
@@ -363,32 +363,32 @@ r >< c = f where
363 363
364---------------------------------------------------------------- 364----------------------------------------------------------------
365 365
366takeRows :: Element t => Int -> Matrix t -> Matrix t 366takeRows :: Storable t => Int -> Matrix t -> Matrix t
367takeRows n mt = subMatrix (0,0) (n, cols mt) mt 367takeRows n mt = subMatrix (0,0) (n, cols mt) mt
368 368
369-- | Creates a matrix with the last n rows of another matrix 369-- | Creates a matrix with the last n rows of another matrix
370takeLastRows :: Element t => Int -> Matrix t -> Matrix t 370takeLastRows :: Storable t => Int -> Matrix t -> Matrix t
371takeLastRows n mt = subMatrix (rows mt - n, 0) (n, cols mt) mt 371takeLastRows n mt = subMatrix (rows mt - n, 0) (n, cols mt) mt
372 372
373dropRows :: Element t => Int -> Matrix t -> Matrix t 373dropRows :: Storable t => Int -> Matrix t -> Matrix t
374dropRows n mt = subMatrix (n,0) (rows mt - n, cols mt) mt 374dropRows n mt = subMatrix (n,0) (rows mt - n, cols mt) mt
375 375
376-- | Creates a copy of a matrix without the last n rows 376-- | Creates a copy of a matrix without the last n rows
377dropLastRows :: Element t => Int -> Matrix t -> Matrix t 377dropLastRows :: Storable t => Int -> Matrix t -> Matrix t
378dropLastRows n mt = subMatrix (0,0) (rows mt - n, cols mt) mt 378dropLastRows n mt = subMatrix (0,0) (rows mt - n, cols mt) mt
379 379
380takeColumns :: Element t => Int -> Matrix t -> Matrix t 380takeColumns :: Storable t => Int -> Matrix t -> Matrix t
381takeColumns n mt = subMatrix (0,0) (rows mt, n) mt 381takeColumns n mt = subMatrix (0,0) (rows mt, n) mt
382 382
383-- |Creates a matrix with the last n columns of another matrix 383-- |Creates a matrix with the last n columns of another matrix
384takeLastColumns :: Element t => Int -> Matrix t -> Matrix t 384takeLastColumns :: Storable t => Int -> Matrix t -> Matrix t
385takeLastColumns n mt = subMatrix (0, cols mt - n) (rows mt, n) mt 385takeLastColumns n mt = subMatrix (0, cols mt - n) (rows mt, n) mt
386 386
387dropColumns :: Element t => Int -> Matrix t -> Matrix t 387dropColumns :: Storable t => Int -> Matrix t -> Matrix t
388dropColumns n mt = subMatrix (0,n) (rows mt, cols mt - n) mt 388dropColumns n mt = subMatrix (0,n) (rows mt, cols mt - n) mt
389 389
390-- | Creates a copy of a matrix without the last n columns 390-- | Creates a copy of a matrix without the last n columns
391dropLastColumns :: Element t => Int -> Matrix t -> Matrix t 391dropLastColumns :: Storable t => Int -> Matrix t -> Matrix t
392dropLastColumns n mt = subMatrix (0,0) (rows mt, cols mt - n) mt 392dropLastColumns n mt = subMatrix (0,0) (rows mt, cols mt - n) mt
393 393
394---------------------------------------------------------------- 394----------------------------------------------------------------
@@ -402,7 +402,7 @@ dropLastColumns n mt = subMatrix (0,0) (rows mt, cols mt - n) mt
402 , 5.0, 6.0 ] 402 , 5.0, 6.0 ]
403 403
404-} 404-}
405fromLists :: Element t => [[t]] -> Matrix t 405fromLists :: Storable t => [[t]] -> Matrix t
406fromLists = fromRows . map fromList 406fromLists = fromRows . map fromList
407 407
408-- | creates a 1-row matrix from a vector 408-- | creates a 1-row matrix from a vector
@@ -443,7 +443,7 @@ Hilbert matrix of order N:
443@hilb n = buildMatrix n n (\\(i,j)->1/(fromIntegral i + fromIntegral j +1))@ 443@hilb n = buildMatrix n n (\\(i,j)->1/(fromIntegral i + fromIntegral j +1))@
444 444
445-} 445-}
446buildMatrix :: Element a => Int -> Int -> ((Int, Int) -> a) -> Matrix a 446buildMatrix :: Storable a => Int -> Int -> ((Int, Int) -> a) -> Matrix a
447buildMatrix rc cc f = 447buildMatrix rc cc f =
448 fromLists $ map (map f) 448 fromLists $ map (map f)
449 $ map (\ ri -> map (\ ci -> (ri, ci)) [0 .. (cc - 1)]) [0 .. (rc - 1)] 449 $ map (\ ri -> map (\ ci -> (ri, ci)) [0 .. (cc - 1)]) [0 .. (rc - 1)]
@@ -458,11 +458,11 @@ fromArray2D m = (r><c) (elems m)
458 458
459 459
460-- | rearranges the rows of a matrix according to the order given in a list of integers. 460-- | rearranges the rows of a matrix according to the order given in a list of integers.
461extractRows :: Element t => [Int] -> Matrix t -> Matrix t 461extractRows :: Storable t => [Int] -> Matrix t -> Matrix t
462extractRows l m = m ?? (Pos (idxs l), All) 462extractRows l m = m ?? (Pos (idxs l), All)
463 463
464-- | rearranges the rows of a matrix according to the order given in a list of integers. 464-- | rearranges the rows of a matrix according to the order given in a list of integers.
465extractColumns :: Element t => [Int] -> Matrix t -> Matrix t 465extractColumns :: Storable t => [Int] -> Matrix t -> Matrix t
466extractColumns l m = m ?? (All, Pos (idxs l)) 466extractColumns l m = m ?? (All, Pos (idxs l))
467 467
468 468
@@ -476,13 +476,13 @@ extractColumns l m = m ?? (All, Pos (idxs l))
476 , 0.0, 1.0, 0.0, 1.0, 0.0, 1.0 ] 476 , 0.0, 1.0, 0.0, 1.0, 0.0, 1.0 ]
477 477
478-} 478-}
479repmat :: (Element t) => Matrix t -> Int -> Int -> Matrix t 479repmat :: (Storable t) => Matrix t -> Int -> Int -> Matrix t
480repmat m r c 480repmat m r c
481 | r == 0 || c == 0 = emptyM (r*rows m) (c*cols m) 481 | r == 0 || c == 0 = emptyM (r*rows m) (c*cols m)
482 | otherwise = fromBlocks $ replicate r $ replicate c $ m 482 | otherwise = fromBlocks $ replicate r $ replicate c $ m
483 483
484-- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix. 484-- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix.
485liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t 485liftMatrix2Auto :: (Storable t, Storable a, Storable b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
486liftMatrix2Auto f m1 m2 486liftMatrix2Auto f m1 m2
487 | compat' m1 m2 = lM f m1 m2 487 | compat' m1 m2 = lM f m1 m2
488 | ok = lM f m1' m2' 488 | ok = lM f m1' m2'
@@ -499,7 +499,7 @@ liftMatrix2Auto f m1 m2
499 m2' = conformMTo (r,c) m2 499 m2' = conformMTo (r,c) m2
500 500
501-- FIXME do not flatten if equal order 501-- FIXME do not flatten if equal order
502lM :: (Storable t, Element t1, Element t2) 502lM :: (Storable t, Storable t1, Storable t2)
503 => (Vector t1 -> Vector t2 -> Vector t) 503 => (Vector t1 -> Vector t2 -> Vector t)
504 -> Matrix t1 -> Matrix t2 -> Matrix t 504 -> Matrix t1 -> Matrix t2 -> Matrix t
505lM f m1 m2 = matrixFromVector 505lM f m1 m2 = matrixFromVector
@@ -520,7 +520,7 @@ compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2
520 520
521------------------------------------------------------------ 521------------------------------------------------------------
522 522
523toBlockRows :: Element t => [Int] -> Matrix t -> [Matrix t] 523toBlockRows :: Storable t => [Int] -> Matrix t -> [Matrix t]
524toBlockRows [r] m 524toBlockRows [r] m
525 | r == rows m = [m] 525 | r == rows m = [m]
526toBlockRows rs m 526toBlockRows rs m
@@ -530,13 +530,13 @@ toBlockRows rs m
530 szs = map (* cols m) rs 530 szs = map (* cols m) rs
531 g k = (k><0)[] 531 g k = (k><0)[]
532 532
533toBlockCols :: Element t => [Int] -> Matrix t -> [Matrix t] 533toBlockCols :: Storable t => [Int] -> Matrix t -> [Matrix t]
534toBlockCols [c] m | c == cols m = [m] 534toBlockCols [c] m | c == cols m = [m]
535toBlockCols cs m = map trans . toBlockRows cs . trans $ m 535toBlockCols cs m = map trans . toBlockRows cs . trans $ m
536 536
537-- | Partition a matrix into blocks with the given numbers of rows and columns. 537-- | Partition a matrix into blocks with the given numbers of rows and columns.
538-- The remaining rows and columns are discarded. 538-- The remaining rows and columns are discarded.
539toBlocks :: (Element t) => [Int] -> [Int] -> Matrix t -> [[Matrix t]] 539toBlocks :: (Storable t) => [Int] -> [Int] -> Matrix t -> [[Matrix t]]
540toBlocks rs cs m 540toBlocks rs cs m
541 | ok = map (toBlockCols cs) . toBlockRows rs $ m 541 | ok = map (toBlockCols cs) . toBlockRows rs $ m
542 | otherwise = error $ "toBlocks: bad partition: "++show rs++" "++show cs 542 | otherwise = error $ "toBlocks: bad partition: "++show rs++" "++show cs
@@ -546,7 +546,7 @@ toBlocks rs cs m
546 546
547-- | Fully partition a matrix into blocks of the same size. If the dimensions are not 547-- | Fully partition a matrix into blocks of the same size. If the dimensions are not
548-- a multiple of the given size the last blocks will be smaller. 548-- a multiple of the given size the last blocks will be smaller.
549toBlocksEvery :: (Element t) => Int -> Int -> Matrix t -> [[Matrix t]] 549toBlocksEvery :: (Storable t) => Int -> Int -> Matrix t -> [[Matrix t]]
550toBlocksEvery r c m 550toBlocksEvery r c m
551 | r < 1 || c < 1 = error $ "toBlocksEvery expects block sizes > 0, given "++show r++" and "++ show c 551 | r < 1 || c < 1 = error $ "toBlocksEvery expects block sizes > 0, given "++show r++" and "++ show c
552 | otherwise = toBlocks rs cs m 552 | otherwise = toBlocks rs cs m
@@ -576,7 +576,7 @@ m[1,2] = 6
576 576
577-} 577-}
578mapMatrixWithIndexM_ 578mapMatrixWithIndexM_
579 :: (Element a, Num a, Monad m) => 579 :: (Storable a, Num a, Monad m) =>
580 ((Int, Int) -> a -> m ()) -> Matrix a -> m () 580 ((Int, Int) -> a -> m ()) -> Matrix a -> m ()
581mapMatrixWithIndexM_ g m = mapVectorWithIndexM_ (mk c g) . flatten $ m 581mapMatrixWithIndexM_ g m = mapVectorWithIndexM_ (mk c g) . flatten $ m
582 where 582 where
@@ -592,7 +592,7 @@ Just (3><3)
592 592
593-} 593-}
594mapMatrixWithIndexM 594mapMatrixWithIndexM
595 :: (Element a, Storable b, Monad m) => 595 :: (Storable a, Storable b, Monad m) =>
596 ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b) 596 ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b)
597mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m 597mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m
598 where 598 where
@@ -608,11 +608,11 @@ mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . fla
608 608
609 -} 609 -}
610mapMatrixWithIndex 610mapMatrixWithIndex
611 :: (Element a, Storable b) => 611 :: (Storable a, Storable b) =>
612 ((Int, Int) -> a -> b) -> Matrix a -> Matrix b 612 ((Int, Int) -> a -> b) -> Matrix a -> Matrix b
613mapMatrixWithIndex g m = reshape c . mapVectorWithIndex (mk c g) . flatten $ m 613mapMatrixWithIndex g m = reshape c . mapVectorWithIndex (mk c g) . flatten $ m
614 where 614 where
615 c = cols m 615 c = cols m
616 616
617mapMatrix :: (Element a, Element b) => (a -> b) -> Matrix a -> Matrix b 617mapMatrix :: (Storable a, Storable b) => (a -> b) -> Matrix a -> Matrix b
618mapMatrix f = liftMatrix (mapVector f) 618mapMatrix f = liftMatrix (mapVector f)
diff --git a/packages/base/src/Internal/Extract.hs b/packages/base/src/Internal/Extract.hs
new file mode 100644
index 0000000..84ee20f
--- /dev/null
+++ b/packages/base/src/Internal/Extract.hs
@@ -0,0 +1,145 @@
1{-# LANGUAGE BangPatterns #-}
2{-# LANGUAGE NondecreasingIndentation #-}
3{-# LANGUAGE PatternSynonyms #-}
4{-# LANGUAGE UnboxedTuples #-}
5module Internal.Extract where
6import Control.Monad
7import Data.Complex
8import Data.Function
9import Data.Int
10import Foreign.Ptr
11import Foreign.Storable
12
13type ConstPtr a = Ptr a
14pattern ConstPtr a = a
15
16extractStorable :: Storable t =>
17 Int32 -- int modei
18 -> Int32 -- int modej
19 -> Int32 -- / KIVEC(i)
20 -> ConstPtr Int32 -- \
21 -> Int32 -- / KIVEC(j)
22 -> ConstPtr Int32 -- \
23 -> Int32 -- /
24 -> Int32 -- /
25 -> Int32 -- { KO##T##MAT(m)
26 -> Int32 -- \
27 -> ConstPtr t -- \
28 -> Int32 -- /
29 -> Int32 -- /
30 -> Int32 -- { O##T##MAT(r)
31 -> Int32 -- \
32 -> Ptr t -- \
33 -> IO Int32
34extractStorable modei
35 modej
36 in_ (ConstPtr ip)
37 jn (ConstPtr jp)
38 mr mc mXr mXc (ConstPtr mp)
39 rr rc rXr rXc rp = do
40 -- int i,j,si,sj,ni,nj;
41 ni <- if modei/=0 then return in_
42 else fmap succ $ (-) <$> peekElemOff ip 1 <*> peekElemOff ip 0
43 nj <- if modej/=0 then return jn
44 else fmap succ $ (-) <$> peekElemOff jp 1 <*> peekElemOff jp 0
45 ($ 0) $ fix $ \iloop i -> when (i<ni) $ do
46 si <- if modei/=0 then peekElemOff ip (fromIntegral i)
47 else (+ i) <$> peek ip
48 ($ 0) $ fix $ \jloop j -> when (j<nj) $ do
49 sj <- if modej/=0 then peekElemOff jp (fromIntegral j)
50 else (+ j) <$> peek jp
51 pokeElemOff rp (fromIntegral $ i*rXr + j*rXc)
52 =<< peekElemOff mp (fromIntegral $ si*mXr + sj*mXc)
53 jloop $! succ j
54 iloop $! succ i
55 return 0
56
57{-# SPECIALIZE extractStorable ::
58 Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32
59 -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr Double
60 -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Double
61 -> IO Int32 #-}
62
63{-# SPECIALIZE extractStorable ::
64 Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32
65 -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr Float
66 -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Float
67 -> IO Int32 #-}
68
69{-# SPECIALIZE extractStorable ::
70 Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32
71 -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr (Complex Double)
72 -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr (Complex Double)
73 -> IO Int32 #-}
74
75{-# SPECIALIZE extractStorable ::
76 Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32
77 -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr (Complex Float)
78 -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr (Complex Float)
79 -> IO Int32 #-}
80
81{-# SPECIALIZE extractStorable ::
82 Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32
83 -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr Int32
84 -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Int32
85 -> IO Int32 #-}
86
87{-# SPECIALIZE extractStorable ::
88 Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32
89 -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr Int64
90 -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Int64
91 -> IO Int32 #-}
92
93{-
94type Reorder x = CV Int32 (CV Int32 (CV Int32 (CV x (CV x (IO Int32)))))
95
96foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double
97foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float
98foreign import ccall unsafe "reorderI" c_reorderI :: Reorder Int32
99foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double)
100foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float)
101foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z
102-}
103
104-- #define ERROR(CODE) MACRO(return CODE;)
105-- #define REQUIRES(COND, CODE) MACRO(if(!(COND)) {ERROR(CODE);})
106
107requires :: Monad m => Bool -> Int32 -> m Int32 -> m Int32
108requires cond code go =
109 if cond then go
110 else return code
111
112pattern BAD_SIZE = 2000
113
114reorderStorable :: Storable a =>
115 Int32 -> Ptr Int32 -- k
116 -> Int32 -> ConstPtr Int32 -- strides
117 -> Int32 -> ConstPtr Int32 -- dims
118 -> Int32 -> ConstPtr a -- v
119 -> Int32 -> Ptr a -- r
120 -> IO Int32
121reorderStorable kn kp stridesn stridesp dimsn dimsp vn vp rn rp = do
122 requires (kn == stridesn && stridesn == dimsn) BAD_SIZE $ do
123 let ijlloop !i !j l fin = do
124 pokeElemOff kp (fromIntegral l) 0
125 dimspl <- peekElemOff dimsp (fromIntegral l)
126 stridespl <- peekElemOff stridesp (fromIntegral l)
127 if (l<kn) then ijlloop (i * dimspl) (j + stridespl*(dimspl - 1)) (l + 1) fin
128 else fin i j
129 ijlloop 1 0 0 $ \i j -> do
130 requires (i <= vn && j < rn) BAD_SIZE $ do
131 (\go -> go 0 0) $ fix $ \ijloop i j -> do
132 pokeElemOff rp (fromIntegral i) =<< peekElemOff vp (fromIntegral j)
133 (\go -> go (kn - 1) j) $ fix $ \lloop l !j -> do
134 kpl <- succ <$> peekElemOff kp (fromIntegral l)
135 pokeElemOff kp (fromIntegral l) kpl
136 dimspl <- peekElemOff dimsp (fromIntegral l)
137 if (kpl < dimspl)
138 then do
139 stridespl <- peekElemOff stridesp (fromIntegral l)
140 ijloop (succ i) (j + stridespl)
141 else do
142 if l == 0 then return 0 else do
143 pokeElemOff kp (fromIntegral l) 0
144 stridespl <- peekElemOff stridesp (fromIntegral l)
145 lloop (pred l) (j - stridespl*(dimspl-1))
diff --git a/packages/base/src/Internal/IO.hs b/packages/base/src/Internal/IO.hs
index b0f5606..de5eea5 100644
--- a/packages/base/src/Internal/IO.hs
+++ b/packages/base/src/Internal/IO.hs
@@ -23,6 +23,7 @@ import Internal.Vectorized
23import Text.Printf(printf, PrintfArg, PrintfType) 23import Text.Printf(printf, PrintfArg, PrintfType)
24import Data.List(intersperse,transpose) 24import Data.List(intersperse,transpose)
25import Data.Complex 25import Data.Complex
26import Foreign.Storable
26 27
27 28
28-- | Formatting tool 29-- | Formatting tool
@@ -45,7 +46,7 @@ this function the user can easily define any desired display function:
45@disp = putStr . format \" \" (printf \"%.2f\")@ 46@disp = putStr . format \" \" (printf \"%.2f\")@
46 47
47-} 48-}
48format :: (Element t) => String -> (t -> String) -> Matrix t -> String 49format :: (Storable t) => String -> (t -> String) -> Matrix t -> String
49format sep f m = table sep . map (map f) . toLists $ m 50format sep f m = table sep . map (map f) . toLists $ m
50 51
51{- | Show a matrix with \"autoscaling\" and a given number of decimal places. 52{- | Show a matrix with \"autoscaling\" and a given number of decimal places.
@@ -81,14 +82,14 @@ dispf d x = sdims x ++ "\n" ++ formatFixed (if isInt x then 0 else d) x
81sdims :: Matrix t -> [Char] 82sdims :: Matrix t -> [Char]
82sdims x = show (rows x) ++ "x" ++ show (cols x) 83sdims x = show (rows x) ++ "x" ++ show (cols x)
83 84
84formatFixed :: (Show a, Text.Printf.PrintfArg t, Element t) 85formatFixed :: (Show a, Text.Printf.PrintfArg t, Storable t)
85 => a -> Matrix t -> String 86 => a -> Matrix t -> String
86formatFixed d x = format " " (printf ("%."++show d++"f")) $ x 87formatFixed d x = format " " (printf ("%."++show d++"f")) $ x
87 88
88isInt :: Matrix Double -> Bool 89isInt :: Matrix Double -> Bool
89isInt = all lookslikeInt . toList . flatten 90isInt = all lookslikeInt . toList . flatten
90 91
91formatScaled :: (Text.Printf.PrintfArg b, RealFrac b, Floating b, Num t, Element b, Show t) 92formatScaled :: (Text.Printf.PrintfArg b, RealFrac b, Floating b, Num t, Storable b, Show t)
92 => t -> Matrix b -> [Char] 93 => t -> Matrix b -> [Char]
93formatScaled dec t = "E"++show o++"\n" ++ ss 94formatScaled dec t = "E"++show o++"\n" ++ ss
94 where ss = format " " (printf fmt. g) t 95 where ss = format " " (printf fmt. g) t
@@ -104,7 +105,7 @@ formatScaled dec t = "E"++show o++"\n" ++ ss
10410 |> 0.00 0.11 0.22 0.33 0.44 0.56 0.67 0.78 0.89 1.00 10510 |> 0.00 0.11 0.22 0.33 0.44 0.56 0.67 0.78 0.89 1.00
105 106
106-} 107-}
107vecdisp :: (Element t) => (Matrix t -> String) -> Vector t -> String 108vecdisp :: (Storable t) => (Matrix t -> String) -> Vector t -> String
108vecdisp f v 109vecdisp f v
109 = ((show (dim v) ++ " |> ") ++) . (++"\n") 110 = ((show (dim v) ++ " |> ") ++) . (++"\n")
110 . unwords . lines . tail . dropWhile (not . (`elem` " \n")) 111 . unwords . lines . tail . dropWhile (not . (`elem` " \n"))
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs
index 27d1f95..d88ff6b 100644
--- a/packages/base/src/Internal/LAPACK.hs
+++ b/packages/base/src/Internal/LAPACK.hs
@@ -22,9 +22,12 @@ import Data.Bifunctor (first)
22 22
23import Internal.Devel 23import Internal.Devel
24import Internal.Vector 24import Internal.Vector
25import Internal.Vectorized (constantAux)
25import Internal.Matrix hiding ((#), (#!)) 26import Internal.Matrix hiding ((#), (#!))
26import Internal.Conversion 27import Internal.Conversion
27import Internal.Element 28import Internal.Element
29import Internal.ST (setRect)
30import Data.Int
28import Foreign.Ptr(nullPtr) 31import Foreign.Ptr(nullPtr)
29import Foreign.C.Types 32import Foreign.C.Types
30import Control.Monad(when) 33import Control.Monad(when)
@@ -46,10 +49,10 @@ type TMMM t = t ::> t ::> t ::> Ok
46type F = Float 49type F = Float
47type Q = Complex Float 50type Q = Complex Float
48 51
49foreign import ccall unsafe "multiplyR" dgemmc :: CInt -> CInt -> TMMM R 52foreign import ccall unsafe "multiplyR" dgemmc :: Int32 -> Int32 -> TMMM R
50foreign import ccall unsafe "multiplyC" zgemmc :: CInt -> CInt -> TMMM C 53foreign import ccall unsafe "multiplyC" zgemmc :: Int32 -> Int32 -> TMMM C
51foreign import ccall unsafe "multiplyF" sgemmc :: CInt -> CInt -> TMMM F 54foreign import ccall unsafe "multiplyF" sgemmc :: Int32 -> Int32 -> TMMM F
52foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TMMM Q 55foreign import ccall unsafe "multiplyQ" cgemmc :: Int32 -> Int32 -> TMMM Q
53foreign import ccall unsafe "multiplyI" c_multiplyI :: I -> TMMM I 56foreign import ccall unsafe "multiplyI" c_multiplyI :: I -> TMMM I
54foreign import ccall unsafe "multiplyL" c_multiplyL :: Z -> TMMM Z 57foreign import ccall unsafe "multiplyL" c_multiplyL :: Z -> TMMM Z
55 58
@@ -82,7 +85,7 @@ multiplyF a b = multiplyAux sgemmc "sgemmc" a b
82multiplyQ :: Matrix (Complex Float) -> Matrix (Complex Float) -> Matrix (Complex Float) 85multiplyQ :: Matrix (Complex Float) -> Matrix (Complex Float) -> Matrix (Complex Float)
83multiplyQ a b = multiplyAux cgemmc "cgemmc" a b 86multiplyQ a b = multiplyAux cgemmc "cgemmc" a b
84 87
85multiplyI :: I -> Matrix CInt -> Matrix CInt -> Matrix CInt 88multiplyI :: I -> Matrix Int32 -> Matrix Int32 -> Matrix Int32
86multiplyI m a b = unsafePerformIO $ do 89multiplyI m a b = unsafePerformIO $ do
87 when (cols a /= rows b) $ error $ 90 when (cols a /= rows b) $ error $
88 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b 91 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b
@@ -239,8 +242,8 @@ foreign import ccall unsafe "eig_l_R" dgeev :: R ::> R ::> C :> R ::> Ok
239foreign import ccall unsafe "eig_l_G" dggev :: R ::> R ::> C :> R :> R ::> R ::> Ok 242foreign import ccall unsafe "eig_l_G" dggev :: R ::> R ::> C :> R :> R ::> R ::> Ok
240foreign import ccall unsafe "eig_l_C" zgeev :: C ::> C ::> C :> C ::> Ok 243foreign import ccall unsafe "eig_l_C" zgeev :: C ::> C ::> C :> C ::> Ok
241foreign import ccall unsafe "eig_l_GC" zggev :: C ::> C ::> C :> C :> C ::> C ::> Ok 244foreign import ccall unsafe "eig_l_GC" zggev :: C ::> C ::> C :> C :> C ::> C ::> Ok
242foreign import ccall unsafe "eig_l_S" dsyev :: CInt -> R :> R ::> Ok 245foreign import ccall unsafe "eig_l_S" dsyev :: Int32 -> R :> R ::> Ok
243foreign import ccall unsafe "eig_l_H" zheev :: CInt -> R :> C ::> Ok 246foreign import ccall unsafe "eig_l_H" zheev :: Int32 -> R :> C ::> Ok
244 247
245eigAux f st m = unsafePerformIO $ do 248eigAux f st m = unsafePerformIO $ do
246 a <- copy ColumnMajor m 249 a <- copy ColumnMajor m
@@ -636,7 +639,7 @@ qrgrAux f st n (a, tau) = unsafePerformIO $ do
636 ((subVector 0 n tau') #! res) f #| st 639 ((subVector 0 n tau') #! res) f #| st
637 return res 640 return res
638 where 641 where
639 tau' = vjoin [tau, constantD 0 n] 642 tau' = vjoin [tau, constantAux 0 n]
640 643
641----------------------------------------------------------------------------------- 644-----------------------------------------------------------------------------------
642foreign import ccall unsafe "hess_l_R" dgehrd :: R :> R ::> Ok 645foreign import ccall unsafe "hess_l_R" dgehrd :: R :> R ::> Ok
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index 5436e59..04092f9 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -2,6 +2,7 @@
2{-# LANGUAGE FlexibleContexts #-} 2{-# LANGUAGE FlexibleContexts #-}
3{-# LANGUAGE FlexibleInstances #-} 3{-# LANGUAGE FlexibleInstances #-}
4{-# LANGUAGE BangPatterns #-} 4{-# LANGUAGE BangPatterns #-}
5{-# LANGUAGE CPP #-}
5{-# LANGUAGE TypeOperators #-} 6{-# LANGUAGE TypeOperators #-}
6{-# LANGUAGE TypeFamilies #-} 7{-# LANGUAGE TypeFamilies #-}
7{-# LANGUAGE ViewPatterns #-} 8{-# LANGUAGE ViewPatterns #-}
@@ -22,12 +23,14 @@ module Internal.Matrix where
22 23
23import Internal.Vector 24import Internal.Vector
24import Internal.Devel 25import Internal.Devel
26import Internal.Extract
25import Internal.Vectorized hiding ((#), (#!)) 27import Internal.Vectorized hiding ((#), (#!))
26import Foreign.Marshal.Alloc ( free ) 28import Foreign.Marshal.Alloc ( free )
27import Foreign.Marshal.Array(newArray) 29import Foreign.Marshal.Array(newArray)
28import Foreign.Ptr ( Ptr ) 30import Foreign.Ptr ( Ptr )
29import Foreign.Storable ( Storable ) 31import Foreign.Storable ( Storable )
30import Data.Complex ( Complex ) 32import Data.Complex ( Complex )
33import Data.Int
31import Foreign.C.Types ( CInt(..) ) 34import Foreign.C.Types ( CInt(..) )
32import Foreign.C.String ( CString, newCString ) 35import Foreign.C.String ( CString, newCString )
33import System.IO.Unsafe ( unsafePerformIO ) 36import System.IO.Unsafe ( unsafePerformIO )
@@ -61,19 +64,23 @@ size :: Matrix t -> (Int, Int)
61size m = (irows m, icols m) 64size m = (irows m, icols m)
62{-# INLINE size #-} 65{-# INLINE size #-}
63 66
67-- | True if the matrix is in RowMajor form.
64rowOrder :: Matrix t -> Bool 68rowOrder :: Matrix t -> Bool
65rowOrder m = xCol m == 1 || cols m == 1 69rowOrder m = xCol m == 1 || cols m == 1
66{-# INLINE rowOrder #-} 70{-# INLINE rowOrder #-}
67 71
72-- | True if the matrix is in ColMajor form or if their is only one row.
68colOrder :: Matrix t -> Bool 73colOrder :: Matrix t -> Bool
69colOrder m = xRow m == 1 || rows m == 1 74colOrder m = xRow m == 1 || rows m == 1
70{-# INLINE colOrder #-} 75{-# INLINE colOrder #-}
71 76
77-- | True if the matrix is a single row or column vector.
72is1d :: Matrix t -> Bool 78is1d :: Matrix t -> Bool
73is1d (size->(r,c)) = r==1 || c==1 79is1d (size->(r,c)) = r==1 || c==1
74{-# INLINE is1d #-} 80{-# INLINE is1d #-}
75 81
76-- data is not contiguous 82-- | True if the matrix is not contiguous. This usually
83-- means it is a slice of some larger matrix.
77isSlice :: Storable t => Matrix t -> Bool 84isSlice :: Storable t => Matrix t -> Bool
78isSlice m@(size->(r,c)) = r*c < dim (xdat m) 85isSlice m@(size->(r,c)) = r*c < dim (xdat m)
79{-# INLINE isSlice #-} 86{-# INLINE isSlice #-}
@@ -95,19 +102,23 @@ showInternal m = printf "%dx%d %s %s %d:%d (%d)\n" r c slc ord xr xc dv
95 102
96-------------------------------------------------------------------------------- 103--------------------------------------------------------------------------------
97 104
98-- | Matrix transpose. 105-- | O(1) Matrix transpose. This is only a logical transposition that does not
106-- re-order the element storage. If the storage order is important, use 'cmat'
107-- or 'fmat'.
99trans :: Matrix t -> Matrix t 108trans :: Matrix t -> Matrix t
100trans m@Matrix { irows = r, icols = c, xRow = xr, xCol = xc } = 109trans m@Matrix { irows = r, icols = c, xRow = xr, xCol = xc } =
101 m { irows = c, icols = r, xRow = xc, xCol = xr } 110 m { irows = c, icols = r, xRow = xc, xCol = xr }
102 111
103 112
104cmat :: (Element t) => Matrix t -> Matrix t 113-- | Obtain the RowMajor equivalent of a given Matrix.
114cmat :: (Storable t) => Matrix t -> Matrix t
105cmat m 115cmat m
106 | rowOrder m = m 116 | rowOrder m = m
107 | otherwise = extractAll RowMajor m 117 | otherwise = extractAll RowMajor m
108 118
109 119
110fmat :: (Element t) => Matrix t -> Matrix t 120-- | Obtain the ColumnMajor equivalent of a given Matrix.
121fmat :: (Storable t) => Matrix t -> Matrix t
111fmat m 122fmat m
112 | colOrder m = m 123 | colOrder m = m
113 | otherwise = extractAll ColumnMajor m 124 | otherwise = extractAll ColumnMajor m
@@ -115,14 +126,14 @@ fmat m
115 126
116-- C-Haskell matrix adapters 127-- C-Haskell matrix adapters
117{-# INLINE amatr #-} 128{-# INLINE amatr #-}
118amatr :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> Ptr a -> f) -> IO r 129amatr :: Storable a => Matrix a -> (f -> IO r) -> (Int32 -> Int32 -> Ptr a -> f) -> IO r
119amatr x f g = unsafeWith (xdat x) (f . g r c) 130amatr x f g = unsafeWith (xdat x) (f . g r c)
120 where 131 where
121 r = fi (rows x) 132 r = fi (rows x)
122 c = fi (cols x) 133 c = fi (cols x)
123 134
124{-# INLINE amat #-} 135{-# INLINE amat #-}
125amat :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> CInt -> CInt -> Ptr a -> f) -> IO r 136amat :: Storable a => Matrix a -> (f -> IO r) -> (Int32 -> Int32 -> Int32 -> Int32 -> Ptr a -> f) -> IO r
126amat x f g = unsafeWith (xdat x) (f . g r c sr sc) 137amat x f g = unsafeWith (xdat x) (f . g r c sr sc)
127 where 138 where
128 r = fi (rows x) 139 r = fi (rows x)
@@ -133,8 +144,8 @@ amat x f g = unsafeWith (xdat x) (f . g r c sr sc)
133 144
134instance Storable t => TransArray (Matrix t) 145instance Storable t => TransArray (Matrix t)
135 where 146 where
136 type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b 147 type TransRaw (Matrix t) b = Int32 -> Int32 -> Ptr t -> b
137 type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b 148 type Trans (Matrix t) b = Int32 -> Int32 -> Int32 -> Int32 -> Ptr t -> b
138 apply = amat 149 apply = amat
139 {-# INLINE apply #-} 150 {-# INLINE apply #-}
140 applyRaw = amatr 151 applyRaw = amatr
@@ -151,10 +162,10 @@ a #! b = a # b # id
151 162
152-------------------------------------------------------------------------------- 163--------------------------------------------------------------------------------
153 164
154copy :: Element t => MatrixOrder -> Matrix t -> IO (Matrix t) 165copy :: Storable t => MatrixOrder -> Matrix t -> IO (Matrix t)
155copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) 166copy ord m = extractAux ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1])
156 167
157extractAll :: Element t => MatrixOrder -> Matrix t -> Matrix t 168extractAll :: Storable t => MatrixOrder -> Matrix t -> Matrix t
158extractAll ord m = unsafePerformIO (copy ord m) 169extractAll ord m = unsafePerformIO (copy ord m)
159 170
160{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. 171{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose.
@@ -164,14 +175,14 @@ extractAll ord m = unsafePerformIO (copy ord m)
164it :: (Num t, Element t) => Vector t 175it :: (Num t, Element t) => Vector t
165 176
166-} 177-}
167flatten :: Element t => Matrix t -> Vector t 178flatten :: Storable t => Matrix t -> Vector t
168flatten m 179flatten m
169 | isSlice m || not (rowOrder m) = xdat (extractAll RowMajor m) 180 | isSlice m || not (rowOrder m) = xdat (extractAll RowMajor m)
170 | otherwise = xdat m 181 | otherwise = xdat m
171 182
172 183
173-- | the inverse of 'Data.Packed.Matrix.fromLists' 184-- | the inverse of 'Data.Packed.Matrix.fromLists'
174toLists :: (Element t) => Matrix t -> [[t]] 185toLists :: (Storable t) => Matrix t -> [[t]]
175toLists = map toList . toRows 186toLists = map toList . toRows
176 187
177 188
@@ -192,7 +203,7 @@ compatdim (a:b:xs)
192-- | Create a matrix from a list of vectors. 203-- | Create a matrix from a list of vectors.
193-- All vectors must have the same dimension, 204-- All vectors must have the same dimension,
194-- or dimension 1, which is are automatically expanded. 205-- or dimension 1, which is are automatically expanded.
195fromRows :: Element t => [Vector t] -> Matrix t 206fromRows :: Storable t => [Vector t] -> Matrix t
196fromRows [] = emptyM 0 0 207fromRows [] = emptyM 0 0
197fromRows vs = case compatdim (map dim vs) of 208fromRows vs = case compatdim (map dim vs) of
198 Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs) 209 Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs)
@@ -203,25 +214,25 @@ fromRows vs = case compatdim (map dim vs) of
203 adapt c v 214 adapt c v
204 | c == 0 = fromList[] 215 | c == 0 = fromList[]
205 | dim v == c = v 216 | dim v == c = v
206 | otherwise = constantD (v@>0) c 217 | otherwise = constantAux (v@>0) c
207 218
208-- | extracts the rows of a matrix as a list of vectors 219-- | extracts the rows of a matrix as a list of vectors
209toRows :: Element t => Matrix t -> [Vector t] 220toRows :: Storable t => Matrix t -> [Vector t]
210toRows m 221toRows m
211 | rowOrder m = map sub rowRange 222 | rowOrder m = map sub rowRange
212 | otherwise = map ext rowRange 223 | otherwise = map ext rowRange
213 where 224 where
214 rowRange = [0..rows m-1] 225 rowRange = [0..rows m-1]
215 sub k = subVector (k*xRow m) (cols m) (xdat m) 226 sub k = subVector (k*xRow m) (cols m) (xdat m)
216 ext k = xdat $ unsafePerformIO $ extractR RowMajor m 1 (idxs[k]) 0 (idxs[0,cols m-1]) 227 ext k = xdat $ unsafePerformIO $ extractAux RowMajor m 1 (idxs[k]) 0 (idxs[0,cols m-1])
217 228
218 229
219-- | Creates a matrix from a list of vectors, as columns 230-- | Creates a matrix from a list of vectors, as columns
220fromColumns :: Element t => [Vector t] -> Matrix t 231fromColumns :: Storable t => [Vector t] -> Matrix t
221fromColumns m = trans . fromRows $ m 232fromColumns m = trans . fromRows $ m
222 233
223-- | Creates a list of vectors from the columns of a matrix 234-- | Creates a list of vectors from the columns of a matrix
224toColumns :: Element t => Matrix t -> [Vector t] 235toColumns :: Storable t => Matrix t -> [Vector t]
225toColumns m = toRows . trans $ m 236toColumns m = toRows . trans $ m
226 237
227-- | Reads a matrix position. 238-- | Reads a matrix position.
@@ -271,13 +282,13 @@ reshape c v = matrixFromVector RowMajor (dim v `div` c) c v
271 282
272 283
273-- | application of a vector function on the flattened matrix elements 284-- | application of a vector function on the flattened matrix elements
274liftMatrix :: (Element a, Element b) => (Vector a -> Vector b) -> Matrix a -> Matrix b 285liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
275liftMatrix f m@Matrix { irows = r, icols = c, xdat = d} 286liftMatrix f m@Matrix { irows = r, icols = c, xdat = d}
276 | isSlice m = matrixFromVector RowMajor r c (f (flatten m)) 287 | isSlice m = matrixFromVector RowMajor r c (f (flatten m))
277 | otherwise = matrixFromVector (orderOf m) r c (f d) 288 | otherwise = matrixFromVector (orderOf m) r c (f d)
278 289
279-- | application of a vector function on the flattened matrices elements 290-- | application of a vector function on the flattened matrices elements
280liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t 291liftMatrix2 :: (Storable t, Storable a, Storable b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
281liftMatrix2 f m1@(size->(r,c)) m2 292liftMatrix2 f m1@(size->(r,c)) m2
282 | (r,c)/=size m2 = error "nonconformant matrices in liftMatrix2" 293 | (r,c)/=size m2 = error "nonconformant matrices in liftMatrix2"
283 | rowOrder m1 = matrixFromVector RowMajor r c (f (flatten m1) (flatten m2)) 294 | rowOrder m1 = matrixFromVector RowMajor r c (f (flatten m1) (flatten m2))
@@ -285,103 +296,8 @@ liftMatrix2 f m1@(size->(r,c)) m2
285 296
286------------------------------------------------------------------ 297------------------------------------------------------------------
287 298
288-- | Supported matrix elements.
289class (Storable a) => Element a where
290 constantD :: a -> Int -> Vector a
291 extractR :: MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a)
292 setRect :: Int -> Int -> Matrix a -> Matrix a -> IO ()
293 sortI :: Ord a => Vector a -> Vector CInt
294 sortV :: Ord a => Vector a -> Vector a
295 compareV :: Ord a => Vector a -> Vector a -> Vector CInt
296 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a
297 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
298 rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
299 gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO ()
300 reorderV :: Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation
301
302
303instance Element Float where
304 constantD = constantAux cconstantF
305 extractR = extractAux c_extractF
306 setRect = setRectAux c_setRectF
307 sortI = sortIdxF
308 sortV = sortValF
309 compareV = compareF
310 selectV = selectF
311 remapM = remapF
312 rowOp = rowOpAux c_rowOpF
313 gemm = gemmg c_gemmF
314 reorderV = reorderAux c_reorderF
315
316instance Element Double where
317 constantD = constantAux cconstantR
318 extractR = extractAux c_extractD
319 setRect = setRectAux c_setRectD
320 sortI = sortIdxD
321 sortV = sortValD
322 compareV = compareD
323 selectV = selectD
324 remapM = remapD
325 rowOp = rowOpAux c_rowOpD
326 gemm = gemmg c_gemmD
327 reorderV = reorderAux c_reorderD
328
329instance Element (Complex Float) where
330 constantD = constantAux cconstantQ
331 extractR = extractAux c_extractQ
332 setRect = setRectAux c_setRectQ
333 sortI = undefined
334 sortV = undefined
335 compareV = undefined
336 selectV = selectQ
337 remapM = remapQ
338 rowOp = rowOpAux c_rowOpQ
339 gemm = gemmg c_gemmQ
340 reorderV = reorderAux c_reorderQ
341
342instance Element (Complex Double) where
343 constantD = constantAux cconstantC
344 extractR = extractAux c_extractC
345 setRect = setRectAux c_setRectC
346 sortI = undefined
347 sortV = undefined
348 compareV = undefined
349 selectV = selectC
350 remapM = remapC
351 rowOp = rowOpAux c_rowOpC
352 gemm = gemmg c_gemmC
353 reorderV = reorderAux c_reorderC
354
355instance Element (CInt) where
356 constantD = constantAux cconstantI
357 extractR = extractAux c_extractI
358 setRect = setRectAux c_setRectI
359 sortI = sortIdxI
360 sortV = sortValI
361 compareV = compareI
362 selectV = selectI
363 remapM = remapI
364 rowOp = rowOpAux c_rowOpI
365 gemm = gemmg c_gemmI
366 reorderV = reorderAux c_reorderI
367
368instance Element Z where
369 constantD = constantAux cconstantL
370 extractR = extractAux c_extractL
371 setRect = setRectAux c_setRectL
372 sortI = sortIdxL
373 sortV = sortValL
374 compareV = compareL
375 selectV = selectL
376 remapM = remapL
377 rowOp = rowOpAux c_rowOpL
378 gemm = gemmg c_gemmL
379 reorderV = reorderAux c_reorderL
380
381-------------------------------------------------------------------
382
383-- | reference to a rectangular slice of a matrix (no data copy) 299-- | reference to a rectangular slice of a matrix (no data copy)
384subMatrix :: Element a 300subMatrix :: Storable a
385 => (Int,Int) -- ^ (r0,c0) starting position 301 => (Int,Int) -- ^ (r0,c0) starting position
386 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix 302 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
387 -> Matrix a -- ^ input matrix 303 -> Matrix a -- ^ input matrix
@@ -402,34 +318,34 @@ subMatrix (r0,c0) (rt,ct) m
402maxZ :: (Num t1, Ord t1, Foldable t) => t t1 -> t1 318maxZ :: (Num t1, Ord t1, Foldable t) => t t1 -> t1
403maxZ xs = if minimum xs == 0 then 0 else maximum xs 319maxZ xs = if minimum xs == 0 then 0 else maximum xs
404 320
405conformMs :: Element t => [Matrix t] -> [Matrix t] 321conformMs :: Storable t => [Matrix t] -> [Matrix t]
406conformMs ms = map (conformMTo (r,c)) ms 322conformMs ms = map (conformMTo (r,c)) ms
407 where 323 where
408 r = maxZ (map rows ms) 324 r = maxZ (map rows ms)
409 c = maxZ (map cols ms) 325 c = maxZ (map cols ms)
410 326
411conformVs :: Element t => [Vector t] -> [Vector t] 327conformVs :: Storable t => [Vector t] -> [Vector t]
412conformVs vs = map (conformVTo n) vs 328conformVs vs = map (conformVTo n) vs
413 where 329 where
414 n = maxZ (map dim vs) 330 n = maxZ (map dim vs)
415 331
416conformMTo :: Element t => (Int, Int) -> Matrix t -> Matrix t 332conformMTo :: Storable t => (Int, Int) -> Matrix t -> Matrix t
417conformMTo (r,c) m 333conformMTo (r,c) m
418 | size m == (r,c) = m 334 | size m == (r,c) = m
419 | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) 335 | size m == (1,1) = matrixFromVector RowMajor r c (constantAux (m@@>(0,0)) (r*c))
420 | size m == (r,1) = repCols c m 336 | size m == (r,1) = repCols c m
421 | size m == (1,c) = repRows r m 337 | size m == (1,c) = repRows r m
422 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c) 338 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c)
423 339
424conformVTo :: Element t => Int -> Vector t -> Vector t 340conformVTo :: Storable t => Int -> Vector t -> Vector t
425conformVTo n v 341conformVTo n v
426 | dim v == n = v 342 | dim v == n = v
427 | dim v == 1 = constantD (v@>0) n 343 | dim v == 1 = constantAux (v@>0) n
428 | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n 344 | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n
429 345
430repRows :: Element t => Int -> Matrix t -> Matrix t 346repRows :: Storable t => Int -> Matrix t -> Matrix t
431repRows n x = fromRows (replicate n (flatten x)) 347repRows n x = fromRows (replicate n (flatten x))
432repCols :: Element t => Int -> Matrix t -> Matrix t 348repCols :: Storable t => Int -> Matrix t -> Matrix t
433repCols n x = fromColumns (replicate n (flatten x)) 349repCols n x = fromColumns (replicate n (flatten x))
434 350
435shSize :: Matrix t -> [Char] 351shSize :: Matrix t -> [Char]
@@ -453,32 +369,50 @@ instance (Storable t, NFData t) => NFData (Matrix t)
453 369
454--------------------------------------------------------------- 370---------------------------------------------------------------
455 371
372{-
456extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1, 373extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1,
457 Storable t, Num t3, Num t2, Integral t1, Integral t) 374 Storable t, Num t3, Num t2, Integral t1, Integral t)
458 => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t 375 => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) -- f
459 -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) 376 -> MatrixOrder -- ord
460 -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a) 377 -> c -- m
461extractAux f ord m moder vr modec vc = do 378 -> t3 -- moder
379 -> Vector t1 -- vr
380 -> t2 -- modec
381 -> Vector t -- vc
382 -> IO (Matrix a)
383-}
384
385extractAux :: Storable a =>
386 MatrixOrder
387 -> Matrix a
388 -> Int32
389 -> Vector Int32
390 -> Int32
391 -> Vector Int32
392 -> IO (Matrix a)
393extractAux ord m moder vr modec vc = do
462 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr 394 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr
463 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc 395 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc
464 r <- createMatrix ord nr nc 396 r <- createMatrix ord nr nc
465 (vr # vc # m #! r) (f moder modec) #|"extract" 397 (vr # vc # m #! r) (extractStorable moder modec) #|"extract"
466 398
467 return r 399 return r
468 400
469type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) 401{-
402type Extr x = Int32 -> Int32 -> CIdxs (CIdxs (OM x (OM x (IO Int32))))
470 403
471foreign import ccall unsafe "extractD" c_extractD :: Extr Double 404foreign import ccall unsafe "extractD" c_extractD :: Extr Double
472foreign import ccall unsafe "extractF" c_extractF :: Extr Float 405foreign import ccall unsafe "extractF" c_extractF :: Extr Float
473foreign import ccall unsafe "extractC" c_extractC :: Extr (Complex Double) 406foreign import ccall unsafe "extractC" c_extractC :: Extr (Complex Double)
474foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float) 407foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float)
475foreign import ccall unsafe "extractI" c_extractI :: Extr CInt 408foreign import ccall unsafe "extractI" c_extractI :: Extr Int32
476foreign import ccall unsafe "extractL" c_extractL :: Extr Z 409foreign import ccall unsafe "extractL" c_extractL :: Extr Z
410-}
477 411
478--------------------------------------------------------------- 412---------------------------------------------------------------
479 413
480setRectAux :: (TransArray c1, TransArray c) 414setRectAux :: (TransArray c1, TransArray c)
481 => (CInt -> CInt -> Trans c1 (Trans c (IO CInt))) 415 => (Int32 -> Int32 -> Trans c1 (Trans c (IO Int32)))
482 -> Int -> Int -> c1 -> c -> IO () 416 -> Int -> Int -> c1 -> c -> IO ()
483setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" 417setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect"
484 418
@@ -494,17 +428,17 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z
494-------------------------------------------------------------------------------- 428--------------------------------------------------------------------------------
495 429
496sortG :: (Storable t, Storable a) 430sortG :: (Storable t, Storable a)
497 => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a 431 => (Int32 -> Ptr t -> Int32 -> Ptr a -> IO Int32) -> Vector t -> Vector a
498sortG f v = unsafePerformIO $ do 432sortG f v = unsafePerformIO $ do
499 r <- createVector (dim v) 433 r <- createVector (dim v)
500 (v #! r) f #|"sortG" 434 (v #! r) f #|"sortG"
501 return r 435 return r
502 436
503sortIdxD :: Vector Double -> Vector CInt 437sortIdxD :: Vector Double -> Vector Int32
504sortIdxD = sortG c_sort_indexD 438sortIdxD = sortG c_sort_indexD
505sortIdxF :: Vector Float -> Vector CInt 439sortIdxF :: Vector Float -> Vector Int32
506sortIdxF = sortG c_sort_indexF 440sortIdxF = sortG c_sort_indexF
507sortIdxI :: Vector CInt -> Vector CInt 441sortIdxI :: Vector Int32 -> Vector Int32
508sortIdxI = sortG c_sort_indexI 442sortIdxI = sortG c_sort_indexI
509sortIdxL :: Vector Z -> Vector I 443sortIdxL :: Vector Z -> Vector I
510sortIdxL = sortG c_sort_indexL 444sortIdxL = sortG c_sort_indexL
@@ -513,81 +447,81 @@ sortValD :: Vector Double -> Vector Double
513sortValD = sortG c_sort_valD 447sortValD = sortG c_sort_valD
514sortValF :: Vector Float -> Vector Float 448sortValF :: Vector Float -> Vector Float
515sortValF = sortG c_sort_valF 449sortValF = sortG c_sort_valF
516sortValI :: Vector CInt -> Vector CInt 450sortValI :: Vector Int32 -> Vector Int32
517sortValI = sortG c_sort_valI 451sortValI = sortG c_sort_valI
518sortValL :: Vector Z -> Vector Z 452sortValL :: Vector Z -> Vector Z
519sortValL = sortG c_sort_valL 453sortValL = sortG c_sort_valL
520 454
521foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) 455foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV Int32 (IO Int32))
522foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV CInt (IO CInt)) 456foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV Int32 (IO Int32))
523foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV CInt (CV CInt (IO CInt)) 457foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV Int32 (CV Int32 (IO Int32))
524foreign import ccall unsafe "sort_indexL" c_sort_indexL :: Z :> I :> Ok 458foreign import ccall unsafe "sort_indexL" c_sort_indexL :: Z :> I :> Ok
525 459
526foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO CInt)) 460foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO Int32))
527foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO CInt)) 461foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO Int32))
528foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV CInt (CV CInt (IO CInt)) 462foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV Int32 (CV Int32 (IO Int32))
529foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok 463foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok
530 464
531-------------------------------------------------------------------------------- 465--------------------------------------------------------------------------------
532 466
533compareG :: (TransArray c, Storable t, Storable a) 467compareG :: (TransArray c, Storable t, Storable a)
534 => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) 468 => Trans c (Int32 -> Ptr t -> Int32 -> Ptr a -> IO Int32)
535 -> c -> Vector t -> Vector a 469 -> c -> Vector t -> Vector a
536compareG f u v = unsafePerformIO $ do 470compareG f u v = unsafePerformIO $ do
537 r <- createVector (dim v) 471 r <- createVector (dim v)
538 (u # v #! r) f #|"compareG" 472 (u # v #! r) f #|"compareG"
539 return r 473 return r
540 474
541compareD :: Vector Double -> Vector Double -> Vector CInt 475compareD :: Vector Double -> Vector Double -> Vector Int32
542compareD = compareG c_compareD 476compareD = compareG c_compareD
543compareF :: Vector Float -> Vector Float -> Vector CInt 477compareF :: Vector Float -> Vector Float -> Vector Int32
544compareF = compareG c_compareF 478compareF = compareG c_compareF
545compareI :: Vector CInt -> Vector CInt -> Vector CInt 479compareI :: Vector Int32 -> Vector Int32 -> Vector Int32
546compareI = compareG c_compareI 480compareI = compareG c_compareI
547compareL :: Vector Z -> Vector Z -> Vector CInt 481compareL :: Vector Z -> Vector Z -> Vector Int32
548compareL = compareG c_compareL 482compareL = compareG c_compareL
549 483
550foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) 484foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV Int32 (IO Int32)))
551foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV CInt (IO CInt))) 485foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV Int32 (IO Int32)))
552foreign import ccall unsafe "compareI" c_compareI :: CV CInt (CV CInt (CV CInt (IO CInt))) 486foreign import ccall unsafe "compareI" c_compareI :: CV Int32 (CV Int32 (CV Int32 (IO Int32)))
553foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok 487foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok
554 488
555-------------------------------------------------------------------------------- 489--------------------------------------------------------------------------------
556 490
557selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a) 491selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a)
558 => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt))) 492 => Trans c2 (Trans c1 (Int32 -> Ptr t -> Trans c (Int32 -> Ptr a -> IO Int32)))
559 -> c2 -> c1 -> Vector t -> c -> Vector a 493 -> c2 -> c1 -> Vector t -> c -> Vector a
560selectG f c u v w = unsafePerformIO $ do 494selectG f c u v w = unsafePerformIO $ do
561 r <- createVector (dim v) 495 r <- createVector (dim v)
562 (c # u # v # w #! r) f #|"selectG" 496 (c # u # v # w #! r) f #|"selectG"
563 return r 497 return r
564 498
565selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double 499selectD :: Vector Int32 -> Vector Double -> Vector Double -> Vector Double -> Vector Double
566selectD = selectG c_selectD 500selectD = selectG c_selectD
567selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float 501selectF :: Vector Int32 -> Vector Float -> Vector Float -> Vector Float -> Vector Float
568selectF = selectG c_selectF 502selectF = selectG c_selectF
569selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt 503selectI :: Vector Int32 -> Vector Int32 -> Vector Int32 -> Vector Int32 -> Vector Int32
570selectI = selectG c_selectI 504selectI = selectG c_selectI
571selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z 505selectL :: Vector Int32 -> Vector Z -> Vector Z -> Vector Z -> Vector Z
572selectL = selectG c_selectL 506selectL = selectG c_selectL
573selectC :: Vector CInt 507selectC :: Vector Int32
574 -> Vector (Complex Double) 508 -> Vector (Complex Double)
575 -> Vector (Complex Double) 509 -> Vector (Complex Double)
576 -> Vector (Complex Double) 510 -> Vector (Complex Double)
577 -> Vector (Complex Double) 511 -> Vector (Complex Double)
578selectC = selectG c_selectC 512selectC = selectG c_selectC
579selectQ :: Vector CInt 513selectQ :: Vector Int32
580 -> Vector (Complex Float) 514 -> Vector (Complex Float)
581 -> Vector (Complex Float) 515 -> Vector (Complex Float)
582 -> Vector (Complex Float) 516 -> Vector (Complex Float)
583 -> Vector (Complex Float) 517 -> Vector (Complex Float)
584selectQ = selectG c_selectQ 518selectQ = selectG c_selectQ
585 519
586type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) 520type Sel x = CV Int32 (CV x (CV x (CV x (CV x (IO Int32)))))
587 521
588foreign import ccall unsafe "chooseD" c_selectD :: Sel Double 522foreign import ccall unsafe "chooseD" c_selectD :: Sel Double
589foreign import ccall unsafe "chooseF" c_selectF :: Sel Float 523foreign import ccall unsafe "chooseF" c_selectF :: Sel Float
590foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt 524foreign import ccall unsafe "chooseI" c_selectI :: Sel Int32
591foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double) 525foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double)
592foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float) 526foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float)
593foreign import ccall unsafe "chooseL" c_selectL :: Sel Z 527foreign import ccall unsafe "chooseL" c_selectL :: Sel Z
@@ -595,35 +529,35 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z
595--------------------------------------------------------------------------- 529---------------------------------------------------------------------------
596 530
597remapG :: (TransArray c, TransArray c1, Storable t, Storable a) 531remapG :: (TransArray c, TransArray c1, Storable t, Storable a)
598 => (CInt -> CInt -> CInt -> CInt -> Ptr t 532 => (Int32 -> Int32 -> Int32 -> Int32 -> Ptr t
599 -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))) 533 -> Trans c1 (Trans c (Int32 -> Int32 -> Int32 -> Int32 -> Ptr a -> IO Int32)))
600 -> Matrix t -> c1 -> c -> Matrix a 534 -> Matrix t -> c1 -> c -> Matrix a
601remapG f i j m = unsafePerformIO $ do 535remapG f i j m = unsafePerformIO $ do
602 r <- createMatrix RowMajor (rows i) (cols i) 536 r <- createMatrix RowMajor (rows i) (cols i)
603 (i # j # m #! r) f #|"remapG" 537 (i # j # m #! r) f #|"remapG"
604 return r 538 return r
605 539
606remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double 540remapD :: Matrix Int32 -> Matrix Int32 -> Matrix Double -> Matrix Double
607remapD = remapG c_remapD 541remapD = remapG c_remapD
608remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float 542remapF :: Matrix Int32 -> Matrix Int32 -> Matrix Float -> Matrix Float
609remapF = remapG c_remapF 543remapF = remapG c_remapF
610remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt 544remapI :: Matrix Int32 -> Matrix Int32 -> Matrix Int32 -> Matrix Int32
611remapI = remapG c_remapI 545remapI = remapG c_remapI
612remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z 546remapL :: Matrix Int32 -> Matrix Int32 -> Matrix Z -> Matrix Z
613remapL = remapG c_remapL 547remapL = remapG c_remapL
614remapC :: Matrix CInt 548remapC :: Matrix Int32
615 -> Matrix CInt 549 -> Matrix Int32
616 -> Matrix (Complex Double) 550 -> Matrix (Complex Double)
617 -> Matrix (Complex Double) 551 -> Matrix (Complex Double)
618remapC = remapG c_remapC 552remapC = remapG c_remapC
619remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float) 553remapQ :: Matrix Int32 -> Matrix Int32 -> Matrix (Complex Float) -> Matrix (Complex Float)
620remapQ = remapG c_remapQ 554remapQ = remapG c_remapQ
621 555
622type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) 556type Rem x = OM Int32 (OM Int32 (OM x (OM x (IO Int32))))
623 557
624foreign import ccall unsafe "remapD" c_remapD :: Rem Double 558foreign import ccall unsafe "remapD" c_remapD :: Rem Double
625foreign import ccall unsafe "remapF" c_remapF :: Rem Float 559foreign import ccall unsafe "remapF" c_remapF :: Rem Float
626foreign import ccall unsafe "remapI" c_remapI :: Rem CInt 560foreign import ccall unsafe "remapI" c_remapI :: Rem Int32
627foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double) 561foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double)
628foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float) 562foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float)
629foreign import ccall unsafe "remapL" c_remapL :: Rem Z 563foreign import ccall unsafe "remapL" c_remapL :: Rem Z
@@ -631,14 +565,14 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z
631-------------------------------------------------------------------------------- 565--------------------------------------------------------------------------------
632 566
633rowOpAux :: (TransArray c, Storable a) => 567rowOpAux :: (TransArray c, Storable a) =>
634 (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt)) 568 (Int32 -> Ptr a -> Int32 -> Int32 -> Int32 -> Int32 -> Trans c (IO Int32))
635 -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO () 569 -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO ()
636rowOpAux f c x i1 i2 j1 j2 m = do 570rowOpAux f c x i1 i2 j1 j2 m = do
637 px <- newArray [x] 571 px <- newArray [x]
638 (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" 572 (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp"
639 free px 573 free px
640 574
641type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok 575type RowOp x = Int32 -> Ptr x -> Int32 -> Int32 -> Int32 -> Int32 -> x ::> Ok
642 576
643foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R 577foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R
644foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float 578foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float
@@ -652,7 +586,7 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z
652-------------------------------------------------------------------------------- 586--------------------------------------------------------------------------------
653 587
654gemmg :: (TransArray c1, TransArray c, TransArray c2, TransArray c3) 588gemmg :: (TransArray c1, TransArray c, TransArray c2, TransArray c3)
655 => Trans c3 (Trans c2 (Trans c1 (Trans c (IO CInt)))) 589 => Trans c3 (Trans c2 (Trans c1 (Trans c (IO Int32))))
656 -> c3 -> c2 -> c1 -> c -> IO () 590 -> c3 -> c2 -> c1 -> c -> IO ()
657gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" 591gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg"
658 592
@@ -669,21 +603,26 @@ foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z
669 603
670-------------------------------------------------------------------------------- 604--------------------------------------------------------------------------------
671 605
606{-
672reorderAux :: (TransArray c, Storable t, Storable a1, Storable t1, Storable a) => 607reorderAux :: (TransArray c, Storable t, Storable a1, Storable t1, Storable a) =>
673 (CInt -> Ptr a -> CInt -> Ptr t1 608 (Int32 -> Ptr a -> Int32 -> Ptr t1
674 -> Trans c (CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt)) 609 -> Trans c (Int32 -> Ptr t -> Int32 -> Ptr a1 -> IO Int32))
675 -> Vector t1 -> c -> Vector t -> Vector a1 610 -> Vector t1 -> c -> Vector t -> Vector a1
611-}
612reorderAux :: (TransArray c, Storable a,
613 Trans c (Int32 -> Ptr a -> Int32 -> Ptr a -> IO Int32) ~ (Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr a -> Int32 -> Ptr a -> IO Int32)) =>
614 p -> Vector Int32 -> c -> Vector a -> Vector a
676reorderAux f s d v = unsafePerformIO $ do 615reorderAux f s d v = unsafePerformIO $ do
677 k <- createVector (dim s) 616 k <- createVector (dim s)
678 r <- createVector (dim v) 617 r <- createVector (dim v)
679 (k # s # d # v #! r) f #| "reorderV" 618 (k # s # d # v #! r) reorderStorable #| "reorderV"
680 return r 619 return r
681 620
682type Reorder x = CV CInt (CV CInt (CV CInt (CV x (CV x (IO CInt))))) 621type Reorder x = CV Int32 (CV Int32 (CV Int32 (CV x (CV x (IO Int32)))))
683 622
684foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double 623foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double
685foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float 624foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float
686foreign import ccall unsafe "reorderI" c_reorderI :: Reorder CInt 625foreign import ccall unsafe "reorderI" c_reorderI :: Reorder Int32
687foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double) 626foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double)
688foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float) 627foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float)
689foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z 628foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z
@@ -691,12 +630,12 @@ foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z
691-- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices, 630-- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices,
692-- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@ 631-- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@
693-- This function is intended to be used internally by tensor libraries. 632-- This function is intended to be used internally by tensor libraries.
694reorderVector :: Element a 633reorderVector :: Storable a
695 => Vector CInt -- ^ @strides@: array strides 634 => Vector Int32 -- ^ @strides@: array strides
696 -> Vector CInt -- ^ @dims@: array dimensions of new array @v@ 635 -> Vector Int32 -- ^ @dims@: array dimensions of new array @v@
697 -> Vector a -- ^ @v@: flattened input array 636 -> Vector a -- ^ @v@: flattened input array
698 -> Vector a -- ^ @v'@: flattened output array 637 -> Vector a -- ^ @v'@: flattened output array
699reorderVector = reorderV 638reorderVector = reorderAux ()
700 639
701-------------------------------------------------------------------------------- 640--------------------------------------------------------------------------------
702 641
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs
index eb0c5a8..e67aa67 100644
--- a/packages/base/src/Internal/Modular.hs
+++ b/packages/base/src/Internal/Modular.hs
@@ -135,6 +135,7 @@ instance (Integral t, KnownNat n) => Num (Mod n t)
135 fromInteger = l0 (\m x -> fromInteger x `mod` (fromIntegral m)) 135 fromInteger = l0 (\m x -> fromInteger x `mod` (fromIntegral m))
136 136
137 137
138#if 0
138instance KnownNat m => Element (Mod m I) 139instance KnownNat m => Element (Mod m I)
139 where 140 where
140 constantD x n = i2f (constantD (unMod x) n) 141 constantD x n = i2f (constantD (unMod x) n)
@@ -168,6 +169,7 @@ instance KnownNat m => Element (Mod m Z)
168 gemm u a b c = gemmg (c_gemmML m') (f2i u) (f2iM a) (f2iM b) (f2iM c) 169 gemm u a b c = gemmg (c_gemmML m') (f2i u) (f2iM a) (f2iM b) (f2iM c)
169 where 170 where
170 m' = fromIntegral . natVal $ (undefined :: Proxy m) 171 m' = fromIntegral . natVal $ (undefined :: Proxy m)
172#endif
171 173
172 174
173instance KnownNat m => CTrans (Mod m I) 175instance KnownNat m => CTrans (Mod m I)
@@ -306,10 +308,10 @@ f2i :: Storable t => Vector (Mod n t) -> Vector t
306f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) 308f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n)
307 where (fp,i,n) = unsafeToForeignPtr v 309 where (fp,i,n) = unsafeToForeignPtr v
308 310
309f2iM :: (Element t, Element (Mod n t)) => Matrix (Mod n t) -> Matrix t 311f2iM :: (Storable t, Storable (Mod n t)) => Matrix (Mod n t) -> Matrix t
310f2iM m = m { xdat = f2i (xdat m) } 312f2iM m = m { xdat = f2i (xdat m) }
311 313
312i2fM :: (Element t, Element (Mod n t)) => Matrix t -> Matrix (Mod n t) 314i2fM :: (Storable t, Storable (Mod n t)) => Matrix t -> Matrix (Mod n t)
313i2fM m = m { xdat = i2f (xdat m) } 315i2fM m = m { xdat = i2f (xdat m) }
314 316
315vmod :: forall m t. (KnownNat m, Storable t, Integral t, Numeric t) => Vector t -> Vector (Mod m t) 317vmod :: forall m t. (KnownNat m, Storable t, Integral t, Numeric t) => Vector t -> Vector (Mod m t)
diff --git a/packages/base/src/Internal/Numeric.hs b/packages/base/src/Internal/Numeric.hs
index fd0a217..4f7bb82 100644
--- a/packages/base/src/Internal/Numeric.hs
+++ b/packages/base/src/Internal/Numeric.hs
@@ -4,6 +4,7 @@
4{-# LANGUAGE MultiParamTypeClasses #-} 4{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE FunctionalDependencies #-} 5{-# LANGUAGE FunctionalDependencies #-}
6{-# LANGUAGE UndecidableInstances #-} 6{-# LANGUAGE UndecidableInstances #-}
7{-# LANGUAGE PatternSynonyms #-}
7 8
8{-# OPTIONS_GHC -fno-warn-missing-signatures #-} 9{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
9 10
@@ -22,12 +23,18 @@ module Internal.Numeric where
22import Internal.Vector 23import Internal.Vector
23import Internal.Matrix 24import Internal.Matrix
24import Internal.Element 25import Internal.Element
26import Internal.Extract (requires,pattern BAD_SIZE)
25import Internal.ST as ST 27import Internal.ST as ST
26import Internal.Conversion 28import Internal.Conversion
27import Internal.Vectorized 29import Internal.Vectorized
28import Internal.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ,multiplyI,multiplyL) 30import Internal.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ,multiplyI,multiplyL)
31import Control.Monad
32import Data.Function
33import Data.Int
29import Data.List.Split(chunksOf) 34import Data.List.Split(chunksOf)
30import qualified Data.Vector.Storable as V 35import qualified Data.Vector.Storable as V
36import Foreign.Ptr
37import Foreign.Storable
31 38
32-------------------------------------------------------------------------------- 39--------------------------------------------------------------------------------
33 40
@@ -44,7 +51,7 @@ type instance ArgOf Matrix a = a -> a -> a
44-------------------------------------------------------------------------------- 51--------------------------------------------------------------------------------
45 52
46-- | Basic element-by-element functions for numeric containers 53-- | Basic element-by-element functions for numeric containers
47class Element e => Container c e 54class Storable e => Container c e
48 where 55 where
49 conj' :: c e -> c e 56 conj' :: c e -> c e
50 size' :: c e -> IndexOf c 57 size' :: c e -> IndexOf c
@@ -56,7 +63,7 @@ class Element e => Container c e
56 -- | element by element multiplication 63 -- | element by element multiplication
57 mul :: c e -> c e -> c e 64 mul :: c e -> c e -> c e
58 equal :: c e -> c e -> Bool 65 equal :: c e -> c e -> Bool
59 cmap' :: (Element b) => (e -> b) -> c e -> c b 66 cmap' :: (Storable b) => (e -> b) -> c e -> c b
60 konst' :: e -> IndexOf c -> c e 67 konst' :: e -> IndexOf c -> c e
61 build' :: IndexOf c -> (ArgOf c e) -> c e 68 build' :: IndexOf c -> (ArgOf c e) -> c e
62 atIndex' :: c e -> IndexOf c -> e 69 atIndex' :: c e -> IndexOf c -> e
@@ -107,7 +114,7 @@ instance Container Vector I
107 mul = vectorZipI Mul 114 mul = vectorZipI Mul
108 equal = (==) 115 equal = (==)
109 scalar' = V.singleton 116 scalar' = V.singleton
110 konst' = constantD 117 konst' = constantAux
111 build' = buildV 118 build' = buildV
112 cmap' = mapVector 119 cmap' = mapVector
113 atIndex' = (@>) 120 atIndex' = (@>)
@@ -146,7 +153,7 @@ instance Container Vector Z
146 mul = vectorZipL Mul 153 mul = vectorZipL Mul
147 equal = (==) 154 equal = (==)
148 scalar' = V.singleton 155 scalar' = V.singleton
149 konst' = constantD 156 konst' = constantAux
150 build' = buildV 157 build' = buildV
151 cmap' = mapVector 158 cmap' = mapVector
152 atIndex' = (@>) 159 atIndex' = (@>)
@@ -186,7 +193,7 @@ instance Container Vector Float
186 mul = vectorZipF Mul 193 mul = vectorZipF Mul
187 equal = (==) 194 equal = (==)
188 scalar' = V.singleton 195 scalar' = V.singleton
189 konst' = constantD 196 konst' = constantAux
190 build' = buildV 197 build' = buildV
191 cmap' = mapVector 198 cmap' = mapVector
192 atIndex' = (@>) 199 atIndex' = (@>)
@@ -223,7 +230,7 @@ instance Container Vector Double
223 mul = vectorZipR Mul 230 mul = vectorZipR Mul
224 equal = (==) 231 equal = (==)
225 scalar' = V.singleton 232 scalar' = V.singleton
226 konst' = constantD 233 konst' = constantAux
227 build' = buildV 234 build' = buildV
228 cmap' = mapVector 235 cmap' = mapVector
229 atIndex' = (@>) 236 atIndex' = (@>)
@@ -260,7 +267,7 @@ instance Container Vector (Complex Double)
260 mul = vectorZipC Mul 267 mul = vectorZipC Mul
261 equal = (==) 268 equal = (==)
262 scalar' = V.singleton 269 scalar' = V.singleton
263 konst' = constantD 270 konst' = constantAux
264 build' = buildV 271 build' = buildV
265 cmap' = mapVector 272 cmap' = mapVector
266 atIndex' = (@>) 273 atIndex' = (@>)
@@ -296,7 +303,7 @@ instance Container Vector (Complex Float)
296 mul = vectorZipQ Mul 303 mul = vectorZipQ Mul
297 equal = (==) 304 equal = (==)
298 scalar' = V.singleton 305 scalar' = V.singleton
299 konst' = constantD 306 konst' = constantAux
300 build' = buildV 307 build' = buildV
301 cmap' = mapVector 308 cmap' = mapVector
302 atIndex' = (@>) 309 atIndex' = (@>)
@@ -323,7 +330,7 @@ instance Container Vector (Complex Float)
323 330
324--------------------------------------------------------------- 331---------------------------------------------------------------
325 332
326instance (Num a, Element a, Container Vector a) => Container Matrix a 333instance (Num a, Storable a, Container Vector a) => Container Matrix a
327 where 334 where
328 conj' = liftMatrix conj' 335 conj' = liftMatrix conj'
329 size' = size 336 size' = size
@@ -418,8 +425,8 @@ fromZ = fromZ'
418toZ :: (Container c e) => c e -> c Z 425toZ :: (Container c e) => c e -> c Z
419toZ = toZ' 426toZ = toZ'
420 427
421-- | like 'fmap' (cannot implement instance Functor because of Element class constraint) 428-- | like 'fmap' (cannot implement instance Functor because of Storable class constraint)
422cmap :: (Element b, Container c e) => (e -> b) -> c e -> c b 429cmap :: (Storable b, Container c e) => (e -> b) -> c e -> c b
423cmap = cmap' 430cmap = cmap'
424 431
425-- | generic indexing function 432-- | generic indexing function
@@ -470,7 +477,7 @@ step
470step = step' 477step = step'
471 478
472 479
473-- | Element by element version of @case compare a b of {LT -> l; EQ -> e; GT -> g}@. 480-- | Storable by element version of @case compare a b of {LT -> l; EQ -> e; GT -> g}@.
474-- 481--
475-- Arguments with any dimension = 1 are automatically expanded: 482-- Arguments with any dimension = 1 are automatically expanded:
476-- 483--
@@ -598,7 +605,7 @@ instance Numeric Z
598-------------------------------------------------------------------------------- 605--------------------------------------------------------------------------------
599 606
600-- | Matrix product and related functions 607-- | Matrix product and related functions
601class (Num e, Element e) => Product e where 608class (Num e, Storable e) => Product e where
602 -- | matrix product 609 -- | matrix product
603 multiply :: Matrix e -> Matrix e -> Matrix e 610 multiply :: Matrix e -> Matrix e -> Matrix e
604 -- | sum of absolute value of elements (differs in complex case from @norm1@) 611 -- | sum of absolute value of elements (differs in complex case from @norm1@)
@@ -823,12 +830,12 @@ buildV n f = fromList [f k | k <- ks]
823-------------------------------------------------------- 830--------------------------------------------------------
824 831
825-- | Creates a square matrix with a given diagonal. 832-- | Creates a square matrix with a given diagonal.
826diag :: (Num a, Element a) => Vector a -> Matrix a 833diag :: (Num a, Storable a) => Vector a -> Matrix a
827diag v = diagRect 0 v n n where n = dim v 834diag v = diagRect 0 v n n where n = dim v
828 835
829-- | creates the identity matrix of given dimension 836-- | creates the identity matrix of given dimension
830ident :: (Num a, Element a) => Int -> Matrix a 837ident :: (Num a, Storable a) => Int -> Matrix a
831ident n = diag (constantD 1 n) 838ident n = diag (constantAux 1 n)
832 839
833-------------------------------------------------------- 840--------------------------------------------------------
834 841
@@ -943,3 +950,44 @@ class Testable t
943 950
944-------------------------------------------------------------------------------- 951--------------------------------------------------------------------------------
945 952
953compareV :: (Storable a, Ord a) => Vector a -> Vector a -> Vector Int32
954compareV = compareG compareStorable
955
956compareStorable :: (Storable a, Ord a) =>
957 Int32 -> Ptr a
958 -> Int32 -> Ptr a
959 -> Int32 -> Ptr Int32
960 -> IO Int32
961compareStorable xn xp yn yp rn rp = do
962 requires (xn==yn && xn==rn) BAD_SIZE $ do
963 ($ 0) $ fix $ \kloop k -> when (k<xn) $ do
964 xk <- peekElemOff xp (fromIntegral k)
965 yk <- peekElemOff yp (fromIntegral k)
966 pokeElemOff rp (fromIntegral k) $ case compare xk yk of
967 LT -> -1
968 GT -> 1
969 EQ -> 0
970 kloop (succ k)
971 return 0
972
973selectV :: Storable a => Vector Int32 -> Vector a -> Vector a -> Vector a -> Vector a
974selectV = selectG selectStorable
975
976selectStorable :: Storable a =>
977 Int32 -> Ptr Int32
978 -> Int32 -> Ptr a
979 -> Int32 -> Ptr a
980 -> Int32 -> Ptr a
981 -> Int32 -> Ptr a
982 -> IO Int32
983selectStorable condn condp ltn ltp eqn eqp gtn gtp rn rp = do
984 requires (condn==ltn && ltn==eqn && ltn==gtn && ltn==rn) BAD_SIZE $ do
985 ($ 0) $ fix $ \kloop k -> when (k<condn) $ do
986 condpk <- peekElemOff condp (fromIntegral k)
987 pokeElemOff rp (fromIntegral k) =<< case compare condpk 0 of
988 LT -> peekElemOff ltp (fromIntegral k)
989 GT -> peekElemOff gtp (fromIntegral k)
990 EQ -> peekElemOff eqp (fromIntegral k)
991 kloop (succ k)
992 return 0
993
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index 7d54e6d..326b90a 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -1,6 +1,7 @@
1{-# LANGUAGE Rank2Types #-} 1{-# LANGUAGE Rank2Types #-}
2{-# LANGUAGE BangPatterns #-} 2{-# LANGUAGE BangPatterns #-}
3{-# LANGUAGE ViewPatterns #-} 3{-# LANGUAGE ViewPatterns #-}
4{-# LANGUAGE PatternSynonyms #-}
4 5
5----------------------------------------------------------------------------- 6-----------------------------------------------------------------------------
6-- | 7-- |
@@ -30,14 +31,20 @@ module Internal.ST (
30 unsafeThawVector, unsafeFreezeVector, 31 unsafeThawVector, unsafeFreezeVector,
31 newUndefinedMatrix, 32 newUndefinedMatrix,
32 unsafeReadMatrix, unsafeWriteMatrix, 33 unsafeReadMatrix, unsafeWriteMatrix,
33 unsafeThawMatrix, unsafeFreezeMatrix 34 unsafeThawMatrix, unsafeFreezeMatrix,
35 setRect
34) where 36) where
35 37
36import Internal.Vector 38import Internal.Vector
37import Internal.Matrix 39import Internal.Matrix
38import Internal.Vectorized 40import Internal.Vectorized
41import Internal.Devel ((#|))
39import Control.Monad.ST(ST, runST) 42import Control.Monad.ST(ST, runST)
40import Foreign.Storable(Storable, peekElemOff, pokeElemOff) 43import Control.Monad
44import Data.Function
45import Data.Int
46import Foreign.Ptr
47import Foreign.Storable
41import Control.Monad.ST.Unsafe(unsafeIOToST) 48import Control.Monad.ST.Unsafe(unsafeIOToST)
42 49
43{-# INLINE ioReadV #-} 50{-# INLINE ioReadV #-}
@@ -121,7 +128,7 @@ ioWriteM m r c val = ioWriteV (xdat m) (r * xRow m + c * xCol m) val
121 128
122newtype STMatrix s t = STMatrix (Matrix t) 129newtype STMatrix s t = STMatrix (Matrix t)
123 130
124thawMatrix :: Element t => Matrix t -> ST s (STMatrix s t) 131thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
125thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix 132thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix
126 133
127unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) 134unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
@@ -142,17 +149,17 @@ unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c
142modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () 149modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
143modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c 150modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c
144 151
145liftSTMatrix :: (Element t) => (Matrix t -> a) -> STMatrix s t -> ST s a 152liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s t -> ST s a
146liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x 153liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x
147 154
148unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) 155unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t)
149unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x 156unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
150 157
151 158
152freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t) 159freezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t)
153freezeMatrix m = liftSTMatrix id m 160freezeMatrix m = liftSTMatrix id m
154 161
155cloneMatrix :: Element t => Matrix t -> IO (Matrix t) 162cloneMatrix :: Storable t => Matrix t -> IO (Matrix t)
156cloneMatrix m = copy (orderOf m) m 163cloneMatrix m = copy (orderOf m) m
157 164
158{-# INLINE safeIndexM #-} 165{-# INLINE safeIndexM #-}
@@ -172,7 +179,7 @@ readMatrix = safeIndexM unsafeReadMatrix
172writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () 179writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
173writeMatrix = safeIndexM unsafeWriteMatrix 180writeMatrix = safeIndexM unsafeWriteMatrix
174 181
175setMatrix :: Element t => STMatrix s t -> Int -> Int -> Matrix t -> ST s () 182setMatrix :: Storable t => STMatrix s t -> Int -> Int -> Matrix t -> ST s ()
176setMatrix (STMatrix x) i j m = unsafeIOToST $ setRect i j m x 183setMatrix (STMatrix x) i j m = unsafeIOToST $ setRect i j m x
177 184
178newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t) 185newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t)
@@ -210,7 +217,7 @@ data RowOper t = AXPY t Int Int ColRange
210 | SCAL t RowRange ColRange 217 | SCAL t RowRange ColRange
211 | SWAP Int Int ColRange 218 | SWAP Int Int ColRange
212 219
213rowOper :: (Num t, Element t) => RowOper t -> STMatrix s t -> ST s () 220rowOper :: (Num t, Storable t) => RowOper t -> STMatrix s t -> ST s ()
214 221
215rowOper (AXPY x i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 0 x i1' i2' j1 j2 m 222rowOper (AXPY x i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 0 x i1' i2' j1 j2 m
216 where 223 where
@@ -230,8 +237,8 @@ rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m
230 i2' = i2 `mod` (rows m) 237 i2' = i2 `mod` (rows m)
231 238
232 239
233extractMatrix :: Element a => STMatrix t a -> RowRange -> ColRange -> ST s (Matrix a) 240extractMatrix :: Storable a => STMatrix t a -> RowRange -> ColRange -> ST s (Matrix a)
234extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) 241extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractAux (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2]))
235 where 242 where
236 (i1,i2) = getRowRange (rows m) rr 243 (i1,i2) = getRowRange (rows m) rr
237 (j1,j2) = getColRange (cols m) rc 244 (j1,j2) = getColRange (cols m) rc
@@ -239,19 +246,117 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[
239-- | r0 c0 height width 246-- | r0 c0 height width
240data Slice s t = Slice (STMatrix s t) Int Int Int Int 247data Slice s t = Slice (STMatrix s t) Int Int Int Int
241 248
242slice :: Element a => Slice t a -> Matrix a 249slice :: Storable a => Slice t a -> Matrix a
243slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m 250slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m
244 251
245gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () 252gemmm :: (Storable t, Num t) => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s ()
246gemmm beta (slice->r) alpha (slice->a) (slice->b) = res 253gemmm beta (slice->r) alpha (slice->a) (slice->b) = res
247 where 254 where
248 res = unsafeIOToST (gemm v a b r) 255 res = unsafeIOToST (gemm v a b r)
249 v = fromList [alpha,beta] 256 v = fromList [alpha,beta]
250 257
251 258
252mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) 259mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u)
253mutable f a = runST $ do 260mutable f a = runST $ do
254 x <- thawMatrix a 261 x <- thawMatrix a
255 info <- f (rows a, cols a) x 262 info <- f (rows a, cols a) x
256 r <- unsafeFreezeMatrix x 263 r <- unsafeFreezeMatrix x
257 return (r,info) 264 return (r,info)
265
266
267
268setRect :: Storable t => Int -> Int -> Matrix t -> Matrix t -> IO ()
269setRect i j m r = (m Internal.Matrix.#! r) (setRectStorable (fi i) (fi j)) #|"setRect"
270
271setRectStorable :: Storable t =>
272 Int32 -> Int32
273 -> Int32 -> Int32 -> Int32 -> Int32 -> {- const -} Ptr t
274 -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t
275 -> IO Int32
276setRectStorable i j mr mc mXr mXc mp rr rc rXr rXc rp = do
277 ($ 0) $ fix $ \aloop a -> when (a<mr) $ do
278 ($ 0) $ fix $ \bloop b -> when (b<mc) $ do
279 let x = a+i
280 y = b+j
281 when (0<=x && x<rr && 0<=y && y<rc) $ do
282 pokeElemOff rp (fromIntegral $ rXr*x + rXc*y)
283 =<< peekElemOff mp (fromIntegral $ mXr*a + mXc*b)
284 bloop (succ b)
285 aloop (succ a)
286 return 0
287
288rowOp :: (Storable t, Num t) => Int -> t -> Int -> Int -> Int -> Int -> Matrix t -> IO ()
289rowOp = rowOpAux rowOpStorable
290
291pattern BAD_CODE = 2001
292
293rowOpStorable :: (Storable t, Num t) =>
294 Int32 -> Ptr t -> Int32 -> Int32 -> Int32 -> Int32
295 -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t
296 -> IO Int32
297rowOpStorable 0 pa i1 i2 j1 j2 rr rc rXr rXc rp = do
298 -- AXPY_IMP
299 a <- peek pa
300 ($ j1) $ fix $ \jloop j -> when (j<=j2) $ do
301 ri1j <- peekElemOff rp $ fromIntegral $ rXr*i1 + rXc*j
302 let i2j = fromIntegral $ rXr*i2 + rXc*j
303 ri2j <- peekElemOff rp i2j
304 pokeElemOff rp i2j $ ri2j + a*ri1j
305 jloop (succ j)
306 return 0
307rowOpStorable 1 pa i1 i2 j1 j2 rr rc rXr rXc rp = do
308 -- SCAL_IMP
309 a <- peek pa
310 ($ i1) $ fix $ \iloop i -> when (i<=i2) $ do
311 ($ j1) $ fix $ \jloop j -> when (j<=j2) $ do
312 let rijp = rp `plusPtr` fromIntegral (rXr*i + rXc*j)
313 rij <- peek rijp
314 poke rijp $ a * rij
315 jloop (succ j)
316 iloop (succ i)
317 return 0
318rowOpStorable 2 pa i1 i2 j1 j2 rr rc rXr rXc rp | i1 == i2 = return 0
319rowOpStorable 2 pa i1 i2 j1 j2 rr rc rXr rXc rp = do
320 -- SWAP_IMP
321 ($ j1) $ fix $ \kloop k -> when (k<=j2) $ do
322 let i1k = fromIntegral $ rXr*i1 + rXc*k
323 i2k = fromIntegral $ rXr*i2 + rXc*k
324 aux <- peekElemOff rp i1k
325 pokeElemOff rp i1k =<< peekElemOff rp i2k
326 pokeElemOff rp i2k aux
327 kloop (succ k)
328 return 0
329rowOpStorable _ pa i1 i2 j1 j2 rr rc rXr rXc rp = do
330 return BAD_CODE
331
332gemm :: (Storable t, Num t) => Vector t -> Matrix t -> Matrix t -> Matrix t -> IO ()
333gemm v m1 m2 m3 = (v Internal.Matrix.# m1 Internal.Matrix.# m2 Internal.Matrix.#! m3) gemmStorable #|"gemm"
334
335-- ScalarLike t
336gemmStorable :: (Storable t, Num t) =>
337 Int32 -> Ptr t -- VECG(T,c)
338 -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t -- MATG(T,a)
339 -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t -- MATG(T,b)
340 -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t -- MATG(T,r)
341 -> IO Int32
342gemmStorable cn cp
343 ar ac aXr aXc ap
344 br bc bXr bXc bp
345 rr rc rXr rXc rp = do
346 a <- peek cp
347 b <- peekElemOff cp 1
348 ($ 0) $ fix $ \iloop i -> when (i<rr) $ do
349 ($ 0) $ fix $ \jloop j -> when (j<rc) $ do
350 let kloop k !t fin
351 | k<ac = do
352 aik <- peekElemOff ap (fromIntegral $ i*aXr + k*aXc)
353 bkj <- peekElemOff bp (fromIntegral $ k*bXr + j*bXc)
354 kloop (succ k) (t + aik*bkj) fin
355 | otherwise = fin t
356 kloop 0 0 $ \t -> do
357 let ij = fromIntegral $ i*rXr + j*rXc
358 rij <- peekElemOff rp ij
359 pokeElemOff rp ij (b*rij + a*t)
360 jloop (succ j)
361 iloop (succ i)
362 return 0
diff --git a/packages/base/src/Internal/Sparse.hs b/packages/base/src/Internal/Sparse.hs
index fbea11a..423b169 100644
--- a/packages/base/src/Internal/Sparse.hs
+++ b/packages/base/src/Internal/Sparse.hs
@@ -20,7 +20,7 @@ import Data.Function(on)
20import Control.Arrow((***)) 20import Control.Arrow((***))
21import Control.Monad(when) 21import Control.Monad(when)
22import Data.List(groupBy, sort) 22import Data.List(groupBy, sort)
23import Foreign.C.Types(CInt(..)) 23import Data.Int
24 24
25import Internal.Devel 25import Internal.Devel
26import System.IO.Unsafe(unsafePerformIO) 26import System.IO.Unsafe(unsafePerformIO)
@@ -34,16 +34,16 @@ type AssocMatrix = [((Int,Int),Double)]
34 34
35data CSR = CSR 35data CSR = CSR
36 { csrVals :: Vector Double 36 { csrVals :: Vector Double
37 , csrCols :: Vector CInt 37 , csrCols :: Vector Int32
38 , csrRows :: Vector CInt 38 , csrRows :: Vector Int32
39 , csrNRows :: Int 39 , csrNRows :: Int
40 , csrNCols :: Int 40 , csrNCols :: Int
41 } deriving Show 41 } deriving Show
42 42
43data CSC = CSC 43data CSC = CSC
44 { cscVals :: Vector Double 44 { cscVals :: Vector Double
45 , cscRows :: Vector CInt 45 , cscRows :: Vector Int32
46 , cscCols :: Vector CInt 46 , cscCols :: Vector Int32
47 , cscNRows :: Int 47 , cscNRows :: Int
48 , cscNCols :: Int 48 , cscNCols :: Int
49 } deriving Show 49 } deriving Show
@@ -138,9 +138,9 @@ mkDiagR r c v
138 diagVals = v 138 diagVals = v
139 139
140 140
141type IV t = CInt -> Ptr CInt -> t 141type IV t = Int32 -> Ptr Int32 -> t
142type V t = CInt -> Ptr Double -> t 142type V t = Int32 -> Ptr Double -> t
143type SMxV = V (IV (IV (V (V (IO CInt))))) 143type SMxV = V (IV (IV (V (V (IO Int32)))))
144 144
145gmXv :: GMatrix -> Vector Double -> Vector Double 145gmXv :: GMatrix -> Vector Double -> Vector Double
146gmXv SparseR { gmCSR = CSR{..}, .. } v = unsafePerformIO $ do 146gmXv SparseR { gmCSR = CSR{..}, .. } v = unsafePerformIO $ do
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs
index f642e8d..6f3b4c8 100644
--- a/packages/base/src/Internal/Util.hs
+++ b/packages/base/src/Internal/Util.hs
@@ -83,6 +83,7 @@ import Control.Arrow((&&&),(***))
83import Data.Complex 83import Data.Complex
84import Data.Function(on) 84import Data.Function(on)
85import Internal.ST 85import Internal.ST
86import Foreign.Storable
86#if MIN_VERSION_base(4,11,0) 87#if MIN_VERSION_base(4,11,0)
87import Prelude hiding ((<>)) 88import Prelude hiding ((<>))
88#endif 89#endif
@@ -174,7 +175,7 @@ a & b = vjoin [a,b]
174 175
175-} 176-}
176infixl 3 ||| 177infixl 3 |||
177(|||) :: Element t => Matrix t -> Matrix t -> Matrix t 178(|||) :: Storable t => Matrix t -> Matrix t -> Matrix t
178a ||| b = fromBlocks [[a,b]] 179a ||| b = fromBlocks [[a,b]]
179 180
180-- | a synonym for ('|||') (unicode 0x00a6, broken bar) 181-- | a synonym for ('|||') (unicode 0x00a6, broken bar)
@@ -185,7 +186,7 @@ infixl 3 ¦
185 186
186-- | vertical concatenation 187-- | vertical concatenation
187-- 188--
188(===) :: Element t => Matrix t -> Matrix t -> Matrix t 189(===) :: Storable t => Matrix t -> Matrix t -> Matrix t
189infixl 2 === 190infixl 2 ===
190a === b = fromBlocks [[a],[b]] 191a === b = fromBlocks [[a],[b]]
191 192
@@ -225,7 +226,7 @@ col = asColumn . fromList
225 226
226-} 227-}
227infixl 9 ? 228infixl 9 ?
228(?) :: Element t => Matrix t -> [Int] -> Matrix t 229(?) :: Storable t => Matrix t -> [Int] -> Matrix t
229(?) = flip extractRows 230(?) = flip extractRows
230 231
231{- | extract columns 232{- | extract columns
@@ -240,7 +241,7 @@ infixl 9 ?
240 241
241-} 242-}
242infixl 9 ¿ 243infixl 9 ¿
243(¿) :: Element t => Matrix t -> [Int] -> Matrix t 244(¿) :: Storable t => Matrix t -> [Int] -> Matrix t
244(¿)= flip extractColumns 245(¿)= flip extractColumns
245 246
246 247
@@ -329,7 +330,7 @@ instance Normed (Vector (Complex Float))
329 norm_Inf = norm_Inf . double 330 norm_Inf = norm_Inf . double
330 331
331-- | Frobenius norm (Schatten p-norm with p=2) 332-- | Frobenius norm (Schatten p-norm with p=2)
332norm_Frob :: (Normed (Vector t), Element t) => Matrix t -> R 333norm_Frob :: (Normed (Vector t), Storable t) => Matrix t -> R
333norm_Frob = norm_2 . flatten 334norm_Frob = norm_2 . flatten
334 335
335-- | Sum of singular values (Schatten p-norm with p=1) 336-- | Sum of singular values (Schatten p-norm with p=1)
@@ -346,7 +347,7 @@ True
346True 347True
347 348
348-} 349-}
349magnit :: (Element t, Normed (Vector t)) => R -> t -> Bool 350magnit :: (Storable t, Normed (Vector t)) => R -> t -> Bool
350magnit e x = norm_1 (fromList [x]) > e 351magnit e x = norm_1 (fromList [x]) > e
351 352
352 353
@@ -415,7 +416,7 @@ instance Indexable (Vector (Complex Float)) (Complex Float)
415 where 416 where
416 (!) = (@>) 417 (!) = (@>)
417 418
418instance Element t => Indexable (Matrix t) (Vector t) 419instance Storable t => Indexable (Matrix t) (Vector t)
419 where 420 where
420 m!j = subVector (j*c) c (flatten m) 421 m!j = subVector (j*c) c (flatten m)
421 where 422 where
diff --git a/packages/base/src/Internal/Vector.hs b/packages/base/src/Internal/Vector.hs
index 6271bb6..3037019 100644
--- a/packages/base/src/Internal/Vector.hs
+++ b/packages/base/src/Internal/Vector.hs
@@ -32,7 +32,7 @@ import Foreign.ForeignPtr
32import Foreign.Ptr 32import Foreign.Ptr
33import Foreign.Storable 33import Foreign.Storable
34import Foreign.C.Types(CInt) 34import Foreign.C.Types(CInt)
35import Data.Int(Int64) 35import Data.Int
36import Data.Complex 36import Data.Complex
37import System.IO.Unsafe(unsafePerformIO) 37import System.IO.Unsafe(unsafePerformIO)
38import GHC.ForeignPtr(mallocPlainForeignPtrBytes) 38import GHC.ForeignPtr(mallocPlainForeignPtrBytes)
@@ -46,18 +46,18 @@ import Control.Monad(replicateM)
46import qualified Data.ByteString.Internal as BS 46import qualified Data.ByteString.Internal as BS
47import Data.Vector.Storable.Internal(updPtr) 47import Data.Vector.Storable.Internal(updPtr)
48 48
49type I = CInt 49type I = Int32
50type Z = Int64 50type Z = Int64
51type R = Double 51type R = Double
52type C = Complex Double 52type C = Complex Double
53 53
54 54
55-- | specialized fromIntegral 55-- | specialized fromIntegral
56fi :: Int -> CInt 56fi :: Int -> Int32
57fi = fromIntegral 57fi = fromIntegral
58 58
59-- | specialized fromIntegral 59-- | specialized fromIntegral
60ti :: CInt -> Int 60ti :: Int32 -> Int
61ti = fromIntegral 61ti = fromIntegral
62 62
63 63
@@ -69,7 +69,7 @@ dim = Vector.length
69 69
70-- C-Haskell vector adapter 70-- C-Haskell vector adapter
71{-# INLINE avec #-} 71{-# INLINE avec #-}
72avec :: Storable a => Vector a -> (f -> IO r) -> ((CInt -> Ptr a -> f) -> IO r) 72avec :: Storable a => Vector a -> (f -> IO r) -> ((Int32 -> Ptr a -> f) -> IO r)
73avec v f g = unsafeWith v $ \ptr -> f (g (fromIntegral (Vector.length v)) ptr) 73avec v f g = unsafeWith v $ \ptr -> f (g (fromIntegral (Vector.length v)) ptr)
74 74
75-- allocates memory for a new vector 75-- allocates memory for a new vector
diff --git a/packages/base/src/Internal/Vectorized.hs b/packages/base/src/Internal/Vectorized.hs
index 32430c6..ede3826 100644
--- a/packages/base/src/Internal/Vectorized.hs
+++ b/packages/base/src/Internal/Vectorized.hs
@@ -18,10 +18,12 @@ module Internal.Vectorized where
18import Internal.Vector 18import Internal.Vector
19import Internal.Devel 19import Internal.Devel
20import Data.Complex 20import Data.Complex
21import Data.Function
22import Data.Int
21import Foreign.Marshal.Alloc(free,malloc) 23import Foreign.Marshal.Alloc(free,malloc)
22import Foreign.Marshal.Array(newArray,copyArray) 24import Foreign.Marshal.Array(newArray,copyArray)
23import Foreign.Ptr(Ptr) 25import Foreign.Ptr(Ptr)
24import Foreign.Storable(peek,Storable) 26import Foreign.Storable(peek,pokeElemOff,Storable)
25import Foreign.C.Types 27import Foreign.C.Types
26import Foreign.C.String 28import Foreign.C.String
27import System.IO.Unsafe(unsafePerformIO) 29import System.IO.Unsafe(unsafePerformIO)
@@ -36,8 +38,8 @@ a # b = applyRaw a b
36a #! b = a # b # id 38a #! b = a # b # id
37{-# INLINE (#!) #-} 39{-# INLINE (#!) #-}
38 40
39fromei :: Enum a => a -> CInt 41fromei :: Enum a => a -> Int32
40fromei x = fromIntegral (fromEnum x) :: CInt 42fromei x = fromIntegral (fromEnum x) :: Int32
41 43
42data FunCodeV = Sin 44data FunCodeV = Sin
43 | Cos 45 | Cos
@@ -103,20 +105,20 @@ sumQ = sumg c_sumQ
103sumC :: Vector (Complex Double) -> Complex Double 105sumC :: Vector (Complex Double) -> Complex Double
104sumC = sumg c_sumC 106sumC = sumg c_sumC
105 107
106sumI :: ( TransRaw c (CInt -> Ptr a -> IO CInt) ~ (CInt -> Ptr I -> I :> Ok) 108sumI :: ( TransRaw c (Int32 -> Ptr a -> IO Int32) ~ (Int32 -> Ptr I -> I :> Ok)
107 , TransArray c 109 , TransArray c
108 , Storable a 110 , Storable a
109 ) 111 )
110 => I -> c -> a 112 => I -> c -> a
111sumI m = sumg (c_sumI m) 113sumI m = sumg (c_sumI m)
112 114
113sumL :: ( TransRaw c (CInt -> Ptr a -> IO CInt) ~ (CInt -> Ptr Z -> Z :> Ok) 115sumL :: ( TransRaw c (Int32 -> Ptr a -> IO Int32) ~ (Int32 -> Ptr Z -> Z :> Ok)
114 , TransArray c 116 , TransArray c
115 , Storable a 117 , Storable a
116 ) => Z -> c -> a 118 ) => Z -> c -> a
117sumL m = sumg (c_sumL m) 119sumL m = sumg (c_sumL m)
118 120
119sumg :: (TransArray c, Storable a) => TransRaw c (CInt -> Ptr a -> IO CInt) -> c -> a 121sumg :: (TransArray c, Storable a) => TransRaw c (Int32 -> Ptr a -> IO Int32) -> c -> a
120sumg f x = unsafePerformIO $ do 122sumg f x = unsafePerformIO $ do
121 r <- createVector 1 123 r <- createVector 1
122 (x #! r) f #| "sum" 124 (x #! r) f #| "sum"
@@ -154,7 +156,7 @@ prodL :: Z-> Vector Z -> Z
154prodL = prodg . c_prodL 156prodL = prodg . c_prodL
155 157
156prodg :: (TransArray c, Storable a) 158prodg :: (TransArray c, Storable a)
157 => TransRaw c (CInt -> Ptr a -> IO CInt) -> c -> a 159 => TransRaw c (Int32 -> Ptr a -> IO Int32) -> c -> a
158prodg f x = unsafePerformIO $ do 160prodg f x = unsafePerformIO $ do
159 r <- createVector 1 161 r <- createVector 1
160 (x #! r) f #| "prod" 162 (x #! r) f #| "prod"
@@ -171,7 +173,7 @@ foreign import ccall unsafe "prodL" c_prodL :: Z -> TVV Z
171------------------------------------------------------------------ 173------------------------------------------------------------------
172 174
173toScalarAux :: (Enum a, TransArray c, Storable a1) 175toScalarAux :: (Enum a, TransArray c, Storable a1)
174 => (CInt -> TransRaw c (CInt -> Ptr a1 -> IO CInt)) -> a -> c -> a1 176 => (Int32 -> TransRaw c (Int32 -> Ptr a1 -> IO Int32)) -> a -> c -> a1
175toScalarAux fun code v = unsafePerformIO $ do 177toScalarAux fun code v = unsafePerformIO $ do
176 r <- createVector 1 178 r <- createVector 1
177 (v #! r) (fun (fromei code)) #|"toScalarAux" 179 (v #! r) (fun (fromei code)) #|"toScalarAux"
@@ -179,7 +181,7 @@ toScalarAux fun code v = unsafePerformIO $ do
179 181
180 182
181vectorMapAux :: (Enum a, Storable t, Storable a1) 183vectorMapAux :: (Enum a, Storable t, Storable a1)
182 => (CInt -> CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt) 184 => (Int32 -> Int32 -> Ptr t -> Int32 -> Ptr a1 -> IO Int32)
183 -> a -> Vector t -> Vector a1 185 -> a -> Vector t -> Vector a1
184vectorMapAux fun code v = unsafePerformIO $ do 186vectorMapAux fun code v = unsafePerformIO $ do
185 r <- createVector (dim v) 187 r <- createVector (dim v)
@@ -187,7 +189,7 @@ vectorMapAux fun code v = unsafePerformIO $ do
187 return r 189 return r
188 190
189vectorMapValAux :: (Enum a, Storable a2, Storable t, Storable a1) 191vectorMapValAux :: (Enum a, Storable a2, Storable t, Storable a1)
190 => (CInt -> Ptr a2 -> CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt) 192 => (Int32 -> Ptr a2 -> Int32 -> Ptr t -> Int32 -> Ptr a1 -> IO Int32)
191 -> a -> a2 -> Vector t -> Vector a1 193 -> a -> a2 -> Vector t -> Vector a1
192vectorMapValAux fun code val v = unsafePerformIO $ do 194vectorMapValAux fun code val v = unsafePerformIO $ do
193 r <- createVector (dim v) 195 r <- createVector (dim v)
@@ -197,7 +199,7 @@ vectorMapValAux fun code val v = unsafePerformIO $ do
197 return r 199 return r
198 200
199vectorZipAux :: (Enum a, TransArray c, Storable t, Storable a1) 201vectorZipAux :: (Enum a, TransArray c, Storable t, Storable a1)
200 => (CInt -> CInt -> Ptr t -> TransRaw c (CInt -> Ptr a1 -> IO CInt)) 202 => (Int32 -> Int32 -> Ptr t -> TransRaw c (Int32 -> Ptr a1 -> IO Int32))
201 -> a -> Vector t -> c -> Vector a1 203 -> a -> Vector t -> c -> Vector a1
202vectorZipAux fun code u v = unsafePerformIO $ do 204vectorZipAux fun code u v = unsafePerformIO $ do
203 r <- createVector (dim u) 205 r <- createVector (dim u)
@@ -210,37 +212,37 @@ vectorZipAux fun code u v = unsafePerformIO $ do
210toScalarR :: FunCodeS -> Vector Double -> Double 212toScalarR :: FunCodeS -> Vector Double -> Double
211toScalarR oper = toScalarAux c_toScalarR (fromei oper) 213toScalarR oper = toScalarAux c_toScalarR (fromei oper)
212 214
213foreign import ccall unsafe "toScalarR" c_toScalarR :: CInt -> TVV Double 215foreign import ccall unsafe "toScalarR" c_toScalarR :: Int32 -> TVV Double
214 216
215-- | obtains different functions of a vector: norm1, norm2, max, min, posmax, posmin, etc. 217-- | obtains different functions of a vector: norm1, norm2, max, min, posmax, posmin, etc.
216toScalarF :: FunCodeS -> Vector Float -> Float 218toScalarF :: FunCodeS -> Vector Float -> Float
217toScalarF oper = toScalarAux c_toScalarF (fromei oper) 219toScalarF oper = toScalarAux c_toScalarF (fromei oper)
218 220
219foreign import ccall unsafe "toScalarF" c_toScalarF :: CInt -> TVV Float 221foreign import ccall unsafe "toScalarF" c_toScalarF :: Int32 -> TVV Float
220 222
221-- | obtains different functions of a vector: only norm1, norm2 223-- | obtains different functions of a vector: only norm1, norm2
222toScalarC :: FunCodeS -> Vector (Complex Double) -> Double 224toScalarC :: FunCodeS -> Vector (Complex Double) -> Double
223toScalarC oper = toScalarAux c_toScalarC (fromei oper) 225toScalarC oper = toScalarAux c_toScalarC (fromei oper)
224 226
225foreign import ccall unsafe "toScalarC" c_toScalarC :: CInt -> Complex Double :> Double :> Ok 227foreign import ccall unsafe "toScalarC" c_toScalarC :: Int32 -> Complex Double :> Double :> Ok
226 228
227-- | obtains different functions of a vector: only norm1, norm2 229-- | obtains different functions of a vector: only norm1, norm2
228toScalarQ :: FunCodeS -> Vector (Complex Float) -> Float 230toScalarQ :: FunCodeS -> Vector (Complex Float) -> Float
229toScalarQ oper = toScalarAux c_toScalarQ (fromei oper) 231toScalarQ oper = toScalarAux c_toScalarQ (fromei oper)
230 232
231foreign import ccall unsafe "toScalarQ" c_toScalarQ :: CInt -> Complex Float :> Float :> Ok 233foreign import ccall unsafe "toScalarQ" c_toScalarQ :: Int32 -> Complex Float :> Float :> Ok
232 234
233-- | obtains different functions of a vector: norm1, norm2, max, min, posmax, posmin, etc. 235-- | obtains different functions of a vector: norm1, norm2, max, min, posmax, posmin, etc.
234toScalarI :: FunCodeS -> Vector CInt -> CInt 236toScalarI :: FunCodeS -> Vector Int32 -> Int32
235toScalarI oper = toScalarAux c_toScalarI (fromei oper) 237toScalarI oper = toScalarAux c_toScalarI (fromei oper)
236 238
237foreign import ccall unsafe "toScalarI" c_toScalarI :: CInt -> TVV CInt 239foreign import ccall unsafe "toScalarI" c_toScalarI :: Int32 -> TVV Int32
238 240
239-- | obtains different functions of a vector: norm1, norm2, max, min, posmax, posmin, etc. 241-- | obtains different functions of a vector: norm1, norm2, max, min, posmax, posmin, etc.
240toScalarL :: FunCodeS -> Vector Z -> Z 242toScalarL :: FunCodeS -> Vector Z -> Z
241toScalarL oper = toScalarAux c_toScalarL (fromei oper) 243toScalarL oper = toScalarAux c_toScalarL (fromei oper)
242 244
243foreign import ccall unsafe "toScalarL" c_toScalarL :: CInt -> TVV Z 245foreign import ccall unsafe "toScalarL" c_toScalarL :: Int32 -> TVV Z
244 246
245 247
246------------------------------------------------------------------ 248------------------------------------------------------------------
@@ -249,37 +251,37 @@ foreign import ccall unsafe "toScalarL" c_toScalarL :: CInt -> TVV Z
249vectorMapR :: FunCodeV -> Vector Double -> Vector Double 251vectorMapR :: FunCodeV -> Vector Double -> Vector Double
250vectorMapR = vectorMapAux c_vectorMapR 252vectorMapR = vectorMapAux c_vectorMapR
251 253
252foreign import ccall unsafe "mapR" c_vectorMapR :: CInt -> TVV Double 254foreign import ccall unsafe "mapR" c_vectorMapR :: Int32 -> TVV Double
253 255
254-- | map of complex vectors with given function 256-- | map of complex vectors with given function
255vectorMapC :: FunCodeV -> Vector (Complex Double) -> Vector (Complex Double) 257vectorMapC :: FunCodeV -> Vector (Complex Double) -> Vector (Complex Double)
256vectorMapC oper = vectorMapAux c_vectorMapC (fromei oper) 258vectorMapC oper = vectorMapAux c_vectorMapC (fromei oper)
257 259
258foreign import ccall unsafe "mapC" c_vectorMapC :: CInt -> TVV (Complex Double) 260foreign import ccall unsafe "mapC" c_vectorMapC :: Int32 -> TVV (Complex Double)
259 261
260-- | map of real vectors with given function 262-- | map of real vectors with given function
261vectorMapF :: FunCodeV -> Vector Float -> Vector Float 263vectorMapF :: FunCodeV -> Vector Float -> Vector Float
262vectorMapF = vectorMapAux c_vectorMapF 264vectorMapF = vectorMapAux c_vectorMapF
263 265
264foreign import ccall unsafe "mapF" c_vectorMapF :: CInt -> TVV Float 266foreign import ccall unsafe "mapF" c_vectorMapF :: Int32 -> TVV Float
265 267
266-- | map of real vectors with given function 268-- | map of real vectors with given function
267vectorMapQ :: FunCodeV -> Vector (Complex Float) -> Vector (Complex Float) 269vectorMapQ :: FunCodeV -> Vector (Complex Float) -> Vector (Complex Float)
268vectorMapQ = vectorMapAux c_vectorMapQ 270vectorMapQ = vectorMapAux c_vectorMapQ
269 271
270foreign import ccall unsafe "mapQ" c_vectorMapQ :: CInt -> TVV (Complex Float) 272foreign import ccall unsafe "mapQ" c_vectorMapQ :: Int32 -> TVV (Complex Float)
271 273
272-- | map of real vectors with given function 274-- | map of real vectors with given function
273vectorMapI :: FunCodeV -> Vector CInt -> Vector CInt 275vectorMapI :: FunCodeV -> Vector Int32 -> Vector Int32
274vectorMapI = vectorMapAux c_vectorMapI 276vectorMapI = vectorMapAux c_vectorMapI
275 277
276foreign import ccall unsafe "mapI" c_vectorMapI :: CInt -> TVV CInt 278foreign import ccall unsafe "mapI" c_vectorMapI :: Int32 -> TVV Int32
277 279
278-- | map of real vectors with given function 280-- | map of real vectors with given function
279vectorMapL :: FunCodeV -> Vector Z -> Vector Z 281vectorMapL :: FunCodeV -> Vector Z -> Vector Z
280vectorMapL = vectorMapAux c_vectorMapL 282vectorMapL = vectorMapAux c_vectorMapL
281 283
282foreign import ccall unsafe "mapL" c_vectorMapL :: CInt -> TVV Z 284foreign import ccall unsafe "mapL" c_vectorMapL :: Int32 -> TVV Z
283 285
284------------------------------------------------------------------- 286-------------------------------------------------------------------
285 287
@@ -287,37 +289,37 @@ foreign import ccall unsafe "mapL" c_vectorMapL :: CInt -> TVV Z
287vectorMapValR :: FunCodeSV -> Double -> Vector Double -> Vector Double 289vectorMapValR :: FunCodeSV -> Double -> Vector Double -> Vector Double
288vectorMapValR oper = vectorMapValAux c_vectorMapValR (fromei oper) 290vectorMapValR oper = vectorMapValAux c_vectorMapValR (fromei oper)
289 291
290foreign import ccall unsafe "mapValR" c_vectorMapValR :: CInt -> Ptr Double -> TVV Double 292foreign import ccall unsafe "mapValR" c_vectorMapValR :: Int32 -> Ptr Double -> TVV Double
291 293
292-- | map of complex vectors with given function 294-- | map of complex vectors with given function
293vectorMapValC :: FunCodeSV -> Complex Double -> Vector (Complex Double) -> Vector (Complex Double) 295vectorMapValC :: FunCodeSV -> Complex Double -> Vector (Complex Double) -> Vector (Complex Double)
294vectorMapValC = vectorMapValAux c_vectorMapValC 296vectorMapValC = vectorMapValAux c_vectorMapValC
295 297
296foreign import ccall unsafe "mapValC" c_vectorMapValC :: CInt -> Ptr (Complex Double) -> TVV (Complex Double) 298foreign import ccall unsafe "mapValC" c_vectorMapValC :: Int32 -> Ptr (Complex Double) -> TVV (Complex Double)
297 299
298-- | map of real vectors with given function 300-- | map of real vectors with given function
299vectorMapValF :: FunCodeSV -> Float -> Vector Float -> Vector Float 301vectorMapValF :: FunCodeSV -> Float -> Vector Float -> Vector Float
300vectorMapValF oper = vectorMapValAux c_vectorMapValF (fromei oper) 302vectorMapValF oper = vectorMapValAux c_vectorMapValF (fromei oper)
301 303
302foreign import ccall unsafe "mapValF" c_vectorMapValF :: CInt -> Ptr Float -> TVV Float 304foreign import ccall unsafe "mapValF" c_vectorMapValF :: Int32 -> Ptr Float -> TVV Float
303 305
304-- | map of complex vectors with given function 306-- | map of complex vectors with given function
305vectorMapValQ :: FunCodeSV -> Complex Float -> Vector (Complex Float) -> Vector (Complex Float) 307vectorMapValQ :: FunCodeSV -> Complex Float -> Vector (Complex Float) -> Vector (Complex Float)
306vectorMapValQ oper = vectorMapValAux c_vectorMapValQ (fromei oper) 308vectorMapValQ oper = vectorMapValAux c_vectorMapValQ (fromei oper)
307 309
308foreign import ccall unsafe "mapValQ" c_vectorMapValQ :: CInt -> Ptr (Complex Float) -> TVV (Complex Float) 310foreign import ccall unsafe "mapValQ" c_vectorMapValQ :: Int32 -> Ptr (Complex Float) -> TVV (Complex Float)
309 311
310-- | map of real vectors with given function 312-- | map of real vectors with given function
311vectorMapValI :: FunCodeSV -> CInt -> Vector CInt -> Vector CInt 313vectorMapValI :: FunCodeSV -> Int32 -> Vector Int32 -> Vector Int32
312vectorMapValI oper = vectorMapValAux c_vectorMapValI (fromei oper) 314vectorMapValI oper = vectorMapValAux c_vectorMapValI (fromei oper)
313 315
314foreign import ccall unsafe "mapValI" c_vectorMapValI :: CInt -> Ptr CInt -> TVV CInt 316foreign import ccall unsafe "mapValI" c_vectorMapValI :: Int32 -> Ptr Int32 -> TVV Int32
315 317
316-- | map of real vectors with given function 318-- | map of real vectors with given function
317vectorMapValL :: FunCodeSV -> Z -> Vector Z -> Vector Z 319vectorMapValL :: FunCodeSV -> Z -> Vector Z -> Vector Z
318vectorMapValL oper = vectorMapValAux c_vectorMapValL (fromei oper) 320vectorMapValL oper = vectorMapValAux c_vectorMapValL (fromei oper)
319 321
320foreign import ccall unsafe "mapValL" c_vectorMapValL :: CInt -> Ptr Z -> TVV Z 322foreign import ccall unsafe "mapValL" c_vectorMapValL :: Int32 -> Ptr Z -> TVV Z
321 323
322 324
323------------------------------------------------------------------- 325-------------------------------------------------------------------
@@ -328,42 +330,42 @@ type TVVV t = t :> t :> t :> Ok
328vectorZipR :: FunCodeVV -> Vector Double -> Vector Double -> Vector Double 330vectorZipR :: FunCodeVV -> Vector Double -> Vector Double -> Vector Double
329vectorZipR = vectorZipAux c_vectorZipR 331vectorZipR = vectorZipAux c_vectorZipR
330 332
331foreign import ccall unsafe "zipR" c_vectorZipR :: CInt -> TVVV Double 333foreign import ccall unsafe "zipR" c_vectorZipR :: Int32 -> TVVV Double
332 334
333-- | elementwise operation on complex vectors 335-- | elementwise operation on complex vectors
334vectorZipC :: FunCodeVV -> Vector (Complex Double) -> Vector (Complex Double) -> Vector (Complex Double) 336vectorZipC :: FunCodeVV -> Vector (Complex Double) -> Vector (Complex Double) -> Vector (Complex Double)
335vectorZipC = vectorZipAux c_vectorZipC 337vectorZipC = vectorZipAux c_vectorZipC
336 338
337foreign import ccall unsafe "zipC" c_vectorZipC :: CInt -> TVVV (Complex Double) 339foreign import ccall unsafe "zipC" c_vectorZipC :: Int32 -> TVVV (Complex Double)
338 340
339-- | elementwise operation on real vectors 341-- | elementwise operation on real vectors
340vectorZipF :: FunCodeVV -> Vector Float -> Vector Float -> Vector Float 342vectorZipF :: FunCodeVV -> Vector Float -> Vector Float -> Vector Float
341vectorZipF = vectorZipAux c_vectorZipF 343vectorZipF = vectorZipAux c_vectorZipF
342 344
343foreign import ccall unsafe "zipF" c_vectorZipF :: CInt -> TVVV Float 345foreign import ccall unsafe "zipF" c_vectorZipF :: Int32 -> TVVV Float
344 346
345-- | elementwise operation on complex vectors 347-- | elementwise operation on complex vectors
346vectorZipQ :: FunCodeVV -> Vector (Complex Float) -> Vector (Complex Float) -> Vector (Complex Float) 348vectorZipQ :: FunCodeVV -> Vector (Complex Float) -> Vector (Complex Float) -> Vector (Complex Float)
347vectorZipQ = vectorZipAux c_vectorZipQ 349vectorZipQ = vectorZipAux c_vectorZipQ
348 350
349foreign import ccall unsafe "zipQ" c_vectorZipQ :: CInt -> TVVV (Complex Float) 351foreign import ccall unsafe "zipQ" c_vectorZipQ :: Int32 -> TVVV (Complex Float)
350 352
351-- | elementwise operation on CInt vectors 353-- | elementwise operation on Int32 vectors
352vectorZipI :: FunCodeVV -> Vector CInt -> Vector CInt -> Vector CInt 354vectorZipI :: FunCodeVV -> Vector Int32 -> Vector Int32 -> Vector Int32
353vectorZipI = vectorZipAux c_vectorZipI 355vectorZipI = vectorZipAux c_vectorZipI
354 356
355foreign import ccall unsafe "zipI" c_vectorZipI :: CInt -> TVVV CInt 357foreign import ccall unsafe "zipI" c_vectorZipI :: Int32 -> TVVV Int32
356 358
357-- | elementwise operation on CInt vectors 359-- | elementwise operation on Int32 vectors
358vectorZipL :: FunCodeVV -> Vector Z -> Vector Z -> Vector Z 360vectorZipL :: FunCodeVV -> Vector Z -> Vector Z -> Vector Z
359vectorZipL = vectorZipAux c_vectorZipL 361vectorZipL = vectorZipAux c_vectorZipL
360 362
361foreign import ccall unsafe "zipL" c_vectorZipL :: CInt -> TVVV Z 363foreign import ccall unsafe "zipL" c_vectorZipL :: Int32 -> TVVV Z
362 364
363-------------------------------------------------------------------------------- 365--------------------------------------------------------------------------------
364 366
365foreign import ccall unsafe "vectorScan" c_vectorScan 367foreign import ccall unsafe "vectorScan" c_vectorScan
366 :: CString -> Ptr CInt -> Ptr (Ptr Double) -> IO CInt 368 :: CString -> Ptr Int32 -> Ptr (Ptr Double) -> IO Int32
367 369
368vectorScan :: FilePath -> IO (Vector Double) 370vectorScan :: FilePath -> IO (Vector Double)
369vectorScan s = do 371vectorScan s = do
@@ -401,7 +403,7 @@ randomVector seed dist n = unsafePerformIO $ do
401 (r # id) (c_random_vector (fi seed) ((fi.fromEnum) dist)) #|"randomVector" 403 (r # id) (c_random_vector (fi seed) ((fi.fromEnum) dist)) #|"randomVector"
402 return r 404 return r
403 405
404foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> Double :> Ok 406foreign import ccall unsafe "random_vector" c_random_vector :: Int32 -> Int32 -> Double :> Ok
405 407
406-------------------------------------------------------------------------------- 408--------------------------------------------------------------------------------
407 409
@@ -426,7 +428,7 @@ range n = unsafePerformIO $ do
426 (r # id) c_range_vector #|"range" 428 (r # id) c_range_vector #|"range"
427 return r 429 return r
428 430
429foreign import ccall unsafe "range_vector" c_range_vector :: CInt :> Ok 431foreign import ccall unsafe "range_vector" c_range_vector :: Int32 :> Ok
430 432
431 433
432float2DoubleV :: Vector Float -> Vector Double 434float2DoubleV :: Vector Float -> Vector Double
@@ -435,10 +437,10 @@ float2DoubleV = tog c_float2double
435double2FloatV :: Vector Double -> Vector Float 437double2FloatV :: Vector Double -> Vector Float
436double2FloatV = tog c_double2float 438double2FloatV = tog c_double2float
437 439
438double2IntV :: Vector Double -> Vector CInt 440double2IntV :: Vector Double -> Vector Int32
439double2IntV = tog c_double2int 441double2IntV = tog c_double2int
440 442
441int2DoubleV :: Vector CInt -> Vector Double 443int2DoubleV :: Vector Int32 -> Vector Double
442int2DoubleV = tog c_int2double 444int2DoubleV = tog c_int2double
443 445
444double2longV :: Vector Double -> Vector Z 446double2longV :: Vector Double -> Vector Z
@@ -448,10 +450,10 @@ long2DoubleV :: Vector Z -> Vector Double
448long2DoubleV = tog c_long2double 450long2DoubleV = tog c_long2double
449 451
450 452
451float2IntV :: Vector Float -> Vector CInt 453float2IntV :: Vector Float -> Vector Int32
452float2IntV = tog c_float2int 454float2IntV = tog c_float2int
453 455
454int2floatV :: Vector CInt -> Vector Float 456int2floatV :: Vector Int32 -> Vector Float
455int2floatV = tog c_int2float 457int2floatV = tog c_int2float
456 458
457int2longV :: Vector I -> Vector Z 459int2longV :: Vector I -> Vector Z
@@ -462,7 +464,7 @@ long2intV = tog c_long2int
462 464
463 465
464tog :: (Storable t, Storable a) 466tog :: (Storable t, Storable a)
465 => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a 467 => (Int32 -> Ptr t -> Int32 -> Ptr a -> IO Int32) -> Vector t -> Vector a
466tog f v = unsafePerformIO $ do 468tog f v = unsafePerformIO $ do
467 r <- createVector (dim v) 469 r <- createVector (dim v)
468 (v #! r) f #|"tog" 470 (v #! r) f #|"tog"
@@ -470,12 +472,12 @@ tog f v = unsafePerformIO $ do
470 472
471foreign import ccall unsafe "float2double" c_float2double :: Float :> Double :> Ok 473foreign import ccall unsafe "float2double" c_float2double :: Float :> Double :> Ok
472foreign import ccall unsafe "double2float" c_double2float :: Double :> Float :> Ok 474foreign import ccall unsafe "double2float" c_double2float :: Double :> Float :> Ok
473foreign import ccall unsafe "int2double" c_int2double :: CInt :> Double :> Ok 475foreign import ccall unsafe "int2double" c_int2double :: Int32 :> Double :> Ok
474foreign import ccall unsafe "double2int" c_double2int :: Double :> CInt :> Ok 476foreign import ccall unsafe "double2int" c_double2int :: Double :> Int32 :> Ok
475foreign import ccall unsafe "long2double" c_long2double :: Z :> Double :> Ok 477foreign import ccall unsafe "long2double" c_long2double :: Z :> Double :> Ok
476foreign import ccall unsafe "double2long" c_double2long :: Double :> Z :> Ok 478foreign import ccall unsafe "double2long" c_double2long :: Double :> Z :> Ok
477foreign import ccall unsafe "int2float" c_int2float :: CInt :> Float :> Ok 479foreign import ccall unsafe "int2float" c_int2float :: Int32 :> Float :> Ok
478foreign import ccall unsafe "float2int" c_float2int :: Float :> CInt :> Ok 480foreign import ccall unsafe "float2int" c_float2int :: Float :> Int32 :> Ok
479foreign import ccall unsafe "int2long" c_int2long :: I :> Z :> Ok 481foreign import ccall unsafe "int2long" c_int2long :: I :> Z :> Ok
480foreign import ccall unsafe "long2int" c_long2int :: Z :> I :> Ok 482foreign import ccall unsafe "long2int" c_long2int :: Z :> I :> Ok
481 483
@@ -483,7 +485,7 @@ foreign import ccall unsafe "long2int" c_long2int :: Z :> I :> Ok
483--------------------------------------------------------------- 485---------------------------------------------------------------
484 486
485stepg :: (Storable t, Storable a) 487stepg :: (Storable t, Storable a)
486 => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a 488 => (Int32 -> Ptr t -> Int32 -> Ptr a -> IO Int32) -> Vector t -> Vector a
487stepg f v = unsafePerformIO $ do 489stepg f v = unsafePerformIO $ do
488 r <- createVector (dim v) 490 r <- createVector (dim v)
489 (v #! r) f #|"step" 491 (v #! r) f #|"step"
@@ -495,7 +497,7 @@ stepD = stepg c_stepD
495stepF :: Vector Float -> Vector Float 497stepF :: Vector Float -> Vector Float
496stepF = stepg c_stepF 498stepF = stepg c_stepF
497 499
498stepI :: Vector CInt -> Vector CInt 500stepI :: Vector Int32 -> Vector Int32
499stepI = stepg c_stepI 501stepI = stepg c_stepI
500 502
501stepL :: Vector Z -> Vector Z 503stepL :: Vector Z -> Vector Z
@@ -504,13 +506,13 @@ stepL = stepg c_stepL
504 506
505foreign import ccall unsafe "stepF" c_stepF :: TVV Float 507foreign import ccall unsafe "stepF" c_stepF :: TVV Float
506foreign import ccall unsafe "stepD" c_stepD :: TVV Double 508foreign import ccall unsafe "stepD" c_stepD :: TVV Double
507foreign import ccall unsafe "stepI" c_stepI :: TVV CInt 509foreign import ccall unsafe "stepI" c_stepI :: TVV Int32
508foreign import ccall unsafe "stepL" c_stepL :: TVV Z 510foreign import ccall unsafe "stepL" c_stepL :: TVV Z
509 511
510-------------------------------------------------------------------------------- 512--------------------------------------------------------------------------------
511 513
512conjugateAux :: (Storable t, Storable a) 514conjugateAux :: (Storable t, Storable a)
513 => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a 515 => (Int32 -> Ptr t -> Int32 -> Ptr a -> IO Int32) -> Vector t -> Vector a
514conjugateAux fun x = unsafePerformIO $ do 516conjugateAux fun x = unsafePerformIO $ do
515 v <- createVector (dim x) 517 v <- createVector (dim x)
516 (x #! v) fun #|"conjugateAux" 518 (x #! v) fun #|"conjugateAux"
@@ -536,22 +538,29 @@ cloneVector v = do
536 538
537-------------------------------------------------------------------------------- 539--------------------------------------------------------------------------------
538 540
539constantAux :: (Storable a1, Storable a) 541constantAux :: Storable a => a -> Int -> Vector a
540 => (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a 542constantAux x n = unsafePerformIO $ do
541constantAux fun x n = unsafePerformIO $ do
542 v <- createVector n 543 v <- createVector n
543 px <- newArray [x] 544 px <- newArray [x]
544 (v # id) (fun px) #|"constantAux" 545 (v # id) (constantStorable px) #|"constantAux"
545 free px 546 free px
546 return v 547 return v
547 548
549constantStorable :: Storable a => Ptr a -> Int32 -> Ptr a -> IO Int32
550constantStorable pval n p = do
551 val <- peek pval
552 ($ 0) $ fix $ \iloop i -> when (i<n) $ do
553 pokeElemOff p (fromIntegral i) val
554 iloop $! succ i
555 return 0
556
548type TConst t = Ptr t -> t :> Ok 557type TConst t = Ptr t -> t :> Ok
549 558
550foreign import ccall unsafe "constantF" cconstantF :: TConst Float 559foreign import ccall unsafe "constantF" cconstantF :: TConst Float
551foreign import ccall unsafe "constantR" cconstantR :: TConst Double 560foreign import ccall unsafe "constantR" cconstantR :: TConst Double
552foreign import ccall unsafe "constantQ" cconstantQ :: TConst (Complex Float) 561foreign import ccall unsafe "constantQ" cconstantQ :: TConst (Complex Float)
553foreign import ccall unsafe "constantC" cconstantC :: TConst (Complex Double) 562foreign import ccall unsafe "constantC" cconstantC :: TConst (Complex Double)
554foreign import ccall unsafe "constantI" cconstantI :: TConst CInt 563foreign import ccall unsafe "constantI" cconstantI :: TConst Int32
555foreign import ccall unsafe "constantL" cconstantL :: TConst Z 564foreign import ccall unsafe "constantL" cconstantL :: TConst Z
556 565
557---------------------------------------------------------------------- 566----------------------------------------------------------------------
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs
index 9670187..a0a23bd 100644
--- a/packages/base/src/Numeric/LinearAlgebra.hs
+++ b/packages/base/src/Numeric/LinearAlgebra.hs
@@ -167,7 +167,7 @@ module Numeric.LinearAlgebra (
167 haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv, 167 haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv,
168 iC, sym, mTm, trustSym, unSym, 168 iC, sym, mTm, trustSym, unSym,
169 -- * Auxiliary classes 169 -- * Auxiliary classes
170 Element, Container, Product, Numeric, LSDiv, Herm, 170 Container, Product, Numeric, LSDiv, Herm,
171 Complexable, RealElement, 171 Complexable, RealElement,
172 RealOf, ComplexOf, SingleOf, DoubleOf, 172 RealOf, ComplexOf, SingleOf, DoubleOf,
173 IndexOf, 173 IndexOf,
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs
index 97cfd01..12eddb2 100644
--- a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs
+++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs
@@ -33,6 +33,7 @@ import System.Random
33import Numeric.LinearAlgebra.HMatrix hiding (vector) 33import Numeric.LinearAlgebra.HMatrix hiding (vector)
34import Control.Monad(replicateM) 34import Control.Monad(replicateM)
35import Test.QuickCheck(Arbitrary,arbitrary,choose,vector,sized,shrink) 35import Test.QuickCheck(Arbitrary,arbitrary,choose,vector,sized,shrink)
36import Foreign.Storable
36 37
37import GHC.TypeLits 38import GHC.TypeLits
38import Data.Proxy (Proxy(..)) 39import Data.Proxy (Proxy(..))
@@ -69,7 +70,7 @@ instance KnownNat n => Arbitrary (Static.R n) where
69 70
70 shrink _v = [] 71 shrink _v = []
71 72
72instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where 73instance (Storable a, Arbitrary a) => Arbitrary (Matrix a) where
73 arbitrary = do 74 arbitrary = do
74 m <- chooseDim 75 m <- chooseDim
75 n <- chooseDim 76 n <- chooseDim
@@ -98,7 +99,7 @@ instance (KnownNat n, KnownNat m) => Arbitrary (Static.L m n) where
98 99
99-- a square matrix 100-- a square matrix
100newtype (Sq a) = Sq (Matrix a) deriving Show 101newtype (Sq a) = Sq (Matrix a) deriving Show
101instance (Element a, Arbitrary a) => Arbitrary (Sq a) where 102instance (Storable a, Arbitrary a) => Arbitrary (Sq a) where
102 arbitrary = do 103 arbitrary = do
103 n <- chooseDim 104 n <- chooseDim
104 l <- vector (n*n) 105 l <- vector (n*n)
@@ -141,7 +142,7 @@ instance (Field a, Arbitrary a, Num (Vector a)) => Arbitrary (Herm a) where
141 return $ sym m' 142 return $ sym m'
142 143
143 144
144class (Field a, Arbitrary a, Element (RealOf a), Random (RealOf a)) => ArbitraryField a 145class (Field a, Arbitrary a, Storable (RealOf a), Random (RealOf a)) => ArbitraryField a
145instance ArbitraryField Double 146instance ArbitraryField Double
146instance ArbitraryField (Complex Double) 147instance ArbitraryField (Complex Double)
147 148