diff options
Diffstat (limited to 'packages/base/src')
-rw-r--r-- | packages/base/src/Internal/C/lapack-aux.c | 87 | ||||
-rw-r--r-- | packages/base/src/Internal/C/lapack-aux.h | 1 | ||||
-rw-r--r-- | packages/base/src/Internal/C/vector-aux.c | 7 | ||||
-rw-r--r-- | packages/base/src/Internal/Element.hs | 3 | ||||
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 32 | ||||
-rw-r--r-- | packages/base/src/Internal/Modular.hs | 43 | ||||
-rw-r--r-- | packages/base/src/Internal/ST.hs | 26 | ||||
-rw-r--r-- | packages/base/src/Internal/Util.hs | 80 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra.hs | 2 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Devel.hs | 1 |
10 files changed, 262 insertions, 20 deletions
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c index dcce1c5..e42889d 100644 --- a/packages/base/src/Internal/C/lapack-aux.c +++ b/packages/base/src/Internal/C/lapack-aux.c | |||
@@ -4,6 +4,12 @@ | |||
4 | #include <math.h> | 4 | #include <math.h> |
5 | #include <time.h> | 5 | #include <time.h> |
6 | #include <inttypes.h> | 6 | #include <inttypes.h> |
7 | #include <complex.h> | ||
8 | |||
9 | typedef double complex TCD; | ||
10 | typedef float complex TCF; | ||
11 | |||
12 | #undef complex | ||
7 | 13 | ||
8 | #include "lapack-aux.h" | 14 | #include "lapack-aux.h" |
9 | 15 | ||
@@ -46,6 +52,10 @@ | |||
46 | #define NODEFPOS 2006 | 52 | #define NODEFPOS 2006 |
47 | #define NOSPRTD 2007 | 53 | #define NOSPRTD 2007 |
48 | 54 | ||
55 | inline int mod (int a, int b); | ||
56 | |||
57 | inline int64_t mod_l (int64_t a, int64_t b); | ||
58 | |||
49 | //--------------------------------------- | 59 | //--------------------------------------- |
50 | void asm_finit() { | 60 | void asm_finit() { |
51 | #ifdef i386 | 61 | #ifdef i386 |
@@ -1310,6 +1320,83 @@ int multiplyQ(int ta, int tb, KQMAT(a),KQMAT(b),QMAT(r)) { | |||
1310 | int multiplyI(int m, KOIMAT(a), KOIMAT(b), OIMAT(r)) MULT_IMP | 1320 | int multiplyI(int m, KOIMAT(a), KOIMAT(b), OIMAT(r)) MULT_IMP |
1311 | int multiplyL(int64_t m, KOLMAT(a), KOLMAT(b), OLMAT(r)) MULT_IMP | 1321 | int multiplyL(int64_t m, KOLMAT(a), KOLMAT(b), OLMAT(r)) MULT_IMP |
1312 | 1322 | ||
1323 | /////////////////////////////// inplace row ops //////////////////////////////// | ||
1324 | |||
1325 | #define AXPY_IMP { \ | ||
1326 | int j; \ | ||
1327 | for(j=j1; j<=j2; j++) { \ | ||
1328 | AT(r,i2,j) += a*AT(r,i1,j); \ | ||
1329 | } OK } | ||
1330 | |||
1331 | #define AXPY_MOD_IMP(M) { \ | ||
1332 | int j; \ | ||
1333 | for(j=j1; j<=j2; j++) { \ | ||
1334 | AT(r,i2,j) = M(AT(r,i2,j) + M(a*AT(r,i1,j), m) , m); \ | ||
1335 | } OK } | ||
1336 | |||
1337 | |||
1338 | #define SCAL_IMP { \ | ||
1339 | int i,j; \ | ||
1340 | for(i=i1; i<=i2; i++) { \ | ||
1341 | for(j=j1; j<=j2; j++) { \ | ||
1342 | AT(r,i,j) = a*AT(r,i,j); \ | ||
1343 | } \ | ||
1344 | } OK } | ||
1345 | |||
1346 | #define SCAL_MOD_IMP(M) { \ | ||
1347 | int i,j; \ | ||
1348 | for(i=i1; i<=i2; i++) { \ | ||
1349 | for(j=j1; j<=j2; j++) { \ | ||
1350 | AT(r,i,j) = M(a*AT(r,i,j) , m); \ | ||
1351 | } \ | ||
1352 | } OK } | ||
1353 | |||
1354 | |||
1355 | #define SWAP_IMP(T) { \ | ||
1356 | T aux; \ | ||
1357 | int k; \ | ||
1358 | if (i1 != i2) { \ | ||
1359 | for (k=j1; k<=j2; k++) { \ | ||
1360 | aux = AT(r,i1,k); \ | ||
1361 | AT(r,i1,k) = AT(r,i2,k); \ | ||
1362 | AT(r,i2,k) = aux; \ | ||
1363 | } \ | ||
1364 | } OK } | ||
1365 | |||
1366 | |||
1367 | #define ROWOP_IMP(T) { \ | ||
1368 | T a = *pa; \ | ||
1369 | switch(code) { \ | ||
1370 | case 0: AXPY_IMP \ | ||
1371 | case 1: SCAL_IMP \ | ||
1372 | case 2: SWAP_IMP(T) \ | ||
1373 | default: ERROR(BAD_CODE); \ | ||
1374 | } \ | ||
1375 | } | ||
1376 | |||
1377 | #define ROWOP_MOD_IMP(T,M) { \ | ||
1378 | T a = *pa; \ | ||
1379 | switch(code) { \ | ||
1380 | case 0: AXPY_MOD_IMP(M) \ | ||
1381 | case 1: SCAL_MOD_IMP(M) \ | ||
1382 | case 2: SWAP_IMP(T) \ | ||
1383 | default: ERROR(BAD_CODE); \ | ||
1384 | } \ | ||
1385 | } | ||
1386 | |||
1387 | |||
1388 | #define ROWOP(T) int rowop_##T(int code, T* pa, int i1, int i2, int j1, int j2, MATG(T,r)) ROWOP_IMP(T) | ||
1389 | |||
1390 | #define ROWOP_MOD(T,M) int rowop_mod_##T(T m, int code, T* pa, int i1, int i2, int j1, int j2, MATG(T,r)) ROWOP_MOD_IMP(T,M) | ||
1391 | |||
1392 | ROWOP(double) | ||
1393 | ROWOP(float) | ||
1394 | ROWOP(TCD) | ||
1395 | ROWOP(TCF) | ||
1396 | ROWOP(int32_t) | ||
1397 | ROWOP(int64_t) | ||
1398 | ROWOP_MOD(int32_t,mod) | ||
1399 | ROWOP_MOD(int64_t,mod_l) | ||
1313 | 1400 | ||
1314 | ////////////////// sparse matrix-product /////////////////////////////////////// | 1401 | ////////////////// sparse matrix-product /////////////////////////////////////// |
1315 | 1402 | ||
diff --git a/packages/base/src/Internal/C/lapack-aux.h b/packages/base/src/Internal/C/lapack-aux.h index 1549bb5..e4d95bc 100644 --- a/packages/base/src/Internal/C/lapack-aux.h +++ b/packages/base/src/Internal/C/lapack-aux.h | |||
@@ -59,6 +59,7 @@ typedef short ftnlen; | |||
59 | #define OQMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, complex* A##p | 59 | #define OQMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, complex* A##p |
60 | #define OCMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, doublecomplex* A##p | 60 | #define OCMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, doublecomplex* A##p |
61 | 61 | ||
62 | #define MATG(T,A) int A##r, int A##c, int A##Xr, int A##Xc, T* A##p | ||
62 | 63 | ||
63 | #define KIVEC(A) int A##n, const int*A##p | 64 | #define KIVEC(A) int A##n, const int*A##p |
64 | #define KLVEC(A) int A##n, const int64_t*A##p | 65 | #define KLVEC(A) int A##n, const int64_t*A##p |
diff --git a/packages/base/src/Internal/C/vector-aux.c b/packages/base/src/Internal/C/vector-aux.c index c161556..5528a9d 100644 --- a/packages/base/src/Internal/C/vector-aux.c +++ b/packages/base/src/Internal/C/vector-aux.c | |||
@@ -716,6 +716,7 @@ int mapValF(int code, float* pval, KFVEC(x), FVEC(r)) { | |||
716 | } | 716 | } |
717 | } | 717 | } |
718 | 718 | ||
719 | inline | ||
719 | int mod (int a, int b) { | 720 | int mod (int a, int b) { |
720 | int m = a % b; | 721 | int m = a % b; |
721 | if (b>0) { | 722 | if (b>0) { |
@@ -741,7 +742,7 @@ int mapValI(int code, int* pval, KIVEC(x), IVEC(r)) { | |||
741 | } | 742 | } |
742 | } | 743 | } |
743 | 744 | ||
744 | 745 | inline | |
745 | int64_t mod_l (int64_t a, int64_t b) { | 746 | int64_t mod_l (int64_t a, int64_t b) { |
746 | int64_t m = a % b; | 747 | int64_t m = a % b; |
747 | if (b>0) { | 748 | if (b>0) { |
@@ -1230,7 +1231,7 @@ int round_vector_i(KDVEC(v),IVEC(r)) { | |||
1230 | int mod_vector(int m, KIVEC(v), IVEC(r)) { | 1231 | int mod_vector(int m, KIVEC(v), IVEC(r)) { |
1231 | int k; | 1232 | int k; |
1232 | for(k=0; k<vn; k++) { | 1233 | for(k=0; k<vn; k++) { |
1233 | rp[k] = vp[k] % m; | 1234 | rp[k] = mod(vp[k],m); |
1234 | } | 1235 | } |
1235 | OK | 1236 | OK |
1236 | } | 1237 | } |
@@ -1266,7 +1267,7 @@ int round_vector_l(KDVEC(v),LVEC(r)) { | |||
1266 | int mod_vector_l(int64_t m, KLVEC(v), LVEC(r)) { | 1267 | int mod_vector_l(int64_t m, KLVEC(v), LVEC(r)) { |
1267 | int k; | 1268 | int k; |
1268 | for(k=0; k<vn; k++) { | 1269 | for(k=0; k<vn; k++) { |
1269 | rp[k] = vp[k] % m; | 1270 | rp[k] = mod_l(vp[k],m); |
1270 | } | 1271 | } |
1271 | OK | 1272 | OK |
1272 | } | 1273 | } |
diff --git a/packages/base/src/Internal/Element.hs b/packages/base/src/Internal/Element.hs index 55bff67..4007491 100644 --- a/packages/base/src/Internal/Element.hs +++ b/packages/base/src/Internal/Element.hs | |||
@@ -30,6 +30,7 @@ import Text.Printf | |||
30 | import Data.List(transpose,intersperse) | 30 | import Data.List(transpose,intersperse) |
31 | import Data.List.Split(chunksOf) | 31 | import Data.List.Split(chunksOf) |
32 | import Foreign.Storable(Storable) | 32 | import Foreign.Storable(Storable) |
33 | import System.IO.Unsafe(unsafePerformIO) | ||
33 | import Control.Monad(liftM) | 34 | import Control.Monad(liftM) |
34 | 35 | ||
35 | ------------------------------------------------------------------- | 36 | ------------------------------------------------------------------- |
@@ -147,7 +148,7 @@ m ?? (e, TakeLast n) = m ?? (e, Drop (cols m - n)) | |||
147 | m ?? (DropLast n, e) = m ?? (Take (rows m - n), e) | 148 | m ?? (DropLast n, e) = m ?? (Take (rows m - n), e) |
148 | m ?? (e, DropLast n) = m ?? (e, Take (cols m - n)) | 149 | m ?? (e, DropLast n) = m ?? (e, Take (cols m - n)) |
149 | 150 | ||
150 | m ?? (er,ec) = extractR m moder rs modec cs | 151 | m ?? (er,ec) = unsafePerformIO $ extractR m moder rs modec cs |
151 | where | 152 | where |
152 | (moder,rs) = mkExt (rows m) er | 153 | (moder,rs) = mkExt (rows m) er |
153 | (modec,cs) = mkExt (cols m) ec | 154 | (modec,cs) = mkExt (cols m) ec |
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 8de06ce..fa1aad6 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -20,6 +20,7 @@ import Internal.Vector | |||
20 | import Internal.Devel | 20 | import Internal.Devel |
21 | import Internal.Vectorized | 21 | import Internal.Vectorized |
22 | import Foreign.Marshal.Alloc ( free ) | 22 | import Foreign.Marshal.Alloc ( free ) |
23 | import Foreign.Marshal.Array(newArray) | ||
23 | import Foreign.Ptr ( Ptr ) | 24 | import Foreign.Ptr ( Ptr ) |
24 | import Foreign.Storable ( Storable ) | 25 | import Foreign.Storable ( Storable ) |
25 | import Data.Complex ( Complex ) | 26 | import Data.Complex ( Complex ) |
@@ -273,12 +274,13 @@ compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2 | |||
273 | class (Storable a) => Element a where | 274 | class (Storable a) => Element a where |
274 | transdata :: Int -> Vector a -> Int -> Vector a | 275 | transdata :: Int -> Vector a -> Int -> Vector a |
275 | constantD :: a -> Int -> Vector a | 276 | constantD :: a -> Int -> Vector a |
276 | extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> Matrix a | 277 | extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) |
277 | sortI :: Ord a => Vector a -> Vector CInt | 278 | sortI :: Ord a => Vector a -> Vector CInt |
278 | sortV :: Ord a => Vector a -> Vector a | 279 | sortV :: Ord a => Vector a -> Vector a |
279 | compareV :: Ord a => Vector a -> Vector a -> Vector CInt | 280 | compareV :: Ord a => Vector a -> Vector a -> Vector CInt |
280 | selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a | 281 | selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a |
281 | remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a | 282 | remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a |
283 | rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () | ||
282 | 284 | ||
283 | 285 | ||
284 | instance Element Float where | 286 | instance Element Float where |
@@ -290,6 +292,7 @@ instance Element Float where | |||
290 | compareV = compareF | 292 | compareV = compareF |
291 | selectV = selectF | 293 | selectV = selectF |
292 | remapM = remapF | 294 | remapM = remapF |
295 | rowOp = rowOpAux c_rowOpF | ||
293 | 296 | ||
294 | instance Element Double where | 297 | instance Element Double where |
295 | transdata = transdataAux ctransR | 298 | transdata = transdataAux ctransR |
@@ -300,6 +303,7 @@ instance Element Double where | |||
300 | compareV = compareD | 303 | compareV = compareD |
301 | selectV = selectD | 304 | selectV = selectD |
302 | remapM = remapD | 305 | remapM = remapD |
306 | rowOp = rowOpAux c_rowOpD | ||
303 | 307 | ||
304 | 308 | ||
305 | instance Element (Complex Float) where | 309 | instance Element (Complex Float) where |
@@ -311,6 +315,7 @@ instance Element (Complex Float) where | |||
311 | compareV = undefined | 315 | compareV = undefined |
312 | selectV = selectQ | 316 | selectV = selectQ |
313 | remapM = remapQ | 317 | remapM = remapQ |
318 | rowOp = rowOpAux c_rowOpQ | ||
314 | 319 | ||
315 | 320 | ||
316 | instance Element (Complex Double) where | 321 | instance Element (Complex Double) where |
@@ -322,6 +327,7 @@ instance Element (Complex Double) where | |||
322 | compareV = undefined | 327 | compareV = undefined |
323 | selectV = selectC | 328 | selectV = selectC |
324 | remapM = remapC | 329 | remapM = remapC |
330 | rowOp = rowOpAux c_rowOpC | ||
325 | 331 | ||
326 | instance Element (CInt) where | 332 | instance Element (CInt) where |
327 | transdata = transdataAux ctransI | 333 | transdata = transdataAux ctransI |
@@ -332,6 +338,7 @@ instance Element (CInt) where | |||
332 | compareV = compareI | 338 | compareV = compareI |
333 | selectV = selectI | 339 | selectV = selectI |
334 | remapM = remapI | 340 | remapM = remapI |
341 | rowOp = rowOpAux c_rowOpI | ||
335 | 342 | ||
336 | instance Element Z where | 343 | instance Element Z where |
337 | transdata = transdataAux ctransL | 344 | transdata = transdataAux ctransL |
@@ -342,6 +349,7 @@ instance Element Z where | |||
342 | compareV = compareL | 349 | compareV = compareL |
343 | selectV = selectL | 350 | selectV = selectL |
344 | remapM = remapL | 351 | remapM = remapL |
352 | rowOp = rowOpAux c_rowOpL | ||
345 | 353 | ||
346 | ------------------------------------------------------------------- | 354 | ------------------------------------------------------------------- |
347 | 355 | ||
@@ -379,7 +387,7 @@ subMatrix :: Element a | |||
379 | -> Matrix a -- ^ result | 387 | -> Matrix a -- ^ result |
380 | subMatrix (r0,c0) (rt,ct) m | 388 | subMatrix (r0,c0) (rt,ct) m |
381 | | 0 <= r0 && 0 <= rt && r0+rt <= rows m && | 389 | | 0 <= r0 && 0 <= rt && r0+rt <= rows m && |
382 | 0 <= c0 && 0 <= ct && c0+ct <= cols m = extractR m 0 (idxs[r0,r0+rt-1]) 0 (idxs[c0,c0+ct-1]) | 390 | 0 <= c0 && 0 <= ct && c0+ct <= cols m = unsafePerformIO $ extractR m 0 (idxs[r0,r0+rt-1]) 0 (idxs[c0,c0+ct-1]) |
383 | | otherwise = error $ "wrong subMatrix "++ | 391 | | otherwise = error $ "wrong subMatrix "++ |
384 | show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) | 392 | show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) |
385 | 393 | ||
@@ -430,7 +438,7 @@ instance (Storable t, NFData t) => NFData (Matrix t) | |||
430 | 438 | ||
431 | --------------------------------------------------------------- | 439 | --------------------------------------------------------------- |
432 | 440 | ||
433 | extractAux f m moder vr modec vc = unsafePerformIO $ do | 441 | extractAux f m moder vr modec vc = do |
434 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr | 442 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr |
435 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc | 443 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc |
436 | r <- createMatrix RowMajor nr nc | 444 | r <- createMatrix RowMajor nr nc |
@@ -538,6 +546,24 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z | |||
538 | 546 | ||
539 | -------------------------------------------------------------------------------- | 547 | -------------------------------------------------------------------------------- |
540 | 548 | ||
549 | rowOpAux f c x i1 i2 j1 j2 m = do | ||
550 | px <- newArray [x] | ||
551 | app1 (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) omat m "rowOp" | ||
552 | free px | ||
553 | |||
554 | type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok | ||
555 | |||
556 | foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R | ||
557 | foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float | ||
558 | foreign import ccall unsafe "rowop_TCD" c_rowOpC :: RowOp C | ||
559 | foreign import ccall unsafe "rowop_TCF" c_rowOpQ :: RowOp (Complex Float) | ||
560 | foreign import ccall unsafe "rowop_int32_t" c_rowOpI :: RowOp I | ||
561 | foreign import ccall unsafe "rowop_int64_t" c_rowOpL :: RowOp Z | ||
562 | foreign import ccall unsafe "rowop_mod_int32_t" c_rowOpMI :: I -> RowOp I | ||
563 | foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z | ||
564 | |||
565 | -------------------------------------------------------------------------------- | ||
566 | |||
541 | foreign import ccall unsafe "saveMatrix" c_saveMatrix | 567 | foreign import ccall unsafe "saveMatrix" c_saveMatrix |
542 | :: CString -> CString -> Double ..> Ok | 568 | :: CString -> CString -> Double ..> Ok |
543 | 569 | ||
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index 1289a21..824fc57 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs | |||
@@ -111,18 +111,46 @@ instance forall n t . (Integral t, KnownNat n) => Num (Mod n t) | |||
111 | fromInteger = l0 (\m x -> fromInteger x `mod` (fromIntegral m)) | 111 | fromInteger = l0 (\m x -> fromInteger x `mod` (fromIntegral m)) |
112 | 112 | ||
113 | 113 | ||
114 | instance KnownNat m => Element (Mod m I) | ||
115 | where | ||
116 | transdata n v m = i2f (transdata n (f2i v) m) | ||
117 | constantD x n = i2f (constantD (unMod x) n) | ||
118 | extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js | ||
119 | sortI = sortI . f2i | ||
120 | sortV = i2f . sortV . f2i | ||
121 | compareV u v = compareV (f2i u) (f2i v) | ||
122 | selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) | ||
123 | remapM i j m = i2fM (remap i j (f2iM m)) | ||
124 | rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x) | ||
125 | where | ||
126 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
114 | 127 | ||
115 | instance (Ord t, Element t) => Element (Mod n t) | 128 | instance KnownNat m => Element (Mod m Z) |
116 | where | 129 | where |
117 | transdata n v m = i2f (transdata n (f2i v) m) | 130 | transdata n v m = i2f (transdata n (f2i v) m) |
118 | constantD x n = i2f (constantD (unMod x) n) | 131 | constantD x n = i2f (constantD (unMod x) n) |
119 | extractR m mi is mj js = i2fM (extractR (f2iM m) mi is mj js) | 132 | extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js |
120 | sortI = sortI . f2i | 133 | sortI = sortI . f2i |
121 | sortV = i2f . sortV . f2i | 134 | sortV = i2f . sortV . f2i |
122 | compareV u v = compareV (f2i u) (f2i v) | 135 | compareV u v = compareV (f2i u) (f2i v) |
123 | selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) | 136 | selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) |
124 | remapM i j m = i2fM (remap i j (f2iM m)) | 137 | remapM i j m = i2fM (remap i j (f2iM m)) |
138 | rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x) | ||
139 | where | ||
140 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
125 | 141 | ||
142 | {- | ||
143 | instance (Ord t, Element t) => Element (Mod m t) | ||
144 | where | ||
145 | transdata n v m = i2f (transdata n (f2i v) m) | ||
146 | constantD x n = i2f (constantD (unMod x) n) | ||
147 | extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js | ||
148 | sortI = sortI . f2i | ||
149 | sortV = i2f . sortV . f2i | ||
150 | compareV u v = compareV (f2i u) (f2i v) | ||
151 | selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) | ||
152 | remapM i j m = i2fM (remap i j (f2iM m)) | ||
153 | -} | ||
126 | 154 | ||
127 | instance forall m . KnownNat m => Container Vector (Mod m I) | 155 | instance forall m . KnownNat m => Container Vector (Mod m I) |
128 | where | 156 | where |
@@ -205,12 +233,10 @@ instance forall m . KnownNat m => Container Vector (Mod m Z) | |||
205 | toZ' = f2i | 233 | toZ' = f2i |
206 | 234 | ||
207 | 235 | ||
208 | |||
209 | instance (Storable t, Indexable (Vector t) t) => Indexable (Vector (Mod m t)) (Mod m t) | 236 | instance (Storable t, Indexable (Vector t) t) => Indexable (Vector (Mod m t)) (Mod m t) |
210 | where | 237 | where |
211 | (!) = (@>) | 238 | (!) = (@>) |
212 | 239 | ||
213 | |||
214 | type instance RealOf (Mod n I) = I | 240 | type instance RealOf (Mod n I) = I |
215 | type instance RealOf (Mod n Z) = Z | 241 | type instance RealOf (Mod n Z) = Z |
216 | 242 | ||
@@ -270,6 +296,15 @@ instance forall m . KnownNat m => Num (Vector (Mod m I)) | |||
270 | negate = lift1 negate | 296 | negate = lift1 negate |
271 | fromInteger x = fromInt (fromInteger x) | 297 | fromInteger x = fromInt (fromInteger x) |
272 | 298 | ||
299 | instance forall m . KnownNat m => Num (Vector (Mod m Z)) | ||
300 | where | ||
301 | (+) = lift2 (+) | ||
302 | (*) = lift2 (*) | ||
303 | (-) = lift2 (-) | ||
304 | abs = lift1 abs | ||
305 | signum = lift1 signum | ||
306 | negate = lift1 negate | ||
307 | fromInteger x = fromZ (fromInteger x) | ||
273 | 308 | ||
274 | -------------------------------------------------------------------------------- | 309 | -------------------------------------------------------------------------------- |
275 | 310 | ||
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index ae75a1b..107d3c3 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs | |||
@@ -1,5 +1,6 @@ | |||
1 | {-# LANGUAGE Rank2Types #-} | 1 | {-# LANGUAGE Rank2Types #-} |
2 | {-# LANGUAGE BangPatterns #-} | 2 | {-# LANGUAGE BangPatterns #-} |
3 | |||
3 | ----------------------------------------------------------------------------- | 4 | ----------------------------------------------------------------------------- |
4 | -- | | 5 | -- | |
5 | -- Module : Internal.ST | 6 | -- Module : Internal.ST |
@@ -20,6 +21,8 @@ module Internal.ST ( | |||
20 | -- * Mutable Matrices | 21 | -- * Mutable Matrices |
21 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, | 22 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, |
22 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, | 23 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, |
24 | axpy, scal, swap, extractRect, | ||
25 | mutable, | ||
23 | -- * Unsafe functions | 26 | -- * Unsafe functions |
24 | newUndefinedVector, | 27 | newUndefinedVector, |
25 | unsafeReadVector, unsafeWriteVector, | 28 | unsafeReadVector, unsafeWriteVector, |
@@ -34,8 +37,6 @@ import Internal.Matrix | |||
34 | import Internal.Vectorized | 37 | import Internal.Vectorized |
35 | import Control.Monad.ST(ST, runST) | 38 | import Control.Monad.ST(ST, runST) |
36 | import Foreign.Storable(Storable, peekElemOff, pokeElemOff) | 39 | import Foreign.Storable(Storable, peekElemOff, pokeElemOff) |
37 | |||
38 | |||
39 | import Control.Monad.ST.Unsafe(unsafeIOToST) | 40 | import Control.Monad.ST.Unsafe(unsafeIOToST) |
40 | 41 | ||
41 | {-# INLINE ioReadV #-} | 42 | {-# INLINE ioReadV #-} |
@@ -144,6 +145,7 @@ liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x | |||
144 | unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) | 145 | unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) |
145 | unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x | 146 | unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x |
146 | 147 | ||
148 | |||
147 | freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) | 149 | freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) |
148 | freezeMatrix m = liftSTMatrix id m | 150 | freezeMatrix m = liftSTMatrix id m |
149 | 151 | ||
@@ -171,3 +173,23 @@ newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c | |||
171 | newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t) | 173 | newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t) |
172 | newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c) | 174 | newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c) |
173 | 175 | ||
176 | -------------------------------------------------------------------------------- | ||
177 | |||
178 | rowOpST :: Element t => Int -> t -> Int -> Int -> Int -> Int -> STMatrix s t -> ST s () | ||
179 | rowOpST c x i1 i2 j1 j2 (STMatrix m) = unsafeIOToST (rowOp c x i1 i2 j1 j2 m) | ||
180 | |||
181 | axpy (STMatrix m) a i j = rowOpST 0 a i j 0 (cols m -1) (STMatrix m) | ||
182 | scal (STMatrix m) a i = rowOpST 1 a i i 0 (cols m -1) (STMatrix m) | ||
183 | swap (STMatrix m) i j = rowOpST 2 0 i j 0 (cols m -1) (STMatrix m) | ||
184 | |||
185 | extractRect (STMatrix m) i1 i2 j1 j2 = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) | ||
186 | |||
187 | -------------------------------------------------------------------------------- | ||
188 | |||
189 | mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) | ||
190 | mutable f a = runST $ do | ||
191 | x <- thawMatrix a | ||
192 | info <- f (rows a, cols a) x | ||
193 | r <- unsafeFreezeMatrix x | ||
194 | return (r,info) | ||
195 | |||
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs index b1fb800..7a556e9 100644 --- a/packages/base/src/Internal/Util.hs +++ b/packages/base/src/Internal/Util.hs | |||
@@ -54,7 +54,7 @@ module Internal.Util( | |||
54 | -- ** 2D | 54 | -- ** 2D |
55 | corr2, conv2, separable, | 55 | corr2, conv2, separable, |
56 | block2x2,block3x3,view1,unView1,foldMatrix, | 56 | block2x2,block3x3,view1,unView1,foldMatrix, |
57 | gaussElim | 57 | gaussElim_1, gaussElim_2, gaussElim |
58 | ) where | 58 | ) where |
59 | 59 | ||
60 | import Internal.Vector | 60 | import Internal.Vector |
@@ -64,17 +64,19 @@ import Internal.Element | |||
64 | import Internal.Container | 64 | import Internal.Container |
65 | import Internal.Vectorized | 65 | import Internal.Vectorized |
66 | import Internal.IO | 66 | import Internal.IO |
67 | import Internal.Algorithms hiding (i,Normed) | 67 | import Internal.Algorithms hiding (i,Normed,swap) |
68 | import Numeric.Matrix() | 68 | import Numeric.Matrix() |
69 | import Numeric.Vector() | 69 | import Numeric.Vector() |
70 | import Internal.Random | 70 | import Internal.Random |
71 | import Internal.Convolution | 71 | import Internal.Convolution |
72 | import Control.Monad(when) | 72 | import Control.Monad(when,forM_) |
73 | import Text.Printf | 73 | import Text.Printf |
74 | import Data.List.Split(splitOn) | 74 | import Data.List.Split(splitOn) |
75 | import Data.List(intercalate,) | 75 | import Data.List(intercalate,sortBy) |
76 | import Control.Arrow((&&&)) | 76 | import Control.Arrow((&&&)) |
77 | import Data.Complex | 77 | import Data.Complex |
78 | import Data.Function(on) | ||
79 | import Internal.ST | ||
78 | 80 | ||
79 | type ℝ = Double | 81 | type ℝ = Double |
80 | type ℕ = Int | 82 | type ℕ = Int |
@@ -359,6 +361,10 @@ instance Indexable (Vector I) I | |||
359 | where | 361 | where |
360 | (!) = (@>) | 362 | (!) = (@>) |
361 | 363 | ||
364 | instance Indexable (Vector Z) Z | ||
365 | where | ||
366 | (!) = (@>) | ||
367 | |||
362 | instance Indexable (Vector (Complex Double)) (Complex Double) | 368 | instance Indexable (Vector (Complex Double)) (Complex Double) |
363 | where | 369 | where |
364 | (!) = (@>) | 370 | (!) = (@>) |
@@ -550,11 +556,11 @@ down g a = foldMatrix g f a | |||
550 | -- | 556 | -- |
551 | -- @a <> gaussElim a b = b@ | 557 | -- @a <> gaussElim a b = b@ |
552 | -- | 558 | -- |
553 | gaussElim | 559 | gaussElim_2 |
554 | :: (Eq t, Fractional t, Num (Vector t), Numeric t) | 560 | :: (Eq t, Fractional t, Num (Vector t), Numeric t) |
555 | => Matrix t -> Matrix t -> Matrix t | 561 | => Matrix t -> Matrix t -> Matrix t |
556 | 562 | ||
557 | gaussElim a b = flipudrl r | 563 | gaussElim_2 a b = flipudrl r |
558 | where | 564 | where |
559 | flipudrl = flipud . fliprl | 565 | flipudrl = flipud . fliprl |
560 | splitColsAt n = (takeColumns n &&& dropColumns n) | 566 | splitColsAt n = (takeColumns n &&& dropColumns n) |
@@ -564,6 +570,68 @@ gaussElim a b = flipudrl r | |||
564 | 570 | ||
565 | -------------------------------------------------------------------------------- | 571 | -------------------------------------------------------------------------------- |
566 | 572 | ||
573 | gaussElim_1 | ||
574 | :: (Fractional t, Num (Vector t), Ord t, Indexable (Vector t) t, Numeric t) | ||
575 | => Matrix t -> Matrix t -> Matrix t | ||
576 | |||
577 | gaussElim_1 x y = dropColumns (rows x) (flipud $ fromRows s2) | ||
578 | where | ||
579 | rs = toRows $ fromBlocks [[x , y]] | ||
580 | s1 = fromRows $ pivotDown (rows x) 0 rs -- interesting | ||
581 | s2 = pivotUp (rows x-1) (toRows $ flipud s1) | ||
582 | |||
583 | pivotDown t n xs | ||
584 | | t == n = [] | ||
585 | | otherwise = y : pivotDown t (n+1) ys | ||
586 | where | ||
587 | y:ys = redu (pivot n xs) | ||
588 | |||
589 | pivot k = (const k &&& id) | ||
590 | . sortBy (flip compare `on` (abs. (!k))) | ||
591 | |||
592 | redu (k,x:zs) | ||
593 | | p == 0 = error "gauss: singular!" -- FIXME | ||
594 | | otherwise = u : map f zs | ||
595 | where | ||
596 | p = x!k | ||
597 | u = scale (recip (x!k)) x | ||
598 | f z = z - scale (z!k) u | ||
599 | redu (_,[]) = [] | ||
600 | |||
601 | |||
602 | pivotUp n xs | ||
603 | | n == -1 = [] | ||
604 | | otherwise = y : pivotUp (n-1) ys | ||
605 | where | ||
606 | y:ys = redu' (n,xs) | ||
607 | |||
608 | redu' (k,x:zs) = u : map f zs | ||
609 | where | ||
610 | u = x | ||
611 | f z = z - scale (z!k) u | ||
612 | redu' (_,[]) = [] | ||
613 | |||
614 | -------------------------------------------------------------------------------- | ||
615 | |||
616 | gaussElim a b = dropColumns (rows a) $ fst $ mutable gaussST (fromBlocks [[a,b]]) | ||
617 | |||
618 | gaussST (r,_) x = do | ||
619 | let n = r-1 | ||
620 | forM_ [0..n] $ \i -> do | ||
621 | c <- maxIndex . abs . flatten <$> extractRect x i n i i | ||
622 | swap x i (i+c) | ||
623 | a <- readMatrix x i i | ||
624 | scal x (recip a) i | ||
625 | forM_ [i+1..n] $ \j -> do | ||
626 | b <- readMatrix x j i | ||
627 | axpy x (-b) i j | ||
628 | forM_ [n,n-1..1] $ \i -> do | ||
629 | forM_ [i-1,i-2..0] $ \j -> do | ||
630 | b <- readMatrix x j i | ||
631 | axpy x (-b) i j | ||
632 | |||
633 | -------------------------------------------------------------------------------- | ||
634 | |||
567 | instance Testable (Matrix I) where | 635 | instance Testable (Matrix I) where |
568 | checkT _ = test | 636 | checkT _ = test |
569 | 637 | ||
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs index 56e5053..c97f415 100644 --- a/packages/base/src/Numeric/LinearAlgebra.hs +++ b/packages/base/src/Numeric/LinearAlgebra.hs | |||
@@ -134,7 +134,7 @@ module Numeric.LinearAlgebra ( | |||
134 | Seed, RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample, | 134 | Seed, RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample, |
135 | 135 | ||
136 | -- * Misc | 136 | -- * Misc |
137 | meanCov, rowOuters, pairwiseD2, unitary, peps, relativeError, haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv, gaussElim, | 137 | meanCov, rowOuters, pairwiseD2, unitary, peps, relativeError, haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv, gaussElim, gaussElim_1, gaussElim_2, |
138 | ℝ,ℂ,iC, | 138 | ℝ,ℂ,iC, |
139 | -- * Auxiliary classes | 139 | -- * Auxiliary classes |
140 | Element, Container, Product, Numeric, LSDiv, | 140 | Element, Container, Product, Numeric, LSDiv, |
diff --git a/packages/base/src/Numeric/LinearAlgebra/Devel.hs b/packages/base/src/Numeric/LinearAlgebra/Devel.hs index 1a70663..84763fe 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Devel.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Devel.hs | |||
@@ -44,6 +44,7 @@ module Numeric.LinearAlgebra.Devel( | |||
44 | -- ** Mutable Matrices | 44 | -- ** Mutable Matrices |
45 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, | 45 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, |
46 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, | 46 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, |
47 | axpy,scal,swap, extractRect, mutable, | ||
47 | -- ** Unsafe functions | 48 | -- ** Unsafe functions |
48 | newUndefinedVector, | 49 | newUndefinedVector, |
49 | unsafeReadVector, unsafeWriteVector, | 50 | unsafeReadVector, unsafeWriteVector, |