diff options
author | Alberto Ruiz <aruiz@um.es> | 2008-10-22 12:59:18 +0000 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2008-10-22 12:59:18 +0000 |
commit | faeaf6d261b760e628c1e63551d822d16876c0cc (patch) | |
tree | 45e3e2d1460d72e1fd037e19d4470963b75cc00e /lib/Numeric/LinearAlgebra/Algorithms.hs | |
parent | 9d9b1274a522e1bf0c5dea210765a0368ebb74a5 (diff) |
-Wall
Diffstat (limited to 'lib/Numeric/LinearAlgebra/Algorithms.hs')
-rw-r--r-- | lib/Numeric/LinearAlgebra/Algorithms.hs | 38 |
1 files changed, 19 insertions, 19 deletions
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index fbefa68..45298b5 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs | |||
@@ -168,9 +168,9 @@ square m = rows m == cols m | |||
168 | 168 | ||
169 | -- | determinant of a square matrix, computed from the LU decomposition. | 169 | -- | determinant of a square matrix, computed from the LU decomposition. |
170 | det :: Field t => Matrix t -> t | 170 | det :: Field t => Matrix t -> t |
171 | det m | square m = s * (product $ toList $ takeDiag $ lu) | 171 | det m | square m = s * (product $ toList $ takeDiag $ lup) |
172 | | otherwise = error "det of nonsquare matrix" | 172 | | otherwise = error "det of nonsquare matrix" |
173 | where (lu,perm) = luPacked m | 173 | where (lup,perm) = luPacked m |
174 | s = signlp (rows m) perm | 174 | s = signlp (rows m) perm |
175 | 175 | ||
176 | -- | LU factorization of a general matrix using lapack's dgetrf or zgetrf. | 176 | -- | LU factorization of a general matrix using lapack's dgetrf or zgetrf. |
@@ -501,21 +501,21 @@ fixPerm r vals = (fromColumns $ elems res, sign) | |||
501 | s = toColumns (ident r) | 501 | s = toColumns (ident r) |
502 | (res,sign) = foldl swap (listArray (0,r-1) s, 1) (zip v vals) | 502 | (res,sign) = foldl swap (listArray (0,r-1) s, 1) (zip v vals) |
503 | 503 | ||
504 | triang r c h v = reshape c $ fromList [el i j | i<-[0..r-1], j<-[0..c-1]] | 504 | triang r c h v = (r><c) [el s t | s<-[0..r-1], t<-[0..c-1]] |
505 | where el i j = if j-i>=h then v else 1 - v | 505 | where el p q = if q-p>=h then v else 1 - v |
506 | 506 | ||
507 | luFact (lu,perm) | r <= c = (l ,u ,p, s) | 507 | luFact (l_u,perm) | r <= c = (l ,u ,p, s) |
508 | | otherwise = (l',u',p, s) | 508 | | otherwise = (l',u',p, s) |
509 | where | 509 | where |
510 | r = rows lu | 510 | r = rows l_u |
511 | c = cols lu | 511 | c = cols l_u |
512 | tu = triang r c 0 1 | 512 | tu = triang r c 0 1 |
513 | tl = triang r c 0 0 | 513 | tl = triang r c 0 0 |
514 | l = takeColumns r (lu |*| tl) |+| diagRect (constant 1 r) r r | 514 | l = takeColumns r (l_u |*| tl) |+| diagRect (constant 1 r) r r |
515 | u = lu |*| tu | 515 | u = l_u |*| tu |
516 | (p,s) = fixPerm r perm | 516 | (p,s) = fixPerm r perm |
517 | l' = (lu |*| tl) |+| diagRect (constant 1 c) r c | 517 | l' = (l_u |*| tl) |+| diagRect (constant 1 c) r c |
518 | u' = takeRows c (lu |*| tu) | 518 | u' = takeRows c (l_u |*| tu) |
519 | (|+|) = add | 519 | (|+|) = add |
520 | (|*|) = mul | 520 | (|*|) = mul |
521 | 521 | ||
@@ -572,8 +572,8 @@ kronecker a b = fromBlocks | |||
572 | -- reference multiply | 572 | -- reference multiply |
573 | --------------------------------------------------------------------- | 573 | --------------------------------------------------------------------- |
574 | 574 | ||
575 | mulH a b = fromLists [[ dot ai bj | bj <- toColumns b] | ai <- toRows a ] | 575 | mulH a b = fromLists [[ doth ai bj | bj <- toColumns b] | ai <- toRows a ] |
576 | where dot u v = sum $ zipWith (*) (toList u) (toList v) | 576 | where doth u v = sum $ zipWith (*) (toList u) (toList v) |
577 | 577 | ||
578 | ----------------------------------------------------------------------------------- | 578 | ----------------------------------------------------------------------------------- |
579 | -- workaround | 579 | -- workaround |
@@ -599,7 +599,7 @@ colMajor = CBLASOrder 102 | |||
599 | 599 | ||
600 | noTrans, trans', conjTrans :: CBLASTrans | 600 | noTrans, trans', conjTrans :: CBLASTrans |
601 | noTrans = CBLASTrans 111 | 601 | noTrans = CBLASTrans 111 |
602 | trans' = CBLASTrans 112 | 602 | trans' = CBLASTrans 112 |
603 | conjTrans = CBLASTrans 113 | 603 | conjTrans = CBLASTrans 113 |
604 | 604 | ||
605 | foreign import ccall "cblas.h cblas_dgemm" | 605 | foreign import ccall "cblas.h cblas_dgemm" |
@@ -608,12 +608,12 @@ foreign import ccall "cblas.h cblas_dgemm" | |||
608 | -> Ptr Double -> CInt -> IO () | 608 | -> Ptr Double -> CInt -> IO () |
609 | 609 | ||
610 | multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double | 610 | multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double |
611 | multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b) | 611 | multiplyR3 x y = multiply3 dgemm "cblas_dgemm" (fmat x) (fmat y) |
612 | where | 612 | where |
613 | multiply3 f st a b | 613 | multiply3 f st a b |
614 | | cols a == rows b = unsafePerformIO $ do | 614 | | cols a == rows b = unsafePerformIO $ do |
615 | s <- createMatrix ColumnMajor (rows a) (cols b) | 615 | s <- createMatrix ColumnMajor (rows a) (cols b) |
616 | let g ar ac ap br bc bp rr rc rp = f colMajor noTrans noTrans ar bc ac 1 ap ar bp br 0 rp rr >> return 0 | 616 | let g ar ac ap br bc bp rr _rc rp = f colMajor noTrans noTrans ar bc ac 1 ap ar bp br 0 rp rr >> return 0 |
617 | app3 g mat a mat b mat s st | 617 | app3 g mat a mat b mat s st |
618 | return s | 618 | return s |
619 | | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" | 619 | | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" |
@@ -625,14 +625,14 @@ foreign import ccall "cblas.h cblas_zgemm" | |||
625 | -> Ptr (Complex Double) -> CInt -> IO () | 625 | -> Ptr (Complex Double) -> CInt -> IO () |
626 | 626 | ||
627 | multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | 627 | multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) |
628 | multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat b) | 628 | multiplyC3 x y = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat x) (fmat y) |
629 | where | 629 | where |
630 | multiply3 f st a b | 630 | multiply3 f st a b |
631 | | cols a == rows b = do | 631 | | cols a == rows b = do |
632 | s <- createMatrix ColumnMajor (rows a) (cols b) | 632 | s <- createMatrix ColumnMajor (rows a) (cols b) |
633 | palpha <- new 1 | 633 | palpha <- new 1 |
634 | pbeta <- new 0 | 634 | pbeta <- new 0 |
635 | let g ar ac ap br bc bp rr rc rp = f colMajor noTrans noTrans ar bc ac palpha ap ar bp br pbeta rp rr >> return 0 | 635 | let g ar ac ap br bc bp rr _rc rp = f colMajor noTrans noTrans ar bc ac palpha ap ar bp br pbeta rp rr >> return 0 |
636 | app3 g mat a mat b mat s st | 636 | app3 g mat a mat b mat s st |
637 | free palpha | 637 | free palpha |
638 | free pbeta | 638 | free pbeta |