diff options
author | Alberto Ruiz <aruiz@um.es> | 2007-06-04 08:34:45 +0000 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2007-06-04 08:34:45 +0000 |
commit | 0a9817cc481fb09f1962eb2c272125e56a123814 (patch) | |
tree | e444abd9f1918e9a25e2b99f6c8498d0f03fcdf3 /lib/Data | |
parent | 80673221e704b451e0d9468d6dfe1a38ad676c07 (diff) |
fortran/C
Diffstat (limited to 'lib/Data')
-rw-r--r-- | lib/Data/Packed/Internal.hs | 286 | ||||
-rw-r--r-- | lib/Data/Packed/aux.c | 98 | ||||
-rw-r--r-- | lib/Data/Packed/aux.h | 25 |
3 files changed, 312 insertions, 97 deletions
diff --git a/lib/Data/Packed/Internal.hs b/lib/Data/Packed/Internal.hs index 5e19e58..b06f044 100644 --- a/lib/Data/Packed/Internal.hs +++ b/lib/Data/Packed/Internal.hs | |||
@@ -1,3 +1,4 @@ | |||
1 | {-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} | ||
1 | ----------------------------------------------------------------------------- | 2 | ----------------------------------------------------------------------------- |
2 | -- | | 3 | -- | |
3 | -- Module : Data.Packed.Internal | 4 | -- Module : Data.Packed.Internal |
@@ -14,39 +15,16 @@ | |||
14 | 15 | ||
15 | module Data.Packed.Internal where | 16 | module Data.Packed.Internal where |
16 | 17 | ||
17 | import Foreign | 18 | import Foreign hiding (xor) |
18 | import Complex | 19 | import Complex |
19 | import Control.Monad(when) | 20 | import Control.Monad(when) |
20 | import Debug.Trace | 21 | import Debug.Trace |
22 | import Data.List(transpose,intersperse) | ||
23 | import Data.Typeable | ||
24 | import Data.Maybe(fromJust) | ||
21 | 25 | ||
22 | debug x = trace (show x) x | 26 | debug x = trace (show x) x |
23 | 27 | ||
24 | -- | 1D array | ||
25 | data Vector t = V { dim :: Int | ||
26 | , fptr :: ForeignPtr t | ||
27 | , ptr :: Ptr t | ||
28 | } | ||
29 | |||
30 | data TransMode = NoTrans | Trans | ConjTrans | ||
31 | |||
32 | -- | 2D array | ||
33 | data Matrix t = M { rows :: Int | ||
34 | , cols :: Int | ||
35 | , mat :: Vector t | ||
36 | , trMode :: TransMode | ||
37 | , isCOrder :: Bool | ||
38 | } | ||
39 | |||
40 | data IdxTp = Covariant | Contravariant | ||
41 | |||
42 | -- | multidimensional array | ||
43 | data Tensor t = T { rank :: Int | ||
44 | , dims :: [Int] | ||
45 | , idxNm :: [String] | ||
46 | , idxTp :: [IdxTp] | ||
47 | , ten :: Vector t | ||
48 | } | ||
49 | |||
50 | ---------------------------------------------------------------------- | 28 | ---------------------------------------------------------------------- |
51 | instance (Storable a, RealFloat a) => Storable (Complex a) where -- | 29 | instance (Storable a, RealFloat a) => Storable (Complex a) where -- |
52 | alignment x = alignment (realPart x) -- | 30 | alignment x = alignment (realPart x) -- |
@@ -57,36 +35,36 @@ instance (Storable a, RealFloat a) => Storable (Complex a) where -- | |||
57 | poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- | 35 | poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- |
58 | ---------------------------------------------------------------------- | 36 | ---------------------------------------------------------------------- |
59 | 37 | ||
60 | |||
61 | -- f // vec a // vec b // vec res // check "vector add" [a,b] | ||
62 | |||
63 | (//) :: x -> (x -> y) -> y | 38 | (//) :: x -> (x -> y) -> y |
64 | infixl 0 // | 39 | infixl 0 // |
65 | (//) = flip ($) | 40 | (//) = flip ($) |
66 | 41 | ||
67 | vec :: Vector a -> (Int -> Ptr b -> t) -> t | ||
68 | vec v f = f (dim v) (castPtr $ ptr v) | ||
69 | |||
70 | mata :: Matrix a -> (Int-> Int -> Ptr b -> t) -> t | ||
71 | mata m f = f (rows m) (cols m) (castPtr $ ptr (mat m)) | ||
72 | |||
73 | pd2pc :: Ptr Double -> Ptr (Complex (Double)) | ||
74 | pd2pc = castPtr | ||
75 | |||
76 | pc2pd :: Ptr (Complex (Double)) -> Ptr Double | ||
77 | pc2pd = castPtr | ||
78 | |||
79 | check msg ls f = do | 42 | check msg ls f = do |
80 | err <- f | 43 | err <- f |
81 | when (err/=0) (error msg) | 44 | when (err/=0) (error msg) |
82 | mapM_ (touchForeignPtr . fptr) ls | 45 | mapM_ (touchForeignPtr . fptr) ls |
83 | return () | 46 | return () |
84 | 47 | ||
48 | ---------------------------------------------------------------------- | ||
49 | |||
50 | data Vector t = V { dim :: Int | ||
51 | , fptr :: ForeignPtr t | ||
52 | , ptr :: Ptr t | ||
53 | } deriving Typeable | ||
54 | |||
55 | type Vc t s = Int -> Ptr t -> s | ||
56 | infixr 5 :> | ||
57 | type t :> s = Vc t s | ||
58 | |||
59 | vec :: Vector t -> (Vc t s) -> s | ||
60 | vec v f = f (dim v) (ptr v) | ||
61 | |||
85 | createVector :: Storable a => Int -> IO (Vector a) | 62 | createVector :: Storable a => Int -> IO (Vector a) |
86 | createVector n = do | 63 | createVector n = do |
87 | when (n <= 0) $ error ("trying to createVector of dim "++show n) | 64 | when (n <= 0) $ error ("trying to createVector of dim "++show n) |
88 | fp <- mallocForeignPtrArray n | 65 | fp <- mallocForeignPtrArray n |
89 | let p = unsafeForeignPtrToPtr fp | 66 | let p = unsafeForeignPtrToPtr fp |
67 | --putStrLn ("\n---------> V"++show n) | ||
90 | return $ V n fp p | 68 | return $ V n fp p |
91 | 69 | ||
92 | fromList :: Storable a => [a] -> Vector a | 70 | fromList :: Storable a => [a] -> Vector a |
@@ -99,6 +77,8 @@ fromList l = unsafePerformIO $ do | |||
99 | toList :: Storable a => Vector a -> [a] | 77 | toList :: Storable a => Vector a -> [a] |
100 | toList v = unsafePerformIO $ peekArray (dim v) (ptr v) | 78 | toList v = unsafePerformIO $ peekArray (dim v) (ptr v) |
101 | 79 | ||
80 | n # l = if length l == n then fromList l else error "# with wrong size" | ||
81 | |||
102 | at' :: Storable a => Vector a -> Int -> a | 82 | at' :: Storable a => Vector a -> Int -> a |
103 | at' v n = unsafePerformIO $ peekElemOff (ptr v) n | 83 | at' v n = unsafePerformIO $ peekElemOff (ptr v) n |
104 | 84 | ||
@@ -106,42 +86,208 @@ at :: Storable a => Vector a -> Int -> a | |||
106 | at v n | n >= 0 && n < dim v = at' v n | 86 | at v n | n >= 0 && n < dim v = at' v n |
107 | | otherwise = error "vector index out of range" | 87 | | otherwise = error "vector index out of range" |
108 | 88 | ||
109 | dsv v = sizeOf (v `at` 0) | 89 | instance (Show a, Storable a) => (Show (Vector a)) where |
110 | dsm m = (dsv.mat) m | 90 | show v = (show (dim v))++" # " ++ show (toList v) |
111 | 91 | ||
112 | constant :: Storable a => Int -> a -> Vector a | 92 | ------------------------------------------------------------------------ |
113 | constant n x = unsafePerformIO $ do | ||
114 | v <- createVector n | ||
115 | let f k p | k == n = return 0 | ||
116 | | otherwise = pokeElemOff p k x >> f (k+1) p | ||
117 | const (f 0) // vec v // check "constant" [] | ||
118 | return v | ||
119 | 93 | ||
120 | instance (Show a, Storable a) => (Show (Vector a)) where | 94 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) |
121 | show v = "fromList " ++ show (toList v) | 95 | |
96 | -- | 2D array | ||
97 | data Matrix t = M { rows :: Int | ||
98 | , cols :: Int | ||
99 | , cmat :: Vector t | ||
100 | , fmat :: Vector t | ||
101 | , isTrans :: Bool | ||
102 | , order :: MatrixOrder | ||
103 | } deriving Typeable | ||
104 | |||
105 | xor a b = a && not b || b && not a | ||
106 | |||
107 | fortran m = order m == ColumnMajor | ||
108 | |||
109 | dat m = if fortran m `xor` isTrans m then fmat m else cmat m | ||
110 | |||
111 | pref m = if fortran m then fmat m else cmat m | ||
112 | |||
113 | trans m = m { rows = cols m | ||
114 | , cols = rows m | ||
115 | , isTrans = not (isTrans m) | ||
116 | } | ||
117 | |||
118 | type Mt t s = Int -> Int -> Ptr t -> s | ||
119 | infixr 6 ::> | ||
120 | type t ::> s = Mt t s | ||
121 | |||
122 | mat :: Matrix t -> (Mt t s) -> s | ||
123 | mat m f = f (rows m) (cols m) (ptr (dat m)) | ||
124 | |||
125 | gmat m f | fortran m = | ||
126 | if (isTrans m) | ||
127 | then f 0 (rows m) (cols m) (ptr (fmat m)) | ||
128 | else f 1 (cols m) (rows m) (ptr (fmat m)) | ||
129 | | otherwise = | ||
130 | if isTrans m | ||
131 | then f 1 (cols m) (rows m) (ptr (cmat m)) | ||
132 | else f 0 (rows m) (cols m) (ptr (cmat m)) | ||
122 | 133 | ||
123 | instance (Show a, Storable a) => (Show (Matrix a)) where | 134 | instance (Show a, Storable a) => (Show (Matrix a)) where |
124 | show m = "reshape "++show (cols m) ++ " $ fromList " ++ show (toList (mat m)) | 135 | show m = (sizes++) . dsp . map (map show) . toLists $ m |
136 | where sizes = "("++show (rows m)++"><"++show (cols m)++")\n" | ||
137 | |||
138 | partit :: Int -> [a] -> [[a]] | ||
139 | partit _ [] = [] | ||
140 | partit n l = take n l : partit n (drop n l) | ||
141 | |||
142 | toLists m = partit (cols m) . toList . cmat $ m | ||
125 | 143 | ||
126 | reshape :: Storable a => Int -> Vector a -> Matrix a | 144 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp |
127 | reshape n v = M { rows = dim v `div` n | 145 | where |
128 | , cols = n | 146 | mt = transpose as |
129 | , mat = v | 147 | longs = map (maximum . map length) mt |
130 | , trMode = NoTrans | 148 | mtp = zipWith (\a b -> map (pad a) b) longs mt |
131 | , isCOrder = True | 149 | pad n str = replicate (n - length str) ' ' ++ str |
132 | } | 150 | unwords' = concat . intersperse ", " |
133 | 151 | ||
134 | createMatrix r c = do | 152 | matrixFromVector RowMajor c v = |
153 | M { rows = r | ||
154 | , cols = c | ||
155 | , cmat = v | ||
156 | , fmat = transdata c v r | ||
157 | , order = RowMajor | ||
158 | , isTrans = False | ||
159 | } where r = dim v `div` c -- TODO check mod=0 | ||
160 | |||
161 | matrixFromVector ColumnMajor c v = | ||
162 | M { rows = r | ||
163 | , cols = c | ||
164 | , fmat = v | ||
165 | , cmat = transdata c v r | ||
166 | , order = ColumnMajor | ||
167 | , isTrans = False | ||
168 | } where r = dim v `div` c -- TODO check mod=0 | ||
169 | |||
170 | createMatrix order r c = do | ||
135 | p <- createVector (r*c) | 171 | p <- createVector (r*c) |
136 | return (reshape c p) | 172 | return (matrixFromVector order c p) |
173 | |||
174 | transdataG :: Storable a => Int -> Vector a -> Int -> Vector a | ||
175 | transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d | ||
176 | |||
177 | transdataR :: Int -> Vector Double -> Int -> Vector Double | ||
178 | transdataR = transdataAux ctransR | ||
179 | |||
180 | transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) | ||
181 | transdataC = transdataAux ctransC | ||
182 | |||
183 | transdataAux fun c1 d c2 = unsafePerformIO $ do | ||
184 | v <- createVector (dim d) | ||
185 | let r1 = dim d `div` c1 | ||
186 | r2 = dim d `div` c2 | ||
187 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] | ||
188 | --putStrLn "---> transdataAux" | ||
189 | return v | ||
190 | |||
191 | foreign import ccall safe "aux.h transR" | ||
192 | ctransR :: Double ::> Double ::> IO Int | ||
193 | foreign import ccall safe "aux.h transC" | ||
194 | ctransC :: Complex Double ::> Complex Double ::> IO Int | ||
195 | |||
196 | |||
197 | class (Storable a, Typeable a) => Field a where | ||
198 | instance (Storable a, Typeable a) => Field a where | ||
199 | |||
200 | isReal w x = typeOf (undefined :: Double) == typeOf (w x) | ||
201 | isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x) | ||
202 | baseOf v = (v `at` 0) | ||
203 | |||
204 | scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b | ||
205 | scast = fromJust . cast | ||
206 | |||
207 | transdata :: Field a => Int -> Vector a -> Int -> Vector a | ||
208 | transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2 | ||
209 | | isComp baseOf d = scast $ transdataC c1 (scast d) c2 | ||
210 | | otherwise = transdataG c1 d c2 | ||
211 | |||
212 | --transdata :: Storable a => Int -> Vector a -> Int -> Vector a | ||
213 | --transdata = transdataG | ||
214 | --{-# RULES "transdataR" transdata=transdataR #-} | ||
215 | --{-# RULES "transdataC" transdata=transdataC #-} | ||
216 | |||
217 | ------------------------------------------------------------------ | ||
218 | |||
219 | constantG n x = fromList (replicate n x) | ||
220 | |||
221 | constantR :: Int -> Double -> Vector Double | ||
222 | constantR = constantAux cconstantR | ||
223 | |||
224 | constantC :: Int -> Complex Double -> Vector (Complex Double) | ||
225 | constantC = constantAux cconstantC | ||
226 | |||
227 | constantAux fun n x = unsafePerformIO $ do | ||
228 | v <- createVector n | ||
229 | px <- newArray [x] | ||
230 | fun px // vec v // check "constantAux" [] | ||
231 | free px | ||
232 | return v | ||
137 | 233 | ||
138 | type CMat s = Int -> Int -> Ptr Double -> s | 234 | foreign import ccall safe "aux.h constantR" |
139 | type CVec s = Int -> Ptr Double -> s | 235 | cconstantR :: Ptr Double -> Double :> IO Int |
140 | 236 | ||
141 | foreign import ccall safe "aux.h trans" ctrans :: Int -> CMat (CMat (IO Int)) | 237 | foreign import ccall safe "aux.h constantC" |
238 | cconstantC :: Ptr (Complex Double) -> Complex Double :> IO Int | ||
142 | 239 | ||
143 | trans :: Storable a => Matrix a -> Matrix a | 240 | constant :: Field a => Int -> a -> Vector a |
144 | trans m = unsafePerformIO $ do | 241 | constant n x | isReal id x = scast $ constantR n (scast x) |
145 | r <- createMatrix (cols m) (rows m) | 242 | | isComp id x = scast $ constantC n (scast x) |
146 | ctrans (dsm m) // mata m // mata r // check "trans" [mat m] | 243 | | otherwise = constantG n x |
244 | |||
245 | ------------------------------------------------------------------ | ||
246 | |||
247 | dotL a b = sum (zipWith (*) a b) | ||
248 | |||
249 | multiplyL a b = [[dotL x y | y <- transpose b] | x <- a] | ||
250 | |||
251 | transL m = m {rows = cols m, cols = rows m, cmat = v, fmat = cmat m} | ||
252 | where v = transdataG (cols m) (cmat m) (rows m) | ||
253 | |||
254 | ------------------------------------------------------------------ | ||
255 | |||
256 | multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b) | ||
257 | |||
258 | multiplyAux order fun a b = unsafePerformIO $ do | ||
259 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ | ||
260 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) | ||
261 | r <- createMatrix order (rows a) (cols b) | ||
262 | fun // gmat a // gmat b // mat r // check "multiplyAux" [pref a, pref b] | ||
147 | return r | 263 | return r |
264 | |||
265 | foreign import ccall safe "aux.h multiplyR" | ||
266 | cmultiplyR :: Int -> Double ::> (Int -> Double ::> (Double ::> IO Int)) | ||
267 | |||
268 | foreign import ccall safe "aux.h multiplyC" | ||
269 | cmultiplyC :: Int -> Complex Double ::> (Int -> Complex Double ::> (Complex Double ::> IO Int)) | ||
270 | |||
271 | multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a | ||
272 | multiply RowMajor a b = multiplyD RowMajor a b | ||
273 | multiply ColumnMajor a b = trans $ multiplyT ColumnMajor a b | ||
274 | |||
275 | multiplyT order a b = multiplyD order (trans b) (trans a) | ||
276 | |||
277 | multiplyD order a b | ||
278 | | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b) | ||
279 | | isComp (baseOf.dat) a = scast $ multiplyAux order cmultiplyC (scast a) (scast b) | ||
280 | | otherwise = multiplyG a b | ||
281 | |||
282 | -------------------------------------------------------------------- | ||
283 | |||
284 | data IdxTp = Covariant | Contravariant | ||
285 | |||
286 | -- | multidimensional array | ||
287 | data Tensor t = T { rank :: Int | ||
288 | , dims :: [Int] | ||
289 | , idxNm :: [String] | ||
290 | , idxTp :: [IdxTp] | ||
291 | , ten :: Vector t | ||
292 | } | ||
293 | |||
diff --git a/lib/Data/Packed/aux.c b/lib/Data/Packed/aux.c index d772d90..da36035 100644 --- a/lib/Data/Packed/aux.c +++ b/lib/Data/Packed/aux.c | |||
@@ -48,12 +48,12 @@ | |||
48 | 48 | ||
49 | #define DVVIEW(A) gsl_vector_view A = gsl_vector_view_array(A##p,A##n) | 49 | #define DVVIEW(A) gsl_vector_view A = gsl_vector_view_array(A##p,A##n) |
50 | #define DMVIEW(A) gsl_matrix_view A = gsl_matrix_view_array(A##p,A##r,A##c) | 50 | #define DMVIEW(A) gsl_matrix_view A = gsl_matrix_view_array(A##p,A##r,A##c) |
51 | #define CVVIEW(A) gsl_vector_complex_view A = gsl_vector_complex_view_array(A##p,A##n) | 51 | #define CVVIEW(A) gsl_vector_complex_view A = gsl_vector_complex_view_array((double*)A##p,A##n) |
52 | #define CMVIEW(A) gsl_matrix_complex_view A = gsl_matrix_complex_view_array(A##p,A##r,A##c) | 52 | #define CMVIEW(A) gsl_matrix_complex_view A = gsl_matrix_complex_view_array((double*)A##p,A##r,A##c) |
53 | #define KDVVIEW(A) gsl_vector_const_view A = gsl_vector_const_view_array(A##p,A##n) | 53 | #define KDVVIEW(A) gsl_vector_const_view A = gsl_vector_const_view_array(A##p,A##n) |
54 | #define KDMVIEW(A) gsl_matrix_const_view A = gsl_matrix_const_view_array(A##p,A##r,A##c) | 54 | #define KDMVIEW(A) gsl_matrix_const_view A = gsl_matrix_const_view_array(A##p,A##r,A##c) |
55 | #define KCVVIEW(A) gsl_vector_complex_const_view A = gsl_vector_complex_const_view_array(A##p,A##n) | 55 | #define KCVVIEW(A) gsl_vector_complex_const_view A = gsl_vector_complex_const_view_array((double*)A##p,A##n) |
56 | #define KCMVIEW(A) gsl_matrix_complex_const_view A = gsl_matrix_complex_const_view_array(A##p,A##r,A##c) | 56 | #define KCMVIEW(A) gsl_matrix_complex_const_view A = gsl_matrix_complex_const_view_array((double*)A##p,A##r,A##c) |
57 | 57 | ||
58 | #define V(a) (&a.vector) | 58 | #define V(a) (&a.vector) |
59 | #define M(a) (&a.matrix) | 59 | #define M(a) (&a.matrix) |
@@ -65,26 +65,80 @@ | |||
65 | #define BAD_CODE 1001 | 65 | #define BAD_CODE 1001 |
66 | #define MEM 1002 | 66 | #define MEM 1002 |
67 | #define BAD_FILE 1003 | 67 | #define BAD_FILE 1003 |
68 | #define BAD_TYPE 1004 | ||
69 | 68 | ||
70 | 69 | ||
71 | int trans(int size,KMAT(x),MAT(t)) { | 70 | |
71 | int transR(KRMAT(x),RMAT(t)) { | ||
72 | REQUIRES(xr==tc && xc==tr,BAD_SIZE); | ||
73 | DEBUGMSG("transR"); | ||
74 | KDMVIEW(x); | ||
75 | DMVIEW(t); | ||
76 | int res = gsl_matrix_transpose_memcpy(M(t),M(x)); | ||
77 | CHECK(res,res); | ||
78 | OK | ||
79 | } | ||
80 | |||
81 | int transC(KCMAT(x),CMAT(t)) { | ||
72 | REQUIRES(xr==tc && xc==tr,BAD_SIZE); | 82 | REQUIRES(xr==tc && xc==tr,BAD_SIZE); |
73 | DEBUGMSG("trans"); | 83 | DEBUGMSG("transC"); |
74 | if(size==8) { | 84 | KCMVIEW(x); |
75 | DEBUGMSG("trans double"); | 85 | CMVIEW(t); |
76 | KDMVIEW(x); | 86 | int res = gsl_matrix_complex_transpose_memcpy(M(t),M(x)); |
77 | DMVIEW(t); | 87 | CHECK(res,res); |
78 | int res = gsl_matrix_transpose_memcpy(M(t),M(x)); | 88 | OK |
79 | CHECK(res,res); | 89 | } |
80 | OK | 90 | |
81 | } else if (size==16) { | 91 | |
82 | DEBUGMSG("trans complex double"); | 92 | int constantR(double * pval, RVEC(r)) { |
83 | KCMVIEW(x); | 93 | DEBUGMSG("constantR") |
84 | CMVIEW(t); | 94 | int k; |
85 | int res = gsl_matrix_complex_transpose_memcpy(M(t),M(x)); | 95 | double val = *pval; |
86 | CHECK(res,res); | 96 | for(k=0;k<rn;k++) { |
87 | OK | 97 | rp[k]=val; |
88 | } | 98 | } |
89 | return BAD_TYPE; | 99 | OK |
100 | } | ||
101 | |||
102 | int constantC(gsl_complex* pval, CVEC(r)) { | ||
103 | DEBUGMSG("constantC") | ||
104 | int k; | ||
105 | gsl_complex val = *pval; | ||
106 | for(k=0;k<rn;k++) { | ||
107 | rp[k]=val; | ||
108 | } | ||
109 | OK | ||
110 | } | ||
111 | |||
112 | int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r)) { | ||
113 | //printf("%d %d %d %d %d %d\n",ar,ac,br,bc,rr,rc); | ||
114 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
115 | DEBUGMSG("multiplyR (gsl_blas_dgemm)"); | ||
116 | KDMVIEW(a); | ||
117 | KDMVIEW(b); | ||
118 | DMVIEW(r); | ||
119 | int res = gsl_blas_dgemm( | ||
120 | ta?CblasTrans:CblasNoTrans, | ||
121 | tb?CblasTrans:CblasNoTrans, | ||
122 | 1.0, M(a), M(b), | ||
123 | 0.0, M(r)); | ||
124 | CHECK(res,res); | ||
125 | OK | ||
126 | } | ||
127 | |||
128 | int multiplyC(int ta, KCMAT(a), int tb, KCMAT(b),CMAT(r)) { | ||
129 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
130 | DEBUGMSG("multiplyC (gsl_blas_zgemm)"); | ||
131 | KCMVIEW(a); | ||
132 | KCMVIEW(b); | ||
133 | CMVIEW(r); | ||
134 | gsl_complex alpha, beta; | ||
135 | GSL_SET_COMPLEX(&alpha,1.,0.); | ||
136 | GSL_SET_COMPLEX(&beta,0.,0.); | ||
137 | int res = gsl_blas_zgemm( | ||
138 | ta?CblasTrans:CblasNoTrans, | ||
139 | tb?CblasTrans:CblasNoTrans, | ||
140 | alpha, M(a), M(b), | ||
141 | beta, M(r)); | ||
142 | CHECK(res,res); | ||
143 | OK | ||
90 | } | 144 | } |
diff --git a/lib/Data/Packed/aux.h b/lib/Data/Packed/aux.h index c51234a..f45b55a 100644 --- a/lib/Data/Packed/aux.h +++ b/lib/Data/Packed/aux.h | |||
@@ -1,6 +1,21 @@ | |||
1 | #define VEC(A) int A##n, double*A##p | 1 | #include <gsl/gsl_complex.h> |
2 | #define MAT(A) int A##r, int A##c, double* A##p | ||
3 | #define KVEC(A) int A##n, const double*A##p | ||
4 | #define KMAT(A) int A##r, int A##c, const double* A##p | ||
5 | 2 | ||
6 | int trans(int size, KMAT(x),MAT(t)); | 3 | #define RVEC(A) int A##n, double*A##p |
4 | #define RMAT(A) int A##r, int A##c, double* A##p | ||
5 | #define KRVEC(A) int A##n, const double*A##p | ||
6 | #define KRMAT(A) int A##r, int A##c, const double* A##p | ||
7 | |||
8 | #define CVEC(A) int A##n, gsl_complex*A##p | ||
9 | #define CMAT(A) int A##r, int A##c, gsl_complex* A##p | ||
10 | #define KCVEC(A) int A##n, const gsl_complex*A##p | ||
11 | #define KCMAT(A) int A##r, int A##c, const gsl_complex* A##p | ||
12 | |||
13 | |||
14 | int transR(KRMAT(x),RMAT(t)); | ||
15 | int transC(KCMAT(x),CMAT(t)); | ||
16 | |||
17 | int constantR(double *val , RVEC(r)); | ||
18 | int constantC(gsl_complex *val, CVEC(r)); | ||
19 | |||
20 | int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r)); | ||
21 | int multiplyC(int ta, KCMAT(a), int tb, KCMAT(b),CMAT(r)); | ||