From 21ccf5342555bd41a61ed132b09eacebf3c71feb Mon Sep 17 00:00:00 2001 From: Vivian McPhail Date: Mon, 5 Jul 2010 09:21:57 +0000 Subject: added Vectors typeclass and refactored --- lib/Numeric/GSL/Vector.hs | 70 ++++++++++++++++++++++ lib/Numeric/GSL/gsl-aux.c | 102 ++++++++++++++++++++++++++++++-- lib/Numeric/LinearAlgebra/Algorithms.hs | 5 +- lib/Numeric/LinearAlgebra/Interface.hs | 8 ++- lib/Numeric/LinearAlgebra/Linear.hs | 62 ++++++++++++++++++- 5 files changed, 239 insertions(+), 8 deletions(-) (limited to 'lib/Numeric') diff --git a/lib/Numeric/GSL/Vector.hs b/lib/Numeric/GSL/Vector.hs index d09323b..97a0f9c 100644 --- a/lib/Numeric/GSL/Vector.hs +++ b/lib/Numeric/GSL/Vector.hs @@ -14,6 +14,8 @@ ----------------------------------------------------------------------------- module Numeric.GSL.Vector ( + sumF, sumR, sumQ, sumC, + dotF, dotR, dotQ, dotC, FunCodeS(..), toScalarR, toScalarF, FunCodeV(..), vectorMapR, vectorMapC, vectorMapF, FunCodeSV(..), vectorMapValR, vectorMapValC, vectorMapValF, @@ -76,6 +78,74 @@ data FunCodeS = Norm2 ------------------------------------------------------------------ +-- | sum of elements +sumF :: Vector Float -> Float +sumF x = unsafePerformIO $ do + r <- createVector 1 + app2 c_sumF vec x vec r "sumF" + return $ r @> 0 + +-- | sum of elements +sumR :: Vector Double -> Double +sumR x = unsafePerformIO $ do + r <- createVector 1 + app2 c_sumR vec x vec r "sumR" + return $ r @> 0 + +-- | sum of elements +sumQ :: Vector (Complex Float) -> Complex Float +sumQ x = unsafePerformIO $ do + r <- createVector 1 + app2 c_sumQ vec x vec r "sumQ" + return $ r @> 0 + +-- | sum of elements +sumC :: Vector (Complex Double) -> Complex Double +sumC x = unsafePerformIO $ do + r <- createVector 1 + app2 c_sumC vec x vec r "sumC" + return $ r @> 0 + +foreign import ccall safe "gsl-aux.h sumF" c_sumF :: TFF +foreign import ccall safe "gsl-aux.h sumR" c_sumR :: TVV +foreign import ccall safe "gsl-aux.h sumQ" c_sumQ :: TQVQV +foreign import ccall safe "gsl-aux.h sumC" c_sumC :: TCVCV + +-- | dot product +dotF :: Vector Float -> Vector Float -> Float +dotF x y = unsafePerformIO $ do + r <- createVector 1 + app3 c_dotF vec x vec y vec r "dotF" + return $ r @> 0 + +-- | dot product +dotR :: Vector Double -> Vector Double -> Double +dotR x y = unsafePerformIO $ do + r <- createVector 1 + app3 c_dotR vec x vec y vec r "dotR" + return $ r @> 0 + +-- | dot product +dotQ :: Vector (Complex Float) -> Vector (Complex Float) -> Complex Float +dotQ x y = unsafePerformIO $ do + r <- createVector 1 + app3 c_dotQ vec x vec y vec r "dotQ" + return $ r @> 0 + +-- | dot product +dotC :: Vector (Complex Double) -> Vector (Complex Double) -> Complex Double +dotC x y = unsafePerformIO $ do + r <- createVector 1 + app3 c_dotC vec x vec y vec r "dotC" + return $ r @> 0 + +foreign import ccall safe "gsl-aux.h dotF" c_dotF :: TFFF +foreign import ccall safe "gsl-aux.h dotR" c_dotR :: TVVV +foreign import ccall safe "gsl-aux.h dotQ" c_dotQ :: TQVQVQV +foreign import ccall safe "gsl-aux.h dotC" c_dotC :: TCVCVCV + +------------------------------------------------------------------ + toScalarAux fun code v = unsafePerformIO $ do r <- createVector 1 app2 (fun (fromei code)) vec v vec r "toScalarAux" diff --git a/lib/Numeric/GSL/gsl-aux.c b/lib/Numeric/GSL/gsl-aux.c index 6bb16f0..fe33766 100644 --- a/lib/Numeric/GSL/gsl-aux.c +++ b/lib/Numeric/GSL/gsl-aux.c @@ -76,12 +76,12 @@ #define FVVIEW(A) gsl_vector_float_view A = gsl_vector_float_view_array(A##p,A##n) #define FMVIEW(A) gsl_matrix_float_view A = gsl_matrix_float_view_array(A##p,A##r,A##c) -#define QVVIEW(A) gsl_vector_float_complex_view A = gsl_vector_float_complex_view_array((float*)A##p,A##n) -#define QMVIEW(A) gsl_matrix_float_complex_view A = gsl_matrix_float_complex_view_array((float*)A##p,A##r,A##c) +#define QVVIEW(A) gsl_vector_complex_float_view A = gsl_vector_float_complex_view_array((float*)A##p,A##n) +#define QMVIEW(A) gsl_matrix_complex_float_view A = gsl_matrix_float_complex_view_array((float*)A##p,A##r,A##c) #define KFVVIEW(A) gsl_vector_float_const_view A = gsl_vector_float_const_view_array(A##p,A##n) #define KFMVIEW(A) gsl_matrix_float_const_view A = gsl_matrix_float_const_view_array(A##p,A##r,A##c) -#define KQVVIEW(A) gsl_vector_float_complex_const_view A = gsl_vector_float_complex_const_view_array((float*)A##p,A##n) -#define KQMVIEW(A) gsl_matrix_float_complex_const_view A = gsl_matrix_float_complex_const_view_array((float*)A##p,A##r,A##c) +#define KQVVIEW(A) gsl_vector_complex_float_const_view A = gsl_vector_complex_float_const_view_array((float*)A##p,A##n) +#define KQMVIEW(A) gsl_matrix_complex_float_const_view A = gsl_matrix_complex_float_const_view_array((float*)A##p,A##r,A##c) #define V(a) (&a.vector) #define M(a) (&a.matrix) @@ -103,6 +103,100 @@ void no_abort_on_error() { } +int sumF(KFVEC(x),FVEC(r)) { + DEBUGMSG("sumF"); + REQUIRES(rn==1,BAD_SIZE); + int i; + float res = 0; + for (i = 0; i < xn; i++) res += xp[i]; + rp[0] = res; + OK +} + +int sumR(KRVEC(x),RVEC(r)) { + DEBUGMSG("sumR"); + REQUIRES(rn==1,BAD_SIZE); + int i; + double res = 0; + for (i = 0; i < xn; i++) res += xp[i]; + rp[0] = res; + OK +} + +int sumQ(KQVEC(x),QVEC(r)) { + DEBUGMSG("sumQ"); + REQUIRES(rn==1,BAD_SIZE); + int i; + gsl_complex_float res; + res.dat[0] = 0; + res.dat[1] = 0; + for (i = 0; i < xn; i++) { + res.dat[0] += xp[i].dat[0]; + res.dat[1] += xp[i].dat[1]; + } + rp[0] = res; + OK +} + +int sumC(KCVEC(x),CVEC(r)) { + DEBUGMSG("sumC"); + REQUIRES(rn==1,BAD_SIZE); + int i; + gsl_complex res; + res.dat[0] = 0; + res.dat[1] = 0; + for (i = 0; i < xn; i++) { + res.dat[0] += xp[i].dat[0]; + res.dat[1] += xp[i].dat[1]; + } + rp[0] = res; + OK +} + +int dotF(KFVEC(x), KFVEC(y), FVEC(r)) { + DEBUGMSG("dotF"); + REQUIRES(xn==yn,BAD_SIZE); + REQUIRES(rn==1,BAD_SIZE); + DEBUGMSG("dotF"); + KFVVIEW(x); + KFVVIEW(y); + gsl_blas_sdot(V(x),V(y),rp); + OK +} + +int dotR(KRVEC(x), KRVEC(y), RVEC(r)) { + DEBUGMSG("dotR"); + REQUIRES(xn==yn,BAD_SIZE); + REQUIRES(rn==1,BAD_SIZE); + DEBUGMSG("dotR"); + KDVVIEW(x); + KDVVIEW(y); + gsl_blas_ddot(V(x),V(y),rp); + OK +} + +int dotQ(KQVEC(x), KQVEC(y), QVEC(r)) { + DEBUGMSG("dotQ"); + REQUIRES(xn==yn,BAD_SIZE); + REQUIRES(rn==1,BAD_SIZE); + DEBUGMSG("dotQ"); + KQVVIEW(x); + KQVVIEW(y); + gsl_blas_cdotu(V(x),V(y),rp); + OK +} + +int dotC(KCVEC(x), KCVEC(y), CVEC(r)) { + DEBUGMSG("dotC"); + REQUIRES(xn==yn,BAD_SIZE); + REQUIRES(rn==1,BAD_SIZE); + DEBUGMSG("dotC"); + KCVVIEW(x); + KCVVIEW(y); + gsl_blas_zdotu(V(x),V(y),rp); + OK +} + int toScalarR(int code, KRVEC(x), RVEC(r)) { REQUIRES(rn==1,BAD_SIZE); DEBUGMSG("toScalarR"); diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index 55398e0..e058490 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs @@ -22,7 +22,7 @@ module Numeric.LinearAlgebra.Algorithms ( -- * Supported types Field(), -- * Products - multiply, dot, + multiply, -- dot, moved dot to typeclass outer, kronecker, -- * Linear Systems linearSolve, @@ -707,12 +707,13 @@ luFact (l_u,perm) | r <= c = (l ,u ,p, s) -------------------------------------------------- +{- moved to Numeric.LinearAlgebra.Interface Vector typeclass -- | Euclidean inner product. dot :: (Field t) => Vector t -> Vector t -> t dot u v = multiply r c @@> (0,0) where r = asRow u c = asColumn v - +-} {- | Outer product of two vectors. diff --git a/lib/Numeric/LinearAlgebra/Interface.hs b/lib/Numeric/LinearAlgebra/Interface.hs index 30547d9..f8917a0 100644 --- a/lib/Numeric/LinearAlgebra/Interface.hs +++ b/lib/Numeric/LinearAlgebra/Interface.hs @@ -28,6 +28,9 @@ import Numeric.LinearAlgebra.Instances() import Data.Packed.Vector import Data.Packed.Matrix import Numeric.LinearAlgebra.Algorithms +import Numeric.LinearAlgebra.Linear + +--import Numeric.GSL.Vector class Mul a b c | a b -> c where infixl 7 <> @@ -46,7 +49,8 @@ instance Mul Vector Matrix Vector where --------------------------------------------------- -- | Dot product: @u \<.\> v = dot u v@ -(<.>) :: (Field t) => Vector t -> Vector t -> t +--(<.>) :: (Field t) => Vector t -> Vector t -> t +(<.>) :: Vectors Vector t => Vector t -> Vector t -> t infixl 7 <.> (<.>) = dot @@ -115,3 +119,5 @@ a <|> b = joinH a b -- (<->) :: (Element t, Joinable a b) => a t -> b t -> Matrix t a <-> b = joinV a b +---------------------------------------------------- + diff --git a/lib/Numeric/LinearAlgebra/Linear.hs b/lib/Numeric/LinearAlgebra/Linear.hs index 481d72a..1651247 100644 --- a/lib/Numeric/LinearAlgebra/Linear.hs +++ b/lib/Numeric/LinearAlgebra/Linear.hs @@ -1,4 +1,5 @@ {-# LANGUAGE UndecidableInstances, MultiParamTypeClasses, FlexibleInstances #-} +{-# LANGUAGE FlexibleContexts #-} ----------------------------------------------------------------------------- {- | Module : Numeric.LinearAlgebra.Linear @@ -15,6 +16,7 @@ Basic optimized operations on vectors and matrices. ----------------------------------------------------------------------------- module Numeric.LinearAlgebra.Linear ( + Vectors(..), normalise, Linear(..) ) where @@ -23,6 +25,64 @@ import Data.Packed.Matrix import Data.Complex import Numeric.GSL.Vector +-- | normalise a vector to unit length +normalise :: (Floating a, Vectors Vector a, + Linear Vector a, Fractional (Vector a)) => Vector a -> Vector a +normalise v = scaleRecip (vectorSum v) v + +-- | basic Vector functions +class (Num b) => Vectors a b where + vectorSum :: a b -> b + euclidean :: a b -> b + absSum :: a b -> b + vectorMin :: a b -> b + vectorMax :: a b -> b + minIdx :: a b -> Int + maxIdx :: a b -> Int + dot :: a b -> a b -> b + +instance Vectors Vector Float where + vectorSum = sumF + euclidean = toScalarF Norm2 + absSum = toScalarF AbsSum + vectorMin = toScalarF Min + vectorMax = toScalarF Max + minIdx = round . toScalarF MinIdx + maxIdx = round . toScalarF MaxIdx + dot = dotF + +instance Vectors Vector Double where + vectorSum = sumR + euclidean = toScalarR Norm2 + absSum = toScalarR AbsSum + vectorMin = toScalarR Min + vectorMax = toScalarR Max + minIdx = round . toScalarR MinIdx + maxIdx = round . toScalarR MaxIdx + dot = dotR + +instance Vectors Vector (Complex Float) where + vectorSum = sumQ + euclidean = undefined + absSum = undefined + vectorMin = undefined + vectorMax = undefined + minIdx = undefined + maxIdx = undefined + dot = dotQ + +instance Vectors Vector (Complex Double) where + vectorSum = sumC + euclidean = undefined + absSum = undefined + vectorMin = undefined + vectorMax = undefined + minIdx = undefined + maxIdx = undefined + dot = dotC + +---------------------------------------------------- + -- | Basic element-by-element functions. class (Container c e) => Linear c e where -- | create a structure with a single element @@ -50,7 +110,7 @@ instance Linear Vector Float where sub = vectorZipF Sub mul = vectorZipF Mul divide = vectorZipF Div - equal u v = dim u == dim v && vectorFMax (vectorMapF Abs (sub u v)) == 0.0 + equal u v = dim u == dim v && vectorMax (vectorMapF Abs (sub u v)) == 0.0 scalar x = fromList [x] instance Linear Vector Double where -- cgit v1.2.3