summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2009-06-04 09:01:56 +0000
committerAlberto Ruiz <aruiz@um.es>2009-06-04 09:01:56 +0000
commit6e0dd472ef8c570ec1924ac641e5872db30ac142 (patch)
tree64963c6af75cdbc02336de82b51136964f36dc73
parentf49ac4def26b38d3d084375007715156be347412 (diff)
added some root finding algorithms
-rw-r--r--examples/root.hs30
-rw-r--r--hmatrix.cabal2
-rw-r--r--lib/Numeric/GSL.hs2
-rw-r--r--lib/Numeric/GSL/Root.hs117
-rw-r--r--lib/Numeric/GSL/gsl-aux.c101
-rw-r--r--lib/Numeric/GSL/gsl-aux.h6
-rw-r--r--lib/Numeric/LinearAlgebra/Tests.hs10
7 files changed, 256 insertions, 12 deletions
diff --git a/examples/root.hs b/examples/root.hs
new file mode 100644
index 0000000..9a674fd
--- /dev/null
+++ b/examples/root.hs
@@ -0,0 +1,30 @@
1-- root finding examples
2import Numeric.GSL
3import Numeric.LinearAlgebra
4import Graphics.Plot
5import Text.Printf(printf)
6
7rosenbrock a b [x,y] = [ a*(1-x), b*(y-x^2) ]
8
9disp = putStrLn . format " " (printf "%.3f")
10
11-- Numerical estimation of the gradient
12gradient f v = [partialDerivative k f v | k <- [0 .. length v -1]]
13
14partialDerivative n f v = fst (derivCentral 0.01 g (v!!n)) where
15 g x = f (concat [a,x:b])
16 (a,_:b) = splitAt n v
17
18test method = do
19 print method
20 let (s,p) = root method 1E-7 30 (rosenbrock 1 10) [-10,-5]
21 print s -- solution
22 disp p -- evolution of the algorithm
23-- let [x,y] = tail (toColumns p)
24-- mplot [x,y] -- path from the starting point to the solution
25
26main = do
27 test Hybrids
28 test Hybrid
29 test DNewton
30 test Broyden
diff --git a/hmatrix.cabal b/hmatrix.cabal
index 87436ae..70de11e 100644
--- a/hmatrix.cabal
+++ b/hmatrix.cabal
@@ -24,6 +24,7 @@ extra-source-files: examples/tests.hs
24 examples/deriv.hs 24 examples/deriv.hs
25 examples/integrate.hs 25 examples/integrate.hs
26 examples/minimize.hs 26 examples/minimize.hs
27 examples/root.hs
27 examples/pca1.hs 28 examples/pca1.hs
28 examples/pca2.hs 29 examples/pca2.hs
29 examples/pinv.hs 30 examples/pinv.hs
@@ -80,6 +81,7 @@ library
80 Numeric.GSL.Fourier, 81 Numeric.GSL.Fourier,
81 Numeric.GSL.Polynomials, 82 Numeric.GSL.Polynomials,
82 Numeric.GSL.Minimization, 83 Numeric.GSL.Minimization,
84 Numeric.GSL.Root,
83 Numeric.GSL.Vector, 85 Numeric.GSL.Vector,
84 Numeric.GSL.Special, 86 Numeric.GSL.Special,
85 Numeric.GSL.Special.Gamma, 87 Numeric.GSL.Special.Gamma,
diff --git a/lib/Numeric/GSL.hs b/lib/Numeric/GSL.hs
index 32962ef..2e90fff 100644
--- a/lib/Numeric/GSL.hs
+++ b/lib/Numeric/GSL.hs
@@ -18,6 +18,7 @@ module Numeric.GSL (
18, module Numeric.GSL.Fourier 18, module Numeric.GSL.Fourier
19, module Numeric.GSL.Polynomials 19, module Numeric.GSL.Polynomials
20, module Numeric.GSL.Minimization 20, module Numeric.GSL.Minimization
21, module Numeric.GSL.Root
21, module Numeric.GSL.Special 22, module Numeric.GSL.Special
22, module Complex 23, module Complex
23, setErrorHandlerOff 24, setErrorHandlerOff
@@ -29,6 +30,7 @@ import Numeric.GSL.Special
29import Numeric.GSL.Fourier 30import Numeric.GSL.Fourier
30import Numeric.GSL.Polynomials 31import Numeric.GSL.Polynomials
31import Numeric.GSL.Minimization 32import Numeric.GSL.Minimization
33import Numeric.GSL.Root
32import Complex 34import Complex
33import Numeric.GSL.Special 35import Numeric.GSL.Special
34 36
diff --git a/lib/Numeric/GSL/Root.hs b/lib/Numeric/GSL/Root.hs
new file mode 100644
index 0000000..ad1b72c
--- /dev/null
+++ b/lib/Numeric/GSL/Root.hs
@@ -0,0 +1,117 @@
1{- |
2Module : Numeric.GSL.Root
3Copyright : (c) Alberto Ruiz 2009
4License : GPL
5
6Maintainer : Alberto Ruiz (aruiz at um dot es)
7Stability : provisional
8Portability : uses ffi
9
10Multidimensional root finding.
11
12<http://www.gnu.org/software/gsl/manual/html_node/Multidimensional-Root_002dFinding.html>
13
14The example in the GSL manual:
15
16@import Numeric.GSL
17import Numeric.LinearAlgebra(format)
18import Text.Printf(printf)
19
20rosenbrock a b [x,y] = [ a*(1-x), b*(y-x^2) ]
21
22disp = putStrLn . format \" \" (printf \"%.3f\")
23
24main = do
25 let (sol,path) = root Hybrids 1E-7 30 (rosenbrock 1 10) [-10,-5]
26 print sol
27 disp path
28
29\> main
30[1.0,1.0]
31 0.000 -10.000 -5.000 11.000 -1050.000
32 1.000 -3.976 24.827 4.976 90.203
33 2.000 -3.976 24.827 4.976 90.203
34 3.000 -3.976 24.827 4.976 90.203
35 4.000 -1.274 -5.680 2.274 -73.018
36 5.000 -1.274 -5.680 2.274 -73.018
37 6.000 0.249 0.298 0.751 2.359
38 7.000 0.249 0.298 0.751 2.359
39 8.000 1.000 0.878 -0.000 -1.218
40 9.000 1.000 0.989 -0.000 -0.108
4110.000 1.000 1.000 0.000 0.000
42@
43
44-}
45-----------------------------------------------------------------------------
46
47module Numeric.GSL.Root (
48 root, RootMethod(..)
49) where
50
51import Data.Packed.Internal
52import Data.Packed.Matrix
53import Foreign
54import Foreign.C.Types(CInt)
55
56-------------------------------------------------------------------------
57
58data RootMethod = Hybrids
59 | Hybrid
60 | DNewton
61 | Broyden
62 deriving (Enum,Eq,Show)
63
64-- | Nonlinear multidimensional root finding using algorithms that do not require
65-- any derivative information to be supplied by the user.
66-- Any derivatives needed are approximated by finite differences.
67root :: RootMethod
68 -> Double -- ^ maximum residual
69 -> Int -- ^ maximum number of iterations allowed
70 -> ([Double] -> [Double]) -- ^ function to minimize
71 -> [Double] -- ^ starting point
72 -> ([Double], Matrix Double) -- ^ solution vector and optimization path
73
74root method epsabs maxit fun xinit = rootGen (fi (fromEnum method)) fun xinit epsabs maxit
75
76rootGen m f xi epsabs maxit = unsafePerformIO $ do
77 let xiv = fromList xi
78 n = dim xiv
79 fp <- mkVecVecfun (aux_vTov (fromList.f.toList))
80 rawpath <- withVector xiv $ \xiv' ->
81 createMIO maxit (2*n+1)
82 (c_root m fp epsabs (fi maxit) // xiv')
83 "root"
84 let it = round (rawpath @@> (maxit-1,0))
85 path = takeRows it rawpath
86 [sol] = toLists $ dropRows (it-1) path
87 freeHaskellFunPtr fp
88 return (take n $ drop 1 sol, path)
89
90
91foreign import ccall "root"
92 c_root:: CInt -> FunPtr (CInt -> Ptr Double -> Ptr Double -> IO ()) -> Double -> CInt -> TVM
93
94---------------------------------------------------------------------
95
96foreign import ccall "wrapper"
97 mkVecVecfun :: (CInt -> Ptr Double -> Ptr Double -> IO ())
98 -> IO (FunPtr (CInt -> Ptr Double -> Ptr Double->IO()))
99
100aux_vTov :: (Vector Double -> Vector Double) -> (CInt -> Ptr Double -> Ptr Double -> IO())
101aux_vTov f n p r = g where
102 V {fptr = pr} = f x
103 x = createV (fromIntegral n) copy "aux_vTov"
104 copy n' q = do
105 copyArray q p (fromIntegral n')
106 return 0
107 g = withForeignPtr pr $ \p' -> copyArray r p' (fromIntegral n)
108
109createV n fun msg = unsafePerformIO $ do
110 r <- createVector n
111 app1 fun vec r msg
112 return r
113
114createMIO r c fun msg = do
115 res <- createMatrix RowMajor r c
116 app1 fun mat res msg
117 return res
diff --git a/lib/Numeric/GSL/gsl-aux.c b/lib/Numeric/GSL/gsl-aux.c
index 3802574..80c23fc 100644
--- a/lib/Numeric/GSL/gsl-aux.c
+++ b/lib/Numeric/GSL/gsl-aux.c
@@ -7,6 +7,7 @@
7#include <gsl/gsl_deriv.h> 7#include <gsl/gsl_deriv.h>
8#include <gsl/gsl_poly.h> 8#include <gsl/gsl_poly.h>
9#include <gsl/gsl_multimin.h> 9#include <gsl/gsl_multimin.h>
10#include <gsl/gsl_multiroots.h>
10#include <gsl/gsl_complex.h> 11#include <gsl/gsl_complex.h>
11#include <gsl/gsl_complex_math.h> 12#include <gsl/gsl_complex_math.h>
12#include <string.h> 13#include <string.h>
@@ -288,6 +289,22 @@ int fft(int code, KCVEC(X), CVEC(R)) {
288} 289}
289 290
290 291
292int deriv(int code, double f(double, void*), double x, double h, double * result, double * abserr)
293{
294 gsl_function F;
295 F.function = f;
296 F.params = 0;
297
298 if(code==0) return gsl_deriv_central (&F, x, h, result, abserr);
299
300 if(code==1) return gsl_deriv_forward (&F, x, h, result, abserr);
301
302 if(code==2) return gsl_deriv_backward (&F, x, h, result, abserr);
303
304 return 0;
305}
306
307
291int integrate_qng(double f(double, void*), double a, double b, double prec, 308int integrate_qng(double f(double, void*), double a, double b, double prec,
292 double *result, double*error) { 309 double *result, double*error) {
293 DEBUGMSG("integrate_qng"); 310 DEBUGMSG("integrate_qng");
@@ -440,7 +457,7 @@ void fdf_aux_min(const gsl_vector * x, void * pars, double * f, gsl_vector * g)
440 df_aux_min(x,pars,g); 457 df_aux_min(x,pars,g);
441} 458}
442 459
443// conjugate gradient 460
444int minimizeWithDeriv(int method, double f(int, double*), void df(int, double*, double*), 461int minimizeWithDeriv(int method, double f(int, double*), void df(int, double*, double*),
445 double initstep, double minimpar, double tolgrad, int maxit, 462 double initstep, double minimpar, double tolgrad, int maxit,
446 KRVEC(xi), RMAT(sol)) { 463 KRVEC(xi), RMAT(sol)) {
@@ -492,18 +509,82 @@ int minimizeWithDeriv(int method, double f(int, double*), void df(int, double*,
492 OK 509 OK
493} 510}
494 511
512//---------------------------------------------------------------
495 513
496int deriv(int code, double f(double, void*), double x, double h, double * result, double * abserr) 514typedef void TrawfunV(int, double*, double*);
497{
498 gsl_function F;
499 F.function = f;
500 F.params = 0;
501 515
502 if(code==0) return gsl_deriv_central (&F, x, h, result, abserr); 516int only_f_aux_root(const gsl_vector*x, void *pars, gsl_vector*y) {
517 TrawfunV * f = (TrawfunV*) pars;
518 double* p = (double*)calloc(x->size,sizeof(double));
519 double* q = (double*)calloc(x->size,sizeof(double));
520 int k;
521 for(k=0;k<x->size;k++) {
522 p[k] = gsl_vector_get(x,k);
523 }
524 f(x->size,p,q);
525 for(k=0;k<y->size;k++) {
526 gsl_vector_set(y,k,q[k]);
527 }
528 free(p);
529 free(q);
530 return 0; //hmmm
531}
503 532
504 if(code==1) return gsl_deriv_forward (&F, x, h, result, abserr); 533int root(int method, void f(int, double*, int, double*),
534 double epsabs, int maxit,
535 KRVEC(xi), RMAT(sol)) {
536 REQUIRES(solr == maxit && solc == 1+2*xin,BAD_SIZE);
537 DEBUGMSG("root_only_f");
538 gsl_multiroot_function my_func;
539 // extract function from pars
540 my_func.f = only_f_aux_root;
541 my_func.n = xin;
542 my_func.params = f;
543 size_t iter = 0;
544 int status;
545 const gsl_multiroot_fsolver_type *T;
546 gsl_multiroot_fsolver *s;
547 // Starting point
548 KDVVIEW(xi);
549 switch(method) {
550 case 0 : {T = gsl_multiroot_fsolver_hybrids;; break; }
551 case 1 : {T = gsl_multiroot_fsolver_hybrid; break; }
552 case 2 : {T = gsl_multiroot_fsolver_dnewton; break; }
553 case 3 : {T = gsl_multiroot_fsolver_broyden; break; }
554 default: ERROR(BAD_CODE);
555 }
556 s = gsl_multiroot_fsolver_alloc (T, my_func.n);
557 gsl_multiroot_fsolver_set (s, &my_func, V(xi));
505 558
506 if(code==2) return gsl_deriv_backward (&F, x, h, result, abserr); 559 do {
560 status = gsl_multiroot_fsolver_iterate (s);
507 561
508 return 0; 562 solp[iter*solc+0] = iter;
563
564 int k;
565 for(k=0;k<xin;k++) {
566 solp[iter*solc+k+1] = gsl_vector_get(s->x,k);
567 }
568 for(k=xin;k<2*xin;k++) {
569 solp[iter*solc+k+1] = gsl_vector_get(s->f,k-xin);
570 }
571
572 iter++;
573 if (status) /* check if solver is stuck */
574 break;
575
576 status =
577 gsl_multiroot_test_residual (s->f, epsabs);
578 }
579 while (status == GSL_CONTINUE && iter < maxit);
580
581 int i,j;
582 for (i=iter; i<solr; i++) {
583 solp[i*solc+0] = iter;
584 for(j=1;j<solc;j++) {
585 solp[i*solc+j]=0.;
586 }
587 }
588 gsl_multiroot_fsolver_free(s);
589 OK
509} 590}
diff --git a/lib/Numeric/GSL/gsl-aux.h b/lib/Numeric/GSL/gsl-aux.h
index e88322c..c9fd546 100644
--- a/lib/Numeric/GSL/gsl-aux.h
+++ b/lib/Numeric/GSL/gsl-aux.h
@@ -28,6 +28,8 @@ int zipC(int code, KCVEC(a), KCVEC(b), CVEC(r));
28 28
29int fft(int code, KCVEC(a), CVEC(b)); 29int fft(int code, KCVEC(a), CVEC(b));
30 30
31int deriv(int code, double f(double, void*), double x, double h, double * result, double * abserr);
32
31int integrate_qng(double f(double, void*), double a, double b, double prec, 33int integrate_qng(double f(double, void*), double a, double b, double prec,
32 double *result, double*error); 34 double *result, double*error);
33 35
@@ -43,4 +45,6 @@ int minimizeWithDeriv(int method, double f(int, double*), void df(int, double*,
43 double initstep, double minimpar, double tolgrad, int maxit, 45 double initstep, double minimpar, double tolgrad, int maxit,
44 KRVEC(xi), RMAT(sol)); 46 KRVEC(xi), RMAT(sol));
45 47
46int deriv(int code, double f(double, void*), double x, double h, double * result, double * abserr); 48int root(int method, void f(int, double*, int, double*),
49 double epsabs, int maxit,
50 KRVEC(xi), RMAT(sol));
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs
index 278df78..4f73e3a 100644
--- a/lib/Numeric/LinearAlgebra/Tests.hs
+++ b/lib/Numeric/LinearAlgebra/Tests.hs
@@ -36,6 +36,8 @@ a ^ b = a Prelude.^ (b :: Int)
36 36
37utest str b = TestCase $ assertBool str b 37utest str b = TestCase $ assertBool str b
38 38
39a ~~ b = fromList a |~| fromList b
40
39feye n = flipud (ident n) :: Matrix Double 41feye n = flipud (ident n) :: Matrix Double
40 42
41detTest1 = det m == 26 43detTest1 = det m == 26
@@ -112,12 +114,17 @@ minimizationTest = TestList [ utest "minimization conj grad" (minim1 f df [5,7]
112 ] 114 ]
113 where f [x,y] = 10*(x-1)^2 + 20*(y-2)^2 + 30 115 where f [x,y] = 10*(x-1)^2 + 20*(y-2)^2 + 30
114 df [x,y] = [20*(x-1), 40*(y-2)] 116 df [x,y] = [20*(x-1), 40*(y-2)]
115 a ~~ b = fromList a |~| fromList b
116 minim1 g dg ini = fst $ minimizeConjugateGradient 1E-2 1E-4 1E-3 30 g dg ini 117 minim1 g dg ini = fst $ minimizeConjugateGradient 1E-2 1E-4 1E-3 30 g dg ini
117 minim2 g dg ini = fst $ minimizeVectorBFGS2 1E-2 1E-2 1E-3 30 g dg ini 118 minim2 g dg ini = fst $ minimizeVectorBFGS2 1E-2 1E-2 1E-3 30 g dg ini
118 119
119--------------------------------------------------------------------- 120---------------------------------------------------------------------
120 121
122rootFindingTest = utest "root Hybrids" (sol ~~ [1,1])
123 where sol = fst $ root Hybrids 1E-7 30 (rosenbrock 1 10) [-10,-5]
124 rosenbrock a b [x,y] = [ a*(1-x), b*(y-x^2) ]
125
126---------------------------------------------------------------------
127
121rot :: Double -> Matrix Double 128rot :: Double -> Matrix Double
122rot a = (3><3) [ c,0,s 129rot a = (3><3) [ c,0,s
123 , 0,1,0 130 , 0,1,0
@@ -217,6 +224,7 @@ runTests n = do
217 , utest "integrate" (abs (volSphere 2.5 - 4/3*pi*2.5^3) < 1E-8) 224 , utest "integrate" (abs (volSphere 2.5 - 4/3*pi*2.5^3) < 1E-8)
218 , utest "polySolve" (polySolveProp [1,2,3,4]) 225 , utest "polySolve" (polySolveProp [1,2,3,4])
219 , minimizationTest 226 , minimizationTest
227 , rootFindingTest
220 ] 228 ]
221 return () 229 return ()
222 230