summaryrefslogtreecommitdiff
path: root/lib/Numeric/LinearAlgebra/Algorithms.hs
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Numeric/LinearAlgebra/Algorithms.hs')
-rw-r--r--lib/Numeric/LinearAlgebra/Algorithms.hs38
1 files changed, 19 insertions, 19 deletions
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs
index fbefa68..45298b5 100644
--- a/lib/Numeric/LinearAlgebra/Algorithms.hs
+++ b/lib/Numeric/LinearAlgebra/Algorithms.hs
@@ -168,9 +168,9 @@ square m = rows m == cols m
168 168
169-- | determinant of a square matrix, computed from the LU decomposition. 169-- | determinant of a square matrix, computed from the LU decomposition.
170det :: Field t => Matrix t -> t 170det :: Field t => Matrix t -> t
171det m | square m = s * (product $ toList $ takeDiag $ lu) 171det m | square m = s * (product $ toList $ takeDiag $ lup)
172 | otherwise = error "det of nonsquare matrix" 172 | otherwise = error "det of nonsquare matrix"
173 where (lu,perm) = luPacked m 173 where (lup,perm) = luPacked m
174 s = signlp (rows m) perm 174 s = signlp (rows m) perm
175 175
176-- | LU factorization of a general matrix using lapack's dgetrf or zgetrf. 176-- | LU factorization of a general matrix using lapack's dgetrf or zgetrf.
@@ -501,21 +501,21 @@ fixPerm r vals = (fromColumns $ elems res, sign)
501 s = toColumns (ident r) 501 s = toColumns (ident r)
502 (res,sign) = foldl swap (listArray (0,r-1) s, 1) (zip v vals) 502 (res,sign) = foldl swap (listArray (0,r-1) s, 1) (zip v vals)
503 503
504triang r c h v = reshape c $ fromList [el i j | i<-[0..r-1], j<-[0..c-1]] 504triang r c h v = (r><c) [el s t | s<-[0..r-1], t<-[0..c-1]]
505 where el i j = if j-i>=h then v else 1 - v 505 where el p q = if q-p>=h then v else 1 - v
506 506
507luFact (lu,perm) | r <= c = (l ,u ,p, s) 507luFact (l_u,perm) | r <= c = (l ,u ,p, s)
508 | otherwise = (l',u',p, s) 508 | otherwise = (l',u',p, s)
509 where 509 where
510 r = rows lu 510 r = rows l_u
511 c = cols lu 511 c = cols l_u
512 tu = triang r c 0 1 512 tu = triang r c 0 1
513 tl = triang r c 0 0 513 tl = triang r c 0 0
514 l = takeColumns r (lu |*| tl) |+| diagRect (constant 1 r) r r 514 l = takeColumns r (l_u |*| tl) |+| diagRect (constant 1 r) r r
515 u = lu |*| tu 515 u = l_u |*| tu
516 (p,s) = fixPerm r perm 516 (p,s) = fixPerm r perm
517 l' = (lu |*| tl) |+| diagRect (constant 1 c) r c 517 l' = (l_u |*| tl) |+| diagRect (constant 1 c) r c
518 u' = takeRows c (lu |*| tu) 518 u' = takeRows c (l_u |*| tu)
519 (|+|) = add 519 (|+|) = add
520 (|*|) = mul 520 (|*|) = mul
521 521
@@ -572,8 +572,8 @@ kronecker a b = fromBlocks
572-- reference multiply 572-- reference multiply
573--------------------------------------------------------------------- 573---------------------------------------------------------------------
574 574
575mulH a b = fromLists [[ dot ai bj | bj <- toColumns b] | ai <- toRows a ] 575mulH a b = fromLists [[ doth ai bj | bj <- toColumns b] | ai <- toRows a ]
576 where dot u v = sum $ zipWith (*) (toList u) (toList v) 576 where doth u v = sum $ zipWith (*) (toList u) (toList v)
577 577
578----------------------------------------------------------------------------------- 578-----------------------------------------------------------------------------------
579-- workaround 579-- workaround
@@ -599,7 +599,7 @@ colMajor = CBLASOrder 102
599 599
600noTrans, trans', conjTrans :: CBLASTrans 600noTrans, trans', conjTrans :: CBLASTrans
601noTrans = CBLASTrans 111 601noTrans = CBLASTrans 111
602trans' = CBLASTrans 112 602trans' = CBLASTrans 112
603conjTrans = CBLASTrans 113 603conjTrans = CBLASTrans 113
604 604
605foreign import ccall "cblas.h cblas_dgemm" 605foreign import ccall "cblas.h cblas_dgemm"
@@ -608,12 +608,12 @@ foreign import ccall "cblas.h cblas_dgemm"
608 -> Ptr Double -> CInt -> IO () 608 -> Ptr Double -> CInt -> IO ()
609 609
610multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double 610multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double
611multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b) 611multiplyR3 x y = multiply3 dgemm "cblas_dgemm" (fmat x) (fmat y)
612 where 612 where
613 multiply3 f st a b 613 multiply3 f st a b
614 | cols a == rows b = unsafePerformIO $ do 614 | cols a == rows b = unsafePerformIO $ do
615 s <- createMatrix ColumnMajor (rows a) (cols b) 615 s <- createMatrix ColumnMajor (rows a) (cols b)
616 let g ar ac ap br bc bp rr rc rp = f colMajor noTrans noTrans ar bc ac 1 ap ar bp br 0 rp rr >> return 0 616 let g ar ac ap br bc bp rr _rc rp = f colMajor noTrans noTrans ar bc ac 1 ap ar bp br 0 rp rr >> return 0
617 app3 g mat a mat b mat s st 617 app3 g mat a mat b mat s st
618 return s 618 return s
619 | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" 619 | otherwise = error $ st ++ " (matrix product) of nonconformant matrices"
@@ -625,14 +625,14 @@ foreign import ccall "cblas.h cblas_zgemm"
625 -> Ptr (Complex Double) -> CInt -> IO () 625 -> Ptr (Complex Double) -> CInt -> IO ()
626 626
627multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) 627multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
628multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat b) 628multiplyC3 x y = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat x) (fmat y)
629 where 629 where
630 multiply3 f st a b 630 multiply3 f st a b
631 | cols a == rows b = do 631 | cols a == rows b = do
632 s <- createMatrix ColumnMajor (rows a) (cols b) 632 s <- createMatrix ColumnMajor (rows a) (cols b)
633 palpha <- new 1 633 palpha <- new 1
634 pbeta <- new 0 634 pbeta <- new 0
635 let g ar ac ap br bc bp rr rc rp = f colMajor noTrans noTrans ar bc ac palpha ap ar bp br pbeta rp rr >> return 0 635 let g ar ac ap br bc bp rr _rc rp = f colMajor noTrans noTrans ar bc ac palpha ap ar bp br pbeta rp rr >> return 0
636 app3 g mat a mat b mat s st 636 app3 g mat a mat b mat s st
637 free palpha 637 free palpha
638 free pbeta 638 free pbeta