diff options
Diffstat (limited to 'packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs')
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | 55 |
1 files changed, 29 insertions, 26 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index 2577b8e..b6a59e2 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | |||
@@ -101,7 +101,7 @@ import Numeric.LinearAlgebra.Devel (createVector) | |||
101 | 101 | ||
102 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), | 102 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), |
103 | subMatrix, rows, cols, toLists, | 103 | subMatrix, rows, cols, toLists, |
104 | size) | 104 | size, subVector) |
105 | 105 | ||
106 | import qualified Types as T | 106 | import qualified Types as T |
107 | import Arkode (sDIRK_2_1_2, kVAERNO_4_2_3, sDIRK_5_3_4) | 107 | import Arkode (sDIRK_2_1_2, kVAERNO_4_2_3, sDIRK_5_3_4) |
@@ -482,22 +482,24 @@ solveOdeC method jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do | |||
482 | else do | 482 | else do |
483 | return $ Left res | 483 | return $ Left res |
484 | 484 | ||
485 | btGet :: ODEMethod -> Matrix Double | 485 | btGet :: ODEMethod -> (Matrix Double, Vector Double) |
486 | btGet method = case getBT method of | 486 | btGet method = case getBT method of |
487 | Left c -> error $ show c -- FIXME | 487 | Left c -> error $ show c -- FIXME |
488 | Right (v, sqp) -> subMatrix (0, 0) (s, s) $ | 488 | Right ((v, w), sqp) -> ( subMatrix (0, 0) (s, s) $ |
489 | (B.arkSMax >< B.arkSMax) (V.toList v) | 489 | (B.arkSMax >< B.arkSMax) (V.toList v) |
490 | , subVector 0 s w) | ||
490 | where | 491 | where |
491 | s = fromIntegral $ sqp V.! 0 | 492 | s = fromIntegral $ sqp V.! 0 |
492 | 493 | ||
493 | getBT :: ODEMethod -> Either Int (V.Vector Double, V.Vector Int) | 494 | getBT :: ODEMethod -> Either Int ((V.Vector Double, V.Vector Double), V.Vector Int) |
494 | getBT method = case getButcherTable method of | 495 | getBT method = case getButcherTable method of |
495 | Left c -> Left $ fromIntegral c | 496 | Left c -> Left $ fromIntegral c |
496 | Right (v, sqp) -> Right $ (coerce v, V.map fromIntegral sqp) | 497 | Right ((v, w), sqp) -> Right $ ((coerce v, coerce w), V.map fromIntegral sqp) |
497 | 498 | ||
498 | getButcherTable :: ODEMethod -> Either CInt ((V.Vector CDouble), V.Vector CInt) | 499 | getButcherTable :: ODEMethod |
500 | -> Either CInt ((V.Vector CDouble, V.Vector CDouble), V.Vector CInt) | ||
499 | getButcherTable method = unsafePerformIO $ do | 501 | getButcherTable method = unsafePerformIO $ do |
500 | -- arkode seems to want an ODE in order to set and then get the | 502 | -- ARKode seems to want an ODE in order to set and then get the |
501 | -- Butcher tableau so here's one to keep it happy | 503 | -- Butcher tableau so here's one to keep it happy |
502 | let fun :: CDouble -> V.Vector CDouble -> V.Vector CDouble | 504 | let fun :: CDouble -> V.Vector CDouble -> V.Vector CDouble |
503 | fun _t ys = V.fromList [ ys V.! 0 ] | 505 | fun _t ys = V.fromList [ ys V.! 0 ] |
@@ -509,41 +511,37 @@ getButcherTable method = unsafePerformIO $ do | |||
509 | mN :: CInt | 511 | mN :: CInt |
510 | mN = fromIntegral $ getMethod method | 512 | mN = fromIntegral $ getMethod method |
511 | 513 | ||
512 | -- FIXME: I believe these gets taken from the ghc heap and so should | ||
513 | -- be subject to garbage collection. | ||
514 | btSQP :: V.Vector CInt <- createVector 3 | 514 | btSQP :: V.Vector CInt <- createVector 3 |
515 | btSQPMut <- V.thaw btSQP | 515 | btSQPMut <- V.thaw btSQP |
516 | btAs :: V.Vector CDouble <- createVector (B.arkSMax * B.arkSMax) | 516 | btAs :: V.Vector CDouble <- createVector (B.arkSMax * B.arkSMax) |
517 | btAsMut <- V.thaw btAs | 517 | btAsMut <- V.thaw btAs |
518 | -- We need the types that sundials expects. These are tied together | 518 | btCs :: V.Vector CDouble <- createVector B.arkSMax |
519 | -- in 'Types'. FIXME: The Haskell type is currently empty! | 519 | btCsMut <- V.thaw btCs |
520 | let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt | 520 | let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt |
521 | funIO x y f _ptr = do | 521 | funIO x y f _ptr = do |
522 | -- Convert the pointer we get from C (y) to a vector, and then | ||
523 | -- apply the user-supplied function. | ||
524 | fImm <- fun x <$> getDataFromContents dim y | 522 | fImm <- fun x <$> getDataFromContents dim y |
525 | -- Fill in the provided pointer with the resulting vector. | ||
526 | putDataInContents fImm dim f | 523 | putDataInContents fImm dim f |
527 | -- I don't understand what this comment means | 524 | -- FIXME: I don't understand what this comment means |
528 | -- Unsafe since the function will be called many times. | 525 | -- Unsafe since the function will be called many times. |
529 | [CU.exp| int{ 0 } |] | 526 | [CU.exp| int{ 0 } |] |
530 | res <- [C.block| int { | 527 | res <- [C.block| int { |
531 | /* general problem variables */ | 528 | /* general problem variables */ |
532 | int flag; /* reusable error-checking flag */ | 529 | |
533 | N_Vector y = NULL; /* empty vector for storing solution */ | 530 | int flag; /* reusable error-checking flag */ |
534 | void *arkode_mem = NULL; /* empty ARKode memory structure */ | 531 | N_Vector y = NULL; /* empty vector for storing solution */ |
532 | void *arkode_mem = NULL; /* empty ARKode memory structure */ | ||
533 | int i, j; /* reusable loop indices */ | ||
535 | 534 | ||
536 | /* general problem parameters */ | 535 | /* general problem parameters */ |
537 | /* initial time */ | 536 | |
538 | realtype T0 = RCONST(($vec-ptr:(double *ts))[0]); | 537 | realtype T0 = RCONST(($vec-ptr:(double *ts))[0]); /* initial time */ |
539 | /* number of dependent vars. */ | 538 | sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars */ |
540 | sunindextype NEQ = $(sunindextype nEq); | ||
541 | 539 | ||
542 | /* Initialize data structures */ | 540 | /* Initialize data structures */ |
543 | y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ | 541 | |
542 | y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ | ||
544 | if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; | 543 | if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; |
545 | /* Specify initial condition */ | 544 | /* Specify initial condition */ |
546 | int i, j; | ||
547 | for (i = 0; i < NEQ; i++) { | 545 | for (i = 0; i < NEQ; i++) { |
548 | NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i]; | 546 | NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i]; |
549 | }; | 547 | }; |
@@ -577,6 +575,10 @@ getButcherTable method = unsafePerformIO $ do | |||
577 | } | 575 | } |
578 | } | 576 | } |
579 | 577 | ||
578 | for (i = 0; i < s; i++) { | ||
579 | ($vec-ptr:(double *btCsMut))[i] = ci[i]; | ||
580 | } | ||
581 | |||
580 | /* Clean up and return */ | 582 | /* Clean up and return */ |
581 | N_VDestroy(y); /* Free y vector */ | 583 | N_VDestroy(y); /* Free y vector */ |
582 | ARKodeFree(&arkode_mem); /* Free integrator memory */ | 584 | ARKodeFree(&arkode_mem); /* Free integrator memory */ |
@@ -587,7 +589,8 @@ getButcherTable method = unsafePerformIO $ do | |||
587 | then do | 589 | then do |
588 | x <- V.freeze btAsMut | 590 | x <- V.freeze btAsMut |
589 | y <- V.freeze btSQPMut | 591 | y <- V.freeze btSQPMut |
590 | return $ Right (x, y) | 592 | z <- V.freeze btCsMut |
593 | return $ Right ((x, z), y) | ||
591 | else do | 594 | else do |
592 | return $ Left res | 595 | return $ Left res |
593 | 596 | ||