From e2cb1eff0a954a83e0661ea1e7f70a47ed54e893 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Mon, 8 Jun 2015 10:09:39 +0200 Subject: modular C matrix product --- packages/base/src/Internal/C/lapack-aux.c | 38 ++++++++++++++----------------- packages/base/src/Internal/C/vector-aux.c | 32 +++++++++++++++++++------- packages/base/src/Internal/LAPACK.hs | 16 ++++++------- packages/base/src/Internal/Modular.hs | 26 +++++++++++++++++---- packages/base/src/Internal/Numeric.hs | 12 +++++----- packages/base/src/Internal/Vectorized.hs | 21 ++++++++--------- 6 files changed, 85 insertions(+), 60 deletions(-) diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c index 7da6f6a..1601bef 100644 --- a/packages/base/src/Internal/C/lapack-aux.c +++ b/packages/base/src/Internal/C/lapack-aux.c @@ -1290,29 +1290,25 @@ int multiplyQ(int ta, int tb, KQMAT(a),KQMAT(b),QMAT(r)) { } -int multiplyI(KOIMAT(a), KOIMAT(b), OIMAT(r)) { - { TRAV(r,i,j) { - int k; - AT(r,i,j) = 0; - for (k=0;k CInt -> TMMM R foreign import ccall unsafe "multiplyC" zgemmc :: CInt -> CInt -> TMMM C foreign import ccall unsafe "multiplyF" sgemmc :: CInt -> CInt -> TMMM F foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TMMM Q -foreign import ccall unsafe "multiplyI" c_multiplyI :: CInt ::> CInt ::> CInt ::> Ok -foreign import ccall unsafe "multiplyL" c_multiplyL :: Z ::> Z ::> Z ::> Ok +foreign import ccall unsafe "multiplyI" c_multiplyI :: I -> I ::> I ::> I ::> Ok +foreign import ccall unsafe "multiplyL" c_multiplyL :: Z -> Z ::> Z ::> Z ::> Ok isT Matrix{order = ColumnMajor} = 0 isT Matrix{order = RowMajor} = 1 @@ -68,20 +68,20 @@ multiplyF a b = multiplyAux sgemmc "sgemmc" a b multiplyQ :: Matrix (Complex Float) -> Matrix (Complex Float) -> Matrix (Complex Float) multiplyQ a b = multiplyAux cgemmc "cgemmc" a b -multiplyI :: Matrix CInt -> Matrix CInt -> Matrix CInt -multiplyI a b = unsafePerformIO $ do +multiplyI :: I -> Matrix CInt -> Matrix CInt -> Matrix CInt +multiplyI m a b = unsafePerformIO $ do when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b s <- createMatrix ColumnMajor (rows a) (cols b) - app3 c_multiplyI omat a omat b omat s "c_multiplyI" + app3 (c_multiplyI m) omat a omat b omat s "c_multiplyI" return s -multiplyL :: Matrix Z -> Matrix Z -> Matrix Z -multiplyL a b = unsafePerformIO $ do +multiplyL :: Z -> Matrix Z -> Matrix Z -> Matrix Z +multiplyL m a b = unsafePerformIO $ do when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b s <- createMatrix ColumnMajor (rows a) (cols b) - app3 c_multiplyL omat a omat b omat s "c_multiplyL" + app3 (c_multiplyL m) omat a omat b omat s "c_multiplyL" return s ----------------------------------------------------------------------------- diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index 36ffb57..0274607 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs @@ -30,6 +30,8 @@ import Internal.Matrix hiding (mat,size) import Internal.Numeric import Internal.Element import Internal.Container +import Internal.Vectorized (prodI,sumI) +import Internal.LAPACK (multiplyI) import Internal.Util(Indexable(..),gaussElim) import GHC.TypeLits import Data.Proxy(Proxy) @@ -145,8 +147,12 @@ instance forall m . KnownNat m => Container Vector (F m) maxIndex' = maxIndex . f2i minElement' = Mod . minElement . f2i maxElement' = Mod . maxElement . f2i - sumElements' = fromIntegral . sumElements . f2i -- FIXME - prodElements' = fromIntegral . sumElements . f2i -- FIXME + sumElements' = fromIntegral . sumI m' . f2i + where + m' = fromIntegral . natVal $ (undefined :: Proxy m) + prodElements' = fromIntegral . prodI m' . f2i + where + m' = fromIntegral . natVal $ (undefined :: Proxy m) step' = i2f . step . f2i find' = findV assoc' = assocV @@ -170,14 +176,14 @@ instance Indexable (Vector (F m)) (F m) type instance RealOf (F n) = I - instance KnownNat m => Product (F m) where norm2 = undefined absSum = undefined norm1 = undefined normInf = undefined - multiply = lift2 multiply -- FIXME - + multiply = lift2 (multiplyI m') + where + m' = fromIntegral . natVal $ (undefined :: Proxy m) instance KnownNat m => Numeric (F m) @@ -236,6 +242,9 @@ test = (ok, info) ad = fromInt a :: Matrix Double bd = fromInt b :: Matrix Double + g = (3><3) (repeat (40000)) :: Matrix I + gm = fromInt g :: Matrix (F 100000) + info = do print v print m @@ -247,10 +256,17 @@ test = (ok, info) print $ am <> gaussElim am bm - bm print $ ad <> gaussElim ad bd - bd + + print g + print $ g <> g + print gm + print $ gm <> gm ok = and [ toInt (m #> v) == cmod 11 (toInt m #> toInt v ) , am <> gaussElim am bm == bm + , prodElements (konst (9:: F 10) (12::Int)) == product (replicate 12 (9:: F 10)) + , gm <> gm == konst 0 (3,3) ] diff --git a/packages/base/src/Internal/Numeric.hs b/packages/base/src/Internal/Numeric.hs index eb744d1..2ef96bf 100644 --- a/packages/base/src/Internal/Numeric.hs +++ b/packages/base/src/Internal/Numeric.hs @@ -113,8 +113,8 @@ instance Container Vector I maxIndex' = emptyErrorV "maxIndex" (fromIntegral . toScalarI MaxIdx) minElement' = emptyErrorV "minElement" (toScalarI Min) maxElement' = emptyErrorV "maxElement" (toScalarI Max) - sumElements' = sumI - prodElements' = prodI + sumElements' = sumI 1 + prodElements' = prodI 1 step' = stepI find' = findV assoc' = assocV @@ -152,8 +152,8 @@ instance Container Vector Z maxIndex' = emptyErrorV "maxIndex" (fromIntegral . toScalarL MaxIdx) minElement' = emptyErrorV "minElement" (toScalarL Min) maxElement' = emptyErrorV "maxElement" (toScalarL Max) - sumElements' = sumL - prodElements' = prodL + sumElements' = sumL 1 + prodElements' = prodL 1 step' = stepL find' = findV assoc' = assocV @@ -596,14 +596,14 @@ instance Product I where absSum = emptyVal (sumElements . vectorMapI Abs) norm1 = absSum normInf = emptyVal (maxElement . vectorMapI Abs) - multiply = emptyMul multiplyI + multiply = emptyMul (multiplyI 1) instance Product Z where norm2 = undefined absSum = emptyVal (sumElements . vectorMapL Abs) norm1 = absSum normInf = emptyVal (maxElement . vectorMapL Abs) - multiply = emptyMul multiplyL + multiply = emptyMul (multiplyL 1) emptyMul m a b diff --git a/packages/base/src/Internal/Vectorized.hs b/packages/base/src/Internal/Vectorized.hs index ff51494..5c89ac9 100644 --- a/packages/base/src/Internal/Vectorized.hs +++ b/packages/base/src/Internal/Vectorized.hs @@ -94,11 +94,9 @@ sumQ = sumg c_sumQ sumC :: Vector (Complex Double) -> Complex Double sumC = sumg c_sumC --- | sum of elements -sumI :: Vector CInt -> CInt -sumI = sumg c_sumI +sumI m = sumg (c_sumI m) -sumL = sumg c_sumL +sumL m = sumg (c_sumL m) sumg f x = unsafePerformIO $ do r <- createVector 1 @@ -111,8 +109,8 @@ foreign import ccall unsafe "sumF" c_sumF :: TVV Float foreign import ccall unsafe "sumR" c_sumR :: TVV Double foreign import ccall unsafe "sumQ" c_sumQ :: TVV (Complex Float) foreign import ccall unsafe "sumC" c_sumC :: TVV (Complex Double) -foreign import ccall unsafe "sumI" c_sumI :: TVV CInt -foreign import ccall unsafe "sumL" c_sumL :: TVV Z +foreign import ccall unsafe "sumI" c_sumI :: I -> TVV I +foreign import ccall unsafe "sumL" c_sumL :: Z -> TVV Z -- | product of elements prodF :: Vector Float -> Float @@ -130,11 +128,10 @@ prodQ = prodg c_prodQ prodC :: Vector (Complex Double) -> Complex Double prodC = prodg c_prodC --- | product of elements -prodI :: Vector CInt -> CInt -prodI = prodg c_prodI -prodL = prodg c_prodL +prodI = prodg . c_prodI + +prodL = prodg . c_prodL prodg f x = unsafePerformIO $ do r <- createVector 1 @@ -146,8 +143,8 @@ foreign import ccall unsafe "prodF" c_prodF :: TVV Float foreign import ccall unsafe "prodR" c_prodR :: TVV Double foreign import ccall unsafe "prodQ" c_prodQ :: TVV (Complex Float) foreign import ccall unsafe "prodC" c_prodC :: TVV (Complex Double) -foreign import ccall unsafe "prodI" c_prodI :: TVV (CInt) -foreign import ccall unsafe "prodL" c_prodL :: TVV Z +foreign import ccall unsafe "prodI" c_prodI :: I -> TVV I +foreign import ccall unsafe "prodL" c_prodL :: Z -> TVV Z ------------------------------------------------------------------ -- cgit v1.2.3