summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2007-06-05 10:09:17 +0000
committerAlberto Ruiz <aruiz@um.es>2007-06-05 10:09:17 +0000
commit1fb4ea70c517050d3cbad75357a4fffbf5a40e7b (patch)
tree6107b470be9300297f9f08a4280fbf46faf6a862
parent7430630fa0504296b796223e01cbd417b88650ef (diff)
working on contraction
-rw-r--r--examples/pru.hs24
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs16
-rw-r--r--lib/Data/Packed/Internal/Tensor.hs12
3 files changed, 34 insertions, 18 deletions
diff --git a/examples/pru.hs b/examples/pru.hs
index d6dc5d4..bddc08f 100644
--- a/examples/pru.hs
+++ b/examples/pru.hs
@@ -55,24 +55,23 @@ rd = (2><2)
55 55
56main = do 56main = do
57 print $ r |=| rd 57 print $ r |=| rd
58 print $ foldl part t [("p",1),("q",0),("r",2)]
58 print $ foldl part t [("p",1),("r",2),("q",0)] 59 print $ foldl part t [("p",1),("r",2),("q",0)]
60 print $ foldl part t $ reverse [("p",1),("r",2),("q",0)]
59 61
60t = T [(4,(Covariant,"p")),(2,(Covariant,"q")),(3,(Contravariant,"r"))] $ fromList [1..24::Double] 62t = T [(4,(Covariant,"p")),(2,(Covariant,"q")),(3,(Contravariant,"r"))] $ fromList [1..24::Double]
61 63
62 64
63findIdx name t = ((d1,d2),m) where 65findIdx name t = ((d1,d2),m) where
64 (d1,d2) = span (\(_,(_,n)) -> n /=name) (dims t) 66 (d1,d2) = span (\(_,(_,n)) -> n /=name) (dims t)
65 c = product (map fst (tail d2)) 67 c = product (map fst d2)
66 m = matrixFromVector RowMajor c (ten t) 68 m = matrixFromVector RowMajor c (ten t)
67 69
68 70putFirstIdx name t = (nd,m')
69putFirstIdx name t =
70 if null d1
71 then (nd,m)
72 else (nd,m')
73 where ((d1,d2),m) = findIdx name t 71 where ((d1,d2),m) = findIdx name t
74 m' = trans $ matrixFromVector RowMajor (fst $ head d2) $ dat m 72 m' = matrixFromVector RowMajor c $ cdat $ trans m
75 nd = d2++d1 73 nd = d2++d1
74 c = dim (ten t) `div` (fst $ head d2)
76 75
77part t (name,k) = if k<0 || k>=l 76part t (name,k) = if k<0 || k>=l
78 then error $ "part "++show (name,k)++" out of range in "++show t 77 then error $ "part "++show (name,k)++" out of range in "++show t
@@ -88,8 +87,13 @@ parts t name = map f (toRows m)
88t1 = T [(4,(Covariant,"p")),(4,(Contravariant,"q")),(2,(Covariant,"r"))] $ fromList [1..32::Double] 87t1 = T [(4,(Covariant,"p")),(4,(Contravariant,"q")),(2,(Covariant,"r"))] $ fromList [1..32::Double]
89t2 = T [(4,(Covariant,"p")),(4,(Contravariant,"q"))] $ fromList [1..16::Double] 88t2 = T [(4,(Covariant,"p")),(4,(Contravariant,"q"))] $ fromList [1..16::Double]
90 89
91--contract1 t name1 name2 = map head $ zipWith drop [0..] (map (flip parts name2) (parts t name1)) 90contract1 t name1 name2 = map head $ zipWith drop [0..] (map (flip parts name2) (parts t name1))
91
92sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls)
92 93
93--sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) 94on f g = \x y -> f (g x) (g y)
94 95
95on f g = \x y -> f (g x) (g y) \ No newline at end of file 96contract t1 n1 t2 n2 = T (tail d1++tail d2) (cdat m)
97 where (d1,m1) = putFirstIdx n1 t1
98 (d2,m2) = putFirstIdx n2 t2
99 m = multiply RowMajor (trans m2) m1 \ No newline at end of file
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index 2c57c07..a6f7f0c 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -107,13 +107,17 @@ transdataR = transdataAux ctransR
107transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) 107transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double)
108transdataC = transdataAux ctransC 108transdataC = transdataAux ctransC
109 109
110transdataAux fun c1 d c2 = unsafePerformIO $ do 110transdataAux fun c1 d c2 =
111 v <- createVector (dim d) 111 if noneed
112 let r1 = dim d `div` c1 112 then d
113 else unsafePerformIO $ do
114 v <- createVector (dim d)
115 fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d]
116 putStrLn "---> transdataAux"
117 return v
118 where r1 = dim d `div` c1
113 r2 = dim d `div` c2 119 r2 = dim d `div` c2
114 fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] 120 noneed = r1 == 1 || c1 == 1
115 --putStrLn "---> transdataAux"
116 return v
117 121
118foreign import ccall safe "aux.h transR" 122foreign import ccall safe "aux.h transR"
119 ctransR :: Double ::> Double ::> IO Int 123 ctransR :: Double ::> Double ::> IO Int
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs
index 11101a9..960e3c5 100644
--- a/lib/Data/Packed/Internal/Tensor.hs
+++ b/lib/Data/Packed/Internal/Tensor.hs
@@ -17,16 +17,24 @@ module Data.Packed.Internal.Tensor where
17 17
18import Data.Packed.Internal.Vector 18import Data.Packed.Internal.Vector
19import Data.Packed.Internal.Matrix 19import Data.Packed.Internal.Matrix
20 20import Foreign.Storable
21 21
22data IdxTp = Covariant | Contravariant deriving Show 22data IdxTp = Covariant | Contravariant deriving Show
23 23
24data Tensor t = T { dims :: [(Int,(IdxTp,String))] 24data Tensor t = T { dims :: [(Int,(IdxTp,String))]
25 , ten :: Vector t 25 , ten :: Vector t
26 } deriving Show 26 }
27 27
28rank = length . dims 28rank = length . dims
29 29
30outer u v = dat (multiply RowMajor r c) 30outer u v = dat (multiply RowMajor r c)
31 where r = matrixFromVector RowMajor 1 u 31 where r = matrixFromVector RowMajor 1 u
32 c = matrixFromVector RowMajor (dim v) v 32 c = matrixFromVector RowMajor (dim v) v
33
34instance (Show a,Storable a) => Show (Tensor a) where
35 show T {dims = [], ten = t} = "scalar "++show (t `at` 0)
36 show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t)
37
38
39shdims [(n,(t,name))] = name++"["++show n++"]"
40shdims (d:ds) = shdims [d] ++ "><"++ shdims ds \ No newline at end of file