summaryrefslogtreecommitdiff
path: root/lib/Numeric
diff options
context:
space:
mode:
authorVivian McPhail <haskell.vivian.mcphail@gmail.com>2010-07-05 09:21:57 +0000
committerVivian McPhail <haskell.vivian.mcphail@gmail.com>2010-07-05 09:21:57 +0000
commit21ccf5342555bd41a61ed132b09eacebf3c71feb (patch)
treebad2e548a20ea0d1dbe3813199e40b315634ac7d /lib/Numeric
parentdd054da0524abdb14d013c9f9f43272515b77b6e (diff)
added Vectors typeclass and refactored
Diffstat (limited to 'lib/Numeric')
-rw-r--r--lib/Numeric/GSL/Vector.hs70
-rw-r--r--lib/Numeric/GSL/gsl-aux.c102
-rw-r--r--lib/Numeric/LinearAlgebra/Algorithms.hs5
-rw-r--r--lib/Numeric/LinearAlgebra/Interface.hs8
-rw-r--r--lib/Numeric/LinearAlgebra/Linear.hs62
5 files changed, 239 insertions, 8 deletions
diff --git a/lib/Numeric/GSL/Vector.hs b/lib/Numeric/GSL/Vector.hs
index d09323b..97a0f9c 100644
--- a/lib/Numeric/GSL/Vector.hs
+++ b/lib/Numeric/GSL/Vector.hs
@@ -14,6 +14,8 @@
14----------------------------------------------------------------------------- 14-----------------------------------------------------------------------------
15 15
16module Numeric.GSL.Vector ( 16module Numeric.GSL.Vector (
17 sumF, sumR, sumQ, sumC,
18 dotF, dotR, dotQ, dotC,
17 FunCodeS(..), toScalarR, toScalarF, 19 FunCodeS(..), toScalarR, toScalarF,
18 FunCodeV(..), vectorMapR, vectorMapC, vectorMapF, 20 FunCodeV(..), vectorMapR, vectorMapC, vectorMapF,
19 FunCodeSV(..), vectorMapValR, vectorMapValC, vectorMapValF, 21 FunCodeSV(..), vectorMapValR, vectorMapValC, vectorMapValF,
@@ -76,6 +78,74 @@ data FunCodeS = Norm2
76 78
77------------------------------------------------------------------ 79------------------------------------------------------------------
78 80
81-- | sum of elements
82sumF :: Vector Float -> Float
83sumF x = unsafePerformIO $ do
84 r <- createVector 1
85 app2 c_sumF vec x vec r "sumF"
86 return $ r @> 0
87
88-- | sum of elements
89sumR :: Vector Double -> Double
90sumR x = unsafePerformIO $ do
91 r <- createVector 1
92 app2 c_sumR vec x vec r "sumR"
93 return $ r @> 0
94
95-- | sum of elements
96sumQ :: Vector (Complex Float) -> Complex Float
97sumQ x = unsafePerformIO $ do
98 r <- createVector 1
99 app2 c_sumQ vec x vec r "sumQ"
100 return $ r @> 0
101
102-- | sum of elements
103sumC :: Vector (Complex Double) -> Complex Double
104sumC x = unsafePerformIO $ do
105 r <- createVector 1
106 app2 c_sumC vec x vec r "sumC"
107 return $ r @> 0
108
109foreign import ccall safe "gsl-aux.h sumF" c_sumF :: TFF
110foreign import ccall safe "gsl-aux.h sumR" c_sumR :: TVV
111foreign import ccall safe "gsl-aux.h sumQ" c_sumQ :: TQVQV
112foreign import ccall safe "gsl-aux.h sumC" c_sumC :: TCVCV
113
114-- | dot product
115dotF :: Vector Float -> Vector Float -> Float
116dotF x y = unsafePerformIO $ do
117 r <- createVector 1
118 app3 c_dotF vec x vec y vec r "dotF"
119 return $ r @> 0
120
121-- | dot product
122dotR :: Vector Double -> Vector Double -> Double
123dotR x y = unsafePerformIO $ do
124 r <- createVector 1
125 app3 c_dotR vec x vec y vec r "dotR"
126 return $ r @> 0
127
128-- | dot product
129dotQ :: Vector (Complex Float) -> Vector (Complex Float) -> Complex Float
130dotQ x y = unsafePerformIO $ do
131 r <- createVector 1
132 app3 c_dotQ vec x vec y vec r "dotQ"
133 return $ r @> 0
134
135-- | dot product
136dotC :: Vector (Complex Double) -> Vector (Complex Double) -> Complex Double
137dotC x y = unsafePerformIO $ do
138 r <- createVector 1
139 app3 c_dotC vec x vec y vec r "dotC"
140 return $ r @> 0
141
142foreign import ccall safe "gsl-aux.h dotF" c_dotF :: TFFF
143foreign import ccall safe "gsl-aux.h dotR" c_dotR :: TVVV
144foreign import ccall safe "gsl-aux.h dotQ" c_dotQ :: TQVQVQV
145foreign import ccall safe "gsl-aux.h dotC" c_dotC :: TCVCVCV
146
147------------------------------------------------------------------
148
79toScalarAux fun code v = unsafePerformIO $ do 149toScalarAux fun code v = unsafePerformIO $ do
80 r <- createVector 1 150 r <- createVector 1
81 app2 (fun (fromei code)) vec v vec r "toScalarAux" 151 app2 (fun (fromei code)) vec v vec r "toScalarAux"
diff --git a/lib/Numeric/GSL/gsl-aux.c b/lib/Numeric/GSL/gsl-aux.c
index 6bb16f0..fe33766 100644
--- a/lib/Numeric/GSL/gsl-aux.c
+++ b/lib/Numeric/GSL/gsl-aux.c
@@ -76,12 +76,12 @@
76 76
77#define FVVIEW(A) gsl_vector_float_view A = gsl_vector_float_view_array(A##p,A##n) 77#define FVVIEW(A) gsl_vector_float_view A = gsl_vector_float_view_array(A##p,A##n)
78#define FMVIEW(A) gsl_matrix_float_view A = gsl_matrix_float_view_array(A##p,A##r,A##c) 78#define FMVIEW(A) gsl_matrix_float_view A = gsl_matrix_float_view_array(A##p,A##r,A##c)
79#define QVVIEW(A) gsl_vector_float_complex_view A = gsl_vector_float_complex_view_array((float*)A##p,A##n) 79#define QVVIEW(A) gsl_vector_complex_float_view A = gsl_vector_float_complex_view_array((float*)A##p,A##n)
80#define QMVIEW(A) gsl_matrix_float_complex_view A = gsl_matrix_float_complex_view_array((float*)A##p,A##r,A##c) 80#define QMVIEW(A) gsl_matrix_complex_float_view A = gsl_matrix_float_complex_view_array((float*)A##p,A##r,A##c)
81#define KFVVIEW(A) gsl_vector_float_const_view A = gsl_vector_float_const_view_array(A##p,A##n) 81#define KFVVIEW(A) gsl_vector_float_const_view A = gsl_vector_float_const_view_array(A##p,A##n)
82#define KFMVIEW(A) gsl_matrix_float_const_view A = gsl_matrix_float_const_view_array(A##p,A##r,A##c) 82#define KFMVIEW(A) gsl_matrix_float_const_view A = gsl_matrix_float_const_view_array(A##p,A##r,A##c)
83#define KQVVIEW(A) gsl_vector_float_complex_const_view A = gsl_vector_float_complex_const_view_array((float*)A##p,A##n) 83#define KQVVIEW(A) gsl_vector_complex_float_const_view A = gsl_vector_complex_float_const_view_array((float*)A##p,A##n)
84#define KQMVIEW(A) gsl_matrix_float_complex_const_view A = gsl_matrix_float_complex_const_view_array((float*)A##p,A##r,A##c) 84#define KQMVIEW(A) gsl_matrix_complex_float_const_view A = gsl_matrix_complex_float_const_view_array((float*)A##p,A##r,A##c)
85 85
86#define V(a) (&a.vector) 86#define V(a) (&a.vector)
87#define M(a) (&a.matrix) 87#define M(a) (&a.matrix)
@@ -103,6 +103,100 @@ void no_abort_on_error() {
103} 103}
104 104
105 105
106int sumF(KFVEC(x),FVEC(r)) {
107 DEBUGMSG("sumF");
108 REQUIRES(rn==1,BAD_SIZE);
109 int i;
110 float res = 0;
111 for (i = 0; i < xn; i++) res += xp[i];
112 rp[0] = res;
113 OK
114}
115
116int sumR(KRVEC(x),RVEC(r)) {
117 DEBUGMSG("sumR");
118 REQUIRES(rn==1,BAD_SIZE);
119 int i;
120 double res = 0;
121 for (i = 0; i < xn; i++) res += xp[i];
122 rp[0] = res;
123 OK
124}
125
126int sumQ(KQVEC(x),QVEC(r)) {
127 DEBUGMSG("sumQ");
128 REQUIRES(rn==1,BAD_SIZE);
129 int i;
130 gsl_complex_float res;
131 res.dat[0] = 0;
132 res.dat[1] = 0;
133 for (i = 0; i < xn; i++) {
134 res.dat[0] += xp[i].dat[0];
135 res.dat[1] += xp[i].dat[1];
136 }
137 rp[0] = res;
138 OK
139}
140
141int sumC(KCVEC(x),CVEC(r)) {
142 DEBUGMSG("sumC");
143 REQUIRES(rn==1,BAD_SIZE);
144 int i;
145 gsl_complex res;
146 res.dat[0] = 0;
147 res.dat[1] = 0;
148 for (i = 0; i < xn; i++) {
149 res.dat[0] += xp[i].dat[0];
150 res.dat[1] += xp[i].dat[1];
151 }
152 rp[0] = res;
153 OK
154}
155
156int dotF(KFVEC(x), KFVEC(y), FVEC(r)) {
157 DEBUGMSG("dotF");
158 REQUIRES(xn==yn,BAD_SIZE);
159 REQUIRES(rn==1,BAD_SIZE);
160 DEBUGMSG("dotF");
161 KFVVIEW(x);
162 KFVVIEW(y);
163 gsl_blas_sdot(V(x),V(y),rp);
164 OK
165}
166
167int dotR(KRVEC(x), KRVEC(y), RVEC(r)) {
168 DEBUGMSG("dotR");
169 REQUIRES(xn==yn,BAD_SIZE);
170 REQUIRES(rn==1,BAD_SIZE);
171 DEBUGMSG("dotR");
172 KDVVIEW(x);
173 KDVVIEW(y);
174 gsl_blas_ddot(V(x),V(y),rp);
175 OK
176}
177
178int dotQ(KQVEC(x), KQVEC(y), QVEC(r)) {
179 DEBUGMSG("dotQ");
180 REQUIRES(xn==yn,BAD_SIZE);
181 REQUIRES(rn==1,BAD_SIZE);
182 DEBUGMSG("dotQ");
183 KQVVIEW(x);
184 KQVVIEW(y);
185 gsl_blas_cdotu(V(x),V(y),rp);
186 OK
187}
188
189int dotC(KCVEC(x), KCVEC(y), CVEC(r)) {
190 DEBUGMSG("dotC");
191 REQUIRES(xn==yn,BAD_SIZE);
192 REQUIRES(rn==1,BAD_SIZE);
193 DEBUGMSG("dotC");
194 KCVVIEW(x);
195 KCVVIEW(y);
196 gsl_blas_zdotu(V(x),V(y),rp);
197 OK
198}
199
106int toScalarR(int code, KRVEC(x), RVEC(r)) { 200int toScalarR(int code, KRVEC(x), RVEC(r)) {
107 REQUIRES(rn==1,BAD_SIZE); 201 REQUIRES(rn==1,BAD_SIZE);
108 DEBUGMSG("toScalarR"); 202 DEBUGMSG("toScalarR");
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs
index 55398e0..e058490 100644
--- a/lib/Numeric/LinearAlgebra/Algorithms.hs
+++ b/lib/Numeric/LinearAlgebra/Algorithms.hs
@@ -22,7 +22,7 @@ module Numeric.LinearAlgebra.Algorithms (
22-- * Supported types 22-- * Supported types
23 Field(), 23 Field(),
24-- * Products 24-- * Products
25 multiply, dot, 25 multiply, -- dot, moved dot to typeclass
26 outer, kronecker, 26 outer, kronecker,
27-- * Linear Systems 27-- * Linear Systems
28 linearSolve, 28 linearSolve,
@@ -707,12 +707,13 @@ luFact (l_u,perm) | r <= c = (l ,u ,p, s)
707 707
708-------------------------------------------------- 708--------------------------------------------------
709 709
710{- moved to Numeric.LinearAlgebra.Interface Vector typeclass
710-- | Euclidean inner product. 711-- | Euclidean inner product.
711dot :: (Field t) => Vector t -> Vector t -> t 712dot :: (Field t) => Vector t -> Vector t -> t
712dot u v = multiply r c @@> (0,0) 713dot u v = multiply r c @@> (0,0)
713 where r = asRow u 714 where r = asRow u
714 c = asColumn v 715 c = asColumn v
715 716-}
716 717
717{- | Outer product of two vectors. 718{- | Outer product of two vectors.
718 719
diff --git a/lib/Numeric/LinearAlgebra/Interface.hs b/lib/Numeric/LinearAlgebra/Interface.hs
index 30547d9..f8917a0 100644
--- a/lib/Numeric/LinearAlgebra/Interface.hs
+++ b/lib/Numeric/LinearAlgebra/Interface.hs
@@ -28,6 +28,9 @@ import Numeric.LinearAlgebra.Instances()
28import Data.Packed.Vector 28import Data.Packed.Vector
29import Data.Packed.Matrix 29import Data.Packed.Matrix
30import Numeric.LinearAlgebra.Algorithms 30import Numeric.LinearAlgebra.Algorithms
31import Numeric.LinearAlgebra.Linear
32
33--import Numeric.GSL.Vector
31 34
32class Mul a b c | a b -> c where 35class Mul a b c | a b -> c where
33 infixl 7 <> 36 infixl 7 <>
@@ -46,7 +49,8 @@ instance Mul Vector Matrix Vector where
46--------------------------------------------------- 49---------------------------------------------------
47 50
48-- | Dot product: @u \<.\> v = dot u v@ 51-- | Dot product: @u \<.\> v = dot u v@
49(<.>) :: (Field t) => Vector t -> Vector t -> t 52--(<.>) :: (Field t) => Vector t -> Vector t -> t
53(<.>) :: Vectors Vector t => Vector t -> Vector t -> t
50infixl 7 <.> 54infixl 7 <.>
51(<.>) = dot 55(<.>) = dot
52 56
@@ -115,3 +119,5 @@ a <|> b = joinH a b
115-- (<->) :: (Element t, Joinable a b) => a t -> b t -> Matrix t 119-- (<->) :: (Element t, Joinable a b) => a t -> b t -> Matrix t
116a <-> b = joinV a b 120a <-> b = joinV a b
117 121
122----------------------------------------------------
123
diff --git a/lib/Numeric/LinearAlgebra/Linear.hs b/lib/Numeric/LinearAlgebra/Linear.hs
index 481d72a..1651247 100644
--- a/lib/Numeric/LinearAlgebra/Linear.hs
+++ b/lib/Numeric/LinearAlgebra/Linear.hs
@@ -1,4 +1,5 @@
1{-# LANGUAGE UndecidableInstances, MultiParamTypeClasses, FlexibleInstances #-} 1{-# LANGUAGE UndecidableInstances, MultiParamTypeClasses, FlexibleInstances #-}
2{-# LANGUAGE FlexibleContexts #-}
2----------------------------------------------------------------------------- 3-----------------------------------------------------------------------------
3{- | 4{- |
4Module : Numeric.LinearAlgebra.Linear 5Module : Numeric.LinearAlgebra.Linear
@@ -15,6 +16,7 @@ Basic optimized operations on vectors and matrices.
15----------------------------------------------------------------------------- 16-----------------------------------------------------------------------------
16 17
17module Numeric.LinearAlgebra.Linear ( 18module Numeric.LinearAlgebra.Linear (
19 Vectors(..), normalise,
18 Linear(..) 20 Linear(..)
19) where 21) where
20 22
@@ -23,6 +25,64 @@ import Data.Packed.Matrix
23import Data.Complex 25import Data.Complex
24import Numeric.GSL.Vector 26import Numeric.GSL.Vector
25 27
28-- | normalise a vector to unit length
29normalise :: (Floating a, Vectors Vector a,
30 Linear Vector a, Fractional (Vector a)) => Vector a -> Vector a
31normalise v = scaleRecip (vectorSum v) v
32
33-- | basic Vector functions
34class (Num b) => Vectors a b where
35 vectorSum :: a b -> b
36 euclidean :: a b -> b
37 absSum :: a b -> b
38 vectorMin :: a b -> b
39 vectorMax :: a b -> b
40 minIdx :: a b -> Int
41 maxIdx :: a b -> Int
42 dot :: a b -> a b -> b
43
44instance Vectors Vector Float where
45 vectorSum = sumF
46 euclidean = toScalarF Norm2
47 absSum = toScalarF AbsSum
48 vectorMin = toScalarF Min
49 vectorMax = toScalarF Max
50 minIdx = round . toScalarF MinIdx
51 maxIdx = round . toScalarF MaxIdx
52 dot = dotF
53
54instance Vectors Vector Double where
55 vectorSum = sumR
56 euclidean = toScalarR Norm2
57 absSum = toScalarR AbsSum
58 vectorMin = toScalarR Min
59 vectorMax = toScalarR Max
60 minIdx = round . toScalarR MinIdx
61 maxIdx = round . toScalarR MaxIdx
62 dot = dotR
63
64instance Vectors Vector (Complex Float) where
65 vectorSum = sumQ
66 euclidean = undefined
67 absSum = undefined
68 vectorMin = undefined
69 vectorMax = undefined
70 minIdx = undefined
71 maxIdx = undefined
72 dot = dotQ
73
74instance Vectors Vector (Complex Double) where
75 vectorSum = sumC
76 euclidean = undefined
77 absSum = undefined
78 vectorMin = undefined
79 vectorMax = undefined
80 minIdx = undefined
81 maxIdx = undefined
82 dot = dotC
83
84----------------------------------------------------
85
26-- | Basic element-by-element functions. 86-- | Basic element-by-element functions.
27class (Container c e) => Linear c e where 87class (Container c e) => Linear c e where
28 -- | create a structure with a single element 88 -- | create a structure with a single element
@@ -50,7 +110,7 @@ instance Linear Vector Float where
50 sub = vectorZipF Sub 110 sub = vectorZipF Sub
51 mul = vectorZipF Mul 111 mul = vectorZipF Mul
52 divide = vectorZipF Div 112 divide = vectorZipF Div
53 equal u v = dim u == dim v && vectorFMax (vectorMapF Abs (sub u v)) == 0.0 113 equal u v = dim u == dim v && vectorMax (vectorMapF Abs (sub u v)) == 0.0
54 scalar x = fromList [x] 114 scalar x = fromList [x]
55 115
56instance Linear Vector Double where 116instance Linear Vector Double where