summaryrefslogtreecommitdiff
path: root/lib/Data/Packed/Internal/Tensor.hs
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2007-06-29 08:09:46 +0000
committerAlberto Ruiz <aruiz@um.es>2007-06-29 08:09:46 +0000
commite36c04dca536caa42b41a4280d3f21375219970d (patch)
tree18aef1b4607ee19c88cd740dbc2e7a65d783bd2f /lib/Data/Packed/Internal/Tensor.hs
parenta749785e839d14fadc47ab4c6e94afdd167bdd21 (diff)
tensor refactoring
Diffstat (limited to 'lib/Data/Packed/Internal/Tensor.hs')
-rw-r--r--lib/Data/Packed/Internal/Tensor.hs197
1 files changed, 160 insertions, 37 deletions
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs
index 8296935..dedbb9c 100644
--- a/lib/Data/Packed/Internal/Tensor.hs
+++ b/lib/Data/Packed/Internal/Tensor.hs
@@ -14,12 +14,11 @@
14 14
15module Data.Packed.Internal.Tensor where 15module Data.Packed.Internal.Tensor where
16 16
17import Data.Packed.Internal.Common 17import Data.Packed.Internal
18import Data.Packed.Internal.Vector
19import Data.Packed.Internal.Matrix
20import Foreign.Storable 18import Foreign.Storable
21import Data.List(sort,elemIndex,nub,foldl1') 19import Data.List(sort,elemIndex,nub,foldl1',foldl')
22import GSL.Vector 20import GSL.Vector
21import Data.Packed.Matrix
23 22
24data IdxType = Covariant | Contravariant deriving (Show,Eq) 23data IdxType = Covariant | Contravariant deriving (Show,Eq)
25 24
@@ -27,12 +26,13 @@ type IdxName = String
27 26
28data IdxDesc = IdxDesc { idxDim :: Int, 27data IdxDesc = IdxDesc { idxDim :: Int,
29 idxType :: IdxType, 28 idxType :: IdxType,
30 idxName :: IdxName } 29 idxName :: IdxName } deriving Show
31 30
32data Tensor t = T { dims :: [IdxDesc] 31data Tensor t = T { dims :: [IdxDesc]
33 , ten :: Vector t 32 , ten :: Vector t
34 } 33 }
35 34
35-- | tensor rank (number of indices)
36rank :: Tensor t -> Int 36rank :: Tensor t -> Int
37rank = length . dims 37rank = length . dims
38 38
@@ -40,14 +40,16 @@ instance (Show a,Storable a) => Show (Tensor a) where
40 show T {dims = [], ten = t} = "scalar "++show (t `at` 0) 40 show T {dims = [], ten = t} = "scalar "++show (t `at` 0)
41 show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) 41 show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t)
42 42
43 43-- | a nice description of the tensor structure
44shdims :: [IdxDesc] -> String 44shdims :: [IdxDesc] -> String
45shdims [IdxDesc n t name] = name ++ sym t ++"["++show n++"]" 45shdims [] = ""
46shdims [d] = shdim d
47shdims (d:ds) = shdim d ++ "><"++ shdims ds
48shdim (IdxDesc n t name) = name ++ sym t ++"["++show n++"]"
46 where sym Covariant = "_" 49 where sym Covariant = "_"
47 sym Contravariant = "^" 50 sym Contravariant = "^"
48shdims (d:ds) = shdims [d] ++ "><"++ shdims ds
49
50 51
52-- | express the tensor as a matrix with the given index in columns
51findIdx :: (Field t) => IdxName -> Tensor t 53findIdx :: (Field t) => IdxName -> Tensor t
52 -> (([IdxDesc], [IdxDesc]), Matrix t) 54 -> (([IdxDesc], [IdxDesc]), Matrix t)
53findIdx name t = ((d1,d2),m) where 55findIdx name t = ((d1,d2),m) where
@@ -55,6 +57,7 @@ findIdx name t = ((d1,d2),m) where
55 c = product (map idxDim d2) 57 c = product (map idxDim d2)
56 m = matrixFromVector RowMajor c (ten t) 58 m = matrixFromVector RowMajor c (ten t)
57 59
60-- | express the tensor as a matrix with the given index in rows
58putFirstIdx :: (Field t) => String -> Tensor t -> ([IdxDesc], Matrix t) 61putFirstIdx :: (Field t) => String -> Tensor t -> ([IdxDesc], Matrix t)
59putFirstIdx name t = (nd,m') 62putFirstIdx name t = (nd,m')
60 where ((d1,d2),m) = findIdx name t 63 where ((d1,d2),m) = findIdx name t
@@ -62,6 +65,7 @@ putFirstIdx name t = (nd,m')
62 nd = d2++d1 65 nd = d2++d1
63 c = dim (ten t) `div` (idxDim $ head d2) 66 c = dim (ten t) `div` (idxDim $ head d2)
64 67
68-- | extracts a given part of a tensor
65part :: (Field t) => Tensor t -> (IdxName, Int) -> Tensor t 69part :: (Field t) => Tensor t -> (IdxName, Int) -> Tensor t
66part t (name,k) = if k<0 || k>=l 70part t (name,k) = if k<0 || k>=l
67 then error $ "part "++show (name,k)++" out of range" -- in "++show t 71 then error $ "part "++show (name,k)++" out of range" -- in "++show t
@@ -69,32 +73,70 @@ part t (name,k) = if k<0 || k>=l
69 where (d:ds,m) = putFirstIdx name t 73 where (d:ds,m) = putFirstIdx name t
70 l = idxDim d 74 l = idxDim d
71 75
76-- | creates a list with all parts of a tensor along a given index
72parts :: (Field t) => Tensor t -> IdxName -> [Tensor t] 77parts :: (Field t) => Tensor t -> IdxName -> [Tensor t]
73parts t name = map f (toRows m) 78parts t name = map f (toRows m)
74 where (d:ds,m) = putFirstIdx name t 79 where (d:ds,m) = putFirstIdx name t
75 l = idxDim d 80 l = idxDim d
76 f t = T {dims=ds, ten=t} 81 f t = T {dims=ds, ten=t}
77 82
78concatRename :: [IdxDesc] -> [IdxDesc] -> [IdxDesc] 83-- | tensor product without without any contractions
79concatRename l1 l2 = l1 ++ map ren l2 where 84rawProduct :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t
80 ren idx = if {- s `elem` fs -} True then idx {idxName = idxName idx ++ "'"} else idx 85rawProduct (T d1 v1) (T d2 v2) = T (d1++d2) (outer' v1 v2)
81 fs = map idxName l1
82 86
83--prod :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t 87-- | contraction of the product of two tensors
84prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer' v1 v2) 88contraction2 :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t
85 89contraction2 t1 n1 t2 n2 =
86--contraction :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t
87contraction t1 n1 t2 n2 =
88 if compatIdx t1 n1 t2 n2 90 if compatIdx t1 n1 t2 n2
89 then T (concatRename (tail d1) (tail d2)) (cdat m) 91 then T (tail d1 ++ tail d2) (cdat m)
90 else error "wrong contraction'" 92 else error "wrong contraction2"
91 where (d1,m1) = putFirstIdx n1 t1 93 where (d1,m1) = putFirstIdx n1 t1
92 (d2,m2) = putFirstIdx n2 t2 94 (d2,m2) = putFirstIdx n2 t2
93 m = multiply RowMajor (trans m1) m2 95 m = multiply RowMajor (trans m1) m2
94 96
95--sumT :: (Storable t, Enum t, Num t) => [Tensor t] -> [t] 97-- | contraction of a tensor along two given indices
96--sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) 98contraction1 :: (Field t, Num t) => Tensor t -> IdxName -> IdxName -> Tensor t
97--addT ts = T (dims (head ts)) (fromList $ sumT ts) 99contraction1 t name1 name2 =
100 if compatIdx t name1 t name2
101 then addT y
102 else error $ "wrong contraction1: "++(shdims$dims$t)++" "++name1++" "++name2
103 where d = dims (head y)
104 x = (map (flip parts name2) (parts t name1))
105 y = map head $ zipWith drop [0..] x
106
107-- | contraction of a tensor along a repeated index
108contraction1c :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t
109contraction1c t n = contraction1 renamed n' n
110 where n' = n++"'" -- hmmm
111 renamed = withIdx t auxnames
112 auxnames = h ++ (n':r)
113 (h,_:r) = break (==n) (map idxName (dims t))
114
115-- | alternative and inefficient version of contraction2
116contraction2' :: (Field t, Enum t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t
117contraction2' t1 n1 t2 n2 =
118 if compatIdx t1 n1 t2 n2
119 then contraction1 (rawProduct t1 t2) n1 n2
120 else error "wrong contraction'"
121
122-- | applies a sequence of contractions
123contractions t pairs = foldl' contract1b t pairs
124 where contract1b t (n1,n2) = contraction1 t n1 n2
125
126-- | applies a sequence of contractions of same index
127contractionsC t is = foldl' contraction1c t is
128
129
130-- | applies a contraction on the first indices of the tensors
131contractionF t1 t2 = contraction2 t1 n1 t2 n2
132 where n1 = fn t1
133 n2 = fn t2
134 fn = idxName . head . dims
135
136-- | computes all compatible contractions of the product of two tensors that would arise if the index names were equal
137possibleContractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t]
138possibleContractions t1 t2 = [ contraction2 t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ]
139
98 140
99liftTensor f (T d v) = T d (f v) 141liftTensor f (T d v) = T d (f v)
100 142
@@ -106,18 +148,7 @@ liftTensor2 f (T d1 v1) (T d2 v2) | compat d1 d2 = T d1 (f v1 v2)
106a |+| b = liftTensor2 add a b 148a |+| b = liftTensor2 add a b
107addT l = foldl1' (|+|) l 149addT l = foldl1' (|+|) l
108 150
109--contract1 :: (Num t, Field t) => Tensor t -> IdxName -> IdxName -> Tensor t 151-- | index transposition to a desired order. You can specify only a subset of the indices, which will be moved to the front of indices list
110contract1 t name1 name2 = addT y
111 where d = dims (head y)
112 x = (map (flip parts name2) (parts t name1))
113 y = map head $ zipWith drop [0..] x
114
115--contraction' :: (Field t, Enum t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t
116contraction' t1 n1 t2 n2 =
117 if compatIdx t1 n1 t2 n2
118 then contract1 (prod t1 t2) n1 (n2++"'")
119 else error "wrong contraction'"
120
121tridx :: (Field t) => [IdxName] -> Tensor t -> Tensor t 152tridx :: (Field t) => [IdxName] -> Tensor t -> Tensor t
122tridx [] t = t 153tridx [] t = t
123tridx (name:rest) t = T (d:ds) (join ts) where 154tridx (name:rest) t = T (d:ds) (join ts) where
@@ -142,8 +173,6 @@ names t = sort $ map idxName (dims t)
142normal :: (Field t) => Tensor t -> Tensor t 173normal :: (Field t) => Tensor t -> Tensor t
143normal t = tridx (names t) t 174normal t = tridx (names t) t
144 175
145possibleContractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t]
146possibleContractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ]
147 176
148-- sent to Haskell-Cafe by Sebastian Sylvan 177-- sent to Haskell-Cafe by Sebastian Sylvan
149perms :: [t] -> [[t]] 178perms :: [t] -> [[t]]
@@ -161,3 +190,97 @@ signature :: (Num t, Ord a) => [a] -> t
161signature l | length (nub l) < length l = 0 190signature l | length (nub l) < length l = 0
162 | even (interchanges l) = 1 191 | even (interchanges l) = 1
163 | otherwise = -1 192 | otherwise = -1
193
194scalar x = T [] (fromList [x])
195tensorFromVector (tp,nm) v = T {dims = [IdxDesc (dim v) tp nm]
196 , ten = v}
197tensorFromMatrix (tpr,nmr) (tpc,nmc) m = T {dims = [IdxDesc (rows m) tpr nmr,IdxDesc (cols m) tpc nmc]
198 , ten = cdat m}
199
200tvector n v = tensorFromVector (Contravariant,n) v
201tcovector n v = tensorFromVector (Covariant,n) v
202
203
204antisym t = T (dims t) (ten (antisym' (auxrename t)))
205 where
206 scsig t = scalar (signature (nms t)) `rawProduct` t
207 where nms = map idxName . dims
208
209 antisym' t = addT $ map (scsig . flip tridx t) (perms (names t))
210
211 auxrename (T d v) = T d' v
212 where d' = [IdxDesc n c (show (pos q)) | IdxDesc n c q <- d]
213 pos n = i where Just i = elemIndex n nms
214 nms = map idxName d
215
216
217norper t = rawProduct t (scalar (recip $ fromIntegral $ fact (rank t)))
218antinorper t = rawProduct t (scalar (fromIntegral $ fact (rank t)))
219
220wedge a b = antisym (rawProduct (norper a) (norper b))
221
222a /\ b = wedge a b
223
224normAT t = sqrt $ innerAT t t
225
226innerAT t1 t2 = dot (ten t1) (ten t2) / fromIntegral (fact $ rank t1)
227
228fact n = product [1..n]
229
230leviCivita n = antisym $ foldl1 rawProduct $ zipWith tcovector (map show [1,2..]) (toRows (ident n))
231
232-- | obtains de dual of the exterior product of a list of X?
233dualV vs = foldl' contractionF (leviCivita n) vs
234 where n = idxDim . head . dims . head $ vs
235
236-- | raises or lowers all the indices of a tensor (with euclidean metric)
237raise (T d v) = T (map raise' d) v
238 where raise' idx@IdxDesc {idxType = Covariant } = idx {idxType = Contravariant}
239 raise' idx@IdxDesc {idxType = Contravariant } = idx {idxType = Covariant}
240-- | raises or lowers all the indices of a tensor with a given an (inverse) metric
241raiseWith = undefined
242
243-- | obtains the dual of a multivector
244dualMV t = rawProduct (contractions lct ds) x
245 where
246 lc = leviCivita n
247 lct = rawProduct lc t
248 nms1 = map idxName (dims lc)
249 nms2 = map idxName (dims t)
250 ds = zip nms1 nms2
251 n = idxDim . head . dims $ t
252 x = scalar (recip $ fromIntegral $ fact (rank t))
253
254
255-- | shows only the relevant components of an antisymmetric tensor
256niceAS t = filter ((/=0.0).fst) $ zip vals base
257 where vals = map ((`at` 0).ten.foldl' partF t) (map (map pred) base)
258 base = asBase r n
259 r = length (dims t)
260 n = idxDim . head . dims $ t
261 partF t i = part t (name,i) where name = idxName . head . dims $ t
262 asBase r n = filter (\x-> (x==nub x && x==sort x)) $ sequence $ replicate r [1..n]
263
264-- | renames specified indices of a tensor (repeated indices get the same name)
265idxRename (T d v) l = T (map (ir l) d) v
266 where ir l i = case lookup (idxName i) l of
267 Nothing -> i
268 Just r -> i {idxName = r}
269
270-- | renames all the indices in the current order (repeated indices may get different names)
271withIdx (T d v) l = T d' v
272 where d' = zipWith f d l
273 f i n = i {idxName=n}
274
275desiredContractions2 t1 t2 = [ (n1,n2) | n1 <- names t1, n2 <- names t2, n1==n2]
276
277desiredContractions1 t = [ n1 | (a,n1) <- x , (b,n2) <- x, a/=b, n1==n2]
278 where x = zip [0..] (names t)
279
280okContract t1 t2 = r where
281 t1r = contractionsC t1 (desiredContractions1 t1)
282 t2r = contractionsC t2 (desiredContractions1 t2)
283 cs = desiredContractions2 t1r t2r
284 r = case cs of
285 [] -> rawProduct t1r t2r
286 (n1,n2):as -> contractionsC (contraction2 t1r n1 t2r n2) (map fst as)