diff options
Diffstat (limited to 'packages/base/src/Internal/LAPACK.hs')
-rw-r--r-- | packages/base/src/Internal/LAPACK.hs | 72 |
1 files changed, 35 insertions, 37 deletions
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index ce00c16..65deceb 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs | |||
@@ -455,29 +455,29 @@ mbCholS = unsafePerformIO . mbCatch . cholAux dpotrf "cholS" | |||
455 | 455 | ||
456 | type TMVM t = t ::> t :> t ::> Ok | 456 | type TMVM t = t ::> t :> t ::> Ok |
457 | 457 | ||
458 | foreign import ccall unsafe "qr_l_R" dgeqr2 :: TMVM R | 458 | foreign import ccall unsafe "qr_l_R" dgeqr2 :: R :> R ::> Ok |
459 | foreign import ccall unsafe "qr_l_C" zgeqr2 :: TMVM C | 459 | foreign import ccall unsafe "qr_l_C" zgeqr2 :: C :> C ::> Ok |
460 | 460 | ||
461 | -- | QR factorization of a real matrix, using LAPACK's /dgeqr2/. | 461 | -- | QR factorization of a real matrix, using LAPACK's /dgeqr2/. |
462 | qrR :: Matrix Double -> (Matrix Double, Vector Double) | 462 | qrR :: Matrix Double -> (Matrix Double, Vector Double) |
463 | qrR = qrAux dgeqr2 "qrR" . fmat | 463 | qrR = qrAux dgeqr2 "qrR" |
464 | 464 | ||
465 | -- | QR factorization of a complex matrix, using LAPACK's /zgeqr2/. | 465 | -- | QR factorization of a complex matrix, using LAPACK's /zgeqr2/. |
466 | qrC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) | 466 | qrC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) |
467 | qrC = qrAux zgeqr2 "qrC" . fmat | 467 | qrC = qrAux zgeqr2 "qrC" |
468 | 468 | ||
469 | qrAux f st a = unsafePerformIO $ do | 469 | qrAux f st a = unsafePerformIO $ do |
470 | r <- createMatrix ColumnMajor m n | 470 | r <- copy ColumnMajor a |
471 | tau <- createVector mn | 471 | tau <- createVector mn |
472 | f # a # tau # r #| st | 472 | f # tau # r #| st |
473 | return (r,tau) | 473 | return (r,tau) |
474 | where | 474 | where |
475 | m = rows a | 475 | m = rows a |
476 | n = cols a | 476 | n = cols a |
477 | mn = min m n | 477 | mn = min m n |
478 | 478 | ||
479 | foreign import ccall unsafe "c_dorgqr" dorgqr :: TMVM R | 479 | foreign import ccall unsafe "c_dorgqr" dorgqr :: R :> R ::> Ok |
480 | foreign import ccall unsafe "c_zungqr" zungqr :: TMVM C | 480 | foreign import ccall unsafe "c_zungqr" zungqr :: C :> C ::> Ok |
481 | 481 | ||
482 | -- | build rotation from reflectors | 482 | -- | build rotation from reflectors |
483 | qrgrR :: Int -> (Matrix Double, Vector Double) -> Matrix Double | 483 | qrgrR :: Int -> (Matrix Double, Vector Double) -> Matrix Double |
@@ -487,28 +487,28 @@ qrgrC :: Int -> (Matrix (Complex Double), Vector (Complex Double)) -> Matrix (Co | |||
487 | qrgrC = qrgrAux zungqr "qrgrC" | 487 | qrgrC = qrgrAux zungqr "qrgrC" |
488 | 488 | ||
489 | qrgrAux f st n (a, tau) = unsafePerformIO $ do | 489 | qrgrAux f st n (a, tau) = unsafePerformIO $ do |
490 | res <- createMatrix ColumnMajor (rows a) n | 490 | res <- copy ColumnMajor (sliceMatrix (0,0) (rows a,n) a) |
491 | f # (fmat a) # (subVector 0 n tau') # res #| st | 491 | f # (subVector 0 n tau') # res #| st |
492 | return res | 492 | return res |
493 | where | 493 | where |
494 | tau' = vjoin [tau, constantD 0 n] | 494 | tau' = vjoin [tau, constantD 0 n] |
495 | 495 | ||
496 | ----------------------------------------------------------------------------------- | 496 | ----------------------------------------------------------------------------------- |
497 | foreign import ccall unsafe "hess_l_R" dgehrd :: TMVM R | 497 | foreign import ccall unsafe "hess_l_R" dgehrd :: R :> R ::> Ok |
498 | foreign import ccall unsafe "hess_l_C" zgehrd :: TMVM C | 498 | foreign import ccall unsafe "hess_l_C" zgehrd :: C :> C ::> Ok |
499 | 499 | ||
500 | -- | Hessenberg factorization of a square real matrix, using LAPACK's /dgehrd/. | 500 | -- | Hessenberg factorization of a square real matrix, using LAPACK's /dgehrd/. |
501 | hessR :: Matrix Double -> (Matrix Double, Vector Double) | 501 | hessR :: Matrix Double -> (Matrix Double, Vector Double) |
502 | hessR = hessAux dgehrd "hessR" . fmat | 502 | hessR = hessAux dgehrd "hessR" |
503 | 503 | ||
504 | -- | Hessenberg factorization of a square complex matrix, using LAPACK's /zgehrd/. | 504 | -- | Hessenberg factorization of a square complex matrix, using LAPACK's /zgehrd/. |
505 | hessC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) | 505 | hessC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) |
506 | hessC = hessAux zgehrd "hessC" . fmat | 506 | hessC = hessAux zgehrd "hessC" |
507 | 507 | ||
508 | hessAux f st a = unsafePerformIO $ do | 508 | hessAux f st a = unsafePerformIO $ do |
509 | r <- createMatrix ColumnMajor m n | 509 | r <- copy ColumnMajor a |
510 | tau <- createVector (mn-1) | 510 | tau <- createVector (mn-1) |
511 | f # a # tau # r #| st | 511 | f # tau # r #| st |
512 | return (r,tau) | 512 | return (r,tau) |
513 | where | 513 | where |
514 | m = rows a | 514 | m = rows a |
@@ -516,28 +516,28 @@ hessAux f st a = unsafePerformIO $ do | |||
516 | mn = min m n | 516 | mn = min m n |
517 | 517 | ||
518 | ----------------------------------------------------------------------------------- | 518 | ----------------------------------------------------------------------------------- |
519 | foreign import ccall unsafe "schur_l_R" dgees :: R ::> R ::> R ::> Ok | 519 | foreign import ccall unsafe "schur_l_R" dgees :: R ::> R ::> Ok |
520 | foreign import ccall unsafe "schur_l_C" zgees :: C ::> C ::> C ::> Ok | 520 | foreign import ccall unsafe "schur_l_C" zgees :: C ::> C ::> Ok |
521 | 521 | ||
522 | -- | Schur factorization of a square real matrix, using LAPACK's /dgees/. | 522 | -- | Schur factorization of a square real matrix, using LAPACK's /dgees/. |
523 | schurR :: Matrix Double -> (Matrix Double, Matrix Double) | 523 | schurR :: Matrix Double -> (Matrix Double, Matrix Double) |
524 | schurR = schurAux dgees "schurR" . fmat | 524 | schurR = schurAux dgees "schurR" |
525 | 525 | ||
526 | -- | Schur factorization of a square complex matrix, using LAPACK's /zgees/. | 526 | -- | Schur factorization of a square complex matrix, using LAPACK's /zgees/. |
527 | schurC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix (Complex Double)) | 527 | schurC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix (Complex Double)) |
528 | schurC = schurAux zgees "schurC" . fmat | 528 | schurC = schurAux zgees "schurC" |
529 | 529 | ||
530 | schurAux f st a = unsafePerformIO $ do | 530 | schurAux f st a = unsafePerformIO $ do |
531 | u <- createMatrix ColumnMajor n n | 531 | u <- createMatrix ColumnMajor n n |
532 | s <- createMatrix ColumnMajor n n | 532 | s <- copy ColumnMajor a |
533 | f # a # u # s #| st | 533 | f # u # s #| st |
534 | return (u,s) | 534 | return (u,s) |
535 | where | 535 | where |
536 | n = rows a | 536 | n = rows a |
537 | 537 | ||
538 | ----------------------------------------------------------------------------------- | 538 | ----------------------------------------------------------------------------------- |
539 | foreign import ccall unsafe "lu_l_R" dgetrf :: TMVM R | 539 | foreign import ccall unsafe "lu_l_R" dgetrf :: R :> R ::> Ok |
540 | foreign import ccall unsafe "lu_l_C" zgetrf :: C ::> R :> C ::> Ok | 540 | foreign import ccall unsafe "lu_l_C" zgetrf :: R :> C ::> Ok |
541 | 541 | ||
542 | -- | LU factorization of a general real matrix, using LAPACK's /dgetrf/. | 542 | -- | LU factorization of a general real matrix, using LAPACK's /dgetrf/. |
543 | luR :: Matrix Double -> (Matrix Double, [Int]) | 543 | luR :: Matrix Double -> (Matrix Double, [Int]) |
@@ -548,9 +548,9 @@ luC :: Matrix (Complex Double) -> (Matrix (Complex Double), [Int]) | |||
548 | luC = luAux zgetrf "luC" . fmat | 548 | luC = luAux zgetrf "luC" . fmat |
549 | 549 | ||
550 | luAux f st a = unsafePerformIO $ do | 550 | luAux f st a = unsafePerformIO $ do |
551 | lu <- createMatrix ColumnMajor n m | 551 | lu <- copy ColumnMajor a |
552 | piv <- createVector (min n m) | 552 | piv <- createVector (min n m) |
553 | f # a # piv # lu #| st | 553 | f # piv # lu #| st |
554 | return (lu, map (pred.round) (toList piv)) | 554 | return (lu, map (pred.round) (toList piv)) |
555 | where | 555 | where |
556 | n = rows a | 556 | n = rows a |
@@ -558,10 +558,8 @@ luAux f st a = unsafePerformIO $ do | |||
558 | 558 | ||
559 | ----------------------------------------------------------------------------------- | 559 | ----------------------------------------------------------------------------------- |
560 | 560 | ||
561 | type Tlus t = t ::> Double :> t ::> t ::> Ok | 561 | foreign import ccall unsafe "luS_l_R" dgetrs :: R ::> R :> R ::> Ok |
562 | 562 | foreign import ccall unsafe "luS_l_C" zgetrs :: C ::> R :> C ::> Ok | |
563 | foreign import ccall unsafe "luS_l_R" dgetrs :: Tlus R | ||
564 | foreign import ccall unsafe "luS_l_C" zgetrs :: Tlus C | ||
565 | 563 | ||
566 | -- | Solve a real linear system from a precomputed LU decomposition ('luR'), using LAPACK's /dgetrs/. | 564 | -- | Solve a real linear system from a precomputed LU decomposition ('luR'), using LAPACK's /dgetrs/. |
567 | lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double | 565 | lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double |
@@ -573,13 +571,13 @@ lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv (fmat b) | |||
573 | 571 | ||
574 | lusAux f st a piv b | 572 | lusAux f st a piv b |
575 | | n1==n2 && n2==n =unsafePerformIO $ do | 573 | | n1==n2 && n2==n =unsafePerformIO $ do |
576 | x <- createMatrix ColumnMajor n m | 574 | x <- copy ColumnMajor b |
577 | f # a # piv' # b # x #| st | 575 | f # a # piv' # x #| st |
578 | return x | 576 | return x |
579 | | otherwise = error $ st ++ " on LU factorization of nonsquare matrix" | 577 | | otherwise = error $ st ++ " on LU factorization of nonsquare matrix" |
580 | where n1 = rows a | 578 | where |
581 | n2 = cols a | 579 | n1 = rows a |
582 | n = rows b | 580 | n2 = cols a |
583 | m = cols b | 581 | n = rows b |
584 | piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double | 582 | piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double |
585 | 583 | ||