From 8bdb87764762ef43b186bcc04caa404928df22fa Mon Sep 17 00:00:00 2001 From: Sidharth Kapur Date: Mon, 1 Feb 2016 17:40:40 -0600 Subject: some work (will probably undo this commit later) --- packages/tests/src/Numeric/LinearAlgebra/Tests.hs | 10 +++++ .../src/Numeric/LinearAlgebra/Tests/Instances.hs | 46 +++++++++++++++++++--- .../src/Numeric/LinearAlgebra/Tests/Properties.hs | 44 ++++++++++++++++++++- 3 files changed, 94 insertions(+), 6 deletions(-) (limited to 'packages/tests/src') diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs index d9cc3b6..043ebf3 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs @@ -26,6 +26,7 @@ module Numeric.LinearAlgebra.Tests( utest, runTests, runBenchmarks + , binaryTests -- , findNaN --, runBigTests ) where @@ -743,6 +744,15 @@ makeUnitary v | realPart n > 1 = v / scalar n | otherwise = v where n = sqrt (v `dot` v) +binaryTests :: IO () +binaryTests = do + let test :: forall t . T.Testable t => t -> IO () + test = qCheck 100 + test vectorBinaryRoundtripProp + test staticVectorBinaryRoundtripProp + qCheck 30 matrixBinaryRoundtripProp + qCheck 30 staticMatrixBinaryRoundtripProp + -- -- | Some additional tests on big matrices. They take a few minutes. -- runBigTests :: IO () -- runBigTests = undefined diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs index 3d5441d..23d7e6f 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE FlexibleContexts, UndecidableInstances, FlexibleInstances #-} +{-# LANGUAGE FlexibleContexts, UndecidableInstances, FlexibleInstances, ScopedTypeVariables #-} ----------------------------------------------------------------------------- {- | Module : Numeric.LinearAlgebra.Tests.Instances @@ -29,6 +29,10 @@ import Numeric.LinearAlgebra.HMatrix hiding (vector) import Control.Monad(replicateM) import Test.QuickCheck(Arbitrary,arbitrary,choose,vector,sized,shrink) +import GHC.TypeLits +import Data.Proxy (Proxy) +import qualified Numeric.LinearAlgebra.Static as Static + shrinkListElementwise :: (Arbitrary a) => [a] -> [[a]] shrinkListElementwise [] = [] @@ -40,14 +44,27 @@ shrinkPair (a,b) = [ (a,x) | x <- shrink b ] ++ [ (x,b) | x <- shrink a ] chooseDim = sized $ \m -> choose (1,max 1 m) -instance (Field a, Arbitrary a) => Arbitrary (Vector a) where +instance (Field a, Arbitrary a) => Arbitrary (Vector a) where arbitrary = do m <- chooseDim l <- vector m return $ fromList l -- shrink any one of the components shrink = map fromList . shrinkListElementwise . toList -instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where +instance KnownNat n => Arbitrary (Static.R n) where + arbitrary = do + l <- vector n + return (Static.fromList l) + + where proxy :: Proxy n + proxy = proxy + + n :: Int + n = fromIntegral (natVal proxy) + + shrink v = [] + +instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where arbitrary = do m <- chooseDim n <- chooseDim @@ -57,9 +74,28 @@ instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where -- shrink any one of the components shrink a = map (rows a >< cols a) . shrinkListElementwise - . concat . toLists + . concat . toLists $ a +instance (KnownNat n, KnownNat m) => Arbitrary (Static.L m n) where + arbitrary = do + l <- vector (m * n) + return (Static.fromList l) + + where proxyM :: Proxy m + proxyM = proxyM + + proxyN :: Proxy n + proxyN = proxyN + + m :: Int + m = fromIntegral (natVal proxyM) + + n :: Int + n = fromIntegral (natVal proxyN) + + shrink mat = [] + -- a square matrix newtype (Sq a) = Sq (Matrix a) deriving Show instance (Element a, Arbitrary a) => Arbitrary (Sq a) where @@ -121,7 +157,7 @@ instance (ArbitraryField a, Numeric a) => Arbitrary (SqWC a) where -- a positive definite square matrix (the eigenvalues are between 0 and 100) newtype (PosDef a) = PosDef (Matrix a) deriving Show -instance (Numeric a, ArbitraryField a, Num (Vector a)) +instance (Numeric a, ArbitraryField a, Num (Vector a)) => Arbitrary (PosDef a) where arbitrary = do m <- arbitrary diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs index 046644f..0de9f37 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs @@ -1,5 +1,6 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE DataKinds #-} ----------------------------------------------------------------------------- {- | @@ -39,12 +40,25 @@ module Numeric.LinearAlgebra.Tests.Properties ( expmDiagProp, multProp1, multProp2, subProp, - linearSolveProp, linearSolvePropH, linearSolveProp2 + linearSolveProp, linearSolvePropH, linearSolveProp2, + + -- Binary properties + vectorBinaryRoundtripProp + , staticVectorBinaryRoundtripProp + , matrixBinaryRoundtripProp + , staticMatrixBinaryRoundtripProp + , staticVectorBinaryFailProp ) where import Numeric.LinearAlgebra.HMatrix hiding (Testable,unitary) +import qualified Numeric.LinearAlgebra.Static as Static import Test.QuickCheck +import Data.Binary +import Data.Binary.Get (runGet) +import Data.Either (isLeft) +import Debug.Trace (traceShowId) + (~=) :: Double -> Double -> Bool a ~= b = abs (a - b) < 1e-10 @@ -275,3 +289,31 @@ linearSolveProp2 f (a,x) = not wc `trivial` (not wc || a <> f a b |~| b) subProp m = m == (conj . tr . fromColumns . toRows) m +------------------------------------------------------------------ + +vectorBinaryRoundtripProp :: Vector Double -> Bool +vectorBinaryRoundtripProp vec = decode (encode vec) == vec + +staticVectorBinaryRoundtripProp :: Static.R 5 -> Bool +staticVectorBinaryRoundtripProp vec = + let + decoded = decode (encode vec) :: Static.R 500 + in + Static.extract decoded == Static.extract vec + +matrixBinaryRoundtripProp :: Matrix Double -> Bool +matrixBinaryRoundtripProp mat = decode (encode mat) == mat + +staticMatrixBinaryRoundtripProp :: Static.L 100 200 -> Bool +staticMatrixBinaryRoundtripProp mat = + let + decoded = decode (encode mat) :: Static.L 100 200 + in + (Static.extract decoded) == (Static.extract mat) + +staticVectorBinaryFailProp :: Static.R 20 -> Bool +staticVectorBinaryFailProp vec = + let + decoded = runGet get (encode vec) :: Either String (Static.R 50) + in + isLeft decoded -- cgit v1.2.3