diff options
Diffstat (limited to 'lib/Data/Packed/Internal/Vector.hs')
-rw-r--r-- | lib/Data/Packed/Internal/Vector.hs | 34 |
1 files changed, 20 insertions, 14 deletions
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs index dc86484..386ebb5 100644 --- a/lib/Data/Packed/Internal/Vector.hs +++ b/lib/Data/Packed/Internal/Vector.hs | |||
@@ -31,11 +31,11 @@ data Vector t = V { dim :: Int -- ^ number of elements | |||
31 | , fptr :: ForeignPtr t -- ^ foreign pointer to the memory block | 31 | , fptr :: ForeignPtr t -- ^ foreign pointer to the memory block |
32 | } | 32 | } |
33 | 33 | ||
34 | ptr (V _ fptr) = unsafeForeignPtrToPtr fptr | 34 | --ptr (V _ fptr) = unsafeForeignPtrToPtr fptr |
35 | 35 | ||
36 | -- | check the error code and touch foreign ptr of vector arguments (if any) | 36 | -- | check the error code |
37 | check :: String -> [Vector a] -> IO Int -> IO () | 37 | check :: String -> IO Int -> IO () |
38 | check msg ls f = do | 38 | check msg f = do |
39 | err <- f | 39 | err <- f |
40 | when (err/=0) $ if err > 1024 | 40 | when (err/=0) $ if err > 1024 |
41 | then (error (msg++": "++errorCode err)) -- our errors | 41 | then (error (msg++": "++errorCode err)) -- our errors |
@@ -43,7 +43,6 @@ check msg ls f = do | |||
43 | ps <- gsl_strerror err | 43 | ps <- gsl_strerror err |
44 | s <- peekCString ps | 44 | s <- peekCString ps |
45 | error (msg++": "++s) | 45 | error (msg++": "++s) |
46 | mapM_ (touchForeignPtr . fptr) ls | ||
47 | return () | 46 | return () |
48 | 47 | ||
49 | -- | description of GSL error codes | 48 | -- | description of GSL error codes |
@@ -55,9 +54,14 @@ type Vc t s = Int -> Ptr t -> s | |||
55 | -- infixr 5 :> | 54 | -- infixr 5 :> |
56 | -- type t :> s = Vc t s | 55 | -- type t :> s = Vc t s |
57 | 56 | ||
58 | -- | adaptation of our vectors to be admitted by foreign functions: @f \/\/ vec v@ | 57 | --- | adaptation of our vectors to be admitted by foreign functions: @f \/\/ vec v@ |
59 | vec :: Vector t -> (Vc t s) -> s | 58 | --vec :: Vector t -> (Vc t s) -> s |
60 | vec v f = f (dim v) (ptr v) | 59 | --vec v f = f (dim v) (ptr v) |
60 | |||
61 | withVector (V n fp) f = withForeignPtr fp $ \p -> do | ||
62 | let v f = do | ||
63 | f n p | ||
64 | f v | ||
61 | 65 | ||
62 | -- | allocates memory for a new vector | 66 | -- | allocates memory for a new vector |
63 | createVector :: Storable a => Int -> IO (Vector a) | 67 | createVector :: Storable a => Int -> IO (Vector a) |
@@ -76,7 +80,8 @@ fromList :: Storable a => [a] -> Vector a | |||
76 | fromList l = unsafePerformIO $ do | 80 | fromList l = unsafePerformIO $ do |
77 | v <- createVector (length l) | 81 | v <- createVector (length l) |
78 | let f _ p = pokeArray p l >> return 0 | 82 | let f _ p = pokeArray p l >> return 0 |
79 | f // vec v // check "fromList" [] | 83 | withVector v $ \v -> |
84 | f // v // check "fromList" | ||
80 | return v | 85 | return v |
81 | 86 | ||
82 | safeRead v = unsafePerformIO . withForeignPtr (fptr v) | 87 | safeRead v = unsafePerformIO . withForeignPtr (fptr v) |
@@ -118,8 +123,9 @@ subVector k l (v@V {dim=n}) | |||
118 | | k<0 || k >= n || k+l > n || l < 0 = error "subVector out of range" | 123 | | k<0 || k >= n || k+l > n || l < 0 = error "subVector out of range" |
119 | | otherwise = unsafePerformIO $ do | 124 | | otherwise = unsafePerformIO $ do |
120 | r <- createVector l | 125 | r <- createVector l |
121 | let f = copyArray (ptr r) (advancePtr (ptr v) k) l >> return 0 | 126 | let f _ s _ d = copyArray d (advancePtr s k) l >> return 0 |
122 | f // check "subVector" [v,r] | 127 | ww2 withVector v withVector r $ \v r -> |
128 | f // v // r // check "subVector" | ||
123 | return r | 129 | return r |
124 | 130 | ||
125 | {- | Reads a vector position: | 131 | {- | Reads a vector position: |
@@ -144,12 +150,12 @@ join [] = error "joining zero vectors" | |||
144 | join as = unsafePerformIO $ do | 150 | join as = unsafePerformIO $ do |
145 | let tot = sum (map dim as) | 151 | let tot = sum (map dim as) |
146 | r@V {fptr = p} <- createVector tot | 152 | r@V {fptr = p} <- createVector tot |
147 | withForeignPtr p $ \_ -> | 153 | withForeignPtr p $ \ptr -> |
148 | joiner as tot (ptr r) | 154 | joiner as tot ptr |
149 | return r | 155 | return r |
150 | where joiner [] _ _ = return () | 156 | where joiner [] _ _ = return () |
151 | joiner (r@V {dim = n, fptr = b} : cs) _ p = do | 157 | joiner (r@V {dim = n, fptr = b} : cs) _ p = do |
152 | withForeignPtr b $ \_ -> copyArray p (ptr r) n | 158 | withForeignPtr b $ \pb -> copyArray p pb n |
153 | joiner cs 0 (advancePtr p n) | 159 | joiner cs 0 (advancePtr p n) |
154 | 160 | ||
155 | 161 | ||