diff options
Diffstat (limited to 'packages/base/src')
22 files changed, 226 insertions, 20 deletions
diff --git a/packages/base/src/Internal/Algorithms.hs b/packages/base/src/Internal/Algorithms.hs index 5fe7796..6027c46 100644 --- a/packages/base/src/Internal/Algorithms.hs +++ b/packages/base/src/Internal/Algorithms.hs | |||
@@ -4,6 +4,8 @@ | |||
4 | {-# LANGUAGE UndecidableInstances #-} | 4 | {-# LANGUAGE UndecidableInstances #-} |
5 | {-# LANGUAGE TypeFamilies #-} | 5 | {-# LANGUAGE TypeFamilies #-} |
6 | 6 | ||
7 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
8 | |||
7 | ----------------------------------------------------------------------------- | 9 | ----------------------------------------------------------------------------- |
8 | {- | | 10 | {- | |
9 | Module : Internal.Algorithms | 11 | Module : Internal.Algorithms |
diff --git a/packages/base/src/Internal/CG.hs b/packages/base/src/Internal/CG.hs index cc10ad8..29edd35 100644 --- a/packages/base/src/Internal/CG.hs +++ b/packages/base/src/Internal/CG.hs | |||
@@ -1,6 +1,8 @@ | |||
1 | {-# LANGUAGE FlexibleContexts, FlexibleInstances #-} | 1 | {-# LANGUAGE FlexibleContexts, FlexibleInstances #-} |
2 | {-# LANGUAGE RecordWildCards #-} | 2 | {-# LANGUAGE RecordWildCards #-} |
3 | 3 | ||
4 | {-# OPTIONS_GHC -fno-warn-orphans #-} | ||
5 | |||
4 | module Internal.CG( | 6 | module Internal.CG( |
5 | cgSolve, cgSolve', | 7 | cgSolve, cgSolve', |
6 | CGState(..), R, V | 8 | CGState(..), R, V |
diff --git a/packages/base/src/Internal/Chain.hs b/packages/base/src/Internal/Chain.hs index f87eb02..4000c2b 100644 --- a/packages/base/src/Internal/Chain.hs +++ b/packages/base/src/Internal/Chain.hs | |||
@@ -1,5 +1,7 @@ | |||
1 | {-# LANGUAGE FlexibleContexts #-} | 1 | {-# LANGUAGE FlexibleContexts #-} |
2 | 2 | ||
3 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
4 | |||
3 | ----------------------------------------------------------------------------- | 5 | ----------------------------------------------------------------------------- |
4 | -- | | 6 | -- | |
5 | -- Module : Internal.Chain | 7 | -- Module : Internal.Chain |
diff --git a/packages/base/src/Internal/Container.hs b/packages/base/src/Internal/Container.hs index cdcdad0..41b8214 100644 --- a/packages/base/src/Internal/Container.hs +++ b/packages/base/src/Internal/Container.hs | |||
@@ -5,6 +5,8 @@ | |||
5 | {-# LANGUAGE FunctionalDependencies #-} | 5 | {-# LANGUAGE FunctionalDependencies #-} |
6 | {-# LANGUAGE UndecidableInstances #-} | 6 | {-# LANGUAGE UndecidableInstances #-} |
7 | 7 | ||
8 | {-# OPTIONS_GHC -fno-warn-simplifiable-class-constraints #-} | ||
9 | |||
8 | ----------------------------------------------------------------------------- | 10 | ----------------------------------------------------------------------------- |
9 | -- | | 11 | -- | |
10 | -- Module : Internal.Container | 12 | -- Module : Internal.Container |
diff --git a/packages/base/src/Internal/Devel.hs b/packages/base/src/Internal/Devel.hs index 3887663..f72d8aa 100644 --- a/packages/base/src/Internal/Devel.hs +++ b/packages/base/src/Internal/Devel.hs | |||
@@ -54,6 +54,7 @@ check msg f = do | |||
54 | 54 | ||
55 | -- | postfix error code check | 55 | -- | postfix error code check |
56 | infixl 0 #| | 56 | infixl 0 #| |
57 | (#|) :: IO CInt -> String -> IO () | ||
57 | (#|) = flip check | 58 | (#|) = flip check |
58 | 59 | ||
59 | -- | Error capture and conversion to Maybe | 60 | -- | Error capture and conversion to Maybe |
diff --git a/packages/base/src/Internal/Element.hs b/packages/base/src/Internal/Element.hs index eb3a25b..2e330ee 100644 --- a/packages/base/src/Internal/Element.hs +++ b/packages/base/src/Internal/Element.hs | |||
@@ -4,6 +4,8 @@ | |||
4 | {-# LANGUAGE UndecidableInstances #-} | 4 | {-# LANGUAGE UndecidableInstances #-} |
5 | {-# LANGUAGE MultiParamTypeClasses #-} | 5 | {-# LANGUAGE MultiParamTypeClasses #-} |
6 | 6 | ||
7 | {-# OPTIONS_GHC -fno-warn-orphans #-} | ||
8 | |||
7 | ----------------------------------------------------------------------------- | 9 | ----------------------------------------------------------------------------- |
8 | -- | | 10 | -- | |
9 | -- Module : Data.Packed.Matrix | 11 | -- Module : Data.Packed.Matrix |
@@ -31,6 +33,7 @@ import Data.List.Split(chunksOf) | |||
31 | import Foreign.Storable(Storable) | 33 | import Foreign.Storable(Storable) |
32 | import System.IO.Unsafe(unsafePerformIO) | 34 | import System.IO.Unsafe(unsafePerformIO) |
33 | import Control.Monad(liftM) | 35 | import Control.Monad(liftM) |
36 | import Foreign.C.Types(CInt) | ||
34 | 37 | ||
35 | ------------------------------------------------------------------- | 38 | ------------------------------------------------------------------- |
36 | 39 | ||
@@ -53,8 +56,10 @@ instance (Show a, Element a) => (Show (Matrix a)) where | |||
53 | show m | rows m == 0 || cols m == 0 = sizes m ++" []" | 56 | show m | rows m == 0 || cols m == 0 = sizes m ++" []" |
54 | show m = (sizes m++) . dsp . map (map show) . toLists $ m | 57 | show m = (sizes m++) . dsp . map (map show) . toLists $ m |
55 | 58 | ||
59 | sizes :: Matrix t -> [Char] | ||
56 | sizes m = "("++show (rows m)++"><"++show (cols m)++")\n" | 60 | sizes m = "("++show (rows m)++"><"++show (cols m)++")\n" |
57 | 61 | ||
62 | dsp :: [[[Char]]] -> [Char] | ||
58 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp | 63 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp |
59 | where | 64 | where |
60 | mt = transpose as | 65 | mt = transpose as |
@@ -73,6 +78,7 @@ instance (Element a, Read a) => Read (Matrix a) where | |||
73 | rs = read . snd . breakAt '(' .init . fst . breakAt '>' $ dims | 78 | rs = read . snd . breakAt '(' .init . fst . breakAt '>' $ dims |
74 | 79 | ||
75 | 80 | ||
81 | breakAt :: Eq a => a -> [a] -> ([a], [a]) | ||
76 | breakAt c l = (a++[c],tail b) where | 82 | breakAt c l = (a++[c],tail b) where |
77 | (a,b) = break (==c) l | 83 | (a,b) = break (==c) l |
78 | 84 | ||
@@ -88,7 +94,8 @@ data Extractor | |||
88 | | Drop Int | 94 | | Drop Int |
89 | | DropLast Int | 95 | | DropLast Int |
90 | deriving Show | 96 | deriving Show |
91 | 97 | ||
98 | ppext :: Extractor -> [Char] | ||
92 | ppext All = ":" | 99 | ppext All = ":" |
93 | ppext (Range a 1 c) = printf "%d:%d" a c | 100 | ppext (Range a 1 c) = printf "%d:%d" a c |
94 | ppext (Range a b c) = printf "%d:%d:%d" a b c | 101 | ppext (Range a b c) = printf "%d:%d:%d" a b c |
@@ -128,10 +135,14 @@ ppext (DropLast n) = printf "DropLast %d" n | |||
128 | infixl 9 ?? | 135 | infixl 9 ?? |
129 | (??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t | 136 | (??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t |
130 | 137 | ||
138 | minEl :: Vector CInt -> CInt | ||
131 | minEl = toScalarI Min | 139 | minEl = toScalarI Min |
140 | maxEl :: Vector CInt -> CInt | ||
132 | maxEl = toScalarI Max | 141 | maxEl = toScalarI Max |
142 | cmodi :: Foreign.C.Types.CInt -> Vector Foreign.C.Types.CInt -> Vector Foreign.C.Types.CInt | ||
133 | cmodi = vectorMapValI ModVS | 143 | cmodi = vectorMapValI ModVS |
134 | 144 | ||
145 | extractError :: Matrix t1 -> (Extractor, Extractor) -> t | ||
135 | extractError m (e1,e2)= error $ printf "can't extract (%s,%s) from matrix %dx%d" (ppext e1::String) (ppext e2::String) (rows m) (cols m) | 146 | extractError m (e1,e2)= error $ printf "can't extract (%s,%s) from matrix %dx%d" (ppext e1::String) (ppext e2::String) (rows m) (cols m) |
136 | 147 | ||
137 | m ?? (Range a s b,e) | s /= 1 = m ?? (Pos (idxs [a,a+s .. b]), e) | 148 | m ?? (Range a s b,e) | s /= 1 = m ?? (Pos (idxs [a,a+s .. b]), e) |
@@ -232,8 +243,10 @@ disp = putStr . dispf 2 | |||
232 | fromBlocks :: Element t => [[Matrix t]] -> Matrix t | 243 | fromBlocks :: Element t => [[Matrix t]] -> Matrix t |
233 | fromBlocks = fromBlocksRaw . adaptBlocks | 244 | fromBlocks = fromBlocksRaw . adaptBlocks |
234 | 245 | ||
246 | fromBlocksRaw :: Element t => [[Matrix t]] -> Matrix t | ||
235 | fromBlocksRaw mms = joinVert . map joinHoriz $ mms | 247 | fromBlocksRaw mms = joinVert . map joinHoriz $ mms |
236 | 248 | ||
249 | adaptBlocks :: Element t => [[Matrix t]] -> [[Matrix t]] | ||
237 | adaptBlocks ms = ms' where | 250 | adaptBlocks ms = ms' where |
238 | bc = case common length ms of | 251 | bc = case common length ms of |
239 | Just c -> c | 252 | Just c -> c |
@@ -486,6 +499,9 @@ liftMatrix2Auto f m1 m2 | |||
486 | m2' = conformMTo (r,c) m2 | 499 | m2' = conformMTo (r,c) m2 |
487 | 500 | ||
488 | -- FIXME do not flatten if equal order | 501 | -- FIXME do not flatten if equal order |
502 | lM :: (Storable t, Element t1, Element t2) | ||
503 | => (Vector t1 -> Vector t2 -> Vector t) | ||
504 | -> Matrix t1 -> Matrix t2 -> Matrix t | ||
489 | lM f m1 m2 = matrixFromVector | 505 | lM f m1 m2 = matrixFromVector |
490 | RowMajor | 506 | RowMajor |
491 | (max' (rows m1) (rows m2)) | 507 | (max' (rows m1) (rows m2)) |
@@ -504,6 +520,7 @@ compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 | |||
504 | 520 | ||
505 | ------------------------------------------------------------ | 521 | ------------------------------------------------------------ |
506 | 522 | ||
523 | toBlockRows :: Element t => [Int] -> Matrix t -> [Matrix t] | ||
507 | toBlockRows [r] m | 524 | toBlockRows [r] m |
508 | | r == rows m = [m] | 525 | | r == rows m = [m] |
509 | toBlockRows rs m | 526 | toBlockRows rs m |
@@ -513,6 +530,7 @@ toBlockRows rs m | |||
513 | szs = map (* cols m) rs | 530 | szs = map (* cols m) rs |
514 | g k = (k><0)[] | 531 | g k = (k><0)[] |
515 | 532 | ||
533 | toBlockCols :: Element t => [Int] -> Matrix t -> [Matrix t] | ||
516 | toBlockCols [c] m | c == cols m = [m] | 534 | toBlockCols [c] m | c == cols m = [m] |
517 | toBlockCols cs m = map trans . toBlockRows cs . trans $ m | 535 | toBlockCols cs m = map trans . toBlockRows cs . trans $ m |
518 | 536 | ||
@@ -576,7 +594,7 @@ Just (3><3) | |||
576 | mapMatrixWithIndexM | 594 | mapMatrixWithIndexM |
577 | :: (Element a, Storable b, Monad m) => | 595 | :: (Element a, Storable b, Monad m) => |
578 | ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b) | 596 | ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b) |
579 | mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m | 597 | mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m |
580 | where | 598 | where |
581 | c = cols m | 599 | c = cols m |
582 | 600 | ||
@@ -598,4 +616,3 @@ mapMatrixWithIndex g m = reshape c . mapVectorWithIndex (mk c g) . flatten $ m | |||
598 | 616 | ||
599 | mapMatrix :: (Element a, Element b) => (a -> b) -> Matrix a -> Matrix b | 617 | mapMatrix :: (Element a, Element b) => (a -> b) -> Matrix a -> Matrix b |
600 | mapMatrix f = liftMatrix (mapVector f) | 618 | mapMatrix f = liftMatrix (mapVector f) |
601 | |||
diff --git a/packages/base/src/Internal/IO.hs b/packages/base/src/Internal/IO.hs index a899cfd..b0f5606 100644 --- a/packages/base/src/Internal/IO.hs +++ b/packages/base/src/Internal/IO.hs | |||
@@ -20,7 +20,7 @@ import Internal.Devel | |||
20 | import Internal.Vector | 20 | import Internal.Vector |
21 | import Internal.Matrix | 21 | import Internal.Matrix |
22 | import Internal.Vectorized | 22 | import Internal.Vectorized |
23 | import Text.Printf(printf) | 23 | import Text.Printf(printf, PrintfArg, PrintfType) |
24 | import Data.List(intersperse,transpose) | 24 | import Data.List(intersperse,transpose) |
25 | import Data.Complex | 25 | import Data.Complex |
26 | 26 | ||
@@ -78,12 +78,18 @@ disps d x = sdims x ++ " " ++ formatScaled d x | |||
78 | dispf :: Int -> Matrix Double -> String | 78 | dispf :: Int -> Matrix Double -> String |
79 | dispf d x = sdims x ++ "\n" ++ formatFixed (if isInt x then 0 else d) x | 79 | dispf d x = sdims x ++ "\n" ++ formatFixed (if isInt x then 0 else d) x |
80 | 80 | ||
81 | sdims :: Matrix t -> [Char] | ||
81 | sdims x = show (rows x) ++ "x" ++ show (cols x) | 82 | sdims x = show (rows x) ++ "x" ++ show (cols x) |
82 | 83 | ||
84 | formatFixed :: (Show a, Text.Printf.PrintfArg t, Element t) | ||
85 | => a -> Matrix t -> String | ||
83 | formatFixed d x = format " " (printf ("%."++show d++"f")) $ x | 86 | formatFixed d x = format " " (printf ("%."++show d++"f")) $ x |
84 | 87 | ||
88 | isInt :: Matrix Double -> Bool | ||
85 | isInt = all lookslikeInt . toList . flatten | 89 | isInt = all lookslikeInt . toList . flatten |
86 | 90 | ||
91 | formatScaled :: (Text.Printf.PrintfArg b, RealFrac b, Floating b, Num t, Element b, Show t) | ||
92 | => t -> Matrix b -> [Char] | ||
87 | formatScaled dec t = "E"++show o++"\n" ++ ss | 93 | formatScaled dec t = "E"++show o++"\n" ++ ss |
88 | where ss = format " " (printf fmt. g) t | 94 | where ss = format " " (printf fmt. g) t |
89 | g x | o >= 0 = x/10^(o::Int) | 95 | g x | o >= 0 = x/10^(o::Int) |
@@ -133,14 +139,18 @@ showComplex d (a:+b) | |||
133 | s2 = if b<0 then "-" else "" | 139 | s2 = if b<0 then "-" else "" |
134 | s3 = if b<0 then "-" else "+" | 140 | s3 = if b<0 then "-" else "+" |
135 | 141 | ||
142 | shcr :: (Show a, Show t1, Text.Printf.PrintfType t, Text.Printf.PrintfArg t1, RealFrac t1) | ||
143 | => a -> t1 -> t | ||
136 | shcr d a | lookslikeInt a = printf "%.0f" a | 144 | shcr d a | lookslikeInt a = printf "%.0f" a |
137 | | otherwise = printf ("%."++show d++"f") a | 145 | | otherwise = printf ("%."++show d++"f") a |
138 | 146 | ||
139 | 147 | lookslikeInt :: (Show a, RealFrac a) => a -> Bool | |
140 | lookslikeInt x = show (round x :: Int) ++".0" == shx || "-0.0" == shx | 148 | lookslikeInt x = show (round x :: Int) ++".0" == shx || "-0.0" == shx |
141 | where shx = show x | 149 | where shx = show x |
142 | 150 | ||
151 | isZero :: Show a => a -> Bool | ||
143 | isZero x = show x `elem` ["0.0","-0.0"] | 152 | isZero x = show x `elem` ["0.0","-0.0"] |
153 | isOne :: Show a => a -> Bool | ||
144 | isOne x = show x `elem` ["1.0","-1.0"] | 154 | isOne x = show x `elem` ["1.0","-1.0"] |
145 | 155 | ||
146 | -- | Pretty print a complex matrix with at most n decimal digits. | 156 | -- | Pretty print a complex matrix with at most n decimal digits. |
@@ -168,6 +178,6 @@ loadMatrix f = do | |||
168 | else | 178 | else |
169 | return (reshape c v) | 179 | return (reshape c v) |
170 | 180 | ||
171 | 181 | loadMatrix' :: FilePath -> IO (Maybe (Matrix Double)) | |
172 | loadMatrix' name = mbCatch (loadMatrix name) | 182 | loadMatrix' name = mbCatch (loadMatrix name) |
173 | 183 | ||
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index e306454..64cf2f5 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs | |||
@@ -1,6 +1,8 @@ | |||
1 | {-# LANGUAGE TypeOperators #-} | 1 | {-# LANGUAGE TypeOperators #-} |
2 | {-# LANGUAGE ViewPatterns #-} | 2 | {-# LANGUAGE ViewPatterns #-} |
3 | 3 | ||
4 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
5 | |||
4 | ----------------------------------------------------------------------------- | 6 | ----------------------------------------------------------------------------- |
5 | -- | | 7 | -- | |
6 | -- Module : Numeric.LinearAlgebra.LAPACK | 8 | -- Module : Numeric.LinearAlgebra.LAPACK |
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 2856ec2..5436e59 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -57,19 +57,24 @@ cols :: Matrix t -> Int | |||
57 | cols = icols | 57 | cols = icols |
58 | {-# INLINE cols #-} | 58 | {-# INLINE cols #-} |
59 | 59 | ||
60 | size :: Matrix t -> (Int, Int) | ||
60 | size m = (irows m, icols m) | 61 | size m = (irows m, icols m) |
61 | {-# INLINE size #-} | 62 | {-# INLINE size #-} |
62 | 63 | ||
64 | rowOrder :: Matrix t -> Bool | ||
63 | rowOrder m = xCol m == 1 || cols m == 1 | 65 | rowOrder m = xCol m == 1 || cols m == 1 |
64 | {-# INLINE rowOrder #-} | 66 | {-# INLINE rowOrder #-} |
65 | 67 | ||
68 | colOrder :: Matrix t -> Bool | ||
66 | colOrder m = xRow m == 1 || rows m == 1 | 69 | colOrder m = xRow m == 1 || rows m == 1 |
67 | {-# INLINE colOrder #-} | 70 | {-# INLINE colOrder #-} |
68 | 71 | ||
72 | is1d :: Matrix t -> Bool | ||
69 | is1d (size->(r,c)) = r==1 || c==1 | 73 | is1d (size->(r,c)) = r==1 || c==1 |
70 | {-# INLINE is1d #-} | 74 | {-# INLINE is1d #-} |
71 | 75 | ||
72 | -- data is not contiguous | 76 | -- data is not contiguous |
77 | isSlice :: Storable t => Matrix t -> Bool | ||
73 | isSlice m@(size->(r,c)) = r*c < dim (xdat m) | 78 | isSlice m@(size->(r,c)) = r*c < dim (xdat m) |
74 | {-# INLINE isSlice #-} | 79 | {-# INLINE isSlice #-} |
75 | 80 | ||
@@ -136,16 +141,20 @@ instance Storable t => TransArray (Matrix t) | |||
136 | {-# INLINE applyRaw #-} | 141 | {-# INLINE applyRaw #-} |
137 | 142 | ||
138 | infixr 1 # | 143 | infixr 1 # |
144 | (#) :: TransArray c => c -> (b -> IO r) -> Trans c b -> IO r | ||
139 | a # b = apply a b | 145 | a # b = apply a b |
140 | {-# INLINE (#) #-} | 146 | {-# INLINE (#) #-} |
141 | 147 | ||
148 | (#!) :: (TransArray c, TransArray c1) => c1 -> c -> Trans c1 (Trans c (IO r)) -> IO r | ||
142 | a #! b = a # b # id | 149 | a #! b = a # b # id |
143 | {-# INLINE (#!) #-} | 150 | {-# INLINE (#!) #-} |
144 | 151 | ||
145 | -------------------------------------------------------------------------------- | 152 | -------------------------------------------------------------------------------- |
146 | 153 | ||
154 | copy :: Element t => MatrixOrder -> Matrix t -> IO (Matrix t) | ||
147 | copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) | 155 | copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) |
148 | 156 | ||
157 | extractAll :: Element t => MatrixOrder -> Matrix t -> Matrix t | ||
149 | extractAll ord m = unsafePerformIO (copy ord m) | 158 | extractAll ord m = unsafePerformIO (copy ord m) |
150 | 159 | ||
151 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. | 160 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. |
@@ -224,11 +233,13 @@ m@Matrix {irows = r, icols = c} @@> (i,j) | |||
224 | {-# INLINE (@@>) #-} | 233 | {-# INLINE (@@>) #-} |
225 | 234 | ||
226 | -- Unsafe matrix access without range checking | 235 | -- Unsafe matrix access without range checking |
236 | atM' :: Storable t => Matrix t -> Int -> Int -> t | ||
227 | atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m)) | 237 | atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m)) |
228 | {-# INLINE atM' #-} | 238 | {-# INLINE atM' #-} |
229 | 239 | ||
230 | ------------------------------------------------------------------ | 240 | ------------------------------------------------------------------ |
231 | 241 | ||
242 | matrixFromVector :: Storable t => MatrixOrder -> Int -> Int -> Vector t -> Matrix t | ||
232 | matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } | 243 | matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } |
233 | matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } | 244 | matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } |
234 | matrixFromVector o r c v | 245 | matrixFromVector o r c v |
@@ -388,18 +399,21 @@ subMatrix (r0,c0) (rt,ct) m | |||
388 | 399 | ||
389 | -------------------------------------------------------------------------- | 400 | -------------------------------------------------------------------------- |
390 | 401 | ||
402 | maxZ :: (Num t1, Ord t1, Foldable t) => t t1 -> t1 | ||
391 | maxZ xs = if minimum xs == 0 then 0 else maximum xs | 403 | maxZ xs = if minimum xs == 0 then 0 else maximum xs |
392 | 404 | ||
405 | conformMs :: Element t => [Matrix t] -> [Matrix t] | ||
393 | conformMs ms = map (conformMTo (r,c)) ms | 406 | conformMs ms = map (conformMTo (r,c)) ms |
394 | where | 407 | where |
395 | r = maxZ (map rows ms) | 408 | r = maxZ (map rows ms) |
396 | c = maxZ (map cols ms) | 409 | c = maxZ (map cols ms) |
397 | 410 | ||
398 | 411 | conformVs :: Element t => [Vector t] -> [Vector t] | |
399 | conformVs vs = map (conformVTo n) vs | 412 | conformVs vs = map (conformVTo n) vs |
400 | where | 413 | where |
401 | n = maxZ (map dim vs) | 414 | n = maxZ (map dim vs) |
402 | 415 | ||
416 | conformMTo :: Element t => (Int, Int) -> Matrix t -> Matrix t | ||
403 | conformMTo (r,c) m | 417 | conformMTo (r,c) m |
404 | | size m == (r,c) = m | 418 | | size m == (r,c) = m |
405 | | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) | 419 | | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) |
@@ -407,18 +421,24 @@ conformMTo (r,c) m | |||
407 | | size m == (1,c) = repRows r m | 421 | | size m == (1,c) = repRows r m |
408 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c) | 422 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c) |
409 | 423 | ||
424 | conformVTo :: Element t => Int -> Vector t -> Vector t | ||
410 | conformVTo n v | 425 | conformVTo n v |
411 | | dim v == n = v | 426 | | dim v == n = v |
412 | | dim v == 1 = constantD (v@>0) n | 427 | | dim v == 1 = constantD (v@>0) n |
413 | | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n | 428 | | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n |
414 | 429 | ||
430 | repRows :: Element t => Int -> Matrix t -> Matrix t | ||
415 | repRows n x = fromRows (replicate n (flatten x)) | 431 | repRows n x = fromRows (replicate n (flatten x)) |
432 | repCols :: Element t => Int -> Matrix t -> Matrix t | ||
416 | repCols n x = fromColumns (replicate n (flatten x)) | 433 | repCols n x = fromColumns (replicate n (flatten x)) |
417 | 434 | ||
435 | shSize :: Matrix t -> [Char] | ||
418 | shSize = shDim . size | 436 | shSize = shDim . size |
419 | 437 | ||
438 | shDim :: (Show a, Show a1) => (a1, a) -> [Char] | ||
420 | shDim (r,c) = "(" ++ show r ++"x"++ show c ++")" | 439 | shDim (r,c) = "(" ++ show r ++"x"++ show c ++")" |
421 | 440 | ||
441 | emptyM :: Storable t => Int -> Int -> Matrix t | ||
422 | emptyM r c = matrixFromVector RowMajor r c (fromList[]) | 442 | emptyM r c = matrixFromVector RowMajor r c (fromList[]) |
423 | 443 | ||
424 | ---------------------------------------------------------------------- | 444 | ---------------------------------------------------------------------- |
@@ -433,6 +453,11 @@ instance (Storable t, NFData t) => NFData (Matrix t) | |||
433 | 453 | ||
434 | --------------------------------------------------------------- | 454 | --------------------------------------------------------------- |
435 | 455 | ||
456 | extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1, | ||
457 | Storable t, Num t3, Num t2, Integral t1, Integral t) | ||
458 | => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t | ||
459 | -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) | ||
460 | -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a) | ||
436 | extractAux f ord m moder vr modec vc = do | 461 | extractAux f ord m moder vr modec vc = do |
437 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr | 462 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr |
438 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc | 463 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc |
@@ -452,6 +477,9 @@ foreign import ccall unsafe "extractL" c_extractL :: Extr Z | |||
452 | 477 | ||
453 | --------------------------------------------------------------- | 478 | --------------------------------------------------------------- |
454 | 479 | ||
480 | setRectAux :: (TransArray c1, TransArray c) | ||
481 | => (CInt -> CInt -> Trans c1 (Trans c (IO CInt))) | ||
482 | -> Int -> Int -> c1 -> c -> IO () | ||
455 | setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" | 483 | setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" |
456 | 484 | ||
457 | type SetRect x = I -> I -> x ::> x::> Ok | 485 | type SetRect x = I -> I -> x ::> x::> Ok |
@@ -465,19 +493,29 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z | |||
465 | 493 | ||
466 | -------------------------------------------------------------------------------- | 494 | -------------------------------------------------------------------------------- |
467 | 495 | ||
496 | sortG :: (Storable t, Storable a) | ||
497 | => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a | ||
468 | sortG f v = unsafePerformIO $ do | 498 | sortG f v = unsafePerformIO $ do |
469 | r <- createVector (dim v) | 499 | r <- createVector (dim v) |
470 | (v #! r) f #|"sortG" | 500 | (v #! r) f #|"sortG" |
471 | return r | 501 | return r |
472 | 502 | ||
503 | sortIdxD :: Vector Double -> Vector CInt | ||
473 | sortIdxD = sortG c_sort_indexD | 504 | sortIdxD = sortG c_sort_indexD |
505 | sortIdxF :: Vector Float -> Vector CInt | ||
474 | sortIdxF = sortG c_sort_indexF | 506 | sortIdxF = sortG c_sort_indexF |
507 | sortIdxI :: Vector CInt -> Vector CInt | ||
475 | sortIdxI = sortG c_sort_indexI | 508 | sortIdxI = sortG c_sort_indexI |
509 | sortIdxL :: Vector Z -> Vector I | ||
476 | sortIdxL = sortG c_sort_indexL | 510 | sortIdxL = sortG c_sort_indexL |
477 | 511 | ||
512 | sortValD :: Vector Double -> Vector Double | ||
478 | sortValD = sortG c_sort_valD | 513 | sortValD = sortG c_sort_valD |
514 | sortValF :: Vector Float -> Vector Float | ||
479 | sortValF = sortG c_sort_valF | 515 | sortValF = sortG c_sort_valF |
516 | sortValI :: Vector CInt -> Vector CInt | ||
480 | sortValI = sortG c_sort_valI | 517 | sortValI = sortG c_sort_valI |
518 | sortValL :: Vector Z -> Vector Z | ||
481 | sortValL = sortG c_sort_valL | 519 | sortValL = sortG c_sort_valL |
482 | 520 | ||
483 | foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) | 521 | foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) |
@@ -492,14 +530,21 @@ foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok | |||
492 | 530 | ||
493 | -------------------------------------------------------------------------------- | 531 | -------------------------------------------------------------------------------- |
494 | 532 | ||
533 | compareG :: (TransArray c, Storable t, Storable a) | ||
534 | => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) | ||
535 | -> c -> Vector t -> Vector a | ||
495 | compareG f u v = unsafePerformIO $ do | 536 | compareG f u v = unsafePerformIO $ do |
496 | r <- createVector (dim v) | 537 | r <- createVector (dim v) |
497 | (u # v #! r) f #|"compareG" | 538 | (u # v #! r) f #|"compareG" |
498 | return r | 539 | return r |
499 | 540 | ||
541 | compareD :: Vector Double -> Vector Double -> Vector CInt | ||
500 | compareD = compareG c_compareD | 542 | compareD = compareG c_compareD |
543 | compareF :: Vector Float -> Vector Float -> Vector CInt | ||
501 | compareF = compareG c_compareF | 544 | compareF = compareG c_compareF |
545 | compareI :: Vector CInt -> Vector CInt -> Vector CInt | ||
502 | compareI = compareG c_compareI | 546 | compareI = compareG c_compareI |
547 | compareL :: Vector Z -> Vector Z -> Vector CInt | ||
503 | compareL = compareG c_compareL | 548 | compareL = compareG c_compareL |
504 | 549 | ||
505 | foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) | 550 | foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) |
@@ -509,16 +554,33 @@ foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok | |||
509 | 554 | ||
510 | -------------------------------------------------------------------------------- | 555 | -------------------------------------------------------------------------------- |
511 | 556 | ||
557 | selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a) | ||
558 | => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt))) | ||
559 | -> c2 -> c1 -> Vector t -> c -> Vector a | ||
512 | selectG f c u v w = unsafePerformIO $ do | 560 | selectG f c u v w = unsafePerformIO $ do |
513 | r <- createVector (dim v) | 561 | r <- createVector (dim v) |
514 | (c # u # v # w #! r) f #|"selectG" | 562 | (c # u # v # w #! r) f #|"selectG" |
515 | return r | 563 | return r |
516 | 564 | ||
565 | selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double | ||
517 | selectD = selectG c_selectD | 566 | selectD = selectG c_selectD |
567 | selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float | ||
518 | selectF = selectG c_selectF | 568 | selectF = selectG c_selectF |
569 | selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt | ||
519 | selectI = selectG c_selectI | 570 | selectI = selectG c_selectI |
571 | selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z | ||
520 | selectL = selectG c_selectL | 572 | selectL = selectG c_selectL |
573 | selectC :: Vector CInt | ||
574 | -> Vector (Complex Double) | ||
575 | -> Vector (Complex Double) | ||
576 | -> Vector (Complex Double) | ||
577 | -> Vector (Complex Double) | ||
521 | selectC = selectG c_selectC | 578 | selectC = selectG c_selectC |
579 | selectQ :: Vector CInt | ||
580 | -> Vector (Complex Float) | ||
581 | -> Vector (Complex Float) | ||
582 | -> Vector (Complex Float) | ||
583 | -> Vector (Complex Float) | ||
522 | selectQ = selectG c_selectQ | 584 | selectQ = selectG c_selectQ |
523 | 585 | ||
524 | type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) | 586 | type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) |
@@ -532,16 +594,29 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z | |||
532 | 594 | ||
533 | --------------------------------------------------------------------------- | 595 | --------------------------------------------------------------------------- |
534 | 596 | ||
597 | remapG :: (TransArray c, TransArray c1, Storable t, Storable a) | ||
598 | => (CInt -> CInt -> CInt -> CInt -> Ptr t | ||
599 | -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))) | ||
600 | -> Matrix t -> c1 -> c -> Matrix a | ||
535 | remapG f i j m = unsafePerformIO $ do | 601 | remapG f i j m = unsafePerformIO $ do |
536 | r <- createMatrix RowMajor (rows i) (cols i) | 602 | r <- createMatrix RowMajor (rows i) (cols i) |
537 | (i # j # m #! r) f #|"remapG" | 603 | (i # j # m #! r) f #|"remapG" |
538 | return r | 604 | return r |
539 | 605 | ||
606 | remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double | ||
540 | remapD = remapG c_remapD | 607 | remapD = remapG c_remapD |
608 | remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float | ||
541 | remapF = remapG c_remapF | 609 | remapF = remapG c_remapF |
610 | remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt | ||
542 | remapI = remapG c_remapI | 611 | remapI = remapG c_remapI |
612 | remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z | ||
543 | remapL = remapG c_remapL | 613 | remapL = remapG c_remapL |
614 | remapC :: Matrix CInt | ||
615 | -> Matrix CInt | ||
616 | -> Matrix (Complex Double) | ||
617 | -> Matrix (Complex Double) | ||
544 | remapC = remapG c_remapC | 618 | remapC = remapG c_remapC |
619 | remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float) | ||
545 | remapQ = remapG c_remapQ | 620 | remapQ = remapG c_remapQ |
546 | 621 | ||
547 | type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) | 622 | type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) |
@@ -555,6 +630,9 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z | |||
555 | 630 | ||
556 | -------------------------------------------------------------------------------- | 631 | -------------------------------------------------------------------------------- |
557 | 632 | ||
633 | rowOpAux :: (TransArray c, Storable a) => | ||
634 | (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt)) | ||
635 | -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO () | ||
558 | rowOpAux f c x i1 i2 j1 j2 m = do | 636 | rowOpAux f c x i1 i2 j1 j2 m = do |
559 | px <- newArray [x] | 637 | px <- newArray [x] |
560 | (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" | 638 | (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" |
@@ -573,6 +651,9 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z | |||
573 | 651 | ||
574 | -------------------------------------------------------------------------------- | 652 | -------------------------------------------------------------------------------- |
575 | 653 | ||
654 | gemmg :: (TransArray c1, TransArray c, TransArray c2, TransArray c3) | ||
655 | => Trans c3 (Trans c2 (Trans c1 (Trans c (IO CInt)))) | ||
656 | -> c3 -> c2 -> c1 -> c -> IO () | ||
576 | gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" | 657 | gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" |
577 | 658 | ||
578 | type Tgemm x = x :> x ::> x ::> x ::> Ok | 659 | type Tgemm x = x :> x ::> x ::> x ::> Ok |
@@ -588,6 +669,10 @@ foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z | |||
588 | 669 | ||
589 | -------------------------------------------------------------------------------- | 670 | -------------------------------------------------------------------------------- |
590 | 671 | ||
672 | reorderAux :: (TransArray c, Storable t, Storable a1, Storable t1, Storable a) => | ||
673 | (CInt -> Ptr a -> CInt -> Ptr t1 | ||
674 | -> Trans c (CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt)) | ||
675 | -> Vector t1 -> c -> Vector t -> Vector a1 | ||
591 | reorderAux f s d v = unsafePerformIO $ do | 676 | reorderAux f s d v = unsafePerformIO $ do |
592 | k <- createVector (dim s) | 677 | k <- createVector (dim s) |
593 | r <- createVector (dim v) | 678 | r <- createVector (dim v) |
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index 9d51444..eb0c5a8 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs | |||
@@ -13,6 +13,9 @@ | |||
13 | {-# LANGUAGE TypeFamilies #-} | 13 | {-# LANGUAGE TypeFamilies #-} |
14 | {-# LANGUAGE TypeOperators #-} | 14 | {-# LANGUAGE TypeOperators #-} |
15 | 15 | ||
16 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
17 | {-# OPTIONS_GHC -fno-warn-missing-methods #-} | ||
18 | |||
16 | {- | | 19 | {- | |
17 | Module : Internal.Modular | 20 | Module : Internal.Modular |
18 | Copyright : (c) Alberto Ruiz 2015 | 21 | Copyright : (c) Alberto Ruiz 2015 |
diff --git a/packages/base/src/Internal/Numeric.hs b/packages/base/src/Internal/Numeric.hs index c9ef0c5..fd0a217 100644 --- a/packages/base/src/Internal/Numeric.hs +++ b/packages/base/src/Internal/Numeric.hs | |||
@@ -5,6 +5,8 @@ | |||
5 | {-# LANGUAGE FunctionalDependencies #-} | 5 | {-# LANGUAGE FunctionalDependencies #-} |
6 | {-# LANGUAGE UndecidableInstances #-} | 6 | {-# LANGUAGE UndecidableInstances #-} |
7 | 7 | ||
8 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
9 | |||
8 | ----------------------------------------------------------------------------- | 10 | ----------------------------------------------------------------------------- |
9 | -- | | 11 | -- | |
10 | -- Module : Data.Packed.Internal.Numeric | 12 | -- Module : Data.Packed.Internal.Numeric |
@@ -788,13 +790,7 @@ type instance RealOf (Complex Float) = Float | |||
788 | type instance RealOf I = I | 790 | type instance RealOf I = I |
789 | type instance RealOf Z = Z | 791 | type instance RealOf Z = Z |
790 | 792 | ||
791 | type family ComplexOf x | 793 | type ComplexOf x = Complex (RealOf x) |
792 | |||
793 | type instance ComplexOf Double = Complex Double | ||
794 | type instance ComplexOf (Complex Double) = Complex Double | ||
795 | |||
796 | type instance ComplexOf Float = Complex Float | ||
797 | type instance ComplexOf (Complex Float) = Complex Float | ||
798 | 794 | ||
799 | type family SingleOf x | 795 | type family SingleOf x |
800 | 796 | ||
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index 544c9e4..7d54e6d 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs | |||
@@ -81,6 +81,8 @@ unsafeFreezeVector :: (Storable t) => STVector s t -> ST s (Vector t) | |||
81 | unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x | 81 | unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x |
82 | 82 | ||
83 | {-# INLINE safeIndexV #-} | 83 | {-# INLINE safeIndexV #-} |
84 | safeIndexV :: Storable t2 | ||
85 | => (STVector s t2 -> Int -> t) -> STVector t1 t2 -> Int -> t | ||
84 | safeIndexV f (STVector v) k | 86 | safeIndexV f (STVector v) k |
85 | | k < 0 || k>= dim v = error $ "out of range error in vector (dim=" | 87 | | k < 0 || k>= dim v = error $ "out of range error in vector (dim=" |
86 | ++show (dim v)++", pos="++show k++")" | 88 | ++show (dim v)++", pos="++show k++")" |
@@ -150,9 +152,12 @@ unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x | |||
150 | freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t) | 152 | freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t) |
151 | freezeMatrix m = liftSTMatrix id m | 153 | freezeMatrix m = liftSTMatrix id m |
152 | 154 | ||
155 | cloneMatrix :: Element t => Matrix t -> IO (Matrix t) | ||
153 | cloneMatrix m = copy (orderOf m) m | 156 | cloneMatrix m = copy (orderOf m) m |
154 | 157 | ||
155 | {-# INLINE safeIndexM #-} | 158 | {-# INLINE safeIndexM #-} |
159 | safeIndexM :: (STMatrix s t2 -> Int -> Int -> t) | ||
160 | -> STMatrix t1 t2 -> Int -> Int -> t | ||
156 | safeIndexM f (STMatrix m) r c | 161 | safeIndexM f (STMatrix m) r c |
157 | | r<0 || r>=rows m || | 162 | | r<0 || r>=rows m || |
158 | c<0 || c>=cols m = error $ "out of range error in matrix (size=" | 163 | c<0 || c>=cols m = error $ "out of range error in matrix (size=" |
@@ -184,6 +189,7 @@ data ColRange = AllCols | |||
184 | | Col Int | 189 | | Col Int |
185 | | FromCol Int | 190 | | FromCol Int |
186 | 191 | ||
192 | getColRange :: Int -> ColRange -> (Int, Int) | ||
187 | getColRange c AllCols = (0,c-1) | 193 | getColRange c AllCols = (0,c-1) |
188 | getColRange c (ColRange a b) = (a `mod` c, b `mod` c) | 194 | getColRange c (ColRange a b) = (a `mod` c, b `mod` c) |
189 | getColRange c (Col a) = (a `mod` c, a `mod` c) | 195 | getColRange c (Col a) = (a `mod` c, a `mod` c) |
@@ -194,6 +200,7 @@ data RowRange = AllRows | |||
194 | | Row Int | 200 | | Row Int |
195 | | FromRow Int | 201 | | FromRow Int |
196 | 202 | ||
203 | getRowRange :: Int -> RowRange -> (Int, Int) | ||
197 | getRowRange r AllRows = (0,r-1) | 204 | getRowRange r AllRows = (0,r-1) |
198 | getRowRange r (RowRange a b) = (a `mod` r, b `mod` r) | 205 | getRowRange r (RowRange a b) = (a `mod` r, b `mod` r) |
199 | getRowRange r (Row a) = (a `mod` r, a `mod` r) | 206 | getRowRange r (Row a) = (a `mod` r, a `mod` r) |
@@ -223,6 +230,7 @@ rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m | |||
223 | i2' = i2 `mod` (rows m) | 230 | i2' = i2 `mod` (rows m) |
224 | 231 | ||
225 | 232 | ||
233 | extractMatrix :: Element a => STMatrix t a -> RowRange -> ColRange -> ST s (Matrix a) | ||
226 | extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) | 234 | extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) |
227 | where | 235 | where |
228 | (i1,i2) = getRowRange (rows m) rr | 236 | (i1,i2) = getRowRange (rows m) rr |
@@ -231,6 +239,7 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[ | |||
231 | -- | r0 c0 height width | 239 | -- | r0 c0 height width |
232 | data Slice s t = Slice (STMatrix s t) Int Int Int Int | 240 | data Slice s t = Slice (STMatrix s t) Int Int Int Int |
233 | 241 | ||
242 | slice :: Element a => Slice t a -> Matrix a | ||
234 | slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m | 243 | slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m |
235 | 244 | ||
236 | gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () | 245 | gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () |
@@ -238,7 +247,7 @@ gemmm beta (slice->r) alpha (slice->a) (slice->b) = res | |||
238 | where | 247 | where |
239 | res = unsafeIOToST (gemm v a b r) | 248 | res = unsafeIOToST (gemm v a b r) |
240 | v = fromList [alpha,beta] | 249 | v = fromList [alpha,beta] |
241 | 250 | ||
242 | 251 | ||
243 | mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) | 252 | mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) |
244 | mutable f a = runST $ do | 253 | mutable f a = runST $ do |
@@ -246,4 +255,3 @@ mutable f a = runST $ do | |||
246 | info <- f (rows a, cols a) x | 255 | info <- f (rows a, cols a) x |
247 | r <- unsafeFreezeMatrix x | 256 | r <- unsafeFreezeMatrix x |
248 | return (r,info) | 257 | return (r,info) |
249 | |||
diff --git a/packages/base/src/Internal/Sparse.hs b/packages/base/src/Internal/Sparse.hs index a8a5fe0..fbea11a 100644 --- a/packages/base/src/Internal/Sparse.hs +++ b/packages/base/src/Internal/Sparse.hs | |||
@@ -2,6 +2,8 @@ | |||
2 | {-# LANGUAGE MultiParamTypeClasses #-} | 2 | {-# LANGUAGE MultiParamTypeClasses #-} |
3 | {-# LANGUAGE FlexibleInstances #-} | 3 | {-# LANGUAGE FlexibleInstances #-} |
4 | 4 | ||
5 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
6 | |||
5 | module Internal.Sparse( | 7 | module Internal.Sparse( |
6 | GMatrix(..), CSR(..), mkCSR, fromCSR, | 8 | GMatrix(..), CSR(..), mkCSR, fromCSR, |
7 | mkSparse, mkDiagR, mkDense, | 9 | mkSparse, mkDiagR, mkDense, |
diff --git a/packages/base/src/Internal/Static.hs b/packages/base/src/Internal/Static.hs index 357645e..566506c 100644 --- a/packages/base/src/Internal/Static.hs +++ b/packages/base/src/Internal/Static.hs | |||
@@ -15,6 +15,9 @@ | |||
15 | {-# LANGUAGE BangPatterns #-} | 15 | {-# LANGUAGE BangPatterns #-} |
16 | {-# LANGUAGE DeriveGeneric #-} | 16 | {-# LANGUAGE DeriveGeneric #-} |
17 | 17 | ||
18 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
19 | {-# OPTIONS_GHC -fno-warn-simplifiable-class-constraints #-} | ||
20 | |||
18 | {- | | 21 | {- | |
19 | Module : Internal.Static | 22 | Module : Internal.Static |
20 | Copyright : (c) Alberto Ruiz 2006-14 | 23 | Copyright : (c) Alberto Ruiz 2006-14 |
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs index 959e58f..f642e8d 100644 --- a/packages/base/src/Internal/Util.hs +++ b/packages/base/src/Internal/Util.hs | |||
@@ -6,6 +6,8 @@ | |||
6 | {-# LANGUAGE ScopedTypeVariables #-} | 6 | {-# LANGUAGE ScopedTypeVariables #-} |
7 | {-# LANGUAGE ViewPatterns #-} | 7 | {-# LANGUAGE ViewPatterns #-} |
8 | 8 | ||
9 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
10 | {-# OPTIONS_GHC -fno-warn-orphans #-} | ||
9 | 11 | ||
10 | ----------------------------------------------------------------------------- | 12 | ----------------------------------------------------------------------------- |
11 | {- | | 13 | {- | |
diff --git a/packages/base/src/Internal/Vector.hs b/packages/base/src/Internal/Vector.hs index e1e4aa8..6271bb6 100644 --- a/packages/base/src/Internal/Vector.hs +++ b/packages/base/src/Internal/Vector.hs | |||
@@ -1,6 +1,7 @@ | |||
1 | {-# LANGUAGE MagicHash, UnboxedTuples, BangPatterns, FlexibleContexts #-} | 1 | {-# LANGUAGE MagicHash, UnboxedTuples, BangPatterns, FlexibleContexts #-} |
2 | {-# LANGUAGE TypeSynonymInstances #-} | 2 | {-# LANGUAGE TypeSynonymInstances #-} |
3 | 3 | ||
4 | {-# OPTIONS_GHC -fno-warn-orphans #-} | ||
4 | 5 | ||
5 | -- | | 6 | -- | |
6 | -- Module : Internal.Vector | 7 | -- Module : Internal.Vector |
@@ -40,6 +41,7 @@ import qualified Data.Vector.Storable as Vector | |||
40 | import Data.Vector.Storable(Vector, fromList, unsafeToForeignPtr, unsafeFromForeignPtr, unsafeWith) | 41 | import Data.Vector.Storable(Vector, fromList, unsafeToForeignPtr, unsafeFromForeignPtr, unsafeWith) |
41 | 42 | ||
42 | import Data.Binary | 43 | import Data.Binary |
44 | import Data.Binary.Put | ||
43 | import Control.Monad(replicateM) | 45 | import Control.Monad(replicateM) |
44 | import qualified Data.ByteString.Internal as BS | 46 | import qualified Data.ByteString.Internal as BS |
45 | import Data.Vector.Storable.Internal(updPtr) | 47 | import Data.Vector.Storable.Internal(updPtr) |
@@ -92,6 +94,7 @@ createVector n = do | |||
92 | 94 | ||
93 | -} | 95 | -} |
94 | 96 | ||
97 | safeRead :: Storable a => Vector a -> (Ptr a -> IO c) -> c | ||
95 | safeRead v = inlinePerformIO . unsafeWith v | 98 | safeRead v = inlinePerformIO . unsafeWith v |
96 | {-# INLINE safeRead #-} | 99 | {-# INLINE safeRead #-} |
97 | 100 | ||
@@ -287,11 +290,13 @@ foldVectorWithIndex f x v = unsafePerformIO $ | |||
287 | go (dim v -1) x | 290 | go (dim v -1) x |
288 | {-# INLINE foldVectorWithIndex #-} | 291 | {-# INLINE foldVectorWithIndex #-} |
289 | 292 | ||
293 | foldLoop :: (Int -> t -> t) -> t -> Int -> t | ||
290 | foldLoop f s0 d = go (d - 1) s0 | 294 | foldLoop f s0 d = go (d - 1) s0 |
291 | where | 295 | where |
292 | go 0 s = f (0::Int) s | 296 | go 0 s = f (0::Int) s |
293 | go !j !s = go (j - 1) (f j s) | 297 | go !j !s = go (j - 1) (f j s) |
294 | 298 | ||
299 | foldVectorG :: Storable t1 => (Int -> (Int -> t1) -> t -> t) -> t -> Vector t1 -> t | ||
295 | foldVectorG f s0 v = foldLoop g s0 (dim v) | 300 | foldVectorG f s0 v = foldLoop g s0 (dim v) |
296 | where g !k !s = f k (safeRead v . flip peekElemOff) s | 301 | where g !k !s = f k (safeRead v . flip peekElemOff) s |
297 | {-# INLINE g #-} -- Thanks to Ryan Ingram (http://permalink.gmane.org/gmane.comp.lang.haskell.cafe/46479) | 302 | {-# INLINE g #-} -- Thanks to Ryan Ingram (http://permalink.gmane.org/gmane.comp.lang.haskell.cafe/46479) |
@@ -394,8 +399,10 @@ chunks d = let c = d `div` chunk | |||
394 | m = d `mod` chunk | 399 | m = d `mod` chunk |
395 | in if m /= 0 then reverse (m:(replicate c chunk)) else (replicate c chunk) | 400 | in if m /= 0 then reverse (m:(replicate c chunk)) else (replicate c chunk) |
396 | 401 | ||
402 | putVector :: (Storable t, Binary t) => Vector t -> Data.Binary.Put.PutM () | ||
397 | putVector v = mapM_ put $! toList v | 403 | putVector v = mapM_ put $! toList v |
398 | 404 | ||
405 | getVector :: (Storable a, Binary a) => Int -> Get (Vector a) | ||
399 | getVector d = do | 406 | getVector d = do |
400 | xs <- replicateM d get | 407 | xs <- replicateM d get |
401 | return $! fromList xs | 408 | return $! fromList xs |
diff --git a/packages/base/src/Internal/Vectorized.hs b/packages/base/src/Internal/Vectorized.hs index 2990173..32430c6 100644 --- a/packages/base/src/Internal/Vectorized.hs +++ b/packages/base/src/Internal/Vectorized.hs | |||
@@ -28,12 +28,15 @@ import System.IO.Unsafe(unsafePerformIO) | |||
28 | import Control.Monad(when) | 28 | import Control.Monad(when) |
29 | 29 | ||
30 | infixr 1 # | 30 | infixr 1 # |
31 | (#) :: TransArray c => c -> (b -> IO r) -> TransRaw c b -> IO r | ||
31 | a # b = applyRaw a b | 32 | a # b = applyRaw a b |
32 | {-# INLINE (#) #-} | 33 | {-# INLINE (#) #-} |
33 | 34 | ||
35 | (#!) :: (TransArray c, TransArray c1) => c1 -> c -> TransRaw c1 (TransRaw c (IO r)) -> IO r | ||
34 | a #! b = a # b # id | 36 | a #! b = a # b # id |
35 | {-# INLINE (#!) #-} | 37 | {-# INLINE (#!) #-} |
36 | 38 | ||
39 | fromei :: Enum a => a -> CInt | ||
37 | fromei x = fromIntegral (fromEnum x) :: CInt | 40 | fromei x = fromIntegral (fromEnum x) :: CInt |
38 | 41 | ||
39 | data FunCodeV = Sin | 42 | data FunCodeV = Sin |
@@ -100,10 +103,20 @@ sumQ = sumg c_sumQ | |||
100 | sumC :: Vector (Complex Double) -> Complex Double | 103 | sumC :: Vector (Complex Double) -> Complex Double |
101 | sumC = sumg c_sumC | 104 | sumC = sumg c_sumC |
102 | 105 | ||
106 | sumI :: ( TransRaw c (CInt -> Ptr a -> IO CInt) ~ (CInt -> Ptr I -> I :> Ok) | ||
107 | , TransArray c | ||
108 | , Storable a | ||
109 | ) | ||
110 | => I -> c -> a | ||
103 | sumI m = sumg (c_sumI m) | 111 | sumI m = sumg (c_sumI m) |
104 | 112 | ||
113 | sumL :: ( TransRaw c (CInt -> Ptr a -> IO CInt) ~ (CInt -> Ptr Z -> Z :> Ok) | ||
114 | , TransArray c | ||
115 | , Storable a | ||
116 | ) => Z -> c -> a | ||
105 | sumL m = sumg (c_sumL m) | 117 | sumL m = sumg (c_sumL m) |
106 | 118 | ||
119 | sumg :: (TransArray c, Storable a) => TransRaw c (CInt -> Ptr a -> IO CInt) -> c -> a | ||
107 | sumg f x = unsafePerformIO $ do | 120 | sumg f x = unsafePerformIO $ do |
108 | r <- createVector 1 | 121 | r <- createVector 1 |
109 | (x #! r) f #| "sum" | 122 | (x #! r) f #| "sum" |
@@ -140,6 +153,8 @@ prodI = prodg . c_prodI | |||
140 | prodL :: Z-> Vector Z -> Z | 153 | prodL :: Z-> Vector Z -> Z |
141 | prodL = prodg . c_prodL | 154 | prodL = prodg . c_prodL |
142 | 155 | ||
156 | prodg :: (TransArray c, Storable a) | ||
157 | => TransRaw c (CInt -> Ptr a -> IO CInt) -> c -> a | ||
143 | prodg f x = unsafePerformIO $ do | 158 | prodg f x = unsafePerformIO $ do |
144 | r <- createVector 1 | 159 | r <- createVector 1 |
145 | (x #! r) f #| "prod" | 160 | (x #! r) f #| "prod" |
@@ -155,16 +170,25 @@ foreign import ccall unsafe "prodL" c_prodL :: Z -> TVV Z | |||
155 | 170 | ||
156 | ------------------------------------------------------------------ | 171 | ------------------------------------------------------------------ |
157 | 172 | ||
173 | toScalarAux :: (Enum a, TransArray c, Storable a1) | ||
174 | => (CInt -> TransRaw c (CInt -> Ptr a1 -> IO CInt)) -> a -> c -> a1 | ||
158 | toScalarAux fun code v = unsafePerformIO $ do | 175 | toScalarAux fun code v = unsafePerformIO $ do |
159 | r <- createVector 1 | 176 | r <- createVector 1 |
160 | (v #! r) (fun (fromei code)) #|"toScalarAux" | 177 | (v #! r) (fun (fromei code)) #|"toScalarAux" |
161 | return (r @> 0) | 178 | return (r @> 0) |
162 | 179 | ||
180 | |||
181 | vectorMapAux :: (Enum a, Storable t, Storable a1) | ||
182 | => (CInt -> CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt) | ||
183 | -> a -> Vector t -> Vector a1 | ||
163 | vectorMapAux fun code v = unsafePerformIO $ do | 184 | vectorMapAux fun code v = unsafePerformIO $ do |
164 | r <- createVector (dim v) | 185 | r <- createVector (dim v) |
165 | (v #! r) (fun (fromei code)) #|"vectorMapAux" | 186 | (v #! r) (fun (fromei code)) #|"vectorMapAux" |
166 | return r | 187 | return r |
167 | 188 | ||
189 | vectorMapValAux :: (Enum a, Storable a2, Storable t, Storable a1) | ||
190 | => (CInt -> Ptr a2 -> CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt) | ||
191 | -> a -> a2 -> Vector t -> Vector a1 | ||
168 | vectorMapValAux fun code val v = unsafePerformIO $ do | 192 | vectorMapValAux fun code val v = unsafePerformIO $ do |
169 | r <- createVector (dim v) | 193 | r <- createVector (dim v) |
170 | pval <- newArray [val] | 194 | pval <- newArray [val] |
@@ -172,6 +196,9 @@ vectorMapValAux fun code val v = unsafePerformIO $ do | |||
172 | free pval | 196 | free pval |
173 | return r | 197 | return r |
174 | 198 | ||
199 | vectorZipAux :: (Enum a, TransArray c, Storable t, Storable a1) | ||
200 | => (CInt -> CInt -> Ptr t -> TransRaw c (CInt -> Ptr a1 -> IO CInt)) | ||
201 | -> a -> Vector t -> c -> Vector a1 | ||
175 | vectorZipAux fun code u v = unsafePerformIO $ do | 202 | vectorZipAux fun code u v = unsafePerformIO $ do |
176 | r <- createVector (dim u) | 203 | r <- createVector (dim u) |
177 | (u # v #! r) (fun (fromei code)) #|"vectorZipAux" | 204 | (u # v #! r) (fun (fromei code)) #|"vectorZipAux" |
@@ -378,6 +405,7 @@ foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> D | |||
378 | 405 | ||
379 | -------------------------------------------------------------------------------- | 406 | -------------------------------------------------------------------------------- |
380 | 407 | ||
408 | roundVector :: Vector Double -> Vector Double | ||
381 | roundVector v = unsafePerformIO $ do | 409 | roundVector v = unsafePerformIO $ do |
382 | r <- createVector (dim v) | 410 | r <- createVector (dim v) |
383 | (v #! r) c_round_vector #|"roundVector" | 411 | (v #! r) c_round_vector #|"roundVector" |
@@ -433,6 +461,8 @@ long2intV :: Vector Z -> Vector I | |||
433 | long2intV = tog c_long2int | 461 | long2intV = tog c_long2int |
434 | 462 | ||
435 | 463 | ||
464 | tog :: (Storable t, Storable a) | ||
465 | => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a | ||
436 | tog f v = unsafePerformIO $ do | 466 | tog f v = unsafePerformIO $ do |
437 | r <- createVector (dim v) | 467 | r <- createVector (dim v) |
438 | (v #! r) f #|"tog" | 468 | (v #! r) f #|"tog" |
@@ -452,6 +482,8 @@ foreign import ccall unsafe "long2int" c_long2int :: Z :> I :> Ok | |||
452 | 482 | ||
453 | --------------------------------------------------------------- | 483 | --------------------------------------------------------------- |
454 | 484 | ||
485 | stepg :: (Storable t, Storable a) | ||
486 | => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a | ||
455 | stepg f v = unsafePerformIO $ do | 487 | stepg f v = unsafePerformIO $ do |
456 | r <- createVector (dim v) | 488 | r <- createVector (dim v) |
457 | (v #! r) f #|"step" | 489 | (v #! r) f #|"step" |
@@ -477,6 +509,8 @@ foreign import ccall unsafe "stepL" c_stepL :: TVV Z | |||
477 | 509 | ||
478 | -------------------------------------------------------------------------------- | 510 | -------------------------------------------------------------------------------- |
479 | 511 | ||
512 | conjugateAux :: (Storable t, Storable a) | ||
513 | => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a | ||
480 | conjugateAux fun x = unsafePerformIO $ do | 514 | conjugateAux fun x = unsafePerformIO $ do |
481 | v <- createVector (dim x) | 515 | v <- createVector (dim x) |
482 | (x #! v) fun #|"conjugateAux" | 516 | (x #! v) fun #|"conjugateAux" |
@@ -502,6 +536,8 @@ cloneVector v = do | |||
502 | 536 | ||
503 | -------------------------------------------------------------------------------- | 537 | -------------------------------------------------------------------------------- |
504 | 538 | ||
539 | constantAux :: (Storable a1, Storable a) | ||
540 | => (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a | ||
505 | constantAux fun x n = unsafePerformIO $ do | 541 | constantAux fun x n = unsafePerformIO $ do |
506 | v <- createVector n | 542 | v <- createVector n |
507 | px <- newArray [x] | 543 | px <- newArray [x] |
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs index 520eeb7..91923e9 100644 --- a/packages/base/src/Numeric/LinearAlgebra.hs +++ b/packages/base/src/Numeric/LinearAlgebra.hs | |||
@@ -1,6 +1,8 @@ | |||
1 | {-# LANGUAGE CPP #-} | 1 | {-# LANGUAGE CPP #-} |
2 | {-# LANGUAGE FlexibleContexts #-} | 2 | {-# LANGUAGE FlexibleContexts #-} |
3 | 3 | ||
4 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
5 | |||
4 | ----------------------------------------------------------------------------- | 6 | ----------------------------------------------------------------------------- |
5 | {- | | 7 | {- | |
6 | Module : Numeric.LinearAlgebra | 8 | Module : Numeric.LinearAlgebra |
diff --git a/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs b/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs index 3a84645..57e5cf1 100644 --- a/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs +++ b/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs | |||
@@ -28,7 +28,9 @@ infixr 8 <·> | |||
28 | (<·>) :: Numeric t => Vector t -> Vector t -> t | 28 | (<·>) :: Numeric t => Vector t -> Vector t -> t |
29 | (<·>) = dot | 29 | (<·>) = dot |
30 | 30 | ||
31 | app :: Numeric t => Matrix t -> Vector t -> Vector t | ||
31 | app m v = m #> v | 32 | app m v = m #> v |
32 | 33 | ||
34 | mul :: Numeric t => Matrix t -> Matrix t -> Matrix t | ||
33 | mul a b = a <> b | 35 | mul a b = a <> b |
34 | 36 | ||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs index e328904..2e05c90 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs | |||
@@ -14,6 +14,8 @@ | |||
14 | {-# LANGUAGE GADTs #-} | 14 | {-# LANGUAGE GADTs #-} |
15 | {-# LANGUAGE TypeFamilies #-} | 15 | {-# LANGUAGE TypeFamilies #-} |
16 | 16 | ||
17 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
18 | {-# OPTIONS_GHC -fno-warn-orphans #-} | ||
17 | 19 | ||
18 | {- | | 20 | {- | |
19 | Module : Numeric.LinearAlgebra.Static | 21 | Module : Numeric.LinearAlgebra.Static |
diff --git a/packages/base/src/Numeric/Matrix.hs b/packages/base/src/Numeric/Matrix.hs index 06da150..6e3db61 100644 --- a/packages/base/src/Numeric/Matrix.hs +++ b/packages/base/src/Numeric/Matrix.hs | |||
@@ -4,6 +4,8 @@ | |||
4 | {-# LANGUAGE UndecidableInstances #-} | 4 | {-# LANGUAGE UndecidableInstances #-} |
5 | {-# LANGUAGE MultiParamTypeClasses #-} | 5 | {-# LANGUAGE MultiParamTypeClasses #-} |
6 | 6 | ||
7 | {-# OPTIONS_GHC -fno-warn-orphans #-} | ||
8 | |||
7 | ----------------------------------------------------------------------------- | 9 | ----------------------------------------------------------------------------- |
8 | -- | | 10 | -- | |
9 | -- Module : Numeric.Matrix | 11 | -- Module : Numeric.Matrix |
@@ -35,6 +37,7 @@ import Data.List(partition) | |||
35 | import qualified Data.Foldable as F | 37 | import qualified Data.Foldable as F |
36 | import qualified Data.Semigroup as S | 38 | import qualified Data.Semigroup as S |
37 | import Internal.Chain | 39 | import Internal.Chain |
40 | import Foreign.Storable(Storable) | ||
38 | 41 | ||
39 | 42 | ||
40 | ------------------------------------------------------------------- | 43 | ------------------------------------------------------------------- |
@@ -80,8 +83,16 @@ instance (Floating a, Container Vector a, Floating (Vector a), Fractional (Matri | |||
80 | 83 | ||
81 | -------------------------------------------------------------------------------- | 84 | -------------------------------------------------------------------------------- |
82 | 85 | ||
86 | isScalar :: Matrix t -> Bool | ||
83 | isScalar m = rows m == 1 && cols m == 1 | 87 | isScalar m = rows m == 1 && cols m == 1 |
84 | 88 | ||
89 | adaptScalarM :: (Foreign.Storable.Storable t1, Foreign.Storable.Storable t2) | ||
90 | => (t1 -> Matrix t2 -> t) | ||
91 | -> (Matrix t1 -> Matrix t2 -> t) | ||
92 | -> (Matrix t1 -> t2 -> t) | ||
93 | -> Matrix t1 | ||
94 | -> Matrix t2 | ||
95 | -> t | ||
85 | adaptScalarM f1 f2 f3 x y | 96 | adaptScalarM f1 f2 f3 x y |
86 | | isScalar x = f1 (x @@>(0,0) ) y | 97 | | isScalar x = f1 (x @@>(0,0) ) y |
87 | | isScalar y = f3 x (y @@>(0,0) ) | 98 | | isScalar y = f3 x (y @@>(0,0) ) |
@@ -96,7 +107,7 @@ instance (Container Vector t, Eq t, Num (Vector t), Product t) => M.Monoid (Matr | |||
96 | where | 107 | where |
97 | mempty = 1 | 108 | mempty = 1 |
98 | mappend = adaptScalarM scale mXm (flip scale) | 109 | mappend = adaptScalarM scale mXm (flip scale) |
99 | 110 | ||
100 | mconcat xs = work (partition isScalar xs) | 111 | mconcat xs = work (partition isScalar xs) |
101 | where | 112 | where |
102 | work (ss,[]) = product ss | 113 | work (ss,[]) = product ss |
@@ -106,4 +117,3 @@ instance (Container Vector t, Eq t, Num (Vector t), Product t) => M.Monoid (Matr | |||
106 | | otherwise = scale x00 m | 117 | | otherwise = scale x00 m |
107 | where | 118 | where |
108 | x00 = x @@> (0,0) | 119 | x00 = x @@> (0,0) |
109 | |||
diff --git a/packages/base/src/Numeric/Vector.hs b/packages/base/src/Numeric/Vector.hs index 017196c..1e5877d 100644 --- a/packages/base/src/Numeric/Vector.hs +++ b/packages/base/src/Numeric/Vector.hs | |||
@@ -3,6 +3,9 @@ | |||
3 | {-# LANGUAGE FlexibleInstances #-} | 3 | {-# LANGUAGE FlexibleInstances #-} |
4 | {-# LANGUAGE UndecidableInstances #-} | 4 | {-# LANGUAGE UndecidableInstances #-} |
5 | {-# LANGUAGE MultiParamTypeClasses #-} | 5 | {-# LANGUAGE MultiParamTypeClasses #-} |
6 | |||
7 | {-# OPTIONS_GHC -fno-warn-orphans #-} | ||
8 | |||
6 | ----------------------------------------------------------------------------- | 9 | ----------------------------------------------------------------------------- |
7 | -- | | 10 | -- | |
8 | -- Module : Numeric.Vector | 11 | -- Module : Numeric.Vector |
@@ -14,7 +17,7 @@ | |||
14 | -- | 17 | -- |
15 | -- Provides instances of standard classes 'Show', 'Read', 'Eq', | 18 | -- Provides instances of standard classes 'Show', 'Read', 'Eq', |
16 | -- 'Num', 'Fractional', and 'Floating' for 'Vector'. | 19 | -- 'Num', 'Fractional', and 'Floating' for 'Vector'. |
17 | -- | 20 | -- |
18 | ----------------------------------------------------------------------------- | 21 | ----------------------------------------------------------------------------- |
19 | 22 | ||
20 | module Numeric.Vector () where | 23 | module Numeric.Vector () where |
@@ -23,9 +26,17 @@ import Internal.Vectorized | |||
23 | import Internal.Vector | 26 | import Internal.Vector |
24 | import Internal.Numeric | 27 | import Internal.Numeric |
25 | import Internal.Conversion | 28 | import Internal.Conversion |
29 | import Foreign.Storable(Storable) | ||
26 | 30 | ||
27 | ------------------------------------------------------------------- | 31 | ------------------------------------------------------------------- |
28 | 32 | ||
33 | adaptScalar :: (Foreign.Storable.Storable t1, Foreign.Storable.Storable t2) | ||
34 | => (t1 -> Vector t2 -> t) | ||
35 | -> (Vector t1 -> Vector t2 -> t) | ||
36 | -> (Vector t1 -> t2 -> t) | ||
37 | -> Vector t1 | ||
38 | -> Vector t2 | ||
39 | -> t | ||
29 | adaptScalar f1 f2 f3 x y | 40 | adaptScalar f1 f2 f3 x y |
30 | | dim x == 1 = f1 (x@>0) y | 41 | | dim x == 1 = f1 (x@>0) y |
31 | | dim y == 1 = f3 x (y@>0) | 42 | | dim y == 1 = f3 x (y@>0) |
@@ -172,4 +183,3 @@ instance Floating (Vector (Complex Float)) where | |||
172 | sqrt = vectorMapQ Sqrt | 183 | sqrt = vectorMapQ Sqrt |
173 | (**) = adaptScalar (vectorMapValQ PowSV) (vectorZipQ Pow) (flip (vectorMapValQ PowVS)) | 184 | (**) = adaptScalar (vectorMapValQ PowSV) (vectorZipQ Pow) (flip (vectorMapValQ PowVS)) |
174 | pi = fromList [pi] | 185 | pi = fromList [pi] |
175 | |||