diff options
author | Alberto Ruiz <aruiz@um.es> | 2007-09-08 09:46:33 +0000 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2007-09-08 09:46:33 +0000 |
commit | 34380f2b5d7b048a4d68197f16a8db0e53742030 (patch) | |
tree | 444aff88cda5c247d49bac0d294d8cfb9ef7bf23 /lib/Data | |
parent | 0c38c1b0e122a56ea98c494e60ba90afe2688664 (diff) |
type classes
Diffstat (limited to 'lib/Data')
-rw-r--r-- | lib/Data/Packed/Internal/Common.hs | 18 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 204 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Tensor.hs | 37 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Vector.hs | 42 | ||||
-rw-r--r-- | lib/Data/Packed/Matrix.hs | 15 | ||||
-rw-r--r-- | lib/Data/Packed/Vector.hs | 12 |
6 files changed, 184 insertions, 144 deletions
diff --git a/lib/Data/Packed/Internal/Common.hs b/lib/Data/Packed/Internal/Common.hs index 1bfed6d..1212968 100644 --- a/lib/Data/Packed/Internal/Common.hs +++ b/lib/Data/Packed/Internal/Common.hs | |||
@@ -28,7 +28,7 @@ debug x = trace (show x) x | |||
28 | data Vector t = V { dim :: Int | 28 | data Vector t = V { dim :: Int |
29 | , fptr :: ForeignPtr t | 29 | , fptr :: ForeignPtr t |
30 | , ptr :: Ptr t | 30 | , ptr :: Ptr t |
31 | } deriving Typeable | 31 | } -- deriving Typeable |
32 | 32 | ||
33 | ---------------------------------------------------------------------- | 33 | ---------------------------------------------------------------------- |
34 | instance (Storable a, RealFloat a) => Storable (Complex a) where -- | 34 | instance (Storable a, RealFloat a) => Storable (Complex a) where -- |
@@ -78,17 +78,17 @@ check msg ls f = do | |||
78 | mapM_ (touchForeignPtr . fptr) ls | 78 | mapM_ (touchForeignPtr . fptr) ls |
79 | return () | 79 | return () |
80 | 80 | ||
81 | class (Storable a, Typeable a) => Field a | 81 | --class (Storable a, Typeable a) => Field a |
82 | instance (Storable a, Typeable a) => Field a | 82 | --instance (Storable a, Typeable a) => Field a |
83 | 83 | ||
84 | isReal :: (Data.Typeable.Typeable a) => (t -> a) -> t -> Bool | 84 | --isReal :: (Data.Typeable.Typeable a) => (t -> a) -> t -> Bool |
85 | isReal w x = typeOf (undefined :: Double) == typeOf (w x) | 85 | --isReal w x = typeOf (undefined :: Double) == typeOf (w x) |
86 | 86 | ||
87 | isComp :: (Data.Typeable.Typeable a) => (t -> a) -> t -> Bool | 87 | --isComp :: (Data.Typeable.Typeable a) => (t -> a) -> t -> Bool |
88 | isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x) | 88 | --isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x) |
89 | 89 | ||
90 | scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b | 90 | --scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b |
91 | scast = fromJust . cast | 91 | --scast = fromJust . cast |
92 | 92 | ||
93 | {- | conversion of Haskell functions into function pointers that can be used in the C side | 93 | {- | conversion of Haskell functions into function pointers that can be used in the C side |
94 | -} | 94 | -} |
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 9895393..48652f3 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -1,4 +1,4 @@ | |||
1 | {-# OPTIONS_GHC -fglasgow-exts #-} | 1 | {-# OPTIONS_GHC -fglasgow-exts -fallow-overlapping-instances #-} |
2 | ----------------------------------------------------------------------------- | 2 | ----------------------------------------------------------------------------- |
3 | -- | | 3 | -- | |
4 | -- Module : Data.Packed.Internal.Matrix | 4 | -- Module : Data.Packed.Internal.Matrix |
@@ -22,9 +22,65 @@ import Foreign hiding (xor) | |||
22 | import Complex | 22 | import Complex |
23 | import Control.Monad(when) | 23 | import Control.Monad(when) |
24 | import Data.List(transpose,intersperse) | 24 | import Data.List(transpose,intersperse) |
25 | import Data.Typeable | 25 | --import Data.Typeable |
26 | import Data.Maybe(fromJust) | 26 | import Data.Maybe(fromJust) |
27 | 27 | ||
28 | ---------------------------------------------------------------- | ||
29 | |||
30 | class Storable a => Field a where | ||
31 | constant :: a -> Int -> Vector a | ||
32 | transdata :: Int -> Vector a -> Int -> Vector a | ||
33 | multiplyD :: MatrixOrder -> Matrix a -> Matrix a -> Matrix a | ||
34 | subMatrix :: (Int,Int) -- ^ (r0,c0) starting position | ||
35 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | ||
36 | -> Matrix a -> Matrix a | ||
37 | diag :: Vector a -> Matrix a | ||
38 | |||
39 | |||
40 | instance Field Double where | ||
41 | constant = constantR | ||
42 | transdata = transdataR | ||
43 | multiplyD = multiplyR | ||
44 | subMatrix = subMatrixR | ||
45 | diag = diagR | ||
46 | |||
47 | instance Field (Complex Double) where | ||
48 | constant = constantC | ||
49 | transdata = transdataC | ||
50 | multiplyD = multiplyC | ||
51 | subMatrix = subMatrixC | ||
52 | diag = diagC | ||
53 | |||
54 | ----------------------------------------------------------------- | ||
55 | |||
56 | transdataR :: Int -> Vector Double -> Int -> Vector Double | ||
57 | transdataR = transdataAux ctransR | ||
58 | |||
59 | transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) | ||
60 | transdataC = transdataAux ctransC | ||
61 | |||
62 | transdataAux fun c1 d c2 = | ||
63 | if noneed | ||
64 | then d | ||
65 | else unsafePerformIO $ do | ||
66 | v <- createVector (dim d) | ||
67 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] | ||
68 | --putStrLn "---> transdataAux" | ||
69 | return v | ||
70 | where r1 = dim d `div` c1 | ||
71 | r2 = dim d `div` c2 | ||
72 | noneed = r1 == 1 || c1 == 1 | ||
73 | |||
74 | foreign import ccall safe "aux.h transR" | ||
75 | ctransR :: TMM -- Double ::> Double ::> IO Int | ||
76 | foreign import ccall safe "aux.h transC" | ||
77 | ctransC :: TCMCM -- Complex Double ::> Complex Double ::> IO Int | ||
78 | |||
79 | transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d | ||
80 | |||
81 | |||
82 | |||
83 | |||
28 | 84 | ||
29 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) | 85 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) |
30 | 86 | ||
@@ -34,9 +90,18 @@ data Matrix t = M { rows :: Int | |||
34 | , tdat :: Vector t | 90 | , tdat :: Vector t |
35 | , isTrans :: Bool | 91 | , isTrans :: Bool |
36 | , order :: MatrixOrder | 92 | , order :: MatrixOrder |
37 | } deriving Typeable | 93 | } -- deriving Typeable |
38 | 94 | ||
39 | 95 | ||
96 | data NMat t = MC { rws, cls :: Int, dtc :: Vector t} | ||
97 | | MF { rws, cls :: Int, dtf :: Vector t} | ||
98 | | Tr (NMat t) | ||
99 | |||
100 | ntrans (Tr m) = m | ||
101 | ntrans m = Tr m | ||
102 | |||
103 | viewC m@MC{} = m | ||
104 | viewF m@MF{} = m | ||
40 | 105 | ||
41 | fortran m = order m == ColumnMajor | 106 | fortran m = order m == ColumnMajor |
42 | 107 | ||
@@ -78,7 +143,11 @@ matrixFromVector RowMajor c v = | |||
78 | , tdat = transdata c v r | 143 | , tdat = transdata c v r |
79 | , order = RowMajor | 144 | , order = RowMajor |
80 | , isTrans = False | 145 | , isTrans = False |
81 | } where r = dim v `div` c -- TODO check mod=0 | 146 | } where (d,m) = dim v `divMod` c |
147 | r | m==0 = d | ||
148 | | otherwise = error "matrixFromVector" | ||
149 | |||
150 | -- r = dim v `div` c -- TODO check mod=0 | ||
82 | 151 | ||
83 | matrixFromVector ColumnMajor c v = | 152 | matrixFromVector ColumnMajor c v = |
84 | M { rows = r | 153 | M { rows = r |
@@ -87,7 +156,9 @@ matrixFromVector ColumnMajor c v = | |||
87 | , tdat = transdata r v c | 156 | , tdat = transdata r v c |
88 | , order = ColumnMajor | 157 | , order = ColumnMajor |
89 | , isTrans = False | 158 | , isTrans = False |
90 | } where r = dim v `div` c -- TODO check mod=0 | 159 | } where (d,m) = dim v `divMod` c |
160 | r | m==0 = d | ||
161 | | otherwise = error "matrixFromVector" | ||
91 | 162 | ||
92 | createMatrix order r c = do | 163 | createMatrix order r c = do |
93 | p <- createVector (r*c) | 164 | p <- createVector (r*c) |
@@ -102,48 +173,11 @@ createMatrix order r c = do | |||
102 | , 9.0, 10.0, 11.0, 12.0 ]@ | 173 | , 9.0, 10.0, 11.0, 12.0 ]@ |
103 | 174 | ||
104 | -} | 175 | -} |
105 | reshape :: (Field t) => Int -> Vector t -> Matrix t | 176 | reshape :: Field t => Int -> Vector t -> Matrix t |
106 | reshape c v = matrixFromVector RowMajor c v | 177 | reshape c v = matrixFromVector RowMajor c v |
107 | 178 | ||
108 | singleton x = reshape 1 (fromList [x]) | 179 | singleton x = reshape 1 (fromList [x]) |
109 | 180 | ||
110 | transdataG :: Storable a => Int -> Vector a -> Int -> Vector a | ||
111 | transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d | ||
112 | |||
113 | transdataR :: Int -> Vector Double -> Int -> Vector Double | ||
114 | transdataR = transdataAux ctransR | ||
115 | |||
116 | transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) | ||
117 | transdataC = transdataAux ctransC | ||
118 | |||
119 | transdataAux fun c1 d c2 = | ||
120 | if noneed | ||
121 | then d | ||
122 | else unsafePerformIO $ do | ||
123 | v <- createVector (dim d) | ||
124 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] | ||
125 | --putStrLn "---> transdataAux" | ||
126 | return v | ||
127 | where r1 = dim d `div` c1 | ||
128 | r2 = dim d `div` c2 | ||
129 | noneed = r1 == 1 || c1 == 1 | ||
130 | |||
131 | foreign import ccall safe "aux.h transR" | ||
132 | ctransR :: TMM -- Double ::> Double ::> IO Int | ||
133 | foreign import ccall safe "aux.h transC" | ||
134 | ctransC :: TCMCM -- Complex Double ::> Complex Double ::> IO Int | ||
135 | |||
136 | transdata :: Field a => Int -> Vector a -> Int -> Vector a | ||
137 | transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2 | ||
138 | | isComp baseOf d = scast $ transdataC c1 (scast d) c2 | ||
139 | | otherwise = transdataG c1 d c2 | ||
140 | |||
141 | --transdata :: Storable a => Int -> Vector a -> Int -> Vector a | ||
142 | --transdata = transdataG | ||
143 | --{-# RULES "transdataR" transdata=transdataR #-} | ||
144 | --{-# RULES "transdataC" transdata=transdataC #-} | ||
145 | |||
146 | ----------------------------------------------------------------- | ||
147 | liftMatrix :: (Field a, Field b) => (Vector a -> Vector b) -> Matrix a -> Matrix b | 181 | liftMatrix :: (Field a, Field b) => (Vector a -> Vector b) -> Matrix a -> Matrix b |
148 | liftMatrix f m = reshape (cols m) (f (cdat m)) | 182 | liftMatrix f m = reshape (cols m) (f (cdat m)) |
149 | 183 | ||
@@ -163,7 +197,7 @@ multiplyL a b | ok = [[dotL x y | y <- transpose b] | x <- a] | |||
163 | Nothing -> False | 197 | Nothing -> False |
164 | Just c -> c == length b | 198 | Just c -> c == length b |
165 | 199 | ||
166 | transL m = matrixFromVector RowMajor (rows m) $ transdataG (cols m) (cdat m) (rows m) | 200 | transL m = matrixFromVector RowMajor (rows m) $ transdata (cols m) (cdat m) (rows m) |
167 | 201 | ||
168 | multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b) | 202 | multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b) |
169 | 203 | ||
@@ -179,7 +213,7 @@ gmatC m f | fortran m = | |||
179 | else f 0 (rows m) (cols m) (ptr (dat m)) | 213 | else f 0 (rows m) (cols m) (ptr (dat m)) |
180 | 214 | ||
181 | 215 | ||
182 | multiplyAux order fun a b = unsafePerformIO $ do | 216 | multiplyAux fun order a b = unsafePerformIO $ do |
183 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ | 217 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ |
184 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) | 218 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) |
185 | r <- createMatrix order (rows a) (cols b) | 219 | r <- createMatrix order (rows a) (cols b) |
@@ -198,37 +232,14 @@ foreign import ccall safe "aux.h multiplyC" | |||
198 | -> Int -> Int -> Ptr (Complex Double) | 232 | -> Int -> Int -> Ptr (Complex Double) |
199 | -> IO Int | 233 | -> IO Int |
200 | 234 | ||
201 | multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a | 235 | multiply :: (Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a |
202 | multiply RowMajor a b = multiplyD RowMajor a b | 236 | multiply RowMajor a b = multiplyD RowMajor a b |
203 | multiply ColumnMajor a b = m {rows = cols m, cols = rows m, order = ColumnMajor} | 237 | multiply ColumnMajor a b = m {rows = cols m, cols = rows m, order = ColumnMajor} |
204 | where m = multiplyD RowMajor (trans b) (trans a) | 238 | where m = multiplyD RowMajor (trans b) (trans a) |
205 | 239 | ||
206 | multiplyD order a b | ||
207 | | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b) | ||
208 | | isComp (baseOf.dat) a = scast $ multiplyAux order cmultiplyC (scast a) (scast b) | ||
209 | | otherwise = multiplyG a b | ||
210 | |||
211 | ---------------------------------------------------------------------- | ||
212 | 240 | ||
213 | outer' u v = dat (outer u v) | 241 | multiplyR = multiplyAux cmultiplyR |
214 | 242 | multiplyC = multiplyAux cmultiplyC | |
215 | {- | Outer product of two vectors. | ||
216 | |||
217 | @\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3] | ||
218 | (3><3) | ||
219 | [ 5.0, 2.0, 3.0 | ||
220 | , 10.0, 4.0, 6.0 | ||
221 | , 15.0, 6.0, 9.0 ]@ | ||
222 | -} | ||
223 | outer :: (Num t, Field t) => Vector t -> Vector t -> Matrix t | ||
224 | outer u v = multiply RowMajor r c | ||
225 | where r = matrixFromVector RowMajor 1 u | ||
226 | c = matrixFromVector RowMajor (dim v) v | ||
227 | |||
228 | dot :: (Field t, Num t) => Vector t -> Vector t -> t | ||
229 | dot u v = dat (multiply RowMajor r c) `at` 0 | ||
230 | where r = matrixFromVector RowMajor (dim u) u | ||
231 | c = matrixFromVector RowMajor 1 v | ||
232 | 243 | ||
233 | ---------------------------------------------------------------------- | 244 | ---------------------------------------------------------------------- |
234 | 245 | ||
@@ -251,14 +262,14 @@ subMatrixC (r0,c0) (rt,ct) x = | |||
251 | subMatrixR (r0,2*c0) (rt,2*ct) . | 262 | subMatrixR (r0,2*c0) (rt,2*ct) . |
252 | reshape (2*cols x) . asReal . cdat $ x | 263 | reshape (2*cols x) . asReal . cdat $ x |
253 | 264 | ||
254 | subMatrix :: (Field a) | 265 | --subMatrix :: (Field a) |
255 | => (Int,Int) -- ^ (r0,c0) starting position | 266 | -- => (Int,Int) -- ^ (r0,c0) starting position |
256 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | 267 | -- -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix |
257 | -> Matrix a -> Matrix a | 268 | -- -> Matrix a -> Matrix a |
258 | subMatrix st sz m | 269 | --subMatrix st sz m |
259 | | isReal (baseOf.dat) m = scast $ subMatrixR st sz (scast m) | 270 | -- | isReal (baseOf.dat) m = scast $ subMatrixR st sz (scast m) |
260 | | isComp (baseOf.dat) m = scast $ subMatrixC st sz (scast m) | 271 | -- | isComp (baseOf.dat) m = scast $ subMatrixC st sz (scast m) |
261 | | otherwise = subMatrixG st sz m | 272 | -- | otherwise = subMatrixG st sz m |
262 | 273 | ||
263 | subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x)) | 274 | subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x)) |
264 | where subList s n = take n . drop s | 275 | where subList s n = take n . drop s |
@@ -281,11 +292,11 @@ diagC = diagAux c_diagC "diagC" | |||
281 | foreign import ccall "aux.h diagC" c_diagC :: TCVCM | 292 | foreign import ccall "aux.h diagC" c_diagC :: TCVCM |
282 | 293 | ||
283 | -- | diagonal matrix from a vector | 294 | -- | diagonal matrix from a vector |
284 | diag :: (Num a, Field a) => Vector a -> Matrix a | 295 | --diag :: (Num a, Field a) => Vector a -> Matrix a |
285 | diag v | 296 | --diag v |
286 | | isReal (baseOf) v = scast $ diagR (scast v) | 297 | -- | isReal (baseOf) v = scast $ diagR (scast v) |
287 | | isComp (baseOf) v = scast $ diagC (scast v) | 298 | -- | isComp (baseOf) v = scast $ diagC (scast v) |
288 | | otherwise = diagG v | 299 | -- | otherwise = diagG v |
289 | 300 | ||
290 | diagG v = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]] | 301 | diagG v = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]] |
291 | where c = dim v | 302 | where c = dim v |
@@ -313,13 +324,34 @@ fromColumns :: Field t => [Vector t] -> Matrix t | |||
313 | fromColumns m = trans . fromRows $ m | 324 | fromColumns m = trans . fromRows $ m |
314 | 325 | ||
315 | -- | Creates a list of vectors from the columns of a matrix | 326 | -- | Creates a list of vectors from the columns of a matrix |
316 | toColumns :: Field t => Matrix t -> [Vector t] | 327 | toColumns :: Storable t => Matrix t -> [Vector t] |
317 | toColumns m = toRows . trans $ m | 328 | toColumns m = toRows . trans $ m |
318 | 329 | ||
319 | 330 | ||
320 | -- | Reads a matrix position. | 331 | -- | Reads a matrix position. |
321 | (@@>) :: Field t => Matrix t -> (Int,Int) -> t | 332 | (@@>) :: Storable t => Matrix t -> (Int,Int) -> t |
322 | infixl 9 @@> | 333 | infixl 9 @@> |
323 | m@M {rows = r, cols = c} @@> (i,j) | 334 | m@M {rows = r, cols = c} @@> (i,j) |
324 | | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" | 335 | | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" |
325 | | otherwise = cdat m `at` (i*c+j) | 336 | | otherwise = cdat m `at` (i*c+j) |
337 | |||
338 | ------------------------------------------------------------------ | ||
339 | |||
340 | constantR :: Double -> Int -> Vector Double | ||
341 | constantR = constantAux cconstantR | ||
342 | |||
343 | constantC :: Complex Double -> Int -> Vector (Complex Double) | ||
344 | constantC = constantAux cconstantC | ||
345 | |||
346 | constantAux fun x n = unsafePerformIO $ do | ||
347 | v <- createVector n | ||
348 | px <- newArray [x] | ||
349 | fun px // vec v // check "constantAux" [] | ||
350 | free px | ||
351 | return v | ||
352 | |||
353 | foreign import ccall safe "aux.h constantR" | ||
354 | cconstantR :: Ptr Double -> TV -- Double :> IO Int | ||
355 | |||
356 | foreign import ccall safe "aux.h constantC" | ||
357 | cconstantC :: Ptr (Complex Double) -> TCV -- Complex Double :> IO Int | ||
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs index 34132d8..6876685 100644 --- a/lib/Data/Packed/Internal/Tensor.hs +++ b/lib/Data/Packed/Internal/Tensor.hs | |||
@@ -1,3 +1,5 @@ | |||
1 | {-# OPTIONS_GHC -fglasgow-exts #-} | ||
2 | |||
1 | ----------------------------------------------------------------------------- | 3 | ----------------------------------------------------------------------------- |
2 | -- | | 4 | -- | |
3 | -- Module : Data.Packed.Internal.Tensor | 5 | -- Module : Data.Packed.Internal.Tensor |
@@ -19,6 +21,8 @@ import Foreign.Storable | |||
19 | import Data.List(sort,elemIndex,nub,foldl1',foldl') | 21 | import Data.List(sort,elemIndex,nub,foldl1',foldl') |
20 | import GSL.Vector | 22 | import GSL.Vector |
21 | import Data.Packed.Matrix | 23 | import Data.Packed.Matrix |
24 | import Data.Packed.Vector | ||
25 | import LinearAlgebra.Linear | ||
22 | 26 | ||
23 | data IdxType = Covariant | Contravariant deriving (Show,Eq) | 27 | data IdxType = Covariant | Contravariant deriving (Show,Eq) |
24 | 28 | ||
@@ -171,6 +175,7 @@ compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where | |||
171 | = t1 /= t2 && n1 == n2 | 175 | = t1 /= t2 && n1 == n2 |
172 | 176 | ||
173 | 177 | ||
178 | outer' u v = dat (outer u v) | ||
174 | 179 | ||
175 | -- | tensor product without without any contractions | 180 | -- | tensor product without without any contractions |
176 | rawProduct :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t | 181 | rawProduct :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t |
@@ -187,7 +192,7 @@ contraction2 t1 n1 t2 n2 = | |||
187 | m = multiply RowMajor (trans m1) m2 | 192 | m = multiply RowMajor (trans m1) m2 |
188 | 193 | ||
189 | -- | contraction of a tensor along two given indices | 194 | -- | contraction of a tensor along two given indices |
190 | contraction1 :: (Field t, Num t) => Tensor t -> IdxName -> IdxName -> Tensor t | 195 | contraction1 :: (Linear Vector t) => Tensor t -> IdxName -> IdxName -> Tensor t |
191 | contraction1 t name1 name2 = | 196 | contraction1 t name1 name2 = |
192 | if compatIdx t name1 t name2 | 197 | if compatIdx t name1 t name2 |
193 | then sumT y | 198 | then sumT y |
@@ -197,7 +202,7 @@ contraction1 t name1 name2 = | |||
197 | y = map head $ zipWith drop [0..] x | 202 | y = map head $ zipWith drop [0..] x |
198 | 203 | ||
199 | -- | contraction of a tensor along a repeated index | 204 | -- | contraction of a tensor along a repeated index |
200 | contraction1c :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t | 205 | contraction1c :: (Linear Vector t) => Tensor t -> IdxName -> Tensor t |
201 | contraction1c t n = contraction1 renamed n' n | 206 | contraction1c t n = contraction1 renamed n' n |
202 | where n' = n++"'" -- hmmm | 207 | where n' = n++"'" -- hmmm |
203 | renamed = withIdx t auxnames | 208 | renamed = withIdx t auxnames |
@@ -205,31 +210,31 @@ contraction1c t n = contraction1 renamed n' n | |||
205 | (h,_:r) = break (==n) (map idxName (dims t)) | 210 | (h,_:r) = break (==n) (map idxName (dims t)) |
206 | 211 | ||
207 | -- | alternative and inefficient version of contraction2 | 212 | -- | alternative and inefficient version of contraction2 |
208 | contraction2' :: (Field t, Enum t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t | 213 | contraction2' :: (Linear Vector t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t |
209 | contraction2' t1 n1 t2 n2 = | 214 | contraction2' t1 n1 t2 n2 = |
210 | if compatIdx t1 n1 t2 n2 | 215 | if compatIdx t1 n1 t2 n2 |
211 | then contraction1 (rawProduct t1 t2) n1 n2 | 216 | then contraction1 (rawProduct t1 t2) n1 n2 |
212 | else error "wrong contraction'" | 217 | else error "wrong contraction'" |
213 | 218 | ||
214 | -- | applies a sequence of contractions | 219 | -- | applies a sequence of contractions |
215 | contractions :: (Field t, Num t) => Tensor t -> [(IdxName, IdxName)] -> Tensor t | 220 | contractions :: (Linear Vector t) => Tensor t -> [(IdxName, IdxName)] -> Tensor t |
216 | contractions t pairs = foldl' contract1b t pairs | 221 | contractions t pairs = foldl' contract1b t pairs |
217 | where contract1b t (n1,n2) = contraction1 t n1 n2 | 222 | where contract1b t (n1,n2) = contraction1 t n1 n2 |
218 | 223 | ||
219 | -- | applies a sequence of contractions of same index | 224 | -- | applies a sequence of contractions of same index |
220 | contractionsC :: (Field t, Num t) => Tensor t -> [IdxName] -> Tensor t | 225 | contractionsC :: (Linear Vector t) => Tensor t -> [IdxName] -> Tensor t |
221 | contractionsC t is = foldl' contraction1c t is | 226 | contractionsC t is = foldl' contraction1c t is |
222 | 227 | ||
223 | 228 | ||
224 | -- | applies a contraction on the first indices of the tensors | 229 | -- | applies a contraction on the first indices of the tensors |
225 | contractionF :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t | 230 | contractionF :: (Linear Vector t) => Tensor t -> Tensor t -> Tensor t |
226 | contractionF t1 t2 = contraction2 t1 n1 t2 n2 | 231 | contractionF t1 t2 = contraction2 t1 n1 t2 n2 |
227 | where n1 = fn t1 | 232 | where n1 = fn t1 |
228 | n2 = fn t2 | 233 | n2 = fn t2 |
229 | fn = idxName . head . dims | 234 | fn = idxName . head . dims |
230 | 235 | ||
231 | -- | computes all compatible contractions of the product of two tensors that would arise if the index names were equal | 236 | -- | computes all compatible contractions of the product of two tensors that would arise if the index names were equal |
232 | possibleContractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t] | 237 | possibleContractions :: (Linear Vector t) => Tensor t -> Tensor t -> [Tensor t] |
233 | possibleContractions t1 t2 = [ contraction2 t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] | 238 | possibleContractions t1 t2 = [ contraction2 t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] |
234 | 239 | ||
235 | 240 | ||
@@ -242,7 +247,7 @@ desiredContractions1 t = [ n1 | (a,n1) <- x , (b,n2) <- x, a/=b, n1==n2] | |||
242 | where x = zip [0..] (names t) | 247 | where x = zip [0..] (names t) |
243 | 248 | ||
244 | -- | tensor product with the convention that repeated indices are contracted. | 249 | -- | tensor product with the convention that repeated indices are contracted. |
245 | mulT :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t | 250 | mulT :: (Linear Vector t) => Tensor t -> Tensor t -> Tensor t |
246 | mulT t1 t2 = r where | 251 | mulT t1 t2 = r where |
247 | t1r = contractionsC t1 (desiredContractions1 t1) | 252 | t1r = contractionsC t1 (desiredContractions1 t1) |
248 | t2r = contractionsC t2 (desiredContractions1 t2) | 253 | t2r = contractionsC t2 (desiredContractions1 t2) |
@@ -254,10 +259,10 @@ mulT t1 t2 = r where | |||
254 | ----------------------------------------------------------------- | 259 | ----------------------------------------------------------------- |
255 | 260 | ||
256 | -- | tensor addition (for tensors with the same structure) | 261 | -- | tensor addition (for tensors with the same structure) |
257 | addT :: (Num a, Field a) => Tensor a -> Tensor a -> Tensor a | 262 | addT :: (Linear Vector a) => Tensor a -> Tensor a -> Tensor a |
258 | addT a b = liftTensor2 add a b | 263 | addT a b = liftTensor2 add a b |
259 | 264 | ||
260 | sumT :: (Field a, Num a) => [Tensor a] -> Tensor a | 265 | sumT :: (Linear Vector a) => [Tensor a] -> Tensor a |
261 | sumT l = foldl1' addT l | 266 | sumT l = foldl1' addT l |
262 | 267 | ||
263 | ----------------------------------------------------------------- | 268 | ----------------------------------------------------------------- |
@@ -281,19 +286,19 @@ signature l | length (nub l) < length l = 0 | |||
281 | | otherwise = -1 | 286 | | otherwise = -1 |
282 | 287 | ||
283 | 288 | ||
284 | sym :: (Field t, Num t) => Tensor t -> Tensor t | 289 | sym :: (Linear Vector t) => Tensor t -> Tensor t |
285 | sym t = T (dims t) (ten (sym' (withIdx t seqind))) | 290 | sym t = T (dims t) (ten (sym' (withIdx t seqind))) |
286 | where sym' t = sumT $ map (flip tridx t) (perms (names t)) | 291 | where sym' t = sumT $ map (flip tridx t) (perms (names t)) |
287 | where nms = map idxName . dims | 292 | where nms = map idxName . dims |
288 | 293 | ||
289 | antisym :: (Field t, Num t) => Tensor t -> Tensor t | 294 | antisym :: (Linear Vector t) => Tensor t -> Tensor t |
290 | antisym t = T (dims t) (ten (antisym' (withIdx t seqind))) | 295 | antisym t = T (dims t) (ten (antisym' (withIdx t seqind))) |
291 | where antisym' t = sumT $ map (scsig . flip tridx t) (perms (names t)) | 296 | where antisym' t = sumT $ map (scsig . flip tridx t) (perms (names t)) |
292 | scsig t = scalar (signature (nms t)) `rawProduct` t | 297 | scsig t = scalar (signature (nms t)) `rawProduct` t |
293 | where nms = map idxName . dims | 298 | where nms = map idxName . dims |
294 | 299 | ||
295 | -- | the wedge product of two tensors (implemented as the antisymmetrization of the ordinary tensor product). | 300 | -- | the wedge product of two tensors (implemented as the antisymmetrization of the ordinary tensor product). |
296 | wedge :: (Field t, Fractional t) => Tensor t -> Tensor t -> Tensor t | 301 | wedge :: (Linear Vector t, Fractional t) => Tensor t -> Tensor t -> Tensor t |
297 | wedge a b = antisym (rawProduct (norper a) (norper b)) | 302 | wedge a b = antisym (rawProduct (norper a) (norper b)) |
298 | where norper t = rawProduct t (scalar (recip $ fromIntegral $ fact (rank t))) | 303 | where norper t = rawProduct t (scalar (recip $ fromIntegral $ fact (rank t))) |
299 | 304 | ||
@@ -313,19 +318,19 @@ seqind :: [String] | |||
313 | seqind = map show [1..] | 318 | seqind = map show [1..] |
314 | 319 | ||
315 | -- | completely antisymmetric covariant tensor of dimension n | 320 | -- | completely antisymmetric covariant tensor of dimension n |
316 | leviCivita :: (Field t, Num t) => Int -> Tensor t | 321 | leviCivita :: (Linear Vector t) => Int -> Tensor t |
317 | leviCivita n = antisym $ foldl1 rawProduct $ zipWith withIdx auxbase seqind' | 322 | leviCivita n = antisym $ foldl1 rawProduct $ zipWith withIdx auxbase seqind' |
318 | where auxbase = map tc (toRows (ident n)) | 323 | where auxbase = map tc (toRows (ident n)) |
319 | tc = tensorFromVector Covariant | 324 | tc = tensorFromVector Covariant |
320 | 325 | ||
321 | -- | contraction of leviCivita with a list of vectors (and raise with euclidean metric) | 326 | -- | contraction of leviCivita with a list of vectors (and raise with euclidean metric) |
322 | innerLevi :: (Num t, Field t) => [Tensor t] -> Tensor t | 327 | innerLevi :: (Linear Vector t) => [Tensor t] -> Tensor t |
323 | innerLevi vs = raise $ foldl' contractionF (leviCivita n) vs | 328 | innerLevi vs = raise $ foldl' contractionF (leviCivita n) vs |
324 | where n = idxDim . head . dims . head $ vs | 329 | where n = idxDim . head . dims . head $ vs |
325 | 330 | ||
326 | 331 | ||
327 | -- | obtains the dual of a multivector (with euclidean metric) | 332 | -- | obtains the dual of a multivector (with euclidean metric) |
328 | dual :: (Field t, Fractional t) => Tensor t -> Tensor t | 333 | dual :: (Linear Vector t, Fractional t) => Tensor t -> Tensor t |
329 | dual t = raise $ leviCivita n `mulT` withIdx t seqind `rawProduct` x | 334 | dual t = raise $ leviCivita n `mulT` withIdx t seqind `rawProduct` x |
330 | where n = idxDim . head . dims $ t | 335 | where n = idxDim . head . dims $ t |
331 | x = scalar (recip $ fromIntegral $ fact (rank t)) | 336 | x = scalar (recip $ fromIntegral $ fact (rank t)) |
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs index ab93577..f2646a4 100644 --- a/lib/Data/Packed/Internal/Vector.hs +++ b/lib/Data/Packed/Internal/Vector.hs | |||
@@ -1,4 +1,4 @@ | |||
1 | {-# OPTIONS_GHC -fglasgow-exts #-} | 1 | {-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} |
2 | ----------------------------------------------------------------------------- | 2 | ----------------------------------------------------------------------------- |
3 | -- | | 3 | -- | |
4 | -- Module : Data.Packed.Internal.Vector | 4 | -- Module : Data.Packed.Internal.Vector |
@@ -19,6 +19,8 @@ import Data.Packed.Internal.Common | |||
19 | import Foreign | 19 | import Foreign |
20 | import Complex | 20 | import Complex |
21 | import Control.Monad(when) | 21 | import Control.Monad(when) |
22 | import Data.List(transpose) | ||
23 | import Debug.Trace(trace) | ||
22 | 24 | ||
23 | type Vc t s = Int -> Ptr t -> s | 25 | type Vc t s = Int -> Ptr t -> s |
24 | -- not yet admitted by my haddock version | 26 | -- not yet admitted by my haddock version |
@@ -28,7 +30,7 @@ type Vc t s = Int -> Ptr t -> s | |||
28 | vec :: Vector t -> (Vc t s) -> s | 30 | vec :: Vector t -> (Vc t s) -> s |
29 | vec v f = f (dim v) (ptr v) | 31 | vec v f = f (dim v) (ptr v) |
30 | 32 | ||
31 | baseOf v = (v `at` 0) | 33 | --baseOf v = (v `at` 0) |
32 | 34 | ||
33 | createVector :: Storable a => Int -> IO (Vector a) | 35 | createVector :: Storable a => Int -> IO (Vector a) |
34 | createVector n = do | 36 | createVector n = do |
@@ -78,9 +80,16 @@ subVector' k l (v@V {dim=n, ptr=p, fptr=fp}) | |||
78 | | otherwise = v {dim=l, ptr=advancePtr p k} | 80 | | otherwise = v {dim=l, ptr=advancePtr p k} |
79 | 81 | ||
80 | 82 | ||
83 | -- | Reads a vector position. | ||
84 | (@>) :: Storable t => Vector t -> Int -> t | ||
85 | infixl 9 @> | ||
86 | (@>) = at | ||
87 | |||
88 | |||
89 | |||
81 | 90 | ||
82 | -- | creates a new Vector by joining a list of Vectors | 91 | -- | creates a new Vector by joining a list of Vectors |
83 | join :: Field t => [Vector t] -> Vector t | 92 | join :: Storable t => [Vector t] -> Vector t |
84 | join [] = error "joining zero vectors" | 93 | join [] = error "joining zero vectors" |
85 | join as = unsafePerformIO $ do | 94 | join as = unsafePerformIO $ do |
86 | let tot = sum (map dim as) | 95 | let tot = sum (map dim as) |
@@ -103,34 +112,11 @@ asComplex :: Vector Double -> Vector (Complex Double) | |||
103 | asComplex v = V { dim = dim v `div` 2, fptr = castForeignPtr (fptr v), ptr = castPtr (ptr v) } | 112 | asComplex v = V { dim = dim v `div` 2, fptr = castForeignPtr (fptr v), ptr = castPtr (ptr v) } |
104 | 113 | ||
105 | 114 | ||
106 | constantG x n = fromList (replicate n x) | 115 | ---------------------------------------------------------------- |
107 | |||
108 | constantR :: Double -> Int -> Vector Double | ||
109 | constantR = constantAux cconstantR | ||
110 | |||
111 | constantC :: Complex Double -> Int -> Vector (Complex Double) | ||
112 | constantC = constantAux cconstantC | ||
113 | |||
114 | constantAux fun x n = unsafePerformIO $ do | ||
115 | v <- createVector n | ||
116 | px <- newArray [x] | ||
117 | fun px // vec v // check "constantAux" [] | ||
118 | free px | ||
119 | return v | ||
120 | |||
121 | foreign import ccall safe "aux.h constantR" | ||
122 | cconstantR :: Ptr Double -> TV -- Double :> IO Int | ||
123 | |||
124 | foreign import ccall safe "aux.h constantC" | ||
125 | cconstantC :: Ptr (Complex Double) -> TCV -- Complex Double :> IO Int | ||
126 | |||
127 | constant :: Field a => a -> Int -> Vector a | ||
128 | constant x n | isReal id x = scast $ constantR (scast x) n | ||
129 | | isComp id x = scast $ constantC (scast x) n | ||
130 | | otherwise = constantG x n | ||
131 | 116 | ||
132 | liftVector :: (Storable a, Storable b) => (a-> b) -> Vector a -> Vector b | 117 | liftVector :: (Storable a, Storable b) => (a-> b) -> Vector a -> Vector b |
133 | liftVector f = fromList . map f . toList | 118 | liftVector f = fromList . map f . toList |
134 | 119 | ||
135 | liftVector2 :: (Storable a, Storable b, Storable c) => (a-> b -> c) -> Vector a -> Vector b -> Vector c | 120 | liftVector2 :: (Storable a, Storable b, Storable c) => (a-> b -> c) -> Vector a -> Vector b -> Vector c |
136 | liftVector2 f u v = fromList $ zipWith f (toList u) (toList v) | 121 | liftVector2 f u v = fromList $ zipWith f (toList u) (toList v) |
122 | |||
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index 2033dc7..2e8cb3d 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs | |||
@@ -134,3 +134,18 @@ asRow v = reshape (dim v) v | |||
134 | 134 | ||
135 | asColumn :: Field a => Vector a -> Matrix a | 135 | asColumn :: Field a => Vector a -> Matrix a |
136 | asColumn v = reshape 1 v | 136 | asColumn v = reshape 1 v |
137 | |||
138 | ------------------------------------------------ | ||
139 | |||
140 | {- | Outer product of two vectors. | ||
141 | |||
142 | @\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3] | ||
143 | (3><3) | ||
144 | [ 5.0, 2.0, 3.0 | ||
145 | , 10.0, 4.0, 6.0 | ||
146 | , 15.0, 6.0, 9.0 ]@ | ||
147 | -} | ||
148 | outer :: (Num t, Field t) => Vector t -> Vector t -> Matrix t | ||
149 | outer u v = multiply RowMajor r c | ||
150 | where r = matrixFromVector RowMajor 1 u | ||
151 | c = matrixFromVector RowMajor (dim v) v | ||
diff --git a/lib/Data/Packed/Vector.hs b/lib/Data/Packed/Vector.hs index 27ba6a3..867b77b 100644 --- a/lib/Data/Packed/Vector.hs +++ b/lib/Data/Packed/Vector.hs | |||
@@ -27,7 +27,7 @@ module Data.Packed.Vector ( | |||
27 | 27 | ||
28 | import Data.Packed.Internal | 28 | import Data.Packed.Internal |
29 | import Complex | 29 | import Complex |
30 | import GSL.Vector | 30 | --import GSL.Vector |
31 | 31 | ||
32 | -- | creates a complex vector from vectors with real and imaginary parts | 32 | -- | creates a complex vector from vectors with real and imaginary parts |
33 | toComplex :: (Vector Double, Vector Double) -> Vector (Complex Double) | 33 | toComplex :: (Vector Double, Vector Double) -> Vector (Complex Double) |
@@ -50,7 +50,9 @@ linspace :: Int -> (Double, Double) -> Vector Double | |||
50 | linspace n (a,b) = fromList [a::Double,a+delta .. b] | 50 | linspace n (a,b) = fromList [a::Double,a+delta .. b] |
51 | where delta = (b-a)/(fromIntegral n -1) | 51 | where delta = (b-a)/(fromIntegral n -1) |
52 | 52 | ||
53 | -- | Reads a vector position. | 53 | |
54 | (@>) :: Field t => Vector t -> Int -> t | 54 | dot :: (Field t) => Vector t -> Vector t -> t |
55 | infixl 9 @> | 55 | dot u v = dat (multiply RowMajor r c) `at` 0 |
56 | (@>) = at | 56 | where r = matrixFromVector RowMajor (dim u) u |
57 | c = matrixFromVector RowMajor 1 v | ||
58 | |||