diff options
Diffstat (limited to 'lib/Data/Packed/Internal')
-rw-r--r-- | lib/Data/Packed/Internal/Common.hs | 5 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 41 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Vector.hs | 34 |
3 files changed, 55 insertions, 25 deletions
diff --git a/lib/Data/Packed/Internal/Common.hs b/lib/Data/Packed/Internal/Common.hs index 2b3ec28..c3a733c 100644 --- a/lib/Data/Packed/Internal/Common.hs +++ b/lib/Data/Packed/Internal/Common.hs | |||
@@ -60,6 +60,11 @@ common f = commonval . map f where | |||
60 | infixl 0 // | 60 | infixl 0 // |
61 | (//) = flip ($) | 61 | (//) = flip ($) |
62 | 62 | ||
63 | -- hmm.. | ||
64 | ww2 w1 o1 w2 o2 f = w1 o1 $ \a1 -> w2 o2 $ \a2 -> f a1 a2 | ||
65 | ww3 w1 o1 w2 o2 w3 o3 f = w1 o1 $ \a1 -> ww2 w2 o2 w3 o3 (f a1) | ||
66 | ww4 w1 o1 w2 o2 w3 o3 w4 o4 f = w1 o1 $ \a1 -> ww3 w2 o2 w3 o3 w4 o4 (f a1) | ||
67 | |||
63 | -- GSL error codes are <= 1024 | 68 | -- GSL error codes are <= 1024 |
64 | -- | error codes for the auxiliary functions required by the wrappers | 69 | -- | error codes for the auxiliary functions required by the wrappers |
65 | errorCode :: Int -> String | 70 | errorCode :: Int -> String |
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 5617996..90a96b5 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -79,9 +79,20 @@ cmat MF {rows = r, cols = c, fdat = d } = MC {rows = r, cols = c, cdat = transda | |||
79 | fmat m@MF{} = m | 79 | fmat m@MF{} = m |
80 | fmat MC {rows = r, cols = c, cdat = d } = MF {rows = r, cols = c, fdat = transdata c d r} | 80 | fmat MC {rows = r, cols = c, cdat = d } = MF {rows = r, cols = c, fdat = transdata c d r} |
81 | 81 | ||
82 | matc m f = f (rows m) (cols m) (ptr (cdat m)) | 82 | --matc m f = f (rows m) (cols m) (ptr (cdat m)) |
83 | matf m f = f (rows m) (cols m) (ptr (fdat m)) | 83 | --matf m f = f (rows m) (cols m) (ptr (fdat m)) |
84 | 84 | ||
85 | withMatrix MC {rows = r, cols = c, cdat = d } f = | ||
86 | withForeignPtr (fptr d) $ \p -> do | ||
87 | let m f = do | ||
88 | f r c p | ||
89 | f m | ||
90 | |||
91 | withMatrix MF {rows = r, cols = c, fdat = d } f = | ||
92 | withForeignPtr (fptr d) $ \p -> do | ||
93 | let m f = do | ||
94 | f r c p | ||
95 | f m | ||
85 | 96 | ||
86 | {- | Creates a vector by concatenation of rows | 97 | {- | Creates a vector by concatenation of rows |
87 | 98 | ||
@@ -236,7 +247,9 @@ transdataAux fun c1 d c2 = | |||
236 | then d | 247 | then d |
237 | else unsafePerformIO $ do | 248 | else unsafePerformIO $ do |
238 | v <- createVector (dim d) | 249 | v <- createVector (dim d) |
239 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d,v] | 250 | withForeignPtr (fptr d) $ \pd -> |
251 | withForeignPtr (fptr v) $ \pv -> | ||
252 | fun r1 c1 pd r2 c2 pv // check "transdataAux" | ||
240 | -- putStrLn $ "---> transdataAux" ++ show (toList d) ++ show (toList v) | 253 | -- putStrLn $ "---> transdataAux" ++ show (toList d) ++ show (toList v) |
241 | return v | 254 | return v |
242 | where r1 = dim d `div` c1 | 255 | where r1 = dim d `div` c1 |
@@ -250,8 +263,8 @@ foreign import ccall safe "auxi.h transC" | |||
250 | 263 | ||
251 | ------------------------------------------------------------------ | 264 | ------------------------------------------------------------------ |
252 | 265 | ||
253 | gmatC MF {rows = r, cols = c, fdat = d} f = f 1 c r (ptr d) | 266 | gmatC MF {rows = r, cols = c, fdat = d} p f = f 1 c r p |
254 | gmatC MC {rows = r, cols = c, cdat = d} f = f 0 r c (ptr d) | 267 | gmatC MC {rows = r, cols = c, cdat = d} p f = f 0 r c p |
255 | 268 | ||
256 | dtt MC { cdat = d } = d | 269 | dtt MC { cdat = d } = d |
257 | dtt MF { fdat = d } = d | 270 | dtt MF { fdat = d } = d |
@@ -260,7 +273,9 @@ multiplyAux fun a b = unsafePerformIO $ do | |||
260 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ | 273 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ |
261 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) | 274 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) |
262 | r <- createMatrix RowMajor (rows a) (cols b) | 275 | r <- createMatrix RowMajor (rows a) (cols b) |
263 | fun // gmatC a // gmatC b // matc r // check "multiplyAux" [dtt a, dtt b, cdat r] | 276 | withForeignPtr (fptr (dtt a)) $ \pa -> withForeignPtr (fptr (dtt b)) $ \pb -> |
277 | withMatrix r $ \r -> | ||
278 | fun // gmatC a pa // gmatC b pb // r // check "multiplyAux" | ||
264 | return r | 279 | return r |
265 | 280 | ||
266 | multiplyR = multiplyAux cmultiplyR | 281 | multiplyR = multiplyAux cmultiplyR |
@@ -293,7 +308,8 @@ subMatrixR :: (Int,Int) -> (Int,Int) -> Matrix Double -> Matrix Double | |||
293 | subMatrixR (r0,c0) (rt,ct) x' = unsafePerformIO $ do | 308 | subMatrixR (r0,c0) (rt,ct) x' = unsafePerformIO $ do |
294 | r <- createMatrix RowMajor rt ct | 309 | r <- createMatrix RowMajor rt ct |
295 | let x = cmat x' | 310 | let x = cmat x' |
296 | c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // matc x // matc r // check "subMatrixR" [cdat x] | 311 | ww2 withMatrix x withMatrix r $ \x r -> |
312 | c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // x // r // check "subMatrixR" | ||
297 | return r | 313 | return r |
298 | foreign import ccall "auxi.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM | 314 | foreign import ccall "auxi.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM |
299 | 315 | ||
@@ -317,8 +333,9 @@ subMatrix = subMatrixD | |||
317 | 333 | ||
318 | diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do | 334 | diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do |
319 | m <- createMatrix RowMajor n n | 335 | m <- createMatrix RowMajor n n |
320 | fun // vec v // matc m // check msg [v] | 336 | ww2 withVector v withMatrix m $ \v m -> |
321 | return m -- {tdat = dat m} | 337 | fun // v // m // check msg |
338 | return m | ||
322 | 339 | ||
323 | -- | diagonal matrix from a real vector | 340 | -- | diagonal matrix from a real vector |
324 | diagR :: Vector Double -> Matrix Double | 341 | diagR :: Vector Double -> Matrix Double |
@@ -339,7 +356,8 @@ diag = diagD | |||
339 | constantAux fun x n = unsafePerformIO $ do | 356 | constantAux fun x n = unsafePerformIO $ do |
340 | v <- createVector n | 357 | v <- createVector n |
341 | px <- newArray [x] | 358 | px <- newArray [x] |
342 | fun px // vec v // check "constantAux" [] | 359 | withVector v $ \v -> |
360 | fun px // v // check "constantAux" | ||
343 | free px | 361 | free px |
344 | return v | 362 | return v |
345 | 363 | ||
@@ -385,7 +403,8 @@ fromFile :: FilePath -> (Int,Int) -> IO (Matrix Double) | |||
385 | fromFile filename (r,c) = do | 403 | fromFile filename (r,c) = do |
386 | charname <- newCString filename | 404 | charname <- newCString filename |
387 | res <- createMatrix RowMajor r c | 405 | res <- createMatrix RowMajor r c |
388 | c_gslReadMatrix charname // matc res // check "gslReadMatrix" [] | 406 | withMatrix res $ \res -> |
407 | c_gslReadMatrix charname // res // check "gslReadMatrix" | ||
389 | --free charname -- TO DO: free the auxiliary CString | 408 | --free charname -- TO DO: free the auxiliary CString |
390 | return res | 409 | return res |
391 | foreign import ccall "auxi.h matrix_fscanf" c_gslReadMatrix:: Ptr CChar -> TM | 410 | foreign import ccall "auxi.h matrix_fscanf" c_gslReadMatrix:: Ptr CChar -> TM |
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs index dc86484..386ebb5 100644 --- a/lib/Data/Packed/Internal/Vector.hs +++ b/lib/Data/Packed/Internal/Vector.hs | |||
@@ -31,11 +31,11 @@ data Vector t = V { dim :: Int -- ^ number of elements | |||
31 | , fptr :: ForeignPtr t -- ^ foreign pointer to the memory block | 31 | , fptr :: ForeignPtr t -- ^ foreign pointer to the memory block |
32 | } | 32 | } |
33 | 33 | ||
34 | ptr (V _ fptr) = unsafeForeignPtrToPtr fptr | 34 | --ptr (V _ fptr) = unsafeForeignPtrToPtr fptr |
35 | 35 | ||
36 | -- | check the error code and touch foreign ptr of vector arguments (if any) | 36 | -- | check the error code |
37 | check :: String -> [Vector a] -> IO Int -> IO () | 37 | check :: String -> IO Int -> IO () |
38 | check msg ls f = do | 38 | check msg f = do |
39 | err <- f | 39 | err <- f |
40 | when (err/=0) $ if err > 1024 | 40 | when (err/=0) $ if err > 1024 |
41 | then (error (msg++": "++errorCode err)) -- our errors | 41 | then (error (msg++": "++errorCode err)) -- our errors |
@@ -43,7 +43,6 @@ check msg ls f = do | |||
43 | ps <- gsl_strerror err | 43 | ps <- gsl_strerror err |
44 | s <- peekCString ps | 44 | s <- peekCString ps |
45 | error (msg++": "++s) | 45 | error (msg++": "++s) |
46 | mapM_ (touchForeignPtr . fptr) ls | ||
47 | return () | 46 | return () |
48 | 47 | ||
49 | -- | description of GSL error codes | 48 | -- | description of GSL error codes |
@@ -55,9 +54,14 @@ type Vc t s = Int -> Ptr t -> s | |||
55 | -- infixr 5 :> | 54 | -- infixr 5 :> |
56 | -- type t :> s = Vc t s | 55 | -- type t :> s = Vc t s |
57 | 56 | ||
58 | -- | adaptation of our vectors to be admitted by foreign functions: @f \/\/ vec v@ | 57 | --- | adaptation of our vectors to be admitted by foreign functions: @f \/\/ vec v@ |
59 | vec :: Vector t -> (Vc t s) -> s | 58 | --vec :: Vector t -> (Vc t s) -> s |
60 | vec v f = f (dim v) (ptr v) | 59 | --vec v f = f (dim v) (ptr v) |
60 | |||
61 | withVector (V n fp) f = withForeignPtr fp $ \p -> do | ||
62 | let v f = do | ||
63 | f n p | ||
64 | f v | ||
61 | 65 | ||
62 | -- | allocates memory for a new vector | 66 | -- | allocates memory for a new vector |
63 | createVector :: Storable a => Int -> IO (Vector a) | 67 | createVector :: Storable a => Int -> IO (Vector a) |
@@ -76,7 +80,8 @@ fromList :: Storable a => [a] -> Vector a | |||
76 | fromList l = unsafePerformIO $ do | 80 | fromList l = unsafePerformIO $ do |
77 | v <- createVector (length l) | 81 | v <- createVector (length l) |
78 | let f _ p = pokeArray p l >> return 0 | 82 | let f _ p = pokeArray p l >> return 0 |
79 | f // vec v // check "fromList" [] | 83 | withVector v $ \v -> |
84 | f // v // check "fromList" | ||
80 | return v | 85 | return v |
81 | 86 | ||
82 | safeRead v = unsafePerformIO . withForeignPtr (fptr v) | 87 | safeRead v = unsafePerformIO . withForeignPtr (fptr v) |
@@ -118,8 +123,9 @@ subVector k l (v@V {dim=n}) | |||
118 | | k<0 || k >= n || k+l > n || l < 0 = error "subVector out of range" | 123 | | k<0 || k >= n || k+l > n || l < 0 = error "subVector out of range" |
119 | | otherwise = unsafePerformIO $ do | 124 | | otherwise = unsafePerformIO $ do |
120 | r <- createVector l | 125 | r <- createVector l |
121 | let f = copyArray (ptr r) (advancePtr (ptr v) k) l >> return 0 | 126 | let f _ s _ d = copyArray d (advancePtr s k) l >> return 0 |
122 | f // check "subVector" [v,r] | 127 | ww2 withVector v withVector r $ \v r -> |
128 | f // v // r // check "subVector" | ||
123 | return r | 129 | return r |
124 | 130 | ||
125 | {- | Reads a vector position: | 131 | {- | Reads a vector position: |
@@ -144,12 +150,12 @@ join [] = error "joining zero vectors" | |||
144 | join as = unsafePerformIO $ do | 150 | join as = unsafePerformIO $ do |
145 | let tot = sum (map dim as) | 151 | let tot = sum (map dim as) |
146 | r@V {fptr = p} <- createVector tot | 152 | r@V {fptr = p} <- createVector tot |
147 | withForeignPtr p $ \_ -> | 153 | withForeignPtr p $ \ptr -> |
148 | joiner as tot (ptr r) | 154 | joiner as tot ptr |
149 | return r | 155 | return r |
150 | where joiner [] _ _ = return () | 156 | where joiner [] _ _ = return () |
151 | joiner (r@V {dim = n, fptr = b} : cs) _ p = do | 157 | joiner (r@V {dim = n, fptr = b} : cs) _ p = do |
152 | withForeignPtr b $ \_ -> copyArray p (ptr r) n | 158 | withForeignPtr b $ \pb -> copyArray p pb n |
153 | joiner cs 0 (advancePtr p n) | 159 | joiner cs 0 (advancePtr p n) |
154 | 160 | ||
155 | 161 | ||