diff options
author | Alberto Ruiz <aruiz@um.es> | 2017-03-22 13:53:01 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-03-22 13:53:01 +0100 |
commit | db8efa9f0d46ee21f0dacdfe35c0d966d91d751d (patch) | |
tree | cfda55f02230554e9e6ecd78039d1b6ab3857672 /packages/base/src | |
parent | ddae74c9f73a1d7fcb8ad00bb74ee77ac3d01086 (diff) | |
parent | 49d718705d205d62aea2762445f95735a671f305 (diff) |
Merge pull request #224 from idontgetoutmuch/master
Add tridiagonal solver and tests for it and triagonal solver.
Diffstat (limited to 'packages/base/src')
-rw-r--r-- | packages/base/src/Internal/Algorithms.hs | 76 | ||||
-rw-r--r-- | packages/base/src/Internal/C/lapack-aux.c | 72 | ||||
-rw-r--r-- | packages/base/src/Internal/LAPACK.hs | 22 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra.hs | 2 |
4 files changed, 170 insertions, 2 deletions
diff --git a/packages/base/src/Internal/Algorithms.hs b/packages/base/src/Internal/Algorithms.hs index 23a5e13..99c9e34 100644 --- a/packages/base/src/Internal/Algorithms.hs +++ b/packages/base/src/Internal/Algorithms.hs | |||
@@ -62,6 +62,7 @@ class (Numeric t, | |||
62 | linearSolve' :: Matrix t -> Matrix t -> Matrix t | 62 | linearSolve' :: Matrix t -> Matrix t -> Matrix t |
63 | cholSolve' :: Matrix t -> Matrix t -> Matrix t | 63 | cholSolve' :: Matrix t -> Matrix t -> Matrix t |
64 | triSolve' :: UpLo -> Matrix t -> Matrix t -> Matrix t | 64 | triSolve' :: UpLo -> Matrix t -> Matrix t -> Matrix t |
65 | triDiagSolve' :: Vector t -> Vector t -> Vector t -> Matrix t -> Matrix t | ||
65 | ldlPacked' :: Matrix t -> (Matrix t, [Int]) | 66 | ldlPacked' :: Matrix t -> (Matrix t, [Int]) |
66 | ldlSolve' :: (Matrix t, [Int]) -> Matrix t -> Matrix t | 67 | ldlSolve' :: (Matrix t, [Int]) -> Matrix t -> Matrix t |
67 | linearSolveSVD' :: Matrix t -> Matrix t -> Matrix t | 68 | linearSolveSVD' :: Matrix t -> Matrix t -> Matrix t |
@@ -88,6 +89,7 @@ instance Field Double where | |||
88 | mbLinearSolve' = mbLinearSolveR | 89 | mbLinearSolve' = mbLinearSolveR |
89 | cholSolve' = cholSolveR | 90 | cholSolve' = cholSolveR |
90 | triSolve' = triSolveR | 91 | triSolve' = triSolveR |
92 | triDiagSolve' = triDiagSolveR | ||
91 | linearSolveLS' = linearSolveLSR | 93 | linearSolveLS' = linearSolveLSR |
92 | linearSolveSVD' = linearSolveSVDR Nothing | 94 | linearSolveSVD' = linearSolveSVDR Nothing |
93 | eig' = eigR | 95 | eig' = eigR |
@@ -118,6 +120,7 @@ instance Field (Complex Double) where | |||
118 | mbLinearSolve' = mbLinearSolveC | 120 | mbLinearSolve' = mbLinearSolveC |
119 | cholSolve' = cholSolveC | 121 | cholSolve' = cholSolveC |
120 | triSolve' = triSolveC | 122 | triSolve' = triSolveC |
123 | triDiagSolve' = triDiagSolveC | ||
121 | linearSolveLS' = linearSolveLSC | 124 | linearSolveLS' = linearSolveLSC |
122 | linearSolveSVD' = linearSolveSVDC Nothing | 125 | linearSolveSVD' = linearSolveSVDC Nothing |
123 | eig' = eigC | 126 | eig' = eigC |
@@ -356,10 +359,79 @@ cholSolve | |||
356 | -> Matrix t -- ^ solution | 359 | -> Matrix t -- ^ solution |
357 | cholSolve = {-# SCC "cholSolve" #-} cholSolve' | 360 | cholSolve = {-# SCC "cholSolve" #-} cholSolve' |
358 | 361 | ||
359 | -- | Solve a triangular linear system. | 362 | -- | Solve a triangular linear system. If `Upper` is specified then |
360 | triSolve :: Field t => UpLo -> Matrix t -> Matrix t -> Matrix t | 363 | -- all elements below the diagonal are ignored; if `Lower` is |
364 | -- specified then all elements above the diagonal are ignored. | ||
365 | triSolve | ||
366 | :: Field t | ||
367 | => UpLo -- ^ `Lower` or `Upper` | ||
368 | -> Matrix t -- ^ coefficient matrix | ||
369 | -> Matrix t -- ^ right hand sides | ||
370 | -> Matrix t -- ^ solution | ||
361 | triSolve = {-# SCC "triSolve" #-} triSolve' | 371 | triSolve = {-# SCC "triSolve" #-} triSolve' |
362 | 372 | ||
373 | -- | Solve a tridiagonal linear system. Suppose you wish to solve \(Ax = b\) where | ||
374 | -- | ||
375 | -- \[ | ||
376 | -- A = | ||
377 | -- \begin{bmatrix} | ||
378 | -- 1.0 & 4.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 | ||
379 | -- \\ 3.0 & 1.0 & 4.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 | ||
380 | -- \\ 0.0 & 3.0 & 1.0 & 4.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 | ||
381 | -- \\ 0.0 & 0.0 & 3.0 & 1.0 & 4.0 & 0.0 & 0.0 & 0.0 & 0.0 | ||
382 | -- \\ 0.0 & 0.0 & 0.0 & 3.0 & 1.0 & 4.0 & 0.0 & 0.0 & 0.0 | ||
383 | -- \\ 0.0 & 0.0 & 0.0 & 0.0 & 3.0 & 1.0 & 4.0 & 0.0 & 0.0 | ||
384 | -- \\ 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 3.0 & 1.0 & 4.0 & 0.0 | ||
385 | -- \\ 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 3.0 & 1.0 & 4.0 | ||
386 | -- \\ 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 3.0 & 1.0 | ||
387 | -- \end{bmatrix} | ||
388 | -- \quad | ||
389 | -- b = | ||
390 | -- \begin{bmatrix} | ||
391 | -- 1.0 & 1.0 & 1.0 | ||
392 | -- \\ 1.0 & -1.0 & 2.0 | ||
393 | -- \\ 1.0 & 1.0 & 3.0 | ||
394 | -- \\ 1.0 & -1.0 & 4.0 | ||
395 | -- \\ 1.0 & 1.0 & 5.0 | ||
396 | -- \\ 1.0 & -1.0 & 6.0 | ||
397 | -- \\ 1.0 & 1.0 & 7.0 | ||
398 | -- \\ 1.0 & -1.0 & 8.0 | ||
399 | -- \\ 1.0 & 1.0 & 9.0 | ||
400 | -- \end{bmatrix} | ||
401 | -- \] | ||
402 | -- | ||
403 | -- then | ||
404 | -- | ||
405 | -- @ | ||
406 | -- dL = fromList [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0] | ||
407 | -- d = fromList [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] | ||
408 | -- dU = fromList [4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0] | ||
409 | -- | ||
410 | -- b = (9><3) | ||
411 | -- [ | ||
412 | -- 1.0, 1.0, 1.0, | ||
413 | -- 1.0, -1.0, 2.0, | ||
414 | -- 1.0, 1.0, 3.0, | ||
415 | -- 1.0, -1.0, 4.0, | ||
416 | -- 1.0, 1.0, 5.0, | ||
417 | -- 1.0, -1.0, 6.0, | ||
418 | -- 1.0, 1.0, 7.0, | ||
419 | -- 1.0, -1.0, 8.0, | ||
420 | -- 1.0, 1.0, 9.0 | ||
421 | -- ] | ||
422 | -- | ||
423 | -- x = triDiagSolve dL d dU b | ||
424 | -- @ | ||
425 | -- | ||
426 | triDiagSolve | ||
427 | :: Field t | ||
428 | => Vector t -- ^ lower diagonal: \(n - 1\) elements | ||
429 | -> Vector t -- ^ diagonal: \(n\) elements | ||
430 | -> Vector t -- ^ upper diagonal: \(n - 1\) elements | ||
431 | -> Matrix t -- ^ right hand sides | ||
432 | -> Matrix t -- ^ solution | ||
433 | triDiagSolve = {-# SCC "triDiagSolve" #-} triDiagSolve' | ||
434 | |||
363 | -- | Minimum norm solution of a general linear least squares problem Ax=B using the SVD. Admits rank-deficient systems but it is slower than 'linearSolveLS'. The effective rank of A is determined by treating as zero those singular valures which are less than 'eps' times the largest singular value. | 435 | -- | Minimum norm solution of a general linear least squares problem Ax=B using the SVD. Admits rank-deficient systems but it is slower than 'linearSolveLS'. The effective rank of A is determined by treating as zero those singular valures which are less than 'eps' times the largest singular value. |
364 | linearSolveSVD :: Field t => Matrix t -> Matrix t -> Matrix t | 436 | linearSolveSVD :: Field t => Matrix t -> Matrix t -> Matrix t |
365 | linearSolveSVD = {-# SCC "linearSolveSVD" #-} linearSolveSVD' | 437 | linearSolveSVD = {-# SCC "linearSolveSVD" #-} linearSolveSVD' |
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c index 4a8129c..5018a45 100644 --- a/packages/base/src/Internal/C/lapack-aux.c +++ b/packages/base/src/Internal/C/lapack-aux.c | |||
@@ -668,6 +668,78 @@ int triSolveC_l_l(KOCMAT(a),OCMAT(b)) { | |||
668 | OK | 668 | OK |
669 | } | 669 | } |
670 | 670 | ||
671 | //////// tridiagonal real linear system //////////// | ||
672 | |||
673 | int dgttrf_(integer *n, | ||
674 | doublereal *dl, doublereal *d, doublereal *du, doublereal *du2, | ||
675 | integer *ipiv, | ||
676 | integer *info); | ||
677 | |||
678 | int dgttrs_(char *trans, integer *n, integer *nrhs, | ||
679 | doublereal *dl, doublereal *d, doublereal *du, doublereal *du2, | ||
680 | integer *ipiv, doublereal *b, integer *ldb, | ||
681 | integer *info); | ||
682 | |||
683 | int triDiagSolveR_l(DVEC(dl), DVEC(d), DVEC(du), ODMAT(b)) { | ||
684 | integer n = dn; | ||
685 | integer nhrs = bc; | ||
686 | REQUIRES(n >= 1 && dln == dn - 1 && dun == dn - 1 && br == n, BAD_SIZE); | ||
687 | DEBUGMSG("triDiagSolveR_l"); | ||
688 | integer res; | ||
689 | integer* ipiv = (integer*)malloc(n*sizeof(integer)); | ||
690 | double* du2 = (double*)malloc((n - 2)*sizeof(double)); | ||
691 | dgttrf_ (&n, | ||
692 | dlp, dp, dup, du2, | ||
693 | ipiv, | ||
694 | &res); | ||
695 | CHECK(res,res); | ||
696 | dgttrs_ ("N", | ||
697 | &n,&nhrs, | ||
698 | dlp, dp, dup, du2, | ||
699 | ipiv, bp, &n, | ||
700 | &res); | ||
701 | CHECK(res,res); | ||
702 | free(ipiv); | ||
703 | free(du2); | ||
704 | OK | ||
705 | } | ||
706 | |||
707 | //////// tridiagonal complex linear system //////////// | ||
708 | |||
709 | int zgttrf_(integer *n, | ||
710 | doublecomplex *dl, doublecomplex *d, doublecomplex *du, doublecomplex *du2, | ||
711 | integer *ipiv, | ||
712 | integer *info); | ||
713 | |||
714 | int zgttrs_(char *trans, integer *n, integer *nrhs, | ||
715 | doublecomplex *dl, doublecomplex *d, doublecomplex *du, doublecomplex *du2, | ||
716 | integer *ipiv, doublecomplex *b, integer *ldb, | ||
717 | integer *info); | ||
718 | |||
719 | int triDiagSolveC_l(CVEC(dl), CVEC(d), CVEC(du), OCMAT(b)) { | ||
720 | integer n = dn; | ||
721 | integer nhrs = bc; | ||
722 | REQUIRES(n >= 1 && dln == dn - 1 && dun == dn - 1 && br == n, BAD_SIZE); | ||
723 | DEBUGMSG("triDiagSolveC_l"); | ||
724 | integer res; | ||
725 | integer* ipiv = (integer*)malloc(n*sizeof(integer)); | ||
726 | doublecomplex* du2 = (doublecomplex*)malloc((n - 2)*sizeof(doublecomplex)); | ||
727 | zgttrf_ (&n, | ||
728 | dlp, dp, dup, du2, | ||
729 | ipiv, | ||
730 | &res); | ||
731 | CHECK(res,res); | ||
732 | zgttrs_ ("N", | ||
733 | &n,&nhrs, | ||
734 | dlp, dp, dup, du2, | ||
735 | ipiv, bp, &n, | ||
736 | &res); | ||
737 | CHECK(res,res); | ||
738 | free(ipiv); | ||
739 | free(du2); | ||
740 | OK | ||
741 | } | ||
742 | |||
671 | //////////////////// least squares real linear system //////////// | 743 | //////////////////// least squares real linear system //////////// |
672 | 744 | ||
673 | int dgels_(char *trans, integer *m, integer *n, integer * | 745 | int dgels_(char *trans, integer *m, integer *n, integer * |
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index b4dd5cf..e306454 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs | |||
@@ -436,6 +436,28 @@ triSolveC :: UpLo -> Matrix (Complex Double) -> Matrix (Complex Double) -> Matri | |||
436 | triSolveC Lower a b = linearSolveTRAux2 id ztrtrs_l "triSolveC" (fmat a) b | 436 | triSolveC Lower a b = linearSolveTRAux2 id ztrtrs_l "triSolveC" (fmat a) b |
437 | triSolveC Upper a b = linearSolveTRAux2 id ztrtrs_u "triSolveC" (fmat a) b | 437 | triSolveC Upper a b = linearSolveTRAux2 id ztrtrs_u "triSolveC" (fmat a) b |
438 | 438 | ||
439 | -------------------------------------------------------------------------------- | ||
440 | foreign import ccall unsafe "triDiagSolveR_l" dgttrs :: R :> R :> R :> R ::> Ok | ||
441 | foreign import ccall unsafe "triDiagSolveC_l" zgttrs :: C :> C :> C :> C ::> Ok | ||
442 | |||
443 | linearSolveGTAux2 g f st dl d du b | ||
444 | | ndl == nd - 1 && | ||
445 | ndu == nd - 1 && | ||
446 | nd == r = unsafePerformIO . g $ do | ||
447 | s <- copy ColumnMajor b | ||
448 | (dl # d # du #! s) f #| st | ||
449 | return s | ||
450 | | otherwise = error $ st ++ " of nonsquare matrix" | ||
451 | where | ||
452 | ndl = dim dl | ||
453 | nd = dim d | ||
454 | ndu = dim du | ||
455 | r = rows b | ||
456 | |||
457 | -- | Solves a tridiagonal system of linear equations. | ||
458 | triDiagSolveR dl d du b = linearSolveGTAux2 id dgttrs "triDiagSolveR" dl d du b | ||
459 | triDiagSolveC dl d du b = linearSolveGTAux2 id zgttrs "triDiagSolveC" dl d du b | ||
460 | |||
439 | ----------------------------------------------------------------------------------- | 461 | ----------------------------------------------------------------------------------- |
440 | 462 | ||
441 | foreign import ccall unsafe "linearSolveLSR_l" dgels :: R ::> R ::> Ok | 463 | foreign import ccall unsafe "linearSolveLSR_l" dgels :: R ::> R ::> Ok |
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs index 869330c..fd100e0 100644 --- a/packages/base/src/Numeric/LinearAlgebra.hs +++ b/packages/base/src/Numeric/LinearAlgebra.hs | |||
@@ -97,6 +97,8 @@ module Numeric.LinearAlgebra ( | |||
97 | -- ** Triangular | 97 | -- ** Triangular |
98 | UpLo(..), | 98 | UpLo(..), |
99 | triSolve, | 99 | triSolve, |
100 | -- ** Tridiagonal | ||
101 | triDiagSolve, | ||
100 | -- ** Sparse | 102 | -- ** Sparse |
101 | cgSolve, | 103 | cgSolve, |
102 | cgSolve', | 104 | cgSolve', |