summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/LAPACK.hs
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-06-28 20:04:02 +0200
committerAlberto Ruiz <aruiz@um.es>2015-06-28 20:04:02 +0200
commitc5ed204b8d6a36681c7ec6b227c634bfae501435 (patch)
tree1255c2ec0f276054a6908a468b569036947bb586 /packages/base/src/Internal/LAPACK.hs
parent79fa0200e1d5500f994d88e39d6fddff907a85f8 (diff)
pass copied slice (qr, hess,schur,lu)
Diffstat (limited to 'packages/base/src/Internal/LAPACK.hs')
-rw-r--r--packages/base/src/Internal/LAPACK.hs72
1 files changed, 35 insertions, 37 deletions
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs
index ce00c16..65deceb 100644
--- a/packages/base/src/Internal/LAPACK.hs
+++ b/packages/base/src/Internal/LAPACK.hs
@@ -455,29 +455,29 @@ mbCholS = unsafePerformIO . mbCatch . cholAux dpotrf "cholS"
455 455
456type TMVM t = t ::> t :> t ::> Ok 456type TMVM t = t ::> t :> t ::> Ok
457 457
458foreign import ccall unsafe "qr_l_R" dgeqr2 :: TMVM R 458foreign import ccall unsafe "qr_l_R" dgeqr2 :: R :> R ::> Ok
459foreign import ccall unsafe "qr_l_C" zgeqr2 :: TMVM C 459foreign import ccall unsafe "qr_l_C" zgeqr2 :: C :> C ::> Ok
460 460
461-- | QR factorization of a real matrix, using LAPACK's /dgeqr2/. 461-- | QR factorization of a real matrix, using LAPACK's /dgeqr2/.
462qrR :: Matrix Double -> (Matrix Double, Vector Double) 462qrR :: Matrix Double -> (Matrix Double, Vector Double)
463qrR = qrAux dgeqr2 "qrR" . fmat 463qrR = qrAux dgeqr2 "qrR"
464 464
465-- | QR factorization of a complex matrix, using LAPACK's /zgeqr2/. 465-- | QR factorization of a complex matrix, using LAPACK's /zgeqr2/.
466qrC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) 466qrC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double))
467qrC = qrAux zgeqr2 "qrC" . fmat 467qrC = qrAux zgeqr2 "qrC"
468 468
469qrAux f st a = unsafePerformIO $ do 469qrAux f st a = unsafePerformIO $ do
470 r <- createMatrix ColumnMajor m n 470 r <- copy ColumnMajor a
471 tau <- createVector mn 471 tau <- createVector mn
472 f # a # tau # r #| st 472 f # tau # r #| st
473 return (r,tau) 473 return (r,tau)
474 where 474 where
475 m = rows a 475 m = rows a
476 n = cols a 476 n = cols a
477 mn = min m n 477 mn = min m n
478 478
479foreign import ccall unsafe "c_dorgqr" dorgqr :: TMVM R 479foreign import ccall unsafe "c_dorgqr" dorgqr :: R :> R ::> Ok
480foreign import ccall unsafe "c_zungqr" zungqr :: TMVM C 480foreign import ccall unsafe "c_zungqr" zungqr :: C :> C ::> Ok
481 481
482-- | build rotation from reflectors 482-- | build rotation from reflectors
483qrgrR :: Int -> (Matrix Double, Vector Double) -> Matrix Double 483qrgrR :: Int -> (Matrix Double, Vector Double) -> Matrix Double
@@ -487,28 +487,28 @@ qrgrC :: Int -> (Matrix (Complex Double), Vector (Complex Double)) -> Matrix (Co
487qrgrC = qrgrAux zungqr "qrgrC" 487qrgrC = qrgrAux zungqr "qrgrC"
488 488
489qrgrAux f st n (a, tau) = unsafePerformIO $ do 489qrgrAux f st n (a, tau) = unsafePerformIO $ do
490 res <- createMatrix ColumnMajor (rows a) n 490 res <- copy ColumnMajor (sliceMatrix (0,0) (rows a,n) a)
491 f # (fmat a) # (subVector 0 n tau') # res #| st 491 f # (subVector 0 n tau') # res #| st
492 return res 492 return res
493 where 493 where
494 tau' = vjoin [tau, constantD 0 n] 494 tau' = vjoin [tau, constantD 0 n]
495 495
496----------------------------------------------------------------------------------- 496-----------------------------------------------------------------------------------
497foreign import ccall unsafe "hess_l_R" dgehrd :: TMVM R 497foreign import ccall unsafe "hess_l_R" dgehrd :: R :> R ::> Ok
498foreign import ccall unsafe "hess_l_C" zgehrd :: TMVM C 498foreign import ccall unsafe "hess_l_C" zgehrd :: C :> C ::> Ok
499 499
500-- | Hessenberg factorization of a square real matrix, using LAPACK's /dgehrd/. 500-- | Hessenberg factorization of a square real matrix, using LAPACK's /dgehrd/.
501hessR :: Matrix Double -> (Matrix Double, Vector Double) 501hessR :: Matrix Double -> (Matrix Double, Vector Double)
502hessR = hessAux dgehrd "hessR" . fmat 502hessR = hessAux dgehrd "hessR"
503 503
504-- | Hessenberg factorization of a square complex matrix, using LAPACK's /zgehrd/. 504-- | Hessenberg factorization of a square complex matrix, using LAPACK's /zgehrd/.
505hessC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) 505hessC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double))
506hessC = hessAux zgehrd "hessC" . fmat 506hessC = hessAux zgehrd "hessC"
507 507
508hessAux f st a = unsafePerformIO $ do 508hessAux f st a = unsafePerformIO $ do
509 r <- createMatrix ColumnMajor m n 509 r <- copy ColumnMajor a
510 tau <- createVector (mn-1) 510 tau <- createVector (mn-1)
511 f # a # tau # r #| st 511 f # tau # r #| st
512 return (r,tau) 512 return (r,tau)
513 where 513 where
514 m = rows a 514 m = rows a
@@ -516,28 +516,28 @@ hessAux f st a = unsafePerformIO $ do
516 mn = min m n 516 mn = min m n
517 517
518----------------------------------------------------------------------------------- 518-----------------------------------------------------------------------------------
519foreign import ccall unsafe "schur_l_R" dgees :: R ::> R ::> R ::> Ok 519foreign import ccall unsafe "schur_l_R" dgees :: R ::> R ::> Ok
520foreign import ccall unsafe "schur_l_C" zgees :: C ::> C ::> C ::> Ok 520foreign import ccall unsafe "schur_l_C" zgees :: C ::> C ::> Ok
521 521
522-- | Schur factorization of a square real matrix, using LAPACK's /dgees/. 522-- | Schur factorization of a square real matrix, using LAPACK's /dgees/.
523schurR :: Matrix Double -> (Matrix Double, Matrix Double) 523schurR :: Matrix Double -> (Matrix Double, Matrix Double)
524schurR = schurAux dgees "schurR" . fmat 524schurR = schurAux dgees "schurR"
525 525
526-- | Schur factorization of a square complex matrix, using LAPACK's /zgees/. 526-- | Schur factorization of a square complex matrix, using LAPACK's /zgees/.
527schurC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix (Complex Double)) 527schurC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix (Complex Double))
528schurC = schurAux zgees "schurC" . fmat 528schurC = schurAux zgees "schurC"
529 529
530schurAux f st a = unsafePerformIO $ do 530schurAux f st a = unsafePerformIO $ do
531 u <- createMatrix ColumnMajor n n 531 u <- createMatrix ColumnMajor n n
532 s <- createMatrix ColumnMajor n n 532 s <- copy ColumnMajor a
533 f # a # u # s #| st 533 f # u # s #| st
534 return (u,s) 534 return (u,s)
535 where 535 where
536 n = rows a 536 n = rows a
537 537
538----------------------------------------------------------------------------------- 538-----------------------------------------------------------------------------------
539foreign import ccall unsafe "lu_l_R" dgetrf :: TMVM R 539foreign import ccall unsafe "lu_l_R" dgetrf :: R :> R ::> Ok
540foreign import ccall unsafe "lu_l_C" zgetrf :: C ::> R :> C ::> Ok 540foreign import ccall unsafe "lu_l_C" zgetrf :: R :> C ::> Ok
541 541
542-- | LU factorization of a general real matrix, using LAPACK's /dgetrf/. 542-- | LU factorization of a general real matrix, using LAPACK's /dgetrf/.
543luR :: Matrix Double -> (Matrix Double, [Int]) 543luR :: Matrix Double -> (Matrix Double, [Int])
@@ -548,9 +548,9 @@ luC :: Matrix (Complex Double) -> (Matrix (Complex Double), [Int])
548luC = luAux zgetrf "luC" . fmat 548luC = luAux zgetrf "luC" . fmat
549 549
550luAux f st a = unsafePerformIO $ do 550luAux f st a = unsafePerformIO $ do
551 lu <- createMatrix ColumnMajor n m 551 lu <- copy ColumnMajor a
552 piv <- createVector (min n m) 552 piv <- createVector (min n m)
553 f # a # piv # lu #| st 553 f # piv # lu #| st
554 return (lu, map (pred.round) (toList piv)) 554 return (lu, map (pred.round) (toList piv))
555 where 555 where
556 n = rows a 556 n = rows a
@@ -558,10 +558,8 @@ luAux f st a = unsafePerformIO $ do
558 558
559----------------------------------------------------------------------------------- 559-----------------------------------------------------------------------------------
560 560
561type Tlus t = t ::> Double :> t ::> t ::> Ok 561foreign import ccall unsafe "luS_l_R" dgetrs :: R ::> R :> R ::> Ok
562 562foreign import ccall unsafe "luS_l_C" zgetrs :: C ::> R :> C ::> Ok
563foreign import ccall unsafe "luS_l_R" dgetrs :: Tlus R
564foreign import ccall unsafe "luS_l_C" zgetrs :: Tlus C
565 563
566-- | Solve a real linear system from a precomputed LU decomposition ('luR'), using LAPACK's /dgetrs/. 564-- | Solve a real linear system from a precomputed LU decomposition ('luR'), using LAPACK's /dgetrs/.
567lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double 565lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double
@@ -573,13 +571,13 @@ lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv (fmat b)
573 571
574lusAux f st a piv b 572lusAux f st a piv b
575 | n1==n2 && n2==n =unsafePerformIO $ do 573 | n1==n2 && n2==n =unsafePerformIO $ do
576 x <- createMatrix ColumnMajor n m 574 x <- copy ColumnMajor b
577 f # a # piv' # b # x #| st 575 f # a # piv' # x #| st
578 return x 576 return x
579 | otherwise = error $ st ++ " on LU factorization of nonsquare matrix" 577 | otherwise = error $ st ++ " on LU factorization of nonsquare matrix"
580 where n1 = rows a 578 where
581 n2 = cols a 579 n1 = rows a
582 n = rows b 580 n2 = cols a
583 m = cols b 581 n = rows b
584 piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double 582 piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double
585 583