summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2007-11-12 12:24:12 +0000
committerAlberto Ruiz <aruiz@um.es>2007-11-12 12:24:12 +0000
commit25d7892ac78f0f1a4fda538dd35430ebff02baaa (patch)
tree170572a869a5d73cd09bdf39b17fbb37b6e451fd
parent33a9909d0d59f468039597c405306b8d5fa9e008 (diff)
withMatrix
-rw-r--r--lib/Data/Packed/Internal/Common.hs5
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs41
-rw-r--r--lib/Data/Packed/Internal/Vector.hs34
-rw-r--r--lib/Numeric/GSL/Differentiation.hs2
-rw-r--r--lib/Numeric/GSL/Fourier.hs3
-rw-r--r--lib/Numeric/GSL/Integration.hs4
-rw-r--r--lib/Numeric/GSL/Matrix.hs52
-rw-r--r--lib/Numeric/GSL/Minimization.hs34
-rw-r--r--lib/Numeric/GSL/Polynomials.hs3
-rw-r--r--lib/Numeric/GSL/Special/Internal.hs4
-rw-r--r--lib/Numeric/GSL/Vector.hs12
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK.hs34
12 files changed, 146 insertions, 82 deletions
diff --git a/lib/Data/Packed/Internal/Common.hs b/lib/Data/Packed/Internal/Common.hs
index 2b3ec28..c3a733c 100644
--- a/lib/Data/Packed/Internal/Common.hs
+++ b/lib/Data/Packed/Internal/Common.hs
@@ -60,6 +60,11 @@ common f = commonval . map f where
60infixl 0 // 60infixl 0 //
61(//) = flip ($) 61(//) = flip ($)
62 62
63-- hmm..
64ww2 w1 o1 w2 o2 f = w1 o1 $ \a1 -> w2 o2 $ \a2 -> f a1 a2
65ww3 w1 o1 w2 o2 w3 o3 f = w1 o1 $ \a1 -> ww2 w2 o2 w3 o3 (f a1)
66ww4 w1 o1 w2 o2 w3 o3 w4 o4 f = w1 o1 $ \a1 -> ww3 w2 o2 w3 o3 w4 o4 (f a1)
67
63-- GSL error codes are <= 1024 68-- GSL error codes are <= 1024
64-- | error codes for the auxiliary functions required by the wrappers 69-- | error codes for the auxiliary functions required by the wrappers
65errorCode :: Int -> String 70errorCode :: Int -> String
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index 5617996..90a96b5 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -79,9 +79,20 @@ cmat MF {rows = r, cols = c, fdat = d } = MC {rows = r, cols = c, cdat = transda
79fmat m@MF{} = m 79fmat m@MF{} = m
80fmat MC {rows = r, cols = c, cdat = d } = MF {rows = r, cols = c, fdat = transdata c d r} 80fmat MC {rows = r, cols = c, cdat = d } = MF {rows = r, cols = c, fdat = transdata c d r}
81 81
82matc m f = f (rows m) (cols m) (ptr (cdat m)) 82--matc m f = f (rows m) (cols m) (ptr (cdat m))
83matf m f = f (rows m) (cols m) (ptr (fdat m)) 83--matf m f = f (rows m) (cols m) (ptr (fdat m))
84 84
85withMatrix MC {rows = r, cols = c, cdat = d } f =
86 withForeignPtr (fptr d) $ \p -> do
87 let m f = do
88 f r c p
89 f m
90
91withMatrix MF {rows = r, cols = c, fdat = d } f =
92 withForeignPtr (fptr d) $ \p -> do
93 let m f = do
94 f r c p
95 f m
85 96
86{- | Creates a vector by concatenation of rows 97{- | Creates a vector by concatenation of rows
87 98
@@ -236,7 +247,9 @@ transdataAux fun c1 d c2 =
236 then d 247 then d
237 else unsafePerformIO $ do 248 else unsafePerformIO $ do
238 v <- createVector (dim d) 249 v <- createVector (dim d)
239 fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d,v] 250 withForeignPtr (fptr d) $ \pd ->
251 withForeignPtr (fptr v) $ \pv ->
252 fun r1 c1 pd r2 c2 pv // check "transdataAux"
240 -- putStrLn $ "---> transdataAux" ++ show (toList d) ++ show (toList v) 253 -- putStrLn $ "---> transdataAux" ++ show (toList d) ++ show (toList v)
241 return v 254 return v
242 where r1 = dim d `div` c1 255 where r1 = dim d `div` c1
@@ -250,8 +263,8 @@ foreign import ccall safe "auxi.h transC"
250 263
251------------------------------------------------------------------ 264------------------------------------------------------------------
252 265
253gmatC MF {rows = r, cols = c, fdat = d} f = f 1 c r (ptr d) 266gmatC MF {rows = r, cols = c, fdat = d} p f = f 1 c r p
254gmatC MC {rows = r, cols = c, cdat = d} f = f 0 r c (ptr d) 267gmatC MC {rows = r, cols = c, cdat = d} p f = f 0 r c p
255 268
256dtt MC { cdat = d } = d 269dtt MC { cdat = d } = d
257dtt MF { fdat = d } = d 270dtt MF { fdat = d } = d
@@ -260,7 +273,9 @@ multiplyAux fun a b = unsafePerformIO $ do
260 when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ 273 when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++
261 show (rows a,cols a) ++ " x " ++ show (rows b, cols b) 274 show (rows a,cols a) ++ " x " ++ show (rows b, cols b)
262 r <- createMatrix RowMajor (rows a) (cols b) 275 r <- createMatrix RowMajor (rows a) (cols b)
263 fun // gmatC a // gmatC b // matc r // check "multiplyAux" [dtt a, dtt b, cdat r] 276 withForeignPtr (fptr (dtt a)) $ \pa -> withForeignPtr (fptr (dtt b)) $ \pb ->
277 withMatrix r $ \r ->
278 fun // gmatC a pa // gmatC b pb // r // check "multiplyAux"
264 return r 279 return r
265 280
266multiplyR = multiplyAux cmultiplyR 281multiplyR = multiplyAux cmultiplyR
@@ -293,7 +308,8 @@ subMatrixR :: (Int,Int) -> (Int,Int) -> Matrix Double -> Matrix Double
293subMatrixR (r0,c0) (rt,ct) x' = unsafePerformIO $ do 308subMatrixR (r0,c0) (rt,ct) x' = unsafePerformIO $ do
294 r <- createMatrix RowMajor rt ct 309 r <- createMatrix RowMajor rt ct
295 let x = cmat x' 310 let x = cmat x'
296 c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // matc x // matc r // check "subMatrixR" [cdat x] 311 ww2 withMatrix x withMatrix r $ \x r ->
312 c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // x // r // check "subMatrixR"
297 return r 313 return r
298foreign import ccall "auxi.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM 314foreign import ccall "auxi.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM
299 315
@@ -317,8 +333,9 @@ subMatrix = subMatrixD
317 333
318diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do 334diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do
319 m <- createMatrix RowMajor n n 335 m <- createMatrix RowMajor n n
320 fun // vec v // matc m // check msg [v] 336 ww2 withVector v withMatrix m $ \v m ->
321 return m -- {tdat = dat m} 337 fun // v // m // check msg
338 return m
322 339
323-- | diagonal matrix from a real vector 340-- | diagonal matrix from a real vector
324diagR :: Vector Double -> Matrix Double 341diagR :: Vector Double -> Matrix Double
@@ -339,7 +356,8 @@ diag = diagD
339constantAux fun x n = unsafePerformIO $ do 356constantAux fun x n = unsafePerformIO $ do
340 v <- createVector n 357 v <- createVector n
341 px <- newArray [x] 358 px <- newArray [x]
342 fun px // vec v // check "constantAux" [] 359 withVector v $ \v ->
360 fun px // v // check "constantAux"
343 free px 361 free px
344 return v 362 return v
345 363
@@ -385,7 +403,8 @@ fromFile :: FilePath -> (Int,Int) -> IO (Matrix Double)
385fromFile filename (r,c) = do 403fromFile filename (r,c) = do
386 charname <- newCString filename 404 charname <- newCString filename
387 res <- createMatrix RowMajor r c 405 res <- createMatrix RowMajor r c
388 c_gslReadMatrix charname // matc res // check "gslReadMatrix" [] 406 withMatrix res $ \res ->
407 c_gslReadMatrix charname // res // check "gslReadMatrix"
389 --free charname -- TO DO: free the auxiliary CString 408 --free charname -- TO DO: free the auxiliary CString
390 return res 409 return res
391foreign import ccall "auxi.h matrix_fscanf" c_gslReadMatrix:: Ptr CChar -> TM 410foreign import ccall "auxi.h matrix_fscanf" c_gslReadMatrix:: Ptr CChar -> TM
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs
index dc86484..386ebb5 100644
--- a/lib/Data/Packed/Internal/Vector.hs
+++ b/lib/Data/Packed/Internal/Vector.hs
@@ -31,11 +31,11 @@ data Vector t = V { dim :: Int -- ^ number of elements
31 , fptr :: ForeignPtr t -- ^ foreign pointer to the memory block 31 , fptr :: ForeignPtr t -- ^ foreign pointer to the memory block
32 } 32 }
33 33
34ptr (V _ fptr) = unsafeForeignPtrToPtr fptr 34--ptr (V _ fptr) = unsafeForeignPtrToPtr fptr
35 35
36-- | check the error code and touch foreign ptr of vector arguments (if any) 36-- | check the error code
37check :: String -> [Vector a] -> IO Int -> IO () 37check :: String -> IO Int -> IO ()
38check msg ls f = do 38check msg f = do
39 err <- f 39 err <- f
40 when (err/=0) $ if err > 1024 40 when (err/=0) $ if err > 1024
41 then (error (msg++": "++errorCode err)) -- our errors 41 then (error (msg++": "++errorCode err)) -- our errors
@@ -43,7 +43,6 @@ check msg ls f = do
43 ps <- gsl_strerror err 43 ps <- gsl_strerror err
44 s <- peekCString ps 44 s <- peekCString ps
45 error (msg++": "++s) 45 error (msg++": "++s)
46 mapM_ (touchForeignPtr . fptr) ls
47 return () 46 return ()
48 47
49-- | description of GSL error codes 48-- | description of GSL error codes
@@ -55,9 +54,14 @@ type Vc t s = Int -> Ptr t -> s
55-- infixr 5 :> 54-- infixr 5 :>
56-- type t :> s = Vc t s 55-- type t :> s = Vc t s
57 56
58-- | adaptation of our vectors to be admitted by foreign functions: @f \/\/ vec v@ 57--- | adaptation of our vectors to be admitted by foreign functions: @f \/\/ vec v@
59vec :: Vector t -> (Vc t s) -> s 58--vec :: Vector t -> (Vc t s) -> s
60vec v f = f (dim v) (ptr v) 59--vec v f = f (dim v) (ptr v)
60
61withVector (V n fp) f = withForeignPtr fp $ \p -> do
62 let v f = do
63 f n p
64 f v
61 65
62-- | allocates memory for a new vector 66-- | allocates memory for a new vector
63createVector :: Storable a => Int -> IO (Vector a) 67createVector :: Storable a => Int -> IO (Vector a)
@@ -76,7 +80,8 @@ fromList :: Storable a => [a] -> Vector a
76fromList l = unsafePerformIO $ do 80fromList l = unsafePerformIO $ do
77 v <- createVector (length l) 81 v <- createVector (length l)
78 let f _ p = pokeArray p l >> return 0 82 let f _ p = pokeArray p l >> return 0
79 f // vec v // check "fromList" [] 83 withVector v $ \v ->
84 f // v // check "fromList"
80 return v 85 return v
81 86
82safeRead v = unsafePerformIO . withForeignPtr (fptr v) 87safeRead v = unsafePerformIO . withForeignPtr (fptr v)
@@ -118,8 +123,9 @@ subVector k l (v@V {dim=n})
118 | k<0 || k >= n || k+l > n || l < 0 = error "subVector out of range" 123 | k<0 || k >= n || k+l > n || l < 0 = error "subVector out of range"
119 | otherwise = unsafePerformIO $ do 124 | otherwise = unsafePerformIO $ do
120 r <- createVector l 125 r <- createVector l
121 let f = copyArray (ptr r) (advancePtr (ptr v) k) l >> return 0 126 let f _ s _ d = copyArray d (advancePtr s k) l >> return 0
122 f // check "subVector" [v,r] 127 ww2 withVector v withVector r $ \v r ->
128 f // v // r // check "subVector"
123 return r 129 return r
124 130
125{- | Reads a vector position: 131{- | Reads a vector position:
@@ -144,12 +150,12 @@ join [] = error "joining zero vectors"
144join as = unsafePerformIO $ do 150join as = unsafePerformIO $ do
145 let tot = sum (map dim as) 151 let tot = sum (map dim as)
146 r@V {fptr = p} <- createVector tot 152 r@V {fptr = p} <- createVector tot
147 withForeignPtr p $ \_ -> 153 withForeignPtr p $ \ptr ->
148 joiner as tot (ptr r) 154 joiner as tot ptr
149 return r 155 return r
150 where joiner [] _ _ = return () 156 where joiner [] _ _ = return ()
151 joiner (r@V {dim = n, fptr = b} : cs) _ p = do 157 joiner (r@V {dim = n, fptr = b} : cs) _ p = do
152 withForeignPtr b $ \_ -> copyArray p (ptr r) n 158 withForeignPtr b $ \pb -> copyArray p pb n
153 joiner cs 0 (advancePtr p n) 159 joiner cs 0 (advancePtr p n)
154 160
155 161
diff --git a/lib/Numeric/GSL/Differentiation.hs b/lib/Numeric/GSL/Differentiation.hs
index e7fea92..09236bd 100644
--- a/lib/Numeric/GSL/Differentiation.hs
+++ b/lib/Numeric/GSL/Differentiation.hs
@@ -35,7 +35,7 @@ derivGen c h f x = unsafePerformIO $ do
35 r <- malloc 35 r <- malloc
36 e <- malloc 36 e <- malloc
37 fp <- mkfun (\x _ -> f x) 37 fp <- mkfun (\x _ -> f x)
38 c_deriv c fp x h r e // check "deriv" [] 38 c_deriv c fp x h r e // check "deriv"
39 vr <- peek r 39 vr <- peek r
40 ve <- peek e 40 ve <- peek e
41 let result = (vr,ve) 41 let result = (vr,ve)
diff --git a/lib/Numeric/GSL/Fourier.hs b/lib/Numeric/GSL/Fourier.hs
index e975fbf..4b08625 100644
--- a/lib/Numeric/GSL/Fourier.hs
+++ b/lib/Numeric/GSL/Fourier.hs
@@ -26,7 +26,8 @@ import Foreign
26 26
27genfft code v = unsafePerformIO $ do 27genfft code v = unsafePerformIO $ do
28 r <- createVector (dim v) 28 r <- createVector (dim v)
29 c_fft code // vec v // vec r // check "fft" [v] 29 ww2 withVector v withVector r $ \ v r ->
30 c_fft code // v // r // check "fft"
30 return r 31 return r
31 32
32foreign import ccall "gsl-aux.h fft" c_fft :: Int -> TCVCV 33foreign import ccall "gsl-aux.h fft" c_fft :: Int -> TCVCV
diff --git a/lib/Numeric/GSL/Integration.hs b/lib/Numeric/GSL/Integration.hs
index d756417..747b34c 100644
--- a/lib/Numeric/GSL/Integration.hs
+++ b/lib/Numeric/GSL/Integration.hs
@@ -44,7 +44,7 @@ integrateQAGS prec n f a b = unsafePerformIO $ do
44 r <- malloc 44 r <- malloc
45 e <- malloc 45 e <- malloc
46 fp <- mkfun (\x _ -> f x) 46 fp <- mkfun (\x _ -> f x)
47 c_integrate_qags fp a b prec n r e // check "integrate_qags" [] 47 c_integrate_qags fp a b prec n r e // check "integrate_qags"
48 vr <- peek r 48 vr <- peek r
49 ve <- peek e 49 ve <- peek e
50 let result = (vr,ve) 50 let result = (vr,ve)
@@ -75,7 +75,7 @@ integrateQNG prec f a b = unsafePerformIO $ do
75 r <- malloc 75 r <- malloc
76 e <- malloc 76 e <- malloc
77 fp <- mkfun (\x _ -> f x) 77 fp <- mkfun (\x _ -> f x)
78 c_integrate_qng fp a b prec r e // check "integrate_qng" [] 78 c_integrate_qng fp a b prec r e // check "integrate_qng"
79 vr <- peek r 79 vr <- peek r
80 ve <- peek e 80 ve <- peek e
81 let result = (vr,ve) 81 let result = (vr,ve)
diff --git a/lib/Numeric/GSL/Matrix.hs b/lib/Numeric/GSL/Matrix.hs
index e803c53..09a0be4 100644
--- a/lib/Numeric/GSL/Matrix.hs
+++ b/lib/Numeric/GSL/Matrix.hs
@@ -51,7 +51,8 @@ eigSg' m
51 | otherwise = unsafePerformIO $ do 51 | otherwise = unsafePerformIO $ do
52 l <- createVector r 52 l <- createVector r
53 v <- createMatrix RowMajor r r 53 v <- createMatrix RowMajor r r
54 c_eigS // matc m // vec l // matc v // check "eigSg" [cdat m] 54 ww3 withMatrix m withVector l withMatrix v $ \m l v ->
55 c_eigS // m // l // v // check "eigSg"
55 return (l,v) 56 return (l,v)
56 where r = rows m 57 where r = rows m
57foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM 58foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM
@@ -84,7 +85,8 @@ eigHg' m
84 | otherwise = unsafePerformIO $ do 85 | otherwise = unsafePerformIO $ do
85 l <- createVector r 86 l <- createVector r
86 v <- createMatrix RowMajor r r 87 v <- createMatrix RowMajor r r
87 c_eigH // matc m // vec l // matc v // check "eigHg" [cdat m] 88 ww3 withMatrix m withVector l withMatrix v $ \m l v ->
89 c_eigH // m // l // v // check "eigHg"
88 return (l,v) 90 return (l,v)
89 where r = rows m 91 where r = rows m
90foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM 92foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM
@@ -120,7 +122,8 @@ svd' x = unsafePerformIO $ do
120 u <- createMatrix RowMajor r c 122 u <- createMatrix RowMajor r c
121 s <- createVector c 123 s <- createVector c
122 v <- createMatrix RowMajor c c 124 v <- createMatrix RowMajor c c
123 c_svd // matc x // matc u // vec s // matc v // check "svdg" [cdat x] 125 ww4 withMatrix x withMatrix u withVector s withMatrix v $ \x u s v ->
126 c_svd // x // u // s // v // check "svdg"
124 return (u,s,v) 127 return (u,s,v)
125 where r = rows x 128 where r = rows x
126 c = cols x 129 c = cols x
@@ -149,7 +152,8 @@ qr = qr' . cmat
149qr' x = unsafePerformIO $ do 152qr' x = unsafePerformIO $ do
150 q <- createMatrix RowMajor r r 153 q <- createMatrix RowMajor r r
151 rot <- createMatrix RowMajor r c 154 rot <- createMatrix RowMajor r c
152 c_qr // matc x // matc q // matc rot // check "qr" [cdat x] 155 ww3 withMatrix x withMatrix q withMatrix rot $ \x q rot ->
156 c_qr // x // q // rot // check "qr"
153 return (q,rot) 157 return (q,rot)
154 where r = rows x 158 where r = rows x
155 c = cols x 159 c = cols x
@@ -161,7 +165,8 @@ qrPacked = qrPacked' . cmat
161qrPacked' x = unsafePerformIO $ do 165qrPacked' x = unsafePerformIO $ do
162 qr <- createMatrix RowMajor r c 166 qr <- createMatrix RowMajor r c
163 tau <- createVector (min r c) 167 tau <- createVector (min r c)
164 c_qrPacked // matc x // matc qr // vec tau // check "qrUnpacked" [cdat x] 168 ww3 withMatrix x withMatrix qr withVector tau $ \x qr tau ->
169 c_qrPacked // x // qr // tau // check "qrUnpacked"
165 return (qr,tau) 170 return (qr,tau)
166 where r = rows x 171 where r = rows x
167 c = cols x 172 c = cols x
@@ -172,9 +177,10 @@ unpackQR (qr,tau) = unpackQR' (cmat qr, tau)
172 177
173unpackQR' (qr,tau) = unsafePerformIO $ do 178unpackQR' (qr,tau) = unsafePerformIO $ do
174 q <- createMatrix RowMajor r r 179 q <- createMatrix RowMajor r r
175 rot <- createMatrix RowMajor r c 180 res <- createMatrix RowMajor r c
176 c_qrUnpack // matc qr // vec tau // matc q // matc rot // check "qrUnpack" [cdat qr,tau] 181 ww4 withMatrix qr withVector tau withMatrix q withMatrix res $ \qr tau q res ->
177 return (q,rot) 182 c_qrUnpack // qr // tau // q // res // check "qrUnpack"
183 return (q,res)
178 where r = rows qr 184 where r = rows qr
179 c = cols qr 185 c = cols qr
180foreign import ccall "gsl-aux.h QRunpack" c_qrUnpack :: TMVMM 186foreign import ccall "gsl-aux.h QRunpack" c_qrUnpack :: TMVMM
@@ -196,20 +202,22 @@ cholR :: Matrix Double -> Matrix Double
196cholR = cholR' . cmat 202cholR = cholR' . cmat
197 203
198cholR' x = unsafePerformIO $ do 204cholR' x = unsafePerformIO $ do
199 res <- createMatrix RowMajor r r 205 r <- createMatrix RowMajor n n
200 c_cholR // matc x // matc res // check "cholR" [cdat x] 206 ww2 withMatrix x withMatrix r $ \x r ->
201 return res 207 c_cholR // x // r // check "cholR"
202 where r = rows x 208 return r
209 where n = rows x
203foreign import ccall "gsl-aux.h cholR" c_cholR :: TMM 210foreign import ccall "gsl-aux.h cholR" c_cholR :: TMM
204 211
205cholC :: Matrix (Complex Double) -> Matrix (Complex Double) 212cholC :: Matrix (Complex Double) -> Matrix (Complex Double)
206cholC = cholC' . cmat 213cholC = cholC' . cmat
207 214
208cholC' x = unsafePerformIO $ do 215cholC' x = unsafePerformIO $ do
209 res <- createMatrix RowMajor r r 216 r <- createMatrix RowMajor n n
210 c_cholC // matc x // matc res // check "cholC" [cdat x] 217 ww2 withMatrix x withMatrix r $ \x r ->
211 return res 218 c_cholC // x // r // check "cholC"
212 where r = rows x 219 return r
220 where n = rows x
213foreign import ccall "gsl-aux.h cholC" c_cholC :: TCMCM 221foreign import ccall "gsl-aux.h cholC" c_cholC :: TCMCM
214 222
215 223
@@ -223,7 +231,8 @@ luSolveR a b = luSolveR' (cmat a) (cmat b)
223luSolveR' a b 231luSolveR' a b
224 | n1==n2 && n1==r = unsafePerformIO $ do 232 | n1==n2 && n1==r = unsafePerformIO $ do
225 s <- createMatrix RowMajor r c 233 s <- createMatrix RowMajor r c
226 c_luSolveR // matc a // matc b // matc s // check "luSolveR" [cdat a, cdat b] 234 ww3 withMatrix a withMatrix b withMatrix s $ \ a b s ->
235 c_luSolveR // a // b // s // check "luSolveR"
227 return s 236 return s
228 | otherwise = error "luSolveR of nonsquare matrix" 237 | otherwise = error "luSolveR of nonsquare matrix"
229 where n1 = rows a 238 where n1 = rows a
@@ -240,7 +249,8 @@ luSolveC a b = luSolveC' (cmat a) (cmat b)
240luSolveC' a b 249luSolveC' a b
241 | n1==n2 && n1==r = unsafePerformIO $ do 250 | n1==n2 && n1==r = unsafePerformIO $ do
242 s <- createMatrix RowMajor r c 251 s <- createMatrix RowMajor r c
243 c_luSolveC // matc a // matc b // matc s // check "luSolveC" [cdat a, cdat b] 252 ww3 withMatrix a withMatrix b withMatrix s $ \ a b s ->
253 c_luSolveC // a // b // s // check "luSolveC"
244 return s 254 return s
245 | otherwise = error "luSolveC of nonsquare matrix" 255 | otherwise = error "luSolveC of nonsquare matrix"
246 where n1 = rows a 256 where n1 = rows a
@@ -256,7 +266,8 @@ luRaux = luRaux' . cmat
256 266
257luRaux' x = unsafePerformIO $ do 267luRaux' x = unsafePerformIO $ do
258 res <- createVector (r*r+r+1) 268 res <- createVector (r*r+r+1)
259 c_luRaux // matc x // vec res // check "luRaux" [cdat x] 269 ww2 withMatrix x withVector res $ \x res ->
270 c_luRaux // x // res // check "luRaux"
260 return res 271 return res
261 where r = rows x 272 where r = rows x
262 c = cols x 273 c = cols x
@@ -269,7 +280,8 @@ luCaux = luCaux' . cmat
269 280
270luCaux' x = unsafePerformIO $ do 281luCaux' x = unsafePerformIO $ do
271 res <- createVector (r*r+r+1) 282 res <- createVector (r*r+r+1)
272 c_luCaux // matc x // vec res // check "luCaux" [cdat x] 283 ww2 withMatrix x withVector res $ \x res ->
284 c_luCaux // x // res // check "luCaux"
273 return res 285 return res
274 where r = rows x 286 where r = rows x
275 c = cols x 287 c = cols x
diff --git a/lib/Numeric/GSL/Minimization.hs b/lib/Numeric/GSL/Minimization.hs
index f523849..e44b2e5 100644
--- a/lib/Numeric/GSL/Minimization.hs
+++ b/lib/Numeric/GSL/Minimization.hs
@@ -26,7 +26,6 @@ import Data.Packed.Matrix
26import Foreign 26import Foreign
27import Complex 27import Complex
28 28
29
30------------------------------------------------------------------------- 29-------------------------------------------------------------------------
31 30
32{- | The method of Nelder and Mead, implemented by /gsl_multimin_fminimizer_nmsimplex/. The gradient of the function is not required. This is the example in the GSL manual: 31{- | The method of Nelder and Mead, implemented by /gsl_multimin_fminimizer_nmsimplex/. The gradient of the function is not required. This is the example in the GSL manual:
@@ -85,9 +84,10 @@ minimizeNMSimplex f xi sz tol maxit = unsafePerformIO $ do
85 szv = fromList sz 84 szv = fromList sz
86 n = dim xiv 85 n = dim xiv
87 fp <- mkVecfun (iv (f.toList)) 86 fp <- mkVecfun (iv (f.toList))
88 rawpath <- createMIO maxit (n+3) 87 rawpath <- ww2 withVector xiv withVector szv $ \xiv szv ->
89 (c_minimizeNMSimplex fp tol maxit // vec xiv // vec szv) 88 createMIO maxit (n+3)
90 "minimizeNMSimplex" [xiv,szv] 89 (c_minimizeNMSimplex fp tol maxit // xiv // szv)
90 "minimizeNMSimplex"
91 let it = round (rawpath @@> (maxit-1,0)) 91 let it = round (rawpath @@> (maxit-1,0))
92 path = takeRows it rawpath 92 path = takeRows it rawpath
93 [sol] = toLists $ dropRows (it-1) path 93 [sol] = toLists $ dropRows (it-1) path
@@ -150,9 +150,10 @@ minimizeConjugateGradient istep minimpar tol maxit f df xi = unsafePerformIO $ d
150 df' = (fromList . df . toList) 150 df' = (fromList . df . toList)
151 fp <- mkVecfun (iv f') 151 fp <- mkVecfun (iv f')
152 dfp <- mkVecVecfun (aux_vTov df') 152 dfp <- mkVecVecfun (aux_vTov df')
153 rawpath <- createMIO maxit (n+2) 153 rawpath <- withVector xiv $ \xiv ->
154 (c_minimizeConjugateGradient fp dfp istep minimpar tol maxit // vec xiv) 154 createMIO maxit (n+2)
155 "minimizeDerivV" [xiv] 155 (c_minimizeConjugateGradient fp dfp istep minimpar tol maxit // xiv)
156 "minimizeDerivV"
156 let it = round (rawpath @@> (maxit-1,0)) 157 let it = round (rawpath @@> (maxit-1,0))
157 path = takeRows it rawpath 158 path = takeRows it rawpath
158 sol = toList $ cdat $ dropColumns 2 $ dropRows (it-1) path 159 sol = toList $ cdat $ dropColumns 2 $ dropRows (it-1) path
@@ -169,7 +170,7 @@ foreign import ccall "gsl-aux.h minimizeWithDeriv"
169 170
170--------------------------------------------------------------------- 171---------------------------------------------------------------------
171iv :: (Vector Double -> Double) -> (Int -> Ptr Double -> Double) 172iv :: (Vector Double -> Double) -> (Int -> Ptr Double -> Double)
172iv f n p = f (createV n copy "iv" []) where 173iv f n p = f (createV n copy "iv") where
173 copy n q = do 174 copy n q = do
174 copyArray q p n 175 copyArray q p n
175 return 0 176 return 0
@@ -187,25 +188,30 @@ foreign import ccall "wrapper"
187aux_vTov :: (Vector Double -> Vector Double) -> (Int -> Ptr Double -> Ptr Double -> IO()) 188aux_vTov :: (Vector Double -> Vector Double) -> (Int -> Ptr Double -> Ptr Double -> IO())
188aux_vTov f n p r = g where 189aux_vTov f n p r = g where
189 v@V {fptr = pr} = f x 190 v@V {fptr = pr} = f x
190 x = createV n copy "aux_vTov" [] 191 x = createV n copy "aux_vTov"
191 copy n q = do 192 copy n q = do
192 copyArray q p n 193 copyArray q p n
193 return 0 194 return 0
194 g = withForeignPtr pr $ \_ -> copyArray r (ptr v) n 195 g = withForeignPtr pr $ \p -> copyArray r p n
195 196
196-------------------------------------------------------------------- 197--------------------------------------------------------------------
197 198
198createV n fun msg ptrs = unsafePerformIO $ do 199
200createV n fun msg = unsafePerformIO $ do
199 r <- createVector n 201 r <- createVector n
200 fun // vec r // check msg ptrs 202 withVector r $ \ r ->
203 fun // r // check msg
201 return r 204 return r
202 205
206{-
203createM r c fun msg ptrs = unsafePerformIO $ do 207createM r c fun msg ptrs = unsafePerformIO $ do
204 r <- createMatrix RowMajor r c 208 r <- createMatrix RowMajor r c
205 fun // matc r // check msg ptrs 209 fun // matc r // check msg ptrs
206 return r 210 return r
211-}
207 212
208createMIO r c fun msg ptrs = do 213createMIO r c fun msg = do
209 r <- createMatrix RowMajor r c 214 r <- createMatrix RowMajor r c
210 fun // matc r // check msg ptrs 215 withMatrix r $ \ r ->
216 fun // r // check msg
211 return r 217 return r
diff --git a/lib/Numeric/GSL/Polynomials.hs b/lib/Numeric/GSL/Polynomials.hs
index 42694f0..e663711 100644
--- a/lib/Numeric/GSL/Polynomials.hs
+++ b/lib/Numeric/GSL/Polynomials.hs
@@ -47,7 +47,8 @@ polySolve = toList . polySolve' . fromList
47polySolve' :: Vector Double -> Vector (Complex Double) 47polySolve' :: Vector Double -> Vector (Complex Double)
48polySolve' v | dim v > 1 = unsafePerformIO $ do 48polySolve' v | dim v > 1 = unsafePerformIO $ do
49 r <- createVector (dim v-1) 49 r <- createVector (dim v-1)
50 c_polySolve // vec v // vec r // check "polySolve" [v] 50 ww2 withVector v withVector r $ \ v r ->
51 c_polySolve // v // r // check "polySolve"
51 return r 52 return r
52 | otherwise = error "polySolve on a polynomial of degree zero" 53 | otherwise = error "polySolve on a polynomial of degree zero"
53 54
diff --git a/lib/Numeric/GSL/Special/Internal.hs b/lib/Numeric/GSL/Special/Internal.hs
index a08809b..ca36009 100644
--- a/lib/Numeric/GSL/Special/Internal.hs
+++ b/lib/Numeric/GSL/Special/Internal.hs
@@ -45,7 +45,7 @@ type Size_t = Int
45createSFR :: Storable a => String -> (Ptr a -> IO Int) -> (a, a) 45createSFR :: Storable a => String -> (Ptr a -> IO Int) -> (a, a)
46createSFR s f = unsafePerformIO $ do 46createSFR s f = unsafePerformIO $ do
47 p <- mallocArray 2 47 p <- mallocArray 2
48 f p // check s [] 48 f p // check s
49 [val,err] <- peekArray 2 p 49 [val,err] <- peekArray 2 p
50 free p 50 free p
51 return (val,err) 51 return (val,err)
@@ -60,7 +60,7 @@ createSFR_E10 s f = unsafePerformIO $ do
60 let sd = sizeOf (0::Double) 60 let sd = sizeOf (0::Double)
61 let si = sizeOf (0::Int) 61 let si = sizeOf (0::Int)
62 p <- mallocBytes (2*sd + si) 62 p <- mallocBytes (2*sd + si)
63 f p // check s [] 63 f p // check s
64 val <- peekByteOff p 0 64 val <- peekByteOff p 0
65 err <- peekByteOff p sd 65 err <- peekByteOff p sd
66 expo <- peekByteOff p (2*sd) 66 expo <- peekByteOff p (2*sd)
diff --git a/lib/Numeric/GSL/Vector.hs b/lib/Numeric/GSL/Vector.hs
index d94b377..65f3a2e 100644
--- a/lib/Numeric/GSL/Vector.hs
+++ b/lib/Numeric/GSL/Vector.hs
@@ -73,24 +73,28 @@ data FunCodeS = Norm2
73 73
74toScalarAux fun code v = unsafePerformIO $ do 74toScalarAux fun code v = unsafePerformIO $ do
75 r <- createVector 1 75 r <- createVector 1
76 fun (fromEnum code) // vec v // vec r // check "toScalarAux" [v] 76 ww2 withVector v withVector r $ \v r ->
77 fun (fromEnum code) // v // r // check "toScalarAux"
77 return (r `at` 0) 78 return (r `at` 0)
78 79
79vectorMapAux fun code v = unsafePerformIO $ do 80vectorMapAux fun code v = unsafePerformIO $ do
80 r <- createVector (dim v) 81 r <- createVector (dim v)
81 fun (fromEnum code) // vec v // vec r // check "vectorMapAux" [v] 82 ww2 withVector v withVector r $ \v r ->
83 fun (fromEnum code) // v // r // check "vectorMapAux"
82 return r 84 return r
83 85
84vectorMapValAux fun code val v = unsafePerformIO $ do 86vectorMapValAux fun code val v = unsafePerformIO $ do
85 r <- createVector (dim v) 87 r <- createVector (dim v)
86 pval <- newArray [val] 88 pval <- newArray [val]
87 fun (fromEnum code) pval // vec v // vec r // check "vectorMapValAux" [v] 89 ww2 withVector v withVector r $ \v r ->
90 fun (fromEnum code) pval // v // r // check "vectorMapValAux"
88 free pval 91 free pval
89 return r 92 return r
90 93
91vectorZipAux fun code u v = unsafePerformIO $ do 94vectorZipAux fun code u v = unsafePerformIO $ do
92 r <- createVector (dim u) 95 r <- createVector (dim u)
93 fun (fromEnum code) // vec u // vec v // vec r // check "vectorZipAux" [u,v] 96 ww3 withVector u withVector v withVector r $ \u v r ->
97 fun (fromEnum code) // u // v // r // check "vectorZipAux"
94 return r 98 return r
95 99
96--------------------------------------------------------------------- 100---------------------------------------------------------------------
diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs
index 315be17..19516e3 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK.hs
+++ b/lib/Numeric/LinearAlgebra/LAPACK.hs
@@ -61,7 +61,8 @@ svdAux f st x = unsafePerformIO $ do
61 u <- createMatrix ColumnMajor r r 61 u <- createMatrix ColumnMajor r r
62 s <- createVector (min r c) 62 s <- createVector (min r c)
63 v <- createMatrix ColumnMajor c c 63 v <- createMatrix ColumnMajor c c
64 f // matf x // matf u // vec s // matf v // check st [fdat x] 64 ww4 withMatrix x withMatrix u withVector s withMatrix v $ \x u s v ->
65 f // x // u // s // v // check st
65 return (u,s,trans v) 66 return (u,s,trans v)
66 where r = rows x 67 where r = rows x
67 c = cols x 68 c = cols x
@@ -73,7 +74,8 @@ eigAux f st m
73 l <- createVector r 74 l <- createVector r
74 v <- createMatrix ColumnMajor r r 75 v <- createMatrix ColumnMajor r r
75 dummy <- createMatrix ColumnMajor 1 1 76 dummy <- createMatrix ColumnMajor 1 1
76 f // matf m // matf dummy // vec l // matf v // check st [fdat m] 77 ww4 withMatrix m withMatrix dummy withVector l withMatrix v $ \m dummy l v ->
78 f // m // dummy // l // v // check st
77 return (l,v) 79 return (l,v)
78 where r = rows m 80 where r = rows m
79 81
@@ -115,7 +117,8 @@ eigRaux m
115 l <- createVector r 117 l <- createVector r
116 v <- createMatrix ColumnMajor r r 118 v <- createMatrix ColumnMajor r r
117 dummy <- createMatrix ColumnMajor 1 1 119 dummy <- createMatrix ColumnMajor 1 1
118 dgeev // matf m // matf dummy // vec l // matf v // check "eigR" [fdat m] 120 ww4 withMatrix m withMatrix dummy withVector l withMatrix v $ \m dummy l v ->
121 dgeev // m // dummy // l // v // check "eigR"
119 return (l,v) 122 return (l,v)
120 where r = rows m 123 where r = rows m
121 124
@@ -144,7 +147,8 @@ eigS' m
144 | otherwise = unsafePerformIO $ do 147 | otherwise = unsafePerformIO $ do
145 l <- createVector r 148 l <- createVector r
146 v <- createMatrix ColumnMajor r r 149 v <- createMatrix ColumnMajor r r
147 dsyev // matf m // vec l // matf v // check "eigS" [fdat m] 150 ww3 withMatrix m withVector l withMatrix v $ \m l v ->
151 dsyev // m // l // v // check "eigS"
148 return (l,v) 152 return (l,v)
149 where r = rows m 153 where r = rows m
150 154
@@ -166,7 +170,8 @@ eigH' m
166 | otherwise = unsafePerformIO $ do 170 | otherwise = unsafePerformIO $ do
167 l <- createVector r 171 l <- createVector r
168 v <- createMatrix ColumnMajor r r 172 v <- createMatrix ColumnMajor r r
169 zheev // matf m // vec l // matf v // check "eigH" [fdat m] 173 ww3 withMatrix m withVector l withMatrix v $ \m l v ->
174 zheev // m // l // v // check "eigH"
170 return (l,v) 175 return (l,v)
171 where r = rows m 176 where r = rows m
172 177
@@ -177,7 +182,8 @@ foreign import ccall "LAPACK/lapack-aux.h linearSolveC_l" zgesv :: TCMCMCM
177linearSolveSQAux f st a b 182linearSolveSQAux f st a b
178 | n1==n2 && n1==r = unsafePerformIO $ do 183 | n1==n2 && n1==r = unsafePerformIO $ do
179 s <- createMatrix ColumnMajor r c 184 s <- createMatrix ColumnMajor r c
180 f // matf a // matf b // matf s // check st [fdat a, fdat b] 185 ww3 withMatrix a withMatrix b withMatrix s $ \a b s ->
186 f // a // b // s // check st
181 return s 187 return s
182 | otherwise = error $ st ++ " of nonsquare matrix" 188 | otherwise = error $ st ++ " of nonsquare matrix"
183 where n1 = rows a 189 where n1 = rows a
@@ -201,7 +207,8 @@ foreign import ccall "LAPACK/lapack-aux.h linearSolveSVDC_l" zgelss :: Double ->
201 207
202linearSolveAux f st a b = unsafePerformIO $ do 208linearSolveAux f st a b = unsafePerformIO $ do
203 r <- createMatrix ColumnMajor (max m n) nrhs 209 r <- createMatrix ColumnMajor (max m n) nrhs
204 f // matf a // matf b // matf r // check st [fdat a, fdat b] 210 ww3 withMatrix a withMatrix b withMatrix r $ \a b r ->
211 f // a // b // r // check st
205 return r 212 return r
206 where m = rows a 213 where m = rows a
207 n = cols a 214 n = cols a
@@ -251,7 +258,8 @@ cholS = cholAux dpotrf "cholS" . fmat
251 258
252cholAux f st a = unsafePerformIO $ do 259cholAux f st a = unsafePerformIO $ do
253 r <- createMatrix ColumnMajor n n 260 r <- createMatrix ColumnMajor n n
254 f // matf a // matf r // check st [fdat a] 261 ww2 withMatrix a withMatrix r $ \a r ->
262 f // a // r // check st
255 return r 263 return r
256 where n = rows a 264 where n = rows a
257 265
@@ -270,8 +278,8 @@ qrC = qrAux zgeqr2 "qrC" . fmat
270qrAux f st a = unsafePerformIO $ do 278qrAux f st a = unsafePerformIO $ do
271 r <- createMatrix ColumnMajor m n 279 r <- createMatrix ColumnMajor m n
272 tau <- createVector mn 280 tau <- createVector mn
273 withForeignPtr (fptr $ fdat $ a) $ \p -> 281 ww3 withMatrix a withMatrix r withVector tau $ \ a r tau ->
274 f m n p // vec tau // matf r // check st [fdat a] 282 f // a // tau // r // check st
275 return (r,tau) 283 return (r,tau)
276 where m = rows a 284 where m = rows a
277 n = cols a 285 n = cols a
@@ -292,7 +300,8 @@ hessC = hessAux zgehrd "hessC" . fmat
292hessAux f st a = unsafePerformIO $ do 300hessAux f st a = unsafePerformIO $ do
293 r <- createMatrix ColumnMajor m n 301 r <- createMatrix ColumnMajor m n
294 tau <- createVector (mn-1) 302 tau <- createVector (mn-1)
295 f // matf a // vec tau // matf r // check st [fdat a] 303 ww3 withMatrix a withMatrix r withVector tau $ \ a r tau ->
304 f // a // tau // r // check st
296 return (r,tau) 305 return (r,tau)
297 where m = rows a 306 where m = rows a
298 n = cols a 307 n = cols a
@@ -313,7 +322,8 @@ schurC = schurAux zgees "schurC" . fmat
313schurAux f st a = unsafePerformIO $ do 322schurAux f st a = unsafePerformIO $ do
314 u <- createMatrix ColumnMajor n n 323 u <- createMatrix ColumnMajor n n
315 s <- createMatrix ColumnMajor n n 324 s <- createMatrix ColumnMajor n n
316 f // matf a // matf u // matf s // check st [fdat a] 325 ww3 withMatrix a withMatrix u withMatrix s $ \ a u s ->
326 f // a // u // s // check st
317 return (u,s) 327 return (u,s)
318 where n = rows a 328 where n = rows a
319 329