diff options
author | Alberto Ruiz <aruiz@um.es> | 2007-06-05 10:09:17 +0000 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2007-06-05 10:09:17 +0000 |
commit | 1fb4ea70c517050d3cbad75357a4fffbf5a40e7b (patch) | |
tree | 6107b470be9300297f9f08a4280fbf46faf6a862 | |
parent | 7430630fa0504296b796223e01cbd417b88650ef (diff) |
working on contraction
-rw-r--r-- | examples/pru.hs | 24 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 16 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Tensor.hs | 12 |
3 files changed, 34 insertions, 18 deletions
diff --git a/examples/pru.hs b/examples/pru.hs index d6dc5d4..bddc08f 100644 --- a/examples/pru.hs +++ b/examples/pru.hs | |||
@@ -55,24 +55,23 @@ rd = (2><2) | |||
55 | 55 | ||
56 | main = do | 56 | main = do |
57 | print $ r |=| rd | 57 | print $ r |=| rd |
58 | print $ foldl part t [("p",1),("q",0),("r",2)] | ||
58 | print $ foldl part t [("p",1),("r",2),("q",0)] | 59 | print $ foldl part t [("p",1),("r",2),("q",0)] |
60 | print $ foldl part t $ reverse [("p",1),("r",2),("q",0)] | ||
59 | 61 | ||
60 | t = T [(4,(Covariant,"p")),(2,(Covariant,"q")),(3,(Contravariant,"r"))] $ fromList [1..24::Double] | 62 | t = T [(4,(Covariant,"p")),(2,(Covariant,"q")),(3,(Contravariant,"r"))] $ fromList [1..24::Double] |
61 | 63 | ||
62 | 64 | ||
63 | findIdx name t = ((d1,d2),m) where | 65 | findIdx name t = ((d1,d2),m) where |
64 | (d1,d2) = span (\(_,(_,n)) -> n /=name) (dims t) | 66 | (d1,d2) = span (\(_,(_,n)) -> n /=name) (dims t) |
65 | c = product (map fst (tail d2)) | 67 | c = product (map fst d2) |
66 | m = matrixFromVector RowMajor c (ten t) | 68 | m = matrixFromVector RowMajor c (ten t) |
67 | 69 | ||
68 | 70 | putFirstIdx name t = (nd,m') | |
69 | putFirstIdx name t = | ||
70 | if null d1 | ||
71 | then (nd,m) | ||
72 | else (nd,m') | ||
73 | where ((d1,d2),m) = findIdx name t | 71 | where ((d1,d2),m) = findIdx name t |
74 | m' = trans $ matrixFromVector RowMajor (fst $ head d2) $ dat m | 72 | m' = matrixFromVector RowMajor c $ cdat $ trans m |
75 | nd = d2++d1 | 73 | nd = d2++d1 |
74 | c = dim (ten t) `div` (fst $ head d2) | ||
76 | 75 | ||
77 | part t (name,k) = if k<0 || k>=l | 76 | part t (name,k) = if k<0 || k>=l |
78 | then error $ "part "++show (name,k)++" out of range in "++show t | 77 | then error $ "part "++show (name,k)++" out of range in "++show t |
@@ -88,8 +87,13 @@ parts t name = map f (toRows m) | |||
88 | t1 = T [(4,(Covariant,"p")),(4,(Contravariant,"q")),(2,(Covariant,"r"))] $ fromList [1..32::Double] | 87 | t1 = T [(4,(Covariant,"p")),(4,(Contravariant,"q")),(2,(Covariant,"r"))] $ fromList [1..32::Double] |
89 | t2 = T [(4,(Covariant,"p")),(4,(Contravariant,"q"))] $ fromList [1..16::Double] | 88 | t2 = T [(4,(Covariant,"p")),(4,(Contravariant,"q"))] $ fromList [1..16::Double] |
90 | 89 | ||
91 | --contract1 t name1 name2 = map head $ zipWith drop [0..] (map (flip parts name2) (parts t name1)) | 90 | contract1 t name1 name2 = map head $ zipWith drop [0..] (map (flip parts name2) (parts t name1)) |
91 | |||
92 | sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) | ||
92 | 93 | ||
93 | --sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) | 94 | on f g = \x y -> f (g x) (g y) |
94 | 95 | ||
95 | on f g = \x y -> f (g x) (g y) \ No newline at end of file | 96 | contract t1 n1 t2 n2 = T (tail d1++tail d2) (cdat m) |
97 | where (d1,m1) = putFirstIdx n1 t1 | ||
98 | (d2,m2) = putFirstIdx n2 t2 | ||
99 | m = multiply RowMajor (trans m2) m1 \ No newline at end of file | ||
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 2c57c07..a6f7f0c 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -107,13 +107,17 @@ transdataR = transdataAux ctransR | |||
107 | transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) | 107 | transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) |
108 | transdataC = transdataAux ctransC | 108 | transdataC = transdataAux ctransC |
109 | 109 | ||
110 | transdataAux fun c1 d c2 = unsafePerformIO $ do | 110 | transdataAux fun c1 d c2 = |
111 | v <- createVector (dim d) | 111 | if noneed |
112 | let r1 = dim d `div` c1 | 112 | then d |
113 | else unsafePerformIO $ do | ||
114 | v <- createVector (dim d) | ||
115 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] | ||
116 | putStrLn "---> transdataAux" | ||
117 | return v | ||
118 | where r1 = dim d `div` c1 | ||
113 | r2 = dim d `div` c2 | 119 | r2 = dim d `div` c2 |
114 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] | 120 | noneed = r1 == 1 || c1 == 1 |
115 | --putStrLn "---> transdataAux" | ||
116 | return v | ||
117 | 121 | ||
118 | foreign import ccall safe "aux.h transR" | 122 | foreign import ccall safe "aux.h transR" |
119 | ctransR :: Double ::> Double ::> IO Int | 123 | ctransR :: Double ::> Double ::> IO Int |
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs index 11101a9..960e3c5 100644 --- a/lib/Data/Packed/Internal/Tensor.hs +++ b/lib/Data/Packed/Internal/Tensor.hs | |||
@@ -17,16 +17,24 @@ module Data.Packed.Internal.Tensor where | |||
17 | 17 | ||
18 | import Data.Packed.Internal.Vector | 18 | import Data.Packed.Internal.Vector |
19 | import Data.Packed.Internal.Matrix | 19 | import Data.Packed.Internal.Matrix |
20 | 20 | import Foreign.Storable | |
21 | 21 | ||
22 | data IdxTp = Covariant | Contravariant deriving Show | 22 | data IdxTp = Covariant | Contravariant deriving Show |
23 | 23 | ||
24 | data Tensor t = T { dims :: [(Int,(IdxTp,String))] | 24 | data Tensor t = T { dims :: [(Int,(IdxTp,String))] |
25 | , ten :: Vector t | 25 | , ten :: Vector t |
26 | } deriving Show | 26 | } |
27 | 27 | ||
28 | rank = length . dims | 28 | rank = length . dims |
29 | 29 | ||
30 | outer u v = dat (multiply RowMajor r c) | 30 | outer u v = dat (multiply RowMajor r c) |
31 | where r = matrixFromVector RowMajor 1 u | 31 | where r = matrixFromVector RowMajor 1 u |
32 | c = matrixFromVector RowMajor (dim v) v | 32 | c = matrixFromVector RowMajor (dim v) v |
33 | |||
34 | instance (Show a,Storable a) => Show (Tensor a) where | ||
35 | show T {dims = [], ten = t} = "scalar "++show (t `at` 0) | ||
36 | show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) | ||
37 | |||
38 | |||
39 | shdims [(n,(t,name))] = name++"["++show n++"]" | ||
40 | shdims (d:ds) = shdims [d] ++ "><"++ shdims ds \ No newline at end of file | ||