summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2007-06-06 17:40:09 +0000
committerAlberto Ruiz <aruiz@um.es>2007-06-06 17:40:09 +0000
commite7c03c1ab4de85e7a700d2eafaebd37f4607c51f (patch)
tree4512d18907d88d0390671fcde4e8886d30cd0492
parenta4254a0b9bfbd720efbe42b86aa50107a74d56c7 (diff)
working on tensor contractions
-rw-r--r--examples/pru.hs76
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs17
-rw-r--r--lib/Data/Packed/Internal/Tensor.hs73
-rw-r--r--lib/Data/Packed/Internal/Vector.hs2
4 files changed, 133 insertions, 35 deletions
diff --git a/examples/pru.hs b/examples/pru.hs
index bddc08f..a935d93 100644
--- a/examples/pru.hs
+++ b/examples/pru.hs
@@ -7,7 +7,7 @@ import Data.Packed.Internal.Tensor
7 7
8import Complex 8import Complex
9import Numeric(showGFloat) 9import Numeric(showGFloat)
10import Data.List(transpose,intersperse) 10import Data.List(transpose,intersperse,sort)
11import Foreign.Storable 11import Foreign.Storable
12 12
13r >< c = f where 13r >< c = f where
@@ -22,8 +22,6 @@ r >|< c = f where
22 ++show (dim v) ++"in ("++show r++"><"++show c++")" 22 ++show (dim v) ++"in ("++show r++"><"++show c++")"
23 where v = fromList l 23 where v = fromList l
24 24
25
26
27vr = fromList [1..15::Double] 25vr = fromList [1..15::Double]
28vc = fromList (map (\x->x :+ (x+1)) [1..15::Double]) 26vc = fromList (map (\x->x :+ (x+1)) [1..15::Double])
29 27
@@ -62,38 +60,56 @@ main = do
62t = T [(4,(Covariant,"p")),(2,(Covariant,"q")),(3,(Contravariant,"r"))] $ fromList [1..24::Double] 60t = T [(4,(Covariant,"p")),(2,(Covariant,"q")),(3,(Contravariant,"r"))] $ fromList [1..24::Double]
63 61
64 62
65findIdx name t = ((d1,d2),m) where
66 (d1,d2) = span (\(_,(_,n)) -> n /=name) (dims t)
67 c = product (map fst d2)
68 m = matrixFromVector RowMajor c (ten t)
69 63
70putFirstIdx name t = (nd,m') 64t1 = T [(4,(Covariant,"p")),(4,(Contravariant,"q")),(2,(Covariant,"r"))] $ fromList [1..32::Double]
71 where ((d1,d2),m) = findIdx name t 65t2 = T [(4,(Covariant,"p")),(4,(Contravariant,"q"))] $ fromList [1..16::Double]
72 m' = matrixFromVector RowMajor c $ cdat $ trans m
73 nd = d2++d1
74 c = dim (ten t) `div` (fst $ head d2)
75 66
76part t (name,k) = if k<0 || k>=l
77 then error $ "part "++show (name,k)++" out of range in "++show t
78 else T {dims = ds, ten = toRows m !! k}
79 where (d:ds,m) = putFirstIdx name t
80 (l,_) = d
81 67
82parts t name = map f (toRows m)
83 where (d:ds,m) = putFirstIdx name t
84 (l,_) = d
85 f t = T {dims=ds, ten=t}
86 68
87t1 = T [(4,(Covariant,"p")),(4,(Contravariant,"q")),(2,(Covariant,"r"))] $ fromList [1..32::Double]
88t2 = T [(4,(Covariant,"p")),(4,(Contravariant,"q"))] $ fromList [1..16::Double]
89 69
90contract1 t name1 name2 = map head $ zipWith drop [0..] (map (flip parts name2) (parts t name1))
91 70
92sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls)
93 71
94on f g = \x y -> f (g x) (g y) 72delta i j | i==j = 1
73 | otherwise = 0
74
75e i n = fromList [ delta k i | k <- [1..n]]
76
77ident n = fromRows [ e i n | i <- [1..n]]
78
79diag l = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]]
80 where c = length l
81
82tensorFromVector idx v = T {dims = [(dim v,idx)], ten = v}
83tensorFromMatrix idxr idxc m = T {dims = [(rows m,idxr),(cols m,idxc)], ten = cdat m}
84
85td = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ diag [1..4] :: Tensor Double
86
87tn = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ (2><3) [1..6] :: Tensor Double
88tt = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ (2><3) [1..6] :: Tensor Double
89
90tq = T [(3,(Covariant,"p")),(2,(Covariant,"q")),(2,(Covariant,"r"))] $ fromList [11 .. 22] :: Tensor Double
91
92r1 = contraction tt "j" tq "p"
93r1' = contraction' tt "j" tq "p"
94
95pru = do
96 mapM_ (putStrLn.shdims.dims.normal) (contractions t1 t2)
97 let t1 = contraction tt "i" tq "q"
98 print $ normal t1
99 print $ foldl part t1 [("j",0),("p'",1),("r'",1)]
100 let t2 = contraction' tt "i" tq "q"
101 print $ normal t2
102 print $ foldl part t2 [("j",0),("p'",1),("r'",1)]
103 let t1 = contraction tq "q" tt "i"
104 print $ normal t1
105 print $ foldl part t1 [("j'",0),("p",1),("r",1)]
106 let t2 = contraction' tq "q" tt "i"
107 print $ normal t2
108 print $ foldl part t2 [("j'",0),("p",1),("r",1)]
109
110
111names t = sort $ map (snd.snd) (dims t)
112
113normal t = tridx (names t) t
95 114
96contract t1 n1 t2 n2 = T (tail d1++tail d2) (cdat m) 115contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ]
97 where (d1,m1) = putFirstIdx n1 t1
98 (d2,m2) = putFirstIdx n2 t2
99 m = multiply RowMajor (trans m2) m1 \ No newline at end of file
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index a2a70dd..ec6657a 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -65,6 +65,15 @@ partit :: Int -> [a] -> [[a]]
65partit _ [] = [] 65partit _ [] = []
66partit n l = take n l : partit n (drop n l) 66partit n l = take n l : partit n (drop n l)
67 67
68-- | obtains the common value of a property of a list
69common :: (Eq a) => (b->a) -> [b] -> Maybe a
70common f = commonval . map f where
71 commonval :: (Eq a) => [a] -> Maybe a
72 commonval [] = Nothing
73 commonval [a] = Just a
74 commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing
75
76
68toLists m | fortran m = transpose $ partit (rows m) . toList . dat $ m 77toLists m | fortran m = transpose $ partit (rows m) . toList . dat $ m
69 | otherwise = partit (cols m) . toList . dat $ m 78 | otherwise = partit (cols m) . toList . dat $ m
70 79
@@ -115,7 +124,7 @@ transdataAux fun c1 d c2 =
115 else unsafePerformIO $ do 124 else unsafePerformIO $ do
116 v <- createVector (dim d) 125 v <- createVector (dim d)
117 fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] 126 fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d]
118 putStrLn "---> transdataAux" 127 --putStrLn "---> transdataAux"
119 return v 128 return v
120 where r1 = dim d `div` c1 129 where r1 = dim d `div` c1
121 r2 = dim d `div` c2 130 r2 = dim d `div` c2
@@ -136,6 +145,12 @@ transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2
136--{-# RULES "transdataR" transdata=transdataR #-} 145--{-# RULES "transdataR" transdata=transdataR #-}
137--{-# RULES "transdataC" transdata=transdataC #-} 146--{-# RULES "transdataC" transdata=transdataC #-}
138 147
148-- | creates a Matrix from a list of vectors
149fromRows :: Field t => [Vector t] -> Matrix t
150fromRows vs = case common dim vs of
151 Nothing -> error "fromRows applied to [] or to vectors with different sizes"
152 Just c -> reshape c (join vs)
153
139-- | extracts the rows of a matrix as a list of vectors 154-- | extracts the rows of a matrix as a list of vectors
140toRows :: Storable t => Matrix t -> [Vector t] 155toRows :: Storable t => Matrix t -> [Vector t]
141toRows m = toRows' 0 where 156toRows m = toRows' 0 where
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs
index 960e3c5..b66d6b8 100644
--- a/lib/Data/Packed/Internal/Tensor.hs
+++ b/lib/Data/Packed/Internal/Tensor.hs
@@ -1,4 +1,4 @@
1{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} 1--{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-}
2----------------------------------------------------------------------------- 2-----------------------------------------------------------------------------
3-- | 3-- |
4-- Module : Data.Packed.Internal.Tensor 4-- Module : Data.Packed.Internal.Tensor
@@ -19,7 +19,7 @@ import Data.Packed.Internal.Vector
19import Data.Packed.Internal.Matrix 19import Data.Packed.Internal.Matrix
20import Foreign.Storable 20import Foreign.Storable
21 21
22data IdxTp = Covariant | Contravariant deriving Show 22data IdxTp = Covariant | Contravariant deriving (Show,Eq)
23 23
24data Tensor t = T { dims :: [(Int,(IdxTp,String))] 24data Tensor t = T { dims :: [(Int,(IdxTp,String))]
25 , ten :: Vector t 25 , ten :: Vector t
@@ -36,5 +36,70 @@ instance (Show a,Storable a) => Show (Tensor a) where
36 show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) 36 show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t)
37 37
38 38
39shdims [(n,(t,name))] = name++"["++show n++"]" 39shdims [(n,(t,name))] = name ++ sym t ++"["++show n++"]"
40shdims (d:ds) = shdims [d] ++ "><"++ shdims ds \ No newline at end of file 40 where sym Covariant = "_"
41 sym Contravariant = "^"
42shdims (d:ds) = shdims [d] ++ "><"++ shdims ds
43
44
45
46findIdx name t = ((d1,d2),m) where
47 (d1,d2) = span (\(_,(_,n)) -> n /=name) (dims t)
48 c = product (map fst d2)
49 m = matrixFromVector RowMajor c (ten t)
50
51putFirstIdx name t = (nd,m')
52 where ((d1,d2),m) = findIdx name t
53 m' = matrixFromVector RowMajor c $ cdat $ trans m
54 nd = d2++d1
55 c = dim (ten t) `div` (fst $ head d2)
56
57part t (name,k) = if k<0 || k>=l
58 then error $ "part "++show (name,k)++" out of range in "++show t
59 else T {dims = ds, ten = toRows m !! k}
60 where (d:ds,m) = putFirstIdx name t
61 (l,_) = d
62
63parts t name = map f (toRows m)
64 where (d:ds,m) = putFirstIdx name t
65 (l,_) = d
66 f t = T {dims=ds, ten=t}
67
68concatRename l1 l2 = l1 ++ map ren l2 where
69 ren (n,(t,s)) = if {- s `elem` fs -} True then (n,(t,s++"'")) else (n,(t,s))
70 fs = map (snd.snd) l1
71
72prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer v1 v2)
73
74contraction t1 n1 t2 n2 =
75 if compatIdx t1 n1 t2 n2
76 then T (concatRename (tail d1) (tail d2)) (cdat m)
77 else error "wrong contraction'"
78 where (d1,m1) = putFirstIdx n1 t1
79 (d2,m2) = putFirstIdx n2 t2
80 m = multiply RowMajor (trans m1) m2
81
82sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls)
83
84contract1 t name1 name2 = T d $ fromList $ sumT y
85 where d = dims (head y)
86 x = (map (flip parts name2) (parts t name1))
87 y = map head $ zipWith drop [0..] x
88
89contraction' t1 n1 t2 n2 =
90 if compatIdx t1 n1 t2 n2
91 then contract1 (prod t1 t2) n1 (n2++"'")
92 else error "wrong contraction'"
93
94tridx [] t = t
95tridx (name:rest) t = T (d:ds) (join ts) where
96 ((_,d:_),_) = findIdx name t
97 ps = map (tridx rest) (parts t name)
98 ts = map ten ps
99 ds = dims (head ps)
100
101compatIdxAux (n1,(t1,_)) (n2, (t2,_)) = t1 /= t2 && n1 == n2
102
103compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where
104 d1 = head $ snd $ fst $ findIdx n1 t1
105 d2 = head $ snd $ fst $ findIdx n2 t2
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs
index 6ed9339..36d5df7 100644
--- a/lib/Data/Packed/Internal/Vector.hs
+++ b/lib/Data/Packed/Internal/Vector.hs
@@ -35,6 +35,8 @@ instance (Storable a, RealFloat a) => Storable (Complex a) where --
35 poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- 35 poke p (a :+ b) = pokeArray (castPtr p) [a,b] --
36---------------------------------------------------------------------- 36----------------------------------------------------------------------
37 37
38on f g = \x y -> f (g x) (g y)
39
38(//) :: x -> (x -> y) -> y 40(//) :: x -> (x -> y) -> y
39infixl 0 // 41infixl 0 //
40(//) = flip ($) 42(//) = flip ($)