summaryrefslogtreecommitdiff
path: root/packages/base/src/Data/Packed/Internal/Numeric.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Data/Packed/Internal/Numeric.hs')
-rw-r--r--packages/base/src/Data/Packed/Internal/Numeric.hs607
1 files changed, 607 insertions, 0 deletions
diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Data/Packed/Internal/Numeric.hs
new file mode 100644
index 0000000..81a8083
--- /dev/null
+++ b/packages/base/src/Data/Packed/Internal/Numeric.hs
@@ -0,0 +1,607 @@
1{-# LANGUAGE CPP #-}
2{-# LANGUAGE TypeFamilies #-}
3{-# LANGUAGE FlexibleContexts #-}
4{-# LANGUAGE FlexibleInstances #-}
5{-# LANGUAGE MultiParamTypeClasses #-}
6{-# LANGUAGE UndecidableInstances #-}
7
8-----------------------------------------------------------------------------
9-- |
10-- Module : Data.Packed.Internal.Numeric
11-- Copyright : (c) Alberto Ruiz 2010-14
12-- License : BSD3
13-- Maintainer : Alberto Ruiz
14-- Stability : provisional
15--
16-----------------------------------------------------------------------------
17
18module Data.Packed.Internal.Numeric (
19 -- * Basic functions
20 ident, diag, ctrans,
21 -- * Generic operations
22 Container(..),
23 -- * Matrix product and related functions
24 Product(..), udot,
25 mXm,mXv,vXm,
26 outer, kronecker,
27 -- * Element conversion
28 Convert(..),
29 Complexable(),
30 RealElement(),
31
32 RealOf, ComplexOf, SingleOf, DoubleOf,
33
34 IndexOf,
35 module Data.Complex
36) where
37
38import Data.Packed
39import Data.Packed.ST as ST
40import Numeric.Conversion
41import Data.Packed.Development
42import Numeric.Vectorized
43import Data.Complex
44import Control.Applicative((<*>))
45
46import Numeric.LinearAlgebra.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ)
47
48-------------------------------------------------------------------
49
50type family IndexOf (c :: * -> *)
51
52type instance IndexOf Vector = Int
53type instance IndexOf Matrix = (Int,Int)
54
55type family ArgOf (c :: * -> *) a
56
57type instance ArgOf Vector a = a -> a
58type instance ArgOf Matrix a = a -> a -> a
59
60-------------------------------------------------------------------
61
62-- | Basic element-by-element functions for numeric containers
63class (Complexable c, Fractional e, Element e) => Container c e where
64 -- | create a structure with a single element
65 --
66 -- >>> let v = fromList [1..3::Double]
67 -- >>> v / scalar (norm2 v)
68 -- fromList [0.2672612419124244,0.5345224838248488,0.8017837257372732]
69 --
70 scalar :: e -> c e
71 -- | complex conjugate
72 conj :: c e -> c e
73 scale :: e -> c e -> c e
74 -- | scale the element by element reciprocal of the object:
75 --
76 -- @scaleRecip 2 (fromList [5,i]) == 2 |> [0.4 :+ 0.0,0.0 :+ (-2.0)]@
77 scaleRecip :: e -> c e -> c e
78 addConstant :: e -> c e -> c e
79 add :: c e -> c e -> c e
80 sub :: c e -> c e -> c e
81 -- | element by element multiplication
82 mul :: c e -> c e -> c e
83 -- | element by element division
84 divide :: c e -> c e -> c e
85 equal :: c e -> c e -> Bool
86 --
87 -- element by element inverse tangent
88 arctan2 :: c e -> c e -> c e
89 --
90 -- | cannot implement instance Functor because of Element class constraint
91 cmap :: (Element b) => (e -> b) -> c e -> c b
92 -- | constant structure of given size
93 konst' :: e -> IndexOf c -> c e
94 -- | create a structure using a function
95 --
96 -- Hilbert matrix of order N:
97 --
98 -- @hilb n = build' (n,n) (\\i j -> 1/(i+j+1))@
99 build' :: IndexOf c -> (ArgOf c e) -> c e
100 -- | indexing function
101 atIndex :: c e -> IndexOf c -> e
102 -- | index of min element
103 minIndex :: c e -> IndexOf c
104 -- | index of max element
105 maxIndex :: c e -> IndexOf c
106 -- | value of min element
107 minElement :: c e -> e
108 -- | value of max element
109 maxElement :: c e -> e
110 -- the C functions sumX/prodX are twice as fast as using foldVector
111 -- | the sum of elements (faster than using @fold@)
112 sumElements :: c e -> e
113 -- | the product of elements (faster than using @fold@)
114 prodElements :: c e -> e
115
116 -- | A more efficient implementation of @cmap (\\x -> if x>0 then 1 else 0)@
117 --
118 -- >>> step $ linspace 5 (-1,1::Double)
119 -- 5 |> [0.0,0.0,0.0,1.0,1.0]
120 --
121
122 step :: RealElement e => c e -> c e
123
124 -- | Element by element version of @case compare a b of {LT -> l; EQ -> e; GT -> g}@.
125 --
126 -- Arguments with any dimension = 1 are automatically expanded:
127 --
128 -- >>> cond ((1><4)[1..]) ((3><1)[1..]) 0 100 ((3><4)[1..]) :: Matrix Double
129 -- (3><4)
130 -- [ 100.0, 2.0, 3.0, 4.0
131 -- , 0.0, 100.0, 7.0, 8.0
132 -- , 0.0, 0.0, 100.0, 12.0 ]
133 --
134
135 cond :: RealElement e
136 => c e -- ^ a
137 -> c e -- ^ b
138 -> c e -- ^ l
139 -> c e -- ^ e
140 -> c e -- ^ g
141 -> c e -- ^ result
142
143 -- | Find index of elements which satisfy a predicate
144 --
145 -- >>> find (>0) (ident 3 :: Matrix Double)
146 -- [(0,0),(1,1),(2,2)]
147 --
148
149 find :: (e -> Bool) -> c e -> [IndexOf c]
150
151 -- | Create a structure from an association list
152 --
153 -- >>> assoc 5 0 [(3,7),(1,4)] :: Vector Double
154 -- fromList [0.0,4.0,0.0,7.0,0.0]
155 --
156 -- >>> assoc (2,3) 0 [((0,2),7),((1,0),2*i-3)] :: Matrix (Complex Double)
157 -- (2><3)
158 -- [ 0.0 :+ 0.0, 0.0 :+ 0.0, 7.0 :+ 0.0
159 -- , (-3.0) :+ 2.0, 0.0 :+ 0.0, 0.0 :+ 0.0 ]
160 --
161 assoc :: IndexOf c -- ^ size
162 -> e -- ^ default value
163 -> [(IndexOf c, e)] -- ^ association list
164 -> c e -- ^ result
165
166 -- | Modify a structure using an update function
167 --
168 -- >>> accum (ident 5) (+) [((1,1),5),((0,3),3)] :: Matrix Double
169 -- (5><5)
170 -- [ 1.0, 0.0, 0.0, 3.0, 0.0
171 -- , 0.0, 6.0, 0.0, 0.0, 0.0
172 -- , 0.0, 0.0, 1.0, 0.0, 0.0
173 -- , 0.0, 0.0, 0.0, 1.0, 0.0
174 -- , 0.0, 0.0, 0.0, 0.0, 1.0 ]
175 --
176 -- computation of histogram:
177 --
178 -- >>> accum (konst 0 7) (+) (map (flip (,) 1) [4,5,4,1,5,2,5]) :: Vector Double
179 -- fromList [0.0,1.0,1.0,0.0,2.0,3.0,0.0]
180 --
181
182 accum :: c e -- ^ initial structure
183 -> (e -> e -> e) -- ^ update function
184 -> [(IndexOf c, e)] -- ^ association list
185 -> c e -- ^ result
186
187--------------------------------------------------------------------------
188
189instance Container Vector Float where
190 scale = vectorMapValF Scale
191 scaleRecip = vectorMapValF Recip
192 addConstant = vectorMapValF AddConstant
193 add = vectorZipF Add
194 sub = vectorZipF Sub
195 mul = vectorZipF Mul
196 divide = vectorZipF Div
197 equal u v = dim u == dim v && maxElement (vectorMapF Abs (sub u v)) == 0.0
198 arctan2 = vectorZipF ATan2
199 scalar x = fromList [x]
200 konst' = constant
201 build' = buildV
202 conj = id
203 cmap = mapVector
204 atIndex = (@>)
205 minIndex = emptyErrorV "minIndex" (round . toScalarF MinIdx)
206 maxIndex = emptyErrorV "maxIndex" (round . toScalarF MaxIdx)
207 minElement = emptyErrorV "minElement" (toScalarF Min)
208 maxElement = emptyErrorV "maxElement" (toScalarF Max)
209 sumElements = sumF
210 prodElements = prodF
211 step = stepF
212 find = findV
213 assoc = assocV
214 accum = accumV
215 cond = condV condF
216
217instance Container Vector Double where
218 scale = vectorMapValR Scale
219 scaleRecip = vectorMapValR Recip
220 addConstant = vectorMapValR AddConstant
221 add = vectorZipR Add
222 sub = vectorZipR Sub
223 mul = vectorZipR Mul
224 divide = vectorZipR Div
225 equal u v = dim u == dim v && maxElement (vectorMapR Abs (sub u v)) == 0.0
226 arctan2 = vectorZipR ATan2
227 scalar x = fromList [x]
228 konst' = constant
229 build' = buildV
230 conj = id
231 cmap = mapVector
232 atIndex = (@>)
233 minIndex = emptyErrorV "minIndex" (round . toScalarR MinIdx)
234 maxIndex = emptyErrorV "maxIndex" (round . toScalarR MaxIdx)
235 minElement = emptyErrorV "minElement" (toScalarR Min)
236 maxElement = emptyErrorV "maxElement" (toScalarR Max)
237 sumElements = sumR
238 prodElements = prodR
239 step = stepD
240 find = findV
241 assoc = assocV
242 accum = accumV
243 cond = condV condD
244
245instance Container Vector (Complex Double) where
246 scale = vectorMapValC Scale
247 scaleRecip = vectorMapValC Recip
248 addConstant = vectorMapValC AddConstant
249 add = vectorZipC Add
250 sub = vectorZipC Sub
251 mul = vectorZipC Mul
252 divide = vectorZipC Div
253 equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0
254 arctan2 = vectorZipC ATan2
255 scalar x = fromList [x]
256 konst' = constant
257 build' = buildV
258 conj = conjugateC
259 cmap = mapVector
260 atIndex = (@>)
261 minIndex = emptyErrorV "minIndex" (minIndex . fst . fromComplex . (mul <*> conj))
262 maxIndex = emptyErrorV "maxIndex" (maxIndex . fst . fromComplex . (mul <*> conj))
263 minElement = emptyErrorV "minElement" (atIndex <*> minIndex)
264 maxElement = emptyErrorV "maxElement" (atIndex <*> maxIndex)
265 sumElements = sumC
266 prodElements = prodC
267 step = undefined -- cannot match
268 find = findV
269 assoc = assocV
270 accum = accumV
271 cond = undefined -- cannot match
272
273instance Container Vector (Complex Float) where
274 scale = vectorMapValQ Scale
275 scaleRecip = vectorMapValQ Recip
276 addConstant = vectorMapValQ AddConstant
277 add = vectorZipQ Add
278 sub = vectorZipQ Sub
279 mul = vectorZipQ Mul
280 divide = vectorZipQ Div
281 equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0
282 arctan2 = vectorZipQ ATan2
283 scalar x = fromList [x]
284 konst' = constant
285 build' = buildV
286 conj = conjugateQ
287 cmap = mapVector
288 atIndex = (@>)
289 minIndex = emptyErrorV "minIndex" (minIndex . fst . fromComplex . (mul <*> conj))
290 maxIndex = emptyErrorV "maxIndex" (maxIndex . fst . fromComplex . (mul <*> conj))
291 minElement = emptyErrorV "minElement" (atIndex <*> minIndex)
292 maxElement = emptyErrorV "maxElement" (atIndex <*> maxIndex)
293 sumElements = sumQ
294 prodElements = prodQ
295 step = undefined -- cannot match
296 find = findV
297 assoc = assocV
298 accum = accumV
299 cond = undefined -- cannot match
300
301---------------------------------------------------------------
302
303instance (Container Vector a) => Container Matrix a where
304 scale x = liftMatrix (scale x)
305 scaleRecip x = liftMatrix (scaleRecip x)
306 addConstant x = liftMatrix (addConstant x)
307 add = liftMatrix2 add
308 sub = liftMatrix2 sub
309 mul = liftMatrix2 mul
310 divide = liftMatrix2 divide
311 equal a b = cols a == cols b && flatten a `equal` flatten b
312 arctan2 = liftMatrix2 arctan2
313 scalar x = (1><1) [x]
314 konst' v (r,c) = matrixFromVector RowMajor r c (konst' v (r*c))
315 build' = buildM
316 conj = liftMatrix conj
317 cmap f = liftMatrix (mapVector f)
318 atIndex = (@@>)
319 minIndex = emptyErrorM "minIndex of Matrix" $
320 \m -> divMod (minIndex $ flatten m) (cols m)
321 maxIndex = emptyErrorM "maxIndex of Matrix" $
322 \m -> divMod (maxIndex $ flatten m) (cols m)
323 minElement = emptyErrorM "minElement of Matrix" (atIndex <*> minIndex)
324 maxElement = emptyErrorM "maxElement of Matrix" (atIndex <*> maxIndex)
325 sumElements = sumElements . flatten
326 prodElements = prodElements . flatten
327 step = liftMatrix step
328 find = findM
329 assoc = assocM
330 accum = accumM
331 cond = condM
332
333
334emptyErrorV msg f v =
335 if dim v > 0
336 then f v
337 else error $ msg ++ " of Vector with dim = 0"
338
339emptyErrorM msg f m =
340 if rows m > 0 && cols m > 0
341 then f m
342 else error $ msg++" "++shSize m
343
344----------------------------------------------------
345
346-- | Matrix product and related functions
347class (Num e, Element e) => Product e where
348 -- | matrix product
349 multiply :: Matrix e -> Matrix e -> Matrix e
350 -- | sum of absolute value of elements (differs in complex case from @norm1@)
351 absSum :: Vector e -> RealOf e
352 -- | sum of absolute value of elements
353 norm1 :: Vector e -> RealOf e
354 -- | euclidean norm
355 norm2 :: Vector e -> RealOf e
356 -- | element of maximum magnitude
357 normInf :: Vector e -> RealOf e
358
359instance Product Float where
360 norm2 = emptyVal (toScalarF Norm2)
361 absSum = emptyVal (toScalarF AbsSum)
362 norm1 = emptyVal (toScalarF AbsSum)
363 normInf = emptyVal (maxElement . vectorMapF Abs)
364 multiply = emptyMul multiplyF
365
366instance Product Double where
367 norm2 = emptyVal (toScalarR Norm2)
368 absSum = emptyVal (toScalarR AbsSum)
369 norm1 = emptyVal (toScalarR AbsSum)
370 normInf = emptyVal (maxElement . vectorMapR Abs)
371 multiply = emptyMul multiplyR
372
373instance Product (Complex Float) where
374 norm2 = emptyVal (toScalarQ Norm2)
375 absSum = emptyVal (toScalarQ AbsSum)
376 norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapQ Abs)
377 normInf = emptyVal (maxElement . fst . fromComplex . vectorMapQ Abs)
378 multiply = emptyMul multiplyQ
379
380instance Product (Complex Double) where
381 norm2 = emptyVal (toScalarC Norm2)
382 absSum = emptyVal (toScalarC AbsSum)
383 norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapC Abs)
384 normInf = emptyVal (maxElement . fst . fromComplex . vectorMapC Abs)
385 multiply = emptyMul multiplyC
386
387emptyMul m a b
388 | x1 == 0 && x2 == 0 || r == 0 || c == 0 = konst' 0 (r,c)
389 | otherwise = m a b
390 where
391 r = rows a
392 x1 = cols a
393 x2 = rows b
394 c = cols b
395
396emptyVal f v =
397 if dim v > 0
398 then f v
399 else 0
400
401-- FIXME remove unused C wrappers
402-- | unconjugated dot product
403udot :: Product e => Vector e -> Vector e -> e
404udot u v
405 | dim u == dim v = val (asRow u `multiply` asColumn v)
406 | otherwise = error $ "different dimensions "++show (dim u)++" and "++show (dim v)++" in dot product"
407 where
408 val m | dim u > 0 = m@@>(0,0)
409 | otherwise = 0
410
411----------------------------------------------------------
412
413-- synonym for matrix product
414mXm :: Product t => Matrix t -> Matrix t -> Matrix t
415mXm = multiply
416
417-- matrix - vector product
418mXv :: Product t => Matrix t -> Vector t -> Vector t
419mXv m v = flatten $ m `mXm` (asColumn v)
420
421-- vector - matrix product
422vXm :: Product t => Vector t -> Matrix t -> Vector t
423vXm v m = flatten $ (asRow v) `mXm` m
424
425{- | Outer product of two vectors.
426
427>>> fromList [1,2,3] `outer` fromList [5,2,3]
428(3><3)
429 [ 5.0, 2.0, 3.0
430 , 10.0, 4.0, 6.0
431 , 15.0, 6.0, 9.0 ]
432
433-}
434outer :: (Product t) => Vector t -> Vector t -> Matrix t
435outer u v = asColumn u `multiply` asRow v
436
437{- | Kronecker product of two matrices.
438
439@m1=(2><3)
440 [ 1.0, 2.0, 0.0
441 , 0.0, -1.0, 3.0 ]
442m2=(4><3)
443 [ 1.0, 2.0, 3.0
444 , 4.0, 5.0, 6.0
445 , 7.0, 8.0, 9.0
446 , 10.0, 11.0, 12.0 ]@
447
448>>> kronecker m1 m2
449(8><9)
450 [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0
451 , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0
452 , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0
453 , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0
454 , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0
455 , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0
456 , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0
457 , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]
458
459-}
460kronecker :: (Product t) => Matrix t -> Matrix t -> Matrix t
461kronecker a b = fromBlocks
462 . splitEvery (cols a)
463 . map (reshape (cols b))
464 . toRows
465 $ flatten a `outer` flatten b
466
467-------------------------------------------------------------------
468
469
470class Convert t where
471 real :: Container c t => c (RealOf t) -> c t
472 complex :: Container c t => c t -> c (ComplexOf t)
473 single :: Container c t => c t -> c (SingleOf t)
474 double :: Container c t => c t -> c (DoubleOf t)
475 toComplex :: (Container c t, RealElement t) => (c t, c t) -> c (Complex t)
476 fromComplex :: (Container c t, RealElement t) => c (Complex t) -> (c t, c t)
477
478
479instance Convert Double where
480 real = id
481 complex = comp'
482 single = single'
483 double = id
484 toComplex = toComplex'
485 fromComplex = fromComplex'
486
487instance Convert Float where
488 real = id
489 complex = comp'
490 single = id
491 double = double'
492 toComplex = toComplex'
493 fromComplex = fromComplex'
494
495instance Convert (Complex Double) where
496 real = comp'
497 complex = id
498 single = single'
499 double = id
500 toComplex = toComplex'
501 fromComplex = fromComplex'
502
503instance Convert (Complex Float) where
504 real = comp'
505 complex = id
506 single = id
507 double = double'
508 toComplex = toComplex'
509 fromComplex = fromComplex'
510
511-------------------------------------------------------------------
512
513type family RealOf x
514
515type instance RealOf Double = Double
516type instance RealOf (Complex Double) = Double
517
518type instance RealOf Float = Float
519type instance RealOf (Complex Float) = Float
520
521type family ComplexOf x
522
523type instance ComplexOf Double = Complex Double
524type instance ComplexOf (Complex Double) = Complex Double
525
526type instance ComplexOf Float = Complex Float
527type instance ComplexOf (Complex Float) = Complex Float
528
529type family SingleOf x
530
531type instance SingleOf Double = Float
532type instance SingleOf Float = Float
533
534type instance SingleOf (Complex a) = Complex (SingleOf a)
535
536type family DoubleOf x
537
538type instance DoubleOf Double = Double
539type instance DoubleOf Float = Double
540
541type instance DoubleOf (Complex a) = Complex (DoubleOf a)
542
543type family ElementOf c
544
545type instance ElementOf (Vector a) = a
546type instance ElementOf (Matrix a) = a
547
548------------------------------------------------------------
549
550buildM (rc,cc) f = fromLists [ [f r c | c <- cs] | r <- rs ]
551 where rs = map fromIntegral [0 .. (rc-1)]
552 cs = map fromIntegral [0 .. (cc-1)]
553
554buildV n f = fromList [f k | k <- ks]
555 where ks = map fromIntegral [0 .. (n-1)]
556
557--------------------------------------------------------
558-- | conjugate transpose
559ctrans :: (Container Vector e, Element e) => Matrix e -> Matrix e
560ctrans = liftMatrix conj . trans
561
562-- | Creates a square matrix with a given diagonal.
563diag :: (Num a, Element a) => Vector a -> Matrix a
564diag v = diagRect 0 v n n where n = dim v
565
566-- | creates the identity matrix of given dimension
567ident :: (Num a, Element a) => Int -> Matrix a
568ident n = diag (constant 1 n)
569
570--------------------------------------------------------
571
572findV p x = foldVectorWithIndex g [] x where
573 g k z l = if p z then k:l else l
574
575findM p x = map ((`divMod` cols x)) $ findV p (flatten x)
576
577assocV n z xs = ST.runSTVector $ do
578 v <- ST.newVector z n
579 mapM_ (\(k,x) -> ST.writeVector v k x) xs
580 return v
581
582assocM (r,c) z xs = ST.runSTMatrix $ do
583 m <- ST.newMatrix z r c
584 mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs
585 return m
586
587accumV v0 f xs = ST.runSTVector $ do
588 v <- ST.thawVector v0
589 mapM_ (\(k,x) -> ST.modifyVector v k (f x)) xs
590 return v
591
592accumM m0 f xs = ST.runSTMatrix $ do
593 m <- ST.thawMatrix m0
594 mapM_ (\((i,j),x) -> ST.modifyMatrix m i j (f x)) xs
595 return m
596
597----------------------------------------------------------------------
598
599condM a b l e t = matrixFromVector RowMajor (rows a'') (cols a'') $ cond a' b' l' e' t'
600 where
601 args@(a'':_) = conformMs [a,b,l,e,t]
602 [a', b', l', e', t'] = map flatten args
603
604condV f a b l e t = f a' b' l' e' t'
605 where
606 [a', b', l', e', t'] = conformVs [a,b,l,e,t]
607