summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--packages/base/src/Internal/Matrix.hs1
-rw-r--r--packages/base/src/Internal/Static.hs24
-rw-r--r--packages/tests/hmatrix-tests.cabal1
-rw-r--r--packages/tests/src/Numeric/LinearAlgebra/Tests.hs10
-rw-r--r--packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs46
-rw-r--r--packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs44
6 files changed, 116 insertions, 10 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index 6efbe5f..c47c625 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -5,6 +5,7 @@
5{-# LANGUAGE TypeOperators #-} 5{-# LANGUAGE TypeOperators #-}
6{-# LANGUAGE TypeFamilies #-} 6{-# LANGUAGE TypeFamilies #-}
7{-# LANGUAGE ViewPatterns #-} 7{-# LANGUAGE ViewPatterns #-}
8{-# LANGUAGE DeriveGeneric #-}
8{-# LANGUAGE ConstrainedClassMethods #-} 9{-# LANGUAGE ConstrainedClassMethods #-}
9 10
10-- | 11-- |
diff --git a/packages/base/src/Internal/Static.hs b/packages/base/src/Internal/Static.hs
index 2c31097..1e9a5a3 100644
--- a/packages/base/src/Internal/Static.hs
+++ b/packages/base/src/Internal/Static.hs
@@ -13,6 +13,7 @@
13{-# LANGUAGE TypeOperators #-} 13{-# LANGUAGE TypeOperators #-}
14{-# LANGUAGE ViewPatterns #-} 14{-# LANGUAGE ViewPatterns #-}
15{-# LANGUAGE BangPatterns #-} 15{-# LANGUAGE BangPatterns #-}
16{-# LANGUAGE DeriveGeneric #-}
16 17
17{- | 18{- |
18Module : Internal.Static 19Module : Internal.Static
@@ -34,6 +35,8 @@ import Control.DeepSeq
34import Data.Proxy(Proxy) 35import Data.Proxy(Proxy)
35import Foreign.Storable(Storable) 36import Foreign.Storable(Storable)
36import Text.Printf 37import Text.Printf
38import Data.Binary
39import GHC.Generics (Generic)
37 40
38-------------------------------------------------------------------------------- 41--------------------------------------------------------------------------------
39 42
@@ -41,7 +44,14 @@ type ℝ = Double
41type ℂ = Complex Double 44type ℂ = Complex Double
42 45
43newtype Dim (n :: Nat) t = Dim t 46newtype Dim (n :: Nat) t = Dim t
44 deriving Show 47 deriving (Show, Generic)
48
49instance Binary a => Binary (Complex a)
50 where
51 put (r :+ i) = put (r, i)
52 get = (\(r,i) -> r :+ i) <$> get
53
54instance (Binary a) => Binary (Dim n a)
45 55
46lift1F 56lift1F
47 :: (c t -> c t) 57 :: (c t -> c t)
@@ -59,15 +69,21 @@ instance NFData t => NFData (Dim n t) where
59-------------------------------------------------------------------------------- 69--------------------------------------------------------------------------------
60 70
61newtype R n = R (Dim n (Vector ℝ)) 71newtype R n = R (Dim n (Vector ℝ))
62 deriving (Num,Fractional,Floating) 72 deriving (Num,Fractional,Floating,Generic)
63 73
64newtype C n = C (Dim n (Vector ℂ)) 74newtype C n = C (Dim n (Vector ℂ))
65 deriving (Num,Fractional,Floating) 75 deriving (Num,Fractional,Floating,Generic)
66 76
67newtype L m n = L (Dim m (Dim n (Matrix ℝ))) 77newtype L m n = L (Dim m (Dim n (Matrix ℝ)))
78 deriving (Generic)
68 79
69newtype M m n = M (Dim m (Dim n (Matrix ℂ))) 80newtype M m n = M (Dim m (Dim n (Matrix ℂ)))
81 deriving (Generic)
70 82
83instance (KnownNat n) => Binary (R n)
84instance (KnownNat n) => Binary (C n)
85instance (KnownNat m, KnownNat n) => Binary (L m n)
86instance (KnownNat m, KnownNat n) => Binary (M m n)
71 87
72mkR :: Vector ℝ -> R n 88mkR :: Vector ℝ -> R n
73mkR = R . Dim 89mkR = R . Dim
diff --git a/packages/tests/hmatrix-tests.cabal b/packages/tests/hmatrix-tests.cabal
index d4c87aa..00f3a38 100644
--- a/packages/tests/hmatrix-tests.cabal
+++ b/packages/tests/hmatrix-tests.cabal
@@ -29,6 +29,7 @@ library
29 Build-Depends: base >= 4 && < 5, deepseq, 29 Build-Depends: base >= 4 && < 5, deepseq,
30 QuickCheck >= 2, HUnit, random, 30 QuickCheck >= 2, HUnit, random,
31 hmatrix >= 0.18 31 hmatrix >= 0.18
32 , binary
32 if flag(gsl) 33 if flag(gsl)
33 Build-Depends: hmatrix-gsl >= 0.18 34 Build-Depends: hmatrix-gsl >= 0.18
34 35
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
index d9cc3b6..043ebf3 100644
--- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
+++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
@@ -26,6 +26,7 @@ module Numeric.LinearAlgebra.Tests(
26 utest, 26 utest,
27 runTests, 27 runTests,
28 runBenchmarks 28 runBenchmarks
29 , binaryTests
29-- , findNaN 30-- , findNaN
30--, runBigTests 31--, runBigTests
31) where 32) where
@@ -743,6 +744,15 @@ makeUnitary v | realPart n > 1 = v / scalar n
743 | otherwise = v 744 | otherwise = v
744 where n = sqrt (v `dot` v) 745 where n = sqrt (v `dot` v)
745 746
747binaryTests :: IO ()
748binaryTests = do
749 let test :: forall t . T.Testable t => t -> IO ()
750 test = qCheck 100
751 test vectorBinaryRoundtripProp
752 test staticVectorBinaryRoundtripProp
753 qCheck 30 matrixBinaryRoundtripProp
754 qCheck 30 staticMatrixBinaryRoundtripProp
755
746-- -- | Some additional tests on big matrices. They take a few minutes. 756-- -- | Some additional tests on big matrices. They take a few minutes.
747-- runBigTests :: IO () 757-- runBigTests :: IO ()
748-- runBigTests = undefined 758-- runBigTests = undefined
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs
index 3d5441d..23d7e6f 100644
--- a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs
+++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs
@@ -1,4 +1,4 @@
1{-# LANGUAGE FlexibleContexts, UndecidableInstances, FlexibleInstances #-} 1{-# LANGUAGE FlexibleContexts, UndecidableInstances, FlexibleInstances, ScopedTypeVariables #-}
2----------------------------------------------------------------------------- 2-----------------------------------------------------------------------------
3{- | 3{- |
4Module : Numeric.LinearAlgebra.Tests.Instances 4Module : Numeric.LinearAlgebra.Tests.Instances
@@ -29,6 +29,10 @@ import Numeric.LinearAlgebra.HMatrix hiding (vector)
29import Control.Monad(replicateM) 29import Control.Monad(replicateM)
30import Test.QuickCheck(Arbitrary,arbitrary,choose,vector,sized,shrink) 30import Test.QuickCheck(Arbitrary,arbitrary,choose,vector,sized,shrink)
31 31
32import GHC.TypeLits
33import Data.Proxy (Proxy)
34import qualified Numeric.LinearAlgebra.Static as Static
35
32 36
33shrinkListElementwise :: (Arbitrary a) => [a] -> [[a]] 37shrinkListElementwise :: (Arbitrary a) => [a] -> [[a]]
34shrinkListElementwise [] = [] 38shrinkListElementwise [] = []
@@ -40,14 +44,27 @@ shrinkPair (a,b) = [ (a,x) | x <- shrink b ] ++ [ (x,b) | x <- shrink a ]
40 44
41chooseDim = sized $ \m -> choose (1,max 1 m) 45chooseDim = sized $ \m -> choose (1,max 1 m)
42 46
43instance (Field a, Arbitrary a) => Arbitrary (Vector a) where 47instance (Field a, Arbitrary a) => Arbitrary (Vector a) where
44 arbitrary = do m <- chooseDim 48 arbitrary = do m <- chooseDim
45 l <- vector m 49 l <- vector m
46 return $ fromList l 50 return $ fromList l
47 -- shrink any one of the components 51 -- shrink any one of the components
48 shrink = map fromList . shrinkListElementwise . toList 52 shrink = map fromList . shrinkListElementwise . toList
49 53
50instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where 54instance KnownNat n => Arbitrary (Static.R n) where
55 arbitrary = do
56 l <- vector n
57 return (Static.fromList l)
58
59 where proxy :: Proxy n
60 proxy = proxy
61
62 n :: Int
63 n = fromIntegral (natVal proxy)
64
65 shrink v = []
66
67instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where
51 arbitrary = do 68 arbitrary = do
52 m <- chooseDim 69 m <- chooseDim
53 n <- chooseDim 70 n <- chooseDim
@@ -57,9 +74,28 @@ instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where
57 -- shrink any one of the components 74 -- shrink any one of the components
58 shrink a = map (rows a >< cols a) 75 shrink a = map (rows a >< cols a)
59 . shrinkListElementwise 76 . shrinkListElementwise
60 . concat . toLists 77 . concat . toLists
61 $ a 78 $ a
62 79
80instance (KnownNat n, KnownNat m) => Arbitrary (Static.L m n) where
81 arbitrary = do
82 l <- vector (m * n)
83 return (Static.fromList l)
84
85 where proxyM :: Proxy m
86 proxyM = proxyM
87
88 proxyN :: Proxy n
89 proxyN = proxyN
90
91 m :: Int
92 m = fromIntegral (natVal proxyM)
93
94 n :: Int
95 n = fromIntegral (natVal proxyN)
96
97 shrink mat = []
98
63-- a square matrix 99-- a square matrix
64newtype (Sq a) = Sq (Matrix a) deriving Show 100newtype (Sq a) = Sq (Matrix a) deriving Show
65instance (Element a, Arbitrary a) => Arbitrary (Sq a) where 101instance (Element a, Arbitrary a) => Arbitrary (Sq a) where
@@ -121,7 +157,7 @@ instance (ArbitraryField a, Numeric a) => Arbitrary (SqWC a) where
121 157
122-- a positive definite square matrix (the eigenvalues are between 0 and 100) 158-- a positive definite square matrix (the eigenvalues are between 0 and 100)
123newtype (PosDef a) = PosDef (Matrix a) deriving Show 159newtype (PosDef a) = PosDef (Matrix a) deriving Show
124instance (Numeric a, ArbitraryField a, Num (Vector a)) 160instance (Numeric a, ArbitraryField a, Num (Vector a))
125 => Arbitrary (PosDef a) where 161 => Arbitrary (PosDef a) where
126 arbitrary = do 162 arbitrary = do
127 m <- arbitrary 163 m <- arbitrary
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs
index 046644f..0de9f37 100644
--- a/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs
+++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs
@@ -1,5 +1,6 @@
1{-# LANGUAGE FlexibleContexts #-} 1{-# LANGUAGE FlexibleContexts #-}
2{-# LANGUAGE TypeFamilies #-} 2{-# LANGUAGE TypeFamilies #-}
3{-# LANGUAGE DataKinds #-}
3 4
4----------------------------------------------------------------------------- 5-----------------------------------------------------------------------------
5{- | 6{- |
@@ -39,12 +40,25 @@ module Numeric.LinearAlgebra.Tests.Properties (
39 expmDiagProp, 40 expmDiagProp,
40 multProp1, multProp2, 41 multProp1, multProp2,
41 subProp, 42 subProp,
42 linearSolveProp, linearSolvePropH, linearSolveProp2 43 linearSolveProp, linearSolvePropH, linearSolveProp2,
44
45 -- Binary properties
46 vectorBinaryRoundtripProp
47 , staticVectorBinaryRoundtripProp
48 , matrixBinaryRoundtripProp
49 , staticMatrixBinaryRoundtripProp
50 , staticVectorBinaryFailProp
43) where 51) where
44 52
45import Numeric.LinearAlgebra.HMatrix hiding (Testable,unitary) 53import Numeric.LinearAlgebra.HMatrix hiding (Testable,unitary)
54import qualified Numeric.LinearAlgebra.Static as Static
46import Test.QuickCheck 55import Test.QuickCheck
47 56
57import Data.Binary
58import Data.Binary.Get (runGet)
59import Data.Either (isLeft)
60import Debug.Trace (traceShowId)
61
48(~=) :: Double -> Double -> Bool 62(~=) :: Double -> Double -> Bool
49a ~= b = abs (a - b) < 1e-10 63a ~= b = abs (a - b) < 1e-10
50 64
@@ -275,3 +289,31 @@ linearSolveProp2 f (a,x) = not wc `trivial` (not wc || a <> f a b |~| b)
275 289
276subProp m = m == (conj . tr . fromColumns . toRows) m 290subProp m = m == (conj . tr . fromColumns . toRows) m
277 291
292------------------------------------------------------------------
293
294vectorBinaryRoundtripProp :: Vector Double -> Bool
295vectorBinaryRoundtripProp vec = decode (encode vec) == vec
296
297staticVectorBinaryRoundtripProp :: Static.R 5 -> Bool
298staticVectorBinaryRoundtripProp vec =
299 let
300 decoded = decode (encode vec) :: Static.R 500
301 in
302 Static.extract decoded == Static.extract vec
303
304matrixBinaryRoundtripProp :: Matrix Double -> Bool
305matrixBinaryRoundtripProp mat = decode (encode mat) == mat
306
307staticMatrixBinaryRoundtripProp :: Static.L 100 200 -> Bool
308staticMatrixBinaryRoundtripProp mat =
309 let
310 decoded = decode (encode mat) :: Static.L 100 200
311 in
312 (Static.extract decoded) == (Static.extract mat)
313
314staticVectorBinaryFailProp :: Static.R 20 -> Bool
315staticVectorBinaryFailProp vec =
316 let
317 decoded = runGet get (encode vec) :: Either String (Static.R 50)
318 in
319 isLeft decoded