From 755175a557d07c6f73683f358ddd8f8ee07f26a9 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Tue, 20 Mar 2018 12:03:09 +0000 Subject: Handle arbitrary systems --- packages/sundials/hmatrix-sundials.cabal | 2 +- packages/sundials/src/Main.hs | 47 +++++++++++++++++++++++--------- packages/sundials/src/helpers.c | 22 --------------- packages/sundials/src/helpers.h | 3 -- 4 files changed, 35 insertions(+), 39 deletions(-) diff --git a/packages/sundials/hmatrix-sundials.cabal b/packages/sundials/hmatrix-sundials.cabal index 43a83d0..d928ab1 100644 --- a/packages/sundials/hmatrix-sundials.cabal +++ b/packages/sundials/hmatrix-sundials.cabal @@ -1,7 +1,7 @@ -- Initial sundials.cabal generated by cabal init. For further -- documentation, see http://haskell.org/cabal/users-guide/ -name: sundials +name: hmatrix-sundials version: 0.1.0.0 -- synopsis: -- description: diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index b3ebcb3..fc48710 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs @@ -78,11 +78,27 @@ stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] u = v V.! 0 lamda = -100.0 +brusselator :: Double -> V.Vector Double -> V.Vector Double +brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2 + , w * u - v * u^2 + , (b - w) / eps - w * u + ] + where + a = 1.0 + b = 3.5 + eps = 5.0e-6 + u = x V.! 0 + v = x V.! 1 + w = x V.! 2 + solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> - V.Vector Double -> - CInt + V.Vector CDouble -> + Either CInt (V.Vector CDouble) solveOdeC fun f0 = unsafePerformIO $ do let dim = V.length f0 + nEq :: CLong + nEq = fromIntegral dim + fMut <- V.thaw f0 -- We need the types that sundials expects. These are tied together -- in 'Types'. The Haskell type is currently empty! let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt @@ -110,21 +126,22 @@ solveOdeC fun f0 = unsafePerformIO $ do realtype T0 = RCONST(0.0); /* initial time */ realtype Tf = RCONST(10.0); /* final time */ realtype dTout = RCONST(1.0); /* time between outputs */ - sunindextype NEQ = 1; /* number of dependent vars. */ + sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ realtype reltol = 1.0e-6; /* tolerances */ realtype abstol = 1.0e-10; - realtype lamda = -100.0; /* stiffness parameter */ /* Initial diagnostics output */ printf("\nAnalytical ODE test problem:\n"); - printf(" lamda = %"GSYM"\n", lamda); printf(" reltol = %.1"ESYM"\n", reltol); printf(" abstol = %.1"ESYM"\n\n",abstol); /* Initialize data structures */ y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; - N_VConst(0.0, y); /* Specify initial condition */ + int i; + for (i = 0; i < NEQ; i++) { + NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i]; + }; /* Specify initial condition */ arkode_mem = ARKodeCreate(); /* Create the solver memory */ if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; @@ -139,8 +156,6 @@ solveOdeC fun f0 = unsafePerformIO $ do if (check_flag(&flag, "ARKodeInit", 1)) return 1; /* Set routines */ - flag = ARKodeSetUserData(arkode_mem, (void *) &lamda); /* Pass lamda to user functions */ - if (check_flag(&flag, "ARKodeSetUserData", 1)) return 1; flag = ARKodeSStolerances(arkode_mem, reltol, abstol); /* Specify tolerances */ if (check_flag(&flag, "ARKodeSStolerances", 1)) return 1; @@ -182,6 +197,10 @@ solveOdeC fun f0 = unsafePerformIO $ do printf(" ---------------------\n"); fclose(UFID); + for (i = 0; i < NEQ; i++) { + ($vec-ptr:(double *fMut))[i] = NV_Ith_S(y,i); + }; + /* Get/print some final statistics on how the solve progressed */ flag = ARKodeGetNumSteps(arkode_mem, &nst); check_flag(&flag, "ARKodeGetNumSteps", 1); @@ -212,9 +231,6 @@ solveOdeC fun f0 = unsafePerformIO $ do printf(" Total number of linear solver convergence failures = %li\n", ncfn); printf(" Total number of error test failures = %li\n\n", netf); - /* check the solution error */ - flag = check_ans(y, t, reltol, abstol); - /* Clean up and return */ N_VDestroy(y); /* Free y vector */ ARKodeFree(&arkode_mem); /* Free integrator memory */ @@ -223,9 +239,14 @@ solveOdeC fun f0 = unsafePerformIO $ do return flag; } |] - return res + if res ==0 + then do + v <- V.freeze fMut + return $ Right v + else do + return $ Left res main :: IO () main = do - let res = solveOdeC (coerce stiffish) (V.fromList [1.0]) + let res = solveOdeC (coerce brusselator) (V.fromList [1.2, 3.1, 3.0]) putStrLn $ show res diff --git a/packages/sundials/src/helpers.c b/packages/sundials/src/helpers.c index 420d3be..f0ca592 100644 --- a/packages/sundials/src/helpers.c +++ b/packages/sundials/src/helpers.c @@ -42,25 +42,3 @@ int check_flag(void *flagvalue, const char *funcname, int opt) return 0; } - -/* check the computed solution */ -int check_ans(N_Vector y, realtype t, realtype rtol, realtype atol) -{ - int passfail=0; /* answer pass (0) or fail (1) flag */ - realtype ans, err, ewt; /* answer data, error, and error weight */ - realtype ONE=RCONST(1.0); - - /* compute solution error */ - ans = atan(t); - ewt = ONE / (rtol * SUNRabs(ans) + atol); - err = ewt * SUNRabs(NV_Ith_S(y,0) - ans); - - /* is the solution within the tolerances? */ - passfail = (err < ONE) ? 0 : 1; - - if (passfail) { - fprintf(stdout, "\nSUNDIALS_WARNING: check_ans error=%g \n\n", err); - } - - return(passfail); -} diff --git a/packages/sundials/src/helpers.h b/packages/sundials/src/helpers.h index 69a3dfe..5c1d9f3 100644 --- a/packages/sundials/src/helpers.h +++ b/packages/sundials/src/helpers.h @@ -20,6 +20,3 @@ typedef struct _N_VectorContent_Serial BazType; NULL pointer */ int check_flag(void *flagvalue, const char *funcname, int opt); - -/* check the computed solution */ -int check_ans(N_Vector y, realtype t, realtype rtol, realtype atol); -- cgit v1.2.3