summaryrefslogtreecommitdiff
path: root/packages/base/src
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-06-08 10:09:39 +0200
committerAlberto Ruiz <aruiz@um.es>2015-06-08 10:09:39 +0200
commite2cb1eff0a954a83e0661ea1e7f70a47ed54e893 (patch)
treef1b214ba3cb8f29f1b17156e7bb5ef72d3f53d39 /packages/base/src
parentccb56d051ce92879a54fcd218bfeac48523b0de0 (diff)
modular C matrix product
Diffstat (limited to 'packages/base/src')
-rw-r--r--packages/base/src/Internal/C/lapack-aux.c38
-rw-r--r--packages/base/src/Internal/C/vector-aux.c32
-rw-r--r--packages/base/src/Internal/LAPACK.hs16
-rw-r--r--packages/base/src/Internal/Modular.hs26
-rw-r--r--packages/base/src/Internal/Numeric.hs12
-rw-r--r--packages/base/src/Internal/Vectorized.hs21
6 files changed, 85 insertions, 60 deletions
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c
index 7da6f6a..1601bef 100644
--- a/packages/base/src/Internal/C/lapack-aux.c
+++ b/packages/base/src/Internal/C/lapack-aux.c
@@ -1290,29 +1290,25 @@ int multiplyQ(int ta, int tb, KQMAT(a),KQMAT(b),QMAT(r)) {
1290} 1290}
1291 1291
1292 1292
1293int multiplyI(KOIMAT(a), KOIMAT(b), OIMAT(r)) { 1293#define MULT_IMP_VER(OP) \
1294 { TRAV(r,i,j) { 1294 { TRAV(r,i,j) { \
1295 int k; 1295 int k; \
1296 AT(r,i,j) = 0; 1296 AT(r,i,j) = 0; \
1297 for (k=0;k<ac;k++) { 1297 for (k=0;k<ac;k++) { \
1298 AT(r,i,j) += AT(a,i,k) * AT(b,k,j); 1298 OP \
1299 } 1299 } \
1300 } 1300 } \
1301 } 1301 }
1302 OK
1303}
1304 1302
1305int multiplyL(KOLMAT(a), KOLMAT(b), OLMAT(r)) { 1303#define MULT_IMP { \
1306 { TRAV(r,i,j) { 1304 if (m==1) { \
1307 int k; 1305 MULT_IMP_VER( AT(r,i,j) += AT(a,i,k) * AT(b,k,j); ) \
1308 AT(r,i,j) = 0; 1306 } else { \
1309 for (k=0;k<ac;k++) { 1307 MULT_IMP_VER( AT(r,i,j) = (AT(r,i,j) + (AT(a,i,k) * AT(b,k,j)) % m) % m ; ) \
1310 AT(r,i,j) += AT(a,i,k) * AT(b,k,j); 1308 } OK }
1311 } 1309
1312 } 1310int multiplyI(int m, KOIMAT(a), KOIMAT(b), OIMAT(r)) MULT_IMP
1313 } 1311int multiplyL(int32_t m, KOLMAT(a), KOLMAT(b), OLMAT(r)) MULT_IMP
1314 OK
1315}
1316 1312
1317 1313
1318////////////////// sparse matrix-product /////////////////////////////////////// 1314////////////////// sparse matrix-product ///////////////////////////////////////
diff --git a/packages/base/src/Internal/C/vector-aux.c b/packages/base/src/Internal/C/vector-aux.c
index 70e46bc..580aa1c 100644
--- a/packages/base/src/Internal/C/vector-aux.c
+++ b/packages/base/src/Internal/C/vector-aux.c
@@ -58,20 +58,28 @@ int sumR(KDVEC(x),DVEC(r)) {
58 OK 58 OK
59} 59}
60 60
61int sumI(KIVEC(x),IVEC(r)) { 61int sumI(int m, KIVEC(x),IVEC(r)) {
62 REQUIRES(rn==1,BAD_SIZE); 62 REQUIRES(rn==1,BAD_SIZE);
63 int i; 63 int i;
64 int res = 0; 64 int res = 0;
65 for (i = 0; i < xn; i++) res += xp[i]; 65 if (m==1) {
66 for (i = 0; i < xn; i++) res += xp[i];
67 } else {
68 for (i = 0; i < xn; i++) res = (res + xp[i]) % m;
69 }
66 rp[0] = res; 70 rp[0] = res;
67 OK 71 OK
68} 72}
69 73
70int sumL(KLVEC(x),LVEC(r)) { 74int sumL(int32_t m, KLVEC(x),LVEC(r)) {
71 REQUIRES(rn==1,BAD_SIZE); 75 REQUIRES(rn==1,BAD_SIZE);
72 int i; 76 int i;
73 int res = 0; 77 int res = 0;
74 for (i = 0; i < xn; i++) res += xp[i]; 78 if (m==1) {
79 for (i = 0; i < xn; i++) res += xp[i];
80 } else {
81 for (i = 0; i < xn; i++) res = (res + xp[i]) % m;
82 }
75 rp[0] = res; 83 rp[0] = res;
76 OK 84 OK
77} 85}
@@ -127,20 +135,28 @@ int prodR(KDVEC(x),DVEC(r)) {
127 OK 135 OK
128} 136}
129 137
130int prodI(KIVEC(x),IVEC(r)) { 138int prodI(int m, KIVEC(x),IVEC(r)) {
131 REQUIRES(rn==1,BAD_SIZE); 139 REQUIRES(rn==1,BAD_SIZE);
132 int i; 140 int i;
133 int res = 1; 141 int res = 1;
134 for (i = 0; i < xn; i++) res *= xp[i]; 142 if (m==1) {
143 for (i = 0; i < xn; i++) res *= xp[i];
144 } else {
145 for (i = 0; i < xn; i++) res = (res * xp[i]) % m;
146 }
135 rp[0] = res; 147 rp[0] = res;
136 OK 148 OK
137} 149}
138 150
139int prodL(KLVEC(x),LVEC(r)) { 151int prodL(int32_t m, KLVEC(x),LVEC(r)) {
140 REQUIRES(rn==1,BAD_SIZE); 152 REQUIRES(rn==1,BAD_SIZE);
141 int i; 153 int i;
142 int res = 1; 154 int res = 1;
143 for (i = 0; i < xn; i++) res *= xp[i]; 155 if (m==1) {
156 for (i = 0; i < xn; i++) res *= xp[i];
157 } else {
158 for (i = 0; i < xn; i++) res = (res * xp[i]) % m;
159 }
144 rp[0] = res; 160 rp[0] = res;
145 OK 161 OK
146} 162}
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs
index 469b0d5..8df568d 100644
--- a/packages/base/src/Internal/LAPACK.hs
+++ b/packages/base/src/Internal/LAPACK.hs
@@ -36,8 +36,8 @@ foreign import ccall unsafe "multiplyR" dgemmc :: CInt -> CInt -> TMMM R
36foreign import ccall unsafe "multiplyC" zgemmc :: CInt -> CInt -> TMMM C 36foreign import ccall unsafe "multiplyC" zgemmc :: CInt -> CInt -> TMMM C
37foreign import ccall unsafe "multiplyF" sgemmc :: CInt -> CInt -> TMMM F 37foreign import ccall unsafe "multiplyF" sgemmc :: CInt -> CInt -> TMMM F
38foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TMMM Q 38foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TMMM Q
39foreign import ccall unsafe "multiplyI" c_multiplyI :: CInt ::> CInt ::> CInt ::> Ok 39foreign import ccall unsafe "multiplyI" c_multiplyI :: I -> I ::> I ::> I ::> Ok
40foreign import ccall unsafe "multiplyL" c_multiplyL :: Z ::> Z ::> Z ::> Ok 40foreign import ccall unsafe "multiplyL" c_multiplyL :: Z -> Z ::> Z ::> Z ::> Ok
41 41
42isT Matrix{order = ColumnMajor} = 0 42isT Matrix{order = ColumnMajor} = 0
43isT Matrix{order = RowMajor} = 1 43isT Matrix{order = RowMajor} = 1
@@ -68,20 +68,20 @@ multiplyF a b = multiplyAux sgemmc "sgemmc" a b
68multiplyQ :: Matrix (Complex Float) -> Matrix (Complex Float) -> Matrix (Complex Float) 68multiplyQ :: Matrix (Complex Float) -> Matrix (Complex Float) -> Matrix (Complex Float)
69multiplyQ a b = multiplyAux cgemmc "cgemmc" a b 69multiplyQ a b = multiplyAux cgemmc "cgemmc" a b
70 70
71multiplyI :: Matrix CInt -> Matrix CInt -> Matrix CInt 71multiplyI :: I -> Matrix CInt -> Matrix CInt -> Matrix CInt
72multiplyI a b = unsafePerformIO $ do 72multiplyI m a b = unsafePerformIO $ do
73 when (cols a /= rows b) $ error $ 73 when (cols a /= rows b) $ error $
74 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b 74 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b
75 s <- createMatrix ColumnMajor (rows a) (cols b) 75 s <- createMatrix ColumnMajor (rows a) (cols b)
76 app3 c_multiplyI omat a omat b omat s "c_multiplyI" 76 app3 (c_multiplyI m) omat a omat b omat s "c_multiplyI"
77 return s 77 return s
78 78
79multiplyL :: Matrix Z -> Matrix Z -> Matrix Z 79multiplyL :: Z -> Matrix Z -> Matrix Z -> Matrix Z
80multiplyL a b = unsafePerformIO $ do 80multiplyL m a b = unsafePerformIO $ do
81 when (cols a /= rows b) $ error $ 81 when (cols a /= rows b) $ error $
82 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b 82 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b
83 s <- createMatrix ColumnMajor (rows a) (cols b) 83 s <- createMatrix ColumnMajor (rows a) (cols b)
84 app3 c_multiplyL omat a omat b omat s "c_multiplyL" 84 app3 (c_multiplyL m) omat a omat b omat s "c_multiplyL"
85 return s 85 return s
86 86
87----------------------------------------------------------------------------- 87-----------------------------------------------------------------------------
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs
index 36ffb57..0274607 100644
--- a/packages/base/src/Internal/Modular.hs
+++ b/packages/base/src/Internal/Modular.hs
@@ -30,6 +30,8 @@ import Internal.Matrix hiding (mat,size)
30import Internal.Numeric 30import Internal.Numeric
31import Internal.Element 31import Internal.Element
32import Internal.Container 32import Internal.Container
33import Internal.Vectorized (prodI,sumI)
34import Internal.LAPACK (multiplyI)
33import Internal.Util(Indexable(..),gaussElim) 35import Internal.Util(Indexable(..),gaussElim)
34import GHC.TypeLits 36import GHC.TypeLits
35import Data.Proxy(Proxy) 37import Data.Proxy(Proxy)
@@ -145,8 +147,12 @@ instance forall m . KnownNat m => Container Vector (F m)
145 maxIndex' = maxIndex . f2i 147 maxIndex' = maxIndex . f2i
146 minElement' = Mod . minElement . f2i 148 minElement' = Mod . minElement . f2i
147 maxElement' = Mod . maxElement . f2i 149 maxElement' = Mod . maxElement . f2i
148 sumElements' = fromIntegral . sumElements . f2i -- FIXME 150 sumElements' = fromIntegral . sumI m' . f2i
149 prodElements' = fromIntegral . sumElements . f2i -- FIXME 151 where
152 m' = fromIntegral . natVal $ (undefined :: Proxy m)
153 prodElements' = fromIntegral . prodI m' . f2i
154 where
155 m' = fromIntegral . natVal $ (undefined :: Proxy m)
150 step' = i2f . step . f2i 156 step' = i2f . step . f2i
151 find' = findV 157 find' = findV
152 assoc' = assocV 158 assoc' = assocV
@@ -170,14 +176,14 @@ instance Indexable (Vector (F m)) (F m)
170 176
171type instance RealOf (F n) = I 177type instance RealOf (F n) = I
172 178
173
174instance KnownNat m => Product (F m) where 179instance KnownNat m => Product (F m) where
175 norm2 = undefined 180 norm2 = undefined
176 absSum = undefined 181 absSum = undefined
177 norm1 = undefined 182 norm1 = undefined
178 normInf = undefined 183 normInf = undefined
179 multiply = lift2 multiply -- FIXME 184 multiply = lift2 (multiplyI m')
180 185 where
186 m' = fromIntegral . natVal $ (undefined :: Proxy m)
181 187
182instance KnownNat m => Numeric (F m) 188instance KnownNat m => Numeric (F m)
183 189
@@ -236,6 +242,9 @@ test = (ok, info)
236 ad = fromInt a :: Matrix Double 242 ad = fromInt a :: Matrix Double
237 bd = fromInt b :: Matrix Double 243 bd = fromInt b :: Matrix Double
238 244
245 g = (3><3) (repeat (40000)) :: Matrix I
246 gm = fromInt g :: Matrix (F 100000)
247
239 info = do 248 info = do
240 print v 249 print v
241 print m 250 print m
@@ -247,10 +256,17 @@ test = (ok, info)
247 256
248 print $ am <> gaussElim am bm - bm 257 print $ am <> gaussElim am bm - bm
249 print $ ad <> gaussElim ad bd - bd 258 print $ ad <> gaussElim ad bd - bd
259
260 print g
261 print $ g <> g
262 print gm
263 print $ gm <> gm
250 264
251 ok = and 265 ok = and
252 [ toInt (m #> v) == cmod 11 (toInt m #> toInt v ) 266 [ toInt (m #> v) == cmod 11 (toInt m #> toInt v )
253 , am <> gaussElim am bm == bm 267 , am <> gaussElim am bm == bm
268 , prodElements (konst (9:: F 10) (12::Int)) == product (replicate 12 (9:: F 10))
269 , gm <> gm == konst 0 (3,3)
254 ] 270 ]
255 271
256 272
diff --git a/packages/base/src/Internal/Numeric.hs b/packages/base/src/Internal/Numeric.hs
index eb744d1..2ef96bf 100644
--- a/packages/base/src/Internal/Numeric.hs
+++ b/packages/base/src/Internal/Numeric.hs
@@ -113,8 +113,8 @@ instance Container Vector I
113 maxIndex' = emptyErrorV "maxIndex" (fromIntegral . toScalarI MaxIdx) 113 maxIndex' = emptyErrorV "maxIndex" (fromIntegral . toScalarI MaxIdx)
114 minElement' = emptyErrorV "minElement" (toScalarI Min) 114 minElement' = emptyErrorV "minElement" (toScalarI Min)
115 maxElement' = emptyErrorV "maxElement" (toScalarI Max) 115 maxElement' = emptyErrorV "maxElement" (toScalarI Max)
116 sumElements' = sumI 116 sumElements' = sumI 1
117 prodElements' = prodI 117 prodElements' = prodI 1
118 step' = stepI 118 step' = stepI
119 find' = findV 119 find' = findV
120 assoc' = assocV 120 assoc' = assocV
@@ -152,8 +152,8 @@ instance Container Vector Z
152 maxIndex' = emptyErrorV "maxIndex" (fromIntegral . toScalarL MaxIdx) 152 maxIndex' = emptyErrorV "maxIndex" (fromIntegral . toScalarL MaxIdx)
153 minElement' = emptyErrorV "minElement" (toScalarL Min) 153 minElement' = emptyErrorV "minElement" (toScalarL Min)
154 maxElement' = emptyErrorV "maxElement" (toScalarL Max) 154 maxElement' = emptyErrorV "maxElement" (toScalarL Max)
155 sumElements' = sumL 155 sumElements' = sumL 1
156 prodElements' = prodL 156 prodElements' = prodL 1
157 step' = stepL 157 step' = stepL
158 find' = findV 158 find' = findV
159 assoc' = assocV 159 assoc' = assocV
@@ -596,14 +596,14 @@ instance Product I where
596 absSum = emptyVal (sumElements . vectorMapI Abs) 596 absSum = emptyVal (sumElements . vectorMapI Abs)
597 norm1 = absSum 597 norm1 = absSum
598 normInf = emptyVal (maxElement . vectorMapI Abs) 598 normInf = emptyVal (maxElement . vectorMapI Abs)
599 multiply = emptyMul multiplyI 599 multiply = emptyMul (multiplyI 1)
600 600
601instance Product Z where 601instance Product Z where
602 norm2 = undefined 602 norm2 = undefined
603 absSum = emptyVal (sumElements . vectorMapL Abs) 603 absSum = emptyVal (sumElements . vectorMapL Abs)
604 norm1 = absSum 604 norm1 = absSum
605 normInf = emptyVal (maxElement . vectorMapL Abs) 605 normInf = emptyVal (maxElement . vectorMapL Abs)
606 multiply = emptyMul multiplyL 606 multiply = emptyMul (multiplyL 1)
607 607
608 608
609emptyMul m a b 609emptyMul m a b
diff --git a/packages/base/src/Internal/Vectorized.hs b/packages/base/src/Internal/Vectorized.hs
index ff51494..5c89ac9 100644
--- a/packages/base/src/Internal/Vectorized.hs
+++ b/packages/base/src/Internal/Vectorized.hs
@@ -94,11 +94,9 @@ sumQ = sumg c_sumQ
94sumC :: Vector (Complex Double) -> Complex Double 94sumC :: Vector (Complex Double) -> Complex Double
95sumC = sumg c_sumC 95sumC = sumg c_sumC
96 96
97-- | sum of elements 97sumI m = sumg (c_sumI m)
98sumI :: Vector CInt -> CInt
99sumI = sumg c_sumI
100 98
101sumL = sumg c_sumL 99sumL m = sumg (c_sumL m)
102 100
103sumg f x = unsafePerformIO $ do 101sumg f x = unsafePerformIO $ do
104 r <- createVector 1 102 r <- createVector 1
@@ -111,8 +109,8 @@ foreign import ccall unsafe "sumF" c_sumF :: TVV Float
111foreign import ccall unsafe "sumR" c_sumR :: TVV Double 109foreign import ccall unsafe "sumR" c_sumR :: TVV Double
112foreign import ccall unsafe "sumQ" c_sumQ :: TVV (Complex Float) 110foreign import ccall unsafe "sumQ" c_sumQ :: TVV (Complex Float)
113foreign import ccall unsafe "sumC" c_sumC :: TVV (Complex Double) 111foreign import ccall unsafe "sumC" c_sumC :: TVV (Complex Double)
114foreign import ccall unsafe "sumI" c_sumI :: TVV CInt 112foreign import ccall unsafe "sumI" c_sumI :: I -> TVV I
115foreign import ccall unsafe "sumL" c_sumL :: TVV Z 113foreign import ccall unsafe "sumL" c_sumL :: Z -> TVV Z
116 114
117-- | product of elements 115-- | product of elements
118prodF :: Vector Float -> Float 116prodF :: Vector Float -> Float
@@ -130,11 +128,10 @@ prodQ = prodg c_prodQ
130prodC :: Vector (Complex Double) -> Complex Double 128prodC :: Vector (Complex Double) -> Complex Double
131prodC = prodg c_prodC 129prodC = prodg c_prodC
132 130
133-- | product of elements
134prodI :: Vector CInt -> CInt
135prodI = prodg c_prodI
136 131
137prodL = prodg c_prodL 132prodI = prodg . c_prodI
133
134prodL = prodg . c_prodL
138 135
139prodg f x = unsafePerformIO $ do 136prodg f x = unsafePerformIO $ do
140 r <- createVector 1 137 r <- createVector 1
@@ -146,8 +143,8 @@ foreign import ccall unsafe "prodF" c_prodF :: TVV Float
146foreign import ccall unsafe "prodR" c_prodR :: TVV Double 143foreign import ccall unsafe "prodR" c_prodR :: TVV Double
147foreign import ccall unsafe "prodQ" c_prodQ :: TVV (Complex Float) 144foreign import ccall unsafe "prodQ" c_prodQ :: TVV (Complex Float)
148foreign import ccall unsafe "prodC" c_prodC :: TVV (Complex Double) 145foreign import ccall unsafe "prodC" c_prodC :: TVV (Complex Double)
149foreign import ccall unsafe "prodI" c_prodI :: TVV (CInt) 146foreign import ccall unsafe "prodI" c_prodI :: I -> TVV I
150foreign import ccall unsafe "prodL" c_prodL :: TVV Z 147foreign import ccall unsafe "prodL" c_prodL :: Z -> TVV Z
151 148
152------------------------------------------------------------------ 149------------------------------------------------------------------
153 150