summaryrefslogtreecommitdiff
path: root/lib/Numeric/LinearAlgebra
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2008-10-22 12:59:18 +0000
committerAlberto Ruiz <aruiz@um.es>2008-10-22 12:59:18 +0000
commitfaeaf6d261b760e628c1e63551d822d16876c0cc (patch)
tree45e3e2d1460d72e1fd037e19d4470963b75cc00e /lib/Numeric/LinearAlgebra
parent9d9b1274a522e1bf0c5dea210765a0368ebb74a5 (diff)
-Wall
Diffstat (limited to 'lib/Numeric/LinearAlgebra')
-rw-r--r--lib/Numeric/LinearAlgebra/Algorithms.hs38
-rw-r--r--lib/Numeric/LinearAlgebra/Instances.hs7
-rw-r--r--lib/Numeric/LinearAlgebra/Linear.hs2
-rw-r--r--lib/Numeric/LinearAlgebra/Tests.hs8
-rw-r--r--lib/Numeric/LinearAlgebra/Tests/Instances.hs6
-rw-r--r--lib/Numeric/LinearAlgebra/Tests/Properties.hs15
6 files changed, 38 insertions, 38 deletions
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs
index fbefa68..45298b5 100644
--- a/lib/Numeric/LinearAlgebra/Algorithms.hs
+++ b/lib/Numeric/LinearAlgebra/Algorithms.hs
@@ -168,9 +168,9 @@ square m = rows m == cols m
168 168
169-- | determinant of a square matrix, computed from the LU decomposition. 169-- | determinant of a square matrix, computed from the LU decomposition.
170det :: Field t => Matrix t -> t 170det :: Field t => Matrix t -> t
171det m | square m = s * (product $ toList $ takeDiag $ lu) 171det m | square m = s * (product $ toList $ takeDiag $ lup)
172 | otherwise = error "det of nonsquare matrix" 172 | otherwise = error "det of nonsquare matrix"
173 where (lu,perm) = luPacked m 173 where (lup,perm) = luPacked m
174 s = signlp (rows m) perm 174 s = signlp (rows m) perm
175 175
176-- | LU factorization of a general matrix using lapack's dgetrf or zgetrf. 176-- | LU factorization of a general matrix using lapack's dgetrf or zgetrf.
@@ -501,21 +501,21 @@ fixPerm r vals = (fromColumns $ elems res, sign)
501 s = toColumns (ident r) 501 s = toColumns (ident r)
502 (res,sign) = foldl swap (listArray (0,r-1) s, 1) (zip v vals) 502 (res,sign) = foldl swap (listArray (0,r-1) s, 1) (zip v vals)
503 503
504triang r c h v = reshape c $ fromList [el i j | i<-[0..r-1], j<-[0..c-1]] 504triang r c h v = (r><c) [el s t | s<-[0..r-1], t<-[0..c-1]]
505 where el i j = if j-i>=h then v else 1 - v 505 where el p q = if q-p>=h then v else 1 - v
506 506
507luFact (lu,perm) | r <= c = (l ,u ,p, s) 507luFact (l_u,perm) | r <= c = (l ,u ,p, s)
508 | otherwise = (l',u',p, s) 508 | otherwise = (l',u',p, s)
509 where 509 where
510 r = rows lu 510 r = rows l_u
511 c = cols lu 511 c = cols l_u
512 tu = triang r c 0 1 512 tu = triang r c 0 1
513 tl = triang r c 0 0 513 tl = triang r c 0 0
514 l = takeColumns r (lu |*| tl) |+| diagRect (constant 1 r) r r 514 l = takeColumns r (l_u |*| tl) |+| diagRect (constant 1 r) r r
515 u = lu |*| tu 515 u = l_u |*| tu
516 (p,s) = fixPerm r perm 516 (p,s) = fixPerm r perm
517 l' = (lu |*| tl) |+| diagRect (constant 1 c) r c 517 l' = (l_u |*| tl) |+| diagRect (constant 1 c) r c
518 u' = takeRows c (lu |*| tu) 518 u' = takeRows c (l_u |*| tu)
519 (|+|) = add 519 (|+|) = add
520 (|*|) = mul 520 (|*|) = mul
521 521
@@ -572,8 +572,8 @@ kronecker a b = fromBlocks
572-- reference multiply 572-- reference multiply
573--------------------------------------------------------------------- 573---------------------------------------------------------------------
574 574
575mulH a b = fromLists [[ dot ai bj | bj <- toColumns b] | ai <- toRows a ] 575mulH a b = fromLists [[ doth ai bj | bj <- toColumns b] | ai <- toRows a ]
576 where dot u v = sum $ zipWith (*) (toList u) (toList v) 576 where doth u v = sum $ zipWith (*) (toList u) (toList v)
577 577
578----------------------------------------------------------------------------------- 578-----------------------------------------------------------------------------------
579-- workaround 579-- workaround
@@ -599,7 +599,7 @@ colMajor = CBLASOrder 102
599 599
600noTrans, trans', conjTrans :: CBLASTrans 600noTrans, trans', conjTrans :: CBLASTrans
601noTrans = CBLASTrans 111 601noTrans = CBLASTrans 111
602trans' = CBLASTrans 112 602trans' = CBLASTrans 112
603conjTrans = CBLASTrans 113 603conjTrans = CBLASTrans 113
604 604
605foreign import ccall "cblas.h cblas_dgemm" 605foreign import ccall "cblas.h cblas_dgemm"
@@ -608,12 +608,12 @@ foreign import ccall "cblas.h cblas_dgemm"
608 -> Ptr Double -> CInt -> IO () 608 -> Ptr Double -> CInt -> IO ()
609 609
610multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double 610multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double
611multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b) 611multiplyR3 x y = multiply3 dgemm "cblas_dgemm" (fmat x) (fmat y)
612 where 612 where
613 multiply3 f st a b 613 multiply3 f st a b
614 | cols a == rows b = unsafePerformIO $ do 614 | cols a == rows b = unsafePerformIO $ do
615 s <- createMatrix ColumnMajor (rows a) (cols b) 615 s <- createMatrix ColumnMajor (rows a) (cols b)
616 let g ar ac ap br bc bp rr rc rp = f colMajor noTrans noTrans ar bc ac 1 ap ar bp br 0 rp rr >> return 0 616 let g ar ac ap br bc bp rr _rc rp = f colMajor noTrans noTrans ar bc ac 1 ap ar bp br 0 rp rr >> return 0
617 app3 g mat a mat b mat s st 617 app3 g mat a mat b mat s st
618 return s 618 return s
619 | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" 619 | otherwise = error $ st ++ " (matrix product) of nonconformant matrices"
@@ -625,14 +625,14 @@ foreign import ccall "cblas.h cblas_zgemm"
625 -> Ptr (Complex Double) -> CInt -> IO () 625 -> Ptr (Complex Double) -> CInt -> IO ()
626 626
627multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) 627multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
628multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat b) 628multiplyC3 x y = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat x) (fmat y)
629 where 629 where
630 multiply3 f st a b 630 multiply3 f st a b
631 | cols a == rows b = do 631 | cols a == rows b = do
632 s <- createMatrix ColumnMajor (rows a) (cols b) 632 s <- createMatrix ColumnMajor (rows a) (cols b)
633 palpha <- new 1 633 palpha <- new 1
634 pbeta <- new 0 634 pbeta <- new 0
635 let g ar ac ap br bc bp rr rc rp = f colMajor noTrans noTrans ar bc ac palpha ap ar bp br pbeta rp rr >> return 0 635 let g ar ac ap br bc bp rr _rc rp = f colMajor noTrans noTrans ar bc ac palpha ap ar bp br pbeta rp rr >> return 0
636 app3 g mat a mat b mat s st 636 app3 g mat a mat b mat s st
637 free palpha 637 free palpha
638 free pbeta 638 free pbeta
diff --git a/lib/Numeric/LinearAlgebra/Instances.hs b/lib/Numeric/LinearAlgebra/Instances.hs
index 334ccff..79a8990 100644
--- a/lib/Numeric/LinearAlgebra/Instances.hs
+++ b/lib/Numeric/LinearAlgebra/Instances.hs
@@ -22,7 +22,6 @@ module Numeric.LinearAlgebra.Instances(
22import Numeric.LinearAlgebra.Linear 22import Numeric.LinearAlgebra.Linear
23import Numeric.GSL.Vector 23import Numeric.GSL.Vector
24import Data.Packed.Matrix 24import Data.Packed.Matrix
25import Data.Packed.Vector
26import Complex 25import Complex
27import Data.List(transpose,intersperse) 26import Data.List(transpose,intersperse)
28import Foreign(Storable) 27import Foreign(Storable)
@@ -49,11 +48,11 @@ instance (Show a, Storable a) => (Show (Vector a)) where
49------------------------------------------------------------------ 48------------------------------------------------------------------
50 49
51instance (Element a, Read a) => Read (Matrix a) where 50instance (Element a, Read a) => Read (Matrix a) where
52 readsPrec _ s = [((rows><cols) . read $ listnums, rest)] 51 readsPrec _ s = [((rs><cs) . read $ listnums, rest)]
53 where (thing,rest) = breakAt ']' s 52 where (thing,rest) = breakAt ']' s
54 (dims,listnums) = breakAt ')' thing 53 (dims,listnums) = breakAt ')' thing
55 cols = read . init . fst. breakAt ')' . snd . breakAt '<' $ dims 54 cs = read . init . fst. breakAt ')' . snd . breakAt '<' $ dims
56 rows = read . snd . breakAt '(' .init . fst . breakAt '>' $ dims 55 rs = read . snd . breakAt '(' .init . fst . breakAt '>' $ dims
57 56
58instance (Element a, Read a) => Read (Vector a) where 57instance (Element a, Read a) => Read (Vector a) where
59 readsPrec _ s = [((d |>) . read $ listnums, rest)] 58 readsPrec _ s = [((d |>) . read $ listnums, rest)]
diff --git a/lib/Numeric/LinearAlgebra/Linear.hs b/lib/Numeric/LinearAlgebra/Linear.hs
index 3af9960..9bf60e2 100644
--- a/lib/Numeric/LinearAlgebra/Linear.hs
+++ b/lib/Numeric/LinearAlgebra/Linear.hs
@@ -18,8 +18,6 @@ module Numeric.LinearAlgebra.Linear (
18 Linear(..) 18 Linear(..)
19) where 19) where
20 20
21
22import Data.Packed.Internal(partit)
23import Data.Packed 21import Data.Packed
24import Numeric.GSL.Vector 22import Numeric.GSL.Vector
25import Complex 23import Complex
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs
index 07b9f63..7ebd1f2 100644
--- a/lib/Numeric/LinearAlgebra/Tests.hs
+++ b/lib/Numeric/LinearAlgebra/Tests.hs
@@ -22,11 +22,15 @@ module Numeric.LinearAlgebra.Tests(
22import Numeric.LinearAlgebra 22import Numeric.LinearAlgebra
23import Numeric.LinearAlgebra.Tests.Instances 23import Numeric.LinearAlgebra.Tests.Instances
24import Numeric.LinearAlgebra.Tests.Properties 24import Numeric.LinearAlgebra.Tests.Properties
25import Test.QuickCheck 25import Test.QuickCheck hiding (test)
26import Test.HUnit hiding ((~:),test) 26import Test.HUnit hiding ((~:),test)
27import System.Info 27import System.Info
28import Data.List(foldl1') 28import Data.List(foldl1')
29import Numeric.GSL hiding (sin,cos,exp,choose) 29import Numeric.GSL hiding (sin,cos,exp,choose)
30import Prelude hiding ((^))
31import qualified Prelude
32
33a ^ b = a Prelude.^ (b :: Int)
30 34
31qCheck n = check defaultConfig {configSize = const n} 35qCheck n = check defaultConfig {configSize = const n}
32 36
@@ -73,7 +77,7 @@ besselTest = utest "bessel_J0_e" ( abs (r-expected) < e )
73 expected = -0.17759677131433830434739701 77 expected = -0.17759677131433830434739701
74 78
75exponentialTest = utest "exp_e10_e" ( abs (v*10^e - expected) < 4E-2 ) 79exponentialTest = utest "exp_e10_e" ( abs (v*10^e - expected) < 4E-2 )
76 where (v,e,err) = exp_e10_e 30.0 80 where (v,e,_err) = exp_e10_e 30.0
77 expected = exp 30.0 81 expected = exp 30.0
78 82
79--------------------------------------------------------------------- 83---------------------------------------------------------------------
diff --git a/lib/Numeric/LinearAlgebra/Tests/Instances.hs b/lib/Numeric/LinearAlgebra/Tests/Instances.hs
index 677ad2b..4e829d2 100644
--- a/lib/Numeric/LinearAlgebra/Tests/Instances.hs
+++ b/lib/Numeric/LinearAlgebra/Tests/Instances.hs
@@ -30,9 +30,9 @@ import Control.Monad(replicateM)
30 30
31instance (Arbitrary a, RealFloat a) => Arbitrary (Complex a) where 31instance (Arbitrary a, RealFloat a) => Arbitrary (Complex a) where
32 arbitrary = do 32 arbitrary = do
33 r <- arbitrary 33 re <- arbitrary
34 i <- arbitrary 34 im <- arbitrary
35 return (r:+i) 35 return (re :+ im)
36 coarbitrary = undefined 36 coarbitrary = undefined
37 37
38chooseDim = sized $ \m -> choose (1,max 1 m) 38chooseDim = sized $ \m -> choose (1,max 1 m)
diff --git a/lib/Numeric/LinearAlgebra/Tests/Properties.hs b/lib/Numeric/LinearAlgebra/Tests/Properties.hs
index 5663b86..b5321c2 100644
--- a/lib/Numeric/LinearAlgebra/Tests/Properties.hs
+++ b/lib/Numeric/LinearAlgebra/Tests/Properties.hs
@@ -39,11 +39,10 @@ module Numeric.LinearAlgebra.Tests.Properties (
39) where 39) where
40 40
41import Numeric.LinearAlgebra 41import Numeric.LinearAlgebra
42import Numeric.LinearAlgebra.Tests.Instances(Sq(..),Her(..),Rot(..))
43import Test.QuickCheck 42import Test.QuickCheck
44import Debug.Trace 43-- import Debug.Trace
45 44
46debug x = trace (show x) x 45-- debug x = trace (show x) x
47 46
48-- relative error 47-- relative error
49dist :: (Normed t, Num t) => t -> t -> Double 48dist :: (Normed t, Num t) => t -> t -> Double
@@ -77,7 +76,7 @@ hermitian m = square m && m |~| ctrans m
77wellCond m = rcond m > 1/100 76wellCond m = rcond m > 1/100
78 77
79positiveDefinite m = minimum (toList e) > 0 78positiveDefinite m = minimum (toList e) > 0
80 where (e,v) = eigSH m 79 where (e,_v) = eigSH m
81 80
82upperTriang m = rows m == 1 || down == z 81upperTriang m = rows m == 1 || down == z
83 where down = fromList $ concat $ zipWith drop [1..] (toLists (ctrans m)) 82 where down = fromList $ concat $ zipWith drop [1..] (toLists (ctrans m))
@@ -107,8 +106,8 @@ pinvProp m = m <> p <> m |~| m
107 106
108detProp m = s d1 |~| s d2 107detProp m = s d1 |~| s d2
109 where d1 = det m 108 where d1 = det m
110 d2 = det' m * det q 109 d2 = det' * det q
111 det' m = product $ toList $ takeDiag r 110 det' = product $ toList $ takeDiag r
112 (q,r) = qr m 111 (q,r) = qr m
113 s x = fromList [x] 112 s x = fromList [x]
114 113
@@ -147,10 +146,10 @@ schurProp2 m = m |~| u <> s <> ctrans u && unitary u && upperHessenberg s -- fix
147 146
148cholProp m = m |~| ctrans c <> c && upperTriang c 147cholProp m = m |~| ctrans c <> c && upperTriang c
149 where c = chol m 148 where c = chol m
150 pos = positiveDefinite m 149 -- pos = positiveDefinite m
151 150
152expmDiagProp m = expm (logm m) :~ 7 ~: complex m 151expmDiagProp m = expm (logm m) :~ 7 ~: complex m
153 where logm m = matFunc log m 152 where logm = matFunc log
154 153
155multProp1 (a,b) = a <> b |~| mulH a b 154multProp1 (a,b) = a <> b |~| mulH a b
156 155