summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2010-08-26 17:49:45 +0000
committerAlberto Ruiz <aruiz@um.es>2010-08-26 17:49:45 +0000
commit6058e1b17c005be1ea95ebb7d98d9fd15bb538d2 (patch)
treec4277e00c2c92a0ed8f3750255154fa8e2b6fe2d
parentf541d7dbdc8338b1dd1c0538751d837a16740bd8 (diff)
Float matrix product
-rw-r--r--CHANGES2
-rw-r--r--lib/Data/Packed/Internal/Signatures.hs4
-rw-r--r--lib/Data/Packed/Internal/Vector.hs20
-rw-r--r--lib/Data/Packed/Matrix.hs66
-rw-r--r--lib/Graphics/Plot.hs1
-rw-r--r--lib/Numeric/LinearAlgebra/Algorithms.hs81
-rw-r--r--lib/Numeric/LinearAlgebra/Interface.hs2
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK.hs18
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c81
-rw-r--r--lib/Numeric/LinearAlgebra/Linear.hs90
-rw-r--r--lib/Numeric/LinearAlgebra/Tests.hs16
-rw-r--r--lib/Numeric/LinearAlgebra/Tests/Properties.hs6
12 files changed, 264 insertions, 123 deletions
diff --git a/CHANGES b/CHANGES
index 20c5f0e..8d09c9f 100644
--- a/CHANGES
+++ b/CHANGES
@@ -5,7 +5,7 @@
5 5
6- Vectors typeclass 6- Vectors typeclass
7 7
8- Initial support for Vector Float and Vector (Complex Float) 8- Support for Float and Complex Float elements (excluding LAPACK computations)
9 9
10- Binary instances for Vector and Matrix 10- Binary instances for Vector and Matrix
11 11
diff --git a/lib/Data/Packed/Internal/Signatures.hs b/lib/Data/Packed/Internal/Signatures.hs
index 8c1c5f6..b81efa4 100644
--- a/lib/Data/Packed/Internal/Signatures.hs
+++ b/lib/Data/Packed/Internal/Signatures.hs
@@ -24,12 +24,15 @@ type PQ = Ptr (Complex Float) --
24type PC = Ptr (Complex Double) -- 24type PC = Ptr (Complex Double) --
25type TF = CInt -> PF -> IO CInt -- 25type TF = CInt -> PF -> IO CInt --
26type TFF = CInt -> PF -> TF -- 26type TFF = CInt -> PF -> TF --
27type TFV = CInt -> PF -> TV --
28type TVF = CInt -> PD -> TF --
27type TFFF = CInt -> PF -> TFF -- 29type TFFF = CInt -> PF -> TFF --
28type TV = CInt -> PD -> IO CInt -- 30type TV = CInt -> PD -> IO CInt --
29type TVV = CInt -> PD -> TV -- 31type TVV = CInt -> PD -> TV --
30type TVVV = CInt -> PD -> TVV -- 32type TVVV = CInt -> PD -> TVV --
31type TFM = CInt -> CInt -> PF -> IO CInt -- 33type TFM = CInt -> CInt -> PF -> IO CInt --
32type TFMFM = CInt -> CInt -> PF -> TFM -- 34type TFMFM = CInt -> CInt -> PF -> TFM --
35type TFMFMFM = CInt -> CInt -> PF -> TFMFM --
33type TM = CInt -> CInt -> PD -> IO CInt -- 36type TM = CInt -> CInt -> PD -> IO CInt --
34type TMM = CInt -> CInt -> PD -> TM -- 37type TMM = CInt -> CInt -> PD -> TM --
35type TVMM = CInt -> PD -> TMM -- 38type TVMM = CInt -> PD -> TMM --
@@ -61,6 +64,7 @@ type TQVQVQV = CInt -> PQ -> TQVQV --
61type TQVF = CInt -> PQ -> TF -- 64type TQVF = CInt -> PQ -> TF --
62type TQM = CInt -> CInt -> PQ -> IO CInt -- 65type TQM = CInt -> CInt -> PQ -> IO CInt --
63type TQMQM = CInt -> CInt -> PQ -> TQM -- 66type TQMQM = CInt -> CInt -> PQ -> TQM --
67type TQMQMQM = CInt -> CInt -> PQ -> TQMQM --
64type TCMCV = CInt -> CInt -> PC -> TCV -- 68type TCMCV = CInt -> CInt -> PC -> TCV --
65type TVCV = CInt -> PD -> TCV -- 69type TVCV = CInt -> PD -> TCV --
66type TCVM = CInt -> PC -> TM -- 70type TCVM = CInt -> PC -> TM --
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs
index ac2d0d7..c8cc2c2 100644
--- a/lib/Data/Packed/Internal/Vector.hs
+++ b/lib/Data/Packed/Internal/Vector.hs
@@ -21,7 +21,7 @@ module Data.Packed.Internal.Vector (
21 mapVectorM, mapVectorM_, 21 mapVectorM, mapVectorM_,
22 foldVector, foldVectorG, foldLoop, 22 foldVector, foldVectorG, foldLoop,
23 createVector, vec, 23 createVector, vec,
24 asComplex, asReal, 24 asComplex, asReal, float2DoubleV, double2FloatV,
25 fwriteVector, freadVector, fprintfVector, fscanfVector, 25 fwriteVector, freadVector, fprintfVector, fscanfVector,
26 cloneVector, 26 cloneVector,
27 unsafeToForeignPtr, 27 unsafeToForeignPtr,
@@ -274,6 +274,24 @@ asComplex :: (RealFloat a, Storable a) => Vector a -> Vector (Complex a)
274asComplex v = unsafeFromForeignPtr (castForeignPtr fp) (i `div` 2) (n `div` 2) 274asComplex v = unsafeFromForeignPtr (castForeignPtr fp) (i `div` 2) (n `div` 2)
275 where (fp,i,n) = unsafeToForeignPtr v 275 where (fp,i,n) = unsafeToForeignPtr v
276 276
277---------------------------------------------------------------
278
279float2DoubleV :: Vector Float -> Vector Double
280float2DoubleV v = unsafePerformIO $ do
281 r <- createVector (dim v)
282 app2 c_float2double vec v vec r "float2double"
283 return r
284
285double2FloatV :: Vector Double -> Vector Float
286double2FloatV v = unsafePerformIO $ do
287 r <- createVector (dim v)
288 app2 c_double2float vec v vec r "double2float2"
289 return r
290
291
292foreign import ccall "float2double" c_float2double:: TFV
293foreign import ccall "double2float" c_double2float:: TVF
294
277---------------------------------------------------------------- 295----------------------------------------------------------------
278 296
279cloneVector :: Storable t => Vector t -> IO (Vector t) 297cloneVector :: Storable t => Vector t -> IO (Vector t)
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs
index 8aa1693..8694249 100644
--- a/lib/Data/Packed/Matrix.hs
+++ b/lib/Data/Packed/Matrix.hs
@@ -1,6 +1,10 @@
1{-# LANGUAGE TypeFamilies #-} 1{-# LANGUAGE TypeFamilies #-}
2{-# LANGUAGE FlexibleContexts #-} 2{-# LANGUAGE FlexibleContexts #-}
3{-# LANGUAGE FlexibleInstances #-} 3{-# LANGUAGE FlexibleInstances #-}
4{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE FunctionalDependencies #-}
6
7
4----------------------------------------------------------------------------- 8-----------------------------------------------------------------------------
5-- | 9-- |
6-- Module : Data.Packed.Matrix 10-- Module : Data.Packed.Matrix
@@ -16,8 +20,9 @@
16----------------------------------------------------------------------------- 20-----------------------------------------------------------------------------
17 21
18module Data.Packed.Matrix ( 22module Data.Packed.Matrix (
19 Element, Scalar, Container(..), Convert(..), 23 Element, RealElement, Container(..),
20 RealOf, ComplexOf, SingleOf, DoubleOf, ElementOf, AutoReal(..), 24 Convert(..), RealOf, ComplexOf, SingleOf, DoubleOf, ElementOf,
25 AutoReal(..),
21 Matrix,rows,cols, 26 Matrix,rows,cols,
22 (><), 27 (><),
23 trans, 28 trans,
@@ -51,7 +56,7 @@ import Data.Binary
51import Foreign.Storable 56import Foreign.Storable
52import Control.Monad(replicateM) 57import Control.Monad(replicateM)
53import Control.Arrow((***)) 58import Control.Arrow((***))
54import GHC.Float(double2Float,float2Double) 59--import GHC.Float(double2Float,float2Double)
55 60
56------------------------------------------------------------------- 61-------------------------------------------------------------------
57 62
@@ -468,17 +473,32 @@ toBlocksEvery r c m = toBlocks rs cs m where
468 473
469-- | conversion utilities 474-- | conversion utilities
470 475
471class (Element t, Element (Complex t), Fractional t, RealFloat t) => Scalar t 476class (Element t, Element (Complex t), RealFloat t) => RealElement t
477
478instance RealElement Double
479instance RealElement Float
480
481class (Element s, Element d) => Prec s d | s -> d, d -> s where
482 double2FloatG :: Vector d -> Vector s
483 float2DoubleG :: Vector s -> Vector d
484
485instance Prec Float Double where
486 double2FloatG = double2FloatV
487 float2DoubleG = float2DoubleV
488
489instance Prec (Complex Float) (Complex Double) where
490 double2FloatG = asComplex . double2FloatV . asReal
491 float2DoubleG = asComplex . float2DoubleV . asReal
472 492
473instance Scalar Double
474instance Scalar Float
475 493
476class Container c where 494class Container c where
477 toComplex :: (Scalar e) => (c e, c e) -> c (Complex e) 495 toComplex :: (RealElement e) => (c e, c e) -> c (Complex e)
478 fromComplex :: (Scalar e) => c (Complex e) -> (c e, c e) 496 fromComplex :: (RealElement e) => c (Complex e) -> (c e, c e)
479 comp :: (Scalar e) => c e -> c (Complex e) 497 comp :: (RealElement e) => c e -> c (Complex e)
480 conj :: (Scalar e) => c (Complex e) -> c (Complex e) 498 conj :: (RealElement e) => c (Complex e) -> c (Complex e)
481 cmap :: (Element a, Element b) => (a -> b) -> c a -> c b 499 cmap :: (Element a, Element b) => (a -> b) -> c a -> c b
500 single :: Prec a b => c b -> c a
501 double :: Prec a b => c a -> c b
482 502
483instance Container Vector where 503instance Container Vector where
484 toComplex = toComplexV 504 toComplex = toComplexV
@@ -486,6 +506,8 @@ instance Container Vector where
486 comp v = toComplex (v,constantD 0 (dim v)) 506 comp v = toComplex (v,constantD 0 (dim v))
487 conj = conjV 507 conj = conjV
488 cmap = mapVector 508 cmap = mapVector
509 single = double2FloatG
510 double = float2DoubleG
489 511
490instance Container Matrix where 512instance Container Matrix where
491 toComplex = uncurry $ liftMatrix2 $ curry toComplex 513 toComplex = uncurry $ liftMatrix2 $ curry toComplex
@@ -494,6 +516,8 @@ instance Container Matrix where
494 comp = liftMatrix comp 516 comp = liftMatrix comp
495 conj = liftMatrix conj 517 conj = liftMatrix conj
496 cmap f = liftMatrix (cmap f) 518 cmap f = liftMatrix (cmap f)
519 single = liftMatrix single
520 double = liftMatrix double
497 521
498------------------------------------------------------------------- 522-------------------------------------------------------------------
499 523
@@ -534,38 +558,40 @@ type instance ElementOf (Matrix a) = a
534 558
535------------------------------------------------------------------- 559-------------------------------------------------------------------
536 560
561-- | generic conversion functions
537class Convert t where 562class Convert t where
538 real' :: Container c => c (RealOf t) -> c t 563 real' :: Container c => c (RealOf t) -> c t
539 complex' :: Container c => c t -> c (ComplexOf t) 564 complex' :: Container c => c t -> c (ComplexOf t)
540 single :: Container c => c t -> c (SingleOf t) 565 single' :: Container c => c t -> c (SingleOf t)
541 double :: Container c => c t -> c (DoubleOf t) 566 double' :: Container c => c t -> c (DoubleOf t)
542 567
543instance Convert Double where 568instance Convert Double where
544 real' = id 569 real' = id
545 complex' = comp 570 complex' = comp
546 single = cmap double2Float 571 single' = single
547 double = id 572 double' = id
548 573
549instance Convert Float where 574instance Convert Float where
550 real' = id 575 real' = id
551 complex' = comp 576 complex' = comp
552 single = id 577 single' = id
553 double = cmap float2Double 578 double' = double
554 579
555instance Convert (Complex Double) where 580instance Convert (Complex Double) where
556 real' = comp 581 real' = comp
557 complex' = id 582 complex' = id
558 single = toComplex . (single *** single) . fromComplex 583 single' = single
559 double = id 584 double' = id
560 585
561instance Convert (Complex Float) where 586instance Convert (Complex Float) where
562 real' = comp 587 real' = comp
563 complex' = id 588 complex' = id
564 single = id 589 single' = id
565 double = toComplex . (double *** double) . fromComplex 590 double' = double
566 591
567------------------------------------------------------------------- 592-------------------------------------------------------------------
568 593
594-- | to be replaced by Convert
569class AutoReal t where 595class AutoReal t where
570 real :: Container c => c Double -> c t 596 real :: Container c => c Double -> c t
571 complex :: Container c => c t -> c (Complex Double) 597 complex :: Container c => c t -> c (Complex Double)
diff --git a/lib/Graphics/Plot.hs b/lib/Graphics/Plot.hs
index b2acc15..2dc0553 100644
--- a/lib/Graphics/Plot.hs
+++ b/lib/Graphics/Plot.hs
@@ -29,7 +29,6 @@ module Graphics.Plot(
29) where 29) where
30 30
31import Data.Packed 31import Data.Packed
32import Numeric.LinearAlgebra(outer)
33import Numeric.LinearAlgebra.Linear 32import Numeric.LinearAlgebra.Linear
34import Data.List(intersperse) 33import Data.List(intersperse)
35import System.Process (system) 34import System.Process (system)
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs
index f4b7ee9..8962c60 100644
--- a/lib/Numeric/LinearAlgebra/Algorithms.hs
+++ b/lib/Numeric/LinearAlgebra/Algorithms.hs
@@ -21,9 +21,6 @@ imported from "Numeric.LinearAlgebra.LAPACK".
21module Numeric.LinearAlgebra.Algorithms ( 21module Numeric.LinearAlgebra.Algorithms (
22-- * Supported types 22-- * Supported types
23 Field(), 23 Field(),
24-- * Products
25 multiply, -- dot, moved dot to typeclass
26 outer, kronecker,
27-- * Linear Systems 24-- * Linear Systems
28 linearSolve, 25 linearSolve,
29 luSolve, 26 luSolve,
@@ -64,7 +61,6 @@ module Numeric.LinearAlgebra.Algorithms (
64-- * Norms 61-- * Norms
65 Normed(..), NormType(..), 62 Normed(..), NormType(..),
66-- * Misc 63-- * Misc
67 ctrans,
68 eps, i, 64 eps, i,
69-- * Util 65-- * Util
70 haussholder, 66 haussholder,
@@ -86,7 +82,7 @@ import Data.List(foldl1')
86import Data.Array 82import Data.Array
87 83
88-- | Auxiliary typeclass used to define generic computations for both real and complex matrices. 84-- | Auxiliary typeclass used to define generic computations for both real and complex matrices.
89class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where 85class (Prod t, Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where
90 svd' :: Matrix t -> (Matrix t, Vector Double, Matrix t) 86 svd' :: Matrix t -> (Matrix t, Vector Double, Matrix t)
91 thinSVD' :: Matrix t -> (Matrix t, Vector Double, Matrix t) 87 thinSVD' :: Matrix t -> (Matrix t, Vector Double, Matrix t)
92 sv' :: Matrix t -> Vector Double 88 sv' :: Matrix t -> Vector Double
@@ -105,8 +101,6 @@ class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where
105 qr' :: Matrix t -> (Matrix t, Matrix t) 101 qr' :: Matrix t -> (Matrix t, Matrix t)
106 hess' :: Matrix t -> (Matrix t, Matrix t) 102 hess' :: Matrix t -> (Matrix t, Matrix t)
107 schur' :: Matrix t -> (Matrix t, Matrix t) 103 schur' :: Matrix t -> (Matrix t, Matrix t)
108 ctrans' :: Matrix t -> Matrix t
109 multiply' :: Matrix t -> Matrix t -> Matrix t
110 104
111 105
112instance Field Double where 106instance Field Double where
@@ -119,7 +113,6 @@ instance Field Double where
119 cholSolve' = cholSolveR 113 cholSolve' = cholSolveR
120 linearSolveLS' = linearSolveLSR 114 linearSolveLS' = linearSolveLSR
121 linearSolveSVD' = linearSolveSVDR Nothing 115 linearSolveSVD' = linearSolveSVDR Nothing
122 ctrans' = trans
123 eig' = eigR 116 eig' = eigR
124 eigSH'' = eigS 117 eigSH'' = eigS
125 eigOnly = eigOnlyR 118 eigOnly = eigOnlyR
@@ -129,7 +122,6 @@ instance Field Double where
129 qr' = unpackQR . qrR 122 qr' = unpackQR . qrR
130 hess' = unpackHess hessR 123 hess' = unpackHess hessR
131 schur' = schurR 124 schur' = schurR
132 multiply' = multiplyR
133 125
134instance Field (Complex Double) where 126instance Field (Complex Double) where
135#ifdef NOZGESDD 127#ifdef NOZGESDD
@@ -146,7 +138,6 @@ instance Field (Complex Double) where
146 cholSolve' = cholSolveC 138 cholSolve' = cholSolveC
147 linearSolveLS' = linearSolveLSC 139 linearSolveLS' = linearSolveLSC
148 linearSolveSVD' = linearSolveSVDC Nothing 140 linearSolveSVD' = linearSolveSVDC Nothing
149 ctrans' = conj . trans
150 eig' = eigC 141 eig' = eigC
151 eigOnly = eigOnlyC 142 eigOnly = eigOnlyC
152 eigSH'' = eigH 143 eigSH'' = eigH
@@ -156,7 +147,6 @@ instance Field (Complex Double) where
156 qr' = unpackQR . qrC 147 qr' = unpackQR . qrC
157 hess' = unpackHess hessC 148 hess' = unpackHess hessC
158 schur' = schurC 149 schur' = schurC
159 multiply' = multiplyC
160 150
161-------------------------------------------------------------- 151--------------------------------------------------------------
162 152
@@ -324,13 +314,6 @@ hess = hess'
324schur :: Field t => Matrix t -> (Matrix t, Matrix t) 314schur :: Field t => Matrix t -> (Matrix t, Matrix t)
325schur = schur' 315schur = schur'
326 316
327-- | Generic conjugate transpose.
328ctrans :: Field t => Matrix t -> Matrix t
329ctrans = ctrans'
330
331-- | Matrix product.
332multiply :: Field t => Matrix t -> Matrix t -> Matrix t
333multiply = {-# SCC "multiply" #-} multiply'
334 317
335-- | Similar to 'cholSH', but instead of an error (e.g., caused by a matrix not positive definite) it returns 'Nothing'. 318-- | Similar to 'cholSH', but instead of an error (e.g., caused by a matrix not positive definite) it returns 'Nothing'.
336mbCholSH :: Field t => Matrix t -> Maybe (Matrix t) 319mbCholSH :: Field t => Matrix t -> Maybe (Matrix t)
@@ -404,20 +387,6 @@ peps x = 2.0**(fromIntegral $ 1-floatDigits x)
404i :: Complex Double 387i :: Complex Double
405i = 0:+1 388i = 0:+1
406 389
407
408-- matrix product
409mXm :: (Num t, Field t) => Matrix t -> Matrix t -> Matrix t
410mXm = multiply
411
412-- matrix - vector product
413mXv :: (Num t, Field t) => Matrix t -> Vector t -> Vector t
414mXv m v = flatten $ m `mXm` (asColumn v)
415
416-- vector - matrix product
417vXm :: (Num t, Field t) => Vector t -> Matrix t -> Vector t
418vXm v m = flatten $ (asRow v) `mXm` m
419
420
421--------------------------------------------------------------------------- 390---------------------------------------------------------------------------
422 391
423norm2 :: Vector Double -> Double 392norm2 :: Vector Double -> Double
@@ -723,51 +692,3 @@ luFact (l_u,perm) | r <= c = (l ,u ,p, s)
723 (|*|) = mul 692 (|*|) = mul
724 693
725-------------------------------------------------- 694--------------------------------------------------
726
727{- moved to Numeric.LinearAlgebra.Interface Vector typeclass
728-- | Euclidean inner product.
729dot :: (Field t) => Vector t -> Vector t -> t
730dot u v = multiply r c @@> (0,0)
731 where r = asRow u
732 c = asColumn v
733-}
734
735{- | Outer product of two vectors.
736
737@\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3]
738(3><3)
739 [ 5.0, 2.0, 3.0
740 , 10.0, 4.0, 6.0
741 , 15.0, 6.0, 9.0 ]@
742-}
743outer :: (Field t) => Vector t -> Vector t -> Matrix t
744outer u v = asColumn u `multiply` asRow v
745
746{- | Kronecker product of two matrices.
747
748@m1=(2><3)
749 [ 1.0, 2.0, 0.0
750 , 0.0, -1.0, 3.0 ]
751m2=(4><3)
752 [ 1.0, 2.0, 3.0
753 , 4.0, 5.0, 6.0
754 , 7.0, 8.0, 9.0
755 , 10.0, 11.0, 12.0 ]@
756
757@\> kronecker m1 m2
758(8><9)
759 [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0
760 , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0
761 , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0
762 , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0
763 , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0
764 , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0
765 , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0
766 , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]@
767-}
768kronecker :: (Field t) => Matrix t -> Matrix t -> Matrix t
769kronecker a b = fromBlocks
770 . splitEvery (cols a)
771 . map (reshape (cols b))
772 . toRows
773 $ flatten a `outer` flatten b
diff --git a/lib/Numeric/LinearAlgebra/Interface.hs b/lib/Numeric/LinearAlgebra/Interface.hs
index f8917a0..6df782f 100644
--- a/lib/Numeric/LinearAlgebra/Interface.hs
+++ b/lib/Numeric/LinearAlgebra/Interface.hs
@@ -35,7 +35,7 @@ import Numeric.LinearAlgebra.Linear
35class Mul a b c | a b -> c where 35class Mul a b c | a b -> c where
36 infixl 7 <> 36 infixl 7 <>
37 -- | Matrix-matrix, matrix-vector, and vector-matrix products. 37 -- | Matrix-matrix, matrix-vector, and vector-matrix products.
38 (<>) :: Field t => a t -> b t -> c t 38 (<>) :: Prod t => a t -> b t -> c t
39 39
40instance Mul Matrix Matrix Matrix where 40instance Mul Matrix Matrix Matrix where
41 (<>) = multiply 41 (<>) = multiply
diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs
index 7f057ba..eec3035 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK.hs
+++ b/lib/Numeric/LinearAlgebra/LAPACK.hs
@@ -14,7 +14,7 @@
14 14
15module Numeric.LinearAlgebra.LAPACK ( 15module Numeric.LinearAlgebra.LAPACK (
16 -- * Matrix product 16 -- * Matrix product
17 multiplyR, multiplyC, 17 multiplyR, multiplyC, multiplyF, multiplyQ,
18 -- * Linear systems 18 -- * Linear systems
19 linearSolveR, linearSolveC, 19 linearSolveR, linearSolveC,
20 lusR, lusC, 20 lusR, lusC,
@@ -51,8 +51,10 @@ import Control.Monad(when)
51 51
52----------------------------------------------------------------------------------- 52-----------------------------------------------------------------------------------
53 53
54foreign import ccall "LAPACK/lapack-aux.h multiplyR" dgemmc :: CInt -> CInt -> TMMM 54foreign import ccall "multiplyR" dgemmc :: CInt -> CInt -> TMMM
55foreign import ccall "LAPACK/lapack-aux.h multiplyC" zgemmc :: CInt -> CInt -> TCMCMCM 55foreign import ccall "multiplyC" zgemmc :: CInt -> CInt -> TCMCMCM
56foreign import ccall "multiplyF" sgemmc :: CInt -> CInt -> TFMFMFM
57foreign import ccall "multiplyQ" cgemmc :: CInt -> CInt -> TQMQMQM
56 58
57isT MF{} = 0 59isT MF{} = 0
58isT MC{} = 1 60isT MC{} = 1
@@ -69,12 +71,20 @@ multiplyAux f st a b = unsafePerformIO $ do
69 71
70-- | Matrix product based on BLAS's /dgemm/. 72-- | Matrix product based on BLAS's /dgemm/.
71multiplyR :: Matrix Double -> Matrix Double -> Matrix Double 73multiplyR :: Matrix Double -> Matrix Double -> Matrix Double
72multiplyR a b = multiplyAux dgemmc "dgemmc" a b 74multiplyR a b = {-# SCC "multiplyR" #-} multiplyAux dgemmc "dgemmc" a b
73 75
74-- | Matrix product based on BLAS's /zgemm/. 76-- | Matrix product based on BLAS's /zgemm/.
75multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) 77multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
76multiplyC a b = multiplyAux zgemmc "zgemmc" a b 78multiplyC a b = multiplyAux zgemmc "zgemmc" a b
77 79
80-- | Matrix product based on BLAS's /sgemm/.
81multiplyF :: Matrix Float -> Matrix Float -> Matrix Float
82multiplyF a b = multiplyAux sgemmc "sgemmc" a b
83
84-- | Matrix product based on BLAS's /cgemm/.
85multiplyQ :: Matrix (Complex Float) -> Matrix (Complex Float) -> Matrix (Complex Float)
86multiplyQ a b = multiplyAux cgemmc "cgemmc" a b
87
78----------------------------------------------------------------------------- 88-----------------------------------------------------------------------------
79foreign import ccall "svd_l_R" dgesvd :: TMMVM 89foreign import ccall "svd_l_R" dgesvd :: TMMVM
80foreign import ccall "svd_l_C" zgesvd :: TCMCMVCM 90foreign import ccall "svd_l_C" zgesvd :: TCMCMVCM
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
index 7a40991..9e44431 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
+++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
@@ -11,15 +11,25 @@
11 11
12#define MIN(A,B) ((A)<(B)?(A):(B)) 12#define MIN(A,B) ((A)<(B)?(A):(B))
13#define MAX(A,B) ((A)>(B)?(A):(B)) 13#define MAX(A,B) ((A)>(B)?(A):(B))
14 14
15// #define DBGL
16
15#ifdef DBGL 17#ifdef DBGL
16#define DEBUGMSG(M) printf("LAPACK Wrapper "M"\n: "); size_t t0 = time(NULL); 18#define DEBUGMSG(M) printf("\nLAPACK "M"\n");
17#define OK MACRO(printf("%ld s\n",time(0)-t0); return 0;);
18#else 19#else
19#define DEBUGMSG(M) 20#define DEBUGMSG(M)
20#define OK return 0;
21#endif 21#endif
22 22
23#define OK return 0;
24
25// #ifdef DBGL
26// #define DEBUGMSG(M) printf("LAPACK Wrapper "M"\n: "); size_t t0 = time(NULL);
27// #define OK MACRO(printf("%ld s\n",time(0)-t0); return 0;);
28// #else
29// #define DEBUGMSG(M)
30// #define OK return 0;
31// #endif
32
23#define TRACEMAT(M) {int q; printf(" %d x %d: ",M##r,M##c); \ 33#define TRACEMAT(M) {int q; printf(" %d x %d: ",M##r,M##c); \
24 for(q=0;q<M##r*M##c;q++) printf("%.1f ",M##p[q]); printf("\n");} 34 for(q=0;q<M##r*M##c;q++) printf("%.1f ",M##p[q]); printf("\n");}
25 35
@@ -1004,6 +1014,7 @@ void dgemm_(char *, char *, integer *, integer *, integer *,
1004 1014
1005int multiplyR(int ta, int tb, KDMAT(a),KDMAT(b),DMAT(r)) { 1015int multiplyR(int ta, int tb, KDMAT(a),KDMAT(b),DMAT(r)) {
1006 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); 1016 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1017 DEBUGMSG("dgemm_");
1007 integer m = ta?ac:ar; 1018 integer m = ta?ac:ar;
1008 integer n = tb?br:bc; 1019 integer n = tb?br:bc;
1009 integer k = ta?ar:ac; 1020 integer k = ta?ar:ac;
@@ -1022,6 +1033,7 @@ void zgemm_(char *, char *, integer *, integer *, integer *,
1022 1033
1023int multiplyC(int ta, int tb, KCMAT(a),KCMAT(b),CMAT(r)) { 1034int multiplyC(int ta, int tb, KCMAT(a),KCMAT(b),CMAT(r)) {
1024 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); 1035 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1036 DEBUGMSG("zgemm_");
1025 integer m = ta?ac:ar; 1037 integer m = ta?ac:ar;
1026 integer n = tb?br:bc; 1038 integer n = tb?br:bc;
1027 integer k = ta?ar:ac; 1039 integer k = ta?ar:ac;
@@ -1037,6 +1049,47 @@ int multiplyC(int ta, int tb, KCMAT(a),KCMAT(b),CMAT(r)) {
1037 OK 1049 OK
1038} 1050}
1039 1051
1052void sgemm_(char *, char *, integer *, integer *, integer *,
1053 float *, const float *, integer *, const float *,
1054 integer *, float *, float *, integer *);
1055
1056int multiplyF(int ta, int tb, KFMAT(a),KFMAT(b),FMAT(r)) {
1057 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1058 DEBUGMSG("sgemm_");
1059 integer m = ta?ac:ar;
1060 integer n = tb?br:bc;
1061 integer k = ta?ar:ac;
1062 integer lda = ar;
1063 integer ldb = br;
1064 integer ldc = rr;
1065 float alpha = 1;
1066 float beta = 0;
1067 sgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc);
1068 OK
1069}
1070
1071void cgemm_(char *, char *, integer *, integer *, integer *,
1072 complex *, const complex *, integer *, const complex *,
1073 integer *, complex *, complex *, integer *);
1074
1075int multiplyQ(int ta, int tb, KQMAT(a),KQMAT(b),QMAT(r)) {
1076 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1077 DEBUGMSG("cgemm_");
1078 integer m = ta?ac:ar;
1079 integer n = tb?br:bc;
1080 integer k = ta?ar:ac;
1081 integer lda = ar;
1082 integer ldb = br;
1083 integer ldc = rr;
1084 complex alpha = {1,0};
1085 complex beta = {0,0};
1086 cgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,
1087 (complex*)ap,&lda,
1088 (complex*)bp,&ldb,&beta,
1089 (complex*)rp,&ldc);
1090 OK
1091}
1092
1040//////////////////// transpose ///////////////////////// 1093//////////////////// transpose /////////////////////////
1041 1094
1042int transF(KFMAT(x),FMAT(t)) { 1095int transF(KFMAT(x),FMAT(t)) {
@@ -1128,3 +1181,23 @@ int constantC(doublecomplex* pval, CVEC(r)) {
1128 } 1181 }
1129 OK 1182 OK
1130} 1183}
1184
1185//////////////////// float-double conversion /////////////////////////
1186
1187int float2double(FVEC(x),DVEC(y)) {
1188 DEBUGMSG("float2double")
1189 int k;
1190 for(k=0;k<xn;k++) {
1191 yp[k]=xp[k];
1192 }
1193 OK
1194}
1195
1196int double2float(DVEC(x),FVEC(y)) {
1197 DEBUGMSG("double2float")
1198 int k;
1199 for(k=0;k<xn;k++) {
1200 yp[k]=xp[k];
1201 }
1202 OK
1203}
diff --git a/lib/Numeric/LinearAlgebra/Linear.hs b/lib/Numeric/LinearAlgebra/Linear.hs
index 51e93fb..ae48245 100644
--- a/lib/Numeric/LinearAlgebra/Linear.hs
+++ b/lib/Numeric/LinearAlgebra/Linear.hs
@@ -19,15 +19,19 @@ module Numeric.LinearAlgebra.Linear (
19 -- * Linear Algebra Typeclasses 19 -- * Linear Algebra Typeclasses
20 Vectors(..), 20 Vectors(..),
21 Linear(..), 21 Linear(..),
22 -- * Products
23 Prod(..),
24 mXm,mXv,vXm, mulH,
25 outer, kronecker,
22 -- * Creation of numeric vectors 26 -- * Creation of numeric vectors
23 constant, linspace 27 constant, linspace
24) where 28) where
25 29
26import Data.Packed.Internal.Vector 30import Data.Packed.Internal
27import Data.Packed.Internal.Matrix
28import Data.Packed.Matrix 31import Data.Packed.Matrix
29import Data.Complex 32import Data.Complex
30import Numeric.GSL.Vector 33import Numeric.GSL.Vector
34import Numeric.LinearAlgebra.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ)
31 35
32import Control.Monad(ap) 36import Control.Monad(ap)
33 37
@@ -86,7 +90,7 @@ instance Vectors Vector (Complex Double) where
86---------------------------------------------------- 90----------------------------------------------------
87 91
88-- | Basic element-by-element functions. 92-- | Basic element-by-element functions.
89class (Element e, AutoReal e, Convert e, Container c) => Linear c e where 93class (Element e, AutoReal e, Container c) => Linear c e where
90 -- | create a structure with a single element 94 -- | create a structure with a single element
91 scalar :: e -> c e 95 scalar :: e -> c e
92 scale :: e -> c e -> c e 96 scale :: e -> c e -> c e
@@ -184,3 +188,83 @@ linspace :: (Enum e, Linear Vector e) => Int -> (e, e) -> Vector e
184linspace n (a,b) = addConstant a $ scale s $ fromList [0 .. fromIntegral n-1] 188linspace n (a,b) = addConstant a $ scale s $ fromList [0 .. fromIntegral n-1]
185 where s = (b-a)/fromIntegral (n-1) 189 where s = (b-a)/fromIntegral (n-1)
186 190
191----------------------------------------------------
192
193-- reference multiply
194mulH a b = fromLists [[ doth ai bj | bj <- toColumns b] | ai <- toRows a ]
195 where doth u v = sum $ zipWith (*) (toList u) (toList v)
196
197class Element t => Prod t where
198 multiply :: Matrix t -> Matrix t -> Matrix t
199 multiply = mulH
200 ctrans :: Matrix t -> Matrix t
201
202instance Prod Double where
203 multiply = multiplyR
204 ctrans = trans
205
206instance Prod (Complex Double) where
207 multiply = multiplyC
208 ctrans = conj . trans
209
210instance Prod Float where
211 multiply = multiplyF
212 ctrans = trans
213
214instance Prod (Complex Float) where
215 multiply = multiplyQ
216 ctrans = conj . trans
217
218----------------------------------------------------------
219
220-- synonym for matrix product
221mXm :: Prod t => Matrix t -> Matrix t -> Matrix t
222mXm = multiply
223
224-- matrix - vector product
225mXv :: Prod t => Matrix t -> Vector t -> Vector t
226mXv m v = flatten $ m `mXm` (asColumn v)
227
228-- vector - matrix product
229vXm :: Prod t => Vector t -> Matrix t -> Vector t
230vXm v m = flatten $ (asRow v) `mXm` m
231
232{- | Outer product of two vectors.
233
234@\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3]
235(3><3)
236 [ 5.0, 2.0, 3.0
237 , 10.0, 4.0, 6.0
238 , 15.0, 6.0, 9.0 ]@
239-}
240outer :: (Prod t) => Vector t -> Vector t -> Matrix t
241outer u v = asColumn u `multiply` asRow v
242
243{- | Kronecker product of two matrices.
244
245@m1=(2><3)
246 [ 1.0, 2.0, 0.0
247 , 0.0, -1.0, 3.0 ]
248m2=(4><3)
249 [ 1.0, 2.0, 3.0
250 , 4.0, 5.0, 6.0
251 , 7.0, 8.0, 9.0
252 , 10.0, 11.0, 12.0 ]@
253
254@\> kronecker m1 m2
255(8><9)
256 [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0
257 , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0
258 , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0
259 , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0
260 , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0
261 , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0
262 , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0
263 , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]@
264-}
265kronecker :: (Prod t) => Matrix t -> Matrix t -> Matrix t
266kronecker a b = fromBlocks
267 . splitEvery (cols a)
268 . map (reshape (cols b))
269 . toRows
270 $ flatten a `outer` flatten b
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs
index e3b6e1f..91f6742 100644
--- a/lib/Numeric/LinearAlgebra/Tests.hs
+++ b/lib/Numeric/LinearAlgebra/Tests.hs
@@ -34,6 +34,7 @@ import qualified Prelude
34import System.CPUTime 34import System.CPUTime
35import Text.Printf 35import Text.Printf
36import Data.Packed.Development(unsafeFromForeignPtr,unsafeToForeignPtr) 36import Data.Packed.Development(unsafeFromForeignPtr,unsafeToForeignPtr)
37import Control.Arrow((***))
37 38
38#include "Tests/quickCheckCompat.h" 39#include "Tests/quickCheckCompat.h"
39 40
@@ -224,11 +225,16 @@ runTests :: Int -- ^ maximum dimension
224runTests n = do 225runTests n = do
225 setErrorHandlerOff 226 setErrorHandlerOff
226 let test p = qCheck n p 227 let test p = qCheck n p
227 putStrLn "------ mult" 228 putStrLn "------ mult Double"
228 test (multProp1 . rConsist) 229 test (multProp1 10 . rConsist)
229 test (multProp1 . cConsist) 230 test (multProp1 10 . cConsist)
230 test (multProp2 . rConsist) 231 test (multProp2 10 . rConsist)
231 test (multProp2 . cConsist) 232 test (multProp2 10 . cConsist)
233 putStrLn "------ mult Float"
234 test (multProp1 6 . (single *** single) . rConsist)
235 test (multProp1 6 . (single *** single) . cConsist)
236 test (multProp2 6 . (single *** single) . rConsist)
237 test (multProp2 6 . (single *** single) . cConsist)
232 putStrLn "------ sub-trans" 238 putStrLn "------ sub-trans"
233 test (subProp . rM) 239 test (subProp . rM)
234 test (subProp . cM) 240 test (subProp . cM)
diff --git a/lib/Numeric/LinearAlgebra/Tests/Properties.hs b/lib/Numeric/LinearAlgebra/Tests/Properties.hs
index d29e19a..f7a948e 100644
--- a/lib/Numeric/LinearAlgebra/Tests/Properties.hs
+++ b/lib/Numeric/LinearAlgebra/Tests/Properties.hs
@@ -42,7 +42,7 @@ module Numeric.LinearAlgebra.Tests.Properties (
42 linearSolveProp, linearSolveProp2 42 linearSolveProp, linearSolveProp2
43) where 43) where
44 44
45import Numeric.LinearAlgebra 45import Numeric.LinearAlgebra hiding (mulH)
46import Numeric.LinearAlgebra.LAPACK 46import Numeric.LinearAlgebra.LAPACK
47import Debug.Trace 47import Debug.Trace
48#include "quickCheckCompat.h" 48#include "quickCheckCompat.h"
@@ -237,9 +237,9 @@ expmDiagProp m = expm (logm m) :~ 7 ~: complex m
237mulH a b = fromLists [[ doth ai bj | bj <- toColumns b] | ai <- toRows a ] 237mulH a b = fromLists [[ doth ai bj | bj <- toColumns b] | ai <- toRows a ]
238 where doth u v = sum $ zipWith (*) (toList u) (toList v) 238 where doth u v = sum $ zipWith (*) (toList u) (toList v)
239 239
240multProp1 (a,b) = a <> b |~| mulH a b 240multProp1 p (a,b) = (a <> b) :~p~: (mulH a b)
241 241
242multProp2 (a,b) = ctrans (a <> b) |~| ctrans b <> ctrans a 242multProp2 p (a,b) = (ctrans (a <> b)) :~p~: (ctrans b <> ctrans a)
243 243
244linearSolveProp f m = f m m |~| ident (rows m) 244linearSolveProp f m = f m m |~| ident (rows m)
245 245