summaryrefslogtreecommitdiff
path: root/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs')
-rw-r--r--packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs142
1 files changed, 132 insertions, 10 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
index f432951..76ed61b 100644
--- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
+++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
@@ -6,8 +6,25 @@
6{-# LANGUAGE OverloadedStrings #-} 6{-# LANGUAGE OverloadedStrings #-}
7{-# LANGUAGE ScopedTypeVariables #-} 7{-# LANGUAGE ScopedTypeVariables #-}
8 8
9-- |
10-- Module: Numeric.Sundials.ARKode
11--
12-- Blah
13--
14-- \[
15-- \begin{array}{c|cccc}
16-- c_1 & 0.0 & 0.0 & 0.0 & 0.0 \\
17-- c_2 & 0.4358665215 & 0.4358665215 & 0.0 & 0.0 \\
18-- c_3 & 0.490563388419108 & 7.3570090080892e-2 & 0.4358665215 & 0.0 \\
19-- c_4 & 0.308809969973036 & 1.490563388254106 & -1.235239879727145 & 0.4358665215 \\
20-- \end{array}
21-- \]
22--
9module Numeric.Sundials.Arkode.ODE ( solveOde 23module Numeric.Sundials.Arkode.ODE ( solveOde
10 , odeSolve 24 , odeSolve
25 , getButcherTable
26 , getBT
27 , btGet
11 ) where 28 ) where
12 29
13import qualified Language.C.Inline as C 30import qualified Language.C.Inline as C
@@ -28,9 +45,10 @@ import System.IO.Unsafe (unsafePerformIO)
28 45
29import Numeric.LinearAlgebra.Devel (createVector) 46import Numeric.LinearAlgebra.Devel (createVector)
30 47
31import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><)) 48import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), subMatrix)
32 49
33import qualified Types as T 50import qualified Types as T
51import qualified Bar as B
34 52
35C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) 53C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx)
36 54
@@ -83,16 +101,16 @@ vectorToC vec len ptr = do
83 V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec 101 V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec
84 102
85data SundialsDiagnostics = SundialsDiagnostics { 103data SundialsDiagnostics = SundialsDiagnostics {
86 aRKodeGetNumSteps :: Int 104 aRKodeGetNumSteps :: Int
87 , aRKodeGetNumStepAttempts :: Int 105 , aRKodeGetNumStepAttempts :: Int
88 , aRKodeGetNumRhsEvals_fe :: Int 106 , aRKodeGetNumRhsEvals_fe :: Int
89 , aRKodeGetNumRhsEvals_fi :: Int 107 , aRKodeGetNumRhsEvals_fi :: Int
90 , aRKodeGetNumLinSolvSetups :: Int 108 , aRKodeGetNumLinSolvSetups :: Int
91 , aRKodeGetNumErrTestFails :: Int 109 , aRKodeGetNumErrTestFails :: Int
92 , aRKodeGetNumNonlinSolvIters :: Int 110 , aRKodeGetNumNonlinSolvIters :: Int
93 , aRKodeGetNumNonlinSolvConvFails :: Int 111 , aRKodeGetNumNonlinSolvConvFails :: Int
94 , aRKDlsGetNumJacEvals :: Int 112 , aRKDlsGetNumJacEvals :: Int
95 , aRKDlsGetNumRhsEvals :: Int 113 , aRKDlsGetNumRhsEvals :: Int
96 } deriving Show 114 } deriving Show
97 115
98odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 116odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
@@ -312,3 +330,107 @@ solveOdeC fun f0 ts = unsafePerformIO $ do
312 else do 330 else do
313 return $ Left res 331 return $ Left res
314 332
333btGet :: Matrix Double
334btGet = case getBT of
335 Left c -> error $ show c -- FIXME
336 Right (v, sqp) -> subMatrix (0, 0) (4, 4) $ (B.arkSMax >< B.arkSMax) (V.toList v)
337
338getBT :: Either Int (V.Vector Double, V.Vector Int)
339getBT = case getButcherTable of
340 Left c -> Left $ fromIntegral c
341 Right (v, sqp) -> Right $ (coerce v, V.map fromIntegral sqp)
342
343getButcherTable :: Either CInt ((V.Vector CDouble), V.Vector CInt)
344getButcherTable = unsafePerformIO $ do
345 -- arkode seems to want an ODE in order to set and then get the
346 -- Butcher tableau so here's one to keep it happy
347 let fun :: CDouble -> V.Vector CDouble -> V.Vector CDouble
348 fun t ys = V.fromList [ ys V.! 0 ]
349 f0 = V.fromList [ 1.0 ]
350 ts = V.fromList [ 0.0 ]
351 dim = V.length f0
352 nEq :: CLong
353 nEq = fromIntegral dim
354
355 -- FIXME: I believe these gets taken from the ghc heap and so should
356 -- be subject to garbage collection.
357 btSQP :: V.Vector CInt <- createVector 3
358 btSQPMut <- V.thaw btSQP
359 btAs :: V.Vector CDouble <- createVector (B.arkSMax * B.arkSMax)
360 btAsMut <- V.thaw btAs
361 -- We need the types that sundials expects. These are tied together
362 -- in 'Types'. FIXME: The Haskell type is currently empty!
363 let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt
364 funIO x y f _ptr = do
365 -- Convert the pointer we get from C (y) to a vector, and then
366 -- apply the user-supplied function.
367 fImm <- fun x <$> getDataFromContents dim y
368 -- Fill in the provided pointer with the resulting vector.
369 putDataInContents fImm dim f
370 -- I don't understand what this comment means
371 -- Unsafe since the function will be called many times.
372 [CU.exp| int{ 0 } |]
373 res <- [C.block| int {
374 /* general problem variables */
375 int flag; /* reusable error-checking flag */
376 N_Vector y = NULL; /* empty vector for storing solution */
377 void *arkode_mem = NULL; /* empty ARKode memory structure */
378
379 /* general problem parameters */
380 /* initial time */
381 realtype T0 = RCONST(($vec-ptr:(double *ts))[0]);
382 /* number of dependent vars. */
383 sunindextype NEQ = $(sunindextype nEq);
384
385 /* Initialize data structures */
386 y = N_VNew_Serial(NEQ); /* Create serial vector for solution */
387 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1;
388 /* Specify initial condition */
389 int i, j;
390 for (i = 0; i < NEQ; i++) {
391 NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i];
392 };
393 arkode_mem = ARKodeCreate(); /* Create the solver memory */
394 if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1;
395
396 flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, BarType y[], BarType dydt[], void * params)), T0, y);
397 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
398
399 flag = ARKodeSetIRKTableNum(arkode_mem, KVAERNO_4_2_3);
400 if (check_flag(&flag, "ARKode", 1)) return 1;
401
402 int s, q, p;
403 realtype *ai = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype));
404 realtype *ae = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype));
405 realtype *ci = (realtype *)malloc(ARK_S_MAX * sizeof(realtype));
406 realtype *ce = (realtype *)malloc(ARK_S_MAX * sizeof(realtype));
407 realtype *bi = (realtype *)malloc(ARK_S_MAX * sizeof(realtype));
408 realtype *be = (realtype *)malloc(ARK_S_MAX * sizeof(realtype));
409 realtype *b2i = (realtype *)malloc(ARK_S_MAX * sizeof(realtype));
410 realtype *b2e = (realtype *)malloc(ARK_S_MAX * sizeof(realtype));
411 flag = ARKodeGetCurrentButcherTables(arkode_mem, &s, &q, &p, ai, ae, ci, ce, bi, be, b2i, b2e);
412 if (check_flag(&flag, "ARKode", 1)) return 1;
413 $vec-ptr:(int *btSQPMut)[0] = s;
414 $vec-ptr:(int *btSQPMut)[1] = q;
415 $vec-ptr:(int *btSQPMut)[2] = p;
416 for (i = 0; i < s; i++) {
417 for (j = 0; j < s; j++) {
418 /* FIXME: double should be realtype */
419 ($vec-ptr:(double *btAsMut))[i * ARK_S_MAX + j] = ai[i * ARK_S_MAX + j];
420 }
421 }
422
423 /* Clean up and return */
424 N_VDestroy(y); /* Free y vector */
425 ARKodeFree(&arkode_mem); /* Free integrator memory */
426
427 return flag;
428 } |]
429 if res == 0
430 then do
431 x <- V.freeze btAsMut
432 y <- V.freeze btSQPMut
433 return $ Right (x, y)
434 else do
435 return $ Left res
436