diff options
Diffstat (limited to 'lib/Numeric')
-rw-r--r-- | lib/Numeric/GSL/Vector.hs | 2 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Algorithms.hs | 38 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Instances.hs | 7 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Linear.hs | 2 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Tests.hs | 8 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Tests/Instances.hs | 6 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Tests/Properties.hs | 15 |
7 files changed, 39 insertions, 39 deletions
diff --git a/lib/Numeric/GSL/Vector.hs b/lib/Numeric/GSL/Vector.hs index 9e06783..92cda87 100644 --- a/lib/Numeric/GSL/Vector.hs +++ b/lib/Numeric/GSL/Vector.hs | |||
@@ -27,7 +27,7 @@ import Complex | |||
27 | import Foreign | 27 | import Foreign |
28 | import Foreign.C.Types(CInt) | 28 | import Foreign.C.Types(CInt) |
29 | 29 | ||
30 | fromei x = fromIntegral (fromEnum x) | 30 | fromei x = fromIntegral (fromEnum x) :: CInt |
31 | 31 | ||
32 | data FunCodeV = Sin | 32 | data FunCodeV = Sin |
33 | | Cos | 33 | | Cos |
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index fbefa68..45298b5 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs | |||
@@ -168,9 +168,9 @@ square m = rows m == cols m | |||
168 | 168 | ||
169 | -- | determinant of a square matrix, computed from the LU decomposition. | 169 | -- | determinant of a square matrix, computed from the LU decomposition. |
170 | det :: Field t => Matrix t -> t | 170 | det :: Field t => Matrix t -> t |
171 | det m | square m = s * (product $ toList $ takeDiag $ lu) | 171 | det m | square m = s * (product $ toList $ takeDiag $ lup) |
172 | | otherwise = error "det of nonsquare matrix" | 172 | | otherwise = error "det of nonsquare matrix" |
173 | where (lu,perm) = luPacked m | 173 | where (lup,perm) = luPacked m |
174 | s = signlp (rows m) perm | 174 | s = signlp (rows m) perm |
175 | 175 | ||
176 | -- | LU factorization of a general matrix using lapack's dgetrf or zgetrf. | 176 | -- | LU factorization of a general matrix using lapack's dgetrf or zgetrf. |
@@ -501,21 +501,21 @@ fixPerm r vals = (fromColumns $ elems res, sign) | |||
501 | s = toColumns (ident r) | 501 | s = toColumns (ident r) |
502 | (res,sign) = foldl swap (listArray (0,r-1) s, 1) (zip v vals) | 502 | (res,sign) = foldl swap (listArray (0,r-1) s, 1) (zip v vals) |
503 | 503 | ||
504 | triang r c h v = reshape c $ fromList [el i j | i<-[0..r-1], j<-[0..c-1]] | 504 | triang r c h v = (r><c) [el s t | s<-[0..r-1], t<-[0..c-1]] |
505 | where el i j = if j-i>=h then v else 1 - v | 505 | where el p q = if q-p>=h then v else 1 - v |
506 | 506 | ||
507 | luFact (lu,perm) | r <= c = (l ,u ,p, s) | 507 | luFact (l_u,perm) | r <= c = (l ,u ,p, s) |
508 | | otherwise = (l',u',p, s) | 508 | | otherwise = (l',u',p, s) |
509 | where | 509 | where |
510 | r = rows lu | 510 | r = rows l_u |
511 | c = cols lu | 511 | c = cols l_u |
512 | tu = triang r c 0 1 | 512 | tu = triang r c 0 1 |
513 | tl = triang r c 0 0 | 513 | tl = triang r c 0 0 |
514 | l = takeColumns r (lu |*| tl) |+| diagRect (constant 1 r) r r | 514 | l = takeColumns r (l_u |*| tl) |+| diagRect (constant 1 r) r r |
515 | u = lu |*| tu | 515 | u = l_u |*| tu |
516 | (p,s) = fixPerm r perm | 516 | (p,s) = fixPerm r perm |
517 | l' = (lu |*| tl) |+| diagRect (constant 1 c) r c | 517 | l' = (l_u |*| tl) |+| diagRect (constant 1 c) r c |
518 | u' = takeRows c (lu |*| tu) | 518 | u' = takeRows c (l_u |*| tu) |
519 | (|+|) = add | 519 | (|+|) = add |
520 | (|*|) = mul | 520 | (|*|) = mul |
521 | 521 | ||
@@ -572,8 +572,8 @@ kronecker a b = fromBlocks | |||
572 | -- reference multiply | 572 | -- reference multiply |
573 | --------------------------------------------------------------------- | 573 | --------------------------------------------------------------------- |
574 | 574 | ||
575 | mulH a b = fromLists [[ dot ai bj | bj <- toColumns b] | ai <- toRows a ] | 575 | mulH a b = fromLists [[ doth ai bj | bj <- toColumns b] | ai <- toRows a ] |
576 | where dot u v = sum $ zipWith (*) (toList u) (toList v) | 576 | where doth u v = sum $ zipWith (*) (toList u) (toList v) |
577 | 577 | ||
578 | ----------------------------------------------------------------------------------- | 578 | ----------------------------------------------------------------------------------- |
579 | -- workaround | 579 | -- workaround |
@@ -599,7 +599,7 @@ colMajor = CBLASOrder 102 | |||
599 | 599 | ||
600 | noTrans, trans', conjTrans :: CBLASTrans | 600 | noTrans, trans', conjTrans :: CBLASTrans |
601 | noTrans = CBLASTrans 111 | 601 | noTrans = CBLASTrans 111 |
602 | trans' = CBLASTrans 112 | 602 | trans' = CBLASTrans 112 |
603 | conjTrans = CBLASTrans 113 | 603 | conjTrans = CBLASTrans 113 |
604 | 604 | ||
605 | foreign import ccall "cblas.h cblas_dgemm" | 605 | foreign import ccall "cblas.h cblas_dgemm" |
@@ -608,12 +608,12 @@ foreign import ccall "cblas.h cblas_dgemm" | |||
608 | -> Ptr Double -> CInt -> IO () | 608 | -> Ptr Double -> CInt -> IO () |
609 | 609 | ||
610 | multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double | 610 | multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double |
611 | multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b) | 611 | multiplyR3 x y = multiply3 dgemm "cblas_dgemm" (fmat x) (fmat y) |
612 | where | 612 | where |
613 | multiply3 f st a b | 613 | multiply3 f st a b |
614 | | cols a == rows b = unsafePerformIO $ do | 614 | | cols a == rows b = unsafePerformIO $ do |
615 | s <- createMatrix ColumnMajor (rows a) (cols b) | 615 | s <- createMatrix ColumnMajor (rows a) (cols b) |
616 | let g ar ac ap br bc bp rr rc rp = f colMajor noTrans noTrans ar bc ac 1 ap ar bp br 0 rp rr >> return 0 | 616 | let g ar ac ap br bc bp rr _rc rp = f colMajor noTrans noTrans ar bc ac 1 ap ar bp br 0 rp rr >> return 0 |
617 | app3 g mat a mat b mat s st | 617 | app3 g mat a mat b mat s st |
618 | return s | 618 | return s |
619 | | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" | 619 | | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" |
@@ -625,14 +625,14 @@ foreign import ccall "cblas.h cblas_zgemm" | |||
625 | -> Ptr (Complex Double) -> CInt -> IO () | 625 | -> Ptr (Complex Double) -> CInt -> IO () |
626 | 626 | ||
627 | multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | 627 | multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) |
628 | multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat b) | 628 | multiplyC3 x y = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat x) (fmat y) |
629 | where | 629 | where |
630 | multiply3 f st a b | 630 | multiply3 f st a b |
631 | | cols a == rows b = do | 631 | | cols a == rows b = do |
632 | s <- createMatrix ColumnMajor (rows a) (cols b) | 632 | s <- createMatrix ColumnMajor (rows a) (cols b) |
633 | palpha <- new 1 | 633 | palpha <- new 1 |
634 | pbeta <- new 0 | 634 | pbeta <- new 0 |
635 | let g ar ac ap br bc bp rr rc rp = f colMajor noTrans noTrans ar bc ac palpha ap ar bp br pbeta rp rr >> return 0 | 635 | let g ar ac ap br bc bp rr _rc rp = f colMajor noTrans noTrans ar bc ac palpha ap ar bp br pbeta rp rr >> return 0 |
636 | app3 g mat a mat b mat s st | 636 | app3 g mat a mat b mat s st |
637 | free palpha | 637 | free palpha |
638 | free pbeta | 638 | free pbeta |
diff --git a/lib/Numeric/LinearAlgebra/Instances.hs b/lib/Numeric/LinearAlgebra/Instances.hs index 334ccff..79a8990 100644 --- a/lib/Numeric/LinearAlgebra/Instances.hs +++ b/lib/Numeric/LinearAlgebra/Instances.hs | |||
@@ -22,7 +22,6 @@ module Numeric.LinearAlgebra.Instances( | |||
22 | import Numeric.LinearAlgebra.Linear | 22 | import Numeric.LinearAlgebra.Linear |
23 | import Numeric.GSL.Vector | 23 | import Numeric.GSL.Vector |
24 | import Data.Packed.Matrix | 24 | import Data.Packed.Matrix |
25 | import Data.Packed.Vector | ||
26 | import Complex | 25 | import Complex |
27 | import Data.List(transpose,intersperse) | 26 | import Data.List(transpose,intersperse) |
28 | import Foreign(Storable) | 27 | import Foreign(Storable) |
@@ -49,11 +48,11 @@ instance (Show a, Storable a) => (Show (Vector a)) where | |||
49 | ------------------------------------------------------------------ | 48 | ------------------------------------------------------------------ |
50 | 49 | ||
51 | instance (Element a, Read a) => Read (Matrix a) where | 50 | instance (Element a, Read a) => Read (Matrix a) where |
52 | readsPrec _ s = [((rows><cols) . read $ listnums, rest)] | 51 | readsPrec _ s = [((rs><cs) . read $ listnums, rest)] |
53 | where (thing,rest) = breakAt ']' s | 52 | where (thing,rest) = breakAt ']' s |
54 | (dims,listnums) = breakAt ')' thing | 53 | (dims,listnums) = breakAt ')' thing |
55 | cols = read . init . fst. breakAt ')' . snd . breakAt '<' $ dims | 54 | cs = read . init . fst. breakAt ')' . snd . breakAt '<' $ dims |
56 | rows = read . snd . breakAt '(' .init . fst . breakAt '>' $ dims | 55 | rs = read . snd . breakAt '(' .init . fst . breakAt '>' $ dims |
57 | 56 | ||
58 | instance (Element a, Read a) => Read (Vector a) where | 57 | instance (Element a, Read a) => Read (Vector a) where |
59 | readsPrec _ s = [((d |>) . read $ listnums, rest)] | 58 | readsPrec _ s = [((d |>) . read $ listnums, rest)] |
diff --git a/lib/Numeric/LinearAlgebra/Linear.hs b/lib/Numeric/LinearAlgebra/Linear.hs index 3af9960..9bf60e2 100644 --- a/lib/Numeric/LinearAlgebra/Linear.hs +++ b/lib/Numeric/LinearAlgebra/Linear.hs | |||
@@ -18,8 +18,6 @@ module Numeric.LinearAlgebra.Linear ( | |||
18 | Linear(..) | 18 | Linear(..) |
19 | ) where | 19 | ) where |
20 | 20 | ||
21 | |||
22 | import Data.Packed.Internal(partit) | ||
23 | import Data.Packed | 21 | import Data.Packed |
24 | import Numeric.GSL.Vector | 22 | import Numeric.GSL.Vector |
25 | import Complex | 23 | import Complex |
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs index 07b9f63..7ebd1f2 100644 --- a/lib/Numeric/LinearAlgebra/Tests.hs +++ b/lib/Numeric/LinearAlgebra/Tests.hs | |||
@@ -22,11 +22,15 @@ module Numeric.LinearAlgebra.Tests( | |||
22 | import Numeric.LinearAlgebra | 22 | import Numeric.LinearAlgebra |
23 | import Numeric.LinearAlgebra.Tests.Instances | 23 | import Numeric.LinearAlgebra.Tests.Instances |
24 | import Numeric.LinearAlgebra.Tests.Properties | 24 | import Numeric.LinearAlgebra.Tests.Properties |
25 | import Test.QuickCheck | 25 | import Test.QuickCheck hiding (test) |
26 | import Test.HUnit hiding ((~:),test) | 26 | import Test.HUnit hiding ((~:),test) |
27 | import System.Info | 27 | import System.Info |
28 | import Data.List(foldl1') | 28 | import Data.List(foldl1') |
29 | import Numeric.GSL hiding (sin,cos,exp,choose) | 29 | import Numeric.GSL hiding (sin,cos,exp,choose) |
30 | import Prelude hiding ((^)) | ||
31 | import qualified Prelude | ||
32 | |||
33 | a ^ b = a Prelude.^ (b :: Int) | ||
30 | 34 | ||
31 | qCheck n = check defaultConfig {configSize = const n} | 35 | qCheck n = check defaultConfig {configSize = const n} |
32 | 36 | ||
@@ -73,7 +77,7 @@ besselTest = utest "bessel_J0_e" ( abs (r-expected) < e ) | |||
73 | expected = -0.17759677131433830434739701 | 77 | expected = -0.17759677131433830434739701 |
74 | 78 | ||
75 | exponentialTest = utest "exp_e10_e" ( abs (v*10^e - expected) < 4E-2 ) | 79 | exponentialTest = utest "exp_e10_e" ( abs (v*10^e - expected) < 4E-2 ) |
76 | where (v,e,err) = exp_e10_e 30.0 | 80 | where (v,e,_err) = exp_e10_e 30.0 |
77 | expected = exp 30.0 | 81 | expected = exp 30.0 |
78 | 82 | ||
79 | --------------------------------------------------------------------- | 83 | --------------------------------------------------------------------- |
diff --git a/lib/Numeric/LinearAlgebra/Tests/Instances.hs b/lib/Numeric/LinearAlgebra/Tests/Instances.hs index 677ad2b..4e829d2 100644 --- a/lib/Numeric/LinearAlgebra/Tests/Instances.hs +++ b/lib/Numeric/LinearAlgebra/Tests/Instances.hs | |||
@@ -30,9 +30,9 @@ import Control.Monad(replicateM) | |||
30 | 30 | ||
31 | instance (Arbitrary a, RealFloat a) => Arbitrary (Complex a) where | 31 | instance (Arbitrary a, RealFloat a) => Arbitrary (Complex a) where |
32 | arbitrary = do | 32 | arbitrary = do |
33 | r <- arbitrary | 33 | re <- arbitrary |
34 | i <- arbitrary | 34 | im <- arbitrary |
35 | return (r:+i) | 35 | return (re :+ im) |
36 | coarbitrary = undefined | 36 | coarbitrary = undefined |
37 | 37 | ||
38 | chooseDim = sized $ \m -> choose (1,max 1 m) | 38 | chooseDim = sized $ \m -> choose (1,max 1 m) |
diff --git a/lib/Numeric/LinearAlgebra/Tests/Properties.hs b/lib/Numeric/LinearAlgebra/Tests/Properties.hs index 5663b86..b5321c2 100644 --- a/lib/Numeric/LinearAlgebra/Tests/Properties.hs +++ b/lib/Numeric/LinearAlgebra/Tests/Properties.hs | |||
@@ -39,11 +39,10 @@ module Numeric.LinearAlgebra.Tests.Properties ( | |||
39 | ) where | 39 | ) where |
40 | 40 | ||
41 | import Numeric.LinearAlgebra | 41 | import Numeric.LinearAlgebra |
42 | import Numeric.LinearAlgebra.Tests.Instances(Sq(..),Her(..),Rot(..)) | ||
43 | import Test.QuickCheck | 42 | import Test.QuickCheck |
44 | import Debug.Trace | 43 | -- import Debug.Trace |
45 | 44 | ||
46 | debug x = trace (show x) x | 45 | -- debug x = trace (show x) x |
47 | 46 | ||
48 | -- relative error | 47 | -- relative error |
49 | dist :: (Normed t, Num t) => t -> t -> Double | 48 | dist :: (Normed t, Num t) => t -> t -> Double |
@@ -77,7 +76,7 @@ hermitian m = square m && m |~| ctrans m | |||
77 | wellCond m = rcond m > 1/100 | 76 | wellCond m = rcond m > 1/100 |
78 | 77 | ||
79 | positiveDefinite m = minimum (toList e) > 0 | 78 | positiveDefinite m = minimum (toList e) > 0 |
80 | where (e,v) = eigSH m | 79 | where (e,_v) = eigSH m |
81 | 80 | ||
82 | upperTriang m = rows m == 1 || down == z | 81 | upperTriang m = rows m == 1 || down == z |
83 | where down = fromList $ concat $ zipWith drop [1..] (toLists (ctrans m)) | 82 | where down = fromList $ concat $ zipWith drop [1..] (toLists (ctrans m)) |
@@ -107,8 +106,8 @@ pinvProp m = m <> p <> m |~| m | |||
107 | 106 | ||
108 | detProp m = s d1 |~| s d2 | 107 | detProp m = s d1 |~| s d2 |
109 | where d1 = det m | 108 | where d1 = det m |
110 | d2 = det' m * det q | 109 | d2 = det' * det q |
111 | det' m = product $ toList $ takeDiag r | 110 | det' = product $ toList $ takeDiag r |
112 | (q,r) = qr m | 111 | (q,r) = qr m |
113 | s x = fromList [x] | 112 | s x = fromList [x] |
114 | 113 | ||
@@ -147,10 +146,10 @@ schurProp2 m = m |~| u <> s <> ctrans u && unitary u && upperHessenberg s -- fix | |||
147 | 146 | ||
148 | cholProp m = m |~| ctrans c <> c && upperTriang c | 147 | cholProp m = m |~| ctrans c <> c && upperTriang c |
149 | where c = chol m | 148 | where c = chol m |
150 | pos = positiveDefinite m | 149 | -- pos = positiveDefinite m |
151 | 150 | ||
152 | expmDiagProp m = expm (logm m) :~ 7 ~: complex m | 151 | expmDiagProp m = expm (logm m) :~ 7 ~: complex m |
153 | where logm m = matFunc log m | 152 | where logm = matFunc log |
154 | 153 | ||
155 | multProp1 (a,b) = a <> b |~| mulH a b | 154 | multProp1 (a,b) = a <> b |~| mulH a b |
156 | 155 | ||