From 755175a557d07c6f73683f358ddd8f8ee07f26a9 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Tue, 20 Mar 2018 12:03:09 +0000 Subject: Handle arbitrary systems --- packages/sundials/src/Main.hs | 47 +++++++++++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 13 deletions(-) (limited to 'packages/sundials/src/Main.hs') diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index b3ebcb3..fc48710 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs @@ -78,11 +78,27 @@ stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] u = v V.! 0 lamda = -100.0 +brusselator :: Double -> V.Vector Double -> V.Vector Double +brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2 + , w * u - v * u^2 + , (b - w) / eps - w * u + ] + where + a = 1.0 + b = 3.5 + eps = 5.0e-6 + u = x V.! 0 + v = x V.! 1 + w = x V.! 2 + solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> - V.Vector Double -> - CInt + V.Vector CDouble -> + Either CInt (V.Vector CDouble) solveOdeC fun f0 = unsafePerformIO $ do let dim = V.length f0 + nEq :: CLong + nEq = fromIntegral dim + fMut <- V.thaw f0 -- We need the types that sundials expects. These are tied together -- in 'Types'. The Haskell type is currently empty! let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt @@ -110,21 +126,22 @@ solveOdeC fun f0 = unsafePerformIO $ do realtype T0 = RCONST(0.0); /* initial time */ realtype Tf = RCONST(10.0); /* final time */ realtype dTout = RCONST(1.0); /* time between outputs */ - sunindextype NEQ = 1; /* number of dependent vars. */ + sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ realtype reltol = 1.0e-6; /* tolerances */ realtype abstol = 1.0e-10; - realtype lamda = -100.0; /* stiffness parameter */ /* Initial diagnostics output */ printf("\nAnalytical ODE test problem:\n"); - printf(" lamda = %"GSYM"\n", lamda); printf(" reltol = %.1"ESYM"\n", reltol); printf(" abstol = %.1"ESYM"\n\n",abstol); /* Initialize data structures */ y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; - N_VConst(0.0, y); /* Specify initial condition */ + int i; + for (i = 0; i < NEQ; i++) { + NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i]; + }; /* Specify initial condition */ arkode_mem = ARKodeCreate(); /* Create the solver memory */ if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; @@ -139,8 +156,6 @@ solveOdeC fun f0 = unsafePerformIO $ do if (check_flag(&flag, "ARKodeInit", 1)) return 1; /* Set routines */ - flag = ARKodeSetUserData(arkode_mem, (void *) &lamda); /* Pass lamda to user functions */ - if (check_flag(&flag, "ARKodeSetUserData", 1)) return 1; flag = ARKodeSStolerances(arkode_mem, reltol, abstol); /* Specify tolerances */ if (check_flag(&flag, "ARKodeSStolerances", 1)) return 1; @@ -182,6 +197,10 @@ solveOdeC fun f0 = unsafePerformIO $ do printf(" ---------------------\n"); fclose(UFID); + for (i = 0; i < NEQ; i++) { + ($vec-ptr:(double *fMut))[i] = NV_Ith_S(y,i); + }; + /* Get/print some final statistics on how the solve progressed */ flag = ARKodeGetNumSteps(arkode_mem, &nst); check_flag(&flag, "ARKodeGetNumSteps", 1); @@ -212,9 +231,6 @@ solveOdeC fun f0 = unsafePerformIO $ do printf(" Total number of linear solver convergence failures = %li\n", ncfn); printf(" Total number of error test failures = %li\n\n", netf); - /* check the solution error */ - flag = check_ans(y, t, reltol, abstol); - /* Clean up and return */ N_VDestroy(y); /* Free y vector */ ARKodeFree(&arkode_mem); /* Free integrator memory */ @@ -223,9 +239,14 @@ solveOdeC fun f0 = unsafePerformIO $ do return flag; } |] - return res + if res ==0 + then do + v <- V.freeze fMut + return $ Right v + else do + return $ Left res main :: IO () main = do - let res = solveOdeC (coerce stiffish) (V.fromList [1.0]) + let res = solveOdeC (coerce brusselator) (V.fromList [1.2, 3.1, 3.0]) putStrLn $ show res -- cgit v1.2.3