summaryrefslogtreecommitdiff
path: root/lib/Data/Packed/Internal
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2007-06-04 19:10:28 +0000
committerAlberto Ruiz <aruiz@um.es>2007-06-04 19:10:28 +0000
commit7430630fa0504296b796223e01cbd417b88650ef (patch)
treec338dea8b82867a4c161fcee5817ed2ca27c7258 /lib/Data/Packed/Internal
parent0a9817cc481fb09f1962eb2c272125e56a123814 (diff)
separation of Internal
Diffstat (limited to 'lib/Data/Packed/Internal')
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs187
-rw-r--r--lib/Data/Packed/Internal/Tensor.hs32
-rw-r--r--lib/Data/Packed/Internal/Vector.hs164
-rw-r--r--lib/Data/Packed/Internal/aux.c144
-rw-r--r--lib/Data/Packed/Internal/aux.h21
5 files changed, 548 insertions, 0 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
new file mode 100644
index 0000000..2c57c07
--- /dev/null
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -0,0 +1,187 @@
1{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-}
2-----------------------------------------------------------------------------
3-- |
4-- Module : Data.Packed.Internal.Matrix
5-- Copyright : (c) Alberto Ruiz 2007
6-- License : GPL-style
7--
8-- Maintainer : Alberto Ruiz <aruiz@um.es>
9-- Stability : provisional
10-- Portability : portable (uses FFI)
11--
12-- Fundamental types
13--
14-----------------------------------------------------------------------------
15
16module Data.Packed.Internal.Matrix where
17
18import Data.Packed.Internal.Vector
19
20import Foreign hiding (xor)
21import Complex
22import Control.Monad(when)
23import Debug.Trace
24import Data.List(transpose,intersperse)
25import Data.Typeable
26import Data.Maybe(fromJust)
27
28debug x = trace (show x) x
29
30
31data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
32
33-- | 2D array
34data Matrix t = M { rows :: Int
35 , cols :: Int
36 , dat :: Vector t
37 , tdat :: Vector t
38 , isTrans :: Bool
39 , order :: MatrixOrder
40 } deriving Typeable
41
42xor a b = a && not b || b && not a
43
44fortran m = order m == ColumnMajor
45
46cdat m = if fortran m `xor` isTrans m then tdat m else dat m
47fdat m = if fortran m `xor` isTrans m then dat m else tdat m
48
49trans m = m { rows = cols m
50 , cols = rows m
51 , isTrans = not (isTrans m)
52 }
53
54type Mt t s = Int -> Int -> Ptr t -> s
55infixr 6 ::>
56type t ::> s = Mt t s
57
58mat d m f = f (rows m) (cols m) (ptr (d m))
59
60instance (Show a, Storable a) => (Show (Matrix a)) where
61 show m = (sizes++) . dsp . map (map show) . toLists $ m
62 where sizes = "("++show (rows m)++"><"++show (cols m)++")\n"
63
64partit :: Int -> [a] -> [[a]]
65partit _ [] = []
66partit n l = take n l : partit n (drop n l)
67
68toLists m | fortran m = transpose $ partit (rows m) . toList . dat $ m
69 | otherwise = partit (cols m) . toList . dat $ m
70
71dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp
72 where
73 mt = transpose as
74 longs = map (maximum . map length) mt
75 mtp = zipWith (\a b -> map (pad a) b) longs mt
76 pad n str = replicate (n - length str) ' ' ++ str
77 unwords' = concat . intersperse ", "
78
79matrixFromVector RowMajor c v =
80 M { rows = r
81 , cols = c
82 , dat = v
83 , tdat = transdata c v r
84 , order = RowMajor
85 , isTrans = False
86 } where r = dim v `div` c -- TODO check mod=0
87
88matrixFromVector ColumnMajor c v =
89 M { rows = r
90 , cols = c
91 , dat = v
92 , tdat = transdata r v c
93 , order = ColumnMajor
94 , isTrans = False
95 } where r = dim v `div` c -- TODO check mod=0
96
97createMatrix order r c = do
98 p <- createVector (r*c)
99 return (matrixFromVector order c p)
100
101transdataG :: Storable a => Int -> Vector a -> Int -> Vector a
102transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d
103
104transdataR :: Int -> Vector Double -> Int -> Vector Double
105transdataR = transdataAux ctransR
106
107transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double)
108transdataC = transdataAux ctransC
109
110transdataAux fun c1 d c2 = unsafePerformIO $ do
111 v <- createVector (dim d)
112 let r1 = dim d `div` c1
113 r2 = dim d `div` c2
114 fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d]
115 --putStrLn "---> transdataAux"
116 return v
117
118foreign import ccall safe "aux.h transR"
119 ctransR :: Double ::> Double ::> IO Int
120foreign import ccall safe "aux.h transC"
121 ctransC :: Complex Double ::> Complex Double ::> IO Int
122
123transdata :: Field a => Int -> Vector a -> Int -> Vector a
124transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2
125 | isComp baseOf d = scast $ transdataC c1 (scast d) c2
126 | otherwise = transdataG c1 d c2
127
128--transdata :: Storable a => Int -> Vector a -> Int -> Vector a
129--transdata = transdataG
130--{-# RULES "transdataR" transdata=transdataR #-}
131--{-# RULES "transdataC" transdata=transdataC #-}
132
133-- | extracts the rows of a matrix as a list of vectors
134toRows :: Storable t => Matrix t -> [Vector t]
135toRows m = toRows' 0 where
136 v = cdat m
137 r = rows m
138 c = cols m
139 toRows' k | k == r*c = []
140 | otherwise = subVector k c v : toRows' (k+c)
141
142------------------------------------------------------------------
143
144dotL a b = sum (zipWith (*) a b)
145
146multiplyL a b = [[dotL x y | y <- transpose b] | x <- a]
147
148transL m = matrixFromVector RowMajor (rows m) $ transdataG (cols m) (cdat m) (rows m)
149
150multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b)
151
152------------------------------------------------------------------
153
154gmatC m f | fortran m =
155 if (isTrans m)
156 then f 0 (rows m) (cols m) (ptr (dat m))
157 else f 1 (cols m) (rows m) (ptr (dat m))
158 | otherwise =
159 if isTrans m
160 then f 1 (cols m) (rows m) (ptr (dat m))
161 else f 0 (rows m) (cols m) (ptr (dat m))
162
163
164multiplyAux order fun a b = unsafePerformIO $ do
165 when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++
166 show (rows a,cols a) ++ " x " ++ show (rows b, cols b)
167 r <- createMatrix order (rows a) (cols b)
168 fun // gmatC a // gmatC b // mat dat r // check "multiplyAux" [dat a, dat b]
169 return r
170
171foreign import ccall safe "aux.h multiplyR"
172 cmultiplyR :: Int -> Double ::> (Int -> Double ::> (Double ::> IO Int))
173
174foreign import ccall safe "aux.h multiplyC"
175 cmultiplyC :: Int -> Complex Double ::> (Int -> Complex Double ::> (Complex Double ::> IO Int))
176
177multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a
178multiply RowMajor a b = multiplyD RowMajor a b
179multiply ColumnMajor a b = trans $ multiplyT ColumnMajor a b
180
181multiplyT order a b = multiplyD order (trans b) (trans a)
182
183multiplyD order a b
184 | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b)
185 | isComp (baseOf.dat) a = scast $ multiplyAux order cmultiplyC (scast a) (scast b)
186 | otherwise = multiplyG a b
187
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs
new file mode 100644
index 0000000..11101a9
--- /dev/null
+++ b/lib/Data/Packed/Internal/Tensor.hs
@@ -0,0 +1,32 @@
1{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-}
2-----------------------------------------------------------------------------
3-- |
4-- Module : Data.Packed.Internal.Tensor
5-- Copyright : (c) Alberto Ruiz 2007
6-- License : GPL-style
7--
8-- Maintainer : Alberto Ruiz <aruiz@um.es>
9-- Stability : provisional
10-- Portability : portable (uses FFI)
11--
12-- Fundamental types
13--
14-----------------------------------------------------------------------------
15
16module Data.Packed.Internal.Tensor where
17
18import Data.Packed.Internal.Vector
19import Data.Packed.Internal.Matrix
20
21
22data IdxTp = Covariant | Contravariant deriving Show
23
24data Tensor t = T { dims :: [(Int,(IdxTp,String))]
25 , ten :: Vector t
26 } deriving Show
27
28rank = length . dims
29
30outer u v = dat (multiply RowMajor r c)
31 where r = matrixFromVector RowMajor 1 u
32 c = matrixFromVector RowMajor (dim v) v
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs
new file mode 100644
index 0000000..7dcefeb
--- /dev/null
+++ b/lib/Data/Packed/Internal/Vector.hs
@@ -0,0 +1,164 @@
1{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-}
2-----------------------------------------------------------------------------
3-- |
4-- Module : Data.Packed.Internal.Vector
5-- Copyright : (c) Alberto Ruiz 2007
6-- License : GPL-style
7--
8-- Maintainer : Alberto Ruiz <aruiz@um.es>
9-- Stability : provisional
10-- Portability : portable (uses FFI)
11--
12-- Fundamental types
13--
14-----------------------------------------------------------------------------
15
16module Data.Packed.Internal.Vector where
17
18import Foreign
19import Complex
20import Control.Monad(when)
21import Debug.Trace
22import Data.List(transpose,intersperse)
23import Data.Typeable
24import Data.Maybe(fromJust)
25
26debug x = trace (show x) x
27
28----------------------------------------------------------------------
29instance (Storable a, RealFloat a) => Storable (Complex a) where --
30 alignment x = alignment (realPart x) --
31 sizeOf x = 2 * sizeOf (realPart x) --
32 peek p = do --
33 [re,im] <- peekArray 2 (castPtr p) --
34 return (re :+ im) --
35 poke p (a :+ b) = pokeArray (castPtr p) [a,b] --
36----------------------------------------------------------------------
37
38(//) :: x -> (x -> y) -> y
39infixl 0 //
40(//) = flip ($)
41
42check msg ls f = do
43 err <- f
44 when (err/=0) (error msg)
45 mapM_ (touchForeignPtr . fptr) ls
46 return ()
47
48class (Storable a, Typeable a) => Field a where
49instance (Storable a, Typeable a) => Field a where
50
51isReal w x = typeOf (undefined :: Double) == typeOf (w x)
52isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x)
53baseOf v = (v `at` 0)
54
55scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b
56scast = fromJust . cast
57
58
59
60----------------------------------------------------------------------
61
62data Vector t = V { dim :: Int
63 , fptr :: ForeignPtr t
64 , ptr :: Ptr t
65 } deriving Typeable
66
67type Vc t s = Int -> Ptr t -> s
68infixr 5 :>
69type t :> s = Vc t s
70
71vec :: Vector t -> (Vc t s) -> s
72vec v f = f (dim v) (ptr v)
73
74createVector :: Storable a => Int -> IO (Vector a)
75createVector n = do
76 when (n <= 0) $ error ("trying to createVector of dim "++show n)
77 fp <- mallocForeignPtrArray n
78 let p = unsafeForeignPtrToPtr fp
79 --putStrLn ("\n---------> V"++show n)
80 return $ V n fp p
81
82fromList :: Storable a => [a] -> Vector a
83fromList l = unsafePerformIO $ do
84 v <- createVector (length l)
85 let f _ p = pokeArray p l >> return 0
86 f // vec v // check "fromList" []
87 return v
88
89toList :: Storable a => Vector a -> [a]
90toList v = unsafePerformIO $ peekArray (dim v) (ptr v)
91
92n # l = if length l == n then fromList l else error "# with wrong size"
93
94at' :: Storable a => Vector a -> Int -> a
95at' v n = unsafePerformIO $ peekElemOff (ptr v) n
96
97at :: Storable a => Vector a -> Int -> a
98at v n | n >= 0 && n < dim v = at' v n
99 | otherwise = error "vector index out of range"
100
101instance (Show a, Storable a) => (Show (Vector a)) where
102 show v = (show (dim v))++" # " ++ show (toList v)
103
104-- | creates a Vector taking a number of consecutive toList from another Vector
105subVector :: Storable t => Int -- ^ index of the starting element
106 -> Int -- ^ number of toList to extract
107 -> Vector t -- ^ source
108 -> Vector t -- ^ result
109subVector k l (v@V {dim=n, ptr=p, fptr=fp})
110 | k<0 || k >= n || k+l > n || l < 0 = error "subVector out of range"
111 | otherwise = unsafePerformIO $ do
112 r <- createVector l
113 let f = copyArray (ptr r) (advancePtr p k) l >> return 0
114 f // check "subVector" [v]
115 return r
116
117subVector' k l (v@V {dim=n, ptr=p, fptr=fp})
118 | k<0 || k >= n || k+l > n || l < 0 = error "subVector out of range"
119 | otherwise = v {dim=l, ptr=advancePtr p k}
120
121
122{-
123-- | creates a new Vector by joining a list of Vectors
124join :: Field t => [Vector t] -> Vector t
125join [] = error "joining an empty list"
126join as = unsafePerformIO $ do
127 let tot = sum (map size as)
128 p <- mallocForeignPtrArray tot
129 withForeignPtr p $ \p ->
130 joiner as tot p
131 return (V tot p)
132 where joiner [] _ _ = return ()
133 joiner (V n b : cs) _ p = do
134 withForeignPtr b $ \b' -> copyArray p b' n
135 joiner cs 0 (advancePtr p n)
136-}
137
138
139constantG n x = fromList (replicate n x)
140
141constantR :: Int -> Double -> Vector Double
142constantR = constantAux cconstantR
143
144constantC :: Int -> Complex Double -> Vector (Complex Double)
145constantC = constantAux cconstantC
146
147constantAux fun n x = unsafePerformIO $ do
148 v <- createVector n
149 px <- newArray [x]
150 fun px // vec v // check "constantAux" []
151 free px
152 return v
153
154foreign import ccall safe "aux.h constantR"
155 cconstantR :: Ptr Double -> Double :> IO Int
156
157foreign import ccall safe "aux.h constantC"
158 cconstantC :: Ptr (Complex Double) -> Complex Double :> IO Int
159
160constant :: Field a => Int -> a -> Vector a
161constant n x | isReal id x = scast $ constantR n (scast x)
162 | isComp id x = scast $ constantC n (scast x)
163 | otherwise = constantG n x
164
diff --git a/lib/Data/Packed/Internal/aux.c b/lib/Data/Packed/Internal/aux.c
new file mode 100644
index 0000000..da36035
--- /dev/null
+++ b/lib/Data/Packed/Internal/aux.c
@@ -0,0 +1,144 @@
1#include "aux.h"
2#include <gsl/gsl_blas.h>
3#include <gsl/gsl_linalg.h>
4#include <gsl/gsl_matrix.h>
5#include <gsl/gsl_math.h>
6#include <gsl/gsl_errno.h>
7#include <gsl/gsl_fft_complex.h>
8#include <gsl/gsl_eigen.h>
9#include <gsl/gsl_integration.h>
10#include <gsl/gsl_deriv.h>
11#include <gsl/gsl_poly.h>
12#include <gsl/gsl_multimin.h>
13#include <gsl/gsl_complex.h>
14#include <gsl/gsl_complex_math.h>
15#include <string.h>
16#include <stdio.h>
17
18#define MACRO(B) do {B} while (0)
19#define ERROR(CODE) MACRO(return CODE;)
20#define REQUIRES(COND, CODE) MACRO(if(!(COND)) {ERROR(CODE);})
21
22#define MIN(A,B) ((A)<(B)?(A):(B))
23#define MAX(A,B) ((A)>(B)?(A):(B))
24
25#ifdef DBG
26#define DEBUGMSG(M) printf("GSL Wrapper "M": "); size_t t0 = time(NULL);
27#define OK MACRO(printf("%ld s\n",time(0)-t0); return 0;);
28#else
29#define DEBUGMSG(M)
30#define OK return 0;
31#endif
32
33
34#define CHECK(RES,CODE) MACRO(if(RES) return CODE;)
35
36#ifdef DBG
37#define DEBUGMAT(MSG,X) printf(MSG" = \n"); gsl_matrix_fprintf(stdout,X,"%f"); printf("\n");
38#else
39#define DEBUGMAT(MSG,X)
40#endif
41
42#ifdef DBG
43#define DEBUGVEC(MSG,X) printf(MSG" = \n"); gsl_vector_fprintf(stdout,X,"%f"); printf("\n");
44#else
45#define DEBUGVEC(MSG,X)
46#endif
47
48
49#define DVVIEW(A) gsl_vector_view A = gsl_vector_view_array(A##p,A##n)
50#define DMVIEW(A) gsl_matrix_view A = gsl_matrix_view_array(A##p,A##r,A##c)
51#define CVVIEW(A) gsl_vector_complex_view A = gsl_vector_complex_view_array((double*)A##p,A##n)
52#define CMVIEW(A) gsl_matrix_complex_view A = gsl_matrix_complex_view_array((double*)A##p,A##r,A##c)
53#define KDVVIEW(A) gsl_vector_const_view A = gsl_vector_const_view_array(A##p,A##n)
54#define KDMVIEW(A) gsl_matrix_const_view A = gsl_matrix_const_view_array(A##p,A##r,A##c)
55#define KCVVIEW(A) gsl_vector_complex_const_view A = gsl_vector_complex_const_view_array((double*)A##p,A##n)
56#define KCMVIEW(A) gsl_matrix_complex_const_view A = gsl_matrix_complex_const_view_array((double*)A##p,A##r,A##c)
57
58#define V(a) (&a.vector)
59#define M(a) (&a.matrix)
60
61#define GCVEC(A) int A##n, gsl_complex*A##p
62#define KGCVEC(A) int A##n, const gsl_complex*A##p
63
64#define BAD_SIZE 1000
65#define BAD_CODE 1001
66#define MEM 1002
67#define BAD_FILE 1003
68
69
70
71int transR(KRMAT(x),RMAT(t)) {
72 REQUIRES(xr==tc && xc==tr,BAD_SIZE);
73 DEBUGMSG("transR");
74 KDMVIEW(x);
75 DMVIEW(t);
76 int res = gsl_matrix_transpose_memcpy(M(t),M(x));
77 CHECK(res,res);
78 OK
79}
80
81int transC(KCMAT(x),CMAT(t)) {
82 REQUIRES(xr==tc && xc==tr,BAD_SIZE);
83 DEBUGMSG("transC");
84 KCMVIEW(x);
85 CMVIEW(t);
86 int res = gsl_matrix_complex_transpose_memcpy(M(t),M(x));
87 CHECK(res,res);
88 OK
89}
90
91
92int constantR(double * pval, RVEC(r)) {
93 DEBUGMSG("constantR")
94 int k;
95 double val = *pval;
96 for(k=0;k<rn;k++) {
97 rp[k]=val;
98 }
99 OK
100}
101
102int constantC(gsl_complex* pval, CVEC(r)) {
103 DEBUGMSG("constantC")
104 int k;
105 gsl_complex val = *pval;
106 for(k=0;k<rn;k++) {
107 rp[k]=val;
108 }
109 OK
110}
111
112int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r)) {
113 //printf("%d %d %d %d %d %d\n",ar,ac,br,bc,rr,rc);
114 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
115 DEBUGMSG("multiplyR (gsl_blas_dgemm)");
116 KDMVIEW(a);
117 KDMVIEW(b);
118 DMVIEW(r);
119 int res = gsl_blas_dgemm(
120 ta?CblasTrans:CblasNoTrans,
121 tb?CblasTrans:CblasNoTrans,
122 1.0, M(a), M(b),
123 0.0, M(r));
124 CHECK(res,res);
125 OK
126}
127
128int multiplyC(int ta, KCMAT(a), int tb, KCMAT(b),CMAT(r)) {
129 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
130 DEBUGMSG("multiplyC (gsl_blas_zgemm)");
131 KCMVIEW(a);
132 KCMVIEW(b);
133 CMVIEW(r);
134 gsl_complex alpha, beta;
135 GSL_SET_COMPLEX(&alpha,1.,0.);
136 GSL_SET_COMPLEX(&beta,0.,0.);
137 int res = gsl_blas_zgemm(
138 ta?CblasTrans:CblasNoTrans,
139 tb?CblasTrans:CblasNoTrans,
140 alpha, M(a), M(b),
141 beta, M(r));
142 CHECK(res,res);
143 OK
144}
diff --git a/lib/Data/Packed/Internal/aux.h b/lib/Data/Packed/Internal/aux.h
new file mode 100644
index 0000000..f45b55a
--- /dev/null
+++ b/lib/Data/Packed/Internal/aux.h
@@ -0,0 +1,21 @@
1#include <gsl/gsl_complex.h>
2
3#define RVEC(A) int A##n, double*A##p
4#define RMAT(A) int A##r, int A##c, double* A##p
5#define KRVEC(A) int A##n, const double*A##p
6#define KRMAT(A) int A##r, int A##c, const double* A##p
7
8#define CVEC(A) int A##n, gsl_complex*A##p
9#define CMAT(A) int A##r, int A##c, gsl_complex* A##p
10#define KCVEC(A) int A##n, const gsl_complex*A##p
11#define KCMAT(A) int A##r, int A##c, const gsl_complex* A##p
12
13
14int transR(KRMAT(x),RMAT(t));
15int transC(KCMAT(x),CMAT(t));
16
17int constantR(double *val , RVEC(r));
18int constantC(gsl_complex *val, CVEC(r));
19
20int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r));
21int multiplyC(int ta, KCMAT(a), int tb, KCMAT(b),CMAT(r));