summaryrefslogtreecommitdiff
path: root/lib/Data/Packed/Internal.hs
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2007-06-04 19:10:28 +0000
committerAlberto Ruiz <aruiz@um.es>2007-06-04 19:10:28 +0000
commit7430630fa0504296b796223e01cbd417b88650ef (patch)
treec338dea8b82867a4c161fcee5817ed2ca27c7258 /lib/Data/Packed/Internal.hs
parent0a9817cc481fb09f1962eb2c272125e56a123814 (diff)
separation of Internal
Diffstat (limited to 'lib/Data/Packed/Internal.hs')
-rw-r--r--lib/Data/Packed/Internal.hs293
1 files changed, 0 insertions, 293 deletions
diff --git a/lib/Data/Packed/Internal.hs b/lib/Data/Packed/Internal.hs
deleted file mode 100644
index b06f044..0000000
--- a/lib/Data/Packed/Internal.hs
+++ /dev/null
@@ -1,293 +0,0 @@
1{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-}
2-----------------------------------------------------------------------------
3-- |
4-- Module : Data.Packed.Internal
5-- Copyright : (c) Alberto Ruiz 2007
6-- License : GPL-style
7--
8-- Maintainer : Alberto Ruiz <aruiz@um.es>
9-- Stability : provisional
10-- Portability : portable (uses FFI)
11--
12-- Fundamental types
13--
14-----------------------------------------------------------------------------
15
16module Data.Packed.Internal where
17
18import Foreign hiding (xor)
19import Complex
20import Control.Monad(when)
21import Debug.Trace
22import Data.List(transpose,intersperse)
23import Data.Typeable
24import Data.Maybe(fromJust)
25
26debug x = trace (show x) x
27
28----------------------------------------------------------------------
29instance (Storable a, RealFloat a) => Storable (Complex a) where --
30 alignment x = alignment (realPart x) --
31 sizeOf x = 2 * sizeOf (realPart x) --
32 peek p = do --
33 [re,im] <- peekArray 2 (castPtr p) --
34 return (re :+ im) --
35 poke p (a :+ b) = pokeArray (castPtr p) [a,b] --
36----------------------------------------------------------------------
37
38(//) :: x -> (x -> y) -> y
39infixl 0 //
40(//) = flip ($)
41
42check msg ls f = do
43 err <- f
44 when (err/=0) (error msg)
45 mapM_ (touchForeignPtr . fptr) ls
46 return ()
47
48----------------------------------------------------------------------
49
50data Vector t = V { dim :: Int
51 , fptr :: ForeignPtr t
52 , ptr :: Ptr t
53 } deriving Typeable
54
55type Vc t s = Int -> Ptr t -> s
56infixr 5 :>
57type t :> s = Vc t s
58
59vec :: Vector t -> (Vc t s) -> s
60vec v f = f (dim v) (ptr v)
61
62createVector :: Storable a => Int -> IO (Vector a)
63createVector n = do
64 when (n <= 0) $ error ("trying to createVector of dim "++show n)
65 fp <- mallocForeignPtrArray n
66 let p = unsafeForeignPtrToPtr fp
67 --putStrLn ("\n---------> V"++show n)
68 return $ V n fp p
69
70fromList :: Storable a => [a] -> Vector a
71fromList l = unsafePerformIO $ do
72 v <- createVector (length l)
73 let f _ p = pokeArray p l >> return 0
74 f // vec v // check "fromList" []
75 return v
76
77toList :: Storable a => Vector a -> [a]
78toList v = unsafePerformIO $ peekArray (dim v) (ptr v)
79
80n # l = if length l == n then fromList l else error "# with wrong size"
81
82at' :: Storable a => Vector a -> Int -> a
83at' v n = unsafePerformIO $ peekElemOff (ptr v) n
84
85at :: Storable a => Vector a -> Int -> a
86at v n | n >= 0 && n < dim v = at' v n
87 | otherwise = error "vector index out of range"
88
89instance (Show a, Storable a) => (Show (Vector a)) where
90 show v = (show (dim v))++" # " ++ show (toList v)
91
92------------------------------------------------------------------------
93
94data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
95
96-- | 2D array
97data Matrix t = M { rows :: Int
98 , cols :: Int
99 , cmat :: Vector t
100 , fmat :: Vector t
101 , isTrans :: Bool
102 , order :: MatrixOrder
103 } deriving Typeable
104
105xor a b = a && not b || b && not a
106
107fortran m = order m == ColumnMajor
108
109dat m = if fortran m `xor` isTrans m then fmat m else cmat m
110
111pref m = if fortran m then fmat m else cmat m
112
113trans m = m { rows = cols m
114 , cols = rows m
115 , isTrans = not (isTrans m)
116 }
117
118type Mt t s = Int -> Int -> Ptr t -> s
119infixr 6 ::>
120type t ::> s = Mt t s
121
122mat :: Matrix t -> (Mt t s) -> s
123mat m f = f (rows m) (cols m) (ptr (dat m))
124
125gmat m f | fortran m =
126 if (isTrans m)
127 then f 0 (rows m) (cols m) (ptr (fmat m))
128 else f 1 (cols m) (rows m) (ptr (fmat m))
129 | otherwise =
130 if isTrans m
131 then f 1 (cols m) (rows m) (ptr (cmat m))
132 else f 0 (rows m) (cols m) (ptr (cmat m))
133
134instance (Show a, Storable a) => (Show (Matrix a)) where
135 show m = (sizes++) . dsp . map (map show) . toLists $ m
136 where sizes = "("++show (rows m)++"><"++show (cols m)++")\n"
137
138partit :: Int -> [a] -> [[a]]
139partit _ [] = []
140partit n l = take n l : partit n (drop n l)
141
142toLists m = partit (cols m) . toList . cmat $ m
143
144dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp
145 where
146 mt = transpose as
147 longs = map (maximum . map length) mt
148 mtp = zipWith (\a b -> map (pad a) b) longs mt
149 pad n str = replicate (n - length str) ' ' ++ str
150 unwords' = concat . intersperse ", "
151
152matrixFromVector RowMajor c v =
153 M { rows = r
154 , cols = c
155 , cmat = v
156 , fmat = transdata c v r
157 , order = RowMajor
158 , isTrans = False
159 } where r = dim v `div` c -- TODO check mod=0
160
161matrixFromVector ColumnMajor c v =
162 M { rows = r
163 , cols = c
164 , fmat = v
165 , cmat = transdata c v r
166 , order = ColumnMajor
167 , isTrans = False
168 } where r = dim v `div` c -- TODO check mod=0
169
170createMatrix order r c = do
171 p <- createVector (r*c)
172 return (matrixFromVector order c p)
173
174transdataG :: Storable a => Int -> Vector a -> Int -> Vector a
175transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d
176
177transdataR :: Int -> Vector Double -> Int -> Vector Double
178transdataR = transdataAux ctransR
179
180transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double)
181transdataC = transdataAux ctransC
182
183transdataAux fun c1 d c2 = unsafePerformIO $ do
184 v <- createVector (dim d)
185 let r1 = dim d `div` c1
186 r2 = dim d `div` c2
187 fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d]
188 --putStrLn "---> transdataAux"
189 return v
190
191foreign import ccall safe "aux.h transR"
192 ctransR :: Double ::> Double ::> IO Int
193foreign import ccall safe "aux.h transC"
194 ctransC :: Complex Double ::> Complex Double ::> IO Int
195
196
197class (Storable a, Typeable a) => Field a where
198instance (Storable a, Typeable a) => Field a where
199
200isReal w x = typeOf (undefined :: Double) == typeOf (w x)
201isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x)
202baseOf v = (v `at` 0)
203
204scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b
205scast = fromJust . cast
206
207transdata :: Field a => Int -> Vector a -> Int -> Vector a
208transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2
209 | isComp baseOf d = scast $ transdataC c1 (scast d) c2
210 | otherwise = transdataG c1 d c2
211
212--transdata :: Storable a => Int -> Vector a -> Int -> Vector a
213--transdata = transdataG
214--{-# RULES "transdataR" transdata=transdataR #-}
215--{-# RULES "transdataC" transdata=transdataC #-}
216
217------------------------------------------------------------------
218
219constantG n x = fromList (replicate n x)
220
221constantR :: Int -> Double -> Vector Double
222constantR = constantAux cconstantR
223
224constantC :: Int -> Complex Double -> Vector (Complex Double)
225constantC = constantAux cconstantC
226
227constantAux fun n x = unsafePerformIO $ do
228 v <- createVector n
229 px <- newArray [x]
230 fun px // vec v // check "constantAux" []
231 free px
232 return v
233
234foreign import ccall safe "aux.h constantR"
235 cconstantR :: Ptr Double -> Double :> IO Int
236
237foreign import ccall safe "aux.h constantC"
238 cconstantC :: Ptr (Complex Double) -> Complex Double :> IO Int
239
240constant :: Field a => Int -> a -> Vector a
241constant n x | isReal id x = scast $ constantR n (scast x)
242 | isComp id x = scast $ constantC n (scast x)
243 | otherwise = constantG n x
244
245------------------------------------------------------------------
246
247dotL a b = sum (zipWith (*) a b)
248
249multiplyL a b = [[dotL x y | y <- transpose b] | x <- a]
250
251transL m = m {rows = cols m, cols = rows m, cmat = v, fmat = cmat m}
252 where v = transdataG (cols m) (cmat m) (rows m)
253
254------------------------------------------------------------------
255
256multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b)
257
258multiplyAux order fun a b = unsafePerformIO $ do
259 when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++
260 show (rows a,cols a) ++ " x " ++ show (rows b, cols b)
261 r <- createMatrix order (rows a) (cols b)
262 fun // gmat a // gmat b // mat r // check "multiplyAux" [pref a, pref b]
263 return r
264
265foreign import ccall safe "aux.h multiplyR"
266 cmultiplyR :: Int -> Double ::> (Int -> Double ::> (Double ::> IO Int))
267
268foreign import ccall safe "aux.h multiplyC"
269 cmultiplyC :: Int -> Complex Double ::> (Int -> Complex Double ::> (Complex Double ::> IO Int))
270
271multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a
272multiply RowMajor a b = multiplyD RowMajor a b
273multiply ColumnMajor a b = trans $ multiplyT ColumnMajor a b
274
275multiplyT order a b = multiplyD order (trans b) (trans a)
276
277multiplyD order a b
278 | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b)
279 | isComp (baseOf.dat) a = scast $ multiplyAux order cmultiplyC (scast a) (scast b)
280 | otherwise = multiplyG a b
281
282--------------------------------------------------------------------
283
284data IdxTp = Covariant | Contravariant
285
286-- | multidimensional array
287data Tensor t = T { rank :: Int
288 , dims :: [Int]
289 , idxNm :: [String]
290 , idxTp :: [IdxTp]
291 , ten :: Vector t
292 }
293