summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2009-06-07 13:01:03 +0000
committerAlberto Ruiz <aruiz@um.es>2009-06-07 13:01:03 +0000
commit7697c6dc27fd0d9601728af576e8d7b9d1c800ee (patch)
treec4b2c0f52f3884fbbd93bd6c0739e0fc73b921a2 /lib
parent49a3d719221cd9484a64688ffcdbeb13cb8e55a0 (diff)
root finding with jacobian
Diffstat (limited to 'lib')
-rw-r--r--lib/Numeric/GSL/Root.hs83
-rw-r--r--lib/Numeric/GSL/gsl-aux.c121
-rw-r--r--lib/Numeric/LinearAlgebra/Tests.hs9
3 files changed, 195 insertions, 18 deletions
diff --git a/lib/Numeric/GSL/Root.hs b/lib/Numeric/GSL/Root.hs
index d674fad..6ce2c4c 100644
--- a/lib/Numeric/GSL/Root.hs
+++ b/lib/Numeric/GSL/Root.hs
@@ -45,7 +45,8 @@ main = do
45----------------------------------------------------------------------------- 45-----------------------------------------------------------------------------
46 46
47module Numeric.GSL.Root ( 47module Numeric.GSL.Root (
48 root, RootMethod(..) 48 root, RootMethod(..),
49 rootJ, RootMethodJ(..),
49) where 50) where
50 51
51import Data.Packed.Internal 52import Data.Packed.Internal
@@ -76,7 +77,7 @@ root method epsabs maxit fun xinit = rootGen (fi (fromEnum method)) fun xinit ep
76rootGen m f xi epsabs maxit = unsafePerformIO $ do 77rootGen m f xi epsabs maxit = unsafePerformIO $ do
77 let xiv = fromList xi 78 let xiv = fromList xi
78 n = dim xiv 79 n = dim xiv
79 fp <- mkVecVecfun (aux_vTov (fromList . checkdim n f . toList)) 80 fp <- mkVecVecfun (aux_vTov (checkdim1 n . fromList . f . toList))
80 rawpath <- withVector xiv $ \xiv' -> 81 rawpath <- withVector xiv $ \xiv' ->
81 createMIO maxit (2*n+1) 82 createMIO maxit (2*n+1)
82 (c_root m fp epsabs (fi maxit) // xiv') 83 (c_root m fp epsabs (fi maxit) // xiv')
@@ -89,22 +90,74 @@ rootGen m f xi epsabs maxit = unsafePerformIO $ do
89 90
90 91
91foreign import ccall "root" 92foreign import ccall "root"
92 c_root:: CInt -> FunPtr (CInt -> Ptr Double -> Ptr Double -> IO ()) -> Double -> CInt -> TVM 93 c_root:: CInt -> FunPtr TVV -> Double -> CInt -> TVM
94
95-------------------------------------------------------------------------
96
97data RootMethodJ = HybridsJ
98 | HybridJ
99 | Newton
100 | GNewton
101 deriving (Enum,Eq,Show)
102
103-- | Nonlinear multidimensional root finding using both the function and its derivatives.
104rootJ :: RootMethodJ
105 -> Double -- ^ maximum residual
106 -> Int -- ^ maximum number of iterations allowed
107 -> ([Double] -> [Double]) -- ^ function to minimize
108 -> ([Double] -> [[Double]]) -- ^ Jacobian
109 -> [Double] -- ^ starting point
110 -> ([Double], Matrix Double) -- ^ solution vector and optimization path
111
112rootJ method epsabs maxit fun jac xinit = rootJGen (fi (fromEnum method)) fun jac xinit epsabs maxit
113
114rootJGen m f jac xi epsabs maxit = unsafePerformIO $ do
115 let xiv = fromList xi
116 n = dim xiv
117 fp <- mkVecVecfun (aux_vTov (checkdim1 n . fromList . f . toList))
118 jp <- mkVecMatfun (aux_vTom (checkdim2 n . fromLists . jac . toList))
119 rawpath <- withVector xiv $ \xiv' ->
120 createMIO maxit (2*n+1)
121 (c_rootj m fp jp epsabs (fi maxit) // xiv')
122 "root"
123 let it = round (rawpath @@> (maxit-1,0))
124 path = takeRows it rawpath
125 [sol] = toLists $ dropRows (it-1) path
126 freeHaskellFunPtr fp
127 return (take n $ drop 1 sol, path)
128
129
130foreign import ccall "rootj"
131 c_rootj:: CInt -> FunPtr TVV -> FunPtr TVM -> Double -> CInt -> TVM
132
93 133
94--------------------------------------------------------------------- 134---------------------------------------------------------------------
95 135
96foreign import ccall "wrapper" 136foreign import ccall "wrapper"
97 mkVecVecfun :: (CInt -> Ptr Double -> Ptr Double -> IO ()) 137 mkVecVecfun :: TVV -> IO (FunPtr TVV)
98 -> IO (FunPtr (CInt -> Ptr Double -> Ptr Double->IO()))
99 138
100aux_vTov :: (Vector Double -> Vector Double) -> (CInt -> Ptr Double -> Ptr Double -> IO()) 139aux_vTov :: (Vector Double -> Vector Double) -> TVV
101aux_vTov f n p r = g where 140aux_vTov f n p nr r = g where
102 V {fptr = pr} = f x 141 V {fptr = pr} = f x
103 x = createV (fromIntegral n) copy "aux_vTov" 142 x = createV (fromIntegral n) copy "aux_vTov"
104 copy n' q = do 143 copy n' q = do
105 copyArray q p (fromIntegral n') 144 copyArray q p (fromIntegral n')
106 return 0 145 return 0
107 g = withForeignPtr pr $ \p' -> copyArray r p' (fromIntegral n) 146 g = do withForeignPtr pr $ \p' -> copyArray r p' (fromIntegral nr)
147 return 0
148
149foreign import ccall "wrapper"
150 mkVecMatfun :: TVM -> IO (FunPtr TVM)
151
152aux_vTom :: (Vector Double -> Matrix Double) -> TVM
153aux_vTom f n p rr cr r = g where
154 V {fptr = pr} = flatten $ f x
155 x = createV (fromIntegral n) copy "aux_vTov"
156 copy n' q = do
157 copyArray q p (fromIntegral n')
158 return 0
159 g = do withForeignPtr pr $ \p' -> copyArray r p' (fromIntegral $ rr*cr)
160 return 0
108 161
109createV n fun msg = unsafePerformIO $ do 162createV n fun msg = unsafePerformIO $ do
110 r <- createVector n 163 r <- createVector n
@@ -116,8 +169,12 @@ createMIO r c fun msg = do
116 app1 fun mat res msg 169 app1 fun mat res msg
117 return res 170 return res
118 171
119checkdim n f x 172checkdim1 n v
120 | length y /= n = error $ "Error: "++ show n 173 | dim v == n = v
121 ++ " results expected in the function supplied to root" 174 | otherwise = error $ "Error: "++ show n
122 | otherwise = y 175 ++ " results expected in the function supplied to root"
123 where y = f x 176
177checkdim2 n m
178 | rows m == n && cols m == n = m
179 | otherwise = error $ "Error: "++ show n ++ "x" ++ show n
180 ++ " Jacobian expected in root"
diff --git a/lib/Numeric/GSL/gsl-aux.c b/lib/Numeric/GSL/gsl-aux.c
index 80c23fc..c6b052f 100644
--- a/lib/Numeric/GSL/gsl-aux.c
+++ b/lib/Numeric/GSL/gsl-aux.c
@@ -511,17 +511,17 @@ int minimizeWithDeriv(int method, double f(int, double*), void df(int, double*,
511 511
512//--------------------------------------------------------------- 512//---------------------------------------------------------------
513 513
514typedef void TrawfunV(int, double*, double*); 514typedef void TrawfunV(int, double*, int, double*);
515 515
516int only_f_aux_root(const gsl_vector*x, void *pars, gsl_vector*y) { 516int only_f_aux_root(const gsl_vector*x, void *pars, gsl_vector*y) {
517 TrawfunV * f = (TrawfunV*) pars; 517 TrawfunV * f = (TrawfunV*) pars;
518 double* p = (double*)calloc(x->size,sizeof(double)); 518 double* p = (double*)calloc(x->size,sizeof(double));
519 double* q = (double*)calloc(x->size,sizeof(double)); 519 double* q = (double*)calloc(y->size,sizeof(double));
520 int k; 520 int k;
521 for(k=0;k<x->size;k++) { 521 for(k=0;k<x->size;k++) {
522 p[k] = gsl_vector_get(x,k); 522 p[k] = gsl_vector_get(x,k);
523 } 523 }
524 f(x->size,p,q); 524 f(x->size,p,y->size,q);
525 for(k=0;k<y->size;k++) { 525 for(k=0;k<y->size;k++) {
526 gsl_vector_set(y,k,q[k]); 526 gsl_vector_set(y,k,q[k]);
527 } 527 }
@@ -588,3 +588,118 @@ int root(int method, void f(int, double*, int, double*),
588 gsl_multiroot_fsolver_free(s); 588 gsl_multiroot_fsolver_free(s);
589 OK 589 OK
590} 590}
591
592// working with the jacobian
593
594typedef struct {int (*f)(int, double*, int, double *); int (*jf)(int, double*, int, int, double*);} Tfjf;
595
596int f_aux_root(const gsl_vector*x, void *pars, gsl_vector*y) {
597 Tfjf * fjf = ((Tfjf*) pars);
598 double* p = (double*)calloc(x->size,sizeof(double));
599 double* q = (double*)calloc(y->size,sizeof(double));
600 int k;
601 for(k=0;k<x->size;k++) {
602 p[k] = gsl_vector_get(x,k);
603 }
604 (fjf->f)(x->size,p,y->size,q);
605 for(k=0;k<y->size;k++) {
606 gsl_vector_set(y,k,q[k]);
607 }
608 free(p);
609 free(q);
610 return 0;
611}
612
613int jf_aux_root(const gsl_vector * x, void * pars, gsl_matrix * jac) {
614 Tfjf * fjf = ((Tfjf*) pars);
615 double* p = (double*)calloc(x->size,sizeof(double));
616 double* q = (double*)calloc((x->size)*(x->size),sizeof(double));
617 int i,j,k;
618 for(k=0;k<x->size;k++) {
619 p[k] = gsl_vector_get(x,k);
620 }
621
622 (fjf->jf)(x->size,p,x->size,x->size,q);
623
624 k=0;
625 for(i=0;i<x->size;i++) {
626 for(j=0;j<x->size;j++){
627 gsl_matrix_set(jac,i,j,q[k++]);
628 }
629 }
630 free(p);
631 free(q);
632 return 0;
633}
634
635int fjf_aux_root(const gsl_vector * x, void * pars, gsl_vector * f, gsl_matrix * g) {
636 f_aux_root(x,pars,f);
637 jf_aux_root(x,pars,g);
638 return 0;
639}
640
641int rootj(int method, int f(int, double*, int, double*),
642 int jac(int, double*, int, int, double*),
643 double epsabs, int maxit,
644 KRVEC(xi), RMAT(sol)) {
645 REQUIRES(solr == maxit && solc == 1+2*xin,BAD_SIZE);
646 DEBUGMSG("root_fjf");
647 gsl_multiroot_function_fdf my_func;
648 // extract function from pars
649 my_func.f = f_aux_root;
650 my_func.df = jf_aux_root;
651 my_func.fdf = fjf_aux_root;
652 my_func.n = xin;
653 Tfjf stfjf;
654 stfjf.f = f;
655 stfjf.jf = jac;
656 my_func.params = &stfjf;
657 size_t iter = 0;
658 int status;
659 const gsl_multiroot_fdfsolver_type *T;
660 gsl_multiroot_fdfsolver *s;
661 // Starting point
662 KDVVIEW(xi);
663 switch(method) {
664 case 0 : {T = gsl_multiroot_fdfsolver_hybridsj;; break; }
665 case 1 : {T = gsl_multiroot_fdfsolver_hybridj; break; }
666 case 2 : {T = gsl_multiroot_fdfsolver_newton; break; }
667 case 3 : {T = gsl_multiroot_fdfsolver_gnewton; break; }
668 default: ERROR(BAD_CODE);
669 }
670 s = gsl_multiroot_fdfsolver_alloc (T, my_func.n);
671
672 gsl_multiroot_fdfsolver_set (s, &my_func, V(xi));
673
674 do {
675 status = gsl_multiroot_fdfsolver_iterate (s);
676
677 solp[iter*solc+0] = iter;
678
679 int k;
680 for(k=0;k<xin;k++) {
681 solp[iter*solc+k+1] = gsl_vector_get(s->x,k);
682 }
683 for(k=xin;k<2*xin;k++) {
684 solp[iter*solc+k+1] = gsl_vector_get(s->f,k-xin);
685 }
686
687 iter++;
688 if (status) /* check if solver is stuck */
689 break;
690
691 status =
692 gsl_multiroot_test_residual (s->f, epsabs);
693 }
694 while (status == GSL_CONTINUE && iter < maxit);
695
696 int i,j;
697 for (i=iter; i<solr; i++) {
698 solp[i*solc+0] = iter;
699 for(j=1;j<solc;j++) {
700 solp[i*solc+j]=0.;
701 }
702 }
703 gsl_multiroot_fdfsolver_free(s);
704 OK
705} \ No newline at end of file
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs
index 174e418..83f581f 100644
--- a/lib/Numeric/LinearAlgebra/Tests.hs
+++ b/lib/Numeric/LinearAlgebra/Tests.hs
@@ -119,9 +119,14 @@ minimizationTest = TestList [ utest "minimization conj grad" (minim1 f df [5,7]
119 119
120--------------------------------------------------------------------- 120---------------------------------------------------------------------
121 121
122rootFindingTest = utest "root Hybrids" (sol ~~ [1,1]) 122rootFindingTest = TestList [ utest "root Hybrids" (fst sol1 ~~ [1,1])
123 where sol = fst $ root Hybrids 1E-7 30 (rosenbrock 1 10) [-10,-5] 123 , utest "root Newton" (rows (snd sol2) == 2)
124 ]
125 where sol1 = root Hybrids 1E-7 30 (rosenbrock 1 10) [-10,-5]
126 sol2 = rootJ Newton 1E-7 30 (rosenbrock 1 10) (jacobian 1 10) [-10,-5]
124 rosenbrock a b [x,y] = [ a*(1-x), b*(y-x^2) ] 127 rosenbrock a b [x,y] = [ a*(1-x), b*(y-x^2) ]
128 jacobian a b [x,_y] = [ [-a , 0]
129 , [-2*b*x, b] ]
125 130
126--------------------------------------------------------------------- 131---------------------------------------------------------------------
127 132