summaryrefslogtreecommitdiff
path: root/lib/Data/Packed/Internal
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2007-11-12 12:24:12 +0000
committerAlberto Ruiz <aruiz@um.es>2007-11-12 12:24:12 +0000
commit25d7892ac78f0f1a4fda538dd35430ebff02baaa (patch)
tree170572a869a5d73cd09bdf39b17fbb37b6e451fd /lib/Data/Packed/Internal
parent33a9909d0d59f468039597c405306b8d5fa9e008 (diff)
withMatrix
Diffstat (limited to 'lib/Data/Packed/Internal')
-rw-r--r--lib/Data/Packed/Internal/Common.hs5
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs41
-rw-r--r--lib/Data/Packed/Internal/Vector.hs34
3 files changed, 55 insertions, 25 deletions
diff --git a/lib/Data/Packed/Internal/Common.hs b/lib/Data/Packed/Internal/Common.hs
index 2b3ec28..c3a733c 100644
--- a/lib/Data/Packed/Internal/Common.hs
+++ b/lib/Data/Packed/Internal/Common.hs
@@ -60,6 +60,11 @@ common f = commonval . map f where
60infixl 0 // 60infixl 0 //
61(//) = flip ($) 61(//) = flip ($)
62 62
63-- hmm..
64ww2 w1 o1 w2 o2 f = w1 o1 $ \a1 -> w2 o2 $ \a2 -> f a1 a2
65ww3 w1 o1 w2 o2 w3 o3 f = w1 o1 $ \a1 -> ww2 w2 o2 w3 o3 (f a1)
66ww4 w1 o1 w2 o2 w3 o3 w4 o4 f = w1 o1 $ \a1 -> ww3 w2 o2 w3 o3 w4 o4 (f a1)
67
63-- GSL error codes are <= 1024 68-- GSL error codes are <= 1024
64-- | error codes for the auxiliary functions required by the wrappers 69-- | error codes for the auxiliary functions required by the wrappers
65errorCode :: Int -> String 70errorCode :: Int -> String
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index 5617996..90a96b5 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -79,9 +79,20 @@ cmat MF {rows = r, cols = c, fdat = d } = MC {rows = r, cols = c, cdat = transda
79fmat m@MF{} = m 79fmat m@MF{} = m
80fmat MC {rows = r, cols = c, cdat = d } = MF {rows = r, cols = c, fdat = transdata c d r} 80fmat MC {rows = r, cols = c, cdat = d } = MF {rows = r, cols = c, fdat = transdata c d r}
81 81
82matc m f = f (rows m) (cols m) (ptr (cdat m)) 82--matc m f = f (rows m) (cols m) (ptr (cdat m))
83matf m f = f (rows m) (cols m) (ptr (fdat m)) 83--matf m f = f (rows m) (cols m) (ptr (fdat m))
84 84
85withMatrix MC {rows = r, cols = c, cdat = d } f =
86 withForeignPtr (fptr d) $ \p -> do
87 let m f = do
88 f r c p
89 f m
90
91withMatrix MF {rows = r, cols = c, fdat = d } f =
92 withForeignPtr (fptr d) $ \p -> do
93 let m f = do
94 f r c p
95 f m
85 96
86{- | Creates a vector by concatenation of rows 97{- | Creates a vector by concatenation of rows
87 98
@@ -236,7 +247,9 @@ transdataAux fun c1 d c2 =
236 then d 247 then d
237 else unsafePerformIO $ do 248 else unsafePerformIO $ do
238 v <- createVector (dim d) 249 v <- createVector (dim d)
239 fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d,v] 250 withForeignPtr (fptr d) $ \pd ->
251 withForeignPtr (fptr v) $ \pv ->
252 fun r1 c1 pd r2 c2 pv // check "transdataAux"
240 -- putStrLn $ "---> transdataAux" ++ show (toList d) ++ show (toList v) 253 -- putStrLn $ "---> transdataAux" ++ show (toList d) ++ show (toList v)
241 return v 254 return v
242 where r1 = dim d `div` c1 255 where r1 = dim d `div` c1
@@ -250,8 +263,8 @@ foreign import ccall safe "auxi.h transC"
250 263
251------------------------------------------------------------------ 264------------------------------------------------------------------
252 265
253gmatC MF {rows = r, cols = c, fdat = d} f = f 1 c r (ptr d) 266gmatC MF {rows = r, cols = c, fdat = d} p f = f 1 c r p
254gmatC MC {rows = r, cols = c, cdat = d} f = f 0 r c (ptr d) 267gmatC MC {rows = r, cols = c, cdat = d} p f = f 0 r c p
255 268
256dtt MC { cdat = d } = d 269dtt MC { cdat = d } = d
257dtt MF { fdat = d } = d 270dtt MF { fdat = d } = d
@@ -260,7 +273,9 @@ multiplyAux fun a b = unsafePerformIO $ do
260 when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ 273 when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++
261 show (rows a,cols a) ++ " x " ++ show (rows b, cols b) 274 show (rows a,cols a) ++ " x " ++ show (rows b, cols b)
262 r <- createMatrix RowMajor (rows a) (cols b) 275 r <- createMatrix RowMajor (rows a) (cols b)
263 fun // gmatC a // gmatC b // matc r // check "multiplyAux" [dtt a, dtt b, cdat r] 276 withForeignPtr (fptr (dtt a)) $ \pa -> withForeignPtr (fptr (dtt b)) $ \pb ->
277 withMatrix r $ \r ->
278 fun // gmatC a pa // gmatC b pb // r // check "multiplyAux"
264 return r 279 return r
265 280
266multiplyR = multiplyAux cmultiplyR 281multiplyR = multiplyAux cmultiplyR
@@ -293,7 +308,8 @@ subMatrixR :: (Int,Int) -> (Int,Int) -> Matrix Double -> Matrix Double
293subMatrixR (r0,c0) (rt,ct) x' = unsafePerformIO $ do 308subMatrixR (r0,c0) (rt,ct) x' = unsafePerformIO $ do
294 r <- createMatrix RowMajor rt ct 309 r <- createMatrix RowMajor rt ct
295 let x = cmat x' 310 let x = cmat x'
296 c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // matc x // matc r // check "subMatrixR" [cdat x] 311 ww2 withMatrix x withMatrix r $ \x r ->
312 c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // x // r // check "subMatrixR"
297 return r 313 return r
298foreign import ccall "auxi.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM 314foreign import ccall "auxi.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM
299 315
@@ -317,8 +333,9 @@ subMatrix = subMatrixD
317 333
318diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do 334diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do
319 m <- createMatrix RowMajor n n 335 m <- createMatrix RowMajor n n
320 fun // vec v // matc m // check msg [v] 336 ww2 withVector v withMatrix m $ \v m ->
321 return m -- {tdat = dat m} 337 fun // v // m // check msg
338 return m
322 339
323-- | diagonal matrix from a real vector 340-- | diagonal matrix from a real vector
324diagR :: Vector Double -> Matrix Double 341diagR :: Vector Double -> Matrix Double
@@ -339,7 +356,8 @@ diag = diagD
339constantAux fun x n = unsafePerformIO $ do 356constantAux fun x n = unsafePerformIO $ do
340 v <- createVector n 357 v <- createVector n
341 px <- newArray [x] 358 px <- newArray [x]
342 fun px // vec v // check "constantAux" [] 359 withVector v $ \v ->
360 fun px // v // check "constantAux"
343 free px 361 free px
344 return v 362 return v
345 363
@@ -385,7 +403,8 @@ fromFile :: FilePath -> (Int,Int) -> IO (Matrix Double)
385fromFile filename (r,c) = do 403fromFile filename (r,c) = do
386 charname <- newCString filename 404 charname <- newCString filename
387 res <- createMatrix RowMajor r c 405 res <- createMatrix RowMajor r c
388 c_gslReadMatrix charname // matc res // check "gslReadMatrix" [] 406 withMatrix res $ \res ->
407 c_gslReadMatrix charname // res // check "gslReadMatrix"
389 --free charname -- TO DO: free the auxiliary CString 408 --free charname -- TO DO: free the auxiliary CString
390 return res 409 return res
391foreign import ccall "auxi.h matrix_fscanf" c_gslReadMatrix:: Ptr CChar -> TM 410foreign import ccall "auxi.h matrix_fscanf" c_gslReadMatrix:: Ptr CChar -> TM
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs
index dc86484..386ebb5 100644
--- a/lib/Data/Packed/Internal/Vector.hs
+++ b/lib/Data/Packed/Internal/Vector.hs
@@ -31,11 +31,11 @@ data Vector t = V { dim :: Int -- ^ number of elements
31 , fptr :: ForeignPtr t -- ^ foreign pointer to the memory block 31 , fptr :: ForeignPtr t -- ^ foreign pointer to the memory block
32 } 32 }
33 33
34ptr (V _ fptr) = unsafeForeignPtrToPtr fptr 34--ptr (V _ fptr) = unsafeForeignPtrToPtr fptr
35 35
36-- | check the error code and touch foreign ptr of vector arguments (if any) 36-- | check the error code
37check :: String -> [Vector a] -> IO Int -> IO () 37check :: String -> IO Int -> IO ()
38check msg ls f = do 38check msg f = do
39 err <- f 39 err <- f
40 when (err/=0) $ if err > 1024 40 when (err/=0) $ if err > 1024
41 then (error (msg++": "++errorCode err)) -- our errors 41 then (error (msg++": "++errorCode err)) -- our errors
@@ -43,7 +43,6 @@ check msg ls f = do
43 ps <- gsl_strerror err 43 ps <- gsl_strerror err
44 s <- peekCString ps 44 s <- peekCString ps
45 error (msg++": "++s) 45 error (msg++": "++s)
46 mapM_ (touchForeignPtr . fptr) ls
47 return () 46 return ()
48 47
49-- | description of GSL error codes 48-- | description of GSL error codes
@@ -55,9 +54,14 @@ type Vc t s = Int -> Ptr t -> s
55-- infixr 5 :> 54-- infixr 5 :>
56-- type t :> s = Vc t s 55-- type t :> s = Vc t s
57 56
58-- | adaptation of our vectors to be admitted by foreign functions: @f \/\/ vec v@ 57--- | adaptation of our vectors to be admitted by foreign functions: @f \/\/ vec v@
59vec :: Vector t -> (Vc t s) -> s 58--vec :: Vector t -> (Vc t s) -> s
60vec v f = f (dim v) (ptr v) 59--vec v f = f (dim v) (ptr v)
60
61withVector (V n fp) f = withForeignPtr fp $ \p -> do
62 let v f = do
63 f n p
64 f v
61 65
62-- | allocates memory for a new vector 66-- | allocates memory for a new vector
63createVector :: Storable a => Int -> IO (Vector a) 67createVector :: Storable a => Int -> IO (Vector a)
@@ -76,7 +80,8 @@ fromList :: Storable a => [a] -> Vector a
76fromList l = unsafePerformIO $ do 80fromList l = unsafePerformIO $ do
77 v <- createVector (length l) 81 v <- createVector (length l)
78 let f _ p = pokeArray p l >> return 0 82 let f _ p = pokeArray p l >> return 0
79 f // vec v // check "fromList" [] 83 withVector v $ \v ->
84 f // v // check "fromList"
80 return v 85 return v
81 86
82safeRead v = unsafePerformIO . withForeignPtr (fptr v) 87safeRead v = unsafePerformIO . withForeignPtr (fptr v)
@@ -118,8 +123,9 @@ subVector k l (v@V {dim=n})
118 | k<0 || k >= n || k+l > n || l < 0 = error "subVector out of range" 123 | k<0 || k >= n || k+l > n || l < 0 = error "subVector out of range"
119 | otherwise = unsafePerformIO $ do 124 | otherwise = unsafePerformIO $ do
120 r <- createVector l 125 r <- createVector l
121 let f = copyArray (ptr r) (advancePtr (ptr v) k) l >> return 0 126 let f _ s _ d = copyArray d (advancePtr s k) l >> return 0
122 f // check "subVector" [v,r] 127 ww2 withVector v withVector r $ \v r ->
128 f // v // r // check "subVector"
123 return r 129 return r
124 130
125{- | Reads a vector position: 131{- | Reads a vector position:
@@ -144,12 +150,12 @@ join [] = error "joining zero vectors"
144join as = unsafePerformIO $ do 150join as = unsafePerformIO $ do
145 let tot = sum (map dim as) 151 let tot = sum (map dim as)
146 r@V {fptr = p} <- createVector tot 152 r@V {fptr = p} <- createVector tot
147 withForeignPtr p $ \_ -> 153 withForeignPtr p $ \ptr ->
148 joiner as tot (ptr r) 154 joiner as tot ptr
149 return r 155 return r
150 where joiner [] _ _ = return () 156 where joiner [] _ _ = return ()
151 joiner (r@V {dim = n, fptr = b} : cs) _ p = do 157 joiner (r@V {dim = n, fptr = b} : cs) _ p = do
152 withForeignPtr b $ \_ -> copyArray p (ptr r) n 158 withForeignPtr b $ \pb -> copyArray p pb n
153 joiner cs 0 (advancePtr p n) 159 joiner cs 0 (advancePtr p n)
154 160
155 161