summaryrefslogtreecommitdiff
path: root/lib/Data/Packed/Internal
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2007-06-05 10:09:17 +0000
committerAlberto Ruiz <aruiz@um.es>2007-06-05 10:09:17 +0000
commit1fb4ea70c517050d3cbad75357a4fffbf5a40e7b (patch)
tree6107b470be9300297f9f08a4280fbf46faf6a862 /lib/Data/Packed/Internal
parent7430630fa0504296b796223e01cbd417b88650ef (diff)
working on contraction
Diffstat (limited to 'lib/Data/Packed/Internal')
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs16
-rw-r--r--lib/Data/Packed/Internal/Tensor.hs12
2 files changed, 20 insertions, 8 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index 2c57c07..a6f7f0c 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -107,13 +107,17 @@ transdataR = transdataAux ctransR
107transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) 107transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double)
108transdataC = transdataAux ctransC 108transdataC = transdataAux ctransC
109 109
110transdataAux fun c1 d c2 = unsafePerformIO $ do 110transdataAux fun c1 d c2 =
111 v <- createVector (dim d) 111 if noneed
112 let r1 = dim d `div` c1 112 then d
113 else unsafePerformIO $ do
114 v <- createVector (dim d)
115 fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d]
116 putStrLn "---> transdataAux"
117 return v
118 where r1 = dim d `div` c1
113 r2 = dim d `div` c2 119 r2 = dim d `div` c2
114 fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] 120 noneed = r1 == 1 || c1 == 1
115 --putStrLn "---> transdataAux"
116 return v
117 121
118foreign import ccall safe "aux.h transR" 122foreign import ccall safe "aux.h transR"
119 ctransR :: Double ::> Double ::> IO Int 123 ctransR :: Double ::> Double ::> IO Int
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs
index 11101a9..960e3c5 100644
--- a/lib/Data/Packed/Internal/Tensor.hs
+++ b/lib/Data/Packed/Internal/Tensor.hs
@@ -17,16 +17,24 @@ module Data.Packed.Internal.Tensor where
17 17
18import Data.Packed.Internal.Vector 18import Data.Packed.Internal.Vector
19import Data.Packed.Internal.Matrix 19import Data.Packed.Internal.Matrix
20 20import Foreign.Storable
21 21
22data IdxTp = Covariant | Contravariant deriving Show 22data IdxTp = Covariant | Contravariant deriving Show
23 23
24data Tensor t = T { dims :: [(Int,(IdxTp,String))] 24data Tensor t = T { dims :: [(Int,(IdxTp,String))]
25 , ten :: Vector t 25 , ten :: Vector t
26 } deriving Show 26 }
27 27
28rank = length . dims 28rank = length . dims
29 29
30outer u v = dat (multiply RowMajor r c) 30outer u v = dat (multiply RowMajor r c)
31 where r = matrixFromVector RowMajor 1 u 31 where r = matrixFromVector RowMajor 1 u
32 c = matrixFromVector RowMajor (dim v) v 32 c = matrixFromVector RowMajor (dim v) v
33
34instance (Show a,Storable a) => Show (Tensor a) where
35 show T {dims = [], ten = t} = "scalar "++show (t `at` 0)
36 show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t)
37
38
39shdims [(n,(t,name))] = name++"["++show n++"]"
40shdims (d:ds) = shdims [d] ++ "><"++ shdims ds \ No newline at end of file