summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/Data/Packed/Internal.hs286
-rw-r--r--lib/Data/Packed/aux.c98
-rw-r--r--lib/Data/Packed/aux.h25
3 files changed, 312 insertions, 97 deletions
diff --git a/lib/Data/Packed/Internal.hs b/lib/Data/Packed/Internal.hs
index 5e19e58..b06f044 100644
--- a/lib/Data/Packed/Internal.hs
+++ b/lib/Data/Packed/Internal.hs
@@ -1,3 +1,4 @@
1{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-}
1----------------------------------------------------------------------------- 2-----------------------------------------------------------------------------
2-- | 3-- |
3-- Module : Data.Packed.Internal 4-- Module : Data.Packed.Internal
@@ -14,39 +15,16 @@
14 15
15module Data.Packed.Internal where 16module Data.Packed.Internal where
16 17
17import Foreign 18import Foreign hiding (xor)
18import Complex 19import Complex
19import Control.Monad(when) 20import Control.Monad(when)
20import Debug.Trace 21import Debug.Trace
22import Data.List(transpose,intersperse)
23import Data.Typeable
24import Data.Maybe(fromJust)
21 25
22debug x = trace (show x) x 26debug x = trace (show x) x
23 27
24-- | 1D array
25data Vector t = V { dim :: Int
26 , fptr :: ForeignPtr t
27 , ptr :: Ptr t
28 }
29
30data TransMode = NoTrans | Trans | ConjTrans
31
32-- | 2D array
33data Matrix t = M { rows :: Int
34 , cols :: Int
35 , mat :: Vector t
36 , trMode :: TransMode
37 , isCOrder :: Bool
38 }
39
40data IdxTp = Covariant | Contravariant
41
42-- | multidimensional array
43data Tensor t = T { rank :: Int
44 , dims :: [Int]
45 , idxNm :: [String]
46 , idxTp :: [IdxTp]
47 , ten :: Vector t
48 }
49
50---------------------------------------------------------------------- 28----------------------------------------------------------------------
51instance (Storable a, RealFloat a) => Storable (Complex a) where -- 29instance (Storable a, RealFloat a) => Storable (Complex a) where --
52 alignment x = alignment (realPart x) -- 30 alignment x = alignment (realPart x) --
@@ -57,36 +35,36 @@ instance (Storable a, RealFloat a) => Storable (Complex a) where --
57 poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- 35 poke p (a :+ b) = pokeArray (castPtr p) [a,b] --
58---------------------------------------------------------------------- 36----------------------------------------------------------------------
59 37
60
61-- f // vec a // vec b // vec res // check "vector add" [a,b]
62
63(//) :: x -> (x -> y) -> y 38(//) :: x -> (x -> y) -> y
64infixl 0 // 39infixl 0 //
65(//) = flip ($) 40(//) = flip ($)
66 41
67vec :: Vector a -> (Int -> Ptr b -> t) -> t
68vec v f = f (dim v) (castPtr $ ptr v)
69
70mata :: Matrix a -> (Int-> Int -> Ptr b -> t) -> t
71mata m f = f (rows m) (cols m) (castPtr $ ptr (mat m))
72
73pd2pc :: Ptr Double -> Ptr (Complex (Double))
74pd2pc = castPtr
75
76pc2pd :: Ptr (Complex (Double)) -> Ptr Double
77pc2pd = castPtr
78
79check msg ls f = do 42check msg ls f = do
80 err <- f 43 err <- f
81 when (err/=0) (error msg) 44 when (err/=0) (error msg)
82 mapM_ (touchForeignPtr . fptr) ls 45 mapM_ (touchForeignPtr . fptr) ls
83 return () 46 return ()
84 47
48----------------------------------------------------------------------
49
50data Vector t = V { dim :: Int
51 , fptr :: ForeignPtr t
52 , ptr :: Ptr t
53 } deriving Typeable
54
55type Vc t s = Int -> Ptr t -> s
56infixr 5 :>
57type t :> s = Vc t s
58
59vec :: Vector t -> (Vc t s) -> s
60vec v f = f (dim v) (ptr v)
61
85createVector :: Storable a => Int -> IO (Vector a) 62createVector :: Storable a => Int -> IO (Vector a)
86createVector n = do 63createVector n = do
87 when (n <= 0) $ error ("trying to createVector of dim "++show n) 64 when (n <= 0) $ error ("trying to createVector of dim "++show n)
88 fp <- mallocForeignPtrArray n 65 fp <- mallocForeignPtrArray n
89 let p = unsafeForeignPtrToPtr fp 66 let p = unsafeForeignPtrToPtr fp
67 --putStrLn ("\n---------> V"++show n)
90 return $ V n fp p 68 return $ V n fp p
91 69
92fromList :: Storable a => [a] -> Vector a 70fromList :: Storable a => [a] -> Vector a
@@ -99,6 +77,8 @@ fromList l = unsafePerformIO $ do
99toList :: Storable a => Vector a -> [a] 77toList :: Storable a => Vector a -> [a]
100toList v = unsafePerformIO $ peekArray (dim v) (ptr v) 78toList v = unsafePerformIO $ peekArray (dim v) (ptr v)
101 79
80n # l = if length l == n then fromList l else error "# with wrong size"
81
102at' :: Storable a => Vector a -> Int -> a 82at' :: Storable a => Vector a -> Int -> a
103at' v n = unsafePerformIO $ peekElemOff (ptr v) n 83at' v n = unsafePerformIO $ peekElemOff (ptr v) n
104 84
@@ -106,42 +86,208 @@ at :: Storable a => Vector a -> Int -> a
106at v n | n >= 0 && n < dim v = at' v n 86at v n | n >= 0 && n < dim v = at' v n
107 | otherwise = error "vector index out of range" 87 | otherwise = error "vector index out of range"
108 88
109dsv v = sizeOf (v `at` 0) 89instance (Show a, Storable a) => (Show (Vector a)) where
110dsm m = (dsv.mat) m 90 show v = (show (dim v))++" # " ++ show (toList v)
111 91
112constant :: Storable a => Int -> a -> Vector a 92------------------------------------------------------------------------
113constant n x = unsafePerformIO $ do
114 v <- createVector n
115 let f k p | k == n = return 0
116 | otherwise = pokeElemOff p k x >> f (k+1) p
117 const (f 0) // vec v // check "constant" []
118 return v
119 93
120instance (Show a, Storable a) => (Show (Vector a)) where 94data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
121 show v = "fromList " ++ show (toList v) 95
96-- | 2D array
97data Matrix t = M { rows :: Int
98 , cols :: Int
99 , cmat :: Vector t
100 , fmat :: Vector t
101 , isTrans :: Bool
102 , order :: MatrixOrder
103 } deriving Typeable
104
105xor a b = a && not b || b && not a
106
107fortran m = order m == ColumnMajor
108
109dat m = if fortran m `xor` isTrans m then fmat m else cmat m
110
111pref m = if fortran m then fmat m else cmat m
112
113trans m = m { rows = cols m
114 , cols = rows m
115 , isTrans = not (isTrans m)
116 }
117
118type Mt t s = Int -> Int -> Ptr t -> s
119infixr 6 ::>
120type t ::> s = Mt t s
121
122mat :: Matrix t -> (Mt t s) -> s
123mat m f = f (rows m) (cols m) (ptr (dat m))
124
125gmat m f | fortran m =
126 if (isTrans m)
127 then f 0 (rows m) (cols m) (ptr (fmat m))
128 else f 1 (cols m) (rows m) (ptr (fmat m))
129 | otherwise =
130 if isTrans m
131 then f 1 (cols m) (rows m) (ptr (cmat m))
132 else f 0 (rows m) (cols m) (ptr (cmat m))
122 133
123instance (Show a, Storable a) => (Show (Matrix a)) where 134instance (Show a, Storable a) => (Show (Matrix a)) where
124 show m = "reshape "++show (cols m) ++ " $ fromList " ++ show (toList (mat m)) 135 show m = (sizes++) . dsp . map (map show) . toLists $ m
136 where sizes = "("++show (rows m)++"><"++show (cols m)++")\n"
137
138partit :: Int -> [a] -> [[a]]
139partit _ [] = []
140partit n l = take n l : partit n (drop n l)
141
142toLists m = partit (cols m) . toList . cmat $ m
125 143
126reshape :: Storable a => Int -> Vector a -> Matrix a 144dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp
127reshape n v = M { rows = dim v `div` n 145 where
128 , cols = n 146 mt = transpose as
129 , mat = v 147 longs = map (maximum . map length) mt
130 , trMode = NoTrans 148 mtp = zipWith (\a b -> map (pad a) b) longs mt
131 , isCOrder = True 149 pad n str = replicate (n - length str) ' ' ++ str
132 } 150 unwords' = concat . intersperse ", "
133 151
134createMatrix r c = do 152matrixFromVector RowMajor c v =
153 M { rows = r
154 , cols = c
155 , cmat = v
156 , fmat = transdata c v r
157 , order = RowMajor
158 , isTrans = False
159 } where r = dim v `div` c -- TODO check mod=0
160
161matrixFromVector ColumnMajor c v =
162 M { rows = r
163 , cols = c
164 , fmat = v
165 , cmat = transdata c v r
166 , order = ColumnMajor
167 , isTrans = False
168 } where r = dim v `div` c -- TODO check mod=0
169
170createMatrix order r c = do
135 p <- createVector (r*c) 171 p <- createVector (r*c)
136 return (reshape c p) 172 return (matrixFromVector order c p)
173
174transdataG :: Storable a => Int -> Vector a -> Int -> Vector a
175transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d
176
177transdataR :: Int -> Vector Double -> Int -> Vector Double
178transdataR = transdataAux ctransR
179
180transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double)
181transdataC = transdataAux ctransC
182
183transdataAux fun c1 d c2 = unsafePerformIO $ do
184 v <- createVector (dim d)
185 let r1 = dim d `div` c1
186 r2 = dim d `div` c2
187 fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d]
188 --putStrLn "---> transdataAux"
189 return v
190
191foreign import ccall safe "aux.h transR"
192 ctransR :: Double ::> Double ::> IO Int
193foreign import ccall safe "aux.h transC"
194 ctransC :: Complex Double ::> Complex Double ::> IO Int
195
196
197class (Storable a, Typeable a) => Field a where
198instance (Storable a, Typeable a) => Field a where
199
200isReal w x = typeOf (undefined :: Double) == typeOf (w x)
201isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x)
202baseOf v = (v `at` 0)
203
204scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b
205scast = fromJust . cast
206
207transdata :: Field a => Int -> Vector a -> Int -> Vector a
208transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2
209 | isComp baseOf d = scast $ transdataC c1 (scast d) c2
210 | otherwise = transdataG c1 d c2
211
212--transdata :: Storable a => Int -> Vector a -> Int -> Vector a
213--transdata = transdataG
214--{-# RULES "transdataR" transdata=transdataR #-}
215--{-# RULES "transdataC" transdata=transdataC #-}
216
217------------------------------------------------------------------
218
219constantG n x = fromList (replicate n x)
220
221constantR :: Int -> Double -> Vector Double
222constantR = constantAux cconstantR
223
224constantC :: Int -> Complex Double -> Vector (Complex Double)
225constantC = constantAux cconstantC
226
227constantAux fun n x = unsafePerformIO $ do
228 v <- createVector n
229 px <- newArray [x]
230 fun px // vec v // check "constantAux" []
231 free px
232 return v
137 233
138type CMat s = Int -> Int -> Ptr Double -> s 234foreign import ccall safe "aux.h constantR"
139type CVec s = Int -> Ptr Double -> s 235 cconstantR :: Ptr Double -> Double :> IO Int
140 236
141foreign import ccall safe "aux.h trans" ctrans :: Int -> CMat (CMat (IO Int)) 237foreign import ccall safe "aux.h constantC"
238 cconstantC :: Ptr (Complex Double) -> Complex Double :> IO Int
142 239
143trans :: Storable a => Matrix a -> Matrix a 240constant :: Field a => Int -> a -> Vector a
144trans m = unsafePerformIO $ do 241constant n x | isReal id x = scast $ constantR n (scast x)
145 r <- createMatrix (cols m) (rows m) 242 | isComp id x = scast $ constantC n (scast x)
146 ctrans (dsm m) // mata m // mata r // check "trans" [mat m] 243 | otherwise = constantG n x
244
245------------------------------------------------------------------
246
247dotL a b = sum (zipWith (*) a b)
248
249multiplyL a b = [[dotL x y | y <- transpose b] | x <- a]
250
251transL m = m {rows = cols m, cols = rows m, cmat = v, fmat = cmat m}
252 where v = transdataG (cols m) (cmat m) (rows m)
253
254------------------------------------------------------------------
255
256multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b)
257
258multiplyAux order fun a b = unsafePerformIO $ do
259 when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++
260 show (rows a,cols a) ++ " x " ++ show (rows b, cols b)
261 r <- createMatrix order (rows a) (cols b)
262 fun // gmat a // gmat b // mat r // check "multiplyAux" [pref a, pref b]
147 return r 263 return r
264
265foreign import ccall safe "aux.h multiplyR"
266 cmultiplyR :: Int -> Double ::> (Int -> Double ::> (Double ::> IO Int))
267
268foreign import ccall safe "aux.h multiplyC"
269 cmultiplyC :: Int -> Complex Double ::> (Int -> Complex Double ::> (Complex Double ::> IO Int))
270
271multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a
272multiply RowMajor a b = multiplyD RowMajor a b
273multiply ColumnMajor a b = trans $ multiplyT ColumnMajor a b
274
275multiplyT order a b = multiplyD order (trans b) (trans a)
276
277multiplyD order a b
278 | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b)
279 | isComp (baseOf.dat) a = scast $ multiplyAux order cmultiplyC (scast a) (scast b)
280 | otherwise = multiplyG a b
281
282--------------------------------------------------------------------
283
284data IdxTp = Covariant | Contravariant
285
286-- | multidimensional array
287data Tensor t = T { rank :: Int
288 , dims :: [Int]
289 , idxNm :: [String]
290 , idxTp :: [IdxTp]
291 , ten :: Vector t
292 }
293
diff --git a/lib/Data/Packed/aux.c b/lib/Data/Packed/aux.c
index d772d90..da36035 100644
--- a/lib/Data/Packed/aux.c
+++ b/lib/Data/Packed/aux.c
@@ -48,12 +48,12 @@
48 48
49#define DVVIEW(A) gsl_vector_view A = gsl_vector_view_array(A##p,A##n) 49#define DVVIEW(A) gsl_vector_view A = gsl_vector_view_array(A##p,A##n)
50#define DMVIEW(A) gsl_matrix_view A = gsl_matrix_view_array(A##p,A##r,A##c) 50#define DMVIEW(A) gsl_matrix_view A = gsl_matrix_view_array(A##p,A##r,A##c)
51#define CVVIEW(A) gsl_vector_complex_view A = gsl_vector_complex_view_array(A##p,A##n) 51#define CVVIEW(A) gsl_vector_complex_view A = gsl_vector_complex_view_array((double*)A##p,A##n)
52#define CMVIEW(A) gsl_matrix_complex_view A = gsl_matrix_complex_view_array(A##p,A##r,A##c) 52#define CMVIEW(A) gsl_matrix_complex_view A = gsl_matrix_complex_view_array((double*)A##p,A##r,A##c)
53#define KDVVIEW(A) gsl_vector_const_view A = gsl_vector_const_view_array(A##p,A##n) 53#define KDVVIEW(A) gsl_vector_const_view A = gsl_vector_const_view_array(A##p,A##n)
54#define KDMVIEW(A) gsl_matrix_const_view A = gsl_matrix_const_view_array(A##p,A##r,A##c) 54#define KDMVIEW(A) gsl_matrix_const_view A = gsl_matrix_const_view_array(A##p,A##r,A##c)
55#define KCVVIEW(A) gsl_vector_complex_const_view A = gsl_vector_complex_const_view_array(A##p,A##n) 55#define KCVVIEW(A) gsl_vector_complex_const_view A = gsl_vector_complex_const_view_array((double*)A##p,A##n)
56#define KCMVIEW(A) gsl_matrix_complex_const_view A = gsl_matrix_complex_const_view_array(A##p,A##r,A##c) 56#define KCMVIEW(A) gsl_matrix_complex_const_view A = gsl_matrix_complex_const_view_array((double*)A##p,A##r,A##c)
57 57
58#define V(a) (&a.vector) 58#define V(a) (&a.vector)
59#define M(a) (&a.matrix) 59#define M(a) (&a.matrix)
@@ -65,26 +65,80 @@
65#define BAD_CODE 1001 65#define BAD_CODE 1001
66#define MEM 1002 66#define MEM 1002
67#define BAD_FILE 1003 67#define BAD_FILE 1003
68#define BAD_TYPE 1004
69 68
70 69
71int trans(int size,KMAT(x),MAT(t)) { 70
71int transR(KRMAT(x),RMAT(t)) {
72 REQUIRES(xr==tc && xc==tr,BAD_SIZE);
73 DEBUGMSG("transR");
74 KDMVIEW(x);
75 DMVIEW(t);
76 int res = gsl_matrix_transpose_memcpy(M(t),M(x));
77 CHECK(res,res);
78 OK
79}
80
81int transC(KCMAT(x),CMAT(t)) {
72 REQUIRES(xr==tc && xc==tr,BAD_SIZE); 82 REQUIRES(xr==tc && xc==tr,BAD_SIZE);
73 DEBUGMSG("trans"); 83 DEBUGMSG("transC");
74 if(size==8) { 84 KCMVIEW(x);
75 DEBUGMSG("trans double"); 85 CMVIEW(t);
76 KDMVIEW(x); 86 int res = gsl_matrix_complex_transpose_memcpy(M(t),M(x));
77 DMVIEW(t); 87 CHECK(res,res);
78 int res = gsl_matrix_transpose_memcpy(M(t),M(x)); 88 OK
79 CHECK(res,res); 89}
80 OK 90
81 } else if (size==16) { 91
82 DEBUGMSG("trans complex double"); 92int constantR(double * pval, RVEC(r)) {
83 KCMVIEW(x); 93 DEBUGMSG("constantR")
84 CMVIEW(t); 94 int k;
85 int res = gsl_matrix_complex_transpose_memcpy(M(t),M(x)); 95 double val = *pval;
86 CHECK(res,res); 96 for(k=0;k<rn;k++) {
87 OK 97 rp[k]=val;
88 } 98 }
89 return BAD_TYPE; 99 OK
100}
101
102int constantC(gsl_complex* pval, CVEC(r)) {
103 DEBUGMSG("constantC")
104 int k;
105 gsl_complex val = *pval;
106 for(k=0;k<rn;k++) {
107 rp[k]=val;
108 }
109 OK
110}
111
112int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r)) {
113 //printf("%d %d %d %d %d %d\n",ar,ac,br,bc,rr,rc);
114 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
115 DEBUGMSG("multiplyR (gsl_blas_dgemm)");
116 KDMVIEW(a);
117 KDMVIEW(b);
118 DMVIEW(r);
119 int res = gsl_blas_dgemm(
120 ta?CblasTrans:CblasNoTrans,
121 tb?CblasTrans:CblasNoTrans,
122 1.0, M(a), M(b),
123 0.0, M(r));
124 CHECK(res,res);
125 OK
126}
127
128int multiplyC(int ta, KCMAT(a), int tb, KCMAT(b),CMAT(r)) {
129 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
130 DEBUGMSG("multiplyC (gsl_blas_zgemm)");
131 KCMVIEW(a);
132 KCMVIEW(b);
133 CMVIEW(r);
134 gsl_complex alpha, beta;
135 GSL_SET_COMPLEX(&alpha,1.,0.);
136 GSL_SET_COMPLEX(&beta,0.,0.);
137 int res = gsl_blas_zgemm(
138 ta?CblasTrans:CblasNoTrans,
139 tb?CblasTrans:CblasNoTrans,
140 alpha, M(a), M(b),
141 beta, M(r));
142 CHECK(res,res);
143 OK
90} 144}
diff --git a/lib/Data/Packed/aux.h b/lib/Data/Packed/aux.h
index c51234a..f45b55a 100644
--- a/lib/Data/Packed/aux.h
+++ b/lib/Data/Packed/aux.h
@@ -1,6 +1,21 @@
1#define VEC(A) int A##n, double*A##p 1#include <gsl/gsl_complex.h>
2#define MAT(A) int A##r, int A##c, double* A##p
3#define KVEC(A) int A##n, const double*A##p
4#define KMAT(A) int A##r, int A##c, const double* A##p
5 2
6int trans(int size, KMAT(x),MAT(t)); 3#define RVEC(A) int A##n, double*A##p
4#define RMAT(A) int A##r, int A##c, double* A##p
5#define KRVEC(A) int A##n, const double*A##p
6#define KRMAT(A) int A##r, int A##c, const double* A##p
7
8#define CVEC(A) int A##n, gsl_complex*A##p
9#define CMAT(A) int A##r, int A##c, gsl_complex* A##p
10#define KCVEC(A) int A##n, const gsl_complex*A##p
11#define KCMAT(A) int A##r, int A##c, const gsl_complex* A##p
12
13
14int transR(KRMAT(x),RMAT(t));
15int transC(KCMAT(x),CMAT(t));
16
17int constantR(double *val , RVEC(r));
18int constantC(gsl_complex *val, CVEC(r));
19
20int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r));
21int multiplyC(int ta, KCMAT(a), int tb, KCMAT(b),CMAT(r));