diff options
author | Alberto Ruiz <aruiz@um.es> | 2015-06-05 16:44:52 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2015-06-05 16:44:52 +0200 |
commit | 64df799c68817054705a99e9ee02723603fae29e (patch) | |
tree | bf1f5b04eb9984c230d295905570330c026337e1 /packages/base/src/Data | |
parent | 11d7c37dc8b314338bc6382d80e74aaec2bb5620 (diff) |
move internal numeric
Diffstat (limited to 'packages/base/src/Data')
-rw-r--r-- | packages/base/src/Data/Packed/Internal/Numeric.hs | 928 |
1 files changed, 0 insertions, 928 deletions
diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Data/Packed/Internal/Numeric.hs deleted file mode 100644 index a03159d..0000000 --- a/packages/base/src/Data/Packed/Internal/Numeric.hs +++ /dev/null | |||
@@ -1,928 +0,0 @@ | |||
1 | {-# LANGUAGE CPP #-} | ||
2 | {-# LANGUAGE TypeFamilies #-} | ||
3 | {-# LANGUAGE FlexibleContexts #-} | ||
4 | {-# LANGUAGE FlexibleInstances #-} | ||
5 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
6 | {-# LANGUAGE FunctionalDependencies #-} | ||
7 | {-# LANGUAGE UndecidableInstances #-} | ||
8 | |||
9 | ----------------------------------------------------------------------------- | ||
10 | -- | | ||
11 | -- Module : Data.Packed.Internal.Numeric | ||
12 | -- Copyright : (c) Alberto Ruiz 2010-14 | ||
13 | -- License : BSD3 | ||
14 | -- Maintainer : Alberto Ruiz | ||
15 | -- Stability : provisional | ||
16 | -- | ||
17 | ----------------------------------------------------------------------------- | ||
18 | |||
19 | module Data.Packed.Internal.Numeric ( | ||
20 | -- * Basic functions | ||
21 | ident, diag, ctrans, | ||
22 | -- * Generic operations | ||
23 | Container(..), | ||
24 | scalar, conj, scale, arctan2, cmap, cmod, | ||
25 | atIndex, minIndex, maxIndex, minElement, maxElement, | ||
26 | sumElements, prodElements, | ||
27 | step, cond, find, assoc, accum, findV, assocV, accumV, | ||
28 | Transposable(..), Linear(..), Testable(..), | ||
29 | -- * Matrix product and related functions | ||
30 | Product(..), udot, | ||
31 | mXm,mXv,vXm, | ||
32 | outer, kronecker, | ||
33 | -- * sorting | ||
34 | sortV, sortI, | ||
35 | -- * Element conversion | ||
36 | Convert(..), | ||
37 | Complexable(), | ||
38 | RealElement(), | ||
39 | roundVector, fromInt, toInt, | ||
40 | RealOf, ComplexOf, SingleOf, DoubleOf, | ||
41 | IndexOf, | ||
42 | I, Extractor(..), (??), range, idxs, remapM, | ||
43 | module Data.Complex | ||
44 | ) where | ||
45 | |||
46 | import Data.Packed | ||
47 | import Data.Packed.ST as ST | ||
48 | import Numeric.Conversion | ||
49 | import Data.Packed.Development | ||
50 | import Numeric.Vectorized | ||
51 | import Data.Complex | ||
52 | import Numeric.LinearAlgebra.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ,multiplyI) | ||
53 | import Data.Packed.Internal | ||
54 | import Text.Printf(printf) | ||
55 | |||
56 | ------------------------------------------------------------------- | ||
57 | |||
58 | type family IndexOf (c :: * -> *) | ||
59 | |||
60 | type instance IndexOf Vector = Int | ||
61 | type instance IndexOf Matrix = (Int,Int) | ||
62 | |||
63 | type family ArgOf (c :: * -> *) a | ||
64 | |||
65 | type instance ArgOf Vector a = a -> a | ||
66 | type instance ArgOf Matrix a = a -> a -> a | ||
67 | |||
68 | -------------------------------------------------------------------------- | ||
69 | |||
70 | data Extractor | ||
71 | = All | ||
72 | | Range Int Int Int | ||
73 | | Pos (Vector I) | ||
74 | | PosCyc (Vector I) | ||
75 | | Take Int | ||
76 | | TakeLast Int | ||
77 | | Drop Int | ||
78 | | DropLast Int | ||
79 | deriving Show | ||
80 | |||
81 | -- | Create a vector of indexes, useful for matrix extraction using '??' | ||
82 | idxs :: [Int] -> Vector I | ||
83 | idxs js = fromList (map fromIntegral js) :: Vector I | ||
84 | |||
85 | -- | ||
86 | infixl 9 ?? | ||
87 | (??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t | ||
88 | |||
89 | |||
90 | extractError m e = error $ printf "can't extract %s from matrix %dx%d" (show e) (rows m) (cols m) | ||
91 | |||
92 | m ?? (Range a s b,e) | s /= 1 = m ?? (Pos (idxs [a,a+s .. b]), e) | ||
93 | m ?? (e,Range a s b) | s /= 1 = m ?? (e, Pos (idxs [a,a+s .. b])) | ||
94 | |||
95 | m ?? e@(Range a _ b,_) | a < 0 || b >= rows m = extractError m e | ||
96 | m ?? e@(_,Range a _ b) | a < 0 || b >= cols m = extractError m e | ||
97 | |||
98 | m ?? e@(Pos vs,_) | minElement vs < 0 || maxElement vs >= fromIntegral (rows m) = extractError m e | ||
99 | m ?? e@(_,Pos vs) | minElement vs < 0 || maxElement vs >= fromIntegral (cols m) = extractError m e | ||
100 | |||
101 | m ?? (All,All) = m | ||
102 | |||
103 | m ?? (Range a _ b,e) | a > b = m ?? (Take 0,e) | ||
104 | m ?? (e,Range a _ b) | a > b = m ?? (e,Take 0) | ||
105 | |||
106 | m ?? (Take n,e) | ||
107 | | n <= 0 = (0><cols m) [] ?? (All,e) | ||
108 | | n >= rows m = m ?? (All,e) | ||
109 | |||
110 | m ?? (e,Take n) | ||
111 | | n <= 0 = (rows m><0) [] ?? (e,All) | ||
112 | | n >= cols m = m ?? (e,All) | ||
113 | |||
114 | m ?? (Drop n,e) | ||
115 | | n <= 0 = m ?? (All,e) | ||
116 | | n >= rows m = (0><cols m) [] ?? (All,e) | ||
117 | |||
118 | m ?? (e,Drop n) | ||
119 | | n <= 0 = m ?? (e,All) | ||
120 | | n >= cols m = (rows m><0) [] ?? (e,All) | ||
121 | |||
122 | m ?? (TakeLast n, e) = m ?? (Drop (rows m - n), e) | ||
123 | m ?? (e, TakeLast n) = m ?? (e, Drop (cols m - n)) | ||
124 | |||
125 | m ?? (DropLast n, e) = m ?? (Take (rows m - n), e) | ||
126 | m ?? (e, DropLast n) = m ?? (e, Take (cols m - n)) | ||
127 | |||
128 | m ?? (er,ec) = extractR m moder rs modec cs | ||
129 | where | ||
130 | (moder,rs) = mkExt (rows m) er | ||
131 | (modec,cs) = mkExt (cols m) ec | ||
132 | ran a b = (0, idxs [a,b]) | ||
133 | pos ks = (1, ks) | ||
134 | mkExt _ (Pos ks) = pos ks | ||
135 | mkExt n (PosCyc ks) | ||
136 | | n == 0 = mkExt n (Take 0) | ||
137 | | otherwise = pos (cmod n ks) | ||
138 | mkExt _ (Range mn _ mx) = ran mn mx | ||
139 | mkExt _ (Take k) = ran 0 (k-1) | ||
140 | mkExt n (Drop k) = ran k (n-1) | ||
141 | mkExt n _ = ran 0 (n-1) -- All | ||
142 | |||
143 | ------------------------------------------------------------------- | ||
144 | |||
145 | |||
146 | -- | Basic element-by-element functions for numeric containers | ||
147 | class Element e => Container c e | ||
148 | where | ||
149 | conj' :: c e -> c e | ||
150 | size' :: c e -> IndexOf c | ||
151 | scalar' :: e -> c e | ||
152 | scale' :: e -> c e -> c e | ||
153 | addConstant :: e -> c e -> c e | ||
154 | add :: c e -> c e -> c e | ||
155 | sub :: c e -> c e -> c e | ||
156 | -- | element by element multiplication | ||
157 | mul :: c e -> c e -> c e | ||
158 | equal :: c e -> c e -> Bool | ||
159 | cmap' :: (Element b) => (e -> b) -> c e -> c b | ||
160 | konst' :: e -> IndexOf c -> c e | ||
161 | build' :: IndexOf c -> (ArgOf c e) -> c e | ||
162 | atIndex' :: c e -> IndexOf c -> e | ||
163 | minIndex' :: c e -> IndexOf c | ||
164 | maxIndex' :: c e -> IndexOf c | ||
165 | minElement' :: c e -> e | ||
166 | maxElement' :: c e -> e | ||
167 | sumElements' :: c e -> e | ||
168 | prodElements' :: c e -> e | ||
169 | step' :: Ord e => c e -> c e | ||
170 | cond' :: Ord e | ||
171 | => c e -- ^ a | ||
172 | -> c e -- ^ b | ||
173 | -> c e -- ^ l | ||
174 | -> c e -- ^ e | ||
175 | -> c e -- ^ g | ||
176 | -> c e -- ^ result | ||
177 | ccompare' :: Ord e => c e -> c e -> c I | ||
178 | cselect' :: c I -> c e -> c e -> c e -> c e | ||
179 | find' :: (e -> Bool) -> c e -> [IndexOf c] | ||
180 | assoc' :: IndexOf c -- ^ size | ||
181 | -> e -- ^ default value | ||
182 | -> [(IndexOf c, e)] -- ^ association list | ||
183 | -> c e -- ^ result | ||
184 | accum' :: c e -- ^ initial structure | ||
185 | -> (e -> e -> e) -- ^ update function | ||
186 | -> [(IndexOf c, e)] -- ^ association list | ||
187 | -> c e -- ^ result | ||
188 | |||
189 | -- | scale the element by element reciprocal of the object: | ||
190 | -- | ||
191 | -- @scaleRecip 2 (fromList [5,i]) == 2 |> [0.4 :+ 0.0,0.0 :+ (-2.0)]@ | ||
192 | scaleRecip :: Fractional e => e -> c e -> c e | ||
193 | -- | element by element division | ||
194 | divide :: Fractional e => c e -> c e -> c e | ||
195 | -- | ||
196 | -- element by element inverse tangent | ||
197 | arctan2' :: Fractional e => c e -> c e -> c e | ||
198 | cmod' :: Integral e => e -> c e -> c e | ||
199 | fromInt' :: c I -> c e | ||
200 | toInt' :: c e -> c I | ||
201 | |||
202 | |||
203 | -------------------------------------------------------------------------- | ||
204 | |||
205 | instance Container Vector I | ||
206 | where | ||
207 | conj' = id | ||
208 | size' = dim | ||
209 | scale' = vectorMapValI Scale | ||
210 | addConstant = vectorMapValI AddConstant | ||
211 | add = vectorZipI Add | ||
212 | sub = vectorZipI Sub | ||
213 | mul = vectorZipI Mul | ||
214 | equal u v = dim u == dim v && maxElement' (vectorMapI Abs (sub u v)) == 0 | ||
215 | scalar' x = fromList [x] | ||
216 | konst' = constantD | ||
217 | build' = buildV | ||
218 | cmap' = mapVector | ||
219 | atIndex' = (@>) | ||
220 | minIndex' = emptyErrorV "minIndex" (fromIntegral . toScalarI MinIdx) | ||
221 | maxIndex' = emptyErrorV "maxIndex" (fromIntegral . toScalarI MaxIdx) | ||
222 | minElement' = emptyErrorV "minElement" (toScalarI Min) | ||
223 | maxElement' = emptyErrorV "maxElement" (toScalarI Max) | ||
224 | sumElements' = sumI | ||
225 | prodElements' = prodI | ||
226 | step' = stepI | ||
227 | find' = findV | ||
228 | assoc' = assocV | ||
229 | accum' = accumV | ||
230 | cond' = condV condI | ||
231 | ccompare' = compareCV compareV | ||
232 | cselect' = selectCV selectV | ||
233 | scaleRecip = undefined -- cannot match | ||
234 | divide = undefined | ||
235 | arctan2' = undefined | ||
236 | cmod' m x | ||
237 | | m /= 0 = vectorMapValI ModVS m x | ||
238 | | otherwise = error $ "cmod 0 on vector of size "++(show $ dim x) | ||
239 | fromInt' = id | ||
240 | toInt' = id | ||
241 | |||
242 | instance Container Vector Float | ||
243 | where | ||
244 | conj' = id | ||
245 | size' = dim | ||
246 | scale' = vectorMapValF Scale | ||
247 | addConstant = vectorMapValF AddConstant | ||
248 | add = vectorZipF Add | ||
249 | sub = vectorZipF Sub | ||
250 | mul = vectorZipF Mul | ||
251 | equal u v = dim u == dim v && maxElement (vectorMapF Abs (sub u v)) == 0.0 | ||
252 | scalar' x = fromList [x] | ||
253 | konst' = constantD | ||
254 | build' = buildV | ||
255 | cmap' = mapVector | ||
256 | atIndex' = (@>) | ||
257 | minIndex' = emptyErrorV "minIndex" (round . toScalarF MinIdx) | ||
258 | maxIndex' = emptyErrorV "maxIndex" (round . toScalarF MaxIdx) | ||
259 | minElement' = emptyErrorV "minElement" (toScalarF Min) | ||
260 | maxElement' = emptyErrorV "maxElement" (toScalarF Max) | ||
261 | sumElements' = sumF | ||
262 | prodElements' = prodF | ||
263 | step' = stepF | ||
264 | find' = findV | ||
265 | assoc' = assocV | ||
266 | accum' = accumV | ||
267 | cond' = condV condF | ||
268 | ccompare' = compareCV compareV | ||
269 | cselect' = selectCV selectV | ||
270 | scaleRecip = vectorMapValF Recip | ||
271 | divide = vectorZipF Div | ||
272 | arctan2' = vectorZipF ATan2 | ||
273 | cmod' = undefined | ||
274 | fromInt' = int2floatV | ||
275 | toInt' = float2IntV | ||
276 | |||
277 | |||
278 | |||
279 | instance Container Vector Double | ||
280 | where | ||
281 | conj' = id | ||
282 | size' = dim | ||
283 | scale' = vectorMapValR Scale | ||
284 | addConstant = vectorMapValR AddConstant | ||
285 | add = vectorZipR Add | ||
286 | sub = vectorZipR Sub | ||
287 | mul = vectorZipR Mul | ||
288 | equal u v = dim u == dim v && maxElement (vectorMapR Abs (sub u v)) == 0.0 | ||
289 | scalar' x = fromList [x] | ||
290 | konst' = constantD | ||
291 | build' = buildV | ||
292 | cmap' = mapVector | ||
293 | atIndex' = (@>) | ||
294 | minIndex' = emptyErrorV "minIndex" (round . toScalarR MinIdx) | ||
295 | maxIndex' = emptyErrorV "maxIndex" (round . toScalarR MaxIdx) | ||
296 | minElement' = emptyErrorV "minElement" (toScalarR Min) | ||
297 | maxElement' = emptyErrorV "maxElement" (toScalarR Max) | ||
298 | sumElements' = sumR | ||
299 | prodElements' = prodR | ||
300 | step' = stepD | ||
301 | find' = findV | ||
302 | assoc' = assocV | ||
303 | accum' = accumV | ||
304 | cond' = condV condD | ||
305 | ccompare' = compareCV compareV | ||
306 | cselect' = selectCV selectV | ||
307 | scaleRecip = vectorMapValR Recip | ||
308 | divide = vectorZipR Div | ||
309 | arctan2' = vectorZipR ATan2 | ||
310 | cmod' = undefined | ||
311 | fromInt' = int2DoubleV | ||
312 | toInt' = double2IntV | ||
313 | |||
314 | |||
315 | instance Container Vector (Complex Double) | ||
316 | where | ||
317 | conj' = conjugateC | ||
318 | size' = dim | ||
319 | scale' = vectorMapValC Scale | ||
320 | addConstant = vectorMapValC AddConstant | ||
321 | add = vectorZipC Add | ||
322 | sub = vectorZipC Sub | ||
323 | mul = vectorZipC Mul | ||
324 | equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 | ||
325 | scalar' x = fromList [x] | ||
326 | konst' = constantD | ||
327 | build' = buildV | ||
328 | cmap' = mapVector | ||
329 | atIndex' = (@>) | ||
330 | minIndex' = emptyErrorV "minIndex" (minIndex' . fst . fromComplex . (mul <*> conj')) | ||
331 | maxIndex' = emptyErrorV "maxIndex" (maxIndex' . fst . fromComplex . (mul <*> conj')) | ||
332 | minElement' = emptyErrorV "minElement" (atIndex' <*> minIndex') | ||
333 | maxElement' = emptyErrorV "maxElement" (atIndex' <*> maxIndex') | ||
334 | sumElements' = sumC | ||
335 | prodElements' = prodC | ||
336 | step' = undefined -- cannot match | ||
337 | find' = findV | ||
338 | assoc' = assocV | ||
339 | accum' = accumV | ||
340 | cond' = undefined -- cannot match | ||
341 | ccompare' = undefined | ||
342 | cselect' = selectCV selectV | ||
343 | scaleRecip = vectorMapValC Recip | ||
344 | divide = vectorZipC Div | ||
345 | arctan2' = vectorZipC ATan2 | ||
346 | cmod' = undefined | ||
347 | fromInt' = complex . int2DoubleV | ||
348 | toInt' = toInt' . fst . fromComplex | ||
349 | |||
350 | instance Container Vector (Complex Float) | ||
351 | where | ||
352 | conj' = conjugateQ | ||
353 | size' = dim | ||
354 | scale' = vectorMapValQ Scale | ||
355 | addConstant = vectorMapValQ AddConstant | ||
356 | add = vectorZipQ Add | ||
357 | sub = vectorZipQ Sub | ||
358 | mul = vectorZipQ Mul | ||
359 | equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 | ||
360 | scalar' x = fromList [x] | ||
361 | konst' = constantD | ||
362 | build' = buildV | ||
363 | cmap' = mapVector | ||
364 | atIndex' = (@>) | ||
365 | minIndex' = emptyErrorV "minIndex" (minIndex' . fst . fromComplex . (mul <*> conj')) | ||
366 | maxIndex' = emptyErrorV "maxIndex" (maxIndex' . fst . fromComplex . (mul <*> conj')) | ||
367 | minElement' = emptyErrorV "minElement" (atIndex' <*> minIndex') | ||
368 | maxElement' = emptyErrorV "maxElement" (atIndex' <*> maxIndex') | ||
369 | sumElements' = sumQ | ||
370 | prodElements' = prodQ | ||
371 | step' = undefined -- cannot match | ||
372 | find' = findV | ||
373 | assoc' = assocV | ||
374 | accum' = accumV | ||
375 | cond' = undefined -- cannot match | ||
376 | ccompare' = undefined | ||
377 | cselect' = selectCV selectV | ||
378 | scaleRecip = vectorMapValQ Recip | ||
379 | divide = vectorZipQ Div | ||
380 | arctan2' = vectorZipQ ATan2 | ||
381 | cmod' = undefined | ||
382 | fromInt' = complex . int2floatV | ||
383 | toInt' = toInt' . fst . fromComplex | ||
384 | |||
385 | --------------------------------------------------------------- | ||
386 | |||
387 | instance (Num a, Element a, Container Vector a) => Container Matrix a | ||
388 | where | ||
389 | conj' = liftMatrix conj' | ||
390 | size' = size | ||
391 | scale' x = liftMatrix (scale' x) | ||
392 | addConstant x = liftMatrix (addConstant x) | ||
393 | add = liftMatrix2 add | ||
394 | sub = liftMatrix2 sub | ||
395 | mul = liftMatrix2 mul | ||
396 | equal a b = cols a == cols b && flatten a `equal` flatten b | ||
397 | scalar' x = (1><1) [x] | ||
398 | konst' v (r,c) = matrixFromVector RowMajor r c (konst' v (r*c)) | ||
399 | build' = buildM | ||
400 | cmap' f = liftMatrix (mapVector f) | ||
401 | atIndex' = (@@>) | ||
402 | minIndex' = emptyErrorM "minIndex of Matrix" $ | ||
403 | \m -> divMod (minIndex' $ flatten m) (cols m) | ||
404 | maxIndex' = emptyErrorM "maxIndex of Matrix" $ | ||
405 | \m -> divMod (maxIndex' $ flatten m) (cols m) | ||
406 | minElement' = emptyErrorM "minElement of Matrix" (atIndex' <*> minIndex') | ||
407 | maxElement' = emptyErrorM "maxElement of Matrix" (atIndex' <*> maxIndex') | ||
408 | sumElements' = sumElements' . flatten | ||
409 | prodElements' = prodElements' . flatten | ||
410 | step' = liftMatrix step' | ||
411 | find' = findM | ||
412 | assoc' = assocM | ||
413 | accum' = accumM | ||
414 | cond' = condM | ||
415 | ccompare' = compareM | ||
416 | cselect' = selectM | ||
417 | scaleRecip x = liftMatrix (scaleRecip x) | ||
418 | divide = liftMatrix2 divide | ||
419 | arctan2' = liftMatrix2 arctan2' | ||
420 | cmod' m x | ||
421 | | m /= 0 = liftMatrix (cmod' m) x | ||
422 | | otherwise = error $ "cmod 0 on matrix "++shSize x | ||
423 | fromInt' = liftMatrix fromInt' | ||
424 | toInt' = liftMatrix toInt' | ||
425 | |||
426 | |||
427 | emptyErrorV msg f v = | ||
428 | if dim v > 0 | ||
429 | then f v | ||
430 | else error $ msg ++ " of empty Vector" | ||
431 | |||
432 | emptyErrorM msg f m = | ||
433 | if rows m > 0 && cols m > 0 | ||
434 | then f m | ||
435 | else error $ msg++" "++shSize m | ||
436 | |||
437 | -------------------------------------------------------------------------------- | ||
438 | |||
439 | -- | create a structure with a single element | ||
440 | -- | ||
441 | -- >>> let v = fromList [1..3::Double] | ||
442 | -- >>> v / scalar (norm2 v) | ||
443 | -- fromList [0.2672612419124244,0.5345224838248488,0.8017837257372732] | ||
444 | -- | ||
445 | scalar :: Container c e => e -> c e | ||
446 | scalar = scalar' | ||
447 | |||
448 | -- | complex conjugate | ||
449 | conj :: Container c e => c e -> c e | ||
450 | conj = conj' | ||
451 | |||
452 | -- | multiplication by scalar | ||
453 | scale :: Container c e => e -> c e -> c e | ||
454 | scale = scale' | ||
455 | |||
456 | arctan2 :: (Fractional e, Container c e) => c e -> c e -> c e | ||
457 | arctan2 = arctan2' | ||
458 | |||
459 | -- | 'mod' for integer arrays | ||
460 | -- | ||
461 | -- >>> cmod 3 (range 5) | ||
462 | -- fromList [0,1,2,0,1] | ||
463 | cmod :: (Integral e, Container c e) => Int -> c e -> c e | ||
464 | cmod m = cmod' (fromIntegral m) | ||
465 | |||
466 | -- | | ||
467 | -- >>>fromInt ((2><2) [0..3]) :: Matrix (Complex Double) | ||
468 | -- (2><2) | ||
469 | -- [ 0.0 :+ 0.0, 1.0 :+ 0.0 | ||
470 | -- , 2.0 :+ 0.0, 3.0 :+ 0.0 ] | ||
471 | -- | ||
472 | fromInt :: (Container c e) => c I -> c e | ||
473 | fromInt = fromInt' | ||
474 | |||
475 | toInt :: (Container c e) => c e -> c I | ||
476 | toInt = toInt' | ||
477 | |||
478 | |||
479 | -- | like 'fmap' (cannot implement instance Functor because of Element class constraint) | ||
480 | cmap :: (Element b, Container c e) => (e -> b) -> c e -> c b | ||
481 | cmap = cmap' | ||
482 | |||
483 | -- | generic indexing function | ||
484 | -- | ||
485 | -- >>> vector [1,2,3] `atIndex` 1 | ||
486 | -- 2.0 | ||
487 | -- | ||
488 | -- >>> matrix 3 [0..8] `atIndex` (2,0) | ||
489 | -- 6.0 | ||
490 | -- | ||
491 | atIndex :: Container c e => c e -> IndexOf c -> e | ||
492 | atIndex = atIndex' | ||
493 | |||
494 | -- | index of minimum element | ||
495 | minIndex :: Container c e => c e -> IndexOf c | ||
496 | minIndex = minIndex' | ||
497 | |||
498 | -- | index of maximum element | ||
499 | maxIndex :: Container c e => c e -> IndexOf c | ||
500 | maxIndex = maxIndex' | ||
501 | |||
502 | -- | value of minimum element | ||
503 | minElement :: Container c e => c e -> e | ||
504 | minElement = minElement' | ||
505 | |||
506 | -- | value of maximum element | ||
507 | maxElement :: Container c e => c e -> e | ||
508 | maxElement = maxElement' | ||
509 | |||
510 | -- | the sum of elements | ||
511 | sumElements :: Container c e => c e -> e | ||
512 | sumElements = sumElements' | ||
513 | |||
514 | -- | the product of elements | ||
515 | prodElements :: Container c e => c e -> e | ||
516 | prodElements = prodElements' | ||
517 | |||
518 | |||
519 | -- | A more efficient implementation of @cmap (\\x -> if x>0 then 1 else 0)@ | ||
520 | -- | ||
521 | -- >>> step $ linspace 5 (-1,1::Double) | ||
522 | -- 5 |> [0.0,0.0,0.0,1.0,1.0] | ||
523 | -- | ||
524 | step | ||
525 | :: (Ord e, Container c e) | ||
526 | => c e | ||
527 | -> c e | ||
528 | step = step' | ||
529 | |||
530 | |||
531 | -- | Element by element version of @case compare a b of {LT -> l; EQ -> e; GT -> g}@. | ||
532 | -- | ||
533 | -- Arguments with any dimension = 1 are automatically expanded: | ||
534 | -- | ||
535 | -- >>> cond ((1><4)[1..]) ((3><1)[1..]) 0 100 ((3><4)[1..]) :: Matrix Double | ||
536 | -- (3><4) | ||
537 | -- [ 100.0, 2.0, 3.0, 4.0 | ||
538 | -- , 0.0, 100.0, 7.0, 8.0 | ||
539 | -- , 0.0, 0.0, 100.0, 12.0 ] | ||
540 | -- | ||
541 | cond | ||
542 | :: (Ord e, Container c e) | ||
543 | => c e -- ^ a | ||
544 | -> c e -- ^ b | ||
545 | -> c e -- ^ l | ||
546 | -> c e -- ^ e | ||
547 | -> c e -- ^ g | ||
548 | -> c e -- ^ result | ||
549 | cond = cond' | ||
550 | |||
551 | |||
552 | -- | Find index of elements which satisfy a predicate | ||
553 | -- | ||
554 | -- >>> find (>0) (ident 3 :: Matrix Double) | ||
555 | -- [(0,0),(1,1),(2,2)] | ||
556 | -- | ||
557 | find | ||
558 | :: Container c e | ||
559 | => (e -> Bool) | ||
560 | -> c e | ||
561 | -> [IndexOf c] | ||
562 | find = find' | ||
563 | |||
564 | |||
565 | -- | Create a structure from an association list | ||
566 | -- | ||
567 | -- >>> assoc 5 0 [(3,7),(1,4)] :: Vector Double | ||
568 | -- fromList [0.0,4.0,0.0,7.0,0.0] | ||
569 | -- | ||
570 | -- >>> assoc (2,3) 0 [((0,2),7),((1,0),2*i-3)] :: Matrix (Complex Double) | ||
571 | -- (2><3) | ||
572 | -- [ 0.0 :+ 0.0, 0.0 :+ 0.0, 7.0 :+ 0.0 | ||
573 | -- , (-3.0) :+ 2.0, 0.0 :+ 0.0, 0.0 :+ 0.0 ] | ||
574 | -- | ||
575 | assoc | ||
576 | :: Container c e | ||
577 | => IndexOf c -- ^ size | ||
578 | -> e -- ^ default value | ||
579 | -> [(IndexOf c, e)] -- ^ association list | ||
580 | -> c e -- ^ result | ||
581 | assoc = assoc' | ||
582 | |||
583 | |||
584 | -- | Modify a structure using an update function | ||
585 | -- | ||
586 | -- >>> accum (ident 5) (+) [((1,1),5),((0,3),3)] :: Matrix Double | ||
587 | -- (5><5) | ||
588 | -- [ 1.0, 0.0, 0.0, 3.0, 0.0 | ||
589 | -- , 0.0, 6.0, 0.0, 0.0, 0.0 | ||
590 | -- , 0.0, 0.0, 1.0, 0.0, 0.0 | ||
591 | -- , 0.0, 0.0, 0.0, 1.0, 0.0 | ||
592 | -- , 0.0, 0.0, 0.0, 0.0, 1.0 ] | ||
593 | -- | ||
594 | -- computation of histogram: | ||
595 | -- | ||
596 | -- >>> accum (konst 0 7) (+) (map (flip (,) 1) [4,5,4,1,5,2,5]) :: Vector Double | ||
597 | -- fromList [0.0,1.0,1.0,0.0,2.0,3.0,0.0] | ||
598 | -- | ||
599 | accum | ||
600 | :: Container c e | ||
601 | => c e -- ^ initial structure | ||
602 | -> (e -> e -> e) -- ^ update function | ||
603 | -> [(IndexOf c, e)] -- ^ association list | ||
604 | -> c e -- ^ result | ||
605 | accum = accum' | ||
606 | |||
607 | |||
608 | -------------------------------------------------------------------------------- | ||
609 | |||
610 | -- | Matrix product and related functions | ||
611 | class (Num e, Element e) => Product e where | ||
612 | -- | matrix product | ||
613 | multiply :: Matrix e -> Matrix e -> Matrix e | ||
614 | -- | sum of absolute value of elements (differs in complex case from @norm1@) | ||
615 | absSum :: Vector e -> RealOf e | ||
616 | -- | sum of absolute value of elements | ||
617 | norm1 :: Vector e -> RealOf e | ||
618 | -- | euclidean norm | ||
619 | norm2 :: Floating e => Vector e -> RealOf e | ||
620 | -- | element of maximum magnitude | ||
621 | normInf :: Vector e -> RealOf e | ||
622 | |||
623 | instance Product Float where | ||
624 | norm2 = emptyVal (toScalarF Norm2) | ||
625 | absSum = emptyVal (toScalarF AbsSum) | ||
626 | norm1 = emptyVal (toScalarF AbsSum) | ||
627 | normInf = emptyVal (maxElement . vectorMapF Abs) | ||
628 | multiply = emptyMul multiplyF | ||
629 | |||
630 | instance Product Double where | ||
631 | norm2 = emptyVal (toScalarR Norm2) | ||
632 | absSum = emptyVal (toScalarR AbsSum) | ||
633 | norm1 = emptyVal (toScalarR AbsSum) | ||
634 | normInf = emptyVal (maxElement . vectorMapR Abs) | ||
635 | multiply = emptyMul multiplyR | ||
636 | |||
637 | instance Product (Complex Float) where | ||
638 | norm2 = emptyVal (toScalarQ Norm2) | ||
639 | absSum = emptyVal (toScalarQ AbsSum) | ||
640 | norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapQ Abs) | ||
641 | normInf = emptyVal (maxElement . fst . fromComplex . vectorMapQ Abs) | ||
642 | multiply = emptyMul multiplyQ | ||
643 | |||
644 | instance Product (Complex Double) where | ||
645 | norm2 = emptyVal (toScalarC Norm2) | ||
646 | absSum = emptyVal (toScalarC AbsSum) | ||
647 | norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapC Abs) | ||
648 | normInf = emptyVal (maxElement . fst . fromComplex . vectorMapC Abs) | ||
649 | multiply = emptyMul multiplyC | ||
650 | |||
651 | instance Product I where | ||
652 | norm2 = undefined | ||
653 | absSum = emptyVal (sumElements . vectorMapI Abs) | ||
654 | norm1 = absSum | ||
655 | normInf = emptyVal (maxElement . vectorMapI Abs) | ||
656 | multiply = emptyMul multiplyI | ||
657 | |||
658 | |||
659 | emptyMul m a b | ||
660 | | x1 == 0 && x2 == 0 || r == 0 || c == 0 = konst' 0 (r,c) | ||
661 | | otherwise = m a b | ||
662 | where | ||
663 | r = rows a | ||
664 | x1 = cols a | ||
665 | x2 = rows b | ||
666 | c = cols b | ||
667 | |||
668 | emptyVal f v = | ||
669 | if dim v > 0 | ||
670 | then f v | ||
671 | else 0 | ||
672 | |||
673 | -- FIXME remove unused C wrappers | ||
674 | -- | unconjugated dot product | ||
675 | udot :: Product e => Vector e -> Vector e -> e | ||
676 | udot u v | ||
677 | | dim u == dim v = val (asRow u `multiply` asColumn v) | ||
678 | | otherwise = error $ "different dimensions "++show (dim u)++" and "++show (dim v)++" in dot product" | ||
679 | where | ||
680 | val m | dim u > 0 = m@@>(0,0) | ||
681 | | otherwise = 0 | ||
682 | |||
683 | ---------------------------------------------------------- | ||
684 | |||
685 | -- synonym for matrix product | ||
686 | mXm :: Product t => Matrix t -> Matrix t -> Matrix t | ||
687 | mXm = multiply | ||
688 | |||
689 | -- matrix - vector product | ||
690 | mXv :: Product t => Matrix t -> Vector t -> Vector t | ||
691 | mXv m v = flatten $ m `mXm` (asColumn v) | ||
692 | |||
693 | -- vector - matrix product | ||
694 | vXm :: Product t => Vector t -> Matrix t -> Vector t | ||
695 | vXm v m = flatten $ (asRow v) `mXm` m | ||
696 | |||
697 | {- | Outer product of two vectors. | ||
698 | |||
699 | >>> fromList [1,2,3] `outer` fromList [5,2,3] | ||
700 | (3><3) | ||
701 | [ 5.0, 2.0, 3.0 | ||
702 | , 10.0, 4.0, 6.0 | ||
703 | , 15.0, 6.0, 9.0 ] | ||
704 | |||
705 | -} | ||
706 | outer :: (Product t) => Vector t -> Vector t -> Matrix t | ||
707 | outer u v = asColumn u `multiply` asRow v | ||
708 | |||
709 | {- | Kronecker product of two matrices. | ||
710 | |||
711 | @m1=(2><3) | ||
712 | [ 1.0, 2.0, 0.0 | ||
713 | , 0.0, -1.0, 3.0 ] | ||
714 | m2=(4><3) | ||
715 | [ 1.0, 2.0, 3.0 | ||
716 | , 4.0, 5.0, 6.0 | ||
717 | , 7.0, 8.0, 9.0 | ||
718 | , 10.0, 11.0, 12.0 ]@ | ||
719 | |||
720 | >>> kronecker m1 m2 | ||
721 | (8><9) | ||
722 | [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0 | ||
723 | , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0 | ||
724 | , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0 | ||
725 | , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0 | ||
726 | , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0 | ||
727 | , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0 | ||
728 | , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0 | ||
729 | , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ] | ||
730 | |||
731 | -} | ||
732 | kronecker :: (Product t) => Matrix t -> Matrix t -> Matrix t | ||
733 | kronecker a b = fromBlocks | ||
734 | . splitEvery (cols a) | ||
735 | . map (reshape (cols b)) | ||
736 | . toRows | ||
737 | $ flatten a `outer` flatten b | ||
738 | |||
739 | ------------------------------------------------------------------- | ||
740 | |||
741 | |||
742 | class Convert t where | ||
743 | real :: Complexable c => c (RealOf t) -> c t | ||
744 | complex :: Complexable c => c t -> c (ComplexOf t) | ||
745 | single :: Complexable c => c t -> c (SingleOf t) | ||
746 | double :: Complexable c => c t -> c (DoubleOf t) | ||
747 | toComplex :: (Complexable c, RealElement t) => (c t, c t) -> c (Complex t) | ||
748 | fromComplex :: (Complexable c, RealElement t) => c (Complex t) -> (c t, c t) | ||
749 | |||
750 | |||
751 | instance Convert Double where | ||
752 | real = id | ||
753 | complex = comp' | ||
754 | single = single' | ||
755 | double = id | ||
756 | toComplex = toComplex' | ||
757 | fromComplex = fromComplex' | ||
758 | |||
759 | instance Convert Float where | ||
760 | real = id | ||
761 | complex = comp' | ||
762 | single = id | ||
763 | double = double' | ||
764 | toComplex = toComplex' | ||
765 | fromComplex = fromComplex' | ||
766 | |||
767 | instance Convert (Complex Double) where | ||
768 | real = comp' | ||
769 | complex = id | ||
770 | single = single' | ||
771 | double = id | ||
772 | toComplex = toComplex' | ||
773 | fromComplex = fromComplex' | ||
774 | |||
775 | instance Convert (Complex Float) where | ||
776 | real = comp' | ||
777 | complex = id | ||
778 | single = id | ||
779 | double = double' | ||
780 | toComplex = toComplex' | ||
781 | fromComplex = fromComplex' | ||
782 | |||
783 | ------------------------------------------------------------------- | ||
784 | |||
785 | type family RealOf x | ||
786 | |||
787 | type instance RealOf Double = Double | ||
788 | type instance RealOf (Complex Double) = Double | ||
789 | |||
790 | type instance RealOf Float = Float | ||
791 | type instance RealOf (Complex Float) = Float | ||
792 | |||
793 | type instance RealOf I = I | ||
794 | |||
795 | type family ComplexOf x | ||
796 | |||
797 | type instance ComplexOf Double = Complex Double | ||
798 | type instance ComplexOf (Complex Double) = Complex Double | ||
799 | |||
800 | type instance ComplexOf Float = Complex Float | ||
801 | type instance ComplexOf (Complex Float) = Complex Float | ||
802 | |||
803 | type family SingleOf x | ||
804 | |||
805 | type instance SingleOf Double = Float | ||
806 | type instance SingleOf Float = Float | ||
807 | |||
808 | type instance SingleOf (Complex a) = Complex (SingleOf a) | ||
809 | |||
810 | type family DoubleOf x | ||
811 | |||
812 | type instance DoubleOf Double = Double | ||
813 | type instance DoubleOf Float = Double | ||
814 | |||
815 | type instance DoubleOf (Complex a) = Complex (DoubleOf a) | ||
816 | |||
817 | type family ElementOf c | ||
818 | |||
819 | type instance ElementOf (Vector a) = a | ||
820 | type instance ElementOf (Matrix a) = a | ||
821 | |||
822 | ------------------------------------------------------------ | ||
823 | |||
824 | buildM (rc,cc) f = fromLists [ [f r c | c <- cs] | r <- rs ] | ||
825 | where rs = map fromIntegral [0 .. (rc-1)] | ||
826 | cs = map fromIntegral [0 .. (cc-1)] | ||
827 | |||
828 | buildV n f = fromList [f k | k <- ks] | ||
829 | where ks = map fromIntegral [0 .. (n-1)] | ||
830 | |||
831 | -------------------------------------------------------- | ||
832 | -- | conjugate transpose | ||
833 | ctrans :: (Container Vector e, Element e) => Matrix e -> Matrix e | ||
834 | ctrans = liftMatrix conj' . trans | ||
835 | |||
836 | -- | Creates a square matrix with a given diagonal. | ||
837 | diag :: (Num a, Element a) => Vector a -> Matrix a | ||
838 | diag v = diagRect 0 v n n where n = dim v | ||
839 | |||
840 | -- | creates the identity matrix of given dimension | ||
841 | ident :: (Num a, Element a) => Int -> Matrix a | ||
842 | ident n = diag (constantD 1 n) | ||
843 | |||
844 | -------------------------------------------------------- | ||
845 | |||
846 | findV p x = foldVectorWithIndex g [] x where | ||
847 | g k z l = if p z then k:l else l | ||
848 | |||
849 | findM p x = map ((`divMod` cols x)) $ findV p (flatten x) | ||
850 | |||
851 | assocV n z xs = ST.runSTVector $ do | ||
852 | v <- ST.newVector z n | ||
853 | mapM_ (\(k,x) -> ST.writeVector v k x) xs | ||
854 | return v | ||
855 | |||
856 | assocM (r,c) z xs = ST.runSTMatrix $ do | ||
857 | m <- ST.newMatrix z r c | ||
858 | mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs | ||
859 | return m | ||
860 | |||
861 | accumV v0 f xs = ST.runSTVector $ do | ||
862 | v <- ST.thawVector v0 | ||
863 | mapM_ (\(k,x) -> ST.modifyVector v k (f x)) xs | ||
864 | return v | ||
865 | |||
866 | accumM m0 f xs = ST.runSTMatrix $ do | ||
867 | m <- ST.thawMatrix m0 | ||
868 | mapM_ (\((i,j),x) -> ST.modifyMatrix m i j (f x)) xs | ||
869 | return m | ||
870 | |||
871 | ---------------------------------------------------------------------- | ||
872 | |||
873 | condM a b l e t = matrixFromVector RowMajor (rows a'') (cols a'') $ cond' a' b' l' e' t' | ||
874 | where | ||
875 | args@(a'':_) = conformMs [a,b,l,e,t] | ||
876 | [a', b', l', e', t'] = map flatten args | ||
877 | |||
878 | condV f a b l e t = f a' b' l' e' t' | ||
879 | where | ||
880 | [a', b', l', e', t'] = conformVs [a,b,l,e,t] | ||
881 | |||
882 | compareM a b = matrixFromVector RowMajor (rows a'') (cols a'') $ ccompare' a' b' | ||
883 | where | ||
884 | args@(a'':_) = conformMs [a,b] | ||
885 | [a', b'] = map flatten args | ||
886 | |||
887 | compareCV f a b = f a' b' | ||
888 | where | ||
889 | [a', b'] = conformVs [a,b] | ||
890 | |||
891 | selectM c l e t = matrixFromVector RowMajor (rows a'') (cols a'') $ cselect' (toInt c') l' e' t' | ||
892 | where | ||
893 | args@(a'':_) = conformMs [fromInt c,l,e,t] | ||
894 | [c', l', e', t'] = map flatten args | ||
895 | |||
896 | selectCV f c l e t = f (toInt c') l' e' t' | ||
897 | where | ||
898 | [c', l', e', t'] = conformVs [fromInt c,l,e,t] | ||
899 | |||
900 | -------------------------------------------------------------------------------- | ||
901 | |||
902 | class Transposable m mt | m -> mt, mt -> m | ||
903 | where | ||
904 | -- | conjugate transpose | ||
905 | tr :: m -> mt | ||
906 | -- | transpose | ||
907 | tr' :: m -> mt | ||
908 | |||
909 | instance (Container Vector t) => Transposable (Matrix t) (Matrix t) | ||
910 | where | ||
911 | tr = ctrans | ||
912 | tr' = trans | ||
913 | |||
914 | class Linear t v | ||
915 | where | ||
916 | scalarL :: t -> v | ||
917 | addL :: v -> v -> v | ||
918 | scaleL :: t -> v -> v | ||
919 | |||
920 | |||
921 | class Testable t | ||
922 | where | ||
923 | checkT :: t -> (Bool, IO()) | ||
924 | ioCheckT :: t -> IO (Bool, IO()) | ||
925 | ioCheckT = return . checkT | ||
926 | |||
927 | -------------------------------------------------------------------------------- | ||
928 | |||