diff options
author | Alberto Ruiz <aruiz@um.es> | 2007-09-08 09:46:33 +0000 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2007-09-08 09:46:33 +0000 |
commit | 34380f2b5d7b048a4d68197f16a8db0e53742030 (patch) | |
tree | 444aff88cda5c247d49bac0d294d8cfb9ef7bf23 /lib/Data/Packed/Internal/Tensor.hs | |
parent | 0c38c1b0e122a56ea98c494e60ba90afe2688664 (diff) |
type classes
Diffstat (limited to 'lib/Data/Packed/Internal/Tensor.hs')
-rw-r--r-- | lib/Data/Packed/Internal/Tensor.hs | 37 |
1 files changed, 21 insertions, 16 deletions
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs index 34132d8..6876685 100644 --- a/lib/Data/Packed/Internal/Tensor.hs +++ b/lib/Data/Packed/Internal/Tensor.hs | |||
@@ -1,3 +1,5 @@ | |||
1 | {-# OPTIONS_GHC -fglasgow-exts #-} | ||
2 | |||
1 | ----------------------------------------------------------------------------- | 3 | ----------------------------------------------------------------------------- |
2 | -- | | 4 | -- | |
3 | -- Module : Data.Packed.Internal.Tensor | 5 | -- Module : Data.Packed.Internal.Tensor |
@@ -19,6 +21,8 @@ import Foreign.Storable | |||
19 | import Data.List(sort,elemIndex,nub,foldl1',foldl') | 21 | import Data.List(sort,elemIndex,nub,foldl1',foldl') |
20 | import GSL.Vector | 22 | import GSL.Vector |
21 | import Data.Packed.Matrix | 23 | import Data.Packed.Matrix |
24 | import Data.Packed.Vector | ||
25 | import LinearAlgebra.Linear | ||
22 | 26 | ||
23 | data IdxType = Covariant | Contravariant deriving (Show,Eq) | 27 | data IdxType = Covariant | Contravariant deriving (Show,Eq) |
24 | 28 | ||
@@ -171,6 +175,7 @@ compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where | |||
171 | = t1 /= t2 && n1 == n2 | 175 | = t1 /= t2 && n1 == n2 |
172 | 176 | ||
173 | 177 | ||
178 | outer' u v = dat (outer u v) | ||
174 | 179 | ||
175 | -- | tensor product without without any contractions | 180 | -- | tensor product without without any contractions |
176 | rawProduct :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t | 181 | rawProduct :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t |
@@ -187,7 +192,7 @@ contraction2 t1 n1 t2 n2 = | |||
187 | m = multiply RowMajor (trans m1) m2 | 192 | m = multiply RowMajor (trans m1) m2 |
188 | 193 | ||
189 | -- | contraction of a tensor along two given indices | 194 | -- | contraction of a tensor along two given indices |
190 | contraction1 :: (Field t, Num t) => Tensor t -> IdxName -> IdxName -> Tensor t | 195 | contraction1 :: (Linear Vector t) => Tensor t -> IdxName -> IdxName -> Tensor t |
191 | contraction1 t name1 name2 = | 196 | contraction1 t name1 name2 = |
192 | if compatIdx t name1 t name2 | 197 | if compatIdx t name1 t name2 |
193 | then sumT y | 198 | then sumT y |
@@ -197,7 +202,7 @@ contraction1 t name1 name2 = | |||
197 | y = map head $ zipWith drop [0..] x | 202 | y = map head $ zipWith drop [0..] x |
198 | 203 | ||
199 | -- | contraction of a tensor along a repeated index | 204 | -- | contraction of a tensor along a repeated index |
200 | contraction1c :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t | 205 | contraction1c :: (Linear Vector t) => Tensor t -> IdxName -> Tensor t |
201 | contraction1c t n = contraction1 renamed n' n | 206 | contraction1c t n = contraction1 renamed n' n |
202 | where n' = n++"'" -- hmmm | 207 | where n' = n++"'" -- hmmm |
203 | renamed = withIdx t auxnames | 208 | renamed = withIdx t auxnames |
@@ -205,31 +210,31 @@ contraction1c t n = contraction1 renamed n' n | |||
205 | (h,_:r) = break (==n) (map idxName (dims t)) | 210 | (h,_:r) = break (==n) (map idxName (dims t)) |
206 | 211 | ||
207 | -- | alternative and inefficient version of contraction2 | 212 | -- | alternative and inefficient version of contraction2 |
208 | contraction2' :: (Field t, Enum t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t | 213 | contraction2' :: (Linear Vector t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t |
209 | contraction2' t1 n1 t2 n2 = | 214 | contraction2' t1 n1 t2 n2 = |
210 | if compatIdx t1 n1 t2 n2 | 215 | if compatIdx t1 n1 t2 n2 |
211 | then contraction1 (rawProduct t1 t2) n1 n2 | 216 | then contraction1 (rawProduct t1 t2) n1 n2 |
212 | else error "wrong contraction'" | 217 | else error "wrong contraction'" |
213 | 218 | ||
214 | -- | applies a sequence of contractions | 219 | -- | applies a sequence of contractions |
215 | contractions :: (Field t, Num t) => Tensor t -> [(IdxName, IdxName)] -> Tensor t | 220 | contractions :: (Linear Vector t) => Tensor t -> [(IdxName, IdxName)] -> Tensor t |
216 | contractions t pairs = foldl' contract1b t pairs | 221 | contractions t pairs = foldl' contract1b t pairs |
217 | where contract1b t (n1,n2) = contraction1 t n1 n2 | 222 | where contract1b t (n1,n2) = contraction1 t n1 n2 |
218 | 223 | ||
219 | -- | applies a sequence of contractions of same index | 224 | -- | applies a sequence of contractions of same index |
220 | contractionsC :: (Field t, Num t) => Tensor t -> [IdxName] -> Tensor t | 225 | contractionsC :: (Linear Vector t) => Tensor t -> [IdxName] -> Tensor t |
221 | contractionsC t is = foldl' contraction1c t is | 226 | contractionsC t is = foldl' contraction1c t is |
222 | 227 | ||
223 | 228 | ||
224 | -- | applies a contraction on the first indices of the tensors | 229 | -- | applies a contraction on the first indices of the tensors |
225 | contractionF :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t | 230 | contractionF :: (Linear Vector t) => Tensor t -> Tensor t -> Tensor t |
226 | contractionF t1 t2 = contraction2 t1 n1 t2 n2 | 231 | contractionF t1 t2 = contraction2 t1 n1 t2 n2 |
227 | where n1 = fn t1 | 232 | where n1 = fn t1 |
228 | n2 = fn t2 | 233 | n2 = fn t2 |
229 | fn = idxName . head . dims | 234 | fn = idxName . head . dims |
230 | 235 | ||
231 | -- | computes all compatible contractions of the product of two tensors that would arise if the index names were equal | 236 | -- | computes all compatible contractions of the product of two tensors that would arise if the index names were equal |
232 | possibleContractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t] | 237 | possibleContractions :: (Linear Vector t) => Tensor t -> Tensor t -> [Tensor t] |
233 | possibleContractions t1 t2 = [ contraction2 t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] | 238 | possibleContractions t1 t2 = [ contraction2 t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] |
234 | 239 | ||
235 | 240 | ||
@@ -242,7 +247,7 @@ desiredContractions1 t = [ n1 | (a,n1) <- x , (b,n2) <- x, a/=b, n1==n2] | |||
242 | where x = zip [0..] (names t) | 247 | where x = zip [0..] (names t) |
243 | 248 | ||
244 | -- | tensor product with the convention that repeated indices are contracted. | 249 | -- | tensor product with the convention that repeated indices are contracted. |
245 | mulT :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t | 250 | mulT :: (Linear Vector t) => Tensor t -> Tensor t -> Tensor t |
246 | mulT t1 t2 = r where | 251 | mulT t1 t2 = r where |
247 | t1r = contractionsC t1 (desiredContractions1 t1) | 252 | t1r = contractionsC t1 (desiredContractions1 t1) |
248 | t2r = contractionsC t2 (desiredContractions1 t2) | 253 | t2r = contractionsC t2 (desiredContractions1 t2) |
@@ -254,10 +259,10 @@ mulT t1 t2 = r where | |||
254 | ----------------------------------------------------------------- | 259 | ----------------------------------------------------------------- |
255 | 260 | ||
256 | -- | tensor addition (for tensors with the same structure) | 261 | -- | tensor addition (for tensors with the same structure) |
257 | addT :: (Num a, Field a) => Tensor a -> Tensor a -> Tensor a | 262 | addT :: (Linear Vector a) => Tensor a -> Tensor a -> Tensor a |
258 | addT a b = liftTensor2 add a b | 263 | addT a b = liftTensor2 add a b |
259 | 264 | ||
260 | sumT :: (Field a, Num a) => [Tensor a] -> Tensor a | 265 | sumT :: (Linear Vector a) => [Tensor a] -> Tensor a |
261 | sumT l = foldl1' addT l | 266 | sumT l = foldl1' addT l |
262 | 267 | ||
263 | ----------------------------------------------------------------- | 268 | ----------------------------------------------------------------- |
@@ -281,19 +286,19 @@ signature l | length (nub l) < length l = 0 | |||
281 | | otherwise = -1 | 286 | | otherwise = -1 |
282 | 287 | ||
283 | 288 | ||
284 | sym :: (Field t, Num t) => Tensor t -> Tensor t | 289 | sym :: (Linear Vector t) => Tensor t -> Tensor t |
285 | sym t = T (dims t) (ten (sym' (withIdx t seqind))) | 290 | sym t = T (dims t) (ten (sym' (withIdx t seqind))) |
286 | where sym' t = sumT $ map (flip tridx t) (perms (names t)) | 291 | where sym' t = sumT $ map (flip tridx t) (perms (names t)) |
287 | where nms = map idxName . dims | 292 | where nms = map idxName . dims |
288 | 293 | ||
289 | antisym :: (Field t, Num t) => Tensor t -> Tensor t | 294 | antisym :: (Linear Vector t) => Tensor t -> Tensor t |
290 | antisym t = T (dims t) (ten (antisym' (withIdx t seqind))) | 295 | antisym t = T (dims t) (ten (antisym' (withIdx t seqind))) |
291 | where antisym' t = sumT $ map (scsig . flip tridx t) (perms (names t)) | 296 | where antisym' t = sumT $ map (scsig . flip tridx t) (perms (names t)) |
292 | scsig t = scalar (signature (nms t)) `rawProduct` t | 297 | scsig t = scalar (signature (nms t)) `rawProduct` t |
293 | where nms = map idxName . dims | 298 | where nms = map idxName . dims |
294 | 299 | ||
295 | -- | the wedge product of two tensors (implemented as the antisymmetrization of the ordinary tensor product). | 300 | -- | the wedge product of two tensors (implemented as the antisymmetrization of the ordinary tensor product). |
296 | wedge :: (Field t, Fractional t) => Tensor t -> Tensor t -> Tensor t | 301 | wedge :: (Linear Vector t, Fractional t) => Tensor t -> Tensor t -> Tensor t |
297 | wedge a b = antisym (rawProduct (norper a) (norper b)) | 302 | wedge a b = antisym (rawProduct (norper a) (norper b)) |
298 | where norper t = rawProduct t (scalar (recip $ fromIntegral $ fact (rank t))) | 303 | where norper t = rawProduct t (scalar (recip $ fromIntegral $ fact (rank t))) |
299 | 304 | ||
@@ -313,19 +318,19 @@ seqind :: [String] | |||
313 | seqind = map show [1..] | 318 | seqind = map show [1..] |
314 | 319 | ||
315 | -- | completely antisymmetric covariant tensor of dimension n | 320 | -- | completely antisymmetric covariant tensor of dimension n |
316 | leviCivita :: (Field t, Num t) => Int -> Tensor t | 321 | leviCivita :: (Linear Vector t) => Int -> Tensor t |
317 | leviCivita n = antisym $ foldl1 rawProduct $ zipWith withIdx auxbase seqind' | 322 | leviCivita n = antisym $ foldl1 rawProduct $ zipWith withIdx auxbase seqind' |
318 | where auxbase = map tc (toRows (ident n)) | 323 | where auxbase = map tc (toRows (ident n)) |
319 | tc = tensorFromVector Covariant | 324 | tc = tensorFromVector Covariant |
320 | 325 | ||
321 | -- | contraction of leviCivita with a list of vectors (and raise with euclidean metric) | 326 | -- | contraction of leviCivita with a list of vectors (and raise with euclidean metric) |
322 | innerLevi :: (Num t, Field t) => [Tensor t] -> Tensor t | 327 | innerLevi :: (Linear Vector t) => [Tensor t] -> Tensor t |
323 | innerLevi vs = raise $ foldl' contractionF (leviCivita n) vs | 328 | innerLevi vs = raise $ foldl' contractionF (leviCivita n) vs |
324 | where n = idxDim . head . dims . head $ vs | 329 | where n = idxDim . head . dims . head $ vs |
325 | 330 | ||
326 | 331 | ||
327 | -- | obtains the dual of a multivector (with euclidean metric) | 332 | -- | obtains the dual of a multivector (with euclidean metric) |
328 | dual :: (Field t, Fractional t) => Tensor t -> Tensor t | 333 | dual :: (Linear Vector t, Fractional t) => Tensor t -> Tensor t |
329 | dual t = raise $ leviCivita n `mulT` withIdx t seqind `rawProduct` x | 334 | dual t = raise $ leviCivita n `mulT` withIdx t seqind `rawProduct` x |
330 | where n = idxDim . head . dims $ t | 335 | where n = idxDim . head . dims $ t |
331 | x = scalar (recip $ fromIntegral $ fact (rank t)) | 336 | x = scalar (recip $ fromIntegral $ fact (rank t)) |