diff options
Diffstat (limited to 'packages')
-rw-r--r-- | packages/base/src/Data/Packed/Internal/Matrix.hs | 2 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/Internal/Numeric.hs | 439 | ||||
-rw-r--r-- | packages/base/src/Numeric/Container.hs | 8 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra.hs | 21 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Devel.hs | 10 | ||||
-rw-r--r-- | packages/base/src/Numeric/Matrix.hs | 4 |
6 files changed, 282 insertions, 202 deletions
diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs index 9b831cc..91a9466 100644 --- a/packages/base/src/Data/Packed/Internal/Matrix.hs +++ b/packages/base/src/Data/Packed/Internal/Matrix.hs | |||
@@ -80,7 +80,7 @@ data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) | |||
80 | 80 | ||
81 | transOrder RowMajor = ColumnMajor | 81 | transOrder RowMajor = ColumnMajor |
82 | transOrder ColumnMajor = RowMajor | 82 | transOrder ColumnMajor = RowMajor |
83 | {- | Matrix representation suitable for GSL and LAPACK computations. | 83 | {- | Matrix representation suitable for BLAS\/LAPACK computations. |
84 | 84 | ||
85 | The elements are stored in a continuous memory array. | 85 | The elements are stored in a continuous memory array. |
86 | 86 | ||
diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Data/Packed/Internal/Numeric.hs index 3528e96..9cd18df 100644 --- a/packages/base/src/Data/Packed/Internal/Numeric.hs +++ b/packages/base/src/Data/Packed/Internal/Numeric.hs | |||
@@ -20,6 +20,10 @@ module Data.Packed.Internal.Numeric ( | |||
20 | ident, diag, ctrans, | 20 | ident, diag, ctrans, |
21 | -- * Generic operations | 21 | -- * Generic operations |
22 | Container(..), | 22 | Container(..), |
23 | scalar, conj, scale, arctan2, cmap, | ||
24 | atIndex, minIndex, maxIndex, minElement, maxElement, | ||
25 | sumElements, prodElements, | ||
26 | step, cond, find, assoc, accum, | ||
23 | Transposable(..), Linear(..), Testable(..), | 27 | Transposable(..), Linear(..), Testable(..), |
24 | -- * Matrix product and related functions | 28 | -- * Matrix product and related functions |
25 | Product(..), udot, | 29 | Product(..), udot, |
@@ -62,16 +66,9 @@ type instance ArgOf Matrix a = a -> a -> a | |||
62 | 66 | ||
63 | -- | Basic element-by-element functions for numeric containers | 67 | -- | Basic element-by-element functions for numeric containers |
64 | class (Complexable c, Fractional e, Element e) => Container c e where | 68 | class (Complexable c, Fractional e, Element e) => Container c e where |
65 | -- | create a structure with a single element | 69 | scalar' :: e -> c e |
66 | -- | 70 | conj' :: c e -> c e |
67 | -- >>> let v = fromList [1..3::Double] | 71 | scale' :: e -> c e -> c e |
68 | -- >>> v / scalar (norm2 v) | ||
69 | -- fromList [0.2672612419124244,0.5345224838248488,0.8017837257372732] | ||
70 | -- | ||
71 | scalar :: e -> c e | ||
72 | -- | complex conjugate | ||
73 | conj :: c e -> c e | ||
74 | scale :: e -> c e -> c e | ||
75 | -- | scale the element by element reciprocal of the object: | 72 | -- | scale the element by element reciprocal of the object: |
76 | -- | 73 | -- |
77 | -- @scaleRecip 2 (fromList [5,i]) == 2 |> [0.4 :+ 0.0,0.0 :+ (-2.0)]@ | 74 | -- @scaleRecip 2 (fromList [5,i]) == 2 |> [0.4 :+ 0.0,0.0 :+ (-2.0)]@ |
@@ -86,101 +83,31 @@ class (Complexable c, Fractional e, Element e) => Container c e where | |||
86 | equal :: c e -> c e -> Bool | 83 | equal :: c e -> c e -> Bool |
87 | -- | 84 | -- |
88 | -- element by element inverse tangent | 85 | -- element by element inverse tangent |
89 | arctan2 :: c e -> c e -> c e | 86 | arctan2' :: c e -> c e -> c e |
90 | -- | 87 | cmap' :: (Element b) => (e -> b) -> c e -> c b |
91 | -- | cannot implement instance Functor because of Element class constraint | ||
92 | cmap :: (Element b) => (e -> b) -> c e -> c b | ||
93 | -- | constant structure of given size | ||
94 | konst' :: e -> IndexOf c -> c e | 88 | konst' :: e -> IndexOf c -> c e |
95 | -- | create a structure using a function | ||
96 | -- | ||
97 | -- Hilbert matrix of order N: | ||
98 | -- | ||
99 | -- @hilb n = build' (n,n) (\\i j -> 1/(i+j+1))@ | ||
100 | build' :: IndexOf c -> (ArgOf c e) -> c e | 89 | build' :: IndexOf c -> (ArgOf c e) -> c e |
101 | -- | indexing function | 90 | atIndex' :: c e -> IndexOf c -> e |
102 | atIndex :: c e -> IndexOf c -> e | 91 | minIndex' :: c e -> IndexOf c |
103 | -- | index of min element | 92 | maxIndex' :: c e -> IndexOf c |
104 | minIndex :: c e -> IndexOf c | 93 | minElement' :: c e -> e |
105 | -- | index of max element | 94 | maxElement' :: c e -> e |
106 | maxIndex :: c e -> IndexOf c | 95 | sumElements' :: c e -> e |
107 | -- | value of min element | 96 | prodElements' :: c e -> e |
108 | minElement :: c e -> e | 97 | step' :: RealElement e => c e -> c e |
109 | -- | value of max element | 98 | cond' :: RealElement e |
110 | maxElement :: c e -> e | ||
111 | -- the C functions sumX/prodX are twice as fast as using foldVector | ||
112 | -- | the sum of elements (faster than using @fold@) | ||
113 | sumElements :: c e -> e | ||
114 | -- | the product of elements (faster than using @fold@) | ||
115 | prodElements :: c e -> e | ||
116 | |||
117 | -- | A more efficient implementation of @cmap (\\x -> if x>0 then 1 else 0)@ | ||
118 | -- | ||
119 | -- >>> step $ linspace 5 (-1,1::Double) | ||
120 | -- 5 |> [0.0,0.0,0.0,1.0,1.0] | ||
121 | -- | ||
122 | |||
123 | step :: RealElement e => c e -> c e | ||
124 | |||
125 | -- | Element by element version of @case compare a b of {LT -> l; EQ -> e; GT -> g}@. | ||
126 | -- | ||
127 | -- Arguments with any dimension = 1 are automatically expanded: | ||
128 | -- | ||
129 | -- >>> cond ((1><4)[1..]) ((3><1)[1..]) 0 100 ((3><4)[1..]) :: Matrix Double | ||
130 | -- (3><4) | ||
131 | -- [ 100.0, 2.0, 3.0, 4.0 | ||
132 | -- , 0.0, 100.0, 7.0, 8.0 | ||
133 | -- , 0.0, 0.0, 100.0, 12.0 ] | ||
134 | -- | ||
135 | |||
136 | cond :: RealElement e | ||
137 | => c e -- ^ a | 99 | => c e -- ^ a |
138 | -> c e -- ^ b | 100 | -> c e -- ^ b |
139 | -> c e -- ^ l | 101 | -> c e -- ^ l |
140 | -> c e -- ^ e | 102 | -> c e -- ^ e |
141 | -> c e -- ^ g | 103 | -> c e -- ^ g |
142 | -> c e -- ^ result | 104 | -> c e -- ^ result |
143 | 105 | find' :: (e -> Bool) -> c e -> [IndexOf c] | |
144 | -- | Find index of elements which satisfy a predicate | 106 | assoc' :: IndexOf c -- ^ size |
145 | -- | ||
146 | -- >>> find (>0) (ident 3 :: Matrix Double) | ||
147 | -- [(0,0),(1,1),(2,2)] | ||
148 | -- | ||
149 | |||
150 | find :: (e -> Bool) -> c e -> [IndexOf c] | ||
151 | |||
152 | -- | Create a structure from an association list | ||
153 | -- | ||
154 | -- >>> assoc 5 0 [(3,7),(1,4)] :: Vector Double | ||
155 | -- fromList [0.0,4.0,0.0,7.0,0.0] | ||
156 | -- | ||
157 | -- >>> assoc (2,3) 0 [((0,2),7),((1,0),2*i-3)] :: Matrix (Complex Double) | ||
158 | -- (2><3) | ||
159 | -- [ 0.0 :+ 0.0, 0.0 :+ 0.0, 7.0 :+ 0.0 | ||
160 | -- , (-3.0) :+ 2.0, 0.0 :+ 0.0, 0.0 :+ 0.0 ] | ||
161 | -- | ||
162 | assoc :: IndexOf c -- ^ size | ||
163 | -> e -- ^ default value | 107 | -> e -- ^ default value |
164 | -> [(IndexOf c, e)] -- ^ association list | 108 | -> [(IndexOf c, e)] -- ^ association list |
165 | -> c e -- ^ result | 109 | -> c e -- ^ result |
166 | 110 | accum' :: c e -- ^ initial structure | |
167 | -- | Modify a structure using an update function | ||
168 | -- | ||
169 | -- >>> accum (ident 5) (+) [((1,1),5),((0,3),3)] :: Matrix Double | ||
170 | -- (5><5) | ||
171 | -- [ 1.0, 0.0, 0.0, 3.0, 0.0 | ||
172 | -- , 0.0, 6.0, 0.0, 0.0, 0.0 | ||
173 | -- , 0.0, 0.0, 1.0, 0.0, 0.0 | ||
174 | -- , 0.0, 0.0, 0.0, 1.0, 0.0 | ||
175 | -- , 0.0, 0.0, 0.0, 0.0, 1.0 ] | ||
176 | -- | ||
177 | -- computation of histogram: | ||
178 | -- | ||
179 | -- >>> accum (konst 0 7) (+) (map (flip (,) 1) [4,5,4,1,5,2,5]) :: Vector Double | ||
180 | -- fromList [0.0,1.0,1.0,0.0,2.0,3.0,0.0] | ||
181 | -- | ||
182 | |||
183 | accum :: c e -- ^ initial structure | ||
184 | -> (e -> e -> e) -- ^ update function | 111 | -> (e -> e -> e) -- ^ update function |
185 | -> [(IndexOf c, e)] -- ^ association list | 112 | -> [(IndexOf c, e)] -- ^ association list |
186 | -> c e -- ^ result | 113 | -> c e -- ^ result |
@@ -188,7 +115,7 @@ class (Complexable c, Fractional e, Element e) => Container c e where | |||
188 | -------------------------------------------------------------------------- | 115 | -------------------------------------------------------------------------- |
189 | 116 | ||
190 | instance Container Vector Float where | 117 | instance Container Vector Float where |
191 | scale = vectorMapValF Scale | 118 | scale' = vectorMapValF Scale |
192 | scaleRecip = vectorMapValF Recip | 119 | scaleRecip = vectorMapValF Recip |
193 | addConstant = vectorMapValF AddConstant | 120 | addConstant = vectorMapValF AddConstant |
194 | add = vectorZipF Add | 121 | add = vectorZipF Add |
@@ -196,27 +123,27 @@ instance Container Vector Float where | |||
196 | mul = vectorZipF Mul | 123 | mul = vectorZipF Mul |
197 | divide = vectorZipF Div | 124 | divide = vectorZipF Div |
198 | equal u v = dim u == dim v && maxElement (vectorMapF Abs (sub u v)) == 0.0 | 125 | equal u v = dim u == dim v && maxElement (vectorMapF Abs (sub u v)) == 0.0 |
199 | arctan2 = vectorZipF ATan2 | 126 | arctan2' = vectorZipF ATan2 |
200 | scalar x = fromList [x] | 127 | scalar' x = fromList [x] |
201 | konst' = constant | 128 | konst' = constant |
202 | build' = buildV | 129 | build' = buildV |
203 | conj = id | 130 | conj' = id |
204 | cmap = mapVector | 131 | cmap' = mapVector |
205 | atIndex = (@>) | 132 | atIndex' = (@>) |
206 | minIndex = emptyErrorV "minIndex" (round . toScalarF MinIdx) | 133 | minIndex' = emptyErrorV "minIndex" (round . toScalarF MinIdx) |
207 | maxIndex = emptyErrorV "maxIndex" (round . toScalarF MaxIdx) | 134 | maxIndex' = emptyErrorV "maxIndex" (round . toScalarF MaxIdx) |
208 | minElement = emptyErrorV "minElement" (toScalarF Min) | 135 | minElement' = emptyErrorV "minElement" (toScalarF Min) |
209 | maxElement = emptyErrorV "maxElement" (toScalarF Max) | 136 | maxElement' = emptyErrorV "maxElement" (toScalarF Max) |
210 | sumElements = sumF | 137 | sumElements' = sumF |
211 | prodElements = prodF | 138 | prodElements' = prodF |
212 | step = stepF | 139 | step' = stepF |
213 | find = findV | 140 | find' = findV |
214 | assoc = assocV | 141 | assoc' = assocV |
215 | accum = accumV | 142 | accum' = accumV |
216 | cond = condV condF | 143 | cond' = condV condF |
217 | 144 | ||
218 | instance Container Vector Double where | 145 | instance Container Vector Double where |
219 | scale = vectorMapValR Scale | 146 | scale' = vectorMapValR Scale |
220 | scaleRecip = vectorMapValR Recip | 147 | scaleRecip = vectorMapValR Recip |
221 | addConstant = vectorMapValR AddConstant | 148 | addConstant = vectorMapValR AddConstant |
222 | add = vectorZipR Add | 149 | add = vectorZipR Add |
@@ -224,27 +151,27 @@ instance Container Vector Double where | |||
224 | mul = vectorZipR Mul | 151 | mul = vectorZipR Mul |
225 | divide = vectorZipR Div | 152 | divide = vectorZipR Div |
226 | equal u v = dim u == dim v && maxElement (vectorMapR Abs (sub u v)) == 0.0 | 153 | equal u v = dim u == dim v && maxElement (vectorMapR Abs (sub u v)) == 0.0 |
227 | arctan2 = vectorZipR ATan2 | 154 | arctan2' = vectorZipR ATan2 |
228 | scalar x = fromList [x] | 155 | scalar' x = fromList [x] |
229 | konst' = constant | 156 | konst' = constant |
230 | build' = buildV | 157 | build' = buildV |
231 | conj = id | 158 | conj' = id |
232 | cmap = mapVector | 159 | cmap' = mapVector |
233 | atIndex = (@>) | 160 | atIndex' = (@>) |
234 | minIndex = emptyErrorV "minIndex" (round . toScalarR MinIdx) | 161 | minIndex' = emptyErrorV "minIndex" (round . toScalarR MinIdx) |
235 | maxIndex = emptyErrorV "maxIndex" (round . toScalarR MaxIdx) | 162 | maxIndex' = emptyErrorV "maxIndex" (round . toScalarR MaxIdx) |
236 | minElement = emptyErrorV "minElement" (toScalarR Min) | 163 | minElement' = emptyErrorV "minElement" (toScalarR Min) |
237 | maxElement = emptyErrorV "maxElement" (toScalarR Max) | 164 | maxElement' = emptyErrorV "maxElement" (toScalarR Max) |
238 | sumElements = sumR | 165 | sumElements' = sumR |
239 | prodElements = prodR | 166 | prodElements' = prodR |
240 | step = stepD | 167 | step' = stepD |
241 | find = findV | 168 | find' = findV |
242 | assoc = assocV | 169 | assoc' = assocV |
243 | accum = accumV | 170 | accum' = accumV |
244 | cond = condV condD | 171 | cond' = condV condD |
245 | 172 | ||
246 | instance Container Vector (Complex Double) where | 173 | instance Container Vector (Complex Double) where |
247 | scale = vectorMapValC Scale | 174 | scale' = vectorMapValC Scale |
248 | scaleRecip = vectorMapValC Recip | 175 | scaleRecip = vectorMapValC Recip |
249 | addConstant = vectorMapValC AddConstant | 176 | addConstant = vectorMapValC AddConstant |
250 | add = vectorZipC Add | 177 | add = vectorZipC Add |
@@ -252,27 +179,27 @@ instance Container Vector (Complex Double) where | |||
252 | mul = vectorZipC Mul | 179 | mul = vectorZipC Mul |
253 | divide = vectorZipC Div | 180 | divide = vectorZipC Div |
254 | equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 | 181 | equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 |
255 | arctan2 = vectorZipC ATan2 | 182 | arctan2' = vectorZipC ATan2 |
256 | scalar x = fromList [x] | 183 | scalar' x = fromList [x] |
257 | konst' = constant | 184 | konst' = constant |
258 | build' = buildV | 185 | build' = buildV |
259 | conj = conjugateC | 186 | conj' = conjugateC |
260 | cmap = mapVector | 187 | cmap' = mapVector |
261 | atIndex = (@>) | 188 | atIndex' = (@>) |
262 | minIndex = emptyErrorV "minIndex" (minIndex . fst . fromComplex . (mul <*> conj)) | 189 | minIndex' = emptyErrorV "minIndex" (minIndex' . fst . fromComplex . (mul <*> conj')) |
263 | maxIndex = emptyErrorV "maxIndex" (maxIndex . fst . fromComplex . (mul <*> conj)) | 190 | maxIndex' = emptyErrorV "maxIndex" (maxIndex' . fst . fromComplex . (mul <*> conj')) |
264 | minElement = emptyErrorV "minElement" (atIndex <*> minIndex) | 191 | minElement' = emptyErrorV "minElement" (atIndex' <*> minIndex') |
265 | maxElement = emptyErrorV "maxElement" (atIndex <*> maxIndex) | 192 | maxElement' = emptyErrorV "maxElement" (atIndex' <*> maxIndex') |
266 | sumElements = sumC | 193 | sumElements' = sumC |
267 | prodElements = prodC | 194 | prodElements' = prodC |
268 | step = undefined -- cannot match | 195 | step' = undefined -- cannot match |
269 | find = findV | 196 | find' = findV |
270 | assoc = assocV | 197 | assoc' = assocV |
271 | accum = accumV | 198 | accum' = accumV |
272 | cond = undefined -- cannot match | 199 | cond' = undefined -- cannot match |
273 | 200 | ||
274 | instance Container Vector (Complex Float) where | 201 | instance Container Vector (Complex Float) where |
275 | scale = vectorMapValQ Scale | 202 | scale' = vectorMapValQ Scale |
276 | scaleRecip = vectorMapValQ Recip | 203 | scaleRecip = vectorMapValQ Recip |
277 | addConstant = vectorMapValQ AddConstant | 204 | addConstant = vectorMapValQ AddConstant |
278 | add = vectorZipQ Add | 205 | add = vectorZipQ Add |
@@ -280,29 +207,29 @@ instance Container Vector (Complex Float) where | |||
280 | mul = vectorZipQ Mul | 207 | mul = vectorZipQ Mul |
281 | divide = vectorZipQ Div | 208 | divide = vectorZipQ Div |
282 | equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 | 209 | equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 |
283 | arctan2 = vectorZipQ ATan2 | 210 | arctan2' = vectorZipQ ATan2 |
284 | scalar x = fromList [x] | 211 | scalar' x = fromList [x] |
285 | konst' = constant | 212 | konst' = constant |
286 | build' = buildV | 213 | build' = buildV |
287 | conj = conjugateQ | 214 | conj' = conjugateQ |
288 | cmap = mapVector | 215 | cmap' = mapVector |
289 | atIndex = (@>) | 216 | atIndex' = (@>) |
290 | minIndex = emptyErrorV "minIndex" (minIndex . fst . fromComplex . (mul <*> conj)) | 217 | minIndex' = emptyErrorV "minIndex" (minIndex' . fst . fromComplex . (mul <*> conj')) |
291 | maxIndex = emptyErrorV "maxIndex" (maxIndex . fst . fromComplex . (mul <*> conj)) | 218 | maxIndex' = emptyErrorV "maxIndex" (maxIndex' . fst . fromComplex . (mul <*> conj')) |
292 | minElement = emptyErrorV "minElement" (atIndex <*> minIndex) | 219 | minElement' = emptyErrorV "minElement" (atIndex' <*> minIndex') |
293 | maxElement = emptyErrorV "maxElement" (atIndex <*> maxIndex) | 220 | maxElement' = emptyErrorV "maxElement" (atIndex' <*> maxIndex') |
294 | sumElements = sumQ | 221 | sumElements' = sumQ |
295 | prodElements = prodQ | 222 | prodElements' = prodQ |
296 | step = undefined -- cannot match | 223 | step' = undefined -- cannot match |
297 | find = findV | 224 | find' = findV |
298 | assoc = assocV | 225 | assoc' = assocV |
299 | accum = accumV | 226 | accum' = accumV |
300 | cond = undefined -- cannot match | 227 | cond' = undefined -- cannot match |
301 | 228 | ||
302 | --------------------------------------------------------------- | 229 | --------------------------------------------------------------- |
303 | 230 | ||
304 | instance (Container Vector a) => Container Matrix a where | 231 | instance (Container Vector a) => Container Matrix a where |
305 | scale x = liftMatrix (scale x) | 232 | scale' x = liftMatrix (scale' x) |
306 | scaleRecip x = liftMatrix (scaleRecip x) | 233 | scaleRecip x = liftMatrix (scaleRecip x) |
307 | addConstant x = liftMatrix (addConstant x) | 234 | addConstant x = liftMatrix (addConstant x) |
308 | add = liftMatrix2 add | 235 | add = liftMatrix2 add |
@@ -310,26 +237,26 @@ instance (Container Vector a) => Container Matrix a where | |||
310 | mul = liftMatrix2 mul | 237 | mul = liftMatrix2 mul |
311 | divide = liftMatrix2 divide | 238 | divide = liftMatrix2 divide |
312 | equal a b = cols a == cols b && flatten a `equal` flatten b | 239 | equal a b = cols a == cols b && flatten a `equal` flatten b |
313 | arctan2 = liftMatrix2 arctan2 | 240 | arctan2' = liftMatrix2 arctan2' |
314 | scalar x = (1><1) [x] | 241 | scalar' x = (1><1) [x] |
315 | konst' v (r,c) = matrixFromVector RowMajor r c (konst' v (r*c)) | 242 | konst' v (r,c) = matrixFromVector RowMajor r c (konst' v (r*c)) |
316 | build' = buildM | 243 | build' = buildM |
317 | conj = liftMatrix conj | 244 | conj' = liftMatrix conj' |
318 | cmap f = liftMatrix (mapVector f) | 245 | cmap' f = liftMatrix (mapVector f) |
319 | atIndex = (@@>) | 246 | atIndex' = (@@>) |
320 | minIndex = emptyErrorM "minIndex of Matrix" $ | 247 | minIndex' = emptyErrorM "minIndex of Matrix" $ |
321 | \m -> divMod (minIndex $ flatten m) (cols m) | 248 | \m -> divMod (minIndex' $ flatten m) (cols m) |
322 | maxIndex = emptyErrorM "maxIndex of Matrix" $ | 249 | maxIndex' = emptyErrorM "maxIndex of Matrix" $ |
323 | \m -> divMod (maxIndex $ flatten m) (cols m) | 250 | \m -> divMod (maxIndex' $ flatten m) (cols m) |
324 | minElement = emptyErrorM "minElement of Matrix" (atIndex <*> minIndex) | 251 | minElement' = emptyErrorM "minElement of Matrix" (atIndex' <*> minIndex') |
325 | maxElement = emptyErrorM "maxElement of Matrix" (atIndex <*> maxIndex) | 252 | maxElement' = emptyErrorM "maxElement of Matrix" (atIndex' <*> maxIndex') |
326 | sumElements = sumElements . flatten | 253 | sumElements' = sumElements . flatten |
327 | prodElements = prodElements . flatten | 254 | prodElements' = prodElements . flatten |
328 | step = liftMatrix step | 255 | step' = liftMatrix step |
329 | find = findM | 256 | find' = findM |
330 | assoc = assocM | 257 | assoc' = assocM |
331 | accum = accumM | 258 | accum' = accumM |
332 | cond = condM | 259 | cond' = condM |
333 | 260 | ||
334 | 261 | ||
335 | emptyErrorV msg f v = | 262 | emptyErrorV msg f v = |
@@ -342,7 +269,151 @@ emptyErrorM msg f m = | |||
342 | then f m | 269 | then f m |
343 | else error $ msg++" "++shSize m | 270 | else error $ msg++" "++shSize m |
344 | 271 | ||
345 | ---------------------------------------------------- | 272 | -------------------------------------------------------------------------------- |
273 | |||
274 | -- | create a structure with a single element | ||
275 | -- | ||
276 | -- >>> let v = fromList [1..3::Double] | ||
277 | -- >>> v / scalar (norm2 v) | ||
278 | -- fromList [0.2672612419124244,0.5345224838248488,0.8017837257372732] | ||
279 | -- | ||
280 | scalar :: Container c e => e -> c e | ||
281 | scalar = scalar' | ||
282 | |||
283 | -- | complex conjugate | ||
284 | conj :: Container c e => c e -> c e | ||
285 | conj = conj' | ||
286 | |||
287 | -- | multiplication by scalar | ||
288 | scale :: Container c e => e -> c e -> c e | ||
289 | scale = scale' | ||
290 | |||
291 | arctan2 :: Container c e => c e -> c e -> c e | ||
292 | arctan2 = arctan2' | ||
293 | |||
294 | -- | like 'fmap' (cannot implement instance Functor because of Element class constraint) | ||
295 | cmap :: (Element b, Container c e) => (e -> b) -> c e -> c b | ||
296 | cmap = cmap' | ||
297 | |||
298 | -- | indexing function | ||
299 | atIndex :: Container c e => c e -> IndexOf c -> e | ||
300 | atIndex = atIndex' | ||
301 | |||
302 | -- | index of minimum element | ||
303 | minIndex :: Container c e => c e -> IndexOf c | ||
304 | minIndex = minIndex' | ||
305 | |||
306 | -- | index of maximum element | ||
307 | maxIndex :: Container c e => c e -> IndexOf c | ||
308 | maxIndex = maxIndex' | ||
309 | |||
310 | -- | value of minimum element | ||
311 | minElement :: Container c e => c e -> e | ||
312 | minElement = minElement' | ||
313 | |||
314 | -- | value of maximum element | ||
315 | maxElement :: Container c e => c e -> e | ||
316 | maxElement = maxElement' | ||
317 | |||
318 | -- | the sum of elements | ||
319 | sumElements :: Container c e => c e -> e | ||
320 | sumElements = sumElements' | ||
321 | |||
322 | -- | the product of elements | ||
323 | prodElements :: Container c e => c e -> e | ||
324 | prodElements = prodElements' | ||
325 | |||
326 | |||
327 | -- | A more efficient implementation of @cmap (\\x -> if x>0 then 1 else 0)@ | ||
328 | -- | ||
329 | -- >>> step $ linspace 5 (-1,1::Double) | ||
330 | -- 5 |> [0.0,0.0,0.0,1.0,1.0] | ||
331 | -- | ||
332 | step | ||
333 | :: (RealElement e, Container c e) | ||
334 | => c e | ||
335 | -> c e | ||
336 | step = step' | ||
337 | |||
338 | |||
339 | -- | Element by element version of @case compare a b of {LT -> l; EQ -> e; GT -> g}@. | ||
340 | -- | ||
341 | -- Arguments with any dimension = 1 are automatically expanded: | ||
342 | -- | ||
343 | -- >>> cond ((1><4)[1..]) ((3><1)[1..]) 0 100 ((3><4)[1..]) :: Matrix Double | ||
344 | -- (3><4) | ||
345 | -- [ 100.0, 2.0, 3.0, 4.0 | ||
346 | -- , 0.0, 100.0, 7.0, 8.0 | ||
347 | -- , 0.0, 0.0, 100.0, 12.0 ] | ||
348 | -- | ||
349 | cond | ||
350 | :: (RealElement e, Container c e) | ||
351 | => c e -- ^ a | ||
352 | -> c e -- ^ b | ||
353 | -> c e -- ^ l | ||
354 | -> c e -- ^ e | ||
355 | -> c e -- ^ g | ||
356 | -> c e -- ^ result | ||
357 | cond = cond' | ||
358 | |||
359 | |||
360 | -- | Find index of elements which satisfy a predicate | ||
361 | -- | ||
362 | -- >>> find (>0) (ident 3 :: Matrix Double) | ||
363 | -- [(0,0),(1,1),(2,2)] | ||
364 | -- | ||
365 | find | ||
366 | :: Container c e | ||
367 | => (e -> Bool) | ||
368 | -> c e | ||
369 | -> [IndexOf c] | ||
370 | find = find' | ||
371 | |||
372 | |||
373 | -- | Create a structure from an association list | ||
374 | -- | ||
375 | -- >>> assoc 5 0 [(3,7),(1,4)] :: Vector Double | ||
376 | -- fromList [0.0,4.0,0.0,7.0,0.0] | ||
377 | -- | ||
378 | -- >>> assoc (2,3) 0 [((0,2),7),((1,0),2*i-3)] :: Matrix (Complex Double) | ||
379 | -- (2><3) | ||
380 | -- [ 0.0 :+ 0.0, 0.0 :+ 0.0, 7.0 :+ 0.0 | ||
381 | -- , (-3.0) :+ 2.0, 0.0 :+ 0.0, 0.0 :+ 0.0 ] | ||
382 | -- | ||
383 | assoc | ||
384 | :: Container c e | ||
385 | => IndexOf c -- ^ size | ||
386 | -> e -- ^ default value | ||
387 | -> [(IndexOf c, e)] -- ^ association list | ||
388 | -> c e -- ^ result | ||
389 | assoc = assoc' | ||
390 | |||
391 | |||
392 | -- | Modify a structure using an update function | ||
393 | -- | ||
394 | -- >>> accum (ident 5) (+) [((1,1),5),((0,3),3)] :: Matrix Double | ||
395 | -- (5><5) | ||
396 | -- [ 1.0, 0.0, 0.0, 3.0, 0.0 | ||
397 | -- , 0.0, 6.0, 0.0, 0.0, 0.0 | ||
398 | -- , 0.0, 0.0, 1.0, 0.0, 0.0 | ||
399 | -- , 0.0, 0.0, 0.0, 1.0, 0.0 | ||
400 | -- , 0.0, 0.0, 0.0, 0.0, 1.0 ] | ||
401 | -- | ||
402 | -- computation of histogram: | ||
403 | -- | ||
404 | -- >>> accum (konst 0 7) (+) (map (flip (,) 1) [4,5,4,1,5,2,5]) :: Vector Double | ||
405 | -- fromList [0.0,1.0,1.0,0.0,2.0,3.0,0.0] | ||
406 | -- | ||
407 | accum | ||
408 | :: Container c e | ||
409 | => c e -- ^ initial structure | ||
410 | -> (e -> e -> e) -- ^ update function | ||
411 | -> [(IndexOf c, e)] -- ^ association list | ||
412 | -> c e -- ^ result | ||
413 | accum = accum' | ||
414 | |||
415 | |||
416 | -------------------------------------------------------------------------------- | ||
346 | 417 | ||
347 | -- | Matrix product and related functions | 418 | -- | Matrix product and related functions |
348 | class (Num e, Element e) => Product e where | 419 | class (Num e, Element e) => Product e where |
@@ -558,7 +629,7 @@ buildV n f = fromList [f k | k <- ks] | |||
558 | -------------------------------------------------------- | 629 | -------------------------------------------------------- |
559 | -- | conjugate transpose | 630 | -- | conjugate transpose |
560 | ctrans :: (Container Vector e, Element e) => Matrix e -> Matrix e | 631 | ctrans :: (Container Vector e, Element e) => Matrix e -> Matrix e |
561 | ctrans = liftMatrix conj . trans | 632 | ctrans = liftMatrix conj' . trans |
562 | 633 | ||
563 | -- | Creates a square matrix with a given diagonal. | 634 | -- | Creates a square matrix with a given diagonal. |
564 | diag :: (Num a, Element a) => Vector a -> Matrix a | 635 | diag :: (Num a, Element a) => Vector a -> Matrix a |
diff --git a/packages/base/src/Numeric/Container.hs b/packages/base/src/Numeric/Container.hs index 0633640..067c5fa 100644 --- a/packages/base/src/Numeric/Container.hs +++ b/packages/base/src/Numeric/Container.hs | |||
@@ -32,7 +32,13 @@ module Numeric.Container ( | |||
32 | diag, ident, | 32 | diag, ident, |
33 | ctrans, | 33 | ctrans, |
34 | -- * Generic operations | 34 | -- * Generic operations |
35 | Container(..), Transposable(..), Linear(..), | 35 | Container, |
36 | add, mul, sub, divide, equal, scaleRecip, addConstant, | ||
37 | scalar, conj, scale, arctan2, cmap, | ||
38 | atIndex, minIndex, maxIndex, minElement, maxElement, | ||
39 | sumElements, prodElements, | ||
40 | step, cond, find, assoc, accum, | ||
41 | Transposable(..), Linear(..), | ||
36 | -- * Matrix product | 42 | -- * Matrix product |
37 | Product(..), udot, dot, (◇), | 43 | Product(..), udot, dot, (◇), |
38 | Mul(..), | 44 | Mul(..), |
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs index 549ebd0..9e9151e 100644 --- a/packages/base/src/Numeric/LinearAlgebra.hs +++ b/packages/base/src/Numeric/LinearAlgebra.hs | |||
@@ -13,7 +13,9 @@ module Numeric.LinearAlgebra ( | |||
13 | -- * Basic types and data processing | 13 | -- * Basic types and data processing |
14 | module Numeric.LinearAlgebra.Data, | 14 | module Numeric.LinearAlgebra.Data, |
15 | 15 | ||
16 | -- | The standard numeric classes are defined elementwise: | 16 | -- * Arithmetic and numeric classes |
17 | -- | | ||
18 | -- The standard numeric classes are defined elementwise: | ||
17 | -- | 19 | -- |
18 | -- >>> fromList [1,2,3] * fromList [3,0,-2 :: Double] | 20 | -- >>> fromList [1,2,3] * fromList [3,0,-2 :: Double] |
19 | -- fromList [3.0,0.0,-6.0] | 21 | -- fromList [3.0,0.0,-6.0] |
@@ -38,7 +40,7 @@ module Numeric.LinearAlgebra ( | |||
38 | -- * Matrix product | 40 | -- * Matrix product |
39 | (<.>), | 41 | (<.>), |
40 | 42 | ||
41 | -- | The overloaded multiplication operator may need type annotations to remove | 43 | -- | The overloaded multiplication operators may need type annotations to remove |
42 | -- ambiguity. In those cases we can also use the specific functions 'mXm', 'mXv', and 'dot'. | 44 | -- ambiguity. In those cases we can also use the specific functions 'mXm', 'mXv', and 'dot'. |
43 | -- | 45 | -- |
44 | -- The matrix x matrix product is also implemented in the "Data.Monoid" instance, where | 46 | -- The matrix x matrix product is also implemented in the "Data.Monoid" instance, where |
@@ -66,6 +68,7 @@ module Numeric.LinearAlgebra ( | |||
66 | linearSolveSVD, | 68 | linearSolveSVD, |
67 | luSolve, | 69 | luSolve, |
68 | cholSolve, | 70 | cholSolve, |
71 | cgSolve, | ||
69 | 72 | ||
70 | -- * Inverse and pseudoinverse | 73 | -- * Inverse and pseudoinverse |
71 | inv, pinv, pinvTol, | 74 | inv, pinv, pinvTol, |
@@ -126,7 +129,15 @@ module Numeric.LinearAlgebra ( | |||
126 | RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample, | 129 | RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample, |
127 | 130 | ||
128 | -- * Misc | 131 | -- * Misc |
129 | meanCov, peps, relativeError, haussholder, optimiseMult, dot, udot, mXm, mXv, smXv, (<>), (◇), Seed, checkT | 132 | meanCov, peps, relativeError, haussholder, optimiseMult, dot, udot, mXm, mXv, smXv, (<>), (◇), Seed, checkT, |
133 | -- * Auxiliary classes | ||
134 | Element, Container, Product, Contraction, LSDiv, | ||
135 | Complexable(), RealElement(), | ||
136 | RealOf, ComplexOf, SingleOf, DoubleOf, | ||
137 | IndexOf, | ||
138 | Field, Normed, | ||
139 | CGMat, Transposable | ||
140 | |||
130 | ) where | 141 | ) where |
131 | 142 | ||
132 | import Numeric.LinearAlgebra.Data | 143 | import Numeric.LinearAlgebra.Data |
@@ -137,5 +148,7 @@ import Numeric.Container | |||
137 | import Numeric.LinearAlgebra.Algorithms | 148 | import Numeric.LinearAlgebra.Algorithms |
138 | import Numeric.LinearAlgebra.Util | 149 | import Numeric.LinearAlgebra.Util |
139 | import Numeric.LinearAlgebra.Random | 150 | import Numeric.LinearAlgebra.Random |
140 | import Data.Packed.Internal.Sparse(smXv) | 151 | import Numeric.Sparse(smXv) |
152 | import Numeric.LinearAlgebra.Util.CG(cgSolve) | ||
153 | import Numeric.LinearAlgebra.Util.CG(CGMat) | ||
141 | 154 | ||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Devel.hs b/packages/base/src/Numeric/LinearAlgebra/Devel.hs index c41db2d..ca9e53a 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Devel.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Devel.hs | |||
@@ -49,20 +49,10 @@ module Numeric.LinearAlgebra.Devel( | |||
49 | mapMatrixWithIndex, mapMatrixWithIndexM, mapMatrixWithIndexM_, | 49 | mapMatrixWithIndex, mapMatrixWithIndexM, mapMatrixWithIndexM_, |
50 | liftMatrix, liftMatrix2, liftMatrix2Auto, | 50 | liftMatrix, liftMatrix2, liftMatrix2Auto, |
51 | 51 | ||
52 | -- * Auxiliary classes | ||
53 | Element, Container, Product, Contraction, LSDiv, | ||
54 | Complexable(), RealElement(), | ||
55 | RealOf, ComplexOf, SingleOf, DoubleOf, | ||
56 | IndexOf, | ||
57 | Field, Normed | ||
58 | ) where | 52 | ) where |
59 | 53 | ||
60 | import Data.Packed.Foreign | 54 | import Data.Packed.Foreign |
61 | import Data.Packed.Development | 55 | import Data.Packed.Development |
62 | import Data.Packed.ST | 56 | import Data.Packed.ST |
63 | import Numeric.Container(Container,Contraction,LSDiv,Product, | ||
64 | Complexable(),RealElement(), | ||
65 | RealOf, ComplexOf, SingleOf, DoubleOf, IndexOf) | ||
66 | import Data.Packed | 57 | import Data.Packed |
67 | import Numeric.LinearAlgebra.Algorithms(Field,Normed) | ||
68 | 58 | ||
diff --git a/packages/base/src/Numeric/Matrix.hs b/packages/base/src/Numeric/Matrix.hs index 719b591..a9022c6 100644 --- a/packages/base/src/Numeric/Matrix.hs +++ b/packages/base/src/Numeric/Matrix.hs | |||
@@ -90,8 +90,8 @@ instance (Container Vector t, Eq t, Num (Vector t), Product t) => M.Monoid (Matr | |||
90 | mconcat xs = work (partition isScalar xs) | 90 | mconcat xs = work (partition isScalar xs) |
91 | where | 91 | where |
92 | work (ss,[]) = product ss | 92 | work (ss,[]) = product ss |
93 | work (ss,ms) = scale' (product ss) (optimiseMult ms) | 93 | work (ss,ms) = scl (product ss) (optimiseMult ms) |
94 | scale' x m | 94 | scl x m |
95 | | isScalar x && x00 == 1 = m | 95 | | isScalar x && x00 == 1 = m |
96 | | otherwise = scale x00 m | 96 | | otherwise = scale x00 m |
97 | where | 97 | where |