From 25d7892ac78f0f1a4fda538dd35430ebff02baaa Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Mon, 12 Nov 2007 12:24:12 +0000 Subject: withMatrix --- lib/Data/Packed/Internal/Common.hs | 5 ++++ lib/Data/Packed/Internal/Matrix.hs | 41 +++++++++++++++++++++-------- lib/Data/Packed/Internal/Vector.hs | 34 ++++++++++++++---------- lib/Numeric/GSL/Differentiation.hs | 2 +- lib/Numeric/GSL/Fourier.hs | 3 ++- lib/Numeric/GSL/Integration.hs | 4 +-- lib/Numeric/GSL/Matrix.hs | 52 +++++++++++++++++++++++-------------- lib/Numeric/GSL/Minimization.hs | 34 ++++++++++++++---------- lib/Numeric/GSL/Polynomials.hs | 3 ++- lib/Numeric/GSL/Special/Internal.hs | 4 +-- lib/Numeric/GSL/Vector.hs | 12 ++++++--- lib/Numeric/LinearAlgebra/LAPACK.hs | 34 +++++++++++++++--------- 12 files changed, 146 insertions(+), 82 deletions(-) diff --git a/lib/Data/Packed/Internal/Common.hs b/lib/Data/Packed/Internal/Common.hs index 2b3ec28..c3a733c 100644 --- a/lib/Data/Packed/Internal/Common.hs +++ b/lib/Data/Packed/Internal/Common.hs @@ -60,6 +60,11 @@ common f = commonval . map f where infixl 0 // (//) = flip ($) +-- hmm.. +ww2 w1 o1 w2 o2 f = w1 o1 $ \a1 -> w2 o2 $ \a2 -> f a1 a2 +ww3 w1 o1 w2 o2 w3 o3 f = w1 o1 $ \a1 -> ww2 w2 o2 w3 o3 (f a1) +ww4 w1 o1 w2 o2 w3 o3 w4 o4 f = w1 o1 $ \a1 -> ww3 w2 o2 w3 o3 w4 o4 (f a1) + -- GSL error codes are <= 1024 -- | error codes for the auxiliary functions required by the wrappers errorCode :: Int -> String diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 5617996..90a96b5 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs @@ -79,9 +79,20 @@ cmat MF {rows = r, cols = c, fdat = d } = MC {rows = r, cols = c, cdat = transda fmat m@MF{} = m fmat MC {rows = r, cols = c, cdat = d } = MF {rows = r, cols = c, fdat = transdata c d r} -matc m f = f (rows m) (cols m) (ptr (cdat m)) -matf m f = f (rows m) (cols m) (ptr (fdat m)) +--matc m f = f (rows m) (cols m) (ptr (cdat m)) +--matf m f = f (rows m) (cols m) (ptr (fdat m)) +withMatrix MC {rows = r, cols = c, cdat = d } f = + withForeignPtr (fptr d) $ \p -> do + let m f = do + f r c p + f m + +withMatrix MF {rows = r, cols = c, fdat = d } f = + withForeignPtr (fptr d) $ \p -> do + let m f = do + f r c p + f m {- | Creates a vector by concatenation of rows @@ -236,7 +247,9 @@ transdataAux fun c1 d c2 = then d else unsafePerformIO $ do v <- createVector (dim d) - fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d,v] + withForeignPtr (fptr d) $ \pd -> + withForeignPtr (fptr v) $ \pv -> + fun r1 c1 pd r2 c2 pv // check "transdataAux" -- putStrLn $ "---> transdataAux" ++ show (toList d) ++ show (toList v) return v where r1 = dim d `div` c1 @@ -250,8 +263,8 @@ foreign import ccall safe "auxi.h transC" ------------------------------------------------------------------ -gmatC MF {rows = r, cols = c, fdat = d} f = f 1 c r (ptr d) -gmatC MC {rows = r, cols = c, cdat = d} f = f 0 r c (ptr d) +gmatC MF {rows = r, cols = c, fdat = d} p f = f 1 c r p +gmatC MC {rows = r, cols = c, cdat = d} p f = f 0 r c p dtt MC { cdat = d } = d dtt MF { fdat = d } = d @@ -260,7 +273,9 @@ multiplyAux fun a b = unsafePerformIO $ do when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ show (rows a,cols a) ++ " x " ++ show (rows b, cols b) r <- createMatrix RowMajor (rows a) (cols b) - fun // gmatC a // gmatC b // matc r // check "multiplyAux" [dtt a, dtt b, cdat r] + withForeignPtr (fptr (dtt a)) $ \pa -> withForeignPtr (fptr (dtt b)) $ \pb -> + withMatrix r $ \r -> + fun // gmatC a pa // gmatC b pb // r // check "multiplyAux" return r multiplyR = multiplyAux cmultiplyR @@ -293,7 +308,8 @@ subMatrixR :: (Int,Int) -> (Int,Int) -> Matrix Double -> Matrix Double subMatrixR (r0,c0) (rt,ct) x' = unsafePerformIO $ do r <- createMatrix RowMajor rt ct let x = cmat x' - c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // matc x // matc r // check "subMatrixR" [cdat x] + ww2 withMatrix x withMatrix r $ \x r -> + c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // x // r // check "subMatrixR" return r foreign import ccall "auxi.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM @@ -317,8 +333,9 @@ subMatrix = subMatrixD diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do m <- createMatrix RowMajor n n - fun // vec v // matc m // check msg [v] - return m -- {tdat = dat m} + ww2 withVector v withMatrix m $ \v m -> + fun // v // m // check msg + return m -- | diagonal matrix from a real vector diagR :: Vector Double -> Matrix Double @@ -339,7 +356,8 @@ diag = diagD constantAux fun x n = unsafePerformIO $ do v <- createVector n px <- newArray [x] - fun px // vec v // check "constantAux" [] + withVector v $ \v -> + fun px // v // check "constantAux" free px return v @@ -385,7 +403,8 @@ fromFile :: FilePath -> (Int,Int) -> IO (Matrix Double) fromFile filename (r,c) = do charname <- newCString filename res <- createMatrix RowMajor r c - c_gslReadMatrix charname // matc res // check "gslReadMatrix" [] + withMatrix res $ \res -> + c_gslReadMatrix charname // res // check "gslReadMatrix" --free charname -- TO DO: free the auxiliary CString return res foreign import ccall "auxi.h matrix_fscanf" c_gslReadMatrix:: Ptr CChar -> TM diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs index dc86484..386ebb5 100644 --- a/lib/Data/Packed/Internal/Vector.hs +++ b/lib/Data/Packed/Internal/Vector.hs @@ -31,11 +31,11 @@ data Vector t = V { dim :: Int -- ^ number of elements , fptr :: ForeignPtr t -- ^ foreign pointer to the memory block } -ptr (V _ fptr) = unsafeForeignPtrToPtr fptr +--ptr (V _ fptr) = unsafeForeignPtrToPtr fptr --- | check the error code and touch foreign ptr of vector arguments (if any) -check :: String -> [Vector a] -> IO Int -> IO () -check msg ls f = do +-- | check the error code +check :: String -> IO Int -> IO () +check msg f = do err <- f when (err/=0) $ if err > 1024 then (error (msg++": "++errorCode err)) -- our errors @@ -43,7 +43,6 @@ check msg ls f = do ps <- gsl_strerror err s <- peekCString ps error (msg++": "++s) - mapM_ (touchForeignPtr . fptr) ls return () -- | description of GSL error codes @@ -55,9 +54,14 @@ type Vc t s = Int -> Ptr t -> s -- infixr 5 :> -- type t :> s = Vc t s --- | adaptation of our vectors to be admitted by foreign functions: @f \/\/ vec v@ -vec :: Vector t -> (Vc t s) -> s -vec v f = f (dim v) (ptr v) +--- | adaptation of our vectors to be admitted by foreign functions: @f \/\/ vec v@ +--vec :: Vector t -> (Vc t s) -> s +--vec v f = f (dim v) (ptr v) + +withVector (V n fp) f = withForeignPtr fp $ \p -> do + let v f = do + f n p + f v -- | allocates memory for a new vector createVector :: Storable a => Int -> IO (Vector a) @@ -76,7 +80,8 @@ fromList :: Storable a => [a] -> Vector a fromList l = unsafePerformIO $ do v <- createVector (length l) let f _ p = pokeArray p l >> return 0 - f // vec v // check "fromList" [] + withVector v $ \v -> + f // v // check "fromList" return v safeRead v = unsafePerformIO . withForeignPtr (fptr v) @@ -118,8 +123,9 @@ subVector k l (v@V {dim=n}) | k<0 || k >= n || k+l > n || l < 0 = error "subVector out of range" | otherwise = unsafePerformIO $ do r <- createVector l - let f = copyArray (ptr r) (advancePtr (ptr v) k) l >> return 0 - f // check "subVector" [v,r] + let f _ s _ d = copyArray d (advancePtr s k) l >> return 0 + ww2 withVector v withVector r $ \v r -> + f // v // r // check "subVector" return r {- | Reads a vector position: @@ -144,12 +150,12 @@ join [] = error "joining zero vectors" join as = unsafePerformIO $ do let tot = sum (map dim as) r@V {fptr = p} <- createVector tot - withForeignPtr p $ \_ -> - joiner as tot (ptr r) + withForeignPtr p $ \ptr -> + joiner as tot ptr return r where joiner [] _ _ = return () joiner (r@V {dim = n, fptr = b} : cs) _ p = do - withForeignPtr b $ \_ -> copyArray p (ptr r) n + withForeignPtr b $ \pb -> copyArray p pb n joiner cs 0 (advancePtr p n) diff --git a/lib/Numeric/GSL/Differentiation.hs b/lib/Numeric/GSL/Differentiation.hs index e7fea92..09236bd 100644 --- a/lib/Numeric/GSL/Differentiation.hs +++ b/lib/Numeric/GSL/Differentiation.hs @@ -35,7 +35,7 @@ derivGen c h f x = unsafePerformIO $ do r <- malloc e <- malloc fp <- mkfun (\x _ -> f x) - c_deriv c fp x h r e // check "deriv" [] + c_deriv c fp x h r e // check "deriv" vr <- peek r ve <- peek e let result = (vr,ve) diff --git a/lib/Numeric/GSL/Fourier.hs b/lib/Numeric/GSL/Fourier.hs index e975fbf..4b08625 100644 --- a/lib/Numeric/GSL/Fourier.hs +++ b/lib/Numeric/GSL/Fourier.hs @@ -26,7 +26,8 @@ import Foreign genfft code v = unsafePerformIO $ do r <- createVector (dim v) - c_fft code // vec v // vec r // check "fft" [v] + ww2 withVector v withVector r $ \ v r -> + c_fft code // v // r // check "fft" return r foreign import ccall "gsl-aux.h fft" c_fft :: Int -> TCVCV diff --git a/lib/Numeric/GSL/Integration.hs b/lib/Numeric/GSL/Integration.hs index d756417..747b34c 100644 --- a/lib/Numeric/GSL/Integration.hs +++ b/lib/Numeric/GSL/Integration.hs @@ -44,7 +44,7 @@ integrateQAGS prec n f a b = unsafePerformIO $ do r <- malloc e <- malloc fp <- mkfun (\x _ -> f x) - c_integrate_qags fp a b prec n r e // check "integrate_qags" [] + c_integrate_qags fp a b prec n r e // check "integrate_qags" vr <- peek r ve <- peek e let result = (vr,ve) @@ -75,7 +75,7 @@ integrateQNG prec f a b = unsafePerformIO $ do r <- malloc e <- malloc fp <- mkfun (\x _ -> f x) - c_integrate_qng fp a b prec r e // check "integrate_qng" [] + c_integrate_qng fp a b prec r e // check "integrate_qng" vr <- peek r ve <- peek e let result = (vr,ve) diff --git a/lib/Numeric/GSL/Matrix.hs b/lib/Numeric/GSL/Matrix.hs index e803c53..09a0be4 100644 --- a/lib/Numeric/GSL/Matrix.hs +++ b/lib/Numeric/GSL/Matrix.hs @@ -51,7 +51,8 @@ eigSg' m | otherwise = unsafePerformIO $ do l <- createVector r v <- createMatrix RowMajor r r - c_eigS // matc m // vec l // matc v // check "eigSg" [cdat m] + ww3 withMatrix m withVector l withMatrix v $ \m l v -> + c_eigS // m // l // v // check "eigSg" return (l,v) where r = rows m foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM @@ -84,7 +85,8 @@ eigHg' m | otherwise = unsafePerformIO $ do l <- createVector r v <- createMatrix RowMajor r r - c_eigH // matc m // vec l // matc v // check "eigHg" [cdat m] + ww3 withMatrix m withVector l withMatrix v $ \m l v -> + c_eigH // m // l // v // check "eigHg" return (l,v) where r = rows m foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM @@ -120,7 +122,8 @@ svd' x = unsafePerformIO $ do u <- createMatrix RowMajor r c s <- createVector c v <- createMatrix RowMajor c c - c_svd // matc x // matc u // vec s // matc v // check "svdg" [cdat x] + ww4 withMatrix x withMatrix u withVector s withMatrix v $ \x u s v -> + c_svd // x // u // s // v // check "svdg" return (u,s,v) where r = rows x c = cols x @@ -149,7 +152,8 @@ qr = qr' . cmat qr' x = unsafePerformIO $ do q <- createMatrix RowMajor r r rot <- createMatrix RowMajor r c - c_qr // matc x // matc q // matc rot // check "qr" [cdat x] + ww3 withMatrix x withMatrix q withMatrix rot $ \x q rot -> + c_qr // x // q // rot // check "qr" return (q,rot) where r = rows x c = cols x @@ -161,7 +165,8 @@ qrPacked = qrPacked' . cmat qrPacked' x = unsafePerformIO $ do qr <- createMatrix RowMajor r c tau <- createVector (min r c) - c_qrPacked // matc x // matc qr // vec tau // check "qrUnpacked" [cdat x] + ww3 withMatrix x withMatrix qr withVector tau $ \x qr tau -> + c_qrPacked // x // qr // tau // check "qrUnpacked" return (qr,tau) where r = rows x c = cols x @@ -172,9 +177,10 @@ unpackQR (qr,tau) = unpackQR' (cmat qr, tau) unpackQR' (qr,tau) = unsafePerformIO $ do q <- createMatrix RowMajor r r - rot <- createMatrix RowMajor r c - c_qrUnpack // matc qr // vec tau // matc q // matc rot // check "qrUnpack" [cdat qr,tau] - return (q,rot) + res <- createMatrix RowMajor r c + ww4 withMatrix qr withVector tau withMatrix q withMatrix res $ \qr tau q res -> + c_qrUnpack // qr // tau // q // res // check "qrUnpack" + return (q,res) where r = rows qr c = cols qr foreign import ccall "gsl-aux.h QRunpack" c_qrUnpack :: TMVMM @@ -196,20 +202,22 @@ cholR :: Matrix Double -> Matrix Double cholR = cholR' . cmat cholR' x = unsafePerformIO $ do - res <- createMatrix RowMajor r r - c_cholR // matc x // matc res // check "cholR" [cdat x] - return res - where r = rows x + r <- createMatrix RowMajor n n + ww2 withMatrix x withMatrix r $ \x r -> + c_cholR // x // r // check "cholR" + return r + where n = rows x foreign import ccall "gsl-aux.h cholR" c_cholR :: TMM cholC :: Matrix (Complex Double) -> Matrix (Complex Double) cholC = cholC' . cmat cholC' x = unsafePerformIO $ do - res <- createMatrix RowMajor r r - c_cholC // matc x // matc res // check "cholC" [cdat x] - return res - where r = rows x + r <- createMatrix RowMajor n n + ww2 withMatrix x withMatrix r $ \x r -> + c_cholC // x // r // check "cholC" + return r + where n = rows x foreign import ccall "gsl-aux.h cholC" c_cholC :: TCMCM @@ -223,7 +231,8 @@ luSolveR a b = luSolveR' (cmat a) (cmat b) luSolveR' a b | n1==n2 && n1==r = unsafePerformIO $ do s <- createMatrix RowMajor r c - c_luSolveR // matc a // matc b // matc s // check "luSolveR" [cdat a, cdat b] + ww3 withMatrix a withMatrix b withMatrix s $ \ a b s -> + c_luSolveR // a // b // s // check "luSolveR" return s | otherwise = error "luSolveR of nonsquare matrix" where n1 = rows a @@ -240,7 +249,8 @@ luSolveC a b = luSolveC' (cmat a) (cmat b) luSolveC' a b | n1==n2 && n1==r = unsafePerformIO $ do s <- createMatrix RowMajor r c - c_luSolveC // matc a // matc b // matc s // check "luSolveC" [cdat a, cdat b] + ww3 withMatrix a withMatrix b withMatrix s $ \ a b s -> + c_luSolveC // a // b // s // check "luSolveC" return s | otherwise = error "luSolveC of nonsquare matrix" where n1 = rows a @@ -256,7 +266,8 @@ luRaux = luRaux' . cmat luRaux' x = unsafePerformIO $ do res <- createVector (r*r+r+1) - c_luRaux // matc x // vec res // check "luRaux" [cdat x] + ww2 withMatrix x withVector res $ \x res -> + c_luRaux // x // res // check "luRaux" return res where r = rows x c = cols x @@ -269,7 +280,8 @@ luCaux = luCaux' . cmat luCaux' x = unsafePerformIO $ do res <- createVector (r*r+r+1) - c_luCaux // matc x // vec res // check "luCaux" [cdat x] + ww2 withMatrix x withVector res $ \x res -> + c_luCaux // x // res // check "luCaux" return res where r = rows x c = cols x diff --git a/lib/Numeric/GSL/Minimization.hs b/lib/Numeric/GSL/Minimization.hs index f523849..e44b2e5 100644 --- a/lib/Numeric/GSL/Minimization.hs +++ b/lib/Numeric/GSL/Minimization.hs @@ -26,7 +26,6 @@ import Data.Packed.Matrix import Foreign import Complex - ------------------------------------------------------------------------- {- | The method of Nelder and Mead, implemented by /gsl_multimin_fminimizer_nmsimplex/. The gradient of the function is not required. This is the example in the GSL manual: @@ -85,9 +84,10 @@ minimizeNMSimplex f xi sz tol maxit = unsafePerformIO $ do szv = fromList sz n = dim xiv fp <- mkVecfun (iv (f.toList)) - rawpath <- createMIO maxit (n+3) - (c_minimizeNMSimplex fp tol maxit // vec xiv // vec szv) - "minimizeNMSimplex" [xiv,szv] + rawpath <- ww2 withVector xiv withVector szv $ \xiv szv -> + createMIO maxit (n+3) + (c_minimizeNMSimplex fp tol maxit // xiv // szv) + "minimizeNMSimplex" let it = round (rawpath @@> (maxit-1,0)) path = takeRows it rawpath [sol] = toLists $ dropRows (it-1) path @@ -150,9 +150,10 @@ minimizeConjugateGradient istep minimpar tol maxit f df xi = unsafePerformIO $ d df' = (fromList . df . toList) fp <- mkVecfun (iv f') dfp <- mkVecVecfun (aux_vTov df') - rawpath <- createMIO maxit (n+2) - (c_minimizeConjugateGradient fp dfp istep minimpar tol maxit // vec xiv) - "minimizeDerivV" [xiv] + rawpath <- withVector xiv $ \xiv -> + createMIO maxit (n+2) + (c_minimizeConjugateGradient fp dfp istep minimpar tol maxit // xiv) + "minimizeDerivV" let it = round (rawpath @@> (maxit-1,0)) path = takeRows it rawpath sol = toList $ cdat $ dropColumns 2 $ dropRows (it-1) path @@ -169,7 +170,7 @@ foreign import ccall "gsl-aux.h minimizeWithDeriv" --------------------------------------------------------------------- iv :: (Vector Double -> Double) -> (Int -> Ptr Double -> Double) -iv f n p = f (createV n copy "iv" []) where +iv f n p = f (createV n copy "iv") where copy n q = do copyArray q p n return 0 @@ -187,25 +188,30 @@ foreign import ccall "wrapper" aux_vTov :: (Vector Double -> Vector Double) -> (Int -> Ptr Double -> Ptr Double -> IO()) aux_vTov f n p r = g where v@V {fptr = pr} = f x - x = createV n copy "aux_vTov" [] + x = createV n copy "aux_vTov" copy n q = do copyArray q p n return 0 - g = withForeignPtr pr $ \_ -> copyArray r (ptr v) n + g = withForeignPtr pr $ \p -> copyArray r p n -------------------------------------------------------------------- -createV n fun msg ptrs = unsafePerformIO $ do + +createV n fun msg = unsafePerformIO $ do r <- createVector n - fun // vec r // check msg ptrs + withVector r $ \ r -> + fun // r // check msg return r +{- createM r c fun msg ptrs = unsafePerformIO $ do r <- createMatrix RowMajor r c fun // matc r // check msg ptrs return r +-} -createMIO r c fun msg ptrs = do +createMIO r c fun msg = do r <- createMatrix RowMajor r c - fun // matc r // check msg ptrs + withMatrix r $ \ r -> + fun // r // check msg return r diff --git a/lib/Numeric/GSL/Polynomials.hs b/lib/Numeric/GSL/Polynomials.hs index 42694f0..e663711 100644 --- a/lib/Numeric/GSL/Polynomials.hs +++ b/lib/Numeric/GSL/Polynomials.hs @@ -47,7 +47,8 @@ polySolve = toList . polySolve' . fromList polySolve' :: Vector Double -> Vector (Complex Double) polySolve' v | dim v > 1 = unsafePerformIO $ do r <- createVector (dim v-1) - c_polySolve // vec v // vec r // check "polySolve" [v] + ww2 withVector v withVector r $ \ v r -> + c_polySolve // v // r // check "polySolve" return r | otherwise = error "polySolve on a polynomial of degree zero" diff --git a/lib/Numeric/GSL/Special/Internal.hs b/lib/Numeric/GSL/Special/Internal.hs index a08809b..ca36009 100644 --- a/lib/Numeric/GSL/Special/Internal.hs +++ b/lib/Numeric/GSL/Special/Internal.hs @@ -45,7 +45,7 @@ type Size_t = Int createSFR :: Storable a => String -> (Ptr a -> IO Int) -> (a, a) createSFR s f = unsafePerformIO $ do p <- mallocArray 2 - f p // check s [] + f p // check s [val,err] <- peekArray 2 p free p return (val,err) @@ -60,7 +60,7 @@ createSFR_E10 s f = unsafePerformIO $ do let sd = sizeOf (0::Double) let si = sizeOf (0::Int) p <- mallocBytes (2*sd + si) - f p // check s [] + f p // check s val <- peekByteOff p 0 err <- peekByteOff p sd expo <- peekByteOff p (2*sd) diff --git a/lib/Numeric/GSL/Vector.hs b/lib/Numeric/GSL/Vector.hs index d94b377..65f3a2e 100644 --- a/lib/Numeric/GSL/Vector.hs +++ b/lib/Numeric/GSL/Vector.hs @@ -73,24 +73,28 @@ data FunCodeS = Norm2 toScalarAux fun code v = unsafePerformIO $ do r <- createVector 1 - fun (fromEnum code) // vec v // vec r // check "toScalarAux" [v] + ww2 withVector v withVector r $ \v r -> + fun (fromEnum code) // v // r // check "toScalarAux" return (r `at` 0) vectorMapAux fun code v = unsafePerformIO $ do r <- createVector (dim v) - fun (fromEnum code) // vec v // vec r // check "vectorMapAux" [v] + ww2 withVector v withVector r $ \v r -> + fun (fromEnum code) // v // r // check "vectorMapAux" return r vectorMapValAux fun code val v = unsafePerformIO $ do r <- createVector (dim v) pval <- newArray [val] - fun (fromEnum code) pval // vec v // vec r // check "vectorMapValAux" [v] + ww2 withVector v withVector r $ \v r -> + fun (fromEnum code) pval // v // r // check "vectorMapValAux" free pval return r vectorZipAux fun code u v = unsafePerformIO $ do r <- createVector (dim u) - fun (fromEnum code) // vec u // vec v // vec r // check "vectorZipAux" [u,v] + ww3 withVector u withVector v withVector r $ \u v r -> + fun (fromEnum code) // u // v // r // check "vectorZipAux" return r --------------------------------------------------------------------- diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs index 315be17..19516e3 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK.hs +++ b/lib/Numeric/LinearAlgebra/LAPACK.hs @@ -61,7 +61,8 @@ svdAux f st x = unsafePerformIO $ do u <- createMatrix ColumnMajor r r s <- createVector (min r c) v <- createMatrix ColumnMajor c c - f // matf x // matf u // vec s // matf v // check st [fdat x] + ww4 withMatrix x withMatrix u withVector s withMatrix v $ \x u s v -> + f // x // u // s // v // check st return (u,s,trans v) where r = rows x c = cols x @@ -73,7 +74,8 @@ eigAux f st m l <- createVector r v <- createMatrix ColumnMajor r r dummy <- createMatrix ColumnMajor 1 1 - f // matf m // matf dummy // vec l // matf v // check st [fdat m] + ww4 withMatrix m withMatrix dummy withVector l withMatrix v $ \m dummy l v -> + f // m // dummy // l // v // check st return (l,v) where r = rows m @@ -115,7 +117,8 @@ eigRaux m l <- createVector r v <- createMatrix ColumnMajor r r dummy <- createMatrix ColumnMajor 1 1 - dgeev // matf m // matf dummy // vec l // matf v // check "eigR" [fdat m] + ww4 withMatrix m withMatrix dummy withVector l withMatrix v $ \m dummy l v -> + dgeev // m // dummy // l // v // check "eigR" return (l,v) where r = rows m @@ -144,7 +147,8 @@ eigS' m | otherwise = unsafePerformIO $ do l <- createVector r v <- createMatrix ColumnMajor r r - dsyev // matf m // vec l // matf v // check "eigS" [fdat m] + ww3 withMatrix m withVector l withMatrix v $ \m l v -> + dsyev // m // l // v // check "eigS" return (l,v) where r = rows m @@ -166,7 +170,8 @@ eigH' m | otherwise = unsafePerformIO $ do l <- createVector r v <- createMatrix ColumnMajor r r - zheev // matf m // vec l // matf v // check "eigH" [fdat m] + ww3 withMatrix m withVector l withMatrix v $ \m l v -> + zheev // m // l // v // check "eigH" return (l,v) where r = rows m @@ -177,7 +182,8 @@ foreign import ccall "LAPACK/lapack-aux.h linearSolveC_l" zgesv :: TCMCMCM linearSolveSQAux f st a b | n1==n2 && n1==r = unsafePerformIO $ do s <- createMatrix ColumnMajor r c - f // matf a // matf b // matf s // check st [fdat a, fdat b] + ww3 withMatrix a withMatrix b withMatrix s $ \a b s -> + f // a // b // s // check st return s | otherwise = error $ st ++ " of nonsquare matrix" where n1 = rows a @@ -201,7 +207,8 @@ foreign import ccall "LAPACK/lapack-aux.h linearSolveSVDC_l" zgelss :: Double -> linearSolveAux f st a b = unsafePerformIO $ do r <- createMatrix ColumnMajor (max m n) nrhs - f // matf a // matf b // matf r // check st [fdat a, fdat b] + ww3 withMatrix a withMatrix b withMatrix r $ \a b r -> + f // a // b // r // check st return r where m = rows a n = cols a @@ -251,7 +258,8 @@ cholS = cholAux dpotrf "cholS" . fmat cholAux f st a = unsafePerformIO $ do r <- createMatrix ColumnMajor n n - f // matf a // matf r // check st [fdat a] + ww2 withMatrix a withMatrix r $ \a r -> + f // a // r // check st return r where n = rows a @@ -270,8 +278,8 @@ qrC = qrAux zgeqr2 "qrC" . fmat qrAux f st a = unsafePerformIO $ do r <- createMatrix ColumnMajor m n tau <- createVector mn - withForeignPtr (fptr $ fdat $ a) $ \p -> - f m n p // vec tau // matf r // check st [fdat a] + ww3 withMatrix a withMatrix r withVector tau $ \ a r tau -> + f // a // tau // r // check st return (r,tau) where m = rows a n = cols a @@ -292,7 +300,8 @@ hessC = hessAux zgehrd "hessC" . fmat hessAux f st a = unsafePerformIO $ do r <- createMatrix ColumnMajor m n tau <- createVector (mn-1) - f // matf a // vec tau // matf r // check st [fdat a] + ww3 withMatrix a withMatrix r withVector tau $ \ a r tau -> + f // a // tau // r // check st return (r,tau) where m = rows a n = cols a @@ -313,7 +322,8 @@ schurC = schurAux zgees "schurC" . fmat schurAux f st a = unsafePerformIO $ do u <- createMatrix ColumnMajor n n s <- createMatrix ColumnMajor n n - f // matf a // matf u // matf s // check st [fdat a] + ww3 withMatrix a withMatrix u withMatrix s $ \ a u s -> + f // a // u // s // check st return (u,s) where n = rows a -- cgit v1.2.3