summaryrefslogtreecommitdiff
path: root/lib/Data/Packed/Internal/Tensor.hs
blob: dedbb9cd6ab51b32a97d2a1b92a10beb782a5a8f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
-----------------------------------------------------------------------------
-- |
-- Module      :  Data.Packed.Internal.Tensor
-- Copyright   :  (c) Alberto Ruiz 2007
-- License     :  GPL-style
--
-- Maintainer  :  Alberto Ruiz <aruiz@um.es>
-- Stability   :  provisional
-- Portability :  portable (uses FFI)
--
-- basic tensor operations
--
-----------------------------------------------------------------------------

module Data.Packed.Internal.Tensor where

import Data.Packed.Internal
import Foreign.Storable
import Data.List(sort,elemIndex,nub,foldl1',foldl')
import GSL.Vector
import Data.Packed.Matrix

data IdxType = Covariant | Contravariant deriving (Show,Eq)

type IdxName = String

data IdxDesc = IdxDesc { idxDim  :: Int,
                         idxType :: IdxType,
                         idxName :: IdxName } deriving Show

data Tensor t = T { dims   :: [IdxDesc]
                  , ten    :: Vector t
                  }

-- | tensor rank (number of indices)
rank :: Tensor t -> Int
rank = length . dims

instance (Show a,Storable a) => Show (Tensor a) where
    show T {dims = [], ten = t} = "scalar "++show (t `at` 0)
    show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t)

-- | a nice description of the tensor structure
shdims :: [IdxDesc] -> String
shdims [] = ""
shdims [d] = shdim d 
shdims (d:ds) = shdim d ++ "><"++ shdims ds
shdim (IdxDesc n t name) = name ++ sym t ++"["++show n++"]"
    where sym Covariant     = "_"
          sym Contravariant = "^"

-- | express the tensor as a matrix with the given index in columns
findIdx :: (Field t) => IdxName -> Tensor t
        -> (([IdxDesc], [IdxDesc]), Matrix t)
findIdx name t = ((d1,d2),m) where
    (d1,d2) = span (\d -> idxName d /= name) (dims t)
    c = product (map idxDim d2)
    m = matrixFromVector RowMajor c (ten t)

