summaryrefslogtreecommitdiff
path: root/packages/base/src/Data/Packed/Internal
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Data/Packed/Internal')
-rw-r--r--packages/base/src/Data/Packed/Internal/Numeric.hs928
1 files changed, 0 insertions, 928 deletions
diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Data/Packed/Internal/Numeric.hs
deleted file mode 100644
index a03159d..0000000
--- a/packages/base/src/Data/Packed/Internal/Numeric.hs
+++ /dev/null
@@ -1,928 +0,0 @@
1{-# LANGUAGE CPP #-}
2{-# LANGUAGE TypeFamilies #-}
3{-# LANGUAGE FlexibleContexts #-}
4{-# LANGUAGE FlexibleInstances #-}
5{-# LANGUAGE MultiParamTypeClasses #-}
6{-# LANGUAGE FunctionalDependencies #-}
7{-# LANGUAGE UndecidableInstances #-}
8
9-----------------------------------------------------------------------------
10-- |
11-- Module : Data.Packed.Internal.Numeric
12-- Copyright : (c) Alberto Ruiz 2010-14
13-- License : BSD3
14-- Maintainer : Alberto Ruiz
15-- Stability : provisional
16--
17-----------------------------------------------------------------------------
18
19module Data.Packed.Internal.Numeric (
20 -- * Basic functions
21 ident, diag, ctrans,
22 -- * Generic operations
23 Container(..),
24 scalar, conj, scale, arctan2, cmap, cmod,
25 atIndex, minIndex, maxIndex, minElement, maxElement,
26 sumElements, prodElements,
27 step, cond, find, assoc, accum, findV, assocV, accumV,
28 Transposable(..), Linear(..), Testable(..),
29 -- * Matrix product and related functions
30 Product(..), udot,
31 mXm,mXv,vXm,
32 outer, kronecker,
33 -- * sorting
34 sortV, sortI,
35 -- * Element conversion
36 Convert(..),
37 Complexable(),
38 RealElement(),
39 roundVector, fromInt, toInt,
40 RealOf, ComplexOf, SingleOf, DoubleOf,
41 IndexOf,
42 I, Extractor(..), (??), range, idxs, remapM,
43 module Data.Complex
44) where
45
46import Data.Packed
47import Data.Packed.ST as ST
48import Numeric.Conversion
49import Data.Packed.Development
50import Numeric.Vectorized
51import Data.Complex
52import Numeric.LinearAlgebra.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ,multiplyI)
53import Data.Packed.Internal
54import Text.Printf(printf)
55
56-------------------------------------------------------------------
57
58type family IndexOf (c :: * -> *)
59
60type instance IndexOf Vector = Int
61type instance IndexOf Matrix = (Int,Int)
62
63type family ArgOf (c :: * -> *) a
64
65type instance ArgOf Vector a = a -> a
66type instance ArgOf Matrix a = a -> a -> a
67
68--------------------------------------------------------------------------
69
70data Extractor
71 = All
72 | Range Int Int Int
73 | Pos (Vector I)
74 | PosCyc (Vector I)
75 | Take Int
76 | TakeLast Int
77 | Drop Int
78 | DropLast Int
79 deriving Show
80
81-- | Create a vector of indexes, useful for matrix extraction using '??'
82idxs :: [Int] -> Vector I
83idxs js = fromList (map fromIntegral js) :: Vector I
84
85--
86infixl 9 ??
87(??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t
88
89
90extractError m e = error $ printf "can't extract %s from matrix %dx%d" (show e) (rows m) (cols m)
91
92m ?? (Range a s b,e) | s /= 1 = m ?? (Pos (idxs [a,a+s .. b]), e)
93m ?? (e,Range a s b) | s /= 1 = m ?? (e, Pos (idxs [a,a+s .. b]))
94
95m ?? e@(Range a _ b,_) | a < 0 || b >= rows m = extractError m e
96m ?? e@(_,Range a _ b) | a < 0 || b >= cols m = extractError m e
97
98m ?? e@(Pos vs,_) | minElement vs < 0 || maxElement vs >= fromIntegral (rows m) = extractError m e
99m ?? e@(_,Pos vs) | minElement vs < 0 || maxElement vs >= fromIntegral (cols m) = extractError m e
100
101m ?? (All,All) = m
102
103m ?? (Range a _ b,e) | a > b = m ?? (Take 0,e)
104m ?? (e,Range a _ b) | a > b = m ?? (e,Take 0)
105
106m ?? (Take n,e)
107 | n <= 0 = (0><cols m) [] ?? (All,e)
108 | n >= rows m = m ?? (All,e)
109
110m ?? (e,Take n)
111 | n <= 0 = (rows m><0) [] ?? (e,All)
112 | n >= cols m = m ?? (e,All)
113
114m ?? (Drop n,e)
115 | n <= 0 = m ?? (All,e)
116 | n >= rows m = (0><cols m) [] ?? (All,e)
117
118m ?? (e,Drop n)
119 | n <= 0 = m ?? (e,All)
120 | n >= cols m = (rows m><0) [] ?? (e,All)
121
122m ?? (TakeLast n, e) = m ?? (Drop (rows m - n), e)
123m ?? (e, TakeLast n) = m ?? (e, Drop (cols m - n))
124
125m ?? (DropLast n, e) = m ?? (Take (rows m - n), e)
126m ?? (e, DropLast n) = m ?? (e, Take (cols m - n))
127
128m ?? (er,ec) = extractR m moder rs modec cs
129 where
130 (moder,rs) = mkExt (rows m) er
131 (modec,cs) = mkExt (cols m) ec
132 ran a b = (0, idxs [a,b])
133 pos ks = (1, ks)
134 mkExt _ (Pos ks) = pos ks
135 mkExt n (PosCyc ks)
136 | n == 0 = mkExt n (Take 0)
137 | otherwise = pos (cmod n ks)
138 mkExt _ (Range mn _ mx) = ran mn mx
139 mkExt _ (Take k) = ran 0 (k-1)
140 mkExt n (Drop k) = ran k (n-1)
141 mkExt n _ = ran 0 (n-1) -- All
142
143-------------------------------------------------------------------
144
145
146-- | Basic element-by-element functions for numeric containers
147class Element e => Container c e
148 where
149 conj' :: c e -> c e
150 size' :: c e -> IndexOf c
151 scalar' :: e -> c e
152 scale' :: e -> c e -> c e
153 addConstant :: e -> c e -> c e
154 add :: c e -> c e -> c e
155 sub :: c e -> c e -> c e
156 -- | element by element multiplication
157 mul :: c e -> c e -> c e
158 equal :: c e -> c e -> Bool
159 cmap' :: (Element b) => (e -> b) -> c e -> c b
160 konst' :: e -> IndexOf c -> c e
161 build' :: IndexOf c -> (ArgOf c e) -> c e
162 atIndex' :: c e -> IndexOf c -> e
163 minIndex' :: c e -> IndexOf c
164 maxIndex' :: c e -> IndexOf c
165 minElement' :: c e -> e
166 maxElement' :: c e -> e
167 sumElements' :: c e -> e
168 prodElements' :: c e -> e
169 step' :: Ord e => c e -> c e
170 cond' :: Ord e
171 => c e -- ^ a
172 -> c e -- ^ b
173 -> c e -- ^ l
174 -> c e -- ^ e
175 -> c e -- ^ g
176 -> c e -- ^ result
177 ccompare' :: Ord e => c e -> c e -> c I
178 cselect' :: c I -> c e -> c e -> c e -> c e
179 find' :: (e -> Bool) -> c e -> [IndexOf c]
180 assoc' :: IndexOf c -- ^ size
181 -> e -- ^ default value
182 -> [(IndexOf c, e)] -- ^ association list
183 -> c e -- ^ result
184 accum' :: c e -- ^ initial structure
185 -> (e -> e -> e) -- ^ update function
186 -> [(IndexOf c, e)] -- ^ association list
187 -> c e -- ^ result
188
189 -- | scale the element by element reciprocal of the object:
190 --
191 -- @scaleRecip 2 (fromList [5,i]) == 2 |> [0.4 :+ 0.0,0.0 :+ (-2.0)]@
192 scaleRecip :: Fractional e => e -> c e -> c e
193 -- | element by element division
194 divide :: Fractional e => c e -> c e -> c e
195 --
196 -- element by element inverse tangent
197 arctan2' :: Fractional e => c e -> c e -> c e
198 cmod' :: Integral e => e -> c e -> c e
199 fromInt' :: c I -> c e
200 toInt' :: c e -> c I
201
202
203--------------------------------------------------------------------------
204
205instance Container Vector I
206 where
207 conj' = id
208 size' = dim
209 scale' = vectorMapValI Scale
210 addConstant = vectorMapValI AddConstant
211 add = vectorZipI Add
212 sub = vectorZipI Sub
213 mul = vectorZipI Mul
214 equal u v = dim u == dim v && maxElement' (vectorMapI Abs (sub u v)) == 0
215 scalar' x = fromList [x]
216 konst' = constantD
217 build' = buildV
218 cmap' = mapVector
219 atIndex' = (@>)
220 minIndex' = emptyErrorV "minIndex" (fromIntegral . toScalarI MinIdx)
221 maxIndex' = emptyErrorV "maxIndex" (fromIntegral . toScalarI MaxIdx)
222 minElement' = emptyErrorV "minElement" (toScalarI Min)
223 maxElement' = emptyErrorV "maxElement" (toScalarI Max)
224 sumElements' = sumI
225 prodElements' = prodI
226 step' = stepI
227 find' = findV
228 assoc' = assocV
229 accum' = accumV
230 cond' = condV condI
231 ccompare' = compareCV compareV
232 cselect' = selectCV selectV
233 scaleRecip = undefined -- cannot match
234 divide = undefined
235 arctan2' = undefined
236 cmod' m x
237 | m /= 0 = vectorMapValI ModVS m x
238 | otherwise = error $ "cmod 0 on vector of size "++(show $ dim x)
239 fromInt' = id
240 toInt' = id
241
242instance Container Vector Float
243 where
244 conj' = id
245 size' = dim
246 scale' = vectorMapValF Scale
247 addConstant = vectorMapValF AddConstant
248 add = vectorZipF Add
249 sub = vectorZipF Sub
250 mul = vectorZipF Mul
251 equal u v = dim u == dim v && maxElement (vectorMapF Abs (sub u v)) == 0.0
252 scalar' x = fromList [x]
253 konst' = constantD
254 build' = buildV
255 cmap' = mapVector
256 atIndex' = (@>)
257 minIndex' = emptyErrorV "minIndex" (round . toScalarF MinIdx)
258 maxIndex' = emptyErrorV "maxIndex" (round . toScalarF MaxIdx)
259 minElement' = emptyErrorV "minElement" (toScalarF Min)
260 maxElement' = emptyErrorV "maxElement" (toScalarF Max)
261 sumElements' = sumF
262 prodElements' = prodF
263 step' = stepF
264 find' = findV
265 assoc' = assocV
266 accum' = accumV
267 cond' = condV condF
268 ccompare' = compareCV compareV
269 cselect' = selectCV selectV
270 scaleRecip = vectorMapValF Recip
271 divide = vectorZipF Div
272 arctan2' = vectorZipF ATan2
273 cmod' = undefined
274 fromInt' = int2floatV
275 toInt' = float2IntV
276
277
278
279instance Container Vector Double
280 where
281 conj' = id
282 size' = dim
283 scale' = vectorMapValR Scale
284 addConstant = vectorMapValR AddConstant
285 add = vectorZipR Add
286 sub = vectorZipR Sub
287 mul = vectorZipR Mul
288 equal u v = dim u == dim v && maxElement (vectorMapR Abs (sub u v)) == 0.0
289 scalar' x = fromList [x]
290 konst' = constantD
291 build' = buildV
292 cmap' = mapVector
293 atIndex' = (@>)
294 minIndex' = emptyErrorV "minIndex" (round . toScalarR MinIdx)
295 maxIndex' = emptyErrorV "maxIndex" (round . toScalarR MaxIdx)
296 minElement' = emptyErrorV "minElement" (toScalarR Min)
297 maxElement' = emptyErrorV "maxElement" (toScalarR Max)
298 sumElements' = sumR
299 prodElements' = prodR
300 step' = stepD
301 find' = findV
302 assoc' = assocV
303 accum' = accumV
304 cond' = condV condD
305 ccompare' = compareCV compareV
306 cselect' = selectCV selectV
307 scaleRecip = vectorMapValR Recip
308 divide = vectorZipR Div
309 arctan2' = vectorZipR ATan2
310 cmod' = undefined
311 fromInt' = int2DoubleV
312 toInt' = double2IntV
313
314
315instance Container Vector (Complex Double)
316 where
317 conj' = conjugateC
318 size' = dim
319 scale' = vectorMapValC Scale
320 addConstant = vectorMapValC AddConstant
321 add = vectorZipC Add
322 sub = vectorZipC Sub
323 mul = vectorZipC Mul
324 equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0
325 scalar' x = fromList [x]
326 konst' = constantD
327 build' = buildV
328 cmap' = mapVector
329 atIndex' = (@>)
330 minIndex' = emptyErrorV "minIndex" (minIndex' . fst . fromComplex . (mul <*> conj'))
331 maxIndex' = emptyErrorV "maxIndex" (maxIndex' . fst . fromComplex . (mul <*> conj'))
332 minElement' = emptyErrorV "minElement" (atIndex' <*> minIndex')
333 maxElement' = emptyErrorV "maxElement" (atIndex' <*> maxIndex')
334 sumElements' = sumC
335 prodElements' = prodC
336 step' = undefined -- cannot match
337 find' = findV
338 assoc' = assocV
339 accum' = accumV
340 cond' = undefined -- cannot match
341 ccompare' = undefined
342 cselect' = selectCV selectV
343 scaleRecip = vectorMapValC Recip
344 divide = vectorZipC Div
345 arctan2' = vectorZipC ATan2
346 cmod' = undefined
347 fromInt' = complex . int2DoubleV
348 toInt' = toInt' . fst . fromComplex
349
350instance Container Vector (Complex Float)
351 where
352 conj' = conjugateQ
353 size' = dim
354 scale' = vectorMapValQ Scale
355 addConstant = vectorMapValQ AddConstant
356 add = vectorZipQ Add
357 sub = vectorZipQ Sub
358 mul = vectorZipQ Mul
359 equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0
360 scalar' x = fromList [x]
361 konst' = constantD
362 build' = buildV
363 cmap' = mapVector
364 atIndex' = (@>)
365 minIndex' = emptyErrorV "minIndex" (minIndex' . fst . fromComplex . (mul <*> conj'))
366 maxIndex' = emptyErrorV "maxIndex" (maxIndex' . fst . fromComplex . (mul <*> conj'))
367 minElement' = emptyErrorV "minElement" (atIndex' <*> minIndex')
368 maxElement' = emptyErrorV "maxElement" (atIndex' <*> maxIndex')
369 sumElements' = sumQ
370 prodElements' = prodQ
371 step' = undefined -- cannot match
372 find' = findV
373 assoc' = assocV
374 accum' = accumV
375 cond' = undefined -- cannot match
376 ccompare' = undefined
377 cselect' = selectCV selectV
378 scaleRecip = vectorMapValQ Recip
379 divide = vectorZipQ Div
380 arctan2' = vectorZipQ ATan2
381 cmod' = undefined
382 fromInt' = complex . int2floatV
383 toInt' = toInt' . fst . fromComplex
384
385---------------------------------------------------------------
386
387instance (Num a, Element a, Container Vector a) => Container Matrix a
388 where
389 conj' = liftMatrix conj'
390 size' = size
391 scale' x = liftMatrix (scale' x)
392 addConstant x = liftMatrix (addConstant x)
393 add = liftMatrix2 add
394 sub = liftMatrix2 sub
395 mul = liftMatrix2 mul
396 equal a b = cols a == cols b && flatten a `equal` flatten b
397 scalar' x = (1><1) [x]
398 konst' v (r,c) = matrixFromVector RowMajor r c (konst' v (r*c))
399 build' = buildM
400 cmap' f = liftMatrix (mapVector f)
401 atIndex' = (@@>)
402 minIndex' = emptyErrorM "minIndex of Matrix" $
403 \m -> divMod (minIndex' $ flatten m) (cols m)
404 maxIndex' = emptyErrorM "maxIndex of Matrix" $
405 \m -> divMod (maxIndex' $ flatten m) (cols m)
406 minElement' = emptyErrorM "minElement of Matrix" (atIndex' <*> minIndex')
407 maxElement' = emptyErrorM "maxElement of Matrix" (atIndex' <*> maxIndex')
408 sumElements' = sumElements' . flatten
409 prodElements' = prodElements' . flatten
410 step' = liftMatrix step'
411 find' = findM
412 assoc' = assocM
413 accum' = accumM
414 cond' = condM
415 ccompare' = compareM
416 cselect' = selectM
417 scaleRecip x = liftMatrix (scaleRecip x)
418 divide = liftMatrix2 divide
419 arctan2' = liftMatrix2 arctan2'
420 cmod' m x
421 | m /= 0 = liftMatrix (cmod' m) x
422 | otherwise = error $ "cmod 0 on matrix "++shSize x
423 fromInt' = liftMatrix fromInt'
424 toInt' = liftMatrix toInt'
425
426
427emptyErrorV msg f v =
428 if dim v > 0
429 then f v
430 else error $ msg ++ " of empty Vector"
431
432emptyErrorM msg f m =
433 if rows m > 0 && cols m > 0
434 then f m
435 else error $ msg++" "++shSize m
436
437--------------------------------------------------------------------------------
438
439-- | create a structure with a single element
440--
441-- >>> let v = fromList [1..3::Double]
442-- >>> v / scalar (norm2 v)
443-- fromList [0.2672612419124244,0.5345224838248488,0.8017837257372732]
444--
445scalar :: Container c e => e -> c e
446scalar = scalar'
447
448-- | complex conjugate
449conj :: Container c e => c e -> c e
450conj = conj'
451
452-- | multiplication by scalar
453scale :: Container c e => e -> c e -> c e
454scale = scale'
455
456arctan2 :: (Fractional e, Container c e) => c e -> c e -> c e
457arctan2 = arctan2'
458
459-- | 'mod' for integer arrays
460--
461-- >>> cmod 3 (range 5)
462-- fromList [0,1,2,0,1]
463cmod :: (Integral e, Container c e) => Int -> c e -> c e
464cmod m = cmod' (fromIntegral m)
465
466-- |
467-- >>>fromInt ((2><2) [0..3]) :: Matrix (Complex Double)
468-- (2><2)
469-- [ 0.0 :+ 0.0, 1.0 :+ 0.0
470-- , 2.0 :+ 0.0, 3.0 :+ 0.0 ]
471--
472fromInt :: (Container c e) => c I -> c e
473fromInt = fromInt'
474
475toInt :: (Container c e) => c e -> c I
476toInt = toInt'
477
478
479-- | like 'fmap' (cannot implement instance Functor because of Element class constraint)
480cmap :: (Element b, Container c e) => (e -> b) -> c e -> c b
481cmap = cmap'
482
483-- | generic indexing function
484--
485-- >>> vector [1,2,3] `atIndex` 1
486-- 2.0
487--
488-- >>> matrix 3 [0..8] `atIndex` (2,0)
489-- 6.0
490--
491atIndex :: Container c e => c e -> IndexOf c -> e
492atIndex = atIndex'
493
494-- | index of minimum element
495minIndex :: Container c e => c e -> IndexOf c
496minIndex = minIndex'
497
498-- | index of maximum element
499maxIndex :: Container c e => c e -> IndexOf c
500maxIndex = maxIndex'
501
502-- | value of minimum element
503minElement :: Container c e => c e -> e
504minElement = minElement'
505
506-- | value of maximum element
507maxElement :: Container c e => c e -> e
508maxElement = maxElement'
509
510-- | the sum of elements
511sumElements :: Container c e => c e -> e
512sumElements = sumElements'
513
514-- | the product of elements
515prodElements :: Container c e => c e -> e
516prodElements = prodElements'
517
518
519-- | A more efficient implementation of @cmap (\\x -> if x>0 then 1 else 0)@
520--
521-- >>> step $ linspace 5 (-1,1::Double)
522-- 5 |> [0.0,0.0,0.0,1.0,1.0]
523--
524step
525 :: (Ord e, Container c e)
526 => c e
527 -> c e
528step = step'
529
530
531-- | Element by element version of @case compare a b of {LT -> l; EQ -> e; GT -> g}@.
532--
533-- Arguments with any dimension = 1 are automatically expanded:
534--
535-- >>> cond ((1><4)[1..]) ((3><1)[1..]) 0 100 ((3><4)[1..]) :: Matrix Double
536-- (3><4)
537-- [ 100.0, 2.0, 3.0, 4.0
538-- , 0.0, 100.0, 7.0, 8.0
539-- , 0.0, 0.0, 100.0, 12.0 ]
540--
541cond
542 :: (Ord e, Container c e)
543 => c e -- ^ a
544 -> c e -- ^ b
545 -> c e -- ^ l
546 -> c e -- ^ e
547 -> c e -- ^ g
548 -> c e -- ^ result
549cond = cond'
550
551
552-- | Find index of elements which satisfy a predicate
553--
554-- >>> find (>0) (ident 3 :: Matrix Double)
555-- [(0,0),(1,1),(2,2)]
556--
557find
558 :: Container c e
559 => (e -> Bool)
560 -> c e
561 -> [IndexOf c]
562find = find'
563
564
565-- | Create a structure from an association list
566--
567-- >>> assoc 5 0 [(3,7),(1,4)] :: Vector Double
568-- fromList [0.0,4.0,0.0,7.0,0.0]
569--
570-- >>> assoc (2,3) 0 [((0,2),7),((1,0),2*i-3)] :: Matrix (Complex Double)
571-- (2><3)
572-- [ 0.0 :+ 0.0, 0.0 :+ 0.0, 7.0 :+ 0.0
573-- , (-3.0) :+ 2.0, 0.0 :+ 0.0, 0.0 :+ 0.0 ]
574--
575assoc
576 :: Container c e
577 => IndexOf c -- ^ size
578 -> e -- ^ default value
579 -> [(IndexOf c, e)] -- ^ association list
580 -> c e -- ^ result
581assoc = assoc'
582
583
584-- | Modify a structure using an update function
585--
586-- >>> accum (ident 5) (+) [((1,1),5),((0,3),3)] :: Matrix Double
587-- (5><5)
588-- [ 1.0, 0.0, 0.0, 3.0, 0.0
589-- , 0.0, 6.0, 0.0, 0.0, 0.0
590-- , 0.0, 0.0, 1.0, 0.0, 0.0
591-- , 0.0, 0.0, 0.0, 1.0, 0.0
592-- , 0.0, 0.0, 0.0, 0.0, 1.0 ]
593--
594-- computation of histogram:
595--
596-- >>> accum (konst 0 7) (+) (map (flip (,) 1) [4,5,4,1,5,2,5]) :: Vector Double
597-- fromList [0.0,1.0,1.0,0.0,2.0,3.0,0.0]
598--
599accum
600 :: Container c e
601 => c e -- ^ initial structure
602 -> (e -> e -> e) -- ^ update function
603 -> [(IndexOf c, e)] -- ^ association list
604 -> c e -- ^ result
605accum = accum'
606
607
608--------------------------------------------------------------------------------
609
610-- | Matrix product and related functions
611class (Num e, Element e) => Product e where
612 -- | matrix product
613 multiply :: Matrix e -> Matrix e -> Matrix e
614 -- | sum of absolute value of elements (differs in complex case from @norm1@)
615 absSum :: Vector e -> RealOf e
616 -- | sum of absolute value of elements
617 norm1 :: Vector e -> RealOf e
618 -- | euclidean norm
619 norm2 :: Floating e => Vector e -> RealOf e
620 -- | element of maximum magnitude
621 normInf :: Vector e -> RealOf e
622
623instance Product Float where
624 norm2 = emptyVal (toScalarF Norm2)
625 absSum = emptyVal (toScalarF AbsSum)
626 norm1 = emptyVal (toScalarF AbsSum)
627 normInf = emptyVal (maxElement . vectorMapF Abs)
628 multiply = emptyMul multiplyF
629
630instance Product Double where
631 norm2 = emptyVal (toScalarR Norm2)
632 absSum = emptyVal (toScalarR AbsSum)
633 norm1 = emptyVal (toScalarR AbsSum)
634 normInf = emptyVal (maxElement . vectorMapR Abs)
635 multiply = emptyMul multiplyR
636
637instance Product (Complex Float) where
638 norm2 = emptyVal (toScalarQ Norm2)
639 absSum = emptyVal (toScalarQ AbsSum)
640 norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapQ Abs)
641 normInf = emptyVal (maxElement . fst . fromComplex . vectorMapQ Abs)
642 multiply = emptyMul multiplyQ
643
644instance Product (Complex Double) where
645 norm2 = emptyVal (toScalarC Norm2)
646 absSum = emptyVal (toScalarC AbsSum)
647 norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapC Abs)
648 normInf = emptyVal (maxElement . fst . fromComplex . vectorMapC Abs)
649 multiply = emptyMul multiplyC
650
651instance Product I where
652 norm2 = undefined
653 absSum = emptyVal (sumElements . vectorMapI Abs)
654 norm1 = absSum
655 normInf = emptyVal (maxElement . vectorMapI Abs)
656 multiply = emptyMul multiplyI
657
658
659emptyMul m a b
660 | x1 == 0 && x2 == 0 || r == 0 || c == 0 = konst' 0 (r,c)
661 | otherwise = m a b
662 where
663 r = rows a
664 x1 = cols a
665 x2 = rows b
666 c = cols b
667
668emptyVal f v =
669 if dim v > 0
670 then f v
671 else 0
672
673-- FIXME remove unused C wrappers
674-- | unconjugated dot product
675udot :: Product e => Vector e -> Vector e -> e
676udot u v
677 | dim u == dim v = val (asRow u `multiply` asColumn v)
678 | otherwise = error $ "different dimensions "++show (dim u)++" and "++show (dim v)++" in dot product"
679 where
680 val m | dim u > 0 = m@@>(0,0)
681 | otherwise = 0
682
683----------------------------------------------------------
684
685-- synonym for matrix product
686mXm :: Product t => Matrix t -> Matrix t -> Matrix t
687mXm = multiply
688
689-- matrix - vector product
690mXv :: Product t => Matrix t -> Vector t -> Vector t
691mXv m v = flatten $ m `mXm` (asColumn v)
692
693-- vector - matrix product
694vXm :: Product t => Vector t -> Matrix t -> Vector t
695vXm v m = flatten $ (asRow v) `mXm` m
696
697{- | Outer product of two vectors.
698
699>>> fromList [1,2,3] `outer` fromList [5,2,3]
700(3><3)
701 [ 5.0, 2.0, 3.0
702 , 10.0, 4.0, 6.0
703 , 15.0, 6.0, 9.0 ]
704
705-}
706outer :: (Product t) => Vector t -> Vector t -> Matrix t
707outer u v = asColumn u `multiply` asRow v
708
709{- | Kronecker product of two matrices.
710
711@m1=(2><3)
712 [ 1.0, 2.0, 0.0
713 , 0.0, -1.0, 3.0 ]
714m2=(4><3)
715 [ 1.0, 2.0, 3.0
716 , 4.0, 5.0, 6.0
717 , 7.0, 8.0, 9.0
718 , 10.0, 11.0, 12.0 ]@
719
720>>> kronecker m1 m2
721(8><9)
722 [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0
723 , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0
724 , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0
725 , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0
726 , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0
727 , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0
728 , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0
729 , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]
730
731-}
732kronecker :: (Product t) => Matrix t -> Matrix t -> Matrix t
733kronecker a b = fromBlocks
734 . splitEvery (cols a)
735 . map (reshape (cols b))
736 . toRows
737 $ flatten a `outer` flatten b
738
739-------------------------------------------------------------------
740
741
742class Convert t where
743 real :: Complexable c => c (RealOf t) -> c t
744 complex :: Complexable c => c t -> c (ComplexOf t)
745 single :: Complexable c => c t -> c (SingleOf t)
746 double :: Complexable c => c t -> c (DoubleOf t)
747 toComplex :: (Complexable c, RealElement t) => (c t, c t) -> c (Complex t)
748 fromComplex :: (Complexable c, RealElement t) => c (Complex t) -> (c t, c t)
749
750
751instance Convert Double where
752 real = id
753 complex = comp'
754 single = single'
755 double = id
756 toComplex = toComplex'
757 fromComplex = fromComplex'
758
759instance Convert Float where
760 real = id
761 complex = comp'
762 single = id
763 double = double'
764 toComplex = toComplex'
765 fromComplex = fromComplex'
766
767instance Convert (Complex Double) where
768 real = comp'
769 complex = id
770 single = single'
771 double = id
772 toComplex = toComplex'
773 fromComplex = fromComplex'
774
775instance Convert (Complex Float) where
776 real = comp'
777 complex = id
778 single = id
779 double = double'
780 toComplex = toComplex'
781 fromComplex = fromComplex'
782
783-------------------------------------------------------------------
784
785type family RealOf x
786
787type instance RealOf Double = Double
788type instance RealOf (Complex Double) = Double
789
790type instance RealOf Float = Float
791type instance RealOf (Complex Float) = Float
792
793type instance RealOf I = I
794
795type family ComplexOf x
796
797type instance ComplexOf Double = Complex Double
798type instance ComplexOf (Complex Double) = Complex Double
799
800type instance ComplexOf Float = Complex Float
801type instance ComplexOf (Complex Float) = Complex Float
802
803type family SingleOf x
804
805type instance SingleOf Double = Float
806type instance SingleOf Float = Float
807
808type instance SingleOf (Complex a) = Complex (SingleOf a)
809
810type family DoubleOf x
811
812type instance DoubleOf Double = Double
813type instance DoubleOf Float = Double
814
815type instance DoubleOf (Complex a) = Complex (DoubleOf a)
816
817type family ElementOf c
818
819type instance ElementOf (Vector a) = a
820type instance ElementOf (Matrix a) = a
821
822------------------------------------------------------------
823
824buildM (rc,cc) f = fromLists [ [f r c | c <- cs] | r <- rs ]
825 where rs = map fromIntegral [0 .. (rc-1)]
826 cs = map fromIntegral [0 .. (cc-1)]
827
828buildV n f = fromList [f k | k <- ks]
829 where ks = map fromIntegral [0 .. (n-1)]
830
831--------------------------------------------------------
832-- | conjugate transpose
833ctrans :: (Container Vector e, Element e) => Matrix e -> Matrix e
834ctrans = liftMatrix conj' . trans
835
836-- | Creates a square matrix with a given diagonal.
837diag :: (Num a, Element a) => Vector a -> Matrix a
838diag v = diagRect 0 v n n where n = dim v
839
840-- | creates the identity matrix of given dimension
841ident :: (Num a, Element a) => Int -> Matrix a
842ident n = diag (constantD 1 n)
843
844--------------------------------------------------------
845
846findV p x = foldVectorWithIndex g [] x where
847 g k z l = if p z then k:l else l
848
849findM p x = map ((`divMod` cols x)) $ findV p (flatten x)
850
851assocV n z xs = ST.runSTVector $ do
852 v <- ST.newVector z n
853 mapM_ (\(k,x) -> ST.writeVector v k x) xs
854 return v
855
856assocM (r,c) z xs = ST.runSTMatrix $ do
857 m <- ST.newMatrix z r c
858 mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs
859 return m
860
861accumV v0 f xs = ST.runSTVector $ do
862 v <- ST.thawVector v0
863 mapM_ (\(k,x) -> ST.modifyVector v k (f x)) xs
864 return v
865
866accumM m0 f xs = ST.runSTMatrix $ do
867 m <- ST.thawMatrix m0
868 mapM_ (\((i,j),x) -> ST.modifyMatrix m i j (f x)) xs
869 return m
870
871----------------------------------------------------------------------
872
873condM a b l e t = matrixFromVector RowMajor (rows a'') (cols a'') $ cond' a' b' l' e' t'
874 where
875 args@(a'':_) = conformMs [a,b,l,e,t]
876 [a', b', l', e', t'] = map flatten args
877
878condV f a b l e t = f a' b' l' e' t'
879 where
880 [a', b', l', e', t'] = conformVs [a,b,l,e,t]
881
882compareM a b = matrixFromVector RowMajor (rows a'') (cols a'') $ ccompare' a' b'
883 where
884 args@(a'':_) = conformMs [a,b]
885 [a', b'] = map flatten args
886
887compareCV f a b = f a' b'
888 where
889 [a', b'] = conformVs [a,b]
890
891selectM c l e t = matrixFromVector RowMajor (rows a'') (cols a'') $ cselect' (toInt c') l' e' t'
892 where
893 args@(a'':_) = conformMs [fromInt c,l,e,t]
894 [c', l', e', t'] = map flatten args
895
896selectCV f c l e t = f (toInt c') l' e' t'
897 where
898 [c', l', e', t'] = conformVs [fromInt c,l,e,t]
899
900--------------------------------------------------------------------------------
901
902class Transposable m mt | m -> mt, mt -> m
903 where
904 -- | conjugate transpose
905 tr :: m -> mt
906 -- | transpose
907 tr' :: m -> mt
908
909instance (Container Vector t) => Transposable (Matrix t) (Matrix t)
910 where
911 tr = ctrans
912 tr' = trans
913
914class Linear t v
915 where
916 scalarL :: t -> v
917 addL :: v -> v -> v
918 scaleL :: t -> v -> v
919
920
921class Testable t
922 where
923 checkT :: t -> (Bool, IO())
924 ioCheckT :: t -> IO (Bool, IO())
925 ioCheckT = return . checkT
926
927--------------------------------------------------------------------------------
928