diff options
Diffstat (limited to 'packages/base/src/Internal/Algorithms.hs')
-rw-r--r-- | packages/base/src/Internal/Algorithms.hs | 138 |
1 files changed, 105 insertions, 33 deletions
diff --git a/packages/base/src/Internal/Algorithms.hs b/packages/base/src/Internal/Algorithms.hs index 3d25491..d2f17f4 100644 --- a/packages/base/src/Internal/Algorithms.hs +++ b/packages/base/src/Internal/Algorithms.hs | |||
@@ -4,6 +4,12 @@ | |||
4 | {-# LANGUAGE UndecidableInstances #-} | 4 | {-# LANGUAGE UndecidableInstances #-} |
5 | {-# LANGUAGE TypeFamilies #-} | 5 | {-# LANGUAGE TypeFamilies #-} |
6 | 6 | ||
7 | {-# LANGUAGE FlexibleContexts, FlexibleInstances #-} | ||
8 | {-# LANGUAGE CPP #-} | ||
9 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
10 | {-# LANGUAGE UndecidableInstances #-} | ||
11 | {-# LANGUAGE TypeFamilies #-} | ||
12 | |||
7 | ----------------------------------------------------------------------------- | 13 | ----------------------------------------------------------------------------- |
8 | {- | | 14 | {- | |
9 | Module : Internal.Algorithms | 15 | Module : Internal.Algorithms |
@@ -32,6 +38,7 @@ import Data.List(foldl1') | |||
32 | import qualified Data.Array as A | 38 | import qualified Data.Array as A |
33 | import Internal.ST | 39 | import Internal.ST |
34 | import Internal.Vectorized(range) | 40 | import Internal.Vectorized(range) |
41 | import Control.DeepSeq | ||
35 | 42 | ||
36 | {- | Generic linear algebra functions for double precision real and complex matrices. | 43 | {- | Generic linear algebra functions for double precision real and complex matrices. |
37 | 44 | ||
@@ -43,6 +50,10 @@ class (Numeric t, | |||
43 | Normed Matrix t, | 50 | Normed Matrix t, |
44 | Normed Vector t, | 51 | Normed Vector t, |
45 | Floating t, | 52 | Floating t, |
53 | Linear t Vector, | ||
54 | Linear t Matrix, | ||
55 | Additive (Vector t), | ||
56 | Additive (Matrix t), | ||
46 | RealOf t ~ Double) => Field t where | 57 | RealOf t ~ Double) => Field t where |
47 | svd' :: Matrix t -> (Matrix t, Vector Double, Matrix t) | 58 | svd' :: Matrix t -> (Matrix t, Vector Double, Matrix t) |
48 | thinSVD' :: Matrix t -> (Matrix t, Vector Double, Matrix t) | 59 | thinSVD' :: Matrix t -> (Matrix t, Vector Double, Matrix t) |
@@ -306,25 +317,38 @@ leftSV m | vertical m = let (u,s,_) = svd m in (u,s) | |||
306 | 317 | ||
307 | -------------------------------------------------------------- | 318 | -------------------------------------------------------------- |
308 | 319 | ||
320 | -- | LU decomposition of a matrix in a compact format. | ||
321 | data LU t = LU (Matrix t) [Int] deriving Show | ||
322 | |||
323 | instance (NFData t, Numeric t) => NFData (LU t) | ||
324 | where | ||
325 | rnf (LU m _) = rnf m | ||
326 | |||
309 | -- | Obtains the LU decomposition of a matrix in a compact data structure suitable for 'luSolve'. | 327 | -- | Obtains the LU decomposition of a matrix in a compact data structure suitable for 'luSolve'. |
310 | luPacked :: Field t => Matrix t -> (Matrix t, [Int]) | 328 | luPacked :: Field t => Matrix t -> LU t |
311 | luPacked = {-# SCC "luPacked" #-} luPacked' | 329 | luPacked x = {-# SCC "luPacked" #-} LU m p |
330 | where | ||
331 | (m,p) = luPacked' x | ||
312 | 332 | ||
313 | -- | Solution of a linear system (for several right hand sides) from the precomputed LU factorization obtained by 'luPacked'. | 333 | -- | Solution of a linear system (for several right hand sides) from the precomputed LU factorization obtained by 'luPacked'. |
314 | luSolve :: Field t => (Matrix t, [Int]) -> Matrix t -> Matrix t | 334 | luSolve :: Field t => LU t -> Matrix t -> Matrix t |
315 | luSolve = {-# SCC "luSolve" #-} luSolve' | 335 | luSolve (LU m p) = {-# SCC "luSolve" #-} luSolve' (m,p) |
316 | 336 | ||
317 | -- | Solve a linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition. For underconstrained or overconstrained systems use 'linearSolveLS' or 'linearSolveSVD'. | 337 | -- | Solve a linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition. For underconstrained or overconstrained systems use 'linearSolveLS' or 'linearSolveSVD'. |
318 | -- It is similar to 'luSolve' . 'luPacked', but @linearSolve@ raises an error if called on a singular system. | 338 | -- It is similar to 'luSolve' . 'luPacked', but @linearSolve@ raises an error if called on a singular system. |
319 | linearSolve :: Field t => Matrix t -> Matrix t -> Matrix t | 339 | linearSolve :: Field t => Matrix t -> Matrix t -> Matrix t |
320 | linearSolve = {-# SCC "linearSolve" #-} linearSolve' | 340 | linearSolve = {-# SCC "linearSolve" #-} linearSolve' |
321 | 341 | ||
322 | -- | Solve a linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition, returning Nothing for a singular system. For underconstrained or overconstrained systems use 'linearSolveLS' or 'linearSolveSVD'. | 342 | -- | Solve a linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition, returning Nothing for a singular system. For underconstrained or overconstrained systems use 'linearSolveLS' or 'linearSolveSVD'. |
323 | mbLinearSolve :: Field t => Matrix t -> Matrix t -> Maybe (Matrix t) | 343 | mbLinearSolve :: Field t => Matrix t -> Matrix t -> Maybe (Matrix t) |
324 | mbLinearSolve = {-# SCC "linearSolve" #-} mbLinearSolve' | 344 | mbLinearSolve = {-# SCC "linearSolve" #-} mbLinearSolve' |
325 | 345 | ||
326 | -- | Solve a symmetric or Hermitian positive definite linear system using a precomputed Cholesky decomposition obtained by 'chol'. | 346 | -- | Solve a symmetric or Hermitian positive definite linear system using a precomputed Cholesky decomposition obtained by 'chol'. |
327 | cholSolve :: Field t => Matrix t -> Matrix t -> Matrix t | 347 | cholSolve |
348 | :: Field t | ||
349 | => Matrix t -- ^ Cholesky decomposition of the coefficient matrix | ||
350 | -> Matrix t -- ^ right hand sides | ||
351 | -> Matrix t -- ^ solution | ||
328 | cholSolve = {-# SCC "cholSolve" #-} cholSolve' | 352 | cholSolve = {-# SCC "cholSolve" #-} cholSolve' |
329 | 353 | ||
330 | -- | Minimum norm solution of a general linear least squares problem Ax=B using the SVD. Admits rank-deficient systems but it is slower than 'linearSolveLS'. The effective rank of A is determined by treating as zero those singular valures which are less than 'eps' times the largest singular value. | 354 | -- | Minimum norm solution of a general linear least squares problem Ax=B using the SVD. Admits rank-deficient systems but it is slower than 'linearSolveLS'. The effective rank of A is determined by treating as zero those singular valures which are less than 'eps' times the largest singular value. |
@@ -338,20 +362,28 @@ linearSolveLS = {-# SCC "linearSolveLS" #-} linearSolveLS' | |||
338 | 362 | ||
339 | -------------------------------------------------------------------------------- | 363 | -------------------------------------------------------------------------------- |
340 | 364 | ||
365 | -- | LDL decomposition of a complex Hermitian or real symmetric matrix in a compact format. | ||
366 | data LDL t = LDL (Matrix t) [Int] deriving Show | ||
367 | |||
368 | instance (NFData t, Numeric t) => NFData (LDL t) | ||
369 | where | ||
370 | rnf (LDL m _) = rnf m | ||
371 | |||
341 | -- | Similar to 'ldlPacked', without checking that the input matrix is hermitian or symmetric. It works with the lower triangular part. | 372 | -- | Similar to 'ldlPacked', without checking that the input matrix is hermitian or symmetric. It works with the lower triangular part. |
342 | ldlPackedSH :: Field t => Matrix t -> (Matrix t, [Int]) | 373 | ldlPackedSH :: Field t => Matrix t -> LDL t |
343 | ldlPackedSH = {-# SCC "ldlPacked" #-} ldlPacked' | 374 | ldlPackedSH x = {-# SCC "ldlPacked" #-} LDL m p |
375 | where | ||
376 | (m,p) = ldlPacked' x | ||
344 | 377 | ||
345 | -- | Obtains the LDL decomposition of a matrix in a compact data structure suitable for 'ldlSolve'. | 378 | -- | Obtains the LDL decomposition of a matrix in a compact data structure suitable for 'ldlSolve'. |
346 | ldlPacked :: Field t => Matrix t -> (Matrix t, [Int]) | 379 | ldlPacked :: Field t => Her t -> LDL t |
347 | ldlPacked m | 380 | ldlPacked (Her m) = ldlPackedSH m |
348 | | exactHermitian m = {-# SCC "ldlPacked" #-} ldlPackedSH m | ||
349 | | otherwise = error "ldlPacked requires complex Hermitian or real symmetrix matrix" | ||
350 | |||
351 | 381 | ||
352 | -- | Solution of a linear system (for several right hand sides) from the precomputed LDL factorization obtained by 'ldlPacked'. | 382 | -- | Solution of a linear system (for several right hand sides) from a precomputed LDL factorization obtained by 'ldlPacked'. |
353 | ldlSolve :: Field t => (Matrix t, [Int]) -> Matrix t -> Matrix t | 383 | -- |
354 | ldlSolve = {-# SCC "ldlSolve" #-} ldlSolve' | 384 | -- Note: this can be slower than the general solver based on the LU decomposition. |
385 | ldlSolve :: Field t => LDL t -> Matrix t -> Matrix t | ||
386 | ldlSolve (LDL m p) = {-# SCC "ldlSolve" #-} ldlSolve' (m,p) | ||
355 | 387 | ||
356 | -------------------------------------------------------------- | 388 | -------------------------------------------------------------- |
357 | 389 | ||
@@ -429,14 +461,12 @@ fromList [11.344814282762075,0.17091518882717918,-0.5157294715892575] | |||
429 | 3.000 5.000 6.000 | 461 | 3.000 5.000 6.000 |
430 | 462 | ||
431 | -} | 463 | -} |
432 | eigSH :: Field t => Matrix t -> (Vector Double, Matrix t) | 464 | eigSH :: Field t => Her t -> (Vector Double, Matrix t) |
433 | eigSH m | exactHermitian m = eigSH' m | 465 | eigSH (Her m) = eigSH' m |
434 | | otherwise = error "eigSH requires complex hermitian or real symmetric matrix" | ||
435 | 466 | ||
436 | -- | Eigenvalues (in descending order) of a complex hermitian or real symmetric matrix. | 467 | -- | Eigenvalues (in descending order) of a complex hermitian or real symmetric matrix. |
437 | eigenvaluesSH :: Field t => Matrix t -> Vector Double | 468 | eigenvaluesSH :: Field t => Her t -> Vector Double |
438 | eigenvaluesSH m | exactHermitian m = eigenvaluesSH' m | 469 | eigenvaluesSH (Her m) = eigenvaluesSH' m |
439 | | otherwise = error "eigenvaluesSH requires complex hermitian or real symmetric matrix" | ||
440 | 470 | ||
441 | -------------------------------------------------------------- | 471 | -------------------------------------------------------------- |
442 | 472 | ||
@@ -490,14 +520,18 @@ mbCholSH = {-# SCC "mbCholSH" #-} mbCholSH' | |||
490 | 520 | ||
491 | -- | Similar to 'chol', without checking that the input matrix is hermitian or symmetric. It works with the upper triangular part. | 521 | -- | Similar to 'chol', without checking that the input matrix is hermitian or symmetric. It works with the upper triangular part. |
492 | cholSH :: Field t => Matrix t -> Matrix t | 522 | cholSH :: Field t => Matrix t -> Matrix t |
493 | cholSH = {-# SCC "cholSH" #-} cholSH' | 523 | cholSH = cholSH' |
494 | 524 | ||
495 | -- | Cholesky factorization of a positive definite hermitian or symmetric matrix. | 525 | -- | Cholesky factorization of a positive definite hermitian or symmetric matrix. |
496 | -- | 526 | -- |
497 | -- If @c = chol m@ then @c@ is upper triangular and @m == tr c \<> c@. | 527 | -- If @c = chol m@ then @c@ is upper triangular and @m == tr c \<> c@. |
498 | chol :: Field t => Matrix t -> Matrix t | 528 | chol :: Field t => Her t -> Matrix t |
499 | chol m | exactHermitian m = cholSH m | 529 | chol (Her m) = {-# SCC "chol" #-} cholSH' m |
500 | | otherwise = error "chol requires positive definite complex hermitian or real symmetric matrix" | 530 | |
531 | -- | Similar to 'chol', but instead of an error (e.g., caused by a matrix not positive definite) it returns 'Nothing'. | ||
532 | mbChol :: Field t => Her t -> Maybe (Matrix t) | ||
533 | mbChol (Her m) = {-# SCC "mbChol" #-} mbCholSH' m | ||
534 | |||
501 | 535 | ||
502 | 536 | ||
503 | -- | Joint computation of inverse and logarithm of determinant of a square matrix. | 537 | -- | Joint computation of inverse and logarithm of determinant of a square matrix. |
@@ -507,7 +541,7 @@ invlndet :: Field t | |||
507 | invlndet m | square m = (im,(ladm,sdm)) | 541 | invlndet m | square m = (im,(ladm,sdm)) |
508 | | otherwise = error $ "invlndet of nonsquare "++ shSize m ++ " matrix" | 542 | | otherwise = error $ "invlndet of nonsquare "++ shSize m ++ " matrix" |
509 | where | 543 | where |
510 | lp@(lup,perm) = luPacked m | 544 | lp@(LU lup perm) = luPacked m |
511 | s = signlp (rows m) perm | 545 | s = signlp (rows m) perm |
512 | dg = toList $ takeDiag $ lup | 546 | dg = toList $ takeDiag $ lup |
513 | ladm = sum $ map (log.abs) dg | 547 | ladm = sum $ map (log.abs) dg |
@@ -519,8 +553,9 @@ invlndet m | square m = (im,(ladm,sdm)) | |||
519 | det :: Field t => Matrix t -> t | 553 | det :: Field t => Matrix t -> t |
520 | det m | square m = {-# SCC "det" #-} s * (product $ toList $ takeDiag $ lup) | 554 | det m | square m = {-# SCC "det" #-} s * (product $ toList $ takeDiag $ lup) |
521 | | otherwise = error $ "det of nonsquare "++ shSize m ++ " matrix" | 555 | | otherwise = error $ "det of nonsquare "++ shSize m ++ " matrix" |
522 | where (lup,perm) = luPacked m | 556 | where |
523 | s = signlp (rows m) perm | 557 | LU lup perm = luPacked m |
558 | s = signlp (rows m) perm | ||
524 | 559 | ||
525 | -- | Explicit LU factorization of a general matrix. | 560 | -- | Explicit LU factorization of a general matrix. |
526 | -- | 561 | -- |
@@ -720,7 +755,7 @@ diagonalize m = if rank v == n | |||
720 | else Nothing | 755 | else Nothing |
721 | where n = rows m | 756 | where n = rows m |
722 | (l,v) = if exactHermitian m | 757 | (l,v) = if exactHermitian m |
723 | then let (l',v') = eigSH m in (real l', v') | 758 | then let (l',v') = eigSH (trustSym m) in (real l', v') |
724 | else eig m | 759 | else eig m |
725 | 760 | ||
726 | -- | Generic matrix functions for diagonalizable matrices. For instance: | 761 | -- | Generic matrix functions for diagonalizable matrices. For instance: |
@@ -835,8 +870,9 @@ fixPerm' s = res $ mutable f s0 | |||
835 | triang r c h v = (r><c) [el s t | s<-[0..r-1], t<-[0..c-1]] | 870 | triang r c h v = (r><c) [el s t | s<-[0..r-1], t<-[0..c-1]] |
836 | where el p q = if q-p>=h then v else 1 - v | 871 | where el p q = if q-p>=h then v else 1 - v |
837 | 872 | ||
838 | luFact (l_u,perm) | r <= c = (l ,u ,p, s) | 873 | luFact (LU l_u perm) |
839 | | otherwise = (l',u',p, s) | 874 | | r <= c = (l ,u ,p, s) |
875 | | otherwise = (l',u',p, s) | ||
840 | where | 876 | where |
841 | r = rows l_u | 877 | r = rows l_u |
842 | c = cols l_u | 878 | c = cols l_u |
@@ -929,7 +965,13 @@ relativeError norm a b = r | |||
929 | ---------------------------------------------------------------------- | 965 | ---------------------------------------------------------------------- |
930 | 966 | ||
931 | -- | Generalized symmetric positive definite eigensystem Av = lBv, | 967 | -- | Generalized symmetric positive definite eigensystem Av = lBv, |
932 | -- for A and B symmetric, B positive definite (conditions not checked). | 968 | -- for A and B symmetric, B positive definite. |
969 | geigSH :: Field t | ||
970 | => Her t -- ^ A | ||
971 | -> Her t -- ^ B | ||
972 | -> (Vector Double, Matrix t) | ||
973 | geigSH (Her a) (Her b) = geigSH' a b | ||
974 | |||
933 | geigSH' :: Field t | 975 | geigSH' :: Field t |
934 | => Matrix t -- ^ A | 976 | => Matrix t -- ^ A |
935 | -> Matrix t -- ^ B | 977 | -> Matrix t -- ^ B |
@@ -943,3 +985,33 @@ geigSH' a b = (l,v') | |||
943 | v' = iu <> v | 985 | v' = iu <> v |
944 | (<>) = mXm | 986 | (<>) = mXm |
945 | 987 | ||
988 | -------------------------------------------------------------------------------- | ||
989 | |||
990 | -- | A matrix that, by construction, it is known to be complex Hermitian or real symmetric. | ||
991 | -- | ||
992 | -- It can be created using 'sym', 'xTx', or 'trustSym', and the matrix can be extracted using 'her'. | ||
993 | data Her t = Her (Matrix t) deriving Show | ||
994 | |||
995 | -- | Extract the general matrix from a 'Her' structure, forgetting its symmetric or Hermitian property. | ||
996 | her :: Her t -> Matrix t | ||
997 | her (Her x) = x | ||
998 | |||
999 | -- | Compute the complex Hermitian or real symmetric part of a square matrix (@(x + tr x)/2@). | ||
1000 | sym :: Field t => Matrix t -> Her t | ||
1001 | sym x = Her (scale 0.5 (tr x `add` x)) | ||
1002 | |||
1003 | -- | Compute the contraction @tr x <> x@ of a general matrix. | ||
1004 | xTx :: Numeric t => Matrix t -> Her t | ||
1005 | xTx x = Her (tr x `mXm` x) | ||
1006 | |||
1007 | instance Field t => Linear t Her where | ||
1008 | scale x (Her m) = Her (scale x m) | ||
1009 | |||
1010 | instance Field t => Additive (Her t) where | ||
1011 | add (Her a) (Her b) = Her (a `add` b) | ||
1012 | |||
1013 | -- | At your own risk, declare that a matrix is complex Hermitian or real symmetric | ||
1014 | -- for usage in 'chol', 'eigSH', etc. Only a triangular part of the matrix will be used. | ||
1015 | trustSym :: Matrix t -> Her t | ||
1016 | trustSym x = (Her x) | ||
1017 | |||