diff options
Diffstat (limited to 'packages/base/src/Internal')
-rw-r--r-- | packages/base/src/Internal/C/lapack-aux.c | 38 | ||||
-rw-r--r-- | packages/base/src/Internal/C/vector-aux.c | 32 | ||||
-rw-r--r-- | packages/base/src/Internal/LAPACK.hs | 16 | ||||
-rw-r--r-- | packages/base/src/Internal/Modular.hs | 26 | ||||
-rw-r--r-- | packages/base/src/Internal/Numeric.hs | 12 | ||||
-rw-r--r-- | packages/base/src/Internal/Vectorized.hs | 21 |
6 files changed, 85 insertions, 60 deletions
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c index 7da6f6a..1601bef 100644 --- a/packages/base/src/Internal/C/lapack-aux.c +++ b/packages/base/src/Internal/C/lapack-aux.c | |||
@@ -1290,29 +1290,25 @@ int multiplyQ(int ta, int tb, KQMAT(a),KQMAT(b),QMAT(r)) { | |||
1290 | } | 1290 | } |
1291 | 1291 | ||
1292 | 1292 | ||
1293 | int multiplyI(KOIMAT(a), KOIMAT(b), OIMAT(r)) { | 1293 | #define MULT_IMP_VER(OP) \ |
1294 | { TRAV(r,i,j) { | 1294 | { TRAV(r,i,j) { \ |
1295 | int k; | 1295 | int k; \ |
1296 | AT(r,i,j) = 0; | 1296 | AT(r,i,j) = 0; \ |
1297 | for (k=0;k<ac;k++) { | 1297 | for (k=0;k<ac;k++) { \ |
1298 | AT(r,i,j) += AT(a,i,k) * AT(b,k,j); | 1298 | OP \ |
1299 | } | 1299 | } \ |
1300 | } | 1300 | } \ |
1301 | } | 1301 | } |
1302 | OK | ||
1303 | } | ||
1304 | 1302 | ||
1305 | int multiplyL(KOLMAT(a), KOLMAT(b), OLMAT(r)) { | 1303 | #define MULT_IMP { \ |
1306 | { TRAV(r,i,j) { | 1304 | if (m==1) { \ |
1307 | int k; | 1305 | MULT_IMP_VER( AT(r,i,j) += AT(a,i,k) * AT(b,k,j); ) \ |
1308 | AT(r,i,j) = 0; | 1306 | } else { \ |
1309 | for (k=0;k<ac;k++) { | 1307 | MULT_IMP_VER( AT(r,i,j) = (AT(r,i,j) + (AT(a,i,k) * AT(b,k,j)) % m) % m ; ) \ |
1310 | AT(r,i,j) += AT(a,i,k) * AT(b,k,j); | 1308 | } OK } |
1311 | } | 1309 | |
1312 | } | 1310 | int multiplyI(int m, KOIMAT(a), KOIMAT(b), OIMAT(r)) MULT_IMP |
1313 | } | 1311 | int multiplyL(int32_t m, KOLMAT(a), KOLMAT(b), OLMAT(r)) MULT_IMP |
1314 | OK | ||
1315 | } | ||
1316 | 1312 | ||
1317 | 1313 | ||
1318 | ////////////////// sparse matrix-product /////////////////////////////////////// | 1314 | ////////////////// sparse matrix-product /////////////////////////////////////// |
diff --git a/packages/base/src/Internal/C/vector-aux.c b/packages/base/src/Internal/C/vector-aux.c index 70e46bc..580aa1c 100644 --- a/packages/base/src/Internal/C/vector-aux.c +++ b/packages/base/src/Internal/C/vector-aux.c | |||
@@ -58,20 +58,28 @@ int sumR(KDVEC(x),DVEC(r)) { | |||
58 | OK | 58 | OK |
59 | } | 59 | } |
60 | 60 | ||
61 | int sumI(KIVEC(x),IVEC(r)) { | 61 | int sumI(int m, KIVEC(x),IVEC(r)) { |
62 | REQUIRES(rn==1,BAD_SIZE); | 62 | REQUIRES(rn==1,BAD_SIZE); |
63 | int i; | 63 | int i; |
64 | int res = 0; | 64 | int res = 0; |
65 | for (i = 0; i < xn; i++) res += xp[i]; | 65 | if (m==1) { |
66 | for (i = 0; i < xn; i++) res += xp[i]; | ||
67 | } else { | ||
68 | for (i = 0; i < xn; i++) res = (res + xp[i]) % m; | ||
69 | } | ||
66 | rp[0] = res; | 70 | rp[0] = res; |
67 | OK | 71 | OK |
68 | } | 72 | } |
69 | 73 | ||
70 | int sumL(KLVEC(x),LVEC(r)) { | 74 | int sumL(int32_t m, KLVEC(x),LVEC(r)) { |
71 | REQUIRES(rn==1,BAD_SIZE); | 75 | REQUIRES(rn==1,BAD_SIZE); |
72 | int i; | 76 | int i; |
73 | int res = 0; | 77 | int res = 0; |
74 | for (i = 0; i < xn; i++) res += xp[i]; | 78 | if (m==1) { |
79 | for (i = 0; i < xn; i++) res += xp[i]; | ||
80 | } else { | ||
81 | for (i = 0; i < xn; i++) res = (res + xp[i]) % m; | ||
82 | } | ||
75 | rp[0] = res; | 83 | rp[0] = res; |
76 | OK | 84 | OK |
77 | } | 85 | } |
@@ -127,20 +135,28 @@ int prodR(KDVEC(x),DVEC(r)) { | |||
127 | OK | 135 | OK |
128 | } | 136 | } |
129 | 137 | ||
130 | int prodI(KIVEC(x),IVEC(r)) { | 138 | int prodI(int m, KIVEC(x),IVEC(r)) { |
131 | REQUIRES(rn==1,BAD_SIZE); | 139 | REQUIRES(rn==1,BAD_SIZE); |
132 | int i; | 140 | int i; |
133 | int res = 1; | 141 | int res = 1; |
134 | for (i = 0; i < xn; i++) res *= xp[i]; | 142 | if (m==1) { |
143 | for (i = 0; i < xn; i++) res *= xp[i]; | ||
144 | } else { | ||
145 | for (i = 0; i < xn; i++) res = (res * xp[i]) % m; | ||
146 | } | ||
135 | rp[0] = res; | 147 | rp[0] = res; |
136 | OK | 148 | OK |
137 | } | 149 | } |
138 | 150 | ||
139 | int prodL(KLVEC(x),LVEC(r)) { | 151 | int prodL(int32_t m, KLVEC(x),LVEC(r)) { |
140 | REQUIRES(rn==1,BAD_SIZE); | 152 | REQUIRES(rn==1,BAD_SIZE); |
141 | int i; | 153 | int i; |
142 | int res = 1; | 154 | int res = 1; |
143 | for (i = 0; i < xn; i++) res *= xp[i]; | 155 | if (m==1) { |
156 | for (i = 0; i < xn; i++) res *= xp[i]; | ||
157 | } else { | ||
158 | for (i = 0; i < xn; i++) res = (res * xp[i]) % m; | ||
159 | } | ||
144 | rp[0] = res; | 160 | rp[0] = res; |
145 | OK | 161 | OK |
146 | } | 162 | } |
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index 469b0d5..8df568d 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs | |||
@@ -36,8 +36,8 @@ foreign import ccall unsafe "multiplyR" dgemmc :: CInt -> CInt -> TMMM R | |||
36 | foreign import ccall unsafe "multiplyC" zgemmc :: CInt -> CInt -> TMMM C | 36 | foreign import ccall unsafe "multiplyC" zgemmc :: CInt -> CInt -> TMMM C |
37 | foreign import ccall unsafe "multiplyF" sgemmc :: CInt -> CInt -> TMMM F | 37 | foreign import ccall unsafe "multiplyF" sgemmc :: CInt -> CInt -> TMMM F |
38 | foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TMMM Q | 38 | foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TMMM Q |
39 | foreign import ccall unsafe "multiplyI" c_multiplyI :: CInt ::> CInt ::> CInt ::> Ok | 39 | foreign import ccall unsafe "multiplyI" c_multiplyI :: I -> I ::> I ::> I ::> Ok |
40 | foreign import ccall unsafe "multiplyL" c_multiplyL :: Z ::> Z ::> Z ::> Ok | 40 | foreign import ccall unsafe "multiplyL" c_multiplyL :: Z -> Z ::> Z ::> Z ::> Ok |
41 | 41 | ||
42 | isT Matrix{order = ColumnMajor} = 0 | 42 | isT Matrix{order = ColumnMajor} = 0 |
43 | isT Matrix{order = RowMajor} = 1 | 43 | isT Matrix{order = RowMajor} = 1 |
@@ -68,20 +68,20 @@ multiplyF a b = multiplyAux sgemmc "sgemmc" a b | |||
68 | multiplyQ :: Matrix (Complex Float) -> Matrix (Complex Float) -> Matrix (Complex Float) | 68 | multiplyQ :: Matrix (Complex Float) -> Matrix (Complex Float) -> Matrix (Complex Float) |
69 | multiplyQ a b = multiplyAux cgemmc "cgemmc" a b | 69 | multiplyQ a b = multiplyAux cgemmc "cgemmc" a b |
70 | 70 | ||
71 | multiplyI :: Matrix CInt -> Matrix CInt -> Matrix CInt | 71 | multiplyI :: I -> Matrix CInt -> Matrix CInt -> Matrix CInt |
72 | multiplyI a b = unsafePerformIO $ do | 72 | multiplyI m a b = unsafePerformIO $ do |
73 | when (cols a /= rows b) $ error $ | 73 | when (cols a /= rows b) $ error $ |
74 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b | 74 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b |
75 | s <- createMatrix ColumnMajor (rows a) (cols b) | 75 | s <- createMatrix ColumnMajor (rows a) (cols b) |
76 | app3 c_multiplyI omat a omat b omat s "c_multiplyI" | 76 | app3 (c_multiplyI m) omat a omat b omat s "c_multiplyI" |
77 | return s | 77 | return s |
78 | 78 | ||
79 | multiplyL :: Matrix Z -> Matrix Z -> Matrix Z | 79 | multiplyL :: Z -> Matrix Z -> Matrix Z -> Matrix Z |
80 | multiplyL a b = unsafePerformIO $ do | 80 | multiplyL m a b = unsafePerformIO $ do |
81 | when (cols a /= rows b) $ error $ | 81 | when (cols a /= rows b) $ error $ |
82 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b | 82 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b |
83 | s <- createMatrix ColumnMajor (rows a) (cols b) | 83 | s <- createMatrix ColumnMajor (rows a) (cols b) |
84 | app3 c_multiplyL omat a omat b omat s "c_multiplyL" | 84 | app3 (c_multiplyL m) omat a omat b omat s "c_multiplyL" |
85 | return s | 85 | return s |
86 | 86 | ||
87 | ----------------------------------------------------------------------------- | 87 | ----------------------------------------------------------------------------- |
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index 36ffb57..0274607 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs | |||
@@ -30,6 +30,8 @@ import Internal.Matrix hiding (mat,size) | |||
30 | import Internal.Numeric | 30 | import Internal.Numeric |
31 | import Internal.Element | 31 | import Internal.Element |
32 | import Internal.Container | 32 | import Internal.Container |
33 | import Internal.Vectorized (prodI,sumI) | ||
34 | import Internal.LAPACK (multiplyI) | ||
33 | import Internal.Util(Indexable(..),gaussElim) | 35 | import Internal.Util(Indexable(..),gaussElim) |
34 | import GHC.TypeLits | 36 | import GHC.TypeLits |
35 | import Data.Proxy(Proxy) | 37 | import Data.Proxy(Proxy) |
@@ -145,8 +147,12 @@ instance forall m . KnownNat m => Container Vector (F m) | |||
145 | maxIndex' = maxIndex . f2i | 147 | maxIndex' = maxIndex . f2i |
146 | minElement' = Mod . minElement . f2i | 148 | minElement' = Mod . minElement . f2i |
147 | maxElement' = Mod . maxElement . f2i | 149 | maxElement' = Mod . maxElement . f2i |
148 | sumElements' = fromIntegral . sumElements . f2i -- FIXME | 150 | sumElements' = fromIntegral . sumI m' . f2i |
149 | prodElements' = fromIntegral . sumElements . f2i -- FIXME | 151 | where |
152 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
153 | prodElements' = fromIntegral . prodI m' . f2i | ||
154 | where | ||
155 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
150 | step' = i2f . step . f2i | 156 | step' = i2f . step . f2i |
151 | find' = findV | 157 | find' = findV |
152 | assoc' = assocV | 158 | assoc' = assocV |
@@ -170,14 +176,14 @@ instance Indexable (Vector (F m)) (F m) | |||
170 | 176 | ||
171 | type instance RealOf (F n) = I | 177 | type instance RealOf (F n) = I |
172 | 178 | ||
173 | |||
174 | instance KnownNat m => Product (F m) where | 179 | instance KnownNat m => Product (F m) where |
175 | norm2 = undefined | 180 | norm2 = undefined |
176 | absSum = undefined | 181 | absSum = undefined |
177 | norm1 = undefined | 182 | norm1 = undefined |
178 | normInf = undefined | 183 | normInf = undefined |
179 | multiply = lift2 multiply -- FIXME | 184 | multiply = lift2 (multiplyI m') |
180 | 185 | where | |
186 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
181 | 187 | ||
182 | instance KnownNat m => Numeric (F m) | 188 | instance KnownNat m => Numeric (F m) |
183 | 189 | ||
@@ -236,6 +242,9 @@ test = (ok, info) | |||
236 | ad = fromInt a :: Matrix Double | 242 | ad = fromInt a :: Matrix Double |
237 | bd = fromInt b :: Matrix Double | 243 | bd = fromInt b :: Matrix Double |
238 | 244 | ||
245 | g = (3><3) (repeat (40000)) :: Matrix I | ||
246 | gm = fromInt g :: Matrix (F 100000) | ||
247 | |||
239 | info = do | 248 | info = do |
240 | print v | 249 | print v |
241 | print m | 250 | print m |
@@ -247,10 +256,17 @@ test = (ok, info) | |||
247 | 256 | ||
248 | print $ am <> gaussElim am bm - bm | 257 | print $ am <> gaussElim am bm - bm |
249 | print $ ad <> gaussElim ad bd - bd | 258 | print $ ad <> gaussElim ad bd - bd |
259 | |||
260 | print g | ||
261 | print $ g <> g | ||
262 | print gm | ||
263 | print $ gm <> gm | ||
250 | 264 | ||
251 | ok = and | 265 | ok = and |
252 | [ toInt (m #> v) == cmod 11 (toInt m #> toInt v ) | 266 | [ toInt (m #> v) == cmod 11 (toInt m #> toInt v ) |
253 | , am <> gaussElim am bm == bm | 267 | , am <> gaussElim am bm == bm |
268 | , prodElements (konst (9:: F 10) (12::Int)) == product (replicate 12 (9:: F 10)) | ||
269 | , gm <> gm == konst 0 (3,3) | ||
254 | ] | 270 | ] |
255 | 271 | ||
256 | 272 | ||
diff --git a/packages/base/src/Internal/Numeric.hs b/packages/base/src/Internal/Numeric.hs index eb744d1..2ef96bf 100644 --- a/packages/base/src/Internal/Numeric.hs +++ b/packages/base/src/Internal/Numeric.hs | |||
@@ -113,8 +113,8 @@ instance Container Vector I | |||
113 | maxIndex' = emptyErrorV "maxIndex" (fromIntegral . toScalarI MaxIdx) | 113 | maxIndex' = emptyErrorV "maxIndex" (fromIntegral . toScalarI MaxIdx) |
114 | minElement' = emptyErrorV "minElement" (toScalarI Min) | 114 | minElement' = emptyErrorV "minElement" (toScalarI Min) |
115 | maxElement' = emptyErrorV "maxElement" (toScalarI Max) | 115 | maxElement' = emptyErrorV "maxElement" (toScalarI Max) |
116 | sumElements' = sumI | 116 | sumElements' = sumI 1 |
117 | prodElements' = prodI | 117 | prodElements' = prodI 1 |
118 | step' = stepI | 118 | step' = stepI |
119 | find' = findV | 119 | find' = findV |
120 | assoc' = assocV | 120 | assoc' = assocV |
@@ -152,8 +152,8 @@ instance Container Vector Z | |||
152 | maxIndex' = emptyErrorV "maxIndex" (fromIntegral . toScalarL MaxIdx) | 152 | maxIndex' = emptyErrorV "maxIndex" (fromIntegral . toScalarL MaxIdx) |
153 | minElement' = emptyErrorV "minElement" (toScalarL Min) | 153 | minElement' = emptyErrorV "minElement" (toScalarL Min) |
154 | maxElement' = emptyErrorV "maxElement" (toScalarL Max) | 154 | maxElement' = emptyErrorV "maxElement" (toScalarL Max) |
155 | sumElements' = sumL | 155 | sumElements' = sumL 1 |
156 | prodElements' = prodL | 156 | prodElements' = prodL 1 |
157 | step' = stepL | 157 | step' = stepL |
158 | find' = findV | 158 | find' = findV |
159 | assoc' = assocV | 159 | assoc' = assocV |
@@ -596,14 +596,14 @@ instance Product I where | |||
596 | absSum = emptyVal (sumElements . vectorMapI Abs) | 596 | absSum = emptyVal (sumElements . vectorMapI Abs) |
597 | norm1 = absSum | 597 | norm1 = absSum |
598 | normInf = emptyVal (maxElement . vectorMapI Abs) | 598 | normInf = emptyVal (maxElement . vectorMapI Abs) |
599 | multiply = emptyMul multiplyI | 599 | multiply = emptyMul (multiplyI 1) |
600 | 600 | ||
601 | instance Product Z where | 601 | instance Product Z where |
602 | norm2 = undefined | 602 | norm2 = undefined |
603 | absSum = emptyVal (sumElements . vectorMapL Abs) | 603 | absSum = emptyVal (sumElements . vectorMapL Abs) |
604 | norm1 = absSum | 604 | norm1 = absSum |
605 | normInf = emptyVal (maxElement . vectorMapL Abs) | 605 | normInf = emptyVal (maxElement . vectorMapL Abs) |
606 | multiply = emptyMul multiplyL | 606 | multiply = emptyMul (multiplyL 1) |
607 | 607 | ||
608 | 608 | ||
609 | emptyMul m a b | 609 | emptyMul m a b |
diff --git a/packages/base/src/Internal/Vectorized.hs b/packages/base/src/Internal/Vectorized.hs index ff51494..5c89ac9 100644 --- a/packages/base/src/Internal/Vectorized.hs +++ b/packages/base/src/Internal/Vectorized.hs | |||
@@ -94,11 +94,9 @@ sumQ = sumg c_sumQ | |||
94 | sumC :: Vector (Complex Double) -> Complex Double | 94 | sumC :: Vector (Complex Double) -> Complex Double |
95 | sumC = sumg c_sumC | 95 | sumC = sumg c_sumC |
96 | 96 | ||
97 | -- | sum of elements | 97 | sumI m = sumg (c_sumI m) |
98 | sumI :: Vector CInt -> CInt | ||
99 | sumI = sumg c_sumI | ||
100 | 98 | ||
101 | sumL = sumg c_sumL | 99 | sumL m = sumg (c_sumL m) |
102 | 100 | ||
103 | sumg f x = unsafePerformIO $ do | 101 | sumg f x = unsafePerformIO $ do |
104 | r <- createVector 1 | 102 | r <- createVector 1 |
@@ -111,8 +109,8 @@ foreign import ccall unsafe "sumF" c_sumF :: TVV Float | |||
111 | foreign import ccall unsafe "sumR" c_sumR :: TVV Double | 109 | foreign import ccall unsafe "sumR" c_sumR :: TVV Double |
112 | foreign import ccall unsafe "sumQ" c_sumQ :: TVV (Complex Float) | 110 | foreign import ccall unsafe "sumQ" c_sumQ :: TVV (Complex Float) |
113 | foreign import ccall unsafe "sumC" c_sumC :: TVV (Complex Double) | 111 | foreign import ccall unsafe "sumC" c_sumC :: TVV (Complex Double) |
114 | foreign import ccall unsafe "sumI" c_sumI :: TVV CInt | 112 | foreign import ccall unsafe "sumI" c_sumI :: I -> TVV I |
115 | foreign import ccall unsafe "sumL" c_sumL :: TVV Z | 113 | foreign import ccall unsafe "sumL" c_sumL :: Z -> TVV Z |
116 | 114 | ||
117 | -- | product of elements | 115 | -- | product of elements |
118 | prodF :: Vector Float -> Float | 116 | prodF :: Vector Float -> Float |
@@ -130,11 +128,10 @@ prodQ = prodg c_prodQ | |||
130 | prodC :: Vector (Complex Double) -> Complex Double | 128 | prodC :: Vector (Complex Double) -> Complex Double |
131 | prodC = prodg c_prodC | 129 | prodC = prodg c_prodC |
132 | 130 | ||
133 | -- | product of elements | ||
134 | prodI :: Vector CInt -> CInt | ||
135 | prodI = prodg c_prodI | ||
136 | 131 | ||
137 | prodL = prodg c_prodL | 132 | prodI = prodg . c_prodI |
133 | |||
134 | prodL = prodg . c_prodL | ||
138 | 135 | ||
139 | prodg f x = unsafePerformIO $ do | 136 | prodg f x = unsafePerformIO $ do |
140 | r <- createVector 1 | 137 | r <- createVector 1 |
@@ -146,8 +143,8 @@ foreign import ccall unsafe "prodF" c_prodF :: TVV Float | |||
146 | foreign import ccall unsafe "prodR" c_prodR :: TVV Double | 143 | foreign import ccall unsafe "prodR" c_prodR :: TVV Double |
147 | foreign import ccall unsafe "prodQ" c_prodQ :: TVV (Complex Float) | 144 | foreign import ccall unsafe "prodQ" c_prodQ :: TVV (Complex Float) |
148 | foreign import ccall unsafe "prodC" c_prodC :: TVV (Complex Double) | 145 | foreign import ccall unsafe "prodC" c_prodC :: TVV (Complex Double) |
149 | foreign import ccall unsafe "prodI" c_prodI :: TVV (CInt) | 146 | foreign import ccall unsafe "prodI" c_prodI :: I -> TVV I |
150 | foreign import ccall unsafe "prodL" c_prodL :: TVV Z | 147 | foreign import ccall unsafe "prodL" c_prodL :: Z -> TVV Z |
151 | 148 | ||
152 | ------------------------------------------------------------------ | 149 | ------------------------------------------------------------------ |
153 | 150 | ||