From 0a9817cc481fb09f1962eb2c272125e56a123814 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Mon, 4 Jun 2007 08:34:45 +0000 Subject: fortran/C --- lib/Data/Packed/Internal.hs | 286 +++++++++++++++++++++++++++++++++----------- lib/Data/Packed/aux.c | 98 +++++++++++---- lib/Data/Packed/aux.h | 25 +++- 3 files changed, 312 insertions(+), 97 deletions(-) (limited to 'lib') diff --git a/lib/Data/Packed/Internal.hs b/lib/Data/Packed/Internal.hs index 5e19e58..b06f044 100644 --- a/lib/Data/Packed/Internal.hs +++ b/lib/Data/Packed/Internal.hs @@ -1,3 +1,4 @@ +{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} ----------------------------------------------------------------------------- -- | -- Module : Data.Packed.Internal @@ -14,39 +15,16 @@ module Data.Packed.Internal where -import Foreign +import Foreign hiding (xor) import Complex import Control.Monad(when) import Debug.Trace +import Data.List(transpose,intersperse) +import Data.Typeable +import Data.Maybe(fromJust) debug x = trace (show x) x --- | 1D array -data Vector t = V { dim :: Int - , fptr :: ForeignPtr t - , ptr :: Ptr t - } - -data TransMode = NoTrans | Trans | ConjTrans - --- | 2D array -data Matrix t = M { rows :: Int - , cols :: Int - , mat :: Vector t - , trMode :: TransMode - , isCOrder :: Bool - } - -data IdxTp = Covariant | Contravariant - --- | multidimensional array -data Tensor t = T { rank :: Int - , dims :: [Int] - , idxNm :: [String] - , idxTp :: [IdxTp] - , ten :: Vector t - } - ---------------------------------------------------------------------- instance (Storable a, RealFloat a) => Storable (Complex a) where -- alignment x = alignment (realPart x) -- @@ -57,36 +35,36 @@ instance (Storable a, RealFloat a) => Storable (Complex a) where -- poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- ---------------------------------------------------------------------- - --- f // vec a // vec b // vec res // check "vector add" [a,b] - (//) :: x -> (x -> y) -> y infixl 0 // (//) = flip ($) -vec :: Vector a -> (Int -> Ptr b -> t) -> t -vec v f = f (dim v) (castPtr $ ptr v) - -mata :: Matrix a -> (Int-> Int -> Ptr b -> t) -> t -mata m f = f (rows m) (cols m) (castPtr $ ptr (mat m)) - -pd2pc :: Ptr Double -> Ptr (Complex (Double)) -pd2pc = castPtr - -pc2pd :: Ptr (Complex (Double)) -> Ptr Double -pc2pd = castPtr - check msg ls f = do err <- f when (err/=0) (error msg) mapM_ (touchForeignPtr . fptr) ls return () +---------------------------------------------------------------------- + +data Vector t = V { dim :: Int + , fptr :: ForeignPtr t + , ptr :: Ptr t + } deriving Typeable + +type Vc t s = Int -> Ptr t -> s +infixr 5 :> +type t :> s = Vc t s + +vec :: Vector t -> (Vc t s) -> s +vec v f = f (dim v) (ptr v) + createVector :: Storable a => Int -> IO (Vector a) createVector n = do when (n <= 0) $ error ("trying to createVector of dim "++show n) fp <- mallocForeignPtrArray n let p = unsafeForeignPtrToPtr fp + --putStrLn ("\n---------> V"++show n) return $ V n fp p fromList :: Storable a => [a] -> Vector a @@ -99,6 +77,8 @@ fromList l = unsafePerformIO $ do toList :: Storable a => Vector a -> [a] toList v = unsafePerformIO $ peekArray (dim v) (ptr v) +n # l = if length l == n then fromList l else error "# with wrong size" + at' :: Storable a => Vector a -> Int -> a at' v n = unsafePerformIO $ peekElemOff (ptr v) n @@ -106,42 +86,208 @@ at :: Storable a => Vector a -> Int -> a at v n | n >= 0 && n < dim v = at' v n | otherwise = error "vector index out of range" -dsv v = sizeOf (v `at` 0) -dsm m = (dsv.mat) m +instance (Show a, Storable a) => (Show (Vector a)) where + show v = (show (dim v))++" # " ++ show (toList v) -constant :: Storable a => Int -> a -> Vector a -constant n x = unsafePerformIO $ do - v <- createVector n - let f k p | k == n = return 0 - | otherwise = pokeElemOff p k x >> f (k+1) p - const (f 0) // vec v // check "constant" [] - return v +------------------------------------------------------------------------ -instance (Show a, Storable a) => (Show (Vector a)) where - show v = "fromList " ++ show (toList v) +data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) + +-- | 2D array +data Matrix t = M { rows :: Int + , cols :: Int + , cmat :: Vector t + , fmat :: Vector t + , isTrans :: Bool + , order :: MatrixOrder + } deriving Typeable + +xor a b = a && not b || b && not a + +fortran m = order m == ColumnMajor + +dat m = if fortran m `xor` isTrans m then fmat m else cmat m + +pref m = if fortran m then fmat m else cmat m + +trans m = m { rows = cols m + , cols = rows m + , isTrans = not (isTrans m) + } + +type Mt t s = Int -> Int -> Ptr t -> s +infixr 6 ::> +type t ::> s = Mt t s + +mat :: Matrix t -> (Mt t s) -> s +mat m f = f (rows m) (cols m) (ptr (dat m)) + +gmat m f | fortran m = + if (isTrans m) + then f 0 (rows m) (cols m) (ptr (fmat m)) + else f 1 (cols m) (rows m) (ptr (fmat m)) + | otherwise = + if isTrans m + then f 1 (cols m) (rows m) (ptr (cmat m)) + else f 0 (rows m) (cols m) (ptr (cmat m)) instance (Show a, Storable a) => (Show (Matrix a)) where - show m = "reshape "++show (cols m) ++ " $ fromList " ++ show (toList (mat m)) + show m = (sizes++) . dsp . map (map show) . toLists $ m + where sizes = "("++show (rows m)++"><"++show (cols m)++")\n" + +partit :: Int -> [a] -> [[a]] +partit _ [] = [] +partit n l = take n l : partit n (drop n l) + +toLists m = partit (cols m) . toList . cmat $ m -reshape :: Storable a => Int -> Vector a -> Matrix a -reshape n v = M { rows = dim v `div` n - , cols = n - , mat = v - , trMode = NoTrans - , isCOrder = True - } +dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp + where + mt = transpose as + longs = map (maximum . map length) mt + mtp = zipWith (\a b -> map (pad a) b) longs mt + pad n str = replicate (n - length str) ' ' ++ str + unwords' = concat . intersperse ", " -createMatrix r c = do +matrixFromVector RowMajor c v = + M { rows = r + , cols = c + , cmat = v + , fmat = transdata c v r + , order = RowMajor + , isTrans = False + } where r = dim v `div` c -- TODO check mod=0 + +matrixFromVector ColumnMajor c v = + M { rows = r + , cols = c + , fmat = v + , cmat = transdata c v r + , order = ColumnMajor + , isTrans = False + } where r = dim v `div` c -- TODO check mod=0 + +createMatrix order r c = do p <- createVector (r*c) - return (reshape c p) + return (matrixFromVector order c p) + +transdataG :: Storable a => Int -> Vector a -> Int -> Vector a +transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d + +transdataR :: Int -> Vector Double -> Int -> Vector Double +transdataR = transdataAux ctransR + +transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) +transdataC = transdataAux ctransC + +transdataAux fun c1 d c2 = unsafePerformIO $ do + v <- createVector (dim d) + let r1 = dim d `div` c1 + r2 = dim d `div` c2 + fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] + --putStrLn "---> transdataAux" + return v + +foreign import ccall safe "aux.h transR" + ctransR :: Double ::> Double ::> IO Int +foreign import ccall safe "aux.h transC" + ctransC :: Complex Double ::> Complex Double ::> IO Int + + +class (Storable a, Typeable a) => Field a where +instance (Storable a, Typeable a) => Field a where + +isReal w x = typeOf (undefined :: Double) == typeOf (w x) +isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x) +baseOf v = (v `at` 0) + +scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b +scast = fromJust . cast + +transdata :: Field a => Int -> Vector a -> Int -> Vector a +transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2 + | isComp baseOf d = scast $ transdataC c1 (scast d) c2 + | otherwise = transdataG c1 d c2 + +--transdata :: Storable a => Int -> Vector a -> Int -> Vector a +--transdata = transdataG +--{-# RULES "transdataR" transdata=transdataR #-} +--{-# RULES "transdataC" transdata=transdataC #-} + +------------------------------------------------------------------ + +constantG n x = fromList (replicate n x) + +constantR :: Int -> Double -> Vector Double +constantR = constantAux cconstantR + +constantC :: Int -> Complex Double -> Vector (Complex Double) +constantC = constantAux cconstantC + +constantAux fun n x = unsafePerformIO $ do + v <- createVector n + px <- newArray [x] + fun px // vec v // check "constantAux" [] + free px + return v -type CMat s = Int -> Int -> Ptr Double -> s -type CVec s = Int -> Ptr Double -> s +foreign import ccall safe "aux.h constantR" + cconstantR :: Ptr Double -> Double :> IO Int -foreign import ccall safe "aux.h trans" ctrans :: Int -> CMat (CMat (IO Int)) +foreign import ccall safe "aux.h constantC" + cconstantC :: Ptr (Complex Double) -> Complex Double :> IO Int -trans :: Storable a => Matrix a -> Matrix a -trans m = unsafePerformIO $ do - r <- createMatrix (cols m) (rows m) - ctrans (dsm m) // mata m // mata r // check "trans" [mat m] +constant :: Field a => Int -> a -> Vector a +constant n x | isReal id x = scast $ constantR n (scast x) + | isComp id x = scast $ constantC n (scast x) + | otherwise = constantG n x + +------------------------------------------------------------------ + +dotL a b = sum (zipWith (*) a b) + +multiplyL a b = [[dotL x y | y <- transpose b] | x <- a] + +transL m = m {rows = cols m, cols = rows m, cmat = v, fmat = cmat m} + where v = transdataG (cols m) (cmat m) (rows m) + +------------------------------------------------------------------ + +multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b) + +multiplyAux order fun a b = unsafePerformIO $ do + when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ + show (rows a,cols a) ++ " x " ++ show (rows b, cols b) + r <- createMatrix order (rows a) (cols b) + fun // gmat a // gmat b // mat r // check "multiplyAux" [pref a, pref b] return r + +foreign import ccall safe "aux.h multiplyR" + cmultiplyR :: Int -> Double ::> (Int -> Double ::> (Double ::> IO Int)) + +foreign import ccall safe "aux.h multiplyC" + cmultiplyC :: Int -> Complex Double ::> (Int -> Complex Double ::> (Complex Double ::> IO Int)) + +multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a +multiply RowMajor a b = multiplyD RowMajor a b +multiply ColumnMajor a b = trans $ multiplyT ColumnMajor a b + +multiplyT order a b = multiplyD order (trans b) (trans a) + +multiplyD order a b + | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b) + | isComp (baseOf.dat) a = scast $ multiplyAux order cmultiplyC (scast a) (scast b) + | otherwise = multiplyG a b + +-------------------------------------------------------------------- + +data IdxTp = Covariant | Contravariant + +-- | multidimensional array +data Tensor t = T { rank :: Int + , dims :: [Int] + , idxNm :: [String] + , idxTp :: [IdxTp] + , ten :: Vector t + } + diff --git a/lib/Data/Packed/aux.c b/lib/Data/Packed/aux.c index d772d90..da36035 100644 --- a/lib/Data/Packed/aux.c +++ b/lib/Data/Packed/aux.c @@ -48,12 +48,12 @@ #define DVVIEW(A) gsl_vector_view A = gsl_vector_view_array(A##p,A##n) #define DMVIEW(A) gsl_matrix_view A = gsl_matrix_view_array(A##p,A##r,A##c) -#define CVVIEW(A) gsl_vector_complex_view A = gsl_vector_complex_view_array(A##p,A##n) -#define CMVIEW(A) gsl_matrix_complex_view A = gsl_matrix_complex_view_array(A##p,A##r,A##c) +#define CVVIEW(A) gsl_vector_complex_view A = gsl_vector_complex_view_array((double*)A##p,A##n) +#define CMVIEW(A) gsl_matrix_complex_view A = gsl_matrix_complex_view_array((double*)A##p,A##r,A##c) #define KDVVIEW(A) gsl_vector_const_view A = gsl_vector_const_view_array(A##p,A##n) #define KDMVIEW(A) gsl_matrix_const_view A = gsl_matrix_const_view_array(A##p,A##r,A##c) -#define KCVVIEW(A) gsl_vector_complex_const_view A = gsl_vector_complex_const_view_array(A##p,A##n) -#define KCMVIEW(A) gsl_matrix_complex_const_view A = gsl_matrix_complex_const_view_array(A##p,A##r,A##c) +#define KCVVIEW(A) gsl_vector_complex_const_view A = gsl_vector_complex_const_view_array((double*)A##p,A##n) +#define KCMVIEW(A) gsl_matrix_complex_const_view A = gsl_matrix_complex_const_view_array((double*)A##p,A##r,A##c) #define V(a) (&a.vector) #define M(a) (&a.matrix) @@ -65,26 +65,80 @@ #define BAD_CODE 1001 #define MEM 1002 #define BAD_FILE 1003 -#define BAD_TYPE 1004 -int trans(int size,KMAT(x),MAT(t)) { + +int transR(KRMAT(x),RMAT(t)) { + REQUIRES(xr==tc && xc==tr,BAD_SIZE); + DEBUGMSG("transR"); + KDMVIEW(x); + DMVIEW(t); + int res = gsl_matrix_transpose_memcpy(M(t),M(x)); + CHECK(res,res); + OK +} + +int transC(KCMAT(x),CMAT(t)) { REQUIRES(xr==tc && xc==tr,BAD_SIZE); - DEBUGMSG("trans"); - if(size==8) { - DEBUGMSG("trans double"); - KDMVIEW(x); - DMVIEW(t); - int res = gsl_matrix_transpose_memcpy(M(t),M(x)); - CHECK(res,res); - OK - } else if (size==16) { - DEBUGMSG("trans complex double"); - KCMVIEW(x); - CMVIEW(t); - int res = gsl_matrix_complex_transpose_memcpy(M(t),M(x)); - CHECK(res,res); - OK + DEBUGMSG("transC"); + KCMVIEW(x); + CMVIEW(t); + int res = gsl_matrix_complex_transpose_memcpy(M(t),M(x)); + CHECK(res,res); + OK +} + + +int constantR(double * pval, RVEC(r)) { + DEBUGMSG("constantR") + int k; + double val = *pval; + for(k=0;k -int trans(int size, KMAT(x),MAT(t)); +#define RVEC(A) int A##n, double*A##p +#define RMAT(A) int A##r, int A##c, double* A##p +#define KRVEC(A) int A##n, const double*A##p +#define KRMAT(A) int A##r, int A##c, const double* A##p + +#define CVEC(A) int A##n, gsl_complex*A##p +#define CMAT(A) int A##r, int A##c, gsl_complex* A##p +#define KCVEC(A) int A##n, const gsl_complex*A##p +#define KCMAT(A) int A##r, int A##c, const gsl_complex* A##p + + +int transR(KRMAT(x),RMAT(t)); +int transC(KCMAT(x),CMAT(t)); + +int constantR(double *val , RVEC(r)); +int constantC(gsl_complex *val, CVEC(r)); + +int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r)); +int multiplyC(int ta, KCMAT(a), int tb, KCMAT(b),CMAT(r)); -- cgit v1.2.3