-- | express the tensor as a matrix with the given index in rows
putFirstIdx :: (Field t) => String -> Tensor t -> ([IdxDesc], Matrix t)
putFirstIdx name t = (nd,m')
    where ((d1,d2),m) = findIdx name t
          m' = matrixFromVector RowMajor c $ cdat $ trans m
          nd = d2++d1
          c = dim (ten t) `div` (idxDim $ head d2)

-- | extracts a given part of a tensor
part :: (Field t) => Tensor t -> (IdxName, Int) -> Tensor t
part t (name,k) = if k<0 || k>=l
                    then error $ "part "++show (name,k)++" out of range" -- in "++show t
                    else T {dims = ds, ten = toRows m !! k}
    where (d:ds,m) = putFirstIdx name t
          l = idxDim d

-- | creates a list with all parts of a tensor along a given index
parts :: (Field t) => Tensor t -> IdxName -> [Tensor t]
parts t name = map f (toRows m)
    where (d:ds,m) = putFirstIdx name t
          l = idxDim d
          f t = T {dims=ds, ten=t}

-- | tensor product without without any contractions
rawProduct :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t
rawProduct (T d1 v1) (T d2 v2) = T (d1++d2) (outer' v1 v2)

-- | contraction of the product of two tensors 
contraction2 :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t
contraction2 t1 n1 t2 n2 =
    if compatIdx t1 n1 t2 n2
        then T (tail d1 ++ tail d2) (cdat m)
        else error "wrong contraction2"
  where (d1,m1) = putFirstIdx n1 t1
        (d2,m2) = putFirstIdx n2 t2
        m = multiply RowMajor (trans m1) m2

-- | contraction of a tensor along two given indices
contraction1 :: (Field t, Num t) => Tensor t -> IdxName -> IdxName -> Tensor t
contraction1 t name1 name2 =
    if compatIdx t name1 t name2
        then addT y
        else error $ "wrong contraction1: "++(shdims$dims$t)++" "++name1++" "++name2
    where d = dims (head y)
          x = (map (flip parts name2) (parts t name1))
          y = map head $ zipWith drop [0..] x

-- | contraction of a tensor along a repeated index
contraction1c :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t
contraction1c t n = contraction1 renamed n' n
    where n' = n++"'" -- hmmm
          renamed = withIdx t auxnames
          auxnames = h ++ (n':r)
          (h,_:r) = break (==n) (map idxName (dims t))

-- | alternative and inefficient version of contraction2
contraction2' :: (Field t, Enum t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t
contraction2' t1 n1 t2 n2 =
    if compatIdx t1 n1 t2 n2
        then contraction1 (rawProduct t1 t2) n1 n2
        else error "wrong contraction'"

-- | applies a sequence of contractions
contractions t pairs = foldl' contract1b t pairs
    where contract1b t (n1,n2) = contraction1 t n1 n2

-- | applies a sequence of contractions of same index
contractionsC t is = foldl' contraction1c t is


-- | applies a contraction on the first indices of the tensors
contractionF t1 t2 = contraction2 t1 n1 t2 n2
    where n1 = fn t1
          n2 = fn t2
          fn = idxName . head . dims

-- | computes all compatible contractions of the product of two tensors that would arise if the index names were equal
possibleContractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t]
possibleContractions t1 t2 = [ contraction2 t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ]


liftTensor f (T d v) = T d (f v)

liftTensor2 f (T d1 v1) (T d2 v2) | compat d1 d2 = T d1 (f v1 v2)
                                  | otherwise = error "liftTensor2 with incompatible tensors"
    where compat a b = length a == length b


a |+| b = liftTensor2 add a b
addT l = foldl1' (|+|) l

-- | index transposition to a desired order. You can specify only a subset of the indices, which will be moved to the front of indices list
tridx :: (Field t) => [IdxName] -> Tensor t -> Tensor t
tridx [] t = t
tridx (name:rest) t = T (d:ds) (join ts) where
    ((_,d:_),_) = findIdx name t
    ps = map (tridx rest) (parts t name)
    ts = map ten ps
    ds = dims (head ps)

compatIdxAux :: IdxDesc -> IdxDesc -> Bool
compatIdxAux IdxDesc {idxDim = n1, idxType = t1}
             IdxDesc {idxDim = n2, idxType = t2}
    = t1 /= t2 && n1 == n2

compatIdx :: (Field t1, Field t) => Tensor t1 -> IdxName -> Tensor t -> IdxName -> Bool
compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where
    d1 = head $ snd $ fst $ findIdx n1 t1
    d2 = head $ snd $ fst $ findIdx n2 t2

names :: Tensor t -> [IdxName]
names t = sort $ map idxName (dims t)

normal :: (Field t) => Tensor t -> Tensor t
normal t = tridx (names t) t


-- sent to Haskell-Cafe by Sebastian Sylvan
perms :: [t] -> [[t]]
perms [x] = [[x]]
perms xs = [y:ps | (y,ys) <- selections xs , ps <- perms ys]
selections []     = []
selections (x:xs) = (x,xs) : [(y,x:ys) | (y,ys) <- selections xs]

interchanges :: (Ord a) => [a] -> Int
interchanges ls = sum (map (count ls) ls)
    where count l p = length $ filter (>p) $ take pel l
              where Just pel = elemIndex p l

signature :: (Num t, Ord a) => [a] -> t
signature l | length (nub l) < length l =  0
            | even (interchanges l)     =  1
            | otherwise                 = -1

scalar x = T [] (fromList [x])
tensorFromVector (tp,nm) v = T {dims = [IdxDesc (dim v) tp nm]
                                       , ten = v}
tensorFromMatrix (tpr,nmr) (tpc,nmc) m = T {dims = [IdxDesc (rows m) tpr nmr,IdxDesc (cols m) tpc nmc]
                                           , ten = cdat m}

tvector n v = tensorFromVector (Contravariant,n) v
tcovector n v = tensorFromVector (Covariant,n) v


antisym t = T (dims t) (ten (antisym' (auxrename t)))
    where
        scsig t = scalar (signature (nms t)) `rawProduct` t
            where nms = map idxName . dims

        antisym' t = addT $ map (scsig . flip tridx t) (perms (names t))

        auxrename (T d v) = T d' v
            where d' = [IdxDesc n c (show (pos q)) | IdxDesc n c q <- d]
                  pos n = i where Just i = elemIndex n nms
                  nms = map idxName d


norper t = rawProduct t (scalar (recip $ fromIntegral $ fact (rank t)))
antinorper t = rawProduct t (scalar (fromIntegral $ fact (rank t)))

wedge a b = antisym (rawProduct (norper a) (norper b))

a /\ b = wedge a b

normAT t = sqrt $ innerAT t t

innerAT t1 t2 = dot (ten t1) (ten t2) / fromIntegral (fact $ rank t1)

fact n = product [1..n]

leviCivita n = antisym $ foldl1 rawProduct $ zipWith tcovector (map show [1,2..]) (toRows (ident n))

-- | obtains de dual of the exterior product of a list of X?
dualV vs = foldl' contractionF (leviCivita n) vs
    where n = idxDim . head . dims . head $ vs

-- | raises or lowers all the indices of a tensor (with euclidean metric)
raise (T d v) = T (map raise' d) v
    where raise' idx@IdxDesc {idxType = Covariant } = idx {idxType = Contravariant}
          raise' idx@IdxDesc {idxType = Contravariant } = idx {idxType = Covariant}
-- | raises or lowers all the indices of a tensor with a given an (inverse) metric
raiseWith = undefined

-- | obtains the dual of a multivector
dualMV t = rawProduct (contractions lct ds) x
    where
        lc = leviCivita n
        lct = rawProduct lc t
        nms1 = map idxName (dims lc)
        nms2 = map idxName (dims t)
        ds = zip nms1 nms2
        n = idxDim . head . dims $ t
        x = scalar (recip $ fromIntegral $ fact (rank t))


-- | shows only the relevant components of an antisymmetric tensor
niceAS t = filter ((/=0.0).fst) $ zip vals base
    where vals = map ((`at` 0).ten.foldl' partF t) (map (map pred) base)
          base = asBase r n
          r = length (dims t)
          n = idxDim . head . dims $ t
          partF t i = part t (name,i) where name = idxName . head . dims $ t
          asBase r n = filter (\x-> (x==nub x && x==sort x)) $ sequence $ replicate r [1..n]

-- | renames specified indices of a tensor (repeated indices get the same name)
idxRename (T d v) l = T (map (ir l) d) v
    where ir l i = case lookup (idxName i) l of
                       Nothing -> i
                       Just r  -> i {idxName = r}

-- | renames all the indices in the current order (repeated indices may get different names)
withIdx (T d v) l = T d' v
    where d' = zipWith f d l
          f i n = i {idxName=n}

desiredContractions2 t1 t2 = [ (n1,n2) | n1 <- names t1, n2 <- names t2, n1==n2]

desiredContractions1 t = [ n1 | (a,n1) <- x , (b,n2) <- x, a/=b, n1==n2]
    where x = zip [0..] (names t)

okContract t1 t2 = r where
    t1r = contractionsC t1 (desiredContractions1 t1)
    t2r = contractionsC t2 (desiredContractions1 t2)
    cs = desiredContractions2 t1r t2r
    r = case cs of
        [] -> rawProduct t1r t2r
        (n1,n2):as -> contractionsC (contraction2 t1r n1 t2r n2) (map fst as)