diff options
Diffstat (limited to 'packages')
68 files changed, 2279 insertions, 62 deletions
diff --git a/packages/base/CHANGELOG b/packages/base/CHANGELOG index 11a57dd..fd1c171 100644 --- a/packages/base/CHANGELOG +++ b/packages/base/CHANGELOG | |||
@@ -2,15 +2,15 @@ | |||
2 | -------- | 2 | -------- |
3 | 3 | ||
4 | * Many new functions and instances in the Static module | 4 | * Many new functions and instances in the Static module |
5 | 5 | ||
6 | * meanCov and gaussianSample use Herm type | 6 | * meanCov and gaussianSample use Herm type |
7 | 7 | ||
8 | * thinQR, thinRQ | 8 | * thinQR, thinRQ |
9 | 9 | ||
10 | * compactSVDTol | 10 | * compactSVDTol |
11 | 11 | ||
12 | * unitary changed to normalize, also admits Vector (Complex Double) | 12 | * unitary changed to normalize, also admits Vector (Complex Double) |
13 | 13 | ||
14 | 0.17.0.0 | 14 | 0.17.0.0 |
15 | -------- | 15 | -------- |
16 | 16 | ||
@@ -288,4 +288,3 @@ | |||
288 | * added NFData instances for Matrix and Vector. | 288 | * added NFData instances for Matrix and Vector. |
289 | 289 | ||
290 | * liftVector, liftVector2 replaced by mapVector, zipVector. | 290 | * liftVector, liftVector2 replaced by mapVector, zipVector. |
291 | |||
diff --git a/packages/base/hmatrix.cabal b/packages/base/hmatrix.cabal index 6368683..1380a36 100644 --- a/packages/base/hmatrix.cabal +++ b/packages/base/hmatrix.cabal | |||
@@ -1,5 +1,5 @@ | |||
1 | Name: hmatrix | 1 | Name: hmatrix |
2 | Version: 0.18.2.0 | 2 | Version: 0.19.0.0 |
3 | License: BSD3 | 3 | License: BSD3 |
4 | License-file: LICENSE | 4 | License-file: LICENSE |
5 | Author: Alberto Ruiz | 5 | Author: Alberto Ruiz |
diff --git a/packages/base/src/Internal/Algorithms.hs b/packages/base/src/Internal/Algorithms.hs index 5fe7796..6027c46 100644 --- a/packages/base/src/Internal/Algorithms.hs +++ b/packages/base/src/Internal/Algorithms.hs | |||
@@ -4,6 +4,8 @@ | |||
4 | {-# LANGUAGE UndecidableInstances #-} | 4 | {-# LANGUAGE UndecidableInstances #-} |
5 | {-# LANGUAGE TypeFamilies #-} | 5 | {-# LANGUAGE TypeFamilies #-} |
6 | 6 | ||
7 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
8 | |||
7 | ----------------------------------------------------------------------------- | 9 | ----------------------------------------------------------------------------- |
8 | {- | | 10 | {- | |
9 | Module : Internal.Algorithms | 11 | Module : Internal.Algorithms |
diff --git a/packages/base/src/Internal/CG.hs b/packages/base/src/Internal/CG.hs index cc10ad8..29edd35 100644 --- a/packages/base/src/Internal/CG.hs +++ b/packages/base/src/Internal/CG.hs | |||
@@ -1,6 +1,8 @@ | |||
1 | {-# LANGUAGE FlexibleContexts, FlexibleInstances #-} | 1 | {-# LANGUAGE FlexibleContexts, FlexibleInstances #-} |
2 | {-# LANGUAGE RecordWildCards #-} | 2 | {-# LANGUAGE RecordWildCards #-} |
3 | 3 | ||
4 | {-# OPTIONS_GHC -fno-warn-orphans #-} | ||
5 | |||
4 | module Internal.CG( | 6 | module Internal.CG( |
5 | cgSolve, cgSolve', | 7 | cgSolve, cgSolve', |
6 | CGState(..), R, V | 8 | CGState(..), R, V |
diff --git a/packages/base/src/Internal/Chain.hs b/packages/base/src/Internal/Chain.hs index f87eb02..4000c2b 100644 --- a/packages/base/src/Internal/Chain.hs +++ b/packages/base/src/Internal/Chain.hs | |||
@@ -1,5 +1,7 @@ | |||
1 | {-# LANGUAGE FlexibleContexts #-} | 1 | {-# LANGUAGE FlexibleContexts #-} |
2 | 2 | ||
3 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
4 | |||
3 | ----------------------------------------------------------------------------- | 5 | ----------------------------------------------------------------------------- |
4 | -- | | 6 | -- | |
5 | -- Module : Internal.Chain | 7 | -- Module : Internal.Chain |
diff --git a/packages/base/src/Internal/Container.hs b/packages/base/src/Internal/Container.hs index cdcdad0..41b8214 100644 --- a/packages/base/src/Internal/Container.hs +++ b/packages/base/src/Internal/Container.hs | |||
@@ -5,6 +5,8 @@ | |||
5 | {-# LANGUAGE FunctionalDependencies #-} | 5 | {-# LANGUAGE FunctionalDependencies #-} |
6 | {-# LANGUAGE UndecidableInstances #-} | 6 | {-# LANGUAGE UndecidableInstances #-} |
7 | 7 | ||
8 | {-# OPTIONS_GHC -fno-warn-simplifiable-class-constraints #-} | ||
9 | |||
8 | ----------------------------------------------------------------------------- | 10 | ----------------------------------------------------------------------------- |
9 | -- | | 11 | -- | |
10 | -- Module : Internal.Container | 12 | -- Module : Internal.Container |
diff --git a/packages/base/src/Internal/Devel.hs b/packages/base/src/Internal/Devel.hs index 3887663..f72d8aa 100644 --- a/packages/base/src/Internal/Devel.hs +++ b/packages/base/src/Internal/Devel.hs | |||
@@ -54,6 +54,7 @@ check msg f = do | |||
54 | 54 | ||
55 | -- | postfix error code check | 55 | -- | postfix error code check |
56 | infixl 0 #| | 56 | infixl 0 #| |
57 | (#|) :: IO CInt -> String -> IO () | ||
57 | (#|) = flip check | 58 | (#|) = flip check |
58 | 59 | ||
59 | -- | Error capture and conversion to Maybe | 60 | -- | Error capture and conversion to Maybe |
diff --git a/packages/base/src/Internal/Element.hs b/packages/base/src/Internal/Element.hs index eb3a25b..2e330ee 100644 --- a/packages/base/src/Internal/Element.hs +++ b/packages/base/src/Internal/Element.hs | |||
@@ -4,6 +4,8 @@ | |||
4 | {-# LANGUAGE UndecidableInstances #-} | 4 | {-# LANGUAGE UndecidableInstances #-} |
5 | {-# LANGUAGE MultiParamTypeClasses #-} | 5 | {-# LANGUAGE MultiParamTypeClasses #-} |
6 | 6 | ||
7 | {-# OPTIONS_GHC -fno-warn-orphans #-} | ||
8 | |||
7 | ----------------------------------------------------------------------------- | 9 | ----------------------------------------------------------------------------- |
8 | -- | | 10 | -- | |
9 | -- Module : Data.Packed.Matrix | 11 | -- Module : Data.Packed.Matrix |
@@ -31,6 +33,7 @@ import Data.List.Split(chunksOf) | |||
31 | import Foreign.Storable(Storable) | 33 | import Foreign.Storable(Storable) |
32 | import System.IO.Unsafe(unsafePerformIO) | 34 | import System.IO.Unsafe(unsafePerformIO) |
33 | import Control.Monad(liftM) | 35 | import Control.Monad(liftM) |
36 | import Foreign.C.Types(CInt) | ||
34 | 37 | ||
35 | ------------------------------------------------------------------- | 38 | ------------------------------------------------------------------- |
36 | 39 | ||
@@ -53,8 +56,10 @@ instance (Show a, Element a) => (Show (Matrix a)) where | |||
53 | show m | rows m == 0 || cols m == 0 = sizes m ++" []" | 56 | show m | rows m == 0 || cols m == 0 = sizes m ++" []" |
54 | show m = (sizes m++) . dsp . map (map show) . toLists $ m | 57 | show m = (sizes m++) . dsp . map (map show) . toLists $ m |
55 | 58 | ||
59 | sizes :: Matrix t -> [Char] | ||
56 | sizes m = "("++show (rows m)++"><"++show (cols m)++")\n" | 60 | sizes m = "("++show (rows m)++"><"++show (cols m)++")\n" |
57 | 61 | ||
62 | dsp :: [[[Char]]] -> [Char] | ||
58 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp | 63 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp |
59 | where | 64 | where |
60 | mt = transpose as | 65 | mt = transpose as |
@@ -73,6 +78,7 @@ instance (Element a, Read a) => Read (Matrix a) where | |||
73 | rs = read . snd . breakAt '(' .init . fst . breakAt '>' $ dims | 78 | rs = read . snd . breakAt '(' .init . fst . breakAt '>' $ dims |
74 | 79 | ||
75 | 80 | ||
81 | breakAt :: Eq a => a -> [a] -> ([a], [a]) | ||
76 | breakAt c l = (a++[c],tail b) where | 82 | breakAt c l = (a++[c],tail b) where |
77 | (a,b) = break (==c) l | 83 | (a,b) = break (==c) l |
78 | 84 | ||
@@ -88,7 +94,8 @@ data Extractor | |||
88 | | Drop Int | 94 | | Drop Int |
89 | | DropLast Int | 95 | | DropLast Int |
90 | deriving Show | 96 | deriving Show |
91 | 97 | ||
98 | ppext :: Extractor -> [Char] | ||
92 | ppext All = ":" | 99 | ppext All = ":" |
93 | ppext (Range a 1 c) = printf "%d:%d" a c | 100 | ppext (Range a 1 c) = printf "%d:%d" a c |
94 | ppext (Range a b c) = printf "%d:%d:%d" a b c | 101 | ppext (Range a b c) = printf "%d:%d:%d" a b c |
@@ -128,10 +135,14 @@ ppext (DropLast n) = printf "DropLast %d" n | |||
128 | infixl 9 ?? | 135 | infixl 9 ?? |
129 | (??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t | 136 | (??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t |
130 | 137 | ||
138 | minEl :: Vector CInt -> CInt | ||
131 | minEl = toScalarI Min | 139 | minEl = toScalarI Min |
140 | maxEl :: Vector CInt -> CInt | ||
132 | maxEl = toScalarI Max | 141 | maxEl = toScalarI Max |
142 | cmodi :: Foreign.C.Types.CInt -> Vector Foreign.C.Types.CInt -> Vector Foreign.C.Types.CInt | ||
133 | cmodi = vectorMapValI ModVS | 143 | cmodi = vectorMapValI ModVS |
134 | 144 | ||
145 | extractError :: Matrix t1 -> (Extractor, Extractor) -> t | ||
135 | extractError m (e1,e2)= error $ printf "can't extract (%s,%s) from matrix %dx%d" (ppext e1::String) (ppext e2::String) (rows m) (cols m) | 146 | extractError m (e1,e2)= error $ printf "can't extract (%s,%s) from matrix %dx%d" (ppext e1::String) (ppext e2::String) (rows m) (cols m) |
136 | 147 | ||
137 | m ?? (Range a s b,e) | s /= 1 = m ?? (Pos (idxs [a,a+s .. b]), e) | 148 | m ?? (Range a s b,e) | s /= 1 = m ?? (Pos (idxs [a,a+s .. b]), e) |
@@ -232,8 +243,10 @@ disp = putStr . dispf 2 | |||
232 | fromBlocks :: Element t => [[Matrix t]] -> Matrix t | 243 | fromBlocks :: Element t => [[Matrix t]] -> Matrix t |
233 | fromBlocks = fromBlocksRaw . adaptBlocks | 244 | fromBlocks = fromBlocksRaw . adaptBlocks |
234 | 245 | ||
246 | fromBlocksRaw :: Element t => [[Matrix t]] -> Matrix t | ||
235 | fromBlocksRaw mms = joinVert . map joinHoriz $ mms | 247 | fromBlocksRaw mms = joinVert . map joinHoriz $ mms |
236 | 248 | ||
249 | adaptBlocks :: Element t => [[Matrix t]] -> [[Matrix t]] | ||
237 | adaptBlocks ms = ms' where | 250 | adaptBlocks ms = ms' where |
238 | bc = case common length ms of | 251 | bc = case common length ms of |
239 | Just c -> c | 252 | Just c -> c |
@@ -486,6 +499,9 @@ liftMatrix2Auto f m1 m2 | |||
486 | m2' = conformMTo (r,c) m2 | 499 | m2' = conformMTo (r,c) m2 |
487 | 500 | ||
488 | -- FIXME do not flatten if equal order | 501 | -- FIXME do not flatten if equal order |
502 | lM :: (Storable t, Element t1, Element t2) | ||
503 | => (Vector t1 -> Vector t2 -> Vector t) | ||
504 | -> Matrix t1 -> Matrix t2 -> Matrix t | ||
489 | lM f m1 m2 = matrixFromVector | 505 | lM f m1 m2 = matrixFromVector |
490 | RowMajor | 506 | RowMajor |
491 | (max' (rows m1) (rows m2)) | 507 | (max' (rows m1) (rows m2)) |
@@ -504,6 +520,7 @@ compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 | |||
504 | 520 | ||
505 | ------------------------------------------------------------ | 521 | ------------------------------------------------------------ |
506 | 522 | ||
523 | toBlockRows :: Element t => [Int] -> Matrix t -> [Matrix t] | ||
507 | toBlockRows [r] m | 524 | toBlockRows [r] m |
508 | | r == rows m = [m] | 525 | | r == rows m = [m] |
509 | toBlockRows rs m | 526 | toBlockRows rs m |
@@ -513,6 +530,7 @@ toBlockRows rs m | |||
513 | szs = map (* cols m) rs | 530 | szs = map (* cols m) rs |
514 | g k = (k><0)[] | 531 | g k = (k><0)[] |
515 | 532 | ||
533 | toBlockCols :: Element t => [Int] -> Matrix t -> [Matrix t] | ||
516 | toBlockCols [c] m | c == cols m = [m] | 534 | toBlockCols [c] m | c == cols m = [m] |
517 | toBlockCols cs m = map trans . toBlockRows cs . trans $ m | 535 | toBlockCols cs m = map trans . toBlockRows cs . trans $ m |
518 | 536 | ||
@@ -576,7 +594,7 @@ Just (3><3) | |||
576 | mapMatrixWithIndexM | 594 | mapMatrixWithIndexM |
577 | :: (Element a, Storable b, Monad m) => | 595 | :: (Element a, Storable b, Monad m) => |
578 | ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b) | 596 | ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b) |
579 | mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m | 597 | mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m |
580 | where | 598 | where |
581 | c = cols m | 599 | c = cols m |
582 | 600 | ||
@@ -598,4 +616,3 @@ mapMatrixWithIndex g m = reshape c . mapVectorWithIndex (mk c g) . flatten $ m | |||
598 | 616 | ||
599 | mapMatrix :: (Element a, Element b) => (a -> b) -> Matrix a -> Matrix b | 617 | mapMatrix :: (Element a, Element b) => (a -> b) -> Matrix a -> Matrix b |
600 | mapMatrix f = liftMatrix (mapVector f) | 618 | mapMatrix f = liftMatrix (mapVector f) |
601 | |||
diff --git a/packages/base/src/Internal/IO.hs b/packages/base/src/Internal/IO.hs index a899cfd..b0f5606 100644 --- a/packages/base/src/Internal/IO.hs +++ b/packages/base/src/Internal/IO.hs | |||
@@ -20,7 +20,7 @@ import Internal.Devel | |||
20 | import Internal.Vector | 20 | import Internal.Vector |
21 | import Internal.Matrix | 21 | import Internal.Matrix |
22 | import Internal.Vectorized | 22 | import Internal.Vectorized |
23 | import Text.Printf(printf) | 23 | import Text.Printf(printf, PrintfArg, PrintfType) |
24 | import Data.List(intersperse,transpose) | 24 | import Data.List(intersperse,transpose) |
25 | import Data.Complex | 25 | import Data.Complex |
26 | 26 | ||
@@ -78,12 +78,18 @@ disps d x = sdims x ++ " " ++ formatScaled d x | |||
78 | dispf :: Int -> Matrix Double -> String | 78 | dispf :: Int -> Matrix Double -> String |
79 | dispf d x = sdims x ++ "\n" ++ formatFixed (if isInt x then 0 else d) x | 79 | dispf d x = sdims x ++ "\n" ++ formatFixed (if isInt x then 0 else d) x |
80 | 80 | ||
81 | sdims :: Matrix t -> [Char] | ||
81 | sdims x = show (rows x) ++ "x" ++ show (cols x) | 82 | sdims x = show (rows x) ++ "x" ++ show (cols x) |
82 | 83 | ||
84 | formatFixed :: (Show a, Text.Printf.PrintfArg t, Element t) | ||
85 | => a -> Matrix t -> String | ||
83 | formatFixed d x = format " " (printf ("%."++show d++"f")) $ x | 86 | formatFixed d x = format " " (printf ("%."++show d++"f")) $ x |
84 | 87 | ||
88 | isInt :: Matrix Double -> Bool | ||
85 | isInt = all lookslikeInt . toList . flatten | 89 | isInt = all lookslikeInt . toList . flatten |
86 | 90 | ||
91 | formatScaled :: (Text.Printf.PrintfArg b, RealFrac b, Floating b, Num t, Element b, Show t) | ||
92 | => t -> Matrix b -> [Char] | ||
87 | formatScaled dec t = "E"++show o++"\n" ++ ss | 93 | formatScaled dec t = "E"++show o++"\n" ++ ss |
88 | where ss = format " " (printf fmt. g) t | 94 | where ss = format " " (printf fmt. g) t |
89 | g x | o >= 0 = x/10^(o::Int) | 95 | g x | o >= 0 = x/10^(o::Int) |
@@ -133,14 +139,18 @@ showComplex d (a:+b) | |||
133 | s2 = if b<0 then "-" else "" | 139 | s2 = if b<0 then "-" else "" |
134 | s3 = if b<0 then "-" else "+" | 140 | s3 = if b<0 then "-" else "+" |
135 | 141 | ||
142 | shcr :: (Show a, Show t1, Text.Printf.PrintfType t, Text.Printf.PrintfArg t1, RealFrac t1) | ||
143 | => a -> t1 -> t | ||
136 | shcr d a | lookslikeInt a = printf "%.0f" a | 144 | shcr d a | lookslikeInt a = printf "%.0f" a |
137 | | otherwise = printf ("%."++show d++"f") a | 145 | | otherwise = printf ("%."++show d++"f") a |
138 | 146 | ||
139 | 147 | lookslikeInt :: (Show a, RealFrac a) => a -> Bool | |
140 | lookslikeInt x = show (round x :: Int) ++".0" == shx || "-0.0" == shx | 148 | lookslikeInt x = show (round x :: Int) ++".0" == shx || "-0.0" == shx |
141 | where shx = show x | 149 | where shx = show x |
142 | 150 | ||
151 | isZero :: Show a => a -> Bool | ||
143 | isZero x = show x `elem` ["0.0","-0.0"] | 152 | isZero x = show x `elem` ["0.0","-0.0"] |
153 | isOne :: Show a => a -> Bool | ||
144 | isOne x = show x `elem` ["1.0","-1.0"] | 154 | isOne x = show x `elem` ["1.0","-1.0"] |
145 | 155 | ||
146 | -- | Pretty print a complex matrix with at most n decimal digits. | 156 | -- | Pretty print a complex matrix with at most n decimal digits. |
@@ -168,6 +178,6 @@ loadMatrix f = do | |||
168 | else | 178 | else |
169 | return (reshape c v) | 179 | return (reshape c v) |
170 | 180 | ||
171 | 181 | loadMatrix' :: FilePath -> IO (Maybe (Matrix Double)) | |
172 | loadMatrix' name = mbCatch (loadMatrix name) | 182 | loadMatrix' name = mbCatch (loadMatrix name) |
173 | 183 | ||
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index e306454..64cf2f5 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs | |||
@@ -1,6 +1,8 @@ | |||
1 | {-# LANGUAGE TypeOperators #-} | 1 | {-# LANGUAGE TypeOperators #-} |
2 | {-# LANGUAGE ViewPatterns #-} | 2 | {-# LANGUAGE ViewPatterns #-} |
3 | 3 | ||
4 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
5 | |||
4 | ----------------------------------------------------------------------------- | 6 | ----------------------------------------------------------------------------- |
5 | -- | | 7 | -- | |
6 | -- Module : Numeric.LinearAlgebra.LAPACK | 8 | -- Module : Numeric.LinearAlgebra.LAPACK |
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 2856ec2..5436e59 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -57,19 +57,24 @@ cols :: Matrix t -> Int | |||
57 | cols = icols | 57 | cols = icols |
58 | {-# INLINE cols #-} | 58 | {-# INLINE cols #-} |
59 | 59 | ||
60 | size :: Matrix t -> (Int, Int) | ||
60 | size m = (irows m, icols m) | 61 | size m = (irows m, icols m) |
61 | {-# INLINE size #-} | 62 | {-# INLINE size #-} |
62 | 63 | ||
64 | rowOrder :: Matrix t -> Bool | ||
63 | rowOrder m = xCol m == 1 || cols m == 1 | 65 | rowOrder m = xCol m == 1 || cols m == 1 |
64 | {-# INLINE rowOrder #-} | 66 | {-# INLINE rowOrder #-} |
65 | 67 | ||
68 | colOrder :: Matrix t -> Bool | ||
66 | colOrder m = xRow m == 1 || rows m == 1 | 69 | colOrder m = xRow m == 1 || rows m == 1 |
67 | {-# INLINE colOrder #-} | 70 | {-# INLINE colOrder #-} |
68 | 71 | ||
72 | is1d :: Matrix t -> Bool | ||
69 | is1d (size->(r,c)) = r==1 || c==1 | 73 | is1d (size->(r,c)) = r==1 || c==1 |
70 | {-# INLINE is1d #-} | 74 | {-# INLINE is1d #-} |
71 | 75 | ||
72 | -- data is not contiguous | 76 | -- data is not contiguous |
77 | isSlice :: Storable t => Matrix t -> Bool | ||
73 | isSlice m@(size->(r,c)) = r*c < dim (xdat m) | 78 | isSlice m@(size->(r,c)) = r*c < dim (xdat m) |
74 | {-# INLINE isSlice #-} | 79 | {-# INLINE isSlice #-} |
75 | 80 | ||
@@ -136,16 +141,20 @@ instance Storable t => TransArray (Matrix t) | |||
136 | {-# INLINE applyRaw #-} | 141 | {-# INLINE applyRaw #-} |
137 | 142 | ||
138 | infixr 1 # | 143 | infixr 1 # |
144 | (#) :: TransArray c => c -> (b -> IO r) -> Trans c b -> IO r | ||
139 | a # b = apply a b | 145 | a # b = apply a b |
140 | {-# INLINE (#) #-} | 146 | {-# INLINE (#) #-} |
141 | 147 | ||
148 | (#!) :: (TransArray c, TransArray c1) => c1 -> c -> Trans c1 (Trans c (IO r)) -> IO r | ||
142 | a #! b = a # b # id | 149 | a #! b = a # b # id |
143 | {-# INLINE (#!) #-} | 150 | {-# INLINE (#!) #-} |
144 | 151 | ||
145 | -------------------------------------------------------------------------------- | 152 | -------------------------------------------------------------------------------- |
146 | 153 | ||
154 | copy :: Element t => MatrixOrder -> Matrix t -> IO (Matrix t) | ||
147 | copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) | 155 | copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) |
148 | 156 | ||
157 | extractAll :: Element t => MatrixOrder -> Matrix t -> Matrix t | ||
149 | extractAll ord m = unsafePerformIO (copy ord m) | 158 | extractAll ord m = unsafePerformIO (copy ord m) |
150 | 159 | ||
151 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. | 160 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. |
@@ -224,11 +233,13 @@ m@Matrix {irows = r, icols = c} @@> (i,j) | |||
224 | {-# INLINE (@@>) #-} | 233 | {-# INLINE (@@>) #-} |
225 | 234 | ||
226 | -- Unsafe matrix access without range checking | 235 | -- Unsafe matrix access without range checking |
236 | atM' :: Storable t => Matrix t -> Int -> Int -> t | ||
227 | atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m)) | 237 | atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m)) |
228 | {-# INLINE atM' #-} | 238 | {-# INLINE atM' #-} |
229 | 239 | ||
230 | ------------------------------------------------------------------ | 240 | ------------------------------------------------------------------ |
231 | 241 | ||
242 | matrixFromVector :: Storable t => MatrixOrder -> Int -> Int -> Vector t -> Matrix t | ||
232 | matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } | 243 | matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } |
233 | matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } | 244 | matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } |
234 | matrixFromVector o r c v | 245 | matrixFromVector o r c v |
@@ -388,18 +399,21 @@ subMatrix (r0,c0) (rt,ct) m | |||
388 | 399 | ||
389 | -------------------------------------------------------------------------- | 400 | -------------------------------------------------------------------------- |
390 | 401 | ||
402 | maxZ :: (Num t1, Ord t1, Foldable t) => t t1 -> t1 | ||
391 | maxZ xs = if minimum xs == 0 then 0 else maximum xs | 403 | maxZ xs = if minimum xs == 0 then 0 else maximum xs |
392 | 404 | ||
405 | conformMs :: Element t => [Matrix t] -> [Matrix t] | ||
393 | conformMs ms = map (conformMTo (r,c)) ms | 406 | conformMs ms = map (conformMTo (r,c)) ms |
394 | where | 407 | where |
395 | r = maxZ (map rows ms) | 408 | r = maxZ (map rows ms) |
396 | c = maxZ (map cols ms) | 409 | c = maxZ (map cols ms) |
397 | 410 | ||
398 | 411 | conformVs :: Element t => [Vector t] -> [Vector t] | |
399 | conformVs vs = map (conformVTo n) vs | 412 | conformVs vs = map (conformVTo n) vs |
400 | where | 413 | where |
401 | n = maxZ (map dim vs) | 414 | n = maxZ (map dim vs) |
402 | 415 | ||
416 | conformMTo :: Element t => (Int, Int) -> Matrix t -> Matrix t | ||
403 | conformMTo (r,c) m | 417 | conformMTo (r,c) m |
404 | | size m == (r,c) = m | 418 | | size m == (r,c) = m |
405 | | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) | 419 | | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) |
@@ -407,18 +421,24 @@ conformMTo (r,c) m | |||
407 | | size m == (1,c) = repRows r m | 421 | | size m == (1,c) = repRows r m |
408 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c) | 422 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c) |
409 | 423 | ||
424 | conformVTo :: Element t => Int -> Vector t -> Vector t | ||
410 | conformVTo n v | 425 | conformVTo n v |
411 | | dim v == n = v | 426 | | dim v == n = v |
412 | | dim v == 1 = constantD (v@>0) n | 427 | | dim v == 1 = constantD (v@>0) n |
413 | | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n | 428 | | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n |
414 | 429 | ||
430 | repRows :: Element t => Int -> Matrix t -> Matrix t | ||
415 | repRows n x = fromRows (replicate n (flatten x)) | 431 | repRows n x = fromRows (replicate n (flatten x)) |
432 | repCols :: Element t => Int -> Matrix t -> Matrix t | ||
416 | repCols n x = fromColumns (replicate n (flatten x)) | 433 | repCols n x = fromColumns (replicate n (flatten x)) |
417 | 434 | ||
435 | shSize :: Matrix t -> [Char] | ||
418 | shSize = shDim . size | 436 | shSize = shDim . size |
419 | 437 | ||
438 | shDim :: (Show a, Show a1) => (a1, a) -> [Char] | ||
420 | shDim (r,c) = "(" ++ show r ++"x"++ show c ++")" | 439 | shDim (r,c) = "(" ++ show r ++"x"++ show c ++")" |
421 | 440 | ||
441 | emptyM :: Storable t => Int -> Int -> Matrix t | ||
422 | emptyM r c = matrixFromVector RowMajor r c (fromList[]) | 442 | emptyM r c = matrixFromVector RowMajor r c (fromList[]) |
423 | 443 | ||
424 | ---------------------------------------------------------------------- | 444 | ---------------------------------------------------------------------- |
@@ -433,6 +453,11 @@ instance (Storable t, NFData t) => NFData (Matrix t) | |||
433 | 453 | ||
434 | --------------------------------------------------------------- | 454 | --------------------------------------------------------------- |
435 | 455 | ||
456 | extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1, | ||
457 | Storable t, Num t3, Num t2, Integral t1, Integral t) | ||
458 | => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t | ||
459 | -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) | ||
460 | -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a) | ||
436 | extractAux f ord m moder vr modec vc = do | 461 | extractAux f ord m moder vr modec vc = do |
437 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr | 462 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr |
438 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc | 463 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc |
@@ -452,6 +477,9 @@ foreign import ccall unsafe "extractL" c_extractL :: Extr Z | |||
452 | 477 | ||
453 | --------------------------------------------------------------- | 478 | --------------------------------------------------------------- |
454 | 479 | ||
480 | setRectAux :: (TransArray c1, TransArray c) | ||
481 | => (CInt -> CInt -> Trans c1 (Trans c (IO CInt))) | ||
482 | -> Int -> Int -> c1 -> c -> IO () | ||
455 | setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" | 483 | setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" |
456 | 484 | ||
457 | type SetRect x = I -> I -> x ::> x::> Ok | 485 | type SetRect x = I -> I -> x ::> x::> Ok |
@@ -465,19 +493,29 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z | |||
465 | 493 | ||
466 | -------------------------------------------------------------------------------- | 494 | -------------------------------------------------------------------------------- |
467 | 495 | ||
496 | sortG :: (Storable t, Storable a) | ||
497 | => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a | ||
468 | sortG f v = unsafePerformIO $ do | 498 | sortG f v = unsafePerformIO $ do |
469 | r <- createVector (dim v) | 499 | r <- createVector (dim v) |
470 | (v #! r) f #|"sortG" | 500 | (v #! r) f #|"sortG" |
471 | return r | 501 | return r |
472 | 502 | ||
503 | sortIdxD :: Vector Double -> Vector CInt | ||
473 | sortIdxD = sortG c_sort_indexD | 504 | sortIdxD = sortG c_sort_indexD |
505 | sortIdxF :: Vector Float -> Vector CInt | ||
474 | sortIdxF = sortG c_sort_indexF | 506 | sortIdxF = sortG c_sort_indexF |
507 | sortIdxI :: Vector CInt -> Vector CInt | ||
475 | sortIdxI = sortG c_sort_indexI | 508 | sortIdxI = sortG c_sort_indexI |
509 | sortIdxL :: Vector Z -> Vector I | ||
476 | sortIdxL = sortG c_sort_indexL | 510 | sortIdxL = sortG c_sort_indexL |
477 | 511 | ||
512 | sortValD :: Vector Double -> Vector Double | ||
478 | sortValD = sortG c_sort_valD | 513 | sortValD = sortG c_sort_valD |
514 | sortValF :: Vector Float -> Vector Float | ||
479 | sortValF = sortG c_sort_valF | 515 | sortValF = sortG c_sort_valF |
516 | sortValI :: Vector CInt -> Vector CInt | ||
480 | sortValI = sortG c_sort_valI | 517 | sortValI = sortG c_sort_valI |
518 | sortValL :: Vector Z -> Vector Z | ||
481 | sortValL = sortG c_sort_valL | 519 | sortValL = sortG c_sort_valL |
482 | 520 | ||
483 | foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) | 521 | foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) |
@@ -492,14 +530,21 @@ foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok | |||
492 | 530 | ||
493 | -------------------------------------------------------------------------------- | 531 | -------------------------------------------------------------------------------- |
494 | 532 | ||
533 | compareG :: (TransArray c, Storable t, Storable a) | ||
534 | => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) | ||
535 | -> c -> Vector t -> Vector a | ||
495 | compareG f u v = unsafePerformIO $ do | 536 | compareG f u v = unsafePerformIO $ do |
496 | r <- createVector (dim v) | 537 | r <- createVector (dim v) |
497 | (u # v #! r) f #|"compareG" | 538 | (u # v #! r) f #|"compareG" |
498 | return r | 539 | return r |
499 | 540 | ||
541 | compareD :: Vector Double -> Vector Double -> Vector CInt | ||
500 | compareD = compareG c_compareD | 542 | compareD = compareG c_compareD |
543 | compareF :: Vector Float -> Vector Float -> Vector CInt | ||
501 | compareF = compareG c_compareF | 544 | compareF = compareG c_compareF |
545 | compareI :: Vector CInt -> Vector CInt -> Vector CInt | ||
502 | compareI = compareG c_compareI | 546 | compareI = compareG c_compareI |
547 | compareL :: Vector Z -> Vector Z -> Vector CInt | ||
503 | compareL = compareG c_compareL | 548 | compareL = compareG c_compareL |
504 | 549 | ||
505 | foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) | 550 | foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) |
@@ -509,16 +554,33 @@ foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok | |||
509 | 554 | ||
510 | -------------------------------------------------------------------------------- | 555 | -------------------------------------------------------------------------------- |
511 | 556 | ||
557 | selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a) | ||
558 | => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt))) | ||
559 | -> c2 -> c1 -> Vector t -> c -> Vector a | ||
512 | selectG f c u v w = unsafePerformIO $ do | 560 | selectG f c u v w = unsafePerformIO $ do |
513 | r <- createVector (dim v) | 561 | r <- createVector (dim v) |
514 | (c # u # v # w #! r) f #|"selectG" | 562 | (c # u # v # w #! r) f #|"selectG" |
515 | return r | 563 | return r |
516 | 564 | ||
565 | selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double | ||
517 | selectD = selectG c_selectD | 566 | selectD = selectG c_selectD |
567 | selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float | ||
518 | selectF = selectG c_selectF | 568 | selectF = selectG c_selectF |
569 | selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt | ||
519 | selectI = selectG c_selectI | 570 | selectI = selectG c_selectI |
571 | selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z | ||
520 | selectL = selectG c_selectL | 572 | selectL = selectG c_selectL |
573 | selectC :: Vector CInt | ||
574 | -> Vector (Complex Double) | ||
575 | -> Vector (Complex Double) | ||
576 | -> Vector (Complex Double) | ||
577 | -> Vector (Complex Double) | ||
521 | selectC = selectG c_selectC | 578 | selectC = selectG c_selectC |
579 | selectQ :: Vector CInt | ||
580 | -> Vector (Complex Float) | ||
581 | -> Vector (Complex Float) | ||
582 | -> Vector (Complex Float) | ||
583 | -> Vector (Complex Float) | ||
522 | selectQ = selectG c_selectQ | 584 | selectQ = selectG c_selectQ |
523 | 585 | ||
524 | type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) | 586 | type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) |
@@ -532,16 +594,29 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z | |||
532 | 594 | ||
533 | --------------------------------------------------------------------------- | 595 | --------------------------------------------------------------------------- |
534 | 596 | ||
597 | remapG :: (TransArray c, TransArray c1, Storable t, Storable a) | ||
598 | => (CInt -> CInt -> CInt -> CInt -> Ptr t | ||
599 | -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))) | ||
600 | -> Matrix t -> c1 -> c -> Matrix a | ||
535 | remapG f i j m = unsafePerformIO $ do | 601 | remapG f i j m = unsafePerformIO $ do |
536 | r <- createMatrix RowMajor (rows i) (cols i) | 602 | r <- createMatrix RowMajor (rows i) (cols i) |
537 | (i # j # m #! r) f #|"remapG" | 603 | (i # j # m #! r) f #|"remapG" |
538 | return r | 604 | return r |
539 | 605 | ||
606 | remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double | ||
540 | remapD = remapG c_remapD | 607 | remapD = remapG c_remapD |
608 | remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float | ||
541 | remapF = remapG c_remapF | 609 | remapF = remapG c_remapF |
610 | remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt | ||
542 | remapI = remapG c_remapI | 611 | remapI = remapG c_remapI |
612 | remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z | ||
543 | remapL = remapG c_remapL | 613 | remapL = remapG c_remapL |
614 | remapC :: Matrix CInt | ||
615 | -> Matrix CInt | ||
616 | -> Matrix (Complex Double) | ||
617 | -> Matrix (Complex Double) | ||
544 | remapC = remapG c_remapC | 618 | remapC = remapG c_remapC |
619 | remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float) | ||
545 | remapQ = remapG c_remapQ | 620 | remapQ = remapG c_remapQ |
546 | 621 | ||
547 | type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) | 622 | type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) |
@@ -555,6 +630,9 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z | |||
555 | 630 | ||
556 | -------------------------------------------------------------------------------- | 631 | -------------------------------------------------------------------------------- |
557 | 632 | ||
633 | rowOpAux :: (TransArray c, Storable a) => | ||
634 | (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt)) | ||
635 | -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO () | ||
558 | rowOpAux f c x i1 i2 j1 j2 m = do | 636 | rowOpAux f c x i1 i2 j1 j2 m = do |
559 | px <- newArray [x] | 637 | px <- newArray [x] |
560 | (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" | 638 | (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" |
@@ -573,6 +651,9 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z | |||
573 | 651 | ||
574 | -------------------------------------------------------------------------------- | 652 | -------------------------------------------------------------------------------- |
575 | 653 | ||
654 | gemmg :: (TransArray c1, TransArray c, TransArray c2, TransArray c3) | ||
655 | => Trans c3 (Trans c2 (Trans c1 (Trans c (IO CInt)))) | ||
656 | -> c3 -> c2 -> c1 -> c -> IO () | ||
576 | gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" | 657 | gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" |
577 | 658 | ||
578 | type Tgemm x = x :> x ::> x ::> x ::> Ok | 659 | type Tgemm x = x :> x ::> x ::> x ::> Ok |
@@ -588,6 +669,10 @@ foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z | |||
588 | 669 | ||
589 | -------------------------------------------------------------------------------- | 670 | -------------------------------------------------------------------------------- |
590 | 671 | ||
672 | reorderAux :: (TransArray c, Storable t, Storable a1, Storable t1, Storable a) => | ||
673 | (CInt -> Ptr a -> CInt -> Ptr t1 | ||
674 | -> Trans c (CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt)) | ||
675 | -> Vector t1 -> c -> Vector t -> Vector a1 | ||
591 | reorderAux f s d v = unsafePerformIO $ do | 676 | reorderAux f s d v = unsafePerformIO $ do |
592 | k <- createVector (dim s) | 677 | k <- createVector (dim s) |
593 | r <- createVector (dim v) | 678 | r <- createVector (dim v) |
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index 9d51444..eb0c5a8 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs | |||
@@ -13,6 +13,9 @@ | |||
13 | {-# LANGUAGE TypeFamilies #-} | 13 | {-# LANGUAGE TypeFamilies #-} |
14 | {-# LANGUAGE TypeOperators #-} | 14 | {-# LANGUAGE TypeOperators #-} |
15 | 15 | ||
16 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
17 | {-# OPTIONS_GHC -fno-warn-missing-methods #-} | ||
18 | |||
16 | {- | | 19 | {- | |
17 | Module : Internal.Modular | 20 | Module : Internal.Modular |
18 | Copyright : (c) Alberto Ruiz 2015 | 21 | Copyright : (c) Alberto Ruiz 2015 |
diff --git a/packages/base/src/Internal/Numeric.hs b/packages/base/src/Internal/Numeric.hs index c9ef0c5..fd0a217 100644 --- a/packages/base/src/Internal/Numeric.hs +++ b/packages/base/src/Internal/Numeric.hs | |||
@@ -5,6 +5,8 @@ | |||
5 | {-# LANGUAGE FunctionalDependencies #-} | 5 | {-# LANGUAGE FunctionalDependencies #-} |
6 | {-# LANGUAGE UndecidableInstances #-} | 6 | {-# LANGUAGE UndecidableInstances #-} |
7 | 7 | ||
8 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
9 | |||
8 | ----------------------------------------------------------------------------- | 10 | ----------------------------------------------------------------------------- |
9 | -- | | 11 | -- | |
10 | -- Module : Data.Packed.Internal.Numeric | 12 | -- Module : Data.Packed.Internal.Numeric |
@@ -788,13 +790,7 @@ type instance RealOf (Complex Float) = Float | |||
788 | type instance RealOf I = I | 790 | type instance RealOf I = I |
789 | type instance RealOf Z = Z | 791 | type instance RealOf Z = Z |
790 | 792 | ||
791 | type family ComplexOf x | 793 | type ComplexOf x = Complex (RealOf x) |
792 | |||
793 | type instance ComplexOf Double = Complex Double | ||
794 | type instance ComplexOf (Complex Double) = Complex Double | ||
795 | |||
796 | type instance ComplexOf Float = Complex Float | ||
797 | type instance ComplexOf (Complex Float) = Complex Float | ||
798 | 794 | ||
799 | type family SingleOf x | 795 | type family SingleOf x |
800 | 796 | ||
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index 544c9e4..7d54e6d 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs | |||
@@ -81,6 +81,8 @@ unsafeFreezeVector :: (Storable t) => STVector s t -> ST s (Vector t) | |||
81 | unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x | 81 | unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x |
82 | 82 | ||
83 | {-# INLINE safeIndexV #-} | 83 | {-# INLINE safeIndexV #-} |
84 | safeIndexV :: Storable t2 | ||
85 | => (STVector s t2 -> Int -> t) -> STVector t1 t2 -> Int -> t | ||
84 | safeIndexV f (STVector v) k | 86 | safeIndexV f (STVector v) k |
85 | | k < 0 || k>= dim v = error $ "out of range error in vector (dim=" | 87 | | k < 0 || k>= dim v = error $ "out of range error in vector (dim=" |
86 | ++show (dim v)++", pos="++show k++")" | 88 | ++show (dim v)++", pos="++show k++")" |
@@ -150,9 +152,12 @@ unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x | |||
150 | freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t) | 152 | freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t) |
151 | freezeMatrix m = liftSTMatrix id m | 153 | freezeMatrix m = liftSTMatrix id m |
152 | 154 | ||
155 | cloneMatrix :: Element t => Matrix t -> IO (Matrix t) | ||
153 | cloneMatrix m = copy (orderOf m) m | 156 | cloneMatrix m = copy (orderOf m) m |
154 | 157 | ||
155 | {-# INLINE safeIndexM #-} | 158 | {-# INLINE safeIndexM #-} |
159 | safeIndexM :: (STMatrix s t2 -> Int -> Int -> t) | ||
160 | -> STMatrix t1 t2 -> Int -> Int -> t | ||
156 | safeIndexM f (STMatrix m) r c | 161 | safeIndexM f (STMatrix m) r c |
157 | | r<0 || r>=rows m || | 162 | | r<0 || r>=rows m || |
158 | c<0 || c>=cols m = error $ "out of range error in matrix (size=" | 163 | c<0 || c>=cols m = error $ "out of range error in matrix (size=" |
@@ -184,6 +189,7 @@ data ColRange = AllCols | |||
184 | | Col Int | 189 | | Col Int |
185 | | FromCol Int | 190 | | FromCol Int |
186 | 191 | ||
192 | getColRange :: Int -> ColRange -> (Int, Int) | ||
187 | getColRange c AllCols = (0,c-1) | 193 | getColRange c AllCols = (0,c-1) |
188 | getColRange c (ColRange a b) = (a `mod` c, b `mod` c) | 194 | getColRange c (ColRange a b) = (a `mod` c, b `mod` c) |
189 | getColRange c (Col a) = (a `mod` c, a `mod` c) | 195 | getColRange c (Col a) = (a `mod` c, a `mod` c) |
@@ -194,6 +200,7 @@ data RowRange = AllRows | |||
194 | | Row Int | 200 | | Row Int |
195 | | FromRow Int | 201 | | FromRow Int |
196 | 202 | ||
203 | getRowRange :: Int -> RowRange -> (Int, Int) | ||
197 | getRowRange r AllRows = (0,r-1) | 204 | getRowRange r AllRows = (0,r-1) |
198 | getRowRange r (RowRange a b) = (a `mod` r, b `mod` r) | 205 | getRowRange r (RowRange a b) = (a `mod` r, b `mod` r) |
199 | getRowRange r (Row a) = (a `mod` r, a `mod` r) | 206 | getRowRange r (Row a) = (a `mod` r, a `mod` r) |
@@ -223,6 +230,7 @@ rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m | |||
223 | i2' = i2 `mod` (rows m) | 230 | i2' = i2 `mod` (rows m) |
224 | 231 | ||
225 | 232 | ||
233 | extractMatrix :: Element a => STMatrix t a -> RowRange -> ColRange -> ST s (Matrix a) | ||
226 | extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) | 234 | extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) |
227 | where | 235 | where |
228 | (i1,i2) = getRowRange (rows m) rr | 236 | (i1,i2) = getRowRange (rows m) rr |
@@ -231,6 +239,7 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[ | |||
231 | -- | r0 c0 height width | 239 | -- | r0 c0 height width |
232 | data Slice s t = Slice (STMatrix s t) Int Int Int Int | 240 | data Slice s t = Slice (STMatrix s t) Int Int Int Int |
233 | 241 | ||
242 | slice :: Element a => Slice t a -> Matrix a | ||
234 | slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m | 243 | slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m |
235 | 244 | ||
236 | gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () | 245 | gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () |
@@ -238,7 +247,7 @@ gemmm beta (slice->r) alpha (slice->a) (slice->b) = res | |||
238 | where | 247 | where |
239 | res = unsafeIOToST (gemm v a b r) | 248 | res = unsafeIOToST (gemm v a b r) |
240 | v = fromList [alpha,beta] | 249 | v = fromList [alpha,beta] |
241 | 250 | ||
242 | 251 | ||
243 | mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) | 252 | mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) |
244 | mutable f a = runST $ do | 253 | mutable f a = runST $ do |
@@ -246,4 +255,3 @@ mutable f a = runST $ do | |||
246 | info <- f (rows a, cols a) x | 255 | info <- f (rows a, cols a) x |
247 | r <- unsafeFreezeMatrix x | 256 | r <- unsafeFreezeMatrix x |
248 | return (r,info) | 257 | return (r,info) |
249 | |||
diff --git a/packages/base/src/Internal/Sparse.hs b/packages/base/src/Internal/Sparse.hs index a8a5fe0..fbea11a 100644 --- a/packages/base/src/Internal/Sparse.hs +++ b/packages/base/src/Internal/Sparse.hs | |||
@@ -2,6 +2,8 @@ | |||
2 | {-# LANGUAGE MultiParamTypeClasses #-} | 2 | {-# LANGUAGE MultiParamTypeClasses #-} |
3 | {-# LANGUAGE FlexibleInstances #-} | 3 | {-# LANGUAGE FlexibleInstances #-} |
4 | 4 | ||
5 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
6 | |||
5 | module Internal.Sparse( | 7 | module Internal.Sparse( |
6 | GMatrix(..), CSR(..), mkCSR, fromCSR, | 8 | GMatrix(..), CSR(..), mkCSR, fromCSR, |
7 | mkSparse, mkDiagR, mkDense, | 9 | mkSparse, mkDiagR, mkDense, |
diff --git a/packages/base/src/Internal/Static.hs b/packages/base/src/Internal/Static.hs index 357645e..566506c 100644 --- a/packages/base/src/Internal/Static.hs +++ b/packages/base/src/Internal/Static.hs | |||
@@ -15,6 +15,9 @@ | |||
15 | {-# LANGUAGE BangPatterns #-} | 15 | {-# LANGUAGE BangPatterns #-} |
16 | {-# LANGUAGE DeriveGeneric #-} | 16 | {-# LANGUAGE DeriveGeneric #-} |
17 | 17 | ||
18 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
19 | {-# OPTIONS_GHC -fno-warn-simplifiable-class-constraints #-} | ||
20 | |||
18 | {- | | 21 | {- | |
19 | Module : Internal.Static | 22 | Module : Internal.Static |
20 | Copyright : (c) Alberto Ruiz 2006-14 | 23 | Copyright : (c) Alberto Ruiz 2006-14 |
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs index 959e58f..f642e8d 100644 --- a/packages/base/src/Internal/Util.hs +++ b/packages/base/src/Internal/Util.hs | |||
@@ -6,6 +6,8 @@ | |||
6 | {-# LANGUAGE ScopedTypeVariables #-} | 6 | {-# LANGUAGE ScopedTypeVariables #-} |
7 | {-# LANGUAGE ViewPatterns #-} | 7 | {-# LANGUAGE ViewPatterns #-} |
8 | 8 | ||
9 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
10 | {-# OPTIONS_GHC -fno-warn-orphans #-} | ||
9 | 11 | ||
10 | ----------------------------------------------------------------------------- | 12 | ----------------------------------------------------------------------------- |
11 | {- | | 13 | {- | |
diff --git a/packages/base/src/Internal/Vector.hs b/packages/base/src/Internal/Vector.hs index e1e4aa8..6271bb6 100644 --- a/packages/base/src/Internal/Vector.hs +++ b/packages/base/src/Internal/Vector.hs | |||
@@ -1,6 +1,7 @@ | |||
1 | {-# LANGUAGE MagicHash, UnboxedTuples, BangPatterns, FlexibleContexts #-} | 1 | {-# LANGUAGE MagicHash, UnboxedTuples, BangPatterns, FlexibleContexts #-} |
2 | {-# LANGUAGE TypeSynonymInstances #-} | 2 | {-# LANGUAGE TypeSynonymInstances #-} |
3 | 3 | ||
4 | {-# OPTIONS_GHC -fno-warn-orphans #-} | ||
4 | 5 | ||
5 | -- | | 6 | -- | |
6 | -- Module : Internal.Vector | 7 | -- Module : Internal.Vector |
@@ -40,6 +41,7 @@ import qualified Data.Vector.Storable as Vector | |||
40 | import Data.Vector.Storable(Vector, fromList, unsafeToForeignPtr, unsafeFromForeignPtr, unsafeWith) | 41 | import Data.Vector.Storable(Vector, fromList, unsafeToForeignPtr, unsafeFromForeignPtr, unsafeWith) |
41 | 42 | ||
42 | import Data.Binary | 43 | import Data.Binary |
44 | import Data.Binary.Put | ||
43 | import Control.Monad(replicateM) | 45 | import Control.Monad(replicateM) |
44 | import qualified Data.ByteString.Internal as BS | 46 | import qualified Data.ByteString.Internal as BS |
45 | import Data.Vector.Storable.Internal(updPtr) | 47 | import Data.Vector.Storable.Internal(updPtr) |
@@ -92,6 +94,7 @@ createVector n = do | |||
92 | 94 | ||
93 | -} | 95 | -} |
94 | 96 | ||
97 | safeRead :: Storable a => Vector a -> (Ptr a -> IO c) -> c | ||
95 | safeRead v = inlinePerformIO . unsafeWith v | 98 | safeRead v = inlinePerformIO . unsafeWith v |
96 | {-# INLINE safeRead #-} | 99 | {-# INLINE safeRead #-} |
97 | 100 | ||
@@ -287,11 +290,13 @@ foldVectorWithIndex f x v = unsafePerformIO $ | |||
287 | go (dim v -1) x | 290 | go (dim v -1) x |
288 | {-# INLINE foldVectorWithIndex #-} | 291 | {-# INLINE foldVectorWithIndex #-} |
289 | 292 | ||
293 | foldLoop :: (Int -> t -> t) -> t -> Int -> t | ||
290 | foldLoop f s0 d = go (d - 1) s0 | 294 | foldLoop f s0 d = go (d - 1) s0 |
291 | where | 295 | where |
292 | go 0 s = f (0::Int) s | 296 | go 0 s = f (0::Int) s |
293 | go !j !s = go (j - 1) (f j s) | 297 | go !j !s = go (j - 1) (f j s) |
294 | 298 | ||
299 | foldVectorG :: Storable t1 => (Int -> (Int -> t1) -> t -> t) -> t -> Vector t1 -> t | ||
295 | foldVectorG f s0 v = foldLoop g s0 (dim v) | 300 | foldVectorG f s0 v = foldLoop g s0 (dim v) |
296 | where g !k !s = f k (safeRead v . flip peekElemOff) s | 301 | where g !k !s = f k (safeRead v . flip peekElemOff) s |
297 | {-# INLINE g #-} -- Thanks to Ryan Ingram (http://permalink.gmane.org/gmane.comp.lang.haskell.cafe/46479) | 302 | {-# INLINE g #-} -- Thanks to Ryan Ingram (http://permalink.gmane.org/gmane.comp.lang.haskell.cafe/46479) |
@@ -394,8 +399,10 @@ chunks d = let c = d `div` chunk | |||
394 | m = d `mod` chunk | 399 | m = d `mod` chunk |
395 | in if m /= 0 then reverse (m:(replicate c chunk)) else (replicate c chunk) | 400 | in if m /= 0 then reverse (m:(replicate c chunk)) else (replicate c chunk) |
396 | 401 | ||
402 | putVector :: (Storable t, Binary t) => Vector t -> Data.Binary.Put.PutM () | ||
397 | putVector v = mapM_ put $! toList v | 403 | putVector v = mapM_ put $! toList v |
398 | 404 | ||
405 | getVector :: (Storable a, Binary a) => Int -> Get (Vector a) | ||
399 | getVector d = do | 406 | getVector d = do |
400 | xs <- replicateM d get | 407 | xs <- replicateM d get |
401 | return $! fromList xs | 408 | return $! fromList xs |
diff --git a/packages/base/src/Internal/Vectorized.hs b/packages/base/src/Internal/Vectorized.hs index 2990173..32430c6 100644 --- a/packages/base/src/Internal/Vectorized.hs +++ b/packages/base/src/Internal/Vectorized.hs | |||
@@ -28,12 +28,15 @@ import System.IO.Unsafe(unsafePerformIO) | |||
28 | import Control.Monad(when) | 28 | import Control.Monad(when) |
29 | 29 | ||
30 | infixr 1 # | 30 | infixr 1 # |
31 | (#) :: TransArray c => c -> (b -> IO r) -> TransRaw c b -> IO r | ||
31 | a # b = applyRaw a b | 32 | a # b = applyRaw a b |
32 | {-# INLINE (#) #-} | 33 | {-# INLINE (#) #-} |
33 | 34 | ||
35 | (#!) :: (TransArray c, TransArray c1) => c1 -> c -> TransRaw c1 (TransRaw c (IO r)) -> IO r | ||
34 | a #! b = a # b # id | 36 | a #! b = a # b # id |
35 | {-# INLINE (#!) #-} | 37 | {-# INLINE (#!) #-} |
36 | 38 | ||
39 | fromei :: Enum a => a -> CInt | ||
37 | fromei x = fromIntegral (fromEnum x) :: CInt | 40 | fromei x = fromIntegral (fromEnum x) :: CInt |
38 | 41 | ||
39 | data FunCodeV = Sin | 42 | data FunCodeV = Sin |
@@ -100,10 +103,20 @@ sumQ = sumg c_sumQ | |||
100 | sumC :: Vector (Complex Double) -> Complex Double | 103 | sumC :: Vector (Complex Double) -> Complex Double |
101 | sumC = sumg c_sumC | 104 | sumC = sumg c_sumC |
102 | 105 | ||
106 | sumI :: ( TransRaw c (CInt -> Ptr a -> IO CInt) ~ (CInt -> Ptr I -> I :> Ok) | ||
107 | , TransArray c | ||
108 | , Storable a | ||
109 | ) | ||
110 | => I -> c -> a | ||
103 | sumI m = sumg (c_sumI m) | 111 | sumI m = sumg (c_sumI m) |
104 | 112 | ||
113 | sumL :: ( TransRaw c (CInt -> Ptr a -> IO CInt) ~ (CInt -> Ptr Z -> Z :> Ok) | ||
114 | , TransArray c | ||
115 | , Storable a | ||
116 | ) => Z -> c -> a | ||
105 | sumL m = sumg (c_sumL m) | 117 | sumL m = sumg (c_sumL m) |
106 | 118 | ||
119 | sumg :: (TransArray c, Storable a) => TransRaw c (CInt -> Ptr a -> IO CInt) -> c -> a | ||
107 | sumg f x = unsafePerformIO $ do | 120 | sumg f x = unsafePerformIO $ do |
108 | r <- createVector 1 | 121 | r <- createVector 1 |
109 | (x #! r) f #| "sum" | 122 | (x #! r) f #| "sum" |
@@ -140,6 +153,8 @@ prodI = prodg . c_prodI | |||
140 | prodL :: Z-> Vector Z -> Z | 153 | prodL :: Z-> Vector Z -> Z |
141 | prodL = prodg . c_prodL | 154 | prodL = prodg . c_prodL |
142 | 155 | ||
156 | prodg :: (TransArray c, Storable a) | ||
157 | => TransRaw c (CInt -> Ptr a -> IO CInt) -> c -> a | ||
143 | prodg f x = unsafePerformIO $ do | 158 | prodg f x = unsafePerformIO $ do |
144 | r <- createVector 1 | 159 | r <- createVector 1 |
145 | (x #! r) f #| "prod" | 160 | (x #! r) f #| "prod" |
@@ -155,16 +170,25 @@ foreign import ccall unsafe "prodL" c_prodL :: Z -> TVV Z | |||
155 | 170 | ||
156 | ------------------------------------------------------------------ | 171 | ------------------------------------------------------------------ |
157 | 172 | ||
173 | toScalarAux :: (Enum a, TransArray c, Storable a1) | ||
174 | => (CInt -> TransRaw c (CInt -> Ptr a1 -> IO CInt)) -> a -> c -> a1 | ||
158 | toScalarAux fun code v = unsafePerformIO $ do | 175 | toScalarAux fun code v = unsafePerformIO $ do |
159 | r <- createVector 1 | 176 | r <- createVector 1 |
160 | (v #! r) (fun (fromei code)) #|"toScalarAux" | 177 | (v #! r) (fun (fromei code)) #|"toScalarAux" |
161 | return (r @> 0) | 178 | return (r @> 0) |
162 | 179 | ||
180 | |||
181 | vectorMapAux :: (Enum a, Storable t, Storable a1) | ||
182 | => (CInt -> CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt) | ||
183 | -> a -> Vector t -> Vector a1 | ||
163 | vectorMapAux fun code v = unsafePerformIO $ do | 184 | vectorMapAux fun code v = unsafePerformIO $ do |
164 | r <- createVector (dim v) | 185 | r <- createVector (dim v) |
165 | (v #! r) (fun (fromei code)) #|"vectorMapAux" | 186 | (v #! r) (fun (fromei code)) #|"vectorMapAux" |
166 | return r | 187 | return r |
167 | 188 | ||
189 | vectorMapValAux :: (Enum a, Storable a2, Storable t, Storable a1) | ||
190 | => (CInt -> Ptr a2 -> CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt) | ||
191 | -> a -> a2 -> Vector t -> Vector a1 | ||
168 | vectorMapValAux fun code val v = unsafePerformIO $ do | 192 | vectorMapValAux fun code val v = unsafePerformIO $ do |
169 | r <- createVector (dim v) | 193 | r <- createVector (dim v) |
170 | pval <- newArray [val] | 194 | pval <- newArray [val] |
@@ -172,6 +196,9 @@ vectorMapValAux fun code val v = unsafePerformIO $ do | |||
172 | free pval | 196 | free pval |
173 | return r | 197 | return r |
174 | 198 | ||
199 | vectorZipAux :: (Enum a, TransArray c, Storable t, Storable a1) | ||
200 | => (CInt -> CInt -> Ptr t -> TransRaw c (CInt -> Ptr a1 -> IO CInt)) | ||
201 | -> a -> Vector t -> c -> Vector a1 | ||
175 | vectorZipAux fun code u v = unsafePerformIO $ do | 202 | vectorZipAux fun code u v = unsafePerformIO $ do |
176 | r <- createVector (dim u) | 203 | r <- createVector (dim u) |
177 | (u # v #! r) (fun (fromei code)) #|"vectorZipAux" | 204 | (u # v #! r) (fun (fromei code)) #|"vectorZipAux" |
@@ -378,6 +405,7 @@ foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> D | |||
378 | 405 | ||
379 | -------------------------------------------------------------------------------- | 406 | -------------------------------------------------------------------------------- |
380 | 407 | ||
408 | roundVector :: Vector Double -> Vector Double | ||
381 | roundVector v = unsafePerformIO $ do | 409 | roundVector v = unsafePerformIO $ do |
382 | r <- createVector (dim v) | 410 | r <- createVector (dim v) |
383 | (v #! r) c_round_vector #|"roundVector" | 411 | (v #! r) c_round_vector #|"roundVector" |
@@ -433,6 +461,8 @@ long2intV :: Vector Z -> Vector I | |||
433 | long2intV = tog c_long2int | 461 | long2intV = tog c_long2int |
434 | 462 | ||
435 | 463 | ||
464 | tog :: (Storable t, Storable a) | ||
465 | => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a | ||
436 | tog f v = unsafePerformIO $ do | 466 | tog f v = unsafePerformIO $ do |
437 | r <- createVector (dim v) | 467 | r <- createVector (dim v) |
438 | (v #! r) f #|"tog" | 468 | (v #! r) f #|"tog" |
@@ -452,6 +482,8 @@ foreign import ccall unsafe "long2int" c_long2int :: Z :> I :> Ok | |||
452 | 482 | ||
453 | --------------------------------------------------------------- | 483 | --------------------------------------------------------------- |
454 | 484 | ||
485 | stepg :: (Storable t, Storable a) | ||
486 | => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a | ||
455 | stepg f v = unsafePerformIO $ do | 487 | stepg f v = unsafePerformIO $ do |
456 | r <- createVector (dim v) | 488 | r <- createVector (dim v) |
457 | (v #! r) f #|"step" | 489 | (v #! r) f #|"step" |
@@ -477,6 +509,8 @@ foreign import ccall unsafe "stepL" c_stepL :: TVV Z | |||
477 | 509 | ||
478 | -------------------------------------------------------------------------------- | 510 | -------------------------------------------------------------------------------- |
479 | 511 | ||
512 | conjugateAux :: (Storable t, Storable a) | ||
513 | => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a | ||
480 | conjugateAux fun x = unsafePerformIO $ do | 514 | conjugateAux fun x = unsafePerformIO $ do |
481 | v <- createVector (dim x) | 515 | v <- createVector (dim x) |
482 | (x #! v) fun #|"conjugateAux" | 516 | (x #! v) fun #|"conjugateAux" |
@@ -502,6 +536,8 @@ cloneVector v = do | |||
502 | 536 | ||
503 | -------------------------------------------------------------------------------- | 537 | -------------------------------------------------------------------------------- |
504 | 538 | ||
539 | constantAux :: (Storable a1, Storable a) | ||
540 | => (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a | ||
505 | constantAux fun x n = unsafePerformIO $ do | 541 | constantAux fun x n = unsafePerformIO $ do |
506 | v <- createVector n | 542 | v <- createVector n |
507 | px <- newArray [x] | 543 | px <- newArray [x] |
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs index 520eeb7..91923e9 100644 --- a/packages/base/src/Numeric/LinearAlgebra.hs +++ b/packages/base/src/Numeric/LinearAlgebra.hs | |||
@@ -1,6 +1,8 @@ | |||
1 | {-# LANGUAGE CPP #-} | 1 | {-# LANGUAGE CPP #-} |
2 | {-# LANGUAGE FlexibleContexts #-} | 2 | {-# LANGUAGE FlexibleContexts #-} |
3 | 3 | ||
4 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
5 | |||
4 | ----------------------------------------------------------------------------- | 6 | ----------------------------------------------------------------------------- |
5 | {- | | 7 | {- | |
6 | Module : Numeric.LinearAlgebra | 8 | Module : Numeric.LinearAlgebra |
diff --git a/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs b/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs index 3a84645..57e5cf1 100644 --- a/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs +++ b/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs | |||
@@ -28,7 +28,9 @@ infixr 8 <·> | |||
28 | (<·>) :: Numeric t => Vector t -> Vector t -> t | 28 | (<·>) :: Numeric t => Vector t -> Vector t -> t |
29 | (<·>) = dot | 29 | (<·>) = dot |
30 | 30 | ||
31 | app :: Numeric t => Matrix t -> Vector t -> Vector t | ||
31 | app m v = m #> v | 32 | app m v = m #> v |
32 | 33 | ||
34 | mul :: Numeric t => Matrix t -> Matrix t -> Matrix t | ||
33 | mul a b = a <> b | 35 | mul a b = a <> b |
34 | 36 | ||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs index e328904..2e05c90 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs | |||
@@ -14,6 +14,8 @@ | |||
14 | {-# LANGUAGE GADTs #-} | 14 | {-# LANGUAGE GADTs #-} |
15 | {-# LANGUAGE TypeFamilies #-} | 15 | {-# LANGUAGE TypeFamilies #-} |
16 | 16 | ||
17 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
18 | {-# OPTIONS_GHC -fno-warn-orphans #-} | ||
17 | 19 | ||
18 | {- | | 20 | {- | |
19 | Module : Numeric.LinearAlgebra.Static | 21 | Module : Numeric.LinearAlgebra.Static |
diff --git a/packages/base/src/Numeric/Matrix.hs b/packages/base/src/Numeric/Matrix.hs index 06da150..6e3db61 100644 --- a/packages/base/src/Numeric/Matrix.hs +++ b/packages/base/src/Numeric/Matrix.hs | |||
@@ -4,6 +4,8 @@ | |||
4 | {-# LANGUAGE UndecidableInstances #-} | 4 | {-# LANGUAGE UndecidableInstances #-} |
5 | {-# LANGUAGE MultiParamTypeClasses #-} | 5 | {-# LANGUAGE MultiParamTypeClasses #-} |
6 | 6 | ||
7 | {-# OPTIONS_GHC -fno-warn-orphans #-} | ||
8 | |||
7 | ----------------------------------------------------------------------------- | 9 | ----------------------------------------------------------------------------- |
8 | -- | | 10 | -- | |
9 | -- Module : Numeric.Matrix | 11 | -- Module : Numeric.Matrix |
@@ -35,6 +37,7 @@ import Data.List(partition) | |||
35 | import qualified Data.Foldable as F | 37 | import qualified Data.Foldable as F |
36 | import qualified Data.Semigroup as S | 38 | import qualified Data.Semigroup as S |
37 | import Internal.Chain | 39 | import Internal.Chain |
40 | import Foreign.Storable(Storable) | ||
38 | 41 | ||
39 | 42 | ||
40 | ------------------------------------------------------------------- | 43 | ------------------------------------------------------------------- |
@@ -80,8 +83,16 @@ instance (Floating a, Container Vector a, Floating (Vector a), Fractional (Matri | |||
80 | 83 | ||
81 | -------------------------------------------------------------------------------- | 84 | -------------------------------------------------------------------------------- |
82 | 85 | ||
86 | isScalar :: Matrix t -> Bool | ||
83 | isScalar m = rows m == 1 && cols m == 1 | 87 | isScalar m = rows m == 1 && cols m == 1 |
84 | 88 | ||
89 | adaptScalarM :: (Foreign.Storable.Storable t1, Foreign.Storable.Storable t2) | ||
90 | => (t1 -> Matrix t2 -> t) | ||
91 | -> (Matrix t1 -> Matrix t2 -> t) | ||
92 | -> (Matrix t1 -> t2 -> t) | ||
93 | -> Matrix t1 | ||
94 | -> Matrix t2 | ||
95 | -> t | ||
85 | adaptScalarM f1 f2 f3 x y | 96 | adaptScalarM f1 f2 f3 x y |
86 | | isScalar x = f1 (x @@>(0,0) ) y | 97 | | isScalar x = f1 (x @@>(0,0) ) y |
87 | | isScalar y = f3 x (y @@>(0,0) ) | 98 | | isScalar y = f3 x (y @@>(0,0) ) |
@@ -96,7 +107,7 @@ instance (Container Vector t, Eq t, Num (Vector t), Product t) => M.Monoid (Matr | |||
96 | where | 107 | where |
97 | mempty = 1 | 108 | mempty = 1 |
98 | mappend = adaptScalarM scale mXm (flip scale) | 109 | mappend = adaptScalarM scale mXm (flip scale) |
99 | 110 | ||
100 | mconcat xs = work (partition isScalar xs) | 111 | mconcat xs = work (partition isScalar xs) |
101 | where | 112 | where |
102 | work (ss,[]) = product ss | 113 | work (ss,[]) = product ss |
@@ -106,4 +117,3 @@ instance (Container Vector t, Eq t, Num (Vector t), Product t) => M.Monoid (Matr | |||
106 | | otherwise = scale x00 m | 117 | | otherwise = scale x00 m |
107 | where | 118 | where |
108 | x00 = x @@> (0,0) | 119 | x00 = x @@> (0,0) |
109 | |||
diff --git a/packages/base/src/Numeric/Vector.hs b/packages/base/src/Numeric/Vector.hs index 017196c..1e5877d 100644 --- a/packages/base/src/Numeric/Vector.hs +++ b/packages/base/src/Numeric/Vector.hs | |||
@@ -3,6 +3,9 @@ | |||
3 | {-# LANGUAGE FlexibleInstances #-} | 3 | {-# LANGUAGE FlexibleInstances #-} |
4 | {-# LANGUAGE UndecidableInstances #-} | 4 | {-# LANGUAGE UndecidableInstances #-} |
5 | {-# LANGUAGE MultiParamTypeClasses #-} | 5 | {-# LANGUAGE MultiParamTypeClasses #-} |
6 | |||
7 | {-# OPTIONS_GHC -fno-warn-orphans #-} | ||
8 | |||
6 | ----------------------------------------------------------------------------- | 9 | ----------------------------------------------------------------------------- |
7 | -- | | 10 | -- | |
8 | -- Module : Numeric.Vector | 11 | -- Module : Numeric.Vector |
@@ -14,7 +17,7 @@ | |||
14 | -- | 17 | -- |
15 | -- Provides instances of standard classes 'Show', 'Read', 'Eq', | 18 | -- Provides instances of standard classes 'Show', 'Read', 'Eq', |
16 | -- 'Num', 'Fractional', and 'Floating' for 'Vector'. | 19 | -- 'Num', 'Fractional', and 'Floating' for 'Vector'. |
17 | -- | 20 | -- |
18 | ----------------------------------------------------------------------------- | 21 | ----------------------------------------------------------------------------- |
19 | 22 | ||
20 | module Numeric.Vector () where | 23 | module Numeric.Vector () where |
@@ -23,9 +26,17 @@ import Internal.Vectorized | |||
23 | import Internal.Vector | 26 | import Internal.Vector |
24 | import Internal.Numeric | 27 | import Internal.Numeric |
25 | import Internal.Conversion | 28 | import Internal.Conversion |
29 | import Foreign.Storable(Storable) | ||
26 | 30 | ||
27 | ------------------------------------------------------------------- | 31 | ------------------------------------------------------------------- |
28 | 32 | ||
33 | adaptScalar :: (Foreign.Storable.Storable t1, Foreign.Storable.Storable t2) | ||
34 | => (t1 -> Vector t2 -> t) | ||
35 | -> (Vector t1 -> Vector t2 -> t) | ||
36 | -> (Vector t1 -> t2 -> t) | ||
37 | -> Vector t1 | ||
38 | -> Vector t2 | ||
39 | -> t | ||
29 | adaptScalar f1 f2 f3 x y | 40 | adaptScalar f1 f2 f3 x y |
30 | | dim x == 1 = f1 (x@>0) y | 41 | | dim x == 1 = f1 (x@>0) y |
31 | | dim y == 1 = f3 x (y@>0) | 42 | | dim y == 1 = f3 x (y@>0) |
@@ -172,4 +183,3 @@ instance Floating (Vector (Complex Float)) where | |||
172 | sqrt = vectorMapQ Sqrt | 183 | sqrt = vectorMapQ Sqrt |
173 | (**) = adaptScalar (vectorMapValQ PowSV) (vectorZipQ Pow) (flip (vectorMapValQ PowVS)) | 184 | (**) = adaptScalar (vectorMapValQ PowSV) (vectorZipQ Pow) (flip (vectorMapValQ PowVS)) |
174 | pi = fromList [pi] | 185 | pi = fromList [pi] |
175 | |||
diff --git a/packages/glpk/hmatrix-glpk.cabal b/packages/glpk/hmatrix-glpk.cabal index 6b0032b..ca93775 100644 --- a/packages/glpk/hmatrix-glpk.cabal +++ b/packages/glpk/hmatrix-glpk.cabal | |||
@@ -1,6 +1,6 @@ | |||
1 | Name: hmatrix-glpk | 1 | Name: hmatrix-glpk |
2 | Version: 0.6.0.0 | 2 | Version: 0.19.0.0 |
3 | License: GPL | 3 | License: GPL-3 |
4 | License-file: LICENSE | 4 | License-file: LICENSE |
5 | Author: Alberto Ruiz | 5 | Author: Alberto Ruiz |
6 | Maintainer: Alberto Ruiz <aruiz@um.es> | 6 | Maintainer: Alberto Ruiz <aruiz@um.es> |
diff --git a/packages/gsl/CHANGELOG b/packages/gsl/CHANGELOG index 091dc0e..a2fe038 100644 --- a/packages/gsl/CHANGELOG +++ b/packages/gsl/CHANGELOG | |||
@@ -2,7 +2,7 @@ | |||
2 | -------- | 2 | -------- |
3 | 3 | ||
4 | * Added interpolation modules | 4 | * Added interpolation modules |
5 | 5 | ||
6 | * Added simulated annealing module | 6 | * Added simulated annealing module |
7 | 7 | ||
8 | * Added odeSolveVWith | 8 | * Added odeSolveVWith |
@@ -11,4 +11,3 @@ | |||
11 | -------- | 11 | -------- |
12 | 12 | ||
13 | * The modules Numeric.GSL.* have been moved from hmatrix to the new package hmatrix-gsl. | 13 | * The modules Numeric.GSL.* have been moved from hmatrix to the new package hmatrix-gsl. |
14 | |||
diff --git a/packages/gsl/hmatrix-gsl.cabal b/packages/gsl/hmatrix-gsl.cabal index d463ee8..76db835 100644 --- a/packages/gsl/hmatrix-gsl.cabal +++ b/packages/gsl/hmatrix-gsl.cabal | |||
@@ -1,21 +1,18 @@ | |||
1 | Name: hmatrix-gsl | 1 | Name: hmatrix-gsl |
2 | Version: 0.18.0.1 | 2 | Version: 0.19.0.0 |
3 | License: GPL | ||
4 | License-file: LICENSE | ||
5 | Author: Alberto Ruiz | ||
6 | Maintainer: Alberto Ruiz <aruiz@um.es> | ||
7 | Stability: provisional | ||
8 | Homepage: https://github.com/albertoruiz/hmatrix | ||
9 | Synopsis: Numerical computation | 3 | Synopsis: Numerical computation |
10 | Description: Purely functional interface to selected numerical computations, | 4 | Description: Purely functional interface to selected numerical computations, |
11 | internally implemented using GSL. | 5 | internally implemented using GSL. |
12 | 6 | Homepage: https://github.com/albertoruiz/hmatrix | |
7 | license: GPL-3 | ||
8 | license-file: LICENSE | ||
9 | Author: Alberto Ruiz | ||
10 | Maintainer: Alberto Ruiz <aruiz@um.es> | ||
11 | Stability: provisional | ||
13 | Category: Math | 12 | Category: Math |
14 | tested-with: GHC ==7.8 | ||
15 | |||
16 | cabal-version: >=1.8 | ||
17 | |||
18 | build-type: Simple | 13 | build-type: Simple |
14 | cabal-version: >=1.18 | ||
15 | |||
19 | 16 | ||
20 | extra-source-files: src/Numeric/GSL/gsl-ode.c | 17 | extra-source-files: src/Numeric/GSL/gsl-ode.c |
21 | 18 | ||
@@ -33,9 +30,6 @@ library | |||
33 | Build-Depends: base<5, hmatrix>=0.18, array, vector, | 30 | Build-Depends: base<5, hmatrix>=0.18, array, vector, |
34 | process, random | 31 | process, random |
35 | 32 | ||
36 | |||
37 | Extensions: ForeignFunctionInterface | ||
38 | |||
39 | hs-source-dirs: src | 33 | hs-source-dirs: src |
40 | Exposed-modules: Numeric.GSL.Differentiation, | 34 | Exposed-modules: Numeric.GSL.Differentiation, |
41 | Numeric.GSL.Integration, | 35 | Numeric.GSL.Integration, |
@@ -98,6 +92,8 @@ library | |||
98 | else | 92 | else |
99 | pkgconfig-depends: gsl | 93 | pkgconfig-depends: gsl |
100 | 94 | ||
95 | default-language: Haskell2010 | ||
96 | |||
101 | 97 | ||
102 | source-repository head | 98 | source-repository head |
103 | type: git | 99 | type: git |
diff --git a/packages/gsl/src/Graphics/Plot.hs b/packages/gsl/src/Graphics/Plot.hs index d2ea192..e422912 100644 --- a/packages/gsl/src/Graphics/Plot.hs +++ b/packages/gsl/src/Graphics/Plot.hs | |||
@@ -1,3 +1,5 @@ | |||
1 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
2 | |||
1 | ----------------------------------------------------------------------------- | 3 | ----------------------------------------------------------------------------- |
2 | -- | | 4 | -- | |
3 | -- Module : Graphics.Plot | 5 | -- Module : Graphics.Plot |
diff --git a/packages/gsl/src/Numeric/GSL/Fitting.hs b/packages/gsl/src/Numeric/GSL/Fitting.hs index 8f2eae3..a732c25 100644 --- a/packages/gsl/src/Numeric/GSL/Fitting.hs +++ b/packages/gsl/src/Numeric/GSL/Fitting.hs | |||
@@ -1,6 +1,8 @@ | |||
1 | {-# LANGUAGE CPP #-} | 1 | {-# LANGUAGE CPP #-} |
2 | {-# LANGUAGE FlexibleContexts #-} | 2 | {-# LANGUAGE FlexibleContexts #-} |
3 | 3 | ||
4 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
5 | |||
4 | {- | | 6 | {- | |
5 | Module : Numeric.GSL.Fitting | 7 | Module : Numeric.GSL.Fitting |
6 | Copyright : (c) Alberto Ruiz 2010 | 8 | Copyright : (c) Alberto Ruiz 2010 |
diff --git a/packages/gsl/src/Numeric/GSL/Fourier.hs b/packages/gsl/src/Numeric/GSL/Fourier.hs index bffab87..ed7353a 100644 --- a/packages/gsl/src/Numeric/GSL/Fourier.hs +++ b/packages/gsl/src/Numeric/GSL/Fourier.hs | |||
@@ -1,5 +1,7 @@ | |||
1 | {-# LANGUAGE TypeFamilies #-} | 1 | {-# LANGUAGE TypeFamilies #-} |
2 | 2 | ||
3 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
4 | |||
3 | {- | | 5 | {- | |
4 | Module : Numeric.GSL.Fourier | 6 | Module : Numeric.GSL.Fourier |
5 | Copyright : (c) Alberto Ruiz 2006 | 7 | Copyright : (c) Alberto Ruiz 2006 |
diff --git a/packages/gsl/src/Numeric/GSL/Integration.hs b/packages/gsl/src/Numeric/GSL/Integration.hs index 9c1d43a..0a1b4c6 100644 --- a/packages/gsl/src/Numeric/GSL/Integration.hs +++ b/packages/gsl/src/Numeric/GSL/Integration.hs | |||
@@ -1,3 +1,5 @@ | |||
1 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
2 | |||
1 | {- | | 3 | {- | |
2 | Module : Numeric.GSL.Integration | 4 | Module : Numeric.GSL.Integration |
3 | Copyright : (c) Alberto Ruiz 2006 | 5 | Copyright : (c) Alberto Ruiz 2006 |
diff --git a/packages/gsl/src/Numeric/GSL/Internal.hs b/packages/gsl/src/Numeric/GSL/Internal.hs index f70e167..e1f8d95 100644 --- a/packages/gsl/src/Numeric/GSL/Internal.hs +++ b/packages/gsl/src/Numeric/GSL/Internal.hs | |||
@@ -1,5 +1,8 @@ | |||
1 | {-# LANGUAGE FlexibleContexts #-} | 1 | {-# LANGUAGE FlexibleContexts #-} |
2 | 2 | ||
3 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
4 | {-# OPTIONS_GHC -fno-warn-unused-top-binds #-} | ||
5 | |||
3 | -- | | 6 | -- | |
4 | -- Module : Numeric.GSL.Internal | 7 | -- Module : Numeric.GSL.Internal |
5 | -- Copyright : (c) Alberto Ruiz 2009 | 8 | -- Copyright : (c) Alberto Ruiz 2009 |
@@ -128,8 +131,7 @@ type TVM = TV (TM Res) | |||
128 | ww2 w1 o1 w2 o2 f = w1 o1 $ \a1 -> w2 o2 $ \a2 -> f a1 a2 | 131 | ww2 w1 o1 w2 o2 f = w1 o1 $ \a1 -> w2 o2 $ \a2 -> f a1 a2 |
129 | 132 | ||
130 | vec x f = unsafeWith x $ \p -> do | 133 | vec x f = unsafeWith x $ \p -> do |
131 | let v g = do | 134 | let v g = g (fi $ V.length x) p |
132 | g (fi $ V.length x) p | ||
133 | f v | 135 | f v |
134 | {-# INLINE vec #-} | 136 | {-# INLINE vec #-} |
135 | 137 | ||
diff --git a/packages/gsl/src/Numeric/GSL/Interpolation.hs b/packages/gsl/src/Numeric/GSL/Interpolation.hs index 6f02405..484d2a2 100644 --- a/packages/gsl/src/Numeric/GSL/Interpolation.hs +++ b/packages/gsl/src/Numeric/GSL/Interpolation.hs | |||
@@ -1,5 +1,7 @@ | |||
1 | {-# LANGUAGE MagicHash, UnboxedTuples #-} | 1 | {-# LANGUAGE MagicHash, UnboxedTuples #-} |
2 | 2 | ||
3 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
4 | |||
3 | {- | | 5 | {- | |
4 | Module : Numeric.GSL.Interpolation | 6 | Module : Numeric.GSL.Interpolation |
5 | Copyright : (c) Matthew Peddie 2015 | 7 | Copyright : (c) Matthew Peddie 2015 |
diff --git a/packages/gsl/src/Numeric/GSL/LinearAlgebra.hs b/packages/gsl/src/Numeric/GSL/LinearAlgebra.hs index 1bf357b..aee64f7 100644 --- a/packages/gsl/src/Numeric/GSL/LinearAlgebra.hs +++ b/packages/gsl/src/Numeric/GSL/LinearAlgebra.hs | |||
@@ -1,3 +1,6 @@ | |||
1 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
2 | {-# OPTIONS_GHC -fno-warn-unused-top-binds #-} | ||
3 | |||
1 | ----------------------------------------------------------------------------- | 4 | ----------------------------------------------------------------------------- |
2 | -- | | 5 | -- | |
3 | -- Module : Numeric.GSL.LinearAlgebra | 6 | -- Module : Numeric.GSL.LinearAlgebra |
diff --git a/packages/gsl/src/Numeric/GSL/Minimization.hs b/packages/gsl/src/Numeric/GSL/Minimization.hs index a0e5306..1fd951b 100644 --- a/packages/gsl/src/Numeric/GSL/Minimization.hs +++ b/packages/gsl/src/Numeric/GSL/Minimization.hs | |||
@@ -1,3 +1,5 @@ | |||
1 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
2 | |||
1 | {-# LANGUAGE FlexibleContexts #-} | 3 | {-# LANGUAGE FlexibleContexts #-} |
2 | 4 | ||
3 | 5 | ||
diff --git a/packages/gsl/src/Numeric/GSL/ODE.hs b/packages/gsl/src/Numeric/GSL/ODE.hs index 987d47e..a1ccd38 100644 --- a/packages/gsl/src/Numeric/GSL/ODE.hs +++ b/packages/gsl/src/Numeric/GSL/ODE.hs | |||
@@ -1,5 +1,7 @@ | |||
1 | {-# LANGUAGE FlexibleContexts #-} | 1 | {-# LANGUAGE FlexibleContexts #-} |
2 | 2 | ||
3 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
4 | {-# OPTIONS_GHC -fno-warn-unused-top-binds #-} | ||
3 | 5 | ||
4 | {- | | 6 | {- | |
5 | Module : Numeric.GSL.ODE | 7 | Module : Numeric.GSL.ODE |
diff --git a/packages/gsl/src/Numeric/GSL/Root.hs b/packages/gsl/src/Numeric/GSL/Root.hs index 724f32f..9cdb061 100644 --- a/packages/gsl/src/Numeric/GSL/Root.hs +++ b/packages/gsl/src/Numeric/GSL/Root.hs | |||
@@ -1,5 +1,7 @@ | |||
1 | {-# LANGUAGE FlexibleContexts #-} | 1 | {-# LANGUAGE FlexibleContexts #-} |
2 | 2 | ||
3 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
4 | |||
3 | {- | | 5 | {- | |
4 | Module : Numeric.GSL.Root | 6 | Module : Numeric.GSL.Root |
5 | Copyright : (c) Alberto Ruiz 2009 | 7 | Copyright : (c) Alberto Ruiz 2009 |
diff --git a/packages/gsl/src/Numeric/GSL/Vector.hs b/packages/gsl/src/Numeric/GSL/Vector.hs index b1c0106..2ca7cc0 100644 --- a/packages/gsl/src/Numeric/GSL/Vector.hs +++ b/packages/gsl/src/Numeric/GSL/Vector.hs | |||
@@ -1,3 +1,6 @@ | |||
1 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
2 | {-# OPTIONS_GHC -fno-warn-unused-top-binds #-} | ||
3 | |||
1 | ----------------------------------------------------------------------------- | 4 | ----------------------------------------------------------------------------- |
2 | -- | | 5 | -- | |
3 | -- Module : Numeric.GSL.Vector | 6 | -- Module : Numeric.GSL.Vector |
diff --git a/packages/sparse/hmatrix-sparse.cabal b/packages/sparse/hmatrix-sparse.cabal index 55eb424..4399b72 100644 --- a/packages/sparse/hmatrix-sparse.cabal +++ b/packages/sparse/hmatrix-sparse.cabal | |||
@@ -1,5 +1,5 @@ | |||
1 | Name: hmatrix-sparse | 1 | Name: hmatrix-sparse |
2 | Version: 0.1.0 | 2 | Version: 0.19.0.0 |
3 | License: BSD3 | 3 | License: BSD3 |
4 | License-file: LICENSE | 4 | License-file: LICENSE |
5 | Author: Alberto Ruiz | 5 | Author: Alberto Ruiz |
@@ -19,7 +19,7 @@ build-type: Simple | |||
19 | 19 | ||
20 | 20 | ||
21 | library | 21 | library |
22 | Build-Depends: base, hmatrix>=0.16 | 22 | Build-Depends: base<5, hmatrix>=0.16 |
23 | 23 | ||
24 | hs-source-dirs: src | 24 | hs-source-dirs: src |
25 | 25 | ||
diff --git a/packages/special/hmatrix-special.cabal b/packages/special/hmatrix-special.cabal index 2848e39..0890bc7 100644 --- a/packages/special/hmatrix-special.cabal +++ b/packages/special/hmatrix-special.cabal | |||
@@ -1,6 +1,6 @@ | |||
1 | Name: hmatrix-special | 1 | Name: hmatrix-special |
2 | Version: 0.4.0.1 | 2 | Version: 0.19.0.0 |
3 | License: GPL | 3 | License: GPL-3 |
4 | License-file: LICENSE | 4 | License-file: LICENSE |
5 | Author: Alberto Ruiz | 5 | Author: Alberto Ruiz |
6 | Maintainer: Alberto Ruiz <aruiz@um.es> | 6 | Maintainer: Alberto Ruiz <aruiz@um.es> |
diff --git a/packages/special/lib/Numeric/GSL/Special/Bessel.hs b/packages/special/lib/Numeric/GSL/Special/Bessel.hs index 70066f8..84d4cf5 100644 --- a/packages/special/lib/Numeric/GSL/Special/Bessel.hs +++ b/packages/special/lib/Numeric/GSL/Special/Bessel.hs | |||
@@ -1,4 +1,7 @@ | |||
1 | {-# LANGUAGE CPP #-} | 1 | {-# LANGUAGE CPP #-} |
2 | |||
3 | {-# OPTIONS_GHC -fno-warn-unused-top-binds #-} | ||
4 | |||
2 | ------------------------------------------------------------ | 5 | ------------------------------------------------------------ |
3 | -- | | 6 | -- | |
4 | -- Module : Numeric.GSL.Special.Bessel | 7 | -- Module : Numeric.GSL.Special.Bessel |
diff --git a/packages/special/lib/Numeric/GSL/Special/Coulomb.hs b/packages/special/lib/Numeric/GSL/Special/Coulomb.hs index 6904739..3bd3ed6 100644 --- a/packages/special/lib/Numeric/GSL/Special/Coulomb.hs +++ b/packages/special/lib/Numeric/GSL/Special/Coulomb.hs | |||
@@ -1,4 +1,7 @@ | |||
1 | {-# LANGUAGE CPP #-} | 1 | {-# LANGUAGE CPP #-} |
2 | |||
3 | {-# OPTIONS_GHC -fno-warn-unused-top-binds #-} | ||
4 | |||
2 | ------------------------------------------------------------ | 5 | ------------------------------------------------------------ |
3 | -- | | 6 | -- | |
4 | -- Module : Numeric.GSL.Special.Coulomb | 7 | -- Module : Numeric.GSL.Special.Coulomb |
diff --git a/packages/special/lib/Numeric/GSL/Special/Coupling.hs b/packages/special/lib/Numeric/GSL/Special/Coupling.hs index ad120cc..e8d9aef 100644 --- a/packages/special/lib/Numeric/GSL/Special/Coupling.hs +++ b/packages/special/lib/Numeric/GSL/Special/Coupling.hs | |||
@@ -1,4 +1,7 @@ | |||
1 | {-# LANGUAGE CPP #-} | 1 | {-# LANGUAGE CPP #-} |
2 | |||
3 | {-# OPTIONS_GHC -fno-warn-unused-top-binds #-} | ||
4 | |||
2 | ------------------------------------------------------------ | 5 | ------------------------------------------------------------ |
3 | -- | | 6 | -- | |
4 | -- Module : Numeric.GSL.Special.Coupling | 7 | -- Module : Numeric.GSL.Special.Coupling |
diff --git a/packages/special/lib/Numeric/GSL/Special/Exp.hs b/packages/special/lib/Numeric/GSL/Special/Exp.hs index b6dfeef..54033c5 100644 --- a/packages/special/lib/Numeric/GSL/Special/Exp.hs +++ b/packages/special/lib/Numeric/GSL/Special/Exp.hs | |||
@@ -1,4 +1,7 @@ | |||
1 | {-# LANGUAGE CPP #-} | 1 | {-# LANGUAGE CPP #-} |
2 | |||
3 | {-# OPTIONS_GHC -fno-warn-unused-top-binds #-} | ||
4 | |||
2 | ------------------------------------------------------------ | 5 | ------------------------------------------------------------ |
3 | -- | | 6 | -- | |
4 | -- Module : Numeric.GSL.Special.Exp | 7 | -- Module : Numeric.GSL.Special.Exp |
diff --git a/packages/special/lib/Numeric/GSL/Special/Gamma.hs b/packages/special/lib/Numeric/GSL/Special/Gamma.hs index 41e24f0..55950cc 100644 --- a/packages/special/lib/Numeric/GSL/Special/Gamma.hs +++ b/packages/special/lib/Numeric/GSL/Special/Gamma.hs | |||
@@ -1,4 +1,7 @@ | |||
1 | {-# LANGUAGE CPP #-} | 1 | {-# LANGUAGE CPP #-} |
2 | |||
3 | {-# OPTIONS_GHC -fno-warn-unused-top-binds #-} | ||
4 | |||
2 | ------------------------------------------------------------ | 5 | ------------------------------------------------------------ |
3 | -- | | 6 | -- | |
4 | -- Module : Numeric.GSL.Special.Gamma | 7 | -- Module : Numeric.GSL.Special.Gamma |
diff --git a/packages/special/lib/Numeric/GSL/Special/Gegenbauer.hs b/packages/special/lib/Numeric/GSL/Special/Gegenbauer.hs index fb8bf3f..1dae1f1 100644 --- a/packages/special/lib/Numeric/GSL/Special/Gegenbauer.hs +++ b/packages/special/lib/Numeric/GSL/Special/Gegenbauer.hs | |||
@@ -1,4 +1,7 @@ | |||
1 | {-# LANGUAGE CPP #-} | 1 | {-# LANGUAGE CPP #-} |
2 | |||
3 | {-# OPTIONS_GHC -fno-warn-unused-top-binds #-} | ||
4 | |||
2 | ------------------------------------------------------------ | 5 | ------------------------------------------------------------ |
3 | -- | | 6 | -- | |
4 | -- Module : Numeric.GSL.Special.Gegenbauer | 7 | -- Module : Numeric.GSL.Special.Gegenbauer |
diff --git a/packages/special/lib/Numeric/GSL/Special/Legendre.hs b/packages/special/lib/Numeric/GSL/Special/Legendre.hs index 927fa2c..5f7d2b0 100644 --- a/packages/special/lib/Numeric/GSL/Special/Legendre.hs +++ b/packages/special/lib/Numeric/GSL/Special/Legendre.hs | |||
@@ -1,4 +1,7 @@ | |||
1 | {-# LANGUAGE CPP #-} | 1 | {-# LANGUAGE CPP #-} |
2 | |||
3 | {-# OPTIONS_GHC -fno-warn-unused-top-binds #-} | ||
4 | |||
2 | ------------------------------------------------------------ | 5 | ------------------------------------------------------------ |
3 | -- | | 6 | -- | |
4 | -- Module : Numeric.GSL.Special.Legendre | 7 | -- Module : Numeric.GSL.Special.Legendre |
diff --git a/packages/special/lib/Numeric/GSL/Special/Trig.hs b/packages/special/lib/Numeric/GSL/Special/Trig.hs index f2c1519..754bed1 100644 --- a/packages/special/lib/Numeric/GSL/Special/Trig.hs +++ b/packages/special/lib/Numeric/GSL/Special/Trig.hs | |||
@@ -1,4 +1,8 @@ | |||
1 | {-# LANGUAGE CPP #-} | 1 | {-# LANGUAGE CPP #-} |
2 | |||
3 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
4 | {-# OPTIONS_GHC -fno-warn-unused-top-binds #-} | ||
5 | |||
2 | ------------------------------------------------------------ | 6 | ------------------------------------------------------------ |
3 | -- | | 7 | -- | |
4 | -- Module : Numeric.GSL.Special.Trig | 8 | -- Module : Numeric.GSL.Special.Trig |
diff --git a/packages/sundials/ChangeLog.md b/packages/sundials/ChangeLog.md new file mode 100644 index 0000000..7b15777 --- /dev/null +++ b/packages/sundials/ChangeLog.md | |||
@@ -0,0 +1,5 @@ | |||
1 | # Revision history for hmatrix-sundials | ||
2 | |||
3 | ## 0.1.0.0 -- 2018-04-21 | ||
4 | |||
5 | * First version. Released on an unsuspecting world. Just Runge-Kutta methods to start with. | ||
diff --git a/packages/sundials/LICENSE b/packages/sundials/LICENSE new file mode 100644 index 0000000..a162e98 --- /dev/null +++ b/packages/sundials/LICENSE | |||
@@ -0,0 +1,30 @@ | |||
1 | Copyright (c) 2018, Dominic Steinitz, Novadiscovery | ||
2 | |||
3 | All rights reserved. | ||
4 | |||
5 | Redistribution and use in source and binary forms, with or without | ||
6 | modification, are permitted provided that the following conditions are met: | ||
7 | |||
8 | * Redistributions of source code must retain the above copyright | ||
9 | notice, this list of conditions and the following disclaimer. | ||
10 | |||
11 | * Redistributions in binary form must reproduce the above | ||
12 | copyright notice, this list of conditions and the following | ||
13 | disclaimer in the documentation and/or other materials provided | ||
14 | with the distribution. | ||
15 | |||
16 | * Neither the name of Dominic Steinitz nor the names of other | ||
17 | contributors may be used to endorse or promote products derived | ||
18 | from this software without specific prior written permission. | ||
19 | |||
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS | ||
21 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT | ||
22 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR | ||
23 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT | ||
24 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, | ||
25 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT | ||
26 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | ||
27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY | ||
28 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
diff --git a/packages/sundials/README.md b/packages/sundials/README.md new file mode 100644 index 0000000..2fac5c2 --- /dev/null +++ b/packages/sundials/README.md | |||
@@ -0,0 +1,8 @@ | |||
1 | Currently only an interface to the Runge-Kutta methods: | ||
2 | [ARKode](https://computation.llnl.gov/projects/sundials/arkode) | ||
3 | |||
4 | The interface is almost certainly going to change. Sundials gives a | ||
5 | rich set of "combinators" for controlling the solution of your problem | ||
6 | and reporting on how it performed. The idea is to initially mimic | ||
7 | hmatrix-gsl and add extra, richer functions but ultimately upgrade the | ||
8 | whole interface both for sundials and for gsl. | ||
diff --git a/packages/sundials/Setup.hs b/packages/sundials/Setup.hs new file mode 100644 index 0000000..9a994af --- /dev/null +++ b/packages/sundials/Setup.hs | |||
@@ -0,0 +1,2 @@ | |||
1 | import Distribution.Simple | ||
2 | main = defaultMain | ||
diff --git a/packages/sundials/diagrams/brusselator.png b/packages/sundials/diagrams/brusselator.png new file mode 100644 index 0000000..740cacb --- /dev/null +++ b/packages/sundials/diagrams/brusselator.png | |||
Binary files differ | |||
diff --git a/packages/sundials/hmatrix-sundials.cabal b/packages/sundials/hmatrix-sundials.cabal new file mode 100644 index 0000000..cd2be4e --- /dev/null +++ b/packages/sundials/hmatrix-sundials.cabal | |||
@@ -0,0 +1,61 @@ | |||
1 | name: hmatrix-sundials | ||
2 | version: 0.19.0.0 | ||
3 | synopsis: hmatrix interface to sundials | ||
4 | description: An interface to the solving suite SUNDIALS. Currently, it | ||
5 | mimics the solving interface in hmstrix-gsl but | ||
6 | provides more diagnostic information and the | ||
7 | Butcher Tableaux (for Runge-Kutta methods). | ||
8 | homepage: https://github.com/idontgetoutmuch/hmatrix/tree/sundials | ||
9 | license: BSD3 | ||
10 | license-file: LICENSE | ||
11 | author: Dominic Steinitz | ||
12 | maintainer: dominic@steinitz.org | ||
13 | copyright: Dominic Steinitz 2018, Novadiscovery 2018 | ||
14 | category: Math | ||
15 | build-type: Simple | ||
16 | extra-source-files: ChangeLog.md, README.md, diagrams/*.png | ||
17 | extra-doc-files: diagrams/*.png | ||
18 | cabal-version: >=1.18 | ||
19 | |||
20 | |||
21 | library | ||
22 | build-depends: base >=4.10 && <4.11, | ||
23 | inline-c >=0.6 && <0.7, | ||
24 | vector >=0.12 && <0.13, | ||
25 | template-haskell >=2.12 && <2.13, | ||
26 | containers >=0.5 && <0.6, | ||
27 | hmatrix>=0.18 | ||
28 | extra-libraries: sundials_arkode, | ||
29 | sundials_cvode | ||
30 | other-extensions: QuasiQuotes | ||
31 | hs-source-dirs: src | ||
32 | exposed-modules: Numeric.Sundials.ODEOpts, | ||
33 | Numeric.Sundials.ARKode.ODE, | ||
34 | Numeric.Sundials.CVode.ODE | ||
35 | other-modules: Numeric.Sundials.Arkode | ||
36 | c-sources: src/helpers.c src/helpers.h | ||
37 | default-language: Haskell2010 | ||
38 | |||
39 | test-suite hmatrix-sundials-testsuite | ||
40 | type: exitcode-stdio-1.0 | ||
41 | main-is: Main.hs | ||
42 | other-modules: Numeric.Sundials.ODEOpts, | ||
43 | Numeric.Sundials.ARKode.ODE, | ||
44 | Numeric.Sundials.CVode.ODE, | ||
45 | Numeric.Sundials.Arkode | ||
46 | build-depends: base >=4.10 && <4.11, | ||
47 | inline-c >=0.6 && <0.7, | ||
48 | vector >=0.12 && <0.13, | ||
49 | template-haskell >=2.12 && <2.13, | ||
50 | containers >=0.5 && <0.6, | ||
51 | hmatrix>=0.18, | ||
52 | plots, | ||
53 | diagrams-lib, | ||
54 | diagrams-rasterific, | ||
55 | lens, | ||
56 | hspec | ||
57 | hs-source-dirs: src | ||
58 | extra-libraries: sundials_arkode, | ||
59 | sundials_cvode | ||
60 | c-sources: src/helpers.c src/helpers.h | ||
61 | default-language: Haskell2010 | ||
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs new file mode 100644 index 0000000..16c21c5 --- /dev/null +++ b/packages/sundials/src/Main.hs | |||
@@ -0,0 +1,186 @@ | |||
1 | {-# OPTIONS_GHC -Wall #-} | ||
2 | |||
3 | import qualified Numeric.Sundials.ARKode.ODE as ARK | ||
4 | import qualified Numeric.Sundials.CVode.ODE as CV | ||
5 | import Numeric.LinearAlgebra | ||
6 | |||
7 | import Plots as P | ||
8 | import qualified Diagrams.Prelude as D | ||
9 | import Diagrams.Backend.Rasterific | ||
10 | |||
11 | import Control.Lens | ||
12 | |||
13 | import Test.Hspec | ||
14 | |||
15 | |||
16 | lorenz :: Double -> [Double] -> [Double] | ||
17 | lorenz _t u = [ sigma * (y - x) | ||
18 | , x * (rho - z) - y | ||
19 | , x * y - beta * z | ||
20 | ] | ||
21 | where | ||
22 | rho = 28.0 | ||
23 | sigma = 10.0 | ||
24 | beta = 8.0 / 3.0 | ||
25 | x = u !! 0 | ||
26 | y = u !! 1 | ||
27 | z = u !! 2 | ||
28 | |||
29 | _lorenzJac :: Double -> Vector Double -> Matrix Double | ||
30 | _lorenzJac _t u = (3><3) [ (-sigma), rho - z, y | ||
31 | , sigma , -1.0 , x | ||
32 | , 0.0 , (-x) , (-beta) | ||
33 | ] | ||
34 | where | ||
35 | rho = 28.0 | ||
36 | sigma = 10.0 | ||
37 | beta = 8.0 / 3.0 | ||
38 | x = u ! 0 | ||
39 | y = u ! 1 | ||
40 | z = u ! 2 | ||
41 | |||
42 | brusselator :: Double -> [Double] -> [Double] | ||
43 | brusselator _t x = [ a - (w + 1) * u + v * u * u | ||
44 | , w * u - v * u * u | ||
45 | , (b - w) / eps - w * u | ||
46 | ] | ||
47 | where | ||
48 | a = 1.0 | ||
49 | b = 3.5 | ||
50 | eps = 5.0e-6 | ||
51 | u = x !! 0 | ||
52 | v = x !! 1 | ||
53 | w = x !! 2 | ||
54 | |||
55 | _brussJac :: Double -> Vector Double -> Matrix Double | ||
56 | _brussJac _t x = (3><3) [ (-(w + 1.0)) + 2.0 * u * v, w - 2.0 * u * v, (-w) | ||
57 | , u * u , (-(u * u)) , 0.0 | ||
58 | , (-u) , u , (-1.0) / eps - u | ||
59 | ] | ||
60 | where | ||
61 | y = toList x | ||
62 | u = y !! 0 | ||
63 | v = y !! 1 | ||
64 | w = y !! 2 | ||
65 | eps = 5.0e-6 | ||
66 | |||
67 | stiffish :: Double -> [Double] -> [Double] | ||
68 | stiffish t v = [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] | ||
69 | where | ||
70 | lamda = -100.0 | ||
71 | u = v !! 0 | ||
72 | |||
73 | stiffishV :: Double -> Vector Double -> Vector Double | ||
74 | stiffishV t v = fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] | ||
75 | where | ||
76 | lamda = -100.0 | ||
77 | u = v ! 0 | ||
78 | |||
79 | _stiffJac :: Double -> Vector Double -> Matrix Double | ||
80 | _stiffJac _t _v = (1><1) [ lamda ] | ||
81 | where | ||
82 | lamda = -100.0 | ||
83 | |||
84 | predatorPrey :: Double -> [Double] -> [Double] | ||
85 | predatorPrey _t v = [ x * a - b * x * y | ||
86 | , d * x * y - c * y - e * y * z | ||
87 | , (-f) * z + g * y * z | ||
88 | ] | ||
89 | where | ||
90 | x = v!!0 | ||
91 | y = v!!1 | ||
92 | z = v!!2 | ||
93 | a = 1.0 | ||
94 | b = 1.0 | ||
95 | c = 1.0 | ||
96 | d = 1.0 | ||
97 | e = 1.0 | ||
98 | f = 1.0 | ||
99 | g = 1.0 | ||
100 | |||
101 | lSaxis :: [[Double]] -> P.Axis B D.V2 Double | ||
102 | lSaxis xs = P.r2Axis &~ do | ||
103 | let ts = xs!!0 | ||
104 | us = xs!!1 | ||
105 | vs = xs!!2 | ||
106 | ws = xs!!3 | ||
107 | P.linePlot' $ zip ts us | ||
108 | P.linePlot' $ zip ts vs | ||
109 | P.linePlot' $ zip ts ws | ||
110 | |||
111 | kSaxis :: [(Double, Double)] -> P.Axis B D.V2 Double | ||
112 | kSaxis xs = P.r2Axis &~ do | ||
113 | P.linePlot' xs | ||
114 | |||
115 | main :: IO () | ||
116 | main = do | ||
117 | |||
118 | let res1 = ARK.odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) | ||
119 | renderRasterific "diagrams/brusselator.png" | ||
120 | (D.dims2D 500.0 500.0) | ||
121 | (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1)) | ||
122 | |||
123 | let res1a = ARK.odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) | ||
124 | renderRasterific "diagrams/brusselatorA.png" | ||
125 | (D.dims2D 500.0 500.0) | ||
126 | (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1a)) | ||
127 | |||
128 | let res2 = ARK.odeSolve stiffish [0.0] (fromList [0.0, 0.1 .. 10.0]) | ||
129 | renderRasterific "diagrams/stiffish.png" | ||
130 | (D.dims2D 500.0 500.0) | ||
131 | (renderAxis $ kSaxis $ zip [0.0, 0.1 .. 10.0] (concat $ toLists res2)) | ||
132 | |||
133 | let res2a = ARK.odeSolveV (ARK.SDIRK_5_3_4') Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0]) | ||
134 | |||
135 | let res2b = ARK.odeSolveV (ARK.TRBDF2_3_3_2') Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0]) | ||
136 | |||
137 | let maxDiffA = maximum $ map abs $ | ||
138 | zipWith (-) ((toLists $ tr res2a)!!0) ((toLists $ tr res2b)!!0) | ||
139 | |||
140 | let res2c = CV.odeSolveV (CV.BDF) Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0]) | ||
141 | |||
142 | let maxDiffB = maximum $ map abs $ | ||
143 | zipWith (-) ((toLists $ tr res2a)!!0) ((toLists $ tr res2c)!!0) | ||
144 | |||
145 | let maxDiffC = maximum $ map abs $ | ||
146 | zipWith (-) ((toLists $ tr res2b)!!0) ((toLists $ tr res2c)!!0) | ||
147 | |||
148 | let res3 = ARK.odeSolve lorenz [-5.0, -5.0, 1.0] (fromList [0.0, 0.01 .. 10.0]) | ||
149 | |||
150 | renderRasterific "diagrams/lorenz.png" | ||
151 | (D.dims2D 500.0 500.0) | ||
152 | (renderAxis $ kSaxis $ zip ((toLists $ tr res3)!!0) ((toLists $ tr res3)!!1)) | ||
153 | |||
154 | renderRasterific "diagrams/lorenz1.png" | ||
155 | (D.dims2D 500.0 500.0) | ||
156 | (renderAxis $ kSaxis $ zip ((toLists $ tr res3)!!0) ((toLists $ tr res3)!!2)) | ||
157 | |||
158 | renderRasterific "diagrams/lorenz2.png" | ||
159 | (D.dims2D 500.0 500.0) | ||
160 | (renderAxis $ kSaxis $ zip ((toLists $ tr res3)!!1) ((toLists $ tr res3)!!2)) | ||
161 | |||
162 | let res4 = CV.odeSolve predatorPrey [0.5, 1.0, 2.0] (fromList [0.0, 0.01 .. 10.0]) | ||
163 | |||
164 | renderRasterific "diagrams/predatorPrey.png" | ||
165 | (D.dims2D 500.0 500.0) | ||
166 | (renderAxis $ kSaxis $ zip ((toLists $ tr res4)!!0) ((toLists $ tr res4)!!1)) | ||
167 | |||
168 | renderRasterific "diagrams/predatorPrey1.png" | ||
169 | (D.dims2D 500.0 500.0) | ||
170 | (renderAxis $ kSaxis $ zip ((toLists $ tr res4)!!0) ((toLists $ tr res4)!!2)) | ||
171 | |||
172 | renderRasterific "diagrams/predatorPrey2.png" | ||
173 | (D.dims2D 500.0 500.0) | ||
174 | (renderAxis $ kSaxis $ zip ((toLists $ tr res4)!!1) ((toLists $ tr res4)!!2)) | ||
175 | |||
176 | let res4a = ARK.odeSolve predatorPrey [0.5, 1.0, 2.0] (fromList [0.0, 0.01 .. 10.0]) | ||
177 | |||
178 | let maxDiffPpA = maximum $ map abs $ | ||
179 | zipWith (-) ((toLists $ tr res4)!!0) ((toLists $ tr res4a)!!0) | ||
180 | |||
181 | hspec $ describe "Compare results" $ do | ||
182 | it "for SDIRK_5_3_4' and TRBDF2_3_3_2'" $ maxDiffA < 1.0e-6 | ||
183 | it "for SDIRK_5_3_4' and BDF" $ maxDiffB < 1.0e-6 | ||
184 | it "for TRBDF2_3_3_2' and BDF" $ maxDiffC < 1.0e-6 | ||
185 | it "for CV and ARK for the Predator Prey model" $ maxDiffPpA < 1.0e-3 | ||
186 | |||
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs new file mode 100644 index 0000000..fafc237 --- /dev/null +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | |||
@@ -0,0 +1,903 @@ | |||
1 | {-# LANGUAGE QuasiQuotes #-} | ||
2 | {-# LANGUAGE TemplateHaskell #-} | ||
3 | {-# LANGUAGE MultiWayIf #-} | ||
4 | {-# LANGUAGE OverloadedStrings #-} | ||
5 | {-# LANGUAGE ScopedTypeVariables #-} | ||
6 | {-# LANGUAGE DeriveGeneric #-} | ||
7 | {-# LANGUAGE TypeOperators #-} | ||
8 | {-# LANGUAGE KindSignatures #-} | ||
9 | {-# LANGUAGE TypeSynonymInstances #-} | ||
10 | {-# LANGUAGE FlexibleInstances #-} | ||
11 | {-# LANGUAGE FlexibleContexts #-} | ||
12 | |||
13 | ----------------------------------------------------------------------------- | ||
14 | -- | | ||
15 | -- Module : Numeric.Sundials.ARKode.ODE | ||
16 | -- Copyright : Dominic Steinitz 2018, | ||
17 | -- Novadiscovery 2018 | ||
18 | -- License : BSD | ||
19 | -- Maintainer : Dominic Steinitz | ||
20 | -- Stability : provisional | ||
21 | -- | ||
22 | -- Solution of ordinary differential equation (ODE) initial value problems. | ||
23 | -- See <https://computation.llnl.gov/projects/sundials/sundials-software> for more detail. | ||
24 | -- | ||
25 | -- A simple example: | ||
26 | -- | ||
27 | -- <<diagrams/brusselator.png#diagram=brusselator&height=400&width=500>> | ||
28 | -- | ||
29 | -- @ | ||
30 | -- import Numeric.Sundials.ARKode.ODE | ||
31 | -- import Numeric.LinearAlgebra | ||
32 | -- | ||
33 | -- import Plots as P | ||
34 | -- import qualified Diagrams.Prelude as D | ||
35 | -- import Diagrams.Backend.Rasterific | ||
36 | -- | ||
37 | -- brusselator :: Double -> [Double] -> [Double] | ||
38 | -- brusselator _t x = [ a - (w + 1) * u + v * u * u | ||
39 | -- , w * u - v * u * u | ||
40 | -- , (b - w) / eps - w * u | ||
41 | -- ] | ||
42 | -- where | ||
43 | -- a = 1.0 | ||
44 | -- b = 3.5 | ||
45 | -- eps = 5.0e-6 | ||
46 | -- u = x !! 0 | ||
47 | -- v = x !! 1 | ||
48 | -- w = x !! 2 | ||
49 | -- | ||
50 | -- lSaxis :: [[Double]] -> P.Axis B D.V2 Double | ||
51 | -- lSaxis xs = P.r2Axis &~ do | ||
52 | -- let ts = xs!!0 | ||
53 | -- us = xs!!1 | ||
54 | -- vs = xs!!2 | ||
55 | -- ws = xs!!3 | ||
56 | -- P.linePlot' $ zip ts us | ||
57 | -- P.linePlot' $ zip ts vs | ||
58 | -- P.linePlot' $ zip ts ws | ||
59 | -- | ||
60 | -- main = do | ||
61 | -- let res1 = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) | ||
62 | -- renderRasterific "diagrams/brusselator.png" | ||
63 | -- (D.dims2D 500.0 500.0) | ||
64 | -- (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1)) | ||
65 | -- @ | ||
66 | -- | ||
67 | -- With Sundials ARKode, it is possible to retrieve the Butcher tableau for the solver. | ||
68 | -- | ||
69 | -- @ | ||
70 | -- import Numeric.Sundials.ARKode.ODE | ||
71 | -- import Numeric.LinearAlgebra | ||
72 | -- | ||
73 | -- import Data.List (intercalate) | ||
74 | -- | ||
75 | -- import Text.PrettyPrint.HughesPJClass | ||
76 | -- | ||
77 | -- | ||
78 | -- butcherTableauTex :: ButcherTable -> String | ||
79 | -- butcherTableauTex (ButcherTable m c b b2) = | ||
80 | -- render $ | ||
81 | -- vcat [ text ("\n\\begin{array}{c|" ++ (concat $ replicate n "c") ++ "}") | ||
82 | -- , us | ||
83 | -- , text "\\hline" | ||
84 | -- , text bs <+> text "\\\\" | ||
85 | -- , text b2s <+> text "\\\\" | ||
86 | -- , text "\\end{array}" | ||
87 | -- ] | ||
88 | -- where | ||
89 | -- n = rows m | ||
90 | -- rs = toLists m | ||
91 | -- ss = map (\r -> intercalate " & " $ map show r) rs | ||
92 | -- ts = zipWith (\i r -> show i ++ " & " ++ r) (toList c) ss | ||
93 | -- us = vcat $ map (\r -> text r <+> text "\\\\") ts | ||
94 | -- bs = " & " ++ (intercalate " & " $ map show $ toList b) | ||
95 | -- b2s = " & " ++ (intercalate " & " $ map show $ toList b2) | ||
96 | -- | ||
97 | -- main :: IO () | ||
98 | -- main = do | ||
99 | -- | ||
100 | -- let res = butcherTable (SDIRK_2_1_2 undefined) | ||
101 | -- putStrLn $ show res | ||
102 | -- putStrLn $ butcherTableauTex res | ||
103 | -- | ||
104 | -- let resA = butcherTable (KVAERNO_4_2_3 undefined) | ||
105 | -- putStrLn $ show resA | ||
106 | -- putStrLn $ butcherTableauTex resA | ||
107 | -- | ||
108 | -- let resB = butcherTable (SDIRK_5_3_4 undefined) | ||
109 | -- putStrLn $ show resB | ||
110 | -- putStrLn $ butcherTableauTex resB | ||
111 | -- @ | ||
112 | -- | ||
113 | -- Using the code above from the examples gives | ||
114 | -- | ||
115 | -- KVAERNO_4_2_3 | ||
116 | -- | ||
117 | -- \[ | ||
118 | -- \begin{array}{c|cccc} | ||
119 | -- 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\ | ||
120 | -- 0.871733043 & 0.4358665215 & 0.4358665215 & 0.0 & 0.0 \\ | ||
121 | -- 1.0 & 0.490563388419108 & 7.3570090080892e-2 & 0.4358665215 & 0.0 \\ | ||
122 | -- 1.0 & 0.308809969973036 & 1.490563388254106 & -1.235239879727145 & 0.4358665215 \\ | ||
123 | -- \hline | ||
124 | -- & 0.308809969973036 & 1.490563388254106 & -1.235239879727145 & 0.4358665215 \\ | ||
125 | -- & 0.490563388419108 & 7.3570090080892e-2 & 0.4358665215 & 0.0 \\ | ||
126 | -- \end{array} | ||
127 | -- \] | ||
128 | -- | ||
129 | -- SDIRK_2_1_2 | ||
130 | -- | ||
131 | -- \[ | ||
132 | -- \begin{array}{c|cc} | ||
133 | -- 1.0 & 1.0 & 0.0 \\ | ||
134 | -- 0.0 & -1.0 & 1.0 \\ | ||
135 | -- \hline | ||
136 | -- & 0.5 & 0.5 \\ | ||
137 | -- & 1.0 & 0.0 \\ | ||
138 | -- \end{array} | ||
139 | -- \] | ||
140 | -- | ||
141 | -- SDIRK_5_3_4 | ||
142 | -- | ||
143 | -- \[ | ||
144 | -- \begin{array}{c|ccccc} | ||
145 | -- 0.25 & 0.25 & 0.0 & 0.0 & 0.0 & 0.0 \\ | ||
146 | -- 0.75 & 0.5 & 0.25 & 0.0 & 0.0 & 0.0 \\ | ||
147 | -- 0.55 & 0.34 & -4.0e-2 & 0.25 & 0.0 & 0.0 \\ | ||
148 | -- 0.5 & 0.2727941176470588 & -5.036764705882353e-2 & 2.7573529411764705e-2 & 0.25 & 0.0 \\ | ||
149 | -- 1.0 & 1.0416666666666667 & -1.0208333333333333 & 7.8125 & -7.083333333333333 & 0.25 \\ | ||
150 | -- \hline | ||
151 | -- & 1.0416666666666667 & -1.0208333333333333 & 7.8125 & -7.083333333333333 & 0.25 \\ | ||
152 | -- & 1.2291666666666667 & -0.17708333333333334 & 7.03125 & -7.083333333333333 & 0.0 \\ | ||
153 | -- \end{array} | ||
154 | -- \] | ||
155 | ----------------------------------------------------------------------------- | ||
156 | module Numeric.Sundials.ARKode.ODE ( odeSolve | ||
157 | , odeSolveV | ||
158 | , odeSolveVWith | ||
159 | , odeSolveVWith' | ||
160 | , ButcherTable(..) | ||
161 | , butcherTable | ||
162 | , ODEMethod(..) | ||
163 | , StepControl(..) | ||
164 | ) where | ||
165 | |||
166 | import qualified Language.C.Inline as C | ||
167 | import qualified Language.C.Inline.Unsafe as CU | ||
168 | |||
169 | import Data.Monoid ((<>)) | ||
170 | import Data.Maybe (isJust) | ||
171 | |||
172 | import Foreign.C.Types (CDouble, CInt, CLong) | ||
173 | import Foreign.Ptr (Ptr) | ||
174 | import Foreign.Storable (poke) | ||
175 | |||
176 | import qualified Data.Vector.Storable as V | ||
177 | |||
178 | import Data.Coerce (coerce) | ||
179 | import System.IO.Unsafe (unsafePerformIO) | ||
180 | import GHC.Generics (C1, Constructor, (:+:)(..), D1, Rep, Generic, M1(..), | ||
181 | from, conName) | ||
182 | |||
183 | import Numeric.LinearAlgebra.Devel (createVector) | ||
184 | |||
185 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, rows, | ||
186 | cols, toLists, size, reshape, | ||
187 | subVector, subMatrix, (><)) | ||
188 | |||
189 | import Numeric.Sundials.ODEOpts (ODEOpts(..), Jacobian, SundialsDiagnostics(..)) | ||
190 | import qualified Numeric.Sundials.Arkode as T | ||
191 | import Numeric.Sundials.Arkode (getDataFromContents, putDataInContents, arkSMax, | ||
192 | sDIRK_2_1_2, | ||
193 | bILLINGTON_3_3_2, | ||
194 | tRBDF2_3_3_2, | ||
195 | kVAERNO_4_2_3, | ||
196 | aRK324L2SA_DIRK_4_2_3, | ||
197 | cASH_5_2_4, | ||
198 | cASH_5_3_4, | ||
199 | sDIRK_5_3_4, | ||
200 | kVAERNO_5_3_4, | ||
201 | aRK436L2SA_DIRK_6_3_4, | ||
202 | kVAERNO_7_4_5, | ||
203 | aRK548L2SA_DIRK_8_4_5, | ||
204 | hEUN_EULER_2_1_2, | ||
205 | bOGACKI_SHAMPINE_4_2_3, | ||
206 | aRK324L2SA_ERK_4_2_3, | ||
207 | zONNEVELD_5_3_4, | ||
208 | aRK436L2SA_ERK_6_3_4, | ||
209 | sAYFY_ABURUB_6_3_4, | ||
210 | cASH_KARP_6_4_5, | ||
211 | fEHLBERG_6_4_5, | ||
212 | dORMAND_PRINCE_7_4_5, | ||
213 | aRK548L2SA_ERK_8_4_5, | ||
214 | vERNER_8_5_6, | ||
215 | fEHLBERG_13_7_8) | ||
216 | |||
217 | |||
218 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | ||
219 | |||
220 | C.include "<stdlib.h>" | ||
221 | C.include "<stdio.h>" | ||
222 | C.include "<math.h>" | ||
223 | C.include "<arkode/arkode.h>" -- prototypes for ARKODE fcts., consts. | ||
224 | C.include "<nvector/nvector_serial.h>" -- serial N_Vector types, fcts., macros | ||
225 | C.include "<sunmatrix/sunmatrix_dense.h>" -- access to dense SUNMatrix | ||
226 | C.include "<sunlinsol/sunlinsol_dense.h>" -- access to dense SUNLinearSolver | ||
227 | C.include "<arkode/arkode_direct.h>" -- access to ARKDls interface | ||
228 | C.include "<sundials/sundials_types.h>" -- definition of type realtype | ||
229 | C.include "<sundials/sundials_math.h>" | ||
230 | C.include "../../../helpers.h" | ||
231 | C.include "Numeric/Sundials/Arkode_hsc.h" | ||
232 | |||
233 | |||
234 | -- | Stepping functions | ||
235 | data ODEMethod = SDIRK_2_1_2 Jacobian | ||
236 | | SDIRK_2_1_2' | ||
237 | | BILLINGTON_3_3_2 Jacobian | ||
238 | | BILLINGTON_3_3_2' | ||
239 | | TRBDF2_3_3_2 Jacobian | ||
240 | | TRBDF2_3_3_2' | ||
241 | | KVAERNO_4_2_3 Jacobian | ||
242 | | KVAERNO_4_2_3' | ||
243 | | ARK324L2SA_DIRK_4_2_3 Jacobian | ||
244 | | ARK324L2SA_DIRK_4_2_3' | ||
245 | | CASH_5_2_4 Jacobian | ||
246 | | CASH_5_2_4' | ||
247 | | CASH_5_3_4 Jacobian | ||
248 | | CASH_5_3_4' | ||
249 | | SDIRK_5_3_4 Jacobian | ||
250 | | SDIRK_5_3_4' | ||
251 | | KVAERNO_5_3_4 Jacobian | ||
252 | | KVAERNO_5_3_4' | ||
253 | | ARK436L2SA_DIRK_6_3_4 Jacobian | ||
254 | | ARK436L2SA_DIRK_6_3_4' | ||
255 | | KVAERNO_7_4_5 Jacobian | ||
256 | | KVAERNO_7_4_5' | ||
257 | | ARK548L2SA_DIRK_8_4_5 Jacobian | ||
258 | | ARK548L2SA_DIRK_8_4_5' | ||
259 | | HEUN_EULER_2_1_2 Jacobian | ||
260 | | HEUN_EULER_2_1_2' | ||
261 | | BOGACKI_SHAMPINE_4_2_3 Jacobian | ||
262 | | BOGACKI_SHAMPINE_4_2_3' | ||
263 | | ARK324L2SA_ERK_4_2_3 Jacobian | ||
264 | | ARK324L2SA_ERK_4_2_3' | ||
265 | | ZONNEVELD_5_3_4 Jacobian | ||
266 | | ZONNEVELD_5_3_4' | ||
267 | | ARK436L2SA_ERK_6_3_4 Jacobian | ||
268 | | ARK436L2SA_ERK_6_3_4' | ||
269 | | SAYFY_ABURUB_6_3_4 Jacobian | ||
270 | | SAYFY_ABURUB_6_3_4' | ||
271 | | CASH_KARP_6_4_5 Jacobian | ||
272 | | CASH_KARP_6_4_5' | ||
273 | | FEHLBERG_6_4_5 Jacobian | ||
274 | | FEHLBERG_6_4_5' | ||
275 | | DORMAND_PRINCE_7_4_5 Jacobian | ||
276 | | DORMAND_PRINCE_7_4_5' | ||
277 | | ARK548L2SA_ERK_8_4_5 Jacobian | ||
278 | | ARK548L2SA_ERK_8_4_5' | ||
279 | | VERNER_8_5_6 Jacobian | ||
280 | | VERNER_8_5_6' | ||
281 | | FEHLBERG_13_7_8 Jacobian | ||
282 | | FEHLBERG_13_7_8' | ||
283 | deriving Generic | ||
284 | |||
285 | constrName :: (HasConstructor (Rep a), Generic a)=> a -> String | ||
286 | constrName = genericConstrName . from | ||
287 | |||
288 | class HasConstructor (f :: * -> *) where | ||
289 | genericConstrName :: f x -> String | ||
290 | |||
291 | instance HasConstructor f => HasConstructor (D1 c f) where | ||
292 | genericConstrName (M1 x) = genericConstrName x | ||
293 | |||
294 | instance (HasConstructor x, HasConstructor y) => HasConstructor (x :+: y) where | ||
295 | genericConstrName (L1 l) = genericConstrName l | ||
296 | genericConstrName (R1 r) = genericConstrName r | ||
297 | |||
298 | instance Constructor c => HasConstructor (C1 c f) where | ||
299 | genericConstrName x = conName x | ||
300 | |||
301 | instance Show ODEMethod where | ||
302 | show x = constrName x | ||
303 | |||
304 | -- FIXME: We can probably do better here with generics | ||
305 | getMethod :: ODEMethod -> Int | ||
306 | getMethod (SDIRK_2_1_2 _) = sDIRK_2_1_2 | ||
307 | getMethod (SDIRK_2_1_2') = sDIRK_2_1_2 | ||
308 | getMethod (BILLINGTON_3_3_2 _) = bILLINGTON_3_3_2 | ||
309 | getMethod (BILLINGTON_3_3_2') = bILLINGTON_3_3_2 | ||
310 | getMethod (TRBDF2_3_3_2 _) = tRBDF2_3_3_2 | ||
311 | getMethod (TRBDF2_3_3_2') = tRBDF2_3_3_2 | ||
312 | getMethod (KVAERNO_4_2_3 _) = kVAERNO_4_2_3 | ||
313 | getMethod (KVAERNO_4_2_3') = kVAERNO_4_2_3 | ||
314 | getMethod (ARK324L2SA_DIRK_4_2_3 _) = aRK324L2SA_DIRK_4_2_3 | ||
315 | getMethod (ARK324L2SA_DIRK_4_2_3') = aRK324L2SA_DIRK_4_2_3 | ||
316 | getMethod (CASH_5_2_4 _) = cASH_5_2_4 | ||
317 | getMethod (CASH_5_2_4') = cASH_5_2_4 | ||
318 | getMethod (CASH_5_3_4 _) = cASH_5_3_4 | ||
319 | getMethod (CASH_5_3_4') = cASH_5_3_4 | ||
320 | getMethod (SDIRK_5_3_4 _) = sDIRK_5_3_4 | ||
321 | getMethod (SDIRK_5_3_4') = sDIRK_5_3_4 | ||
322 | getMethod (KVAERNO_5_3_4 _) = kVAERNO_5_3_4 | ||
323 | getMethod (KVAERNO_5_3_4') = kVAERNO_5_3_4 | ||
324 | getMethod (ARK436L2SA_DIRK_6_3_4 _) = aRK436L2SA_DIRK_6_3_4 | ||
325 | getMethod (ARK436L2SA_DIRK_6_3_4') = aRK436L2SA_DIRK_6_3_4 | ||
326 | getMethod (KVAERNO_7_4_5 _) = kVAERNO_7_4_5 | ||
327 | getMethod (KVAERNO_7_4_5') = kVAERNO_7_4_5 | ||
328 | getMethod (ARK548L2SA_DIRK_8_4_5 _) = aRK548L2SA_DIRK_8_4_5 | ||
329 | getMethod (ARK548L2SA_DIRK_8_4_5') = aRK548L2SA_DIRK_8_4_5 | ||
330 | getMethod (HEUN_EULER_2_1_2 _) = hEUN_EULER_2_1_2 | ||
331 | getMethod (HEUN_EULER_2_1_2') = hEUN_EULER_2_1_2 | ||
332 | getMethod (BOGACKI_SHAMPINE_4_2_3 _) = bOGACKI_SHAMPINE_4_2_3 | ||
333 | getMethod (BOGACKI_SHAMPINE_4_2_3') = bOGACKI_SHAMPINE_4_2_3 | ||
334 | getMethod (ARK324L2SA_ERK_4_2_3 _) = aRK324L2SA_ERK_4_2_3 | ||
335 | getMethod (ARK324L2SA_ERK_4_2_3') = aRK324L2SA_ERK_4_2_3 | ||
336 | getMethod (ZONNEVELD_5_3_4 _) = zONNEVELD_5_3_4 | ||
337 | getMethod (ZONNEVELD_5_3_4') = zONNEVELD_5_3_4 | ||
338 | getMethod (ARK436L2SA_ERK_6_3_4 _) = aRK436L2SA_ERK_6_3_4 | ||
339 | getMethod (ARK436L2SA_ERK_6_3_4') = aRK436L2SA_ERK_6_3_4 | ||
340 | getMethod (SAYFY_ABURUB_6_3_4 _) = sAYFY_ABURUB_6_3_4 | ||
341 | getMethod (SAYFY_ABURUB_6_3_4') = sAYFY_ABURUB_6_3_4 | ||
342 | getMethod (CASH_KARP_6_4_5 _) = cASH_KARP_6_4_5 | ||
343 | getMethod (CASH_KARP_6_4_5') = cASH_KARP_6_4_5 | ||
344 | getMethod (FEHLBERG_6_4_5 _) = fEHLBERG_6_4_5 | ||
345 | getMethod (FEHLBERG_6_4_5' ) = fEHLBERG_6_4_5 | ||
346 | getMethod (DORMAND_PRINCE_7_4_5 _) = dORMAND_PRINCE_7_4_5 | ||
347 | getMethod (DORMAND_PRINCE_7_4_5') = dORMAND_PRINCE_7_4_5 | ||
348 | getMethod (ARK548L2SA_ERK_8_4_5 _) = aRK548L2SA_ERK_8_4_5 | ||
349 | getMethod (ARK548L2SA_ERK_8_4_5') = aRK548L2SA_ERK_8_4_5 | ||
350 | getMethod (VERNER_8_5_6 _) = vERNER_8_5_6 | ||
351 | getMethod (VERNER_8_5_6') = vERNER_8_5_6 | ||
352 | getMethod (FEHLBERG_13_7_8 _) = fEHLBERG_13_7_8 | ||
353 | getMethod (FEHLBERG_13_7_8') = fEHLBERG_13_7_8 | ||
354 | |||
355 | getJacobian :: ODEMethod -> Maybe Jacobian | ||
356 | getJacobian (SDIRK_2_1_2 j) = Just j | ||
357 | getJacobian (BILLINGTON_3_3_2 j) = Just j | ||
358 | getJacobian (TRBDF2_3_3_2 j) = Just j | ||
359 | getJacobian (KVAERNO_4_2_3 j) = Just j | ||
360 | getJacobian (ARK324L2SA_DIRK_4_2_3 j) = Just j | ||
361 | getJacobian (CASH_5_2_4 j) = Just j | ||
362 | getJacobian (CASH_5_3_4 j) = Just j | ||
363 | getJacobian (SDIRK_5_3_4 j) = Just j | ||
364 | getJacobian (KVAERNO_5_3_4 j) = Just j | ||
365 | getJacobian (ARK436L2SA_DIRK_6_3_4 j) = Just j | ||
366 | getJacobian (KVAERNO_7_4_5 j) = Just j | ||
367 | getJacobian (ARK548L2SA_DIRK_8_4_5 j) = Just j | ||
368 | getJacobian (HEUN_EULER_2_1_2 j) = Just j | ||
369 | getJacobian (BOGACKI_SHAMPINE_4_2_3 j) = Just j | ||
370 | getJacobian (ARK324L2SA_ERK_4_2_3 j) = Just j | ||
371 | getJacobian (ZONNEVELD_5_3_4 j) = Just j | ||
372 | getJacobian (ARK436L2SA_ERK_6_3_4 j) = Just j | ||
373 | getJacobian (SAYFY_ABURUB_6_3_4 j) = Just j | ||
374 | getJacobian (CASH_KARP_6_4_5 j) = Just j | ||
375 | getJacobian (FEHLBERG_6_4_5 j) = Just j | ||
376 | getJacobian (DORMAND_PRINCE_7_4_5 j) = Just j | ||
377 | getJacobian (ARK548L2SA_ERK_8_4_5 j) = Just j | ||
378 | getJacobian (VERNER_8_5_6 j) = Just j | ||
379 | getJacobian (FEHLBERG_13_7_8 j) = Just j | ||
380 | getJacobian _ = Nothing | ||
381 | |||
382 | -- | A version of 'odeSolveVWith' with reasonable default step control. | ||
383 | odeSolveV | ||
384 | :: ODEMethod | ||
385 | -> Maybe Double -- ^ initial step size - by default, ARKode | ||
386 | -- estimates the initial step size to be the | ||
387 | -- solution \(h\) of the equation | ||
388 | -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where | ||
389 | -- \(\ddot{y}\) is an estimated value of the | ||
390 | -- second derivative of the solution at \(t_0\) | ||
391 | -> Double -- ^ absolute tolerance for the state vector | ||
392 | -> Double -- ^ relative tolerance for the state vector | ||
393 | -> (Double -> Vector Double -> Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
394 | -> Vector Double -- ^ initial conditions | ||
395 | -> Vector Double -- ^ desired solution times | ||
396 | -> Matrix Double -- ^ solution | ||
397 | odeSolveV meth hi epsAbs epsRel f y0 ts = | ||
398 | odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts | ||
399 | where | ||
400 | g t x0 = coerce $ f t x0 | ||
401 | |||
402 | -- | A version of 'odeSolveV' with reasonable default parameters and | ||
403 | -- system of equations defined using lists. FIXME: we should say | ||
404 | -- something about the fact we could use the Jacobian but don't for | ||
405 | -- compatibility with hmatrix-gsl. | ||
406 | odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
407 | -> [Double] -- ^ initial conditions | ||
408 | -> Vector Double -- ^ desired solution times | ||
409 | -> Matrix Double -- ^ solution | ||
410 | odeSolve f y0 ts = | ||
411 | -- FIXME: These tolerances are different from the ones in GSL | ||
412 | odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) | ||
413 | where | ||
414 | g t x0 = V.fromList $ f t (V.toList x0) | ||
415 | |||
416 | odeSolveVWith :: | ||
417 | ODEMethod | ||
418 | -> StepControl | ||
419 | -> Maybe Double -- ^ initial step size - by default, ARKode | ||
420 | -- estimates the initial step size to be the | ||
421 | -- solution \(h\) of the equation | ||
422 | -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where | ||
423 | -- \(\ddot{y}\) is an estimated value of the second | ||
424 | -- derivative of the solution at \(t_0\) | ||
425 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
426 | -> V.Vector Double -- ^ Initial conditions | ||
427 | -> V.Vector Double -- ^ Desired solution times | ||
428 | -> Matrix Double -- ^ Error code or solution | ||
429 | odeSolveVWith method control initStepSize f y0 tt = | ||
430 | case odeSolveVWith' opts method control initStepSize f y0 tt of | ||
431 | Left c -> error $ show c -- FIXME | ||
432 | Right (v, _d) -> v | ||
433 | where | ||
434 | opts = ODEOpts { maxNumSteps = 10000 | ||
435 | , minStep = 1.0e-12 | ||
436 | , relTol = error "relTol" | ||
437 | , absTols = error "absTol" | ||
438 | , initStep = error "initStep" | ||
439 | , maxFail = 10 | ||
440 | } | ||
441 | |||
442 | odeSolveVWith' :: | ||
443 | ODEOpts | ||
444 | -> ODEMethod | ||
445 | -> StepControl | ||
446 | -> Maybe Double -- ^ initial step size - by default, ARKode | ||
447 | -- estimates the initial step size to be the | ||
448 | -- solution \(h\) of the equation | ||
449 | -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where | ||
450 | -- \(\ddot{y}\) is an estimated value of the second | ||
451 | -- derivative of the solution at \(t_0\) | ||
452 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
453 | -> V.Vector Double -- ^ Initial conditions | ||
454 | -> V.Vector Double -- ^ Desired solution times | ||
455 | -> Either Int (Matrix Double, SundialsDiagnostics) -- ^ Error code or solution | ||
456 | odeSolveVWith' opts method control initStepSize f y0 tt = | ||
457 | case solveOdeC (fromIntegral $ maxFail opts) | ||
458 | (fromIntegral $ maxNumSteps opts) (coerce $ minStep opts) | ||
459 | (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) | ||
460 | (coerce f) (coerce y0) (coerce tt) of | ||
461 | Left c -> Left $ fromIntegral c | ||
462 | Right (v, d) -> Right (reshape l (coerce v), d) | ||
463 | where | ||
464 | l = size y0 | ||
465 | scise (X aTol rTol) = coerce (V.replicate l aTol, rTol) | ||
466 | scise (X' aTol rTol) = coerce (V.replicate l aTol, rTol) | ||
467 | scise (XX' aTol rTol yScale _yDotScale) = coerce (V.replicate l aTol, yScale * rTol) | ||
468 | -- FIXME; Should we check that the length of ss is correct? | ||
469 | scise (ScXX' aTol rTol yScale _yDotScale ss) = coerce (V.map (* aTol) ss, yScale * rTol) | ||
470 | jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $ | ||
471 | getJacobian method | ||
472 | matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } | ||
473 | where | ||
474 | nr = fromIntegral $ rows m | ||
475 | nc = fromIntegral $ cols m | ||
476 | -- FIXME: efficiency | ||
477 | vs = V.fromList $ map coerce $ concat $ toLists m | ||
478 | |||
479 | solveOdeC :: | ||
480 | CInt -> | ||
481 | CLong -> | ||
482 | CDouble -> | ||
483 | CInt -> | ||
484 | Maybe CDouble -> | ||
485 | (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) -> | ||
486 | (V.Vector CDouble, CDouble) -> | ||
487 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
488 | -> V.Vector CDouble -- ^ Initial conditions | ||
489 | -> V.Vector CDouble -- ^ Desired solution times | ||
490 | -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution | ||
491 | solveOdeC maxErrTestFails maxNumSteps_ minStep_ method initStepSize | ||
492 | jacH (aTols, rTol) fun f0 ts = unsafePerformIO $ do | ||
493 | |||
494 | let isInitStepSize :: CInt | ||
495 | isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize | ||
496 | ss :: CDouble | ||
497 | ss = case initStepSize of | ||
498 | -- It would be better to put an error message here but | ||
499 | -- inline-c seems to evaluate this even if it is never | ||
500 | -- used :( | ||
501 | Nothing -> 0.0 | ||
502 | Just x -> x | ||
503 | |||
504 | let dim = V.length f0 | ||
505 | nEq :: CLong | ||
506 | nEq = fromIntegral dim | ||
507 | nTs :: CInt | ||
508 | nTs = fromIntegral $ V.length ts | ||
509 | -- FIXME: I believe this gets taken from the ghc heap and so should | ||
510 | -- be subject to garbage collection. | ||
511 | quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) | ||
512 | qMatMut <- V.thaw quasiMatrixRes | ||
513 | diagnostics :: V.Vector CLong <- createVector 10 -- FIXME | ||
514 | diagMut <- V.thaw diagnostics | ||
515 | -- We need the types that sundials expects. These are tied together | ||
516 | -- in 'CLangToHaskellTypes'. FIXME: The Haskell type is currently empty! | ||
517 | let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt | ||
518 | funIO x y f _ptr = do | ||
519 | -- Convert the pointer we get from C (y) to a vector, and then | ||
520 | -- apply the user-supplied function. | ||
521 | fImm <- fun x <$> getDataFromContents dim y | ||
522 | -- Fill in the provided pointer with the resulting vector. | ||
523 | putDataInContents fImm dim f | ||
524 | -- FIXME: I don't understand what this comment means | ||
525 | -- Unsafe since the function will be called many times. | ||
526 | [CU.exp| int{ 0 } |] | ||
527 | let isJac :: CInt | ||
528 | isJac = fromIntegral $ fromEnum $ isJust jacH | ||
529 | jacIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunMatrix -> | ||
530 | Ptr () -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunVector -> | ||
531 | IO CInt | ||
532 | jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do | ||
533 | case jacH of | ||
534 | Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" | ||
535 | Just jacI -> do j <- jacI t <$> getDataFromContents dim y | ||
536 | poke jacS j | ||
537 | -- FIXME: I don't understand what this comment means | ||
538 | -- Unsafe since the function will be called many times. | ||
539 | [CU.exp| int{ 0 } |] | ||
540 | |||
541 | res <- [C.block| int { | ||
542 | /* general problem variables */ | ||
543 | |||
544 | int flag; /* reusable error-checking flag */ | ||
545 | int i, j; /* reusable loop indices */ | ||
546 | N_Vector y = NULL; /* empty vector for storing solution */ | ||
547 | N_Vector tv = NULL; /* empty vector for storing absolute tolerances */ | ||
548 | SUNMatrix A = NULL; /* empty matrix for linear solver */ | ||
549 | SUNLinearSolver LS = NULL; /* empty linear solver object */ | ||
550 | void *arkode_mem = NULL; /* empty ARKode memory structure */ | ||
551 | realtype t; | ||
552 | long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf; | ||
553 | |||
554 | /* general problem parameters */ | ||
555 | |||
556 | realtype T0 = RCONST(($vec-ptr:(double *ts))[0]); /* initial time */ | ||
557 | sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ | ||
558 | |||
559 | /* Initialize data structures */ | ||
560 | |||
561 | y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ | ||
562 | if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; | ||
563 | /* Specify initial condition */ | ||
564 | for (i = 0; i < NEQ; i++) { | ||
565 | NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i]; | ||
566 | }; | ||
567 | |||
568 | tv = N_VNew_Serial(NEQ); /* Create serial vector for absolute tolerances */ | ||
569 | if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1; | ||
570 | /* Specify tolerances */ | ||
571 | for (i = 0; i < NEQ; i++) { | ||
572 | NV_Ith_S(tv,i) = ($vec-ptr:(double *aTols))[i]; | ||
573 | }; | ||
574 | |||
575 | arkode_mem = ARKodeCreate(); /* Create the solver memory */ | ||
576 | if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; | ||
577 | |||
578 | /* Call ARKodeInit to initialize the integrator memory and specify the */ | ||
579 | /* right-hand side function in y'=f(t,y), the inital time T0, and */ | ||
580 | /* the initial dependent variable vector y. Note: we treat the */ | ||
581 | /* problem as fully implicit and set f_E to NULL and f_I to f. */ | ||
582 | |||
583 | /* Here we use the C types defined in helpers.h which tie up with */ | ||
584 | /* the Haskell types defined in CLangToHaskellTypes */ | ||
585 | if ($(int method) < MIN_DIRK_NUM) { | ||
586 | flag = ARKodeInit(arkode_mem, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), NULL, T0, y); | ||
587 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | ||
588 | } else { | ||
589 | flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y); | ||
590 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | ||
591 | } | ||
592 | |||
593 | flag = ARKodeSetMinStep(arkode_mem, $(double minStep_)); | ||
594 | if (check_flag(&flag, "ARKodeSetMinStep", 1)) return 1; | ||
595 | flag = ARKodeSetMaxNumSteps(arkode_mem, $(long int maxNumSteps_)); | ||
596 | if (check_flag(&flag, "ARKodeSetMaxNumSteps", 1)) return 1; | ||
597 | flag = ARKodeSetMaxErrTestFails(arkode_mem, $(int maxErrTestFails)); | ||
598 | if (check_flag(&flag, "ARKodeSetMaxErrTestFails", 1)) return 1; | ||
599 | |||
600 | /* Set routines */ | ||
601 | flag = ARKodeSVtolerances(arkode_mem, $(double rTol), tv); | ||
602 | if (check_flag(&flag, "ARKodeSVtolerances", 1)) return 1; | ||
603 | |||
604 | /* Initialize dense matrix data structure and solver */ | ||
605 | A = SUNDenseMatrix(NEQ, NEQ); | ||
606 | if (check_flag((void *)A, "SUNDenseMatrix", 0)) return 1; | ||
607 | LS = SUNDenseLinearSolver(y, A); | ||
608 | if (check_flag((void *)LS, "SUNDenseLinearSolver", 0)) return 1; | ||
609 | |||
610 | /* Attach matrix and linear solver */ | ||
611 | flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); | ||
612 | if (check_flag(&flag, "ARKDlsSetLinearSolver", 1)) return 1; | ||
613 | |||
614 | /* Set the initial step size if there is one */ | ||
615 | if ($(int isInitStepSize)) { | ||
616 | /* FIXME: We could check if the initial step size is 0 */ | ||
617 | /* or even NaN and then throw an error */ | ||
618 | flag = ARKodeSetInitStep(arkode_mem, $(double ss)); | ||
619 | if (check_flag(&flag, "ARKodeSetInitStep", 1)) return 1; | ||
620 | } | ||
621 | |||
622 | /* Set the Jacobian if there is one */ | ||
623 | if ($(int isJac)) { | ||
624 | flag = ARKDlsSetJacFn(arkode_mem, $fun:(int (* jacIO) (double t, SunVector y[], SunVector fy[], SunMatrix Jac[], void * params, SunVector tmp1[], SunVector tmp2[], SunVector tmp3[]))); | ||
625 | if (check_flag(&flag, "ARKDlsSetJacFn", 1)) return 1; | ||
626 | } | ||
627 | |||
628 | /* Store initial conditions */ | ||
629 | for (j = 0; j < NEQ; j++) { | ||
630 | ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j); | ||
631 | } | ||
632 | |||
633 | /* Explicitly set the method */ | ||
634 | if ($(int method) >= MIN_DIRK_NUM) { | ||
635 | flag = ARKodeSetIRKTableNum(arkode_mem, $(int method)); | ||
636 | if (check_flag(&flag, "ARKodeSetIRKTableNum", 1)) return 1; | ||
637 | } else { | ||
638 | flag = ARKodeSetERKTableNum(arkode_mem, $(int method)); | ||
639 | if (check_flag(&flag, "ARKodeSetERKTableNum", 1)) return 1; | ||
640 | } | ||
641 | |||
642 | /* Main time-stepping loop: calls ARKode to perform the integration */ | ||
643 | /* Stops when the final time has been reached */ | ||
644 | for (i = 1; i < $(int nTs); i++) { | ||
645 | |||
646 | flag = ARKode(arkode_mem, ($vec-ptr:(double *ts))[i], y, &t, ARK_NORMAL); /* call integrator */ | ||
647 | if (check_flag(&flag, "ARKode", 1)) break; | ||
648 | |||
649 | /* Store the results for Haskell */ | ||
650 | for (j = 0; j < NEQ; j++) { | ||
651 | ($vec-ptr:(double *qMatMut))[i * NEQ + j] = NV_Ith_S(y,j); | ||
652 | } | ||
653 | |||
654 | /* unsuccessful solve: break */ | ||
655 | if (flag < 0) { | ||
656 | fprintf(stderr,"Solver failure, stopping integration\n"); | ||
657 | break; | ||
658 | } | ||
659 | } | ||
660 | |||
661 | /* Get some final statistics on how the solve progressed */ | ||
662 | |||
663 | flag = ARKodeGetNumSteps(arkode_mem, &nst); | ||
664 | check_flag(&flag, "ARKodeGetNumSteps", 1); | ||
665 | ($vec-ptr:(long int *diagMut))[0] = nst; | ||
666 | |||
667 | flag = ARKodeGetNumStepAttempts(arkode_mem, &nst_a); | ||
668 | check_flag(&flag, "ARKodeGetNumStepAttempts", 1); | ||
669 | ($vec-ptr:(long int *diagMut))[1] = nst_a; | ||
670 | |||
671 | flag = ARKodeGetNumRhsEvals(arkode_mem, &nfe, &nfi); | ||
672 | check_flag(&flag, "ARKodeGetNumRhsEvals", 1); | ||
673 | ($vec-ptr:(long int *diagMut))[2] = nfe; | ||
674 | ($vec-ptr:(long int *diagMut))[3] = nfi; | ||
675 | |||
676 | flag = ARKodeGetNumLinSolvSetups(arkode_mem, &nsetups); | ||
677 | check_flag(&flag, "ARKodeGetNumLinSolvSetups", 1); | ||
678 | ($vec-ptr:(long int *diagMut))[4] = nsetups; | ||
679 | |||
680 | flag = ARKodeGetNumErrTestFails(arkode_mem, &netf); | ||
681 | check_flag(&flag, "ARKodeGetNumErrTestFails", 1); | ||
682 | ($vec-ptr:(long int *diagMut))[5] = netf; | ||
683 | |||
684 | flag = ARKodeGetNumNonlinSolvIters(arkode_mem, &nni); | ||
685 | check_flag(&flag, "ARKodeGetNumNonlinSolvIters", 1); | ||
686 | ($vec-ptr:(long int *diagMut))[6] = nni; | ||
687 | |||
688 | flag = ARKodeGetNumNonlinSolvConvFails(arkode_mem, &ncfn); | ||
689 | check_flag(&flag, "ARKodeGetNumNonlinSolvConvFails", 1); | ||
690 | ($vec-ptr:(long int *diagMut))[7] = ncfn; | ||
691 | |||
692 | flag = ARKDlsGetNumJacEvals(arkode_mem, &nje); | ||
693 | check_flag(&flag, "ARKDlsGetNumJacEvals", 1); | ||
694 | ($vec-ptr:(long int *diagMut))[8] = ncfn; | ||
695 | |||
696 | flag = ARKDlsGetNumRhsEvals(arkode_mem, &nfeLS); | ||
697 | check_flag(&flag, "ARKDlsGetNumRhsEvals", 1); | ||
698 | ($vec-ptr:(long int *diagMut))[9] = ncfn; | ||
699 | |||
700 | /* Clean up and return */ | ||
701 | N_VDestroy(y); /* Free y vector */ | ||
702 | N_VDestroy(tv); /* Free tv vector */ | ||
703 | ARKodeFree(&arkode_mem); /* Free integrator memory */ | ||
704 | SUNLinSolFree(LS); /* Free linear solver */ | ||
705 | SUNMatDestroy(A); /* Free A matrix */ | ||
706 | |||
707 | return flag; | ||
708 | } |] | ||
709 | if res == 0 | ||
710 | then do | ||
711 | preD <- V.freeze diagMut | ||
712 | let d = SundialsDiagnostics (fromIntegral $ preD V.!0) | ||
713 | (fromIntegral $ preD V.!1) | ||
714 | (fromIntegral $ preD V.!2) | ||
715 | (fromIntegral $ preD V.!3) | ||
716 | (fromIntegral $ preD V.!4) | ||
717 | (fromIntegral $ preD V.!5) | ||
718 | (fromIntegral $ preD V.!6) | ||
719 | (fromIntegral $ preD V.!7) | ||
720 | (fromIntegral $ preD V.!8) | ||
721 | (fromIntegral $ preD V.!9) | ||
722 | m <- V.freeze qMatMut | ||
723 | return $ Right (m, d) | ||
724 | else do | ||
725 | return $ Left res | ||
726 | |||
727 | data ButcherTable = ButcherTable { am :: Matrix Double | ||
728 | , cv :: Vector Double | ||
729 | , bv :: Vector Double | ||
730 | , b2v :: Vector Double | ||
731 | } | ||
732 | deriving Show | ||
733 | |||
734 | data ButcherTable' a = ButcherTable' { am' :: V.Vector a | ||
735 | , cv' :: V.Vector a | ||
736 | , bv' :: V.Vector a | ||
737 | , b2v' :: V.Vector a | ||
738 | } | ||
739 | deriving Show | ||
740 | |||
741 | butcherTable :: ODEMethod -> ButcherTable | ||
742 | butcherTable method = | ||
743 | case getBT method of | ||
744 | Left c -> error $ show c -- FIXME | ||
745 | Right (ButcherTable' v w x y, sqp) -> | ||
746 | ButcherTable { am = subMatrix (0, 0) (s, s) $ (arkSMax >< arkSMax) (V.toList v) | ||
747 | , cv = subVector 0 s w | ||
748 | , bv = subVector 0 s x | ||
749 | , b2v = subVector 0 s y | ||
750 | } | ||
751 | where | ||
752 | s = fromIntegral $ sqp V.! 0 | ||
753 | |||
754 | getBT :: ODEMethod -> Either Int (ButcherTable' Double, V.Vector Int) | ||
755 | getBT method = case getButcherTable method of | ||
756 | Left c -> | ||
757 | Left $ fromIntegral c | ||
758 | Right (ButcherTable' a b c d, sqp) -> | ||
759 | Right $ ( ButcherTable' (coerce a) (coerce b) (coerce c) (coerce d) | ||
760 | , V.map fromIntegral sqp ) | ||
761 | |||
762 | getButcherTable :: ODEMethod | ||
763 | -> Either CInt (ButcherTable' CDouble, V.Vector CInt) | ||
764 | getButcherTable method = unsafePerformIO $ do | ||
765 | -- ARKode seems to want an ODE in order to set and then get the | ||
766 | -- Butcher tableau so here's one to keep it happy | ||
767 | let funI :: CDouble -> V.Vector CDouble -> V.Vector CDouble | ||
768 | funI _t ys = V.fromList [ ys V.! 0 ] | ||
769 | let funE :: CDouble -> V.Vector CDouble -> V.Vector CDouble | ||
770 | funE _t ys = V.fromList [ ys V.! 0 ] | ||
771 | f0 = V.fromList [ 1.0 ] | ||
772 | ts = V.fromList [ 0.0 ] | ||
773 | dim = V.length f0 | ||
774 | nEq :: CLong | ||
775 | nEq = fromIntegral dim | ||
776 | mN :: CInt | ||
777 | mN = fromIntegral $ getMethod method | ||
778 | |||
779 | btSQP :: V.Vector CInt <- createVector 3 | ||
780 | btSQPMut <- V.thaw btSQP | ||
781 | btAs :: V.Vector CDouble <- createVector (arkSMax * arkSMax) | ||
782 | btAsMut <- V.thaw btAs | ||
783 | btCs :: V.Vector CDouble <- createVector arkSMax | ||
784 | btBs :: V.Vector CDouble <- createVector arkSMax | ||
785 | btB2s :: V.Vector CDouble <- createVector arkSMax | ||
786 | btCsMut <- V.thaw btCs | ||
787 | btBsMut <- V.thaw btBs | ||
788 | btB2sMut <- V.thaw btB2s | ||
789 | let funIOI :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt | ||
790 | funIOI x y f _ptr = do | ||
791 | fImm <- funI x <$> getDataFromContents dim y | ||
792 | putDataInContents fImm dim f | ||
793 | -- FIXME: I don't understand what this comment means | ||
794 | -- Unsafe since the function will be called many times. | ||
795 | [CU.exp| int{ 0 } |] | ||
796 | let funIOE :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt | ||
797 | funIOE x y f _ptr = do | ||
798 | fImm <- funE x <$> getDataFromContents dim y | ||
799 | putDataInContents fImm dim f | ||
800 | -- FIXME: I don't understand what this comment means | ||
801 | -- Unsafe since the function will be called many times. | ||
802 | [CU.exp| int{ 0 } |] | ||
803 | res <- [C.block| int { | ||
804 | /* general problem variables */ | ||
805 | |||
806 | int flag; /* reusable error-checking flag */ | ||
807 | N_Vector y = NULL; /* empty vector for storing solution */ | ||
808 | void *arkode_mem = NULL; /* empty ARKode memory structure */ | ||
809 | int i, j; /* reusable loop indices */ | ||
810 | |||
811 | /* general problem parameters */ | ||
812 | |||
813 | realtype T0 = RCONST(($vec-ptr:(double *ts))[0]); /* initial time */ | ||
814 | sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars */ | ||
815 | |||
816 | /* Initialize data structures */ | ||
817 | |||
818 | y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ | ||
819 | if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; | ||
820 | /* Specify initial condition */ | ||
821 | for (i = 0; i < NEQ; i++) { | ||
822 | NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i]; | ||
823 | }; | ||
824 | arkode_mem = ARKodeCreate(); /* Create the solver memory */ | ||
825 | if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; | ||
826 | |||
827 | flag = ARKodeInit(arkode_mem, $fun:(int (* funIOE) (double t, SunVector y[], SunVector dydt[], void * params)), $fun:(int (* funIOI) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y); | ||
828 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | ||
829 | |||
830 | if ($(int mN) >= MIN_DIRK_NUM) { | ||
831 | flag = ARKodeSetIRKTableNum(arkode_mem, $(int mN)); | ||
832 | if (check_flag(&flag, "ARKodeSetIRKTableNum", 1)) return 1; | ||
833 | } else { | ||
834 | flag = ARKodeSetERKTableNum(arkode_mem, $(int mN)); | ||
835 | if (check_flag(&flag, "ARKodeSetERKTableNum", 1)) return 1; | ||
836 | } | ||
837 | |||
838 | int s, q, p; | ||
839 | realtype *ai = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype)); | ||
840 | realtype *ae = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype)); | ||
841 | realtype *ci = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
842 | realtype *ce = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
843 | realtype *bi = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
844 | realtype *be = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
845 | realtype *b2i = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
846 | realtype *b2e = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
847 | flag = ARKodeGetCurrentButcherTables(arkode_mem, &s, &q, &p, ai, ae, ci, ce, bi, be, b2i, b2e); | ||
848 | if (check_flag(&flag, "ARKode", 1)) return 1; | ||
849 | $vec-ptr:(int *btSQPMut)[0] = s; | ||
850 | $vec-ptr:(int *btSQPMut)[1] = q; | ||
851 | $vec-ptr:(int *btSQPMut)[2] = p; | ||
852 | for (i = 0; i < s; i++) { | ||
853 | for (j = 0; j < s; j++) { | ||
854 | /* FIXME: double should be realtype */ | ||
855 | ($vec-ptr:(double *btAsMut))[i * ARK_S_MAX + j] = ai[i * ARK_S_MAX + j]; | ||
856 | } | ||
857 | } | ||
858 | |||
859 | for (i = 0; i < s; i++) { | ||
860 | ($vec-ptr:(double *btCsMut))[i] = ci[i]; | ||
861 | ($vec-ptr:(double *btBsMut))[i] = bi[i]; | ||
862 | ($vec-ptr:(double *btB2sMut))[i] = b2i[i]; | ||
863 | } | ||
864 | |||
865 | /* Clean up and return */ | ||
866 | N_VDestroy(y); /* Free y vector */ | ||
867 | ARKodeFree(&arkode_mem); /* Free integrator memory */ | ||
868 | |||
869 | return flag; | ||
870 | } |] | ||
871 | if res == 0 | ||
872 | then do | ||
873 | x <- V.freeze btAsMut | ||
874 | y <- V.freeze btSQPMut | ||
875 | z <- V.freeze btCsMut | ||
876 | u <- V.freeze btBsMut | ||
877 | v <- V.freeze btB2sMut | ||
878 | return $ Right (ButcherTable' { am' = x, cv' = z, bv' = u, b2v' = v }, y) | ||
879 | else do | ||
880 | return $ Left res | ||
881 | |||
882 | -- | Adaptive step-size control | ||
883 | -- functions. | ||
884 | -- | ||
885 | -- [GSL](https://www.gnu.org/software/gsl/doc/html/ode-initval.html#adaptive-step-size-control) | ||
886 | -- allows the user to control the step size adjustment using | ||
887 | -- \(D_i = \epsilon^{abs}s_i + \epsilon^{rel}(a_{y} |y_i| + a_{dy/dt} h |\dot{y}_i|)\) where | ||
888 | -- \(\epsilon^{abs}\) is the required absolute error, \(\epsilon^{rel}\) | ||
889 | -- is the required relative error, \(s_i\) is a vector of scaling | ||
890 | -- factors, \(a_{y}\) is a scaling factor for the solution \(y\) and | ||
891 | -- \(a_{dydt}\) is a scaling factor for the derivative of the solution \(dy/dt\). | ||
892 | -- | ||
893 | -- [ARKode](https://computation.llnl.gov/projects/sundials/arkode) | ||
894 | -- allows the user to control the step size adjustment using | ||
895 | -- \(\eta^{rel}|y_i| + \eta^{abs}_i\). For compatibility with | ||
896 | -- [hmatrix-gsl](https://hackage.haskell.org/package/hmatrix-gsl), | ||
897 | -- tolerances for \(y\) and \(\dot{y}\) can be specified but the latter have no | ||
898 | -- effect. | ||
899 | data StepControl = X Double Double -- ^ absolute and relative tolerance for \(y\); in GSL terms, \(a_{y} = 1\) and \(a_{dy/dt} = 0\); in ARKode terms, the \(\eta^{abs}_i\) are identical | ||
900 | | X' Double Double -- ^ absolute and relative tolerance for \(\dot{y}\); in GSL terms, \(a_{y} = 0\) and \(a_{dy/dt} = 1\); in ARKode terms, the latter is treated as the relative tolerance for \(y\) so this is the same as specifying 'X' which may be entirely incorrect for the given problem | ||
901 | | XX' Double Double Double Double -- ^ include both via relative tolerance | ||
902 | -- scaling factors \(a_y\), \(a_{{dy}/{dt}}\); in ARKode terms, the latter is ignored and \(\eta^{rel} = a_{y}\epsilon^{rel}\) | ||
903 | | ScXX' Double Double Double Double (Vector Double) -- ^ scale absolute tolerance of \(y_i\); in ARKode terms, \(a_{{dy}/{dt}}\) is ignored, \(\eta^{abs}_i = s_i \epsilon^{abs}\) and \(\eta^{rel} = a_{y}\epsilon^{rel}\) | ||
diff --git a/packages/sundials/src/Numeric/Sundials/Arkode.hsc b/packages/sundials/src/Numeric/Sundials/Arkode.hsc new file mode 100644 index 0000000..0850258 --- /dev/null +++ b/packages/sundials/src/Numeric/Sundials/Arkode.hsc | |||
@@ -0,0 +1,204 @@ | |||
1 | {-# LANGUAGE QuasiQuotes #-} | ||
2 | {-# LANGUAGE TemplateHaskell #-} | ||
3 | {-# LANGUAGE OverloadedStrings #-} | ||
4 | {-# LANGUAGE EmptyDataDecls #-} | ||
5 | |||
6 | module Numeric.Sundials.Arkode where | ||
7 | |||
8 | import Foreign | ||
9 | import Foreign.C.Types | ||
10 | |||
11 | import Language.C.Types as CT | ||
12 | |||
13 | import qualified Data.Vector.Storable as VS | ||
14 | import qualified Data.Vector.Storable.Mutable as VM | ||
15 | |||
16 | import qualified Language.Haskell.TH as TH | ||
17 | import qualified Data.Map as Map | ||
18 | import Language.C.Inline.Context | ||
19 | |||
20 | import qualified Data.Vector.Storable as V | ||
21 | |||
22 | |||
23 | #include <stdio.h> | ||
24 | #include <sundials/sundials_nvector.h> | ||
25 | #include <sundials/sundials_matrix.h> | ||
26 | #include <nvector/nvector_serial.h> | ||
27 | #include <sunmatrix/sunmatrix_dense.h> | ||
28 | #include <arkode/arkode.h> | ||
29 | #include <cvode/cvode.h> | ||
30 | |||
31 | |||
32 | data SunVector | ||
33 | data SunMatrix = SunMatrix { rows :: CInt | ||
34 | , cols :: CInt | ||
35 | , vals :: V.Vector CDouble | ||
36 | } | ||
37 | |||
38 | -- | This is true only if configured/ built as 64 bits | ||
39 | type SunIndexType = CLong | ||
40 | |||
41 | sunTypesTable :: Map.Map TypeSpecifier TH.TypeQ | ||
42 | sunTypesTable = Map.fromList | ||
43 | [ | ||
44 | (TypeName "sunindextype", [t| SunIndexType |] ) | ||
45 | , (TypeName "SunVector", [t| SunVector |] ) | ||
46 | , (TypeName "SunMatrix", [t| SunMatrix |] ) | ||
47 | ] | ||
48 | |||
49 | sunCtx :: Context | ||
50 | sunCtx = mempty {ctxTypesTable = sunTypesTable} | ||
51 | |||
52 | getMatrixDataFromContents :: Ptr SunMatrix -> IO SunMatrix | ||
53 | getMatrixDataFromContents ptr = do | ||
54 | qtr <- getContentMatrixPtr ptr | ||
55 | rs <- getNRows qtr | ||
56 | cs <- getNCols qtr | ||
57 | rtr <- getMatrixData qtr | ||
58 | vs <- vectorFromC (fromIntegral $ rs * cs) rtr | ||
59 | return $ SunMatrix { rows = rs, cols = cs, vals = vs } | ||
60 | |||
61 | putMatrixDataFromContents :: SunMatrix -> Ptr SunMatrix -> IO () | ||
62 | putMatrixDataFromContents mat ptr = do | ||
63 | let rs = rows mat | ||
64 | cs = cols mat | ||
65 | vs = vals mat | ||
66 | qtr <- getContentMatrixPtr ptr | ||
67 | putNRows rs qtr | ||
68 | putNCols cs qtr | ||
69 | rtr <- getMatrixData qtr | ||
70 | vectorToC vs (fromIntegral $ rs * cs) rtr | ||
71 | |||
72 | instance Storable SunMatrix where | ||
73 | poke = flip putMatrixDataFromContents | ||
74 | peek = getMatrixDataFromContents | ||
75 | sizeOf _ = error "sizeOf not supported for SunMatrix" | ||
76 | alignment _ = error "alignment not supported for SunMatrix" | ||
77 | |||
78 | vectorFromC :: Storable a => Int -> Ptr a -> IO (VS.Vector a) | ||
79 | vectorFromC len ptr = do | ||
80 | ptr' <- newForeignPtr_ ptr | ||
81 | VS.freeze $ VM.unsafeFromForeignPtr0 ptr' len | ||
82 | |||
83 | vectorToC :: Storable a => VS.Vector a -> Int -> Ptr a -> IO () | ||
84 | vectorToC vec len ptr = do | ||
85 | ptr' <- newForeignPtr_ ptr | ||
86 | VS.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | ||
87 | |||
88 | getDataFromContents :: Int -> Ptr SunVector -> IO (VS.Vector CDouble) | ||
89 | getDataFromContents len ptr = do | ||
90 | qtr <- getContentPtr ptr | ||
91 | rtr <- getData qtr | ||
92 | vectorFromC len rtr | ||
93 | |||
94 | putDataInContents :: Storable a => VS.Vector a -> Int -> Ptr b -> IO () | ||
95 | putDataInContents vec len ptr = do | ||
96 | qtr <- getContentPtr ptr | ||
97 | rtr <- getData qtr | ||
98 | vectorToC vec len rtr | ||
99 | |||
100 | #def typedef struct _generic_N_Vector SunVector; | ||
101 | #def typedef struct _N_VectorContent_Serial SunContent; | ||
102 | |||
103 | #def typedef struct _generic_SUNMatrix SunMatrix; | ||
104 | #def typedef struct _SUNMatrixContent_Dense SunMatrixContent; | ||
105 | |||
106 | getContentMatrixPtr :: Storable a => Ptr b -> IO a | ||
107 | getContentMatrixPtr ptr = (#peek SunMatrix, content) ptr | ||
108 | |||
109 | getNRows :: Ptr b -> IO CInt | ||
110 | getNRows ptr = (#peek SunMatrixContent, M) ptr | ||
111 | putNRows :: CInt -> Ptr b -> IO () | ||
112 | putNRows nr ptr = (#poke SunMatrixContent, M) ptr nr | ||
113 | |||
114 | getNCols :: Ptr b -> IO CInt | ||
115 | getNCols ptr = (#peek SunMatrixContent, N) ptr | ||
116 | putNCols :: CInt -> Ptr b -> IO () | ||
117 | putNCols nc ptr = (#poke SunMatrixContent, N) ptr nc | ||
118 | |||
119 | getMatrixData :: Storable a => Ptr b -> IO a | ||
120 | getMatrixData ptr = (#peek SunMatrixContent, data) ptr | ||
121 | |||
122 | getContentPtr :: Storable a => Ptr b -> IO a | ||
123 | getContentPtr ptr = (#peek SunVector, content) ptr | ||
124 | |||
125 | getData :: Storable a => Ptr b -> IO a | ||
126 | getData ptr = (#peek SunContent, data) ptr | ||
127 | |||
128 | cV_ADAMS :: Int | ||
129 | cV_ADAMS = #const CV_ADAMS | ||
130 | cV_BDF :: Int | ||
131 | cV_BDF = #const CV_BDF | ||
132 | |||
133 | arkSMax :: Int | ||
134 | arkSMax = #const ARK_S_MAX | ||
135 | |||
136 | mIN_DIRK_NUM, mAX_DIRK_NUM :: Int | ||
137 | mIN_DIRK_NUM = #const MIN_DIRK_NUM | ||
138 | mAX_DIRK_NUM = #const MAX_DIRK_NUM | ||
139 | |||
140 | -- FIXME: We could just use inline-c instead | ||
141 | |||
142 | -- Butcher table accessors -- implicit | ||
143 | sDIRK_2_1_2 :: Int | ||
144 | sDIRK_2_1_2 = #const SDIRK_2_1_2 | ||
145 | bILLINGTON_3_3_2 :: Int | ||
146 | bILLINGTON_3_3_2 = #const BILLINGTON_3_3_2 | ||
147 | tRBDF2_3_3_2 :: Int | ||
148 | tRBDF2_3_3_2 = #const TRBDF2_3_3_2 | ||
149 | kVAERNO_4_2_3 :: Int | ||
150 | kVAERNO_4_2_3 = #const KVAERNO_4_2_3 | ||
151 | aRK324L2SA_DIRK_4_2_3 :: Int | ||
152 | aRK324L2SA_DIRK_4_2_3 = #const ARK324L2SA_DIRK_4_2_3 | ||
153 | cASH_5_2_4 :: Int | ||
154 | cASH_5_2_4 = #const CASH_5_2_4 | ||
155 | cASH_5_3_4 :: Int | ||
156 | cASH_5_3_4 = #const CASH_5_3_4 | ||
157 | sDIRK_5_3_4 :: Int | ||
158 | sDIRK_5_3_4 = #const SDIRK_5_3_4 | ||
159 | kVAERNO_5_3_4 :: Int | ||
160 | kVAERNO_5_3_4 = #const KVAERNO_5_3_4 | ||
161 | aRK436L2SA_DIRK_6_3_4 :: Int | ||
162 | aRK436L2SA_DIRK_6_3_4 = #const ARK436L2SA_DIRK_6_3_4 | ||
163 | kVAERNO_7_4_5 :: Int | ||
164 | kVAERNO_7_4_5 = #const KVAERNO_7_4_5 | ||
165 | aRK548L2SA_DIRK_8_4_5 :: Int | ||
166 | aRK548L2SA_DIRK_8_4_5 = #const ARK548L2SA_DIRK_8_4_5 | ||
167 | |||
168 | -- #define DEFAULT_DIRK_2 SDIRK_2_1_2 | ||
169 | -- #define DEFAULT_DIRK_3 ARK324L2SA_DIRK_4_2_3 | ||
170 | -- #define DEFAULT_DIRK_4 SDIRK_5_3_4 | ||
171 | -- #define DEFAULT_DIRK_5 ARK548L2SA_DIRK_8_4_5 | ||
172 | |||
173 | -- Butcher table accessors -- explicit | ||
174 | hEUN_EULER_2_1_2 :: Int | ||
175 | hEUN_EULER_2_1_2 = #const HEUN_EULER_2_1_2 | ||
176 | bOGACKI_SHAMPINE_4_2_3 :: Int | ||
177 | bOGACKI_SHAMPINE_4_2_3 = #const BOGACKI_SHAMPINE_4_2_3 | ||
178 | aRK324L2SA_ERK_4_2_3 :: Int | ||
179 | aRK324L2SA_ERK_4_2_3 = #const ARK324L2SA_ERK_4_2_3 | ||
180 | zONNEVELD_5_3_4 :: Int | ||
181 | zONNEVELD_5_3_4 = #const ZONNEVELD_5_3_4 | ||
182 | aRK436L2SA_ERK_6_3_4 :: Int | ||
183 | aRK436L2SA_ERK_6_3_4 = #const ARK436L2SA_ERK_6_3_4 | ||
184 | sAYFY_ABURUB_6_3_4 :: Int | ||
185 | sAYFY_ABURUB_6_3_4 = #const SAYFY_ABURUB_6_3_4 | ||
186 | cASH_KARP_6_4_5 :: Int | ||
187 | cASH_KARP_6_4_5 = #const CASH_KARP_6_4_5 | ||
188 | fEHLBERG_6_4_5 :: Int | ||
189 | fEHLBERG_6_4_5 = #const FEHLBERG_6_4_5 | ||
190 | dORMAND_PRINCE_7_4_5 :: Int | ||
191 | dORMAND_PRINCE_7_4_5 = #const DORMAND_PRINCE_7_4_5 | ||
192 | aRK548L2SA_ERK_8_4_5 :: Int | ||
193 | aRK548L2SA_ERK_8_4_5 = #const ARK548L2SA_ERK_8_4_5 | ||
194 | vERNER_8_5_6 :: Int | ||
195 | vERNER_8_5_6 = #const VERNER_8_5_6 | ||
196 | fEHLBERG_13_7_8 :: Int | ||
197 | fEHLBERG_13_7_8 = #const FEHLBERG_13_7_8 | ||
198 | |||
199 | -- #define DEFAULT_ERK_2 HEUN_EULER_2_1_2 | ||
200 | -- #define DEFAULT_ERK_3 BOGACKI_SHAMPINE_4_2_3 | ||
201 | -- #define DEFAULT_ERK_4 ZONNEVELD_5_3_4 | ||
202 | -- #define DEFAULT_ERK_5 CASH_KARP_6_4_5 | ||
203 | -- #define DEFAULT_ERK_6 VERNER_8_5_6 | ||
204 | -- #define DEFAULT_ERK_8 FEHLBERG_13_7_8 | ||
diff --git a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs new file mode 100644 index 0000000..a6f185e --- /dev/null +++ b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs | |||
@@ -0,0 +1,476 @@ | |||
1 | {-# OPTIONS_GHC -Wall #-} | ||
2 | |||
3 | {-# LANGUAGE QuasiQuotes #-} | ||
4 | {-# LANGUAGE TemplateHaskell #-} | ||
5 | {-# LANGUAGE MultiWayIf #-} | ||
6 | {-# LANGUAGE OverloadedStrings #-} | ||
7 | {-# LANGUAGE ScopedTypeVariables #-} | ||
8 | |||
9 | ----------------------------------------------------------------------------- | ||
10 | -- | | ||
11 | -- Module : Numeric.Sundials.CVode.ODE | ||
12 | -- Copyright : Dominic Steinitz 2018, | ||
13 | -- Novadiscovery 2018 | ||
14 | -- License : BSD | ||
15 | -- Maintainer : Dominic Steinitz | ||
16 | -- Stability : provisional | ||
17 | -- | ||
18 | -- Solution of ordinary differential equation (ODE) initial value problems. | ||
19 | -- | ||
20 | -- <https://computation.llnl.gov/projects/sundials/sundials-software> | ||
21 | -- | ||
22 | -- A simple example: | ||
23 | -- | ||
24 | -- <<diagrams/brusselator.png#diagram=brusselator&height=400&width=500>> | ||
25 | -- | ||
26 | -- @ | ||
27 | -- import Numeric.Sundials.CVode.ODE | ||
28 | -- import Numeric.LinearAlgebra | ||
29 | -- | ||
30 | -- import Plots as P | ||
31 | -- import qualified Diagrams.Prelude as D | ||
32 | -- import Diagrams.Backend.Rasterific | ||
33 | -- | ||
34 | -- brusselator :: Double -> [Double] -> [Double] | ||
35 | -- brusselator _t x = [ a - (w + 1) * u + v * u * u | ||
36 | -- , w * u - v * u * u | ||
37 | -- , (b - w) / eps - w * u | ||
38 | -- ] | ||
39 | -- where | ||
40 | -- a = 1.0 | ||
41 | -- b = 3.5 | ||
42 | -- eps = 5.0e-6 | ||
43 | -- u = x !! 0 | ||
44 | -- v = x !! 1 | ||
45 | -- w = x !! 2 | ||
46 | -- | ||
47 | -- lSaxis :: [[Double]] -> P.Axis B D.V2 Double | ||
48 | -- lSaxis xs = P.r2Axis &~ do | ||
49 | -- let ts = xs!!0 | ||
50 | -- us = xs!!1 | ||
51 | -- vs = xs!!2 | ||
52 | -- ws = xs!!3 | ||
53 | -- P.linePlot' $ zip ts us | ||
54 | -- P.linePlot' $ zip ts vs | ||
55 | -- P.linePlot' $ zip ts ws | ||
56 | -- | ||
57 | -- main = do | ||
58 | -- let res1 = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) | ||
59 | -- renderRasterific "diagrams/brusselator.png" | ||
60 | -- (D.dims2D 500.0 500.0) | ||
61 | -- (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1)) | ||
62 | -- @ | ||
63 | -- | ||
64 | ----------------------------------------------------------------------------- | ||
65 | module Numeric.Sundials.CVode.ODE ( odeSolve | ||
66 | , odeSolveV | ||
67 | , odeSolveVWith | ||
68 | , odeSolveVWith' | ||
69 | , ODEMethod(..) | ||
70 | , StepControl(..) | ||
71 | ) where | ||
72 | |||
73 | import qualified Language.C.Inline as C | ||
74 | import qualified Language.C.Inline.Unsafe as CU | ||
75 | |||
76 | import Data.Monoid ((<>)) | ||
77 | import Data.Maybe (isJust) | ||
78 | |||
79 | import Foreign.C.Types (CDouble, CInt, CLong) | ||
80 | import Foreign.Ptr (Ptr) | ||
81 | import Foreign.Storable (poke) | ||
82 | |||
83 | import qualified Data.Vector.Storable as V | ||
84 | |||
85 | import Data.Coerce (coerce) | ||
86 | import System.IO.Unsafe (unsafePerformIO) | ||
87 | |||
88 | import Numeric.LinearAlgebra.Devel (createVector) | ||
89 | |||
90 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, rows, | ||
91 | cols, toLists, size, reshape) | ||
92 | |||
93 | import Numeric.Sundials.Arkode (cV_ADAMS, cV_BDF, | ||
94 | getDataFromContents, putDataInContents) | ||
95 | import qualified Numeric.Sundials.Arkode as T | ||
96 | import Numeric.Sundials.ODEOpts (ODEOpts(..), Jacobian, SundialsDiagnostics(..)) | ||
97 | |||
98 | |||
99 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | ||
100 | |||
101 | C.include "<stdlib.h>" | ||
102 | C.include "<stdio.h>" | ||
103 | C.include "<math.h>" | ||
104 | C.include "<cvode/cvode.h>" -- prototypes for CVODE fcts., consts. | ||
105 | C.include "<nvector/nvector_serial.h>" -- serial N_Vector types, fcts., macros | ||
106 | C.include "<sunmatrix/sunmatrix_dense.h>" -- access to dense SUNMatrix | ||
107 | C.include "<sunlinsol/sunlinsol_dense.h>" -- access to dense SUNLinearSolver | ||
108 | C.include "<cvode/cvode_direct.h>" -- access to CVDls interface | ||
109 | C.include "<sundials/sundials_types.h>" -- definition of type realtype | ||
110 | C.include "<sundials/sundials_math.h>" | ||
111 | C.include "../../../helpers.h" | ||
112 | C.include "Numeric/Sundials/Arkode_hsc.h" | ||
113 | |||
114 | |||
115 | -- | Stepping functions | ||
116 | data ODEMethod = ADAMS | ||
117 | | BDF | ||
118 | |||
119 | getMethod :: ODEMethod -> Int | ||
120 | getMethod (ADAMS) = cV_ADAMS | ||
121 | getMethod (BDF) = cV_BDF | ||
122 | |||
123 | getJacobian :: ODEMethod -> Maybe Jacobian | ||
124 | getJacobian _ = Nothing | ||
125 | |||
126 | -- | A version of 'odeSolveVWith' with reasonable default step control. | ||
127 | odeSolveV | ||
128 | :: ODEMethod | ||
129 | -> Maybe Double -- ^ initial step size - by default, CVode | ||
130 | -- estimates the initial step size to be the | ||
131 | -- solution \(h\) of the equation | ||
132 | -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where | ||
133 | -- \(\ddot{y}\) is an estimated value of the | ||
134 | -- second derivative of the solution at \(t_0\) | ||
135 | -> Double -- ^ absolute tolerance for the state vector | ||
136 | -> Double -- ^ relative tolerance for the state vector | ||
137 | -> (Double -> Vector Double -> Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
138 | -> Vector Double -- ^ initial conditions | ||
139 | -> Vector Double -- ^ desired solution times | ||
140 | -> Matrix Double -- ^ solution | ||
141 | odeSolveV meth hi epsAbs epsRel f y0 ts = | ||
142 | odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts | ||
143 | where | ||
144 | g t x0 = coerce $ f t x0 | ||
145 | |||
146 | -- | A version of 'odeSolveV' with reasonable default parameters and | ||
147 | -- system of equations defined using lists. FIXME: we should say | ||
148 | -- something about the fact we could use the Jacobian but don't for | ||
149 | -- compatibility with hmatrix-gsl. | ||
150 | odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
151 | -> [Double] -- ^ initial conditions | ||
152 | -> Vector Double -- ^ desired solution times | ||
153 | -> Matrix Double -- ^ solution | ||
154 | odeSolve f y0 ts = | ||
155 | -- FIXME: These tolerances are different from the ones in GSL | ||
156 | odeSolveVWith BDF (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) | ||
157 | where | ||
158 | g t x0 = V.fromList $ f t (V.toList x0) | ||
159 | |||
160 | odeSolveVWith :: | ||
161 | ODEMethod | ||
162 | -> StepControl | ||
163 | -> Maybe Double -- ^ initial step size - by default, CVode | ||
164 | -- estimates the initial step size to be the | ||
165 | -- solution \(h\) of the equation | ||
166 | -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where | ||
167 | -- \(\ddot{y}\) is an estimated value of the second | ||
168 | -- derivative of the solution at \(t_0\) | ||
169 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
170 | -> V.Vector Double -- ^ Initial conditions | ||
171 | -> V.Vector Double -- ^ Desired solution times | ||
172 | -> Matrix Double -- ^ Error code or solution | ||
173 | odeSolveVWith method control initStepSize f y0 tt = | ||
174 | case odeSolveVWith' opts method control initStepSize f y0 tt of | ||
175 | Left c -> error $ show c -- FIXME | ||
176 | Right (v, _d) -> v | ||
177 | where | ||
178 | opts = ODEOpts { maxNumSteps = 10000 | ||
179 | , minStep = 1.0e-12 | ||
180 | , relTol = error "relTol" | ||
181 | , absTols = error "absTol" | ||
182 | , initStep = error "initStep" | ||
183 | , maxFail = 10 | ||
184 | } | ||
185 | |||
186 | odeSolveVWith' :: | ||
187 | ODEOpts | ||
188 | -> ODEMethod | ||
189 | -> StepControl | ||
190 | -> Maybe Double -- ^ initial step size - by default, CVode | ||
191 | -- estimates the initial step size to be the | ||
192 | -- solution \(h\) of the equation | ||
193 | -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where | ||
194 | -- \(\ddot{y}\) is an estimated value of the second | ||
195 | -- derivative of the solution at \(t_0\) | ||
196 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
197 | -> V.Vector Double -- ^ Initial conditions | ||
198 | -> V.Vector Double -- ^ Desired solution times | ||
199 | -> Either Int (Matrix Double, SundialsDiagnostics) -- ^ Error code or solution | ||
200 | odeSolveVWith' opts method control initStepSize f y0 tt = | ||
201 | case solveOdeC (fromIntegral $ maxFail opts) | ||
202 | (fromIntegral $ maxNumSteps opts) (coerce $ minStep opts) | ||
203 | (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) | ||
204 | (coerce f) (coerce y0) (coerce tt) of | ||
205 | Left c -> Left $ fromIntegral c | ||
206 | Right (v, d) -> Right (reshape l (coerce v), d) | ||
207 | where | ||
208 | l = size y0 | ||
209 | scise (X aTol rTol) = coerce (V.replicate l aTol, rTol) | ||
210 | scise (X' aTol rTol) = coerce (V.replicate l aTol, rTol) | ||
211 | scise (XX' aTol rTol yScale _yDotScale) = coerce (V.replicate l aTol, yScale * rTol) | ||
212 | -- FIXME; Should we check that the length of ss is correct? | ||
213 | scise (ScXX' aTol rTol yScale _yDotScale ss) = coerce (V.map (* aTol) ss, yScale * rTol) | ||
214 | jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $ | ||
215 | getJacobian method | ||
216 | matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } | ||
217 | where | ||
218 | nr = fromIntegral $ rows m | ||
219 | nc = fromIntegral $ cols m | ||
220 | -- FIXME: efficiency | ||
221 | vs = V.fromList $ map coerce $ concat $ toLists m | ||
222 | |||
223 | solveOdeC :: | ||
224 | CInt -> | ||
225 | CLong -> | ||
226 | CDouble -> | ||
227 | CInt -> | ||
228 | Maybe CDouble -> | ||
229 | (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) -> | ||
230 | (V.Vector CDouble, CDouble) -> | ||
231 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
232 | -> V.Vector CDouble -- ^ Initial conditions | ||
233 | -> V.Vector CDouble -- ^ Desired solution times | ||
234 | -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution | ||
235 | solveOdeC maxErrTestFails maxNumSteps_ minStep_ method initStepSize | ||
236 | jacH (aTols, rTol) fun f0 ts = | ||
237 | unsafePerformIO $ do | ||
238 | |||
239 | let isInitStepSize :: CInt | ||
240 | isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize | ||
241 | ss :: CDouble | ||
242 | ss = case initStepSize of | ||
243 | -- It would be better to put an error message here but | ||
244 | -- inline-c seems to evaluate this even if it is never | ||
245 | -- used :( | ||
246 | Nothing -> 0.0 | ||
247 | Just x -> x | ||
248 | |||
249 | let dim = V.length f0 | ||
250 | nEq :: CLong | ||
251 | nEq = fromIntegral dim | ||
252 | nTs :: CInt | ||
253 | nTs = fromIntegral $ V.length ts | ||
254 | quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) | ||
255 | qMatMut <- V.thaw quasiMatrixRes | ||
256 | diagnostics :: V.Vector CLong <- createVector 10 -- FIXME | ||
257 | diagMut <- V.thaw diagnostics | ||
258 | -- We need the types that sundials expects. These are tied together | ||
259 | -- in 'CLangToHaskellTypes'. FIXME: The Haskell type is currently empty! | ||
260 | let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt | ||
261 | funIO x y f _ptr = do | ||
262 | -- Convert the pointer we get from C (y) to a vector, and then | ||
263 | -- apply the user-supplied function. | ||
264 | fImm <- fun x <$> getDataFromContents dim y | ||
265 | -- Fill in the provided pointer with the resulting vector. | ||
266 | putDataInContents fImm dim f | ||
267 | -- FIXME: I don't understand what this comment means | ||
268 | -- Unsafe since the function will be called many times. | ||
269 | [CU.exp| int{ 0 } |] | ||
270 | let isJac :: CInt | ||
271 | isJac = fromIntegral $ fromEnum $ isJust jacH | ||
272 | jacIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunMatrix -> | ||
273 | Ptr () -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunVector -> | ||
274 | IO CInt | ||
275 | jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do | ||
276 | case jacH of | ||
277 | Nothing -> error "Numeric.Sundials.CVode.ODE: Jacobian not defined" | ||
278 | Just jacI -> do j <- jacI t <$> getDataFromContents dim y | ||
279 | poke jacS j | ||
280 | -- FIXME: I don't understand what this comment means | ||
281 | -- Unsafe since the function will be called many times. | ||
282 | [CU.exp| int{ 0 } |] | ||
283 | |||
284 | res <- [C.block| int { | ||
285 | /* general problem variables */ | ||
286 | |||
287 | int flag; /* reusable error-checking flag */ | ||
288 | int i, j; /* reusable loop indices */ | ||
289 | N_Vector y = NULL; /* empty vector for storing solution */ | ||
290 | N_Vector tv = NULL; /* empty vector for storing absolute tolerances */ | ||
291 | |||
292 | SUNMatrix A = NULL; /* empty matrix for linear solver */ | ||
293 | SUNLinearSolver LS = NULL; /* empty linear solver object */ | ||
294 | void *cvode_mem = NULL; /* empty CVODE memory structure */ | ||
295 | realtype t; | ||
296 | long int nst, nfe, nsetups, nje, nfeLS, nni, ncfn, netf, nge; | ||
297 | |||
298 | /* general problem parameters */ | ||
299 | |||
300 | realtype T0 = RCONST(($vec-ptr:(double *ts))[0]); /* initial time */ | ||
301 | sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ | ||
302 | |||
303 | /* Initialize data structures */ | ||
304 | |||
305 | y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ | ||
306 | if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; | ||
307 | /* Specify initial condition */ | ||
308 | for (i = 0; i < NEQ; i++) { | ||
309 | NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i]; | ||
310 | }; | ||
311 | |||
312 | cvode_mem = CVodeCreate($(int method), CV_NEWTON); | ||
313 | if (check_flag((void *)cvode_mem, "CVodeCreate", 0)) return(1); | ||
314 | |||
315 | /* Call CVodeInit to initialize the integrator memory and specify the | ||
316 | * user's right hand side function in y'=f(t,y), the inital time T0, and | ||
317 | * the initial dependent variable vector y. */ | ||
318 | flag = CVodeInit(cvode_mem, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y); | ||
319 | if (check_flag(&flag, "CVodeInit", 1)) return(1); | ||
320 | |||
321 | tv = N_VNew_Serial(NEQ); /* Create serial vector for absolute tolerances */ | ||
322 | if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1; | ||
323 | /* Specify tolerances */ | ||
324 | for (i = 0; i < NEQ; i++) { | ||
325 | NV_Ith_S(tv,i) = ($vec-ptr:(double *aTols))[i]; | ||
326 | }; | ||
327 | |||
328 | flag = CVodeSetMinStep(cvode_mem, $(double minStep_)); | ||
329 | if (check_flag(&flag, "CVodeSetMinStep", 1)) return 1; | ||
330 | flag = CVodeSetMaxNumSteps(cvode_mem, $(long int maxNumSteps_)); | ||
331 | if (check_flag(&flag, "CVodeSetMaxNumSteps", 1)) return 1; | ||
332 | flag = CVodeSetMaxErrTestFails(cvode_mem, $(int maxErrTestFails)); | ||
333 | if (check_flag(&flag, "CVodeSetMaxErrTestFails", 1)) return 1; | ||
334 | |||
335 | /* Call CVodeSVtolerances to specify the scalar relative tolerance | ||
336 | * and vector absolute tolerances */ | ||
337 | flag = CVodeSVtolerances(cvode_mem, $(double rTol), tv); | ||
338 | if (check_flag(&flag, "CVodeSVtolerances", 1)) return(1); | ||
339 | |||
340 | /* Initialize dense matrix data structure and solver */ | ||
341 | A = SUNDenseMatrix(NEQ, NEQ); | ||
342 | if (check_flag((void *)A, "SUNDenseMatrix", 0)) return 1; | ||
343 | LS = SUNDenseLinearSolver(y, A); | ||
344 | if (check_flag((void *)LS, "SUNDenseLinearSolver", 0)) return 1; | ||
345 | |||
346 | /* Attach matrix and linear solver */ | ||
347 | flag = CVDlsSetLinearSolver(cvode_mem, LS, A); | ||
348 | if (check_flag(&flag, "CVDlsSetLinearSolver", 1)) return 1; | ||
349 | |||
350 | /* Set the initial step size if there is one */ | ||
351 | if ($(int isInitStepSize)) { | ||
352 | /* FIXME: We could check if the initial step size is 0 */ | ||
353 | /* or even NaN and then throw an error */ | ||
354 | flag = CVodeSetInitStep(cvode_mem, $(double ss)); | ||
355 | if (check_flag(&flag, "CVodeSetInitStep", 1)) return 1; | ||
356 | } | ||
357 | |||
358 | /* Set the Jacobian if there is one */ | ||
359 | if ($(int isJac)) { | ||
360 | flag = CVDlsSetJacFn(cvode_mem, $fun:(int (* jacIO) (double t, SunVector y[], SunVector fy[], SunMatrix Jac[], void * params, SunVector tmp1[], SunVector tmp2[], SunVector tmp3[]))); | ||
361 | if (check_flag(&flag, "CVDlsSetJacFn", 1)) return 1; | ||
362 | } | ||
363 | |||
364 | /* Store initial conditions */ | ||
365 | for (j = 0; j < NEQ; j++) { | ||
366 | ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j); | ||
367 | } | ||
368 | |||
369 | /* Main time-stepping loop: calls CVode to perform the integration */ | ||
370 | /* Stops when the final time has been reached */ | ||
371 | for (i = 1; i < $(int nTs); i++) { | ||
372 | |||
373 | flag = CVode(cvode_mem, ($vec-ptr:(double *ts))[i], y, &t, CV_NORMAL); /* call integrator */ | ||
374 | if (check_flag(&flag, "CVode", 1)) break; | ||
375 | |||
376 | /* Store the results for Haskell */ | ||
377 | for (j = 0; j < NEQ; j++) { | ||
378 | ($vec-ptr:(double *qMatMut))[i * NEQ + j] = NV_Ith_S(y,j); | ||
379 | } | ||
380 | |||
381 | /* unsuccessful solve: break */ | ||
382 | if (flag < 0) { | ||
383 | fprintf(stderr,"Solver failure, stopping integration\n"); | ||
384 | break; | ||
385 | } | ||
386 | } | ||
387 | |||
388 | /* Get some final statistics on how the solve progressed */ | ||
389 | |||
390 | flag = CVodeGetNumSteps(cvode_mem, &nst); | ||
391 | check_flag(&flag, "CVodeGetNumSteps", 1); | ||
392 | ($vec-ptr:(long int *diagMut))[0] = nst; | ||
393 | |||
394 | /* FIXME */ | ||
395 | ($vec-ptr:(long int *diagMut))[1] = 0; | ||
396 | |||
397 | flag = CVodeGetNumRhsEvals(cvode_mem, &nfe); | ||
398 | check_flag(&flag, "CVodeGetNumRhsEvals", 1); | ||
399 | ($vec-ptr:(long int *diagMut))[2] = nfe; | ||
400 | /* FIXME */ | ||
401 | ($vec-ptr:(long int *diagMut))[3] = 0; | ||
402 | |||
403 | flag = CVodeGetNumLinSolvSetups(cvode_mem, &nsetups); | ||
404 | check_flag(&flag, "CVodeGetNumLinSolvSetups", 1); | ||
405 | ($vec-ptr:(long int *diagMut))[4] = nsetups; | ||
406 | |||
407 | flag = CVodeGetNumErrTestFails(cvode_mem, &netf); | ||
408 | check_flag(&flag, "CVodeGetNumErrTestFails", 1); | ||
409 | ($vec-ptr:(long int *diagMut))[5] = netf; | ||
410 | |||
411 | flag = CVodeGetNumNonlinSolvIters(cvode_mem, &nni); | ||
412 | check_flag(&flag, "CVodeGetNumNonlinSolvIters", 1); | ||
413 | ($vec-ptr:(long int *diagMut))[6] = nni; | ||
414 | |||
415 | flag = CVodeGetNumNonlinSolvConvFails(cvode_mem, &ncfn); | ||
416 | check_flag(&flag, "CVodeGetNumNonlinSolvConvFails", 1); | ||
417 | ($vec-ptr:(long int *diagMut))[7] = ncfn; | ||
418 | |||
419 | flag = CVDlsGetNumJacEvals(cvode_mem, &nje); | ||
420 | check_flag(&flag, "CVDlsGetNumJacEvals", 1); | ||
421 | ($vec-ptr:(long int *diagMut))[8] = ncfn; | ||
422 | |||
423 | flag = CVDlsGetNumRhsEvals(cvode_mem, &nfeLS); | ||
424 | check_flag(&flag, "CVDlsGetNumRhsEvals", 1); | ||
425 | ($vec-ptr:(long int *diagMut))[9] = ncfn; | ||
426 | |||
427 | /* Clean up and return */ | ||
428 | |||
429 | N_VDestroy(y); /* Free y vector */ | ||
430 | N_VDestroy(tv); /* Free tv vector */ | ||
431 | CVodeFree(&cvode_mem); /* Free integrator memory */ | ||
432 | SUNLinSolFree(LS); /* Free linear solver */ | ||
433 | SUNMatDestroy(A); /* Free A matrix */ | ||
434 | |||
435 | return flag; | ||
436 | } |] | ||
437 | if res == 0 | ||
438 | then do | ||
439 | preD <- V.freeze diagMut | ||
440 | let d = SundialsDiagnostics (fromIntegral $ preD V.!0) | ||
441 | (fromIntegral $ preD V.!1) | ||
442 | (fromIntegral $ preD V.!2) | ||
443 | (fromIntegral $ preD V.!3) | ||
444 | (fromIntegral $ preD V.!4) | ||
445 | (fromIntegral $ preD V.!5) | ||
446 | (fromIntegral $ preD V.!6) | ||
447 | (fromIntegral $ preD V.!7) | ||
448 | (fromIntegral $ preD V.!8) | ||
449 | (fromIntegral $ preD V.!9) | ||
450 | m <- V.freeze qMatMut | ||
451 | return $ Right (m, d) | ||
452 | else do | ||
453 | return $ Left res | ||
454 | |||
455 | -- | Adaptive step-size control | ||
456 | -- functions. | ||
457 | -- | ||
458 | -- [GSL](https://www.gnu.org/software/gsl/doc/html/ode-initval.html#adaptive-step-size-control) | ||
459 | -- allows the user to control the step size adjustment using | ||
460 | -- \(D_i = \epsilon^{abs}s_i + \epsilon^{rel}(a_{y} |y_i| + a_{dy/dt} h |\dot{y}_i|)\) where | ||
461 | -- \(\epsilon^{abs}\) is the required absolute error, \(\epsilon^{rel}\) | ||
462 | -- is the required relative error, \(s_i\) is a vector of scaling | ||
463 | -- factors, \(a_{y}\) is a scaling factor for the solution \(y\) and | ||
464 | -- \(a_{dydt}\) is a scaling factor for the derivative of the solution \(dy/dt\). | ||
465 | -- | ||
466 | -- [ARKode](https://computation.llnl.gov/projects/sundials/arkode) | ||
467 | -- allows the user to control the step size adjustment using | ||
468 | -- \(\eta^{rel}|y_i| + \eta^{abs}_i\). For compatibility with | ||
469 | -- [hmatrix-gsl](https://hackage.haskell.org/package/hmatrix-gsl), | ||
470 | -- tolerances for \(y\) and \(\dot{y}\) can be specified but the latter have no | ||
471 | -- effect. | ||
472 | data StepControl = X Double Double -- ^ absolute and relative tolerance for \(y\); in GSL terms, \(a_{y} = 1\) and \(a_{dy/dt} = 0\); in ARKode terms, the \(\eta^{abs}_i\) are identical | ||
473 | | X' Double Double -- ^ absolute and relative tolerance for \(\dot{y}\); in GSL terms, \(a_{y} = 0\) and \(a_{dy/dt} = 1\); in ARKode terms, the latter is treated as the relative tolerance for \(y\) so this is the same as specifying 'X' which may be entirely incorrect for the given problem | ||
474 | | XX' Double Double Double Double -- ^ include both via relative tolerance | ||
475 | -- scaling factors \(a_y\), \(a_{{dy}/{dt}}\); in ARKode terms, the latter is ignored and \(\eta^{rel} = a_{y}\epsilon^{rel}\) | ||
476 | | ScXX' Double Double Double Double (Vector Double) -- ^ scale absolute tolerance of \(y_i\); in ARKode terms, \(a_{{dy}/{dt}}\) is ignored, \(\eta^{abs}_i = s_i \epsilon^{abs}\) and \(\eta^{rel} = a_{y}\epsilon^{rel}\) | ||
diff --git a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs new file mode 100644 index 0000000..027d99a --- /dev/null +++ b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs | |||
@@ -0,0 +1,32 @@ | |||
1 | module Numeric.Sundials.ODEOpts where | ||
2 | |||
3 | import Data.Word (Word32) | ||
4 | import qualified Data.Vector.Storable as VS | ||
5 | |||
6 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix) | ||
7 | |||
8 | |||
9 | type Jacobian = Double -> Vector Double -> Matrix Double | ||
10 | |||
11 | data ODEOpts = ODEOpts { | ||
12 | maxNumSteps :: Word32 | ||
13 | , minStep :: Double | ||
14 | , relTol :: Double | ||
15 | , absTols :: VS.Vector Double | ||
16 | , initStep :: Maybe Double | ||
17 | , maxFail :: Word32 | ||
18 | } deriving (Read, Show, Eq, Ord) | ||
19 | |||
20 | data SundialsDiagnostics = SundialsDiagnostics { | ||
21 | aRKodeGetNumSteps :: Int | ||
22 | , aRKodeGetNumStepAttempts :: Int | ||
23 | , aRKodeGetNumRhsEvals_fe :: Int | ||
24 | , aRKodeGetNumRhsEvals_fi :: Int | ||
25 | , aRKodeGetNumLinSolvSetups :: Int | ||
26 | , aRKodeGetNumErrTestFails :: Int | ||
27 | , aRKodeGetNumNonlinSolvIters :: Int | ||
28 | , aRKodeGetNumNonlinSolvConvFails :: Int | ||
29 | , aRKDlsGetNumJacEvals :: Int | ||
30 | , aRKDlsGetNumRhsEvals :: Int | ||
31 | } deriving Show | ||
32 | |||
diff --git a/packages/sundials/src/helpers.c b/packages/sundials/src/helpers.c new file mode 100644 index 0000000..f0ca592 --- /dev/null +++ b/packages/sundials/src/helpers.c | |||
@@ -0,0 +1,44 @@ | |||
1 | #include <stdio.h> | ||
2 | #include <math.h> | ||
3 | #include <arkode/arkode.h> /* prototypes for ARKODE fcts., consts. */ | ||
4 | #include <nvector/nvector_serial.h> /* serial N_Vector types, fcts., macros */ | ||
5 | #include <sunmatrix/sunmatrix_dense.h> /* access to dense SUNMatrix */ | ||
6 | #include <sunlinsol/sunlinsol_dense.h> /* access to dense SUNLinearSolver */ | ||
7 | #include <arkode/arkode_direct.h> /* access to ARKDls interface */ | ||
8 | #include <sundials/sundials_types.h> /* definition of type realtype */ | ||
9 | #include <sundials/sundials_math.h> | ||
10 | |||
11 | /* Check function return value... | ||
12 | opt == 0 means SUNDIALS function allocates memory so check if | ||
13 | returned NULL pointer | ||
14 | opt == 1 means SUNDIALS function returns a flag so check if | ||
15 | flag >= 0 | ||
16 | opt == 2 means function allocates memory so check if returned | ||
17 | NULL pointer | ||
18 | */ | ||
19 | int check_flag(void *flagvalue, const char *funcname, int opt) | ||
20 | { | ||
21 | int *errflag; | ||
22 | |||
23 | /* Check if SUNDIALS function returned NULL pointer - no memory allocated */ | ||
24 | if (opt == 0 && flagvalue == NULL) { | ||
25 | fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n", | ||
26 | funcname); | ||
27 | return 1; } | ||
28 | |||
29 | /* Check if flag < 0 */ | ||
30 | else if (opt == 1) { | ||
31 | errflag = (int *) flagvalue; | ||
32 | if (*errflag < 0) { | ||
33 | fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed with flag = %d\n\n", | ||
34 | funcname, *errflag); | ||
35 | return 1; }} | ||
36 | |||
37 | /* Check if function returned NULL pointer - no memory allocated */ | ||
38 | else if (opt == 2 && flagvalue == NULL) { | ||
39 | fprintf(stderr, "\nMEMORY_ERROR: %s() failed - returned NULL pointer\n\n", | ||
40 | funcname); | ||
41 | return 1; } | ||
42 | |||
43 | return 0; | ||
44 | } | ||
diff --git a/packages/sundials/src/helpers.h b/packages/sundials/src/helpers.h new file mode 100644 index 0000000..3d8fbc0 --- /dev/null +++ b/packages/sundials/src/helpers.h | |||
@@ -0,0 +1,9 @@ | |||
1 | /* Check function return value... | ||
2 | opt == 0 means SUNDIALS function allocates memory so check if | ||
3 | returned NULL pointer | ||
4 | opt == 1 means SUNDIALS function returns a flag so check if | ||
5 | flag >= 0 | ||
6 | opt == 2 means function allocates memory so check if returned | ||
7 | NULL pointer | ||
8 | */ | ||
9 | int check_flag(void *flagvalue, const char *funcname, int opt); | ||
diff --git a/packages/tests/hmatrix-tests.cabal b/packages/tests/hmatrix-tests.cabal index 00f3a38..31fa32e 100644 --- a/packages/tests/hmatrix-tests.cabal +++ b/packages/tests/hmatrix-tests.cabal | |||
@@ -1,5 +1,5 @@ | |||
1 | Name: hmatrix-tests | 1 | Name: hmatrix-tests |
2 | Version: 0.6.0.0 | 2 | Version: 0.19.0.0 |
3 | License: BSD3 | 3 | License: BSD3 |
4 | License-file: LICENSE | 4 | License-file: LICENSE |
5 | Author: Alberto Ruiz | 5 | Author: Alberto Ruiz |
diff --git a/packages/tests/src/Numeric/GSL/Tests.hs b/packages/tests/src/Numeric/GSL/Tests.hs index 025427b..ed15935 100644 --- a/packages/tests/src/Numeric/GSL/Tests.hs +++ b/packages/tests/src/Numeric/GSL/Tests.hs | |||
@@ -1,4 +1,4 @@ | |||
1 | {-# OPTIONS_GHC -fno-warn-unused-imports -fno-warn-incomplete-patterns #-} | 1 | {-# OPTIONS_GHC -fno-warn-unused-imports -fno-warn-incomplete-patterns -fno-warn-missing-signatures #-} |
2 | {- | | 2 | {- | |
3 | Module : Numeric.GLS.Tests | 3 | Module : Numeric.GLS.Tests |
4 | Copyright : (c) Alberto Ruiz 2014 | 4 | Copyright : (c) Alberto Ruiz 2014 |
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs index 2aefc87..4ed1462 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs | |||
@@ -1,5 +1,5 @@ | |||
1 | {-# LANGUAGE CPP #-} | 1 | {-# LANGUAGE CPP #-} |
2 | {-# OPTIONS_GHC -fno-warn-unused-imports -fno-warn-incomplete-patterns #-} | 2 | {-# OPTIONS_GHC -fno-warn-unused-imports -fno-warn-incomplete-patterns -fno-warn-missing-signatures #-} |
3 | {-# LANGUAGE DataKinds #-} | 3 | {-# LANGUAGE DataKinds #-} |
4 | {-# LANGUAGE TypeFamilies #-} | 4 | {-# LANGUAGE TypeFamilies #-} |
5 | {-# LANGUAGE FlexibleContexts #-} | 5 | {-# LANGUAGE FlexibleContexts #-} |
@@ -31,7 +31,7 @@ module Numeric.LinearAlgebra.Tests( | |||
31 | --, runBigTests | 31 | --, runBigTests |
32 | ) where | 32 | ) where |
33 | 33 | ||
34 | import Numeric.LinearAlgebra hiding (unitary) | 34 | import Numeric.LinearAlgebra |
35 | import Numeric.LinearAlgebra.Devel | 35 | import Numeric.LinearAlgebra.Devel |
36 | import Numeric.LinearAlgebra.Static(L) | 36 | import Numeric.LinearAlgebra.Static(L) |
37 | import Numeric.LinearAlgebra.Tests.Instances | 37 | import Numeric.LinearAlgebra.Tests.Instances |
@@ -514,7 +514,7 @@ indexProp g f x = a1 == g a2 && a2 == a3 && b1 == g b2 && b2 == b3 | |||
514 | 514 | ||
515 | -------------------------------------------------------------------------------- | 515 | -------------------------------------------------------------------------------- |
516 | 516 | ||
517 | sliceTest = utest "slice test" $ and | 517 | _sliceTest = TestList |
518 | [ testSlice (chol . trustSym) (gen 5 :: Matrix R) | 518 | [ testSlice (chol . trustSym) (gen 5 :: Matrix R) |
519 | , testSlice (chol . trustSym) (gen 5 :: Matrix C) | 519 | , testSlice (chol . trustSym) (gen 5 :: Matrix C) |
520 | , testSlice qr (rec :: Matrix R) | 520 | , testSlice qr (rec :: Matrix R) |
@@ -617,7 +617,7 @@ sliceTest = utest "slice test" $ and | |||
617 | 617 | ||
618 | test_qrgr n t x = qrgr n (QR x t) | 618 | test_qrgr n t x = qrgr n (QR x t) |
619 | 619 | ||
620 | ok_qrgr x = simeq 1E-15 q q' | 620 | ok_qrgr x = TestCase . assertBool "ok_qrgr" $ simeq 1E-15 q q' |
621 | where | 621 | where |
622 | (q,_) = qr x | 622 | (q,_) = qr x |
623 | atau = qrRaw x | 623 | atau = qrRaw x |
@@ -646,7 +646,8 @@ sliceTest = utest "slice test" $ and | |||
646 | rec :: Numeric t => Matrix t | 646 | rec :: Numeric t => Matrix t |
647 | rec = subMatrix (0,0) (4,5) (gen 5) | 647 | rec = subMatrix (0,0) (4,5) (gen 5) |
648 | 648 | ||
649 | testSlice f x@(size->sz@(r,c)) = all (==f x) (map f (g y1 ++ g y2)) | 649 | testSlice f x@(size->sz@(r,c)) = |
650 | TestList . map (TestCase . assertEqual "" (f x)) $ (map f (g y1 ++ g y2)) | ||
650 | where | 651 | where |
651 | subm = subMatrix | 652 | subm = subMatrix |
652 | g y = [ subm (a*r,b*c) sz y | a <-[0..2], b <- [0..2]] | 653 | g y = [ subm (a*r,b*c) sz y | a <-[0..2], b <- [0..2]] |
@@ -841,7 +842,7 @@ runTests n = do | |||
841 | , staticTest | 842 | , staticTest |
842 | , intTest | 843 | , intTest |
843 | , modularTest | 844 | , modularTest |
844 | , sliceTest | 845 | -- , sliceTest |
845 | ] | 846 | ] |
846 | when (errors c + failures c > 0) exitFailure | 847 | when (errors c + failures c > 0) exitFailure |
847 | return () | 848 | return () |
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs index f0bddd0..59230e0 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs | |||
@@ -1,4 +1,8 @@ | |||
1 | {-# LANGUAGE CPP, FlexibleContexts, UndecidableInstances, FlexibleInstances, ScopedTypeVariables #-} | 1 | {-# LANGUAGE CPP, FlexibleContexts, UndecidableInstances, FlexibleInstances, ScopedTypeVariables #-} |
2 | |||
3 | {-# OPTIONS_GHC -fno-warn-orphans #-} | ||
4 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
5 | |||
2 | ----------------------------------------------------------------------------- | 6 | ----------------------------------------------------------------------------- |
3 | {- | | 7 | {- | |
4 | Module : Numeric.LinearAlgebra.Tests.Instances | 8 | Module : Numeric.LinearAlgebra.Tests.Instances |
@@ -62,7 +66,7 @@ instance KnownNat n => Arbitrary (Static.R n) where | |||
62 | n :: Int | 66 | n :: Int |
63 | n = fromIntegral (natVal (Proxy :: Proxy n)) | 67 | n = fromIntegral (natVal (Proxy :: Proxy n)) |
64 | 68 | ||
65 | shrink v = [] | 69 | shrink _v = [] |
66 | 70 | ||
67 | instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where | 71 | instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where |
68 | arbitrary = do | 72 | arbitrary = do |
@@ -89,7 +93,7 @@ instance (KnownNat n, KnownNat m) => Arbitrary (Static.L m n) where | |||
89 | n :: Int | 93 | n :: Int |
90 | n = fromIntegral (natVal (Proxy :: Proxy n)) | 94 | n = fromIntegral (natVal (Proxy :: Proxy n)) |
91 | 95 | ||
92 | shrink mat = [] | 96 | shrink _mat = [] |
93 | 97 | ||
94 | -- a square matrix | 98 | -- a square matrix |
95 | newtype (Sq a) = Sq (Matrix a) deriving Show | 99 | newtype (Sq a) = Sq (Matrix a) deriving Show |
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs index e3a6242..6cd3a9e 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs | |||
@@ -3,6 +3,8 @@ | |||
3 | {-# LANGUAGE TypeFamilies #-} | 3 | {-# LANGUAGE TypeFamilies #-} |
4 | {-# LANGUAGE DataKinds #-} | 4 | {-# LANGUAGE DataKinds #-} |
5 | 5 | ||
6 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | ||
7 | |||
6 | ----------------------------------------------------------------------------- | 8 | ----------------------------------------------------------------------------- |
7 | {- | | 9 | {- | |
8 | Module : Numeric.LinearAlgebra.Tests.Properties | 10 | Module : Numeric.LinearAlgebra.Tests.Properties |
@@ -51,14 +53,13 @@ module Numeric.LinearAlgebra.Tests.Properties ( | |||
51 | , staticVectorBinaryFailProp | 53 | , staticVectorBinaryFailProp |
52 | ) where | 54 | ) where |
53 | 55 | ||
54 | import Numeric.LinearAlgebra.HMatrix hiding (Testable,unitary) | 56 | import Numeric.LinearAlgebra.HMatrix hiding (Testable) |
55 | import qualified Numeric.LinearAlgebra.Static as Static | 57 | import qualified Numeric.LinearAlgebra.Static as Static |
56 | import Test.QuickCheck | 58 | import Test.QuickCheck |
57 | 59 | ||
58 | import Data.Binary | 60 | import Data.Binary |
59 | import Data.Binary.Get (runGet) | 61 | import Data.Binary.Get (runGet) |
60 | import Data.Either (isLeft) | 62 | import Data.Either (isLeft) |
61 | import Debug.Trace (traceShowId) | ||
62 | #if MIN_VERSION_base(4,11,0) | 63 | #if MIN_VERSION_base(4,11,0) |
63 | import Prelude hiding ((<>)) | 64 | import Prelude hiding ((<>)) |
64 | #endif | 65 | #endif |
diff --git a/packages/tests/src/TestBase.hs b/packages/tests/src/TestBase.hs index 23fd675..51867b1 100644 --- a/packages/tests/src/TestBase.hs +++ b/packages/tests/src/TestBase.hs | |||
@@ -1,3 +1,4 @@ | |||
1 | import Numeric.LinearAlgebra.Tests | 1 | import Numeric.LinearAlgebra.Tests |
2 | 2 | ||
3 | main :: IO () | ||
3 | main = runTests 20 | 4 | main = runTests 20 |
diff --git a/packages/tests/src/TestGSL.hs b/packages/tests/src/TestGSL.hs index 112422d..cc6b1e7 100644 --- a/packages/tests/src/TestGSL.hs +++ b/packages/tests/src/TestGSL.hs | |||
@@ -1,3 +1,4 @@ | |||
1 | import Numeric.GSL.Tests | 1 | import Numeric.GSL.Tests |
2 | 2 | ||
3 | main :: IO () | ||
3 | main = runTests 20 | 4 | main = runTests 20 |