diff options
-rw-r--r-- | packages/base/src/Internal/Algorithms.hs | 138 | ||||
-rw-r--r-- | packages/base/src/Internal/CG.hs | 20 | ||||
-rw-r--r-- | packages/base/src/Internal/Modular.hs | 10 | ||||
-rw-r--r-- | packages/base/src/Internal/Numeric.hs | 50 | ||||
-rw-r--r-- | packages/base/src/Internal/Util.hs | 14 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra.hs | 47 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/HMatrix.hs | 3 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Static.hs | 12 | ||||
-rw-r--r-- | packages/tests/src/Numeric/LinearAlgebra/Tests.hs | 48 | ||||
-rw-r--r-- | packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs | 18 | ||||
-rw-r--r-- | packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs | 12 |
11 files changed, 239 insertions, 133 deletions
diff --git a/packages/base/src/Internal/Algorithms.hs b/packages/base/src/Internal/Algorithms.hs index 3d25491..d2f17f4 100644 --- a/packages/base/src/Internal/Algorithms.hs +++ b/packages/base/src/Internal/Algorithms.hs | |||
@@ -4,6 +4,12 @@ | |||
4 | {-# LANGUAGE UndecidableInstances #-} | 4 | {-# LANGUAGE UndecidableInstances #-} |
5 | {-# LANGUAGE TypeFamilies #-} | 5 | {-# LANGUAGE TypeFamilies #-} |
6 | 6 | ||
7 | {-# LANGUAGE FlexibleContexts, FlexibleInstances #-} | ||
8 | {-# LANGUAGE CPP #-} | ||
9 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
10 | {-# LANGUAGE UndecidableInstances #-} | ||
11 | {-# LANGUAGE TypeFamilies #-} | ||
12 | |||
7 | ----------------------------------------------------------------------------- | 13 | ----------------------------------------------------------------------------- |
8 | {- | | 14 | {- | |
9 | Module : Internal.Algorithms | 15 | Module : Internal.Algorithms |
@@ -32,6 +38,7 @@ import Data.List(foldl1') | |||
32 | import qualified Data.Array as A | 38 | import qualified Data.Array as A |
33 | import Internal.ST | 39 | import Internal.ST |
34 | import Internal.Vectorized(range) | 40 | import Internal.Vectorized(range) |
41 | import Control.DeepSeq | ||
35 | 42 | ||
36 | {- | Generic linear algebra functions for double precision real and complex matrices. | 43 | {- | Generic linear algebra functions for double precision real and complex matrices. |
37 | 44 | ||
@@ -43,6 +50,10 @@ class (Numeric t, | |||
43 | Normed Matrix t, | 50 | Normed Matrix t, |
44 | Normed Vector t, | 51 | Normed Vector t, |
45 | Floating t, | 52 | Floating t, |
53 | Linear t Vector, | ||
54 | Linear t Matrix, | ||
55 | Additive (Vector t), | ||
56 | Additive (Matrix t), | ||
46 | RealOf t ~ Double) => Field t where | 57 | RealOf t ~ Double) => Field t where |
47 | svd' :: Matrix t -> (Matrix t, Vector Double, Matrix t) | 58 | svd' :: Matrix t -> (Matrix t, Vector Double, Matrix t) |
48 | thinSVD' :: Matrix t -> (Matrix t, Vector Double, Matrix t) | 59 | thinSVD' :: Matrix t -> (Matrix t, Vector Double, Matrix t) |
@@ -306,25 +317,38 @@ leftSV m | vertical m = let (u,s,_) = svd m in (u,s) | |||
306 | 317 | ||
307 | -------------------------------------------------------------- | 318 | -------------------------------------------------------------- |
308 | 319 | ||
320 | -- | LU decomposition of a matrix in a compact format. | ||
321 | data LU t = LU (Matrix t) [Int] deriving Show | ||
322 | |||
323 | instance (NFData t, Numeric t) => NFData (LU t) | ||
324 | where | ||
325 | rnf (LU m _) = rnf m | ||
326 | |||
309 | -- | Obtains the LU decomposition of a matrix in a compact data structure suitable for 'luSolve'. | 327 | -- | Obtains the LU decomposition of a matrix in a compact data structure suitable for 'luSolve'. |
310 | luPacked :: Field t => Matrix t -> (Matrix t, [Int]) | 328 | luPacked :: Field t => Matrix t -> LU t |
311 | luPacked = {-# SCC "luPacked" #-} luPacked' | 329 | luPacked x = {-# SCC "luPacked" #-} LU m p |
330 | where | ||
331 | (m,p) = luPacked' x | ||
312 | 332 | ||
313 | -- | Solution of a linear system (for several right hand sides) from the precomputed LU factorization obtained by 'luPacked'. | 333 | -- | Solution of a linear system (for several right hand sides) from the precomputed LU factorization obtained by 'luPacked'. |
314 | luSolve :: Field t => (Matrix t, [Int]) -> Matrix t -> Matrix t | 334 | luSolve :: Field t => LU t -> Matrix t -> Matrix t |
315 | luSolve = {-# SCC "luSolve" #-} luSolve' | 335 | luSolve (LU m p) = {-# SCC "luSolve" #-} luSolve' (m,p) |
316 | 336 | ||
317 | -- | Solve a linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition. For underconstrained or overconstrained systems use 'linearSolveLS' or 'linearSolveSVD'. | 337 | -- | Solve a linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition. For underconstrained or overconstrained systems use 'linearSolveLS' or 'linearSolveSVD'. |
318 | -- It is similar to 'luSolve' . 'luPacked', but @linearSolve@ raises an error if called on a singular system. | 338 | -- It is similar to 'luSolve' . 'luPacked', but @linearSolve@ raises an error if called on a singular system. |
319 | linearSolve :: Field t => Matrix t -> Matrix t -> Matrix t | 339 | linearSolve :: Field t => Matrix t -> Matrix t -> Matrix t |
320 | linearSolve = {-# SCC "linearSolve" #-} linearSolve' | 340 | linearSolve = {-# SCC "linearSolve" #-} linearSolve' |
321 | 341 | ||
322 | -- | Solve a linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition, returning Nothing for a singular system. For underconstrained or overconstrained systems use 'linearSolveLS' or 'linearSolveSVD'. | 342 | -- | Solve a linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition, returning Nothing for a singular system. For underconstrained or overconstrained systems use 'linearSolveLS' or 'linearSolveSVD'. |
323 | mbLinearSolve :: Field t => Matrix t -> Matrix t -> Maybe (Matrix t) | 343 | mbLinearSolve :: Field t => Matrix t -> Matrix t -> Maybe (Matrix t) |
324 | mbLinearSolve = {-# SCC "linearSolve" #-} mbLinearSolve' | 344 | mbLinearSolve = {-# SCC "linearSolve" #-} mbLinearSolve' |
325 | 345 | ||
326 | -- | Solve a symmetric or Hermitian positive definite linear system using a precomputed Cholesky decomposition obtained by 'chol'. | 346 | -- | Solve a symmetric or Hermitian positive definite linear system using a precomputed Cholesky decomposition obtained by 'chol'. |
327 | cholSolve :: Field t => Matrix t -> Matrix t -> Matrix t | 347 | cholSolve |
348 | :: Field t | ||
349 | => Matrix t -- ^ Cholesky decomposition of the coefficient matrix | ||
350 | -> Matrix t -- ^ right hand sides | ||
351 | -> Matrix t -- ^ solution | ||
328 | cholSolve = {-# SCC "cholSolve" #-} cholSolve' | 352 | cholSolve = {-# SCC "cholSolve" #-} cholSolve' |
329 | 353 | ||
330 | -- | Minimum norm solution of a general linear least squares problem Ax=B using the SVD. Admits rank-deficient systems but it is slower than 'linearSolveLS'. The effective rank of A is determined by treating as zero those singular valures which are less than 'eps' times the largest singular value. | 354 | -- | Minimum norm solution of a general linear least squares problem Ax=B using the SVD. Admits rank-deficient systems but it is slower than 'linearSolveLS'. The effective rank of A is determined by treating as zero those singular valures which are less than 'eps' times the largest singular value. |
@@ -338,20 +362,28 @@ linearSolveLS = {-# SCC "linearSolveLS" #-} linearSolveLS' | |||
338 | 362 | ||
339 | -------------------------------------------------------------------------------- | 363 | -------------------------------------------------------------------------------- |
340 | 364 | ||
365 | -- | LDL decomposition of a complex Hermitian or real symmetric matrix in a compact format. | ||
366 | data LDL t = LDL (Matrix t) [Int] deriving Show | ||
367 | |||
368 | instance (NFData t, Numeric t) => NFData (LDL t) | ||
369 | where | ||
370 | rnf (LDL m _) = rnf m | ||
371 | |||
341 | -- | Similar to 'ldlPacked', without checking that the input matrix is hermitian or symmetric. It works with the lower triangular part. | 372 | -- | Similar to 'ldlPacked', without checking that the input matrix is hermitian or symmetric. It works with the lower triangular part. |
342 | ldlPackedSH :: Field t => Matrix t -> (Matrix t, [Int]) | 373 | ldlPackedSH :: Field t => Matrix t -> LDL t |
343 | ldlPackedSH = {-# SCC "ldlPacked" #-} ldlPacked' | 374 | ldlPackedSH x = {-# SCC "ldlPacked" #-} LDL m p |
375 | where | ||
376 | (m,p) = ldlPacked' x | ||
344 | 377 | ||
345 | -- | Obtains the LDL decomposition of a matrix in a compact data structure suitable for 'ldlSolve'. | 378 | -- | Obtains the LDL decomposition of a matrix in a compact data structure suitable for 'ldlSolve'. |
346 | ldlPacked :: Field t => Matrix t -> (Matrix t, [Int]) | 379 | ldlPacked :: Field t => Her t -> LDL t |
347 | ldlPacked m | 380 | ldlPacked (Her m) = ldlPackedSH m |
348 | | exactHermitian m = {-# SCC "ldlPacked" #-} ldlPackedSH m | ||
349 | | otherwise = error "ldlPacked requires complex Hermitian or real symmetrix matrix" | ||
350 | |||
351 | 381 | ||
352 | -- | Solution of a linear system (for several right hand sides) from the precomputed LDL factorization obtained by 'ldlPacked'. | 382 | -- | Solution of a linear system (for several right hand sides) from a precomputed LDL factorization obtained by 'ldlPacked'. |
353 | ldlSolve :: Field t => (Matrix t, [Int]) -> Matrix t -> Matrix t | 383 | -- |
354 | ldlSolve = {-# SCC "ldlSolve" #-} ldlSolve' | 384 | -- Note: this can be slower than the general solver based on the LU decomposition. |
385 | ldlSolve :: Field t => LDL t -> Matrix t -> Matrix t | ||
386 | ldlSolve (LDL m p) = {-# SCC "ldlSolve" #-} ldlSolve' (m,p) | ||
355 | 387 | ||
356 | -------------------------------------------------------------- | 388 | -------------------------------------------------------------- |
357 | 389 | ||
@@ -429,14 +461,12 @@ fromList [11.344814282762075,0.17091518882717918,-0.5157294715892575] | |||
429 | 3.000 5.000 6.000 | 461 | 3.000 5.000 6.000 |
430 | 462 | ||
431 | -} | 463 | -} |
432 | eigSH :: Field t => Matrix t -> (Vector Double, Matrix t) | 464 | eigSH :: Field t => Her t -> (Vector Double, Matrix t) |
433 | eigSH m | exactHermitian m = eigSH' m | 465 | eigSH (Her m) = eigSH' m |
434 | | otherwise = error "eigSH requires complex hermitian or real symmetric matrix" | ||
435 | 466 | ||
436 | -- | Eigenvalues (in descending order) of a complex hermitian or real symmetric matrix. | 467 | -- | Eigenvalues (in descending order) of a complex hermitian or real symmetric matrix. |
437 | eigenvaluesSH :: Field t => Matrix t -> Vector Double | 468 | eigenvaluesSH :: Field t => Her t -> Vector Double |
438 | eigenvaluesSH m | exactHermitian m = eigenvaluesSH' m | 469 | eigenvaluesSH (Her m) = eigenvaluesSH' m |
439 | | otherwise = error "eigenvaluesSH requires complex hermitian or real symmetric matrix" | ||
440 | 470 | ||
441 | -------------------------------------------------------------- | 471 | -------------------------------------------------------------- |
442 | 472 | ||
@@ -490,14 +520,18 @@ mbCholSH = {-# SCC "mbCholSH" #-} mbCholSH' | |||
490 | 520 | ||
491 | -- | Similar to 'chol', without checking that the input matrix is hermitian or symmetric. It works with the upper triangular part. | 521 | -- | Similar to 'chol', without checking that the input matrix is hermitian or symmetric. It works with the upper triangular part. |
492 | cholSH :: Field t => Matrix t -> Matrix t | 522 | cholSH :: Field t => Matrix t -> Matrix t |
493 | cholSH = {-# SCC "cholSH" #-} cholSH' | 523 | cholSH = cholSH' |
494 | 524 | ||
495 | -- | Cholesky factorization of a positive definite hermitian or symmetric matrix. | 525 | -- | Cholesky factorization of a positive definite hermitian or symmetric matrix. |
496 | -- | 526 | -- |
497 | -- If @c = chol m@ then @c@ is upper triangular and @m == tr c \<> c@. | 527 | -- If @c = chol m@ then @c@ is upper triangular and @m == tr c \<> c@. |
498 | chol :: Field t => Matrix t -> Matrix t | 528 | chol :: Field t => Her t -> Matrix t |
499 | chol m | exactHermitian m = cholSH m | 529 | chol (Her m) = {-# SCC "chol" #-} cholSH' m |
500 | | otherwise = error "chol requires positive definite complex hermitian or real symmetric matrix" | 530 | |
531 | -- | Similar to 'chol', but instead of an error (e.g., caused by a matrix not positive definite) it returns 'Nothing'. | ||
532 | mbChol :: Field t => Her t -> Maybe (Matrix t) | ||
533 | mbChol (Her m) = {-# SCC "mbChol" #-} mbCholSH' m | ||
534 | |||
501 | 535 | ||
502 | 536 | ||
503 | -- | Joint computation of inverse and logarithm of determinant of a square matrix. | 537 | -- | Joint computation of inverse and logarithm of determinant of a square matrix. |
@@ -507,7 +541,7 @@ invlndet :: Field t | |||
507 | invlndet m | square m = (im,(ladm,sdm)) | 541 | invlndet m | square m = (im,(ladm,sdm)) |
508 | | otherwise = error $ "invlndet of nonsquare "++ shSize m ++ " matrix" | 542 | | otherwise = error $ "invlndet of nonsquare "++ shSize m ++ " matrix" |
509 | where | 543 | where |
510 | lp@(lup,perm) = luPacked m | 544 | lp@(LU lup perm) = luPacked m |
511 | s = signlp (rows m) perm | 545 | s = signlp (rows m) perm |
512 | dg = toList $ takeDiag $ lup | 546 | dg = toList $ takeDiag $ lup |
513 | ladm = sum $ map (log.abs) dg | 547 | ladm = sum $ map (log.abs) dg |
@@ -519,8 +553,9 @@ invlndet m | square m = (im,(ladm,sdm)) | |||
519 | det :: Field t => Matrix t -> t | 553 | det :: Field t => Matrix t -> t |
520 | det m | square m = {-# SCC "det" #-} s * (product $ toList $ takeDiag $ lup) | 554 | det m | square m = {-# SCC "det" #-} s * (product $ toList $ takeDiag $ lup) |
521 | | otherwise = error $ "det of nonsquare "++ shSize m ++ " matrix" | 555 | | otherwise = error $ "det of nonsquare "++ shSize m ++ " matrix" |
522 | where (lup,perm) = luPacked m | 556 | where |
523 | s = signlp (rows m) perm | 557 | LU lup perm = luPacked m |
558 | s = signlp (rows m) perm | ||
524 | 559 | ||
525 | -- | Explicit LU factorization of a general matrix. | 560 | -- | Explicit LU factorization of a general matrix. |
526 | -- | 561 | -- |
@@ -720,7 +755,7 @@ diagonalize m = if rank v == n | |||
720 | else Nothing | 755 | else Nothing |
721 | where n = rows m | 756 | where n = rows m |
722 | (l,v) = if exactHermitian m | 757 | (l,v) = if exactHermitian m |
723 | then let (l',v') = eigSH m in (real l', v') | 758 | then let (l',v') = eigSH (trustSym m) in (real l', v') |
724 | else eig m | 759 | else eig m |
725 | 760 | ||
726 | -- | Generic matrix functions for diagonalizable matrices. For instance: | 761 | -- | Generic matrix functions for diagonalizable matrices. For instance: |
@@ -835,8 +870,9 @@ fixPerm' s = res $ mutable f s0 | |||
835 | triang r c h v = (r><c) [el s t | s<-[0..r-1], t<-[0..c-1]] | 870 | triang r c h v = (r><c) [el s t | s<-[0..r-1], t<-[0..c-1]] |
836 | where el p q = if q-p>=h then v else 1 - v | 871 | where el p q = if q-p>=h then v else 1 - v |
837 | 872 | ||
838 | luFact (l_u,perm) | r <= c = (l ,u ,p, s) | 873 | luFact (LU l_u perm) |
839 | | otherwise = (l',u',p, s) | 874 | | r <= c = (l ,u ,p, s) |
875 | | otherwise = (l',u',p, s) | ||
840 | where | 876 | where |
841 | r = rows l_u | 877 | r = rows l_u |
842 | c = cols l_u | 878 | c = cols l_u |
@@ -929,7 +965,13 @@ relativeError norm a b = r | |||
929 | ---------------------------------------------------------------------- | 965 | ---------------------------------------------------------------------- |
930 | 966 | ||
931 | -- | Generalized symmetric positive definite eigensystem Av = lBv, | 967 | -- | Generalized symmetric positive definite eigensystem Av = lBv, |
932 | -- for A and B symmetric, B positive definite (conditions not checked). | 968 | -- for A and B symmetric, B positive definite. |
969 | geigSH :: Field t | ||
970 | => Her t -- ^ A | ||
971 | -> Her t -- ^ B | ||
972 | -> (Vector Double, Matrix t) | ||
973 | geigSH (Her a) (Her b) = geigSH' a b | ||
974 | |||
933 | geigSH' :: Field t | 975 | geigSH' :: Field t |
934 | => Matrix t -- ^ A | 976 | => Matrix t -- ^ A |
935 | -> Matrix t -- ^ B | 977 | -> Matrix t -- ^ B |
@@ -943,3 +985,33 @@ geigSH' a b = (l,v') | |||
943 | v' = iu <> v | 985 | v' = iu <> v |
944 | (<>) = mXm | 986 | (<>) = mXm |
945 | 987 | ||
988 | -------------------------------------------------------------------------------- | ||
989 | |||
990 | -- | A matrix that, by construction, it is known to be complex Hermitian or real symmetric. | ||
991 | -- | ||
992 | -- It can be created using 'sym', 'xTx', or 'trustSym', and the matrix can be extracted using 'her'. | ||
993 | data Her t = Her (Matrix t) deriving Show | ||
994 | |||
995 | -- | Extract the general matrix from a 'Her' structure, forgetting its symmetric or Hermitian property. | ||
996 | her :: Her t -> Matrix t | ||
997 | her (Her x) = x | ||
998 | |||
999 | -- | Compute the complex Hermitian or real symmetric part of a square matrix (@(x + tr x)/2@). | ||
1000 | sym :: Field t => Matrix t -> Her t | ||
1001 | sym x = Her (scale 0.5 (tr x `add` x)) | ||
1002 | |||
1003 | -- | Compute the contraction @tr x <> x@ of a general matrix. | ||
1004 | xTx :: Numeric t => Matrix t -> Her t | ||
1005 | xTx x = Her (tr x `mXm` x) | ||
1006 | |||
1007 | instance Field t => Linear t Her where | ||
1008 | scale x (Her m) = Her (scale x m) | ||
1009 | |||
1010 | instance Field t => Additive (Her t) where | ||
1011 | add (Her a) (Her b) = Her (a `add` b) | ||
1012 | |||
1013 | -- | At your own risk, declare that a matrix is complex Hermitian or real symmetric | ||
1014 | -- for usage in 'chol', 'eigSH', etc. Only a triangular part of the matrix will be used. | ||
1015 | trustSym :: Matrix t -> Her t | ||
1016 | trustSym x = (Her x) | ||
1017 | |||
diff --git a/packages/base/src/Internal/CG.hs b/packages/base/src/Internal/CG.hs index f0142cd..cc10ad8 100644 --- a/packages/base/src/Internal/CG.hs +++ b/packages/base/src/Internal/CG.hs | |||
@@ -32,11 +32,11 @@ v /// b = debugMat b 2 asRow v | |||
32 | type V = Vector R | 32 | type V = Vector R |
33 | 33 | ||
34 | data CGState = CGState | 34 | data CGState = CGState |
35 | { cgp :: V -- ^ conjugate gradient | 35 | { cgp :: Vector R -- ^ conjugate gradient |
36 | , cgr :: V -- ^ residual | 36 | , cgr :: Vector R -- ^ residual |
37 | , cgr2 :: R -- ^ squared norm of residual | 37 | , cgr2 :: R -- ^ squared norm of residual |
38 | , cgx :: V -- ^ current solution | 38 | , cgx :: Vector R -- ^ current solution |
39 | , cgdx :: R -- ^ normalized size of correction | 39 | , cgdx :: R -- ^ normalized size of correction |
40 | } | 40 | } |
41 | 41 | ||
42 | cg :: Bool -> (V -> V) -> (V -> V) -> CGState -> CGState | 42 | cg :: Bool -> (V -> V) -> (V -> V) -> CGState -> CGState |
@@ -89,23 +89,25 @@ takeUntil q xs = a++ take 1 b | |||
89 | where | 89 | where |
90 | (a,b) = break q xs | 90 | (a,b) = break q xs |
91 | 91 | ||
92 | -- | Solve a sparse linear system using the conjugate gradient method with default parameters. | ||
92 | cgSolve | 93 | cgSolve |
93 | :: Bool -- ^ is symmetric | 94 | :: Bool -- ^ is symmetric |
94 | -> GMatrix -- ^ coefficient matrix | 95 | -> GMatrix -- ^ coefficient matrix |
95 | -> Vector Double -- ^ right-hand side | 96 | -> Vector R -- ^ right-hand side |
96 | -> Vector Double -- ^ solution | 97 | -> Vector R -- ^ solution |
97 | cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0 | 98 | cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0 |
98 | where | 99 | where |
99 | n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double)) | 100 | n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double)) |
100 | 101 | ||
102 | -- | Solve a sparse linear system using the conjugate gradient method with default parameters. | ||
101 | cgSolve' | 103 | cgSolve' |
102 | :: Bool -- ^ symmetric | 104 | :: Bool -- ^ symmetric |
103 | -> R -- ^ relative tolerance for the residual (e.g. 1E-4) | 105 | -> R -- ^ relative tolerance for the residual (e.g. 1E-4) |
104 | -> R -- ^ relative tolerance for δx (e.g. 1E-3) | 106 | -> R -- ^ relative tolerance for δx (e.g. 1E-3) |
105 | -> Int -- ^ maximum number of iterations | 107 | -> Int -- ^ maximum number of iterations |
106 | -> GMatrix -- ^ coefficient matrix | 108 | -> GMatrix -- ^ coefficient matrix |
107 | -> V -- ^ initial solution | 109 | -> Vector R -- ^ initial solution |
108 | -> V -- ^ right-hand side | 110 | -> Vector R -- ^ right-hand side |
109 | -> [CGState] -- ^ solution | 111 | -> [CGState] -- ^ solution |
110 | cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es | 112 | cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es |
111 | 113 | ||
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index 64ed2bb..a3421a8 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs | |||
@@ -33,7 +33,7 @@ import Internal.Element | |||
33 | import Internal.Container | 33 | import Internal.Container |
34 | import Internal.Vectorized (prodI,sumI,prodL,sumL) | 34 | import Internal.Vectorized (prodI,sumI,prodL,sumL) |
35 | import Internal.LAPACK (multiplyI, multiplyL) | 35 | import Internal.LAPACK (multiplyI, multiplyL) |
36 | import Internal.Algorithms(luFact) | 36 | import Internal.Algorithms(luFact,LU(..)) |
37 | import Internal.Util(Normed(..),Indexable(..), | 37 | import Internal.Util(Normed(..),Indexable(..), |
38 | gaussElim, gaussElim_1, gaussElim_2, | 38 | gaussElim, gaussElim_1, gaussElim_2, |
39 | luST, luSolve', luPacked', magnit, invershur) | 39 | luST, luSolve', luPacked', magnit, invershur) |
@@ -169,7 +169,7 @@ instance forall m . KnownNat m => Container Vector (Mod m I) | |||
169 | size' = dim | 169 | size' = dim |
170 | scale' s x = vmod (scale (unMod s) (f2i x)) | 170 | scale' s x = vmod (scale (unMod s) (f2i x)) |
171 | addConstant c x = vmod (addConstant (unMod c) (f2i x)) | 171 | addConstant c x = vmod (addConstant (unMod c) (f2i x)) |
172 | add a b = vmod (add (f2i a) (f2i b)) | 172 | add' a b = vmod (add' (f2i a) (f2i b)) |
173 | sub a b = vmod (sub (f2i a) (f2i b)) | 173 | sub a b = vmod (sub (f2i a) (f2i b)) |
174 | mul a b = vmod (mul (f2i a) (f2i b)) | 174 | mul a b = vmod (mul (f2i a) (f2i b)) |
175 | equal u v = equal (f2i u) (f2i v) | 175 | equal u v = equal (f2i u) (f2i v) |
@@ -209,7 +209,7 @@ instance forall m . KnownNat m => Container Vector (Mod m Z) | |||
209 | size' = dim | 209 | size' = dim |
210 | scale' s x = vmod (scale (unMod s) (f2i x)) | 210 | scale' s x = vmod (scale (unMod s) (f2i x)) |
211 | addConstant c x = vmod (addConstant (unMod c) (f2i x)) | 211 | addConstant c x = vmod (addConstant (unMod c) (f2i x)) |
212 | add a b = vmod (add (f2i a) (f2i b)) | 212 | add' a b = vmod (add' (f2i a) (f2i b)) |
213 | sub a b = vmod (sub (f2i a) (f2i b)) | 213 | sub a b = vmod (sub (f2i a) (f2i b)) |
214 | mul a b = vmod (mul (f2i a) (f2i b)) | 214 | mul a b = vmod (mul (f2i a) (f2i b)) |
215 | equal u v = equal (f2i u) (f2i v) | 215 | equal u v = equal (f2i u) (f2i v) |
@@ -371,7 +371,9 @@ test = (ok, info) | |||
371 | 371 | ||
372 | checkLU okf t = norm_Inf $ flatten (l <> u <> p - t) | 372 | checkLU okf t = norm_Inf $ flatten (l <> u <> p - t) |
373 | where | 373 | where |
374 | (l,u,p,_ :: Int) = luFact $ mutable (luST okf) t | 374 | (l,u,p,_ :: Int) = luFact (LU x' p') |
375 | where | ||
376 | (x',p') = mutable (luST okf) t | ||
375 | 377 | ||
376 | checkSolve aa = norm_Inf $ flatten (aa <> x - bb) | 378 | checkSolve aa = norm_Inf $ flatten (aa <> x - bb) |
377 | where | 379 | where |
diff --git a/packages/base/src/Internal/Numeric.hs b/packages/base/src/Internal/Numeric.hs index a8ae2bb..e8c7440 100644 --- a/packages/base/src/Internal/Numeric.hs +++ b/packages/base/src/Internal/Numeric.hs | |||
@@ -49,7 +49,7 @@ class Element e => Container c e | |||
49 | scalar' :: e -> c e | 49 | scalar' :: e -> c e |
50 | scale' :: e -> c e -> c e | 50 | scale' :: e -> c e -> c e |
51 | addConstant :: e -> c e -> c e | 51 | addConstant :: e -> c e -> c e |
52 | add :: c e -> c e -> c e | 52 | add' :: c e -> c e -> c e |
53 | sub :: c e -> c e -> c e | 53 | sub :: c e -> c e -> c e |
54 | -- | element by element multiplication | 54 | -- | element by element multiplication |
55 | mul :: c e -> c e -> c e | 55 | mul :: c e -> c e -> c e |
@@ -100,7 +100,7 @@ instance Container Vector I | |||
100 | size' = dim | 100 | size' = dim |
101 | scale' = vectorMapValI Scale | 101 | scale' = vectorMapValI Scale |
102 | addConstant = vectorMapValI AddConstant | 102 | addConstant = vectorMapValI AddConstant |
103 | add = vectorZipI Add | 103 | add' = vectorZipI Add |
104 | sub = vectorZipI Sub | 104 | sub = vectorZipI Sub |
105 | mul = vectorZipI Mul | 105 | mul = vectorZipI Mul |
106 | equal u v = dim u == dim v && maxElement' (vectorMapI Abs (sub u v)) == 0 | 106 | equal u v = dim u == dim v && maxElement' (vectorMapI Abs (sub u v)) == 0 |
@@ -139,7 +139,7 @@ instance Container Vector Z | |||
139 | size' = dim | 139 | size' = dim |
140 | scale' = vectorMapValL Scale | 140 | scale' = vectorMapValL Scale |
141 | addConstant = vectorMapValL AddConstant | 141 | addConstant = vectorMapValL AddConstant |
142 | add = vectorZipL Add | 142 | add' = vectorZipL Add |
143 | sub = vectorZipL Sub | 143 | sub = vectorZipL Sub |
144 | mul = vectorZipL Mul | 144 | mul = vectorZipL Mul |
145 | equal u v = dim u == dim v && maxElement' (vectorMapL Abs (sub u v)) == 0 | 145 | equal u v = dim u == dim v && maxElement' (vectorMapL Abs (sub u v)) == 0 |
@@ -179,7 +179,7 @@ instance Container Vector Float | |||
179 | size' = dim | 179 | size' = dim |
180 | scale' = vectorMapValF Scale | 180 | scale' = vectorMapValF Scale |
181 | addConstant = vectorMapValF AddConstant | 181 | addConstant = vectorMapValF AddConstant |
182 | add = vectorZipF Add | 182 | add' = vectorZipF Add |
183 | sub = vectorZipF Sub | 183 | sub = vectorZipF Sub |
184 | mul = vectorZipF Mul | 184 | mul = vectorZipF Mul |
185 | equal u v = dim u == dim v && maxElement (vectorMapF Abs (sub u v)) == 0.0 | 185 | equal u v = dim u == dim v && maxElement (vectorMapF Abs (sub u v)) == 0.0 |
@@ -216,7 +216,7 @@ instance Container Vector Double | |||
216 | size' = dim | 216 | size' = dim |
217 | scale' = vectorMapValR Scale | 217 | scale' = vectorMapValR Scale |
218 | addConstant = vectorMapValR AddConstant | 218 | addConstant = vectorMapValR AddConstant |
219 | add = vectorZipR Add | 219 | add' = vectorZipR Add |
220 | sub = vectorZipR Sub | 220 | sub = vectorZipR Sub |
221 | mul = vectorZipR Mul | 221 | mul = vectorZipR Mul |
222 | equal u v = dim u == dim v && maxElement (vectorMapR Abs (sub u v)) == 0.0 | 222 | equal u v = dim u == dim v && maxElement (vectorMapR Abs (sub u v)) == 0.0 |
@@ -253,7 +253,7 @@ instance Container Vector (Complex Double) | |||
253 | size' = dim | 253 | size' = dim |
254 | scale' = vectorMapValC Scale | 254 | scale' = vectorMapValC Scale |
255 | addConstant = vectorMapValC AddConstant | 255 | addConstant = vectorMapValC AddConstant |
256 | add = vectorZipC Add | 256 | add' = vectorZipC Add |
257 | sub = vectorZipC Sub | 257 | sub = vectorZipC Sub |
258 | mul = vectorZipC Mul | 258 | mul = vectorZipC Mul |
259 | equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 | 259 | equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 |
@@ -289,7 +289,7 @@ instance Container Vector (Complex Float) | |||
289 | size' = dim | 289 | size' = dim |
290 | scale' = vectorMapValQ Scale | 290 | scale' = vectorMapValQ Scale |
291 | addConstant = vectorMapValQ AddConstant | 291 | addConstant = vectorMapValQ AddConstant |
292 | add = vectorZipQ Add | 292 | add' = vectorZipQ Add |
293 | sub = vectorZipQ Sub | 293 | sub = vectorZipQ Sub |
294 | mul = vectorZipQ Mul | 294 | mul = vectorZipQ Mul |
295 | equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 | 295 | equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 |
@@ -327,7 +327,7 @@ instance (Num a, Element a, Container Vector a) => Container Matrix a | |||
327 | size' = size | 327 | size' = size |
328 | scale' x = liftMatrix (scale' x) | 328 | scale' x = liftMatrix (scale' x) |
329 | addConstant x = liftMatrix (addConstant x) | 329 | addConstant x = liftMatrix (addConstant x) |
330 | add = liftMatrix2 add | 330 | add' = liftMatrix2 add' |
331 | sub = liftMatrix2 sub | 331 | sub = liftMatrix2 sub |
332 | mul = liftMatrix2 mul | 332 | mul = liftMatrix2 mul |
333 | equal a b = cols a == cols b && flatten a `equal` flatten b | 333 | equal a b = cols a == cols b && flatten a `equal` flatten b |
@@ -387,9 +387,6 @@ scalar = scalar' | |||
387 | conj :: Container c e => c e -> c e | 387 | conj :: Container c e => c e -> c e |
388 | conj = conj' | 388 | conj = conj' |
389 | 389 | ||
390 | -- | multiplication by scalar | ||
391 | scale :: Container c e => e -> c e -> c e | ||
392 | scale = scale' | ||
393 | 390 | ||
394 | arctan2 :: (Fractional e, Container c e) => c e -> c e -> c e | 391 | arctan2 :: (Fractional e, Container c e) => c e -> c e -> c e |
395 | arctan2 = arctan2' | 392 | arctan2 = arctan2' |
@@ -581,6 +578,10 @@ class ( Container Vector t | |||
581 | , Konst t (Int,Int) Matrix | 578 | , Konst t (Int,Int) Matrix |
582 | , CTrans t | 579 | , CTrans t |
583 | , Product t | 580 | , Product t |
581 | , Additive (Vector t) | ||
582 | , Additive (Matrix t) | ||
583 | , Linear t Vector | ||
584 | , Linear t Matrix | ||
584 | ) => Numeric t | 585 | ) => Numeric t |
585 | 586 | ||
586 | instance Numeric Double | 587 | instance Numeric Double |
@@ -912,11 +913,30 @@ instance (CTrans t, Container Vector t) => Transposable (Matrix t) (Matrix t) | |||
912 | tr = ctrans | 913 | tr = ctrans |
913 | tr' = trans | 914 | tr' = trans |
914 | 915 | ||
915 | class Linear t v | 916 | class Additive c |
916 | where | 917 | where |
917 | scalarL :: t -> v | 918 | add :: c -> c -> c |
918 | addL :: v -> v -> v | 919 | |
919 | scaleL :: t -> v -> v | 920 | class Linear t c |
921 | where | ||
922 | scale :: t -> c t -> c t | ||
923 | |||
924 | |||
925 | instance Container Vector t => Linear t Vector | ||
926 | where | ||
927 | scale = scale' | ||
928 | |||
929 | instance Container Matrix t => Linear t Matrix | ||
930 | where | ||
931 | scale = scale' | ||
932 | |||
933 | instance Container Vector t => Additive (Vector t) | ||
934 | where | ||
935 | add = add' | ||
936 | |||
937 | instance Container Matrix t => Additive (Matrix t) | ||
938 | where | ||
939 | add = add' | ||
920 | 940 | ||
921 | 941 | ||
922 | class Testable t | 942 | class Testable t |
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs index 4123e6c..36b7855 100644 --- a/packages/base/src/Internal/Util.hs +++ b/packages/base/src/Internal/Util.hs | |||
@@ -458,12 +458,12 @@ rowOuters a b = a' * b' | |||
458 | -------------------------------------------------------------------------------- | 458 | -------------------------------------------------------------------------------- |
459 | 459 | ||
460 | -- | solution of overconstrained homogeneous linear system | 460 | -- | solution of overconstrained homogeneous linear system |
461 | null1 :: Matrix Double -> Vector Double | 461 | null1 :: Matrix R -> Vector R |
462 | null1 = last . toColumns . snd . rightSV | 462 | null1 = last . toColumns . snd . rightSV |
463 | 463 | ||
464 | -- | solution of overconstrained homogeneous symmetric linear system | 464 | -- | solution of overconstrained homogeneous symmetric linear system |
465 | null1sym :: Matrix Double -> Vector Double | 465 | null1sym :: Her R -> Vector R |
466 | null1sym = last . toColumns . snd . eigSH' | 466 | null1sym = last . toColumns . snd . eigSH |
467 | 467 | ||
468 | -------------------------------------------------------------------------------- | 468 | -------------------------------------------------------------------------------- |
469 | 469 | ||
@@ -712,7 +712,9 @@ luST ok (r,_) x = do | |||
712 | , 0, 0, 0, 0, 1 ] | 712 | , 0, 0, 0, 0, 1 ] |
713 | 713 | ||
714 | -} | 714 | -} |
715 | luPacked' x = mutable (luST (magnit 0)) x | 715 | luPacked' x = LU m p |
716 | where | ||
717 | (m,p) = mutable (luST (magnit 0)) x | ||
716 | 718 | ||
717 | -------------------------------------------------------------------------------- | 719 | -------------------------------------------------------------------------------- |
718 | 720 | ||
@@ -782,7 +784,7 @@ forwSust' lup rhs = foldl' f (rhs?[]) ls | |||
782 | (b - l<>x) | 784 | (b - l<>x) |
783 | 785 | ||
784 | 786 | ||
785 | luSolve'' (lup,p) b = backSust' lup (forwSust' lup pb) | 787 | luSolve'' (LU lup p) b = backSust' lup (forwSust' lup pb) |
786 | where | 788 | where |
787 | pb = b ?? (Pos (fixPerm' p), All) | 789 | pb = b ?? (Pos (fixPerm' p), All) |
788 | 790 | ||
@@ -827,7 +829,7 @@ backSust lup rhs = fst $ mutable f rhs | |||
827 | , 7, 10, 6 ] | 829 | , 7, 10, 6 ] |
828 | 830 | ||
829 | -} | 831 | -} |
830 | luSolve' (lup,p) b = backSust lup (forwSust lup pb) | 832 | luSolve' (LU lup p) b = backSust lup (forwSust lup pb) |
831 | where | 833 | where |
832 | pb = b ?? (Pos (fixPerm' p), All) | 834 | pb = b ?? (Pos (fixPerm' p), All) |
833 | 835 | ||
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs index 9a924e0..7be2600 100644 --- a/packages/base/src/Numeric/LinearAlgebra.hs +++ b/packages/base/src/Numeric/LinearAlgebra.hs | |||
@@ -53,11 +53,11 @@ module Numeric.LinearAlgebra ( | |||
53 | -- | 53 | -- |
54 | 54 | ||
55 | -- * Products | 55 | -- * Products |
56 | -- ** dot | 56 | -- ** Dot |
57 | dot, (<.>), | 57 | dot, (<.>), |
58 | -- ** matrix-vector | 58 | -- ** Matrix-vector |
59 | (#>), (<#), (!#>), | 59 | (#>), (<#), (!#>), |
60 | -- ** matrix-matrix | 60 | -- ** Matrix-matrix |
61 | (<>), | 61 | (<>), |
62 | -- | The matrix product is also implemented in the "Data.Monoid" instance, where | 62 | -- | The matrix product is also implemented in the "Data.Monoid" instance, where |
63 | -- single-element matrices (created from numeric literals or using 'scalar') | 63 | -- single-element matrices (created from numeric literals or using 'scalar') |
@@ -73,20 +73,25 @@ module Numeric.LinearAlgebra ( | |||
73 | -- 'mconcat' uses 'optimiseMult' to get the optimal association order. | 73 | -- 'mconcat' uses 'optimiseMult' to get the optimal association order. |
74 | 74 | ||
75 | 75 | ||
76 | -- ** other | 76 | -- ** Other |
77 | outer, kronecker, cross, | 77 | outer, kronecker, cross, |
78 | scale, | 78 | scale, add, |
79 | sumElements, prodElements, | 79 | sumElements, prodElements, |
80 | 80 | ||
81 | -- * Linear systems | 81 | -- * Linear systems |
82 | -- ** General | ||
82 | (<\>), | 83 | (<\>), |
83 | linearSolve, | ||
84 | linearSolveLS, | 84 | linearSolveLS, |
85 | linearSolveSVD, | 85 | linearSolveSVD, |
86 | luSolve, | 86 | -- ** Determined |
87 | luSolve', | 87 | linearSolve, |
88 | luSolve, luPacked, | ||
89 | luSolve', luPacked', | ||
90 | -- ** Symmetric indefinite | ||
91 | ldlSolve, ldlPacked, | ||
92 | -- ** Positive definite | ||
88 | cholSolve, | 93 | cholSolve, |
89 | ldlSolve, | 94 | -- ** Sparse |
90 | cgSolve, | 95 | cgSolve, |
91 | cgSolve', | 96 | cgSolve', |
92 | 97 | ||
@@ -113,21 +118,18 @@ module Numeric.LinearAlgebra ( | |||
113 | leftSV, rightSV, | 118 | leftSV, rightSV, |
114 | 119 | ||
115 | -- * Eigendecomposition | 120 | -- * Eigendecomposition |
116 | eig, eigSH, eigSH', | 121 | eig, eigSH, |
117 | eigenvalues, eigenvaluesSH, eigenvaluesSH', | 122 | eigenvalues, eigenvaluesSH, |
118 | geigSH', | 123 | geigSH, |
119 | 124 | ||
120 | -- * QR | 125 | -- * QR |
121 | qr, rq, qrRaw, qrgr, | 126 | qr, rq, qrRaw, qrgr, |
122 | 127 | ||
123 | -- * Cholesky | 128 | -- * Cholesky |
124 | chol, cholSH, mbCholSH, | 129 | chol, mbChol, |
125 | 130 | ||
126 | -- * LU | 131 | -- * LU |
127 | lu, luPacked, luPacked', luFact, | 132 | lu, luFact, |
128 | |||
129 | -- * LDL | ||
130 | ldlPacked, ldlPackedSH, | ||
131 | 133 | ||
132 | -- * Hessenberg | 134 | -- * Hessenberg |
133 | hess, | 135 | hess, |
@@ -150,14 +152,16 @@ module Numeric.LinearAlgebra ( | |||
150 | -- * Misc | 152 | -- * Misc |
151 | meanCov, rowOuters, pairwiseD2, unitary, peps, relativeError, magnit, | 153 | meanCov, rowOuters, pairwiseD2, unitary, peps, relativeError, magnit, |
152 | haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv, | 154 | haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv, |
153 | iC, | 155 | iC, sym, xTx, trustSym, her, |
154 | -- * Auxiliary classes | 156 | -- * Auxiliary classes |
155 | Element, Container, Product, Numeric, LSDiv, | 157 | Element, Container, Product, Numeric, LSDiv, Her, |
156 | Complexable, RealElement, | 158 | Complexable, RealElement, |
157 | RealOf, ComplexOf, SingleOf, DoubleOf, | 159 | RealOf, ComplexOf, SingleOf, DoubleOf, |
158 | IndexOf, | 160 | IndexOf, |
159 | Field, | 161 | Field, Linear(), Additive(), |
160 | Transposable, | 162 | Transposable, |
163 | LU(..), | ||
164 | LDL(..), | ||
161 | CGState(..), | 165 | CGState(..), |
162 | Testable(..) | 166 | Testable(..) |
163 | ) where | 167 | ) where |
@@ -169,7 +173,7 @@ import Numeric.Vector() | |||
169 | import Internal.Matrix | 173 | import Internal.Matrix |
170 | import Internal.Container hiding ((<>)) | 174 | import Internal.Container hiding ((<>)) |
171 | import Internal.Numeric hiding (mul) | 175 | import Internal.Numeric hiding (mul) |
172 | import Internal.Algorithms hiding (linearSolve,Normed,orth,luPacked',linearSolve',luSolve') | 176 | import Internal.Algorithms hiding (linearSolve,Normed,orth,luPacked',linearSolve',luSolve',ldlPacked') |
173 | import qualified Internal.Algorithms as A | 177 | import qualified Internal.Algorithms as A |
174 | import Internal.Util | 178 | import Internal.Util |
175 | import Internal.Random | 179 | import Internal.Random |
@@ -246,4 +250,3 @@ nullspace m = nullspaceSVD (Left (1*eps)) m (rightSV m) | |||
246 | -- | return an orthonormal basis of the range space of a matrix. See also 'orthSVD'. | 250 | -- | return an orthonormal basis of the range space of a matrix. See also 'orthSVD'. |
247 | orth m = orthSVD (Left (1*eps)) m (leftSV m) | 251 | orth m = orthSVD (Left (1*eps)) m (leftSV m) |
248 | 252 | ||
249 | |||
diff --git a/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs b/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs index bac1c0c..5ce529c 100644 --- a/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs +++ b/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs | |||
@@ -13,11 +13,12 @@ compatibility with previous version, to be removed | |||
13 | 13 | ||
14 | module Numeric.LinearAlgebra.HMatrix ( | 14 | module Numeric.LinearAlgebra.HMatrix ( |
15 | module Numeric.LinearAlgebra, | 15 | module Numeric.LinearAlgebra, |
16 | (¦),(——),ℝ,ℂ,(<·>),app,mul | 16 | (¦),(——),ℝ,ℂ,(<·>),app,mul, cholSH, mbCholSH, eigSH', eigenvaluesSH', geigSH' |
17 | ) where | 17 | ) where |
18 | 18 | ||
19 | import Numeric.LinearAlgebra | 19 | import Numeric.LinearAlgebra |
20 | import Internal.Util | 20 | import Internal.Util |
21 | import Internal.Algorithms(cholSH, mbCholSH, eigSH', eigenvaluesSH', geigSH') | ||
21 | 22 | ||
22 | infixr 8 <·> | 23 | infixr 8 <·> |
23 | (<·>) :: Numeric t => Vector t -> Vector t -> t | 24 | (<·>) :: Numeric t => Vector t -> Vector t -> t |
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs index 0dab0e6..ded69fa 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs | |||
@@ -63,9 +63,9 @@ import GHC.TypeLits | |||
63 | import Numeric.LinearAlgebra hiding ( | 63 | import Numeric.LinearAlgebra hiding ( |
64 | (<>),(#>),(<.>),Konst(..),diag, disp,(===),(|||), | 64 | (<>),(#>),(<.>),Konst(..),diag, disp,(===),(|||), |
65 | row,col,vector,matrix,linspace,toRows,toColumns, | 65 | row,col,vector,matrix,linspace,toRows,toColumns, |
66 | (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH', | 66 | (<\>),fromList,takeDiag,svd,eig,eigSH, |
67 | eigenvalues,eigenvaluesSH,eigenvaluesSH',build, | 67 | eigenvalues,eigenvaluesSH,build, |
68 | qr,size,dot,chol,range,R,C) | 68 | qr,size,dot,chol,range,R,C,Her,her,sym) |
69 | import qualified Numeric.LinearAlgebra as LA | 69 | import qualified Numeric.LinearAlgebra as LA |
70 | import Data.Proxy(Proxy) | 70 | import Data.Proxy(Proxy) |
71 | import Internal.Static | 71 | import Internal.Static |
@@ -292,10 +292,10 @@ her m = Her $ (m + LA.tr m)/2 | |||
292 | 292 | ||
293 | instance KnownNat n => Eigen (Sym n) (R n) (L n n) | 293 | instance KnownNat n => Eigen (Sym n) (R n) (L n n) |
294 | where | 294 | where |
295 | eigenvalues (Sym (extract -> m)) = mkR . LA.eigenvaluesSH' $ m | 295 | eigenvalues (Sym (extract -> m)) = mkR . LA.eigenvaluesSH . LA.trustSym $ m |
296 | eigensystem (Sym (extract -> m)) = (mkR l, mkL v) | 296 | eigensystem (Sym (extract -> m)) = (mkR l, mkL v) |
297 | where | 297 | where |
298 | (l,v) = LA.eigSH' m | 298 | (l,v) = LA.eigSH . LA.trustSym $ m |
299 | 299 | ||
300 | instance KnownNat n => Eigen (Sq n) (C n) (M n n) | 300 | instance KnownNat n => Eigen (Sq n) (C n) (M n n) |
301 | where | 301 | where |
@@ -305,7 +305,7 @@ instance KnownNat n => Eigen (Sq n) (C n) (M n n) | |||
305 | (l,v) = LA.eig m | 305 | (l,v) = LA.eig m |
306 | 306 | ||
307 | chol :: KnownNat n => Sym n -> Sq n | 307 | chol :: KnownNat n => Sym n -> Sq n |
308 | chol (extract . unSym -> m) = mkL $ LA.cholSH m | 308 | chol (extract . unSym -> m) = mkL $ LA.chol $ LA.trustSym m |
309 | 309 | ||
310 | -------------------------------------------------------------------------------- | 310 | -------------------------------------------------------------------------------- |
311 | 311 | ||
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs index 2ff1580..30480d7 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs | |||
@@ -127,8 +127,8 @@ expmTest2 = expm nd2 :~15~: (2><2) | |||
127 | mbCholTest = utest "mbCholTest" (ok1 && ok2) where | 127 | mbCholTest = utest "mbCholTest" (ok1 && ok2) where |
128 | m1 = (2><2) [2,5,5,8 :: Double] | 128 | m1 = (2><2) [2,5,5,8 :: Double] |
129 | m2 = (2><2) [3,5,5,9 :: Complex Double] | 129 | m2 = (2><2) [3,5,5,9 :: Complex Double] |
130 | ok1 = mbCholSH m1 == Nothing | 130 | ok1 = mbChol (trustSym m1) == Nothing |
131 | ok2 = mbCholSH m2 == Just (chol m2) | 131 | ok2 = mbChol (trustSym m2) == Just (chol $ trustSym m2) |
132 | 132 | ||
133 | --------------------------------------------------------------------- | 133 | --------------------------------------------------------------------- |
134 | 134 | ||
@@ -403,8 +403,8 @@ indexProp g f x = a1 == g a2 && a2 == a3 && b1 == g b2 && b2 == b3 | |||
403 | -------------------------------------------------------------------------------- | 403 | -------------------------------------------------------------------------------- |
404 | 404 | ||
405 | sliceTest = utest "slice test" $ and | 405 | sliceTest = utest "slice test" $ and |
406 | [ testSlice chol (gen 5 :: Matrix R) | 406 | [ testSlice (chol . trustSym) (gen 5 :: Matrix R) |
407 | , testSlice chol (gen 5 :: Matrix C) | 407 | , testSlice (chol . trustSym) (gen 5 :: Matrix C) |
408 | , testSlice qr (rec :: Matrix R) | 408 | , testSlice qr (rec :: Matrix R) |
409 | , testSlice qr (rec :: Matrix C) | 409 | , testSlice qr (rec :: Matrix C) |
410 | , testSlice hess (agen 5 :: Matrix R) | 410 | , testSlice hess (agen 5 :: Matrix R) |
@@ -420,12 +420,12 @@ sliceTest = utest "slice test" $ and | |||
420 | 420 | ||
421 | , testSlice eig (agen 5 :: Matrix R) | 421 | , testSlice eig (agen 5 :: Matrix R) |
422 | , testSlice eig (agen 5 :: Matrix C) | 422 | , testSlice eig (agen 5 :: Matrix C) |
423 | , testSlice eigSH (gen 5 :: Matrix R) | 423 | , testSlice (eigSH . trustSym) (gen 5 :: Matrix R) |
424 | , testSlice eigSH (gen 5 :: Matrix C) | 424 | , testSlice (eigSH . trustSym) (gen 5 :: Matrix C) |
425 | , testSlice eigenvalues (agen 5 :: Matrix R) | 425 | , testSlice eigenvalues (agen 5 :: Matrix R) |
426 | , testSlice eigenvalues (agen 5 :: Matrix C) | 426 | , testSlice eigenvalues (agen 5 :: Matrix C) |
427 | , testSlice eigenvaluesSH (gen 5 :: Matrix R) | 427 | , testSlice (eigenvaluesSH . trustSym) (gen 5 :: Matrix R) |
428 | , testSlice eigenvaluesSH (gen 5 :: Matrix C) | 428 | , testSlice (eigenvaluesSH . trustSym) (gen 5 :: Matrix C) |
429 | 429 | ||
430 | , testSlice svd (rec :: Matrix R) | 430 | , testSlice svd (rec :: Matrix R) |
431 | , testSlice thinSVD (rec :: Matrix R) | 431 | , testSlice thinSVD (rec :: Matrix R) |
@@ -489,10 +489,10 @@ sliceTest = utest "slice test" $ and | |||
489 | , testSlice ((<>) (ogen 5:: Matrix (Z ./. 7))) (gen 5) | 489 | , testSlice ((<>) (ogen 5:: Matrix (Z ./. 7))) (gen 5) |
490 | , testSlice (flip (<>) (gen 5:: Matrix (Z ./. 7))) (ogen 5) | 490 | , testSlice (flip (<>) (gen 5:: Matrix (Z ./. 7))) (ogen 5) |
491 | 491 | ||
492 | , testSlice (flip cholSolve (agen 5:: Matrix R)) (chol $ gen 5) | 492 | , testSlice (flip cholSolve (agen 5:: Matrix R)) (chol $ trustSym $ gen 5) |
493 | , testSlice (flip cholSolve (agen 5:: Matrix C)) (chol $ gen 5) | 493 | , testSlice (flip cholSolve (agen 5:: Matrix C)) (chol $ trustSym $ gen 5) |
494 | , testSlice (cholSolve (chol $ gen 5:: Matrix R)) (agen 5) | 494 | , testSlice (cholSolve (chol $ trustSym $ gen 5:: Matrix R)) (agen 5) |
495 | , testSlice (cholSolve (chol $ gen 5:: Matrix C)) (agen 5) | 495 | , testSlice (cholSolve (chol $ trustSym $ gen 5:: Matrix C)) (agen 5) |
496 | 496 | ||
497 | , ok_qrgr (rec :: Matrix R) | 497 | , ok_qrgr (rec :: Matrix R) |
498 | , ok_qrgr (rec :: Matrix C) | 498 | , ok_qrgr (rec :: Matrix C) |
@@ -515,8 +515,8 @@ sliceTest = utest "slice test" $ and | |||
515 | 515 | ||
516 | test_lus m = testSlice f lup | 516 | test_lus m = testSlice f lup |
517 | where | 517 | where |
518 | f x = luSolve (x,p) m | 518 | f x = luSolve (LU x p) m |
519 | (lup,p) = luPacked m | 519 | (LU lup p) = luPacked m |
520 | 520 | ||
521 | gen :: Numeric t => Int -> Matrix t | 521 | gen :: Numeric t => Int -> Matrix t |
522 | gen n = diagRect 1 (konst 5 n) n n | 522 | gen n = diagRect 1 (konst 5 n) n n |
@@ -588,11 +588,11 @@ runTests n = do | |||
588 | test (linearSolveProp (luSolve.luPacked) . rSqWC) | 588 | test (linearSolveProp (luSolve.luPacked) . rSqWC) |
589 | test (linearSolveProp (luSolve.luPacked) . cSqWC) | 589 | test (linearSolveProp (luSolve.luPacked) . cSqWC) |
590 | putStrLn "------ ldlSolve" | 590 | putStrLn "------ ldlSolve" |
591 | test (linearSolveProp (ldlSolve.ldlPacked) . rSymWC) | 591 | test (linearSolvePropH (ldlSolve.ldlPacked) . rSymWC) |
592 | test (linearSolveProp (ldlSolve.ldlPacked) . cSymWC) | 592 | test (linearSolvePropH (ldlSolve.ldlPacked) . cSymWC) |
593 | putStrLn "------ cholSolve" | 593 | putStrLn "------ cholSolve" |
594 | test (linearSolveProp (cholSolve.chol) . rPosDef) | 594 | test (linearSolveProp (cholSolve.chol.trustSym) . rPosDef) |
595 | test (linearSolveProp (cholSolve.chol) . cPosDef) | 595 | test (linearSolveProp (cholSolve.chol.trustSym) . cPosDef) |
596 | putStrLn "------ luSolveLS" | 596 | putStrLn "------ luSolveLS" |
597 | test (linearSolveProp linearSolveLS . rSqWC) | 597 | test (linearSolveProp linearSolveLS . rSqWC) |
598 | test (linearSolveProp linearSolveLS . cSqWC) | 598 | test (linearSolveProp linearSolveLS . cSqWC) |
@@ -865,8 +865,8 @@ eigBench = do | |||
865 | let m = reshape 1000 (randomVector 777 Uniform (1000*1000)) | 865 | let m = reshape 1000 (randomVector 777 Uniform (1000*1000)) |
866 | s = m + tr m | 866 | s = m + tr m |
867 | m `seq` s `seq` putStrLn "" | 867 | m `seq` s `seq` putStrLn "" |
868 | time "eigenvalues symmetric 1000x1000" (eigenvaluesSH' m) | 868 | time "eigenvalues symmetric 1000x1000" (eigenvaluesSH (trustSym m)) |
869 | time "eigenvectors symmetric 1000x1000" (snd $ eigSH' m) | 869 | time "eigenvectors symmetric 1000x1000" (snd $ eigSH (trustSym m)) |
870 | time "eigenvalues general 1000x1000" (eigenvalues m) | 870 | time "eigenvalues general 1000x1000" (eigenvalues m) |
871 | time "eigenvectors general 1000x1000" (snd $ eig m) | 871 | time "eigenvectors general 1000x1000" (snd $ eig m) |
872 | 872 | ||
@@ -893,12 +893,14 @@ solveBenchN n = do | |||
893 | time ("svd solve " ++ show n) (linearSolveSVD a b) | 893 | time ("svd solve " ++ show n) (linearSolveSVD a b) |
894 | time (" ls solve " ++ show n) (linearSolveLS a b) | 894 | time (" ls solve " ++ show n) (linearSolveLS a b) |
895 | time (" solve " ++ show n) (linearSolve a b) | 895 | time (" solve " ++ show n) (linearSolve a b) |
896 | time ("cholSolve " ++ show n) (cholSolve (chol a) b) | 896 | -- time (" LU solve " ++ show n) (luSolve (luPacked a) b) |
897 | time ("LDL solve " ++ show n) (ldlSolve (ldlPacked (trustSym a)) b) | ||
898 | time ("cholSolve " ++ show n) (cholSolve (chol $ trustSym a) b) | ||
897 | 899 | ||
898 | solveBench = do | 900 | solveBench = do |
899 | solveBenchN 500 | 901 | solveBenchN 500 |
900 | solveBenchN 1000 | 902 | solveBenchN 1000 |
901 | -- solveBenchN 1500 | 903 | solveBenchN 1500 |
902 | 904 | ||
903 | -------------------------------- | 905 | -------------------------------- |
904 | 906 | ||
@@ -906,7 +908,7 @@ cholBenchN n = do | |||
906 | let x = uniformSample 777 (2*n) (replicate n (-1,1)) | 908 | let x = uniformSample 777 (2*n) (replicate n (-1,1)) |
907 | a = tr x <> x | 909 | a = tr x <> x |
908 | a `seq` putStr "" | 910 | a `seq` putStr "" |
909 | time ("chol " ++ show n) (chol a) | 911 | time ("chol " ++ show n) (chol $ trustSym a) |
910 | 912 | ||
911 | cholBench = do | 913 | cholBench = do |
912 | putStrLn "" | 914 | putStrLn "" |
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs index 7c54535..4704989 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs | |||
@@ -14,7 +14,7 @@ Arbitrary instances for vectors, matrices. | |||
14 | module Numeric.LinearAlgebra.Tests.Instances( | 14 | module Numeric.LinearAlgebra.Tests.Instances( |
15 | Sq(..), rSq,cSq, | 15 | Sq(..), rSq,cSq, |
16 | Rot(..), rRot,cRot, | 16 | Rot(..), rRot,cRot, |
17 | Her(..), rHer,cHer, | 17 | rHer,cHer, |
18 | WC(..), rWC,cWC, | 18 | WC(..), rWC,cWC, |
19 | SqWC(..), rSqWC, cSqWC, rSymWC, cSymWC, | 19 | SqWC(..), rSqWC, cSqWC, rSymWC, cSymWC, |
20 | PosDef(..), rPosDef, cPosDef, | 20 | PosDef(..), rPosDef, cPosDef, |
@@ -81,12 +81,12 @@ instance (Field a, Arbitrary a) => Arbitrary (Rot a) where | |||
81 | 81 | ||
82 | 82 | ||
83 | -- a complex hermitian or real symmetric matrix | 83 | -- a complex hermitian or real symmetric matrix |
84 | newtype (Her a) = Her (Matrix a) deriving Show | 84 | --newtype (Her a) = Her (Matrix a) deriving Show |
85 | instance (Field a, Arbitrary a, Num (Vector a)) => Arbitrary (Her a) where | 85 | instance (Field a, Arbitrary a, Num (Vector a)) => Arbitrary (Her a) where |
86 | arbitrary = do | 86 | arbitrary = do |
87 | Sq m <- arbitrary | 87 | Sq m <- arbitrary |
88 | let m' = m/2 | 88 | let m' = m/2 |
89 | return $ Her (m' + tr m') | 89 | return $ sym m' |
90 | 90 | ||
91 | 91 | ||
92 | class (Field a, Arbitrary a, Element (RealOf a), Random (RealOf a)) => ArbitraryField a | 92 | class (Field a, Arbitrary a, Element (RealOf a), Random (RealOf a)) => ArbitraryField a |
@@ -125,9 +125,9 @@ newtype (PosDef a) = PosDef (Matrix a) deriving Show | |||
125 | instance (Numeric a, ArbitraryField a, Num (Vector a)) | 125 | instance (Numeric a, ArbitraryField a, Num (Vector a)) |
126 | => Arbitrary (PosDef a) where | 126 | => Arbitrary (PosDef a) where |
127 | arbitrary = do | 127 | arbitrary = do |
128 | Her m <- arbitrary | 128 | m <- arbitrary |
129 | let (_,v) = eigSH m | 129 | let (_,v) = eigSH m |
130 | n = rows m | 130 | n = rows (her m) |
131 | l <- replicateM n (choose (0,100)) | 131 | l <- replicateM n (choose (0,100)) |
132 | let s = diag (fromList l) | 132 | let s = diag (fromList l) |
133 | p = v <> real s <> tr v | 133 | p = v <> real s <> tr v |
@@ -161,8 +161,8 @@ fM m = m :: FM | |||
161 | zM m = m :: ZM | 161 | zM m = m :: ZM |
162 | 162 | ||
163 | 163 | ||
164 | rHer (Her m) = m :: RM | 164 | rHer m = her m :: RM |
165 | cHer (Her m) = m :: CM | 165 | cHer m = her m :: CM |
166 | 166 | ||
167 | rRot (Rot m) = m :: RM | 167 | rRot (Rot m) = m :: RM |
168 | cRot (Rot m) = m :: CM | 168 | cRot (Rot m) = m :: CM |
@@ -176,8 +176,8 @@ cWC (WC m) = m :: CM | |||
176 | rSqWC (SqWC m) = m :: RM | 176 | rSqWC (SqWC m) = m :: RM |
177 | cSqWC (SqWC m) = m :: CM | 177 | cSqWC (SqWC m) = m :: CM |
178 | 178 | ||
179 | rSymWC (SqWC m) = m + tr m :: RM | 179 | rSymWC (SqWC m) = sym m :: Her R |
180 | cSymWC (SqWC m) = m + tr m :: CM | 180 | cSymWC (SqWC m) = sym m :: Her C |
181 | 181 | ||
182 | rPosDef (PosDef m) = m :: RM | 182 | rPosDef (PosDef m) = m :: RM |
183 | cPosDef (PosDef m) = m :: CM | 183 | cPosDef (PosDef m) = m :: CM |
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs index 207a303..2ac3588 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs | |||
@@ -39,7 +39,7 @@ module Numeric.LinearAlgebra.Tests.Properties ( | |||
39 | expmDiagProp, | 39 | expmDiagProp, |
40 | multProp1, multProp2, | 40 | multProp1, multProp2, |
41 | subProp, | 41 | subProp, |
42 | linearSolveProp, linearSolveProp2 | 42 | linearSolveProp, linearSolvePropH, linearSolveProp2 |
43 | ) where | 43 | ) where |
44 | 44 | ||
45 | import Numeric.LinearAlgebra.HMatrix hiding (Testable,unitary) | 45 | import Numeric.LinearAlgebra.HMatrix hiding (Testable,unitary) |
@@ -209,11 +209,11 @@ eigProp m = complex m <> v |~| v <> diag s | |||
209 | eigSHProp m = m <> v |~| v <> real (diag s) | 209 | eigSHProp m = m <> v |~| v <> real (diag s) |
210 | && unitary v | 210 | && unitary v |
211 | && m |~| v <> real (diag s) <> tr v | 211 | && m |~| v <> real (diag s) <> tr v |
212 | where (s, v) = eigSH m | 212 | where (s, v) = eigSH' m |
213 | 213 | ||
214 | eigProp2 m = fst (eig m) |~| eigenvalues m | 214 | eigProp2 m = fst (eig m) |~| eigenvalues m |
215 | 215 | ||
216 | eigSHProp2 m = fst (eigSH m) |~| eigenvaluesSH m | 216 | eigSHProp2 m = fst (eigSH' m) |~| eigenvaluesSH' m |
217 | 217 | ||
218 | ------------------------------------------------------------------ | 218 | ------------------------------------------------------------------ |
219 | 219 | ||
@@ -246,9 +246,9 @@ schurProp2 m = m |~| u <> s <> tr u && unitary u && upperHessenberg s -- fixme | |||
246 | where (u,s) = schur m | 246 | where (u,s) = schur m |
247 | 247 | ||
248 | cholProp m = m |~| tr c <> c && upperTriang c | 248 | cholProp m = m |~| tr c <> c && upperTriang c |
249 | where c = chol m | 249 | where c = chol (trustSym m) |
250 | 250 | ||
251 | exactProp m = chol m == chol (m+0) | 251 | exactProp m = chol (trustSym m) == chol (trustSym (m+0)) |
252 | 252 | ||
253 | expmDiagProp m = expm (logm m) :~ 7 ~: complex m | 253 | expmDiagProp m = expm (logm m) :~ 7 ~: complex m |
254 | where logm = matFunc log | 254 | where logm = matFunc log |
@@ -263,6 +263,8 @@ multProp2 p (a,b) = (tr (a <> b)) :~p~: (tr b <> tr a) | |||
263 | 263 | ||
264 | linearSolveProp f m = f m m |~| ident (rows m) | 264 | linearSolveProp f m = f m m |~| ident (rows m) |
265 | 265 | ||
266 | linearSolvePropH f m = f m (her m) |~| ident (rows (her m)) | ||
267 | |||
266 | linearSolveProp2 f (a,x) = not wc `trivial` (not wc || a <> f a b |~| b) | 268 | linearSolveProp2 f (a,x) = not wc `trivial` (not wc || a <> f a b |~| b) |
267 | where q = min (rows a) (cols a) | 269 | where q = min (rows a) (cols a) |
268 | b = a <> x | 270 | b = a <> x |