summaryrefslogtreecommitdiff
path: root/packages/base/src/Data/Packed
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Data/Packed')
-rw-r--r--packages/base/src/Data/Packed/Numeric.hs608
1 files changed, 608 insertions, 0 deletions
diff --git a/packages/base/src/Data/Packed/Numeric.hs b/packages/base/src/Data/Packed/Numeric.hs
new file mode 100644
index 0000000..4892089
--- /dev/null
+++ b/packages/base/src/Data/Packed/Numeric.hs
@@ -0,0 +1,608 @@
1{-# LANGUAGE CPP #-}
2{-# LANGUAGE TypeFamilies #-}
3{-# LANGUAGE FlexibleContexts #-}
4{-# LANGUAGE FlexibleInstances #-}
5{-# LANGUAGE MultiParamTypeClasses #-}
6{-# LANGUAGE UndecidableInstances #-}
7
8-----------------------------------------------------------------------------
9-- |
10-- Module : Data.Packed.Numeric
11-- Copyright : (c) Alberto Ruiz 2010-14
12-- License : BSD3
13--
14-- Maintainer : Alberto Ruiz
15-- Stability : provisional
16--
17-----------------------------------------------------------------------------
18
19module Data.Packed.Numeric (
20 -- * Basic functions
21 ident, diag, ctrans,
22 -- * Generic operations
23 Container(..),
24 -- * Matrix product and related functions
25 Product(..), udot,
26 mXm,mXv,vXm,
27 outer, kronecker,
28 -- * Element conversion
29 Convert(..),
30 Complexable(),
31 RealElement(),
32
33 RealOf, ComplexOf, SingleOf, DoubleOf,
34
35 IndexOf,
36 module Data.Complex
37) where
38
39import Data.Packed
40import Data.Packed.ST as ST
41import Numeric.Conversion
42import Data.Packed.Development
43import Numeric.Vectorized
44import Data.Complex
45import Control.Applicative((<*>))
46
47import Numeric.LinearAlgebra.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ)
48
49-------------------------------------------------------------------
50
51type family IndexOf (c :: * -> *)
52
53type instance IndexOf Vector = Int
54type instance IndexOf Matrix = (Int,Int)
55
56type family ArgOf (c :: * -> *) a
57
58type instance ArgOf Vector a = a -> a
59type instance ArgOf Matrix a = a -> a -> a
60
61-------------------------------------------------------------------
62
63-- | Basic element-by-element functions for numeric containers
64class (Complexable c, Fractional e, Element e) => Container c e where
65 -- | create a structure with a single element
66 --
67 -- >>> let v = fromList [1..3::Double]
68 -- >>> v / scalar (norm2 v)
69 -- fromList [0.2672612419124244,0.5345224838248488,0.8017837257372732]
70 --
71 scalar :: e -> c e
72 -- | complex conjugate
73 conj :: c e -> c e
74 scale :: e -> c e -> c e
75 -- | scale the element by element reciprocal of the object:
76 --
77 -- @scaleRecip 2 (fromList [5,i]) == 2 |> [0.4 :+ 0.0,0.0 :+ (-2.0)]@
78 scaleRecip :: e -> c e -> c e
79 addConstant :: e -> c e -> c e
80 add :: c e -> c e -> c e
81 sub :: c e -> c e -> c e
82 -- | element by element multiplication
83 mul :: c e -> c e -> c e
84 -- | element by element division
85 divide :: c e -> c e -> c e
86 equal :: c e -> c e -> Bool
87 --
88 -- element by element inverse tangent
89 arctan2 :: c e -> c e -> c e
90 --
91 -- | cannot implement instance Functor because of Element class constraint
92 cmap :: (Element b) => (e -> b) -> c e -> c b
93 -- | constant structure of given size
94 konst' :: e -> IndexOf c -> c e
95 -- | create a structure using a function
96 --
97 -- Hilbert matrix of order N:
98 --
99 -- @hilb n = build' (n,n) (\\i j -> 1/(i+j+1))@
100 build' :: IndexOf c -> (ArgOf c e) -> c e
101 -- | indexing function
102 atIndex :: c e -> IndexOf c -> e
103 -- | index of min element
104 minIndex :: c e -> IndexOf c
105 -- | index of max element
106 maxIndex :: c e -> IndexOf c
107 -- | value of min element
108 minElement :: c e -> e
109 -- | value of max element
110 maxElement :: c e -> e
111 -- the C functions sumX/prodX are twice as fast as using foldVector
112 -- | the sum of elements (faster than using @fold@)
113 sumElements :: c e -> e
114 -- | the product of elements (faster than using @fold@)
115 prodElements :: c e -> e
116
117 -- | A more efficient implementation of @cmap (\\x -> if x>0 then 1 else 0)@
118 --
119 -- >>> step $ linspace 5 (-1,1::Double)
120 -- 5 |> [0.0,0.0,0.0,1.0,1.0]
121 --
122
123 step :: RealElement e => c e -> c e
124
125 -- | Element by element version of @case compare a b of {LT -> l; EQ -> e; GT -> g}@.
126 --
127 -- Arguments with any dimension = 1 are automatically expanded:
128 --
129 -- >>> cond ((1><4)[1..]) ((3><1)[1..]) 0 100 ((3><4)[1..]) :: Matrix Double
130 -- (3><4)
131 -- [ 100.0, 2.0, 3.0, 4.0
132 -- , 0.0, 100.0, 7.0, 8.0
133 -- , 0.0, 0.0, 100.0, 12.0 ]
134 --
135
136 cond :: RealElement e
137 => c e -- ^ a
138 -> c e -- ^ b
139 -> c e -- ^ l
140 -> c e -- ^ e
141 -> c e -- ^ g
142 -> c e -- ^ result
143
144 -- | Find index of elements which satisfy a predicate
145 --
146 -- >>> find (>0) (ident 3 :: Matrix Double)
147 -- [(0,0),(1,1),(2,2)]
148 --
149
150 find :: (e -> Bool) -> c e -> [IndexOf c]
151
152 -- | Create a structure from an association list
153 --
154 -- >>> assoc 5 0 [(3,7),(1,4)] :: Vector Double
155 -- fromList [0.0,4.0,0.0,7.0,0.0]
156 --
157 -- >>> assoc (2,3) 0 [((0,2),7),((1,0),2*i-3)] :: Matrix (Complex Double)
158 -- (2><3)
159 -- [ 0.0 :+ 0.0, 0.0 :+ 0.0, 7.0 :+ 0.0
160 -- , (-3.0) :+ 2.0, 0.0 :+ 0.0, 0.0 :+ 0.0 ]
161 --
162 assoc :: IndexOf c -- ^ size
163 -> e -- ^ default value
164 -> [(IndexOf c, e)] -- ^ association list
165 -> c e -- ^ result
166
167 -- | Modify a structure using an update function
168 --
169 -- >>> accum (ident 5) (+) [((1,1),5),((0,3),3)] :: Matrix Double
170 -- (5><5)
171 -- [ 1.0, 0.0, 0.0, 3.0, 0.0
172 -- , 0.0, 6.0, 0.0, 0.0, 0.0
173 -- , 0.0, 0.0, 1.0, 0.0, 0.0
174 -- , 0.0, 0.0, 0.0, 1.0, 0.0
175 -- , 0.0, 0.0, 0.0, 0.0, 1.0 ]
176 --
177 -- computation of histogram:
178 --
179 -- >>> accum (konst 0 7) (+) (map (flip (,) 1) [4,5,4,1,5,2,5]) :: Vector Double
180 -- fromList [0.0,1.0,1.0,0.0,2.0,3.0,0.0]
181 --
182
183 accum :: c e -- ^ initial structure
184 -> (e -> e -> e) -- ^ update function
185 -> [(IndexOf c, e)] -- ^ association list
186 -> c e -- ^ result
187
188--------------------------------------------------------------------------
189
190instance Container Vector Float where
191 scale = vectorMapValF Scale
192 scaleRecip = vectorMapValF Recip
193 addConstant = vectorMapValF AddConstant
194 add = vectorZipF Add
195 sub = vectorZipF Sub
196 mul = vectorZipF Mul
197 divide = vectorZipF Div
198 equal u v = dim u == dim v && maxElement (vectorMapF Abs (sub u v)) == 0.0
199 arctan2 = vectorZipF ATan2
200 scalar x = fromList [x]
201 konst' = constant
202 build' = buildV
203 conj = id
204 cmap = mapVector
205 atIndex = (@>)
206 minIndex = emptyErrorV "minIndex" (round . toScalarF MinIdx)
207 maxIndex = emptyErrorV "maxIndex" (round . toScalarF MaxIdx)
208 minElement = emptyErrorV "minElement" (toScalarF Min)
209 maxElement = emptyErrorV "maxElement" (toScalarF Max)
210 sumElements = sumF
211 prodElements = prodF
212 step = stepF
213 find = findV
214 assoc = assocV
215 accum = accumV
216 cond = condV condF
217
218instance Container Vector Double where
219 scale = vectorMapValR Scale
220 scaleRecip = vectorMapValR Recip
221 addConstant = vectorMapValR AddConstant
222 add = vectorZipR Add
223 sub = vectorZipR Sub
224 mul = vectorZipR Mul
225 divide = vectorZipR Div
226 equal u v = dim u == dim v && maxElement (vectorMapR Abs (sub u v)) == 0.0
227 arctan2 = vectorZipR ATan2
228 scalar x = fromList [x]
229 konst' = constant
230 build' = buildV
231 conj = id
232 cmap = mapVector
233 atIndex = (@>)
234 minIndex = emptyErrorV "minIndex" (round . toScalarR MinIdx)
235 maxIndex = emptyErrorV "maxIndex" (round . toScalarR MaxIdx)
236 minElement = emptyErrorV "minElement" (toScalarR Min)
237 maxElement = emptyErrorV "maxElement" (toScalarR Max)
238 sumElements = sumR
239 prodElements = prodR
240 step = stepD
241 find = findV
242 assoc = assocV
243 accum = accumV
244 cond = condV condD
245
246instance Container Vector (Complex Double) where
247 scale = vectorMapValC Scale
248 scaleRecip = vectorMapValC Recip
249 addConstant = vectorMapValC AddConstant
250 add = vectorZipC Add
251 sub = vectorZipC Sub
252 mul = vectorZipC Mul
253 divide = vectorZipC Div
254 equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0
255 arctan2 = vectorZipC ATan2
256 scalar x = fromList [x]
257 konst' = constant
258 build' = buildV
259 conj = conjugateC
260 cmap = mapVector
261 atIndex = (@>)
262 minIndex = emptyErrorV "minIndex" (minIndex . fst . fromComplex . (mul <*> conj))
263 maxIndex = emptyErrorV "maxIndex" (maxIndex . fst . fromComplex . (mul <*> conj))
264 minElement = emptyErrorV "minElement" (atIndex <*> minIndex)
265 maxElement = emptyErrorV "maxElement" (atIndex <*> maxIndex)
266 sumElements = sumC
267 prodElements = prodC
268 step = undefined -- cannot match
269 find = findV
270 assoc = assocV
271 accum = accumV
272 cond = undefined -- cannot match
273
274instance Container Vector (Complex Float) where
275 scale = vectorMapValQ Scale
276 scaleRecip = vectorMapValQ Recip
277 addConstant = vectorMapValQ AddConstant
278 add = vectorZipQ Add
279 sub = vectorZipQ Sub
280 mul = vectorZipQ Mul
281 divide = vectorZipQ Div
282 equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0
283 arctan2 = vectorZipQ ATan2
284 scalar x = fromList [x]
285 konst' = constant
286 build' = buildV
287 conj = conjugateQ
288 cmap = mapVector
289 atIndex = (@>)
290 minIndex = emptyErrorV "minIndex" (minIndex . fst . fromComplex . (mul <*> conj))
291 maxIndex = emptyErrorV "maxIndex" (maxIndex . fst . fromComplex . (mul <*> conj))
292 minElement = emptyErrorV "minElement" (atIndex <*> minIndex)
293 maxElement = emptyErrorV "maxElement" (atIndex <*> maxIndex)
294 sumElements = sumQ
295 prodElements = prodQ
296 step = undefined -- cannot match
297 find = findV
298 assoc = assocV
299 accum = accumV
300 cond = undefined -- cannot match
301
302---------------------------------------------------------------
303
304instance (Container Vector a) => Container Matrix a where
305 scale x = liftMatrix (scale x)
306 scaleRecip x = liftMatrix (scaleRecip x)
307 addConstant x = liftMatrix (addConstant x)
308 add = liftMatrix2 add
309 sub = liftMatrix2 sub
310 mul = liftMatrix2 mul
311 divide = liftMatrix2 divide
312 equal a b = cols a == cols b && flatten a `equal` flatten b
313 arctan2 = liftMatrix2 arctan2
314 scalar x = (1><1) [x]
315 konst' v (r,c) = matrixFromVector RowMajor r c (konst' v (r*c))
316 build' = buildM
317 conj = liftMatrix conj
318 cmap f = liftMatrix (mapVector f)
319 atIndex = (@@>)
320 minIndex = emptyErrorM "minIndex of Matrix" $
321 \m -> divMod (minIndex $ flatten m) (cols m)
322 maxIndex = emptyErrorM "maxIndex of Matrix" $
323 \m -> divMod (maxIndex $ flatten m) (cols m)
324 minElement = emptyErrorM "minElement of Matrix" (atIndex <*> minIndex)
325 maxElement = emptyErrorM "maxElement of Matrix" (atIndex <*> maxIndex)
326 sumElements = sumElements . flatten
327 prodElements = prodElements . flatten
328 step = liftMatrix step
329 find = findM
330 assoc = assocM
331 accum = accumM
332 cond = condM
333
334
335emptyErrorV msg f v =
336 if dim v > 0
337 then f v
338 else error $ msg ++ " of Vector with dim = 0"
339
340emptyErrorM msg f m =
341 if rows m > 0 && cols m > 0
342 then f m
343 else error $ msg++" "++shSize m
344
345----------------------------------------------------
346
347-- | Matrix product and related functions
348class (Num e, Element e) => Product e where
349 -- | matrix product
350 multiply :: Matrix e -> Matrix e -> Matrix e
351 -- | sum of absolute value of elements (differs in complex case from @norm1@)
352 absSum :: Vector e -> RealOf e
353 -- | sum of absolute value of elements
354 norm1 :: Vector e -> RealOf e
355 -- | euclidean norm
356 norm2 :: Vector e -> RealOf e
357 -- | element of maximum magnitude
358 normInf :: Vector e -> RealOf e
359
360instance Product Float where
361 norm2 = emptyVal (toScalarF Norm2)
362 absSum = emptyVal (toScalarF AbsSum)
363 norm1 = emptyVal (toScalarF AbsSum)
364 normInf = emptyVal (maxElement . vectorMapF Abs)
365 multiply = emptyMul multiplyF
366
367instance Product Double where
368 norm2 = emptyVal (toScalarR Norm2)
369 absSum = emptyVal (toScalarR AbsSum)
370 norm1 = emptyVal (toScalarR AbsSum)
371 normInf = emptyVal (maxElement . vectorMapR Abs)
372 multiply = emptyMul multiplyR
373
374instance Product (Complex Float) where
375 norm2 = emptyVal (toScalarQ Norm2)
376 absSum = emptyVal (toScalarQ AbsSum)
377 norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapQ Abs)
378 normInf = emptyVal (maxElement . fst . fromComplex . vectorMapQ Abs)
379 multiply = emptyMul multiplyQ
380
381instance Product (Complex Double) where
382 norm2 = emptyVal (toScalarC Norm2)
383 absSum = emptyVal (toScalarC AbsSum)
384 norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapC Abs)
385 normInf = emptyVal (maxElement . fst . fromComplex . vectorMapC Abs)
386 multiply = emptyMul multiplyC
387
388emptyMul m a b
389 | x1 == 0 && x2 == 0 || r == 0 || c == 0 = konst' 0 (r,c)
390 | otherwise = m a b
391 where
392 r = rows a
393 x1 = cols a
394 x2 = rows b
395 c = cols b
396
397emptyVal f v =
398 if dim v > 0
399 then f v
400 else 0
401
402-- FIXME remove unused C wrappers
403-- | unconjugated dot product
404udot :: Product e => Vector e -> Vector e -> e
405udot u v
406 | dim u == dim v = val (asRow u `multiply` asColumn v)
407 | otherwise = error $ "different dimensions "++show (dim u)++" and "++show (dim v)++" in dot product"
408 where
409 val m | dim u > 0 = m@@>(0,0)
410 | otherwise = 0
411
412----------------------------------------------------------
413
414-- synonym for matrix product
415mXm :: Product t => Matrix t -> Matrix t -> Matrix t
416mXm = multiply
417
418-- matrix - vector product
419mXv :: Product t => Matrix t -> Vector t -> Vector t
420mXv m v = flatten $ m `mXm` (asColumn v)
421
422-- vector - matrix product
423vXm :: Product t => Vector t -> Matrix t -> Vector t
424vXm v m = flatten $ (asRow v) `mXm` m
425
426{- | Outer product of two vectors.
427
428>>> fromList [1,2,3] `outer` fromList [5,2,3]
429(3><3)
430 [ 5.0, 2.0, 3.0
431 , 10.0, 4.0, 6.0
432 , 15.0, 6.0, 9.0 ]
433
434-}
435outer :: (Product t) => Vector t -> Vector t -> Matrix t
436outer u v = asColumn u `multiply` asRow v
437
438{- | Kronecker product of two matrices.
439
440@m1=(2><3)
441 [ 1.0, 2.0, 0.0
442 , 0.0, -1.0, 3.0 ]
443m2=(4><3)
444 [ 1.0, 2.0, 3.0
445 , 4.0, 5.0, 6.0
446 , 7.0, 8.0, 9.0
447 , 10.0, 11.0, 12.0 ]@
448
449>>> kronecker m1 m2
450(8><9)
451 [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0
452 , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0
453 , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0
454 , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0
455 , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0
456 , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0
457 , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0
458 , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]
459
460-}
461kronecker :: (Product t) => Matrix t -> Matrix t -> Matrix t
462kronecker a b = fromBlocks
463 . splitEvery (cols a)
464 . map (reshape (cols b))
465 . toRows
466 $ flatten a `outer` flatten b
467
468-------------------------------------------------------------------
469
470
471class Convert t where
472 real :: Container c t => c (RealOf t) -> c t
473 complex :: Container c t => c t -> c (ComplexOf t)
474 single :: Container c t => c t -> c (SingleOf t)
475 double :: Container c t => c t -> c (DoubleOf t)
476 toComplex :: (Container c t, RealElement t) => (c t, c t) -> c (Complex t)
477 fromComplex :: (Container c t, RealElement t) => c (Complex t) -> (c t, c t)
478
479
480instance Convert Double where
481 real = id
482 complex = comp'
483 single = single'
484 double = id
485 toComplex = toComplex'
486 fromComplex = fromComplex'
487
488instance Convert Float where
489 real = id
490 complex = comp'
491 single = id
492 double = double'
493 toComplex = toComplex'
494 fromComplex = fromComplex'
495
496instance Convert (Complex Double) where
497 real = comp'
498 complex = id
499 single = single'
500 double = id
501 toComplex = toComplex'
502 fromComplex = fromComplex'
503
504instance Convert (Complex Float) where
505 real = comp'
506 complex = id
507 single = id
508 double = double'
509 toComplex = toComplex'
510 fromComplex = fromComplex'
511
512-------------------------------------------------------------------
513
514type family RealOf x
515
516type instance RealOf Double = Double
517type instance RealOf (Complex Double) = Double
518
519type instance RealOf Float = Float
520type instance RealOf (Complex Float) = Float
521
522type family ComplexOf x
523
524type instance ComplexOf Double = Complex Double
525type instance ComplexOf (Complex Double) = Complex Double
526
527type instance ComplexOf Float = Complex Float
528type instance ComplexOf (Complex Float) = Complex Float
529
530type family SingleOf x
531
532type instance SingleOf Double = Float
533type instance SingleOf Float = Float
534
535type instance SingleOf (Complex a) = Complex (SingleOf a)
536
537type family DoubleOf x
538
539type instance DoubleOf Double = Double
540type instance DoubleOf Float = Double
541
542type instance DoubleOf (Complex a) = Complex (DoubleOf a)
543
544type family ElementOf c
545
546type instance ElementOf (Vector a) = a
547type instance ElementOf (Matrix a) = a
548
549------------------------------------------------------------
550
551buildM (rc,cc) f = fromLists [ [f r c | c <- cs] | r <- rs ]
552 where rs = map fromIntegral [0 .. (rc-1)]
553 cs = map fromIntegral [0 .. (cc-1)]
554
555buildV n f = fromList [f k | k <- ks]
556 where ks = map fromIntegral [0 .. (n-1)]
557
558--------------------------------------------------------
559-- | conjugate transpose
560ctrans :: (Container Vector e, Element e) => Matrix e -> Matrix e
561ctrans = liftMatrix conj . trans
562
563-- | Creates a square matrix with a given diagonal.
564diag :: (Num a, Element a) => Vector a -> Matrix a
565diag v = diagRect 0 v n n where n = dim v
566
567-- | creates the identity matrix of given dimension
568ident :: (Num a, Element a) => Int -> Matrix a
569ident n = diag (constant 1 n)
570
571--------------------------------------------------------
572
573findV p x = foldVectorWithIndex g [] x where
574 g k z l = if p z then k:l else l
575
576findM p x = map ((`divMod` cols x)) $ findV p (flatten x)
577
578assocV n z xs = ST.runSTVector $ do
579 v <- ST.newVector z n
580 mapM_ (\(k,x) -> ST.writeVector v k x) xs
581 return v
582
583assocM (r,c) z xs = ST.runSTMatrix $ do
584 m <- ST.newMatrix z r c
585 mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs
586 return m
587
588accumV v0 f xs = ST.runSTVector $ do
589 v <- ST.thawVector v0
590 mapM_ (\(k,x) -> ST.modifyVector v k (f x)) xs
591 return v
592
593accumM m0 f xs = ST.runSTMatrix $ do
594 m <- ST.thawMatrix m0
595 mapM_ (\((i,j),x) -> ST.modifyMatrix m i j (f x)) xs
596 return m
597
598----------------------------------------------------------------------
599
600condM a b l e t = matrixFromVector RowMajor (rows a'') (cols a'') $ cond a' b' l' e' t'
601 where
602 args@(a'':_) = conformMs [a,b,l,e,t]
603 [a', b', l', e', t'] = map flatten args
604
605condV f a b l e t = f a' b' l' e' t'
606 where
607 [a', b', l', e', t'] = conformVs [a,b,l,e,t]
608