summaryrefslogtreecommitdiff
path: root/packages
diff options
context:
space:
mode:
Diffstat (limited to 'packages')
-rw-r--r--packages/base/src/Internal/C/lapack-aux.c104
-rw-r--r--packages/base/src/Internal/Matrix.hs8
-rw-r--r--packages/base/src/Internal/Modular.hs4
-rw-r--r--packages/base/src/Internal/ST.hs9
-rw-r--r--packages/tests/src/Numeric/LinearAlgebra/Tests.hs151
5 files changed, 202 insertions, 74 deletions
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c
index ca60846..30689bf 100644
--- a/packages/base/src/Internal/C/lapack-aux.c
+++ b/packages/base/src/Internal/C/lapack-aux.c
@@ -1093,16 +1093,15 @@ void dgemm_(char *, char *, integer *, integer *, integer *,
1093 integer *, double *, double *, integer *); 1093 integer *, double *, double *, integer *);
1094 1094
1095int multiplyR(int ta, int tb, KODMAT(a),KODMAT(b),ODMAT(r)) { 1095int multiplyR(int ta, int tb, KODMAT(a),KODMAT(b),ODMAT(r)) {
1096 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1097 DEBUGMSG("dgemm_"); 1096 DEBUGMSG("dgemm_");
1098 CHECKNANR(a,"NaN multR Input\n") 1097 CHECKNANR(a,"NaN multR Input\n")
1099 CHECKNANR(b,"NaN multR Input\n") 1098 CHECKNANR(b,"NaN multR Input\n")
1100 integer m = ta?ac:ar; 1099 integer m = ta?ac:ar;
1101 integer n = tb?br:bc; 1100 integer n = tb?br:bc;
1102 integer k = ta?ar:ac; 1101 integer k = ta?ar:ac;
1103 integer lda = ar; 1102 integer lda = aXc;
1104 integer ldb = br; 1103 integer ldb = bXc;
1105 integer ldc = rr; 1104 integer ldc = rXc;
1106 double alpha = 1; 1105 double alpha = 1;
1107 double beta = 0; 1106 double beta = 0;
1108 dgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc); 1107 dgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc);
@@ -1115,16 +1114,15 @@ void zgemm_(char *, char *, integer *, integer *, integer *,
1115 integer *, doublecomplex *, doublecomplex *, integer *); 1114 integer *, doublecomplex *, doublecomplex *, integer *);
1116 1115
1117int multiplyC(int ta, int tb, KOCMAT(a),KOCMAT(b),OCMAT(r)) { 1116int multiplyC(int ta, int tb, KOCMAT(a),KOCMAT(b),OCMAT(r)) {
1118 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1119 DEBUGMSG("zgemm_"); 1117 DEBUGMSG("zgemm_");
1120 CHECKNANC(a,"NaN multC Input\n") 1118 CHECKNANC(a,"NaN multC Input\n")
1121 CHECKNANC(b,"NaN multC Input\n") 1119 CHECKNANC(b,"NaN multC Input\n")
1122 integer m = ta?ac:ar; 1120 integer m = ta?ac:ar;
1123 integer n = tb?br:bc; 1121 integer n = tb?br:bc;
1124 integer k = ta?ar:ac; 1122 integer k = ta?ar:ac;
1125 integer lda = ar; 1123 integer lda = aXc;
1126 integer ldb = br; 1124 integer ldb = bXc;
1127 integer ldc = rr; 1125 integer ldc = rXc;
1128 doublecomplex alpha = {1,0}; 1126 doublecomplex alpha = {1,0};
1129 doublecomplex beta = {0,0}; 1127 doublecomplex beta = {0,0};
1130 zgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha, 1128 zgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,
@@ -1140,14 +1138,13 @@ void sgemm_(char *, char *, integer *, integer *, integer *,
1140 integer *, float *, float *, integer *); 1138 integer *, float *, float *, integer *);
1141 1139
1142int multiplyF(int ta, int tb, KOFMAT(a),KOFMAT(b),OFMAT(r)) { 1140int multiplyF(int ta, int tb, KOFMAT(a),KOFMAT(b),OFMAT(r)) {
1143 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1144 DEBUGMSG("sgemm_"); 1141 DEBUGMSG("sgemm_");
1145 integer m = ta?ac:ar; 1142 integer m = ta?ac:ar;
1146 integer n = tb?br:bc; 1143 integer n = tb?br:bc;
1147 integer k = ta?ar:ac; 1144 integer k = ta?ar:ac;
1148 integer lda = ar; 1145 integer lda = aXc;
1149 integer ldb = br; 1146 integer ldb = bXc;
1150 integer ldc = rr; 1147 integer ldc = rXc;
1151 float alpha = 1; 1148 float alpha = 1;
1152 float beta = 0; 1149 float beta = 0;
1153 sgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc); 1150 sgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc);
@@ -1159,14 +1156,13 @@ void cgemm_(char *, char *, integer *, integer *, integer *,
1159 integer *, complex *, complex *, integer *); 1156 integer *, complex *, complex *, integer *);
1160 1157
1161int multiplyQ(int ta, int tb, KOQMAT(a),KOQMAT(b),OQMAT(r)) { 1158int multiplyQ(int ta, int tb, KOQMAT(a),KOQMAT(b),OQMAT(r)) {
1162 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1163 DEBUGMSG("cgemm_"); 1159 DEBUGMSG("cgemm_");
1164 integer m = ta?ac:ar; 1160 integer m = ta?ac:ar;
1165 integer n = tb?br:bc; 1161 integer n = tb?br:bc;
1166 integer k = ta?ar:ac; 1162 integer k = ta?ar:ac;
1167 integer lda = ar; 1163 integer lda = aXc;
1168 integer ldb = br; 1164 integer ldb = bXc;
1169 integer ldc = rr; 1165 integer ldc = rXc;
1170 complex alpha = {1,0}; 1166 complex alpha = {1,0};
1171 complex beta = {0,0}; 1167 complex beta = {0,0};
1172 cgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha, 1168 cgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,
@@ -1187,15 +1183,15 @@ int multiplyQ(int ta, int tb, KOQMAT(a),KOQMAT(b),OQMAT(r)) {
1187 } \ 1183 } \
1188 } 1184 }
1189 1185
1190#define MULT_IMP { \ 1186#define MULT_IMP(M) { \
1191 if (m==1) { \ 1187 if (m==1) { \
1192 MULT_IMP_VER( AT(r,i,j) += AT(a,i,k) * AT(b,k,j); ) \ 1188 MULT_IMP_VER( AT(r,i,j) += AT(a,i,k) * AT(b,k,j); ) \
1193 } else { \ 1189 } else { \
1194 MULT_IMP_VER( AT(r,i,j) = (AT(r,i,j) + (AT(a,i,k) * AT(b,k,j)) % m) % m ; ) \ 1190 MULT_IMP_VER( AT(r,i,j) = M(AT(r,i,j) + M(AT(a,i,k) * AT(b,k,j), m) , m) ; ) \
1195 } OK } 1191 } OK }
1196 1192
1197int multiplyI(int m, KOIMAT(a), KOIMAT(b), OIMAT(r)) MULT_IMP 1193int multiplyI(int m, KOIMAT(a), KOIMAT(b), OIMAT(r)) MULT_IMP(mod)
1198int multiplyL(int64_t m, KOLMAT(a), KOLMAT(b), OLMAT(r)) MULT_IMP 1194int multiplyL(int64_t m, KOLMAT(a), KOLMAT(b), OLMAT(r)) MULT_IMP(mod_l)
1199 1195
1200/////////////////////////////// inplace row ops //////////////////////////////// 1196/////////////////////////////// inplace row ops ////////////////////////////////
1201 1197
@@ -1277,27 +1273,19 @@ ROWOP_MOD(int64_t,mod_l)
1277 1273
1278/////////////////////////////// inplace GEMM //////////////////////////////// 1274/////////////////////////////// inplace GEMM ////////////////////////////////
1279 1275
1280#define GEMM(T) int gemm_##T(VECG(T,c),VECG(int,p),MATG(T,a),MATG(T,b),MATG(T,r)) { \ 1276#define GEMM(T) int gemm_##T(VECG(T,c),MATG(T,a),MATG(T,b),MATG(T,r)) { \
1281 T a = cp[0], b = cp[1]; \ 1277 T a = cp[0], b = cp[1]; \
1282 int r1a = pp[0], c1a = pp[2], c2a = pp[3] ; \ 1278 T t; \
1283 int r1b = pp[4], c1b = pp[6] ; \ 1279 int k; \
1284 int r1r = pp[8], r2r = pp[9], c1r = pp[10], c2r = pp[11]; \ 1280 { TRAV(r,i,j) { \
1285 int dra = r1a - r1r; \ 1281 t = 0; \
1286 int dcb = c1b-c1r; \ 1282 for(k=0; k<ac; k++) { \
1287 int nk = c2a-c1a+1; \ 1283 t += AT(a,i,k) * AT(b,k,j); \
1288 int i,j,k; \ 1284 } \
1289 T t; \ 1285 AT(r,i,j) = b*AT(r,i,j) + a*t; \
1290 for (i=r1r; i<=r2r; i++) { \ 1286 } \
1291 for (j=c1r; j<=c2r; j++) { \ 1287 } OK }
1292 t = 0; \ 1288
1293 for(k=0; k<nk; k++) { \
1294 t += AT(a,i+dra,k+c1a) * AT(b,k+r1b,j+dcb); \
1295 } \
1296 AT(r,i,j) = b*AT(r,i,j) + a*t; \
1297 } \
1298 } \
1299 OK \
1300}
1301 1289
1302GEMM(double) 1290GEMM(double)
1303GEMM(float) 1291GEMM(float)
@@ -1306,27 +1294,19 @@ GEMM(TCF)
1306GEMM(int32_t) 1294GEMM(int32_t)
1307GEMM(int64_t) 1295GEMM(int64_t)
1308 1296
1309#define GEMM_MOD(T,M) int gemm_mod_##T(T m, VECG(T,c),VECG(int,p),MATG(T,a),MATG(T,b),MATG(T,r)) { \ 1297#define GEMM_MOD(T,M) int gemm_mod_##T(T m, VECG(T,c),MATG(T,a),MATG(T,b),MATG(T,r)) { \
1310 T a = cp[0], b = cp[1]; \ 1298 T a = cp[0], b = cp[1]; \
1311 int r1a = pp[0], c1a = pp[2], c2a = pp[3] ; \ 1299 int k; \
1312 int r1b = pp[4], c1b = pp[6] ; \ 1300 T t; \
1313 int r1r = pp[8], r2r = pp[9], c1r = pp[10], c2r = pp[11]; \ 1301 { TRAV(r,i,j) { \
1314 int dra = r1a - r1r; \ 1302 t = 0; \
1315 int dcb = c1b-c1r; \ 1303 for(k=0; k<ac; k++) { \
1316 int nk = c2a-c1a+1; \ 1304 t = M(t+M(AT(a,i,k) * AT(b,k,j))); \
1317 int i,j,k; \ 1305 } \
1318 T t; \ 1306 AT(r,i,j) = M(M(b*AT(r,i,j)) + M(a*t)); \
1319 for (i=r1r; i<=r2r; i++) { \ 1307 } \
1320 for (j=c1r; j<=c2r; j++) { \ 1308 } OK }
1321 t = 0; \ 1309
1322 for(k=0; k<nk; k++) { \
1323 t = M(t+M(AT(a,i+dra,k+c1a) * AT(b,k+r1b,j+dcb))); \
1324 } \
1325 AT(r,i,j) = M(M(b*AT(r,i,j)) + M(a*t)); \
1326 } \
1327 } \
1328 OK \
1329}
1330 1310
1331#define MOD32(X) mod(X,m) 1311#define MOD32(X) mod(X,m)
1332#define MOD64(X) mod_l(X,m) 1312#define MOD64(X) mod_l(X,m)
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index 8597dcb..a789cae 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -226,6 +226,8 @@ atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m))
226 226
227------------------------------------------------------------------ 227------------------------------------------------------------------
228 228
229matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 }
230matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d }
229matrixFromVector o r c v 231matrixFromVector o r c v
230 | r * c == dim v = m 232 | r * c == dim v = m
231 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m 233 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m
@@ -280,7 +282,7 @@ class (Storable a) => Element a where
280 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a 282 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a
281 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a 283 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
282 rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () 284 rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
283 gemm :: Vector a -> Vector I -> Matrix a -> Matrix a -> Matrix a -> IO () 285 gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO ()
284 286
285 287
286instance Element Float where 288instance Element Float where
@@ -569,9 +571,9 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z
569 571
570-------------------------------------------------------------------------------- 572--------------------------------------------------------------------------------
571 573
572gemmg f u v m1 m2 m3 = f # u # v # m1 # m2 # m3 #|"gemmg" 574gemmg f v m1 m2 m3 = f # v # m1 # m2 # m3 #|"gemmg"
573 575
574type Tgemm x = x :> I :> x ::> x ::> x ::> Ok 576type Tgemm x = x :> x ::> x ::> x ::> Ok
575 577
576foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R 578foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R
577foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float 579foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs
index 54d9cb8..8fa2747 100644
--- a/packages/base/src/Internal/Modular.hs
+++ b/packages/base/src/Internal/Modular.hs
@@ -137,7 +137,7 @@ instance KnownNat m => Element (Mod m I)
137 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x) 137 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x)
138 where 138 where
139 m' = fromIntegral . natVal $ (undefined :: Proxy m) 139 m' = fromIntegral . natVal $ (undefined :: Proxy m)
140 gemm u p a b c = gemmg (c_gemmMI m') (f2i u) p (f2iM a) (f2iM b) (f2iM c) 140 gemm u a b c = gemmg (c_gemmMI m') (f2i u) (f2iM a) (f2iM b) (f2iM c)
141 where 141 where
142 m' = fromIntegral . natVal $ (undefined :: Proxy m) 142 m' = fromIntegral . natVal $ (undefined :: Proxy m)
143 143
@@ -154,7 +154,7 @@ instance KnownNat m => Element (Mod m Z)
154 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x) 154 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x)
155 where 155 where
156 m' = fromIntegral . natVal $ (undefined :: Proxy m) 156 m' = fromIntegral . natVal $ (undefined :: Proxy m)
157 gemm u p a b c = gemmg (c_gemmML m') (f2i u) p (f2iM a) (f2iM b) (f2iM c) 157 gemm u a b c = gemmg (c_gemmML m') (f2i u) (f2iM a) (f2iM b) (f2iM c)
158 where 158 where
159 m' = fromIntegral . natVal $ (undefined :: Proxy m) 159 m' = fromIntegral . natVal $ (undefined :: Proxy m)
160 160
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index 91c2a11..62dfddf 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -231,14 +231,13 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[
231-- | r0 c0 height width 231-- | r0 c0 height width
232data Slice s t = Slice (STMatrix s t) Int Int Int Int 232data Slice s t = Slice (STMatrix s t) Int Int Int Int
233 233
234slice (Slice (STMatrix m) r0 c0 nr nc) = (m, idxs[r0,r0+nr-1,c0,c0+nc-1]) 234slice (Slice (STMatrix m) r0 c0 nr nc) = sliceMatrix (r0,c0) (nr,nc) m
235 235
236gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () 236gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s ()
237gemmm beta (slice->(r,pr)) alpha (slice->(a,pa)) (slice->(b,pb)) = res 237gemmm beta (slice->r) alpha (slice->a) (slice->b) = res
238 where 238 where
239 res = unsafeIOToST (gemm u v a b r) 239 res = unsafeIOToST (gemm v a b r)
240 u = fromList [alpha,beta] 240 v = fromList [alpha,beta]
241 v = vjoin[pa,pb,pr]
242 241
243 242
244mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) 243mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u)
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
index b226c9f..79cb769 100644
--- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
+++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
@@ -4,6 +4,8 @@
4{-# LANGUAGE TypeFamilies #-} 4{-# LANGUAGE TypeFamilies #-}
5{-# LANGUAGE FlexibleContexts #-} 5{-# LANGUAGE FlexibleContexts #-}
6{-# LANGUAGE RankNTypes #-} 6{-# LANGUAGE RankNTypes #-}
7{-# LANGUAGE TypeOperators #-}
8{-# LANGUAGE ViewPatterns #-}
7 9
8----------------------------------------------------------------------------- 10-----------------------------------------------------------------------------
9{- | 11{- |
@@ -76,7 +78,7 @@ detTest1 = det m == 26
76 && det mc == 38 :+ (-3) 78 && det mc == 38 :+ (-3)
77 && det (feye 2) == -1 79 && det (feye 2) == -1
78 where 80 where
79 m = (3><3) 81 m = (3><3)
80 [ 1, 2, 3 82 [ 1, 2, 3
81 , 4, 5, 7 83 , 4, 5, 7
82 , 2, 8, 4 :: Double 84 , 2, 8, 4 :: Double
@@ -357,7 +359,7 @@ accumTest = utest "accum" ok
357 ,0,1,7 359 ,0,1,7
358 ,0,0,4] 360 ,0,0,4]
359 && 361 &&
360 toList (flatten x) == [1,0,0,0,1,0,0,0,1] 362 toList (flatten x) == [1,0,0,0,1,0,0,0,1]
361 363
362-------------------------------------------------------------------------------- 364--------------------------------------------------------------------------------
363 365
@@ -400,6 +402,150 @@ indexProp g f x = a1 == g a2 && a2 == a3 && b1 == g b2 && b2 == b3
400 402
401-------------------------------------------------------------------------------- 403--------------------------------------------------------------------------------
402 404
405sliceTest = utest "slice test" $ and
406 [ testSlice chol (gen 5 :: Matrix R)
407 , testSlice chol (gen 5 :: Matrix C)
408 , testSlice qr (rec :: Matrix R)
409 , testSlice qr (rec :: Matrix C)
410 , testSlice hess (agen 5 :: Matrix R)
411 , testSlice hess (agen 5 :: Matrix C)
412 , testSlice schur (agen 5 :: Matrix R)
413 , testSlice schur (agen 5 :: Matrix C)
414 , testSlice lu (agen 5 :: Matrix R)
415 , testSlice lu (agen 5 :: Matrix C)
416 , testSlice (luSolve (luPacked (agen 5 :: Matrix R))) (agen 5)
417 , testSlice (luSolve (luPacked (agen 5 :: Matrix C))) (agen 5)
418 , test_lus (agen 5 :: Matrix R)
419 , test_lus (agen 5 :: Matrix C)
420
421 , testSlice eig (agen 5 :: Matrix R)
422 , testSlice eig (agen 5 :: Matrix C)
423 , testSlice eigSH (gen 5 :: Matrix R)
424 , testSlice eigSH (gen 5 :: Matrix C)
425 , testSlice eigenvalues (agen 5 :: Matrix R)
426 , testSlice eigenvalues (agen 5 :: Matrix C)
427 , testSlice eigenvaluesSH (gen 5 :: Matrix R)
428 , testSlice eigenvaluesSH (gen 5 :: Matrix C)
429
430 , testSlice svd (rec :: Matrix R)
431 , testSlice thinSVD (rec :: Matrix R)
432 , testSlice compactSVD (rec :: Matrix R)
433 , testSlice leftSV (rec :: Matrix R)
434 , testSlice rightSV (rec :: Matrix R)
435 , testSlice singularValues (rec :: Matrix R)
436
437 , testSlice svd (rec :: Matrix C)
438 , testSlice thinSVD (rec :: Matrix C)
439 , testSlice compactSVD (rec :: Matrix C)
440 , testSlice leftSV (rec :: Matrix C)
441 , testSlice rightSV (rec :: Matrix C)
442 , testSlice singularValues (rec :: Matrix C)
443
444 , testSlice (linearSolve (agen 5:: Matrix R)) (agen 5)
445 , testSlice (flip linearSolve (agen 5:: Matrix R)) (agen 5)
446
447 , testSlice (linearSolve (agen 5:: Matrix C)) (agen 5)
448 , testSlice (flip linearSolve (agen 5:: Matrix C)) (agen 5)
449
450 , testSlice (linearSolveLS (ogen 5:: Matrix R)) (ogen 5)
451 , testSlice (flip linearSolveLS (ogen 5:: Matrix R)) (ogen 5)
452
453 , testSlice (linearSolveLS (ogen 5:: Matrix C)) (ogen 5)
454 , testSlice (flip linearSolveLS (ogen 5:: Matrix C)) (ogen 5)
455
456 , testSlice (linearSolveSVD (ogen 5:: Matrix R)) (ogen 5)
457 , testSlice (flip linearSolveSVD (ogen 5:: Matrix R)) (ogen 5)
458
459 , testSlice (linearSolveSVD (ogen 5:: Matrix C)) (ogen 5)
460 , testSlice (flip linearSolveSVD (ogen 5:: Matrix C)) (ogen 5)
461
462 , testSlice (linearSolveLS (ugen 5:: Matrix R)) (ugen 5)
463 , testSlice (flip linearSolveLS (ugen 5:: Matrix R)) (ugen 5)
464
465 , testSlice (linearSolveLS (ugen 5:: Matrix C)) (ugen 5)
466 , testSlice (flip linearSolveLS (ugen 5:: Matrix C)) (ugen 5)
467
468 , testSlice (linearSolveSVD (ugen 5:: Matrix R)) (ugen 5)
469 , testSlice (flip linearSolveSVD (ugen 5:: Matrix R)) (ugen 5)
470
471 , testSlice (linearSolveSVD (ugen 5:: Matrix C)) (ugen 5)
472 , testSlice (flip linearSolveSVD (ugen 5:: Matrix C)) (ugen 5)
473
474 , testSlice ((<>) (ogen 5:: Matrix R)) (gen 5)
475 , testSlice (flip (<>) (gen 5:: Matrix R)) (ogen 5)
476 , testSlice ((<>) (ogen 5:: Matrix C)) (gen 5)
477 , testSlice (flip (<>) (gen 5:: Matrix C)) (ogen 5)
478 , testSlice ((<>) (ogen 5:: Matrix Float)) (gen 5)
479 , testSlice (flip (<>) (gen 5:: Matrix Float)) (ogen 5)
480 , testSlice ((<>) (ogen 5:: Matrix (Complex Float))) (gen 5)
481 , testSlice (flip (<>) (gen 5:: Matrix (Complex Float))) (ogen 5)
482 , testSlice ((<>) (ogen 5:: Matrix I)) (gen 5)
483 , testSlice (flip (<>) (gen 5:: Matrix I)) (ogen 5)
484 , testSlice ((<>) (ogen 5:: Matrix Z)) (gen 5)
485 , testSlice (flip (<>) (gen 5:: Matrix Z)) (ogen 5)
486
487 , testSlice ((<>) (ogen 5:: Matrix (I ./. 7))) (gen 5)
488 , testSlice (flip (<>) (gen 5:: Matrix (I ./. 7))) (ogen 5)
489 , testSlice ((<>) (ogen 5:: Matrix (Z ./. 7))) (gen 5)
490 , testSlice (flip (<>) (gen 5:: Matrix (Z ./. 7))) (ogen 5)
491
492 , testSlice (flip cholSolve (agen 5:: Matrix R)) (chol $ gen 5)
493 , testSlice (flip cholSolve (agen 5:: Matrix C)) (chol $ gen 5)
494 , testSlice (cholSolve (chol $ gen 5:: Matrix R)) (agen 5)
495 , testSlice (cholSolve (chol $ gen 5:: Matrix C)) (agen 5)
496
497 , ok_qrgr (rec :: Matrix R)
498 , ok_qrgr (rec :: Matrix C)
499 , testSlice (test_qrgr 4 tau1) qrr1
500 , testSlice (test_qrgr 4 tau2) qrr2
501 ]
502 where
503 (qrr1,tau1) = qrRaw (rec :: Matrix R)
504 (qrr2,tau2) = qrRaw (rec :: Matrix C)
505
506 test_qrgr n t x = qrgr n (x,t)
507
508 ok_qrgr x = simeq 1E-15 q q'
509 where
510 (q,_) = qr x
511 atau = qrRaw x
512 q' = qrgr (rows q) atau
513
514 simeq eps a b = not $ magnit eps (norm_1 $ flatten (a-b))
515
516 test_lus m = testSlice f lup
517 where
518 f x = luSolve (x,p) m
519 (lup,p) = luPacked m
520
521 gen :: Numeric t => Int -> Matrix t
522 gen n = diagRect 1 (konst 5 n) n n
523
524 agen :: (Numeric t, Num (Vector t))=> Int -> Matrix t
525 agen n = gen n + fromInt ((n><n)[0..])
526
527 ogen :: (Numeric t, Num (Vector t))=> Int -> Matrix t
528 ogen n = gen n === gen n
529
530 ugen :: (Numeric t, Num (Vector t))=> Int -> Matrix t
531 ugen n = takeRows 3 (gen n)
532
533
534 rec :: Numeric t => Matrix t
535 rec = subMatrix (0,0) (4,5) (gen 5)
536
537 testSlice f x@(size->sz@(r,c)) = all (==f x) (map f (g y1 ++ g y2))
538 where
539 subm = sliceMatrix
540 g y = [ subm (a*r,b*c) sz y | a <-[0..2], b <- [0..2]]
541 h z = fromBlocks (replicate 3 (replicate 3 z))
542 y1 = h x
543 y2 = (tr . h . tr) x
544
545
546
547--------------------------------------------------------------------------------
548
403-- | All tests must pass with a maximum dimension of about 20 549-- | All tests must pass with a maximum dimension of about 20
404-- (some tests may fail with bigger sizes due to precision loss). 550-- (some tests may fail with bigger sizes due to precision loss).
405runTests :: Int -- ^ maximum dimension 551runTests :: Int -- ^ maximum dimension
@@ -578,6 +724,7 @@ runTests n = do
578 , staticTest 724 , staticTest
579 , intTest 725 , intTest
580 , modularTest 726 , modularTest
727 , sliceTest
581 ] 728 ]
582 when (errors c + failures c > 0) exitFailure 729 when (errors c + failures c > 0) exitFailure
583 return () 730 return ()