summaryrefslogtreecommitdiff
path: root/lib/Numeric/LinearAlgebra
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Numeric/LinearAlgebra')
-rw-r--r--lib/Numeric/LinearAlgebra/Algorithms.hs171
-rw-r--r--lib/Numeric/LinearAlgebra/Interface.hs4
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c74
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h6
-rw-r--r--lib/Numeric/LinearAlgebra/Linear.hs54
-rw-r--r--lib/Numeric/LinearAlgebra/Tests.hs5
-rw-r--r--lib/Numeric/LinearAlgebra/Tests/Instances.hs16
-rw-r--r--lib/Numeric/LinearAlgebra/Tests/Properties.hs6
8 files changed, 279 insertions, 57 deletions
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs
index bbc5986..c7118c1 100644
--- a/lib/Numeric/LinearAlgebra/Algorithms.hs
+++ b/lib/Numeric/LinearAlgebra/Algorithms.hs
@@ -20,6 +20,7 @@ imported from "Numeric.LinearAlgebra.LAPACK".
20 20
21module Numeric.LinearAlgebra.Algorithms ( 21module Numeric.LinearAlgebra.Algorithms (
22-- * Linear Systems 22-- * Linear Systems
23 multiply, dot,
23 linearSolve, 24 linearSolve,
24 inv, pinv, 25 inv, pinv,
25 pinvTol, det, rank, rcond, 26 pinvTol, det, rank, rcond,
@@ -51,6 +52,8 @@ module Numeric.LinearAlgebra.Algorithms (
51-- * Misc 52-- * Misc
52 ctrans, 53 ctrans,
53 eps, i, 54 eps, i,
55 outer, kronecker,
56 mulH,
54-- * Util 57-- * Util
55 haussholder, 58 haussholder,
56 unpackQR, unpackHess, 59 unpackQR, unpackHess,
@@ -60,13 +63,14 @@ module Numeric.LinearAlgebra.Algorithms (
60 63
61import Data.Packed.Internal hiding (fromComplex, toComplex, comp, conj, (//)) 64import Data.Packed.Internal hiding (fromComplex, toComplex, comp, conj, (//))
62import Data.Packed 65import Data.Packed
63import qualified Numeric.GSL.Matrix as GSL
64import Numeric.GSL.Vector 66import Numeric.GSL.Vector
65import Numeric.LinearAlgebra.LAPACK as LAPACK 67import Numeric.LinearAlgebra.LAPACK as LAPACK
66import Complex 68import Complex
67import Numeric.LinearAlgebra.Linear 69import Numeric.LinearAlgebra.Linear
68import Data.List(foldl1') 70import Data.List(foldl1')
69import Data.Array 71import Data.Array
72import Foreign
73import Foreign.C.Types
70 74
71-- | Auxiliary typeclass used to define generic computations for both real and complex matrices. 75-- | Auxiliary typeclass used to define generic computations for both real and complex matrices.
72class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where 76class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where
@@ -105,6 +109,7 @@ class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where
105 schur :: Matrix t -> (Matrix t, Matrix t) 109 schur :: Matrix t -> (Matrix t, Matrix t)
106 -- | Conjugate transpose. 110 -- | Conjugate transpose.
107 ctrans :: Matrix t -> Matrix t 111 ctrans :: Matrix t -> Matrix t
112 multiply :: Matrix t -> Matrix t -> Matrix t
108 113
109 114
110instance Field Double where 115instance Field Double where
@@ -116,9 +121,10 @@ instance Field Double where
116 eig = eigR 121 eig = eigR
117 eigSH' = eigS 122 eigSH' = eigS
118 cholSH = cholS 123 cholSH = cholS
119 qr = GSL.unpackQR . qrR 124 qr = unpackQR . qrR
120 hess = unpackHess hessR 125 hess = unpackHess hessR
121 schur = schurR 126 schur = schurR
127 multiply = multiplyR3
122 128
123instance Field (Complex Double) where 129instance Field (Complex Double) where
124 svd = svdC 130 svd = svdC
@@ -132,6 +138,8 @@ instance Field (Complex Double) where
132 qr = unpackQR . qrC 138 qr = unpackQR . qrC
133 hess = unpackHess hessC 139 hess = unpackHess hessC
134 schur = schurC 140 schur = schurC
141 multiply = mulCW -- workaround
142 -- multiplyC3
135 143
136-- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev. 144-- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev.
137-- 145--
@@ -501,3 +509,162 @@ luFact (lu,perm) | r <= c = (l ,u ,p, s)
501 u' = takeRows c (lu |*| tu) 509 u' = takeRows c (lu |*| tu)
502 (|+|) = add 510 (|+|) = add
503 (|*|) = mul 511 (|*|) = mul
512
513--------------------------------------------------
514
515-- | euclidean inner product
516dot :: (Field t) => Vector t -> Vector t -> t
517dot u v = multiply r c @@> (0,0)
518 where r = asRow u
519 c = asColumn v
520
521
522{- | Outer product of two vectors.
523
524@\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3]
525(3><3)
526 [ 5.0, 2.0, 3.0
527 , 10.0, 4.0, 6.0
528 , 15.0, 6.0, 9.0 ]@
529-}
530outer :: (Field t) => Vector t -> Vector t -> Matrix t
531outer u v = asColumn u `multiply` asRow v
532
533{- | Kronecker product of two matrices.
534
535@m1=(2><3)
536 [ 1.0, 2.0, 0.0
537 , 0.0, -1.0, 3.0 ]
538m2=(4><3)
539 [ 1.0, 2.0, 3.0
540 , 4.0, 5.0, 6.0
541 , 7.0, 8.0, 9.0
542 , 10.0, 11.0, 12.0 ]@
543
544@\> kronecker m1 m2
545(8><9)
546 [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0
547 , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0
548 , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0
549 , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0
550 , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0
551 , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0
552 , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0
553 , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]@
554-}
555kronecker :: (Field t) => Matrix t -> Matrix t -> Matrix t
556kronecker a b = fromBlocks
557 . partit (cols a)
558 . map (reshape (cols b))
559 . toRows
560 $ flatten a `outer` flatten b
561
562---------------------------------------------------------------------
563-- reference multiply
564---------------------------------------------------------------------
565
566mulH a b = fromLists [[ dot ai bj | bj <- toColumns b] | ai <- toRows a ]
567 where dot u v = sum $ zipWith (*) (toList u) (toList v)
568
569-----------------------------------------------------------------------------------
570-- workaround
571-----------------------------------------------------------------------------------
572
573mulCW a b = toComplex (rr,ri)
574 where rr = multiply ar br `sub` multiply ai bi
575 ri = multiply ar bi `add` multiply ai br
576 (ar,ai) = fromComplex a
577 (br,bi) = fromComplex b
578
579-----------------------------------------------------------------------------------
580-- Direct CBLAS
581-----------------------------------------------------------------------------------
582
583newtype CBLASOrder = CBLASOrder CInt deriving (Eq, Show)
584newtype CBLASTrans = CBLASTrans CInt deriving (Eq, Show)
585
586rowMajor, colMajor :: CBLASOrder
587rowMajor = CBLASOrder 101
588colMajor = CBLASOrder 102
589
590noTrans, trans', conjTrans :: CBLASTrans
591noTrans = CBLASTrans 111
592trans' = CBLASTrans 112
593conjTrans = CBLASTrans 113
594
595foreign import ccall "cblas.h cblas_dgemm"
596 dgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Double -> Ptr Double -> CInt -> IO ()
597
598
599multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double
600multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b)
601 where
602 multiply3 f st a b
603 | cols a == rows b = unsafePerformIO $ do
604 s <- createMatrix ColumnMajor (rows a) (cols b)
605 let g ar ac ap br bc bp rr rc rp = f colMajor noTrans noTrans ar bc ac 1 ap ar bp br 0 rp rr >> return 0
606 app3 g mat a mat b mat s st
607 return s
608 | otherwise = error $ st ++ " (matrix product) of nonconformant matrices"
609
610
611foreign import ccall "cblas.h cblas_zgemm"
612 zgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> IO ()
613
614multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
615multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat b)
616 where
617 multiply3 f st a b
618 | cols a == rows b = do
619 s <- createMatrix ColumnMajor (rows a) (cols b)
620 palpha <- new 1
621 pbeta <- new 0
622 let g ar ac ap br bc bp rr rc rp = f colMajor noTrans noTrans ar bc ac palpha ap ar bp br pbeta rp rr >> return 0
623 app3 g mat a mat b mat s st
624 free palpha
625 free pbeta
626 return s
627 -- if toLists s== toLists s then return s else error $ "HORROR " ++ (show (toLists s))
628 | otherwise = error $ st ++ " (matrix product) of nonconformant matrices"
629
630-----------------------------------------------------------------------------------
631-- BLAS via auxiliary C
632-----------------------------------------------------------------------------------
633
634foreign import ccall "multiply.h multiplyR2" dgemmc :: TMMM
635foreign import ccall "multiply.h multiplyC2" zgemmc :: TCMCMCM
636
637multiply2 f st a b
638 | cols a == rows b = unsafePerformIO $ do
639 s <- createMatrix ColumnMajor (rows a) (cols b)
640 app3 f mat a mat b mat s st
641 if toLists s== toLists s then return s else error $ "AYYY " ++ (show (toLists s))
642 | otherwise = error $ st ++ " (matrix product) of nonconformant matrices"
643
644multiplyR2 :: Matrix Double -> Matrix Double -> Matrix Double
645multiplyR2 a b = multiply2 dgemmc "dgemmc" (fmat a) (fmat b)
646
647multiplyC2 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
648multiplyC2 a b = multiply2 zgemmc "zgemmc" (fmat a) (fmat b)
649
650-----------------------------------------------------------------------------------
651-- direct C multiplication
652-----------------------------------------------------------------------------------
653
654foreign import ccall "multiply.h multiplyR" cmultiplyR :: TMMM
655foreign import ccall "multiply.h multiplyC" cmultiplyC :: TCMCMCM
656
657cmultiply f st a b
658-- | cols a == rows b =
659 = unsafePerformIO $ do
660 s <- createMatrix RowMajor (rows a) (cols b)
661 app3 f mat a mat b mat s st
662 if toLists s== toLists s then return s else error $ "BRUTAL " ++ (show (toLists s))
663 -- return s
664-- | otherwise = error $ st ++ " (matrix product) of nonconformant matrices"
665
666multiplyR :: Matrix Double -> Matrix Double -> Matrix Double
667multiplyR a b = cmultiply cmultiplyR "cmultiplyR" (cmat a) (cmat b)
668
669multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
670multiplyC a b = cmultiply cmultiplyC "cmultiplyR" (cmat a) (cmat b)
diff --git a/lib/Numeric/LinearAlgebra/Interface.hs b/lib/Numeric/LinearAlgebra/Interface.hs
index 4a9b309..0ae9698 100644
--- a/lib/Numeric/LinearAlgebra/Interface.hs
+++ b/lib/Numeric/LinearAlgebra/Interface.hs
@@ -29,7 +29,7 @@ import Numeric.LinearAlgebra.Algorithms
29class Mul a b c | a b -> c where 29class Mul a b c | a b -> c where
30 infixl 7 <> 30 infixl 7 <>
31 -- | matrix product 31 -- | matrix product
32 (<>) :: Element t => a t -> b t -> c t 32 (<>) :: Field t => a t -> b t -> c t
33 33
34instance Mul Matrix Matrix Matrix where 34instance Mul Matrix Matrix Matrix where
35 (<>) = multiply 35 (<>) = multiply
@@ -43,7 +43,7 @@ instance Mul Vector Matrix Vector where
43--------------------------------------------------- 43---------------------------------------------------
44 44
45-- | @u \<.\> v = dot u v@ 45-- | @u \<.\> v = dot u v@
46(<.>) :: (Element t) => Vector t -> Vector t -> t 46(<.>) :: (Field t) => Vector t -> Vector t -> t
47infixl 7 <.> 47infixl 7 <.>
48(<.>) = dot 48(<.>) = dot
49 49
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
index 310f6ee..0dccea2 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
+++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
@@ -814,3 +814,77 @@ int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)) {
814 free(auxipiv); 814 free(auxipiv);
815 OK 815 OK
816} 816}
817
818////////////////////////////////////////////////////////////
819
820int multiplyR(KDMAT(a),KDMAT(b),DMAT(r)) {
821 REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
822 int i,j,k;
823 for (i=0;i<ar;i++) {
824 for(j=0;j<bc;j++) {
825 double temp = 0;
826 for(k=0;k<ac;k++) {
827 temp += ap[i*ac+k]*bp[k*bc+j];
828 }
829 rp[i*rc+j] = temp;
830 }
831 }
832 OK
833}
834
835int multiplyC(KCMAT(a),KCMAT(b),CMAT(r)) {
836 REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
837 int i,j,k;
838 for (i=0;i<ar;i++) {
839 for(j=0;j<bc;j++) {
840 doublecomplex temp = {0,0};
841 for(k=0;k<ac;k++) {
842 doublecomplex aik = ((doublecomplex*)ap)[i*ac+k];
843 doublecomplex bkj = ((doublecomplex*)bp)[k*bc+j];
844 //double w = aik.r+aik.i+bkj.r+bkj.i;
845 //if (w>w) exit(1);
846 //printf("%d",w>w);
847 temp.r += aik.r * bkj.r - aik.i * bkj.i;
848 temp.i += aik.r * bkj.i + aik.i * bkj.r;
849 //printf("%f %f %f %f \n",aik.r,aik.i,bkj.r,bkj.i);
850 //printf("%f %f %f \n",w,temp.r,temp.i);
851
852 }
853 ((doublecomplex*)rp)[i*rc+j] = temp;
854 //printf("%f %f\n",temp.r,temp.i);
855 }
856 }
857 OK
858}
859
860void dgemm_(char *, char *, integer *, integer *, integer *,
861 double *, const double *, integer *, const double *,
862 integer *, double *, double *, integer *);
863
864void zgemm_(char *, char *, integer *, integer *, integer *,
865 doublecomplex *, const doublecomplex *, integer *, const doublecomplex *,
866 integer *, doublecomplex *, doublecomplex *, integer *);
867
868
869int multiplyR2(KDMAT(a),KDMAT(b),DMAT(r)) {
870 REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
871 double alpha = 1;
872 double beta = 0;
873 integer m = ar;
874 integer n = bc;
875 integer k = ac;
876 int i,j;
877 dgemm_("N","N",&m,&n,&k,&alpha,ap,&m,bp,&k,&beta,rp,&m);
878 OK
879}
880
881int multiplyC2(KCMAT(a),KCMAT(b),CMAT(r)) {
882 REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
883 integer m = ar;
884 integer n = bc;
885 integer k = ac;
886 doublecomplex alpha = {1,0};
887 doublecomplex beta = {0,0};
888 zgemm_("N","N",&m,&n,&k,&alpha,(doublecomplex*)ap,&m,(doublecomplex*)bp,&k,&beta,(doublecomplex*)rp,&m);
889 OK
890}
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
index 79e52be..c0361a6 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
+++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
@@ -84,3 +84,9 @@ int schur_l_C(KCMAT(a), CMAT(u), CMAT(s));
84 84
85int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)); 85int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r));
86int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)); 86int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r));
87
88int multiplyR(KDMAT(a),KDMAT(b),DMAT(r));
89int multiplyC(KCMAT(a),KCMAT(b),CMAT(r));
90
91int multiplyR2(KDMAT(a),KDMAT(b),DMAT(r));
92int multiplyC2(KCMAT(a),KCMAT(b),CMAT(r));
diff --git a/lib/Numeric/LinearAlgebra/Linear.hs b/lib/Numeric/LinearAlgebra/Linear.hs
index 0ddbb55..1bf8b04 100644
--- a/lib/Numeric/LinearAlgebra/Linear.hs
+++ b/lib/Numeric/LinearAlgebra/Linear.hs
@@ -15,12 +15,11 @@ Basic optimized operations on vectors and matrices.
15----------------------------------------------------------------------------- 15-----------------------------------------------------------------------------
16 16
17module Numeric.LinearAlgebra.Linear ( 17module Numeric.LinearAlgebra.Linear (
18 Linear(..), 18 Linear(..)
19 multiply, dot, outer, kronecker
20) where 19) where
21 20
22 21
23import Data.Packed.Internal(multiply,partit) 22import Data.Packed.Internal(partit)
24import Data.Packed 23import Data.Packed
25import Numeric.GSL.Vector 24import Numeric.GSL.Vector
26import Complex 25import Complex
@@ -69,52 +68,3 @@ instance (Linear Vector a, Container Matrix a) => (Linear Matrix a) where
69 mul = liftMatrix2 mul 68 mul = liftMatrix2 mul
70 divide = liftMatrix2 divide 69 divide = liftMatrix2 divide
71 equal a b = cols a == cols b && flatten a `equal` flatten b 70 equal a b = cols a == cols b && flatten a `equal` flatten b
72
73--------------------------------------------------
74
75-- | euclidean inner product
76dot :: (Element t) => Vector t -> Vector t -> t
77dot u v = multiply r c @@> (0,0)
78 where r = asRow u
79 c = asColumn v
80
81
82{- | Outer product of two vectors.
83
84@\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3]
85(3><3)
86 [ 5.0, 2.0, 3.0
87 , 10.0, 4.0, 6.0
88 , 15.0, 6.0, 9.0 ]@
89-}
90outer :: (Element t) => Vector t -> Vector t -> Matrix t
91outer u v = asColumn u `multiply` asRow v
92
93{- | Kronecker product of two matrices.
94
95@m1=(2><3)
96 [ 1.0, 2.0, 0.0
97 , 0.0, -1.0, 3.0 ]
98m2=(4><3)
99 [ 1.0, 2.0, 3.0
100 , 4.0, 5.0, 6.0
101 , 7.0, 8.0, 9.0
102 , 10.0, 11.0, 12.0 ]@
103
104@\> kronecker m1 m2
105(8><9)
106 [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0
107 , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0
108 , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0
109 , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0
110 , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0
111 , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0
112 , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0
113 , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]@
114-}
115kronecker :: (Element t) => Matrix t -> Matrix t -> Matrix t
116kronecker a b = fromBlocks
117 . partit (cols a)
118 . map (reshape (cols b))
119 . toRows
120 $ flatten a `outer` flatten b
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs
index 7b28075..07b9f63 100644
--- a/lib/Numeric/LinearAlgebra/Tests.hs
+++ b/lib/Numeric/LinearAlgebra/Tests.hs
@@ -123,6 +123,11 @@ runTests :: Int -- ^ maximum dimension
123runTests n = do 123runTests n = do
124 setErrorHandlerOff 124 setErrorHandlerOff
125 let test p = qCheck n p 125 let test p = qCheck n p
126 putStrLn "------ mult"
127 test (multProp1 . rConsist)
128 test (multProp1 . cConsist)
129 test (multProp2 . rConsist)
130 test (multProp2 . cConsist)
126 putStrLn "------ lu" 131 putStrLn "------ lu"
127 test (luProp . rM) 132 test (luProp . rM)
128 test (luProp . cM) 133 test (luProp . cM)
diff --git a/lib/Numeric/LinearAlgebra/Tests/Instances.hs b/lib/Numeric/LinearAlgebra/Tests/Instances.hs
index af486c8..e7fecf2 100644
--- a/lib/Numeric/LinearAlgebra/Tests/Instances.hs
+++ b/lib/Numeric/LinearAlgebra/Tests/Instances.hs
@@ -20,6 +20,7 @@ module Numeric.LinearAlgebra.Tests.Instances(
20 WC(..), rWC,cWC, 20 WC(..), rWC,cWC,
21 SqWC(..), rSqWC, cSqWC, 21 SqWC(..), rSqWC, cSqWC,
22 PosDef(..), rPosDef, cPosDef, 22 PosDef(..), rPosDef, cPosDef,
23 Consistent(..), rConsist, cConsist,
23 RM,CM, rM,cM 24 RM,CM, rM,cM
24) where 25) where
25 26
@@ -116,6 +117,19 @@ instance (Field a, Arbitrary a) => Arbitrary (PosDef a) where
116 return $ PosDef (0.5 .* p + 0.5 .* ctrans p) 117 return $ PosDef (0.5 .* p + 0.5 .* ctrans p)
117 coarbitrary = undefined 118 coarbitrary = undefined
118 119
120-- a pair of matrices that can be multiplied
121newtype (Consistent a) = Consistent (Matrix a, Matrix a) deriving Show
122instance (Field a, Arbitrary a) => Arbitrary (Consistent a) where
123 arbitrary = do
124 n <- chooseDim
125 k <- chooseDim
126 m <- chooseDim
127 la <- vector (n*k)
128 lb <- vector (k*m)
129 return $ Consistent ((n><k) la, (k><m) lb)
130 coarbitrary = undefined
131
132
119type RM = Matrix Double 133type RM = Matrix Double
120type CM = Matrix (Complex Double) 134type CM = Matrix (Complex Double)
121 135
@@ -140,3 +154,5 @@ cSqWC (SqWC m) = m :: CM
140rPosDef (PosDef m) = m :: RM 154rPosDef (PosDef m) = m :: RM
141cPosDef (PosDef m) = m :: CM 155cPosDef (PosDef m) = m :: CM
142 156
157rConsist (Consistent (a,b)) = (a,b::RM)
158cConsist (Consistent (a,b)) = (a,b::CM)
diff --git a/lib/Numeric/LinearAlgebra/Tests/Properties.hs b/lib/Numeric/LinearAlgebra/Tests/Properties.hs
index 55e9a1b..5663b86 100644
--- a/lib/Numeric/LinearAlgebra/Tests/Properties.hs
+++ b/lib/Numeric/LinearAlgebra/Tests/Properties.hs
@@ -34,7 +34,8 @@ module Numeric.LinearAlgebra.Tests.Properties (
34 hessProp, 34 hessProp,
35 schurProp1, schurProp2, 35 schurProp1, schurProp2,
36 cholProp, 36 cholProp,
37 expmDiagProp 37 expmDiagProp,
38 multProp1, multProp2
38) where 39) where
39 40
40import Numeric.LinearAlgebra 41import Numeric.LinearAlgebra
@@ -151,3 +152,6 @@ cholProp m = m |~| ctrans c <> c && upperTriang c
151expmDiagProp m = expm (logm m) :~ 7 ~: complex m 152expmDiagProp m = expm (logm m) :~ 7 ~: complex m
152 where logm m = matFunc log m 153 where logm m = matFunc log m
153 154
155multProp1 (a,b) = a <> b |~| mulH a b
156
157multProp2 (a,b) = trans (a <> b) |~| trans b <> trans a