From 25d7892ac78f0f1a4fda538dd35430ebff02baaa Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Mon, 12 Nov 2007 12:24:12 +0000 Subject: withMatrix --- lib/Numeric/GSL/Differentiation.hs | 2 +- lib/Numeric/GSL/Fourier.hs | 3 ++- lib/Numeric/GSL/Integration.hs | 4 +-- lib/Numeric/GSL/Matrix.hs | 52 +++++++++++++++++++++++-------------- lib/Numeric/GSL/Minimization.hs | 34 ++++++++++++++---------- lib/Numeric/GSL/Polynomials.hs | 3 ++- lib/Numeric/GSL/Special/Internal.hs | 4 +-- lib/Numeric/GSL/Vector.hs | 12 ++++++--- lib/Numeric/LinearAlgebra/LAPACK.hs | 34 +++++++++++++++--------- 9 files changed, 91 insertions(+), 57 deletions(-) (limited to 'lib/Numeric') diff --git a/lib/Numeric/GSL/Differentiation.hs b/lib/Numeric/GSL/Differentiation.hs index e7fea92..09236bd 100644 --- a/lib/Numeric/GSL/Differentiation.hs +++ b/lib/Numeric/GSL/Differentiation.hs @@ -35,7 +35,7 @@ derivGen c h f x = unsafePerformIO $ do r <- malloc e <- malloc fp <- mkfun (\x _ -> f x) - c_deriv c fp x h r e // check "deriv" [] + c_deriv c fp x h r e // check "deriv" vr <- peek r ve <- peek e let result = (vr,ve) diff --git a/lib/Numeric/GSL/Fourier.hs b/lib/Numeric/GSL/Fourier.hs index e975fbf..4b08625 100644 --- a/lib/Numeric/GSL/Fourier.hs +++ b/lib/Numeric/GSL/Fourier.hs @@ -26,7 +26,8 @@ import Foreign genfft code v = unsafePerformIO $ do r <- createVector (dim v) - c_fft code // vec v // vec r // check "fft" [v] + ww2 withVector v withVector r $ \ v r -> + c_fft code // v // r // check "fft" return r foreign import ccall "gsl-aux.h fft" c_fft :: Int -> TCVCV diff --git a/lib/Numeric/GSL/Integration.hs b/lib/Numeric/GSL/Integration.hs index d756417..747b34c 100644 --- a/lib/Numeric/GSL/Integration.hs +++ b/lib/Numeric/GSL/Integration.hs @@ -44,7 +44,7 @@ integrateQAGS prec n f a b = unsafePerformIO $ do r <- malloc e <- malloc fp <- mkfun (\x _ -> f x) - c_integrate_qags fp a b prec n r e // check "integrate_qags" [] + c_integrate_qags fp a b prec n r e // check "integrate_qags" vr <- peek r ve <- peek e let result = (vr,ve) @@ -75,7 +75,7 @@ integrateQNG prec f a b = unsafePerformIO $ do r <- malloc e <- malloc fp <- mkfun (\x _ -> f x) - c_integrate_qng fp a b prec r e // check "integrate_qng" [] + c_integrate_qng fp a b prec r e // check "integrate_qng" vr <- peek r ve <- peek e let result = (vr,ve) diff --git a/lib/Numeric/GSL/Matrix.hs b/lib/Numeric/GSL/Matrix.hs index e803c53..09a0be4 100644 --- a/lib/Numeric/GSL/Matrix.hs +++ b/lib/Numeric/GSL/Matrix.hs @@ -51,7 +51,8 @@ eigSg' m | otherwise = unsafePerformIO $ do l <- createVector r v <- createMatrix RowMajor r r - c_eigS // matc m // vec l // matc v // check "eigSg" [cdat m] + ww3 withMatrix m withVector l withMatrix v $ \m l v -> + c_eigS // m // l // v // check "eigSg" return (l,v) where r = rows m foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM @@ -84,7 +85,8 @@ eigHg' m | otherwise = unsafePerformIO $ do l <- createVector r v <- createMatrix RowMajor r r - c_eigH // matc m // vec l // matc v // check "eigHg" [cdat m] + ww3 withMatrix m withVector l withMatrix v $ \m l v -> + c_eigH // m // l // v // check "eigHg" return (l,v) where r = rows m foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM @@ -120,7 +122,8 @@ svd' x = unsafePerformIO $ do u <- createMatrix RowMajor r c s <- createVector c v <- createMatrix RowMajor c c - c_svd // matc x // matc u // vec s // matc v // check "svdg" [cdat x] + ww4 withMatrix x withMatrix u withVector s withMatrix v $ \x u s v -> + c_svd // x // u // s // v // check "svdg" return (u,s,v) where r = rows x c = cols x @@ -149,7 +152,8 @@ qr = qr' . cmat qr' x = unsafePerformIO $ do q <- createMatrix RowMajor r r rot <- createMatrix RowMajor r c - c_qr // matc x // matc q // matc rot // check "qr" [cdat x] + ww3 withMatrix x withMatrix q withMatrix rot $ \x q rot -> + c_qr // x // q // rot // check "qr" return (q,rot) where r = rows x c = cols x @@ -161,7 +165,8 @@ qrPacked = qrPacked' . cmat qrPacked' x = unsafePerformIO $ do qr <- createMatrix RowMajor r c tau <- createVector (min r c) - c_qrPacked // matc x // matc qr // vec tau // check "qrUnpacked" [cdat x] + ww3 withMatrix x withMatrix qr withVector tau $ \x qr tau -> + c_qrPacked // x // qr // tau // check "qrUnpacked" return (qr,tau) where r = rows x c = cols x @@ -172,9 +177,10 @@ unpackQR (qr,tau) = unpackQR' (cmat qr, tau) unpackQR' (qr,tau) = unsafePerformIO $ do q <- createMatrix RowMajor r r - rot <- createMatrix RowMajor r c - c_qrUnpack // matc qr // vec tau // matc q // matc rot // check "qrUnpack" [cdat qr,tau] - return (q,rot) + res <- createMatrix RowMajor r c + ww4 withMatrix qr withVector tau withMatrix q withMatrix res $ \qr tau q res -> + c_qrUnpack // qr // tau // q // res // check "qrUnpack" + return (q,res) where r = rows qr c = cols qr foreign import ccall "gsl-aux.h QRunpack" c_qrUnpack :: TMVMM @@ -196,20 +202,22 @@ cholR :: Matrix Double -> Matrix Double cholR = cholR' . cmat cholR' x = unsafePerformIO $ do - res <- createMatrix RowMajor r r - c_cholR // matc x // matc res // check "cholR" [cdat x] - return res - where r = rows x + r <- createMatrix RowMajor n n + ww2 withMatrix x withMatrix r $ \x r -> + c_cholR // x // r // check "cholR" + return r + where n = rows x foreign import ccall "gsl-aux.h cholR" c_cholR :: TMM cholC :: Matrix (Complex Double) -> Matrix (Complex Double) cholC = cholC' . cmat cholC' x = unsafePerformIO $ do - res <- createMatrix RowMajor r r - c_cholC // matc x // matc res // check "cholC" [cdat x] - return res - where r = rows x + r <- createMatrix RowMajor n n + ww2 withMatrix x withMatrix r $ \x r -> + c_cholC // x // r // check "cholC" + return r + where n = rows x foreign import ccall "gsl-aux.h cholC" c_cholC :: TCMCM @@ -223,7 +231,8 @@ luSolveR a b = luSolveR' (cmat a) (cmat b) luSolveR' a b | n1==n2 && n1==r = unsafePerformIO $ do s <- createMatrix RowMajor r c - c_luSolveR // matc a // matc b // matc s // check "luSolveR" [cdat a, cdat b] + ww3 withMatrix a withMatrix b withMatrix s $ \ a b s -> + c_luSolveR // a // b // s // check "luSolveR" return s | otherwise = error "luSolveR of nonsquare matrix" where n1 = rows a @@ -240,7 +249,8 @@ luSolveC a b = luSolveC' (cmat a) (cmat b) luSolveC' a b | n1==n2 && n1==r = unsafePerformIO $ do s <- createMatrix RowMajor r c - c_luSolveC // matc a // matc b // matc s // check "luSolveC" [cdat a, cdat b] + ww3 withMatrix a withMatrix b withMatrix s $ \ a b s -> + c_luSolveC // a // b // s // check "luSolveC" return s | otherwise = error "luSolveC of nonsquare matrix" where n1 = rows a @@ -256,7 +266,8 @@ luRaux = luRaux' . cmat luRaux' x = unsafePerformIO $ do res <- createVector (r*r+r+1) - c_luRaux // matc x // vec res // check "luRaux" [cdat x] + ww2 withMatrix x withVector res $ \x res -> + c_luRaux // x // res // check "luRaux" return res where r = rows x c = cols x @@ -269,7 +280,8 @@ luCaux = luCaux' . cmat luCaux' x = unsafePerformIO $ do res <- createVector (r*r+r+1) - c_luCaux // matc x // vec res // check "luCaux" [cdat x] + ww2 withMatrix x withVector res $ \x res -> + c_luCaux // x // res // check "luCaux" return res where r = rows x c = cols x diff --git a/lib/Numeric/GSL/Minimization.hs b/lib/Numeric/GSL/Minimization.hs index f523849..e44b2e5 100644 --- a/lib/Numeric/GSL/Minimization.hs +++ b/lib/Numeric/GSL/Minimization.hs @@ -26,7 +26,6 @@ import Data.Packed.Matrix import Foreign import Complex - ------------------------------------------------------------------------- {- | The method of Nelder and Mead, implemented by /gsl_multimin_fminimizer_nmsimplex/. The gradient of the function is not required. This is the example in the GSL manual: @@ -85,9 +84,10 @@ minimizeNMSimplex f xi sz tol maxit = unsafePerformIO $ do szv = fromList sz n = dim xiv fp <- mkVecfun (iv (f.toList)) - rawpath <- createMIO maxit (n+3) - (c_minimizeNMSimplex fp tol maxit // vec xiv // vec szv) - "minimizeNMSimplex" [xiv,szv] + rawpath <- ww2 withVector xiv withVector szv $ \xiv szv -> + createMIO maxit (n+3) + (c_minimizeNMSimplex fp tol maxit // xiv // szv) + "minimizeNMSimplex" let it = round (rawpath @@> (maxit-1,0)) path = takeRows it rawpath [sol] = toLists $ dropRows (it-1) path @@ -150,9 +150,10 @@ minimizeConjugateGradient istep minimpar tol maxit f df xi = unsafePerformIO $ d df' = (fromList . df . toList) fp <- mkVecfun (iv f') dfp <- mkVecVecfun (aux_vTov df') - rawpath <- createMIO maxit (n+2) - (c_minimizeConjugateGradient fp dfp istep minimpar tol maxit // vec xiv) - "minimizeDerivV" [xiv] + rawpath <- withVector xiv $ \xiv -> + createMIO maxit (n+2) + (c_minimizeConjugateGradient fp dfp istep minimpar tol maxit // xiv) + "minimizeDerivV" let it = round (rawpath @@> (maxit-1,0)) path = takeRows it rawpath sol = toList $ cdat $ dropColumns 2 $ dropRows (it-1) path @@ -169,7 +170,7 @@ foreign import ccall "gsl-aux.h minimizeWithDeriv" --------------------------------------------------------------------- iv :: (Vector Double -> Double) -> (Int -> Ptr Double -> Double) -iv f n p = f (createV n copy "iv" []) where +iv f n p = f (createV n copy "iv") where copy n q = do copyArray q p n return 0 @@ -187,25 +188,30 @@ foreign import ccall "wrapper" aux_vTov :: (Vector Double -> Vector Double) -> (Int -> Ptr Double -> Ptr Double -> IO()) aux_vTov f n p r = g where v@V {fptr = pr} = f x - x = createV n copy "aux_vTov" [] + x = createV n copy "aux_vTov" copy n q = do copyArray q p n return 0 - g = withForeignPtr pr $ \_ -> copyArray r (ptr v) n + g = withForeignPtr pr $ \p -> copyArray r p n -------------------------------------------------------------------- -createV n fun msg ptrs = unsafePerformIO $ do + +createV n fun msg = unsafePerformIO $ do r <- createVector n - fun // vec r // check msg ptrs + withVector r $ \ r -> + fun // r // check msg return r +{- createM r c fun msg ptrs = unsafePerformIO $ do r <- createMatrix RowMajor r c fun // matc r // check msg ptrs return r +-} -createMIO r c fun msg ptrs = do +createMIO r c fun msg = do r <- createMatrix RowMajor r c - fun // matc r // check msg ptrs + withMatrix r $ \ r -> + fun // r // check msg return r diff --git a/lib/Numeric/GSL/Polynomials.hs b/lib/Numeric/GSL/Polynomials.hs index 42694f0..e663711 100644 --- a/lib/Numeric/GSL/Polynomials.hs +++ b/lib/Numeric/GSL/Polynomials.hs @@ -47,7 +47,8 @@ polySolve = toList . polySolve' . fromList polySolve' :: Vector Double -> Vector (Complex Double) polySolve' v | dim v > 1 = unsafePerformIO $ do r <- createVector (dim v-1) - c_polySolve // vec v // vec r // check "polySolve" [v] + ww2 withVector v withVector r $ \ v r -> + c_polySolve // v // r // check "polySolve" return r | otherwise = error "polySolve on a polynomial of degree zero" diff --git a/lib/Numeric/GSL/Special/Internal.hs b/lib/Numeric/GSL/Special/Internal.hs index a08809b..ca36009 100644 --- a/lib/Numeric/GSL/Special/Internal.hs +++ b/lib/Numeric/GSL/Special/Internal.hs @@ -45,7 +45,7 @@ type Size_t = Int createSFR :: Storable a => String -> (Ptr a -> IO Int) -> (a, a) createSFR s f = unsafePerformIO $ do p <- mallocArray 2 - f p // check s [] + f p // check s [val,err] <- peekArray 2 p free p return (val,err) @@ -60,7 +60,7 @@ createSFR_E10 s f = unsafePerformIO $ do let sd = sizeOf (0::Double) let si = sizeOf (0::Int) p <- mallocBytes (2*sd + si) - f p // check s [] + f p // check s val <- peekByteOff p 0 err <- peekByteOff p sd expo <- peekByteOff p (2*sd) diff --git a/lib/Numeric/GSL/Vector.hs b/lib/Numeric/GSL/Vector.hs index d94b377..65f3a2e 100644 --- a/lib/Numeric/GSL/Vector.hs +++ b/lib/Numeric/GSL/Vector.hs @@ -73,24 +73,28 @@ data FunCodeS = Norm2 toScalarAux fun code v = unsafePerformIO $ do r <- createVector 1 - fun (fromEnum code) // vec v // vec r // check "toScalarAux" [v] + ww2 withVector v withVector r $ \v r -> + fun (fromEnum code) // v // r // check "toScalarAux" return (r `at` 0) vectorMapAux fun code v = unsafePerformIO $ do r <- createVector (dim v) - fun (fromEnum code) // vec v // vec r // check "vectorMapAux" [v] + ww2 withVector v withVector r $ \v r -> + fun (fromEnum code) // v // r // check "vectorMapAux" return r vectorMapValAux fun code val v = unsafePerformIO $ do r <- createVector (dim v) pval <- newArray [val] - fun (fromEnum code) pval // vec v // vec r // check "vectorMapValAux" [v] + ww2 withVector v withVector r $ \v r -> + fun (fromEnum code) pval // v // r // check "vectorMapValAux" free pval return r vectorZipAux fun code u v = unsafePerformIO $ do r <- createVector (dim u) - fun (fromEnum code) // vec u // vec v // vec r // check "vectorZipAux" [u,v] + ww3 withVector u withVector v withVector r $ \u v r -> + fun (fromEnum code) // u // v // r // check "vectorZipAux" return r --------------------------------------------------------------------- diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs index 315be17..19516e3 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK.hs +++ b/lib/Numeric/LinearAlgebra/LAPACK.hs @@ -61,7 +61,8 @@ svdAux f st x = unsafePerformIO $ do u <- createMatrix ColumnMajor r r s <- createVector (min r c) v <- createMatrix ColumnMajor c c - f // matf x // matf u // vec s // matf v // check st [fdat x] + ww4 withMatrix x withMatrix u withVector s withMatrix v $ \x u s v -> + f // x // u // s // v // check st return (u,s,trans v) where r = rows x c = cols x @@ -73,7 +74,8 @@ eigAux f st m l <- createVector r v <- createMatrix ColumnMajor r r dummy <- createMatrix ColumnMajor 1 1 - f // matf m // matf dummy // vec l // matf v // check st [fdat m] + ww4 withMatrix m withMatrix dummy withVector l withMatrix v $ \m dummy l v -> + f // m // dummy // l // v // check st return (l,v) where r = rows m @@ -115,7 +117,8 @@ eigRaux m l <- createVector r v <- createMatrix ColumnMajor r r dummy <- createMatrix ColumnMajor 1 1 - dgeev // matf m // matf dummy // vec l // matf v // check "eigR" [fdat m] + ww4 withMatrix m withMatrix dummy withVector l withMatrix v $ \m dummy l v -> + dgeev // m // dummy // l // v // check "eigR" return (l,v) where r = rows m @@ -144,7 +147,8 @@ eigS' m | otherwise = unsafePerformIO $ do l <- createVector r v <- createMatrix ColumnMajor r r - dsyev // matf m // vec l // matf v // check "eigS" [fdat m] + ww3 withMatrix m withVector l withMatrix v $ \m l v -> + dsyev // m // l // v // check "eigS" return (l,v) where r = rows m @@ -166,7 +170,8 @@ eigH' m | otherwise = unsafePerformIO $ do l <- createVector r v <- createMatrix ColumnMajor r r - zheev // matf m // vec l // matf v // check "eigH" [fdat m] + ww3 withMatrix m withVector l withMatrix v $ \m l v -> + zheev // m // l // v // check "eigH" return (l,v) where r = rows m @@ -177,7 +182,8 @@ foreign import ccall "LAPACK/lapack-aux.h linearSolveC_l" zgesv :: TCMCMCM linearSolveSQAux f st a b | n1==n2 && n1==r = unsafePerformIO $ do s <- createMatrix ColumnMajor r c - f // matf a // matf b // matf s // check st [fdat a, fdat b] + ww3 withMatrix a withMatrix b withMatrix s $ \a b s -> + f // a // b // s // check st return s | otherwise = error $ st ++ " of nonsquare matrix" where n1 = rows a @@ -201,7 +207,8 @@ foreign import ccall "LAPACK/lapack-aux.h linearSolveSVDC_l" zgelss :: Double -> linearSolveAux f st a b = unsafePerformIO $ do r <- createMatrix ColumnMajor (max m n) nrhs - f // matf a // matf b // matf r // check st [fdat a, fdat b] + ww3 withMatrix a withMatrix b withMatrix r $ \a b r -> + f // a // b // r // check st return r where m = rows a n = cols a @@ -251,7 +258,8 @@ cholS = cholAux dpotrf "cholS" . fmat cholAux f st a = unsafePerformIO $ do r <- createMatrix ColumnMajor n n - f // matf a // matf r // check st [fdat a] + ww2 withMatrix a withMatrix r $ \a r -> + f // a // r // check st return r where n = rows a @@ -270,8 +278,8 @@ qrC = qrAux zgeqr2 "qrC" . fmat qrAux f st a = unsafePerformIO $ do r <- createMatrix ColumnMajor m n tau <- createVector mn - withForeignPtr (fptr $ fdat $ a) $ \p -> - f m n p // vec tau // matf r // check st [fdat a] + ww3 withMatrix a withMatrix r withVector tau $ \ a r tau -> + f // a // tau // r // check st return (r,tau) where m = rows a n = cols a @@ -292,7 +300,8 @@ hessC = hessAux zgehrd "hessC" . fmat hessAux f st a = unsafePerformIO $ do r <- createMatrix ColumnMajor m n tau <- createVector (mn-1) - f // matf a // vec tau // matf r // check st [fdat a] + ww3 withMatrix a withMatrix r withVector tau $ \ a r tau -> + f // a // tau // r // check st return (r,tau) where m = rows a n = cols a @@ -313,7 +322,8 @@ schurC = schurAux zgees "schurC" . fmat schurAux f st a = unsafePerformIO $ do u <- createMatrix ColumnMajor n n s <- createMatrix ColumnMajor n n - f // matf a // matf u // matf s // check st [fdat a] + ww3 withMatrix a withMatrix u withMatrix s $ \ a u s -> + f // a // u // s // check st return (u,s) where n = rows a -- cgit v1.2.3