diff options
Diffstat (limited to 'lib/Numeric/LinearAlgebra/Algorithms.hs')
-rw-r--r-- | lib/Numeric/LinearAlgebra/Algorithms.hs | 171 |
1 files changed, 169 insertions, 2 deletions
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index bbc5986..c7118c1 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs | |||
@@ -20,6 +20,7 @@ imported from "Numeric.LinearAlgebra.LAPACK". | |||
20 | 20 | ||
21 | module Numeric.LinearAlgebra.Algorithms ( | 21 | module Numeric.LinearAlgebra.Algorithms ( |
22 | -- * Linear Systems | 22 | -- * Linear Systems |
23 | multiply, dot, | ||
23 | linearSolve, | 24 | linearSolve, |
24 | inv, pinv, | 25 | inv, pinv, |
25 | pinvTol, det, rank, rcond, | 26 | pinvTol, det, rank, rcond, |
@@ -51,6 +52,8 @@ module Numeric.LinearAlgebra.Algorithms ( | |||
51 | -- * Misc | 52 | -- * Misc |
52 | ctrans, | 53 | ctrans, |
53 | eps, i, | 54 | eps, i, |
55 | outer, kronecker, | ||
56 | mulH, | ||
54 | -- * Util | 57 | -- * Util |
55 | haussholder, | 58 | haussholder, |
56 | unpackQR, unpackHess, | 59 | unpackQR, unpackHess, |
@@ -60,13 +63,14 @@ module Numeric.LinearAlgebra.Algorithms ( | |||
60 | 63 | ||
61 | import Data.Packed.Internal hiding (fromComplex, toComplex, comp, conj, (//)) | 64 | import Data.Packed.Internal hiding (fromComplex, toComplex, comp, conj, (//)) |
62 | import Data.Packed | 65 | import Data.Packed |
63 | import qualified Numeric.GSL.Matrix as GSL | ||
64 | import Numeric.GSL.Vector | 66 | import Numeric.GSL.Vector |
65 | import Numeric.LinearAlgebra.LAPACK as LAPACK | 67 | import Numeric.LinearAlgebra.LAPACK as LAPACK |
66 | import Complex | 68 | import Complex |
67 | import Numeric.LinearAlgebra.Linear | 69 | import Numeric.LinearAlgebra.Linear |
68 | import Data.List(foldl1') | 70 | import Data.List(foldl1') |
69 | import Data.Array | 71 | import Data.Array |
72 | import Foreign | ||
73 | import Foreign.C.Types | ||
70 | 74 | ||
71 | -- | Auxiliary typeclass used to define generic computations for both real and complex matrices. | 75 | -- | Auxiliary typeclass used to define generic computations for both real and complex matrices. |
72 | class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where | 76 | class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where |
@@ -105,6 +109,7 @@ class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where | |||
105 | schur :: Matrix t -> (Matrix t, Matrix t) | 109 | schur :: Matrix t -> (Matrix t, Matrix t) |
106 | -- | Conjugate transpose. | 110 | -- | Conjugate transpose. |
107 | ctrans :: Matrix t -> Matrix t | 111 | ctrans :: Matrix t -> Matrix t |
112 | multiply :: Matrix t -> Matrix t -> Matrix t | ||
108 | 113 | ||
109 | 114 | ||
110 | instance Field Double where | 115 | instance Field Double where |
@@ -116,9 +121,10 @@ instance Field Double where | |||
116 | eig = eigR | 121 | eig = eigR |
117 | eigSH' = eigS | 122 | eigSH' = eigS |
118 | cholSH = cholS | 123 | cholSH = cholS |
119 | qr = GSL.unpackQR . qrR | 124 | qr = unpackQR . qrR |
120 | hess = unpackHess hessR | 125 | hess = unpackHess hessR |
121 | schur = schurR | 126 | schur = schurR |
127 | multiply = multiplyR3 | ||
122 | 128 | ||
123 | instance Field (Complex Double) where | 129 | instance Field (Complex Double) where |
124 | svd = svdC | 130 | svd = svdC |
@@ -132,6 +138,8 @@ instance Field (Complex Double) where | |||
132 | qr = unpackQR . qrC | 138 | qr = unpackQR . qrC |
133 | hess = unpackHess hessC | 139 | hess = unpackHess hessC |
134 | schur = schurC | 140 | schur = schurC |
141 | multiply = mulCW -- workaround | ||
142 | -- multiplyC3 | ||
135 | 143 | ||
136 | -- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev. | 144 | -- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev. |
137 | -- | 145 | -- |
@@ -501,3 +509,162 @@ luFact (lu,perm) | r <= c = (l ,u ,p, s) | |||
501 | u' = takeRows c (lu |*| tu) | 509 | u' = takeRows c (lu |*| tu) |
502 | (|+|) = add | 510 | (|+|) = add |
503 | (|*|) = mul | 511 | (|*|) = mul |
512 | |||
513 | -------------------------------------------------- | ||
514 | |||
515 | -- | euclidean inner product | ||
516 | dot :: (Field t) => Vector t -> Vector t -> t | ||
517 | dot u v = multiply r c @@> (0,0) | ||
518 | where r = asRow u | ||
519 | c = asColumn v | ||
520 | |||
521 | |||
522 | {- | Outer product of two vectors. | ||
523 | |||
524 | @\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3] | ||
525 | (3><3) | ||
526 | [ 5.0, 2.0, 3.0 | ||
527 | , 10.0, 4.0, 6.0 | ||
528 | , 15.0, 6.0, 9.0 ]@ | ||
529 | -} | ||
530 | outer :: (Field t) => Vector t -> Vector t -> Matrix t | ||
531 | outer u v = asColumn u `multiply` asRow v | ||
532 | |||
533 | {- | Kronecker product of two matrices. | ||
534 | |||
535 | @m1=(2><3) | ||
536 | [ 1.0, 2.0, 0.0 | ||
537 | , 0.0, -1.0, 3.0 ] | ||
538 | m2=(4><3) | ||
539 | [ 1.0, 2.0, 3.0 | ||
540 | , 4.0, 5.0, 6.0 | ||
541 | , 7.0, 8.0, 9.0 | ||
542 | , 10.0, 11.0, 12.0 ]@ | ||
543 | |||
544 | @\> kronecker m1 m2 | ||
545 | (8><9) | ||
546 | [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0 | ||
547 | , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0 | ||
548 | , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0 | ||
549 | , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0 | ||
550 | , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0 | ||
551 | , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0 | ||
552 | , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0 | ||
553 | , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]@ | ||
554 | -} | ||
555 | kronecker :: (Field t) => Matrix t -> Matrix t -> Matrix t | ||
556 | kronecker a b = fromBlocks | ||
557 | . partit (cols a) | ||
558 | . map (reshape (cols b)) | ||
559 | . toRows | ||
560 | $ flatten a `outer` flatten b | ||
561 | |||
562 | --------------------------------------------------------------------- | ||
563 | -- reference multiply | ||
564 | --------------------------------------------------------------------- | ||
565 | |||
566 | mulH a b = fromLists [[ dot ai bj | bj <- toColumns b] | ai <- toRows a ] | ||
567 | where dot u v = sum $ zipWith (*) (toList u) (toList v) | ||
568 | |||
569 | ----------------------------------------------------------------------------------- | ||
570 | -- workaround | ||
571 | ----------------------------------------------------------------------------------- | ||
572 | |||
573 | mulCW a b = toComplex (rr,ri) | ||
574 | where rr = multiply ar br `sub` multiply ai bi | ||
575 | ri = multiply ar bi `add` multiply ai br | ||
576 | (ar,ai) = fromComplex a | ||
577 | (br,bi) = fromComplex b | ||
578 | |||
579 | ----------------------------------------------------------------------------------- | ||
580 | -- Direct CBLAS | ||
581 | ----------------------------------------------------------------------------------- | ||
582 | |||
583 | newtype CBLASOrder = CBLASOrder CInt deriving (Eq, Show) | ||
584 | newtype CBLASTrans = CBLASTrans CInt deriving (Eq, Show) | ||
585 | |||
586 | rowMajor, colMajor :: CBLASOrder | ||
587 | rowMajor = CBLASOrder 101 | ||
588 | colMajor = CBLASOrder 102 | ||
589 | |||
590 | noTrans, trans', conjTrans :: CBLASTrans | ||
591 | noTrans = CBLASTrans 111 | ||
592 | trans' = CBLASTrans 112 | ||
593 | conjTrans = CBLASTrans 113 | ||
594 | |||
595 | foreign import ccall "cblas.h cblas_dgemm" | ||
596 | dgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Double -> Ptr Double -> CInt -> IO () | ||
597 | |||
598 | |||
599 | multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double | ||
600 | multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b) | ||
601 | where | ||
602 | multiply3 f st a b | ||
603 | | cols a == rows b = unsafePerformIO $ do | ||
604 | s <- createMatrix ColumnMajor (rows a) (cols b) | ||
605 | let g ar ac ap br bc bp rr rc rp = f colMajor noTrans noTrans ar bc ac 1 ap ar bp br 0 rp rr >> return 0 | ||
606 | app3 g mat a mat b mat s st | ||
607 | return s | ||
608 | | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" | ||
609 | |||
610 | |||
611 | foreign import ccall "cblas.h cblas_zgemm" | ||
612 | zgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> IO () | ||
613 | |||
614 | multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | ||
615 | multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat b) | ||
616 | where | ||
617 | multiply3 f st a b | ||
618 | | cols a == rows b = do | ||
619 | s <- createMatrix ColumnMajor (rows a) (cols b) | ||
620 | palpha <- new 1 | ||
621 | pbeta <- new 0 | ||
622 | let g ar ac ap br bc bp rr rc rp = f colMajor noTrans noTrans ar bc ac palpha ap ar bp br pbeta rp rr >> return 0 | ||
623 | app3 g mat a mat b mat s st | ||
624 | free palpha | ||
625 | free pbeta | ||
626 | return s | ||
627 | -- if toLists s== toLists s then return s else error $ "HORROR " ++ (show (toLists s)) | ||
628 | | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" | ||
629 | |||
630 | ----------------------------------------------------------------------------------- | ||
631 | -- BLAS via auxiliary C | ||
632 | ----------------------------------------------------------------------------------- | ||
633 | |||
634 | foreign import ccall "multiply.h multiplyR2" dgemmc :: TMMM | ||
635 | foreign import ccall "multiply.h multiplyC2" zgemmc :: TCMCMCM | ||
636 | |||
637 | multiply2 f st a b | ||
638 | | cols a == rows b = unsafePerformIO $ do | ||
639 | s <- createMatrix ColumnMajor (rows a) (cols b) | ||
640 | app3 f mat a mat b mat s st | ||
641 | if toLists s== toLists s then return s else error $ "AYYY " ++ (show (toLists s)) | ||
642 | | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" | ||
643 | |||
644 | multiplyR2 :: Matrix Double -> Matrix Double -> Matrix Double | ||
645 | multiplyR2 a b = multiply2 dgemmc "dgemmc" (fmat a) (fmat b) | ||
646 | |||
647 | multiplyC2 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | ||
648 | multiplyC2 a b = multiply2 zgemmc "zgemmc" (fmat a) (fmat b) | ||
649 | |||
650 | ----------------------------------------------------------------------------------- | ||
651 | -- direct C multiplication | ||
652 | ----------------------------------------------------------------------------------- | ||
653 | |||
654 | foreign import ccall "multiply.h multiplyR" cmultiplyR :: TMMM | ||
655 | foreign import ccall "multiply.h multiplyC" cmultiplyC :: TCMCMCM | ||
656 | |||
657 | cmultiply f st a b | ||
658 | -- | cols a == rows b = | ||
659 | = unsafePerformIO $ do | ||
660 | s <- createMatrix RowMajor (rows a) (cols b) | ||
661 | app3 f mat a mat b mat s st | ||
662 | if toLists s== toLists s then return s else error $ "BRUTAL " ++ (show (toLists s)) | ||
663 | -- return s | ||
664 | -- | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" | ||
665 | |||
666 | multiplyR :: Matrix Double -> Matrix Double -> Matrix Double | ||
667 | multiplyR a b = cmultiply cmultiplyR "cmultiplyR" (cmat a) (cmat b) | ||
668 | |||
669 | multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | ||
670 | multiplyC a b = cmultiply cmultiplyC "cmultiplyR" (cmat a) (cmat b) | ||