summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2017-03-22 13:53:01 +0100
committerGitHub <noreply@github.com>2017-03-22 13:53:01 +0100
commitdb8efa9f0d46ee21f0dacdfe35c0d966d91d751d (patch)
treecfda55f02230554e9e6ecd78039d1b6ab3857672
parentddae74c9f73a1d7fcb8ad00bb74ee77ac3d01086 (diff)
parent49d718705d205d62aea2762445f95735a671f305 (diff)
Merge pull request #224 from idontgetoutmuch/master
Add tridiagonal solver and tests for it and triagonal solver.
-rw-r--r--packages/base/src/Internal/Algorithms.hs76
-rw-r--r--packages/base/src/Internal/C/lapack-aux.c72
-rw-r--r--packages/base/src/Internal/LAPACK.hs22
-rw-r--r--packages/base/src/Numeric/LinearAlgebra.hs2
-rw-r--r--packages/tests/src/Numeric/LinearAlgebra/Tests.hs107
5 files changed, 277 insertions, 2 deletions
diff --git a/packages/base/src/Internal/Algorithms.hs b/packages/base/src/Internal/Algorithms.hs
index 23a5e13..99c9e34 100644
--- a/packages/base/src/Internal/Algorithms.hs
+++ b/packages/base/src/Internal/Algorithms.hs
@@ -62,6 +62,7 @@ class (Numeric t,
62 linearSolve' :: Matrix t -> Matrix t -> Matrix t 62 linearSolve' :: Matrix t -> Matrix t -> Matrix t
63 cholSolve' :: Matrix t -> Matrix t -> Matrix t 63 cholSolve' :: Matrix t -> Matrix t -> Matrix t
64 triSolve' :: UpLo -> Matrix t -> Matrix t -> Matrix t 64 triSolve' :: UpLo -> Matrix t -> Matrix t -> Matrix t
65 triDiagSolve' :: Vector t -> Vector t -> Vector t -> Matrix t -> Matrix t
65 ldlPacked' :: Matrix t -> (Matrix t, [Int]) 66 ldlPacked' :: Matrix t -> (Matrix t, [Int])
66 ldlSolve' :: (Matrix t, [Int]) -> Matrix t -> Matrix t 67 ldlSolve' :: (Matrix t, [Int]) -> Matrix t -> Matrix t
67 linearSolveSVD' :: Matrix t -> Matrix t -> Matrix t 68 linearSolveSVD' :: Matrix t -> Matrix t -> Matrix t
@@ -88,6 +89,7 @@ instance Field Double where
88 mbLinearSolve' = mbLinearSolveR 89 mbLinearSolve' = mbLinearSolveR
89 cholSolve' = cholSolveR 90 cholSolve' = cholSolveR
90 triSolve' = triSolveR 91 triSolve' = triSolveR
92 triDiagSolve' = triDiagSolveR
91 linearSolveLS' = linearSolveLSR 93 linearSolveLS' = linearSolveLSR
92 linearSolveSVD' = linearSolveSVDR Nothing 94 linearSolveSVD' = linearSolveSVDR Nothing
93 eig' = eigR 95 eig' = eigR
@@ -118,6 +120,7 @@ instance Field (Complex Double) where
118 mbLinearSolve' = mbLinearSolveC 120 mbLinearSolve' = mbLinearSolveC
119 cholSolve' = cholSolveC 121 cholSolve' = cholSolveC
120 triSolve' = triSolveC 122 triSolve' = triSolveC
123 triDiagSolve' = triDiagSolveC
121 linearSolveLS' = linearSolveLSC 124 linearSolveLS' = linearSolveLSC
122 linearSolveSVD' = linearSolveSVDC Nothing 125 linearSolveSVD' = linearSolveSVDC Nothing
123 eig' = eigC 126 eig' = eigC
@@ -356,10 +359,79 @@ cholSolve
356 -> Matrix t -- ^ solution 359 -> Matrix t -- ^ solution
357cholSolve = {-# SCC "cholSolve" #-} cholSolve' 360cholSolve = {-# SCC "cholSolve" #-} cholSolve'
358 361
359-- | Solve a triangular linear system. 362-- | Solve a triangular linear system. If `Upper` is specified then
360triSolve :: Field t => UpLo -> Matrix t -> Matrix t -> Matrix t 363-- all elements below the diagonal are ignored; if `Lower` is
364-- specified then all elements above the diagonal are ignored.
365triSolve
366 :: Field t
367 => UpLo -- ^ `Lower` or `Upper`
368 -> Matrix t -- ^ coefficient matrix
369 -> Matrix t -- ^ right hand sides
370 -> Matrix t -- ^ solution
361triSolve = {-# SCC "triSolve" #-} triSolve' 371triSolve = {-# SCC "triSolve" #-} triSolve'
362 372
373-- | Solve a tridiagonal linear system. Suppose you wish to solve \(Ax = b\) where
374--
375-- \[
376-- A =
377-- \begin{bmatrix}
378-- 1.0 & 4.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0
379-- \\ 3.0 & 1.0 & 4.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0
380-- \\ 0.0 & 3.0 & 1.0 & 4.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0
381-- \\ 0.0 & 0.0 & 3.0 & 1.0 & 4.0 & 0.0 & 0.0 & 0.0 & 0.0
382-- \\ 0.0 & 0.0 & 0.0 & 3.0 & 1.0 & 4.0 & 0.0 & 0.0 & 0.0
383-- \\ 0.0 & 0.0 & 0.0 & 0.0 & 3.0 & 1.0 & 4.0 & 0.0 & 0.0
384-- \\ 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 3.0 & 1.0 & 4.0 & 0.0
385-- \\ 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 3.0 & 1.0 & 4.0
386-- \\ 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 3.0 & 1.0
387-- \end{bmatrix}
388-- \quad
389-- b =
390-- \begin{bmatrix}
391-- 1.0 & 1.0 & 1.0
392-- \\ 1.0 & -1.0 & 2.0
393-- \\ 1.0 & 1.0 & 3.0
394-- \\ 1.0 & -1.0 & 4.0
395-- \\ 1.0 & 1.0 & 5.0
396-- \\ 1.0 & -1.0 & 6.0
397-- \\ 1.0 & 1.0 & 7.0
398-- \\ 1.0 & -1.0 & 8.0
399-- \\ 1.0 & 1.0 & 9.0
400-- \end{bmatrix}
401-- \]
402--
403-- then
404--
405-- @
406-- dL = fromList [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0]
407-- d = fromList [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
408-- dU = fromList [4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0]
409--
410-- b = (9><3)
411-- [
412-- 1.0, 1.0, 1.0,
413-- 1.0, -1.0, 2.0,
414-- 1.0, 1.0, 3.0,
415-- 1.0, -1.0, 4.0,
416-- 1.0, 1.0, 5.0,
417-- 1.0, -1.0, 6.0,
418-- 1.0, 1.0, 7.0,
419-- 1.0, -1.0, 8.0,
420-- 1.0, 1.0, 9.0
421-- ]
422--
423-- x = triDiagSolve dL d dU b
424-- @
425--
426triDiagSolve
427 :: Field t
428 => Vector t -- ^ lower diagonal: \(n - 1\) elements
429 -> Vector t -- ^ diagonal: \(n\) elements
430 -> Vector t -- ^ upper diagonal: \(n - 1\) elements
431 -> Matrix t -- ^ right hand sides
432 -> Matrix t -- ^ solution
433triDiagSolve = {-# SCC "triDiagSolve" #-} triDiagSolve'
434
363-- | Minimum norm solution of a general linear least squares problem Ax=B using the SVD. Admits rank-deficient systems but it is slower than 'linearSolveLS'. The effective rank of A is determined by treating as zero those singular valures which are less than 'eps' times the largest singular value. 435-- | Minimum norm solution of a general linear least squares problem Ax=B using the SVD. Admits rank-deficient systems but it is slower than 'linearSolveLS'. The effective rank of A is determined by treating as zero those singular valures which are less than 'eps' times the largest singular value.
364linearSolveSVD :: Field t => Matrix t -> Matrix t -> Matrix t 436linearSolveSVD :: Field t => Matrix t -> Matrix t -> Matrix t
365linearSolveSVD = {-# SCC "linearSolveSVD" #-} linearSolveSVD' 437linearSolveSVD = {-# SCC "linearSolveSVD" #-} linearSolveSVD'
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c
index 4a8129c..5018a45 100644
--- a/packages/base/src/Internal/C/lapack-aux.c
+++ b/packages/base/src/Internal/C/lapack-aux.c
@@ -668,6 +668,78 @@ int triSolveC_l_l(KOCMAT(a),OCMAT(b)) {
668 OK 668 OK
669} 669}
670 670
671//////// tridiagonal real linear system ////////////
672
673int dgttrf_(integer *n,
674 doublereal *dl, doublereal *d, doublereal *du, doublereal *du2,
675 integer *ipiv,
676 integer *info);
677
678int dgttrs_(char *trans, integer *n, integer *nrhs,
679 doublereal *dl, doublereal *d, doublereal *du, doublereal *du2,
680 integer *ipiv, doublereal *b, integer *ldb,
681 integer *info);
682
683int triDiagSolveR_l(DVEC(dl), DVEC(d), DVEC(du), ODMAT(b)) {
684 integer n = dn;
685 integer nhrs = bc;
686 REQUIRES(n >= 1 && dln == dn - 1 && dun == dn - 1 && br == n, BAD_SIZE);
687 DEBUGMSG("triDiagSolveR_l");
688 integer res;
689 integer* ipiv = (integer*)malloc(n*sizeof(integer));
690 double* du2 = (double*)malloc((n - 2)*sizeof(double));
691 dgttrf_ (&n,
692 dlp, dp, dup, du2,
693 ipiv,
694 &res);
695 CHECK(res,res);
696 dgttrs_ ("N",
697 &n,&nhrs,
698 dlp, dp, dup, du2,
699 ipiv, bp, &n,
700 &res);
701 CHECK(res,res);
702 free(ipiv);
703 free(du2);
704 OK
705}
706
707//////// tridiagonal complex linear system ////////////
708
709int zgttrf_(integer *n,
710 doublecomplex *dl, doublecomplex *d, doublecomplex *du, doublecomplex *du2,
711 integer *ipiv,
712 integer *info);
713
714int zgttrs_(char *trans, integer *n, integer *nrhs,
715 doublecomplex *dl, doublecomplex *d, doublecomplex *du, doublecomplex *du2,
716 integer *ipiv, doublecomplex *b, integer *ldb,
717 integer *info);
718
719int triDiagSolveC_l(CVEC(dl), CVEC(d), CVEC(du), OCMAT(b)) {
720 integer n = dn;
721 integer nhrs = bc;
722 REQUIRES(n >= 1 && dln == dn - 1 && dun == dn - 1 && br == n, BAD_SIZE);
723 DEBUGMSG("triDiagSolveC_l");
724 integer res;
725 integer* ipiv = (integer*)malloc(n*sizeof(integer));
726 doublecomplex* du2 = (doublecomplex*)malloc((n - 2)*sizeof(doublecomplex));
727 zgttrf_ (&n,
728 dlp, dp, dup, du2,
729 ipiv,
730 &res);
731 CHECK(res,res);
732 zgttrs_ ("N",
733 &n,&nhrs,
734 dlp, dp, dup, du2,
735 ipiv, bp, &n,
736 &res);
737 CHECK(res,res);
738 free(ipiv);
739 free(du2);
740 OK
741}
742
671//////////////////// least squares real linear system //////////// 743//////////////////// least squares real linear system ////////////
672 744
673int dgels_(char *trans, integer *m, integer *n, integer * 745int dgels_(char *trans, integer *m, integer *n, integer *
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs
index b4dd5cf..e306454 100644
--- a/packages/base/src/Internal/LAPACK.hs
+++ b/packages/base/src/Internal/LAPACK.hs
@@ -436,6 +436,28 @@ triSolveC :: UpLo -> Matrix (Complex Double) -> Matrix (Complex Double) -> Matri
436triSolveC Lower a b = linearSolveTRAux2 id ztrtrs_l "triSolveC" (fmat a) b 436triSolveC Lower a b = linearSolveTRAux2 id ztrtrs_l "triSolveC" (fmat a) b
437triSolveC Upper a b = linearSolveTRAux2 id ztrtrs_u "triSolveC" (fmat a) b 437triSolveC Upper a b = linearSolveTRAux2 id ztrtrs_u "triSolveC" (fmat a) b
438 438
439--------------------------------------------------------------------------------
440foreign import ccall unsafe "triDiagSolveR_l" dgttrs :: R :> R :> R :> R ::> Ok
441foreign import ccall unsafe "triDiagSolveC_l" zgttrs :: C :> C :> C :> C ::> Ok
442
443linearSolveGTAux2 g f st dl d du b
444 | ndl == nd - 1 &&
445 ndu == nd - 1 &&
446 nd == r = unsafePerformIO . g $ do
447 s <- copy ColumnMajor b
448 (dl # d # du #! s) f #| st
449 return s
450 | otherwise = error $ st ++ " of nonsquare matrix"
451 where
452 ndl = dim dl
453 nd = dim d
454 ndu = dim du
455 r = rows b
456
457-- | Solves a tridiagonal system of linear equations.
458triDiagSolveR dl d du b = linearSolveGTAux2 id dgttrs "triDiagSolveR" dl d du b
459triDiagSolveC dl d du b = linearSolveGTAux2 id zgttrs "triDiagSolveC" dl d du b
460
439----------------------------------------------------------------------------------- 461-----------------------------------------------------------------------------------
440 462
441foreign import ccall unsafe "linearSolveLSR_l" dgels :: R ::> R ::> Ok 463foreign import ccall unsafe "linearSolveLSR_l" dgels :: R ::> R ::> Ok
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs
index 869330c..fd100e0 100644
--- a/packages/base/src/Numeric/LinearAlgebra.hs
+++ b/packages/base/src/Numeric/LinearAlgebra.hs
@@ -97,6 +97,8 @@ module Numeric.LinearAlgebra (
97 -- ** Triangular 97 -- ** Triangular
98 UpLo(..), 98 UpLo(..),
99 triSolve, 99 triSolve,
100 -- ** Tridiagonal
101 triDiagSolve,
100 -- ** Sparse 102 -- ** Sparse
101 cgSolve, 103 cgSolve,
102 cgSolve', 104 cgSolve',
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
index 043ebf3..55a5f74 100644
--- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
+++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
@@ -131,6 +131,111 @@ mbCholTest = utest "mbCholTest" (ok1 && ok2) where
131 ok1 = mbChol (trustSym m1) == Nothing 131 ok1 = mbChol (trustSym m1) == Nothing
132 ok2 = mbChol (trustSym m2) == Just (chol $ trustSym m2) 132 ok2 = mbChol (trustSym m2) == Just (chol $ trustSym m2)
133 133
134-----------------------------------------------------
135
136triTest = utest "triTest" ok1 where
137
138 a :: Matrix R
139 a = (4><4)
140 [
141 4.30, 0.00, 0.00, 0.00,
142 -3.96, -4.87, 0.00, 0.00,
143 0.40, 0.31, -8.02, 0.00,
144 -0.27, 0.07, -5.95, 0.12
145 ]
146
147 w :: Matrix R
148 w = (4><2)
149 [
150 -12.90, -21.50,
151 16.75, 14.93,
152 -17.55, 6.33,
153 -11.04, 8.09
154 ]
155
156 v :: Matrix R
157 v = triSolve Lower a w
158
159 e :: Matrix R
160 e = (4><2)
161 [
162 -3.0000, -5.0000,
163 -1.0000, 1.0000,
164 2.0000, -1.0000,
165 1.0000, 6.0000
166 ]
167
168 ok1 = (maximum $ map abs $ concat $ toLists $ e - v) <= 1e-14
169
170-----------------------------------------------------
171
172triDiagTest = utest "triDiagTest" (ok1 && ok2) where
173
174 dL, d, dU :: Vector Double
175 dL = fromList [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0]
176 d = fromList [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
177 dU = fromList [4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0]
178
179 b :: Matrix R
180 b = (9><3)
181 [
182 1.0, 1.0, 1.0,
183 1.0, -1.0, 2.0,
184 1.0, 1.0, 3.0,
185 1.0, -1.0, 4.0,
186 1.0, 1.0, 5.0,
187 1.0, -1.0, 6.0,
188 1.0, 1.0, 7.0,
189 1.0, -1.0, 8.0,
190 1.0, 1.0, 9.0
191 ]
192
193 y :: Matrix R
194 y = (9><9)
195 [
196 1.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
197 3.0, 1.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
198 0.0, 3.0, 1.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0,
199 0.0, 0.0, 3.0, 1.0, 4.0, 0.0, 0.0, 0.0, 0.0,
200 0.0, 0.0, 0.0, 3.0, 1.0, 4.0, 0.0, 0.0, 0.0,
201 0.0, 0.0, 0.0, 0.0, 3.0, 1.0, 4.0, 0.0, 0.0,
202 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 1.0, 4.0, 0.0,
203 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 1.0, 4.0,
204 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 1.0
205 ]
206
207 x :: Matrix R
208 x = triDiagSolve dL d dU b
209
210 z :: Matrix C
211 z = (4><4)
212 [
213 1.0 :+ 1.0, 4.0 :+ 4.0, 0.0 :+ 0.0, 0.0 :+ 0.0,
214 3.0 :+ 3.0, 1.0 :+ 1.0, 4.0 :+ 4.0, 0.0 :+ 0.0,
215 0.0 :+ 0.0, 3.0 :+ 3.0, 1.0 :+ 1.0, 4.0 :+ 4.0,
216 0.0 :+ 0.0, 0.0 :+ 0.0, 3.0 :+ 3.0, 1.0 :+ 1.0
217 ]
218
219 zDL, zD, zDu :: Vector C
220 zDL = fromList [3.0 :+ 3.0, 3.0 :+ 3.0, 3.0 :+ 3.0]
221 zD = fromList [1.0 :+ 1.0, 1.0 :+ 1.0, 1.0 :+ 1.0, 1.0 :+ 1.0]
222 zDu = fromList [4.0 :+ 4.0, 4.0 :+ 4.0, 4.0 :+ 4.0]
223
224 zB :: Matrix C
225 zB = (4><3)
226 [
227 1.0 :+ 1.0, 1.0 :+ 1.0, 1.0 :+ (-1.0),
228 1.0 :+ 1.0, (-1.0) :+ (-1.0), 1.0 :+ (-1.0),
229 1.0 :+ 1.0, 1.0 :+ 1.0, 1.0 :+ (-1.0),
230 1.0 :+ 1.0, (-1.0) :+ (-1.0), 1.0 :+ (-1.0)
231 ]
232
233 u :: Matrix C
234 u = triDiagSolve zDL zD zDu zB
235
236 ok1 = (maximum $ map abs $ concat $ toLists $ b - (y <> x)) <= 1e-15
237 ok2 = (maximum $ map magnitude $ concat $ toLists $ zB - (z <> u)) <= 1e-15
238
134--------------------------------------------------------------------- 239---------------------------------------------------------------------
135 240
136randomTestGaussian = (unSym c) :~3~: unSym (snd (meanCov dat)) 241randomTestGaussian = (unSym c) :~3~: unSym (snd (meanCov dat))
@@ -715,6 +820,8 @@ runTests n = do
715 && rank ((2><3)[1,0,0,1,7*peps,0::Double]) == 2 820 && rank ((2><3)[1,0,0,1,7*peps,0::Double]) == 2
716 , utest "block" $ fromBlocks [[ident 3,0],[0,ident 4]] == (ident 7 :: CM) 821 , utest "block" $ fromBlocks [[ident 3,0],[0,ident 4]] == (ident 7 :: CM)
717 , mbCholTest 822 , mbCholTest
823 , triTest
824 , triDiagTest
718 , utest "offset" offsetTest 825 , utest "offset" offsetTest
719 , normsVTest 826 , normsVTest
720 , normsMTest 827 , normsMTest