diff options
author | Alberto Ruiz <aruiz@um.es> | 2007-05-28 12:22:24 +0000 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2007-05-28 12:22:24 +0000 |
commit | 80673221e704b451e0d9468d6dfe1a38ad676c07 (patch) | |
tree | 1cba10d54d457ebfda1ec7810149664818834027 /lib/Data | |
parent | c3a1c3ed7c1be6f255ff3bd4f8ec6d2dd2a29b66 (diff) |
common trans
Diffstat (limited to 'lib/Data')
-rw-r--r-- | lib/Data/Packed/Internal.hs | 27 | ||||
-rw-r--r-- | lib/Data/Packed/aux.c | 90 | ||||
-rw-r--r-- | lib/Data/Packed/aux.h | 6 |
3 files changed, 123 insertions, 0 deletions
diff --git a/lib/Data/Packed/Internal.hs b/lib/Data/Packed/Internal.hs index c8ad8d7..5e19e58 100644 --- a/lib/Data/Packed/Internal.hs +++ b/lib/Data/Packed/Internal.hs | |||
@@ -67,6 +67,15 @@ infixl 0 // | |||
67 | vec :: Vector a -> (Int -> Ptr b -> t) -> t | 67 | vec :: Vector a -> (Int -> Ptr b -> t) -> t |
68 | vec v f = f (dim v) (castPtr $ ptr v) | 68 | vec v f = f (dim v) (castPtr $ ptr v) |
69 | 69 | ||
70 | mata :: Matrix a -> (Int-> Int -> Ptr b -> t) -> t | ||
71 | mata m f = f (rows m) (cols m) (castPtr $ ptr (mat m)) | ||
72 | |||
73 | pd2pc :: Ptr Double -> Ptr (Complex (Double)) | ||
74 | pd2pc = castPtr | ||
75 | |||
76 | pc2pd :: Ptr (Complex (Double)) -> Ptr Double | ||
77 | pc2pd = castPtr | ||
78 | |||
70 | check msg ls f = do | 79 | check msg ls f = do |
71 | err <- f | 80 | err <- f |
72 | when (err/=0) (error msg) | 81 | when (err/=0) (error msg) |
@@ -97,6 +106,9 @@ at :: Storable a => Vector a -> Int -> a | |||
97 | at v n | n >= 0 && n < dim v = at' v n | 106 | at v n | n >= 0 && n < dim v = at' v n |
98 | | otherwise = error "vector index out of range" | 107 | | otherwise = error "vector index out of range" |
99 | 108 | ||
109 | dsv v = sizeOf (v `at` 0) | ||
110 | dsm m = (dsv.mat) m | ||
111 | |||
100 | constant :: Storable a => Int -> a -> Vector a | 112 | constant :: Storable a => Int -> a -> Vector a |
101 | constant n x = unsafePerformIO $ do | 113 | constant n x = unsafePerformIO $ do |
102 | v <- createVector n | 114 | v <- createVector n |
@@ -118,3 +130,18 @@ reshape n v = M { rows = dim v `div` n | |||
118 | , trMode = NoTrans | 130 | , trMode = NoTrans |
119 | , isCOrder = True | 131 | , isCOrder = True |
120 | } | 132 | } |
133 | |||
134 | createMatrix r c = do | ||
135 | p <- createVector (r*c) | ||
136 | return (reshape c p) | ||
137 | |||
138 | type CMat s = Int -> Int -> Ptr Double -> s | ||
139 | type CVec s = Int -> Ptr Double -> s | ||
140 | |||
141 | foreign import ccall safe "aux.h trans" ctrans :: Int -> CMat (CMat (IO Int)) | ||
142 | |||
143 | trans :: Storable a => Matrix a -> Matrix a | ||
144 | trans m = unsafePerformIO $ do | ||
145 | r <- createMatrix (cols m) (rows m) | ||
146 | ctrans (dsm m) // mata m // mata r // check "trans" [mat m] | ||
147 | return r | ||
diff --git a/lib/Data/Packed/aux.c b/lib/Data/Packed/aux.c new file mode 100644 index 0000000..d772d90 --- /dev/null +++ b/lib/Data/Packed/aux.c | |||
@@ -0,0 +1,90 @@ | |||
1 | #include "aux.h" | ||
2 | #include <gsl/gsl_blas.h> | ||
3 | #include <gsl/gsl_linalg.h> | ||
4 | #include <gsl/gsl_matrix.h> | ||
5 | #include <gsl/gsl_math.h> | ||
6 | #include <gsl/gsl_errno.h> | ||
7 | #include <gsl/gsl_fft_complex.h> | ||
8 | #include <gsl/gsl_eigen.h> | ||
9 | #include <gsl/gsl_integration.h> | ||
10 | #include <gsl/gsl_deriv.h> | ||
11 | #include <gsl/gsl_poly.h> | ||
12 | #include <gsl/gsl_multimin.h> | ||
13 | #include <gsl/gsl_complex.h> | ||
14 | #include <gsl/gsl_complex_math.h> | ||
15 | #include <string.h> | ||
16 | #include <stdio.h> | ||
17 | |||
18 | #define MACRO(B) do {B} while (0) | ||
19 | #define ERROR(CODE) MACRO(return CODE;) | ||
20 | #define REQUIRES(COND, CODE) MACRO(if(!(COND)) {ERROR(CODE);}) | ||
21 | |||
22 | #define MIN(A,B) ((A)<(B)?(A):(B)) | ||
23 | #define MAX(A,B) ((A)>(B)?(A):(B)) | ||
24 | |||
25 | #ifdef DBG | ||
26 | #define DEBUGMSG(M) printf("GSL Wrapper "M": "); size_t t0 = time(NULL); | ||
27 | #define OK MACRO(printf("%ld s\n",time(0)-t0); return 0;); | ||
28 | #else | ||
29 | #define DEBUGMSG(M) | ||
30 | #define OK return 0; | ||
31 | #endif | ||
32 | |||
33 | |||
34 | #define CHECK(RES,CODE) MACRO(if(RES) return CODE;) | ||
35 | |||
36 | #ifdef DBG | ||
37 | #define DEBUGMAT(MSG,X) printf(MSG" = \n"); gsl_matrix_fprintf(stdout,X,"%f"); printf("\n"); | ||
38 | #else | ||
39 | #define DEBUGMAT(MSG,X) | ||
40 | #endif | ||
41 | |||
42 | #ifdef DBG | ||
43 | #define DEBUGVEC(MSG,X) printf(MSG" = \n"); gsl_vector_fprintf(stdout,X,"%f"); printf("\n"); | ||
44 | #else | ||
45 | #define DEBUGVEC(MSG,X) | ||
46 | #endif | ||
47 | |||
48 | |||
49 | #define DVVIEW(A) gsl_vector_view A = gsl_vector_view_array(A##p,A##n) | ||
50 | #define DMVIEW(A) gsl_matrix_view A = gsl_matrix_view_array(A##p,A##r,A##c) | ||
51 | #define CVVIEW(A) gsl_vector_complex_view A = gsl_vector_complex_view_array(A##p,A##n) | ||
52 | #define CMVIEW(A) gsl_matrix_complex_view A = gsl_matrix_complex_view_array(A##p,A##r,A##c) | ||
53 | #define KDVVIEW(A) gsl_vector_const_view A = gsl_vector_const_view_array(A##p,A##n) | ||
54 | #define KDMVIEW(A) gsl_matrix_const_view A = gsl_matrix_const_view_array(A##p,A##r,A##c) | ||
55 | #define KCVVIEW(A) gsl_vector_complex_const_view A = gsl_vector_complex_const_view_array(A##p,A##n) | ||
56 | #define KCMVIEW(A) gsl_matrix_complex_const_view A = gsl_matrix_complex_const_view_array(A##p,A##r,A##c) | ||
57 | |||
58 | #define V(a) (&a.vector) | ||
59 | #define M(a) (&a.matrix) | ||
60 | |||
61 | #define GCVEC(A) int A##n, gsl_complex*A##p | ||
62 | #define KGCVEC(A) int A##n, const gsl_complex*A##p | ||
63 | |||
64 | #define BAD_SIZE 1000 | ||
65 | #define BAD_CODE 1001 | ||
66 | #define MEM 1002 | ||
67 | #define BAD_FILE 1003 | ||
68 | #define BAD_TYPE 1004 | ||
69 | |||
70 | |||
71 | int trans(int size,KMAT(x),MAT(t)) { | ||
72 | REQUIRES(xr==tc && xc==tr,BAD_SIZE); | ||
73 | DEBUGMSG("trans"); | ||
74 | if(size==8) { | ||
75 | DEBUGMSG("trans double"); | ||
76 | KDMVIEW(x); | ||
77 | DMVIEW(t); | ||
78 | int res = gsl_matrix_transpose_memcpy(M(t),M(x)); | ||
79 | CHECK(res,res); | ||
80 | OK | ||
81 | } else if (size==16) { | ||
82 | DEBUGMSG("trans complex double"); | ||
83 | KCMVIEW(x); | ||
84 | CMVIEW(t); | ||
85 | int res = gsl_matrix_complex_transpose_memcpy(M(t),M(x)); | ||
86 | CHECK(res,res); | ||
87 | OK | ||
88 | } | ||
89 | return BAD_TYPE; | ||
90 | } | ||
diff --git a/lib/Data/Packed/aux.h b/lib/Data/Packed/aux.h new file mode 100644 index 0000000..c51234a --- /dev/null +++ b/lib/Data/Packed/aux.h | |||
@@ -0,0 +1,6 @@ | |||
1 | #define VEC(A) int A##n, double*A##p | ||
2 | #define MAT(A) int A##r, int A##c, double* A##p | ||
3 | #define KVEC(A) int A##n, const double*A##p | ||
4 | #define KMAT(A) int A##r, int A##c, const double* A##p | ||
5 | |||
6 | int trans(int size, KMAT(x),MAT(t)); | ||