summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal')
-rw-r--r--packages/base/src/Internal/Numeric.hs879
1 files changed, 879 insertions, 0 deletions
diff --git a/packages/base/src/Internal/Numeric.hs b/packages/base/src/Internal/Numeric.hs
new file mode 100644
index 0000000..86a4a4c
--- /dev/null
+++ b/packages/base/src/Internal/Numeric.hs
@@ -0,0 +1,879 @@
1{-# LANGUAGE CPP #-}
2{-# LANGUAGE TypeFamilies #-}
3{-# LANGUAGE FlexibleContexts #-}
4{-# LANGUAGE FlexibleInstances #-}
5{-# LANGUAGE MultiParamTypeClasses #-}
6{-# LANGUAGE FunctionalDependencies #-}
7{-# LANGUAGE UndecidableInstances #-}
8
9-----------------------------------------------------------------------------
10-- |
11-- Module : Data.Packed.Internal.Numeric
12-- Copyright : (c) Alberto Ruiz 2010-14
13-- License : BSD3
14-- Maintainer : Alberto Ruiz
15-- Stability : provisional
16--
17-----------------------------------------------------------------------------
18
19module Internal.Numeric where
20
21import Internal.Tools
22import Internal.Vector
23import Internal.Matrix
24import Internal.Element
25import Internal.ST as ST
26import Internal.Conversion
27import Internal.Vectorized
28import Internal.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ,multiplyI)
29import Data.Vector.Storable(fromList)
30import Text.Printf(printf)
31
32-------------------------------------------------------------------
33
34type family IndexOf (c :: * -> *)
35
36type instance IndexOf Vector = Int
37type instance IndexOf Matrix = (Int,Int)
38
39type family ArgOf (c :: * -> *) a
40
41type instance ArgOf Vector a = a -> a
42type instance ArgOf Matrix a = a -> a -> a
43
44--------------------------------------------------------------------------
45
46data Extractor
47 = All
48 | Range Int Int Int
49 | Pos (Vector I)
50 | PosCyc (Vector I)
51 | Take Int
52 | TakeLast Int
53 | Drop Int
54 | DropLast Int
55 deriving Show
56
57
58--
59infixl 9 ??
60(??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t
61
62
63extractError m e = error $ printf "can't extract %s from matrix %dx%d" (show e) (rows m) (cols m)
64
65m ?? (Range a s b,e) | s /= 1 = m ?? (Pos (idxs [a,a+s .. b]), e)
66m ?? (e,Range a s b) | s /= 1 = m ?? (e, Pos (idxs [a,a+s .. b]))
67
68m ?? e@(Range a _ b,_) | a < 0 || b >= rows m = extractError m e
69m ?? e@(_,Range a _ b) | a < 0 || b >= cols m = extractError m e
70
71m ?? e@(Pos vs,_) | minElement vs < 0 || maxElement vs >= fromIntegral (rows m) = extractError m e
72m ?? e@(_,Pos vs) | minElement vs < 0 || maxElement vs >= fromIntegral (cols m) = extractError m e
73
74m ?? (All,All) = m
75
76m ?? (Range a _ b,e) | a > b = m ?? (Take 0,e)
77m ?? (e,Range a _ b) | a > b = m ?? (e,Take 0)
78
79m ?? (Take n,e)
80 | n <= 0 = (0><cols m) [] ?? (All,e)
81 | n >= rows m = m ?? (All,e)
82
83m ?? (e,Take n)
84 | n <= 0 = (rows m><0) [] ?? (e,All)
85 | n >= cols m = m ?? (e,All)
86
87m ?? (Drop n,e)
88 | n <= 0 = m ?? (All,e)
89 | n >= rows m = (0><cols m) [] ?? (All,e)
90
91m ?? (e,Drop n)
92 | n <= 0 = m ?? (e,All)
93 | n >= cols m = (rows m><0) [] ?? (e,All)
94
95m ?? (TakeLast n, e) = m ?? (Drop (rows m - n), e)
96m ?? (e, TakeLast n) = m ?? (e, Drop (cols m - n))
97
98m ?? (DropLast n, e) = m ?? (Take (rows m - n), e)
99m ?? (e, DropLast n) = m ?? (e, Take (cols m - n))
100
101m ?? (er,ec) = extractR m moder rs modec cs
102 where
103 (moder,rs) = mkExt (rows m) er
104 (modec,cs) = mkExt (cols m) ec
105 ran a b = (0, idxs [a,b])
106 pos ks = (1, ks)
107 mkExt _ (Pos ks) = pos ks
108 mkExt n (PosCyc ks)
109 | n == 0 = mkExt n (Take 0)
110 | otherwise = pos (cmod n ks)
111 mkExt _ (Range mn _ mx) = ran mn mx
112 mkExt _ (Take k) = ran 0 (k-1)
113 mkExt n (Drop k) = ran k (n-1)
114 mkExt n _ = ran 0 (n-1) -- All
115
116-------------------------------------------------------------------
117
118
119-- | Basic element-by-element functions for numeric containers
120class Element e => Container c e
121 where
122 conj' :: c e -> c e
123 size' :: c e -> IndexOf c
124 scalar' :: e -> c e
125 scale' :: e -> c e -> c e
126 addConstant :: e -> c e -> c e
127 add :: c e -> c e -> c e
128 sub :: c e -> c e -> c e
129 -- | element by element multiplication
130 mul :: c e -> c e -> c e
131 equal :: c e -> c e -> Bool
132 cmap' :: (Element b) => (e -> b) -> c e -> c b
133 konst' :: e -> IndexOf c -> c e
134 build' :: IndexOf c -> (ArgOf c e) -> c e
135 atIndex' :: c e -> IndexOf c -> e
136 minIndex' :: c e -> IndexOf c
137 maxIndex' :: c e -> IndexOf c
138 minElement' :: c e -> e
139 maxElement' :: c e -> e
140 sumElements' :: c e -> e
141 prodElements' :: c e -> e
142 step' :: Ord e => c e -> c e
143 ccompare' :: Ord e => c e -> c e -> c I
144 cselect' :: c I -> c e -> c e -> c e -> c e
145 find' :: (e -> Bool) -> c e -> [IndexOf c]
146 assoc' :: IndexOf c -- ^ size
147 -> e -- ^ default value
148 -> [(IndexOf c, e)] -- ^ association list
149 -> c e -- ^ result
150 accum' :: c e -- ^ initial structure
151 -> (e -> e -> e) -- ^ update function
152 -> [(IndexOf c, e)] -- ^ association list
153 -> c e -- ^ result
154
155 -- | scale the element by element reciprocal of the object:
156 --
157 -- @scaleRecip 2 (fromList [5,i]) == 2 |> [0.4 :+ 0.0,0.0 :+ (-2.0)]@
158 scaleRecip :: Fractional e => e -> c e -> c e
159 -- | element by element division
160 divide :: Fractional e => c e -> c e -> c e
161 --
162 -- element by element inverse tangent
163 arctan2' :: Fractional e => c e -> c e -> c e
164 cmod' :: Integral e => e -> c e -> c e
165 fromInt' :: c I -> c e
166 toInt' :: c e -> c I
167
168
169--------------------------------------------------------------------------
170
171instance Container Vector I
172 where
173 conj' = id
174 size' = dim
175 scale' = vectorMapValI Scale
176 addConstant = vectorMapValI AddConstant
177 add = vectorZipI Add
178 sub = vectorZipI Sub
179 mul = vectorZipI Mul
180 equal u v = dim u == dim v && maxElement' (vectorMapI Abs (sub u v)) == 0
181 scalar' x = fromList [x]
182 konst' = constantD
183 build' = buildV
184 cmap' = mapVector
185 atIndex' = (@>)
186 minIndex' = emptyErrorV "minIndex" (fromIntegral . toScalarI MinIdx)
187 maxIndex' = emptyErrorV "maxIndex" (fromIntegral . toScalarI MaxIdx)
188 minElement' = emptyErrorV "minElement" (toScalarI Min)
189 maxElement' = emptyErrorV "maxElement" (toScalarI Max)
190 sumElements' = sumI
191 prodElements' = prodI
192 step' = stepI
193 find' = findV
194 assoc' = assocV
195 accum' = accumV
196 ccompare' = compareCV compareV
197 cselect' = selectCV selectV
198 scaleRecip = undefined -- cannot match
199 divide = undefined
200 arctan2' = undefined
201 cmod' m x
202 | m /= 0 = vectorMapValI ModVS m x
203 | otherwise = error $ "cmod 0 on vector of size "++(show $ dim x)
204 fromInt' = id
205 toInt' = id
206
207instance Container Vector Float
208 where
209 conj' = id
210 size' = dim
211 scale' = vectorMapValF Scale
212 addConstant = vectorMapValF AddConstant
213 add = vectorZipF Add
214 sub = vectorZipF Sub
215 mul = vectorZipF Mul
216 equal u v = dim u == dim v && maxElement (vectorMapF Abs (sub u v)) == 0.0
217 scalar' x = fromList [x]
218 konst' = constantD
219 build' = buildV
220 cmap' = mapVector
221 atIndex' = (@>)
222 minIndex' = emptyErrorV "minIndex" (round . toScalarF MinIdx)
223 maxIndex' = emptyErrorV "maxIndex" (round . toScalarF MaxIdx)
224 minElement' = emptyErrorV "minElement" (toScalarF Min)
225 maxElement' = emptyErrorV "maxElement" (toScalarF Max)
226 sumElements' = sumF
227 prodElements' = prodF
228 step' = stepF
229 find' = findV
230 assoc' = assocV
231 accum' = accumV
232 ccompare' = compareCV compareV
233 cselect' = selectCV selectV
234 scaleRecip = vectorMapValF Recip
235 divide = vectorZipF Div
236 arctan2' = vectorZipF ATan2
237 cmod' = undefined
238 fromInt' = int2floatV
239 toInt' = float2IntV
240
241
242
243instance Container Vector Double
244 where
245 conj' = id
246 size' = dim
247 scale' = vectorMapValR Scale
248 addConstant = vectorMapValR AddConstant
249 add = vectorZipR Add
250 sub = vectorZipR Sub
251 mul = vectorZipR Mul
252 equal u v = dim u == dim v && maxElement (vectorMapR Abs (sub u v)) == 0.0
253 scalar' x = fromList [x]
254 konst' = constantD
255 build' = buildV
256 cmap' = mapVector
257 atIndex' = (@>)
258 minIndex' = emptyErrorV "minIndex" (round . toScalarR MinIdx)
259 maxIndex' = emptyErrorV "maxIndex" (round . toScalarR MaxIdx)
260 minElement' = emptyErrorV "minElement" (toScalarR Min)
261 maxElement' = emptyErrorV "maxElement" (toScalarR Max)
262 sumElements' = sumR
263 prodElements' = prodR
264 step' = stepD
265 find' = findV
266 assoc' = assocV
267 accum' = accumV
268 ccompare' = compareCV compareV
269 cselect' = selectCV selectV
270 scaleRecip = vectorMapValR Recip
271 divide = vectorZipR Div
272 arctan2' = vectorZipR ATan2
273 cmod' = undefined
274 fromInt' = int2DoubleV
275 toInt' = double2IntV
276
277
278instance Container Vector (Complex Double)
279 where
280 conj' = conjugateC
281 size' = dim
282 scale' = vectorMapValC Scale
283 addConstant = vectorMapValC AddConstant
284 add = vectorZipC Add
285 sub = vectorZipC Sub
286 mul = vectorZipC Mul
287 equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0
288 scalar' x = fromList [x]
289 konst' = constantD
290 build' = buildV
291 cmap' = mapVector
292 atIndex' = (@>)
293 minIndex' = emptyErrorV "minIndex" (minIndex' . fst . fromComplex . (mul <*> conj'))
294 maxIndex' = emptyErrorV "maxIndex" (maxIndex' . fst . fromComplex . (mul <*> conj'))
295 minElement' = emptyErrorV "minElement" (atIndex' <*> minIndex')
296 maxElement' = emptyErrorV "maxElement" (atIndex' <*> maxIndex')
297 sumElements' = sumC
298 prodElements' = prodC
299 step' = undefined -- cannot match
300 find' = findV
301 assoc' = assocV
302 accum' = accumV
303 ccompare' = undefined -- cannot match
304 cselect' = selectCV selectV
305 scaleRecip = vectorMapValC Recip
306 divide = vectorZipC Div
307 arctan2' = vectorZipC ATan2
308 cmod' = undefined
309 fromInt' = complex . int2DoubleV
310 toInt' = toInt' . fst . fromComplex
311
312instance Container Vector (Complex Float)
313 where
314 conj' = conjugateQ
315 size' = dim
316 scale' = vectorMapValQ Scale
317 addConstant = vectorMapValQ AddConstant
318 add = vectorZipQ Add
319 sub = vectorZipQ Sub
320 mul = vectorZipQ Mul
321 equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0
322 scalar' x = fromList [x]
323 konst' = constantD
324 build' = buildV
325 cmap' = mapVector
326 atIndex' = (@>)
327 minIndex' = emptyErrorV "minIndex" (minIndex' . fst . fromComplex . (mul <*> conj'))
328 maxIndex' = emptyErrorV "maxIndex" (maxIndex' . fst . fromComplex . (mul <*> conj'))
329 minElement' = emptyErrorV "minElement" (atIndex' <*> minIndex')
330 maxElement' = emptyErrorV "maxElement" (atIndex' <*> maxIndex')
331 sumElements' = sumQ
332 prodElements' = prodQ
333 step' = undefined -- cannot match
334 find' = findV
335 assoc' = assocV
336 accum' = accumV
337 ccompare' = undefined -- cannot match
338 cselect' = selectCV selectV
339 scaleRecip = vectorMapValQ Recip
340 divide = vectorZipQ Div
341 arctan2' = vectorZipQ ATan2
342 cmod' = undefined
343 fromInt' = complex . int2floatV
344 toInt' = toInt' . fst . fromComplex
345
346---------------------------------------------------------------
347
348instance (Num a, Element a, Container Vector a) => Container Matrix a
349 where
350 conj' = liftMatrix conj'
351 size' = size
352 scale' x = liftMatrix (scale' x)
353 addConstant x = liftMatrix (addConstant x)
354 add = liftMatrix2 add
355 sub = liftMatrix2 sub
356 mul = liftMatrix2 mul
357 equal a b = cols a == cols b && flatten a `equal` flatten b
358 scalar' x = (1><1) [x]
359 konst' v (r,c) = matrixFromVector RowMajor r c (konst' v (r*c))
360 build' = buildM
361 cmap' f = liftMatrix (mapVector f)
362 atIndex' = (@@>)
363 minIndex' = emptyErrorM "minIndex of Matrix" $
364 \m -> divMod (minIndex' $ flatten m) (cols m)
365 maxIndex' = emptyErrorM "maxIndex of Matrix" $
366 \m -> divMod (maxIndex' $ flatten m) (cols m)
367 minElement' = emptyErrorM "minElement of Matrix" (atIndex' <*> minIndex')
368 maxElement' = emptyErrorM "maxElement of Matrix" (atIndex' <*> maxIndex')
369 sumElements' = sumElements' . flatten
370 prodElements' = prodElements' . flatten
371 step' = liftMatrix step'
372 find' = findM
373 assoc' = assocM
374 accum' = accumM
375 ccompare' = compareM
376 cselect' = selectM
377 scaleRecip x = liftMatrix (scaleRecip x)
378 divide = liftMatrix2 divide
379 arctan2' = liftMatrix2 arctan2'
380 cmod' m x
381 | m /= 0 = liftMatrix (cmod' m) x
382 | otherwise = error $ "cmod 0 on matrix "++shSize x
383 fromInt' = liftMatrix fromInt'
384 toInt' = liftMatrix toInt'
385
386
387emptyErrorV msg f v =
388 if dim v > 0
389 then f v
390 else error $ msg ++ " of empty Vector"
391
392emptyErrorM msg f m =
393 if rows m > 0 && cols m > 0
394 then f m
395 else error $ msg++" "++shSize m
396
397--------------------------------------------------------------------------------
398
399-- | create a structure with a single element
400--
401-- >>> let v = fromList [1..3::Double]
402-- >>> v / scalar (norm2 v)
403-- fromList [0.2672612419124244,0.5345224838248488,0.8017837257372732]
404--
405scalar :: Container c e => e -> c e
406scalar = scalar'
407
408-- | complex conjugate
409conj :: Container c e => c e -> c e
410conj = conj'
411
412-- | multiplication by scalar
413scale :: Container c e => e -> c e -> c e
414scale = scale'
415
416arctan2 :: (Fractional e, Container c e) => c e -> c e -> c e
417arctan2 = arctan2'
418
419-- | 'mod' for integer arrays
420--
421-- >>> cmod 3 (range 5)
422-- fromList [0,1,2,0,1]
423cmod :: (Integral e, Container c e) => Int -> c e -> c e
424cmod m = cmod' (fromIntegral m)
425
426-- |
427-- >>>fromInt ((2><2) [0..3]) :: Matrix (Complex Double)
428-- (2><2)
429-- [ 0.0 :+ 0.0, 1.0 :+ 0.0
430-- , 2.0 :+ 0.0, 3.0 :+ 0.0 ]
431--
432fromInt :: (Container c e) => c I -> c e
433fromInt = fromInt'
434
435toInt :: (Container c e) => c e -> c I
436toInt = toInt'
437
438
439-- | like 'fmap' (cannot implement instance Functor because of Element class constraint)
440cmap :: (Element b, Container c e) => (e -> b) -> c e -> c b
441cmap = cmap'
442
443-- | generic indexing function
444--
445-- >>> vector [1,2,3] `atIndex` 1
446-- 2.0
447--
448-- >>> matrix 3 [0..8] `atIndex` (2,0)
449-- 6.0
450--
451atIndex :: Container c e => c e -> IndexOf c -> e
452atIndex = atIndex'
453
454-- | index of minimum element
455minIndex :: Container c e => c e -> IndexOf c
456minIndex = minIndex'
457
458-- | index of maximum element
459maxIndex :: Container c e => c e -> IndexOf c
460maxIndex = maxIndex'
461
462-- | value of minimum element
463minElement :: Container c e => c e -> e
464minElement = minElement'
465
466-- | value of maximum element
467maxElement :: Container c e => c e -> e
468maxElement = maxElement'
469
470-- | the sum of elements
471sumElements :: Container c e => c e -> e
472sumElements = sumElements'
473
474-- | the product of elements
475prodElements :: Container c e => c e -> e
476prodElements = prodElements'
477
478
479-- | A more efficient implementation of @cmap (\\x -> if x>0 then 1 else 0)@
480--
481-- >>> step $ linspace 5 (-1,1::Double)
482-- 5 |> [0.0,0.0,0.0,1.0,1.0]
483--
484step
485 :: (Ord e, Container c e)
486 => c e
487 -> c e
488step = step'
489
490
491-- | Element by element version of @case compare a b of {LT -> l; EQ -> e; GT -> g}@.
492--
493-- Arguments with any dimension = 1 are automatically expanded:
494--
495-- >>> cond ((1><4)[1..]) ((3><1)[1..]) 0 100 ((3><4)[1..]) :: Matrix Double
496-- (3><4)
497-- [ 100.0, 2.0, 3.0, 4.0
498-- , 0.0, 100.0, 7.0, 8.0
499-- , 0.0, 0.0, 100.0, 12.0 ]
500--
501cond
502 :: (Ord e, Container c e)
503 => c e -- ^ a
504 -> c e -- ^ b
505 -> c e -- ^ l
506 -> c e -- ^ e
507 -> c e -- ^ g
508 -> c e -- ^ result
509cond a b l e g = cselect' (ccompare' a b) l e g
510
511
512-- | Find index of elements which satisfy a predicate
513--
514-- >>> find (>0) (ident 3 :: Matrix Double)
515-- [(0,0),(1,1),(2,2)]
516--
517find
518 :: Container c e
519 => (e -> Bool)
520 -> c e
521 -> [IndexOf c]
522find = find'
523
524
525-- | Create a structure from an association list
526--
527-- >>> assoc 5 0 [(3,7),(1,4)] :: Vector Double
528-- fromList [0.0,4.0,0.0,7.0,0.0]
529--
530-- >>> assoc (2,3) 0 [((0,2),7),((1,0),2*i-3)] :: Matrix (Complex Double)
531-- (2><3)
532-- [ 0.0 :+ 0.0, 0.0 :+ 0.0, 7.0 :+ 0.0
533-- , (-3.0) :+ 2.0, 0.0 :+ 0.0, 0.0 :+ 0.0 ]
534--
535assoc
536 :: Container c e
537 => IndexOf c -- ^ size
538 -> e -- ^ default value
539 -> [(IndexOf c, e)] -- ^ association list
540 -> c e -- ^ result
541assoc = assoc'
542
543
544-- | Modify a structure using an update function
545--
546-- >>> accum (ident 5) (+) [((1,1),5),((0,3),3)] :: Matrix Double
547-- (5><5)
548-- [ 1.0, 0.0, 0.0, 3.0, 0.0
549-- , 0.0, 6.0, 0.0, 0.0, 0.0
550-- , 0.0, 0.0, 1.0, 0.0, 0.0
551-- , 0.0, 0.0, 0.0, 1.0, 0.0
552-- , 0.0, 0.0, 0.0, 0.0, 1.0 ]
553--
554-- computation of histogram:
555--
556-- >>> accum (konst 0 7) (+) (map (flip (,) 1) [4,5,4,1,5,2,5]) :: Vector Double
557-- fromList [0.0,1.0,1.0,0.0,2.0,3.0,0.0]
558--
559accum
560 :: Container c e
561 => c e -- ^ initial structure
562 -> (e -> e -> e) -- ^ update function
563 -> [(IndexOf c, e)] -- ^ association list
564 -> c e -- ^ result
565accum = accum'
566
567
568--------------------------------------------------------------------------------
569
570-- | Matrix product and related functions
571class (Num e, Element e) => Product e where
572 -- | matrix product
573 multiply :: Matrix e -> Matrix e -> Matrix e
574 -- | sum of absolute value of elements (differs in complex case from @norm1@)
575 absSum :: Vector e -> RealOf e
576 -- | sum of absolute value of elements
577 norm1 :: Vector e -> RealOf e
578 -- | euclidean norm
579 norm2 :: Floating e => Vector e -> RealOf e
580 -- | element of maximum magnitude
581 normInf :: Vector e -> RealOf e
582
583instance Product Float where
584 norm2 = emptyVal (toScalarF Norm2)
585 absSum = emptyVal (toScalarF AbsSum)
586 norm1 = emptyVal (toScalarF AbsSum)
587 normInf = emptyVal (maxElement . vectorMapF Abs)
588 multiply = emptyMul multiplyF
589
590instance Product Double where
591 norm2 = emptyVal (toScalarR Norm2)
592 absSum = emptyVal (toScalarR AbsSum)
593 norm1 = emptyVal (toScalarR AbsSum)
594 normInf = emptyVal (maxElement . vectorMapR Abs)
595 multiply = emptyMul multiplyR
596
597instance Product (Complex Float) where
598 norm2 = emptyVal (toScalarQ Norm2)
599 absSum = emptyVal (toScalarQ AbsSum)
600 norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapQ Abs)
601 normInf = emptyVal (maxElement . fst . fromComplex . vectorMapQ Abs)
602 multiply = emptyMul multiplyQ
603
604instance Product (Complex Double) where
605 norm2 = emptyVal (toScalarC Norm2)
606 absSum = emptyVal (toScalarC AbsSum)
607 norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapC Abs)
608 normInf = emptyVal (maxElement . fst . fromComplex . vectorMapC Abs)
609 multiply = emptyMul multiplyC
610
611instance Product I where
612 norm2 = undefined
613 absSum = emptyVal (sumElements . vectorMapI Abs)
614 norm1 = absSum
615 normInf = emptyVal (maxElement . vectorMapI Abs)
616 multiply = emptyMul multiplyI
617
618
619emptyMul m a b
620 | x1 == 0 && x2 == 0 || r == 0 || c == 0 = konst' 0 (r,c)
621 | otherwise = m a b
622 where
623 r = rows a
624 x1 = cols a
625 x2 = rows b
626 c = cols b
627
628emptyVal f v =
629 if dim v > 0
630 then f v
631 else 0
632
633-- FIXME remove unused C wrappers
634-- | unconjugated dot product
635udot :: Product e => Vector e -> Vector e -> e
636udot u v
637 | dim u == dim v = val (asRow u `multiply` asColumn v)
638 | otherwise = error $ "different dimensions "++show (dim u)++" and "++show (dim v)++" in dot product"
639 where
640 val m | dim u > 0 = m@@>(0,0)
641 | otherwise = 0
642
643----------------------------------------------------------
644
645-- synonym for matrix product
646mXm :: Product t => Matrix t -> Matrix t -> Matrix t
647mXm = multiply
648
649-- matrix - vector product
650mXv :: Product t => Matrix t -> Vector t -> Vector t
651mXv m v = flatten $ m `mXm` (asColumn v)
652
653-- vector - matrix product
654vXm :: Product t => Vector t -> Matrix t -> Vector t
655vXm v m = flatten $ (asRow v) `mXm` m
656
657{- | Outer product of two vectors.
658
659>>> fromList [1,2,3] `outer` fromList [5,2,3]
660(3><3)
661 [ 5.0, 2.0, 3.0
662 , 10.0, 4.0, 6.0
663 , 15.0, 6.0, 9.0 ]
664
665-}
666outer :: (Product t) => Vector t -> Vector t -> Matrix t
667outer u v = asColumn u `multiply` asRow v
668
669{- | Kronecker product of two matrices.
670
671@m1=(2><3)
672 [ 1.0, 2.0, 0.0
673 , 0.0, -1.0, 3.0 ]
674m2=(4><3)
675 [ 1.0, 2.0, 3.0
676 , 4.0, 5.0, 6.0
677 , 7.0, 8.0, 9.0
678 , 10.0, 11.0, 12.0 ]@
679
680>>> kronecker m1 m2
681(8><9)
682 [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0
683 , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0
684 , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0
685 , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0
686 , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0
687 , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0
688 , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0
689 , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]
690
691-}
692kronecker :: (Product t) => Matrix t -> Matrix t -> Matrix t
693kronecker a b = fromBlocks
694 . splitEvery (cols a)
695 . map (reshape (cols b))
696 . toRows
697 $ flatten a `outer` flatten b
698
699-------------------------------------------------------------------
700
701
702class Convert t where
703 real :: Complexable c => c (RealOf t) -> c t
704 complex :: Complexable c => c t -> c (ComplexOf t)
705 single :: Complexable c => c t -> c (SingleOf t)
706 double :: Complexable c => c t -> c (DoubleOf t)
707 toComplex :: (Complexable c, RealElement t) => (c t, c t) -> c (Complex t)
708 fromComplex :: (Complexable c, RealElement t) => c (Complex t) -> (c t, c t)
709
710
711instance Convert Double where
712 real = id
713 complex = comp'
714 single = single'
715 double = id
716 toComplex = toComplex'
717 fromComplex = fromComplex'
718
719instance Convert Float where
720 real = id
721 complex = comp'
722 single = id
723 double = double'
724 toComplex = toComplex'
725 fromComplex = fromComplex'
726
727instance Convert (Complex Double) where
728 real = comp'
729 complex = id
730 single = single'
731 double = id
732 toComplex = toComplex'
733 fromComplex = fromComplex'
734
735instance Convert (Complex Float) where
736 real = comp'
737 complex = id
738 single = id
739 double = double'
740 toComplex = toComplex'
741 fromComplex = fromComplex'
742
743-------------------------------------------------------------------
744
745type family RealOf x
746
747type instance RealOf Double = Double
748type instance RealOf (Complex Double) = Double
749
750type instance RealOf Float = Float
751type instance RealOf (Complex Float) = Float
752
753type instance RealOf I = I
754
755type family ComplexOf x
756
757type instance ComplexOf Double = Complex Double
758type instance ComplexOf (Complex Double) = Complex Double
759
760type instance ComplexOf Float = Complex Float
761type instance ComplexOf (Complex Float) = Complex Float
762
763type family SingleOf x
764
765type instance SingleOf Double = Float
766type instance SingleOf Float = Float
767
768type instance SingleOf (Complex a) = Complex (SingleOf a)
769
770type family DoubleOf x
771
772type instance DoubleOf Double = Double
773type instance DoubleOf Float = Double
774
775type instance DoubleOf (Complex a) = Complex (DoubleOf a)
776
777type family ElementOf c
778
779type instance ElementOf (Vector a) = a
780type instance ElementOf (Matrix a) = a
781
782------------------------------------------------------------
783
784buildM (rc,cc) f = fromLists [ [f r c | c <- cs] | r <- rs ]
785 where rs = map fromIntegral [0 .. (rc-1)]
786 cs = map fromIntegral [0 .. (cc-1)]
787
788buildV n f = fromList [f k | k <- ks]
789 where ks = map fromIntegral [0 .. (n-1)]
790
791--------------------------------------------------------
792-- | conjugate transpose
793ctrans :: (Container Vector e, Element e) => Matrix e -> Matrix e
794ctrans = liftMatrix conj' . trans
795
796-- | Creates a square matrix with a given diagonal.
797diag :: (Num a, Element a) => Vector a -> Matrix a
798diag v = diagRect 0 v n n where n = dim v
799
800-- | creates the identity matrix of given dimension
801ident :: (Num a, Element a) => Int -> Matrix a
802ident n = diag (constantD 1 n)
803
804--------------------------------------------------------
805
806findV p x = foldVectorWithIndex g [] x where
807 g k z l = if p z then k:l else l
808
809findM p x = map ((`divMod` cols x)) $ findV p (flatten x)
810
811assocV n z xs = ST.runSTVector $ do
812 v <- ST.newVector z n
813 mapM_ (\(k,x) -> ST.writeVector v k x) xs
814 return v
815
816assocM (r,c) z xs = ST.runSTMatrix $ do
817 m <- ST.newMatrix z r c
818 mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs
819 return m
820
821accumV v0 f xs = ST.runSTVector $ do
822 v <- ST.thawVector v0
823 mapM_ (\(k,x) -> ST.modifyVector v k (f x)) xs
824 return v
825
826accumM m0 f xs = ST.runSTMatrix $ do
827 m <- ST.thawMatrix m0
828 mapM_ (\((i,j),x) -> ST.modifyMatrix m i j (f x)) xs
829 return m
830
831----------------------------------------------------------------------
832
833compareM a b = matrixFromVector RowMajor (rows a'') (cols a'') $ ccompare' a' b'
834 where
835 args@(a'':_) = conformMs [a,b]
836 [a', b'] = map flatten args
837
838compareCV f a b = f a' b'
839 where
840 [a', b'] = conformVs [a,b]
841
842selectM c l e t = matrixFromVector RowMajor (rows a'') (cols a'') $ cselect' (toInt c') l' e' t'
843 where
844 args@(a'':_) = conformMs [fromInt c,l,e,t]
845 [c', l', e', t'] = map flatten args
846
847selectCV f c l e t = f (toInt c') l' e' t'
848 where
849 [c', l', e', t'] = conformVs [fromInt c,l,e,t]
850
851--------------------------------------------------------------------------------
852
853class Transposable m mt | m -> mt, mt -> m
854 where
855 -- | conjugate transpose
856 tr :: m -> mt
857 -- | transpose
858 tr' :: m -> mt
859
860instance (Container Vector t) => Transposable (Matrix t) (Matrix t)
861 where
862 tr = ctrans
863 tr' = trans
864
865class Linear t v
866 where
867 scalarL :: t -> v
868 addL :: v -> v -> v
869 scaleL :: t -> v -> v
870
871
872class Testable t
873 where
874 checkT :: t -> (Bool, IO())
875 ioCheckT :: t -> IO (Bool, IO())
876 ioCheckT = return . checkT
877
878--------------------------------------------------------------------------------
879