From 25d7892ac78f0f1a4fda538dd35430ebff02baaa Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Mon, 12 Nov 2007 12:24:12 +0000 Subject: withMatrix --- lib/Data/Packed/Internal/Common.hs | 5 +++++ lib/Data/Packed/Internal/Matrix.hs | 41 ++++++++++++++++++++++++++++---------- lib/Data/Packed/Internal/Vector.hs | 34 ++++++++++++++++++------------- 3 files changed, 55 insertions(+), 25 deletions(-) (limited to 'lib/Data/Packed') diff --git a/lib/Data/Packed/Internal/Common.hs b/lib/Data/Packed/Internal/Common.hs index 2b3ec28..c3a733c 100644 --- a/lib/Data/Packed/Internal/Common.hs +++ b/lib/Data/Packed/Internal/Common.hs @@ -60,6 +60,11 @@ common f = commonval . map f where infixl 0 // (//) = flip ($) +-- hmm.. +ww2 w1 o1 w2 o2 f = w1 o1 $ \a1 -> w2 o2 $ \a2 -> f a1 a2 +ww3 w1 o1 w2 o2 w3 o3 f = w1 o1 $ \a1 -> ww2 w2 o2 w3 o3 (f a1) +ww4 w1 o1 w2 o2 w3 o3 w4 o4 f = w1 o1 $ \a1 -> ww3 w2 o2 w3 o3 w4 o4 (f a1) + -- GSL error codes are <= 1024 -- | error codes for the auxiliary functions required by the wrappers errorCode :: Int -> String diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 5617996..90a96b5 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs @@ -79,9 +79,20 @@ cmat MF {rows = r, cols = c, fdat = d } = MC {rows = r, cols = c, cdat = transda fmat m@MF{} = m fmat MC {rows = r, cols = c, cdat = d } = MF {rows = r, cols = c, fdat = transdata c d r} -matc m f = f (rows m) (cols m) (ptr (cdat m)) -matf m f = f (rows m) (cols m) (ptr (fdat m)) +--matc m f = f (rows m) (cols m) (ptr (cdat m)) +--matf m f = f (rows m) (cols m) (ptr (fdat m)) +withMatrix MC {rows = r, cols = c, cdat = d } f = + withForeignPtr (fptr d) $ \p -> do + let m f = do + f r c p + f m + +withMatrix MF {rows = r, cols = c, fdat = d } f = + withForeignPtr (fptr d) $ \p -> do + let m f = do + f r c p + f m {- | Creates a vector by concatenation of rows @@ -236,7 +247,9 @@ transdataAux fun c1 d c2 = then d else unsafePerformIO $ do v <- createVector (dim d) - fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d,v] + withForeignPtr (fptr d) $ \pd -> + withForeignPtr (fptr v) $ \pv -> + fun r1 c1 pd r2 c2 pv // check "transdataAux" -- putStrLn $ "---> transdataAux" ++ show (toList d) ++ show (toList v) return v where r1 = dim d `div` c1 @@ -250,8 +263,8 @@ foreign import ccall safe "auxi.h transC" ------------------------------------------------------------------ -gmatC MF {rows = r, cols = c, fdat = d} f = f 1 c r (ptr d) -gmatC MC {rows = r, cols = c, cdat = d} f = f 0 r c (ptr d) +gmatC MF {rows = r, cols = c, fdat = d} p f = f 1 c r p +gmatC MC {rows = r, cols = c, cdat = d} p f = f 0 r c p dtt MC { cdat = d } = d dtt MF { fdat = d } = d @@ -260,7 +273,9 @@ multiplyAux fun a b = unsafePerformIO $ do when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ show (rows a,cols a) ++ " x " ++ show (rows b, cols b) r <- createMatrix RowMajor (rows a) (cols b) - fun // gmatC a // gmatC b // matc r // check "multiplyAux" [dtt a, dtt b, cdat r] + withForeignPtr (fptr (dtt a)) $ \pa -> withForeignPtr (fptr (dtt b)) $ \pb -> + withMatrix r $ \r -> + fun // gmatC a pa // gmatC b pb // r // check "multiplyAux" return r multiplyR = multiplyAux cmultiplyR @@ -293,7 +308,8 @@ subMatrixR :: (Int,Int) -> (Int,Int) -> Matrix Double -> Matrix Double subMatrixR (r0,c0) (rt,ct) x' = unsafePerformIO $ do r <- createMatrix RowMajor rt ct let x = cmat x' - c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // matc x // matc r // check "subMatrixR" [cdat x] + ww2 withMatrix x withMatrix r $ \x r -> + c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // x // r // check "subMatrixR" return r foreign import ccall "auxi.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM @@ -317,8 +333,9 @@ subMatrix = subMatrixD diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do m <- createMatrix RowMajor n n - fun // vec v // matc m // check msg [v] - return m -- {tdat = dat m} + ww2 withVector v withMatrix m $ \v m -> + fun // v // m // check msg + return m -- | diagonal matrix from a real vector diagR :: Vector Double -> Matrix Double @@ -339,7 +356,8 @@ diag = diagD constantAux fun x n = unsafePerformIO $ do v <- createVector n px <- newArray [x] - fun px // vec v // check "constantAux" [] + withVector v $ \v -> + fun px // v // check "constantAux" free px return v @@ -385,7 +403,8 @@ fromFile :: FilePath -> (Int,Int) -> IO (Matrix Double) fromFile filename (r,c) = do charname <- newCString filename res <- createMatrix RowMajor r c - c_gslReadMatrix charname // matc res // check "gslReadMatrix" [] + withMatrix res $ \res -> + c_gslReadMatrix charname // res // check "gslReadMatrix" --free charname -- TO DO: free the auxiliary CString return res foreign import ccall "auxi.h matrix_fscanf" c_gslReadMatrix:: Ptr CChar -> TM diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs index dc86484..386ebb5 100644 --- a/lib/Data/Packed/Internal/Vector.hs +++ b/lib/Data/Packed/Internal/Vector.hs @@ -31,11 +31,11 @@ data Vector t = V { dim :: Int -- ^ number of elements , fptr :: ForeignPtr t -- ^ foreign pointer to the memory block } -ptr (V _ fptr) = unsafeForeignPtrToPtr fptr +--ptr (V _ fptr) = unsafeForeignPtrToPtr fptr --- | check the error code and touch foreign ptr of vector arguments (if any) -check :: String -> [Vector a] -> IO Int -> IO () -check msg ls f = do +-- | check the error code +check :: String -> IO Int -> IO () +check msg f = do err <- f when (err/=0) $ if err > 1024 then (error (msg++": "++errorCode err)) -- our errors @@ -43,7 +43,6 @@ check msg ls f = do ps <- gsl_strerror err s <- peekCString ps error (msg++": "++s) - mapM_ (touchForeignPtr . fptr) ls return () -- | description of GSL error codes @@ -55,9 +54,14 @@ type Vc t s = Int -> Ptr t -> s -- infixr 5 :> -- type t :> s = Vc t s --- | adaptation of our vectors to be admitted by foreign functions: @f \/\/ vec v@ -vec :: Vector t -> (Vc t s) -> s -vec v f = f (dim v) (ptr v) +--- | adaptation of our vectors to be admitted by foreign functions: @f \/\/ vec v@ +--vec :: Vector t -> (Vc t s) -> s +--vec v f = f (dim v) (ptr v) + +withVector (V n fp) f = withForeignPtr fp $ \p -> do + let v f = do + f n p + f v -- | allocates memory for a new vector createVector :: Storable a => Int -> IO (Vector a) @@ -76,7 +80,8 @@ fromList :: Storable a => [a] -> Vector a fromList l = unsafePerformIO $ do v <- createVector (length l) let f _ p = pokeArray p l >> return 0 - f // vec v // check "fromList" [] + withVector v $ \v -> + f // v // check "fromList" return v safeRead v = unsafePerformIO . withForeignPtr (fptr v) @@ -118,8 +123,9 @@ subVector k l (v@V {dim=n}) | k<0 || k >= n || k+l > n || l < 0 = error "subVector out of range" | otherwise = unsafePerformIO $ do r <- createVector l - let f = copyArray (ptr r) (advancePtr (ptr v) k) l >> return 0 - f // check "subVector" [v,r] + let f _ s _ d = copyArray d (advancePtr s k) l >> return 0 + ww2 withVector v withVector r $ \v r -> + f // v // r // check "subVector" return r {- | Reads a vector position: @@ -144,12 +150,12 @@ join [] = error "joining zero vectors" join as = unsafePerformIO $ do let tot = sum (map dim as) r@V {fptr = p} <- createVector tot - withForeignPtr p $ \_ -> - joiner as tot (ptr r) + withForeignPtr p $ \ptr -> + joiner as tot ptr return r where joiner [] _ _ = return () joiner (r@V {dim = n, fptr = b} : cs) _ p = do - withForeignPtr b $ \_ -> copyArray p (ptr r) n + withForeignPtr b $ \pb -> copyArray p pb n joiner cs 0 (advancePtr p n) -- cgit v1.2.3