diff options
Diffstat (limited to 'packages/base/src/Data/Packed')
-rw-r--r-- | packages/base/src/Data/Packed/Internal/Numeric.hs | 607 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/Numeric.hs | 680 |
2 files changed, 765 insertions, 522 deletions
diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Data/Packed/Internal/Numeric.hs new file mode 100644 index 0000000..81a8083 --- /dev/null +++ b/packages/base/src/Data/Packed/Internal/Numeric.hs | |||
@@ -0,0 +1,607 @@ | |||
1 | {-# LANGUAGE CPP #-} | ||
2 | {-# LANGUAGE TypeFamilies #-} | ||
3 | {-# LANGUAGE FlexibleContexts #-} | ||
4 | {-# LANGUAGE FlexibleInstances #-} | ||
5 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
6 | {-# LANGUAGE UndecidableInstances #-} | ||
7 | |||
8 | ----------------------------------------------------------------------------- | ||
9 | -- | | ||
10 | -- Module : Data.Packed.Internal.Numeric | ||
11 | -- Copyright : (c) Alberto Ruiz 2010-14 | ||
12 | -- License : BSD3 | ||
13 | -- Maintainer : Alberto Ruiz | ||
14 | -- Stability : provisional | ||
15 | -- | ||
16 | ----------------------------------------------------------------------------- | ||
17 | |||
18 | module Data.Packed.Internal.Numeric ( | ||
19 | -- * Basic functions | ||
20 | ident, diag, ctrans, | ||
21 | -- * Generic operations | ||
22 | Container(..), | ||
23 | -- * Matrix product and related functions | ||
24 | Product(..), udot, | ||
25 | mXm,mXv,vXm, | ||
26 | outer, kronecker, | ||
27 | -- * Element conversion | ||
28 | Convert(..), | ||
29 | Complexable(), | ||
30 | RealElement(), | ||
31 | |||
32 | RealOf, ComplexOf, SingleOf, DoubleOf, | ||
33 | |||
34 | IndexOf, | ||
35 | module Data.Complex | ||
36 | ) where | ||
37 | |||
38 | import Data.Packed | ||
39 | import Data.Packed.ST as ST | ||
40 | import Numeric.Conversion | ||
41 | import Data.Packed.Development | ||
42 | import Numeric.Vectorized | ||
43 | import Data.Complex | ||
44 | import Control.Applicative((<*>)) | ||
45 | |||
46 | import Numeric.LinearAlgebra.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ) | ||
47 | |||
48 | ------------------------------------------------------------------- | ||
49 | |||
50 | type family IndexOf (c :: * -> *) | ||
51 | |||
52 | type instance IndexOf Vector = Int | ||
53 | type instance IndexOf Matrix = (Int,Int) | ||
54 | |||
55 | type family ArgOf (c :: * -> *) a | ||
56 | |||
57 | type instance ArgOf Vector a = a -> a | ||
58 | type instance ArgOf Matrix a = a -> a -> a | ||
59 | |||
60 | ------------------------------------------------------------------- | ||
61 | |||
62 | -- | Basic element-by-element functions for numeric containers | ||
63 | class (Complexable c, Fractional e, Element e) => Container c e where | ||
64 | -- | create a structure with a single element | ||
65 | -- | ||
66 | -- >>> let v = fromList [1..3::Double] | ||
67 | -- >>> v / scalar (norm2 v) | ||
68 | -- fromList [0.2672612419124244,0.5345224838248488,0.8017837257372732] | ||
69 | -- | ||
70 | scalar :: e -> c e | ||
71 | -- | complex conjugate | ||
72 | conj :: c e -> c e | ||
73 | scale :: e -> c e -> c e | ||
74 | -- | scale the element by element reciprocal of the object: | ||
75 | -- | ||
76 | -- @scaleRecip 2 (fromList [5,i]) == 2 |> [0.4 :+ 0.0,0.0 :+ (-2.0)]@ | ||
77 | scaleRecip :: e -> c e -> c e | ||
78 | addConstant :: e -> c e -> c e | ||
79 | add :: c e -> c e -> c e | ||
80 | sub :: c e -> c e -> c e | ||
81 | -- | element by element multiplication | ||
82 | mul :: c e -> c e -> c e | ||
83 | -- | element by element division | ||
84 | divide :: c e -> c e -> c e | ||
85 | equal :: c e -> c e -> Bool | ||
86 | -- | ||
87 | -- element by element inverse tangent | ||
88 | arctan2 :: c e -> c e -> c e | ||
89 | -- | ||
90 | -- | cannot implement instance Functor because of Element class constraint | ||
91 | cmap :: (Element b) => (e -> b) -> c e -> c b | ||
92 | -- | constant structure of given size | ||
93 | konst' :: e -> IndexOf c -> c e | ||
94 | -- | create a structure using a function | ||
95 | -- | ||
96 | -- Hilbert matrix of order N: | ||
97 | -- | ||
98 | -- @hilb n = build' (n,n) (\\i j -> 1/(i+j+1))@ | ||
99 | build' :: IndexOf c -> (ArgOf c e) -> c e | ||
100 | -- | indexing function | ||
101 | atIndex :: c e -> IndexOf c -> e | ||
102 | -- | index of min element | ||
103 | minIndex :: c e -> IndexOf c | ||
104 | -- | index of max element | ||
105 | maxIndex :: c e -> IndexOf c | ||
106 | -- | value of min element | ||
107 | minElement :: c e -> e | ||
108 | -- | value of max element | ||
109 | maxElement :: c e -> e | ||
110 | -- the C functions sumX/prodX are twice as fast as using foldVector | ||
111 | -- | the sum of elements (faster than using @fold@) | ||
112 | sumElements :: c e -> e | ||
113 | -- | the product of elements (faster than using @fold@) | ||
114 | prodElements :: c e -> e | ||
115 | |||
116 | -- | A more efficient implementation of @cmap (\\x -> if x>0 then 1 else 0)@ | ||
117 | -- | ||
118 | -- >>> step $ linspace 5 (-1,1::Double) | ||
119 | -- 5 |> [0.0,0.0,0.0,1.0,1.0] | ||
120 | -- | ||
121 | |||
122 | step :: RealElement e => c e -> c e | ||
123 | |||
124 | -- | Element by element version of @case compare a b of {LT -> l; EQ -> e; GT -> g}@. | ||
125 | -- | ||
126 | -- Arguments with any dimension = 1 are automatically expanded: | ||
127 | -- | ||
128 | -- >>> cond ((1><4)[1..]) ((3><1)[1..]) 0 100 ((3><4)[1..]) :: Matrix Double | ||
129 | -- (3><4) | ||
130 | -- [ 100.0, 2.0, 3.0, 4.0 | ||
131 | -- , 0.0, 100.0, 7.0, 8.0 | ||
132 | -- , 0.0, 0.0, 100.0, 12.0 ] | ||
133 | -- | ||
134 | |||
135 | cond :: RealElement e | ||
136 | => c e -- ^ a | ||
137 | -> c e -- ^ b | ||
138 | -> c e -- ^ l | ||
139 | -> c e -- ^ e | ||
140 | -> c e -- ^ g | ||
141 | -> c e -- ^ result | ||
142 | |||
143 | -- | Find index of elements which satisfy a predicate | ||
144 | -- | ||
145 | -- >>> find (>0) (ident 3 :: Matrix Double) | ||
146 | -- [(0,0),(1,1),(2,2)] | ||
147 | -- | ||
148 | |||
149 | find :: (e -> Bool) -> c e -> [IndexOf c] | ||
150 | |||
151 | -- | Create a structure from an association list | ||
152 | -- | ||
153 | -- >>> assoc 5 0 [(3,7),(1,4)] :: Vector Double | ||
154 | -- fromList [0.0,4.0,0.0,7.0,0.0] | ||
155 | -- | ||
156 | -- >>> assoc (2,3) 0 [((0,2),7),((1,0),2*i-3)] :: Matrix (Complex Double) | ||
157 | -- (2><3) | ||
158 | -- [ 0.0 :+ 0.0, 0.0 :+ 0.0, 7.0 :+ 0.0 | ||
159 | -- , (-3.0) :+ 2.0, 0.0 :+ 0.0, 0.0 :+ 0.0 ] | ||
160 | -- | ||
161 | assoc :: IndexOf c -- ^ size | ||
162 | -> e -- ^ default value | ||
163 | -> [(IndexOf c, e)] -- ^ association list | ||
164 | -> c e -- ^ result | ||
165 | |||
166 | -- | Modify a structure using an update function | ||
167 | -- | ||
168 | -- >>> accum (ident 5) (+) [((1,1),5),((0,3),3)] :: Matrix Double | ||
169 | -- (5><5) | ||
170 | -- [ 1.0, 0.0, 0.0, 3.0, 0.0 | ||
171 | -- , 0.0, 6.0, 0.0, 0.0, 0.0 | ||
172 | -- , 0.0, 0.0, 1.0, 0.0, 0.0 | ||
173 | -- , 0.0, 0.0, 0.0, 1.0, 0.0 | ||
174 | -- , 0.0, 0.0, 0.0, 0.0, 1.0 ] | ||
175 | -- | ||
176 | -- computation of histogram: | ||
177 | -- | ||
178 | -- >>> accum (konst 0 7) (+) (map (flip (,) 1) [4,5,4,1,5,2,5]) :: Vector Double | ||
179 | -- fromList [0.0,1.0,1.0,0.0,2.0,3.0,0.0] | ||
180 | -- | ||
181 | |||
182 | accum :: c e -- ^ initial structure | ||
183 | -> (e -> e -> e) -- ^ update function | ||
184 | -> [(IndexOf c, e)] -- ^ association list | ||
185 | -> c e -- ^ result | ||
186 | |||
187 | -------------------------------------------------------------------------- | ||
188 | |||
189 | instance Container Vector Float where | ||
190 | scale = vectorMapValF Scale | ||
191 | scaleRecip = vectorMapValF Recip | ||
192 | addConstant = vectorMapValF AddConstant | ||
193 | add = vectorZipF Add | ||
194 | sub = vectorZipF Sub | ||
195 | mul = vectorZipF Mul | ||
196 | divide = vectorZipF Div | ||
197 | equal u v = dim u == dim v && maxElement (vectorMapF Abs (sub u v)) == 0.0 | ||
198 | arctan2 = vectorZipF ATan2 | ||
199 | scalar x = fromList [x] | ||
200 | konst' = constant | ||
201 | build' = buildV | ||
202 | conj = id | ||
203 | cmap = mapVector | ||
204 | atIndex = (@>) | ||
205 | minIndex = emptyErrorV "minIndex" (round . toScalarF MinIdx) | ||
206 | maxIndex = emptyErrorV "maxIndex" (round . toScalarF MaxIdx) | ||
207 | minElement = emptyErrorV "minElement" (toScalarF Min) | ||
208 | maxElement = emptyErrorV "maxElement" (toScalarF Max) | ||
209 | sumElements = sumF | ||
210 | prodElements = prodF | ||
211 | step = stepF | ||
212 | find = findV | ||
213 | assoc = assocV | ||
214 | accum = accumV | ||
215 | cond = condV condF | ||
216 | |||
217 | instance Container Vector Double where | ||
218 | scale = vectorMapValR Scale | ||
219 | scaleRecip = vectorMapValR Recip | ||
220 | addConstant = vectorMapValR AddConstant | ||
221 | add = vectorZipR Add | ||
222 | sub = vectorZipR Sub | ||
223 | mul = vectorZipR Mul | ||
224 | divide = vectorZipR Div | ||
225 | equal u v = dim u == dim v && maxElement (vectorMapR Abs (sub u v)) == 0.0 | ||
226 | arctan2 = vectorZipR ATan2 | ||
227 | scalar x = fromList [x] | ||
228 | konst' = constant | ||
229 | build' = buildV | ||
230 | conj = id | ||
231 | cmap = mapVector | ||
232 | atIndex = (@>) | ||
233 | minIndex = emptyErrorV "minIndex" (round . toScalarR MinIdx) | ||
234 | maxIndex = emptyErrorV "maxIndex" (round . toScalarR MaxIdx) | ||
235 | minElement = emptyErrorV "minElement" (toScalarR Min) | ||
236 | maxElement = emptyErrorV "maxElement" (toScalarR Max) | ||
237 | sumElements = sumR | ||
238 | prodElements = prodR | ||
239 | step = stepD | ||
240 | find = findV | ||
241 | assoc = assocV | ||
242 | accum = accumV | ||
243 | cond = condV condD | ||
244 | |||
245 | instance Container Vector (Complex Double) where | ||
246 | scale = vectorMapValC Scale | ||
247 | scaleRecip = vectorMapValC Recip | ||
248 | addConstant = vectorMapValC AddConstant | ||
249 | add = vectorZipC Add | ||
250 | sub = vectorZipC Sub | ||
251 | mul = vectorZipC Mul | ||
252 | divide = vectorZipC Div | ||
253 | equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 | ||
254 | arctan2 = vectorZipC ATan2 | ||
255 | scalar x = fromList [x] | ||
256 | konst' = constant | ||
257 | build' = buildV | ||
258 | conj = conjugateC | ||
259 | cmap = mapVector | ||
260 | atIndex = (@>) | ||
261 | minIndex = emptyErrorV "minIndex" (minIndex . fst . fromComplex . (mul <*> conj)) | ||
262 | maxIndex = emptyErrorV "maxIndex" (maxIndex . fst . fromComplex . (mul <*> conj)) | ||
263 | minElement = emptyErrorV "minElement" (atIndex <*> minIndex) | ||
264 | maxElement = emptyErrorV "maxElement" (atIndex <*> maxIndex) | ||
265 | sumElements = sumC | ||
266 | prodElements = prodC | ||
267 | step = undefined -- cannot match | ||
268 | find = findV | ||
269 | assoc = assocV | ||
270 | accum = accumV | ||
271 | cond = undefined -- cannot match | ||
272 | |||
273 | instance Container Vector (Complex Float) where | ||
274 | scale = vectorMapValQ Scale | ||
275 | scaleRecip = vectorMapValQ Recip | ||
276 | addConstant = vectorMapValQ AddConstant | ||
277 | add = vectorZipQ Add | ||
278 | sub = vectorZipQ Sub | ||
279 | mul = vectorZipQ Mul | ||
280 | divide = vectorZipQ Div | ||
281 | equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 | ||
282 | arctan2 = vectorZipQ ATan2 | ||
283 | scalar x = fromList [x] | ||
284 | konst' = constant | ||
285 | build' = buildV | ||
286 | conj = conjugateQ | ||
287 | cmap = mapVector | ||
288 | atIndex = (@>) | ||
289 | minIndex = emptyErrorV "minIndex" (minIndex . fst . fromComplex . (mul <*> conj)) | ||
290 | maxIndex = emptyErrorV "maxIndex" (maxIndex . fst . fromComplex . (mul <*> conj)) | ||
291 | minElement = emptyErrorV "minElement" (atIndex <*> minIndex) | ||
292 | maxElement = emptyErrorV "maxElement" (atIndex <*> maxIndex) | ||
293 | sumElements = sumQ | ||
294 | prodElements = prodQ | ||
295 | step = undefined -- cannot match | ||
296 | find = findV | ||
297 | assoc = assocV | ||
298 | accum = accumV | ||
299 | cond = undefined -- cannot match | ||
300 | |||
301 | --------------------------------------------------------------- | ||
302 | |||
303 | instance (Container Vector a) => Container Matrix a where | ||
304 | scale x = liftMatrix (scale x) | ||
305 | scaleRecip x = liftMatrix (scaleRecip x) | ||
306 | addConstant x = liftMatrix (addConstant x) | ||
307 | add = liftMatrix2 add | ||
308 | sub = liftMatrix2 sub | ||
309 | mul = liftMatrix2 mul | ||
310 | divide = liftMatrix2 divide | ||
311 | equal a b = cols a == cols b && flatten a `equal` flatten b | ||
312 | arctan2 = liftMatrix2 arctan2 | ||
313 | scalar x = (1><1) [x] | ||
314 | konst' v (r,c) = matrixFromVector RowMajor r c (konst' v (r*c)) | ||
315 | build' = buildM | ||
316 | conj = liftMatrix conj | ||
317 | cmap f = liftMatrix (mapVector f) | ||
318 | atIndex = (@@>) | ||
319 | minIndex = emptyErrorM "minIndex of Matrix" $ | ||
320 | \m -> divMod (minIndex $ flatten m) (cols m) | ||
321 | maxIndex = emptyErrorM "maxIndex of Matrix" $ | ||
322 | \m -> divMod (maxIndex $ flatten m) (cols m) | ||
323 | minElement = emptyErrorM "minElement of Matrix" (atIndex <*> minIndex) | ||
324 | maxElement = emptyErrorM "maxElement of Matrix" (atIndex <*> maxIndex) | ||
325 | sumElements = sumElements . flatten | ||
326 | prodElements = prodElements . flatten | ||
327 | step = liftMatrix step | ||
328 | find = findM | ||
329 | assoc = assocM | ||
330 | accum = accumM | ||
331 | cond = condM | ||
332 | |||
333 | |||
334 | emptyErrorV msg f v = | ||
335 | if dim v > 0 | ||
336 | then f v | ||
337 | else error $ msg ++ " of Vector with dim = 0" | ||
338 | |||
339 | emptyErrorM msg f m = | ||
340 | if rows m > 0 && cols m > 0 | ||
341 | then f m | ||
342 | else error $ msg++" "++shSize m | ||
343 | |||
344 | ---------------------------------------------------- | ||
345 | |||
346 | -- | Matrix product and related functions | ||
347 | class (Num e, Element e) => Product e where | ||
348 | -- | matrix product | ||
349 | multiply :: Matrix e -> Matrix e -> Matrix e | ||
350 | -- | sum of absolute value of elements (differs in complex case from @norm1@) | ||
351 | absSum :: Vector e -> RealOf e | ||
352 | -- | sum of absolute value of elements | ||
353 | norm1 :: Vector e -> RealOf e | ||
354 | -- | euclidean norm | ||
355 | norm2 :: Vector e -> RealOf e | ||
356 | -- | element of maximum magnitude | ||
357 | normInf :: Vector e -> RealOf e | ||
358 | |||
359 | instance Product Float where | ||
360 | norm2 = emptyVal (toScalarF Norm2) | ||
361 | absSum = emptyVal (toScalarF AbsSum) | ||
362 | norm1 = emptyVal (toScalarF AbsSum) | ||
363 | normInf = emptyVal (maxElement . vectorMapF Abs) | ||
364 | multiply = emptyMul multiplyF | ||
365 | |||
366 | instance Product Double where | ||
367 | norm2 = emptyVal (toScalarR Norm2) | ||
368 | absSum = emptyVal (toScalarR AbsSum) | ||
369 | norm1 = emptyVal (toScalarR AbsSum) | ||
370 | normInf = emptyVal (maxElement . vectorMapR Abs) | ||
371 | multiply = emptyMul multiplyR | ||
372 | |||
373 | instance Product (Complex Float) where | ||
374 | norm2 = emptyVal (toScalarQ Norm2) | ||
375 | absSum = emptyVal (toScalarQ AbsSum) | ||
376 | norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapQ Abs) | ||
377 | normInf = emptyVal (maxElement . fst . fromComplex . vectorMapQ Abs) | ||
378 | multiply = emptyMul multiplyQ | ||
379 | |||
380 | instance Product (Complex Double) where | ||
381 | norm2 = emptyVal (toScalarC Norm2) | ||
382 | absSum = emptyVal (toScalarC AbsSum) | ||
383 | norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapC Abs) | ||
384 | normInf = emptyVal (maxElement . fst . fromComplex . vectorMapC Abs) | ||
385 | multiply = emptyMul multiplyC | ||
386 | |||
387 | emptyMul m a b | ||
388 | | x1 == 0 && x2 == 0 || r == 0 || c == 0 = konst' 0 (r,c) | ||
389 | | otherwise = m a b | ||
390 | where | ||
391 | r = rows a | ||
392 | x1 = cols a | ||
393 | x2 = rows b | ||
394 | c = cols b | ||
395 | |||
396 | emptyVal f v = | ||
397 | if dim v > 0 | ||
398 | then f v | ||
399 | else 0 | ||
400 | |||
401 | -- FIXME remove unused C wrappers | ||
402 | -- | unconjugated dot product | ||
403 | udot :: Product e => Vector e -> Vector e -> e | ||
404 | udot u v | ||
405 | | dim u == dim v = val (asRow u `multiply` asColumn v) | ||
406 | | otherwise = error $ "different dimensions "++show (dim u)++" and "++show (dim v)++" in dot product" | ||
407 | where | ||
408 | val m | dim u > 0 = m@@>(0,0) | ||
409 | | otherwise = 0 | ||
410 | |||
411 | ---------------------------------------------------------- | ||
412 | |||
413 | -- synonym for matrix product | ||
414 | mXm :: Product t => Matrix t -> Matrix t -> Matrix t | ||
415 | mXm = multiply | ||
416 | |||
417 | -- matrix - vector product | ||
418 | mXv :: Product t => Matrix t -> Vector t -> Vector t | ||
419 | mXv m v = flatten $ m `mXm` (asColumn v) | ||
420 | |||
421 | -- vector - matrix product | ||
422 | vXm :: Product t => Vector t -> Matrix t -> Vector t | ||
423 | vXm v m = flatten $ (asRow v) `mXm` m | ||
424 | |||
425 | {- | Outer product of two vectors. | ||
426 | |||
427 | >>> fromList [1,2,3] `outer` fromList [5,2,3] | ||
428 | (3><3) | ||
429 | [ 5.0, 2.0, 3.0 | ||
430 | , 10.0, 4.0, 6.0 | ||
431 | , 15.0, 6.0, 9.0 ] | ||
432 | |||
433 | -} | ||
434 | outer :: (Product t) => Vector t -> Vector t -> Matrix t | ||
435 | outer u v = asColumn u `multiply` asRow v | ||
436 | |||
437 | {- | Kronecker product of two matrices. | ||
438 | |||
439 | @m1=(2><3) | ||
440 | [ 1.0, 2.0, 0.0 | ||
441 | , 0.0, -1.0, 3.0 ] | ||
442 | m2=(4><3) | ||
443 | [ 1.0, 2.0, 3.0 | ||
444 | , 4.0, 5.0, 6.0 | ||
445 | , 7.0, 8.0, 9.0 | ||
446 | , 10.0, 11.0, 12.0 ]@ | ||
447 | |||
448 | >>> kronecker m1 m2 | ||
449 | (8><9) | ||
450 | [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0 | ||
451 | , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0 | ||
452 | , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0 | ||
453 | , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0 | ||
454 | , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0 | ||
455 | , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0 | ||
456 | , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0 | ||
457 | , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ] | ||
458 | |||
459 | -} | ||
460 | kronecker :: (Product t) => Matrix t -> Matrix t -> Matrix t | ||
461 | kronecker a b = fromBlocks | ||
462 | . splitEvery (cols a) | ||
463 | . map (reshape (cols b)) | ||
464 | . toRows | ||
465 | $ flatten a `outer` flatten b | ||
466 | |||
467 | ------------------------------------------------------------------- | ||
468 | |||
469 | |||
470 | class Convert t where | ||
471 | real :: Container c t => c (RealOf t) -> c t | ||
472 | complex :: Container c t => c t -> c (ComplexOf t) | ||
473 | single :: Container c t => c t -> c (SingleOf t) | ||
474 | double :: Container c t => c t -> c (DoubleOf t) | ||
475 | toComplex :: (Container c t, RealElement t) => (c t, c t) -> c (Complex t) | ||
476 | fromComplex :: (Container c t, RealElement t) => c (Complex t) -> (c t, c t) | ||
477 | |||
478 | |||
479 | instance Convert Double where | ||
480 | real = id | ||
481 | complex = comp' | ||
482 | single = single' | ||
483 | double = id | ||
484 | toComplex = toComplex' | ||
485 | fromComplex = fromComplex' | ||
486 | |||
487 | instance Convert Float where | ||
488 | real = id | ||
489 | complex = comp' | ||
490 | single = id | ||
491 | double = double' | ||
492 | toComplex = toComplex' | ||
493 | fromComplex = fromComplex' | ||
494 | |||
495 | instance Convert (Complex Double) where | ||
496 | real = comp' | ||
497 | complex = id | ||
498 | single = single' | ||
499 | double = id | ||
500 | toComplex = toComplex' | ||
501 | fromComplex = fromComplex' | ||
502 | |||
503 | instance Convert (Complex Float) where | ||
504 | real = comp' | ||
505 | complex = id | ||
506 | single = id | ||
507 | double = double' | ||
508 | toComplex = toComplex' | ||
509 | fromComplex = fromComplex' | ||
510 | |||
511 | ------------------------------------------------------------------- | ||
512 | |||
513 | type family RealOf x | ||
514 | |||
515 | type instance RealOf Double = Double | ||
516 | type instance RealOf (Complex Double) = Double | ||
517 | |||
518 | type instance RealOf Float = Float | ||
519 | type instance RealOf (Complex Float) = Float | ||
520 | |||
521 | type family ComplexOf x | ||
522 | |||
523 | type instance ComplexOf Double = Complex Double | ||
524 | type instance ComplexOf (Complex Double) = Complex Double | ||
525 | |||
526 | type instance ComplexOf Float = Complex Float | ||
527 | type instance ComplexOf (Complex Float) = Complex Float | ||
528 | |||
529 | type family SingleOf x | ||
530 | |||
531 | type instance SingleOf Double = Float | ||
532 | type instance SingleOf Float = Float | ||
533 | |||
534 | type instance SingleOf (Complex a) = Complex (SingleOf a) | ||
535 | |||
536 | type family DoubleOf x | ||
537 | |||
538 | type instance DoubleOf Double = Double | ||
539 | type instance DoubleOf Float = Double | ||
540 | |||
541 | type instance DoubleOf (Complex a) = Complex (DoubleOf a) | ||
542 | |||
543 | type family ElementOf c | ||
544 | |||
545 | type instance ElementOf (Vector a) = a | ||
546 | type instance ElementOf (Matrix a) = a | ||
547 | |||
548 | ------------------------------------------------------------ | ||
549 | |||
550 | buildM (rc,cc) f = fromLists [ [f r c | c <- cs] | r <- rs ] | ||
551 | where rs = map fromIntegral [0 .. (rc-1)] | ||
552 | cs = map fromIntegral [0 .. (cc-1)] | ||
553 | |||
554 | buildV n f = fromList [f k | k <- ks] | ||
555 | where ks = map fromIntegral [0 .. (n-1)] | ||
556 | |||
557 | -------------------------------------------------------- | ||
558 | -- | conjugate transpose | ||
559 | ctrans :: (Container Vector e, Element e) => Matrix e -> Matrix e | ||
560 | ctrans = liftMatrix conj . trans | ||
561 | |||
562 | -- | Creates a square matrix with a given diagonal. | ||
563 | diag :: (Num a, Element a) => Vector a -> Matrix a | ||
564 | diag v = diagRect 0 v n n where n = dim v | ||
565 | |||
566 | -- | creates the identity matrix of given dimension | ||
567 | ident :: (Num a, Element a) => Int -> Matrix a | ||
568 | ident n = diag (constant 1 n) | ||
569 | |||
570 | -------------------------------------------------------- | ||
571 | |||
572 | findV p x = foldVectorWithIndex g [] x where | ||
573 | g k z l = if p z then k:l else l | ||
574 | |||
575 | findM p x = map ((`divMod` cols x)) $ findV p (flatten x) | ||
576 | |||
577 | assocV n z xs = ST.runSTVector $ do | ||
578 | v <- ST.newVector z n | ||
579 | mapM_ (\(k,x) -> ST.writeVector v k x) xs | ||
580 | return v | ||
581 | |||
582 | assocM (r,c) z xs = ST.runSTMatrix $ do | ||
583 | m <- ST.newMatrix z r c | ||
584 | mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs | ||
585 | return m | ||
586 | |||
587 | accumV v0 f xs = ST.runSTVector $ do | ||
588 | v <- ST.thawVector v0 | ||
589 | mapM_ (\(k,x) -> ST.modifyVector v k (f x)) xs | ||
590 | return v | ||
591 | |||
592 | accumM m0 f xs = ST.runSTMatrix $ do | ||
593 | m <- ST.thawMatrix m0 | ||
594 | mapM_ (\((i,j),x) -> ST.modifyMatrix m i j (f x)) xs | ||
595 | return m | ||
596 | |||
597 | ---------------------------------------------------------------------- | ||
598 | |||
599 | condM a b l e t = matrixFromVector RowMajor (rows a'') (cols a'') $ cond a' b' l' e' t' | ||
600 | where | ||
601 | args@(a'':_) = conformMs [a,b,l,e,t] | ||
602 | [a', b', l', e', t'] = map flatten args | ||
603 | |||
604 | condV f a b l e t = f a' b' l' e' t' | ||
605 | where | ||
606 | [a', b', l', e', t'] = conformVs [a,b,l,e,t] | ||
607 | |||
diff --git a/packages/base/src/Data/Packed/Numeric.hs b/packages/base/src/Data/Packed/Numeric.hs index c13e91d..6036e8c 100644 --- a/packages/base/src/Data/Packed/Numeric.hs +++ b/packages/base/src/Data/Packed/Numeric.hs | |||
@@ -1,8 +1,8 @@ | |||
1 | {-# LANGUAGE CPP #-} | ||
2 | {-# LANGUAGE TypeFamilies #-} | 1 | {-# LANGUAGE TypeFamilies #-} |
3 | {-# LANGUAGE FlexibleContexts #-} | 2 | {-# LANGUAGE FlexibleContexts #-} |
4 | {-# LANGUAGE FlexibleInstances #-} | 3 | {-# LANGUAGE FlexibleInstances #-} |
5 | {-# LANGUAGE MultiParamTypeClasses #-} | 4 | {-# LANGUAGE MultiParamTypeClasses #-} |
5 | {-# LANGUAGE FunctionalDependencies #-} | ||
6 | {-# LANGUAGE UndecidableInstances #-} | 6 | {-# LANGUAGE UndecidableInstances #-} |
7 | 7 | ||
8 | ----------------------------------------------------------------------------- | 8 | ----------------------------------------------------------------------------- |
@@ -13,16 +13,32 @@ | |||
13 | -- Maintainer : Alberto Ruiz | 13 | -- Maintainer : Alberto Ruiz |
14 | -- Stability : provisional | 14 | -- Stability : provisional |
15 | -- | 15 | -- |
16 | -- Basic numeric operations on 'Vector' and 'Matrix', including conversion routines. | ||
17 | -- | ||
18 | -- The 'Container' class is used to define optimized generic functions which work | ||
19 | -- on 'Vector' and 'Matrix' with real or complex elements. | ||
20 | -- | ||
21 | -- Some of these functions are also available in the instances of the standard | ||
22 | -- numeric Haskell classes provided by "Numeric.LinearAlgebra". | ||
23 | -- | ||
16 | ----------------------------------------------------------------------------- | 24 | ----------------------------------------------------------------------------- |
25 | {-# OPTIONS_HADDOCK hide #-} | ||
17 | 26 | ||
18 | module Data.Packed.Numeric ( | 27 | module Data.Packed.Numeric ( |
19 | -- * Basic functions | 28 | -- * Basic functions |
20 | ident, diag, ctrans, | 29 | module Data.Packed, |
30 | konst, build, | ||
31 | linspace, | ||
32 | diag, ident, | ||
33 | ctrans, | ||
21 | -- * Generic operations | 34 | -- * Generic operations |
22 | Container(..), | 35 | Container(..), |
23 | -- * Matrix product and related functions | 36 | -- * Matrix product |
24 | Product(..), udot, | 37 | Product(..), udot, dot, (◇), |
25 | mXm,mXv,vXm, | 38 | Mul(..), |
39 | Contraction(..), | ||
40 | optimiseMult, | ||
41 | mXm,mXv,vXm,LSDiv(..), | ||
26 | outer, kronecker, | 42 | outer, kronecker, |
27 | -- * Element conversion | 43 | -- * Element conversion |
28 | Convert(..), | 44 | Convert(..), |
@@ -32,576 +48,196 @@ module Data.Packed.Numeric ( | |||
32 | RealOf, ComplexOf, SingleOf, DoubleOf, | 48 | RealOf, ComplexOf, SingleOf, DoubleOf, |
33 | 49 | ||
34 | IndexOf, | 50 | IndexOf, |
35 | module Data.Complex | 51 | module Data.Complex, |
52 | -- * IO | ||
53 | module Data.Packed.IO | ||
36 | ) where | 54 | ) where |
37 | 55 | ||
38 | import Data.Packed | 56 | import Data.Packed hiding (stepD, stepF, condD, condF, conjugateC, conjugateQ) |
39 | import Data.Packed.ST as ST | 57 | import Data.Packed.Internal.Numeric |
40 | import Numeric.Conversion | ||
41 | import Data.Packed.Development | ||
42 | import Numeric.Vectorized | ||
43 | import Data.Complex | 58 | import Data.Complex |
44 | import Control.Applicative((<*>)) | 59 | import Numeric.LinearAlgebra.Algorithms(Field,linearSolveSVD) |
45 | 60 | import Data.Monoid(Monoid(mconcat)) | |
46 | import Numeric.LinearAlgebra.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ) | 61 | import Data.Packed.IO |
47 | 62 | ||
48 | ------------------------------------------------------------------- | 63 | ------------------------------------------------------------------ |
49 | 64 | ||
50 | type family IndexOf (c :: * -> *) | 65 | {- | Creates a real vector containing a range of values: |
51 | 66 | ||
52 | type instance IndexOf Vector = Int | 67 | >>> linspace 5 (-3,7::Double) |
53 | type instance IndexOf Matrix = (Int,Int) | 68 | fromList [-3.0,-0.5,2.0,4.5,7.0]@ |
54 | 69 | ||
55 | type family ArgOf (c :: * -> *) a | 70 | >>> linspace 5 (8,2+i) :: Vector (Complex Double) |
71 | fromList [8.0 :+ 0.0,6.5 :+ 0.25,5.0 :+ 0.5,3.5 :+ 0.75,2.0 :+ 1.0] | ||
56 | 72 | ||
57 | type instance ArgOf Vector a = a -> a | 73 | Logarithmic spacing can be defined as follows: |
58 | type instance ArgOf Matrix a = a -> a -> a | ||
59 | 74 | ||
60 | ------------------------------------------------------------------- | 75 | @logspace n (a,b) = 10 ** linspace n (a,b)@ |
61 | 76 | -} | |
62 | -- | Basic element-by-element functions for numeric containers | 77 | linspace :: (Container Vector e) => Int -> (e, e) -> Vector e |
63 | class (Complexable c, Fractional e, Element e) => Container c e where | 78 | linspace 0 (a,b) = fromList[(a+b)/2] |
64 | -- | create a structure with a single element | 79 | linspace n (a,b) = addConstant a $ scale s $ fromList $ map fromIntegral [0 .. n-1] |
65 | -- | 80 | where s = (b-a)/fromIntegral (n-1) |
66 | -- >>> let v = fromList [1..3::Double] | ||
67 | -- >>> v / scalar (norm2 v) | ||
68 | -- fromList [0.2672612419124244,0.5345224838248488,0.8017837257372732] | ||
69 | -- | ||
70 | scalar :: e -> c e | ||
71 | -- | complex conjugate | ||
72 | conj :: c e -> c e | ||
73 | scale :: e -> c e -> c e | ||
74 | -- | scale the element by element reciprocal of the object: | ||
75 | -- | ||
76 | -- @scaleRecip 2 (fromList [5,i]) == 2 |> [0.4 :+ 0.0,0.0 :+ (-2.0)]@ | ||
77 | scaleRecip :: e -> c e -> c e | ||
78 | addConstant :: e -> c e -> c e | ||
79 | add :: c e -> c e -> c e | ||
80 | sub :: c e -> c e -> c e | ||
81 | -- | element by element multiplication | ||
82 | mul :: c e -> c e -> c e | ||
83 | -- | element by element division | ||
84 | divide :: c e -> c e -> c e | ||
85 | equal :: c e -> c e -> Bool | ||
86 | -- | ||
87 | -- element by element inverse tangent | ||
88 | arctan2 :: c e -> c e -> c e | ||
89 | -- | ||
90 | -- | cannot implement instance Functor because of Element class constraint | ||
91 | cmap :: (Element b) => (e -> b) -> c e -> c b | ||
92 | -- | constant structure of given size | ||
93 | konst' :: e -> IndexOf c -> c e | ||
94 | -- | create a structure using a function | ||
95 | -- | ||
96 | -- Hilbert matrix of order N: | ||
97 | -- | ||
98 | -- @hilb n = build' (n,n) (\\i j -> 1/(i+j+1))@ | ||
99 | build' :: IndexOf c -> (ArgOf c e) -> c e | ||
100 | -- | indexing function | ||
101 | atIndex :: c e -> IndexOf c -> e | ||
102 | -- | index of min element | ||
103 | minIndex :: c e -> IndexOf c | ||
104 | -- | index of max element | ||
105 | maxIndex :: c e -> IndexOf c | ||
106 | -- | value of min element | ||
107 | minElement :: c e -> e | ||
108 | -- | value of max element | ||
109 | maxElement :: c e -> e | ||
110 | -- the C functions sumX/prodX are twice as fast as using foldVector | ||
111 | -- | the sum of elements (faster than using @fold@) | ||
112 | sumElements :: c e -> e | ||
113 | -- | the product of elements (faster than using @fold@) | ||
114 | prodElements :: c e -> e | ||
115 | |||
116 | -- | A more efficient implementation of @cmap (\\x -> if x>0 then 1 else 0)@ | ||
117 | -- | ||
118 | -- >>> step $ linspace 5 (-1,1::Double) | ||
119 | -- 5 |> [0.0,0.0,0.0,1.0,1.0] | ||
120 | -- | ||
121 | |||
122 | step :: RealElement e => c e -> c e | ||
123 | |||
124 | -- | Element by element version of @case compare a b of {LT -> l; EQ -> e; GT -> g}@. | ||
125 | -- | ||
126 | -- Arguments with any dimension = 1 are automatically expanded: | ||
127 | -- | ||
128 | -- >>> cond ((1><4)[1..]) ((3><1)[1..]) 0 100 ((3><4)[1..]) :: Matrix Double | ||
129 | -- (3><4) | ||
130 | -- [ 100.0, 2.0, 3.0, 4.0 | ||
131 | -- , 0.0, 100.0, 7.0, 8.0 | ||
132 | -- , 0.0, 0.0, 100.0, 12.0 ] | ||
133 | -- | ||
134 | |||
135 | cond :: RealElement e | ||
136 | => c e -- ^ a | ||
137 | -> c e -- ^ b | ||
138 | -> c e -- ^ l | ||
139 | -> c e -- ^ e | ||
140 | -> c e -- ^ g | ||
141 | -> c e -- ^ result | ||
142 | |||
143 | -- | Find index of elements which satisfy a predicate | ||
144 | -- | ||
145 | -- >>> find (>0) (ident 3 :: Matrix Double) | ||
146 | -- [(0,0),(1,1),(2,2)] | ||
147 | -- | ||
148 | |||
149 | find :: (e -> Bool) -> c e -> [IndexOf c] | ||
150 | 81 | ||
151 | -- | Create a structure from an association list | 82 | -------------------------------------------------------- |
152 | -- | ||
153 | -- >>> assoc 5 0 [(3,7),(1,4)] :: Vector Double | ||
154 | -- fromList [0.0,4.0,0.0,7.0,0.0] | ||
155 | -- | ||
156 | -- >>> assoc (2,3) 0 [((0,2),7),((1,0),2*i-3)] :: Matrix (Complex Double) | ||
157 | -- (2><3) | ||
158 | -- [ 0.0 :+ 0.0, 0.0 :+ 0.0, 7.0 :+ 0.0 | ||
159 | -- , (-3.0) :+ 2.0, 0.0 :+ 0.0, 0.0 :+ 0.0 ] | ||
160 | -- | ||
161 | assoc :: IndexOf c -- ^ size | ||
162 | -> e -- ^ default value | ||
163 | -> [(IndexOf c, e)] -- ^ association list | ||
164 | -> c e -- ^ result | ||
165 | 83 | ||
166 | -- | Modify a structure using an update function | 84 | class Contraction a b c | a b -> c |
167 | -- | ||
168 | -- >>> accum (ident 5) (+) [((1,1),5),((0,3),3)] :: Matrix Double | ||
169 | -- (5><5) | ||
170 | -- [ 1.0, 0.0, 0.0, 3.0, 0.0 | ||
171 | -- , 0.0, 6.0, 0.0, 0.0, 0.0 | ||
172 | -- , 0.0, 0.0, 1.0, 0.0, 0.0 | ||
173 | -- , 0.0, 0.0, 0.0, 1.0, 0.0 | ||
174 | -- , 0.0, 0.0, 0.0, 0.0, 1.0 ] | ||
175 | -- | ||
176 | -- computation of histogram: | ||
177 | -- | ||
178 | -- >>> accum (konst 0 7) (+) (map (flip (,) 1) [4,5,4,1,5,2,5]) :: Vector Double | ||
179 | -- fromList [0.0,1.0,1.0,0.0,2.0,3.0,0.0] | ||
180 | -- | ||
181 | |||
182 | accum :: c e -- ^ initial structure | ||
183 | -> (e -> e -> e) -- ^ update function | ||
184 | -> [(IndexOf c, e)] -- ^ association list | ||
185 | -> c e -- ^ result | ||
186 | |||
187 | -------------------------------------------------------------------------- | ||
188 | |||
189 | instance Container Vector Float where | ||
190 | scale = vectorMapValF Scale | ||
191 | scaleRecip = vectorMapValF Recip | ||
192 | addConstant = vectorMapValF AddConstant | ||
193 | add = vectorZipF Add | ||
194 | sub = vectorZipF Sub | ||
195 | mul = vectorZipF Mul | ||
196 | divide = vectorZipF Div | ||
197 | equal u v = dim u == dim v && maxElement (vectorMapF Abs (sub u v)) == 0.0 | ||
198 | arctan2 = vectorZipF ATan2 | ||
199 | scalar x = fromList [x] | ||
200 | konst' = constant | ||
201 | build' = buildV | ||
202 | conj = id | ||
203 | cmap = mapVector | ||
204 | atIndex = (@>) | ||
205 | minIndex = emptyErrorV "minIndex" (round . toScalarF MinIdx) | ||
206 | maxIndex = emptyErrorV "maxIndex" (round . toScalarF MaxIdx) | ||
207 | minElement = emptyErrorV "minElement" (toScalarF Min) | ||
208 | maxElement = emptyErrorV "maxElement" (toScalarF Max) | ||
209 | sumElements = sumF | ||
210 | prodElements = prodF | ||
211 | step = stepF | ||
212 | find = findV | ||
213 | assoc = assocV | ||
214 | accum = accumV | ||
215 | cond = condV condF | ||
216 | |||
217 | instance Container Vector Double where | ||
218 | scale = vectorMapValR Scale | ||
219 | scaleRecip = vectorMapValR Recip | ||
220 | addConstant = vectorMapValR AddConstant | ||
221 | add = vectorZipR Add | ||
222 | sub = vectorZipR Sub | ||
223 | mul = vectorZipR Mul | ||
224 | divide = vectorZipR Div | ||
225 | equal u v = dim u == dim v && maxElement (vectorMapR Abs (sub u v)) == 0.0 | ||
226 | arctan2 = vectorZipR ATan2 | ||
227 | scalar x = fromList [x] | ||
228 | konst' = constant | ||
229 | build' = buildV | ||
230 | conj = id | ||
231 | cmap = mapVector | ||
232 | atIndex = (@>) | ||
233 | minIndex = emptyErrorV "minIndex" (round . toScalarR MinIdx) | ||
234 | maxIndex = emptyErrorV "maxIndex" (round . toScalarR MaxIdx) | ||
235 | minElement = emptyErrorV "minElement" (toScalarR Min) | ||
236 | maxElement = emptyErrorV "maxElement" (toScalarR Max) | ||
237 | sumElements = sumR | ||
238 | prodElements = prodR | ||
239 | step = stepD | ||
240 | find = findV | ||
241 | assoc = assocV | ||
242 | accum = accumV | ||
243 | cond = condV condD | ||
244 | |||
245 | instance Container Vector (Complex Double) where | ||
246 | scale = vectorMapValC Scale | ||
247 | scaleRecip = vectorMapValC Recip | ||
248 | addConstant = vectorMapValC AddConstant | ||
249 | add = vectorZipC Add | ||
250 | sub = vectorZipC Sub | ||
251 | mul = vectorZipC Mul | ||
252 | divide = vectorZipC Div | ||
253 | equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 | ||
254 | arctan2 = vectorZipC ATan2 | ||
255 | scalar x = fromList [x] | ||
256 | konst' = constant | ||
257 | build' = buildV | ||
258 | conj = conjugateC | ||
259 | cmap = mapVector | ||
260 | atIndex = (@>) | ||
261 | minIndex = emptyErrorV "minIndex" (minIndex . fst . fromComplex . (mul <*> conj)) | ||
262 | maxIndex = emptyErrorV "maxIndex" (maxIndex . fst . fromComplex . (mul <*> conj)) | ||
263 | minElement = emptyErrorV "minElement" (atIndex <*> minIndex) | ||
264 | maxElement = emptyErrorV "maxElement" (atIndex <*> maxIndex) | ||
265 | sumElements = sumC | ||
266 | prodElements = prodC | ||
267 | step = undefined -- cannot match | ||
268 | find = findV | ||
269 | assoc = assocV | ||
270 | accum = accumV | ||
271 | cond = undefined -- cannot match | ||
272 | |||
273 | instance Container Vector (Complex Float) where | ||
274 | scale = vectorMapValQ Scale | ||
275 | scaleRecip = vectorMapValQ Recip | ||
276 | addConstant = vectorMapValQ AddConstant | ||
277 | add = vectorZipQ Add | ||
278 | sub = vectorZipQ Sub | ||
279 | mul = vectorZipQ Mul | ||
280 | divide = vectorZipQ Div | ||
281 | equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 | ||
282 | arctan2 = vectorZipQ ATan2 | ||
283 | scalar x = fromList [x] | ||
284 | konst' = constant | ||
285 | build' = buildV | ||
286 | conj = conjugateQ | ||
287 | cmap = mapVector | ||
288 | atIndex = (@>) | ||
289 | minIndex = emptyErrorV "minIndex" (minIndex . fst . fromComplex . (mul <*> conj)) | ||
290 | maxIndex = emptyErrorV "maxIndex" (maxIndex . fst . fromComplex . (mul <*> conj)) | ||
291 | minElement = emptyErrorV "minElement" (atIndex <*> minIndex) | ||
292 | maxElement = emptyErrorV "maxElement" (atIndex <*> maxIndex) | ||
293 | sumElements = sumQ | ||
294 | prodElements = prodQ | ||
295 | step = undefined -- cannot match | ||
296 | find = findV | ||
297 | assoc = assocV | ||
298 | accum = accumV | ||
299 | cond = undefined -- cannot match | ||
300 | |||
301 | --------------------------------------------------------------- | ||
302 | |||
303 | instance (Container Vector a) => Container Matrix a where | ||
304 | scale x = liftMatrix (scale x) | ||
305 | scaleRecip x = liftMatrix (scaleRecip x) | ||
306 | addConstant x = liftMatrix (addConstant x) | ||
307 | add = liftMatrix2 add | ||
308 | sub = liftMatrix2 sub | ||
309 | mul = liftMatrix2 mul | ||
310 | divide = liftMatrix2 divide | ||
311 | equal a b = cols a == cols b && flatten a `equal` flatten b | ||
312 | arctan2 = liftMatrix2 arctan2 | ||
313 | scalar x = (1><1) [x] | ||
314 | konst' v (r,c) = matrixFromVector RowMajor r c (konst' v (r*c)) | ||
315 | build' = buildM | ||
316 | conj = liftMatrix conj | ||
317 | cmap f = liftMatrix (mapVector f) | ||
318 | atIndex = (@@>) | ||
319 | minIndex = emptyErrorM "minIndex of Matrix" $ | ||
320 | \m -> divMod (minIndex $ flatten m) (cols m) | ||
321 | maxIndex = emptyErrorM "maxIndex of Matrix" $ | ||
322 | \m -> divMod (maxIndex $ flatten m) (cols m) | ||
323 | minElement = emptyErrorM "minElement of Matrix" (atIndex <*> minIndex) | ||
324 | maxElement = emptyErrorM "maxElement of Matrix" (atIndex <*> maxIndex) | ||
325 | sumElements = sumElements . flatten | ||
326 | prodElements = prodElements . flatten | ||
327 | step = liftMatrix step | ||
328 | find = findM | ||
329 | assoc = assocM | ||
330 | accum = accumM | ||
331 | cond = condM | ||
332 | |||
333 | |||
334 | emptyErrorV msg f v = | ||
335 | if dim v > 0 | ||
336 | then f v | ||
337 | else error $ msg ++ " of Vector with dim = 0" | ||
338 | |||
339 | emptyErrorM msg f m = | ||
340 | if rows m > 0 && cols m > 0 | ||
341 | then f m | ||
342 | else error $ msg++" "++shSize m | ||
343 | |||
344 | ---------------------------------------------------- | ||
345 | |||
346 | -- | Matrix product and related functions | ||
347 | class (Num e, Element e) => Product e where | ||
348 | -- | matrix product | ||
349 | multiply :: Matrix e -> Matrix e -> Matrix e | ||
350 | -- | sum of absolute value of elements (differs in complex case from @norm1@) | ||
351 | absSum :: Vector e -> RealOf e | ||
352 | -- | sum of absolute value of elements | ||
353 | norm1 :: Vector e -> RealOf e | ||
354 | -- | euclidean norm | ||
355 | norm2 :: Vector e -> RealOf e | ||
356 | -- | element of maximum magnitude | ||
357 | normInf :: Vector e -> RealOf e | ||
358 | |||
359 | instance Product Float where | ||
360 | norm2 = emptyVal (toScalarF Norm2) | ||
361 | absSum = emptyVal (toScalarF AbsSum) | ||
362 | norm1 = emptyVal (toScalarF AbsSum) | ||
363 | normInf = emptyVal (maxElement . vectorMapF Abs) | ||
364 | multiply = emptyMul multiplyF | ||
365 | |||
366 | instance Product Double where | ||
367 | norm2 = emptyVal (toScalarR Norm2) | ||
368 | absSum = emptyVal (toScalarR AbsSum) | ||
369 | norm1 = emptyVal (toScalarR AbsSum) | ||
370 | normInf = emptyVal (maxElement . vectorMapR Abs) | ||
371 | multiply = emptyMul multiplyR | ||
372 | |||
373 | instance Product (Complex Float) where | ||
374 | norm2 = emptyVal (toScalarQ Norm2) | ||
375 | absSum = emptyVal (toScalarQ AbsSum) | ||
376 | norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapQ Abs) | ||
377 | normInf = emptyVal (maxElement . fst . fromComplex . vectorMapQ Abs) | ||
378 | multiply = emptyMul multiplyQ | ||
379 | |||
380 | instance Product (Complex Double) where | ||
381 | norm2 = emptyVal (toScalarC Norm2) | ||
382 | absSum = emptyVal (toScalarC AbsSum) | ||
383 | norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapC Abs) | ||
384 | normInf = emptyVal (maxElement . fst . fromComplex . vectorMapC Abs) | ||
385 | multiply = emptyMul multiplyC | ||
386 | |||
387 | emptyMul m a b | ||
388 | | x1 == 0 && x2 == 0 || r == 0 || c == 0 = konst' 0 (r,c) | ||
389 | | otherwise = m a b | ||
390 | where | ||
391 | r = rows a | ||
392 | x1 = cols a | ||
393 | x2 = rows b | ||
394 | c = cols b | ||
395 | |||
396 | emptyVal f v = | ||
397 | if dim v > 0 | ||
398 | then f v | ||
399 | else 0 | ||
400 | |||
401 | -- FIXME remove unused C wrappers | ||
402 | -- | unconjugated dot product | ||
403 | udot :: Product e => Vector e -> Vector e -> e | ||
404 | udot u v | ||
405 | | dim u == dim v = val (asRow u `multiply` asColumn v) | ||
406 | | otherwise = error $ "different dimensions "++show (dim u)++" and "++show (dim v)++" in dot product" | ||
407 | where | 85 | where |
408 | val m | dim u > 0 = m@@>(0,0) | 86 | infixl 7 <.> |
409 | | otherwise = 0 | 87 | {- | Matrix product, matrix vector product, and dot product |
410 | 88 | ||
411 | ---------------------------------------------------------- | 89 | Examples: |
412 | 90 | ||
413 | -- synonym for matrix product | 91 | >>> let a = (3><4) [1..] :: Matrix Double |
414 | mXm :: Product t => Matrix t -> Matrix t -> Matrix t | 92 | >>> let v = fromList [1,0,2,-1] :: Vector Double |
415 | mXm = multiply | 93 | >>> let u = fromList [1,2,3] :: Vector Double |
416 | 94 | ||
417 | -- matrix - vector product | 95 | >>> a |
418 | mXv :: Product t => Matrix t -> Vector t -> Vector t | 96 | (3><4) |
419 | mXv m v = flatten $ m `mXm` (asColumn v) | 97 | [ 1.0, 2.0, 3.0, 4.0 |
98 | , 5.0, 6.0, 7.0, 8.0 | ||
99 | , 9.0, 10.0, 11.0, 12.0 ] | ||
420 | 100 | ||
421 | -- vector - matrix product | 101 | matrix × matrix: |
422 | vXm :: Product t => Vector t -> Matrix t -> Vector t | ||
423 | vXm v m = flatten $ (asRow v) `mXm` m | ||
424 | 102 | ||
425 | {- | Outer product of two vectors. | 103 | >>> disp 2 (a <.> trans a) |
104 | 3x3 | ||
105 | 30 70 110 | ||
106 | 70 174 278 | ||
107 | 110 278 446 | ||
426 | 108 | ||
427 | >>> fromList [1,2,3] `outer` fromList [5,2,3] | 109 | matrix × vector: |
428 | (3><3) | ||
429 | [ 5.0, 2.0, 3.0 | ||
430 | , 10.0, 4.0, 6.0 | ||
431 | , 15.0, 6.0, 9.0 ] | ||
432 | |||
433 | -} | ||
434 | outer :: (Product t) => Vector t -> Vector t -> Matrix t | ||
435 | outer u v = asColumn u `multiply` asRow v | ||
436 | |||
437 | {- | Kronecker product of two matrices. | ||
438 | |||
439 | @m1=(2><3) | ||
440 | [ 1.0, 2.0, 0.0 | ||
441 | , 0.0, -1.0, 3.0 ] | ||
442 | m2=(4><3) | ||
443 | [ 1.0, 2.0, 3.0 | ||
444 | , 4.0, 5.0, 6.0 | ||
445 | , 7.0, 8.0, 9.0 | ||
446 | , 10.0, 11.0, 12.0 ]@ | ||
447 | |||
448 | >>> kronecker m1 m2 | ||
449 | (8><9) | ||
450 | [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0 | ||
451 | , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0 | ||
452 | , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0 | ||
453 | , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0 | ||
454 | , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0 | ||
455 | , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0 | ||
456 | , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0 | ||
457 | , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ] | ||
458 | |||
459 | -} | ||
460 | kronecker :: (Product t) => Matrix t -> Matrix t -> Matrix t | ||
461 | kronecker a b = fromBlocks | ||
462 | . splitEvery (cols a) | ||
463 | . map (reshape (cols b)) | ||
464 | . toRows | ||
465 | $ flatten a `outer` flatten b | ||
466 | 110 | ||
467 | ------------------------------------------------------------------- | 111 | >>> a <.> v |
112 | fromList [3.0,11.0,19.0] | ||
468 | 113 | ||
114 | dot product: | ||
469 | 115 | ||
470 | class Convert t where | 116 | >>> u <.> fromList[3,2,1::Double] |
471 | real :: Container c t => c (RealOf t) -> c t | 117 | 10 |
472 | complex :: Container c t => c t -> c (ComplexOf t) | ||
473 | single :: Container c t => c t -> c (SingleOf t) | ||
474 | double :: Container c t => c t -> c (DoubleOf t) | ||
475 | toComplex :: (Container c t, RealElement t) => (c t, c t) -> c (Complex t) | ||
476 | fromComplex :: (Container c t, RealElement t) => c (Complex t) -> (c t, c t) | ||
477 | 118 | ||
119 | For complex vectors the first argument is conjugated: | ||
478 | 120 | ||
479 | instance Convert Double where | 121 | >>> fromList [1,i] <.> fromList[2*i+1,3] |
480 | real = id | 122 | 1.0 :+ (-1.0) |
481 | complex = comp' | ||
482 | single = single' | ||
483 | double = id | ||
484 | toComplex = toComplex' | ||
485 | fromComplex = fromComplex' | ||
486 | 123 | ||
487 | instance Convert Float where | 124 | >>> fromList [1,i,1-i] <.> complex a |
488 | real = id | 125 | fromList [10.0 :+ 4.0,12.0 :+ 4.0,14.0 :+ 4.0,16.0 :+ 4.0] |
489 | complex = comp' | ||
490 | single = id | ||
491 | double = double' | ||
492 | toComplex = toComplex' | ||
493 | fromComplex = fromComplex' | ||
494 | 126 | ||
495 | instance Convert (Complex Double) where | 127 | -} |
496 | real = comp' | 128 | (<.>) :: a -> b -> c |
497 | complex = id | ||
498 | single = single' | ||
499 | double = id | ||
500 | toComplex = toComplex' | ||
501 | fromComplex = fromComplex' | ||
502 | |||
503 | instance Convert (Complex Float) where | ||
504 | real = comp' | ||
505 | complex = id | ||
506 | single = id | ||
507 | double = double' | ||
508 | toComplex = toComplex' | ||
509 | fromComplex = fromComplex' | ||
510 | |||
511 | ------------------------------------------------------------------- | ||
512 | 129 | ||
513 | type family RealOf x | ||
514 | 130 | ||
515 | type instance RealOf Double = Double | 131 | instance (Product t, Container Vector t) => Contraction (Vector t) (Vector t) t where |
516 | type instance RealOf (Complex Double) = Double | 132 | u <.> v = conj u `udot` v |
517 | 133 | ||
518 | type instance RealOf Float = Float | 134 | instance Product t => Contraction (Matrix t) (Vector t) (Vector t) where |
519 | type instance RealOf (Complex Float) = Float | 135 | (<.>) = mXv |
520 | 136 | ||
521 | type family ComplexOf x | 137 | instance (Container Vector t, Product t) => Contraction (Vector t) (Matrix t) (Vector t) where |
138 | (<.>) v m = (conj v) `vXm` m | ||
522 | 139 | ||
523 | type instance ComplexOf Double = Complex Double | 140 | instance Product t => Contraction (Matrix t) (Matrix t) (Matrix t) where |
524 | type instance ComplexOf (Complex Double) = Complex Double | 141 | (<.>) = mXm |
525 | 142 | ||
526 | type instance ComplexOf Float = Complex Float | ||
527 | type instance ComplexOf (Complex Float) = Complex Float | ||
528 | 143 | ||
529 | type family SingleOf x | 144 | -------------------------------------------------------------------------------- |
530 | 145 | ||
531 | type instance SingleOf Double = Float | 146 | class Mul a b c | a b -> c where |
532 | type instance SingleOf Float = Float | 147 | infixl 7 <> |
148 | -- | Matrix-matrix, matrix-vector, and vector-matrix products. | ||
149 | (<>) :: Product t => a t -> b t -> c t | ||
533 | 150 | ||
534 | type instance SingleOf (Complex a) = Complex (SingleOf a) | 151 | instance Mul Matrix Matrix Matrix where |
152 | (<>) = mXm | ||
535 | 153 | ||
536 | type family DoubleOf x | 154 | instance Mul Matrix Vector Vector where |
155 | (<>) m v = flatten $ m <> asColumn v | ||
537 | 156 | ||
538 | type instance DoubleOf Double = Double | 157 | instance Mul Vector Matrix Vector where |
539 | type instance DoubleOf Float = Double | 158 | (<>) v m = flatten $ asRow v <> m |
540 | 159 | ||
541 | type instance DoubleOf (Complex a) = Complex (DoubleOf a) | 160 | -------------------------------------------------------------------------------- |
542 | 161 | ||
543 | type family ElementOf c | 162 | class LSDiv c where |
163 | infixl 7 <\> | ||
164 | -- | least squares solution of a linear system, similar to the \\ operator of Matlab\/Octave (based on linearSolveSVD) | ||
165 | (<\>) :: Field t => Matrix t -> c t -> c t | ||
544 | 166 | ||
545 | type instance ElementOf (Vector a) = a | 167 | instance LSDiv Vector where |
546 | type instance ElementOf (Matrix a) = a | 168 | m <\> v = flatten (linearSolveSVD m (reshape 1 v)) |
547 | 169 | ||
548 | ------------------------------------------------------------ | 170 | instance LSDiv Matrix where |
171 | (<\>) = linearSolveSVD | ||
549 | 172 | ||
550 | buildM (rc,cc) f = fromLists [ [f r c | c <- cs] | r <- rs ] | 173 | -------------------------------------------------------------------------------- |
551 | where rs = map fromIntegral [0 .. (rc-1)] | ||
552 | cs = map fromIntegral [0 .. (cc-1)] | ||
553 | 174 | ||
554 | buildV n f = fromList [f k | k <- ks] | 175 | class Konst e d c | d -> c, c -> d |
555 | where ks = map fromIntegral [0 .. (n-1)] | 176 | where |
177 | -- | | ||
178 | -- >>> konst 7 3 :: Vector Float | ||
179 | -- fromList [7.0,7.0,7.0] | ||
180 | -- | ||
181 | -- >>> konst i (3::Int,4::Int) | ||
182 | -- (3><4) | ||
183 | -- [ 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0 | ||
184 | -- , 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0 | ||
185 | -- , 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0 ] | ||
186 | -- | ||
187 | konst :: e -> d -> c e | ||
556 | 188 | ||
557 | -------------------------------------------------------- | 189 | instance Container Vector e => Konst e Int Vector |
558 | -- | conjugate transpose | 190 | where |
559 | ctrans :: (Container Vector e, Element e) => Matrix e -> Matrix e | 191 | konst = konst' |
560 | ctrans = liftMatrix conj . trans | ||
561 | 192 | ||
562 | -- | Creates a square matrix with a given diagonal. | 193 | instance Container Vector e => Konst e (Int,Int) Matrix |
563 | diag :: (Num a, Element a) => Vector a -> Matrix a | 194 | where |
564 | diag v = diagRect 0 v n n where n = dim v | 195 | konst = konst' |
565 | 196 | ||
566 | -- | creates the identity matrix of given dimension | 197 | -------------------------------------------------------------------------------- |
567 | ident :: (Num a, Element a) => Int -> Matrix a | ||
568 | ident n = diag (constant 1 n) | ||
569 | 198 | ||
570 | -------------------------------------------------------- | 199 | class Build d f c e | d -> c, c -> d, f -> e, f -> d, f -> c, c e -> f, d e -> f |
200 | where | ||
201 | -- | | ||
202 | -- >>> build 5 (**2) :: Vector Double | ||
203 | -- fromList [0.0,1.0,4.0,9.0,16.0] | ||
204 | -- | ||
205 | -- Hilbert matrix of order N: | ||
206 | -- | ||
207 | -- >>> let hilb n = build (n,n) (\i j -> 1/(i+j+1)) :: Matrix Double | ||
208 | -- >>> putStr . dispf 2 $ hilb 3 | ||
209 | -- 3x3 | ||
210 | -- 1.00 0.50 0.33 | ||
211 | -- 0.50 0.33 0.25 | ||
212 | -- 0.33 0.25 0.20 | ||
213 | -- | ||
214 | build :: d -> f -> c e | ||
571 | 215 | ||
572 | findV p x = foldVectorWithIndex g [] x where | 216 | instance Container Vector e => Build Int (e -> e) Vector e |
573 | g k z l = if p z then k:l else l | 217 | where |
218 | build = build' | ||
574 | 219 | ||
575 | findM p x = map ((`divMod` cols x)) $ findV p (flatten x) | 220 | instance Container Matrix e => Build (Int,Int) (e -> e -> e) Matrix e |
221 | where | ||
222 | build = build' | ||
576 | 223 | ||
577 | assocV n z xs = ST.runSTVector $ do | 224 | -------------------------------------------------------------------------------- |
578 | v <- ST.newVector z n | ||
579 | mapM_ (\(k,x) -> ST.writeVector v k x) xs | ||
580 | return v | ||
581 | 225 | ||
582 | assocM (r,c) z xs = ST.runSTMatrix $ do | 226 | {- | alternative operator for '(\<.\>)' |
583 | m <- ST.newMatrix z r c | ||
584 | mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs | ||
585 | return m | ||
586 | 227 | ||
587 | accumV v0 f xs = ST.runSTVector $ do | 228 | x25c7, white diamond |
588 | v <- ST.thawVector v0 | ||
589 | mapM_ (\(k,x) -> ST.modifyVector v k (f x)) xs | ||
590 | return v | ||
591 | 229 | ||
592 | accumM m0 f xs = ST.runSTMatrix $ do | 230 | -} |
593 | m <- ST.thawMatrix m0 | 231 | (◇) :: Contraction a b c => a -> b -> c |
594 | mapM_ (\((i,j),x) -> ST.modifyMatrix m i j (f x)) xs | 232 | infixl 7 ◇ |
595 | return m | 233 | (◇) = (<.>) |
596 | 234 | ||
597 | ---------------------------------------------------------------------- | 235 | -- | dot product: @cdot u v = 'udot' ('conj' u) v@ |
236 | dot :: (Container Vector t, Product t) => Vector t -> Vector t -> t | ||
237 | dot u v = udot (conj u) v | ||
598 | 238 | ||
599 | condM a b l e t = matrixFromVector RowMajor (rows a'') (cols a'') $ cond a' b' l' e' t' | 239 | -------------------------------------------------------------------------------- |
600 | where | ||
601 | args@(a'':_) = conformMs [a,b,l,e,t] | ||
602 | [a', b', l', e', t'] = map flatten args | ||
603 | 240 | ||
604 | condV f a b l e t = f a' b' l' e' t' | 241 | optimiseMult :: Monoid (Matrix t) => [Matrix t] -> Matrix t |
605 | where | 242 | optimiseMult = mconcat |
606 | [a', b', l', e', t'] = conformVs [a,b,l,e,t] | ||
607 | 243 | ||