From cf3c788f0c44577ac1a5365e8154200b53a36409 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Tue, 27 May 2014 10:41:40 +0200 Subject: static dimensions, cont. --- packages/base/hmatrix.cabal | 3 +- packages/base/src/Data/Packed/Internal/Numeric.hs | 7 +- packages/base/src/Data/Packed/Numeric.hs | 41 ++- packages/base/src/Numeric/HMatrix.hs | 63 ++-- packages/base/src/Numeric/LinearAlgebra/Data.hs | 11 +- packages/base/src/Numeric/LinearAlgebra/Real.hs | 337 +++++++++++++++++++++ packages/base/src/Numeric/LinearAlgebra/Util.hs | 20 +- packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | 86 +++++- .../base/src/Numeric/LinearAlgebra/Util/Static.hs | 70 ----- packages/base/src/Numeric/Sparse.hs | 127 +++----- packages/gsl/src/Numeric/GSL/Fitting.hs | 2 +- packages/tests/src/Numeric/LinearAlgebra/Tests.hs | 17 +- 12 files changed, 550 insertions(+), 234 deletions(-) create mode 100644 packages/base/src/Numeric/LinearAlgebra/Real.hs delete mode 100644 packages/base/src/Numeric/LinearAlgebra/Util/Static.hs diff --git a/packages/base/hmatrix.cabal b/packages/base/hmatrix.cabal index e958de0..3ca6659 100644 --- a/packages/base/hmatrix.cabal +++ b/packages/base/hmatrix.cabal @@ -47,7 +47,7 @@ library Numeric.HMatrix Numeric.LinearAlgebra.Devel Numeric.LinearAlgebra.Data - + Numeric.LinearAlgebra.Real other-modules: Data.Packed.Internal, @@ -67,7 +67,6 @@ library Numeric.LinearAlgebra.Random Numeric.Conversion Numeric.Sparse - Numeric.LinearAlgebra.Util.Static C-sources: src/C/lapack-aux.c src/C/vector-aux.c diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Data/Packed/Internal/Numeric.hs index 3c1c1d0..0205a17 100644 --- a/packages/base/src/Data/Packed/Internal/Numeric.hs +++ b/packages/base/src/Data/Packed/Internal/Numeric.hs @@ -3,6 +3,7 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE UndecidableInstances #-} ----------------------------------------------------------------------------- @@ -692,12 +693,12 @@ condV f a b l e t = f a' b' l' e' t' -------------------------------------------------------------------------------- -class Transposable t +class Transposable m mt | m -> mt, mt -> m where -- | (conjugate) transpose - tr :: t -> t + tr :: m -> mt -instance (Container Vector t) => Transposable (Matrix t) +instance (Container Vector t) => Transposable (Matrix t) (Matrix t) where tr = ctrans diff --git a/packages/base/src/Data/Packed/Numeric.hs b/packages/base/src/Data/Packed/Numeric.hs index 01cf6c5..7d88cbc 100644 --- a/packages/base/src/Data/Packed/Numeric.hs +++ b/packages/base/src/Data/Packed/Numeric.hs @@ -32,7 +32,7 @@ module Data.Packed.Numeric ( diag, ident, ctrans, -- * Generic operations - Container(..), + Container(..), Numeric, -- add, mul, sub, divide, equal, scaleRecip, addConstant, scalar, conj, scale, arctan2, cmap, atIndex, minIndex, maxIndex, minElement, maxElement, @@ -40,7 +40,7 @@ module Data.Packed.Numeric ( step, cond, find, assoc, accum, Transposable(..), Linear(..), -- * Matrix product - Product(..), udot, dot, (◇), + Product(..), udot, dot, (◇), (<·>), (#>), Mul(..), Contraction(..),(<.>), optimiseMult, @@ -96,7 +96,7 @@ linspace n (a,b) = addConstant a $ scale s $ fromList $ map fromIntegral [0 .. n -------------------------------------------------------- -{- | Matrix product, matrix - vector product, and dot product (equivalent to 'contraction') +{- Matrix product, matrix - vector product, and dot product (equivalent to 'contraction') (This operator can also be written using the unicode symbol ◇ (25c7).) @@ -138,9 +138,8 @@ For complex vectors the first argument is conjugated: >>> fromList [1,i,1-i] <.> complex a fromList [10.0 :+ 4.0,12.0 :+ 4.0,14.0 :+ 4.0,16.0 :+ 4.0] -} -infixl 7 <.> -(<.>) :: Contraction a b c => a -> b -> c -(<.>) = contraction + + class Contraction a b c | a b -> c @@ -160,6 +159,23 @@ instance (Container Vector t, Product t) => Contraction (Vector t) (Matrix t) (V instance Product t => Contraction (Matrix t) (Matrix t) (Matrix t) where contraction = mXm +-------------------------------------------------------------------------------- + +infixl 7 <.> +-- | An infix synonym for 'dot' +(<.>) :: Numeric t => Vector t -> Vector t -> t +(<.>) = dot + + +infixr 8 <·>, #> +-- | dot product +(<·>) :: Numeric t => Vector t -> Vector t -> t +(<·>) = dot + + +-- | matrix-vector product +(#>) :: Numeric t => Matrix t -> Vector t -> Vector t +(#>) = mXv -------------------------------------------------------------------------------- @@ -286,3 +302,16 @@ meanCov x = (med,cov) where -------------------------------------------------------------------------------- +class ( Container Vector t + , Container Matrix t + , Konst t Int Vector + , Konst t (Int,Int) Matrix + , Product t + ) => Numeric t + +instance Numeric Double +instance Numeric (Complex Double) +instance Numeric Float +instance Numeric (Complex Float) + +-------------------------------------------------------------------------------- diff --git a/packages/base/src/Numeric/HMatrix.hs b/packages/base/src/Numeric/HMatrix.hs index d5c66fb..1c70ef6 100644 --- a/packages/base/src/Numeric/HMatrix.hs +++ b/packages/base/src/Numeric/HMatrix.hs @@ -10,16 +10,16 @@ Stability : provisional ----------------------------------------------------------------------------- module Numeric.HMatrix ( - -- * Basic types and data processing + -- * Basic types and data processing module Numeric.LinearAlgebra.Data, - + -- * Arithmetic and numeric classes -- | -- The standard numeric classes are defined elementwise: -- -- >>> fromList [1,2,3] * fromList [3,0,-2 :: Double] -- fromList [3.0,0.0,-6.0] - -- + -- -- >>> (3><3) [1..9] * ident 3 :: Matrix Double -- (3><3) -- [ 1.0, 0.0, 0.0 @@ -29,7 +29,7 @@ module Numeric.HMatrix ( -- In arithmetic operations single-element vectors and matrices -- (created from numeric literals or using 'scalar') automatically -- expand to match the dimensions of the other operand: - -- + -- -- >>> 5 + 2*ident 3 :: Matrix Double -- (3><3) -- [ 7.0, 5.0, 5.0 @@ -37,13 +37,14 @@ module Numeric.HMatrix ( -- , 5.0, 5.0, 7.0 ] -- - -- * Matrix product - (<.>), - - -- | The overloaded multiplication operators may need type annotations to remove - -- ambiguity. In those cases we can also use the specific functions 'mXm', 'mXv', and 'dot'. - -- - -- The matrix x matrix product is also implemented in the "Data.Monoid" instance, where + -- * Products + -- ** dot + (<·>), + -- ** matrix-vector + (#>),(!#>), + -- ** matrix-matrix + (<>), + -- | The matrix x matrix product is also implemented in the "Data.Monoid" instance, where -- single-element matrices (created from numeric literals or using 'scalar') -- are used for scaling. -- @@ -55,12 +56,12 @@ module Numeric.HMatrix ( -- -- 'mconcat' uses 'optimiseMult' to get the optimal association order. - - -- * Other products + + -- ** other outer, kronecker, cross, scale, sumElements, prodElements, - + -- * Linear Systems (<\>), linearSolve, @@ -70,14 +71,14 @@ module Numeric.HMatrix ( cholSolve, cgSolve, cgSolve', - + -- * Inverse and pseudoinverse inv, pinv, pinvTol, -- * Determinant and rank - rcond, rank, ranksv, + rcond, rank, ranksv, det, invlndet, - + -- * Singular value decomposition svd, fullSVD, @@ -85,7 +86,7 @@ module Numeric.HMatrix ( compactSVD, singularValues, leftSV, rightSV, - + -- * Eigensystems eig, eigSH, eigSH', eigenvalues, eigenvaluesSH, eigenvaluesSH', @@ -105,7 +106,7 @@ module Numeric.HMatrix ( -- * LU lu, luPacked, - + -- * Matrix functions expm, sqrtm, @@ -116,7 +117,7 @@ module Numeric.HMatrix ( nullVector, nullspaceSVD, null1, null1sym, - + orth, -- * Norms @@ -129,30 +130,36 @@ module Numeric.HMatrix ( -- * Random arrays - RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample, - + Seed, RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample, + -- * Misc - meanCov, peps, relativeError, haussholder, optimiseMult, dot, udot, mXm, mXv, smXv, (<>), (◇), Seed, checkT, + meanCov, peps, relativeError, haussholder, optimiseMult, udot, -- * Auxiliary classes - Element, Container, Product, Numeric, Contraction, LSDiv, + Element, Container, Product, Contraction(..), Numeric, LSDiv, Complexable, RealElement, RealOf, ComplexOf, SingleOf, DoubleOf, IndexOf, Field, Normed, - CGMat, Transposable, - ℕ,ℤ,ℝ,ℂ,ℝn,ℂn, 𝑖, i_C --ℍ + Transposable, + CGState(..), + Testable(..), + ℕ,ℤ,ℝ,ℂ, 𝑖, i_C --ℍ ) where import Numeric.LinearAlgebra.Data import Numeric.Matrix() import Numeric.Vector() -import Data.Packed.Numeric +import Data.Packed.Numeric hiding ((<>)) import Numeric.LinearAlgebra.Algorithms import Numeric.LinearAlgebra.Util import Numeric.LinearAlgebra.Random -import Numeric.Sparse(smXv) +import Numeric.Sparse((!#>)) import Numeric.LinearAlgebra.Util.CG +-- | matrix product +(<>) :: Numeric t => Matrix t -> Matrix t -> Matrix t +(<>) = mXm +infixr 8 <> diff --git a/packages/base/src/Numeric/LinearAlgebra/Data.hs b/packages/base/src/Numeric/LinearAlgebra/Data.hs index 3128a24..3417a5e 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Data.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Data.hs @@ -49,13 +49,9 @@ module Numeric.LinearAlgebra.Data( find, maxIndex, minIndex, maxElement, minElement, atIndex, -- * Sparse - SMatrix, AssocMatrix, mkCSR, toDense, - mkDiag, - - -- * Static dimensions + GMatrix, AssocMatrix, mkSparse, toDense, + mkDiagR, dense, - Static, ddata, R, vect0, sScalar, vect2, vect3, (&), - -- * IO disp, loadMatrix, saveMatrix, @@ -79,9 +75,8 @@ module Numeric.LinearAlgebra.Data( import Data.Packed.Vector import Data.Packed.Matrix import Data.Packed.Numeric -import Numeric.LinearAlgebra.Util hiding ((&)) +import Numeric.LinearAlgebra.Util hiding ((&),(#)) import Data.Complex import Numeric.Sparse -import Numeric.LinearAlgebra.Util.Static diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs new file mode 100644 index 0000000..db15705 --- /dev/null +++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs @@ -0,0 +1,337 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE EmptyDataDecls #-} +{-# LANGUAGE Rank2Types #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE GADTs #-} + + +{- | +Module : Numeric.LinearAlgebra.Real +Copyright : (c) Alberto Ruiz 2006-14 +License : BSD3 +Stability : provisional + +Experimental interface for real arrays with statically checked dimensions. + +-} + +module Numeric.LinearAlgebra.Real( + -- * Vector + R, + vec2, vec3, vec4, 𝕧, (&), + -- * Matrix + L, Sq, + 𝕞, + (#),(¦),(——), + Konst(..), + eye, + diagR, diag, + blockAt, + -- * Products + (<>),(#>),(<·>), + -- * Pretty printing + Disp(..), + -- * Misc + Dim, unDim, + module Numeric.HMatrix +) where + + +import GHC.TypeLits +import Numeric.HMatrix hiding ((<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——)) +import qualified Numeric.HMatrix as LA +import Data.Packed.ST + +newtype Dim (n :: Nat) t = Dim t + deriving Show + +unDim :: Dim n t -> t +unDim (Dim x) = x + +data Proxy :: Nat -> * + + +lift1F + :: (c t -> c t) + -> Dim n (c t) -> Dim n (c t) +lift1F f (Dim v) = Dim (f v) + +lift2F + :: (c t -> c t -> c t) + -> Dim n (c t) -> Dim n (c t) -> Dim n (c t) +lift2F f (Dim u) (Dim v) = Dim (f u v) + + + +type R n = Dim n (Vector ℝ) + +type L m n = Dim m (Dim n (Matrix ℝ)) + + +infixl 4 & +(&) :: forall n . KnownNat n + => R n -> ℝ -> R (n+1) +Dim v & x = Dim (vjoin [v', scalar x]) + where + d = fromIntegral . natVal $ (undefined :: Proxy n) + v' | d > 1 && size v == 1 = LA.konst (v!0) d + | otherwise = v + + +-- vect0 :: R 0 +-- vect0 = Dim (fromList[]) + +𝕧 :: ℝ -> R 1 +𝕧 = Dim . scalar + + +vec2 :: ℝ -> ℝ -> R 2 +vec2 a b = Dim $ runSTVector $ do + v <- newUndefinedVector 2 + writeVector v 0 a + writeVector v 1 b + return v + +vec3 :: ℝ -> ℝ -> ℝ -> R 3 +vec3 a b c = Dim $ runSTVector $ do + v <- newUndefinedVector 3 + writeVector v 0 a + writeVector v 1 b + writeVector v 2 c + return v + + +vec4 :: ℝ -> ℝ -> ℝ -> ℝ -> R 4 +vec4 a b c d = Dim $ runSTVector $ do + v <- newUndefinedVector 4 + writeVector v 0 a + writeVector v 1 b + writeVector v 2 c + writeVector v 3 d + return v + + + + +instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t)) + where + (+) = lift2F (+) + (*) = lift2F (*) + (-) = lift2F (-) + abs = lift1F abs + signum = lift1F signum + negate = lift1F negate + fromInteger x = Dim (fromInteger x) + +instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t))) + where + (+) = (lift2F . lift2F) (+) + (*) = (lift2F . lift2F) (*) + (-) = (lift2F . lift2F) (-) + abs = (lift1F . lift1F) abs + signum = (lift1F . lift1F) signum + negate = (lift1F . lift1F) negate + fromInteger x = Dim (Dim (fromInteger x)) + +-------------------------------------------------------------------------------- + +class Konst t + where + konst :: ℝ -> t + +instance forall n. KnownNat n => Konst (R n) + where + konst x = Dim (LA.konst x d) + where + d = fromIntegral . natVal $ (undefined :: Proxy n) + +instance forall m n . (KnownNat m, KnownNat n) => Konst (L m n) + where + konst x = Dim (Dim (LA.konst x (m',n'))) + where + m' = fromIntegral . natVal $ (undefined :: Proxy m) + n' = fromIntegral . natVal $ (undefined :: Proxy n) + +-------------------------------------------------------------------------------- + +diagR :: forall m n k . (KnownNat m, KnownNat n) => ℝ -> R k -> L m n +diagR x v = Dim (Dim (diagRect x (unDim v) m' n')) + where + m' = fromIntegral . natVal $ (undefined :: Proxy m) + n' = fromIntegral . natVal $ (undefined :: Proxy n) + +diag :: KnownNat n => R n -> Sq n +diag = diagR 0 + +-------------------------------------------------------------------------------- + +blockAt :: forall m n . (KnownNat m, KnownNat n) => ℝ -> Int -> Int -> Matrix Double -> L m n +blockAt x r c a = Dim (Dim res) + where + z = scalar x + z1 = LA.konst x (r,c) + z2 = LA.konst x (max 0 (m'-(ra+r)), max 0 (n'-(ca+c))) + ra = min (rows a) . max 0 $ m'-r + ca = min (cols a) . max 0 $ n'-c + sa = subMatrix (0,0) (ra, ca) a + m' = fromIntegral . natVal $ (undefined :: Proxy m) + n' = fromIntegral . natVal $ (undefined :: Proxy n) + res = fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]] + +{- +matrix :: (KnownNat m, KnownNat n) => Matrix Double -> L n m +matrix = blockAt 0 0 0 +-} + +-------------------------------------------------------------------------------- + +class Disp t + where + disp :: Int -> t -> IO () + +instance Disp (L n m) + where + disp n (d2 -> a) = do + if rows a == 1 && cols a == 1 + then putStrLn $ "Const " ++ (last . words . LA.dispf n $ a) + else putStr "Dim " >> LA.disp n a + +instance Disp (R n) + where + disp n (unDim -> v) = do + let su = LA.dispf n (asRow v) + if LA.size v == 1 + then putStrLn $ "Const " ++ (last . words $ su ) + else putStr "Dim " >> putStr (tail . dropWhile (/='x') $ su) + +-------------------------------------------------------------------------------- + +infixl 3 # +(#) :: L r c -> R c -> L (r+1) c +Dim (Dim m) # Dim v = Dim (Dim (m LA.—— asRow v)) + + +𝕞 :: forall n . KnownNat n => L 0 n +𝕞 = Dim (Dim (LA.konst 0 (0,d))) + where + d = fromIntegral . natVal $ (undefined :: Proxy n) + +infixl 3 ¦ +(¦) :: L r c1 -> L r c2 -> L r (c1+c2) +Dim (Dim a) ¦ Dim (Dim b) = Dim (Dim (a LA.¦ b)) + +infixl 2 —— +(——) :: L r1 c -> L r2 c -> L (r1+r2) c +Dim (Dim a) —— Dim (Dim b) = Dim (Dim (a LA.—— b)) + + +{- + +-} + +type Sq n = L n n + +type GL = (KnownNat n, KnownNat m) => L m n +type GSq = KnownNat n => Sq n + +infixr 8 <> +(<>) :: L m k -> L k n -> L m n +(d2 -> a) <> (d2 -> b) = Dim (Dim (a LA.<> b)) + +infixr 8 #> +(#>) :: L m n -> R n -> R m +(d2 -> m) #> (unDim -> v) = Dim (m LA.#> v) + +infixr 8 <·> +(<·>) :: R n -> R n -> ℝ +(unDim -> u) <·> (unDim -> v) = udot u v + + +d2 :: forall c (n :: Nat) (n1 :: Nat). Dim n1 (Dim n c) -> c +d2 = unDim . unDim + + +instance Transposable (L m n) (L n m) + where + tr (Dim (Dim a)) = Dim (Dim (tr a)) + + +eye :: forall n . KnownNat n => Sq n +eye = Dim (Dim (ident d)) + where + d = fromIntegral . natVal $ (undefined :: Proxy n) + + +-------------------------------------------------------------------------------- + +test :: (Bool, IO ()) +test = (ok,info) + where + ok = d2 (eye :: Sq 5) == ident 5 + && d2 (mTm sm :: Sq 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..] + && d2 (tm :: L 3 5) == mat 5 [1..15] + && thingS == thingD + && precS == precD + + info = do + print $ u + print $ v + print (eye :: Sq 3) + print $ ((u & 5) + 1) <·> v + print (tm :: L 2 5) + print (tm <> sm :: L 2 3) + print thingS + print thingD + print precS + print precD + + u = vec2 3 5 + + v = 𝕧 2 & 4 & 7 + + mTm :: L n m -> Sq m + mTm a = tr a <> a + + tm :: GL + tm = lmat 0 [1..] + + lmat :: forall m n . (KnownNat m, KnownNat n) => ℝ -> [ℝ] -> L m n + lmat z xs = Dim . Dim . reshape n' . fromList . take (m'*n') $ xs ++ repeat z + where + m' = fromIntegral . natVal $ (undefined :: Proxy m) + n' = fromIntegral . natVal $ (undefined :: Proxy n) + + sm :: GSq + sm = lmat 0 [1..] + + thingS = (u & 1) <·> tr q #> q #> v + where + q = tm :: L 10 3 + + thingD = vjoin [unDim u, 1] LA.<·> tr m LA.#> m LA.#> unDim v + where + m = mat 3 [1..30] + + precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <·> konst 2 #> v + precD = 1 + 2 * vjoin[unDim u, 6] LA.<·> LA.konst 2 (size (unDim u) +1, size (unDim v)) LA.#> unDim v + + +instance (KnownNat n', KnownNat m') => Testable (L n' m') + where + checkT _ = test + +{- +do (snd test) +fst test +-} + + + diff --git a/packages/base/src/Numeric/LinearAlgebra/Util.hs b/packages/base/src/Numeric/LinearAlgebra/Util.hs index a319785..47b1090 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Util.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Util.hs @@ -32,7 +32,7 @@ module Numeric.LinearAlgebra.Util( rand, randn, cross, norm, - ℕ,ℤ,ℝ,ℂ,ℝn,ℂn,𝑖,i_C, --ℍ + ℕ,ℤ,ℝ,ℂ,𝑖,i_C, --ℍ norm_1, norm_2, norm_0, norm_Inf, norm_Frob, norm_nuclear, mnorm_1, mnorm_2, mnorm_0, mnorm_Inf, unitary, @@ -70,8 +70,8 @@ type ℝ = Double type ℕ = Int type ℤ = Int type ℂ = Complex Double -type ℝn = Vector ℝ -type ℂn = Vector ℂ +--type ℝn = Vector ℝ +--type ℂn = Vector ℂ --newtype ℍ m = H m i_C, 𝑖 :: ℂ @@ -84,7 +84,7 @@ i_C = 𝑖 fromList [1.0,2.0,3.0,4.0,5.0] -} -vect :: [ℝ] -> ℝn +vect :: [ℝ] -> Vector ℝ vect = fromList {- | create a real matrix @@ -103,18 +103,6 @@ mat mat c = reshape c . fromList - -class ( Container Vector t - , Container Matrix t - , Konst t Int Vector - , Konst t (Int,Int) Matrix - ) => Numeric t - -instance Numeric Double -instance Numeric (Complex Double) - - - {- | print a real matrix with given number of digits after the decimal point >>> disp 5 $ ident 2 / 3 diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs index 5e2ea84..50372f1 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs @@ -3,11 +3,14 @@ module Numeric.LinearAlgebra.Util.CG( cgSolve, cgSolve', - CGMat, CGState(..), R, V + CGState(..), R, V ) where import Data.Packed.Numeric +import Numeric.Sparse import Numeric.Vector() +import Numeric.LinearAlgebra.Algorithms(linearSolveLS, relativeError, NormType(..)) +import Control.Arrow((***)) {- import Util.Misc(debug, debugMat) @@ -51,7 +54,7 @@ cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx rdx = norm2 dx / max 1 (norm2 x) conjugrad - :: (Transposable m, Contraction m V V) + :: (Transposable m mt, Contraction m V V, Contraction mt V V) => Bool -> m -> V -> V -> R -> R -> [CGState] conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b @@ -82,27 +85,88 @@ takeUntil q xs = a++ take 1 b where (a,b) = break q xs -class (Transposable m, Contraction m V V) => CGMat m - cgSolve - :: CGMat m - => Bool -- ^ is symmetric - -> m -- ^ coefficient matrix + :: Bool -- ^ is symmetric + -> GMatrix -- ^ coefficient matrix -> Vector Double -- ^ right-hand side - -> Vector Double -- ^ solution + -> Vector Double -- ^ solution cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0 where n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double)) cgSolve' - :: CGMat m - => Bool -- ^ symmetric + :: Bool -- ^ symmetric -> R -- ^ relative tolerance for the residual (e.g. 1E-4) -> R -- ^ relative tolerance for δx (e.g. 1E-3) -> Int -- ^ maximum number of iterations - -> m -- ^ coefficient matrix + -> GMatrix -- ^ coefficient matrix -> V -- ^ initial solution -> V -- ^ right-hand side -> [CGState] -- ^ solution cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es + +-------------------------------------------------------------------------------- + +instance Testable GMatrix + where + checkT _ = (ok,info) + where + sma = convo2 20 3 + x1 = vect [1..20] + x2 = vect [1..40] + sm = mkSparse sma + dm = toDense sma + + s1 = sm !#> x1 + d1 = dm #> x1 + + s2 = tr sm !#> x2 + d2 = tr dm #> x2 + + sdia = mkDiagR 40 20 (vect [1..10]) + s3 = sdia !#> x1 + s4 = tr sdia !#> x2 + ddia = diagRect 0 (vect [1..10]) 40 20 + d3 = ddia #> x1 + d4 = tr ddia #> x2 + + v = testb 40 + s5 = cgSolve False sm v + d5 = denseSolve dm v + + info = do + print sm + disp (toDense sma) + print s1; print d1 + print s2; print d2 + print s3; print d3 + print s4; print d4 + print s5; print d5 + print $ relativeError Infinity s5 d5 + + ok = s1==d1 + && s2==d2 + && s3==d3 + && s4==d4 + && relativeError Infinity s5 d5 < 1E-10 + + disp = putStr . dispf 2 + + vect = fromList :: [Double] -> Vector Double + + convomat :: Int -> Int -> AssocMatrix + convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]] + + convo2 :: Int -> Int -> AssocMatrix + convo2 n k = m1 ++ m2 + where + m1 = convomat n k + m2 = map (((+n) *** id) *** id) m1 + + testb n = vect $ take n $ cycle ([0..10]++[9,8..1]) + + denseSolve a = flatten . linearSolveLS a . asColumn + + -- mkDiag v = mkDiagR (dim v) (dim v) v + diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Util/Static.hs deleted file mode 100644 index a3f8eb0..0000000 --- a/packages/base/src/Numeric/LinearAlgebra/Util/Static.hs +++ /dev/null @@ -1,70 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE EmptyDataDecls #-} -{-# LANGUAGE Rank2Types #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE TypeOperators #-} - -module Numeric.LinearAlgebra.Util.Static( - Static (ddata), - R, - vect0, sScalar, vect2, vect3, (&) -) where - - -import GHC.TypeLits -import Data.Packed.Numeric -import Numeric.Vector() -import Numeric.LinearAlgebra.Util(Numeric,ℝ) - -lift1F :: (Vector t -> Vector t) -> Static n (Vector t) -> Static n (Vector t) -lift1F f (Static v) = Static (f v) - -lift2F :: (Vector t -> Vector t -> Vector t) -> Static n (Vector t) -> Static n (Vector t) -> Static n (Vector t) -lift2F f (Static u) (Static v) = Static (f u v) - -newtype Static (n :: Nat) t = Static { ddata :: t } deriving Show - -type R n = Static n (Vector ℝ) - - -infixl 4 & -(&) :: R n -> ℝ -> R (n+1) -Static v & x = Static (vjoin [v, scalar x]) - -vect0 :: R 0 -vect0 = Static (fromList[]) - -sScalar :: ℝ -> R 1 -sScalar = Static . scalar - - -vect2 :: ℝ -> ℝ -> R 2 -vect2 x1 x2 = Static (fromList [x1,x2]) - -vect3 :: ℝ -> ℝ -> ℝ -> R 3 -vect3 x1 x2 x3 = Static (fromList [x1,x2,x3]) - - - - - - -instance forall n t . (KnownNat n, Num (Vector t), Numeric t )=> Num (Static n (Vector t)) - where - (+) = lift2F add - (*) = lift2F mul - (-) = lift2F sub - abs = lift1F abs - signum = lift1F signum - negate = lift1F (scale (-1)) - fromInteger x = Static (konst (fromInteger x) d) - where - d = fromIntegral . natVal $ (undefined :: Proxy n) - -data Proxy :: Nat -> * - diff --git a/packages/base/src/Numeric/Sparse.hs b/packages/base/src/Numeric/Sparse.hs index 2df4578..4d05bdc 100644 --- a/packages/base/src/Numeric/Sparse.hs +++ b/packages/base/src/Numeric/Sparse.hs @@ -3,11 +3,11 @@ {-# LANGUAGE FlexibleInstances #-} module Numeric.Sparse( - SMatrix(..), - mkCSR, mkDiag, + GMatrix(..), + mkSparse, mkDiagR, dense, AssocMatrix, toDense, - smXv + gmXv, (!#>) )where import Data.Packed.Numeric @@ -17,8 +17,7 @@ import Control.Arrow((***)) import Control.Monad(when) import Data.List(groupBy, sort) import Foreign.C.Types(CInt(..)) -import Numeric.LinearAlgebra.Util.CG(CGMat,cgSolve) -import Numeric.LinearAlgebra.Algorithms(linearSolveLS, relativeError, NormType(..)) + import Data.Packed.Development import System.IO.Unsafe(unsafePerformIO) import Foreign(Ptr) @@ -29,7 +28,7 @@ c ~!~ msg = when c (error msg) type AssocMatrix = [((Int,Int),Double)] -data SMatrix +data GMatrix = CSR { csrVals :: Vector Double , csrCols :: Vector CInt @@ -46,14 +45,26 @@ data SMatrix } | Diag { diagVals :: Vector Double + , nRows :: Int + , nCols :: Int + } + | Dense + { gmDense :: Matrix Double , nRows :: Int , nCols :: Int } -- | Banded deriving Show -mkCSR :: AssocMatrix -> SMatrix -mkCSR sm' = CSR{..} +dense :: Matrix Double -> GMatrix +dense m = Dense{..} + where + gmDense = m + nRows = rows m + nCols = cols m + +mkSparse :: AssocMatrix -> GMatrix +mkSparse sm' = CSR{..} where sm = sort sm' rws = map ((fromList *** fromList) @@ -78,37 +89,47 @@ mkDiagR r c v nCols = c diagVals = v -mkDiag v = mkDiagR (dim v) (dim v) v - type IV t = CInt -> Ptr CInt -> t type V t = CInt -> Ptr Double -> t type SMxV = V (IV (IV (V (V (IO CInt))))) -smXv :: SMatrix -> Vector Double -> Vector Double -smXv CSR{..} v = unsafePerformIO $ do - dim v /= nCols ~!~ printf "smXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) +gmXv :: GMatrix -> Vector Double -> Vector Double +gmXv CSR{..} v = unsafePerformIO $ do + dim v /= nCols ~!~ printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) r <- createVector nRows app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv" return r -smXv CSC{..} v = unsafePerformIO $ do - dim v /= nCols ~!~ printf "smXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) +gmXv CSC{..} v = unsafePerformIO $ do + dim v /= nCols ~!~ printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) r <- createVector nRows app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv" return r -smXv Diag{..} v +gmXv Diag{..} v | dim v == nCols = vjoin [ subVector 0 (dim diagVals) v `mul` diagVals , konst 0 (nRows - dim diagVals) ] - | otherwise = error $ printf "smXv (Diag): incorrect sizes: (%d,%d) [%d] x %d" + | otherwise = error $ printf "gmXv (Diag): incorrect sizes: (%d,%d) [%d] x %d" nRows nCols (dim diagVals) (dim v) +gmXv Dense{..} v + | dim v == nCols + = mXv gmDense v + | otherwise = error $ printf "gmXv (Dense): incorrect sizes: (%d,%d) x %d" + nRows nCols (dim v) + -instance Contraction SMatrix (Vector Double) (Vector Double) +-- | general matrix - vector product +infixr 8 !#> +(!#>) :: GMatrix -> Vector Double -> Vector Double +(!#>) = gmXv + + +instance Contraction GMatrix (Vector Double) (Vector Double) where - contraction = smXv + contraction = gmXv -------------------------------------------------------------------------------- @@ -127,75 +148,11 @@ toDense asm = assoc (r+1,c+1) 0 asm -instance Transposable SMatrix +instance Transposable GMatrix GMatrix where tr (CSR vs cs rs n m) = CSC vs cs rs m n tr (CSC vs rs cs n m) = CSR vs rs cs m n tr (Diag v n m) = Diag v m n + tr (Dense a n m) = Dense (tr a) m n -instance CGMat SMatrix -instance CGMat (Matrix Double) - --------------------------------------------------------------------------------- - -instance Testable SMatrix - where - checkT _ = (ok,info) - where - sma = convo2 20 3 - x1 = vect [1..20] - x2 = vect [1..40] - sm = mkCSR sma - dm = toDense sma - - s1 = sm ◇ x1 - d1 = dm ◇ x1 - - s2 = tr sm ◇ x2 - d2 = tr dm ◇ x2 - - sdia = mkDiagR 40 20 (vect [1..10]) - s3 = sdia ◇ x1 - s4 = tr sdia ◇ x2 - ddia = diagRect 0 (vect [1..10]) 40 20 - d3 = ddia ◇ x1 - d4 = tr ddia ◇ x2 - - v = testb 40 - s5 = cgSolve False sm v - d5 = denseSolve dm v - - info = do - print sm - disp (toDense sma) - print s1; print d1 - print s2; print d2 - print s3; print d3 - print s4; print d4 - print s5; print d5 - print $ relativeError Infinity s5 d5 - - ok = s1==d1 - && s2==d2 - && s3==d3 - && s4==d4 - && relativeError Infinity s5 d5 < 1E-10 - - disp = putStr . dispf 2 - - vect = fromList :: [Double] -> Vector Double - - convomat :: Int -> Int -> AssocMatrix - convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]] - - convo2 :: Int -> Int -> AssocMatrix - convo2 n k = m1 ++ m2 - where - m1 = convomat n k - m2 = map (((+n) *** id) *** id) m1 - - testb n = vect $ take n $ cycle ([0..10]++[9,8..1]) - - denseSolve a = flatten . linearSolveLS a . asColumn - diff --git a/packages/gsl/src/Numeric/GSL/Fitting.hs b/packages/gsl/src/Numeric/GSL/Fitting.hs index 93fb281..0a92373 100644 --- a/packages/gsl/src/Numeric/GSL/Fitting.hs +++ b/packages/gsl/src/Numeric/GSL/Fitting.hs @@ -116,7 +116,7 @@ err (model,deriv) dat vsol = zip sol errs where dof = length dat - (rows cov) chi = norm2 (fromList $ cost (resMs model) dat sol) js = fromLists $ jacobian (resDs deriv) dat sol - cov = inv $ trans js <.> js + cov = inv $ trans js <> js errs = toList $ scalar c * sqrt (takeDiag cov) diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs index 3803f3b..02beb21 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs @@ -1,5 +1,7 @@ {-# LANGUAGE CPP #-} {-# OPTIONS_GHC -fno-warn-unused-imports -fno-warn-incomplete-patterns #-} +{-# LANGUAGE DataKinds #-} + ----------------------------------------------------------------------------- {- | Module : Numeric.LinearAlgebra.Tests @@ -25,7 +27,8 @@ module Numeric.LinearAlgebra.Tests( ) where import Numeric.LinearAlgebra -import Numeric.HMatrix +import Numeric.HMatrix hiding ((<>)) +import Numeric.LinearAlgebra.Real(L) import Numeric.LinearAlgebra.Util(col,row) import Data.Packed import Numeric.LinearAlgebra.LAPACK @@ -466,18 +469,23 @@ kroneckerTest = utest "kronecker" ok x = (4><2) [3,5..] b = (2><5) [0,5..] v1 = vec (a <> x <> b) - v2 = (trans b `kronecker` a) <.> vec x + v2 = (trans b `kronecker` a) <> vec x s = trans b <> b v3 = vec s - v4 = (dup 5 :: Matrix Double) <.> vech s + v4 = (dup 5 :: Matrix Double) <> vech s ok = v1 == v2 && v3 == v4 && vtrans 1 a == trans a && vtrans (rows a) a == asColumn (vec a) -------------------------------------------------------------------------------- -sparseTest = utest "sparse mul" (fst $ checkT (undefined :: SMatrix)) +sparseTest = utest "sparse" (fst $ checkT (undefined :: GMatrix)) + +-------------------------------------------------------------------------------- +staticTest = utest "static" (fst $ checkT (undefined :: L 3 5)) + +-------------------------------------------------------------------------------- -- | All tests must pass with a maximum dimension of about 20 -- (some tests may fail with bigger sizes due to precision loss). @@ -655,6 +663,7 @@ runTests n = do , convolutionTest , kroneckerTest , sparseTest + , staticTest ] when (errors c + failures c > 0) exitFailure return () -- cgit v1.2.3