summaryrefslogtreecommitdiff
path: root/lib/Data/Packed/Internal
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2007-07-25 10:01:40 +0000
committerAlberto Ruiz <aruiz@um.es>2007-07-25 10:01:40 +0000
commit34b094b7589bf400114d802549fcba3ce1481683 (patch)
treea09b974ad7080bb7fb6800dfc7f57d26842cd934 /lib/Data/Packed/Internal
parent3a058b3707eecaac8ee3d964baf3e1ea1faabf51 (diff)
tensor refactoring
Diffstat (limited to 'lib/Data/Packed/Internal')
-rw-r--r--lib/Data/Packed/Internal/Tensor.hs249
1 files changed, 159 insertions, 90 deletions
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs
index 4430ebc..34132d8 100644
--- a/lib/Data/Packed/Internal/Tensor.hs
+++ b/lib/Data/Packed/Internal/Tensor.hs
@@ -8,7 +8,7 @@
8-- Stability : provisional 8-- Stability : provisional
9-- Portability : portable (uses FFI) 9-- Portability : portable (uses FFI)
10-- 10--
11-- basic tensor operations 11-- support for basic tensor operations
12-- 12--
13----------------------------------------------------------------------------- 13-----------------------------------------------------------------------------
14 14
@@ -26,28 +26,83 @@ type IdxName = String
26 26
27data IdxDesc = IdxDesc { idxDim :: Int, 27data IdxDesc = IdxDesc { idxDim :: Int,
28 idxType :: IdxType, 28 idxType :: IdxType,
29 idxName :: IdxName } deriving Show 29 idxName :: IdxName } deriving Eq
30
31instance Show IdxDesc where
32 show (IdxDesc n t name) = name ++ sym t ++"["++show n++"]"
33 where sym Covariant = "_"
34 sym Contravariant = "^"
35
30 36
31data Tensor t = T { dims :: [IdxDesc] 37data Tensor t = T { dims :: [IdxDesc]
32 , ten :: Vector t 38 , ten :: Vector t
33 } 39 }
34 40
35-- | tensor rank (number of indices) 41-- | returns the coordinates of a tensor in row - major order
36rank :: Tensor t -> Int 42coords :: Tensor t -> Vector t
37rank = length . dims 43coords = ten
38 44
39instance (Show a,Storable a) => Show (Tensor a) where 45instance (Show a, Field a) => Show (Tensor a) where
40 show T {dims = [], ten = t} = "scalar "++show (t `at` 0) 46 show T {dims = [], ten = t} = "scalar "++show (t `at` 0)
41 show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) 47 show t = "("++shdims (dims t) ++") "++ showdt t
48
49asMatrix t = reshape (idxDim $ dims t!!1) (ten t)
50
51showdt t | rank t == 1 = show (toList (ten t))
52 | rank t == 2 = ('\n':) . dsp . map (map show) . toLists $ asMatrix $ t
53 | otherwise = concatMap showdt $ parts t (head (names t))
42 54
43-- | a nice description of the tensor structure 55-- | a nice description of the tensor structure
44shdims :: [IdxDesc] -> String 56shdims :: [IdxDesc] -> String
45shdims [] = "" 57shdims [] = ""
46shdims [d] = shdim d 58shdims [d] = show d
47shdims (d:ds) = shdim d ++ "><"++ shdims ds 59shdims (d:ds) = show d ++ "><"++ shdims ds
48shdim (IdxDesc n t name) = name ++ sym t ++"["++show n++"]" 60
49 where sym Covariant = "_" 61-- | tensor rank (number of indices)
50 sym Contravariant = "^" 62rank :: Tensor t -> Int
63rank = length . dims
64
65names :: Tensor t -> [IdxName]
66names t = map idxName (dims t)
67
68-- | number of contravariant and covariant indices
69structure :: Tensor t -> (Int,Int)
70structure t = (rank t - n, n) where
71 n = length $ filter isCov (dims t)
72 isCov d = idxType d == Covariant
73
74-- | creates a rank-zero tensor from a scalar
75scalar :: Storable t => t -> Tensor t
76scalar x = T [] (fromList [x])
77
78-- | Creates a tensor from a signed list of dimensions (positive = contravariant, negative = covariant) and a Vector containing the coordinates in row major order.
79tensor :: [Int] -> Vector a -> Tensor a
80tensor dssig vec = T d v `withIdx` seqind where
81 n = product (map abs dssig)
82 v = if dim vec == n then vec else error "wrong arguments for tensor"
83 d = map cr dssig
84 cr n | n > 0 = IdxDesc {idxName = "", idxDim = n, idxType = Contravariant}
85 | n < 0 = IdxDesc {idxName = "", idxDim = -n, idxType = Covariant }
86
87
88tensorFromVector :: IdxType -> Vector t -> Tensor t
89tensorFromVector tp v = T {dims = [IdxDesc (dim v) tp "1"], ten = v}
90
91tensorFromMatrix :: IdxType -> IdxType -> Matrix t -> Tensor t
92tensorFromMatrix tpr tpc m = T {dims = [IdxDesc (rows m) tpr "1",IdxDesc (cols m) tpc "2"]
93 , ten = cdat m}
94
95
96liftTensor :: (Vector a -> Vector b) -> Tensor a -> Tensor b
97liftTensor f (T d v) = T d (f v)
98
99liftTensor2 :: (Vector a -> Vector b -> Vector c) -> Tensor a -> Tensor b -> Tensor c
100liftTensor2 f (T d1 v1) (T d2 v2) | compat d1 d2 = T d1 (f v1 v2)
101 | otherwise = error "liftTensor2 with incompatible tensors"
102 where compat a b = length a == length b
103
104
105
51 106
52-- | express the tensor as a matrix with the given index in columns 107-- | express the tensor as a matrix with the given index in columns
53findIdx :: (Field t) => IdxName -> Tensor t 108findIdx :: (Field t) => IdxName -> Tensor t
@@ -65,6 +120,32 @@ putFirstIdx name t = (nd,m')
65 nd = d2++d1 120 nd = d2++d1
66 c = dim (ten t) `div` (idxDim $ head d2) 121 c = dim (ten t) `div` (idxDim $ head d2)
67 122
123
124-- | renames all the indices in the current order (repeated indices may get different names)
125withIdx :: Tensor t -> [IdxName] -> Tensor t
126withIdx (T d v) l = T d' v
127 where d' = zipWith f d l
128 f i n = i {idxName=n}
129
130
131-- | raises or lowers all the indices of a tensor (with euclidean metric)
132raise :: Tensor t -> Tensor t
133raise (T d v) = T (map raise' d) v
134 where raise' idx@IdxDesc {idxType = Covariant } = idx {idxType = Contravariant}
135 raise' idx@IdxDesc {idxType = Contravariant } = idx {idxType = Covariant}
136
137
138-- | index transposition to a desired order. You can specify only a subset of the indices, which will be moved to the front of indices list
139tridx :: (Field t) => [IdxName] -> Tensor t -> Tensor t
140tridx [] t = t
141tridx (name:rest) t = T (d:ds) (join ts) where
142 ((_,d:_),_) = findIdx name t
143 ps = map (tridx rest) (parts t name)
144 ts = map ten ps
145 ds = dims (head ps)
146
147
148
68-- | extracts a given part of a tensor 149-- | extracts a given part of a tensor
69part :: (Field t) => Tensor t -> (IdxName, Int) -> Tensor t 150part :: (Field t) => Tensor t -> (IdxName, Int) -> Tensor t
70part t (name,k) = if k<0 || k>=l 151part t (name,k) = if k<0 || k>=l
@@ -80,6 +161,17 @@ parts t name = map f (toRows m)
80 l = idxDim d 161 l = idxDim d
81 f t = T {dims=ds, ten=t} 162 f t = T {dims=ds, ten=t}
82 163
164
165compatIdx :: (Field t1, Field t) => Tensor t1 -> IdxName -> Tensor t -> IdxName -> Bool
166compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where
167 d1 = head $ snd $ fst $ findIdx n1 t1
168 d2 = head $ snd $ fst $ findIdx n2 t2
169 compatIdxAux IdxDesc {idxDim = n1, idxType = t1}
170 IdxDesc {idxDim = n2, idxType = t2}
171 = t1 /= t2 && n1 == n2
172
173
174
83-- | tensor product without without any contractions 175-- | tensor product without without any contractions
84rawProduct :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t 176rawProduct :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t
85rawProduct (T d1 v1) (T d2 v2) = T (d1++d2) (outer' v1 v2) 177rawProduct (T d1 v1) (T d2 v2) = T (d1++d2) (outer' v1 v2)
@@ -98,7 +190,7 @@ contraction2 t1 n1 t2 n2 =
98contraction1 :: (Field t, Num t) => Tensor t -> IdxName -> IdxName -> Tensor t 190contraction1 :: (Field t, Num t) => Tensor t -> IdxName -> IdxName -> Tensor t
99contraction1 t name1 name2 = 191contraction1 t name1 name2 =
100 if compatIdx t name1 t name2 192 if compatIdx t name1 t name2
101 then addT y 193 then sumT y
102 else error $ "wrong contraction1: "++(shdims$dims$t)++" "++name1++" "++name2 194 else error $ "wrong contraction1: "++(shdims$dims$t)++" "++name1++" "++name2
103 where d = dims (head y) 195 where d = dims (head y)
104 x = (map (flip parts name2) (parts t name1)) 196 x = (map (flip parts name2) (parts t name1))
@@ -120,14 +212,17 @@ contraction2' t1 n1 t2 n2 =
120 else error "wrong contraction'" 212 else error "wrong contraction'"
121 213
122-- | applies a sequence of contractions 214-- | applies a sequence of contractions
215contractions :: (Field t, Num t) => Tensor t -> [(IdxName, IdxName)] -> Tensor t
123contractions t pairs = foldl' contract1b t pairs 216contractions t pairs = foldl' contract1b t pairs
124 where contract1b t (n1,n2) = contraction1 t n1 n2 217 where contract1b t (n1,n2) = contraction1 t n1 n2
125 218
126-- | applies a sequence of contractions of same index 219-- | applies a sequence of contractions of same index
220contractionsC :: (Field t, Num t) => Tensor t -> [IdxName] -> Tensor t
127contractionsC t is = foldl' contraction1c t is 221contractionsC t is = foldl' contraction1c t is
128 222
129 223
130-- | applies a contraction on the first indices of the tensors 224-- | applies a contraction on the first indices of the tensors
225contractionF :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t
131contractionF t1 t2 = contraction2 t1 n1 t2 n2 226contractionF t1 t2 = contraction2 t1 n1 t2 n2
132 where n1 = fn t1 227 where n1 = fn t1
133 n2 = fn t2 228 n2 = fn t2
@@ -138,45 +233,42 @@ possibleContractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t]
138possibleContractions t1 t2 = [ contraction2 t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] 233possibleContractions t1 t2 = [ contraction2 t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ]
139 234
140 235
141liftTensor f (T d v) = T d (f v)
142
143liftTensor2 f (T d1 v1) (T d2 v2) | compat d1 d2 = T d1 (f v1 v2)
144 | otherwise = error "liftTensor2 with incompatible tensors"
145 where compat a b = length a == length b
146 236
237desiredContractions2 :: Tensor t -> Tensor t1 -> [(IdxName, IdxName)]
238desiredContractions2 t1 t2 = [ (n1,n2) | n1 <- names t1, n2 <- names t2, n1==n2]
147 239
148a |+| b = liftTensor2 add a b 240desiredContractions1 :: Tensor t -> [IdxName]
149addT l = foldl1' (|+|) l 241desiredContractions1 t = [ n1 | (a,n1) <- x , (b,n2) <- x, a/=b, n1==n2]
242 where x = zip [0..] (names t)
150 243
151-- | index transposition to a desired order. You can specify only a subset of the indices, which will be moved to the front of indices list 244-- | tensor product with the convention that repeated indices are contracted.
152tridx :: (Field t) => [IdxName] -> Tensor t -> Tensor t 245mulT :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t
153tridx [] t = t 246mulT t1 t2 = r where
154tridx (name:rest) t = T (d:ds) (join ts) where 247 t1r = contractionsC t1 (desiredContractions1 t1)
155 ((_,d:_),_) = findIdx name t 248 t2r = contractionsC t2 (desiredContractions1 t2)
156 ps = map (tridx rest) (parts t name) 249 cs = desiredContractions2 t1r t2r
157 ts = map ten ps 250 r = case cs of
158 ds = dims (head ps) 251 [] -> rawProduct t1r t2r
252 (n1,n2):as -> contractionsC (contraction2 t1r n1 t2r n2) (map fst as)
159 253
160compatIdxAux :: IdxDesc -> IdxDesc -> Bool 254-----------------------------------------------------------------
161compatIdxAux IdxDesc {idxDim = n1, idxType = t1}
162 IdxDesc {idxDim = n2, idxType = t2}
163 = t1 /= t2 && n1 == n2
164 255
165compatIdx :: (Field t1, Field t) => Tensor t1 -> IdxName -> Tensor t -> IdxName -> Bool 256-- | tensor addition (for tensors with the same structure)
166compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where 257addT :: (Num a, Field a) => Tensor a -> Tensor a -> Tensor a
167 d1 = head $ snd $ fst $ findIdx n1 t1 258addT a b = liftTensor2 add a b
168 d2 = head $ snd $ fst $ findIdx n2 t2
169 259
170names :: Tensor t -> [IdxName] 260sumT :: (Field a, Num a) => [Tensor a] -> Tensor a
171names t = map idxName (dims t) 261sumT l = foldl1' addT l
172 262
263-----------------------------------------------------------------
173 264
174-- sent to Haskell-Cafe by Sebastian Sylvan 265-- sent to Haskell-Cafe by Sebastian Sylvan
175perms :: [t] -> [[t]] 266perms :: [t] -> [[t]]
176perms [x] = [[x]] 267perms [x] = [[x]]
177perms xs = [y:ps | (y,ys) <- selections xs , ps <- perms ys] 268perms xs = [y:ps | (y,ys) <- selections xs , ps <- perms ys]
178selections [] = [] 269 where
179selections (x:xs) = (x,xs) : [(y,x:ys) | (y,ys) <- selections xs] 270 selections [] = []
271 selections (x:xs) = (x,xs) : [(y,x:ys) | (y,ys) <- selections xs]
180 272
181interchanges :: (Ord a) => [a] -> Int 273interchanges :: (Ord a) => [a] -> Int
182interchanges ls = sum (map (count ls) ls) 274interchanges ls = sum (map (count ls) ls)
@@ -188,58 +280,59 @@ signature l | length (nub l) < length l = 0
188 | even (interchanges l) = 1 280 | even (interchanges l) = 1
189 | otherwise = -1 281 | otherwise = -1
190 282
191scalar x = T [] (fromList [x])
192tensorFromVector tp v = T {dims = [IdxDesc (dim v) tp "1"], ten = v}
193tensorFromMatrix tpr tpc m = T {dims = [IdxDesc (rows m) tpr "1",IdxDesc (cols m) tpc "2"]
194 , ten = cdat m}
195 283
196tvector v = tensorFromVector Contravariant v 284sym :: (Field t, Num t) => Tensor t -> Tensor t
197tcovector v = tensorFromVector Covariant v 285sym t = T (dims t) (ten (sym' (withIdx t seqind)))
286 where sym' t = sumT $ map (flip tridx t) (perms (names t))
287 where nms = map idxName . dims
198 288
289antisym :: (Field t, Num t) => Tensor t -> Tensor t
199antisym t = T (dims t) (ten (antisym' (withIdx t seqind))) 290antisym t = T (dims t) (ten (antisym' (withIdx t seqind)))
200 where antisym' t = addT $ map (scsig . flip tridx t) (perms (names t)) 291 where antisym' t = sumT $ map (scsig . flip tridx t) (perms (names t))
201 scsig t = scalar (signature (nms t)) `rawProduct` t 292 scsig t = scalar (signature (nms t)) `rawProduct` t
202 where nms = map idxName . dims 293 where nms = map idxName . dims
203 294
204norper t = rawProduct t (scalar (recip $ fromIntegral $ fact (rank t))) 295-- | the wedge product of two tensors (implemented as the antisymmetrization of the ordinary tensor product).
205antinorper t = rawProduct t (scalar (fromIntegral $ fact (rank t))) 296wedge :: (Field t, Fractional t) => Tensor t -> Tensor t -> Tensor t
206
207wedge a b = antisym (rawProduct (norper a) (norper b)) 297wedge a b = antisym (rawProduct (norper a) (norper b))
298 where norper t = rawProduct t (scalar (recip $ fromIntegral $ fact (rank t)))
208 299
209normAT t = sqrt $ innerAT t t 300-- antinorper t = rawProduct t (scalar (fromIntegral $ fact (rank t)))
210 301
302-- | The euclidean inner product of two completely antisymmetric tensors
303innerAT :: (Fractional t, Field t) => Tensor t -> Tensor t -> t
211innerAT t1 t2 = dot (ten t1) (ten t2) / fromIntegral (fact $ rank t1) 304innerAT t1 t2 = dot (ten t1) (ten t2) / fromIntegral (fact $ rank t1)
212 305
306fact :: (Num t, Enum t) => t -> t
213fact n = product [1..n] 307fact n = product [1..n]
214 308
309seqind' :: [[String]]
215seqind' = map return seqind 310seqind' = map return seqind
311
312seqind :: [String]
216seqind = map show [1..] 313seqind = map show [1..]
217 314
315-- | completely antisymmetric covariant tensor of dimension n
316leviCivita :: (Field t, Num t) => Int -> Tensor t
218leviCivita n = antisym $ foldl1 rawProduct $ zipWith withIdx auxbase seqind' 317leviCivita n = antisym $ foldl1 rawProduct $ zipWith withIdx auxbase seqind'
219 where auxbase = map tcovector (toRows (ident n)) 318 where auxbase = map tc (toRows (ident n))
319 tc = tensorFromVector Covariant
220 320
221-- | obtains de dual of the exterior product of a list of X? 321-- | contraction of leviCivita with a list of vectors (and raise with euclidean metric)
222dualV vs = foldl' contractionF (leviCivita n) vs 322innerLevi :: (Num t, Field t) => [Tensor t] -> Tensor t
323innerLevi vs = raise $ foldl' contractionF (leviCivita n) vs
223 where n = idxDim . head . dims . head $ vs 324 where n = idxDim . head . dims . head $ vs
224 325
225-- | raises or lowers all the indices of a tensor (with euclidean metric)
226raise (T d v) = T (map raise' d) v
227 where raise' idx@IdxDesc {idxType = Covariant } = idx {idxType = Contravariant}
228 raise' idx@IdxDesc {idxType = Contravariant } = idx {idxType = Covariant}
229-- | raises or lowers all the indices of a tensor with a given an (inverse) metric
230raiseWith = undefined
231 326
232dualg f t = f (leviCivita n) `okContract` withIdx t seqind `rawProduct` x 327-- | obtains the dual of a multivector (with euclidean metric)
328dual :: (Field t, Fractional t) => Tensor t -> Tensor t
329dual t = raise $ leviCivita n `mulT` withIdx t seqind `rawProduct` x
233 where n = idxDim . head . dims $ t 330 where n = idxDim . head . dims $ t
234 x = scalar (recip $ fromIntegral $ fact (rank t)) 331 x = scalar (recip $ fromIntegral $ fact (rank t))
235 332
236-- | obtains the dual of a multivector
237dual t = dualg id t
238
239-- | obtains the dual of a multicovector (with euclidean metric)
240codual t = dualg raise t
241 333
242-- | shows only the relevant components of an antisymmetric tensor 334-- | shows only the relevant components of an antisymmetric tensor
335niceAS :: (Field t, Fractional t) => Tensor t -> [(t, [Int])]
243niceAS t = filter ((/=0.0).fst) $ zip vals base 336niceAS t = filter ((/=0.0).fst) $ zip vals base
244 where vals = map ((`at` 0).ten.foldl' partF t) (map (map pred) base) 337 where vals = map ((`at` 0).ten.foldl' partF t) (map (map pred) base)
245 base = asBase r n 338 base = asBase r n
@@ -247,27 +340,3 @@ niceAS t = filter ((/=0.0).fst) $ zip vals base
247 n = idxDim . head . dims $ t 340 n = idxDim . head . dims $ t
248 partF t i = part t (name,i) where name = idxName . head . dims $ t 341 partF t i = part t (name,i) where name = idxName . head . dims $ t
249 asBase r n = filter (\x-> (x==nub x && x==sort x)) $ sequence $ replicate r [1..n] 342 asBase r n = filter (\x-> (x==nub x && x==sort x)) $ sequence $ replicate r [1..n]
250
251-- | renames specified indices of a tensor (repeated indices get the same name)
252idxRename (T d v) l = T (map (ir l) d) v
253 where ir l i = case lookup (idxName i) l of
254 Nothing -> i
255 Just r -> i {idxName = r}
256
257-- | renames all the indices in the current order (repeated indices may get different names)
258withIdx (T d v) l = T d' v
259 where d' = zipWith f d l
260 f i n = i {idxName=n}
261
262desiredContractions2 t1 t2 = [ (n1,n2) | n1 <- names t1, n2 <- names t2, n1==n2]
263
264desiredContractions1 t = [ n1 | (a,n1) <- x , (b,n2) <- x, a/=b, n1==n2]
265 where x = zip [0..] (names t)
266
267okContract t1 t2 = r where
268 t1r = contractionsC t1 (desiredContractions1 t1)
269 t2r = contractionsC t2 (desiredContractions1 t2)
270 cs = desiredContractions2 t1r t2r
271 r = case cs of
272 [] -> rawProduct t1r t2r
273 (n1,n2):as -> contractionsC (contraction2 t1r n1 t2r n2) (map fst as)