summaryrefslogtreecommitdiff
path: root/lib/Data/Packed/Internal/Tensor.hs
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2007-09-08 09:46:33 +0000
committerAlberto Ruiz <aruiz@um.es>2007-09-08 09:46:33 +0000
commit34380f2b5d7b048a4d68197f16a8db0e53742030 (patch)
tree444aff88cda5c247d49bac0d294d8cfb9ef7bf23 /lib/Data/Packed/Internal/Tensor.hs
parent0c38c1b0e122a56ea98c494e60ba90afe2688664 (diff)
type classes
Diffstat (limited to 'lib/Data/Packed/Internal/Tensor.hs')
-rw-r--r--lib/Data/Packed/Internal/Tensor.hs37
1 files changed, 21 insertions, 16 deletions
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs
index 34132d8..6876685 100644
--- a/lib/Data/Packed/Internal/Tensor.hs
+++ b/lib/Data/Packed/Internal/Tensor.hs
@@ -1,3 +1,5 @@
1{-# OPTIONS_GHC -fglasgow-exts #-}
2
1----------------------------------------------------------------------------- 3-----------------------------------------------------------------------------
2-- | 4-- |
3-- Module : Data.Packed.Internal.Tensor 5-- Module : Data.Packed.Internal.Tensor
@@ -19,6 +21,8 @@ import Foreign.Storable
19import Data.List(sort,elemIndex,nub,foldl1',foldl') 21import Data.List(sort,elemIndex,nub,foldl1',foldl')
20import GSL.Vector 22import GSL.Vector
21import Data.Packed.Matrix 23import Data.Packed.Matrix
24import Data.Packed.Vector
25import LinearAlgebra.Linear
22 26
23data IdxType = Covariant | Contravariant deriving (Show,Eq) 27data IdxType = Covariant | Contravariant deriving (Show,Eq)
24 28
@@ -171,6 +175,7 @@ compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where
171 = t1 /= t2 && n1 == n2 175 = t1 /= t2 && n1 == n2
172 176
173 177
178outer' u v = dat (outer u v)
174 179
175-- | tensor product without without any contractions 180-- | tensor product without without any contractions
176rawProduct :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t 181rawProduct :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t
@@ -187,7 +192,7 @@ contraction2 t1 n1 t2 n2 =
187 m = multiply RowMajor (trans m1) m2 192 m = multiply RowMajor (trans m1) m2
188 193
189-- | contraction of a tensor along two given indices 194-- | contraction of a tensor along two given indices
190contraction1 :: (Field t, Num t) => Tensor t -> IdxName -> IdxName -> Tensor t 195contraction1 :: (Linear Vector t) => Tensor t -> IdxName -> IdxName -> Tensor t
191contraction1 t name1 name2 = 196contraction1 t name1 name2 =
192 if compatIdx t name1 t name2 197 if compatIdx t name1 t name2
193 then sumT y 198 then sumT y
@@ -197,7 +202,7 @@ contraction1 t name1 name2 =
197 y = map head $ zipWith drop [0..] x 202 y = map head $ zipWith drop [0..] x
198 203
199-- | contraction of a tensor along a repeated index 204-- | contraction of a tensor along a repeated index
200contraction1c :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t 205contraction1c :: (Linear Vector t) => Tensor t -> IdxName -> Tensor t
201contraction1c t n = contraction1 renamed n' n 206contraction1c t n = contraction1 renamed n' n
202 where n' = n++"'" -- hmmm 207 where n' = n++"'" -- hmmm
203 renamed = withIdx t auxnames 208 renamed = withIdx t auxnames
@@ -205,31 +210,31 @@ contraction1c t n = contraction1 renamed n' n
205 (h,_:r) = break (==n) (map idxName (dims t)) 210 (h,_:r) = break (==n) (map idxName (dims t))
206 211
207-- | alternative and inefficient version of contraction2 212-- | alternative and inefficient version of contraction2
208contraction2' :: (Field t, Enum t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t 213contraction2' :: (Linear Vector t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t
209contraction2' t1 n1 t2 n2 = 214contraction2' t1 n1 t2 n2 =
210 if compatIdx t1 n1 t2 n2 215 if compatIdx t1 n1 t2 n2
211 then contraction1 (rawProduct t1 t2) n1 n2 216 then contraction1 (rawProduct t1 t2) n1 n2
212 else error "wrong contraction'" 217 else error "wrong contraction'"
213 218
214-- | applies a sequence of contractions 219-- | applies a sequence of contractions
215contractions :: (Field t, Num t) => Tensor t -> [(IdxName, IdxName)] -> Tensor t 220contractions :: (Linear Vector t) => Tensor t -> [(IdxName, IdxName)] -> Tensor t
216contractions t pairs = foldl' contract1b t pairs 221contractions t pairs = foldl' contract1b t pairs
217 where contract1b t (n1,n2) = contraction1 t n1 n2 222 where contract1b t (n1,n2) = contraction1 t n1 n2
218 223
219-- | applies a sequence of contractions of same index 224-- | applies a sequence of contractions of same index
220contractionsC :: (Field t, Num t) => Tensor t -> [IdxName] -> Tensor t 225contractionsC :: (Linear Vector t) => Tensor t -> [IdxName] -> Tensor t
221contractionsC t is = foldl' contraction1c t is 226contractionsC t is = foldl' contraction1c t is
222 227
223 228
224-- | applies a contraction on the first indices of the tensors 229-- | applies a contraction on the first indices of the tensors
225contractionF :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t 230contractionF :: (Linear Vector t) => Tensor t -> Tensor t -> Tensor t
226contractionF t1 t2 = contraction2 t1 n1 t2 n2 231contractionF t1 t2 = contraction2 t1 n1 t2 n2
227 where n1 = fn t1 232 where n1 = fn t1
228 n2 = fn t2 233 n2 = fn t2
229 fn = idxName . head . dims 234 fn = idxName . head . dims
230 235
231-- | computes all compatible contractions of the product of two tensors that would arise if the index names were equal 236-- | computes all compatible contractions of the product of two tensors that would arise if the index names were equal
232possibleContractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t] 237possibleContractions :: (Linear Vector t) => Tensor t -> Tensor t -> [Tensor t]
233possibleContractions t1 t2 = [ contraction2 t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] 238possibleContractions t1 t2 = [ contraction2 t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ]
234 239
235 240
@@ -242,7 +247,7 @@ desiredContractions1 t = [ n1 | (a,n1) <- x , (b,n2) <- x, a/=b, n1==n2]
242 where x = zip [0..] (names t) 247 where x = zip [0..] (names t)
243 248
244-- | tensor product with the convention that repeated indices are contracted. 249-- | tensor product with the convention that repeated indices are contracted.
245mulT :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t 250mulT :: (Linear Vector t) => Tensor t -> Tensor t -> Tensor t
246mulT t1 t2 = r where 251mulT t1 t2 = r where
247 t1r = contractionsC t1 (desiredContractions1 t1) 252 t1r = contractionsC t1 (desiredContractions1 t1)
248 t2r = contractionsC t2 (desiredContractions1 t2) 253 t2r = contractionsC t2 (desiredContractions1 t2)
@@ -254,10 +259,10 @@ mulT t1 t2 = r where
254----------------------------------------------------------------- 259-----------------------------------------------------------------
255 260
256-- | tensor addition (for tensors with the same structure) 261-- | tensor addition (for tensors with the same structure)
257addT :: (Num a, Field a) => Tensor a -> Tensor a -> Tensor a 262addT :: (Linear Vector a) => Tensor a -> Tensor a -> Tensor a
258addT a b = liftTensor2 add a b 263addT a b = liftTensor2 add a b
259 264
260sumT :: (Field a, Num a) => [Tensor a] -> Tensor a 265sumT :: (Linear Vector a) => [Tensor a] -> Tensor a
261sumT l = foldl1' addT l 266sumT l = foldl1' addT l
262 267
263----------------------------------------------------------------- 268-----------------------------------------------------------------
@@ -281,19 +286,19 @@ signature l | length (nub l) < length l = 0
281 | otherwise = -1 286 | otherwise = -1
282 287
283 288
284sym :: (Field t, Num t) => Tensor t -> Tensor t 289sym :: (Linear Vector t) => Tensor t -> Tensor t
285sym t = T (dims t) (ten (sym' (withIdx t seqind))) 290sym t = T (dims t) (ten (sym' (withIdx t seqind)))
286 where sym' t = sumT $ map (flip tridx t) (perms (names t)) 291 where sym' t = sumT $ map (flip tridx t) (perms (names t))
287 where nms = map idxName . dims 292 where nms = map idxName . dims
288 293
289antisym :: (Field t, Num t) => Tensor t -> Tensor t 294antisym :: (Linear Vector t) => Tensor t -> Tensor t
290antisym t = T (dims t) (ten (antisym' (withIdx t seqind))) 295antisym t = T (dims t) (ten (antisym' (withIdx t seqind)))
291 where antisym' t = sumT $ map (scsig . flip tridx t) (perms (names t)) 296 where antisym' t = sumT $ map (scsig . flip tridx t) (perms (names t))
292 scsig t = scalar (signature (nms t)) `rawProduct` t 297 scsig t = scalar (signature (nms t)) `rawProduct` t
293 where nms = map idxName . dims 298 where nms = map idxName . dims
294 299
295-- | the wedge product of two tensors (implemented as the antisymmetrization of the ordinary tensor product). 300-- | the wedge product of two tensors (implemented as the antisymmetrization of the ordinary tensor product).
296wedge :: (Field t, Fractional t) => Tensor t -> Tensor t -> Tensor t 301wedge :: (Linear Vector t, Fractional t) => Tensor t -> Tensor t -> Tensor t
297wedge a b = antisym (rawProduct (norper a) (norper b)) 302wedge a b = antisym (rawProduct (norper a) (norper b))
298 where norper t = rawProduct t (scalar (recip $ fromIntegral $ fact (rank t))) 303 where norper t = rawProduct t (scalar (recip $ fromIntegral $ fact (rank t)))
299 304
@@ -313,19 +318,19 @@ seqind :: [String]
313seqind = map show [1..] 318seqind = map show [1..]
314 319
315-- | completely antisymmetric covariant tensor of dimension n 320-- | completely antisymmetric covariant tensor of dimension n
316leviCivita :: (Field t, Num t) => Int -> Tensor t 321leviCivita :: (Linear Vector t) => Int -> Tensor t
317leviCivita n = antisym $ foldl1 rawProduct $ zipWith withIdx auxbase seqind' 322leviCivita n = antisym $ foldl1 rawProduct $ zipWith withIdx auxbase seqind'
318 where auxbase = map tc (toRows (ident n)) 323 where auxbase = map tc (toRows (ident n))
319 tc = tensorFromVector Covariant 324 tc = tensorFromVector Covariant
320 325
321-- | contraction of leviCivita with a list of vectors (and raise with euclidean metric) 326-- | contraction of leviCivita with a list of vectors (and raise with euclidean metric)
322innerLevi :: (Num t, Field t) => [Tensor t] -> Tensor t 327innerLevi :: (Linear Vector t) => [Tensor t] -> Tensor t
323innerLevi vs = raise $ foldl' contractionF (leviCivita n) vs 328innerLevi vs = raise $ foldl' contractionF (leviCivita n) vs
324 where n = idxDim . head . dims . head $ vs 329 where n = idxDim . head . dims . head $ vs
325 330
326 331
327-- | obtains the dual of a multivector (with euclidean metric) 332-- | obtains the dual of a multivector (with euclidean metric)
328dual :: (Field t, Fractional t) => Tensor t -> Tensor t 333dual :: (Linear Vector t, Fractional t) => Tensor t -> Tensor t
329dual t = raise $ leviCivita n `mulT` withIdx t seqind `rawProduct` x 334dual t = raise $ leviCivita n `mulT` withIdx t seqind `rawProduct` x
330 where n = idxDim . head . dims $ t 335 where n = idxDim . head . dims $ t
331 x = scalar (recip $ fromIntegral $ fact (rank t)) 336 x = scalar (recip $ fromIntegral $ fact (rank t))