summaryrefslogtreecommitdiff
path: root/lib/Numeric/LinearAlgebra/Algorithms.hs
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Numeric/LinearAlgebra/Algorithms.hs')
-rw-r--r--lib/Numeric/LinearAlgebra/Algorithms.hs171
1 files changed, 169 insertions, 2 deletions
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs
index bbc5986..c7118c1 100644
--- a/lib/Numeric/LinearAlgebra/Algorithms.hs
+++ b/lib/Numeric/LinearAlgebra/Algorithms.hs
@@ -20,6 +20,7 @@ imported from "Numeric.LinearAlgebra.LAPACK".
20 20
21module Numeric.LinearAlgebra.Algorithms ( 21module Numeric.LinearAlgebra.Algorithms (
22-- * Linear Systems 22-- * Linear Systems
23 multiply, dot,
23 linearSolve, 24 linearSolve,
24 inv, pinv, 25 inv, pinv,
25 pinvTol, det, rank, rcond, 26 pinvTol, det, rank, rcond,
@@ -51,6 +52,8 @@ module Numeric.LinearAlgebra.Algorithms (
51-- * Misc 52-- * Misc
52 ctrans, 53 ctrans,
53 eps, i, 54 eps, i,
55 outer, kronecker,
56 mulH,
54-- * Util 57-- * Util
55 haussholder, 58 haussholder,
56 unpackQR, unpackHess, 59 unpackQR, unpackHess,
@@ -60,13 +63,14 @@ module Numeric.LinearAlgebra.Algorithms (
60 63
61import Data.Packed.Internal hiding (fromComplex, toComplex, comp, conj, (//)) 64import Data.Packed.Internal hiding (fromComplex, toComplex, comp, conj, (//))
62import Data.Packed 65import Data.Packed
63import qualified Numeric.GSL.Matrix as GSL
64import Numeric.GSL.Vector 66import Numeric.GSL.Vector
65import Numeric.LinearAlgebra.LAPACK as LAPACK 67import Numeric.LinearAlgebra.LAPACK as LAPACK
66import Complex 68import Complex
67import Numeric.LinearAlgebra.Linear 69import Numeric.LinearAlgebra.Linear
68import Data.List(foldl1') 70import Data.List(foldl1')
69import Data.Array 71import Data.Array
72import Foreign
73import Foreign.C.Types
70 74
71-- | Auxiliary typeclass used to define generic computations for both real and complex matrices. 75-- | Auxiliary typeclass used to define generic computations for both real and complex matrices.
72class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where 76class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where
@@ -105,6 +109,7 @@ class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where
105 schur :: Matrix t -> (Matrix t, Matrix t) 109 schur :: Matrix t -> (Matrix t, Matrix t)
106 -- | Conjugate transpose. 110 -- | Conjugate transpose.
107 ctrans :: Matrix t -> Matrix t 111 ctrans :: Matrix t -> Matrix t
112 multiply :: Matrix t -> Matrix t -> Matrix t
108 113
109 114
110instance Field Double where 115instance Field Double where
@@ -116,9 +121,10 @@ instance Field Double where
116 eig = eigR 121 eig = eigR
117 eigSH' = eigS 122 eigSH' = eigS
118 cholSH = cholS 123 cholSH = cholS
119 qr = GSL.unpackQR . qrR 124 qr = unpackQR . qrR
120 hess = unpackHess hessR 125 hess = unpackHess hessR
121 schur = schurR 126 schur = schurR
127 multiply = multiplyR3
122 128
123instance Field (Complex Double) where 129instance Field (Complex Double) where
124 svd = svdC 130 svd = svdC
@@ -132,6 +138,8 @@ instance Field (Complex Double) where
132 qr = unpackQR . qrC 138 qr = unpackQR . qrC
133 hess = unpackHess hessC 139 hess = unpackHess hessC
134 schur = schurC 140 schur = schurC
141 multiply = mulCW -- workaround
142 -- multiplyC3
135 143
136-- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev. 144-- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev.
137-- 145--
@@ -501,3 +509,162 @@ luFact (lu,perm) | r <= c = (l ,u ,p, s)
501 u' = takeRows c (lu |*| tu) 509 u' = takeRows c (lu |*| tu)
502 (|+|) = add 510 (|+|) = add
503 (|*|) = mul 511 (|*|) = mul
512
513--------------------------------------------------
514
515-- | euclidean inner product
516dot :: (Field t) => Vector t -> Vector t -> t
517dot u v = multiply r c @@> (0,0)
518 where r = asRow u
519 c = asColumn v
520
521
522{- | Outer product of two vectors.
523
524@\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3]
525(3><3)
526 [ 5.0, 2.0, 3.0
527 , 10.0, 4.0, 6.0
528 , 15.0, 6.0, 9.0 ]@
529-}
530outer :: (Field t) => Vector t -> Vector t -> Matrix t
531outer u v = asColumn u `multiply` asRow v
532
533{- | Kronecker product of two matrices.
534
535@m1=(2><3)
536 [ 1.0, 2.0, 0.0
537 , 0.0, -1.0, 3.0 ]
538m2=(4><3)
539 [ 1.0, 2.0, 3.0
540 , 4.0, 5.0, 6.0
541 , 7.0, 8.0, 9.0
542 , 10.0, 11.0, 12.0 ]@
543
544@\> kronecker m1 m2
545(8><9)
546 [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0
547 , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0
548 , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0
549 , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0
550 , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0
551 , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0
552 , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0
553 , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]@
554-}
555kronecker :: (Field t) => Matrix t -> Matrix t -> Matrix t
556kronecker a b = fromBlocks
557 . partit (cols a)
558 . map (reshape (cols b))
559 . toRows
560 $ flatten a `outer` flatten b
561
562---------------------------------------------------------------------
563-- reference multiply
564---------------------------------------------------------------------
565
566mulH a b = fromLists [[ dot ai bj | bj <- toColumns b] | ai <- toRows a ]
567 where dot u v = sum $ zipWith (*) (toList u) (toList v)
568
569-----------------------------------------------------------------------------------
570-- workaround
571-----------------------------------------------------------------------------------
572
573mulCW a b = toComplex (rr,ri)
574 where rr = multiply ar br `sub` multiply ai bi
575 ri = multiply ar bi `add` multiply ai br
576 (ar,ai) = fromComplex a
577 (br,bi) = fromComplex b
578
579-----------------------------------------------------------------------------------
580-- Direct CBLAS
581-----------------------------------------------------------------------------------
582
583newtype CBLASOrder = CBLASOrder CInt deriving (Eq, Show)
584newtype CBLASTrans = CBLASTrans CInt deriving (Eq, Show)
585
586rowMajor, colMajor :: CBLASOrder
587rowMajor = CBLASOrder 101
588colMajor = CBLASOrder 102
589
590noTrans, trans', conjTrans :: CBLASTrans
591noTrans = CBLASTrans 111
592trans' = CBLASTrans 112
593conjTrans = CBLASTrans 113
594
595foreign import ccall "cblas.h cblas_dgemm"
596 dgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Double -> Ptr Double -> CInt -> IO ()
597
598
599multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double
600multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b)
601 where
602 multiply3 f st a b
603 | cols a == rows b = unsafePerformIO $ do
604 s <- createMatrix ColumnMajor (rows a) (cols b)
605 let g ar ac ap br bc bp rr rc rp = f colMajor noTrans noTrans ar bc ac 1 ap ar bp br 0 rp rr >> return 0
606 app3 g mat a mat b mat s st
607 return s
608 | otherwise = error $ st ++ " (matrix product) of nonconformant matrices"
609
610
611foreign import ccall "cblas.h cblas_zgemm"
612 zgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> IO ()
613
614multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
615multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat b)
616 where
617 multiply3 f st a b
618 | cols a == rows b = do
619 s <- createMatrix ColumnMajor (rows a) (cols b)
620 palpha <- new 1
621 pbeta <- new 0
622 let g ar ac ap br bc bp rr rc rp = f colMajor noTrans noTrans ar bc ac palpha ap ar bp br pbeta rp rr >> return 0
623 app3 g mat a mat b mat s st
624 free palpha
625 free pbeta
626 return s
627 -- if toLists s== toLists s then return s else error $ "HORROR " ++ (show (toLists s))
628 | otherwise = error $ st ++ " (matrix product) of nonconformant matrices"
629
630-----------------------------------------------------------------------------------
631-- BLAS via auxiliary C
632-----------------------------------------------------------------------------------
633
634foreign import ccall "multiply.h multiplyR2" dgemmc :: TMMM
635foreign import ccall "multiply.h multiplyC2" zgemmc :: TCMCMCM
636
637multiply2 f st a b
638 | cols a == rows b = unsafePerformIO $ do
639 s <- createMatrix ColumnMajor (rows a) (cols b)
640 app3 f mat a mat b mat s st
641 if toLists s== toLists s then return s else error $ "AYYY " ++ (show (toLists s))
642 | otherwise = error $ st ++ " (matrix product) of nonconformant matrices"
643
644multiplyR2 :: Matrix Double -> Matrix Double -> Matrix Double
645multiplyR2 a b = multiply2 dgemmc "dgemmc" (fmat a) (fmat b)
646
647multiplyC2 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
648multiplyC2 a b = multiply2 zgemmc "zgemmc" (fmat a) (fmat b)
649
650-----------------------------------------------------------------------------------
651-- direct C multiplication
652-----------------------------------------------------------------------------------
653
654foreign import ccall "multiply.h multiplyR" cmultiplyR :: TMMM
655foreign import ccall "multiply.h multiplyC" cmultiplyC :: TCMCMCM
656
657cmultiply f st a b
658-- | cols a == rows b =
659 = unsafePerformIO $ do
660 s <- createMatrix RowMajor (rows a) (cols b)
661 app3 f mat a mat b mat s st
662 if toLists s== toLists s then return s else error $ "BRUTAL " ++ (show (toLists s))
663 -- return s
664-- | otherwise = error $ st ++ " (matrix product) of nonconformant matrices"
665
666multiplyR :: Matrix Double -> Matrix Double -> Matrix Double
667multiplyR a b = cmultiply cmultiplyR "cmultiplyR" (cmat a) (cmat b)
668
669multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
670multiplyC a b = cmultiply cmultiplyC "cmultiplyR" (cmat a) (cmat b)