diff options
Diffstat (limited to 'packages/base')
21 files changed, 222 insertions, 13 deletions
diff --git a/packages/base/src/Internal/Algorithms.hs b/packages/base/src/Internal/Algorithms.hs index 99c9e34..cea06ce 100644 --- a/packages/base/src/Internal/Algorithms.hs +++ b/packages/base/src/Internal/Algorithms.hs | |||
@@ -4,6 +4,8 @@ | |||
4 | {-# LANGUAGE UndecidableInstances #-} | 4 | {-# LANGUAGE UndecidableInstances #-} |
5 | {-# LANGUAGE TypeFamilies #-} | 5 | {-# LANGUAGE TypeFamilies #-} |
6 | 6 | ||
7 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
8 | |||
7 | ----------------------------------------------------------------------------- | 9 | ----------------------------------------------------------------------------- |
8 | {- | | 10 | {- | |
9 | Module : Internal.Algorithms | 11 | Module : Internal.Algorithms |
diff --git a/packages/base/src/Internal/CG.hs b/packages/base/src/Internal/CG.hs index cc10ad8..29edd35 100644 --- a/packages/base/src/Internal/CG.hs +++ b/packages/base/src/Internal/CG.hs | |||
@@ -1,6 +1,8 @@ | |||
1 | {-# LANGUAGE FlexibleContexts, FlexibleInstances #-} | 1 | {-# LANGUAGE FlexibleContexts, FlexibleInstances #-} |
2 | {-# LANGUAGE RecordWildCards #-} | 2 | {-# LANGUAGE RecordWildCards #-} |
3 | 3 | ||
4 | {-# OPTIONS_GHC -fno-warn-orphans #-} | ||
5 | |||
4 | module Internal.CG( | 6 | module Internal.CG( |
5 | cgSolve, cgSolve', | 7 | cgSolve, cgSolve', |
6 | CGState(..), R, V | 8 | CGState(..), R, V |
diff --git a/packages/base/src/Internal/Chain.hs b/packages/base/src/Internal/Chain.hs index f87eb02..4000c2b 100644 --- a/packages/base/src/Internal/Chain.hs +++ b/packages/base/src/Internal/Chain.hs | |||
@@ -1,5 +1,7 @@ | |||
1 | {-# LANGUAGE FlexibleContexts #-} | 1 | {-# LANGUAGE FlexibleContexts #-} |
2 | 2 | ||
3 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
4 | |||
3 | ----------------------------------------------------------------------------- | 5 | ----------------------------------------------------------------------------- |
4 | -- | | 6 | -- | |
5 | -- Module : Internal.Chain | 7 | -- Module : Internal.Chain |
diff --git a/packages/base/src/Internal/Devel.hs b/packages/base/src/Internal/Devel.hs index 3887663..f72d8aa 100644 --- a/packages/base/src/Internal/Devel.hs +++ b/packages/base/src/Internal/Devel.hs | |||
@@ -54,6 +54,7 @@ check msg f = do | |||
54 | 54 | ||
55 | -- | postfix error code check | 55 | -- | postfix error code check |
56 | infixl 0 #| | 56 | infixl 0 #| |
57 | (#|) :: IO CInt -> String -> IO () | ||
57 | (#|) = flip check | 58 | (#|) = flip check |
58 | 59 | ||
59 | -- | Error capture and conversion to Maybe | 60 | -- | Error capture and conversion to Maybe |
diff --git a/packages/base/src/Internal/Element.hs b/packages/base/src/Internal/Element.hs index eb3a25b..2e330ee 100644 --- a/packages/base/src/Internal/Element.hs +++ b/packages/base/src/Internal/Element.hs | |||
@@ -4,6 +4,8 @@ | |||
4 | {-# LANGUAGE UndecidableInstances #-} | 4 | {-# LANGUAGE UndecidableInstances #-} |
5 | {-# LANGUAGE MultiParamTypeClasses #-} | 5 | {-# LANGUAGE MultiParamTypeClasses #-} |
6 | 6 | ||
7 | {-# OPTIONS_GHC -fno-warn-orphans #-} | ||
8 | |||
7 | ----------------------------------------------------------------------------- | 9 | ----------------------------------------------------------------------------- |
8 | -- | | 10 | -- | |
9 | -- Module : Data.Packed.Matrix | 11 | -- Module : Data.Packed.Matrix |
@@ -31,6 +33,7 @@ import Data.List.Split(chunksOf) | |||
31 | import Foreign.Storable(Storable) | 33 | import Foreign.Storable(Storable) |
32 | import System.IO.Unsafe(unsafePerformIO) | 34 | import System.IO.Unsafe(unsafePerformIO) |
33 | import Control.Monad(liftM) | 35 | import Control.Monad(liftM) |
36 | import Foreign.C.Types(CInt) | ||
34 | 37 | ||
35 | ------------------------------------------------------------------- | 38 | ------------------------------------------------------------------- |
36 | 39 | ||
@@ -53,8 +56,10 @@ instance (Show a, Element a) => (Show (Matrix a)) where | |||
53 | show m | rows m == 0 || cols m == 0 = sizes m ++" []" | 56 | show m | rows m == 0 || cols m == 0 = sizes m ++" []" |
54 | show m = (sizes m++) . dsp . map (map show) . toLists $ m | 57 | show m = (sizes m++) . dsp . map (map show) . toLists $ m |
55 | 58 | ||
59 | sizes :: Matrix t -> [Char] | ||
56 | sizes m = "("++show (rows m)++"><"++show (cols m)++")\n" | 60 | sizes m = "("++show (rows m)++"><"++show (cols m)++")\n" |
57 | 61 | ||
62 | dsp :: [[[Char]]] -> [Char] | ||
58 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp | 63 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp |
59 | where | 64 | where |
60 | mt = transpose as | 65 | mt = transpose as |
@@ -73,6 +78,7 @@ instance (Element a, Read a) => Read (Matrix a) where | |||
73 | rs = read . snd . breakAt '(' .init . fst . breakAt '>' $ dims | 78 | rs = read . snd . breakAt '(' .init . fst . breakAt '>' $ dims |
74 | 79 | ||
75 | 80 | ||
81 | breakAt :: Eq a => a -> [a] -> ([a], [a]) | ||
76 | breakAt c l = (a++[c],tail b) where | 82 | breakAt c l = (a++[c],tail b) where |
77 | (a,b) = break (==c) l | 83 | (a,b) = break (==c) l |
78 | 84 | ||
@@ -88,7 +94,8 @@ data Extractor | |||
88 | | Drop Int | 94 | | Drop Int |
89 | | DropLast Int | 95 | | DropLast Int |
90 | deriving Show | 96 | deriving Show |
91 | 97 | ||
98 | ppext :: Extractor -> [Char] | ||
92 | ppext All = ":" | 99 | ppext All = ":" |
93 | ppext (Range a 1 c) = printf "%d:%d" a c | 100 | ppext (Range a 1 c) = printf "%d:%d" a c |
94 | ppext (Range a b c) = printf "%d:%d:%d" a b c | 101 | ppext (Range a b c) = printf "%d:%d:%d" a b c |
@@ -128,10 +135,14 @@ ppext (DropLast n) = printf "DropLast %d" n | |||
128 | infixl 9 ?? | 135 | infixl 9 ?? |
129 | (??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t | 136 | (??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t |
130 | 137 | ||
138 | minEl :: Vector CInt -> CInt | ||
131 | minEl = toScalarI Min | 139 | minEl = toScalarI Min |
140 | maxEl :: Vector CInt -> CInt | ||
132 | maxEl = toScalarI Max | 141 | maxEl = toScalarI Max |
142 | cmodi :: Foreign.C.Types.CInt -> Vector Foreign.C.Types.CInt -> Vector Foreign.C.Types.CInt | ||
133 | cmodi = vectorMapValI ModVS | 143 | cmodi = vectorMapValI ModVS |
134 | 144 | ||
145 | extractError :: Matrix t1 -> (Extractor, Extractor) -> t | ||
135 | extractError m (e1,e2)= error $ printf "can't extract (%s,%s) from matrix %dx%d" (ppext e1::String) (ppext e2::String) (rows m) (cols m) | 146 | extractError m (e1,e2)= error $ printf "can't extract (%s,%s) from matrix %dx%d" (ppext e1::String) (ppext e2::String) (rows m) (cols m) |
136 | 147 | ||
137 | m ?? (Range a s b,e) | s /= 1 = m ?? (Pos (idxs [a,a+s .. b]), e) | 148 | m ?? (Range a s b,e) | s /= 1 = m ?? (Pos (idxs [a,a+s .. b]), e) |
@@ -232,8 +243,10 @@ disp = putStr . dispf 2 | |||
232 | fromBlocks :: Element t => [[Matrix t]] -> Matrix t | 243 | fromBlocks :: Element t => [[Matrix t]] -> Matrix t |
233 | fromBlocks = fromBlocksRaw . adaptBlocks | 244 | fromBlocks = fromBlocksRaw . adaptBlocks |
234 | 245 | ||
246 | fromBlocksRaw :: Element t => [[Matrix t]] -> Matrix t | ||
235 | fromBlocksRaw mms = joinVert . map joinHoriz $ mms | 247 | fromBlocksRaw mms = joinVert . map joinHoriz $ mms |
236 | 248 | ||
249 | adaptBlocks :: Element t => [[Matrix t]] -> [[Matrix t]] | ||
237 | adaptBlocks ms = ms' where | 250 | adaptBlocks ms = ms' where |
238 | bc = case common length ms of | 251 | bc = case common length ms of |
239 | Just c -> c | 252 | Just c -> c |
@@ -486,6 +499,9 @@ liftMatrix2Auto f m1 m2 | |||
486 | m2' = conformMTo (r,c) m2 | 499 | m2' = conformMTo (r,c) m2 |
487 | 500 | ||
488 | -- FIXME do not flatten if equal order | 501 | -- FIXME do not flatten if equal order |
502 | lM :: (Storable t, Element t1, Element t2) | ||
503 | => (Vector t1 -> Vector t2 -> Vector t) | ||
504 | -> Matrix t1 -> Matrix t2 -> Matrix t | ||
489 | lM f m1 m2 = matrixFromVector | 505 | lM f m1 m2 = matrixFromVector |
490 | RowMajor | 506 | RowMajor |
491 | (max' (rows m1) (rows m2)) | 507 | (max' (rows m1) (rows m2)) |
@@ -504,6 +520,7 @@ compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 | |||
504 | 520 | ||
505 | ------------------------------------------------------------ | 521 | ------------------------------------------------------------ |
506 | 522 | ||
523 | toBlockRows :: Element t => [Int] -> Matrix t -> [Matrix t] | ||
507 | toBlockRows [r] m | 524 | toBlockRows [r] m |
508 | | r == rows m = [m] | 525 | | r == rows m = [m] |
509 | toBlockRows rs m | 526 | toBlockRows rs m |
@@ -513,6 +530,7 @@ toBlockRows rs m | |||
513 | szs = map (* cols m) rs | 530 | szs = map (* cols m) rs |
514 | g k = (k><0)[] | 531 | g k = (k><0)[] |
515 | 532 | ||
533 | toBlockCols :: Element t => [Int] -> Matrix t -> [Matrix t] | ||
516 | toBlockCols [c] m | c == cols m = [m] | 534 | toBlockCols [c] m | c == cols m = [m] |
517 | toBlockCols cs m = map trans . toBlockRows cs . trans $ m | 535 | toBlockCols cs m = map trans . toBlockRows cs . trans $ m |
518 | 536 | ||
@@ -576,7 +594,7 @@ Just (3><3) | |||
576 | mapMatrixWithIndexM | 594 | mapMatrixWithIndexM |
577 | :: (Element a, Storable b, Monad m) => | 595 | :: (Element a, Storable b, Monad m) => |
578 | ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b) | 596 | ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b) |
579 | mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m | 597 | mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m |
580 | where | 598 | where |
581 | c = cols m | 599 | c = cols m |
582 | 600 | ||
@@ -598,4 +616,3 @@ mapMatrixWithIndex g m = reshape c . mapVectorWithIndex (mk c g) . flatten $ m | |||
598 | 616 | ||
599 | mapMatrix :: (Element a, Element b) => (a -> b) -> Matrix a -> Matrix b | 617 | mapMatrix :: (Element a, Element b) => (a -> b) -> Matrix a -> Matrix b |
600 | mapMatrix f = liftMatrix (mapVector f) | 618 | mapMatrix f = liftMatrix (mapVector f) |
601 | |||
diff --git a/packages/base/src/Internal/IO.hs b/packages/base/src/Internal/IO.hs index a899cfd..b0f5606 100644 --- a/packages/base/src/Internal/IO.hs +++ b/packages/base/src/Internal/IO.hs | |||
@@ -20,7 +20,7 @@ import Internal.Devel | |||
20 | import Internal.Vector | 20 | import Internal.Vector |
21 | import Internal.Matrix | 21 | import Internal.Matrix |
22 | import Internal.Vectorized | 22 | import Internal.Vectorized |
23 | import Text.Printf(printf) | 23 | import Text.Printf(printf, PrintfArg, PrintfType) |
24 | import Data.List(intersperse,transpose) | 24 | import Data.List(intersperse,transpose) |
25 | import Data.Complex | 25 | import Data.Complex |
26 | 26 | ||
@@ -78,12 +78,18 @@ disps d x = sdims x ++ " " ++ formatScaled d x | |||
78 | dispf :: Int -> Matrix Double -> String | 78 | dispf :: Int -> Matrix Double -> String |
79 | dispf d x = sdims x ++ "\n" ++ formatFixed (if isInt x then 0 else d) x | 79 | dispf d x = sdims x ++ "\n" ++ formatFixed (if isInt x then 0 else d) x |
80 | 80 | ||
81 | sdims :: Matrix t -> [Char] | ||
81 | sdims x = show (rows x) ++ "x" ++ show (cols x) | 82 | sdims x = show (rows x) ++ "x" ++ show (cols x) |
82 | 83 | ||
84 | formatFixed :: (Show a, Text.Printf.PrintfArg t, Element t) | ||
85 | => a -> Matrix t -> String | ||
83 | formatFixed d x = format " " (printf ("%."++show d++"f")) $ x | 86 | formatFixed d x = format " " (printf ("%."++show d++"f")) $ x |
84 | 87 | ||
88 | isInt :: Matrix Double -> Bool | ||
85 | isInt = all lookslikeInt . toList . flatten | 89 | isInt = all lookslikeInt . toList . flatten |
86 | 90 | ||
91 | formatScaled :: (Text.Printf.PrintfArg b, RealFrac b, Floating b, Num t, Element b, Show t) | ||
92 | => t -> Matrix b -> [Char] | ||
87 | formatScaled dec t = "E"++show o++"\n" ++ ss | 93 | formatScaled dec t = "E"++show o++"\n" ++ ss |
88 | where ss = format " " (printf fmt. g) t | 94 | where ss = format " " (printf fmt. g) t |
89 | g x | o >= 0 = x/10^(o::Int) | 95 | g x | o >= 0 = x/10^(o::Int) |
@@ -133,14 +139,18 @@ showComplex d (a:+b) | |||
133 | s2 = if b<0 then "-" else "" | 139 | s2 = if b<0 then "-" else "" |
134 | s3 = if b<0 then "-" else "+" | 140 | s3 = if b<0 then "-" else "+" |
135 | 141 | ||
142 | shcr :: (Show a, Show t1, Text.Printf.PrintfType t, Text.Printf.PrintfArg t1, RealFrac t1) | ||
143 | => a -> t1 -> t | ||
136 | shcr d a | lookslikeInt a = printf "%.0f" a | 144 | shcr d a | lookslikeInt a = printf "%.0f" a |
137 | | otherwise = printf ("%."++show d++"f") a | 145 | | otherwise = printf ("%."++show d++"f") a |
138 | 146 | ||
139 | 147 | lookslikeInt :: (Show a, RealFrac a) => a -> Bool | |
140 | lookslikeInt x = show (round x :: Int) ++".0" == shx || "-0.0" == shx | 148 | lookslikeInt x = show (round x :: Int) ++".0" == shx || "-0.0" == shx |
141 | where shx = show x | 149 | where shx = show x |
142 | 150 | ||
151 | isZero :: Show a => a -> Bool | ||
143 | isZero x = show x `elem` ["0.0","-0.0"] | 152 | isZero x = show x `elem` ["0.0","-0.0"] |
153 | isOne :: Show a => a -> Bool | ||
144 | isOne x = show x `elem` ["1.0","-1.0"] | 154 | isOne x = show x `elem` ["1.0","-1.0"] |
145 | 155 | ||
146 | -- | Pretty print a complex matrix with at most n decimal digits. | 156 | -- | Pretty print a complex matrix with at most n decimal digits. |
@@ -168,6 +178,6 @@ loadMatrix f = do | |||
168 | else | 178 | else |
169 | return (reshape c v) | 179 | return (reshape c v) |
170 | 180 | ||
171 | 181 | loadMatrix' :: FilePath -> IO (Maybe (Matrix Double)) | |
172 | loadMatrix' name = mbCatch (loadMatrix name) | 182 | loadMatrix' name = mbCatch (loadMatrix name) |
173 | 183 | ||
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index e306454..64cf2f5 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs | |||
@@ -1,6 +1,8 @@ | |||
1 | {-# LANGUAGE TypeOperators #-} | 1 | {-# LANGUAGE TypeOperators #-} |
2 | {-# LANGUAGE ViewPatterns #-} | 2 | {-# LANGUAGE ViewPatterns #-} |
3 | 3 | ||
4 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
5 | |||
4 | ----------------------------------------------------------------------------- | 6 | ----------------------------------------------------------------------------- |
5 | -- | | 7 | -- | |
6 | -- Module : Numeric.LinearAlgebra.LAPACK | 8 | -- Module : Numeric.LinearAlgebra.LAPACK |
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 4905f61..4bfa13d 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -57,19 +57,24 @@ cols :: Matrix t -> Int | |||
57 | cols = icols | 57 | cols = icols |
58 | {-# INLINE cols #-} | 58 | {-# INLINE cols #-} |
59 | 59 | ||
60 | size :: Matrix t -> (Int, Int) | ||
60 | size m = (irows m, icols m) | 61 | size m = (irows m, icols m) |
61 | {-# INLINE size #-} | 62 | {-# INLINE size #-} |
62 | 63 | ||
64 | rowOrder :: Matrix t -> Bool | ||
63 | rowOrder m = xCol m == 1 || cols m == 1 | 65 | rowOrder m = xCol m == 1 || cols m == 1 |
64 | {-# INLINE rowOrder #-} | 66 | {-# INLINE rowOrder #-} |
65 | 67 | ||
68 | colOrder :: Matrix t -> Bool | ||
66 | colOrder m = xRow m == 1 || rows m == 1 | 69 | colOrder m = xRow m == 1 || rows m == 1 |
67 | {-# INLINE colOrder #-} | 70 | {-# INLINE colOrder #-} |
68 | 71 | ||
72 | is1d :: Matrix t -> Bool | ||
69 | is1d (size->(r,c)) = r==1 || c==1 | 73 | is1d (size->(r,c)) = r==1 || c==1 |
70 | {-# INLINE is1d #-} | 74 | {-# INLINE is1d #-} |
71 | 75 | ||
72 | -- data is not contiguous | 76 | -- data is not contiguous |
77 | isSlice :: Storable t => Matrix t -> Bool | ||
73 | isSlice m@(size->(r,c)) = r*c < dim (xdat m) | 78 | isSlice m@(size->(r,c)) = r*c < dim (xdat m) |
74 | {-# INLINE isSlice #-} | 79 | {-# INLINE isSlice #-} |
75 | 80 | ||
@@ -136,16 +141,20 @@ instance Storable t => TransArray (Matrix t) | |||
136 | {-# INLINE applyRaw #-} | 141 | {-# INLINE applyRaw #-} |
137 | 142 | ||
138 | infixr 1 # | 143 | infixr 1 # |
144 | (#) :: TransArray c => c -> (b -> IO r) -> Trans c b -> IO r | ||
139 | a # b = apply a b | 145 | a # b = apply a b |
140 | {-# INLINE (#) #-} | 146 | {-# INLINE (#) #-} |
141 | 147 | ||
148 | (#!) :: (TransArray c, TransArray c1) => c1 -> c -> Trans c1 (Trans c (IO r)) -> IO r | ||
142 | a #! b = a # b # id | 149 | a #! b = a # b # id |
143 | {-# INLINE (#!) #-} | 150 | {-# INLINE (#!) #-} |
144 | 151 | ||
145 | -------------------------------------------------------------------------------- | 152 | -------------------------------------------------------------------------------- |
146 | 153 | ||
154 | copy :: Element t => MatrixOrder -> Matrix t -> IO (Matrix t) | ||
147 | copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) | 155 | copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) |
148 | 156 | ||
157 | extractAll :: Element t => MatrixOrder -> Matrix t -> Matrix t | ||
149 | extractAll ord m = unsafePerformIO (copy ord m) | 158 | extractAll ord m = unsafePerformIO (copy ord m) |
150 | 159 | ||
151 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. | 160 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. |
@@ -223,11 +232,13 @@ m@Matrix {irows = r, icols = c} @@> (i,j) | |||
223 | {-# INLINE (@@>) #-} | 232 | {-# INLINE (@@>) #-} |
224 | 233 | ||
225 | -- Unsafe matrix access without range checking | 234 | -- Unsafe matrix access without range checking |
235 | atM' :: Storable t => Matrix t -> Int -> Int -> t | ||
226 | atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m)) | 236 | atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m)) |
227 | {-# INLINE atM' #-} | 237 | {-# INLINE atM' #-} |
228 | 238 | ||
229 | ------------------------------------------------------------------ | 239 | ------------------------------------------------------------------ |
230 | 240 | ||
241 | matrixFromVector :: Storable t => MatrixOrder -> Int -> Int -> Vector t -> Matrix t | ||
231 | matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } | 242 | matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } |
232 | matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } | 243 | matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } |
233 | matrixFromVector o r c v | 244 | matrixFromVector o r c v |
@@ -387,18 +398,21 @@ subMatrix (r0,c0) (rt,ct) m | |||
387 | 398 | ||
388 | -------------------------------------------------------------------------- | 399 | -------------------------------------------------------------------------- |
389 | 400 | ||
401 | maxZ :: (Num t1, Ord t1, Foldable t) => t t1 -> t1 | ||
390 | maxZ xs = if minimum xs == 0 then 0 else maximum xs | 402 | maxZ xs = if minimum xs == 0 then 0 else maximum xs |
391 | 403 | ||
404 | conformMs :: Element t => [Matrix t] -> [Matrix t] | ||
392 | conformMs ms = map (conformMTo (r,c)) ms | 405 | conformMs ms = map (conformMTo (r,c)) ms |
393 | where | 406 | where |
394 | r = maxZ (map rows ms) | 407 | r = maxZ (map rows ms) |
395 | c = maxZ (map cols ms) | 408 | c = maxZ (map cols ms) |
396 | 409 | ||
397 | 410 | conformVs :: Element t => [Vector t] -> [Vector t] | |
398 | conformVs vs = map (conformVTo n) vs | 411 | conformVs vs = map (conformVTo n) vs |
399 | where | 412 | where |
400 | n = maxZ (map dim vs) | 413 | n = maxZ (map dim vs) |
401 | 414 | ||
415 | conformMTo :: Element t => (Int, Int) -> Matrix t -> Matrix t | ||
402 | conformMTo (r,c) m | 416 | conformMTo (r,c) m |
403 | | size m == (r,c) = m | 417 | | size m == (r,c) = m |
404 | | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) | 418 | | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) |
@@ -406,18 +420,24 @@ conformMTo (r,c) m | |||
406 | | size m == (1,c) = repRows r m | 420 | | size m == (1,c) = repRows r m |
407 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c) | 421 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c) |
408 | 422 | ||
423 | conformVTo :: Element t => Int -> Vector t -> Vector t | ||
409 | conformVTo n v | 424 | conformVTo n v |
410 | | dim v == n = v | 425 | | dim v == n = v |
411 | | dim v == 1 = constantD (v@>0) n | 426 | | dim v == 1 = constantD (v@>0) n |
412 | | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n | 427 | | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n |
413 | 428 | ||
429 | repRows :: Element t => Int -> Matrix t -> Matrix t | ||
414 | repRows n x = fromRows (replicate n (flatten x)) | 430 | repRows n x = fromRows (replicate n (flatten x)) |
431 | repCols :: Element t => Int -> Matrix t -> Matrix t | ||
415 | repCols n x = fromColumns (replicate n (flatten x)) | 432 | repCols n x = fromColumns (replicate n (flatten x)) |
416 | 433 | ||
434 | shSize :: Matrix t -> [Char] | ||
417 | shSize = shDim . size | 435 | shSize = shDim . size |
418 | 436 | ||
437 | shDim :: (Show a, Show a1) => (a1, a) -> [Char] | ||
419 | shDim (r,c) = "(" ++ show r ++"x"++ show c ++")" | 438 | shDim (r,c) = "(" ++ show r ++"x"++ show c ++")" |
420 | 439 | ||
440 | emptyM :: Storable t => Int -> Int -> Matrix t | ||
421 | emptyM r c = matrixFromVector RowMajor r c (fromList[]) | 441 | emptyM r c = matrixFromVector RowMajor r c (fromList[]) |
422 | 442 | ||
423 | ---------------------------------------------------------------------- | 443 | ---------------------------------------------------------------------- |
@@ -432,6 +452,11 @@ instance (Storable t, NFData t) => NFData (Matrix t) | |||
432 | 452 | ||
433 | --------------------------------------------------------------- | 453 | --------------------------------------------------------------- |
434 | 454 | ||
455 | extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1, | ||
456 | Storable t, Num t3, Num t2, Integral t1, Integral t) | ||
457 | => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t | ||
458 | -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) | ||
459 | -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a) | ||
435 | extractAux f ord m moder vr modec vc = do | 460 | extractAux f ord m moder vr modec vc = do |
436 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr | 461 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr |
437 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc | 462 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc |
@@ -451,6 +476,9 @@ foreign import ccall unsafe "extractL" c_extractL :: Extr Z | |||
451 | 476 | ||
452 | --------------------------------------------------------------- | 477 | --------------------------------------------------------------- |
453 | 478 | ||
479 | setRectAux :: (TransArray c1, TransArray c) | ||
480 | => (CInt -> CInt -> Trans c1 (Trans c (IO CInt))) | ||
481 | -> Int -> Int -> c1 -> c -> IO () | ||
454 | setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" | 482 | setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" |
455 | 483 | ||
456 | type SetRect x = I -> I -> x ::> x::> Ok | 484 | type SetRect x = I -> I -> x ::> x::> Ok |
@@ -464,19 +492,29 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z | |||
464 | 492 | ||
465 | -------------------------------------------------------------------------------- | 493 | -------------------------------------------------------------------------------- |
466 | 494 | ||
495 | sortG :: (Storable t, Storable a) | ||
496 | => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a | ||
467 | sortG f v = unsafePerformIO $ do | 497 | sortG f v = unsafePerformIO $ do |
468 | r <- createVector (dim v) | 498 | r <- createVector (dim v) |
469 | (v #! r) f #|"sortG" | 499 | (v #! r) f #|"sortG" |
470 | return r | 500 | return r |
471 | 501 | ||
502 | sortIdxD :: Vector Double -> Vector CInt | ||
472 | sortIdxD = sortG c_sort_indexD | 503 | sortIdxD = sortG c_sort_indexD |
504 | sortIdxF :: Vector Float -> Vector CInt | ||
473 | sortIdxF = sortG c_sort_indexF | 505 | sortIdxF = sortG c_sort_indexF |
506 | sortIdxI :: Vector CInt -> Vector CInt | ||
474 | sortIdxI = sortG c_sort_indexI | 507 | sortIdxI = sortG c_sort_indexI |
508 | sortIdxL :: Vector Z -> Vector I | ||
475 | sortIdxL = sortG c_sort_indexL | 509 | sortIdxL = sortG c_sort_indexL |
476 | 510 | ||
511 | sortValD :: Vector Double -> Vector Double | ||
477 | sortValD = sortG c_sort_valD | 512 | sortValD = sortG c_sort_valD |
513 | sortValF :: Vector Float -> Vector Float | ||
478 | sortValF = sortG c_sort_valF | 514 | sortValF = sortG c_sort_valF |
515 | sortValI :: Vector CInt -> Vector CInt | ||
479 | sortValI = sortG c_sort_valI | 516 | sortValI = sortG c_sort_valI |
517 | sortValL :: Vector Z -> Vector Z | ||
480 | sortValL = sortG c_sort_valL | 518 | sortValL = sortG c_sort_valL |
481 | 519 | ||
482 | foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) | 520 | foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) |
@@ -491,14 +529,21 @@ foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok | |||
491 | 529 | ||
492 | -------------------------------------------------------------------------------- | 530 | -------------------------------------------------------------------------------- |
493 | 531 | ||
532 | compareG :: (TransArray c, Storable t, Storable a) | ||
533 | => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) | ||
534 | -> c -> Vector t -> Vector a | ||
494 | compareG f u v = unsafePerformIO $ do | 535 | compareG f u v = unsafePerformIO $ do |
495 | r <- createVector (dim v) | 536 | r <- createVector (dim v) |
496 | (u # v #! r) f #|"compareG" | 537 | (u # v #! r) f #|"compareG" |
497 | return r | 538 | return r |
498 | 539 | ||
540 | compareD :: Vector Double -> Vector Double -> Vector CInt | ||
499 | compareD = compareG c_compareD | 541 | compareD = compareG c_compareD |
542 | compareF :: Vector Float -> Vector Float -> Vector CInt | ||
500 | compareF = compareG c_compareF | 543 | compareF = compareG c_compareF |
544 | compareI :: Vector CInt -> Vector CInt -> Vector CInt | ||
501 | compareI = compareG c_compareI | 545 | compareI = compareG c_compareI |
546 | compareL :: Vector Z -> Vector Z -> Vector CInt | ||
502 | compareL = compareG c_compareL | 547 | compareL = compareG c_compareL |
503 | 548 | ||
504 | foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) | 549 | foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) |
@@ -508,16 +553,33 @@ foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok | |||
508 | 553 | ||
509 | -------------------------------------------------------------------------------- | 554 | -------------------------------------------------------------------------------- |
510 | 555 | ||
556 | selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a) | ||
557 | => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt))) | ||
558 | -> c2 -> c1 -> Vector t -> c -> Vector a | ||
511 | selectG f c u v w = unsafePerformIO $ do | 559 | selectG f c u v w = unsafePerformIO $ do |
512 | r <- createVector (dim v) | 560 | r <- createVector (dim v) |
513 | (c # u # v # w #! r) f #|"selectG" | 561 | (c # u # v # w #! r) f #|"selectG" |
514 | return r | 562 | return r |
515 | 563 | ||
564 | selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double | ||
516 | selectD = selectG c_selectD | 565 | selectD = selectG c_selectD |
566 | selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float | ||
517 | selectF = selectG c_selectF | 567 | selectF = selectG c_selectF |
568 | selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt | ||
518 | selectI = selectG c_selectI | 569 | selectI = selectG c_selectI |
570 | selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z | ||
519 | selectL = selectG c_selectL | 571 | selectL = selectG c_selectL |
572 | selectC :: Vector CInt | ||
573 | -> Vector (Complex Double) | ||
574 | -> Vector (Complex Double) | ||
575 | -> Vector (Complex Double) | ||
576 | -> Vector (Complex Double) | ||
520 | selectC = selectG c_selectC | 577 | selectC = selectG c_selectC |
578 | selectQ :: Vector CInt | ||
579 | -> Vector (Complex Float) | ||
580 | -> Vector (Complex Float) | ||
581 | -> Vector (Complex Float) | ||
582 | -> Vector (Complex Float) | ||
521 | selectQ = selectG c_selectQ | 583 | selectQ = selectG c_selectQ |
522 | 584 | ||
523 | type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) | 585 | type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) |
@@ -531,16 +593,29 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z | |||
531 | 593 | ||
532 | --------------------------------------------------------------------------- | 594 | --------------------------------------------------------------------------- |
533 | 595 | ||
596 | remapG :: (TransArray c, TransArray c1, Storable t, Storable a) | ||
597 | => (CInt -> CInt -> CInt -> CInt -> Ptr t | ||
598 | -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))) | ||
599 | -> Matrix t -> c1 -> c -> Matrix a | ||
534 | remapG f i j m = unsafePerformIO $ do | 600 | remapG f i j m = unsafePerformIO $ do |
535 | r <- createMatrix RowMajor (rows i) (cols i) | 601 | r <- createMatrix RowMajor (rows i) (cols i) |
536 | (i # j # m #! r) f #|"remapG" | 602 | (i # j # m #! r) f #|"remapG" |
537 | return r | 603 | return r |
538 | 604 | ||
605 | remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double | ||
539 | remapD = remapG c_remapD | 606 | remapD = remapG c_remapD |
607 | remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float | ||
540 | remapF = remapG c_remapF | 608 | remapF = remapG c_remapF |
609 | remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt | ||
541 | remapI = remapG c_remapI | 610 | remapI = remapG c_remapI |
611 | remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z | ||
542 | remapL = remapG c_remapL | 612 | remapL = remapG c_remapL |
613 | remapC :: Matrix CInt | ||
614 | -> Matrix CInt | ||
615 | -> Matrix (Complex Double) | ||
616 | -> Matrix (Complex Double) | ||
543 | remapC = remapG c_remapC | 617 | remapC = remapG c_remapC |
618 | remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float) | ||
544 | remapQ = remapG c_remapQ | 619 | remapQ = remapG c_remapQ |
545 | 620 | ||
546 | type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) | 621 | type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) |
@@ -554,6 +629,9 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z | |||
554 | 629 | ||
555 | -------------------------------------------------------------------------------- | 630 | -------------------------------------------------------------------------------- |
556 | 631 | ||
632 | rowOpAux :: (TransArray c, Storable a) => | ||
633 | (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt)) | ||
634 | -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO () | ||
557 | rowOpAux f c x i1 i2 j1 j2 m = do | 635 | rowOpAux f c x i1 i2 j1 j2 m = do |
558 | px <- newArray [x] | 636 | px <- newArray [x] |
559 | (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" | 637 | (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" |
@@ -572,6 +650,9 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z | |||
572 | 650 | ||
573 | -------------------------------------------------------------------------------- | 651 | -------------------------------------------------------------------------------- |
574 | 652 | ||
653 | gemmg :: (TransArray c1, TransArray c, TransArray c2, TransArray c3) | ||
654 | => Trans c3 (Trans c2 (Trans c1 (Trans c (IO CInt)))) | ||
655 | -> c3 -> c2 -> c1 -> c -> IO () | ||
575 | gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" | 656 | gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" |
576 | 657 | ||
577 | type Tgemm x = x :> x ::> x ::> x ::> Ok | 658 | type Tgemm x = x :> x ::> x ::> x ::> Ok |
@@ -587,6 +668,10 @@ foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z | |||
587 | 668 | ||
588 | -------------------------------------------------------------------------------- | 669 | -------------------------------------------------------------------------------- |
589 | 670 | ||
671 | reorderAux :: (TransArray c, Storable t, Storable a1, Storable t1, Storable a) => | ||
672 | (CInt -> Ptr a -> CInt -> Ptr t1 | ||
673 | -> Trans c (CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt)) | ||
674 | -> Vector t1 -> c -> Vector t -> Vector a1 | ||
590 | reorderAux f s d v = unsafePerformIO $ do | 675 | reorderAux f s d v = unsafePerformIO $ do |
591 | k <- createVector (dim s) | 676 | k <- createVector (dim s) |
592 | r <- createVector (dim v) | 677 | r <- createVector (dim v) |
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index 9d51444..eb0c5a8 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs | |||
@@ -13,6 +13,9 @@ | |||
13 | {-# LANGUAGE TypeFamilies #-} | 13 | {-# LANGUAGE TypeFamilies #-} |
14 | {-# LANGUAGE TypeOperators #-} | 14 | {-# LANGUAGE TypeOperators #-} |
15 | 15 | ||
16 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
17 | {-# OPTIONS_GHC -fno-warn-missing-methods #-} | ||
18 | |||
16 | {- | | 19 | {- | |
17 | Module : Internal.Modular | 20 | Module : Internal.Modular |
18 | Copyright : (c) Alberto Ruiz 2015 | 21 | Copyright : (c) Alberto Ruiz 2015 |
diff --git a/packages/base/src/Internal/Numeric.hs b/packages/base/src/Internal/Numeric.hs index c9ef0c5..216f142 100644 --- a/packages/base/src/Internal/Numeric.hs +++ b/packages/base/src/Internal/Numeric.hs | |||
@@ -5,6 +5,8 @@ | |||
5 | {-# LANGUAGE FunctionalDependencies #-} | 5 | {-# LANGUAGE FunctionalDependencies #-} |
6 | {-# LANGUAGE UndecidableInstances #-} | 6 | {-# LANGUAGE UndecidableInstances #-} |
7 | 7 | ||
8 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
9 | |||
8 | ----------------------------------------------------------------------------- | 10 | ----------------------------------------------------------------------------- |
9 | -- | | 11 | -- | |
10 | -- Module : Data.Packed.Internal.Numeric | 12 | -- Module : Data.Packed.Internal.Numeric |
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index 544c9e4..7d54e6d 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs | |||
@@ -81,6 +81,8 @@ unsafeFreezeVector :: (Storable t) => STVector s t -> ST s (Vector t) | |||
81 | unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x | 81 | unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x |
82 | 82 | ||
83 | {-# INLINE safeIndexV #-} | 83 | {-# INLINE safeIndexV #-} |
84 | safeIndexV :: Storable t2 | ||
85 | => (STVector s t2 -> Int -> t) -> STVector t1 t2 -> Int -> t | ||
84 | safeIndexV f (STVector v) k | 86 | safeIndexV f (STVector v) k |
85 | | k < 0 || k>= dim v = error $ "out of range error in vector (dim=" | 87 | | k < 0 || k>= dim v = error $ "out of range error in vector (dim=" |
86 | ++show (dim v)++", pos="++show k++")" | 88 | ++show (dim v)++", pos="++show k++")" |
@@ -150,9 +152,12 @@ unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x | |||
150 | freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t) | 152 | freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t) |
151 | freezeMatrix m = liftSTMatrix id m | 153 | freezeMatrix m = liftSTMatrix id m |
152 | 154 | ||
155 | cloneMatrix :: Element t => Matrix t -> IO (Matrix t) | ||
153 | cloneMatrix m = copy (orderOf m) m | 156 | cloneMatrix m = copy (orderOf m) m |
154 | 157 | ||
155 | {-# INLINE safeIndexM #-} | 158 | {-# INLINE safeIndexM #-} |
159 | safeIndexM :: (STMatrix s t2 -> Int -> Int -> t) | ||
160 | -> STMatrix t1 t2 -> Int -> Int -> t | ||
156 | safeIndexM f (STMatrix m) r c | 161 | safeIndexM f (STMatrix m) r c |
157 | | r<0 || r>=rows m || | 162 | | r<0 || r>=rows m || |
158 | c<0 || c>=cols m = error $ "out of range error in matrix (size=" | 163 | c<0 || c>=cols m = error $ "out of range error in matrix (size=" |
@@ -184,6 +189,7 @@ data ColRange = AllCols | |||
184 | | Col Int | 189 | | Col Int |
185 | | FromCol Int | 190 | | FromCol Int |
186 | 191 | ||
192 | getColRange :: Int -> ColRange -> (Int, Int) | ||
187 | getColRange c AllCols = (0,c-1) | 193 | getColRange c AllCols = (0,c-1) |
188 | getColRange c (ColRange a b) = (a `mod` c, b `mod` c) | 194 | getColRange c (ColRange a b) = (a `mod` c, b `mod` c) |
189 | getColRange c (Col a) = (a `mod` c, a `mod` c) | 195 | getColRange c (Col a) = (a `mod` c, a `mod` c) |
@@ -194,6 +200,7 @@ data RowRange = AllRows | |||
194 | | Row Int | 200 | | Row Int |
195 | | FromRow Int | 201 | | FromRow Int |
196 | 202 | ||
203 | getRowRange :: Int -> RowRange -> (Int, Int) | ||
197 | getRowRange r AllRows = (0,r-1) | 204 | getRowRange r AllRows = (0,r-1) |
198 | getRowRange r (RowRange a b) = (a `mod` r, b `mod` r) | 205 | getRowRange r (RowRange a b) = (a `mod` r, b `mod` r) |
199 | getRowRange r (Row a) = (a `mod` r, a `mod` r) | 206 | getRowRange r (Row a) = (a `mod` r, a `mod` r) |
@@ -223,6 +230,7 @@ rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m | |||
223 | i2' = i2 `mod` (rows m) | 230 | i2' = i2 `mod` (rows m) |
224 | 231 | ||
225 | 232 | ||
233 | extractMatrix :: Element a => STMatrix t a -> RowRange -> ColRange -> ST s (Matrix a) | ||
226 | extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) | 234 | extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) |
227 | where | 235 | where |
228 | (i1,i2) = getRowRange (rows m) rr | 236 | (i1,i2) = getRowRange (rows m) rr |
@@ -231,6 +239,7 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[ | |||
231 | -- | r0 c0 height width | 239 | -- | r0 c0 height width |
232 | data Slice s t = Slice (STMatrix s t) Int Int Int Int | 240 | data Slice s t = Slice (STMatrix s t) Int Int Int Int |
233 | 241 | ||
242 | slice :: Element a => Slice t a -> Matrix a | ||
234 | slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m | 243 | slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m |
235 | 244 | ||
236 | gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () | 245 | gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () |
@@ -238,7 +247,7 @@ gemmm beta (slice->r) alpha (slice->a) (slice->b) = res | |||
238 | where | 247 | where |
239 | res = unsafeIOToST (gemm v a b r) | 248 | res = unsafeIOToST (gemm v a b r) |
240 | v = fromList [alpha,beta] | 249 | v = fromList [alpha,beta] |
241 | 250 | ||
242 | 251 | ||
243 | mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) | 252 | mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) |
244 | mutable f a = runST $ do | 253 | mutable f a = runST $ do |
@@ -246,4 +255,3 @@ mutable f a = runST $ do | |||
246 | info <- f (rows a, cols a) x | 255 | info <- f (rows a, cols a) x |
247 | r <- unsafeFreezeMatrix x | 256 | r <- unsafeFreezeMatrix x |
248 | return (r,info) | 257 | return (r,info) |
249 | |||
diff --git a/packages/base/src/Internal/Sparse.hs b/packages/base/src/Internal/Sparse.hs index 1ff3f57..6233b03 100644 --- a/packages/base/src/Internal/Sparse.hs +++ b/packages/base/src/Internal/Sparse.hs | |||
@@ -2,6 +2,8 @@ | |||
2 | {-# LANGUAGE MultiParamTypeClasses #-} | 2 | {-# LANGUAGE MultiParamTypeClasses #-} |
3 | {-# LANGUAGE FlexibleInstances #-} | 3 | {-# LANGUAGE FlexibleInstances #-} |
4 | 4 | ||
5 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
6 | |||
5 | module Internal.Sparse( | 7 | module Internal.Sparse( |
6 | GMatrix(..), CSR(..), mkCSR, fromCSR, | 8 | GMatrix(..), CSR(..), mkCSR, fromCSR, |
7 | mkSparse, mkDiagR, mkDense, | 9 | mkSparse, mkDiagR, mkDense, |
diff --git a/packages/base/src/Internal/Static.hs b/packages/base/src/Internal/Static.hs index f9dfff0..6ef1350 100644 --- a/packages/base/src/Internal/Static.hs +++ b/packages/base/src/Internal/Static.hs | |||
@@ -15,6 +15,8 @@ | |||
15 | {-# LANGUAGE BangPatterns #-} | 15 | {-# LANGUAGE BangPatterns #-} |
16 | {-# LANGUAGE DeriveGeneric #-} | 16 | {-# LANGUAGE DeriveGeneric #-} |
17 | 17 | ||
18 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
19 | |||
18 | {- | | 20 | {- | |
19 | Module : Internal.Static | 21 | Module : Internal.Static |
20 | Copyright : (c) Alberto Ruiz 2006-14 | 22 | Copyright : (c) Alberto Ruiz 2006-14 |
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs index 8c8a31e..def7cc3 100644 --- a/packages/base/src/Internal/Util.hs +++ b/packages/base/src/Internal/Util.hs | |||
@@ -6,6 +6,8 @@ | |||
6 | {-# LANGUAGE ScopedTypeVariables #-} | 6 | {-# LANGUAGE ScopedTypeVariables #-} |
7 | {-# LANGUAGE ViewPatterns #-} | 7 | {-# LANGUAGE ViewPatterns #-} |
8 | 8 | ||
9 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
10 | {-# OPTIONS_GHC -fno-warn-orphans #-} | ||
9 | 11 | ||
10 | ----------------------------------------------------------------------------- | 12 | ----------------------------------------------------------------------------- |
11 | {- | | 13 | {- | |
diff --git a/packages/base/src/Internal/Vector.hs b/packages/base/src/Internal/Vector.hs index 67d0416..dedb822 100644 --- a/packages/base/src/Internal/Vector.hs +++ b/packages/base/src/Internal/Vector.hs | |||
@@ -1,6 +1,7 @@ | |||
1 | {-# LANGUAGE MagicHash, UnboxedTuples, BangPatterns, FlexibleContexts #-} | 1 | {-# LANGUAGE MagicHash, UnboxedTuples, BangPatterns, FlexibleContexts #-} |
2 | {-# LANGUAGE TypeSynonymInstances #-} | 2 | {-# LANGUAGE TypeSynonymInstances #-} |
3 | 3 | ||
4 | {-# OPTIONS_GHC -fno-warn-orphans #-} | ||
4 | 5 | ||
5 | -- | | 6 | -- | |
6 | -- Module : Internal.Vector | 7 | -- Module : Internal.Vector |
@@ -40,6 +41,7 @@ import qualified Data.Vector.Storable as Vector | |||
40 | import Data.Vector.Storable(Vector, fromList, unsafeToForeignPtr, unsafeFromForeignPtr, unsafeWith) | 41 | import Data.Vector.Storable(Vector, fromList, unsafeToForeignPtr, unsafeFromForeignPtr, unsafeWith) |
41 | 42 | ||
42 | import Data.Binary | 43 | import Data.Binary |
44 | import Data.Binary.Put | ||
43 | import Control.Monad(replicateM) | 45 | import Control.Monad(replicateM) |
44 | import qualified Data.ByteString.Internal as BS | 46 | import qualified Data.ByteString.Internal as BS |
45 | import Data.Vector.Storable.Internal(updPtr) | 47 | import Data.Vector.Storable.Internal(updPtr) |
@@ -92,6 +94,7 @@ createVector n = do | |||
92 | 94 | ||
93 | -} | 95 | -} |
94 | 96 | ||
97 | safeRead :: Storable a => Vector a -> (Ptr a -> IO c) -> c | ||
95 | safeRead v = inlinePerformIO . unsafeWith v | 98 | safeRead v = inlinePerformIO . unsafeWith v |
96 | {-# INLINE safeRead #-} | 99 | {-# INLINE safeRead #-} |
97 | 100 | ||
@@ -283,11 +286,13 @@ foldVectorWithIndex f x v = unsafePerformIO $ | |||
283 | go (dim v -1) x | 286 | go (dim v -1) x |
284 | {-# INLINE foldVectorWithIndex #-} | 287 | {-# INLINE foldVectorWithIndex #-} |
285 | 288 | ||
289 | foldLoop :: (Int -> t -> t) -> t -> Int -> t | ||
286 | foldLoop f s0 d = go (d - 1) s0 | 290 | foldLoop f s0 d = go (d - 1) s0 |
287 | where | 291 | where |
288 | go 0 s = f (0::Int) s | 292 | go 0 s = f (0::Int) s |
289 | go !j !s = go (j - 1) (f j s) | 293 | go !j !s = go (j - 1) (f j s) |
290 | 294 | ||
295 | foldVectorG :: Storable t1 => (Int -> (Int -> t1) -> t -> t) -> t -> Vector t1 -> t | ||
291 | foldVectorG f s0 v = foldLoop g s0 (dim v) | 296 | foldVectorG f s0 v = foldLoop g s0 (dim v) |
292 | where g !k !s = f k (safeRead v . flip peekElemOff) s | 297 | where g !k !s = f k (safeRead v . flip peekElemOff) s |
293 | {-# INLINE g #-} -- Thanks to Ryan Ingram (http://permalink.gmane.org/gmane.comp.lang.haskell.cafe/46479) | 298 | {-# INLINE g #-} -- Thanks to Ryan Ingram (http://permalink.gmane.org/gmane.comp.lang.haskell.cafe/46479) |
@@ -390,8 +395,10 @@ chunks d = let c = d `div` chunk | |||
390 | m = d `mod` chunk | 395 | m = d `mod` chunk |
391 | in if m /= 0 then reverse (m:(replicate c chunk)) else (replicate c chunk) | 396 | in if m /= 0 then reverse (m:(replicate c chunk)) else (replicate c chunk) |
392 | 397 | ||
398 | putVector :: (Storable t, Binary t) => Vector t -> Data.Binary.Put.PutM () | ||
393 | putVector v = mapM_ put $! toList v | 399 | putVector v = mapM_ put $! toList v |
394 | 400 | ||
401 | getVector :: (Storable a, Binary a) => Int -> Get (Vector a) | ||
395 | getVector d = do | 402 | getVector d = do |
396 | xs <- replicateM d get | 403 | xs <- replicateM d get |
397 | return $! fromList xs | 404 | return $! fromList xs |
diff --git a/packages/base/src/Internal/Vectorized.hs b/packages/base/src/Internal/Vectorized.hs index a410bb2..c00c324 100644 --- a/packages/base/src/Internal/Vectorized.hs +++ b/packages/base/src/Internal/Vectorized.hs | |||
@@ -28,12 +28,15 @@ import System.IO.Unsafe(unsafePerformIO) | |||
28 | import Control.Monad(when) | 28 | import Control.Monad(when) |
29 | 29 | ||
30 | infixr 1 # | 30 | infixr 1 # |
31 | (#) :: TransArray c => c -> (b -> IO r) -> TransRaw c b -> IO r | ||
31 | a # b = applyRaw a b | 32 | a # b = applyRaw a b |
32 | {-# INLINE (#) #-} | 33 | {-# INLINE (#) #-} |
33 | 34 | ||
35 | (#!) :: (TransArray c, TransArray c1) => c1 -> c -> TransRaw c1 (TransRaw c (IO r)) -> IO r | ||
34 | a #! b = a # b # id | 36 | a #! b = a # b # id |
35 | {-# INLINE (#!) #-} | 37 | {-# INLINE (#!) #-} |
36 | 38 | ||
39 | fromei :: Enum a => a -> CInt | ||
37 | fromei x = fromIntegral (fromEnum x) :: CInt | 40 | fromei x = fromIntegral (fromEnum x) :: CInt |
38 | 41 | ||
39 | data FunCodeV = Sin | 42 | data FunCodeV = Sin |
@@ -100,10 +103,20 @@ sumQ = sumg c_sumQ | |||
100 | sumC :: Vector (Complex Double) -> Complex Double | 103 | sumC :: Vector (Complex Double) -> Complex Double |
101 | sumC = sumg c_sumC | 104 | sumC = sumg c_sumC |
102 | 105 | ||
106 | sumI :: ( TransRaw c (CInt -> Ptr a -> IO CInt) ~ (CInt -> Ptr I -> I :> Ok) | ||
107 | , TransArray c | ||
108 | , Storable a | ||
109 | ) | ||
110 | => I -> c -> a | ||
103 | sumI m = sumg (c_sumI m) | 111 | sumI m = sumg (c_sumI m) |
104 | 112 | ||
113 | sumL :: ( TransRaw c (CInt -> Ptr a -> IO CInt) ~ (CInt -> Ptr Z -> Z :> Ok) | ||
114 | , TransArray c | ||
115 | , Storable a | ||
116 | ) => Z -> c -> a | ||
105 | sumL m = sumg (c_sumL m) | 117 | sumL m = sumg (c_sumL m) |
106 | 118 | ||
119 | sumg :: (TransArray c, Storable a) => TransRaw c (CInt -> Ptr a -> IO CInt) -> c -> a | ||
107 | sumg f x = unsafePerformIO $ do | 120 | sumg f x = unsafePerformIO $ do |
108 | r <- createVector 1 | 121 | r <- createVector 1 |
109 | (x #! r) f #| "sum" | 122 | (x #! r) f #| "sum" |
@@ -140,6 +153,8 @@ prodI = prodg . c_prodI | |||
140 | prodL :: Z-> Vector Z -> Z | 153 | prodL :: Z-> Vector Z -> Z |
141 | prodL = prodg . c_prodL | 154 | prodL = prodg . c_prodL |
142 | 155 | ||
156 | prodg :: (TransArray c, Storable a) | ||
157 | => TransRaw c (CInt -> Ptr a -> IO CInt) -> c -> a | ||
143 | prodg f x = unsafePerformIO $ do | 158 | prodg f x = unsafePerformIO $ do |
144 | r <- createVector 1 | 159 | r <- createVector 1 |
145 | (x #! r) f #| "prod" | 160 | (x #! r) f #| "prod" |
@@ -155,16 +170,25 @@ foreign import ccall unsafe "prodL" c_prodL :: Z -> TVV Z | |||
155 | 170 | ||
156 | ------------------------------------------------------------------ | 171 | ------------------------------------------------------------------ |
157 | 172 | ||
173 | toScalarAux :: (Enum a, TransArray c, Storable a1) | ||
174 | => (CInt -> TransRaw c (CInt -> Ptr a1 -> IO CInt)) -> a -> c -> a1 | ||
158 | toScalarAux fun code v = unsafePerformIO $ do | 175 | toScalarAux fun code v = unsafePerformIO $ do |
159 | r <- createVector 1 | 176 | r <- createVector 1 |
160 | (v #! r) (fun (fromei code)) #|"toScalarAux" | 177 | (v #! r) (fun (fromei code)) #|"toScalarAux" |
161 | return (r @> 0) | 178 | return (r @> 0) |
162 | 179 | ||
180 | |||
181 | vectorMapAux :: (Enum a, Storable t, Storable a1) | ||
182 | => (CInt -> CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt) | ||
183 | -> a -> Vector t -> Vector a1 | ||
163 | vectorMapAux fun code v = unsafePerformIO $ do | 184 | vectorMapAux fun code v = unsafePerformIO $ do |
164 | r <- createVector (dim v) | 185 | r <- createVector (dim v) |
165 | (v #! r) (fun (fromei code)) #|"vectorMapAux" | 186 | (v #! r) (fun (fromei code)) #|"vectorMapAux" |
166 | return r | 187 | return r |
167 | 188 | ||
189 | vectorMapValAux :: (Enum a, Storable a2, Storable t, Storable a1) | ||
190 | => (CInt -> Ptr a2 -> CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt) | ||
191 | -> a -> a2 -> Vector t -> Vector a1 | ||
168 | vectorMapValAux fun code val v = unsafePerformIO $ do | 192 | vectorMapValAux fun code val v = unsafePerformIO $ do |
169 | r <- createVector (dim v) | 193 | r <- createVector (dim v) |
170 | pval <- newArray [val] | 194 | pval <- newArray [val] |
@@ -172,6 +196,9 @@ vectorMapValAux fun code val v = unsafePerformIO $ do | |||
172 | free pval | 196 | free pval |
173 | return r | 197 | return r |
174 | 198 | ||
199 | vectorZipAux :: (Enum a, TransArray c, Storable t, Storable a1) | ||
200 | => (CInt -> CInt -> Ptr t -> TransRaw c (CInt -> Ptr a1 -> IO CInt)) | ||
201 | -> a -> Vector t -> c -> Vector a1 | ||
175 | vectorZipAux fun code u v = unsafePerformIO $ do | 202 | vectorZipAux fun code u v = unsafePerformIO $ do |
176 | r <- createVector (dim u) | 203 | r <- createVector (dim u) |
177 | (u # v #! r) (fun (fromei code)) #|"vectorZipAux" | 204 | (u # v #! r) (fun (fromei code)) #|"vectorZipAux" |
@@ -378,6 +405,7 @@ foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> D | |||
378 | 405 | ||
379 | -------------------------------------------------------------------------------- | 406 | -------------------------------------------------------------------------------- |
380 | 407 | ||
408 | roundVector :: Vector Double -> Vector Double | ||
381 | roundVector v = unsafePerformIO $ do | 409 | roundVector v = unsafePerformIO $ do |
382 | r <- createVector (dim v) | 410 | r <- createVector (dim v) |
383 | (v #! r) c_round_vector #|"roundVector" | 411 | (v #! r) c_round_vector #|"roundVector" |
@@ -432,6 +460,8 @@ long2intV :: Vector Z -> Vector I | |||
432 | long2intV = tog c_long2int | 460 | long2intV = tog c_long2int |
433 | 461 | ||
434 | 462 | ||
463 | tog :: (Storable t, Storable a) | ||
464 | => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a | ||
435 | tog f v = unsafePerformIO $ do | 465 | tog f v = unsafePerformIO $ do |
436 | r <- createVector (dim v) | 466 | r <- createVector (dim v) |
437 | (v #! r) f #|"tog" | 467 | (v #! r) f #|"tog" |
@@ -451,6 +481,8 @@ foreign import ccall unsafe "long2int" c_long2int :: Z :> I :> Ok | |||
451 | 481 | ||
452 | --------------------------------------------------------------- | 482 | --------------------------------------------------------------- |
453 | 483 | ||
484 | stepg :: (Storable t, Storable a) | ||
485 | => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a | ||
454 | stepg f v = unsafePerformIO $ do | 486 | stepg f v = unsafePerformIO $ do |
455 | r <- createVector (dim v) | 487 | r <- createVector (dim v) |
456 | (v #! r) f #|"step" | 488 | (v #! r) f #|"step" |
@@ -476,6 +508,8 @@ foreign import ccall unsafe "stepL" c_stepL :: TVV Z | |||
476 | 508 | ||
477 | -------------------------------------------------------------------------------- | 509 | -------------------------------------------------------------------------------- |
478 | 510 | ||
511 | conjugateAux :: (Storable t, Storable a) | ||
512 | => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a | ||
479 | conjugateAux fun x = unsafePerformIO $ do | 513 | conjugateAux fun x = unsafePerformIO $ do |
480 | v <- createVector (dim x) | 514 | v <- createVector (dim x) |
481 | (x #! v) fun #|"conjugateAux" | 515 | (x #! v) fun #|"conjugateAux" |
@@ -501,6 +535,8 @@ cloneVector v = do | |||
501 | 535 | ||
502 | -------------------------------------------------------------------------------- | 536 | -------------------------------------------------------------------------------- |
503 | 537 | ||
538 | constantAux :: (Storable a1, Storable a) | ||
539 | => (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a | ||
504 | constantAux fun x n = unsafePerformIO $ do | 540 | constantAux fun x n = unsafePerformIO $ do |
505 | v <- createVector n | 541 | v <- createVector n |
506 | px <- newArray [x] | 542 | px <- newArray [x] |
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs index 73d4a13..970c77e 100644 --- a/packages/base/src/Numeric/LinearAlgebra.hs +++ b/packages/base/src/Numeric/LinearAlgebra.hs | |||
@@ -1,6 +1,8 @@ | |||
1 | {-# LANGUAGE CPP #-} | 1 | {-# LANGUAGE CPP #-} |
2 | {-# LANGUAGE FlexibleContexts #-} | 2 | {-# LANGUAGE FlexibleContexts #-} |
3 | 3 | ||
4 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
5 | |||
4 | ----------------------------------------------------------------------------- | 6 | ----------------------------------------------------------------------------- |
5 | {- | | 7 | {- | |
6 | Module : Numeric.LinearAlgebra | 8 | Module : Numeric.LinearAlgebra |
diff --git a/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs b/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs index 3a84645..57e5cf1 100644 --- a/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs +++ b/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs | |||
@@ -28,7 +28,9 @@ infixr 8 <·> | |||
28 | (<·>) :: Numeric t => Vector t -> Vector t -> t | 28 | (<·>) :: Numeric t => Vector t -> Vector t -> t |
29 | (<·>) = dot | 29 | (<·>) = dot |
30 | 30 | ||
31 | app :: Numeric t => Matrix t -> Vector t -> Vector t | ||
31 | app m v = m #> v | 32 | app m v = m #> v |
32 | 33 | ||
34 | mul :: Numeric t => Matrix t -> Matrix t -> Matrix t | ||
33 | mul a b = a <> b | 35 | mul a b = a <> b |
34 | 36 | ||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs index e328904..2e05c90 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs | |||
@@ -14,6 +14,8 @@ | |||
14 | {-# LANGUAGE GADTs #-} | 14 | {-# LANGUAGE GADTs #-} |
15 | {-# LANGUAGE TypeFamilies #-} | 15 | {-# LANGUAGE TypeFamilies #-} |
16 | 16 | ||
17 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
18 | {-# OPTIONS_GHC -fno-warn-orphans #-} | ||
17 | 19 | ||
18 | {- | | 20 | {- | |
19 | Module : Numeric.LinearAlgebra.Static | 21 | Module : Numeric.LinearAlgebra.Static |
diff --git a/packages/base/src/Numeric/Matrix.hs b/packages/base/src/Numeric/Matrix.hs index 06da150..6e3db61 100644 --- a/packages/base/src/Numeric/Matrix.hs +++ b/packages/base/src/Numeric/Matrix.hs | |||
@@ -4,6 +4,8 @@ | |||
4 | {-# LANGUAGE UndecidableInstances #-} | 4 | {-# LANGUAGE UndecidableInstances #-} |
5 | {-# LANGUAGE MultiParamTypeClasses #-} | 5 | {-# LANGUAGE MultiParamTypeClasses #-} |
6 | 6 | ||
7 | {-# OPTIONS_GHC -fno-warn-orphans #-} | ||
8 | |||
7 | ----------------------------------------------------------------------------- | 9 | ----------------------------------------------------------------------------- |
8 | -- | | 10 | -- | |
9 | -- Module : Numeric.Matrix | 11 | -- Module : Numeric.Matrix |
@@ -35,6 +37,7 @@ import Data.List(partition) | |||
35 | import qualified Data.Foldable as F | 37 | import qualified Data.Foldable as F |
36 | import qualified Data.Semigroup as S | 38 | import qualified Data.Semigroup as S |
37 | import Internal.Chain | 39 | import Internal.Chain |
40 | import Foreign.Storable(Storable) | ||
38 | 41 | ||
39 | 42 | ||
40 | ------------------------------------------------------------------- | 43 | ------------------------------------------------------------------- |
@@ -80,8 +83,16 @@ instance (Floating a, Container Vector a, Floating (Vector a), Fractional (Matri | |||
80 | 83 | ||
81 | -------------------------------------------------------------------------------- | 84 | -------------------------------------------------------------------------------- |
82 | 85 | ||
86 | isScalar :: Matrix t -> Bool | ||
83 | isScalar m = rows m == 1 && cols m == 1 | 87 | isScalar m = rows m == 1 && cols m == 1 |
84 | 88 | ||
89 | adaptScalarM :: (Foreign.Storable.Storable t1, Foreign.Storable.Storable t2) | ||
90 | => (t1 -> Matrix t2 -> t) | ||
91 | -> (Matrix t1 -> Matrix t2 -> t) | ||
92 | -> (Matrix t1 -> t2 -> t) | ||
93 | -> Matrix t1 | ||
94 | -> Matrix t2 | ||
95 | -> t | ||
85 | adaptScalarM f1 f2 f3 x y | 96 | adaptScalarM f1 f2 f3 x y |
86 | | isScalar x = f1 (x @@>(0,0) ) y | 97 | | isScalar x = f1 (x @@>(0,0) ) y |
87 | | isScalar y = f3 x (y @@>(0,0) ) | 98 | | isScalar y = f3 x (y @@>(0,0) ) |
@@ -96,7 +107,7 @@ instance (Container Vector t, Eq t, Num (Vector t), Product t) => M.Monoid (Matr | |||
96 | where | 107 | where |
97 | mempty = 1 | 108 | mempty = 1 |
98 | mappend = adaptScalarM scale mXm (flip scale) | 109 | mappend = adaptScalarM scale mXm (flip scale) |
99 | 110 | ||
100 | mconcat xs = work (partition isScalar xs) | 111 | mconcat xs = work (partition isScalar xs) |
101 | where | 112 | where |
102 | work (ss,[]) = product ss | 113 | work (ss,[]) = product ss |
@@ -106,4 +117,3 @@ instance (Container Vector t, Eq t, Num (Vector t), Product t) => M.Monoid (Matr | |||
106 | | otherwise = scale x00 m | 117 | | otherwise = scale x00 m |
107 | where | 118 | where |
108 | x00 = x @@> (0,0) | 119 | x00 = x @@> (0,0) |
109 | |||
diff --git a/packages/base/src/Numeric/Vector.hs b/packages/base/src/Numeric/Vector.hs index 017196c..1e5877d 100644 --- a/packages/base/src/Numeric/Vector.hs +++ b/packages/base/src/Numeric/Vector.hs | |||
@@ -3,6 +3,9 @@ | |||
3 | {-# LANGUAGE FlexibleInstances #-} | 3 | {-# LANGUAGE FlexibleInstances #-} |
4 | {-# LANGUAGE UndecidableInstances #-} | 4 | {-# LANGUAGE UndecidableInstances #-} |
5 | {-# LANGUAGE MultiParamTypeClasses #-} | 5 | {-# LANGUAGE MultiParamTypeClasses #-} |
6 | |||
7 | {-# OPTIONS_GHC -fno-warn-orphans #-} | ||
8 | |||
6 | ----------------------------------------------------------------------------- | 9 | ----------------------------------------------------------------------------- |
7 | -- | | 10 | -- | |
8 | -- Module : Numeric.Vector | 11 | -- Module : Numeric.Vector |
@@ -14,7 +17,7 @@ | |||
14 | -- | 17 | -- |
15 | -- Provides instances of standard classes 'Show', 'Read', 'Eq', | 18 | -- Provides instances of standard classes 'Show', 'Read', 'Eq', |
16 | -- 'Num', 'Fractional', and 'Floating' for 'Vector'. | 19 | -- 'Num', 'Fractional', and 'Floating' for 'Vector'. |
17 | -- | 20 | -- |
18 | ----------------------------------------------------------------------------- | 21 | ----------------------------------------------------------------------------- |
19 | 22 | ||
20 | module Numeric.Vector () where | 23 | module Numeric.Vector () where |
@@ -23,9 +26,17 @@ import Internal.Vectorized | |||
23 | import Internal.Vector | 26 | import Internal.Vector |
24 | import Internal.Numeric | 27 | import Internal.Numeric |
25 | import Internal.Conversion | 28 | import Internal.Conversion |
29 | import Foreign.Storable(Storable) | ||
26 | 30 | ||
27 | ------------------------------------------------------------------- | 31 | ------------------------------------------------------------------- |
28 | 32 | ||
33 | adaptScalar :: (Foreign.Storable.Storable t1, Foreign.Storable.Storable t2) | ||
34 | => (t1 -> Vector t2 -> t) | ||
35 | -> (Vector t1 -> Vector t2 -> t) | ||
36 | -> (Vector t1 -> t2 -> t) | ||
37 | -> Vector t1 | ||
38 | -> Vector t2 | ||
39 | -> t | ||
29 | adaptScalar f1 f2 f3 x y | 40 | adaptScalar f1 f2 f3 x y |
30 | | dim x == 1 = f1 (x@>0) y | 41 | | dim x == 1 = f1 (x@>0) y |
31 | | dim y == 1 = f3 x (y@>0) | 42 | | dim y == 1 = f3 x (y@>0) |
@@ -172,4 +183,3 @@ instance Floating (Vector (Complex Float)) where | |||
172 | sqrt = vectorMapQ Sqrt | 183 | sqrt = vectorMapQ Sqrt |
173 | (**) = adaptScalar (vectorMapValQ PowSV) (vectorZipQ Pow) (flip (vectorMapValQ PowVS)) | 184 | (**) = adaptScalar (vectorMapValQ PowSV) (vectorZipQ Pow) (flip (vectorMapValQ PowVS)) |
174 | pi = fromList [pi] | 185 | pi = fromList [pi] |
175 | |||