diff options
author | Alberto Ruiz <aruiz@um.es> | 2007-06-04 19:10:28 +0000 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2007-06-04 19:10:28 +0000 |
commit | 7430630fa0504296b796223e01cbd417b88650ef (patch) | |
tree | c338dea8b82867a4c161fcee5817ed2ca27c7258 /lib/Data/Packed/Internal.hs | |
parent | 0a9817cc481fb09f1962eb2c272125e56a123814 (diff) |
separation of Internal
Diffstat (limited to 'lib/Data/Packed/Internal.hs')
-rw-r--r-- | lib/Data/Packed/Internal.hs | 293 |
1 files changed, 0 insertions, 293 deletions
diff --git a/lib/Data/Packed/Internal.hs b/lib/Data/Packed/Internal.hs deleted file mode 100644 index b06f044..0000000 --- a/lib/Data/Packed/Internal.hs +++ /dev/null | |||
@@ -1,293 +0,0 @@ | |||
1 | {-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} | ||
2 | ----------------------------------------------------------------------------- | ||
3 | -- | | ||
4 | -- Module : Data.Packed.Internal | ||
5 | -- Copyright : (c) Alberto Ruiz 2007 | ||
6 | -- License : GPL-style | ||
7 | -- | ||
8 | -- Maintainer : Alberto Ruiz <aruiz@um.es> | ||
9 | -- Stability : provisional | ||
10 | -- Portability : portable (uses FFI) | ||
11 | -- | ||
12 | -- Fundamental types | ||
13 | -- | ||
14 | ----------------------------------------------------------------------------- | ||
15 | |||
16 | module Data.Packed.Internal where | ||
17 | |||
18 | import Foreign hiding (xor) | ||
19 | import Complex | ||
20 | import Control.Monad(when) | ||
21 | import Debug.Trace | ||
22 | import Data.List(transpose,intersperse) | ||
23 | import Data.Typeable | ||
24 | import Data.Maybe(fromJust) | ||
25 | |||
26 | debug x = trace (show x) x | ||
27 | |||
28 | ---------------------------------------------------------------------- | ||
29 | instance (Storable a, RealFloat a) => Storable (Complex a) where -- | ||
30 | alignment x = alignment (realPart x) -- | ||
31 | sizeOf x = 2 * sizeOf (realPart x) -- | ||
32 | peek p = do -- | ||
33 | [re,im] <- peekArray 2 (castPtr p) -- | ||
34 | return (re :+ im) -- | ||
35 | poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- | ||
36 | ---------------------------------------------------------------------- | ||
37 | |||
38 | (//) :: x -> (x -> y) -> y | ||
39 | infixl 0 // | ||
40 | (//) = flip ($) | ||
41 | |||
42 | check msg ls f = do | ||
43 | err <- f | ||
44 | when (err/=0) (error msg) | ||
45 | mapM_ (touchForeignPtr . fptr) ls | ||
46 | return () | ||
47 | |||
48 | ---------------------------------------------------------------------- | ||
49 | |||
50 | data Vector t = V { dim :: Int | ||
51 | , fptr :: ForeignPtr t | ||
52 | , ptr :: Ptr t | ||
53 | } deriving Typeable | ||
54 | |||
55 | type Vc t s = Int -> Ptr t -> s | ||
56 | infixr 5 :> | ||
57 | type t :> s = Vc t s | ||
58 | |||
59 | vec :: Vector t -> (Vc t s) -> s | ||
60 | vec v f = f (dim v) (ptr v) | ||
61 | |||
62 | createVector :: Storable a => Int -> IO (Vector a) | ||
63 | createVector n = do | ||
64 | when (n <= 0) $ error ("trying to createVector of dim "++show n) | ||
65 | fp <- mallocForeignPtrArray n | ||
66 | let p = unsafeForeignPtrToPtr fp | ||
67 | --putStrLn ("\n---------> V"++show n) | ||
68 | return $ V n fp p | ||
69 | |||
70 | fromList :: Storable a => [a] -> Vector a | ||
71 | fromList l = unsafePerformIO $ do | ||
72 | v <- createVector (length l) | ||
73 | let f _ p = pokeArray p l >> return 0 | ||
74 | f // vec v // check "fromList" [] | ||
75 | return v | ||
76 | |||
77 | toList :: Storable a => Vector a -> [a] | ||
78 | toList v = unsafePerformIO $ peekArray (dim v) (ptr v) | ||
79 | |||
80 | n # l = if length l == n then fromList l else error "# with wrong size" | ||
81 | |||
82 | at' :: Storable a => Vector a -> Int -> a | ||
83 | at' v n = unsafePerformIO $ peekElemOff (ptr v) n | ||
84 | |||
85 | at :: Storable a => Vector a -> Int -> a | ||
86 | at v n | n >= 0 && n < dim v = at' v n | ||
87 | | otherwise = error "vector index out of range" | ||
88 | |||
89 | instance (Show a, Storable a) => (Show (Vector a)) where | ||
90 | show v = (show (dim v))++" # " ++ show (toList v) | ||
91 | |||
92 | ------------------------------------------------------------------------ | ||
93 | |||
94 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) | ||
95 | |||
96 | -- | 2D array | ||
97 | data Matrix t = M { rows :: Int | ||
98 | , cols :: Int | ||
99 | , cmat :: Vector t | ||
100 | , fmat :: Vector t | ||
101 | , isTrans :: Bool | ||
102 | , order :: MatrixOrder | ||
103 | } deriving Typeable | ||
104 | |||
105 | xor a b = a && not b || b && not a | ||
106 | |||
107 | fortran m = order m == ColumnMajor | ||
108 | |||
109 | dat m = if fortran m `xor` isTrans m then fmat m else cmat m | ||
110 | |||
111 | pref m = if fortran m then fmat m else cmat m | ||
112 | |||
113 | trans m = m { rows = cols m | ||
114 | , cols = rows m | ||
115 | , isTrans = not (isTrans m) | ||
116 | } | ||
117 | |||
118 | type Mt t s = Int -> Int -> Ptr t -> s | ||
119 | infixr 6 ::> | ||
120 | type t ::> s = Mt t s | ||
121 | |||
122 | mat :: Matrix t -> (Mt t s) -> s | ||
123 | mat m f = f (rows m) (cols m) (ptr (dat m)) | ||
124 | |||
125 | gmat m f | fortran m = | ||
126 | if (isTrans m) | ||
127 | then f 0 (rows m) (cols m) (ptr (fmat m)) | ||
128 | else f 1 (cols m) (rows m) (ptr (fmat m)) | ||
129 | | otherwise = | ||
130 | if isTrans m | ||
131 | then f 1 (cols m) (rows m) (ptr (cmat m)) | ||
132 | else f 0 (rows m) (cols m) (ptr (cmat m)) | ||
133 | |||
134 | instance (Show a, Storable a) => (Show (Matrix a)) where | ||
135 | show m = (sizes++) . dsp . map (map show) . toLists $ m | ||
136 | where sizes = "("++show (rows m)++"><"++show (cols m)++")\n" | ||
137 | |||
138 | partit :: Int -> [a] -> [[a]] | ||
139 | partit _ [] = [] | ||
140 | partit n l = take n l : partit n (drop n l) | ||
141 | |||
142 | toLists m = partit (cols m) . toList . cmat $ m | ||
143 | |||
144 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp | ||
145 | where | ||
146 | mt = transpose as | ||
147 | longs = map (maximum . map length) mt | ||
148 | mtp = zipWith (\a b -> map (pad a) b) longs mt | ||
149 | pad n str = replicate (n - length str) ' ' ++ str | ||
150 | unwords' = concat . intersperse ", " | ||
151 | |||
152 | matrixFromVector RowMajor c v = | ||
153 | M { rows = r | ||
154 | , cols = c | ||
155 | , cmat = v | ||
156 | , fmat = transdata c v r | ||
157 | , order = RowMajor | ||
158 | , isTrans = False | ||
159 | } where r = dim v `div` c -- TODO check mod=0 | ||
160 | |||
161 | matrixFromVector ColumnMajor c v = | ||
162 | M { rows = r | ||
163 | , cols = c | ||
164 | , fmat = v | ||
165 | , cmat = transdata c v r | ||
166 | , order = ColumnMajor | ||
167 | , isTrans = False | ||
168 | } where r = dim v `div` c -- TODO check mod=0 | ||
169 | |||
170 | createMatrix order r c = do | ||
171 | p <- createVector (r*c) | ||
172 | return (matrixFromVector order c p) | ||
173 | |||
174 | transdataG :: Storable a => Int -> Vector a -> Int -> Vector a | ||
175 | transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d | ||
176 | |||
177 | transdataR :: Int -> Vector Double -> Int -> Vector Double | ||
178 | transdataR = transdataAux ctransR | ||
179 | |||
180 | transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) | ||
181 | transdataC = transdataAux ctransC | ||
182 | |||
183 | transdataAux fun c1 d c2 = unsafePerformIO $ do | ||
184 | v <- createVector (dim d) | ||
185 | let r1 = dim d `div` c1 | ||
186 | r2 = dim d `div` c2 | ||
187 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] | ||
188 | --putStrLn "---> transdataAux" | ||
189 | return v | ||
190 | |||
191 | foreign import ccall safe "aux.h transR" | ||
192 | ctransR :: Double ::> Double ::> IO Int | ||
193 | foreign import ccall safe "aux.h transC" | ||
194 | ctransC :: Complex Double ::> Complex Double ::> IO Int | ||
195 | |||
196 | |||
197 | class (Storable a, Typeable a) => Field a where | ||
198 | instance (Storable a, Typeable a) => Field a where | ||
199 | |||
200 | isReal w x = typeOf (undefined :: Double) == typeOf (w x) | ||
201 | isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x) | ||
202 | baseOf v = (v `at` 0) | ||
203 | |||
204 | scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b | ||
205 | scast = fromJust . cast | ||
206 | |||
207 | transdata :: Field a => Int -> Vector a -> Int -> Vector a | ||
208 | transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2 | ||
209 | | isComp baseOf d = scast $ transdataC c1 (scast d) c2 | ||
210 | | otherwise = transdataG c1 d c2 | ||
211 | |||
212 | --transdata :: Storable a => Int -> Vector a -> Int -> Vector a | ||
213 | --transdata = transdataG | ||
214 | --{-# RULES "transdataR" transdata=transdataR #-} | ||
215 | --{-# RULES "transdataC" transdata=transdataC #-} | ||
216 | |||
217 | ------------------------------------------------------------------ | ||
218 | |||
219 | constantG n x = fromList (replicate n x) | ||
220 | |||
221 | constantR :: Int -> Double -> Vector Double | ||
222 | constantR = constantAux cconstantR | ||
223 | |||
224 | constantC :: Int -> Complex Double -> Vector (Complex Double) | ||
225 | constantC = constantAux cconstantC | ||
226 | |||
227 | constantAux fun n x = unsafePerformIO $ do | ||
228 | v <- createVector n | ||
229 | px <- newArray [x] | ||
230 | fun px // vec v // check "constantAux" [] | ||
231 | free px | ||
232 | return v | ||
233 | |||
234 | foreign import ccall safe "aux.h constantR" | ||
235 | cconstantR :: Ptr Double -> Double :> IO Int | ||
236 | |||
237 | foreign import ccall safe "aux.h constantC" | ||
238 | cconstantC :: Ptr (Complex Double) -> Complex Double :> IO Int | ||
239 | |||
240 | constant :: Field a => Int -> a -> Vector a | ||
241 | constant n x | isReal id x = scast $ constantR n (scast x) | ||
242 | | isComp id x = scast $ constantC n (scast x) | ||
243 | | otherwise = constantG n x | ||
244 | |||
245 | ------------------------------------------------------------------ | ||
246 | |||
247 | dotL a b = sum (zipWith (*) a b) | ||
248 | |||
249 | multiplyL a b = [[dotL x y | y <- transpose b] | x <- a] | ||
250 | |||
251 | transL m = m {rows = cols m, cols = rows m, cmat = v, fmat = cmat m} | ||
252 | where v = transdataG (cols m) (cmat m) (rows m) | ||
253 | |||
254 | ------------------------------------------------------------------ | ||
255 | |||
256 | multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b) | ||
257 | |||
258 | multiplyAux order fun a b = unsafePerformIO $ do | ||
259 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ | ||
260 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) | ||
261 | r <- createMatrix order (rows a) (cols b) | ||
262 | fun // gmat a // gmat b // mat r // check "multiplyAux" [pref a, pref b] | ||
263 | return r | ||
264 | |||
265 | foreign import ccall safe "aux.h multiplyR" | ||
266 | cmultiplyR :: Int -> Double ::> (Int -> Double ::> (Double ::> IO Int)) | ||
267 | |||
268 | foreign import ccall safe "aux.h multiplyC" | ||
269 | cmultiplyC :: Int -> Complex Double ::> (Int -> Complex Double ::> (Complex Double ::> IO Int)) | ||
270 | |||
271 | multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a | ||
272 | multiply RowMajor a b = multiplyD RowMajor a b | ||
273 | multiply ColumnMajor a b = trans $ multiplyT ColumnMajor a b | ||
274 | |||
275 | multiplyT order a b = multiplyD order (trans b) (trans a) | ||
276 | |||
277 | multiplyD order a b | ||
278 | | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b) | ||
279 | | isComp (baseOf.dat) a = scast $ multiplyAux order cmultiplyC (scast a) (scast b) | ||
280 | | otherwise = multiplyG a b | ||
281 | |||
282 | -------------------------------------------------------------------- | ||
283 | |||
284 | data IdxTp = Covariant | Contravariant | ||
285 | |||
286 | -- | multidimensional array | ||
287 | data Tensor t = T { rank :: Int | ||
288 | , dims :: [Int] | ||
289 | , idxNm :: [String] | ||
290 | , idxTp :: [IdxTp] | ||
291 | , ten :: Vector t | ||
292 | } | ||
293 | |||