summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2009-06-07 13:01:03 +0000
committerAlberto Ruiz <aruiz@um.es>2009-06-07 13:01:03 +0000
commit7697c6dc27fd0d9601728af576e8d7b9d1c800ee (patch)
treec4b2c0f52f3884fbbd93bd6c0739e0fc73b921a2
parent49a3d719221cd9484a64688ffcdbeb13cb8e55a0 (diff)
root finding with jacobian
-rw-r--r--examples/root.hs14
-rw-r--r--lib/Numeric/GSL/Root.hs83
-rw-r--r--lib/Numeric/GSL/gsl-aux.c121
-rw-r--r--lib/Numeric/LinearAlgebra/Tests.hs9
4 files changed, 209 insertions, 18 deletions
diff --git a/examples/root.hs b/examples/root.hs
index 69db243..2a24f0f 100644
--- a/examples/root.hs
+++ b/examples/root.hs
@@ -11,6 +11,15 @@ test method = do
11 print s -- solution 11 print s -- solution
12 disp p -- evolution of the algorithm 12 disp p -- evolution of the algorithm
13 13
14jacobian a b [x,y] = [ [-a , 0]
15 , [-2*b*x, b] ]
16
17testJ method = do
18 print method
19 let (s,p) = rootJ method 1E-7 30 (rosenbrock 1 10) (jacobian 1 10) [-10,-5]
20 print s
21 disp p
22
14disp = putStrLn . format " " (printf "%.3f") 23disp = putStrLn . format " " (printf "%.3f")
15 24
16main = do 25main = do
@@ -18,3 +27,8 @@ main = do
18 test Hybrid 27 test Hybrid
19 test DNewton 28 test DNewton
20 test Broyden 29 test Broyden
30
31 testJ HybridsJ
32 testJ HybridJ
33 testJ Newton
34 testJ GNewton
diff --git a/lib/Numeric/GSL/Root.hs b/lib/Numeric/GSL/Root.hs
index d674fad..6ce2c4c 100644
--- a/lib/Numeric/GSL/Root.hs
+++ b/lib/Numeric/GSL/Root.hs
@@ -45,7 +45,8 @@ main = do
45----------------------------------------------------------------------------- 45-----------------------------------------------------------------------------
46 46
47module Numeric.GSL.Root ( 47module Numeric.GSL.Root (
48 root, RootMethod(..) 48 root, RootMethod(..),
49 rootJ, RootMethodJ(..),
49) where 50) where
50 51
51import Data.Packed.Internal 52import Data.Packed.Internal
@@ -76,7 +77,7 @@ root method epsabs maxit fun xinit = rootGen (fi (fromEnum method)) fun xinit ep
76rootGen m f xi epsabs maxit = unsafePerformIO $ do 77rootGen m f xi epsabs maxit = unsafePerformIO $ do
77 let xiv = fromList xi 78 let xiv = fromList xi
78 n = dim xiv 79 n = dim xiv
79 fp <- mkVecVecfun (aux_vTov (fromList . checkdim n f . toList)) 80 fp <- mkVecVecfun (aux_vTov (checkdim1 n . fromList . f . toList))
80 rawpath <- withVector xiv $ \xiv' -> 81 rawpath <- withVector xiv $ \xiv' ->
81 createMIO maxit (2*n+1) 82 createMIO maxit (2*n+1)
82 (c_root m fp epsabs (fi maxit) // xiv') 83 (c_root m fp epsabs (fi maxit) // xiv')
@@ -89,22 +90,74 @@ rootGen m f xi epsabs maxit = unsafePerformIO $ do
89 90
90 91
91foreign import ccall "root" 92foreign import ccall "root"
92 c_root:: CInt -> FunPtr (CInt -> Ptr Double -> Ptr Double -> IO ()) -> Double -> CInt -> TVM 93 c_root:: CInt -> FunPtr TVV -> Double -> CInt -> TVM
94
95-------------------------------------------------------------------------
96
97data RootMethodJ = HybridsJ
98 | HybridJ
99 | Newton
100 | GNewton
101 deriving (Enum,Eq,Show)
102
103-- | Nonlinear multidimensional root finding using both the function and its derivatives.
104rootJ :: RootMethodJ
105 -> Double -- ^ maximum residual
106 -> Int -- ^ maximum number of iterations allowed
107 -> ([Double] -> [Double]) -- ^ function to minimize
108 -> ([Double] -> [[Double]]) -- ^ Jacobian
109 -> [Double] -- ^ starting point
110 -> ([Double], Matrix Double) -- ^ solution vector and optimization path
111
112rootJ method epsabs maxit fun jac xinit = rootJGen (fi (fromEnum method)) fun jac xinit epsabs maxit
113
114rootJGen m f jac xi epsabs maxit = unsafePerformIO $ do
115 let xiv = fromList xi
116 n = dim xiv
117 fp <- mkVecVecfun (aux_vTov (checkdim1 n . fromList . f . toList))
118 jp <- mkVecMatfun (aux_vTom (checkdim2 n . fromLists . jac . toList))
119 rawpath <- withVector xiv $ \xiv' ->
120 createMIO maxit (2*n+1)
121 (c_rootj m fp jp epsabs (fi maxit) // xiv')
122 "root"
123 let it = round (rawpath @@> (maxit-1,0))
124 path = takeRows it rawpath
125 [sol] = toLists $ dropRows (it-1) path
126 freeHaskellFunPtr fp
127 return (take n $ drop 1 sol, path)
128
129
130foreign import ccall "rootj"
131 c_rootj:: CInt -> FunPtr TVV -> FunPtr TVM -> Double -> CInt -> TVM
132
93 133
94--------------------------------------------------------------------- 134---------------------------------------------------------------------
95 135
96foreign import ccall "wrapper" 136foreign import ccall "wrapper"
97 mkVecVecfun :: (CInt -> Ptr Double -> Ptr Double -> IO ()) 137 mkVecVecfun :: TVV -> IO (FunPtr TVV)
98 -> IO (FunPtr (CInt -> Ptr Double -> Ptr Double->IO()))
99 138
100aux_vTov :: (Vector Double -> Vector Double) -> (CInt -> Ptr Double -> Ptr Double -> IO()) 139aux_vTov :: (Vector Double -> Vector Double) -> TVV
101aux_vTov f n p r = g where 140aux_vTov f n p nr r = g where
102 V {fptr = pr} = f x 141 V {fptr = pr} = f x
103 x = createV (fromIntegral n) copy "aux_vTov" 142 x = createV (fromIntegral n) copy "aux_vTov"
104 copy n' q = do 143 copy n' q = do
105 copyArray q p (fromIntegral n') 144 copyArray q p (fromIntegral n')
106 return 0 145 return 0
107 g = withForeignPtr pr $ \p' -> copyArray r p' (fromIntegral n) 146 g = do withForeignPtr pr $ \p' -> copyArray r p' (fromIntegral nr)
147 return 0
148
149foreign import ccall "wrapper"
150 mkVecMatfun :: TVM -> IO (FunPtr TVM)
151
152aux_vTom :: (Vector Double -> Matrix Double) -> TVM
153aux_vTom f n p rr cr r = g where
154 V {fptr = pr} = flatten $ f x
155 x = createV (fromIntegral n) copy "aux_vTov"
156 copy n' q = do
157 copyArray q p (fromIntegral n')
158 return 0
159 g = do withForeignPtr pr $ \p' -> copyArray r p' (fromIntegral $ rr*cr)
160 return 0
108 161
109createV n fun msg = unsafePerformIO $ do 162createV n fun msg = unsafePerformIO $ do
110 r <- createVector n 163 r <- createVector n
@@ -116,8 +169,12 @@ createMIO r c fun msg = do
116 app1 fun mat res msg 169 app1 fun mat res msg
117 return res 170 return res
118 171
119checkdim n f x 172checkdim1 n v
120 | length y /= n = error $ "Error: "++ show n 173 | dim v == n = v
121 ++ " results expected in the function supplied to root" 174 | otherwise = error $ "Error: "++ show n
122 | otherwise = y 175 ++ " results expected in the function supplied to root"
123 where y = f x 176
177checkdim2 n m
178 | rows m == n && cols m == n = m
179 | otherwise = error $ "Error: "++ show n ++ "x" ++ show n
180 ++ " Jacobian expected in root"
diff --git a/lib/Numeric/GSL/gsl-aux.c b/lib/Numeric/GSL/gsl-aux.c
index 80c23fc..c6b052f 100644
--- a/lib/Numeric/GSL/gsl-aux.c
+++ b/lib/Numeric/GSL/gsl-aux.c
@@ -511,17 +511,17 @@ int minimizeWithDeriv(int method, double f(int, double*), void df(int, double*,
511 511
512//--------------------------------------------------------------- 512//---------------------------------------------------------------
513 513
514typedef void TrawfunV(int, double*, double*); 514typedef void TrawfunV(int, double*, int, double*);
515 515
516int only_f_aux_root(const gsl_vector*x, void *pars, gsl_vector*y) { 516int only_f_aux_root(const gsl_vector*x, void *pars, gsl_vector*y) {
517 TrawfunV * f = (TrawfunV*) pars; 517 TrawfunV * f = (TrawfunV*) pars;
518 double* p = (double*)calloc(x->size,sizeof(double)); 518 double* p = (double*)calloc(x->size,sizeof(double));
519 double* q = (double*)calloc(x->size,sizeof(double)); 519 double* q = (double*)calloc(y->size,sizeof(double));
520 int k; 520 int k;
521 for(k=0;k<x->size;k++) { 521 for(k=0;k<x->size;k++) {
522 p[k] = gsl_vector_get(x,k); 522 p[k] = gsl_vector_get(x,k);
523 } 523 }
524 f(x->size,p,q); 524 f(x->size,p,y->size,q);
525 for(k=0;k<y->size;k++) { 525 for(k=0;k<y->size;k++) {
526 gsl_vector_set(y,k,q[k]); 526 gsl_vector_set(y,k,q[k]);
527 } 527 }
@@ -588,3 +588,118 @@ int root(int method, void f(int, double*, int, double*),
588 gsl_multiroot_fsolver_free(s); 588 gsl_multiroot_fsolver_free(s);
589 OK 589 OK
590} 590}
591
592// working with the jacobian
593
594typedef struct {int (*f)(int, double*, int, double *); int (*jf)(int, double*, int, int, double*);} Tfjf;
595
596int f_aux_root(const gsl_vector*x, void *pars, gsl_vector*y) {
597 Tfjf * fjf = ((Tfjf*) pars);
598 double* p = (double*)calloc(x->size,sizeof(double));
599 double* q = (double*)calloc(y->size,sizeof(double));
600 int k;
601 for(k=0;k<x->size;k++) {
602 p[k] = gsl_vector_get(x,k);
603 }
604 (fjf->f)(x->size,p,y->size,q);
605 for(k=0;k<y->size;k++) {
606 gsl_vector_set(y,k,q[k]);
607 }
608 free(p);
609 free(q);
610 return 0;
611}
612
613int jf_aux_root(const gsl_vector * x, void * pars, gsl_matrix * jac) {
614 Tfjf * fjf = ((Tfjf*) pars);
615 double* p = (double*)calloc(x->size,sizeof(double));
616 double* q = (double*)calloc((x->size)*(x->size),sizeof(double));
617 int i,j,k;
618 for(k=0;k<x->size;k++) {
619 p[k] = gsl_vector_get(x,k);
620 }
621
622 (fjf->jf)(x->size,p,x->size,x->size,q);
623
624 k=0;
625 for(i=0;i<x->size;i++) {
626 for(j=0;j<x->size;j++){
627 gsl_matrix_set(jac,i,j,q[k++]);
628 }
629 }
630 free(p);
631 free(q);
632 return 0;
633}
634
635int fjf_aux_root(const gsl_vector * x, void * pars, gsl_vector * f, gsl_matrix * g) {
636 f_aux_root(x,pars,f);
637 jf_aux_root(x,pars,g);
638 return 0;
639}
640
641int rootj(int method, int f(int, double*, int, double*),
642 int jac(int, double*, int, int, double*),
643 double epsabs, int maxit,
644 KRVEC(xi), RMAT(sol)) {
645 REQUIRES(solr == maxit && solc == 1+2*xin,BAD_SIZE);
646 DEBUGMSG("root_fjf");
647 gsl_multiroot_function_fdf my_func;
648 // extract function from pars
649 my_func.f = f_aux_root;
650 my_func.df = jf_aux_root;
651 my_func.fdf = fjf_aux_root;
652 my_func.n = xin;
653 Tfjf stfjf;
654 stfjf.f = f;
655 stfjf.jf = jac;
656 my_func.params = &stfjf;
657 size_t iter = 0;
658 int status;
659 const gsl_multiroot_fdfsolver_type *T;
660 gsl_multiroot_fdfsolver *s;
661 // Starting point
662 KDVVIEW(xi);
663 switch(method) {
664 case 0 : {T = gsl_multiroot_fdfsolver_hybridsj;; break; }
665 case 1 : {T = gsl_multiroot_fdfsolver_hybridj; break; }
666 case 2 : {T = gsl_multiroot_fdfsolver_newton; break; }
667 case 3 : {T = gsl_multiroot_fdfsolver_gnewton; break; }
668 default: ERROR(BAD_CODE);
669 }
670 s = gsl_multiroot_fdfsolver_alloc (T, my_func.n);
671
672 gsl_multiroot_fdfsolver_set (s, &my_func, V(xi));
673
674 do {
675 status = gsl_multiroot_fdfsolver_iterate (s);
676
677 solp[iter*solc+0] = iter;
678
679 int k;
680 for(k=0;k<xin;k++) {
681 solp[iter*solc+k+1] = gsl_vector_get(s->x,k);
682 }
683 for(k=xin;k<2*xin;k++) {
684 solp[iter*solc+k+1] = gsl_vector_get(s->f,k-xin);
685 }
686
687 iter++;
688 if (status) /* check if solver is stuck */
689 break;
690
691 status =
692 gsl_multiroot_test_residual (s->f, epsabs);
693 }
694 while (status == GSL_CONTINUE && iter < maxit);
695
696 int i,j;
697 for (i=iter; i<solr; i++) {
698 solp[i*solc+0] = iter;
699 for(j=1;j<solc;j++) {
700 solp[i*solc+j]=0.;
701 }
702 }
703 gsl_multiroot_fdfsolver_free(s);
704 OK
705} \ No newline at end of file
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs
index 174e418..83f581f 100644
--- a/lib/Numeric/LinearAlgebra/Tests.hs
+++ b/lib/Numeric/LinearAlgebra/Tests.hs
@@ -119,9 +119,14 @@ minimizationTest = TestList [ utest "minimization conj grad" (minim1 f df [5,7]
119 119
120--------------------------------------------------------------------- 120---------------------------------------------------------------------
121 121
122rootFindingTest = utest "root Hybrids" (sol ~~ [1,1]) 122rootFindingTest = TestList [ utest "root Hybrids" (fst sol1 ~~ [1,1])
123 where sol = fst $ root Hybrids 1E-7 30 (rosenbrock 1 10) [-10,-5] 123 , utest "root Newton" (rows (snd sol2) == 2)
124 ]
125 where sol1 = root Hybrids 1E-7 30 (rosenbrock 1 10) [-10,-5]
126 sol2 = rootJ Newton 1E-7 30 (rosenbrock 1 10) (jacobian 1 10) [-10,-5]
124 rosenbrock a b [x,y] = [ a*(1-x), b*(y-x^2) ] 127 rosenbrock a b [x,y] = [ a*(1-x), b*(y-x^2) ]
128 jacobian a b [x,_y] = [ [-a , 0]
129 , [-2*b*x, b] ]
125 130
126--------------------------------------------------------------------- 131---------------------------------------------------------------------
127 132