diff options
author | Alberto Ruiz <aruiz@um.es> | 2007-06-04 19:10:28 +0000 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2007-06-04 19:10:28 +0000 |
commit | 7430630fa0504296b796223e01cbd417b88650ef (patch) | |
tree | c338dea8b82867a4c161fcee5817ed2ca27c7258 /lib/Data/Packed/Internal | |
parent | 0a9817cc481fb09f1962eb2c272125e56a123814 (diff) |
separation of Internal
Diffstat (limited to 'lib/Data/Packed/Internal')
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 187 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Tensor.hs | 32 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Vector.hs | 164 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/aux.c | 144 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/aux.h | 21 |
5 files changed, 548 insertions, 0 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs new file mode 100644 index 0000000..2c57c07 --- /dev/null +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -0,0 +1,187 @@ | |||
1 | {-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} | ||
2 | ----------------------------------------------------------------------------- | ||
3 | -- | | ||
4 | -- Module : Data.Packed.Internal.Matrix | ||
5 | -- Copyright : (c) Alberto Ruiz 2007 | ||
6 | -- License : GPL-style | ||
7 | -- | ||
8 | -- Maintainer : Alberto Ruiz <aruiz@um.es> | ||
9 | -- Stability : provisional | ||
10 | -- Portability : portable (uses FFI) | ||
11 | -- | ||
12 | -- Fundamental types | ||
13 | -- | ||
14 | ----------------------------------------------------------------------------- | ||
15 | |||
16 | module Data.Packed.Internal.Matrix where | ||
17 | |||
18 | import Data.Packed.Internal.Vector | ||
19 | |||
20 | import Foreign hiding (xor) | ||
21 | import Complex | ||
22 | import Control.Monad(when) | ||
23 | import Debug.Trace | ||
24 | import Data.List(transpose,intersperse) | ||
25 | import Data.Typeable | ||
26 | import Data.Maybe(fromJust) | ||
27 | |||
28 | debug x = trace (show x) x | ||
29 | |||
30 | |||
31 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) | ||
32 | |||
33 | -- | 2D array | ||
34 | data Matrix t = M { rows :: Int | ||
35 | , cols :: Int | ||
36 | , dat :: Vector t | ||
37 | , tdat :: Vector t | ||
38 | , isTrans :: Bool | ||
39 | , order :: MatrixOrder | ||
40 | } deriving Typeable | ||
41 | |||
42 | xor a b = a && not b || b && not a | ||
43 | |||
44 | fortran m = order m == ColumnMajor | ||
45 | |||
46 | cdat m = if fortran m `xor` isTrans m then tdat m else dat m | ||
47 | fdat m = if fortran m `xor` isTrans m then dat m else tdat m | ||
48 | |||
49 | trans m = m { rows = cols m | ||
50 | , cols = rows m | ||
51 | , isTrans = not (isTrans m) | ||
52 | } | ||
53 | |||
54 | type Mt t s = Int -> Int -> Ptr t -> s | ||
55 | infixr 6 ::> | ||
56 | type t ::> s = Mt t s | ||
57 | |||
58 | mat d m f = f (rows m) (cols m) (ptr (d m)) | ||
59 | |||
60 | instance (Show a, Storable a) => (Show (Matrix a)) where | ||
61 | show m = (sizes++) . dsp . map (map show) . toLists $ m | ||
62 | where sizes = "("++show (rows m)++"><"++show (cols m)++")\n" | ||
63 | |||
64 | partit :: Int -> [a] -> [[a]] | ||
65 | partit _ [] = [] | ||
66 | partit n l = take n l : partit n (drop n l) | ||
67 | |||
68 | toLists m | fortran m = transpose $ partit (rows m) . toList . dat $ m | ||
69 | | otherwise = partit (cols m) . toList . dat $ m | ||
70 | |||
71 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp | ||
72 | where | ||
73 | mt = transpose as | ||
74 | longs = map (maximum . map length) mt | ||
75 | mtp = zipWith (\a b -> map (pad a) b) longs mt | ||
76 | pad n str = replicate (n - length str) ' ' ++ str | ||
77 | unwords' = concat . intersperse ", " | ||
78 | |||
79 | matrixFromVector RowMajor c v = | ||
80 | M { rows = r | ||
81 | , cols = c | ||
82 | , dat = v | ||
83 | , tdat = transdata c v r | ||
84 | , order = RowMajor | ||
85 | , isTrans = False | ||
86 | } where r = dim v `div` c -- TODO check mod=0 | ||
87 | |||
88 | matrixFromVector ColumnMajor c v = | ||
89 | M { rows = r | ||
90 | , cols = c | ||
91 | , dat = v | ||
92 | , tdat = transdata r v c | ||
93 | , order = ColumnMajor | ||
94 | , isTrans = False | ||
95 | } where r = dim v `div` c -- TODO check mod=0 | ||
96 | |||
97 | createMatrix order r c = do | ||
98 | p <- createVector (r*c) | ||
99 | return (matrixFromVector order c p) | ||
100 | |||
101 | transdataG :: Storable a => Int -> Vector a -> Int -> Vector a | ||
102 | transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d | ||
103 | |||
104 | transdataR :: Int -> Vector Double -> Int -> Vector Double | ||
105 | transdataR = transdataAux ctransR | ||
106 | |||
107 | transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) | ||
108 | transdataC = transdataAux ctransC | ||
109 | |||
110 | transdataAux fun c1 d c2 = unsafePerformIO $ do | ||
111 | v <- createVector (dim d) | ||
112 | let r1 = dim d `div` c1 | ||
113 | r2 = dim d `div` c2 | ||
114 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] | ||
115 | --putStrLn "---> transdataAux" | ||
116 | return v | ||
117 | |||
118 | foreign import ccall safe "aux.h transR" | ||
119 | ctransR :: Double ::> Double ::> IO Int | ||
120 | foreign import ccall safe "aux.h transC" | ||
121 | ctransC :: Complex Double ::> Complex Double ::> IO Int | ||
122 | |||
123 | transdata :: Field a => Int -> Vector a -> Int -> Vector a | ||
124 | transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2 | ||
125 | | isComp baseOf d = scast $ transdataC c1 (scast d) c2 | ||
126 | | otherwise = transdataG c1 d c2 | ||
127 | |||
128 | --transdata :: Storable a => Int -> Vector a -> Int -> Vector a | ||
129 | --transdata = transdataG | ||
130 | --{-# RULES "transdataR" transdata=transdataR #-} | ||
131 | --{-# RULES "transdataC" transdata=transdataC #-} | ||
132 | |||
133 | -- | extracts the rows of a matrix as a list of vectors | ||
134 | toRows :: Storable t => Matrix t -> [Vector t] | ||
135 | toRows m = toRows' 0 where | ||
136 | v = cdat m | ||
137 | r = rows m | ||
138 | c = cols m | ||
139 | toRows' k | k == r*c = [] | ||
140 | | otherwise = subVector k c v : toRows' (k+c) | ||
141 | |||
142 | ------------------------------------------------------------------ | ||
143 | |||
144 | dotL a b = sum (zipWith (*) a b) | ||
145 | |||
146 | multiplyL a b = [[dotL x y | y <- transpose b] | x <- a] | ||
147 | |||
148 | transL m = matrixFromVector RowMajor (rows m) $ transdataG (cols m) (cdat m) (rows m) | ||
149 | |||
150 | multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b) | ||
151 | |||
152 | ------------------------------------------------------------------ | ||
153 | |||
154 | gmatC m f | fortran m = | ||
155 | if (isTrans m) | ||
156 | then f 0 (rows m) (cols m) (ptr (dat m)) | ||
157 | else f 1 (cols m) (rows m) (ptr (dat m)) | ||
158 | | otherwise = | ||
159 | if isTrans m | ||
160 | then f 1 (cols m) (rows m) (ptr (dat m)) | ||
161 | else f 0 (rows m) (cols m) (ptr (dat m)) | ||
162 | |||
163 | |||
164 | multiplyAux order fun a b = unsafePerformIO $ do | ||
165 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ | ||
166 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) | ||
167 | r <- createMatrix order (rows a) (cols b) | ||
168 | fun // gmatC a // gmatC b // mat dat r // check "multiplyAux" [dat a, dat b] | ||
169 | return r | ||
170 | |||
171 | foreign import ccall safe "aux.h multiplyR" | ||
172 | cmultiplyR :: Int -> Double ::> (Int -> Double ::> (Double ::> IO Int)) | ||
173 | |||
174 | foreign import ccall safe "aux.h multiplyC" | ||
175 | cmultiplyC :: Int -> Complex Double ::> (Int -> Complex Double ::> (Complex Double ::> IO Int)) | ||
176 | |||
177 | multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a | ||
178 | multiply RowMajor a b = multiplyD RowMajor a b | ||
179 | multiply ColumnMajor a b = trans $ multiplyT ColumnMajor a b | ||
180 | |||
181 | multiplyT order a b = multiplyD order (trans b) (trans a) | ||
182 | |||
183 | multiplyD order a b | ||
184 | | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b) | ||
185 | | isComp (baseOf.dat) a = scast $ multiplyAux order cmultiplyC (scast a) (scast b) | ||
186 | | otherwise = multiplyG a b | ||
187 | |||
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs new file mode 100644 index 0000000..11101a9 --- /dev/null +++ b/lib/Data/Packed/Internal/Tensor.hs | |||
@@ -0,0 +1,32 @@ | |||
1 | {-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} | ||
2 | ----------------------------------------------------------------------------- | ||
3 | -- | | ||
4 | -- Module : Data.Packed.Internal.Tensor | ||
5 | -- Copyright : (c) Alberto Ruiz 2007 | ||
6 | -- License : GPL-style | ||
7 | -- | ||
8 | -- Maintainer : Alberto Ruiz <aruiz@um.es> | ||
9 | -- Stability : provisional | ||
10 | -- Portability : portable (uses FFI) | ||
11 | -- | ||
12 | -- Fundamental types | ||
13 | -- | ||
14 | ----------------------------------------------------------------------------- | ||
15 | |||
16 | module Data.Packed.Internal.Tensor where | ||
17 | |||
18 | import Data.Packed.Internal.Vector | ||
19 | import Data.Packed.Internal.Matrix | ||
20 | |||
21 | |||
22 | data IdxTp = Covariant | Contravariant deriving Show | ||
23 | |||
24 | data Tensor t = T { dims :: [(Int,(IdxTp,String))] | ||
25 | , ten :: Vector t | ||
26 | } deriving Show | ||
27 | |||
28 | rank = length . dims | ||
29 | |||
30 | outer u v = dat (multiply RowMajor r c) | ||
31 | where r = matrixFromVector RowMajor 1 u | ||
32 | c = matrixFromVector RowMajor (dim v) v | ||
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs new file mode 100644 index 0000000..7dcefeb --- /dev/null +++ b/lib/Data/Packed/Internal/Vector.hs | |||
@@ -0,0 +1,164 @@ | |||
1 | {-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} | ||
2 | ----------------------------------------------------------------------------- | ||
3 | -- | | ||
4 | -- Module : Data.Packed.Internal.Vector | ||
5 | -- Copyright : (c) Alberto Ruiz 2007 | ||
6 | -- License : GPL-style | ||
7 | -- | ||
8 | -- Maintainer : Alberto Ruiz <aruiz@um.es> | ||
9 | -- Stability : provisional | ||
10 | -- Portability : portable (uses FFI) | ||
11 | -- | ||
12 | -- Fundamental types | ||
13 | -- | ||
14 | ----------------------------------------------------------------------------- | ||
15 | |||
16 | module Data.Packed.Internal.Vector where | ||
17 | |||
18 | import Foreign | ||
19 | import Complex | ||
20 | import Control.Monad(when) | ||
21 | import Debug.Trace | ||
22 | import Data.List(transpose,intersperse) | ||
23 | import Data.Typeable | ||
24 | import Data.Maybe(fromJust) | ||
25 | |||
26 | debug x = trace (show x) x | ||
27 | |||
28 | ---------------------------------------------------------------------- | ||
29 | instance (Storable a, RealFloat a) => Storable (Complex a) where -- | ||
30 | alignment x = alignment (realPart x) -- | ||
31 | sizeOf x = 2 * sizeOf (realPart x) -- | ||
32 | peek p = do -- | ||
33 | [re,im] <- peekArray 2 (castPtr p) -- | ||
34 | return (re :+ im) -- | ||
35 | poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- | ||
36 | ---------------------------------------------------------------------- | ||
37 | |||
38 | (//) :: x -> (x -> y) -> y | ||
39 | infixl 0 // | ||
40 | (//) = flip ($) | ||
41 | |||
42 | check msg ls f = do | ||
43 | err <- f | ||
44 | when (err/=0) (error msg) | ||
45 | mapM_ (touchForeignPtr . fptr) ls | ||
46 | return () | ||
47 | |||
48 | class (Storable a, Typeable a) => Field a where | ||
49 | instance (Storable a, Typeable a) => Field a where | ||
50 | |||
51 | isReal w x = typeOf (undefined :: Double) == typeOf (w x) | ||
52 | isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x) | ||
53 | baseOf v = (v `at` 0) | ||
54 | |||
55 | scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b | ||
56 | scast = fromJust . cast | ||
57 | |||
58 | |||
59 | |||
60 | ---------------------------------------------------------------------- | ||
61 | |||
62 | data Vector t = V { dim :: Int | ||
63 | , fptr :: ForeignPtr t | ||
64 | , ptr :: Ptr t | ||
65 | } deriving Typeable | ||
66 | |||
67 | type Vc t s = Int -> Ptr t -> s | ||
68 | infixr 5 :> | ||
69 | type t :> s = Vc t s | ||
70 | |||
71 | vec :: Vector t -> (Vc t s) -> s | ||
72 | vec v f = f (dim v) (ptr v) | ||
73 | |||
74 | createVector :: Storable a => Int -> IO (Vector a) | ||
75 | createVector n = do | ||
76 | when (n <= 0) $ error ("trying to createVector of dim "++show n) | ||
77 | fp <- mallocForeignPtrArray n | ||
78 | let p = unsafeForeignPtrToPtr fp | ||
79 | --putStrLn ("\n---------> V"++show n) | ||
80 | return $ V n fp p | ||
81 | |||
82 | fromList :: Storable a => [a] -> Vector a | ||
83 | fromList l = unsafePerformIO $ do | ||
84 | v <- createVector (length l) | ||
85 | let f _ p = pokeArray p l >> return 0 | ||
86 | f // vec v // check "fromList" [] | ||
87 | return v | ||
88 | |||
89 | toList :: Storable a => Vector a -> [a] | ||
90 | toList v = unsafePerformIO $ peekArray (dim v) (ptr v) | ||
91 | |||
92 | n # l = if length l == n then fromList l else error "# with wrong size" | ||
93 | |||
94 | at' :: Storable a => Vector a -> Int -> a | ||
95 | at' v n = unsafePerformIO $ peekElemOff (ptr v) n | ||
96 | |||
97 | at :: Storable a => Vector a -> Int -> a | ||
98 | at v n | n >= 0 && n < dim v = at' v n | ||
99 | | otherwise = error "vector index out of range" | ||
100 | |||
101 | instance (Show a, Storable a) => (Show (Vector a)) where | ||
102 | show v = (show (dim v))++" # " ++ show (toList v) | ||
103 | |||
104 | -- | creates a Vector taking a number of consecutive toList from another Vector | ||
105 | subVector :: Storable t => Int -- ^ index of the starting element | ||
106 | -> Int -- ^ number of toList to extract | ||
107 | -> Vector t -- ^ source | ||
108 | -> Vector t -- ^ result | ||
109 | subVector k l (v@V {dim=n, ptr=p, fptr=fp}) | ||
110 | | k<0 || k >= n || k+l > n || l < 0 = error "subVector out of range" | ||
111 | | otherwise = unsafePerformIO $ do | ||
112 | r <- createVector l | ||
113 | let f = copyArray (ptr r) (advancePtr p k) l >> return 0 | ||
114 | f // check "subVector" [v] | ||
115 | return r | ||
116 | |||
117 | subVector' k l (v@V {dim=n, ptr=p, fptr=fp}) | ||
118 | | k<0 || k >= n || k+l > n || l < 0 = error "subVector out of range" | ||
119 | | otherwise = v {dim=l, ptr=advancePtr p k} | ||
120 | |||
121 | |||
122 | {- | ||
123 | -- | creates a new Vector by joining a list of Vectors | ||
124 | join :: Field t => [Vector t] -> Vector t | ||
125 | join [] = error "joining an empty list" | ||
126 | join as = unsafePerformIO $ do | ||
127 | let tot = sum (map size as) | ||
128 | p <- mallocForeignPtrArray tot | ||
129 | withForeignPtr p $ \p -> | ||
130 | joiner as tot p | ||
131 | return (V tot p) | ||
132 | where joiner [] _ _ = return () | ||
133 | joiner (V n b : cs) _ p = do | ||
134 | withForeignPtr b $ \b' -> copyArray p b' n | ||
135 | joiner cs 0 (advancePtr p n) | ||
136 | -} | ||
137 | |||
138 | |||
139 | constantG n x = fromList (replicate n x) | ||
140 | |||
141 | constantR :: Int -> Double -> Vector Double | ||
142 | constantR = constantAux cconstantR | ||
143 | |||
144 | constantC :: Int -> Complex Double -> Vector (Complex Double) | ||
145 | constantC = constantAux cconstantC | ||
146 | |||
147 | constantAux fun n x = unsafePerformIO $ do | ||
148 | v <- createVector n | ||
149 | px <- newArray [x] | ||
150 | fun px // vec v // check "constantAux" [] | ||
151 | free px | ||
152 | return v | ||
153 | |||
154 | foreign import ccall safe "aux.h constantR" | ||
155 | cconstantR :: Ptr Double -> Double :> IO Int | ||
156 | |||
157 | foreign import ccall safe "aux.h constantC" | ||
158 | cconstantC :: Ptr (Complex Double) -> Complex Double :> IO Int | ||
159 | |||
160 | constant :: Field a => Int -> a -> Vector a | ||
161 | constant n x | isReal id x = scast $ constantR n (scast x) | ||
162 | | isComp id x = scast $ constantC n (scast x) | ||
163 | | otherwise = constantG n x | ||
164 | |||
diff --git a/lib/Data/Packed/Internal/aux.c b/lib/Data/Packed/Internal/aux.c new file mode 100644 index 0000000..da36035 --- /dev/null +++ b/lib/Data/Packed/Internal/aux.c | |||
@@ -0,0 +1,144 @@ | |||
1 | #include "aux.h" | ||
2 | #include <gsl/gsl_blas.h> | ||
3 | #include <gsl/gsl_linalg.h> | ||
4 | #include <gsl/gsl_matrix.h> | ||
5 | #include <gsl/gsl_math.h> | ||
6 | #include <gsl/gsl_errno.h> | ||
7 | #include <gsl/gsl_fft_complex.h> | ||
8 | #include <gsl/gsl_eigen.h> | ||
9 | #include <gsl/gsl_integration.h> | ||
10 | #include <gsl/gsl_deriv.h> | ||
11 | #include <gsl/gsl_poly.h> | ||
12 | #include <gsl/gsl_multimin.h> | ||
13 | #include <gsl/gsl_complex.h> | ||
14 | #include <gsl/gsl_complex_math.h> | ||
15 | #include <string.h> | ||
16 | #include <stdio.h> | ||
17 | |||
18 | #define MACRO(B) do {B} while (0) | ||
19 | #define ERROR(CODE) MACRO(return CODE;) | ||
20 | #define REQUIRES(COND, CODE) MACRO(if(!(COND)) {ERROR(CODE);}) | ||
21 | |||
22 | #define MIN(A,B) ((A)<(B)?(A):(B)) | ||
23 | #define MAX(A,B) ((A)>(B)?(A):(B)) | ||
24 | |||
25 | #ifdef DBG | ||
26 | #define DEBUGMSG(M) printf("GSL Wrapper "M": "); size_t t0 = time(NULL); | ||
27 | #define OK MACRO(printf("%ld s\n",time(0)-t0); return 0;); | ||
28 | #else | ||
29 | #define DEBUGMSG(M) | ||
30 | #define OK return 0; | ||
31 | #endif | ||
32 | |||
33 | |||
34 | #define CHECK(RES,CODE) MACRO(if(RES) return CODE;) | ||
35 | |||
36 | #ifdef DBG | ||
37 | #define DEBUGMAT(MSG,X) printf(MSG" = \n"); gsl_matrix_fprintf(stdout,X,"%f"); printf("\n"); | ||
38 | #else | ||
39 | #define DEBUGMAT(MSG,X) | ||
40 | #endif | ||
41 | |||
42 | #ifdef DBG | ||
43 | #define DEBUGVEC(MSG,X) printf(MSG" = \n"); gsl_vector_fprintf(stdout,X,"%f"); printf("\n"); | ||
44 | #else | ||
45 | #define DEBUGVEC(MSG,X) | ||
46 | #endif | ||
47 | |||
48 | |||
49 | #define DVVIEW(A) gsl_vector_view A = gsl_vector_view_array(A##p,A##n) | ||
50 | #define DMVIEW(A) gsl_matrix_view A = gsl_matrix_view_array(A##p,A##r,A##c) | ||
51 | #define CVVIEW(A) gsl_vector_complex_view A = gsl_vector_complex_view_array((double*)A##p,A##n) | ||
52 | #define CMVIEW(A) gsl_matrix_complex_view A = gsl_matrix_complex_view_array((double*)A##p,A##r,A##c) | ||
53 | #define KDVVIEW(A) gsl_vector_const_view A = gsl_vector_const_view_array(A##p,A##n) | ||
54 | #define KDMVIEW(A) gsl_matrix_const_view A = gsl_matrix_const_view_array(A##p,A##r,A##c) | ||
55 | #define KCVVIEW(A) gsl_vector_complex_const_view A = gsl_vector_complex_const_view_array((double*)A##p,A##n) | ||
56 | #define KCMVIEW(A) gsl_matrix_complex_const_view A = gsl_matrix_complex_const_view_array((double*)A##p,A##r,A##c) | ||
57 | |||
58 | #define V(a) (&a.vector) | ||
59 | #define M(a) (&a.matrix) | ||
60 | |||
61 | #define GCVEC(A) int A##n, gsl_complex*A##p | ||
62 | #define KGCVEC(A) int A##n, const gsl_complex*A##p | ||
63 | |||
64 | #define BAD_SIZE 1000 | ||
65 | #define BAD_CODE 1001 | ||
66 | #define MEM 1002 | ||
67 | #define BAD_FILE 1003 | ||
68 | |||
69 | |||
70 | |||
71 | int transR(KRMAT(x),RMAT(t)) { | ||
72 | REQUIRES(xr==tc && xc==tr,BAD_SIZE); | ||
73 | DEBUGMSG("transR"); | ||
74 | KDMVIEW(x); | ||
75 | DMVIEW(t); | ||
76 | int res = gsl_matrix_transpose_memcpy(M(t),M(x)); | ||
77 | CHECK(res,res); | ||
78 | OK | ||
79 | } | ||
80 | |||
81 | int transC(KCMAT(x),CMAT(t)) { | ||
82 | REQUIRES(xr==tc && xc==tr,BAD_SIZE); | ||
83 | DEBUGMSG("transC"); | ||
84 | KCMVIEW(x); | ||
85 | CMVIEW(t); | ||
86 | int res = gsl_matrix_complex_transpose_memcpy(M(t),M(x)); | ||
87 | CHECK(res,res); | ||
88 | OK | ||
89 | } | ||
90 | |||
91 | |||
92 | int constantR(double * pval, RVEC(r)) { | ||
93 | DEBUGMSG("constantR") | ||
94 | int k; | ||
95 | double val = *pval; | ||
96 | for(k=0;k<rn;k++) { | ||
97 | rp[k]=val; | ||
98 | } | ||
99 | OK | ||
100 | } | ||
101 | |||
102 | int constantC(gsl_complex* pval, CVEC(r)) { | ||
103 | DEBUGMSG("constantC") | ||
104 | int k; | ||
105 | gsl_complex val = *pval; | ||
106 | for(k=0;k<rn;k++) { | ||
107 | rp[k]=val; | ||
108 | } | ||
109 | OK | ||
110 | } | ||
111 | |||
112 | int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r)) { | ||
113 | //printf("%d %d %d %d %d %d\n",ar,ac,br,bc,rr,rc); | ||
114 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
115 | DEBUGMSG("multiplyR (gsl_blas_dgemm)"); | ||
116 | KDMVIEW(a); | ||
117 | KDMVIEW(b); | ||
118 | DMVIEW(r); | ||
119 | int res = gsl_blas_dgemm( | ||
120 | ta?CblasTrans:CblasNoTrans, | ||
121 | tb?CblasTrans:CblasNoTrans, | ||
122 | 1.0, M(a), M(b), | ||
123 | 0.0, M(r)); | ||
124 | CHECK(res,res); | ||
125 | OK | ||
126 | } | ||
127 | |||
128 | int multiplyC(int ta, KCMAT(a), int tb, KCMAT(b),CMAT(r)) { | ||
129 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
130 | DEBUGMSG("multiplyC (gsl_blas_zgemm)"); | ||
131 | KCMVIEW(a); | ||
132 | KCMVIEW(b); | ||
133 | CMVIEW(r); | ||
134 | gsl_complex alpha, beta; | ||
135 | GSL_SET_COMPLEX(&alpha,1.,0.); | ||
136 | GSL_SET_COMPLEX(&beta,0.,0.); | ||
137 | int res = gsl_blas_zgemm( | ||
138 | ta?CblasTrans:CblasNoTrans, | ||
139 | tb?CblasTrans:CblasNoTrans, | ||
140 | alpha, M(a), M(b), | ||
141 | beta, M(r)); | ||
142 | CHECK(res,res); | ||
143 | OK | ||
144 | } | ||
diff --git a/lib/Data/Packed/Internal/aux.h b/lib/Data/Packed/Internal/aux.h new file mode 100644 index 0000000..f45b55a --- /dev/null +++ b/lib/Data/Packed/Internal/aux.h | |||
@@ -0,0 +1,21 @@ | |||
1 | #include <gsl/gsl_complex.h> | ||
2 | |||
3 | #define RVEC(A) int A##n, double*A##p | ||
4 | #define RMAT(A) int A##r, int A##c, double* A##p | ||
5 | #define KRVEC(A) int A##n, const double*A##p | ||
6 | #define KRMAT(A) int A##r, int A##c, const double* A##p | ||
7 | |||
8 | #define CVEC(A) int A##n, gsl_complex*A##p | ||
9 | #define CMAT(A) int A##r, int A##c, gsl_complex* A##p | ||
10 | #define KCVEC(A) int A##n, const gsl_complex*A##p | ||
11 | #define KCMAT(A) int A##r, int A##c, const gsl_complex* A##p | ||
12 | |||
13 | |||
14 | int transR(KRMAT(x),RMAT(t)); | ||
15 | int transC(KCMAT(x),CMAT(t)); | ||
16 | |||
17 | int constantR(double *val , RVEC(r)); | ||
18 | int constantC(gsl_complex *val, CVEC(r)); | ||
19 | |||
20 | int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r)); | ||
21 | int multiplyC(int ta, KCMAT(a), int tb, KCMAT(b),CMAT(r)); | ||