summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/Data/Packed/Internal/Tensor.hs32
-rw-r--r--lib/Data/Packed/Matrix.hs4
-rw-r--r--lib/Data/Packed/Tensor.hs82
-rw-r--r--lib/GSL/Vector.hs4
4 files changed, 106 insertions, 16 deletions
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs
index c4faf49..8296935 100644
--- a/lib/Data/Packed/Internal/Tensor.hs
+++ b/lib/Data/Packed/Internal/Tensor.hs
@@ -18,7 +18,8 @@ import Data.Packed.Internal.Common
18import Data.Packed.Internal.Vector 18import Data.Packed.Internal.Vector
19import Data.Packed.Internal.Matrix 19import Data.Packed.Internal.Matrix
20import Foreign.Storable 20import Foreign.Storable
21import Data.List(sort,elemIndex,nub) 21import Data.List(sort,elemIndex,nub,foldl1')
22import GSL.Vector
22 23
23data IdxType = Covariant | Contravariant deriving (Show,Eq) 24data IdxType = Covariant | Contravariant deriving (Show,Eq)
24 25
@@ -79,10 +80,10 @@ concatRename l1 l2 = l1 ++ map ren l2 where
79 ren idx = if {- s `elem` fs -} True then idx {idxName = idxName idx ++ "'"} else idx 80 ren idx = if {- s `elem` fs -} True then idx {idxName = idxName idx ++ "'"} else idx
80 fs = map idxName l1 81 fs = map idxName l1
81 82
82prod :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t 83--prod :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t
83prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer' v1 v2) 84prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer' v1 v2)
84 85
85contraction :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t 86--contraction :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t
86contraction t1 n1 t2 n2 = 87contraction t1 n1 t2 n2 =
87 if compatIdx t1 n1 t2 n2 88 if compatIdx t1 n1 t2 n2
88 then T (concatRename (tail d1) (tail d2)) (cdat m) 89 then T (concatRename (tail d1) (tail d2)) (cdat m)
@@ -91,16 +92,27 @@ contraction t1 n1 t2 n2 =
91 (d2,m2) = putFirstIdx n2 t2 92 (d2,m2) = putFirstIdx n2 t2
92 m = multiply RowMajor (trans m1) m2 93 m = multiply RowMajor (trans m1) m2
93 94
94sumT :: (Storable t, Enum t, Num t) => [Tensor t] -> [t] 95--sumT :: (Storable t, Enum t, Num t) => [Tensor t] -> [t]
95sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) 96--sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls)
97--addT ts = T (dims (head ts)) (fromList $ sumT ts)
96 98
97contract1 :: (Num t, Enum t, Field t) => Tensor t -> IdxName -> IdxName -> Tensor t 99liftTensor f (T d v) = T d (f v)
98contract1 t name1 name2 = T d $ fromList $ sumT y 100
101liftTensor2 f (T d1 v1) (T d2 v2) | compat d1 d2 = T d1 (f v1 v2)
102 | otherwise = error "liftTensor2 with incompatible tensors"
103 where compat a b = length a == length b
104
105
106a |+| b = liftTensor2 add a b
107addT l = foldl1' (|+|) l
108
109--contract1 :: (Num t, Field t) => Tensor t -> IdxName -> IdxName -> Tensor t
110contract1 t name1 name2 = addT y
99 where d = dims (head y) 111 where d = dims (head y)
100 x = (map (flip parts name2) (parts t name1)) 112 x = (map (flip parts name2) (parts t name1))
101 y = map head $ zipWith drop [0..] x 113 y = map head $ zipWith drop [0..] x
102 114
103contraction' :: (Field t, Enum t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t 115--contraction' :: (Field t, Enum t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t
104contraction' t1 n1 t2 n2 = 116contraction' t1 n1 t2 n2 =
105 if compatIdx t1 n1 t2 n2 117 if compatIdx t1 n1 t2 n2
106 then contract1 (prod t1 t2) n1 (n2++"'") 118 then contract1 (prod t1 t2) n1 (n2++"'")
@@ -130,8 +142,8 @@ names t = sort $ map idxName (dims t)
130normal :: (Field t) => Tensor t -> Tensor t 142normal :: (Field t) => Tensor t -> Tensor t
131normal t = tridx (names t) t 143normal t = tridx (names t) t
132 144
133contractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t] 145possibleContractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t]
134contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] 146possibleContractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ]
135 147
136-- sent to Haskell-Cafe by Sebastian Sylvan 148-- sent to Haskell-Cafe by Sebastian Sylvan
137perms :: [t] -> [[t]] 149perms :: [t] -> [[t]]
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs
index 36bf32e..2033dc7 100644
--- a/lib/Data/Packed/Matrix.hs
+++ b/lib/Data/Packed/Matrix.hs
@@ -87,14 +87,14 @@ ident n = diag (constant 1 n)
87r >< c = f where 87r >< c = f where
88 f l | dim v == r*c = matrixFromVector RowMajor c v 88 f l | dim v == r*c = matrixFromVector RowMajor c v
89 | otherwise = error $ "inconsistent list size = " 89 | otherwise = error $ "inconsistent list size = "
90 ++show (dim v) ++"in ("++show r++"><"++show c++")" 90 ++show (dim v) ++" in ("++show r++"><"++show c++")"
91 where v = fromList l 91 where v = fromList l
92 92
93(>|<) :: (Field a) => Int -> Int -> [a] -> Matrix a 93(>|<) :: (Field a) => Int -> Int -> [a] -> Matrix a
94r >|< c = f where 94r >|< c = f where
95 f l | dim v == r*c = matrixFromVector ColumnMajor c v 95 f l | dim v == r*c = matrixFromVector ColumnMajor c v
96 | otherwise = error $ "inconsistent list size = " 96 | otherwise = error $ "inconsistent list size = "
97 ++show (dim v) ++"in ("++show r++"><"++show c++")" 97 ++show (dim v) ++" in ("++show r++"><"++show c++")"
98 where v = fromList l 98 where v = fromList l
99 99
100---------------------------------------------------------------- 100----------------------------------------------------------------
diff --git a/lib/Data/Packed/Tensor.hs b/lib/Data/Packed/Tensor.hs
index 75a9288..68ce9a5 100644
--- a/lib/Data/Packed/Tensor.hs
+++ b/lib/Data/Packed/Tensor.hs
@@ -12,9 +12,85 @@
12-- 12--
13----------------------------------------------------------------------------- 13-----------------------------------------------------------------------------
14 14
15module Data.Packed.Tensor ( 15module Data.Packed.Tensor where
16
17) where
18 16
17import Data.Packed.Matrix
19import Data.Packed.Internal 18import Data.Packed.Internal
20import Complex 19import Complex
20import Data.List(transpose,intersperse,sort,elemIndex,nub,foldl',foldl1')
21
22scalar x = T [] (fromList [x])
23tensorFromVector (tp,nm) v = T {dims = [IdxDesc (dim v) tp nm]
24 , ten = v}
25tensorFromMatrix (tpr,nmr) (tpc,nmc) m = T {dims = [IdxDesc (rows m) tpr nmr,IdxDesc (cols m) tpc nmc]
26 , ten = cdat m}
27
28scsig t = scalar (signature (nms t)) `prod` t
29 where nms = map idxName . dims
30
31antisym' t = addT $ map (scsig . flip tridx t) (perms (names t))
32
33
34auxrename (T d v) = T d' v
35 where d' = [IdxDesc n c (show (pos q)) | IdxDesc n c q <- d]
36 pos n = i where Just i = elemIndex n nms
37 nms = map idxName d
38
39antisym t = T (dims t) (ten (antisym' (auxrename t)))
40
41
42norper t = prod t (scalar (recip $ fromIntegral $ product [1 .. length (dims t)]))
43antinorper t = prod t (scalar (fromIntegral $ product [1 .. length (dims t)]))
44
45
46tvector n v = tensorFromVector (Contravariant,n) v
47tcovector n v = tensorFromVector (Covariant,n) v
48
49wedge a b = antisym (prod (norper a) (norper b))
50
51a /\ b = wedge a b
52
53a <*> b = normal $ prod a b
54
55normAT t = sqrt $ innerAT t t
56
57innerAT t1 t2 = dot (ten t1) (ten t2) / fromIntegral (fact $ length $ dims t1)
58
59fact n = product [1..n]
60
61leviCivita n = antisym $ foldl1 prod $ zipWith tcovector (map show [1,2..]) (toRows (ident n))
62
63contractionF t1 t2 = contraction t1 n1 t2 n2
64 where n1 = fn t1
65 n2 = fn t2
66 fn = idxName . head . dims
67
68
69dualV vs = foldl' contractionF (leviCivita n) vs
70 where n = idxDim . head . dims . head $ vs
71
72raise (T d v) = T (map raise' d) v
73 where raise' idx@IdxDesc {idxType = Covariant } = idx {idxType = Contravariant}
74 raise' idx@IdxDesc {idxType = Contravariant } = idx {idxType = Covariant}
75
76dualMV t = prod (foldl' contract1b (lc <*> t) ds) (scalar (recip $ fromIntegral $ fact (length ds)))
77 where
78 lc = leviCivita n
79 nms1 = map idxName (dims lc)
80 nms2 = map ((++"'").idxName) (dims t)
81 ds = zip nms1 nms2
82 n = idxDim . head . dims $ t
83
84contract1b t (n1,n2) = contract1 t n1 n2
85
86contractions t pairs = foldl' contract1b t pairs
87
88asBase r n = filter (\x-> (x==nub x && x==sort x)) $ sequence $ replicate r [1..n]
89
90partF t i = part t (name,i) where name = idxName . head . dims $ t
91
92niceAS t = filter ((/=0.0).fst) $ zip vals base
93 where vals = map ((`at` 0).ten.foldl' partF t) (map (map pred) base)
94 base = asBase r n
95 r = length (dims t)
96 n = idxDim . head . dims $ t
diff --git a/lib/GSL/Vector.hs b/lib/GSL/Vector.hs
index a074254..0b3c3a9 100644
--- a/lib/GSL/Vector.hs
+++ b/lib/GSL/Vector.hs
@@ -21,7 +21,9 @@ module GSL.Vector (
21 scale, addConstant, add, sub, mul, 21 scale, addConstant, add, sub, mul,
22) where 22) where
23 23
24import Data.Packed.Internal 24import Data.Packed.Internal.Common
25import Data.Packed.Internal.Vector
26
25import Complex 27import Complex
26import Foreign 28import Foreign
27 29