diff options
Diffstat (limited to 'lib/Numeric/LinearAlgebra')
-rw-r--r-- | lib/Numeric/LinearAlgebra/Algorithms.hs | 171 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Interface.hs | 4 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | 74 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | 6 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Linear.hs | 54 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Tests.hs | 5 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Tests/Instances.hs | 16 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Tests/Properties.hs | 6 |
8 files changed, 279 insertions, 57 deletions
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index bbc5986..c7118c1 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs | |||
@@ -20,6 +20,7 @@ imported from "Numeric.LinearAlgebra.LAPACK". | |||
20 | 20 | ||
21 | module Numeric.LinearAlgebra.Algorithms ( | 21 | module Numeric.LinearAlgebra.Algorithms ( |
22 | -- * Linear Systems | 22 | -- * Linear Systems |
23 | multiply, dot, | ||
23 | linearSolve, | 24 | linearSolve, |
24 | inv, pinv, | 25 | inv, pinv, |
25 | pinvTol, det, rank, rcond, | 26 | pinvTol, det, rank, rcond, |
@@ -51,6 +52,8 @@ module Numeric.LinearAlgebra.Algorithms ( | |||
51 | -- * Misc | 52 | -- * Misc |
52 | ctrans, | 53 | ctrans, |
53 | eps, i, | 54 | eps, i, |
55 | outer, kronecker, | ||
56 | mulH, | ||
54 | -- * Util | 57 | -- * Util |
55 | haussholder, | 58 | haussholder, |
56 | unpackQR, unpackHess, | 59 | unpackQR, unpackHess, |
@@ -60,13 +63,14 @@ module Numeric.LinearAlgebra.Algorithms ( | |||
60 | 63 | ||
61 | import Data.Packed.Internal hiding (fromComplex, toComplex, comp, conj, (//)) | 64 | import Data.Packed.Internal hiding (fromComplex, toComplex, comp, conj, (//)) |
62 | import Data.Packed | 65 | import Data.Packed |
63 | import qualified Numeric.GSL.Matrix as GSL | ||
64 | import Numeric.GSL.Vector | 66 | import Numeric.GSL.Vector |
65 | import Numeric.LinearAlgebra.LAPACK as LAPACK | 67 | import Numeric.LinearAlgebra.LAPACK as LAPACK |
66 | import Complex | 68 | import Complex |
67 | import Numeric.LinearAlgebra.Linear | 69 | import Numeric.LinearAlgebra.Linear |
68 | import Data.List(foldl1') | 70 | import Data.List(foldl1') |
69 | import Data.Array | 71 | import Data.Array |
72 | import Foreign | ||
73 | import Foreign.C.Types | ||
70 | 74 | ||
71 | -- | Auxiliary typeclass used to define generic computations for both real and complex matrices. | 75 | -- | Auxiliary typeclass used to define generic computations for both real and complex matrices. |
72 | class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where | 76 | class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where |
@@ -105,6 +109,7 @@ class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where | |||
105 | schur :: Matrix t -> (Matrix t, Matrix t) | 109 | schur :: Matrix t -> (Matrix t, Matrix t) |
106 | -- | Conjugate transpose. | 110 | -- | Conjugate transpose. |
107 | ctrans :: Matrix t -> Matrix t | 111 | ctrans :: Matrix t -> Matrix t |
112 | multiply :: Matrix t -> Matrix t -> Matrix t | ||
108 | 113 | ||
109 | 114 | ||
110 | instance Field Double where | 115 | instance Field Double where |
@@ -116,9 +121,10 @@ instance Field Double where | |||
116 | eig = eigR | 121 | eig = eigR |
117 | eigSH' = eigS | 122 | eigSH' = eigS |
118 | cholSH = cholS | 123 | cholSH = cholS |
119 | qr = GSL.unpackQR . qrR | 124 | qr = unpackQR . qrR |
120 | hess = unpackHess hessR | 125 | hess = unpackHess hessR |
121 | schur = schurR | 126 | schur = schurR |
127 | multiply = multiplyR3 | ||
122 | 128 | ||
123 | instance Field (Complex Double) where | 129 | instance Field (Complex Double) where |
124 | svd = svdC | 130 | svd = svdC |
@@ -132,6 +138,8 @@ instance Field (Complex Double) where | |||
132 | qr = unpackQR . qrC | 138 | qr = unpackQR . qrC |
133 | hess = unpackHess hessC | 139 | hess = unpackHess hessC |
134 | schur = schurC | 140 | schur = schurC |
141 | multiply = mulCW -- workaround | ||
142 | -- multiplyC3 | ||
135 | 143 | ||
136 | -- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev. | 144 | -- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev. |
137 | -- | 145 | -- |
@@ -501,3 +509,162 @@ luFact (lu,perm) | r <= c = (l ,u ,p, s) | |||
501 | u' = takeRows c (lu |*| tu) | 509 | u' = takeRows c (lu |*| tu) |
502 | (|+|) = add | 510 | (|+|) = add |
503 | (|*|) = mul | 511 | (|*|) = mul |
512 | |||
513 | -------------------------------------------------- | ||
514 | |||
515 | -- | euclidean inner product | ||
516 | dot :: (Field t) => Vector t -> Vector t -> t | ||
517 | dot u v = multiply r c @@> (0,0) | ||
518 | where r = asRow u | ||
519 | c = asColumn v | ||
520 | |||
521 | |||
522 | {- | Outer product of two vectors. | ||
523 | |||
524 | @\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3] | ||
525 | (3><3) | ||
526 | [ 5.0, 2.0, 3.0 | ||
527 | , 10.0, 4.0, 6.0 | ||
528 | , 15.0, 6.0, 9.0 ]@ | ||
529 | -} | ||
530 | outer :: (Field t) => Vector t -> Vector t -> Matrix t | ||
531 | outer u v = asColumn u `multiply` asRow v | ||
532 | |||
533 | {- | Kronecker product of two matrices. | ||
534 | |||
535 | @m1=(2><3) | ||
536 | [ 1.0, 2.0, 0.0 | ||
537 | , 0.0, -1.0, 3.0 ] | ||
538 | m2=(4><3) | ||
539 | [ 1.0, 2.0, 3.0 | ||
540 | , 4.0, 5.0, 6.0 | ||
541 | , 7.0, 8.0, 9.0 | ||
542 | , 10.0, 11.0, 12.0 ]@ | ||
543 | |||
544 | @\> kronecker m1 m2 | ||
545 | (8><9) | ||
546 | [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0 | ||
547 | , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0 | ||
548 | , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0 | ||
549 | , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0 | ||
550 | , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0 | ||
551 | , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0 | ||
552 | , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0 | ||
553 | , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]@ | ||
554 | -} | ||
555 | kronecker :: (Field t) => Matrix t -> Matrix t -> Matrix t | ||
556 | kronecker a b = fromBlocks | ||
557 | . partit (cols a) | ||
558 | . map (reshape (cols b)) | ||
559 | . toRows | ||
560 | $ flatten a `outer` flatten b | ||
561 | |||
562 | --------------------------------------------------------------------- | ||
563 | -- reference multiply | ||
564 | --------------------------------------------------------------------- | ||
565 | |||
566 | mulH a b = fromLists [[ dot ai bj | bj <- toColumns b] | ai <- toRows a ] | ||
567 | where dot u v = sum $ zipWith (*) (toList u) (toList v) | ||
568 | |||
569 | ----------------------------------------------------------------------------------- | ||
570 | -- workaround | ||
571 | ----------------------------------------------------------------------------------- | ||
572 | |||
573 | mulCW a b = toComplex (rr,ri) | ||
574 | where rr = multiply ar br `sub` multiply ai bi | ||
575 | ri = multiply ar bi `add` multiply ai br | ||
576 | (ar,ai) = fromComplex a | ||
577 | (br,bi) = fromComplex b | ||
578 | |||
579 | ----------------------------------------------------------------------------------- | ||
580 | -- Direct CBLAS | ||
581 | ----------------------------------------------------------------------------------- | ||
582 | |||
583 | newtype CBLASOrder = CBLASOrder CInt deriving (Eq, Show) | ||
584 | newtype CBLASTrans = CBLASTrans CInt deriving (Eq, Show) | ||
585 | |||
586 | rowMajor, colMajor :: CBLASOrder | ||
587 | rowMajor = CBLASOrder 101 | ||
588 | colMajor = CBLASOrder 102 | ||
589 | |||
590 | noTrans, trans', conjTrans :: CBLASTrans | ||
591 | noTrans = CBLASTrans 111 | ||
592 | trans' = CBLASTrans 112 | ||
593 | conjTrans = CBLASTrans 113 | ||
594 | |||
595 | foreign import ccall "cblas.h cblas_dgemm" | ||
596 | dgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Double -> Ptr Double -> CInt -> IO () | ||
597 | |||
598 | |||
599 | multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double | ||
600 | multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b) | ||
601 | where | ||
602 | multiply3 f st a b | ||
603 | | cols a == rows b = unsafePerformIO $ do | ||
604 | s <- createMatrix ColumnMajor (rows a) (cols b) | ||
605 | let g ar ac ap br bc bp rr rc rp = f colMajor noTrans noTrans ar bc ac 1 ap ar bp br 0 rp rr >> return 0 | ||
606 | app3 g mat a mat b mat s st | ||
607 | return s | ||
608 | | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" | ||
609 | |||
610 | |||
611 | foreign import ccall "cblas.h cblas_zgemm" | ||
612 | zgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> IO () | ||
613 | |||
614 | multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | ||
615 | multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat b) | ||
616 | where | ||
617 | multiply3 f st a b | ||
618 | | cols a == rows b = do | ||
619 | s <- createMatrix ColumnMajor (rows a) (cols b) | ||
620 | palpha <- new 1 | ||
621 | pbeta <- new 0 | ||
622 | let g ar ac ap br bc bp rr rc rp = f colMajor noTrans noTrans ar bc ac palpha ap ar bp br pbeta rp rr >> return 0 | ||
623 | app3 g mat a mat b mat s st | ||
624 | free palpha | ||
625 | free pbeta | ||
626 | return s | ||
627 | -- if toLists s== toLists s then return s else error $ "HORROR " ++ (show (toLists s)) | ||
628 | | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" | ||
629 | |||
630 | ----------------------------------------------------------------------------------- | ||
631 | -- BLAS via auxiliary C | ||
632 | ----------------------------------------------------------------------------------- | ||
633 | |||
634 | foreign import ccall "multiply.h multiplyR2" dgemmc :: TMMM | ||
635 | foreign import ccall "multiply.h multiplyC2" zgemmc :: TCMCMCM | ||
636 | |||
637 | multiply2 f st a b | ||
638 | | cols a == rows b = unsafePerformIO $ do | ||
639 | s <- createMatrix ColumnMajor (rows a) (cols b) | ||
640 | app3 f mat a mat b mat s st | ||
641 | if toLists s== toLists s then return s else error $ "AYYY " ++ (show (toLists s)) | ||
642 | | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" | ||
643 | |||
644 | multiplyR2 :: Matrix Double -> Matrix Double -> Matrix Double | ||
645 | multiplyR2 a b = multiply2 dgemmc "dgemmc" (fmat a) (fmat b) | ||
646 | |||
647 | multiplyC2 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | ||
648 | multiplyC2 a b = multiply2 zgemmc "zgemmc" (fmat a) (fmat b) | ||
649 | |||
650 | ----------------------------------------------------------------------------------- | ||
651 | -- direct C multiplication | ||
652 | ----------------------------------------------------------------------------------- | ||
653 | |||
654 | foreign import ccall "multiply.h multiplyR" cmultiplyR :: TMMM | ||
655 | foreign import ccall "multiply.h multiplyC" cmultiplyC :: TCMCMCM | ||
656 | |||
657 | cmultiply f st a b | ||
658 | -- | cols a == rows b = | ||
659 | = unsafePerformIO $ do | ||
660 | s <- createMatrix RowMajor (rows a) (cols b) | ||
661 | app3 f mat a mat b mat s st | ||
662 | if toLists s== toLists s then return s else error $ "BRUTAL " ++ (show (toLists s)) | ||
663 | -- return s | ||
664 | -- | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" | ||
665 | |||
666 | multiplyR :: Matrix Double -> Matrix Double -> Matrix Double | ||
667 | multiplyR a b = cmultiply cmultiplyR "cmultiplyR" (cmat a) (cmat b) | ||
668 | |||
669 | multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | ||
670 | multiplyC a b = cmultiply cmultiplyC "cmultiplyR" (cmat a) (cmat b) | ||
diff --git a/lib/Numeric/LinearAlgebra/Interface.hs b/lib/Numeric/LinearAlgebra/Interface.hs index 4a9b309..0ae9698 100644 --- a/lib/Numeric/LinearAlgebra/Interface.hs +++ b/lib/Numeric/LinearAlgebra/Interface.hs | |||
@@ -29,7 +29,7 @@ import Numeric.LinearAlgebra.Algorithms | |||
29 | class Mul a b c | a b -> c where | 29 | class Mul a b c | a b -> c where |
30 | infixl 7 <> | 30 | infixl 7 <> |
31 | -- | matrix product | 31 | -- | matrix product |
32 | (<>) :: Element t => a t -> b t -> c t | 32 | (<>) :: Field t => a t -> b t -> c t |
33 | 33 | ||
34 | instance Mul Matrix Matrix Matrix where | 34 | instance Mul Matrix Matrix Matrix where |
35 | (<>) = multiply | 35 | (<>) = multiply |
@@ -43,7 +43,7 @@ instance Mul Vector Matrix Vector where | |||
43 | --------------------------------------------------- | 43 | --------------------------------------------------- |
44 | 44 | ||
45 | -- | @u \<.\> v = dot u v@ | 45 | -- | @u \<.\> v = dot u v@ |
46 | (<.>) :: (Element t) => Vector t -> Vector t -> t | 46 | (<.>) :: (Field t) => Vector t -> Vector t -> t |
47 | infixl 7 <.> | 47 | infixl 7 <.> |
48 | (<.>) = dot | 48 | (<.>) = dot |
49 | 49 | ||
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c index 310f6ee..0dccea2 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | |||
@@ -814,3 +814,77 @@ int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)) { | |||
814 | free(auxipiv); | 814 | free(auxipiv); |
815 | OK | 815 | OK |
816 | } | 816 | } |
817 | |||
818 | //////////////////////////////////////////////////////////// | ||
819 | |||
820 | int multiplyR(KDMAT(a),KDMAT(b),DMAT(r)) { | ||
821 | REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
822 | int i,j,k; | ||
823 | for (i=0;i<ar;i++) { | ||
824 | for(j=0;j<bc;j++) { | ||
825 | double temp = 0; | ||
826 | for(k=0;k<ac;k++) { | ||
827 | temp += ap[i*ac+k]*bp[k*bc+j]; | ||
828 | } | ||
829 | rp[i*rc+j] = temp; | ||
830 | } | ||
831 | } | ||
832 | OK | ||
833 | } | ||
834 | |||
835 | int multiplyC(KCMAT(a),KCMAT(b),CMAT(r)) { | ||
836 | REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
837 | int i,j,k; | ||
838 | for (i=0;i<ar;i++) { | ||
839 | for(j=0;j<bc;j++) { | ||
840 | doublecomplex temp = {0,0}; | ||
841 | for(k=0;k<ac;k++) { | ||
842 | doublecomplex aik = ((doublecomplex*)ap)[i*ac+k]; | ||
843 | doublecomplex bkj = ((doublecomplex*)bp)[k*bc+j]; | ||
844 | //double w = aik.r+aik.i+bkj.r+bkj.i; | ||
845 | //if (w>w) exit(1); | ||
846 | //printf("%d",w>w); | ||
847 | temp.r += aik.r * bkj.r - aik.i * bkj.i; | ||
848 | temp.i += aik.r * bkj.i + aik.i * bkj.r; | ||
849 | //printf("%f %f %f %f \n",aik.r,aik.i,bkj.r,bkj.i); | ||
850 | //printf("%f %f %f \n",w,temp.r,temp.i); | ||
851 | |||
852 | } | ||
853 | ((doublecomplex*)rp)[i*rc+j] = temp; | ||
854 | //printf("%f %f\n",temp.r,temp.i); | ||
855 | } | ||
856 | } | ||
857 | OK | ||
858 | } | ||
859 | |||
860 | void dgemm_(char *, char *, integer *, integer *, integer *, | ||
861 | double *, const double *, integer *, const double *, | ||
862 | integer *, double *, double *, integer *); | ||
863 | |||
864 | void zgemm_(char *, char *, integer *, integer *, integer *, | ||
865 | doublecomplex *, const doublecomplex *, integer *, const doublecomplex *, | ||
866 | integer *, doublecomplex *, doublecomplex *, integer *); | ||
867 | |||
868 | |||
869 | int multiplyR2(KDMAT(a),KDMAT(b),DMAT(r)) { | ||
870 | REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
871 | double alpha = 1; | ||
872 | double beta = 0; | ||
873 | integer m = ar; | ||
874 | integer n = bc; | ||
875 | integer k = ac; | ||
876 | int i,j; | ||
877 | dgemm_("N","N",&m,&n,&k,&alpha,ap,&m,bp,&k,&beta,rp,&m); | ||
878 | OK | ||
879 | } | ||
880 | |||
881 | int multiplyC2(KCMAT(a),KCMAT(b),CMAT(r)) { | ||
882 | REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
883 | integer m = ar; | ||
884 | integer n = bc; | ||
885 | integer k = ac; | ||
886 | doublecomplex alpha = {1,0}; | ||
887 | doublecomplex beta = {0,0}; | ||
888 | zgemm_("N","N",&m,&n,&k,&alpha,(doublecomplex*)ap,&m,(doublecomplex*)bp,&k,&beta,(doublecomplex*)rp,&m); | ||
889 | OK | ||
890 | } | ||
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h index 79e52be..c0361a6 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | |||
@@ -84,3 +84,9 @@ int schur_l_C(KCMAT(a), CMAT(u), CMAT(s)); | |||
84 | 84 | ||
85 | int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)); | 85 | int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)); |
86 | int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)); | 86 | int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)); |
87 | |||
88 | int multiplyR(KDMAT(a),KDMAT(b),DMAT(r)); | ||
89 | int multiplyC(KCMAT(a),KCMAT(b),CMAT(r)); | ||
90 | |||
91 | int multiplyR2(KDMAT(a),KDMAT(b),DMAT(r)); | ||
92 | int multiplyC2(KCMAT(a),KCMAT(b),CMAT(r)); | ||
diff --git a/lib/Numeric/LinearAlgebra/Linear.hs b/lib/Numeric/LinearAlgebra/Linear.hs index 0ddbb55..1bf8b04 100644 --- a/lib/Numeric/LinearAlgebra/Linear.hs +++ b/lib/Numeric/LinearAlgebra/Linear.hs | |||
@@ -15,12 +15,11 @@ Basic optimized operations on vectors and matrices. | |||
15 | ----------------------------------------------------------------------------- | 15 | ----------------------------------------------------------------------------- |
16 | 16 | ||
17 | module Numeric.LinearAlgebra.Linear ( | 17 | module Numeric.LinearAlgebra.Linear ( |
18 | Linear(..), | 18 | Linear(..) |
19 | multiply, dot, outer, kronecker | ||
20 | ) where | 19 | ) where |
21 | 20 | ||
22 | 21 | ||
23 | import Data.Packed.Internal(multiply,partit) | 22 | import Data.Packed.Internal(partit) |
24 | import Data.Packed | 23 | import Data.Packed |
25 | import Numeric.GSL.Vector | 24 | import Numeric.GSL.Vector |
26 | import Complex | 25 | import Complex |
@@ -69,52 +68,3 @@ instance (Linear Vector a, Container Matrix a) => (Linear Matrix a) where | |||
69 | mul = liftMatrix2 mul | 68 | mul = liftMatrix2 mul |
70 | divide = liftMatrix2 divide | 69 | divide = liftMatrix2 divide |
71 | equal a b = cols a == cols b && flatten a `equal` flatten b | 70 | equal a b = cols a == cols b && flatten a `equal` flatten b |
72 | |||
73 | -------------------------------------------------- | ||
74 | |||
75 | -- | euclidean inner product | ||
76 | dot :: (Element t) => Vector t -> Vector t -> t | ||
77 | dot u v = multiply r c @@> (0,0) | ||
78 | where r = asRow u | ||
79 | c = asColumn v | ||
80 | |||
81 | |||
82 | {- | Outer product of two vectors. | ||
83 | |||
84 | @\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3] | ||
85 | (3><3) | ||
86 | [ 5.0, 2.0, 3.0 | ||
87 | , 10.0, 4.0, 6.0 | ||
88 | , 15.0, 6.0, 9.0 ]@ | ||
89 | -} | ||
90 | outer :: (Element t) => Vector t -> Vector t -> Matrix t | ||
91 | outer u v = asColumn u `multiply` asRow v | ||
92 | |||
93 | {- | Kronecker product of two matrices. | ||
94 | |||
95 | @m1=(2><3) | ||
96 | [ 1.0, 2.0, 0.0 | ||
97 | , 0.0, -1.0, 3.0 ] | ||
98 | m2=(4><3) | ||
99 | [ 1.0, 2.0, 3.0 | ||
100 | , 4.0, 5.0, 6.0 | ||
101 | , 7.0, 8.0, 9.0 | ||
102 | , 10.0, 11.0, 12.0 ]@ | ||
103 | |||
104 | @\> kronecker m1 m2 | ||
105 | (8><9) | ||
106 | [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0 | ||
107 | , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0 | ||
108 | , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0 | ||
109 | , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0 | ||
110 | , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0 | ||
111 | , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0 | ||
112 | , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0 | ||
113 | , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]@ | ||
114 | -} | ||
115 | kronecker :: (Element t) => Matrix t -> Matrix t -> Matrix t | ||
116 | kronecker a b = fromBlocks | ||
117 | . partit (cols a) | ||
118 | . map (reshape (cols b)) | ||
119 | . toRows | ||
120 | $ flatten a `outer` flatten b | ||
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs index 7b28075..07b9f63 100644 --- a/lib/Numeric/LinearAlgebra/Tests.hs +++ b/lib/Numeric/LinearAlgebra/Tests.hs | |||
@@ -123,6 +123,11 @@ runTests :: Int -- ^ maximum dimension | |||
123 | runTests n = do | 123 | runTests n = do |
124 | setErrorHandlerOff | 124 | setErrorHandlerOff |
125 | let test p = qCheck n p | 125 | let test p = qCheck n p |
126 | putStrLn "------ mult" | ||
127 | test (multProp1 . rConsist) | ||
128 | test (multProp1 . cConsist) | ||
129 | test (multProp2 . rConsist) | ||
130 | test (multProp2 . cConsist) | ||
126 | putStrLn "------ lu" | 131 | putStrLn "------ lu" |
127 | test (luProp . rM) | 132 | test (luProp . rM) |
128 | test (luProp . cM) | 133 | test (luProp . cM) |
diff --git a/lib/Numeric/LinearAlgebra/Tests/Instances.hs b/lib/Numeric/LinearAlgebra/Tests/Instances.hs index af486c8..e7fecf2 100644 --- a/lib/Numeric/LinearAlgebra/Tests/Instances.hs +++ b/lib/Numeric/LinearAlgebra/Tests/Instances.hs | |||
@@ -20,6 +20,7 @@ module Numeric.LinearAlgebra.Tests.Instances( | |||
20 | WC(..), rWC,cWC, | 20 | WC(..), rWC,cWC, |
21 | SqWC(..), rSqWC, cSqWC, | 21 | SqWC(..), rSqWC, cSqWC, |
22 | PosDef(..), rPosDef, cPosDef, | 22 | PosDef(..), rPosDef, cPosDef, |
23 | Consistent(..), rConsist, cConsist, | ||
23 | RM,CM, rM,cM | 24 | RM,CM, rM,cM |
24 | ) where | 25 | ) where |
25 | 26 | ||
@@ -116,6 +117,19 @@ instance (Field a, Arbitrary a) => Arbitrary (PosDef a) where | |||
116 | return $ PosDef (0.5 .* p + 0.5 .* ctrans p) | 117 | return $ PosDef (0.5 .* p + 0.5 .* ctrans p) |
117 | coarbitrary = undefined | 118 | coarbitrary = undefined |
118 | 119 | ||
120 | -- a pair of matrices that can be multiplied | ||
121 | newtype (Consistent a) = Consistent (Matrix a, Matrix a) deriving Show | ||
122 | instance (Field a, Arbitrary a) => Arbitrary (Consistent a) where | ||
123 | arbitrary = do | ||
124 | n <- chooseDim | ||
125 | k <- chooseDim | ||
126 | m <- chooseDim | ||
127 | la <- vector (n*k) | ||
128 | lb <- vector (k*m) | ||
129 | return $ Consistent ((n><k) la, (k><m) lb) | ||
130 | coarbitrary = undefined | ||
131 | |||
132 | |||
119 | type RM = Matrix Double | 133 | type RM = Matrix Double |
120 | type CM = Matrix (Complex Double) | 134 | type CM = Matrix (Complex Double) |
121 | 135 | ||
@@ -140,3 +154,5 @@ cSqWC (SqWC m) = m :: CM | |||
140 | rPosDef (PosDef m) = m :: RM | 154 | rPosDef (PosDef m) = m :: RM |
141 | cPosDef (PosDef m) = m :: CM | 155 | cPosDef (PosDef m) = m :: CM |
142 | 156 | ||
157 | rConsist (Consistent (a,b)) = (a,b::RM) | ||
158 | cConsist (Consistent (a,b)) = (a,b::CM) | ||
diff --git a/lib/Numeric/LinearAlgebra/Tests/Properties.hs b/lib/Numeric/LinearAlgebra/Tests/Properties.hs index 55e9a1b..5663b86 100644 --- a/lib/Numeric/LinearAlgebra/Tests/Properties.hs +++ b/lib/Numeric/LinearAlgebra/Tests/Properties.hs | |||
@@ -34,7 +34,8 @@ module Numeric.LinearAlgebra.Tests.Properties ( | |||
34 | hessProp, | 34 | hessProp, |
35 | schurProp1, schurProp2, | 35 | schurProp1, schurProp2, |
36 | cholProp, | 36 | cholProp, |
37 | expmDiagProp | 37 | expmDiagProp, |
38 | multProp1, multProp2 | ||
38 | ) where | 39 | ) where |
39 | 40 | ||
40 | import Numeric.LinearAlgebra | 41 | import Numeric.LinearAlgebra |
@@ -151,3 +152,6 @@ cholProp m = m |~| ctrans c <> c && upperTriang c | |||
151 | expmDiagProp m = expm (logm m) :~ 7 ~: complex m | 152 | expmDiagProp m = expm (logm m) :~ 7 ~: complex m |
152 | where logm m = matFunc log m | 153 | where logm m = matFunc log m |
153 | 154 | ||
155 | multProp1 (a,b) = a <> b |~| mulH a b | ||
156 | |||
157 | multProp2 (a,b) = trans (a <> b) |~| trans b <> trans a | ||