diff options
Diffstat (limited to 'lib/Data/Packed')
-rw-r--r-- | lib/Data/Packed/Internal.hs | 293 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 187 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Tensor.hs | 32 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Vector.hs | 164 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/aux.c (renamed from lib/Data/Packed/aux.c) | 0 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/aux.h (renamed from lib/Data/Packed/aux.h) | 0 |
6 files changed, 383 insertions, 293 deletions
diff --git a/lib/Data/Packed/Internal.hs b/lib/Data/Packed/Internal.hs deleted file mode 100644 index b06f044..0000000 --- a/lib/Data/Packed/Internal.hs +++ /dev/null | |||
@@ -1,293 +0,0 @@ | |||
1 | {-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} | ||
2 | ----------------------------------------------------------------------------- | ||
3 | -- | | ||
4 | -- Module : Data.Packed.Internal | ||
5 | -- Copyright : (c) Alberto Ruiz 2007 | ||
6 | -- License : GPL-style | ||
7 | -- | ||
8 | -- Maintainer : Alberto Ruiz <aruiz@um.es> | ||
9 | -- Stability : provisional | ||
10 | -- Portability : portable (uses FFI) | ||
11 | -- | ||
12 | -- Fundamental types | ||
13 | -- | ||
14 | ----------------------------------------------------------------------------- | ||
15 | |||
16 | module Data.Packed.Internal where | ||
17 | |||
18 | import Foreign hiding (xor) | ||
19 | import Complex | ||
20 | import Control.Monad(when) | ||
21 | import Debug.Trace | ||
22 | import Data.List(transpose,intersperse) | ||
23 | import Data.Typeable | ||
24 | import Data.Maybe(fromJust) | ||
25 | |||
26 | debug x = trace (show x) x | ||
27 | |||
28 | ---------------------------------------------------------------------- | ||
29 | instance (Storable a, RealFloat a) => Storable (Complex a) where -- | ||
30 | alignment x = alignment (realPart x) -- | ||
31 | sizeOf x = 2 * sizeOf (realPart x) -- | ||
32 | peek p = do -- | ||
33 | [re,im] <- peekArray 2 (castPtr p) -- | ||
34 | return (re :+ im) -- | ||
35 | poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- | ||
36 | ---------------------------------------------------------------------- | ||
37 | |||
38 | (//) :: x -> (x -> y) -> y | ||
39 | infixl 0 // | ||
40 | (//) = flip ($) | ||
41 | |||
42 | check msg ls f = do | ||
43 | err <- f | ||
44 | when (err/=0) (error msg) | ||
45 | mapM_ (touchForeignPtr . fptr) ls | ||
46 | return () | ||
47 | |||
48 | ---------------------------------------------------------------------- | ||
49 | |||
50 | data Vector t = V { dim :: Int | ||
51 | , fptr :: ForeignPtr t | ||
52 | , ptr :: Ptr t | ||
53 | } deriving Typeable | ||
54 | |||
55 | type Vc t s = Int -> Ptr t -> s | ||
56 | infixr 5 :> | ||
57 | type t :> s = Vc t s | ||
58 | |||
59 | vec :: Vector t -> (Vc t s) -> s | ||
60 | vec v f = f (dim v) (ptr v) | ||
61 | |||
62 | createVector :: Storable a => Int -> IO (Vector a) | ||
63 | createVector n = do | ||
64 | when (n <= 0) $ error ("trying to createVector of dim "++show n) | ||
65 | fp <- mallocForeignPtrArray n | ||
66 | let p = unsafeForeignPtrToPtr fp | ||
67 | --putStrLn ("\n---------> V"++show n) | ||
68 | return $ V n fp p | ||
69 | |||
70 | fromList :: Storable a => [a] -> Vector a | ||
71 | fromList l = unsafePerformIO $ do | ||
72 | v <- createVector (length l) | ||
73 | let f _ p = pokeArray p l >> return 0 | ||
74 | f // vec v // check "fromList" [] | ||
75 | return v | ||
76 | |||
77 | toList :: Storable a => Vector a -> [a] | ||
78 | toList v = unsafePerformIO $ peekArray (dim v) (ptr v) | ||
79 | |||
80 | n # l = if length l == n then fromList l else error "# with wrong size" | ||
81 | |||
82 | at' :: Storable a => Vector a -> Int -> a | ||
83 | at' v n = unsafePerformIO $ peekElemOff (ptr v) n | ||
84 | |||
85 | at :: Storable a => Vector a -> Int -> a | ||
86 | at v n | n >= 0 && n < dim v = at' v n | ||
87 | | otherwise = error "vector index out of range" | ||
88 | |||
89 | instance (Show a, Storable a) => (Show (Vector a)) where | ||
90 | show v = (show (dim v))++" # " ++ show (toList v) | ||
91 | |||
92 | ------------------------------------------------------------------------ | ||
93 | |||
94 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) | ||
95 | |||
96 | -- | 2D array | ||
97 | data Matrix t = M { rows :: Int | ||
98 | , cols :: Int | ||
99 | , cmat :: Vector t | ||
100 | , fmat :: Vector t | ||
101 | , isTrans :: Bool | ||
102 | , order :: MatrixOrder | ||
103 | } deriving Typeable | ||
104 | |||
105 | xor a b = a && not b || b && not a | ||
106 | |||
107 | fortran m = order m == ColumnMajor | ||
108 | |||
109 | dat m = if fortran m `xor` isTrans m then fmat m else cmat m | ||
110 | |||
111 | pref m = if fortran m then fmat m else cmat m | ||
112 | |||
113 | trans m = m { rows = cols m | ||
114 | , cols = rows m | ||
115 | , isTrans = not (isTrans m) | ||
116 | } | ||
117 | |||
118 | type Mt t s = Int -> Int -> Ptr t -> s | ||
119 | infixr 6 ::> | ||
120 | type t ::> s = Mt t s | ||
121 | |||
122 | mat :: Matrix t -> (Mt t s) -> s | ||
123 | mat m f = f (rows m) (cols m) (ptr (dat m)) | ||
124 | |||
125 | gmat m f | fortran m = | ||
126 | if (isTrans m) | ||
127 | then f 0 (rows m) (cols m) (ptr (fmat m)) | ||
128 | else f 1 (cols m) (rows m) (ptr (fmat m)) | ||
129 | | otherwise = | ||
130 | if isTrans m | ||
131 | then f 1 (cols m) (rows m) (ptr (cmat m)) | ||
132 | else f 0 (rows m) (cols m) (ptr (cmat m)) | ||
133 | |||
134 | instance (Show a, Storable a) => (Show (Matrix a)) where | ||
135 | show m = (sizes++) . dsp . map (map show) . toLists $ m | ||
136 | where sizes = "("++show (rows m)++"><"++show (cols m)++")\n" | ||
137 | |||
138 | partit :: Int -> [a] -> [[a]] | ||
139 | partit _ [] = [] | ||
140 | partit n l = take n l : partit n (drop n l) | ||
141 | |||
142 | toLists m = partit (cols m) . toList . cmat $ m | ||
143 | |||
144 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp | ||
145 | where | ||
146 | mt = transpose as | ||
147 | longs = map (maximum . map length) mt | ||
148 | mtp = zipWith (\a b -> map (pad a) b) longs mt | ||
149 | pad n str = replicate (n - length str) ' ' ++ str | ||
150 | unwords' = concat . intersperse ", " | ||
151 | |||
152 | matrixFromVector RowMajor c v = | ||
153 | M { rows = r | ||
154 | , cols = c | ||
155 | , cmat = v | ||
156 | , fmat = transdata c v r | ||
157 | , order = RowMajor | ||
158 | , isTrans = False | ||
159 | } where r = dim v `div` c -- TODO check mod=0 | ||
160 | |||
161 | matrixFromVector ColumnMajor c v = | ||
162 | M { rows = r | ||
163 | , cols = c | ||
164 | , fmat = v | ||
165 | , cmat = transdata c v r | ||
166 | , order = ColumnMajor | ||
167 | , isTrans = False | ||
168 | } where r = dim v `div` c -- TODO check mod=0 | ||
169 | |||
170 | createMatrix order r c = do | ||
171 | p <- createVector (r*c) | ||
172 | return (matrixFromVector order c p) | ||
173 | |||
174 | transdataG :: Storable a => Int -> Vector a -> Int -> Vector a | ||
175 | transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d | ||
176 | |||
177 | transdataR :: Int -> Vector Double -> Int -> Vector Double | ||
178 | transdataR = transdataAux ctransR | ||
179 | |||
180 | transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) | ||
181 | transdataC = transdataAux ctransC | ||
182 | |||
183 | transdataAux fun c1 d c2 = unsafePerformIO $ do | ||
184 | v <- createVector (dim d) | ||
185 | let r1 = dim d `div` c1 | ||
186 | r2 = dim d `div` c2 | ||
187 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] | ||
188 | --putStrLn "---> transdataAux" | ||
189 | return v | ||
190 | |||
191 | foreign import ccall safe "aux.h transR" | ||
192 | ctransR :: Double ::> Double ::> IO Int | ||
193 | foreign import ccall safe "aux.h transC" | ||
194 | ctransC :: Complex Double ::> Complex Double ::> IO Int | ||
195 | |||
196 | |||
197 | class (Storable a, Typeable a) => Field a where | ||
198 | instance (Storable a, Typeable a) => Field a where | ||
199 | |||
200 | isReal w x = typeOf (undefined :: Double) == typeOf (w x) | ||
201 | isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x) | ||
202 | baseOf v = (v `at` 0) | ||
203 | |||
204 | scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b | ||
205 | scast = fromJust . cast | ||
206 | |||
207 | transdata :: Field a => Int -> Vector a -> Int -> Vector a | ||
208 | transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2 | ||
209 | | isComp baseOf d = scast $ transdataC c1 (scast d) c2 | ||
210 | | otherwise = transdataG c1 d c2 | ||
211 | |||
212 | --transdata :: Storable a => Int -> Vector a -> Int -> Vector a | ||
213 | --transdata = transdataG | ||
214 | --{-# RULES "transdataR" transdata=transdataR #-} | ||
215 | --{-# RULES "transdataC" transdata=transdataC #-} | ||
216 | |||
217 | ------------------------------------------------------------------ | ||
218 | |||
219 | constantG n x = fromList (replicate n x) | ||
220 | |||
221 | constantR :: Int -> Double -> Vector Double | ||
222 | constantR = constantAux cconstantR | ||
223 | |||
224 | constantC :: Int -> Complex Double -> Vector (Complex Double) | ||
225 | constantC = constantAux cconstantC | ||
226 | |||
227 | constantAux fun n x = unsafePerformIO $ do | ||
228 | v <- createVector n | ||
229 | px <- newArray [x] | ||
230 | fun px // vec v // check "constantAux" [] | ||
231 | free px | ||
232 | return v | ||
233 | |||
234 | foreign import ccall safe "aux.h constantR" | ||
235 | cconstantR :: Ptr Double -> Double :> IO Int | ||
236 | |||
237 | foreign import ccall safe "aux.h constantC" | ||
238 | cconstantC :: Ptr (Complex Double) -> Complex Double :> IO Int | ||
239 | |||
240 | constant :: Field a => Int -> a -> Vector a | ||
241 | constant n x | isReal id x = scast $ constantR n (scast x) | ||
242 | | isComp id x = scast $ constantC n (scast x) | ||
243 | | otherwise = constantG n x | ||
244 | |||
245 | ------------------------------------------------------------------ | ||
246 | |||
247 | dotL a b = sum (zipWith (*) a b) | ||
248 | |||
249 | multiplyL a b = [[dotL x y | y <- transpose b] | x <- a] | ||
250 | |||
251 | transL m = m {rows = cols m, cols = rows m, cmat = v, fmat = cmat m} | ||
252 | where v = transdataG (cols m) (cmat m) (rows m) | ||
253 | |||
254 | ------------------------------------------------------------------ | ||
255 | |||
256 | multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b) | ||
257 | |||
258 | multiplyAux order fun a b = unsafePerformIO $ do | ||
259 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ | ||
260 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) | ||
261 | r <- createMatrix order (rows a) (cols b) | ||
262 | fun // gmat a // gmat b // mat r // check "multiplyAux" [pref a, pref b] | ||
263 | return r | ||
264 | |||
265 | foreign import ccall safe "aux.h multiplyR" | ||
266 | cmultiplyR :: Int -> Double ::> (Int -> Double ::> (Double ::> IO Int)) | ||
267 | |||
268 | foreign import ccall safe "aux.h multiplyC" | ||
269 | cmultiplyC :: Int -> Complex Double ::> (Int -> Complex Double ::> (Complex Double ::> IO Int)) | ||
270 | |||
271 | multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a | ||
272 | multiply RowMajor a b = multiplyD RowMajor a b | ||
273 | multiply ColumnMajor a b = trans $ multiplyT ColumnMajor a b | ||
274 | |||
275 | multiplyT order a b = multiplyD order (trans b) (trans a) | ||
276 | |||
277 | multiplyD order a b | ||
278 | | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b) | ||
279 | | isComp (baseOf.dat) a = scast $ multiplyAux order cmultiplyC (scast a) (scast b) | ||
280 | | otherwise = multiplyG a b | ||
281 | |||
282 | -------------------------------------------------------------------- | ||
283 | |||
284 | data IdxTp = Covariant | Contravariant | ||
285 | |||
286 | -- | multidimensional array | ||
287 | data Tensor t = T { rank :: Int | ||
288 | , dims :: [Int] | ||
289 | , idxNm :: [String] | ||
290 | , idxTp :: [IdxTp] | ||
291 | , ten :: Vector t | ||
292 | } | ||
293 | |||
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs new file mode 100644 index 0000000..2c57c07 --- /dev/null +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -0,0 +1,187 @@ | |||
1 | {-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} | ||
2 | ----------------------------------------------------------------------------- | ||
3 | -- | | ||
4 | -- Module : Data.Packed.Internal.Matrix | ||
5 | -- Copyright : (c) Alberto Ruiz 2007 | ||
6 | -- License : GPL-style | ||
7 | -- | ||
8 | -- Maintainer : Alberto Ruiz <aruiz@um.es> | ||
9 | -- Stability : provisional | ||
10 | -- Portability : portable (uses FFI) | ||
11 | -- | ||
12 | -- Fundamental types | ||
13 | -- | ||
14 | ----------------------------------------------------------------------------- | ||
15 | |||
16 | module Data.Packed.Internal.Matrix where | ||
17 | |||
18 | import Data.Packed.Internal.Vector | ||
19 | |||
20 | import Foreign hiding (xor) | ||
21 | import Complex | ||
22 | import Control.Monad(when) | ||
23 | import Debug.Trace | ||
24 | import Data.List(transpose,intersperse) | ||
25 | import Data.Typeable | ||
26 | import Data.Maybe(fromJust) | ||
27 | |||
28 | debug x = trace (show x) x | ||
29 | |||
30 | |||
31 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) | ||
32 | |||
33 | -- | 2D array | ||
34 | data Matrix t = M { rows :: Int | ||
35 | , cols :: Int | ||
36 | , dat :: Vector t | ||
37 | , tdat :: Vector t | ||
38 | , isTrans :: Bool | ||
39 | , order :: MatrixOrder | ||
40 | } deriving Typeable | ||
41 | |||
42 | xor a b = a && not b || b && not a | ||
43 | |||
44 | fortran m = order m == ColumnMajor | ||
45 | |||
46 | cdat m = if fortran m `xor` isTrans m then tdat m else dat m | ||
47 | fdat m = if fortran m `xor` isTrans m then dat m else tdat m | ||
48 | |||
49 | trans m = m { rows = cols m | ||
50 | , cols = rows m | ||
51 | , isTrans = not (isTrans m) | ||
52 | } | ||
53 | |||
54 | type Mt t s = Int -> Int -> Ptr t -> s | ||
55 | infixr 6 ::> | ||
56 | type t ::> s = Mt t s | ||
57 | |||
58 | mat d m f = f (rows m) (cols m) (ptr (d m)) | ||
59 | |||
60 | instance (Show a, Storable a) => (Show (Matrix a)) where | ||
61 | show m = (sizes++) . dsp . map (map show) . toLists $ m | ||
62 | where sizes = "("++show (rows m)++"><"++show (cols m)++")\n" | ||
63 | |||
64 | partit :: Int -> [a] -> [[a]] | ||
65 | partit _ [] = [] | ||
66 | partit n l = take n l : partit n (drop n l) | ||
67 | |||
68 | toLists m | fortran m = transpose $ partit (rows m) . toList . dat $ m | ||
69 | | otherwise = partit (cols m) . toList . dat $ m | ||
70 | |||
71 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp | ||
72 | where | ||
73 | mt = transpose as | ||
74 | longs = map (maximum . map length) mt | ||
75 | mtp = zipWith (\a b -> map (pad a) b) longs mt | ||
76 | pad n str = replicate (n - length str) ' ' ++ str | ||
77 | unwords' = concat . intersperse ", " | ||
78 | |||
79 | matrixFromVector RowMajor c v = | ||
80 | M { rows = r | ||
81 | , cols = c | ||
82 | , dat = v | ||
83 | , tdat = transdata c v r | ||
84 | , order = RowMajor | ||
85 | , isTrans = False | ||
86 | } where r = dim v `div` c -- TODO check mod=0 | ||
87 | |||
88 | matrixFromVector ColumnMajor c v = | ||
89 | M { rows = r | ||
90 | , cols = c | ||
91 | , dat = v | ||
92 | , tdat = transdata r v c | ||
93 | , order = ColumnMajor | ||
94 | , isTrans = False | ||
95 | } where r = dim v `div` c -- TODO check mod=0 | ||
96 | |||
97 | createMatrix order r c = do | ||
98 | p <- createVector (r*c) | ||
99 | return (matrixFromVector order c p) | ||
100 | |||
101 | transdataG :: Storable a => Int -> Vector a -> Int -> Vector a | ||
102 | transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d | ||
103 | |||
104 | transdataR :: Int -> Vector Double -> Int -> Vector Double | ||
105 | transdataR = transdataAux ctransR | ||
106 | |||
107 | transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) | ||
108 | transdataC = transdataAux ctransC | ||
109 | |||
110 | transdataAux fun c1 d c2 = unsafePerformIO $ do | ||
111 | v <- createVector (dim d) | ||
112 | let r1 = dim d `div` c1 | ||
113 | r2 = dim d `div` c2 | ||
114 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] | ||
115 | --putStrLn "---> transdataAux" | ||
116 | return v | ||
117 | |||
118 | foreign import ccall safe "aux.h transR" | ||
119 | ctransR :: Double ::> Double ::> IO Int | ||
120 | foreign import ccall safe "aux.h transC" | ||
121 | ctransC :: Complex Double ::> Complex Double ::> IO Int | ||
122 | |||
123 | transdata :: Field a => Int -> Vector a -> Int -> Vector a | ||
124 | transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2 | ||
125 | | isComp baseOf d = scast $ transdataC c1 (scast d) c2 | ||
126 | | otherwise = transdataG c1 d c2 | ||
127 | |||
128 | --transdata :: Storable a => Int -> Vector a -> Int -> Vector a | ||
129 | --transdata = transdataG | ||
130 | --{-# RULES "transdataR" transdata=transdataR #-} | ||
131 | --{-# RULES "transdataC" transdata=transdataC #-} | ||
132 | |||
133 | -- | extracts the rows of a matrix as a list of vectors | ||
134 | toRows :: Storable t => Matrix t -> [Vector t] | ||
135 | toRows m = toRows' 0 where | ||
136 | v = cdat m | ||
137 | r = rows m | ||
138 | c = cols m | ||
139 | toRows' k | k == r*c = [] | ||
140 | | otherwise = subVector k c v : toRows' (k+c) | ||
141 | |||
142 | ------------------------------------------------------------------ | ||
143 | |||
144 | dotL a b = sum (zipWith (*) a b) | ||
145 | |||
146 | multiplyL a b = [[dotL x y | y <- transpose b] | x <- a] | ||
147 | |||
148 | transL m = matrixFromVector RowMajor (rows m) $ transdataG (cols m) (cdat m) (rows m) | ||
149 | |||
150 | multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b) | ||
151 | |||
152 | ------------------------------------------------------------------ | ||
153 | |||
154 | gmatC m f | fortran m = | ||
155 | if (isTrans m) | ||
156 | then f 0 (rows m) (cols m) (ptr (dat m)) | ||
157 | else f 1 (cols m) (rows m) (ptr (dat m)) | ||
158 | | otherwise = | ||
159 | if isTrans m | ||
160 | then f 1 (cols m) (rows m) (ptr (dat m)) | ||
161 | else f 0 (rows m) (cols m) (ptr (dat m)) | ||
162 | |||
163 | |||
164 | multiplyAux order fun a b = unsafePerformIO $ do | ||
165 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ | ||
166 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) | ||
167 | r <- createMatrix order (rows a) (cols b) | ||
168 | fun // gmatC a // gmatC b // mat dat r // check "multiplyAux" [dat a, dat b] | ||
169 | return r | ||
170 | |||
171 | foreign import ccall safe "aux.h multiplyR" | ||
172 | cmultiplyR :: Int -> Double ::> (Int -> Double ::> (Double ::> IO Int)) | ||
173 | |||
174 | foreign import ccall safe "aux.h multiplyC" | ||
175 | cmultiplyC :: Int -> Complex Double ::> (Int -> Complex Double ::> (Complex Double ::> IO Int)) | ||
176 | |||
177 | multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a | ||
178 | multiply RowMajor a b = multiplyD RowMajor a b | ||
179 | multiply ColumnMajor a b = trans $ multiplyT ColumnMajor a b | ||
180 | |||
181 | multiplyT order a b = multiplyD order (trans b) (trans a) | ||
182 | |||
183 | multiplyD order a b | ||
184 | | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b) | ||
185 | | isComp (baseOf.dat) a = scast $ multiplyAux order cmultiplyC (scast a) (scast b) | ||
186 | | otherwise = multiplyG a b | ||
187 | |||
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs new file mode 100644 index 0000000..11101a9 --- /dev/null +++ b/lib/Data/Packed/Internal/Tensor.hs | |||
@@ -0,0 +1,32 @@ | |||
1 | {-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} | ||
2 | ----------------------------------------------------------------------------- | ||
3 | -- | | ||
4 | -- Module : Data.Packed.Internal.Tensor | ||
5 | -- Copyright : (c) Alberto Ruiz 2007 | ||
6 | -- License : GPL-style | ||
7 | -- | ||
8 | -- Maintainer : Alberto Ruiz <aruiz@um.es> | ||
9 | -- Stability : provisional | ||
10 | -- Portability : portable (uses FFI) | ||
11 | -- | ||
12 | -- Fundamental types | ||
13 | -- | ||
14 | ----------------------------------------------------------------------------- | ||
15 | |||
16 | module Data.Packed.Internal.Tensor where | ||
17 | |||
18 | import Data.Packed.Internal.Vector | ||
19 | import Data.Packed.Internal.Matrix | ||
20 | |||
21 | |||
22 | data IdxTp = Covariant | Contravariant deriving Show | ||
23 | |||
24 | data Tensor t = T { dims :: [(Int,(IdxTp,String))] | ||
25 | , ten :: Vector t | ||
26 | } deriving Show | ||
27 | |||
28 | rank = length . dims | ||
29 | |||
30 | outer u v = dat (multiply RowMajor r c) | ||
31 | where r = matrixFromVector RowMajor 1 u | ||
32 | c = matrixFromVector RowMajor (dim v) v | ||
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs new file mode 100644 index 0000000..7dcefeb --- /dev/null +++ b/lib/Data/Packed/Internal/Vector.hs | |||
@@ -0,0 +1,164 @@ | |||
1 | {-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} | ||
2 | ----------------------------------------------------------------------------- | ||
3 | -- | | ||
4 | -- Module : Data.Packed.Internal.Vector | ||
5 | -- Copyright : (c) Alberto Ruiz 2007 | ||
6 | -- License : GPL-style | ||
7 | -- | ||
8 | -- Maintainer : Alberto Ruiz <aruiz@um.es> | ||
9 | -- Stability : provisional | ||
10 | -- Portability : portable (uses FFI) | ||
11 | -- | ||
12 | -- Fundamental types | ||
13 | -- | ||
14 | ----------------------------------------------------------------------------- | ||
15 | |||
16 | module Data.Packed.Internal.Vector where | ||
17 | |||
18 | import Foreign | ||
19 | import Complex | ||
20 | import Control.Monad(when) | ||
21 | import Debug.Trace | ||
22 | import Data.List(transpose,intersperse) | ||
23 | import Data.Typeable | ||
24 | import Data.Maybe(fromJust) | ||
25 | |||
26 | debug x = trace (show x) x | ||
27 | |||
28 | ---------------------------------------------------------------------- | ||
29 | instance (Storable a, RealFloat a) => Storable (Complex a) where -- | ||
30 | alignment x = alignment (realPart x) -- | ||
31 | sizeOf x = 2 * sizeOf (realPart x) -- | ||
32 | peek p = do -- | ||
33 | [re,im] <- peekArray 2 (castPtr p) -- | ||
34 | return (re :+ im) -- | ||
35 | poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- | ||
36 | ---------------------------------------------------------------------- | ||
37 | |||
38 | (//) :: x -> (x -> y) -> y | ||
39 | infixl 0 // | ||
40 | (//) = flip ($) | ||
41 | |||
42 | check msg ls f = do | ||
43 | err <- f | ||
44 | when (err/=0) (error msg) | ||
45 | mapM_ (touchForeignPtr . fptr) ls | ||
46 | return () | ||
47 | |||
48 | class (Storable a, Typeable a) => Field a where | ||
49 | instance (Storable a, Typeable a) => Field a where | ||
50 | |||
51 | isReal w x = typeOf (undefined :: Double) == typeOf (w x) | ||
52 | isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x) | ||
53 | baseOf v = (v `at` 0) | ||
54 | |||
55 | scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b | ||
56 | scast = fromJust . cast | ||
57 | |||
58 | |||
59 | |||
60 | ---------------------------------------------------------------------- | ||
61 | |||
62 | data Vector t = V { dim :: Int | ||
63 | , fptr :: ForeignPtr t | ||
64 | , ptr :: Ptr t | ||
65 | } deriving Typeable | ||
66 | |||
67 | type Vc t s = Int -> Ptr t -> s | ||
68 | infixr 5 :> | ||
69 | type t :> s = Vc t s | ||
70 | |||
71 | vec :: Vector t -> (Vc t s) -> s | ||
72 | vec v f = f (dim v) (ptr v) | ||
73 | |||
74 | createVector :: Storable a => Int -> IO (Vector a) | ||
75 | createVector n = do | ||
76 | when (n <= 0) $ error ("trying to createVector of dim "++show n) | ||
77 | fp <- mallocForeignPtrArray n | ||
78 | let p = unsafeForeignPtrToPtr fp | ||
79 | --putStrLn ("\n---------> V"++show n) | ||
80 | return $ V n fp p | ||
81 | |||
82 | fromList :: Storable a => [a] -> Vector a | ||
83 | fromList l = unsafePerformIO $ do | ||
84 | v <- createVector (length l) | ||
85 | let f _ p = pokeArray p l >> return 0 | ||
86 | f // vec v // check "fromList" [] | ||
87 | return v | ||
88 | |||
89 | toList :: Storable a => Vector a -> [a] | ||
90 | toList v = unsafePerformIO $ peekArray (dim v) (ptr v) | ||
91 | |||
92 | n # l = if length l == n then fromList l else error "# with wrong size" | ||
93 | |||
94 | at' :: Storable a => Vector a -> Int -> a | ||
95 | at' v n = unsafePerformIO $ peekElemOff (ptr v) n | ||
96 | |||
97 | at :: Storable a => Vector a -> Int -> a | ||
98 | at v n | n >= 0 && n < dim v = at' v n | ||
99 | | otherwise = error "vector index out of range" | ||
100 | |||
101 | instance (Show a, Storable a) => (Show (Vector a)) where | ||
102 | show v = (show (dim v))++" # " ++ show (toList v) | ||
103 | |||
104 | -- | creates a Vector taking a number of consecutive toList from another Vector | ||
105 | subVector :: Storable t => Int -- ^ index of the starting element | ||
106 | -> Int -- ^ number of toList to extract | ||
107 | -> Vector t -- ^ source | ||
108 | -> Vector t -- ^ result | ||
109 | subVector k l (v@V {dim=n, ptr=p, fptr=fp}) | ||
110 | | k<0 || k >= n || k+l > n || l < 0 = error "subVector out of range" | ||
111 | | otherwise = unsafePerformIO $ do | ||
112 | r <- createVector l | ||
113 | let f = copyArray (ptr r) (advancePtr p k) l >> return 0 | ||
114 | f // check "subVector" [v] | ||
115 | return r | ||
116 | |||
117 | subVector' k l (v@V {dim=n, ptr=p, fptr=fp}) | ||
118 | | k<0 || k >= n || k+l > n || l < 0 = error "subVector out of range" | ||
119 | | otherwise = v {dim=l, ptr=advancePtr p k} | ||
120 | |||
121 | |||
122 | {- | ||
123 | -- | creates a new Vector by joining a list of Vectors | ||
124 | join :: Field t => [Vector t] -> Vector t | ||
125 | join [] = error "joining an empty list" | ||
126 | join as = unsafePerformIO $ do | ||
127 | let tot = sum (map size as) | ||
128 | p <- mallocForeignPtrArray tot | ||
129 | withForeignPtr p $ \p -> | ||
130 | joiner as tot p | ||
131 | return (V tot p) | ||
132 | where joiner [] _ _ = return () | ||
133 | joiner (V n b : cs) _ p = do | ||
134 | withForeignPtr b $ \b' -> copyArray p b' n | ||
135 | joiner cs 0 (advancePtr p n) | ||
136 | -} | ||
137 | |||
138 | |||
139 | constantG n x = fromList (replicate n x) | ||
140 | |||
141 | constantR :: Int -> Double -> Vector Double | ||
142 | constantR = constantAux cconstantR | ||
143 | |||
144 | constantC :: Int -> Complex Double -> Vector (Complex Double) | ||
145 | constantC = constantAux cconstantC | ||
146 | |||
147 | constantAux fun n x = unsafePerformIO $ do | ||
148 | v <- createVector n | ||
149 | px <- newArray [x] | ||
150 | fun px // vec v // check "constantAux" [] | ||
151 | free px | ||
152 | return v | ||
153 | |||
154 | foreign import ccall safe "aux.h constantR" | ||
155 | cconstantR :: Ptr Double -> Double :> IO Int | ||
156 | |||
157 | foreign import ccall safe "aux.h constantC" | ||
158 | cconstantC :: Ptr (Complex Double) -> Complex Double :> IO Int | ||
159 | |||
160 | constant :: Field a => Int -> a -> Vector a | ||
161 | constant n x | isReal id x = scast $ constantR n (scast x) | ||
162 | | isComp id x = scast $ constantC n (scast x) | ||
163 | | otherwise = constantG n x | ||
164 | |||
diff --git a/lib/Data/Packed/aux.c b/lib/Data/Packed/Internal/aux.c index da36035..da36035 100644 --- a/lib/Data/Packed/aux.c +++ b/lib/Data/Packed/Internal/aux.c | |||
diff --git a/lib/Data/Packed/aux.h b/lib/Data/Packed/Internal/aux.h index f45b55a..f45b55a 100644 --- a/lib/Data/Packed/aux.h +++ b/lib/Data/Packed/Internal/aux.h | |||