diff options
Diffstat (limited to 'lib/Data/Packed/Internal.hs')
-rw-r--r-- | lib/Data/Packed/Internal.hs | 286 |
1 files changed, 216 insertions, 70 deletions
diff --git a/lib/Data/Packed/Internal.hs b/lib/Data/Packed/Internal.hs index 5e19e58..b06f044 100644 --- a/lib/Data/Packed/Internal.hs +++ b/lib/Data/Packed/Internal.hs | |||
@@ -1,3 +1,4 @@ | |||
1 | {-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} | ||
1 | ----------------------------------------------------------------------------- | 2 | ----------------------------------------------------------------------------- |
2 | -- | | 3 | -- | |
3 | -- Module : Data.Packed.Internal | 4 | -- Module : Data.Packed.Internal |
@@ -14,39 +15,16 @@ | |||
14 | 15 | ||
15 | module Data.Packed.Internal where | 16 | module Data.Packed.Internal where |
16 | 17 | ||
17 | import Foreign | 18 | import Foreign hiding (xor) |
18 | import Complex | 19 | import Complex |
19 | import Control.Monad(when) | 20 | import Control.Monad(when) |
20 | import Debug.Trace | 21 | import Debug.Trace |
22 | import Data.List(transpose,intersperse) | ||
23 | import Data.Typeable | ||
24 | import Data.Maybe(fromJust) | ||
21 | 25 | ||
22 | debug x = trace (show x) x | 26 | debug x = trace (show x) x |
23 | 27 | ||
24 | -- | 1D array | ||
25 | data Vector t = V { dim :: Int | ||
26 | , fptr :: ForeignPtr t | ||
27 | , ptr :: Ptr t | ||
28 | } | ||
29 | |||
30 | data TransMode = NoTrans | Trans | ConjTrans | ||
31 | |||
32 | -- | 2D array | ||
33 | data Matrix t = M { rows :: Int | ||
34 | , cols :: Int | ||
35 | , mat :: Vector t | ||
36 | , trMode :: TransMode | ||
37 | , isCOrder :: Bool | ||
38 | } | ||
39 | |||
40 | data IdxTp = Covariant | Contravariant | ||
41 | |||
42 | -- | multidimensional array | ||
43 | data Tensor t = T { rank :: Int | ||
44 | , dims :: [Int] | ||
45 | , idxNm :: [String] | ||
46 | , idxTp :: [IdxTp] | ||
47 | , ten :: Vector t | ||
48 | } | ||
49 | |||
50 | ---------------------------------------------------------------------- | 28 | ---------------------------------------------------------------------- |
51 | instance (Storable a, RealFloat a) => Storable (Complex a) where -- | 29 | instance (Storable a, RealFloat a) => Storable (Complex a) where -- |
52 | alignment x = alignment (realPart x) -- | 30 | alignment x = alignment (realPart x) -- |
@@ -57,36 +35,36 @@ instance (Storable a, RealFloat a) => Storable (Complex a) where -- | |||
57 | poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- | 35 | poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- |
58 | ---------------------------------------------------------------------- | 36 | ---------------------------------------------------------------------- |
59 | 37 | ||
60 | |||
61 | -- f // vec a // vec b // vec res // check "vector add" [a,b] | ||
62 | |||
63 | (//) :: x -> (x -> y) -> y | 38 | (//) :: x -> (x -> y) -> y |
64 | infixl 0 // | 39 | infixl 0 // |
65 | (//) = flip ($) | 40 | (//) = flip ($) |
66 | 41 | ||
67 | vec :: Vector a -> (Int -> Ptr b -> t) -> t | ||
68 | vec v f = f (dim v) (castPtr $ ptr v) | ||
69 | |||
70 | mata :: Matrix a -> (Int-> Int -> Ptr b -> t) -> t | ||
71 | mata m f = f (rows m) (cols m) (castPtr $ ptr (mat m)) | ||
72 | |||
73 | pd2pc :: Ptr Double -> Ptr (Complex (Double)) | ||
74 | pd2pc = castPtr | ||
75 | |||
76 | pc2pd :: Ptr (Complex (Double)) -> Ptr Double | ||
77 | pc2pd = castPtr | ||
78 | |||
79 | check msg ls f = do | 42 | check msg ls f = do |
80 | err <- f | 43 | err <- f |
81 | when (err/=0) (error msg) | 44 | when (err/=0) (error msg) |
82 | mapM_ (touchForeignPtr . fptr) ls | 45 | mapM_ (touchForeignPtr . fptr) ls |
83 | return () | 46 | return () |
84 | 47 | ||
48 | ---------------------------------------------------------------------- | ||
49 | |||
50 | data Vector t = V { dim :: Int | ||
51 | , fptr :: ForeignPtr t | ||
52 | , ptr :: Ptr t | ||
53 | } deriving Typeable | ||
54 | |||
55 | type Vc t s = Int -> Ptr t -> s | ||
56 | infixr 5 :> | ||
57 | type t :> s = Vc t s | ||
58 | |||
59 | vec :: Vector t -> (Vc t s) -> s | ||
60 | vec v f = f (dim v) (ptr v) | ||
61 | |||
85 | createVector :: Storable a => Int -> IO (Vector a) | 62 | createVector :: Storable a => Int -> IO (Vector a) |
86 | createVector n = do | 63 | createVector n = do |
87 | when (n <= 0) $ error ("trying to createVector of dim "++show n) | 64 | when (n <= 0) $ error ("trying to createVector of dim "++show n) |
88 | fp <- mallocForeignPtrArray n | 65 | fp <- mallocForeignPtrArray n |
89 | let p = unsafeForeignPtrToPtr fp | 66 | let p = unsafeForeignPtrToPtr fp |
67 | --putStrLn ("\n---------> V"++show n) | ||
90 | return $ V n fp p | 68 | return $ V n fp p |
91 | 69 | ||
92 | fromList :: Storable a => [a] -> Vector a | 70 | fromList :: Storable a => [a] -> Vector a |
@@ -99,6 +77,8 @@ fromList l = unsafePerformIO $ do | |||
99 | toList :: Storable a => Vector a -> [a] | 77 | toList :: Storable a => Vector a -> [a] |
100 | toList v = unsafePerformIO $ peekArray (dim v) (ptr v) | 78 | toList v = unsafePerformIO $ peekArray (dim v) (ptr v) |
101 | 79 | ||
80 | n # l = if length l == n then fromList l else error "# with wrong size" | ||
81 | |||
102 | at' :: Storable a => Vector a -> Int -> a | 82 | at' :: Storable a => Vector a -> Int -> a |
103 | at' v n = unsafePerformIO $ peekElemOff (ptr v) n | 83 | at' v n = unsafePerformIO $ peekElemOff (ptr v) n |
104 | 84 | ||
@@ -106,42 +86,208 @@ at :: Storable a => Vector a -> Int -> a | |||
106 | at v n | n >= 0 && n < dim v = at' v n | 86 | at v n | n >= 0 && n < dim v = at' v n |
107 | | otherwise = error "vector index out of range" | 87 | | otherwise = error "vector index out of range" |
108 | 88 | ||
109 | dsv v = sizeOf (v `at` 0) | 89 | instance (Show a, Storable a) => (Show (Vector a)) where |
110 | dsm m = (dsv.mat) m | 90 | show v = (show (dim v))++" # " ++ show (toList v) |
111 | 91 | ||
112 | constant :: Storable a => Int -> a -> Vector a | 92 | ------------------------------------------------------------------------ |
113 | constant n x = unsafePerformIO $ do | ||
114 | v <- createVector n | ||
115 | let f k p | k == n = return 0 | ||
116 | | otherwise = pokeElemOff p k x >> f (k+1) p | ||
117 | const (f 0) // vec v // check "constant" [] | ||
118 | return v | ||
119 | 93 | ||
120 | instance (Show a, Storable a) => (Show (Vector a)) where | 94 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) |
121 | show v = "fromList " ++ show (toList v) | 95 | |
96 | -- | 2D array | ||
97 | data Matrix t = M { rows :: Int | ||
98 | , cols :: Int | ||
99 | , cmat :: Vector t | ||
100 | , fmat :: Vector t | ||
101 | , isTrans :: Bool | ||
102 | , order :: MatrixOrder | ||
103 | } deriving Typeable | ||
104 | |||
105 | xor a b = a && not b || b && not a | ||
106 | |||
107 | fortran m = order m == ColumnMajor | ||
108 | |||
109 | dat m = if fortran m `xor` isTrans m then fmat m else cmat m | ||
110 | |||
111 | pref m = if fortran m then fmat m else cmat m | ||
112 | |||
113 | trans m = m { rows = cols m | ||
114 | , cols = rows m | ||
115 | , isTrans = not (isTrans m) | ||
116 | } | ||
117 | |||
118 | type Mt t s = Int -> Int -> Ptr t -> s | ||
119 | infixr 6 ::> | ||
120 | type t ::> s = Mt t s | ||
121 | |||
122 | mat :: Matrix t -> (Mt t s) -> s | ||
123 | mat m f = f (rows m) (cols m) (ptr (dat m)) | ||
124 | |||
125 | gmat m f | fortran m = | ||
126 | if (isTrans m) | ||
127 | then f 0 (rows m) (cols m) (ptr (fmat m)) | ||
128 | else f 1 (cols m) (rows m) (ptr (fmat m)) | ||
129 | | otherwise = | ||
130 | if isTrans m | ||
131 | then f 1 (cols m) (rows m) (ptr (cmat m)) | ||
132 | else f 0 (rows m) (cols m) (ptr (cmat m)) | ||
122 | 133 | ||
123 | instance (Show a, Storable a) => (Show (Matrix a)) where | 134 | instance (Show a, Storable a) => (Show (Matrix a)) where |
124 | show m = "reshape "++show (cols m) ++ " $ fromList " ++ show (toList (mat m)) | 135 | show m = (sizes++) . dsp . map (map show) . toLists $ m |
136 | where sizes = "("++show (rows m)++"><"++show (cols m)++")\n" | ||
137 | |||
138 | partit :: Int -> [a] -> [[a]] | ||
139 | partit _ [] = [] | ||
140 | partit n l = take n l : partit n (drop n l) | ||
141 | |||
142 | toLists m = partit (cols m) . toList . cmat $ m | ||
125 | 143 | ||
126 | reshape :: Storable a => Int -> Vector a -> Matrix a | 144 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp |
127 | reshape n v = M { rows = dim v `div` n | 145 | where |
128 | , cols = n | 146 | mt = transpose as |
129 | , mat = v | 147 | longs = map (maximum . map length) mt |
130 | , trMode = NoTrans | 148 | mtp = zipWith (\a b -> map (pad a) b) longs mt |
131 | , isCOrder = True | 149 | pad n str = replicate (n - length str) ' ' ++ str |
132 | } | 150 | unwords' = concat . intersperse ", " |
133 | 151 | ||
134 | createMatrix r c = do | 152 | matrixFromVector RowMajor c v = |
153 | M { rows = r | ||
154 | , cols = c | ||
155 | , cmat = v | ||
156 | , fmat = transdata c v r | ||
157 | , order = RowMajor | ||
158 | , isTrans = False | ||
159 | } where r = dim v `div` c -- TODO check mod=0 | ||
160 | |||
161 | matrixFromVector ColumnMajor c v = | ||
162 | M { rows = r | ||
163 | , cols = c | ||
164 | , fmat = v | ||
165 | , cmat = transdata c v r | ||
166 | , order = ColumnMajor | ||
167 | , isTrans = False | ||
168 | } where r = dim v `div` c -- TODO check mod=0 | ||
169 | |||
170 | createMatrix order r c = do | ||
135 | p <- createVector (r*c) | 171 | p <- createVector (r*c) |
136 | return (reshape c p) | 172 | return (matrixFromVector order c p) |
173 | |||
174 | transdataG :: Storable a => Int -> Vector a -> Int -> Vector a | ||
175 | transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d | ||
176 | |||
177 | transdataR :: Int -> Vector Double -> Int -> Vector Double | ||
178 | transdataR = transdataAux ctransR | ||
179 | |||
180 | transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) | ||
181 | transdataC = transdataAux ctransC | ||
182 | |||
183 | transdataAux fun c1 d c2 = unsafePerformIO $ do | ||
184 | v <- createVector (dim d) | ||
185 | let r1 = dim d `div` c1 | ||
186 | r2 = dim d `div` c2 | ||
187 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] | ||
188 | --putStrLn "---> transdataAux" | ||
189 | return v | ||
190 | |||
191 | foreign import ccall safe "aux.h transR" | ||
192 | ctransR :: Double ::> Double ::> IO Int | ||
193 | foreign import ccall safe "aux.h transC" | ||
194 | ctransC :: Complex Double ::> Complex Double ::> IO Int | ||
195 | |||
196 | |||
197 | class (Storable a, Typeable a) => Field a where | ||
198 | instance (Storable a, Typeable a) => Field a where | ||
199 | |||
200 | isReal w x = typeOf (undefined :: Double) == typeOf (w x) | ||
201 | isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x) | ||
202 | baseOf v = (v `at` 0) | ||
203 | |||
204 | scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b | ||
205 | scast = fromJust . cast | ||
206 | |||
207 | transdata :: Field a => Int -> Vector a -> Int -> Vector a | ||
208 | transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2 | ||
209 | | isComp baseOf d = scast $ transdataC c1 (scast d) c2 | ||
210 | | otherwise = transdataG c1 d c2 | ||
211 | |||
212 | --transdata :: Storable a => Int -> Vector a -> Int -> Vector a | ||
213 | --transdata = transdataG | ||
214 | --{-# RULES "transdataR" transdata=transdataR #-} | ||
215 | --{-# RULES "transdataC" transdata=transdataC #-} | ||
216 | |||
217 | ------------------------------------------------------------------ | ||
218 | |||
219 | constantG n x = fromList (replicate n x) | ||
220 | |||
221 | constantR :: Int -> Double -> Vector Double | ||
222 | constantR = constantAux cconstantR | ||
223 | |||
224 | constantC :: Int -> Complex Double -> Vector (Complex Double) | ||
225 | constantC = constantAux cconstantC | ||
226 | |||
227 | constantAux fun n x = unsafePerformIO $ do | ||
228 | v <- createVector n | ||
229 | px <- newArray [x] | ||
230 | fun px // vec v // check "constantAux" [] | ||
231 | free px | ||
232 | return v | ||
137 | 233 | ||
138 | type CMat s = Int -> Int -> Ptr Double -> s | 234 | foreign import ccall safe "aux.h constantR" |
139 | type CVec s = Int -> Ptr Double -> s | 235 | cconstantR :: Ptr Double -> Double :> IO Int |
140 | 236 | ||
141 | foreign import ccall safe "aux.h trans" ctrans :: Int -> CMat (CMat (IO Int)) | 237 | foreign import ccall safe "aux.h constantC" |
238 | cconstantC :: Ptr (Complex Double) -> Complex Double :> IO Int | ||
142 | 239 | ||
143 | trans :: Storable a => Matrix a -> Matrix a | 240 | constant :: Field a => Int -> a -> Vector a |
144 | trans m = unsafePerformIO $ do | 241 | constant n x | isReal id x = scast $ constantR n (scast x) |
145 | r <- createMatrix (cols m) (rows m) | 242 | | isComp id x = scast $ constantC n (scast x) |
146 | ctrans (dsm m) // mata m // mata r // check "trans" [mat m] | 243 | | otherwise = constantG n x |
244 | |||
245 | ------------------------------------------------------------------ | ||
246 | |||
247 | dotL a b = sum (zipWith (*) a b) | ||
248 | |||
249 | multiplyL a b = [[dotL x y | y <- transpose b] | x <- a] | ||
250 | |||
251 | transL m = m {rows = cols m, cols = rows m, cmat = v, fmat = cmat m} | ||
252 | where v = transdataG (cols m) (cmat m) (rows m) | ||
253 | |||
254 | ------------------------------------------------------------------ | ||
255 | |||
256 | multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b) | ||
257 | |||
258 | multiplyAux order fun a b = unsafePerformIO $ do | ||
259 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ | ||
260 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) | ||
261 | r <- createMatrix order (rows a) (cols b) | ||
262 | fun // gmat a // gmat b // mat r // check "multiplyAux" [pref a, pref b] | ||
147 | return r | 263 | return r |
264 | |||
265 | foreign import ccall safe "aux.h multiplyR" | ||
266 | cmultiplyR :: Int -> Double ::> (Int -> Double ::> (Double ::> IO Int)) | ||
267 | |||
268 | foreign import ccall safe "aux.h multiplyC" | ||
269 | cmultiplyC :: Int -> Complex Double ::> (Int -> Complex Double ::> (Complex Double ::> IO Int)) | ||
270 | |||
271 | multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a | ||
272 | multiply RowMajor a b = multiplyD RowMajor a b | ||
273 | multiply ColumnMajor a b = trans $ multiplyT ColumnMajor a b | ||
274 | |||
275 | multiplyT order a b = multiplyD order (trans b) (trans a) | ||
276 | |||
277 | multiplyD order a b | ||
278 | | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b) | ||
279 | | isComp (baseOf.dat) a = scast $ multiplyAux order cmultiplyC (scast a) (scast b) | ||
280 | | otherwise = multiplyG a b | ||
281 | |||
282 | -------------------------------------------------------------------- | ||
283 | |||
284 | data IdxTp = Covariant | Contravariant | ||
285 | |||
286 | -- | multidimensional array | ||
287 | data Tensor t = T { rank :: Int | ||
288 | , dims :: [Int] | ||
289 | , idxNm :: [String] | ||
290 | , idxTp :: [IdxTp] | ||
291 | , ten :: Vector t | ||
292 | } | ||
293 | |||