summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal')
-rw-r--r--packages/base/src/Internal/Algorithms.hs2
-rw-r--r--packages/base/src/Internal/CG.hs2
-rw-r--r--packages/base/src/Internal/Chain.hs2
-rw-r--r--packages/base/src/Internal/Devel.hs1
-rw-r--r--packages/base/src/Internal/Element.hs23
-rw-r--r--packages/base/src/Internal/IO.hs16
-rw-r--r--packages/base/src/Internal/LAPACK.hs2
-rw-r--r--packages/base/src/Internal/Matrix.hs87
-rw-r--r--packages/base/src/Internal/Modular.hs3
-rw-r--r--packages/base/src/Internal/Numeric.hs2
-rw-r--r--packages/base/src/Internal/ST.hs12
-rw-r--r--packages/base/src/Internal/Sparse.hs2
-rw-r--r--packages/base/src/Internal/Static.hs2
-rw-r--r--packages/base/src/Internal/Util.hs2
-rw-r--r--packages/base/src/Internal/Vector.hs7
-rw-r--r--packages/base/src/Internal/Vectorized.hs36
16 files changed, 192 insertions, 9 deletions
diff --git a/packages/base/src/Internal/Algorithms.hs b/packages/base/src/Internal/Algorithms.hs
index 99c9e34..cea06ce 100644
--- a/packages/base/src/Internal/Algorithms.hs
+++ b/packages/base/src/Internal/Algorithms.hs
@@ -4,6 +4,8 @@
4{-# LANGUAGE UndecidableInstances #-} 4{-# LANGUAGE UndecidableInstances #-}
5{-# LANGUAGE TypeFamilies #-} 5{-# LANGUAGE TypeFamilies #-}
6 6
7{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
8
7----------------------------------------------------------------------------- 9-----------------------------------------------------------------------------
8{- | 10{- |
9Module : Internal.Algorithms 11Module : Internal.Algorithms
diff --git a/packages/base/src/Internal/CG.hs b/packages/base/src/Internal/CG.hs
index cc10ad8..29edd35 100644
--- a/packages/base/src/Internal/CG.hs
+++ b/packages/base/src/Internal/CG.hs
@@ -1,6 +1,8 @@
1{-# LANGUAGE FlexibleContexts, FlexibleInstances #-} 1{-# LANGUAGE FlexibleContexts, FlexibleInstances #-}
2{-# LANGUAGE RecordWildCards #-} 2{-# LANGUAGE RecordWildCards #-}
3 3
4{-# OPTIONS_GHC -fno-warn-orphans #-}
5
4module Internal.CG( 6module Internal.CG(
5 cgSolve, cgSolve', 7 cgSolve, cgSolve',
6 CGState(..), R, V 8 CGState(..), R, V
diff --git a/packages/base/src/Internal/Chain.hs b/packages/base/src/Internal/Chain.hs
index f87eb02..4000c2b 100644
--- a/packages/base/src/Internal/Chain.hs
+++ b/packages/base/src/Internal/Chain.hs
@@ -1,5 +1,7 @@
1{-# LANGUAGE FlexibleContexts #-} 1{-# LANGUAGE FlexibleContexts #-}
2 2
3{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
4
3----------------------------------------------------------------------------- 5-----------------------------------------------------------------------------
4-- | 6-- |
5-- Module : Internal.Chain 7-- Module : Internal.Chain
diff --git a/packages/base/src/Internal/Devel.hs b/packages/base/src/Internal/Devel.hs
index 3887663..f72d8aa 100644
--- a/packages/base/src/Internal/Devel.hs
+++ b/packages/base/src/Internal/Devel.hs
@@ -54,6 +54,7 @@ check msg f = do
54 54
55-- | postfix error code check 55-- | postfix error code check
56infixl 0 #| 56infixl 0 #|
57(#|) :: IO CInt -> String -> IO ()
57(#|) = flip check 58(#|) = flip check
58 59
59-- | Error capture and conversion to Maybe 60-- | Error capture and conversion to Maybe
diff --git a/packages/base/src/Internal/Element.hs b/packages/base/src/Internal/Element.hs
index eb3a25b..2e330ee 100644
--- a/packages/base/src/Internal/Element.hs
+++ b/packages/base/src/Internal/Element.hs
@@ -4,6 +4,8 @@
4{-# LANGUAGE UndecidableInstances #-} 4{-# LANGUAGE UndecidableInstances #-}
5{-# LANGUAGE MultiParamTypeClasses #-} 5{-# LANGUAGE MultiParamTypeClasses #-}
6 6
7{-# OPTIONS_GHC -fno-warn-orphans #-}
8
7----------------------------------------------------------------------------- 9-----------------------------------------------------------------------------
8-- | 10-- |
9-- Module : Data.Packed.Matrix 11-- Module : Data.Packed.Matrix
@@ -31,6 +33,7 @@ import Data.List.Split(chunksOf)
31import Foreign.Storable(Storable) 33import Foreign.Storable(Storable)
32import System.IO.Unsafe(unsafePerformIO) 34import System.IO.Unsafe(unsafePerformIO)
33import Control.Monad(liftM) 35import Control.Monad(liftM)
36import Foreign.C.Types(CInt)
34 37
35------------------------------------------------------------------- 38-------------------------------------------------------------------
36 39
@@ -53,8 +56,10 @@ instance (Show a, Element a) => (Show (Matrix a)) where
53 show m | rows m == 0 || cols m == 0 = sizes m ++" []" 56 show m | rows m == 0 || cols m == 0 = sizes m ++" []"
54 show m = (sizes m++) . dsp . map (map show) . toLists $ m 57 show m = (sizes m++) . dsp . map (map show) . toLists $ m
55 58
59sizes :: Matrix t -> [Char]
56sizes m = "("++show (rows m)++"><"++show (cols m)++")\n" 60sizes m = "("++show (rows m)++"><"++show (cols m)++")\n"
57 61
62dsp :: [[[Char]]] -> [Char]
58dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp 63dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp
59 where 64 where
60 mt = transpose as 65 mt = transpose as
@@ -73,6 +78,7 @@ instance (Element a, Read a) => Read (Matrix a) where
73 rs = read . snd . breakAt '(' .init . fst . breakAt '>' $ dims 78 rs = read . snd . breakAt '(' .init . fst . breakAt '>' $ dims
74 79
75 80
81breakAt :: Eq a => a -> [a] -> ([a], [a])
76breakAt c l = (a++[c],tail b) where 82breakAt c l = (a++[c],tail b) where
77 (a,b) = break (==c) l 83 (a,b) = break (==c) l
78 84
@@ -88,7 +94,8 @@ data Extractor
88 | Drop Int 94 | Drop Int
89 | DropLast Int 95 | DropLast Int
90 deriving Show 96 deriving Show
91 97
98ppext :: Extractor -> [Char]
92ppext All = ":" 99ppext All = ":"
93ppext (Range a 1 c) = printf "%d:%d" a c 100ppext (Range a 1 c) = printf "%d:%d" a c
94ppext (Range a b c) = printf "%d:%d:%d" a b c 101ppext (Range a b c) = printf "%d:%d:%d" a b c
@@ -128,10 +135,14 @@ ppext (DropLast n) = printf "DropLast %d" n
128infixl 9 ?? 135infixl 9 ??
129(??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t 136(??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t
130 137
138minEl :: Vector CInt -> CInt
131minEl = toScalarI Min 139minEl = toScalarI Min
140maxEl :: Vector CInt -> CInt
132maxEl = toScalarI Max 141maxEl = toScalarI Max
142cmodi :: Foreign.C.Types.CInt -> Vector Foreign.C.Types.CInt -> Vector Foreign.C.Types.CInt
133cmodi = vectorMapValI ModVS 143cmodi = vectorMapValI ModVS
134 144
145extractError :: Matrix t1 -> (Extractor, Extractor) -> t
135extractError m (e1,e2)= error $ printf "can't extract (%s,%s) from matrix %dx%d" (ppext e1::String) (ppext e2::String) (rows m) (cols m) 146extractError m (e1,e2)= error $ printf "can't extract (%s,%s) from matrix %dx%d" (ppext e1::String) (ppext e2::String) (rows m) (cols m)
136 147
137m ?? (Range a s b,e) | s /= 1 = m ?? (Pos (idxs [a,a+s .. b]), e) 148m ?? (Range a s b,e) | s /= 1 = m ?? (Pos (idxs [a,a+s .. b]), e)
@@ -232,8 +243,10 @@ disp = putStr . dispf 2
232fromBlocks :: Element t => [[Matrix t]] -> Matrix t 243fromBlocks :: Element t => [[Matrix t]] -> Matrix t
233fromBlocks = fromBlocksRaw . adaptBlocks 244fromBlocks = fromBlocksRaw . adaptBlocks
234 245
246fromBlocksRaw :: Element t => [[Matrix t]] -> Matrix t
235fromBlocksRaw mms = joinVert . map joinHoriz $ mms 247fromBlocksRaw mms = joinVert . map joinHoriz $ mms
236 248
249adaptBlocks :: Element t => [[Matrix t]] -> [[Matrix t]]
237adaptBlocks ms = ms' where 250adaptBlocks ms = ms' where
238 bc = case common length ms of 251 bc = case common length ms of
239 Just c -> c 252 Just c -> c
@@ -486,6 +499,9 @@ liftMatrix2Auto f m1 m2
486 m2' = conformMTo (r,c) m2 499 m2' = conformMTo (r,c) m2
487 500
488-- FIXME do not flatten if equal order 501-- FIXME do not flatten if equal order
502lM :: (Storable t, Element t1, Element t2)
503 => (Vector t1 -> Vector t2 -> Vector t)
504 -> Matrix t1 -> Matrix t2 -> Matrix t
489lM f m1 m2 = matrixFromVector 505lM f m1 m2 = matrixFromVector
490 RowMajor 506 RowMajor
491 (max' (rows m1) (rows m2)) 507 (max' (rows m1) (rows m2))
@@ -504,6 +520,7 @@ compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2
504 520
505------------------------------------------------------------ 521------------------------------------------------------------
506 522
523toBlockRows :: Element t => [Int] -> Matrix t -> [Matrix t]
507toBlockRows [r] m 524toBlockRows [r] m
508 | r == rows m = [m] 525 | r == rows m = [m]
509toBlockRows rs m 526toBlockRows rs m
@@ -513,6 +530,7 @@ toBlockRows rs m
513 szs = map (* cols m) rs 530 szs = map (* cols m) rs
514 g k = (k><0)[] 531 g k = (k><0)[]
515 532
533toBlockCols :: Element t => [Int] -> Matrix t -> [Matrix t]
516toBlockCols [c] m | c == cols m = [m] 534toBlockCols [c] m | c == cols m = [m]
517toBlockCols cs m = map trans . toBlockRows cs . trans $ m 535toBlockCols cs m = map trans . toBlockRows cs . trans $ m
518 536
@@ -576,7 +594,7 @@ Just (3><3)
576mapMatrixWithIndexM 594mapMatrixWithIndexM
577 :: (Element a, Storable b, Monad m) => 595 :: (Element a, Storable b, Monad m) =>
578 ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b) 596 ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b)
579mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m 597mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m
580 where 598 where
581 c = cols m 599 c = cols m
582 600
@@ -598,4 +616,3 @@ mapMatrixWithIndex g m = reshape c . mapVectorWithIndex (mk c g) . flatten $ m
598 616
599mapMatrix :: (Element a, Element b) => (a -> b) -> Matrix a -> Matrix b 617mapMatrix :: (Element a, Element b) => (a -> b) -> Matrix a -> Matrix b
600mapMatrix f = liftMatrix (mapVector f) 618mapMatrix f = liftMatrix (mapVector f)
601
diff --git a/packages/base/src/Internal/IO.hs b/packages/base/src/Internal/IO.hs
index a899cfd..b0f5606 100644
--- a/packages/base/src/Internal/IO.hs
+++ b/packages/base/src/Internal/IO.hs
@@ -20,7 +20,7 @@ import Internal.Devel
20import Internal.Vector 20import Internal.Vector
21import Internal.Matrix 21import Internal.Matrix
22import Internal.Vectorized 22import Internal.Vectorized
23import Text.Printf(printf) 23import Text.Printf(printf, PrintfArg, PrintfType)
24import Data.List(intersperse,transpose) 24import Data.List(intersperse,transpose)
25import Data.Complex 25import Data.Complex
26 26
@@ -78,12 +78,18 @@ disps d x = sdims x ++ " " ++ formatScaled d x
78dispf :: Int -> Matrix Double -> String 78dispf :: Int -> Matrix Double -> String
79dispf d x = sdims x ++ "\n" ++ formatFixed (if isInt x then 0 else d) x 79dispf d x = sdims x ++ "\n" ++ formatFixed (if isInt x then 0 else d) x
80 80
81sdims :: Matrix t -> [Char]
81sdims x = show (rows x) ++ "x" ++ show (cols x) 82sdims x = show (rows x) ++ "x" ++ show (cols x)
82 83
84formatFixed :: (Show a, Text.Printf.PrintfArg t, Element t)
85 => a -> Matrix t -> String
83formatFixed d x = format " " (printf ("%."++show d++"f")) $ x 86formatFixed d x = format " " (printf ("%."++show d++"f")) $ x
84 87
88isInt :: Matrix Double -> Bool
85isInt = all lookslikeInt . toList . flatten 89isInt = all lookslikeInt . toList . flatten
86 90
91formatScaled :: (Text.Printf.PrintfArg b, RealFrac b, Floating b, Num t, Element b, Show t)
92 => t -> Matrix b -> [Char]
87formatScaled dec t = "E"++show o++"\n" ++ ss 93formatScaled dec t = "E"++show o++"\n" ++ ss
88 where ss = format " " (printf fmt. g) t 94 where ss = format " " (printf fmt. g) t
89 g x | o >= 0 = x/10^(o::Int) 95 g x | o >= 0 = x/10^(o::Int)
@@ -133,14 +139,18 @@ showComplex d (a:+b)
133 s2 = if b<0 then "-" else "" 139 s2 = if b<0 then "-" else ""
134 s3 = if b<0 then "-" else "+" 140 s3 = if b<0 then "-" else "+"
135 141
142shcr :: (Show a, Show t1, Text.Printf.PrintfType t, Text.Printf.PrintfArg t1, RealFrac t1)
143 => a -> t1 -> t
136shcr d a | lookslikeInt a = printf "%.0f" a 144shcr d a | lookslikeInt a = printf "%.0f" a
137 | otherwise = printf ("%."++show d++"f") a 145 | otherwise = printf ("%."++show d++"f") a
138 146
139 147lookslikeInt :: (Show a, RealFrac a) => a -> Bool
140lookslikeInt x = show (round x :: Int) ++".0" == shx || "-0.0" == shx 148lookslikeInt x = show (round x :: Int) ++".0" == shx || "-0.0" == shx
141 where shx = show x 149 where shx = show x
142 150
151isZero :: Show a => a -> Bool
143isZero x = show x `elem` ["0.0","-0.0"] 152isZero x = show x `elem` ["0.0","-0.0"]
153isOne :: Show a => a -> Bool
144isOne x = show x `elem` ["1.0","-1.0"] 154isOne x = show x `elem` ["1.0","-1.0"]
145 155
146-- | Pretty print a complex matrix with at most n decimal digits. 156-- | Pretty print a complex matrix with at most n decimal digits.
@@ -168,6 +178,6 @@ loadMatrix f = do
168 else 178 else
169 return (reshape c v) 179 return (reshape c v)
170 180
171 181loadMatrix' :: FilePath -> IO (Maybe (Matrix Double))
172loadMatrix' name = mbCatch (loadMatrix name) 182loadMatrix' name = mbCatch (loadMatrix name)
173 183
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs
index e306454..64cf2f5 100644
--- a/packages/base/src/Internal/LAPACK.hs
+++ b/packages/base/src/Internal/LAPACK.hs
@@ -1,6 +1,8 @@
1{-# LANGUAGE TypeOperators #-} 1{-# LANGUAGE TypeOperators #-}
2{-# LANGUAGE ViewPatterns #-} 2{-# LANGUAGE ViewPatterns #-}
3 3
4{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
5
4----------------------------------------------------------------------------- 6-----------------------------------------------------------------------------
5-- | 7-- |
6-- Module : Numeric.LinearAlgebra.LAPACK 8-- Module : Numeric.LinearAlgebra.LAPACK
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index 4905f61..4bfa13d 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -57,19 +57,24 @@ cols :: Matrix t -> Int
57cols = icols 57cols = icols
58{-# INLINE cols #-} 58{-# INLINE cols #-}
59 59
60size :: Matrix t -> (Int, Int)
60size m = (irows m, icols m) 61size m = (irows m, icols m)
61{-# INLINE size #-} 62{-# INLINE size #-}
62 63
64rowOrder :: Matrix t -> Bool
63rowOrder m = xCol m == 1 || cols m == 1 65rowOrder m = xCol m == 1 || cols m == 1
64{-# INLINE rowOrder #-} 66{-# INLINE rowOrder #-}
65 67
68colOrder :: Matrix t -> Bool
66colOrder m = xRow m == 1 || rows m == 1 69colOrder m = xRow m == 1 || rows m == 1
67{-# INLINE colOrder #-} 70{-# INLINE colOrder #-}
68 71
72is1d :: Matrix t -> Bool
69is1d (size->(r,c)) = r==1 || c==1 73is1d (size->(r,c)) = r==1 || c==1
70{-# INLINE is1d #-} 74{-# INLINE is1d #-}
71 75
72-- data is not contiguous 76-- data is not contiguous
77isSlice :: Storable t => Matrix t -> Bool
73isSlice m@(size->(r,c)) = r*c < dim (xdat m) 78isSlice m@(size->(r,c)) = r*c < dim (xdat m)
74{-# INLINE isSlice #-} 79{-# INLINE isSlice #-}
75 80
@@ -136,16 +141,20 @@ instance Storable t => TransArray (Matrix t)
136 {-# INLINE applyRaw #-} 141 {-# INLINE applyRaw #-}
137 142
138infixr 1 # 143infixr 1 #
144(#) :: TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
139a # b = apply a b 145a # b = apply a b
140{-# INLINE (#) #-} 146{-# INLINE (#) #-}
141 147
148(#!) :: (TransArray c, TransArray c1) => c1 -> c -> Trans c1 (Trans c (IO r)) -> IO r
142a #! b = a # b # id 149a #! b = a # b # id
143{-# INLINE (#!) #-} 150{-# INLINE (#!) #-}
144 151
145-------------------------------------------------------------------------------- 152--------------------------------------------------------------------------------
146 153
154copy :: Element t => MatrixOrder -> Matrix t -> IO (Matrix t)
147copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) 155copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1])
148 156
157extractAll :: Element t => MatrixOrder -> Matrix t -> Matrix t
149extractAll ord m = unsafePerformIO (copy ord m) 158extractAll ord m = unsafePerformIO (copy ord m)
150 159
151{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. 160{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose.
@@ -223,11 +232,13 @@ m@Matrix {irows = r, icols = c} @@> (i,j)
223{-# INLINE (@@>) #-} 232{-# INLINE (@@>) #-}
224 233
225-- Unsafe matrix access without range checking 234-- Unsafe matrix access without range checking
235atM' :: Storable t => Matrix t -> Int -> Int -> t
226atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m)) 236atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m))
227{-# INLINE atM' #-} 237{-# INLINE atM' #-}
228 238
229------------------------------------------------------------------ 239------------------------------------------------------------------
230 240
241matrixFromVector :: Storable t => MatrixOrder -> Int -> Int -> Vector t -> Matrix t
231matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } 242matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 }
232matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } 243matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d }
233matrixFromVector o r c v 244matrixFromVector o r c v
@@ -387,18 +398,21 @@ subMatrix (r0,c0) (rt,ct) m
387 398
388-------------------------------------------------------------------------- 399--------------------------------------------------------------------------
389 400
401maxZ :: (Num t1, Ord t1, Foldable t) => t t1 -> t1
390maxZ xs = if minimum xs == 0 then 0 else maximum xs 402maxZ xs = if minimum xs == 0 then 0 else maximum xs
391 403
404conformMs :: Element t => [Matrix t] -> [Matrix t]
392conformMs ms = map (conformMTo (r,c)) ms 405conformMs ms = map (conformMTo (r,c)) ms
393 where 406 where
394 r = maxZ (map rows ms) 407 r = maxZ (map rows ms)
395 c = maxZ (map cols ms) 408 c = maxZ (map cols ms)
396 409
397 410conformVs :: Element t => [Vector t] -> [Vector t]
398conformVs vs = map (conformVTo n) vs 411conformVs vs = map (conformVTo n) vs
399 where 412 where
400 n = maxZ (map dim vs) 413 n = maxZ (map dim vs)
401 414
415conformMTo :: Element t => (Int, Int) -> Matrix t -> Matrix t
402conformMTo (r,c) m 416conformMTo (r,c) m
403 | size m == (r,c) = m 417 | size m == (r,c) = m
404 | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) 418 | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c))
@@ -406,18 +420,24 @@ conformMTo (r,c) m
406 | size m == (1,c) = repRows r m 420 | size m == (1,c) = repRows r m
407 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c) 421 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c)
408 422
423conformVTo :: Element t => Int -> Vector t -> Vector t
409conformVTo n v 424conformVTo n v
410 | dim v == n = v 425 | dim v == n = v
411 | dim v == 1 = constantD (v@>0) n 426 | dim v == 1 = constantD (v@>0) n
412 | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n 427 | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n
413 428
429repRows :: Element t => Int -> Matrix t -> Matrix t
414repRows n x = fromRows (replicate n (flatten x)) 430repRows n x = fromRows (replicate n (flatten x))
431repCols :: Element t => Int -> Matrix t -> Matrix t
415repCols n x = fromColumns (replicate n (flatten x)) 432repCols n x = fromColumns (replicate n (flatten x))
416 433
434shSize :: Matrix t -> [Char]
417shSize = shDim . size 435shSize = shDim . size
418 436
437shDim :: (Show a, Show a1) => (a1, a) -> [Char]
419shDim (r,c) = "(" ++ show r ++"x"++ show c ++")" 438shDim (r,c) = "(" ++ show r ++"x"++ show c ++")"
420 439
440emptyM :: Storable t => Int -> Int -> Matrix t
421emptyM r c = matrixFromVector RowMajor r c (fromList[]) 441emptyM r c = matrixFromVector RowMajor r c (fromList[])
422 442
423---------------------------------------------------------------------- 443----------------------------------------------------------------------
@@ -432,6 +452,11 @@ instance (Storable t, NFData t) => NFData (Matrix t)
432 452
433--------------------------------------------------------------- 453---------------------------------------------------------------
434 454
455extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1,
456 Storable t, Num t3, Num t2, Integral t1, Integral t)
457 => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t
458 -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))
459 -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a)
435extractAux f ord m moder vr modec vc = do 460extractAux f ord m moder vr modec vc = do
436 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr 461 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr
437 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc 462 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc
@@ -451,6 +476,9 @@ foreign import ccall unsafe "extractL" c_extractL :: Extr Z
451 476
452--------------------------------------------------------------- 477---------------------------------------------------------------
453 478
479setRectAux :: (TransArray c1, TransArray c)
480 => (CInt -> CInt -> Trans c1 (Trans c (IO CInt)))
481 -> Int -> Int -> c1 -> c -> IO ()
454setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" 482setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect"
455 483
456type SetRect x = I -> I -> x ::> x::> Ok 484type SetRect x = I -> I -> x ::> x::> Ok
@@ -464,19 +492,29 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z
464 492
465-------------------------------------------------------------------------------- 493--------------------------------------------------------------------------------
466 494
495sortG :: (Storable t, Storable a)
496 => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a
467sortG f v = unsafePerformIO $ do 497sortG f v = unsafePerformIO $ do
468 r <- createVector (dim v) 498 r <- createVector (dim v)
469 (v #! r) f #|"sortG" 499 (v #! r) f #|"sortG"
470 return r 500 return r
471 501
502sortIdxD :: Vector Double -> Vector CInt
472sortIdxD = sortG c_sort_indexD 503sortIdxD = sortG c_sort_indexD
504sortIdxF :: Vector Float -> Vector CInt
473sortIdxF = sortG c_sort_indexF 505sortIdxF = sortG c_sort_indexF
506sortIdxI :: Vector CInt -> Vector CInt
474sortIdxI = sortG c_sort_indexI 507sortIdxI = sortG c_sort_indexI
508sortIdxL :: Vector Z -> Vector I
475sortIdxL = sortG c_sort_indexL 509sortIdxL = sortG c_sort_indexL
476 510
511sortValD :: Vector Double -> Vector Double
477sortValD = sortG c_sort_valD 512sortValD = sortG c_sort_valD
513sortValF :: Vector Float -> Vector Float
478sortValF = sortG c_sort_valF 514sortValF = sortG c_sort_valF
515sortValI :: Vector CInt -> Vector CInt
479sortValI = sortG c_sort_valI 516sortValI = sortG c_sort_valI
517sortValL :: Vector Z -> Vector Z
480sortValL = sortG c_sort_valL 518sortValL = sortG c_sort_valL
481 519
482foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) 520foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt))
@@ -491,14 +529,21 @@ foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok
491 529
492-------------------------------------------------------------------------------- 530--------------------------------------------------------------------------------
493 531
532compareG :: (TransArray c, Storable t, Storable a)
533 => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt)
534 -> c -> Vector t -> Vector a
494compareG f u v = unsafePerformIO $ do 535compareG f u v = unsafePerformIO $ do
495 r <- createVector (dim v) 536 r <- createVector (dim v)
496 (u # v #! r) f #|"compareG" 537 (u # v #! r) f #|"compareG"
497 return r 538 return r
498 539
540compareD :: Vector Double -> Vector Double -> Vector CInt
499compareD = compareG c_compareD 541compareD = compareG c_compareD
542compareF :: Vector Float -> Vector Float -> Vector CInt
500compareF = compareG c_compareF 543compareF = compareG c_compareF
544compareI :: Vector CInt -> Vector CInt -> Vector CInt
501compareI = compareG c_compareI 545compareI = compareG c_compareI
546compareL :: Vector Z -> Vector Z -> Vector CInt
502compareL = compareG c_compareL 547compareL = compareG c_compareL
503 548
504foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) 549foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt)))
@@ -508,16 +553,33 @@ foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok
508 553
509-------------------------------------------------------------------------------- 554--------------------------------------------------------------------------------
510 555
556selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a)
557 => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt)))
558 -> c2 -> c1 -> Vector t -> c -> Vector a
511selectG f c u v w = unsafePerformIO $ do 559selectG f c u v w = unsafePerformIO $ do
512 r <- createVector (dim v) 560 r <- createVector (dim v)
513 (c # u # v # w #! r) f #|"selectG" 561 (c # u # v # w #! r) f #|"selectG"
514 return r 562 return r
515 563
564selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double
516selectD = selectG c_selectD 565selectD = selectG c_selectD
566selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float
517selectF = selectG c_selectF 567selectF = selectG c_selectF
568selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt
518selectI = selectG c_selectI 569selectI = selectG c_selectI
570selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z
519selectL = selectG c_selectL 571selectL = selectG c_selectL
572selectC :: Vector CInt
573 -> Vector (Complex Double)
574 -> Vector (Complex Double)
575 -> Vector (Complex Double)
576 -> Vector (Complex Double)
520selectC = selectG c_selectC 577selectC = selectG c_selectC
578selectQ :: Vector CInt
579 -> Vector (Complex Float)
580 -> Vector (Complex Float)
581 -> Vector (Complex Float)
582 -> Vector (Complex Float)
521selectQ = selectG c_selectQ 583selectQ = selectG c_selectQ
522 584
523type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) 585type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt)))))
@@ -531,16 +593,29 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z
531 593
532--------------------------------------------------------------------------- 594---------------------------------------------------------------------------
533 595
596remapG :: (TransArray c, TransArray c1, Storable t, Storable a)
597 => (CInt -> CInt -> CInt -> CInt -> Ptr t
598 -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)))
599 -> Matrix t -> c1 -> c -> Matrix a
534remapG f i j m = unsafePerformIO $ do 600remapG f i j m = unsafePerformIO $ do
535 r <- createMatrix RowMajor (rows i) (cols i) 601 r <- createMatrix RowMajor (rows i) (cols i)
536 (i # j # m #! r) f #|"remapG" 602 (i # j # m #! r) f #|"remapG"
537 return r 603 return r
538 604
605remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double
539remapD = remapG c_remapD 606remapD = remapG c_remapD
607remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float
540remapF = remapG c_remapF 608remapF = remapG c_remapF
609remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt
541remapI = remapG c_remapI 610remapI = remapG c_remapI
611remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z
542remapL = remapG c_remapL 612remapL = remapG c_remapL
613remapC :: Matrix CInt
614 -> Matrix CInt
615 -> Matrix (Complex Double)
616 -> Matrix (Complex Double)
543remapC = remapG c_remapC 617remapC = remapG c_remapC
618remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float)
544remapQ = remapG c_remapQ 619remapQ = remapG c_remapQ
545 620
546type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) 621type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt))))
@@ -554,6 +629,9 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z
554 629
555-------------------------------------------------------------------------------- 630--------------------------------------------------------------------------------
556 631
632rowOpAux :: (TransArray c, Storable a) =>
633 (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt))
634 -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO ()
557rowOpAux f c x i1 i2 j1 j2 m = do 635rowOpAux f c x i1 i2 j1 j2 m = do
558 px <- newArray [x] 636 px <- newArray [x]
559 (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" 637 (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp"
@@ -572,6 +650,9 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z
572 650
573-------------------------------------------------------------------------------- 651--------------------------------------------------------------------------------
574 652
653gemmg :: (TransArray c1, TransArray c, TransArray c2, TransArray c3)
654 => Trans c3 (Trans c2 (Trans c1 (Trans c (IO CInt))))
655 -> c3 -> c2 -> c1 -> c -> IO ()
575gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" 656gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg"
576 657
577type Tgemm x = x :> x ::> x ::> x ::> Ok 658type Tgemm x = x :> x ::> x ::> x ::> Ok
@@ -587,6 +668,10 @@ foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z
587 668
588-------------------------------------------------------------------------------- 669--------------------------------------------------------------------------------
589 670
671reorderAux :: (TransArray c, Storable t, Storable a1, Storable t1, Storable a) =>
672 (CInt -> Ptr a -> CInt -> Ptr t1
673 -> Trans c (CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt))
674 -> Vector t1 -> c -> Vector t -> Vector a1
590reorderAux f s d v = unsafePerformIO $ do 675reorderAux f s d v = unsafePerformIO $ do
591 k <- createVector (dim s) 676 k <- createVector (dim s)
592 r <- createVector (dim v) 677 r <- createVector (dim v)
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs
index 9d51444..eb0c5a8 100644
--- a/packages/base/src/Internal/Modular.hs
+++ b/packages/base/src/Internal/Modular.hs
@@ -13,6 +13,9 @@
13{-# LANGUAGE TypeFamilies #-} 13{-# LANGUAGE TypeFamilies #-}
14{-# LANGUAGE TypeOperators #-} 14{-# LANGUAGE TypeOperators #-}
15 15
16{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
17{-# OPTIONS_GHC -fno-warn-missing-methods #-}
18
16{- | 19{- |
17Module : Internal.Modular 20Module : Internal.Modular
18Copyright : (c) Alberto Ruiz 2015 21Copyright : (c) Alberto Ruiz 2015
diff --git a/packages/base/src/Internal/Numeric.hs b/packages/base/src/Internal/Numeric.hs
index c9ef0c5..216f142 100644
--- a/packages/base/src/Internal/Numeric.hs
+++ b/packages/base/src/Internal/Numeric.hs
@@ -5,6 +5,8 @@
5{-# LANGUAGE FunctionalDependencies #-} 5{-# LANGUAGE FunctionalDependencies #-}
6{-# LANGUAGE UndecidableInstances #-} 6{-# LANGUAGE UndecidableInstances #-}
7 7
8{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
9
8----------------------------------------------------------------------------- 10-----------------------------------------------------------------------------
9-- | 11-- |
10-- Module : Data.Packed.Internal.Numeric 12-- Module : Data.Packed.Internal.Numeric
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index 544c9e4..7d54e6d 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -81,6 +81,8 @@ unsafeFreezeVector :: (Storable t) => STVector s t -> ST s (Vector t)
81unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x 81unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x
82 82
83{-# INLINE safeIndexV #-} 83{-# INLINE safeIndexV #-}
84safeIndexV :: Storable t2
85 => (STVector s t2 -> Int -> t) -> STVector t1 t2 -> Int -> t
84safeIndexV f (STVector v) k 86safeIndexV f (STVector v) k
85 | k < 0 || k>= dim v = error $ "out of range error in vector (dim=" 87 | k < 0 || k>= dim v = error $ "out of range error in vector (dim="
86 ++show (dim v)++", pos="++show k++")" 88 ++show (dim v)++", pos="++show k++")"
@@ -150,9 +152,12 @@ unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
150freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t) 152freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t)
151freezeMatrix m = liftSTMatrix id m 153freezeMatrix m = liftSTMatrix id m
152 154
155cloneMatrix :: Element t => Matrix t -> IO (Matrix t)
153cloneMatrix m = copy (orderOf m) m 156cloneMatrix m = copy (orderOf m) m
154 157
155{-# INLINE safeIndexM #-} 158{-# INLINE safeIndexM #-}
159safeIndexM :: (STMatrix s t2 -> Int -> Int -> t)
160 -> STMatrix t1 t2 -> Int -> Int -> t
156safeIndexM f (STMatrix m) r c 161safeIndexM f (STMatrix m) r c
157 | r<0 || r>=rows m || 162 | r<0 || r>=rows m ||
158 c<0 || c>=cols m = error $ "out of range error in matrix (size=" 163 c<0 || c>=cols m = error $ "out of range error in matrix (size="
@@ -184,6 +189,7 @@ data ColRange = AllCols
184 | Col Int 189 | Col Int
185 | FromCol Int 190 | FromCol Int
186 191
192getColRange :: Int -> ColRange -> (Int, Int)
187getColRange c AllCols = (0,c-1) 193getColRange c AllCols = (0,c-1)
188getColRange c (ColRange a b) = (a `mod` c, b `mod` c) 194getColRange c (ColRange a b) = (a `mod` c, b `mod` c)
189getColRange c (Col a) = (a `mod` c, a `mod` c) 195getColRange c (Col a) = (a `mod` c, a `mod` c)
@@ -194,6 +200,7 @@ data RowRange = AllRows
194 | Row Int 200 | Row Int
195 | FromRow Int 201 | FromRow Int
196 202
203getRowRange :: Int -> RowRange -> (Int, Int)
197getRowRange r AllRows = (0,r-1) 204getRowRange r AllRows = (0,r-1)
198getRowRange r (RowRange a b) = (a `mod` r, b `mod` r) 205getRowRange r (RowRange a b) = (a `mod` r, b `mod` r)
199getRowRange r (Row a) = (a `mod` r, a `mod` r) 206getRowRange r (Row a) = (a `mod` r, a `mod` r)
@@ -223,6 +230,7 @@ rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m
223 i2' = i2 `mod` (rows m) 230 i2' = i2 `mod` (rows m)
224 231
225 232
233extractMatrix :: Element a => STMatrix t a -> RowRange -> ColRange -> ST s (Matrix a)
226extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) 234extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2]))
227 where 235 where
228 (i1,i2) = getRowRange (rows m) rr 236 (i1,i2) = getRowRange (rows m) rr
@@ -231,6 +239,7 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[
231-- | r0 c0 height width 239-- | r0 c0 height width
232data Slice s t = Slice (STMatrix s t) Int Int Int Int 240data Slice s t = Slice (STMatrix s t) Int Int Int Int
233 241
242slice :: Element a => Slice t a -> Matrix a
234slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m 243slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m
235 244
236gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () 245gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s ()
@@ -238,7 +247,7 @@ gemmm beta (slice->r) alpha (slice->a) (slice->b) = res
238 where 247 where
239 res = unsafeIOToST (gemm v a b r) 248 res = unsafeIOToST (gemm v a b r)
240 v = fromList [alpha,beta] 249 v = fromList [alpha,beta]
241 250
242 251
243mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) 252mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u)
244mutable f a = runST $ do 253mutable f a = runST $ do
@@ -246,4 +255,3 @@ mutable f a = runST $ do
246 info <- f (rows a, cols a) x 255 info <- f (rows a, cols a) x
247 r <- unsafeFreezeMatrix x 256 r <- unsafeFreezeMatrix x
248 return (r,info) 257 return (r,info)
249
diff --git a/packages/base/src/Internal/Sparse.hs b/packages/base/src/Internal/Sparse.hs
index 1ff3f57..6233b03 100644
--- a/packages/base/src/Internal/Sparse.hs
+++ b/packages/base/src/Internal/Sparse.hs
@@ -2,6 +2,8 @@
2{-# LANGUAGE MultiParamTypeClasses #-} 2{-# LANGUAGE MultiParamTypeClasses #-}
3{-# LANGUAGE FlexibleInstances #-} 3{-# LANGUAGE FlexibleInstances #-}
4 4
5{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
6
5module Internal.Sparse( 7module Internal.Sparse(
6 GMatrix(..), CSR(..), mkCSR, fromCSR, 8 GMatrix(..), CSR(..), mkCSR, fromCSR,
7 mkSparse, mkDiagR, mkDense, 9 mkSparse, mkDiagR, mkDense,
diff --git a/packages/base/src/Internal/Static.hs b/packages/base/src/Internal/Static.hs
index f9dfff0..6ef1350 100644
--- a/packages/base/src/Internal/Static.hs
+++ b/packages/base/src/Internal/Static.hs
@@ -15,6 +15,8 @@
15{-# LANGUAGE BangPatterns #-} 15{-# LANGUAGE BangPatterns #-}
16{-# LANGUAGE DeriveGeneric #-} 16{-# LANGUAGE DeriveGeneric #-}
17 17
18{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
19
18{- | 20{- |
19Module : Internal.Static 21Module : Internal.Static
20Copyright : (c) Alberto Ruiz 2006-14 22Copyright : (c) Alberto Ruiz 2006-14
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs
index 8c8a31e..def7cc3 100644
--- a/packages/base/src/Internal/Util.hs
+++ b/packages/base/src/Internal/Util.hs
@@ -6,6 +6,8 @@
6{-# LANGUAGE ScopedTypeVariables #-} 6{-# LANGUAGE ScopedTypeVariables #-}
7{-# LANGUAGE ViewPatterns #-} 7{-# LANGUAGE ViewPatterns #-}
8 8
9{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
10{-# OPTIONS_GHC -fno-warn-orphans #-}
9 11
10----------------------------------------------------------------------------- 12-----------------------------------------------------------------------------
11{- | 13{- |
diff --git a/packages/base/src/Internal/Vector.hs b/packages/base/src/Internal/Vector.hs
index 67d0416..dedb822 100644
--- a/packages/base/src/Internal/Vector.hs
+++ b/packages/base/src/Internal/Vector.hs
@@ -1,6 +1,7 @@
1{-# LANGUAGE MagicHash, UnboxedTuples, BangPatterns, FlexibleContexts #-} 1{-# LANGUAGE MagicHash, UnboxedTuples, BangPatterns, FlexibleContexts #-}
2{-# LANGUAGE TypeSynonymInstances #-} 2{-# LANGUAGE TypeSynonymInstances #-}
3 3
4{-# OPTIONS_GHC -fno-warn-orphans #-}
4 5
5-- | 6-- |
6-- Module : Internal.Vector 7-- Module : Internal.Vector
@@ -40,6 +41,7 @@ import qualified Data.Vector.Storable as Vector
40import Data.Vector.Storable(Vector, fromList, unsafeToForeignPtr, unsafeFromForeignPtr, unsafeWith) 41import Data.Vector.Storable(Vector, fromList, unsafeToForeignPtr, unsafeFromForeignPtr, unsafeWith)
41 42
42import Data.Binary 43import Data.Binary
44import Data.Binary.Put
43import Control.Monad(replicateM) 45import Control.Monad(replicateM)
44import qualified Data.ByteString.Internal as BS 46import qualified Data.ByteString.Internal as BS
45import Data.Vector.Storable.Internal(updPtr) 47import Data.Vector.Storable.Internal(updPtr)
@@ -92,6 +94,7 @@ createVector n = do
92 94
93-} 95-}
94 96
97safeRead :: Storable a => Vector a -> (Ptr a -> IO c) -> c
95safeRead v = inlinePerformIO . unsafeWith v 98safeRead v = inlinePerformIO . unsafeWith v
96{-# INLINE safeRead #-} 99{-# INLINE safeRead #-}
97 100
@@ -283,11 +286,13 @@ foldVectorWithIndex f x v = unsafePerformIO $
283 go (dim v -1) x 286 go (dim v -1) x
284{-# INLINE foldVectorWithIndex #-} 287{-# INLINE foldVectorWithIndex #-}
285 288
289foldLoop :: (Int -> t -> t) -> t -> Int -> t
286foldLoop f s0 d = go (d - 1) s0 290foldLoop f s0 d = go (d - 1) s0
287 where 291 where
288 go 0 s = f (0::Int) s 292 go 0 s = f (0::Int) s
289 go !j !s = go (j - 1) (f j s) 293 go !j !s = go (j - 1) (f j s)
290 294
295foldVectorG :: Storable t1 => (Int -> (Int -> t1) -> t -> t) -> t -> Vector t1 -> t
291foldVectorG f s0 v = foldLoop g s0 (dim v) 296foldVectorG f s0 v = foldLoop g s0 (dim v)
292 where g !k !s = f k (safeRead v . flip peekElemOff) s 297 where g !k !s = f k (safeRead v . flip peekElemOff) s
293 {-# INLINE g #-} -- Thanks to Ryan Ingram (http://permalink.gmane.org/gmane.comp.lang.haskell.cafe/46479) 298 {-# INLINE g #-} -- Thanks to Ryan Ingram (http://permalink.gmane.org/gmane.comp.lang.haskell.cafe/46479)
@@ -390,8 +395,10 @@ chunks d = let c = d `div` chunk
390 m = d `mod` chunk 395 m = d `mod` chunk
391 in if m /= 0 then reverse (m:(replicate c chunk)) else (replicate c chunk) 396 in if m /= 0 then reverse (m:(replicate c chunk)) else (replicate c chunk)
392 397
398putVector :: (Storable t, Binary t) => Vector t -> Data.Binary.Put.PutM ()
393putVector v = mapM_ put $! toList v 399putVector v = mapM_ put $! toList v
394 400
401getVector :: (Storable a, Binary a) => Int -> Get (Vector a)
395getVector d = do 402getVector d = do
396 xs <- replicateM d get 403 xs <- replicateM d get
397 return $! fromList xs 404 return $! fromList xs
diff --git a/packages/base/src/Internal/Vectorized.hs b/packages/base/src/Internal/Vectorized.hs
index a410bb2..c00c324 100644
--- a/packages/base/src/Internal/Vectorized.hs
+++ b/packages/base/src/Internal/Vectorized.hs
@@ -28,12 +28,15 @@ import System.IO.Unsafe(unsafePerformIO)
28import Control.Monad(when) 28import Control.Monad(when)
29 29
30infixr 1 # 30infixr 1 #
31(#) :: TransArray c => c -> (b -> IO r) -> TransRaw c b -> IO r
31a # b = applyRaw a b 32a # b = applyRaw a b
32{-# INLINE (#) #-} 33{-# INLINE (#) #-}
33 34
35(#!) :: (TransArray c, TransArray c1) => c1 -> c -> TransRaw c1 (TransRaw c (IO r)) -> IO r
34a #! b = a # b # id 36a #! b = a # b # id
35{-# INLINE (#!) #-} 37{-# INLINE (#!) #-}
36 38
39fromei :: Enum a => a -> CInt
37fromei x = fromIntegral (fromEnum x) :: CInt 40fromei x = fromIntegral (fromEnum x) :: CInt
38 41
39data FunCodeV = Sin 42data FunCodeV = Sin
@@ -100,10 +103,20 @@ sumQ = sumg c_sumQ
100sumC :: Vector (Complex Double) -> Complex Double 103sumC :: Vector (Complex Double) -> Complex Double
101sumC = sumg c_sumC 104sumC = sumg c_sumC
102 105
106sumI :: ( TransRaw c (CInt -> Ptr a -> IO CInt) ~ (CInt -> Ptr I -> I :> Ok)
107 , TransArray c
108 , Storable a
109 )
110 => I -> c -> a
103sumI m = sumg (c_sumI m) 111sumI m = sumg (c_sumI m)
104 112
113sumL :: ( TransRaw c (CInt -> Ptr a -> IO CInt) ~ (CInt -> Ptr Z -> Z :> Ok)
114 , TransArray c
115 , Storable a
116 ) => Z -> c -> a
105sumL m = sumg (c_sumL m) 117sumL m = sumg (c_sumL m)
106 118
119sumg :: (TransArray c, Storable a) => TransRaw c (CInt -> Ptr a -> IO CInt) -> c -> a
107sumg f x = unsafePerformIO $ do 120sumg f x = unsafePerformIO $ do
108 r <- createVector 1 121 r <- createVector 1
109 (x #! r) f #| "sum" 122 (x #! r) f #| "sum"
@@ -140,6 +153,8 @@ prodI = prodg . c_prodI
140prodL :: Z-> Vector Z -> Z 153prodL :: Z-> Vector Z -> Z
141prodL = prodg . c_prodL 154prodL = prodg . c_prodL
142 155
156prodg :: (TransArray c, Storable a)
157 => TransRaw c (CInt -> Ptr a -> IO CInt) -> c -> a
143prodg f x = unsafePerformIO $ do 158prodg f x = unsafePerformIO $ do
144 r <- createVector 1 159 r <- createVector 1
145 (x #! r) f #| "prod" 160 (x #! r) f #| "prod"
@@ -155,16 +170,25 @@ foreign import ccall unsafe "prodL" c_prodL :: Z -> TVV Z
155 170
156------------------------------------------------------------------ 171------------------------------------------------------------------
157 172
173toScalarAux :: (Enum a, TransArray c, Storable a1)
174 => (CInt -> TransRaw c (CInt -> Ptr a1 -> IO CInt)) -> a -> c -> a1
158toScalarAux fun code v = unsafePerformIO $ do 175toScalarAux fun code v = unsafePerformIO $ do
159 r <- createVector 1 176 r <- createVector 1
160 (v #! r) (fun (fromei code)) #|"toScalarAux" 177 (v #! r) (fun (fromei code)) #|"toScalarAux"
161 return (r @> 0) 178 return (r @> 0)
162 179
180
181vectorMapAux :: (Enum a, Storable t, Storable a1)
182 => (CInt -> CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt)
183 -> a -> Vector t -> Vector a1
163vectorMapAux fun code v = unsafePerformIO $ do 184vectorMapAux fun code v = unsafePerformIO $ do
164 r <- createVector (dim v) 185 r <- createVector (dim v)
165 (v #! r) (fun (fromei code)) #|"vectorMapAux" 186 (v #! r) (fun (fromei code)) #|"vectorMapAux"
166 return r 187 return r
167 188
189vectorMapValAux :: (Enum a, Storable a2, Storable t, Storable a1)
190 => (CInt -> Ptr a2 -> CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt)
191 -> a -> a2 -> Vector t -> Vector a1
168vectorMapValAux fun code val v = unsafePerformIO $ do 192vectorMapValAux fun code val v = unsafePerformIO $ do
169 r <- createVector (dim v) 193 r <- createVector (dim v)
170 pval <- newArray [val] 194 pval <- newArray [val]
@@ -172,6 +196,9 @@ vectorMapValAux fun code val v = unsafePerformIO $ do
172 free pval 196 free pval
173 return r 197 return r
174 198
199vectorZipAux :: (Enum a, TransArray c, Storable t, Storable a1)
200 => (CInt -> CInt -> Ptr t -> TransRaw c (CInt -> Ptr a1 -> IO CInt))
201 -> a -> Vector t -> c -> Vector a1
175vectorZipAux fun code u v = unsafePerformIO $ do 202vectorZipAux fun code u v = unsafePerformIO $ do
176 r <- createVector (dim u) 203 r <- createVector (dim u)
177 (u # v #! r) (fun (fromei code)) #|"vectorZipAux" 204 (u # v #! r) (fun (fromei code)) #|"vectorZipAux"
@@ -378,6 +405,7 @@ foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> D
378 405
379-------------------------------------------------------------------------------- 406--------------------------------------------------------------------------------
380 407
408roundVector :: Vector Double -> Vector Double
381roundVector v = unsafePerformIO $ do 409roundVector v = unsafePerformIO $ do
382 r <- createVector (dim v) 410 r <- createVector (dim v)
383 (v #! r) c_round_vector #|"roundVector" 411 (v #! r) c_round_vector #|"roundVector"
@@ -432,6 +460,8 @@ long2intV :: Vector Z -> Vector I
432long2intV = tog c_long2int 460long2intV = tog c_long2int
433 461
434 462
463tog :: (Storable t, Storable a)
464 => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a
435tog f v = unsafePerformIO $ do 465tog f v = unsafePerformIO $ do
436 r <- createVector (dim v) 466 r <- createVector (dim v)
437 (v #! r) f #|"tog" 467 (v #! r) f #|"tog"
@@ -451,6 +481,8 @@ foreign import ccall unsafe "long2int" c_long2int :: Z :> I :> Ok
451 481
452--------------------------------------------------------------- 482---------------------------------------------------------------
453 483
484stepg :: (Storable t, Storable a)
485 => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a
454stepg f v = unsafePerformIO $ do 486stepg f v = unsafePerformIO $ do
455 r <- createVector (dim v) 487 r <- createVector (dim v)
456 (v #! r) f #|"step" 488 (v #! r) f #|"step"
@@ -476,6 +508,8 @@ foreign import ccall unsafe "stepL" c_stepL :: TVV Z
476 508
477-------------------------------------------------------------------------------- 509--------------------------------------------------------------------------------
478 510
511conjugateAux :: (Storable t, Storable a)
512 => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a
479conjugateAux fun x = unsafePerformIO $ do 513conjugateAux fun x = unsafePerformIO $ do
480 v <- createVector (dim x) 514 v <- createVector (dim x)
481 (x #! v) fun #|"conjugateAux" 515 (x #! v) fun #|"conjugateAux"
@@ -501,6 +535,8 @@ cloneVector v = do
501 535
502-------------------------------------------------------------------------------- 536--------------------------------------------------------------------------------
503 537
538constantAux :: (Storable a1, Storable a)
539 => (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a
504constantAux fun x n = unsafePerformIO $ do 540constantAux fun x n = unsafePerformIO $ do
505 v <- createVector n 541 v <- createVector n
506 px <- newArray [x] 542 px <- newArray [x]