summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal')
-rw-r--r--packages/base/src/Internal/Algorithms.hs2
-rw-r--r--packages/base/src/Internal/CG.hs2
-rw-r--r--packages/base/src/Internal/Chain.hs2
-rw-r--r--packages/base/src/Internal/Container.hs2
-rw-r--r--packages/base/src/Internal/Devel.hs1
-rw-r--r--packages/base/src/Internal/Element.hs23
-rw-r--r--packages/base/src/Internal/IO.hs16
-rw-r--r--packages/base/src/Internal/LAPACK.hs2
-rw-r--r--packages/base/src/Internal/Matrix.hs87
-rw-r--r--packages/base/src/Internal/Modular.hs3
-rw-r--r--packages/base/src/Internal/Numeric.hs10
-rw-r--r--packages/base/src/Internal/ST.hs12
-rw-r--r--packages/base/src/Internal/Sparse.hs2
-rw-r--r--packages/base/src/Internal/Static.hs3
-rw-r--r--packages/base/src/Internal/Util.hs2
-rw-r--r--packages/base/src/Internal/Vector.hs7
-rw-r--r--packages/base/src/Internal/Vectorized.hs36
17 files changed, 196 insertions, 16 deletions
diff --git a/packages/base/src/Internal/Algorithms.hs b/packages/base/src/Internal/Algorithms.hs
index 5fe7796..6027c46 100644
--- a/packages/base/src/Internal/Algorithms.hs
+++ b/packages/base/src/Internal/Algorithms.hs
@@ -4,6 +4,8 @@
4{-# LANGUAGE UndecidableInstances #-} 4{-# LANGUAGE UndecidableInstances #-}
5{-# LANGUAGE TypeFamilies #-} 5{-# LANGUAGE TypeFamilies #-}
6 6
7{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
8
7----------------------------------------------------------------------------- 9-----------------------------------------------------------------------------
8{- | 10{- |
9Module : Internal.Algorithms 11Module : Internal.Algorithms
diff --git a/packages/base/src/Internal/CG.hs b/packages/base/src/Internal/CG.hs
index cc10ad8..29edd35 100644
--- a/packages/base/src/Internal/CG.hs
+++ b/packages/base/src/Internal/CG.hs
@@ -1,6 +1,8 @@
1{-# LANGUAGE FlexibleContexts, FlexibleInstances #-} 1{-# LANGUAGE FlexibleContexts, FlexibleInstances #-}
2{-# LANGUAGE RecordWildCards #-} 2{-# LANGUAGE RecordWildCards #-}
3 3
4{-# OPTIONS_GHC -fno-warn-orphans #-}
5
4module Internal.CG( 6module Internal.CG(
5 cgSolve, cgSolve', 7 cgSolve, cgSolve',
6 CGState(..), R, V 8 CGState(..), R, V
diff --git a/packages/base/src/Internal/Chain.hs b/packages/base/src/Internal/Chain.hs
index f87eb02..4000c2b 100644
--- a/packages/base/src/Internal/Chain.hs
+++ b/packages/base/src/Internal/Chain.hs
@@ -1,5 +1,7 @@
1{-# LANGUAGE FlexibleContexts #-} 1{-# LANGUAGE FlexibleContexts #-}
2 2
3{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
4
3----------------------------------------------------------------------------- 5-----------------------------------------------------------------------------
4-- | 6-- |
5-- Module : Internal.Chain 7-- Module : Internal.Chain
diff --git a/packages/base/src/Internal/Container.hs b/packages/base/src/Internal/Container.hs
index 20114a0..a498069 100644
--- a/packages/base/src/Internal/Container.hs
+++ b/packages/base/src/Internal/Container.hs
@@ -5,6 +5,8 @@
5{-# LANGUAGE FunctionalDependencies #-} 5{-# LANGUAGE FunctionalDependencies #-}
6{-# LANGUAGE UndecidableInstances #-} 6{-# LANGUAGE UndecidableInstances #-}
7 7
8{-# OPTIONS_GHC -fno-warn-simplifiable-class-constraints #-}
9
8----------------------------------------------------------------------------- 10-----------------------------------------------------------------------------
9-- | 11-- |
10-- Module : Internal.Container 12-- Module : Internal.Container
diff --git a/packages/base/src/Internal/Devel.hs b/packages/base/src/Internal/Devel.hs
index 3887663..f72d8aa 100644
--- a/packages/base/src/Internal/Devel.hs
+++ b/packages/base/src/Internal/Devel.hs
@@ -54,6 +54,7 @@ check msg f = do
54 54
55-- | postfix error code check 55-- | postfix error code check
56infixl 0 #| 56infixl 0 #|
57(#|) :: IO CInt -> String -> IO ()
57(#|) = flip check 58(#|) = flip check
58 59
59-- | Error capture and conversion to Maybe 60-- | Error capture and conversion to Maybe
diff --git a/packages/base/src/Internal/Element.hs b/packages/base/src/Internal/Element.hs
index eb3a25b..2e330ee 100644
--- a/packages/base/src/Internal/Element.hs
+++ b/packages/base/src/Internal/Element.hs
@@ -4,6 +4,8 @@
4{-# LANGUAGE UndecidableInstances #-} 4{-# LANGUAGE UndecidableInstances #-}
5{-# LANGUAGE MultiParamTypeClasses #-} 5{-# LANGUAGE MultiParamTypeClasses #-}
6 6
7{-# OPTIONS_GHC -fno-warn-orphans #-}
8
7----------------------------------------------------------------------------- 9-----------------------------------------------------------------------------
8-- | 10-- |
9-- Module : Data.Packed.Matrix 11-- Module : Data.Packed.Matrix
@@ -31,6 +33,7 @@ import Data.List.Split(chunksOf)
31import Foreign.Storable(Storable) 33import Foreign.Storable(Storable)
32import System.IO.Unsafe(unsafePerformIO) 34import System.IO.Unsafe(unsafePerformIO)
33import Control.Monad(liftM) 35import Control.Monad(liftM)
36import Foreign.C.Types(CInt)
34 37
35------------------------------------------------------------------- 38-------------------------------------------------------------------
36 39
@@ -53,8 +56,10 @@ instance (Show a, Element a) => (Show (Matrix a)) where
53 show m | rows m == 0 || cols m == 0 = sizes m ++" []" 56 show m | rows m == 0 || cols m == 0 = sizes m ++" []"
54 show m = (sizes m++) . dsp . map (map show) . toLists $ m 57 show m = (sizes m++) . dsp . map (map show) . toLists $ m
55 58
59sizes :: Matrix t -> [Char]
56sizes m = "("++show (rows m)++"><"++show (cols m)++")\n" 60sizes m = "("++show (rows m)++"><"++show (cols m)++")\n"
57 61
62dsp :: [[[Char]]] -> [Char]
58dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp 63dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp
59 where 64 where
60 mt = transpose as 65 mt = transpose as
@@ -73,6 +78,7 @@ instance (Element a, Read a) => Read (Matrix a) where
73 rs = read . snd . breakAt '(' .init . fst . breakAt '>' $ dims 78 rs = read . snd . breakAt '(' .init . fst . breakAt '>' $ dims
74 79
75 80
81breakAt :: Eq a => a -> [a] -> ([a], [a])
76breakAt c l = (a++[c],tail b) where 82breakAt c l = (a++[c],tail b) where
77 (a,b) = break (==c) l 83 (a,b) = break (==c) l
78 84
@@ -88,7 +94,8 @@ data Extractor
88 | Drop Int 94 | Drop Int
89 | DropLast Int 95 | DropLast Int
90 deriving Show 96 deriving Show
91 97
98ppext :: Extractor -> [Char]
92ppext All = ":" 99ppext All = ":"
93ppext (Range a 1 c) = printf "%d:%d" a c 100ppext (Range a 1 c) = printf "%d:%d" a c
94ppext (Range a b c) = printf "%d:%d:%d" a b c 101ppext (Range a b c) = printf "%d:%d:%d" a b c
@@ -128,10 +135,14 @@ ppext (DropLast n) = printf "DropLast %d" n
128infixl 9 ?? 135infixl 9 ??
129(??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t 136(??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t
130 137
138minEl :: Vector CInt -> CInt
131minEl = toScalarI Min 139minEl = toScalarI Min
140maxEl :: Vector CInt -> CInt
132maxEl = toScalarI Max 141maxEl = toScalarI Max
142cmodi :: Foreign.C.Types.CInt -> Vector Foreign.C.Types.CInt -> Vector Foreign.C.Types.CInt
133cmodi = vectorMapValI ModVS 143cmodi = vectorMapValI ModVS
134 144
145extractError :: Matrix t1 -> (Extractor, Extractor) -> t
135extractError m (e1,e2)= error $ printf "can't extract (%s,%s) from matrix %dx%d" (ppext e1::String) (ppext e2::String) (rows m) (cols m) 146extractError m (e1,e2)= error $ printf "can't extract (%s,%s) from matrix %dx%d" (ppext e1::String) (ppext e2::String) (rows m) (cols m)
136 147
137m ?? (Range a s b,e) | s /= 1 = m ?? (Pos (idxs [a,a+s .. b]), e) 148m ?? (Range a s b,e) | s /= 1 = m ?? (Pos (idxs [a,a+s .. b]), e)
@@ -232,8 +243,10 @@ disp = putStr . dispf 2
232fromBlocks :: Element t => [[Matrix t]] -> Matrix t 243fromBlocks :: Element t => [[Matrix t]] -> Matrix t
233fromBlocks = fromBlocksRaw . adaptBlocks 244fromBlocks = fromBlocksRaw . adaptBlocks
234 245
246fromBlocksRaw :: Element t => [[Matrix t]] -> Matrix t
235fromBlocksRaw mms = joinVert . map joinHoriz $ mms 247fromBlocksRaw mms = joinVert . map joinHoriz $ mms
236 248
249adaptBlocks :: Element t => [[Matrix t]] -> [[Matrix t]]
237adaptBlocks ms = ms' where 250adaptBlocks ms = ms' where
238 bc = case common length ms of 251 bc = case common length ms of
239 Just c -> c 252 Just c -> c
@@ -486,6 +499,9 @@ liftMatrix2Auto f m1 m2
486 m2' = conformMTo (r,c) m2 499 m2' = conformMTo (r,c) m2
487 500
488-- FIXME do not flatten if equal order 501-- FIXME do not flatten if equal order
502lM :: (Storable t, Element t1, Element t2)
503 => (Vector t1 -> Vector t2 -> Vector t)
504 -> Matrix t1 -> Matrix t2 -> Matrix t
489lM f m1 m2 = matrixFromVector 505lM f m1 m2 = matrixFromVector
490 RowMajor 506 RowMajor
491 (max' (rows m1) (rows m2)) 507 (max' (rows m1) (rows m2))
@@ -504,6 +520,7 @@ compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2
504 520
505------------------------------------------------------------ 521------------------------------------------------------------
506 522
523toBlockRows :: Element t => [Int] -> Matrix t -> [Matrix t]
507toBlockRows [r] m 524toBlockRows [r] m
508 | r == rows m = [m] 525 | r == rows m = [m]
509toBlockRows rs m 526toBlockRows rs m
@@ -513,6 +530,7 @@ toBlockRows rs m
513 szs = map (* cols m) rs 530 szs = map (* cols m) rs
514 g k = (k><0)[] 531 g k = (k><0)[]
515 532
533toBlockCols :: Element t => [Int] -> Matrix t -> [Matrix t]
516toBlockCols [c] m | c == cols m = [m] 534toBlockCols [c] m | c == cols m = [m]
517toBlockCols cs m = map trans . toBlockRows cs . trans $ m 535toBlockCols cs m = map trans . toBlockRows cs . trans $ m
518 536
@@ -576,7 +594,7 @@ Just (3><3)
576mapMatrixWithIndexM 594mapMatrixWithIndexM
577 :: (Element a, Storable b, Monad m) => 595 :: (Element a, Storable b, Monad m) =>
578 ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b) 596 ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b)
579mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m 597mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m
580 where 598 where
581 c = cols m 599 c = cols m
582 600
@@ -598,4 +616,3 @@ mapMatrixWithIndex g m = reshape c . mapVectorWithIndex (mk c g) . flatten $ m
598 616
599mapMatrix :: (Element a, Element b) => (a -> b) -> Matrix a -> Matrix b 617mapMatrix :: (Element a, Element b) => (a -> b) -> Matrix a -> Matrix b
600mapMatrix f = liftMatrix (mapVector f) 618mapMatrix f = liftMatrix (mapVector f)
601
diff --git a/packages/base/src/Internal/IO.hs b/packages/base/src/Internal/IO.hs
index a899cfd..b0f5606 100644
--- a/packages/base/src/Internal/IO.hs
+++ b/packages/base/src/Internal/IO.hs
@@ -20,7 +20,7 @@ import Internal.Devel
20import Internal.Vector 20import Internal.Vector
21import Internal.Matrix 21import Internal.Matrix
22import Internal.Vectorized 22import Internal.Vectorized
23import Text.Printf(printf) 23import Text.Printf(printf, PrintfArg, PrintfType)
24import Data.List(intersperse,transpose) 24import Data.List(intersperse,transpose)
25import Data.Complex 25import Data.Complex
26 26
@@ -78,12 +78,18 @@ disps d x = sdims x ++ " " ++ formatScaled d x
78dispf :: Int -> Matrix Double -> String 78dispf :: Int -> Matrix Double -> String
79dispf d x = sdims x ++ "\n" ++ formatFixed (if isInt x then 0 else d) x 79dispf d x = sdims x ++ "\n" ++ formatFixed (if isInt x then 0 else d) x
80 80
81sdims :: Matrix t -> [Char]
81sdims x = show (rows x) ++ "x" ++ show (cols x) 82sdims x = show (rows x) ++ "x" ++ show (cols x)
82 83
84formatFixed :: (Show a, Text.Printf.PrintfArg t, Element t)
85 => a -> Matrix t -> String
83formatFixed d x = format " " (printf ("%."++show d++"f")) $ x 86formatFixed d x = format " " (printf ("%."++show d++"f")) $ x
84 87
88isInt :: Matrix Double -> Bool
85isInt = all lookslikeInt . toList . flatten 89isInt = all lookslikeInt . toList . flatten
86 90
91formatScaled :: (Text.Printf.PrintfArg b, RealFrac b, Floating b, Num t, Element b, Show t)
92 => t -> Matrix b -> [Char]
87formatScaled dec t = "E"++show o++"\n" ++ ss 93formatScaled dec t = "E"++show o++"\n" ++ ss
88 where ss = format " " (printf fmt. g) t 94 where ss = format " " (printf fmt. g) t
89 g x | o >= 0 = x/10^(o::Int) 95 g x | o >= 0 = x/10^(o::Int)
@@ -133,14 +139,18 @@ showComplex d (a:+b)
133 s2 = if b<0 then "-" else "" 139 s2 = if b<0 then "-" else ""
134 s3 = if b<0 then "-" else "+" 140 s3 = if b<0 then "-" else "+"
135 141
142shcr :: (Show a, Show t1, Text.Printf.PrintfType t, Text.Printf.PrintfArg t1, RealFrac t1)
143 => a -> t1 -> t
136shcr d a | lookslikeInt a = printf "%.0f" a 144shcr d a | lookslikeInt a = printf "%.0f" a
137 | otherwise = printf ("%."++show d++"f") a 145 | otherwise = printf ("%."++show d++"f") a
138 146
139 147lookslikeInt :: (Show a, RealFrac a) => a -> Bool
140lookslikeInt x = show (round x :: Int) ++".0" == shx || "-0.0" == shx 148lookslikeInt x = show (round x :: Int) ++".0" == shx || "-0.0" == shx
141 where shx = show x 149 where shx = show x
142 150
151isZero :: Show a => a -> Bool
143isZero x = show x `elem` ["0.0","-0.0"] 152isZero x = show x `elem` ["0.0","-0.0"]
153isOne :: Show a => a -> Bool
144isOne x = show x `elem` ["1.0","-1.0"] 154isOne x = show x `elem` ["1.0","-1.0"]
145 155
146-- | Pretty print a complex matrix with at most n decimal digits. 156-- | Pretty print a complex matrix with at most n decimal digits.
@@ -168,6 +178,6 @@ loadMatrix f = do
168 else 178 else
169 return (reshape c v) 179 return (reshape c v)
170 180
171 181loadMatrix' :: FilePath -> IO (Maybe (Matrix Double))
172loadMatrix' name = mbCatch (loadMatrix name) 182loadMatrix' name = mbCatch (loadMatrix name)
173 183
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs
index e306454..64cf2f5 100644
--- a/packages/base/src/Internal/LAPACK.hs
+++ b/packages/base/src/Internal/LAPACK.hs
@@ -1,6 +1,8 @@
1{-# LANGUAGE TypeOperators #-} 1{-# LANGUAGE TypeOperators #-}
2{-# LANGUAGE ViewPatterns #-} 2{-# LANGUAGE ViewPatterns #-}
3 3
4{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
5
4----------------------------------------------------------------------------- 6-----------------------------------------------------------------------------
5-- | 7-- |
6-- Module : Numeric.LinearAlgebra.LAPACK 8-- Module : Numeric.LinearAlgebra.LAPACK
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index 2856ec2..5436e59 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -57,19 +57,24 @@ cols :: Matrix t -> Int
57cols = icols 57cols = icols
58{-# INLINE cols #-} 58{-# INLINE cols #-}
59 59
60size :: Matrix t -> (Int, Int)
60size m = (irows m, icols m) 61size m = (irows m, icols m)
61{-# INLINE size #-} 62{-# INLINE size #-}
62 63
64rowOrder :: Matrix t -> Bool
63rowOrder m = xCol m == 1 || cols m == 1 65rowOrder m = xCol m == 1 || cols m == 1
64{-# INLINE rowOrder #-} 66{-# INLINE rowOrder #-}
65 67
68colOrder :: Matrix t -> Bool
66colOrder m = xRow m == 1 || rows m == 1 69colOrder m = xRow m == 1 || rows m == 1
67{-# INLINE colOrder #-} 70{-# INLINE colOrder #-}
68 71
72is1d :: Matrix t -> Bool
69is1d (size->(r,c)) = r==1 || c==1 73is1d (size->(r,c)) = r==1 || c==1
70{-# INLINE is1d #-} 74{-# INLINE is1d #-}
71 75
72-- data is not contiguous 76-- data is not contiguous
77isSlice :: Storable t => Matrix t -> Bool
73isSlice m@(size->(r,c)) = r*c < dim (xdat m) 78isSlice m@(size->(r,c)) = r*c < dim (xdat m)
74{-# INLINE isSlice #-} 79{-# INLINE isSlice #-}
75 80
@@ -136,16 +141,20 @@ instance Storable t => TransArray (Matrix t)
136 {-# INLINE applyRaw #-} 141 {-# INLINE applyRaw #-}
137 142
138infixr 1 # 143infixr 1 #
144(#) :: TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
139a # b = apply a b 145a # b = apply a b
140{-# INLINE (#) #-} 146{-# INLINE (#) #-}
141 147
148(#!) :: (TransArray c, TransArray c1) => c1 -> c -> Trans c1 (Trans c (IO r)) -> IO r
142a #! b = a # b # id 149a #! b = a # b # id
143{-# INLINE (#!) #-} 150{-# INLINE (#!) #-}
144 151
145-------------------------------------------------------------------------------- 152--------------------------------------------------------------------------------
146 153
154copy :: Element t => MatrixOrder -> Matrix t -> IO (Matrix t)
147copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) 155copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1])
148 156
157extractAll :: Element t => MatrixOrder -> Matrix t -> Matrix t
149extractAll ord m = unsafePerformIO (copy ord m) 158extractAll ord m = unsafePerformIO (copy ord m)
150 159
151{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. 160{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose.
@@ -224,11 +233,13 @@ m@Matrix {irows = r, icols = c} @@> (i,j)
224{-# INLINE (@@>) #-} 233{-# INLINE (@@>) #-}
225 234
226-- Unsafe matrix access without range checking 235-- Unsafe matrix access without range checking
236atM' :: Storable t => Matrix t -> Int -> Int -> t
227atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m)) 237atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m))
228{-# INLINE atM' #-} 238{-# INLINE atM' #-}
229 239
230------------------------------------------------------------------ 240------------------------------------------------------------------
231 241
242matrixFromVector :: Storable t => MatrixOrder -> Int -> Int -> Vector t -> Matrix t
232matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } 243matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 }
233matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } 244matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d }
234matrixFromVector o r c v 245matrixFromVector o r c v
@@ -388,18 +399,21 @@ subMatrix (r0,c0) (rt,ct) m
388 399
389-------------------------------------------------------------------------- 400--------------------------------------------------------------------------
390 401
402maxZ :: (Num t1, Ord t1, Foldable t) => t t1 -> t1
391maxZ xs = if minimum xs == 0 then 0 else maximum xs 403maxZ xs = if minimum xs == 0 then 0 else maximum xs
392 404
405conformMs :: Element t => [Matrix t] -> [Matrix t]
393conformMs ms = map (conformMTo (r,c)) ms 406conformMs ms = map (conformMTo (r,c)) ms
394 where 407 where
395 r = maxZ (map rows ms) 408 r = maxZ (map rows ms)
396 c = maxZ (map cols ms) 409 c = maxZ (map cols ms)
397 410
398 411conformVs :: Element t => [Vector t] -> [Vector t]
399conformVs vs = map (conformVTo n) vs 412conformVs vs = map (conformVTo n) vs
400 where 413 where
401 n = maxZ (map dim vs) 414 n = maxZ (map dim vs)
402 415
416conformMTo :: Element t => (Int, Int) -> Matrix t -> Matrix t
403conformMTo (r,c) m 417conformMTo (r,c) m
404 | size m == (r,c) = m 418 | size m == (r,c) = m
405 | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) 419 | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c))
@@ -407,18 +421,24 @@ conformMTo (r,c) m
407 | size m == (1,c) = repRows r m 421 | size m == (1,c) = repRows r m
408 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c) 422 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c)
409 423
424conformVTo :: Element t => Int -> Vector t -> Vector t
410conformVTo n v 425conformVTo n v
411 | dim v == n = v 426 | dim v == n = v
412 | dim v == 1 = constantD (v@>0) n 427 | dim v == 1 = constantD (v@>0) n
413 | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n 428 | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n
414 429
430repRows :: Element t => Int -> Matrix t -> Matrix t
415repRows n x = fromRows (replicate n (flatten x)) 431repRows n x = fromRows (replicate n (flatten x))
432repCols :: Element t => Int -> Matrix t -> Matrix t
416repCols n x = fromColumns (replicate n (flatten x)) 433repCols n x = fromColumns (replicate n (flatten x))
417 434
435shSize :: Matrix t -> [Char]
418shSize = shDim . size 436shSize = shDim . size
419 437
438shDim :: (Show a, Show a1) => (a1, a) -> [Char]
420shDim (r,c) = "(" ++ show r ++"x"++ show c ++")" 439shDim (r,c) = "(" ++ show r ++"x"++ show c ++")"
421 440
441emptyM :: Storable t => Int -> Int -> Matrix t
422emptyM r c = matrixFromVector RowMajor r c (fromList[]) 442emptyM r c = matrixFromVector RowMajor r c (fromList[])
423 443
424---------------------------------------------------------------------- 444----------------------------------------------------------------------
@@ -433,6 +453,11 @@ instance (Storable t, NFData t) => NFData (Matrix t)
433 453
434--------------------------------------------------------------- 454---------------------------------------------------------------
435 455
456extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1,
457 Storable t, Num t3, Num t2, Integral t1, Integral t)
458 => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t
459 -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))
460 -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a)
436extractAux f ord m moder vr modec vc = do 461extractAux f ord m moder vr modec vc = do
437 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr 462 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr
438 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc 463 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc
@@ -452,6 +477,9 @@ foreign import ccall unsafe "extractL" c_extractL :: Extr Z
452 477
453--------------------------------------------------------------- 478---------------------------------------------------------------
454 479
480setRectAux :: (TransArray c1, TransArray c)
481 => (CInt -> CInt -> Trans c1 (Trans c (IO CInt)))
482 -> Int -> Int -> c1 -> c -> IO ()
455setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" 483setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect"
456 484
457type SetRect x = I -> I -> x ::> x::> Ok 485type SetRect x = I -> I -> x ::> x::> Ok
@@ -465,19 +493,29 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z
465 493
466-------------------------------------------------------------------------------- 494--------------------------------------------------------------------------------
467 495
496sortG :: (Storable t, Storable a)
497 => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a
468sortG f v = unsafePerformIO $ do 498sortG f v = unsafePerformIO $ do
469 r <- createVector (dim v) 499 r <- createVector (dim v)
470 (v #! r) f #|"sortG" 500 (v #! r) f #|"sortG"
471 return r 501 return r
472 502
503sortIdxD :: Vector Double -> Vector CInt
473sortIdxD = sortG c_sort_indexD 504sortIdxD = sortG c_sort_indexD
505sortIdxF :: Vector Float -> Vector CInt
474sortIdxF = sortG c_sort_indexF 506sortIdxF = sortG c_sort_indexF
507sortIdxI :: Vector CInt -> Vector CInt
475sortIdxI = sortG c_sort_indexI 508sortIdxI = sortG c_sort_indexI
509sortIdxL :: Vector Z -> Vector I
476sortIdxL = sortG c_sort_indexL 510sortIdxL = sortG c_sort_indexL
477 511
512sortValD :: Vector Double -> Vector Double
478sortValD = sortG c_sort_valD 513sortValD = sortG c_sort_valD
514sortValF :: Vector Float -> Vector Float
479sortValF = sortG c_sort_valF 515sortValF = sortG c_sort_valF
516sortValI :: Vector CInt -> Vector CInt
480sortValI = sortG c_sort_valI 517sortValI = sortG c_sort_valI
518sortValL :: Vector Z -> Vector Z
481sortValL = sortG c_sort_valL 519sortValL = sortG c_sort_valL
482 520
483foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) 521foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt))
@@ -492,14 +530,21 @@ foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok
492 530
493-------------------------------------------------------------------------------- 531--------------------------------------------------------------------------------
494 532
533compareG :: (TransArray c, Storable t, Storable a)
534 => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt)
535 -> c -> Vector t -> Vector a
495compareG f u v = unsafePerformIO $ do 536compareG f u v = unsafePerformIO $ do
496 r <- createVector (dim v) 537 r <- createVector (dim v)
497 (u # v #! r) f #|"compareG" 538 (u # v #! r) f #|"compareG"
498 return r 539 return r
499 540
541compareD :: Vector Double -> Vector Double -> Vector CInt
500compareD = compareG c_compareD 542compareD = compareG c_compareD
543compareF :: Vector Float -> Vector Float -> Vector CInt
501compareF = compareG c_compareF 544compareF = compareG c_compareF
545compareI :: Vector CInt -> Vector CInt -> Vector CInt
502compareI = compareG c_compareI 546compareI = compareG c_compareI
547compareL :: Vector Z -> Vector Z -> Vector CInt
503compareL = compareG c_compareL 548compareL = compareG c_compareL
504 549
505foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) 550foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt)))
@@ -509,16 +554,33 @@ foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok
509 554
510-------------------------------------------------------------------------------- 555--------------------------------------------------------------------------------
511 556
557selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a)
558 => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt)))
559 -> c2 -> c1 -> Vector t -> c -> Vector a
512selectG f c u v w = unsafePerformIO $ do 560selectG f c u v w = unsafePerformIO $ do
513 r <- createVector (dim v) 561 r <- createVector (dim v)
514 (c # u # v # w #! r) f #|"selectG" 562 (c # u # v # w #! r) f #|"selectG"
515 return r 563 return r
516 564
565selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double
517selectD = selectG c_selectD 566selectD = selectG c_selectD
567selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float
518selectF = selectG c_selectF 568selectF = selectG c_selectF
569selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt
519selectI = selectG c_selectI 570selectI = selectG c_selectI
571selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z
520selectL = selectG c_selectL 572selectL = selectG c_selectL
573selectC :: Vector CInt
574 -> Vector (Complex Double)
575 -> Vector (Complex Double)
576 -> Vector (Complex Double)
577 -> Vector (Complex Double)
521selectC = selectG c_selectC 578selectC = selectG c_selectC
579selectQ :: Vector CInt
580 -> Vector (Complex Float)
581 -> Vector (Complex Float)
582 -> Vector (Complex Float)
583 -> Vector (Complex Float)
522selectQ = selectG c_selectQ 584selectQ = selectG c_selectQ
523 585
524type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) 586type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt)))))
@@ -532,16 +594,29 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z
532 594
533--------------------------------------------------------------------------- 595---------------------------------------------------------------------------
534 596
597remapG :: (TransArray c, TransArray c1, Storable t, Storable a)
598 => (CInt -> CInt -> CInt -> CInt -> Ptr t
599 -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)))
600 -> Matrix t -> c1 -> c -> Matrix a
535remapG f i j m = unsafePerformIO $ do 601remapG f i j m = unsafePerformIO $ do
536 r <- createMatrix RowMajor (rows i) (cols i) 602 r <- createMatrix RowMajor (rows i) (cols i)
537 (i # j # m #! r) f #|"remapG" 603 (i # j # m #! r) f #|"remapG"
538 return r 604 return r
539 605
606remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double
540remapD = remapG c_remapD 607remapD = remapG c_remapD
608remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float
541remapF = remapG c_remapF 609remapF = remapG c_remapF
610remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt
542remapI = remapG c_remapI 611remapI = remapG c_remapI
612remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z
543remapL = remapG c_remapL 613remapL = remapG c_remapL
614remapC :: Matrix CInt
615 -> Matrix CInt
616 -> Matrix (Complex Double)
617 -> Matrix (Complex Double)
544remapC = remapG c_remapC 618remapC = remapG c_remapC
619remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float)
545remapQ = remapG c_remapQ 620remapQ = remapG c_remapQ
546 621
547type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) 622type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt))))
@@ -555,6 +630,9 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z
555 630
556-------------------------------------------------------------------------------- 631--------------------------------------------------------------------------------
557 632
633rowOpAux :: (TransArray c, Storable a) =>
634 (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt))
635 -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO ()
558rowOpAux f c x i1 i2 j1 j2 m = do 636rowOpAux f c x i1 i2 j1 j2 m = do
559 px <- newArray [x] 637 px <- newArray [x]
560 (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" 638 (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp"
@@ -573,6 +651,9 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z
573 651
574-------------------------------------------------------------------------------- 652--------------------------------------------------------------------------------
575 653
654gemmg :: (TransArray c1, TransArray c, TransArray c2, TransArray c3)
655 => Trans c3 (Trans c2 (Trans c1 (Trans c (IO CInt))))
656 -> c3 -> c2 -> c1 -> c -> IO ()
576gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" 657gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg"
577 658
578type Tgemm x = x :> x ::> x ::> x ::> Ok 659type Tgemm x = x :> x ::> x ::> x ::> Ok
@@ -588,6 +669,10 @@ foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z
588 669
589-------------------------------------------------------------------------------- 670--------------------------------------------------------------------------------
590 671
672reorderAux :: (TransArray c, Storable t, Storable a1, Storable t1, Storable a) =>
673 (CInt -> Ptr a -> CInt -> Ptr t1
674 -> Trans c (CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt))
675 -> Vector t1 -> c -> Vector t -> Vector a1
591reorderAux f s d v = unsafePerformIO $ do 676reorderAux f s d v = unsafePerformIO $ do
592 k <- createVector (dim s) 677 k <- createVector (dim s)
593 r <- createVector (dim v) 678 r <- createVector (dim v)
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs
index 9d51444..eb0c5a8 100644
--- a/packages/base/src/Internal/Modular.hs
+++ b/packages/base/src/Internal/Modular.hs
@@ -13,6 +13,9 @@
13{-# LANGUAGE TypeFamilies #-} 13{-# LANGUAGE TypeFamilies #-}
14{-# LANGUAGE TypeOperators #-} 14{-# LANGUAGE TypeOperators #-}
15 15
16{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
17{-# OPTIONS_GHC -fno-warn-missing-methods #-}
18
16{- | 19{- |
17Module : Internal.Modular 20Module : Internal.Modular
18Copyright : (c) Alberto Ruiz 2015 21Copyright : (c) Alberto Ruiz 2015
diff --git a/packages/base/src/Internal/Numeric.hs b/packages/base/src/Internal/Numeric.hs
index c9ef0c5..fd0a217 100644
--- a/packages/base/src/Internal/Numeric.hs
+++ b/packages/base/src/Internal/Numeric.hs
@@ -5,6 +5,8 @@
5{-# LANGUAGE FunctionalDependencies #-} 5{-# LANGUAGE FunctionalDependencies #-}
6{-# LANGUAGE UndecidableInstances #-} 6{-# LANGUAGE UndecidableInstances #-}
7 7
8{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
9
8----------------------------------------------------------------------------- 10-----------------------------------------------------------------------------
9-- | 11-- |
10-- Module : Data.Packed.Internal.Numeric 12-- Module : Data.Packed.Internal.Numeric
@@ -788,13 +790,7 @@ type instance RealOf (Complex Float) = Float
788type instance RealOf I = I 790type instance RealOf I = I
789type instance RealOf Z = Z 791type instance RealOf Z = Z
790 792
791type family ComplexOf x 793type ComplexOf x = Complex (RealOf x)
792
793type instance ComplexOf Double = Complex Double
794type instance ComplexOf (Complex Double) = Complex Double
795
796type instance ComplexOf Float = Complex Float
797type instance ComplexOf (Complex Float) = Complex Float
798 794
799type family SingleOf x 795type family SingleOf x
800 796
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index 544c9e4..7d54e6d 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -81,6 +81,8 @@ unsafeFreezeVector :: (Storable t) => STVector s t -> ST s (Vector t)
81unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x 81unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x
82 82
83{-# INLINE safeIndexV #-} 83{-# INLINE safeIndexV #-}
84safeIndexV :: Storable t2
85 => (STVector s t2 -> Int -> t) -> STVector t1 t2 -> Int -> t
84safeIndexV f (STVector v) k 86safeIndexV f (STVector v) k
85 | k < 0 || k>= dim v = error $ "out of range error in vector (dim=" 87 | k < 0 || k>= dim v = error $ "out of range error in vector (dim="
86 ++show (dim v)++", pos="++show k++")" 88 ++show (dim v)++", pos="++show k++")"
@@ -150,9 +152,12 @@ unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
150freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t) 152freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t)
151freezeMatrix m = liftSTMatrix id m 153freezeMatrix m = liftSTMatrix id m
152 154
155cloneMatrix :: Element t => Matrix t -> IO (Matrix t)
153cloneMatrix m = copy (orderOf m) m 156cloneMatrix m = copy (orderOf m) m
154 157
155{-# INLINE safeIndexM #-} 158{-# INLINE safeIndexM #-}
159safeIndexM :: (STMatrix s t2 -> Int -> Int -> t)
160 -> STMatrix t1 t2 -> Int -> Int -> t
156safeIndexM f (STMatrix m) r c 161safeIndexM f (STMatrix m) r c
157 | r<0 || r>=rows m || 162 | r<0 || r>=rows m ||
158 c<0 || c>=cols m = error $ "out of range error in matrix (size=" 163 c<0 || c>=cols m = error $ "out of range error in matrix (size="
@@ -184,6 +189,7 @@ data ColRange = AllCols
184 | Col Int 189 | Col Int
185 | FromCol Int 190 | FromCol Int
186 191
192getColRange :: Int -> ColRange -> (Int, Int)
187getColRange c AllCols = (0,c-1) 193getColRange c AllCols = (0,c-1)
188getColRange c (ColRange a b) = (a `mod` c, b `mod` c) 194getColRange c (ColRange a b) = (a `mod` c, b `mod` c)
189getColRange c (Col a) = (a `mod` c, a `mod` c) 195getColRange c (Col a) = (a `mod` c, a `mod` c)
@@ -194,6 +200,7 @@ data RowRange = AllRows
194 | Row Int 200 | Row Int
195 | FromRow Int 201 | FromRow Int
196 202
203getRowRange :: Int -> RowRange -> (Int, Int)
197getRowRange r AllRows = (0,r-1) 204getRowRange r AllRows = (0,r-1)
198getRowRange r (RowRange a b) = (a `mod` r, b `mod` r) 205getRowRange r (RowRange a b) = (a `mod` r, b `mod` r)
199getRowRange r (Row a) = (a `mod` r, a `mod` r) 206getRowRange r (Row a) = (a `mod` r, a `mod` r)
@@ -223,6 +230,7 @@ rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m
223 i2' = i2 `mod` (rows m) 230 i2' = i2 `mod` (rows m)
224 231
225 232
233extractMatrix :: Element a => STMatrix t a -> RowRange -> ColRange -> ST s (Matrix a)
226extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) 234extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2]))
227 where 235 where
228 (i1,i2) = getRowRange (rows m) rr 236 (i1,i2) = getRowRange (rows m) rr
@@ -231,6 +239,7 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[
231-- | r0 c0 height width 239-- | r0 c0 height width
232data Slice s t = Slice (STMatrix s t) Int Int Int Int 240data Slice s t = Slice (STMatrix s t) Int Int Int Int
233 241
242slice :: Element a => Slice t a -> Matrix a
234slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m 243slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m
235 244
236gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () 245gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s ()
@@ -238,7 +247,7 @@ gemmm beta (slice->r) alpha (slice->a) (slice->b) = res
238 where 247 where
239 res = unsafeIOToST (gemm v a b r) 248 res = unsafeIOToST (gemm v a b r)
240 v = fromList [alpha,beta] 249 v = fromList [alpha,beta]
241 250
242 251
243mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) 252mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u)
244mutable f a = runST $ do 253mutable f a = runST $ do
@@ -246,4 +255,3 @@ mutable f a = runST $ do
246 info <- f (rows a, cols a) x 255 info <- f (rows a, cols a) x
247 r <- unsafeFreezeMatrix x 256 r <- unsafeFreezeMatrix x
248 return (r,info) 257 return (r,info)
249
diff --git a/packages/base/src/Internal/Sparse.hs b/packages/base/src/Internal/Sparse.hs
index a8a5fe0..fbea11a 100644
--- a/packages/base/src/Internal/Sparse.hs
+++ b/packages/base/src/Internal/Sparse.hs
@@ -2,6 +2,8 @@
2{-# LANGUAGE MultiParamTypeClasses #-} 2{-# LANGUAGE MultiParamTypeClasses #-}
3{-# LANGUAGE FlexibleInstances #-} 3{-# LANGUAGE FlexibleInstances #-}
4 4
5{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
6
5module Internal.Sparse( 7module Internal.Sparse(
6 GMatrix(..), CSR(..), mkCSR, fromCSR, 8 GMatrix(..), CSR(..), mkCSR, fromCSR,
7 mkSparse, mkDiagR, mkDense, 9 mkSparse, mkDiagR, mkDense,
diff --git a/packages/base/src/Internal/Static.hs b/packages/base/src/Internal/Static.hs
index 357645e..566506c 100644
--- a/packages/base/src/Internal/Static.hs
+++ b/packages/base/src/Internal/Static.hs
@@ -15,6 +15,9 @@
15{-# LANGUAGE BangPatterns #-} 15{-# LANGUAGE BangPatterns #-}
16{-# LANGUAGE DeriveGeneric #-} 16{-# LANGUAGE DeriveGeneric #-}
17 17
18{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
19{-# OPTIONS_GHC -fno-warn-simplifiable-class-constraints #-}
20
18{- | 21{- |
19Module : Internal.Static 22Module : Internal.Static
20Copyright : (c) Alberto Ruiz 2006-14 23Copyright : (c) Alberto Ruiz 2006-14
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs
index 959e58f..f642e8d 100644
--- a/packages/base/src/Internal/Util.hs
+++ b/packages/base/src/Internal/Util.hs
@@ -6,6 +6,8 @@
6{-# LANGUAGE ScopedTypeVariables #-} 6{-# LANGUAGE ScopedTypeVariables #-}
7{-# LANGUAGE ViewPatterns #-} 7{-# LANGUAGE ViewPatterns #-}
8 8
9{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
10{-# OPTIONS_GHC -fno-warn-orphans #-}
9 11
10----------------------------------------------------------------------------- 12-----------------------------------------------------------------------------
11{- | 13{- |
diff --git a/packages/base/src/Internal/Vector.hs b/packages/base/src/Internal/Vector.hs
index e1e4aa8..6271bb6 100644
--- a/packages/base/src/Internal/Vector.hs
+++ b/packages/base/src/Internal/Vector.hs
@@ -1,6 +1,7 @@
1{-# LANGUAGE MagicHash, UnboxedTuples, BangPatterns, FlexibleContexts #-} 1{-# LANGUAGE MagicHash, UnboxedTuples, BangPatterns, FlexibleContexts #-}
2{-# LANGUAGE TypeSynonymInstances #-} 2{-# LANGUAGE TypeSynonymInstances #-}
3 3
4{-# OPTIONS_GHC -fno-warn-orphans #-}
4 5
5-- | 6-- |
6-- Module : Internal.Vector 7-- Module : Internal.Vector
@@ -40,6 +41,7 @@ import qualified Data.Vector.Storable as Vector
40import Data.Vector.Storable(Vector, fromList, unsafeToForeignPtr, unsafeFromForeignPtr, unsafeWith) 41import Data.Vector.Storable(Vector, fromList, unsafeToForeignPtr, unsafeFromForeignPtr, unsafeWith)
41 42
42import Data.Binary 43import Data.Binary
44import Data.Binary.Put
43import Control.Monad(replicateM) 45import Control.Monad(replicateM)
44import qualified Data.ByteString.Internal as BS 46import qualified Data.ByteString.Internal as BS
45import Data.Vector.Storable.Internal(updPtr) 47import Data.Vector.Storable.Internal(updPtr)
@@ -92,6 +94,7 @@ createVector n = do
92 94
93-} 95-}
94 96
97safeRead :: Storable a => Vector a -> (Ptr a -> IO c) -> c
95safeRead v = inlinePerformIO . unsafeWith v 98safeRead v = inlinePerformIO . unsafeWith v
96{-# INLINE safeRead #-} 99{-# INLINE safeRead #-}
97 100
@@ -287,11 +290,13 @@ foldVectorWithIndex f x v = unsafePerformIO $
287 go (dim v -1) x 290 go (dim v -1) x
288{-# INLINE foldVectorWithIndex #-} 291{-# INLINE foldVectorWithIndex #-}
289 292
293foldLoop :: (Int -> t -> t) -> t -> Int -> t
290foldLoop f s0 d = go (d - 1) s0 294foldLoop f s0 d = go (d - 1) s0
291 where 295 where
292 go 0 s = f (0::Int) s 296 go 0 s = f (0::Int) s
293 go !j !s = go (j - 1) (f j s) 297 go !j !s = go (j - 1) (f j s)
294 298
299foldVectorG :: Storable t1 => (Int -> (Int -> t1) -> t -> t) -> t -> Vector t1 -> t
295foldVectorG f s0 v = foldLoop g s0 (dim v) 300foldVectorG f s0 v = foldLoop g s0 (dim v)
296 where g !k !s = f k (safeRead v . flip peekElemOff) s 301 where g !k !s = f k (safeRead v . flip peekElemOff) s
297 {-# INLINE g #-} -- Thanks to Ryan Ingram (http://permalink.gmane.org/gmane.comp.lang.haskell.cafe/46479) 302 {-# INLINE g #-} -- Thanks to Ryan Ingram (http://permalink.gmane.org/gmane.comp.lang.haskell.cafe/46479)
@@ -394,8 +399,10 @@ chunks d = let c = d `div` chunk
394 m = d `mod` chunk 399 m = d `mod` chunk
395 in if m /= 0 then reverse (m:(replicate c chunk)) else (replicate c chunk) 400 in if m /= 0 then reverse (m:(replicate c chunk)) else (replicate c chunk)
396 401
402putVector :: (Storable t, Binary t) => Vector t -> Data.Binary.Put.PutM ()
397putVector v = mapM_ put $! toList v 403putVector v = mapM_ put $! toList v
398 404
405getVector :: (Storable a, Binary a) => Int -> Get (Vector a)
399getVector d = do 406getVector d = do
400 xs <- replicateM d get 407 xs <- replicateM d get
401 return $! fromList xs 408 return $! fromList xs
diff --git a/packages/base/src/Internal/Vectorized.hs b/packages/base/src/Internal/Vectorized.hs
index 2990173..32430c6 100644
--- a/packages/base/src/Internal/Vectorized.hs
+++ b/packages/base/src/Internal/Vectorized.hs
@@ -28,12 +28,15 @@ import System.IO.Unsafe(unsafePerformIO)
28import Control.Monad(when) 28import Control.Monad(when)
29 29
30infixr 1 # 30infixr 1 #
31(#) :: TransArray c => c -> (b -> IO r) -> TransRaw c b -> IO r
31a # b = applyRaw a b 32a # b = applyRaw a b
32{-# INLINE (#) #-} 33{-# INLINE (#) #-}
33 34
35(#!) :: (TransArray c, TransArray c1) => c1 -> c -> TransRaw c1 (TransRaw c (IO r)) -> IO r
34a #! b = a # b # id 36a #! b = a # b # id
35{-# INLINE (#!) #-} 37{-# INLINE (#!) #-}
36 38
39fromei :: Enum a => a -> CInt
37fromei x = fromIntegral (fromEnum x) :: CInt 40fromei x = fromIntegral (fromEnum x) :: CInt
38 41
39data FunCodeV = Sin 42data FunCodeV = Sin
@@ -100,10 +103,20 @@ sumQ = sumg c_sumQ
100sumC :: Vector (Complex Double) -> Complex Double 103sumC :: Vector (Complex Double) -> Complex Double
101sumC = sumg c_sumC 104sumC = sumg c_sumC
102 105
106sumI :: ( TransRaw c (CInt -> Ptr a -> IO CInt) ~ (CInt -> Ptr I -> I :> Ok)
107 , TransArray c
108 , Storable a
109 )
110 => I -> c -> a
103sumI m = sumg (c_sumI m) 111sumI m = sumg (c_sumI m)
104 112
113sumL :: ( TransRaw c (CInt -> Ptr a -> IO CInt) ~ (CInt -> Ptr Z -> Z :> Ok)
114 , TransArray c
115 , Storable a
116 ) => Z -> c -> a
105sumL m = sumg (c_sumL m) 117sumL m = sumg (c_sumL m)
106 118
119sumg :: (TransArray c, Storable a) => TransRaw c (CInt -> Ptr a -> IO CInt) -> c -> a
107sumg f x = unsafePerformIO $ do 120sumg f x = unsafePerformIO $ do
108 r <- createVector 1 121 r <- createVector 1
109 (x #! r) f #| "sum" 122 (x #! r) f #| "sum"
@@ -140,6 +153,8 @@ prodI = prodg . c_prodI
140prodL :: Z-> Vector Z -> Z 153prodL :: Z-> Vector Z -> Z
141prodL = prodg . c_prodL 154prodL = prodg . c_prodL
142 155
156prodg :: (TransArray c, Storable a)
157 => TransRaw c (CInt -> Ptr a -> IO CInt) -> c -> a
143prodg f x = unsafePerformIO $ do 158prodg f x = unsafePerformIO $ do
144 r <- createVector 1 159 r <- createVector 1
145 (x #! r) f #| "prod" 160 (x #! r) f #| "prod"
@@ -155,16 +170,25 @@ foreign import ccall unsafe "prodL" c_prodL :: Z -> TVV Z
155 170
156------------------------------------------------------------------ 171------------------------------------------------------------------
157 172
173toScalarAux :: (Enum a, TransArray c, Storable a1)
174 => (CInt -> TransRaw c (CInt -> Ptr a1 -> IO CInt)) -> a -> c -> a1
158toScalarAux fun code v = unsafePerformIO $ do 175toScalarAux fun code v = unsafePerformIO $ do
159 r <- createVector 1 176 r <- createVector 1
160 (v #! r) (fun (fromei code)) #|"toScalarAux" 177 (v #! r) (fun (fromei code)) #|"toScalarAux"
161 return (r @> 0) 178 return (r @> 0)
162 179
180
181vectorMapAux :: (Enum a, Storable t, Storable a1)
182 => (CInt -> CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt)
183 -> a -> Vector t -> Vector a1
163vectorMapAux fun code v = unsafePerformIO $ do 184vectorMapAux fun code v = unsafePerformIO $ do
164 r <- createVector (dim v) 185 r <- createVector (dim v)
165 (v #! r) (fun (fromei code)) #|"vectorMapAux" 186 (v #! r) (fun (fromei code)) #|"vectorMapAux"
166 return r 187 return r
167 188
189vectorMapValAux :: (Enum a, Storable a2, Storable t, Storable a1)
190 => (CInt -> Ptr a2 -> CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt)
191 -> a -> a2 -> Vector t -> Vector a1
168vectorMapValAux fun code val v = unsafePerformIO $ do 192vectorMapValAux fun code val v = unsafePerformIO $ do
169 r <- createVector (dim v) 193 r <- createVector (dim v)
170 pval <- newArray [val] 194 pval <- newArray [val]
@@ -172,6 +196,9 @@ vectorMapValAux fun code val v = unsafePerformIO $ do
172 free pval 196 free pval
173 return r 197 return r
174 198
199vectorZipAux :: (Enum a, TransArray c, Storable t, Storable a1)
200 => (CInt -> CInt -> Ptr t -> TransRaw c (CInt -> Ptr a1 -> IO CInt))
201 -> a -> Vector t -> c -> Vector a1
175vectorZipAux fun code u v = unsafePerformIO $ do 202vectorZipAux fun code u v = unsafePerformIO $ do
176 r <- createVector (dim u) 203 r <- createVector (dim u)
177 (u # v #! r) (fun (fromei code)) #|"vectorZipAux" 204 (u # v #! r) (fun (fromei code)) #|"vectorZipAux"
@@ -378,6 +405,7 @@ foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> D
378 405
379-------------------------------------------------------------------------------- 406--------------------------------------------------------------------------------
380 407
408roundVector :: Vector Double -> Vector Double
381roundVector v = unsafePerformIO $ do 409roundVector v = unsafePerformIO $ do
382 r <- createVector (dim v) 410 r <- createVector (dim v)
383 (v #! r) c_round_vector #|"roundVector" 411 (v #! r) c_round_vector #|"roundVector"
@@ -433,6 +461,8 @@ long2intV :: Vector Z -> Vector I
433long2intV = tog c_long2int 461long2intV = tog c_long2int
434 462
435 463
464tog :: (Storable t, Storable a)
465 => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a
436tog f v = unsafePerformIO $ do 466tog f v = unsafePerformIO $ do
437 r <- createVector (dim v) 467 r <- createVector (dim v)
438 (v #! r) f #|"tog" 468 (v #! r) f #|"tog"
@@ -452,6 +482,8 @@ foreign import ccall unsafe "long2int" c_long2int :: Z :> I :> Ok
452 482
453--------------------------------------------------------------- 483---------------------------------------------------------------
454 484
485stepg :: (Storable t, Storable a)
486 => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a
455stepg f v = unsafePerformIO $ do 487stepg f v = unsafePerformIO $ do
456 r <- createVector (dim v) 488 r <- createVector (dim v)
457 (v #! r) f #|"step" 489 (v #! r) f #|"step"
@@ -477,6 +509,8 @@ foreign import ccall unsafe "stepL" c_stepL :: TVV Z
477 509
478-------------------------------------------------------------------------------- 510--------------------------------------------------------------------------------
479 511
512conjugateAux :: (Storable t, Storable a)
513 => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a
480conjugateAux fun x = unsafePerformIO $ do 514conjugateAux fun x = unsafePerformIO $ do
481 v <- createVector (dim x) 515 v <- createVector (dim x)
482 (x #! v) fun #|"conjugateAux" 516 (x #! v) fun #|"conjugateAux"
@@ -502,6 +536,8 @@ cloneVector v = do
502 536
503-------------------------------------------------------------------------------- 537--------------------------------------------------------------------------------
504 538
539constantAux :: (Storable a1, Storable a)
540 => (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a
505constantAux fun x n = unsafePerformIO $ do 541constantAux fun x n = unsafePerformIO $ do
506 v <- createVector n 542 v <- createVector n
507 px <- newArray [x] 543 px <- newArray [x]