summaryrefslogtreecommitdiff
path: root/packages/base/src/Data/Packed
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Data/Packed')
-rw-r--r--packages/base/src/Data/Packed/Internal/Numeric.hs607
-rw-r--r--packages/base/src/Data/Packed/Numeric.hs680
2 files changed, 765 insertions, 522 deletions
diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Data/Packed/Internal/Numeric.hs
new file mode 100644
index 0000000..81a8083
--- /dev/null
+++ b/packages/base/src/Data/Packed/Internal/Numeric.hs
@@ -0,0 +1,607 @@
1{-# LANGUAGE CPP #-}
2{-# LANGUAGE TypeFamilies #-}
3{-# LANGUAGE FlexibleContexts #-}
4{-# LANGUAGE FlexibleInstances #-}
5{-# LANGUAGE MultiParamTypeClasses #-}
6{-# LANGUAGE UndecidableInstances #-}
7
8-----------------------------------------------------------------------------
9-- |
10-- Module : Data.Packed.Internal.Numeric
11-- Copyright : (c) Alberto Ruiz 2010-14
12-- License : BSD3
13-- Maintainer : Alberto Ruiz
14-- Stability : provisional
15--
16-----------------------------------------------------------------------------
17
18module Data.Packed.Internal.Numeric (
19 -- * Basic functions
20 ident, diag, ctrans,
21 -- * Generic operations
22 Container(..),
23 -- * Matrix product and related functions
24 Product(..), udot,
25 mXm,mXv,vXm,
26 outer, kronecker,
27 -- * Element conversion
28 Convert(..),
29 Complexable(),
30 RealElement(),
31
32 RealOf, ComplexOf, SingleOf, DoubleOf,
33
34 IndexOf,
35 module Data.Complex
36) where
37
38import Data.Packed
39import Data.Packed.ST as ST
40import Numeric.Conversion
41import Data.Packed.Development
42import Numeric.Vectorized
43import Data.Complex
44import Control.Applicative((<*>))
45
46import Numeric.LinearAlgebra.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ)
47
48-------------------------------------------------------------------
49
50type family IndexOf (c :: * -> *)
51
52type instance IndexOf Vector = Int
53type instance IndexOf Matrix = (Int,Int)
54
55type family ArgOf (c :: * -> *) a
56
57type instance ArgOf Vector a = a -> a
58type instance ArgOf Matrix a = a -> a -> a
59
60-------------------------------------------------------------------
61
62-- | Basic element-by-element functions for numeric containers
63class (Complexable c, Fractional e, Element e) => Container c e where
64 -- | create a structure with a single element
65 --
66 -- >>> let v = fromList [1..3::Double]
67 -- >>> v / scalar (norm2 v)
68 -- fromList [0.2672612419124244,0.5345224838248488,0.8017837257372732]
69 --
70 scalar :: e -> c e
71 -- | complex conjugate
72 conj :: c e -> c e
73 scale :: e -> c e -> c e
74 -- | scale the element by element reciprocal of the object:
75 --
76 -- @scaleRecip 2 (fromList [5,i]) == 2 |> [0.4 :+ 0.0,0.0 :+ (-2.0)]@
77 scaleRecip :: e -> c e -> c e
78 addConstant :: e -> c e -> c e
79 add :: c e -> c e -> c e
80 sub :: c e -> c e -> c e
81 -- | element by element multiplication
82 mul :: c e -> c e -> c e
83 -- | element by element division
84 divide :: c e -> c e -> c e
85 equal :: c e -> c e -> Bool
86 --
87 -- element by element inverse tangent
88 arctan2 :: c e -> c e -> c e
89 --
90 -- | cannot implement instance Functor because of Element class constraint
91 cmap :: (Element b) => (e -> b) -> c e -> c b
92 -- | constant structure of given size
93 konst' :: e -> IndexOf c -> c e
94 -- | create a structure using a function
95 --
96 -- Hilbert matrix of order N:
97 --
98 -- @hilb n = build' (n,n) (\\i j -> 1/(i+j+1))@
99 build' :: IndexOf c -> (ArgOf c e) -> c e
100 -- | indexing function
101 atIndex :: c e -> IndexOf c -> e
102 -- | index of min element
103 minIndex :: c e -> IndexOf c
104 -- | index of max element
105 maxIndex :: c e -> IndexOf c
106 -- | value of min element
107 minElement :: c e -> e
108 -- | value of max element
109 maxElement :: c e -> e
110 -- the C functions sumX/prodX are twice as fast as using foldVector
111 -- | the sum of elements (faster than using @fold@)
112 sumElements :: c e -> e
113 -- | the product of elements (faster than using @fold@)
114 prodElements :: c e -> e
115
116 -- | A more efficient implementation of @cmap (\\x -> if x>0 then 1 else 0)@
117 --
118 -- >>> step $ linspace 5 (-1,1::Double)
119 -- 5 |> [0.0,0.0,0.0,1.0,1.0]
120 --
121
122 step :: RealElement e => c e -> c e
123
124 -- | Element by element version of @case compare a b of {LT -> l; EQ -> e; GT -> g}@.
125 --
126 -- Arguments with any dimension = 1 are automatically expanded:
127 --
128 -- >>> cond ((1><4)[1..]) ((3><1)[1..]) 0 100 ((3><4)[1..]) :: Matrix Double
129 -- (3><4)
130 -- [ 100.0, 2.0, 3.0, 4.0
131 -- , 0.0, 100.0, 7.0, 8.0
132 -- , 0.0, 0.0, 100.0, 12.0 ]
133 --
134
135 cond :: RealElement e
136 => c e -- ^ a
137 -> c e -- ^ b
138 -> c e -- ^ l
139 -> c e -- ^ e
140 -> c e -- ^ g
141 -> c e -- ^ result
142
143 -- | Find index of elements which satisfy a predicate
144 --
145 -- >>> find (>0) (ident 3 :: Matrix Double)
146 -- [(0,0),(1,1),(2,2)]
147 --
148
149 find :: (e -> Bool) -> c e -> [IndexOf c]
150
151 -- | Create a structure from an association list
152 --
153 -- >>> assoc 5 0 [(3,7),(1,4)] :: Vector Double
154 -- fromList [0.0,4.0,0.0,7.0,0.0]
155 --
156 -- >>> assoc (2,3) 0 [((0,2),7),((1,0),2*i-3)] :: Matrix (Complex Double)
157 -- (2><3)
158 -- [ 0.0 :+ 0.0, 0.0 :+ 0.0, 7.0 :+ 0.0
159 -- , (-3.0) :+ 2.0, 0.0 :+ 0.0, 0.0 :+ 0.0 ]
160 --
161 assoc :: IndexOf c -- ^ size
162 -> e -- ^ default value
163 -> [(IndexOf c, e)] -- ^ association list
164 -> c e -- ^ result
165
166 -- | Modify a structure using an update function
167 --
168 -- >>> accum (ident 5) (+) [((1,1),5),((0,3),3)] :: Matrix Double
169 -- (5><5)
170 -- [ 1.0, 0.0, 0.0, 3.0, 0.0
171 -- , 0.0, 6.0, 0.0, 0.0, 0.0
172 -- , 0.0, 0.0, 1.0, 0.0, 0.0
173 -- , 0.0, 0.0, 0.0, 1.0, 0.0
174 -- , 0.0, 0.0, 0.0, 0.0, 1.0 ]
175 --
176 -- computation of histogram:
177 --
178 -- >>> accum (konst 0 7) (+) (map (flip (,) 1) [4,5,4,1,5,2,5]) :: Vector Double
179 -- fromList [0.0,1.0,1.0,0.0,2.0,3.0,0.0]
180 --
181
182 accum :: c e -- ^ initial structure
183 -> (e -> e -> e) -- ^ update function
184 -> [(IndexOf c, e)] -- ^ association list
185 -> c e -- ^ result
186
187--------------------------------------------------------------------------
188
189instance Container Vector Float where
190 scale = vectorMapValF Scale
191 scaleRecip = vectorMapValF Recip
192 addConstant = vectorMapValF AddConstant
193 add = vectorZipF Add
194 sub = vectorZipF Sub
195 mul = vectorZipF Mul
196 divide = vectorZipF Div
197 equal u v = dim u == dim v && maxElement (vectorMapF Abs (sub u v)) == 0.0
198 arctan2 = vectorZipF ATan2
199 scalar x = fromList [x]
200 konst' = constant
201 build' = buildV
202 conj = id
203 cmap = mapVector
204 atIndex = (@>)
205 minIndex = emptyErrorV "minIndex" (round . toScalarF MinIdx)
206 maxIndex = emptyErrorV "maxIndex" (round . toScalarF MaxIdx)
207 minElement = emptyErrorV "minElement" (toScalarF Min)
208 maxElement = emptyErrorV "maxElement" (toScalarF Max)
209 sumElements = sumF
210 prodElements = prodF
211 step = stepF
212 find = findV
213 assoc = assocV
214 accum = accumV
215 cond = condV condF
216
217instance Container Vector Double where
218 scale = vectorMapValR Scale
219 scaleRecip = vectorMapValR Recip
220 addConstant = vectorMapValR AddConstant
221 add = vectorZipR Add
222 sub = vectorZipR Sub
223 mul = vectorZipR Mul
224 divide = vectorZipR Div
225 equal u v = dim u == dim v && maxElement (vectorMapR Abs (sub u v)) == 0.0
226 arctan2 = vectorZipR ATan2
227 scalar x = fromList [x]
228 konst' = constant
229 build' = buildV
230 conj = id
231 cmap = mapVector
232 atIndex = (@>)
233 minIndex = emptyErrorV "minIndex" (round . toScalarR MinIdx)
234 maxIndex = emptyErrorV "maxIndex" (round . toScalarR MaxIdx)
235 minElement = emptyErrorV "minElement" (toScalarR Min)
236 maxElement = emptyErrorV "maxElement" (toScalarR Max)
237 sumElements = sumR
238 prodElements = prodR
239 step = stepD
240 find = findV
241 assoc = assocV
242 accum = accumV
243 cond = condV condD
244
245instance Container Vector (Complex Double) where
246 scale = vectorMapValC Scale
247 scaleRecip = vectorMapValC Recip
248 addConstant = vectorMapValC AddConstant
249 add = vectorZipC Add
250 sub = vectorZipC Sub
251 mul = vectorZipC Mul
252 divide = vectorZipC Div
253 equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0
254 arctan2 = vectorZipC ATan2
255 scalar x = fromList [x]
256 konst' = constant
257 build' = buildV
258 conj = conjugateC
259 cmap = mapVector
260 atIndex = (@>)
261 minIndex = emptyErrorV "minIndex" (minIndex . fst . fromComplex . (mul <*> conj))
262 maxIndex = emptyErrorV "maxIndex" (maxIndex . fst . fromComplex . (mul <*> conj))
263 minElement = emptyErrorV "minElement" (atIndex <*> minIndex)
264 maxElement = emptyErrorV "maxElement" (atIndex <*> maxIndex)
265 sumElements = sumC
266 prodElements = prodC
267 step = undefined -- cannot match
268 find = findV
269 assoc = assocV
270 accum = accumV
271 cond = undefined -- cannot match
272
273instance Container Vector (Complex Float) where
274 scale = vectorMapValQ Scale
275 scaleRecip = vectorMapValQ Recip
276 addConstant = vectorMapValQ AddConstant
277 add = vectorZipQ Add
278 sub = vectorZipQ Sub
279 mul = vectorZipQ Mul
280 divide = vectorZipQ Div
281 equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0
282 arctan2 = vectorZipQ ATan2
283 scalar x = fromList [x]
284 konst' = constant
285 build' = buildV
286 conj = conjugateQ
287 cmap = mapVector
288 atIndex = (@>)
289 minIndex = emptyErrorV "minIndex" (minIndex . fst . fromComplex . (mul <*> conj))
290 maxIndex = emptyErrorV "maxIndex" (maxIndex . fst . fromComplex . (mul <*> conj))
291 minElement = emptyErrorV "minElement" (atIndex <*> minIndex)
292 maxElement = emptyErrorV "maxElement" (atIndex <*> maxIndex)
293 sumElements = sumQ
294 prodElements = prodQ
295 step = undefined -- cannot match
296 find = findV
297 assoc = assocV
298 accum = accumV
299 cond = undefined -- cannot match
300
301---------------------------------------------------------------
302
303instance (Container Vector a) => Container Matrix a where
304 scale x = liftMatrix (scale x)
305 scaleRecip x = liftMatrix (scaleRecip x)
306 addConstant x = liftMatrix (addConstant x)
307 add = liftMatrix2 add
308 sub = liftMatrix2 sub
309 mul = liftMatrix2 mul
310 divide = liftMatrix2 divide
311 equal a b = cols a == cols b && flatten a `equal` flatten b
312 arctan2 = liftMatrix2 arctan2
313 scalar x = (1><1) [x]
314 konst' v (r,c) = matrixFromVector RowMajor r c (konst' v (r*c))
315 build' = buildM
316 conj = liftMatrix conj
317 cmap f = liftMatrix (mapVector f)
318 atIndex = (@@>)
319 minIndex = emptyErrorM "minIndex of Matrix" $
320 \m -> divMod (minIndex $ flatten m) (cols m)
321 maxIndex = emptyErrorM "maxIndex of Matrix" $
322 \m -> divMod (maxIndex $ flatten m) (cols m)
323 minElement = emptyErrorM "minElement of Matrix" (atIndex <*> minIndex)
324 maxElement = emptyErrorM "maxElement of Matrix" (atIndex <*> maxIndex)
325 sumElements = sumElements . flatten
326 prodElements = prodElements . flatten
327 step = liftMatrix step
328 find = findM
329 assoc = assocM
330 accum = accumM
331 cond = condM
332
333
334emptyErrorV msg f v =
335 if dim v > 0
336 then f v
337 else error $ msg ++ " of Vector with dim = 0"
338
339emptyErrorM msg f m =
340 if rows m > 0 && cols m > 0
341 then f m
342 else error $ msg++" "++shSize m
343
344----------------------------------------------------
345
346-- | Matrix product and related functions
347class (Num e, Element e) => Product e where
348 -- | matrix product
349 multiply :: Matrix e -> Matrix e -> Matrix e
350 -- | sum of absolute value of elements (differs in complex case from @norm1@)
351 absSum :: Vector e -> RealOf e
352 -- | sum of absolute value of elements
353 norm1 :: Vector e -> RealOf e
354 -- | euclidean norm
355 norm2 :: Vector e -> RealOf e
356 -- | element of maximum magnitude
357 normInf :: Vector e -> RealOf e
358
359instance Product Float where
360 norm2 = emptyVal (toScalarF Norm2)
361 absSum = emptyVal (toScalarF AbsSum)
362 norm1 = emptyVal (toScalarF AbsSum)
363 normInf = emptyVal (maxElement . vectorMapF Abs)
364 multiply = emptyMul multiplyF
365
366instance Product Double where
367 norm2 = emptyVal (toScalarR Norm2)
368 absSum = emptyVal (toScalarR AbsSum)
369 norm1 = emptyVal (toScalarR AbsSum)
370 normInf = emptyVal (maxElement . vectorMapR Abs)
371 multiply = emptyMul multiplyR
372
373instance Product (Complex Float) where
374 norm2 = emptyVal (toScalarQ Norm2)
375 absSum = emptyVal (toScalarQ AbsSum)
376 norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapQ Abs)
377 normInf = emptyVal (maxElement . fst . fromComplex . vectorMapQ Abs)
378 multiply = emptyMul multiplyQ
379
380instance Product (Complex Double) where
381 norm2 = emptyVal (toScalarC Norm2)
382 absSum = emptyVal (toScalarC AbsSum)
383 norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapC Abs)
384 normInf = emptyVal (maxElement . fst . fromComplex . vectorMapC Abs)
385 multiply = emptyMul multiplyC
386
387emptyMul m a b
388 | x1 == 0 && x2 == 0 || r == 0 || c == 0 = konst' 0 (r,c)
389 | otherwise = m a b
390 where
391 r = rows a
392 x1 = cols a
393 x2 = rows b
394 c = cols b
395
396emptyVal f v =
397 if dim v > 0
398 then f v
399 else 0
400
401-- FIXME remove unused C wrappers
402-- | unconjugated dot product
403udot :: Product e => Vector e -> Vector e -> e
404udot u v
405 | dim u == dim v = val (asRow u `multiply` asColumn v)
406 | otherwise = error $ "different dimensions "++show (dim u)++" and "++show (dim v)++" in dot product"
407 where
408 val m | dim u > 0 = m@@>(0,0)
409 | otherwise = 0
410
411----------------------------------------------------------
412
413-- synonym for matrix product
414mXm :: Product t => Matrix t -> Matrix t -> Matrix t
415mXm = multiply
416
417-- matrix - vector product
418mXv :: Product t => Matrix t -> Vector t -> Vector t
419mXv m v = flatten $ m `mXm` (asColumn v)
420
421-- vector - matrix product
422vXm :: Product t => Vector t -> Matrix t -> Vector t
423vXm v m = flatten $ (asRow v) `mXm` m
424
425{- | Outer product of two vectors.
426
427>>> fromList [1,2,3] `outer` fromList [5,2,3]
428(3><3)
429 [ 5.0, 2.0, 3.0
430 , 10.0, 4.0, 6.0
431 , 15.0, 6.0, 9.0 ]
432
433-}
434outer :: (Product t) => Vector t -> Vector t -> Matrix t
435outer u v = asColumn u `multiply` asRow v
436
437{- | Kronecker product of two matrices.
438
439@m1=(2><3)
440 [ 1.0, 2.0, 0.0
441 , 0.0, -1.0, 3.0 ]
442m2=(4><3)
443 [ 1.0, 2.0, 3.0
444 , 4.0, 5.0, 6.0
445 , 7.0, 8.0, 9.0
446 , 10.0, 11.0, 12.0 ]@
447
448>>> kronecker m1 m2
449(8><9)
450 [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0
451 , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0
452 , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0
453 , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0
454 , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0
455 , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0
456 , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0
457 , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]
458
459-}
460kronecker :: (Product t) => Matrix t -> Matrix t -> Matrix t
461kronecker a b = fromBlocks
462 . splitEvery (cols a)
463 . map (reshape (cols b))
464 . toRows
465 $ flatten a `outer` flatten b
466
467-------------------------------------------------------------------
468
469
470class Convert t where
471 real :: Container c t => c (RealOf t) -> c t
472 complex :: Container c t => c t -> c (ComplexOf t)
473 single :: Container c t => c t -> c (SingleOf t)
474 double :: Container c t => c t -> c (DoubleOf t)
475 toComplex :: (Container c t, RealElement t) => (c t, c t) -> c (Complex t)
476 fromComplex :: (Container c t, RealElement t) => c (Complex t) -> (c t, c t)
477
478
479instance Convert Double where
480 real = id
481 complex = comp'
482 single = single'
483 double = id
484 toComplex = toComplex'
485 fromComplex = fromComplex'
486
487instance Convert Float where
488 real = id
489 complex = comp'
490 single = id
491 double = double'
492 toComplex = toComplex'
493 fromComplex = fromComplex'
494
495instance Convert (Complex Double) where
496 real = comp'
497 complex = id
498 single = single'
499 double = id
500 toComplex = toComplex'
501 fromComplex = fromComplex'
502
503instance Convert (Complex Float) where
504 real = comp'
505 complex = id
506 single = id
507 double = double'
508 toComplex = toComplex'
509 fromComplex = fromComplex'
510
511-------------------------------------------------------------------
512
513type family RealOf x
514
515type instance RealOf Double = Double
516type instance RealOf (Complex Double) = Double
517
518type instance RealOf Float = Float
519type instance RealOf (Complex Float) = Float
520
521type family ComplexOf x
522
523type instance ComplexOf Double = Complex Double
524type instance ComplexOf (Complex Double) = Complex Double
525
526type instance ComplexOf Float = Complex Float
527type instance ComplexOf (Complex Float) = Complex Float
528
529type family SingleOf x
530
531type instance SingleOf Double = Float
532type instance SingleOf Float = Float
533
534type instance SingleOf (Complex a) = Complex (SingleOf a)
535
536type family DoubleOf x
537
538type instance DoubleOf Double = Double
539type instance DoubleOf Float = Double
540
541type instance DoubleOf (Complex a) = Complex (DoubleOf a)
542
543type family ElementOf c
544
545type instance ElementOf (Vector a) = a
546type instance ElementOf (Matrix a) = a
547
548------------------------------------------------------------
549
550buildM (rc,cc) f = fromLists [ [f r c | c <- cs] | r <- rs ]
551 where rs = map fromIntegral [0 .. (rc-1)]
552 cs = map fromIntegral [0 .. (cc-1)]
553
554buildV n f = fromList [f k | k <- ks]
555 where ks = map fromIntegral [0 .. (n-1)]
556
557--------------------------------------------------------
558-- | conjugate transpose
559ctrans :: (Container Vector e, Element e) => Matrix e -> Matrix e
560ctrans = liftMatrix conj . trans
561
562-- | Creates a square matrix with a given diagonal.
563diag :: (Num a, Element a) => Vector a -> Matrix a
564diag v = diagRect 0 v n n where n = dim v
565
566-- | creates the identity matrix of given dimension
567ident :: (Num a, Element a) => Int -> Matrix a
568ident n = diag (constant 1 n)
569
570--------------------------------------------------------
571
572findV p x = foldVectorWithIndex g [] x where
573 g k z l = if p z then k:l else l
574
575findM p x = map ((`divMod` cols x)) $ findV p (flatten x)
576
577assocV n z xs = ST.runSTVector $ do
578 v <- ST.newVector z n
579 mapM_ (\(k,x) -> ST.writeVector v k x) xs
580 return v
581
582assocM (r,c) z xs = ST.runSTMatrix $ do
583 m <- ST.newMatrix z r c
584 mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs
585 return m
586
587accumV v0 f xs = ST.runSTVector $ do
588 v <- ST.thawVector v0
589 mapM_ (\(k,x) -> ST.modifyVector v k (f x)) xs
590 return v
591
592accumM m0 f xs = ST.runSTMatrix $ do
593 m <- ST.thawMatrix m0
594 mapM_ (\((i,j),x) -> ST.modifyMatrix m i j (f x)) xs
595 return m
596
597----------------------------------------------------------------------
598
599condM a b l e t = matrixFromVector RowMajor (rows a'') (cols a'') $ cond a' b' l' e' t'
600 where
601 args@(a'':_) = conformMs [a,b,l,e,t]
602 [a', b', l', e', t'] = map flatten args
603
604condV f a b l e t = f a' b' l' e' t'
605 where
606 [a', b', l', e', t'] = conformVs [a,b,l,e,t]
607
diff --git a/packages/base/src/Data/Packed/Numeric.hs b/packages/base/src/Data/Packed/Numeric.hs
index c13e91d..6036e8c 100644
--- a/packages/base/src/Data/Packed/Numeric.hs
+++ b/packages/base/src/Data/Packed/Numeric.hs
@@ -1,8 +1,8 @@
1{-# LANGUAGE CPP #-}
2{-# LANGUAGE TypeFamilies #-} 1{-# LANGUAGE TypeFamilies #-}
3{-# LANGUAGE FlexibleContexts #-} 2{-# LANGUAGE FlexibleContexts #-}
4{-# LANGUAGE FlexibleInstances #-} 3{-# LANGUAGE FlexibleInstances #-}
5{-# LANGUAGE MultiParamTypeClasses #-} 4{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE FunctionalDependencies #-}
6{-# LANGUAGE UndecidableInstances #-} 6{-# LANGUAGE UndecidableInstances #-}
7 7
8----------------------------------------------------------------------------- 8-----------------------------------------------------------------------------
@@ -13,16 +13,32 @@
13-- Maintainer : Alberto Ruiz 13-- Maintainer : Alberto Ruiz
14-- Stability : provisional 14-- Stability : provisional
15-- 15--
16-- Basic numeric operations on 'Vector' and 'Matrix', including conversion routines.
17--
18-- The 'Container' class is used to define optimized generic functions which work
19-- on 'Vector' and 'Matrix' with real or complex elements.
20--
21-- Some of these functions are also available in the instances of the standard
22-- numeric Haskell classes provided by "Numeric.LinearAlgebra".
23--
16----------------------------------------------------------------------------- 24-----------------------------------------------------------------------------
25{-# OPTIONS_HADDOCK hide #-}
17 26
18module Data.Packed.Numeric ( 27module Data.Packed.Numeric (
19 -- * Basic functions 28 -- * Basic functions
20 ident, diag, ctrans, 29 module Data.Packed,
30 konst, build,
31 linspace,
32 diag, ident,
33 ctrans,
21 -- * Generic operations 34 -- * Generic operations
22 Container(..), 35 Container(..),
23 -- * Matrix product and related functions 36 -- * Matrix product
24 Product(..), udot, 37 Product(..), udot, dot, (◇),
25 mXm,mXv,vXm, 38 Mul(..),
39 Contraction(..),
40 optimiseMult,
41 mXm,mXv,vXm,LSDiv(..),
26 outer, kronecker, 42 outer, kronecker,
27 -- * Element conversion 43 -- * Element conversion
28 Convert(..), 44 Convert(..),
@@ -32,576 +48,196 @@ module Data.Packed.Numeric (
32 RealOf, ComplexOf, SingleOf, DoubleOf, 48 RealOf, ComplexOf, SingleOf, DoubleOf,
33 49
34 IndexOf, 50 IndexOf,
35 module Data.Complex 51 module Data.Complex,
52 -- * IO
53 module Data.Packed.IO
36) where 54) where
37 55
38import Data.Packed 56import Data.Packed hiding (stepD, stepF, condD, condF, conjugateC, conjugateQ)
39import Data.Packed.ST as ST 57import Data.Packed.Internal.Numeric
40import Numeric.Conversion
41import Data.Packed.Development
42import Numeric.Vectorized
43import Data.Complex 58import Data.Complex
44import Control.Applicative((<*>)) 59import Numeric.LinearAlgebra.Algorithms(Field,linearSolveSVD)
45 60import Data.Monoid(Monoid(mconcat))
46import Numeric.LinearAlgebra.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ) 61import Data.Packed.IO
47 62
48------------------------------------------------------------------- 63------------------------------------------------------------------
49 64
50type family IndexOf (c :: * -> *) 65{- | Creates a real vector containing a range of values:
51 66
52type instance IndexOf Vector = Int 67>>> linspace 5 (-3,7::Double)
53type instance IndexOf Matrix = (Int,Int) 68fromList [-3.0,-0.5,2.0,4.5,7.0]@
54 69
55type family ArgOf (c :: * -> *) a 70>>> linspace 5 (8,2+i) :: Vector (Complex Double)
71fromList [8.0 :+ 0.0,6.5 :+ 0.25,5.0 :+ 0.5,3.5 :+ 0.75,2.0 :+ 1.0]
56 72
57type instance ArgOf Vector a = a -> a 73Logarithmic spacing can be defined as follows:
58type instance ArgOf Matrix a = a -> a -> a
59 74
60------------------------------------------------------------------- 75@logspace n (a,b) = 10 ** linspace n (a,b)@
61 76-}
62-- | Basic element-by-element functions for numeric containers 77linspace :: (Container Vector e) => Int -> (e, e) -> Vector e
63class (Complexable c, Fractional e, Element e) => Container c e where 78linspace 0 (a,b) = fromList[(a+b)/2]
64 -- | create a structure with a single element 79linspace n (a,b) = addConstant a $ scale s $ fromList $ map fromIntegral [0 .. n-1]
65 -- 80 where s = (b-a)/fromIntegral (n-1)
66 -- >>> let v = fromList [1..3::Double]
67 -- >>> v / scalar (norm2 v)
68 -- fromList [0.2672612419124244,0.5345224838248488,0.8017837257372732]
69 --
70 scalar :: e -> c e
71 -- | complex conjugate
72 conj :: c e -> c e
73 scale :: e -> c e -> c e
74 -- | scale the element by element reciprocal of the object:
75 --
76 -- @scaleRecip 2 (fromList [5,i]) == 2 |> [0.4 :+ 0.0,0.0 :+ (-2.0)]@
77 scaleRecip :: e -> c e -> c e
78 addConstant :: e -> c e -> c e
79 add :: c e -> c e -> c e
80 sub :: c e -> c e -> c e
81 -- | element by element multiplication
82 mul :: c e -> c e -> c e
83 -- | element by element division
84 divide :: c e -> c e -> c e
85 equal :: c e -> c e -> Bool
86 --
87 -- element by element inverse tangent
88 arctan2 :: c e -> c e -> c e
89 --
90 -- | cannot implement instance Functor because of Element class constraint
91 cmap :: (Element b) => (e -> b) -> c e -> c b
92 -- | constant structure of given size
93 konst' :: e -> IndexOf c -> c e
94 -- | create a structure using a function
95 --
96 -- Hilbert matrix of order N:
97 --
98 -- @hilb n = build' (n,n) (\\i j -> 1/(i+j+1))@
99 build' :: IndexOf c -> (ArgOf c e) -> c e
100 -- | indexing function
101 atIndex :: c e -> IndexOf c -> e
102 -- | index of min element
103 minIndex :: c e -> IndexOf c
104 -- | index of max element
105 maxIndex :: c e -> IndexOf c
106 -- | value of min element
107 minElement :: c e -> e
108 -- | value of max element
109 maxElement :: c e -> e
110 -- the C functions sumX/prodX are twice as fast as using foldVector
111 -- | the sum of elements (faster than using @fold@)
112 sumElements :: c e -> e
113 -- | the product of elements (faster than using @fold@)
114 prodElements :: c e -> e
115
116 -- | A more efficient implementation of @cmap (\\x -> if x>0 then 1 else 0)@
117 --
118 -- >>> step $ linspace 5 (-1,1::Double)
119 -- 5 |> [0.0,0.0,0.0,1.0,1.0]
120 --
121
122 step :: RealElement e => c e -> c e
123
124 -- | Element by element version of @case compare a b of {LT -> l; EQ -> e; GT -> g}@.
125 --
126 -- Arguments with any dimension = 1 are automatically expanded:
127 --
128 -- >>> cond ((1><4)[1..]) ((3><1)[1..]) 0 100 ((3><4)[1..]) :: Matrix Double
129 -- (3><4)
130 -- [ 100.0, 2.0, 3.0, 4.0
131 -- , 0.0, 100.0, 7.0, 8.0
132 -- , 0.0, 0.0, 100.0, 12.0 ]
133 --
134
135 cond :: RealElement e
136 => c e -- ^ a
137 -> c e -- ^ b
138 -> c e -- ^ l
139 -> c e -- ^ e
140 -> c e -- ^ g
141 -> c e -- ^ result
142
143 -- | Find index of elements which satisfy a predicate
144 --
145 -- >>> find (>0) (ident 3 :: Matrix Double)
146 -- [(0,0),(1,1),(2,2)]
147 --
148
149 find :: (e -> Bool) -> c e -> [IndexOf c]
150 81
151 -- | Create a structure from an association list 82--------------------------------------------------------
152 --
153 -- >>> assoc 5 0 [(3,7),(1,4)] :: Vector Double
154 -- fromList [0.0,4.0,0.0,7.0,0.0]
155 --
156 -- >>> assoc (2,3) 0 [((0,2),7),((1,0),2*i-3)] :: Matrix (Complex Double)
157 -- (2><3)
158 -- [ 0.0 :+ 0.0, 0.0 :+ 0.0, 7.0 :+ 0.0
159 -- , (-3.0) :+ 2.0, 0.0 :+ 0.0, 0.0 :+ 0.0 ]
160 --
161 assoc :: IndexOf c -- ^ size
162 -> e -- ^ default value
163 -> [(IndexOf c, e)] -- ^ association list
164 -> c e -- ^ result
165 83
166 -- | Modify a structure using an update function 84class Contraction a b c | a b -> c
167 --
168 -- >>> accum (ident 5) (+) [((1,1),5),((0,3),3)] :: Matrix Double
169 -- (5><5)
170 -- [ 1.0, 0.0, 0.0, 3.0, 0.0
171 -- , 0.0, 6.0, 0.0, 0.0, 0.0
172 -- , 0.0, 0.0, 1.0, 0.0, 0.0
173 -- , 0.0, 0.0, 0.0, 1.0, 0.0
174 -- , 0.0, 0.0, 0.0, 0.0, 1.0 ]
175 --
176 -- computation of histogram:
177 --
178 -- >>> accum (konst 0 7) (+) (map (flip (,) 1) [4,5,4,1,5,2,5]) :: Vector Double
179 -- fromList [0.0,1.0,1.0,0.0,2.0,3.0,0.0]
180 --
181
182 accum :: c e -- ^ initial structure
183 -> (e -> e -> e) -- ^ update function
184 -> [(IndexOf c, e)] -- ^ association list
185 -> c e -- ^ result
186
187--------------------------------------------------------------------------
188
189instance Container Vector Float where
190 scale = vectorMapValF Scale
191 scaleRecip = vectorMapValF Recip
192 addConstant = vectorMapValF AddConstant
193 add = vectorZipF Add
194 sub = vectorZipF Sub
195 mul = vectorZipF Mul
196 divide = vectorZipF Div
197 equal u v = dim u == dim v && maxElement (vectorMapF Abs (sub u v)) == 0.0
198 arctan2 = vectorZipF ATan2
199 scalar x = fromList [x]
200 konst' = constant
201 build' = buildV
202 conj = id
203 cmap = mapVector
204 atIndex = (@>)
205 minIndex = emptyErrorV "minIndex" (round . toScalarF MinIdx)
206 maxIndex = emptyErrorV "maxIndex" (round . toScalarF MaxIdx)
207 minElement = emptyErrorV "minElement" (toScalarF Min)
208 maxElement = emptyErrorV "maxElement" (toScalarF Max)
209 sumElements = sumF
210 prodElements = prodF
211 step = stepF
212 find = findV
213 assoc = assocV
214 accum = accumV
215 cond = condV condF
216
217instance Container Vector Double where
218 scale = vectorMapValR Scale
219 scaleRecip = vectorMapValR Recip
220 addConstant = vectorMapValR AddConstant
221 add = vectorZipR Add
222 sub = vectorZipR Sub
223 mul = vectorZipR Mul
224 divide = vectorZipR Div
225 equal u v = dim u == dim v && maxElement (vectorMapR Abs (sub u v)) == 0.0
226 arctan2 = vectorZipR ATan2
227 scalar x = fromList [x]
228 konst' = constant
229 build' = buildV
230 conj = id
231 cmap = mapVector
232 atIndex = (@>)
233 minIndex = emptyErrorV "minIndex" (round . toScalarR MinIdx)
234 maxIndex = emptyErrorV "maxIndex" (round . toScalarR MaxIdx)
235 minElement = emptyErrorV "minElement" (toScalarR Min)
236 maxElement = emptyErrorV "maxElement" (toScalarR Max)
237 sumElements = sumR
238 prodElements = prodR
239 step = stepD
240 find = findV
241 assoc = assocV
242 accum = accumV
243 cond = condV condD
244
245instance Container Vector (Complex Double) where
246 scale = vectorMapValC Scale
247 scaleRecip = vectorMapValC Recip
248 addConstant = vectorMapValC AddConstant
249 add = vectorZipC Add
250 sub = vectorZipC Sub
251 mul = vectorZipC Mul
252 divide = vectorZipC Div
253 equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0
254 arctan2 = vectorZipC ATan2
255 scalar x = fromList [x]
256 konst' = constant
257 build' = buildV
258 conj = conjugateC
259 cmap = mapVector
260 atIndex = (@>)
261 minIndex = emptyErrorV "minIndex" (minIndex . fst . fromComplex . (mul <*> conj))
262 maxIndex = emptyErrorV "maxIndex" (maxIndex . fst . fromComplex . (mul <*> conj))
263 minElement = emptyErrorV "minElement" (atIndex <*> minIndex)
264 maxElement = emptyErrorV "maxElement" (atIndex <*> maxIndex)
265 sumElements = sumC
266 prodElements = prodC
267 step = undefined -- cannot match
268 find = findV
269 assoc = assocV
270 accum = accumV
271 cond = undefined -- cannot match
272
273instance Container Vector (Complex Float) where
274 scale = vectorMapValQ Scale
275 scaleRecip = vectorMapValQ Recip
276 addConstant = vectorMapValQ AddConstant
277 add = vectorZipQ Add
278 sub = vectorZipQ Sub
279 mul = vectorZipQ Mul
280 divide = vectorZipQ Div
281 equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0
282 arctan2 = vectorZipQ ATan2
283 scalar x = fromList [x]
284 konst' = constant
285 build' = buildV
286 conj = conjugateQ
287 cmap = mapVector
288 atIndex = (@>)
289 minIndex = emptyErrorV "minIndex" (minIndex . fst . fromComplex . (mul <*> conj))
290 maxIndex = emptyErrorV "maxIndex" (maxIndex . fst . fromComplex . (mul <*> conj))
291 minElement = emptyErrorV "minElement" (atIndex <*> minIndex)
292 maxElement = emptyErrorV "maxElement" (atIndex <*> maxIndex)
293 sumElements = sumQ
294 prodElements = prodQ
295 step = undefined -- cannot match
296 find = findV
297 assoc = assocV
298 accum = accumV
299 cond = undefined -- cannot match
300
301---------------------------------------------------------------
302
303instance (Container Vector a) => Container Matrix a where
304 scale x = liftMatrix (scale x)
305 scaleRecip x = liftMatrix (scaleRecip x)
306 addConstant x = liftMatrix (addConstant x)
307 add = liftMatrix2 add
308 sub = liftMatrix2 sub
309 mul = liftMatrix2 mul
310 divide = liftMatrix2 divide
311 equal a b = cols a == cols b && flatten a `equal` flatten b
312 arctan2 = liftMatrix2 arctan2
313 scalar x = (1><1) [x]
314 konst' v (r,c) = matrixFromVector RowMajor r c (konst' v (r*c))
315 build' = buildM
316 conj = liftMatrix conj
317 cmap f = liftMatrix (mapVector f)
318 atIndex = (@@>)
319 minIndex = emptyErrorM "minIndex of Matrix" $
320 \m -> divMod (minIndex $ flatten m) (cols m)
321 maxIndex = emptyErrorM "maxIndex of Matrix" $
322 \m -> divMod (maxIndex $ flatten m) (cols m)
323 minElement = emptyErrorM "minElement of Matrix" (atIndex <*> minIndex)
324 maxElement = emptyErrorM "maxElement of Matrix" (atIndex <*> maxIndex)
325 sumElements = sumElements . flatten
326 prodElements = prodElements . flatten
327 step = liftMatrix step
328 find = findM
329 assoc = assocM
330 accum = accumM
331 cond = condM
332
333
334emptyErrorV msg f v =
335 if dim v > 0
336 then f v
337 else error $ msg ++ " of Vector with dim = 0"
338
339emptyErrorM msg f m =
340 if rows m > 0 && cols m > 0
341 then f m
342 else error $ msg++" "++shSize m
343
344----------------------------------------------------
345
346-- | Matrix product and related functions
347class (Num e, Element e) => Product e where
348 -- | matrix product
349 multiply :: Matrix e -> Matrix e -> Matrix e
350 -- | sum of absolute value of elements (differs in complex case from @norm1@)
351 absSum :: Vector e -> RealOf e
352 -- | sum of absolute value of elements
353 norm1 :: Vector e -> RealOf e
354 -- | euclidean norm
355 norm2 :: Vector e -> RealOf e
356 -- | element of maximum magnitude
357 normInf :: Vector e -> RealOf e
358
359instance Product Float where
360 norm2 = emptyVal (toScalarF Norm2)
361 absSum = emptyVal (toScalarF AbsSum)
362 norm1 = emptyVal (toScalarF AbsSum)
363 normInf = emptyVal (maxElement . vectorMapF Abs)
364 multiply = emptyMul multiplyF
365
366instance Product Double where
367 norm2 = emptyVal (toScalarR Norm2)
368 absSum = emptyVal (toScalarR AbsSum)
369 norm1 = emptyVal (toScalarR AbsSum)
370 normInf = emptyVal (maxElement . vectorMapR Abs)
371 multiply = emptyMul multiplyR
372
373instance Product (Complex Float) where
374 norm2 = emptyVal (toScalarQ Norm2)
375 absSum = emptyVal (toScalarQ AbsSum)
376 norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapQ Abs)
377 normInf = emptyVal (maxElement . fst . fromComplex . vectorMapQ Abs)
378 multiply = emptyMul multiplyQ
379
380instance Product (Complex Double) where
381 norm2 = emptyVal (toScalarC Norm2)
382 absSum = emptyVal (toScalarC AbsSum)
383 norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapC Abs)
384 normInf = emptyVal (maxElement . fst . fromComplex . vectorMapC Abs)
385 multiply = emptyMul multiplyC
386
387emptyMul m a b
388 | x1 == 0 && x2 == 0 || r == 0 || c == 0 = konst' 0 (r,c)
389 | otherwise = m a b
390 where
391 r = rows a
392 x1 = cols a
393 x2 = rows b
394 c = cols b
395
396emptyVal f v =
397 if dim v > 0
398 then f v
399 else 0
400
401-- FIXME remove unused C wrappers
402-- | unconjugated dot product
403udot :: Product e => Vector e -> Vector e -> e
404udot u v
405 | dim u == dim v = val (asRow u `multiply` asColumn v)
406 | otherwise = error $ "different dimensions "++show (dim u)++" and "++show (dim v)++" in dot product"
407 where 85 where
408 val m | dim u > 0 = m@@>(0,0) 86 infixl 7 <.>
409 | otherwise = 0 87 {- | Matrix product, matrix vector product, and dot product
410 88
411---------------------------------------------------------- 89Examples:
412 90
413-- synonym for matrix product 91>>> let a = (3><4) [1..] :: Matrix Double
414mXm :: Product t => Matrix t -> Matrix t -> Matrix t 92>>> let v = fromList [1,0,2,-1] :: Vector Double
415mXm = multiply 93>>> let u = fromList [1,2,3] :: Vector Double
416 94
417-- matrix - vector product 95>>> a
418mXv :: Product t => Matrix t -> Vector t -> Vector t 96(3><4)
419mXv m v = flatten $ m `mXm` (asColumn v) 97 [ 1.0, 2.0, 3.0, 4.0
98 , 5.0, 6.0, 7.0, 8.0
99 , 9.0, 10.0, 11.0, 12.0 ]
420 100
421-- vector - matrix product 101matrix × matrix:
422vXm :: Product t => Vector t -> Matrix t -> Vector t
423vXm v m = flatten $ (asRow v) `mXm` m
424 102
425{- | Outer product of two vectors. 103>>> disp 2 (a <.> trans a)
1043x3
105 30 70 110
106 70 174 278
107110 278 446
426 108
427>>> fromList [1,2,3] `outer` fromList [5,2,3] 109matrix × vector:
428(3><3)
429 [ 5.0, 2.0, 3.0
430 , 10.0, 4.0, 6.0
431 , 15.0, 6.0, 9.0 ]
432
433-}
434outer :: (Product t) => Vector t -> Vector t -> Matrix t
435outer u v = asColumn u `multiply` asRow v
436
437{- | Kronecker product of two matrices.
438
439@m1=(2><3)
440 [ 1.0, 2.0, 0.0
441 , 0.0, -1.0, 3.0 ]
442m2=(4><3)
443 [ 1.0, 2.0, 3.0
444 , 4.0, 5.0, 6.0
445 , 7.0, 8.0, 9.0
446 , 10.0, 11.0, 12.0 ]@
447
448>>> kronecker m1 m2
449(8><9)
450 [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0
451 , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0
452 , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0
453 , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0
454 , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0
455 , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0
456 , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0
457 , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]
458
459-}
460kronecker :: (Product t) => Matrix t -> Matrix t -> Matrix t
461kronecker a b = fromBlocks
462 . splitEvery (cols a)
463 . map (reshape (cols b))
464 . toRows
465 $ flatten a `outer` flatten b
466 110
467------------------------------------------------------------------- 111>>> a <.> v
112fromList [3.0,11.0,19.0]
468 113
114dot product:
469 115
470class Convert t where 116>>> u <.> fromList[3,2,1::Double]
471 real :: Container c t => c (RealOf t) -> c t 11710
472 complex :: Container c t => c t -> c (ComplexOf t)
473 single :: Container c t => c t -> c (SingleOf t)
474 double :: Container c t => c t -> c (DoubleOf t)
475 toComplex :: (Container c t, RealElement t) => (c t, c t) -> c (Complex t)
476 fromComplex :: (Container c t, RealElement t) => c (Complex t) -> (c t, c t)
477 118
119For complex vectors the first argument is conjugated:
478 120
479instance Convert Double where 121>>> fromList [1,i] <.> fromList[2*i+1,3]
480 real = id 1221.0 :+ (-1.0)
481 complex = comp'
482 single = single'
483 double = id
484 toComplex = toComplex'
485 fromComplex = fromComplex'
486 123
487instance Convert Float where 124>>> fromList [1,i,1-i] <.> complex a
488 real = id 125fromList [10.0 :+ 4.0,12.0 :+ 4.0,14.0 :+ 4.0,16.0 :+ 4.0]
489 complex = comp'
490 single = id
491 double = double'
492 toComplex = toComplex'
493 fromComplex = fromComplex'
494 126
495instance Convert (Complex Double) where 127-}
496 real = comp' 128 (<.>) :: a -> b -> c
497 complex = id
498 single = single'
499 double = id
500 toComplex = toComplex'
501 fromComplex = fromComplex'
502
503instance Convert (Complex Float) where
504 real = comp'
505 complex = id
506 single = id
507 double = double'
508 toComplex = toComplex'
509 fromComplex = fromComplex'
510
511-------------------------------------------------------------------
512 129
513type family RealOf x
514 130
515type instance RealOf Double = Double 131instance (Product t, Container Vector t) => Contraction (Vector t) (Vector t) t where
516type instance RealOf (Complex Double) = Double 132 u <.> v = conj u `udot` v
517 133
518type instance RealOf Float = Float 134instance Product t => Contraction (Matrix t) (Vector t) (Vector t) where
519type instance RealOf (Complex Float) = Float 135 (<.>) = mXv
520 136
521type family ComplexOf x 137instance (Container Vector t, Product t) => Contraction (Vector t) (Matrix t) (Vector t) where
138 (<.>) v m = (conj v) `vXm` m
522 139
523type instance ComplexOf Double = Complex Double 140instance Product t => Contraction (Matrix t) (Matrix t) (Matrix t) where
524type instance ComplexOf (Complex Double) = Complex Double 141 (<.>) = mXm
525 142
526type instance ComplexOf Float = Complex Float
527type instance ComplexOf (Complex Float) = Complex Float
528 143
529type family SingleOf x 144--------------------------------------------------------------------------------
530 145
531type instance SingleOf Double = Float 146class Mul a b c | a b -> c where
532type instance SingleOf Float = Float 147 infixl 7 <>
148 -- | Matrix-matrix, matrix-vector, and vector-matrix products.
149 (<>) :: Product t => a t -> b t -> c t
533 150
534type instance SingleOf (Complex a) = Complex (SingleOf a) 151instance Mul Matrix Matrix Matrix where
152 (<>) = mXm
535 153
536type family DoubleOf x 154instance Mul Matrix Vector Vector where
155 (<>) m v = flatten $ m <> asColumn v
537 156
538type instance DoubleOf Double = Double 157instance Mul Vector Matrix Vector where
539type instance DoubleOf Float = Double 158 (<>) v m = flatten $ asRow v <> m
540 159
541type instance DoubleOf (Complex a) = Complex (DoubleOf a) 160--------------------------------------------------------------------------------
542 161
543type family ElementOf c 162class LSDiv c where
163 infixl 7 <\>
164 -- | least squares solution of a linear system, similar to the \\ operator of Matlab\/Octave (based on linearSolveSVD)
165 (<\>) :: Field t => Matrix t -> c t -> c t
544 166
545type instance ElementOf (Vector a) = a 167instance LSDiv Vector where
546type instance ElementOf (Matrix a) = a 168 m <\> v = flatten (linearSolveSVD m (reshape 1 v))
547 169
548------------------------------------------------------------ 170instance LSDiv Matrix where
171 (<\>) = linearSolveSVD
549 172
550buildM (rc,cc) f = fromLists [ [f r c | c <- cs] | r <- rs ] 173--------------------------------------------------------------------------------
551 where rs = map fromIntegral [0 .. (rc-1)]
552 cs = map fromIntegral [0 .. (cc-1)]
553 174
554buildV n f = fromList [f k | k <- ks] 175class Konst e d c | d -> c, c -> d
555 where ks = map fromIntegral [0 .. (n-1)] 176 where
177 -- |
178 -- >>> konst 7 3 :: Vector Float
179 -- fromList [7.0,7.0,7.0]
180 --
181 -- >>> konst i (3::Int,4::Int)
182 -- (3><4)
183 -- [ 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0
184 -- , 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0
185 -- , 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0 ]
186 --
187 konst :: e -> d -> c e
556 188
557-------------------------------------------------------- 189instance Container Vector e => Konst e Int Vector
558-- | conjugate transpose 190 where
559ctrans :: (Container Vector e, Element e) => Matrix e -> Matrix e 191 konst = konst'
560ctrans = liftMatrix conj . trans
561 192
562-- | Creates a square matrix with a given diagonal. 193instance Container Vector e => Konst e (Int,Int) Matrix
563diag :: (Num a, Element a) => Vector a -> Matrix a 194 where
564diag v = diagRect 0 v n n where n = dim v 195 konst = konst'
565 196
566-- | creates the identity matrix of given dimension 197--------------------------------------------------------------------------------
567ident :: (Num a, Element a) => Int -> Matrix a
568ident n = diag (constant 1 n)
569 198
570-------------------------------------------------------- 199class Build d f c e | d -> c, c -> d, f -> e, f -> d, f -> c, c e -> f, d e -> f
200 where
201 -- |
202 -- >>> build 5 (**2) :: Vector Double
203 -- fromList [0.0,1.0,4.0,9.0,16.0]
204 --
205 -- Hilbert matrix of order N:
206 --
207 -- >>> let hilb n = build (n,n) (\i j -> 1/(i+j+1)) :: Matrix Double
208 -- >>> putStr . dispf 2 $ hilb 3
209 -- 3x3
210 -- 1.00 0.50 0.33
211 -- 0.50 0.33 0.25
212 -- 0.33 0.25 0.20
213 --
214 build :: d -> f -> c e
571 215
572findV p x = foldVectorWithIndex g [] x where 216instance Container Vector e => Build Int (e -> e) Vector e
573 g k z l = if p z then k:l else l 217 where
218 build = build'
574 219
575findM p x = map ((`divMod` cols x)) $ findV p (flatten x) 220instance Container Matrix e => Build (Int,Int) (e -> e -> e) Matrix e
221 where
222 build = build'
576 223
577assocV n z xs = ST.runSTVector $ do 224--------------------------------------------------------------------------------
578 v <- ST.newVector z n
579 mapM_ (\(k,x) -> ST.writeVector v k x) xs
580 return v
581 225
582assocM (r,c) z xs = ST.runSTMatrix $ do 226{- | alternative operator for '(\<.\>)'
583 m <- ST.newMatrix z r c
584 mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs
585 return m
586 227
587accumV v0 f xs = ST.runSTVector $ do 228x25c7, white diamond
588 v <- ST.thawVector v0
589 mapM_ (\(k,x) -> ST.modifyVector v k (f x)) xs
590 return v
591 229
592accumM m0 f xs = ST.runSTMatrix $ do 230-}
593 m <- ST.thawMatrix m0 231(◇) :: Contraction a b c => a -> b -> c
594 mapM_ (\((i,j),x) -> ST.modifyMatrix m i j (f x)) xs 232infixl 7
595 return m 233(◇) = (<.>)
596 234
597---------------------------------------------------------------------- 235-- | dot product: @cdot u v = 'udot' ('conj' u) v@
236dot :: (Container Vector t, Product t) => Vector t -> Vector t -> t
237dot u v = udot (conj u) v
598 238
599condM a b l e t = matrixFromVector RowMajor (rows a'') (cols a'') $ cond a' b' l' e' t' 239--------------------------------------------------------------------------------
600 where
601 args@(a'':_) = conformMs [a,b,l,e,t]
602 [a', b', l', e', t'] = map flatten args
603 240
604condV f a b l e t = f a' b' l' e' t' 241optimiseMult :: Monoid (Matrix t) => [Matrix t] -> Matrix t
605 where 242optimiseMult = mconcat
606 [a', b', l', e', t'] = conformVs [a,b,l,e,t]
607 243