diff options
Diffstat (limited to 'packages')
-rw-r--r-- | packages/sundials/hmatrix-sundials.cabal | 18 | ||||
-rw-r--r-- | packages/sundials/src/Bar.hsc | 6 | ||||
-rw-r--r-- | packages/sundials/src/Main.hs | 32 | ||||
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | 142 | ||||
-rw-r--r-- | packages/sundials/stack.yaml | 66 |
5 files changed, 180 insertions, 84 deletions
diff --git a/packages/sundials/hmatrix-sundials.cabal b/packages/sundials/hmatrix-sundials.cabal index 762537e..dbe50e0 100644 --- a/packages/sundials/hmatrix-sundials.cabal +++ b/packages/sundials/hmatrix-sundials.cabal | |||
@@ -16,8 +16,6 @@ build-type: Simple | |||
16 | extra-source-files: ChangeLog.md | 16 | extra-source-files: ChangeLog.md |
17 | cabal-version: >=1.10 | 17 | cabal-version: >=1.10 |
18 | 18 | ||
19 | extra-source-files: src/helpers.c, src/helpers.h | ||
20 | |||
21 | 19 | ||
22 | library | 20 | library |
23 | build-depends: base >=4.10 && <4.11, | 21 | build-depends: base >=4.10 && <4.11, |
@@ -26,15 +24,19 @@ library | |||
26 | template-haskell >=2.12 && <2.13, | 24 | template-haskell >=2.12 && <2.13, |
27 | containers >=0.5 && <0.6, | 25 | containers >=0.5 && <0.6, |
28 | hmatrix>=0.18 | 26 | hmatrix>=0.18 |
29 | other-extensions: QuasiQuotes, TemplateHaskell, MultiWayIf, OverloadedStrings | 27 | extra-libraries: sundials_arkode |
28 | other-extensions: QuasiQuotes | ||
30 | hs-source-dirs: src | 29 | hs-source-dirs: src |
31 | exposed-modules: Numeric.Sundials.Arkode.ODE | 30 | exposed-modules: Numeric.Sundials.Arkode.ODE |
32 | other-modules: Types | 31 | other-modules: Types, |
32 | Bar | ||
33 | 33 | ||
34 | executable sundials | 34 | executable sundials |
35 | main-is: Main.hs | 35 | main-is: Main.hs |
36 | other-modules: Types, Numeric.Sundials.Arkode.ODE | 36 | other-modules: Types, |
37 | other-extensions: QuasiQuotes, TemplateHaskell, MultiWayIf, OverloadedStrings | 37 | Numeric.Sundials.Arkode.ODE, |
38 | Bar | ||
39 | other-extensions: QuasiQuotes | ||
38 | build-depends: base >=4.10 && <4.11, | 40 | build-depends: base >=4.10 && <4.11, |
39 | inline-c >=0.6 && <0.7, | 41 | inline-c >=0.6 && <0.7, |
40 | vector >=0.12 && <0.13, | 42 | vector >=0.12 && <0.13, |
@@ -43,7 +45,9 @@ executable sundials | |||
43 | hmatrix>=0.18, | 45 | hmatrix>=0.18, |
44 | plots, | 46 | plots, |
45 | diagrams-lib, | 47 | diagrams-lib, |
46 | diagrams-rasterific | 48 | diagrams-rasterific, |
49 | lens, | ||
50 | pretty | ||
47 | hs-source-dirs: src | 51 | hs-source-dirs: src |
48 | default-language: Haskell2010 | 52 | default-language: Haskell2010 |
49 | extra-libraries: sundials_arkode | 53 | extra-libraries: sundials_arkode |
diff --git a/packages/sundials/src/Bar.hsc b/packages/sundials/src/Bar.hsc index 434c4d4..7db0d4a 100644 --- a/packages/sundials/src/Bar.hsc +++ b/packages/sundials/src/Bar.hsc | |||
@@ -1,6 +1,6 @@ | |||
1 | {-# LANGUAGE RecordWildCards #-} | 1 | {-# LANGUAGE RecordWildCards #-} |
2 | 2 | ||
3 | module Example where | 3 | module Bar where |
4 | 4 | ||
5 | import Foreign | 5 | import Foreign |
6 | import Foreign.C.Types | 6 | import Foreign.C.Types |
@@ -8,6 +8,7 @@ import Foreign.C.String | |||
8 | 8 | ||
9 | #include "/Users/dom/sundials/include/sundials/sundials_nvector.h" | 9 | #include "/Users/dom/sundials/include/sundials/sundials_nvector.h" |
10 | #include "/Users/dom/sundials/include/nvector/nvector_serial.h" | 10 | #include "/Users/dom/sundials/include/nvector/nvector_serial.h" |
11 | #include "/Users/dom/sundials/include/arkode/arkode.h" | ||
11 | 12 | ||
12 | #def typedef struct _generic_N_Vector BarType; | 13 | #def typedef struct _generic_N_Vector BarType; |
13 | #def typedef struct _N_VectorContent_Serial BazType; | 14 | #def typedef struct _N_VectorContent_Serial BazType; |
@@ -19,6 +20,9 @@ getContentPtr ptr = (#peek BarType, content) ptr | |||
19 | getData :: Storable a => Ptr b -> IO a | 20 | getData :: Storable a => Ptr b -> IO a |
20 | getData ptr = (#peek BazType, data) ptr | 21 | getData ptr = (#peek BazType, data) ptr |
21 | 22 | ||
23 | arkSMax :: Int | ||
24 | arkSMax = #const ARK_S_MAX | ||
25 | |||
22 | 26 | ||
23 | 27 | ||
24 | 28 | ||
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index 5e51372..895e610 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs | |||
@@ -11,6 +11,9 @@ import Diagrams.Backend.Rasterific | |||
11 | import Control.Lens | 11 | import Control.Lens |
12 | import Data.List (zip4) | 12 | import Data.List (zip4) |
13 | 13 | ||
14 | import Text.PrettyPrint.HughesPJClass | ||
15 | import Data.List (intercalate) | ||
16 | |||
14 | 17 | ||
15 | brusselator _t x = [ a - (w + 1) * u + v * u^2 | 18 | brusselator _t x = [ a - (w + 1) * u + v * u^2 |
16 | , w * u - v * u^2 | 19 | , w * u - v * u^2 |
@@ -43,8 +46,37 @@ kSaxis :: [(Double, Double)] -> P.Axis B D.V2 Double | |||
43 | kSaxis xs = P.r2Axis &~ do | 46 | kSaxis xs = P.r2Axis &~ do |
44 | P.linePlot' xs | 47 | P.linePlot' xs |
45 | 48 | ||
49 | butcherTableauTex :: (Show a, Element a) => Matrix a -> String | ||
50 | butcherTableauTex m = render $ | ||
51 | vcat [ text ("\n\\begin{array}{c|" ++ (concat $ replicate n "c") ++ "}") | ||
52 | , us | ||
53 | , text "\\end{array}" | ||
54 | ] | ||
55 | where | ||
56 | n = rows m | ||
57 | rs = toLists m | ||
58 | ss = map (\r -> intercalate " & " $ map show r) rs | ||
59 | ts = zipWith (\n r -> "c_" ++ show n ++ " & " ++ r) [1..n] ss | ||
60 | us = vcat $ map (\r -> text r <+> text "\\\\") ts | ||
61 | |||
46 | main :: IO () | 62 | main :: IO () |
47 | main = do | 63 | main = do |
64 | -- $$ | ||
65 | -- \begin{array}{c|cccc} | ||
66 | -- c_1 & a_{11} & a_{12}& \dots & a_{1s}\\ | ||
67 | -- c_2 & a_{21} & a_{22}& \dots & a_{2s}\\ | ||
68 | -- \vdots & \vdots & \vdots& \ddots& \vdots\\ | ||
69 | -- c_s & a_{s1} & a_{s2}& \dots & a_{ss} \\ | ||
70 | -- \hline | ||
71 | -- & b_1 & b_2 & \dots & b_s\\ | ||
72 | -- & b^*_1 & b^*_2 & \dots & b^*_s\\ | ||
73 | -- \end{array} | ||
74 | -- $$ | ||
75 | |||
76 | let res = btGet | ||
77 | putStrLn $ show res | ||
78 | putStrLn $ butcherTableauTex res | ||
79 | |||
48 | let res = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) | 80 | let res = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) |
49 | putStrLn $ show res | 81 | putStrLn $ show res |
50 | renderRasterific "diagrams/brusselator.png" | 82 | renderRasterific "diagrams/brusselator.png" |
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index f432951..76ed61b 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | |||
@@ -6,8 +6,25 @@ | |||
6 | {-# LANGUAGE OverloadedStrings #-} | 6 | {-# LANGUAGE OverloadedStrings #-} |
7 | {-# LANGUAGE ScopedTypeVariables #-} | 7 | {-# LANGUAGE ScopedTypeVariables #-} |
8 | 8 | ||
9 | -- | | ||
10 | -- Module: Numeric.Sundials.ARKode | ||
11 | -- | ||
12 | -- Blah | ||
13 | -- | ||
14 | -- \[ | ||
15 | -- \begin{array}{c|cccc} | ||
16 | -- c_1 & 0.0 & 0.0 & 0.0 & 0.0 \\ | ||
17 | -- c_2 & 0.4358665215 & 0.4358665215 & 0.0 & 0.0 \\ | ||
18 | -- c_3 & 0.490563388419108 & 7.3570090080892e-2 & 0.4358665215 & 0.0 \\ | ||
19 | -- c_4 & 0.308809969973036 & 1.490563388254106 & -1.235239879727145 & 0.4358665215 \\ | ||
20 | -- \end{array} | ||
21 | -- \] | ||
22 | -- | ||
9 | module Numeric.Sundials.Arkode.ODE ( solveOde | 23 | module Numeric.Sundials.Arkode.ODE ( solveOde |
10 | , odeSolve | 24 | , odeSolve |
25 | , getButcherTable | ||
26 | , getBT | ||
27 | , btGet | ||
11 | ) where | 28 | ) where |
12 | 29 | ||
13 | import qualified Language.C.Inline as C | 30 | import qualified Language.C.Inline as C |
@@ -28,9 +45,10 @@ import System.IO.Unsafe (unsafePerformIO) | |||
28 | 45 | ||
29 | import Numeric.LinearAlgebra.Devel (createVector) | 46 | import Numeric.LinearAlgebra.Devel (createVector) |
30 | 47 | ||
31 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><)) | 48 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), subMatrix) |
32 | 49 | ||
33 | import qualified Types as T | 50 | import qualified Types as T |
51 | import qualified Bar as B | ||
34 | 52 | ||
35 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | 53 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) |
36 | 54 | ||
@@ -83,16 +101,16 @@ vectorToC vec len ptr = do | |||
83 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | 101 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec |
84 | 102 | ||
85 | data SundialsDiagnostics = SundialsDiagnostics { | 103 | data SundialsDiagnostics = SundialsDiagnostics { |
86 | aRKodeGetNumSteps :: Int | 104 | aRKodeGetNumSteps :: Int |
87 | , aRKodeGetNumStepAttempts :: Int | 105 | , aRKodeGetNumStepAttempts :: Int |
88 | , aRKodeGetNumRhsEvals_fe :: Int | 106 | , aRKodeGetNumRhsEvals_fe :: Int |
89 | , aRKodeGetNumRhsEvals_fi :: Int | 107 | , aRKodeGetNumRhsEvals_fi :: Int |
90 | , aRKodeGetNumLinSolvSetups :: Int | 108 | , aRKodeGetNumLinSolvSetups :: Int |
91 | , aRKodeGetNumErrTestFails :: Int | 109 | , aRKodeGetNumErrTestFails :: Int |
92 | , aRKodeGetNumNonlinSolvIters :: Int | 110 | , aRKodeGetNumNonlinSolvIters :: Int |
93 | , aRKodeGetNumNonlinSolvConvFails :: Int | 111 | , aRKodeGetNumNonlinSolvConvFails :: Int |
94 | , aRKDlsGetNumJacEvals :: Int | 112 | , aRKDlsGetNumJacEvals :: Int |
95 | , aRKDlsGetNumRhsEvals :: Int | 113 | , aRKDlsGetNumRhsEvals :: Int |
96 | } deriving Show | 114 | } deriving Show |
97 | 115 | ||
98 | odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 116 | odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
@@ -312,3 +330,107 @@ solveOdeC fun f0 ts = unsafePerformIO $ do | |||
312 | else do | 330 | else do |
313 | return $ Left res | 331 | return $ Left res |
314 | 332 | ||
333 | btGet :: Matrix Double | ||
334 | btGet = case getBT of | ||
335 | Left c -> error $ show c -- FIXME | ||
336 | Right (v, sqp) -> subMatrix (0, 0) (4, 4) $ (B.arkSMax >< B.arkSMax) (V.toList v) | ||
337 | |||
338 | getBT :: Either Int (V.Vector Double, V.Vector Int) | ||
339 | getBT = case getButcherTable of | ||
340 | Left c -> Left $ fromIntegral c | ||
341 | Right (v, sqp) -> Right $ (coerce v, V.map fromIntegral sqp) | ||
342 | |||
343 | getButcherTable :: Either CInt ((V.Vector CDouble), V.Vector CInt) | ||
344 | getButcherTable = unsafePerformIO $ do | ||
345 | -- arkode seems to want an ODE in order to set and then get the | ||
346 | -- Butcher tableau so here's one to keep it happy | ||
347 | let fun :: CDouble -> V.Vector CDouble -> V.Vector CDouble | ||
348 | fun t ys = V.fromList [ ys V.! 0 ] | ||
349 | f0 = V.fromList [ 1.0 ] | ||
350 | ts = V.fromList [ 0.0 ] | ||
351 | dim = V.length f0 | ||
352 | nEq :: CLong | ||
353 | nEq = fromIntegral dim | ||
354 | |||
355 | -- FIXME: I believe these gets taken from the ghc heap and so should | ||
356 | -- be subject to garbage collection. | ||
357 | btSQP :: V.Vector CInt <- createVector 3 | ||
358 | btSQPMut <- V.thaw btSQP | ||
359 | btAs :: V.Vector CDouble <- createVector (B.arkSMax * B.arkSMax) | ||
360 | btAsMut <- V.thaw btAs | ||
361 | -- We need the types that sundials expects. These are tied together | ||
362 | -- in 'Types'. FIXME: The Haskell type is currently empty! | ||
363 | let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt | ||
364 | funIO x y f _ptr = do | ||
365 | -- Convert the pointer we get from C (y) to a vector, and then | ||
366 | -- apply the user-supplied function. | ||
367 | fImm <- fun x <$> getDataFromContents dim y | ||
368 | -- Fill in the provided pointer with the resulting vector. | ||
369 | putDataInContents fImm dim f | ||
370 | -- I don't understand what this comment means | ||
371 | -- Unsafe since the function will be called many times. | ||
372 | [CU.exp| int{ 0 } |] | ||
373 | res <- [C.block| int { | ||
374 | /* general problem variables */ | ||
375 | int flag; /* reusable error-checking flag */ | ||
376 | N_Vector y = NULL; /* empty vector for storing solution */ | ||
377 | void *arkode_mem = NULL; /* empty ARKode memory structure */ | ||
378 | |||
379 | /* general problem parameters */ | ||
380 | /* initial time */ | ||
381 | realtype T0 = RCONST(($vec-ptr:(double *ts))[0]); | ||
382 | /* number of dependent vars. */ | ||
383 | sunindextype NEQ = $(sunindextype nEq); | ||
384 | |||
385 | /* Initialize data structures */ | ||
386 | y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ | ||
387 | if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; | ||
388 | /* Specify initial condition */ | ||
389 | int i, j; | ||
390 | for (i = 0; i < NEQ; i++) { | ||
391 | NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i]; | ||
392 | }; | ||
393 | arkode_mem = ARKodeCreate(); /* Create the solver memory */ | ||
394 | if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; | ||
395 | |||
396 | flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, BarType y[], BarType dydt[], void * params)), T0, y); | ||
397 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | ||
398 | |||
399 | flag = ARKodeSetIRKTableNum(arkode_mem, KVAERNO_4_2_3); | ||
400 | if (check_flag(&flag, "ARKode", 1)) return 1; | ||
401 | |||
402 | int s, q, p; | ||
403 | realtype *ai = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype)); | ||
404 | realtype *ae = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype)); | ||
405 | realtype *ci = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
406 | realtype *ce = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
407 | realtype *bi = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
408 | realtype *be = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
409 | realtype *b2i = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
410 | realtype *b2e = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
411 | flag = ARKodeGetCurrentButcherTables(arkode_mem, &s, &q, &p, ai, ae, ci, ce, bi, be, b2i, b2e); | ||
412 | if (check_flag(&flag, "ARKode", 1)) return 1; | ||
413 | $vec-ptr:(int *btSQPMut)[0] = s; | ||
414 | $vec-ptr:(int *btSQPMut)[1] = q; | ||
415 | $vec-ptr:(int *btSQPMut)[2] = p; | ||
416 | for (i = 0; i < s; i++) { | ||
417 | for (j = 0; j < s; j++) { | ||
418 | /* FIXME: double should be realtype */ | ||
419 | ($vec-ptr:(double *btAsMut))[i * ARK_S_MAX + j] = ai[i * ARK_S_MAX + j]; | ||
420 | } | ||
421 | } | ||
422 | |||
423 | /* Clean up and return */ | ||
424 | N_VDestroy(y); /* Free y vector */ | ||
425 | ARKodeFree(&arkode_mem); /* Free integrator memory */ | ||
426 | |||
427 | return flag; | ||
428 | } |] | ||
429 | if res == 0 | ||
430 | then do | ||
431 | x <- V.freeze btAsMut | ||
432 | y <- V.freeze btSQPMut | ||
433 | return $ Right (x, y) | ||
434 | else do | ||
435 | return $ Left res | ||
436 | |||
diff --git a/packages/sundials/stack.yaml b/packages/sundials/stack.yaml deleted file mode 100644 index 9c6b17c..0000000 --- a/packages/sundials/stack.yaml +++ /dev/null | |||
@@ -1,66 +0,0 @@ | |||
1 | # This file was automatically generated by 'stack init' | ||
2 | # | ||
3 | # Some commonly used options have been documented as comments in this file. | ||
4 | # For advanced use and comprehensive documentation of the format, please see: | ||
5 | # https://docs.haskellstack.org/en/stable/yaml_configuration/ | ||
6 | |||
7 | # Resolver to choose a 'specific' stackage snapshot or a compiler version. | ||
8 | # A snapshot resolver dictates the compiler version and the set of packages | ||
9 | # to be used for project dependencies. For example: | ||
10 | # | ||
11 | # resolver: lts-3.5 | ||
12 | # resolver: nightly-2015-09-21 | ||
13 | # resolver: ghc-7.10.2 | ||
14 | # resolver: ghcjs-0.1.0_ghc-7.10.2 | ||
15 | # resolver: | ||
16 | # name: custom-snapshot | ||
17 | # location: "./custom-snapshot.yaml" | ||
18 | resolver: lts-10.9 | ||
19 | |||
20 | # User packages to be built. | ||
21 | # Various formats can be used as shown in the example below. | ||
22 | # | ||
23 | # packages: | ||
24 | # - some-directory | ||
25 | # - https://example.com/foo/bar/baz-0.0.2.tar.gz | ||
26 | # - location: | ||
27 | # git: https://github.com/commercialhaskell/stack.git | ||
28 | # commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a | ||
29 | # - location: https://github.com/commercialhaskell/stack/commit/e7b331f14bcffb8367cd58fbfc8b40ec7642100a | ||
30 | # extra-dep: true | ||
31 | # subdirs: | ||
32 | # - auto-update | ||
33 | # - wai | ||
34 | # | ||
35 | # A package marked 'extra-dep: true' will only be built if demanded by a | ||
36 | # non-dependency (i.e. a user package), and its test suites and benchmarks | ||
37 | # will not be run. This is useful for tweaking upstream packages. | ||
38 | packages: | ||
39 | - . | ||
40 | # Dependency packages to be pulled from upstream that are not in the resolver | ||
41 | # (e.g., acme-missiles-0.3) | ||
42 | # extra-deps: [] | ||
43 | |||
44 | # Override default flag values for local packages and extra-deps | ||
45 | # flags: {} | ||
46 | |||
47 | # Extra package databases containing global packages | ||
48 | # extra-package-dbs: [] | ||
49 | |||
50 | # Control whether we use the GHC we find on the path | ||
51 | # system-ghc: true | ||
52 | # | ||
53 | # Require a specific version of stack, using version ranges | ||
54 | # require-stack-version: -any # Default | ||
55 | # require-stack-version: ">=1.6" | ||
56 | # | ||
57 | # Override the architecture used by stack, especially useful on Windows | ||
58 | # arch: i386 | ||
59 | # arch: x86_64 | ||
60 | # | ||
61 | # Extra directories used by stack for building | ||
62 | # extra-include-dirs: [/path/to/dir] | ||
63 | # extra-lib-dirs: [/path/to/dir] | ||
64 | # | ||
65 | # Allow a newer minor version of GHC than the snapshot specifies | ||
66 | # compiler-check: newer-minor \ No newline at end of file | ||