summaryrefslogtreecommitdiff
path: root/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs')
-rw-r--r--packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs55
1 files changed, 29 insertions, 26 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
index 2577b8e..b6a59e2 100644
--- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
+++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
@@ -101,7 +101,7 @@ import Numeric.LinearAlgebra.Devel (createVector)
101 101
102import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), 102import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><),
103 subMatrix, rows, cols, toLists, 103 subMatrix, rows, cols, toLists,
104 size) 104 size, subVector)
105 105
106import qualified Types as T 106import qualified Types as T
107import Arkode (sDIRK_2_1_2, kVAERNO_4_2_3, sDIRK_5_3_4) 107import Arkode (sDIRK_2_1_2, kVAERNO_4_2_3, sDIRK_5_3_4)
@@ -482,22 +482,24 @@ solveOdeC method jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do
482 else do 482 else do
483 return $ Left res 483 return $ Left res
484 484
485btGet :: ODEMethod -> Matrix Double 485btGet :: ODEMethod -> (Matrix Double, Vector Double)
486btGet method = case getBT method of 486btGet method = case getBT method of
487 Left c -> error $ show c -- FIXME 487 Left c -> error $ show c -- FIXME
488 Right (v, sqp) -> subMatrix (0, 0) (s, s) $ 488 Right ((v, w), sqp) -> ( subMatrix (0, 0) (s, s) $
489 (B.arkSMax >< B.arkSMax) (V.toList v) 489 (B.arkSMax >< B.arkSMax) (V.toList v)
490 , subVector 0 s w)
490 where 491 where
491 s = fromIntegral $ sqp V.! 0 492 s = fromIntegral $ sqp V.! 0
492 493
493getBT :: ODEMethod -> Either Int (V.Vector Double, V.Vector Int) 494getBT :: ODEMethod -> Either Int ((V.Vector Double, V.Vector Double), V.Vector Int)
494getBT method = case getButcherTable method of 495getBT method = case getButcherTable method of
495 Left c -> Left $ fromIntegral c 496 Left c -> Left $ fromIntegral c
496 Right (v, sqp) -> Right $ (coerce v, V.map fromIntegral sqp) 497 Right ((v, w), sqp) -> Right $ ((coerce v, coerce w), V.map fromIntegral sqp)
497 498
498getButcherTable :: ODEMethod -> Either CInt ((V.Vector CDouble), V.Vector CInt) 499getButcherTable :: ODEMethod
500 -> Either CInt ((V.Vector CDouble, V.Vector CDouble), V.Vector CInt)
499getButcherTable method = unsafePerformIO $ do 501getButcherTable method = unsafePerformIO $ do
500 -- arkode seems to want an ODE in order to set and then get the 502 -- ARKode seems to want an ODE in order to set and then get the
501 -- Butcher tableau so here's one to keep it happy 503 -- Butcher tableau so here's one to keep it happy
502 let fun :: CDouble -> V.Vector CDouble -> V.Vector CDouble 504 let fun :: CDouble -> V.Vector CDouble -> V.Vector CDouble
503 fun _t ys = V.fromList [ ys V.! 0 ] 505 fun _t ys = V.fromList [ ys V.! 0 ]
@@ -509,41 +511,37 @@ getButcherTable method = unsafePerformIO $ do
509 mN :: CInt 511 mN :: CInt
510 mN = fromIntegral $ getMethod method 512 mN = fromIntegral $ getMethod method
511 513
512 -- FIXME: I believe these gets taken from the ghc heap and so should
513 -- be subject to garbage collection.
514 btSQP :: V.Vector CInt <- createVector 3 514 btSQP :: V.Vector CInt <- createVector 3
515 btSQPMut <- V.thaw btSQP 515 btSQPMut <- V.thaw btSQP
516 btAs :: V.Vector CDouble <- createVector (B.arkSMax * B.arkSMax) 516 btAs :: V.Vector CDouble <- createVector (B.arkSMax * B.arkSMax)
517 btAsMut <- V.thaw btAs 517 btAsMut <- V.thaw btAs
518 -- We need the types that sundials expects. These are tied together 518 btCs :: V.Vector CDouble <- createVector B.arkSMax
519 -- in 'Types'. FIXME: The Haskell type is currently empty! 519 btCsMut <- V.thaw btCs
520 let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt 520 let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt
521 funIO x y f _ptr = do 521 funIO x y f _ptr = do
522 -- Convert the pointer we get from C (y) to a vector, and then
523 -- apply the user-supplied function.
524 fImm <- fun x <$> getDataFromContents dim y 522 fImm <- fun x <$> getDataFromContents dim y
525 -- Fill in the provided pointer with the resulting vector.
526 putDataInContents fImm dim f 523 putDataInContents fImm dim f
527 -- I don't understand what this comment means 524 -- FIXME: I don't understand what this comment means
528 -- Unsafe since the function will be called many times. 525 -- Unsafe since the function will be called many times.
529 [CU.exp| int{ 0 } |] 526 [CU.exp| int{ 0 } |]
530 res <- [C.block| int { 527 res <- [C.block| int {
531 /* general problem variables */ 528 /* general problem variables */
532 int flag; /* reusable error-checking flag */ 529
533 N_Vector y = NULL; /* empty vector for storing solution */ 530 int flag; /* reusable error-checking flag */
534 void *arkode_mem = NULL; /* empty ARKode memory structure */ 531 N_Vector y = NULL; /* empty vector for storing solution */
532 void *arkode_mem = NULL; /* empty ARKode memory structure */
533 int i, j; /* reusable loop indices */
535 534
536 /* general problem parameters */ 535 /* general problem parameters */
537 /* initial time */ 536
538 realtype T0 = RCONST(($vec-ptr:(double *ts))[0]); 537 realtype T0 = RCONST(($vec-ptr:(double *ts))[0]); /* initial time */
539 /* number of dependent vars. */ 538 sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars */
540 sunindextype NEQ = $(sunindextype nEq);
541 539
542 /* Initialize data structures */ 540 /* Initialize data structures */
543 y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ 541
542 y = N_VNew_Serial(NEQ); /* Create serial vector for solution */
544 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; 543 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1;
545 /* Specify initial condition */ 544 /* Specify initial condition */
546 int i, j;
547 for (i = 0; i < NEQ; i++) { 545 for (i = 0; i < NEQ; i++) {
548 NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i]; 546 NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i];
549 }; 547 };
@@ -577,6 +575,10 @@ getButcherTable method = unsafePerformIO $ do
577 } 575 }
578 } 576 }
579 577
578 for (i = 0; i < s; i++) {
579 ($vec-ptr:(double *btCsMut))[i] = ci[i];
580 }
581
580 /* Clean up and return */ 582 /* Clean up and return */
581 N_VDestroy(y); /* Free y vector */ 583 N_VDestroy(y); /* Free y vector */
582 ARKodeFree(&arkode_mem); /* Free integrator memory */ 584 ARKodeFree(&arkode_mem); /* Free integrator memory */
@@ -587,7 +589,8 @@ getButcherTable method = unsafePerformIO $ do
587 then do 589 then do
588 x <- V.freeze btAsMut 590 x <- V.freeze btAsMut
589 y <- V.freeze btSQPMut 591 y <- V.freeze btSQPMut
590 return $ Right (x, y) 592 z <- V.freeze btCsMut
593 return $ Right ((x, z), y)
591 else do 594 else do
592 return $ Left res 595 return $ Left res
593 596