From 365e2435e71de10ebe849acac5a107b6f43817c4 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Sat, 24 May 2014 20:33:05 +0200 Subject: initial support for static dimension checking --- packages/base/hmatrix.cabal | 1 + packages/base/src/Numeric/HMatrix.hs | 2 +- packages/base/src/Numeric/LinearAlgebra/Data.hs | 16 +++-- packages/base/src/Numeric/LinearAlgebra/Util.hs | 14 +++++ .../base/src/Numeric/LinearAlgebra/Util/Static.hs | 70 ++++++++++++++++++++++ 5 files changed, 97 insertions(+), 6 deletions(-) create mode 100644 packages/base/src/Numeric/LinearAlgebra/Util/Static.hs diff --git a/packages/base/hmatrix.cabal b/packages/base/hmatrix.cabal index 01e3c26..e958de0 100644 --- a/packages/base/hmatrix.cabal +++ b/packages/base/hmatrix.cabal @@ -67,6 +67,7 @@ library Numeric.LinearAlgebra.Random Numeric.Conversion Numeric.Sparse + Numeric.LinearAlgebra.Util.Static C-sources: src/C/lapack-aux.c src/C/vector-aux.c diff --git a/packages/base/src/Numeric/HMatrix.hs b/packages/base/src/Numeric/HMatrix.hs index a56c3d2..d5c66fb 100644 --- a/packages/base/src/Numeric/HMatrix.hs +++ b/packages/base/src/Numeric/HMatrix.hs @@ -134,7 +134,7 @@ module Numeric.HMatrix ( -- * Misc meanCov, peps, relativeError, haussholder, optimiseMult, dot, udot, mXm, mXv, smXv, (<>), (◇), Seed, checkT, -- * Auxiliary classes - Element, Container, Product, Contraction, LSDiv, + Element, Container, Product, Numeric, Contraction, LSDiv, Complexable, RealElement, RealOf, ComplexOf, SingleOf, DoubleOf, IndexOf, diff --git a/packages/base/src/Numeric/LinearAlgebra/Data.hs b/packages/base/src/Numeric/LinearAlgebra/Data.hs index 89bebbe..3128a24 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Data.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Data.hs @@ -48,16 +48,20 @@ module Numeric.LinearAlgebra.Data( -- * Find elements find, maxIndex, minIndex, maxElement, minElement, atIndex, + -- * Sparse + SMatrix, AssocMatrix, mkCSR, toDense, + mkDiag, + + -- * Static dimensions + + Static, ddata, R, vect0, sScalar, vect2, vect3, (&), + -- * IO disp, loadMatrix, saveMatrix, latexFormat, dispf, disps, dispcf, format, - -- * Sparse - SMatrix, AssocMatrix, mkCSR, toDense, - mkDiag, - -- * Conversion Convert(..), @@ -75,7 +79,9 @@ module Numeric.LinearAlgebra.Data( import Data.Packed.Vector import Data.Packed.Matrix import Data.Packed.Numeric -import Numeric.LinearAlgebra.Util +import Numeric.LinearAlgebra.Util hiding ((&)) import Data.Complex import Numeric.Sparse +import Numeric.LinearAlgebra.Util.Static + diff --git a/packages/base/src/Numeric/LinearAlgebra/Util.hs b/packages/base/src/Numeric/LinearAlgebra/Util.hs index a7d6946..a319785 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Util.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Util.hs @@ -28,6 +28,7 @@ module Numeric.LinearAlgebra.Util( (&), (¦), (——), (#), (?), (¿), Indexable(..), size, + Numeric, rand, randn, cross, norm, @@ -101,6 +102,19 @@ mat -> Matrix ℝ mat c = reshape c . fromList + + +class ( Container Vector t + , Container Matrix t + , Konst t Int Vector + , Konst t (Int,Int) Matrix + ) => Numeric t + +instance Numeric Double +instance Numeric (Complex Double) + + + {- | print a real matrix with given number of digits after the decimal point >>> disp 5 $ ident 2 / 3 diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Util/Static.hs new file mode 100644 index 0000000..a3f8eb0 --- /dev/null +++ b/packages/base/src/Numeric/LinearAlgebra/Util/Static.hs @@ -0,0 +1,70 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE EmptyDataDecls #-} +{-# LANGUAGE Rank2Types #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TypeOperators #-} + +module Numeric.LinearAlgebra.Util.Static( + Static (ddata), + R, + vect0, sScalar, vect2, vect3, (&) +) where + + +import GHC.TypeLits +import Data.Packed.Numeric +import Numeric.Vector() +import Numeric.LinearAlgebra.Util(Numeric,ℝ) + +lift1F :: (Vector t -> Vector t) -> Static n (Vector t) -> Static n (Vector t) +lift1F f (Static v) = Static (f v) + +lift2F :: (Vector t -> Vector t -> Vector t) -> Static n (Vector t) -> Static n (Vector t) -> Static n (Vector t) +lift2F f (Static u) (Static v) = Static (f u v) + +newtype Static (n :: Nat) t = Static { ddata :: t } deriving Show + +type R n = Static n (Vector ℝ) + + +infixl 4 & +(&) :: R n -> ℝ -> R (n+1) +Static v & x = Static (vjoin [v, scalar x]) + +vect0 :: R 0 +vect0 = Static (fromList[]) + +sScalar :: ℝ -> R 1 +sScalar = Static . scalar + + +vect2 :: ℝ -> ℝ -> R 2 +vect2 x1 x2 = Static (fromList [x1,x2]) + +vect3 :: ℝ -> ℝ -> ℝ -> R 3 +vect3 x1 x2 x3 = Static (fromList [x1,x2,x3]) + + + + + + +instance forall n t . (KnownNat n, Num (Vector t), Numeric t )=> Num (Static n (Vector t)) + where + (+) = lift2F add + (*) = lift2F mul + (-) = lift2F sub + abs = lift1F abs + signum = lift1F signum + negate = lift1F (scale (-1)) + fromInteger x = Static (konst (fromInteger x) d) + where + d = fromIntegral . natVal $ (undefined :: Proxy n) + +data Proxy :: Nat -> * + -- cgit v1.2.3