diff options
Diffstat (limited to 'packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs')
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | 142 |
1 files changed, 132 insertions, 10 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index f432951..76ed61b 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | |||
@@ -6,8 +6,25 @@ | |||
6 | {-# LANGUAGE OverloadedStrings #-} | 6 | {-# LANGUAGE OverloadedStrings #-} |
7 | {-# LANGUAGE ScopedTypeVariables #-} | 7 | {-# LANGUAGE ScopedTypeVariables #-} |
8 | 8 | ||
9 | -- | | ||
10 | -- Module: Numeric.Sundials.ARKode | ||
11 | -- | ||
12 | -- Blah | ||
13 | -- | ||
14 | -- \[ | ||
15 | -- \begin{array}{c|cccc} | ||
16 | -- c_1 & 0.0 & 0.0 & 0.0 & 0.0 \\ | ||
17 | -- c_2 & 0.4358665215 & 0.4358665215 & 0.0 & 0.0 \\ | ||
18 | -- c_3 & 0.490563388419108 & 7.3570090080892e-2 & 0.4358665215 & 0.0 \\ | ||
19 | -- c_4 & 0.308809969973036 & 1.490563388254106 & -1.235239879727145 & 0.4358665215 \\ | ||
20 | -- \end{array} | ||
21 | -- \] | ||
22 | -- | ||
9 | module Numeric.Sundials.Arkode.ODE ( solveOde | 23 | module Numeric.Sundials.Arkode.ODE ( solveOde |
10 | , odeSolve | 24 | , odeSolve |
25 | , getButcherTable | ||
26 | , getBT | ||
27 | , btGet | ||
11 | ) where | 28 | ) where |
12 | 29 | ||
13 | import qualified Language.C.Inline as C | 30 | import qualified Language.C.Inline as C |
@@ -28,9 +45,10 @@ import System.IO.Unsafe (unsafePerformIO) | |||
28 | 45 | ||
29 | import Numeric.LinearAlgebra.Devel (createVector) | 46 | import Numeric.LinearAlgebra.Devel (createVector) |
30 | 47 | ||
31 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><)) | 48 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), subMatrix) |
32 | 49 | ||
33 | import qualified Types as T | 50 | import qualified Types as T |
51 | import qualified Bar as B | ||
34 | 52 | ||
35 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | 53 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) |
36 | 54 | ||
@@ -83,16 +101,16 @@ vectorToC vec len ptr = do | |||
83 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | 101 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec |
84 | 102 | ||
85 | data SundialsDiagnostics = SundialsDiagnostics { | 103 | data SundialsDiagnostics = SundialsDiagnostics { |
86 | aRKodeGetNumSteps :: Int | 104 | aRKodeGetNumSteps :: Int |
87 | , aRKodeGetNumStepAttempts :: Int | 105 | , aRKodeGetNumStepAttempts :: Int |
88 | , aRKodeGetNumRhsEvals_fe :: Int | 106 | , aRKodeGetNumRhsEvals_fe :: Int |
89 | , aRKodeGetNumRhsEvals_fi :: Int | 107 | , aRKodeGetNumRhsEvals_fi :: Int |
90 | , aRKodeGetNumLinSolvSetups :: Int | 108 | , aRKodeGetNumLinSolvSetups :: Int |
91 | , aRKodeGetNumErrTestFails :: Int | 109 | , aRKodeGetNumErrTestFails :: Int |
92 | , aRKodeGetNumNonlinSolvIters :: Int | 110 | , aRKodeGetNumNonlinSolvIters :: Int |
93 | , aRKodeGetNumNonlinSolvConvFails :: Int | 111 | , aRKodeGetNumNonlinSolvConvFails :: Int |
94 | , aRKDlsGetNumJacEvals :: Int | 112 | , aRKDlsGetNumJacEvals :: Int |
95 | , aRKDlsGetNumRhsEvals :: Int | 113 | , aRKDlsGetNumRhsEvals :: Int |
96 | } deriving Show | 114 | } deriving Show |
97 | 115 | ||
98 | odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 116 | odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
@@ -312,3 +330,107 @@ solveOdeC fun f0 ts = unsafePerformIO $ do | |||
312 | else do | 330 | else do |
313 | return $ Left res | 331 | return $ Left res |
314 | 332 | ||
333 | btGet :: Matrix Double | ||
334 | btGet = case getBT of | ||
335 | Left c -> error $ show c -- FIXME | ||
336 | Right (v, sqp) -> subMatrix (0, 0) (4, 4) $ (B.arkSMax >< B.arkSMax) (V.toList v) | ||
337 | |||
338 | getBT :: Either Int (V.Vector Double, V.Vector Int) | ||
339 | getBT = case getButcherTable of | ||
340 | Left c -> Left $ fromIntegral c | ||
341 | Right (v, sqp) -> Right $ (coerce v, V.map fromIntegral sqp) | ||
342 | |||
343 | getButcherTable :: Either CInt ((V.Vector CDouble), V.Vector CInt) | ||
344 | getButcherTable = unsafePerformIO $ do | ||
345 | -- arkode seems to want an ODE in order to set and then get the | ||
346 | -- Butcher tableau so here's one to keep it happy | ||
347 | let fun :: CDouble -> V.Vector CDouble -> V.Vector CDouble | ||
348 | fun t ys = V.fromList [ ys V.! 0 ] | ||
349 | f0 = V.fromList [ 1.0 ] | ||
350 | ts = V.fromList [ 0.0 ] | ||
351 | dim = V.length f0 | ||
352 | nEq :: CLong | ||
353 | nEq = fromIntegral dim | ||
354 | |||
355 | -- FIXME: I believe these gets taken from the ghc heap and so should | ||
356 | -- be subject to garbage collection. | ||
357 | btSQP :: V.Vector CInt <- createVector 3 | ||
358 | btSQPMut <- V.thaw btSQP | ||
359 | btAs :: V.Vector CDouble <- createVector (B.arkSMax * B.arkSMax) | ||
360 | btAsMut <- V.thaw btAs | ||
361 | -- We need the types that sundials expects. These are tied together | ||
362 | -- in 'Types'. FIXME: The Haskell type is currently empty! | ||
363 | let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt | ||
364 | funIO x y f _ptr = do | ||
365 | -- Convert the pointer we get from C (y) to a vector, and then | ||
366 | -- apply the user-supplied function. | ||
367 | fImm <- fun x <$> getDataFromContents dim y | ||
368 | -- Fill in the provided pointer with the resulting vector. | ||
369 | putDataInContents fImm dim f | ||
370 | -- I don't understand what this comment means | ||
371 | -- Unsafe since the function will be called many times. | ||
372 | [CU.exp| int{ 0 } |] | ||
373 | res <- [C.block| int { | ||
374 | /* general problem variables */ | ||
375 | int flag; /* reusable error-checking flag */ | ||
376 | N_Vector y = NULL; /* empty vector for storing solution */ | ||
377 | void *arkode_mem = NULL; /* empty ARKode memory structure */ | ||
378 | |||
379 | /* general problem parameters */ | ||
380 | /* initial time */ | ||
381 | realtype T0 = RCONST(($vec-ptr:(double *ts))[0]); | ||
382 | /* number of dependent vars. */ | ||
383 | sunindextype NEQ = $(sunindextype nEq); | ||
384 | |||
385 | /* Initialize data structures */ | ||
386 | y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ | ||
387 | if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; | ||
388 | /* Specify initial condition */ | ||
389 | int i, j; | ||
390 | for (i = 0; i < NEQ; i++) { | ||
391 | NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i]; | ||
392 | }; | ||
393 | arkode_mem = ARKodeCreate(); /* Create the solver memory */ | ||
394 | if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; | ||
395 | |||
396 | flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, BarType y[], BarType dydt[], void * params)), T0, y); | ||
397 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | ||
398 | |||
399 | flag = ARKodeSetIRKTableNum(arkode_mem, KVAERNO_4_2_3); | ||
400 | if (check_flag(&flag, "ARKode", 1)) return 1; | ||
401 | |||
402 | int s, q, p; | ||
403 | realtype *ai = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype)); | ||
404 | realtype *ae = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype)); | ||
405 | realtype *ci = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
406 | realtype *ce = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
407 | realtype *bi = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
408 | realtype *be = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
409 | realtype *b2i = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
410 | realtype *b2e = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
411 | flag = ARKodeGetCurrentButcherTables(arkode_mem, &s, &q, &p, ai, ae, ci, ce, bi, be, b2i, b2e); | ||
412 | if (check_flag(&flag, "ARKode", 1)) return 1; | ||
413 | $vec-ptr:(int *btSQPMut)[0] = s; | ||
414 | $vec-ptr:(int *btSQPMut)[1] = q; | ||
415 | $vec-ptr:(int *btSQPMut)[2] = p; | ||
416 | for (i = 0; i < s; i++) { | ||
417 | for (j = 0; j < s; j++) { | ||
418 | /* FIXME: double should be realtype */ | ||
419 | ($vec-ptr:(double *btAsMut))[i * ARK_S_MAX + j] = ai[i * ARK_S_MAX + j]; | ||
420 | } | ||
421 | } | ||
422 | |||
423 | /* Clean up and return */ | ||
424 | N_VDestroy(y); /* Free y vector */ | ||
425 | ARKodeFree(&arkode_mem); /* Free integrator memory */ | ||
426 | |||
427 | return flag; | ||
428 | } |] | ||
429 | if res == 0 | ||
430 | then do | ||
431 | x <- V.freeze btAsMut | ||
432 | y <- V.freeze btSQPMut | ||
433 | return $ Right (x, y) | ||
434 | else do | ||
435 | return $ Left res | ||
436 | |||