diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 16 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Tensor.hs | 12 |
2 files changed, 20 insertions, 8 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 2c57c07..a6f7f0c 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -107,13 +107,17 @@ transdataR = transdataAux ctransR | |||
107 | transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) | 107 | transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) |
108 | transdataC = transdataAux ctransC | 108 | transdataC = transdataAux ctransC |
109 | 109 | ||
110 | transdataAux fun c1 d c2 = unsafePerformIO $ do | 110 | transdataAux fun c1 d c2 = |
111 | v <- createVector (dim d) | 111 | if noneed |
112 | let r1 = dim d `div` c1 | 112 | then d |
113 | else unsafePerformIO $ do | ||
114 | v <- createVector (dim d) | ||
115 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] | ||
116 | putStrLn "---> transdataAux" | ||
117 | return v | ||
118 | where r1 = dim d `div` c1 | ||
113 | r2 = dim d `div` c2 | 119 | r2 = dim d `div` c2 |
114 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] | 120 | noneed = r1 == 1 || c1 == 1 |
115 | --putStrLn "---> transdataAux" | ||
116 | return v | ||
117 | 121 | ||
118 | foreign import ccall safe "aux.h transR" | 122 | foreign import ccall safe "aux.h transR" |
119 | ctransR :: Double ::> Double ::> IO Int | 123 | ctransR :: Double ::> Double ::> IO Int |
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs index 11101a9..960e3c5 100644 --- a/lib/Data/Packed/Internal/Tensor.hs +++ b/lib/Data/Packed/Internal/Tensor.hs | |||
@@ -17,16 +17,24 @@ module Data.Packed.Internal.Tensor where | |||
17 | 17 | ||
18 | import Data.Packed.Internal.Vector | 18 | import Data.Packed.Internal.Vector |
19 | import Data.Packed.Internal.Matrix | 19 | import Data.Packed.Internal.Matrix |
20 | 20 | import Foreign.Storable | |
21 | 21 | ||
22 | data IdxTp = Covariant | Contravariant deriving Show | 22 | data IdxTp = Covariant | Contravariant deriving Show |
23 | 23 | ||
24 | data Tensor t = T { dims :: [(Int,(IdxTp,String))] | 24 | data Tensor t = T { dims :: [(Int,(IdxTp,String))] |
25 | , ten :: Vector t | 25 | , ten :: Vector t |
26 | } deriving Show | 26 | } |
27 | 27 | ||
28 | rank = length . dims | 28 | rank = length . dims |
29 | 29 | ||
30 | outer u v = dat (multiply RowMajor r c) | 30 | outer u v = dat (multiply RowMajor r c) |
31 | where r = matrixFromVector RowMajor 1 u | 31 | where r = matrixFromVector RowMajor 1 u |
32 | c = matrixFromVector RowMajor (dim v) v | 32 | c = matrixFromVector RowMajor (dim v) v |
33 | |||
34 | instance (Show a,Storable a) => Show (Tensor a) where | ||
35 | show T {dims = [], ten = t} = "scalar "++show (t `at` 0) | ||
36 | show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) | ||
37 | |||
38 | |||
39 | shdims [(n,(t,name))] = name++"["++show n++"]" | ||
40 | shdims (d:ds) = shdims [d] ++ "><"++ shdims ds \ No newline at end of file | ||