diff options
Diffstat (limited to 'lib/Numeric/ContainerBoot.hs')
-rw-r--r-- | lib/Numeric/ContainerBoot.hs | 77 |
1 files changed, 48 insertions, 29 deletions
diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs index a333489..6445e04 100644 --- a/lib/Numeric/ContainerBoot.hs +++ b/lib/Numeric/ContainerBoot.hs | |||
@@ -25,7 +25,7 @@ module Numeric.ContainerBoot ( | |||
25 | -- * Generic operations | 25 | -- * Generic operations |
26 | Container(..), | 26 | Container(..), |
27 | -- * Matrix product and related functions | 27 | -- * Matrix product and related functions |
28 | Product(..), | 28 | Product(..), udot, |
29 | mXm,mXv,vXm, | 29 | mXm,mXv,vXm, |
30 | outer, kronecker, | 30 | outer, kronecker, |
31 | -- * Element conversion | 31 | -- * Element conversion |
@@ -315,7 +315,7 @@ instance (Container Vector a) => Container Matrix a where | |||
315 | equal a b = cols a == cols b && flatten a `equal` flatten b | 315 | equal a b = cols a == cols b && flatten a `equal` flatten b |
316 | arctan2 = liftMatrix2 arctan2 | 316 | arctan2 = liftMatrix2 arctan2 |
317 | scalar x = (1><1) [x] | 317 | scalar x = (1><1) [x] |
318 | konst' v (r,c) = reshape c (konst' v (r*c)) | 318 | konst' v (r,c) = matrixFromVector RowMajor r c (konst' v (r*c)) |
319 | build' = buildM | 319 | build' = buildM |
320 | conj = liftMatrix conj | 320 | conj = liftMatrix conj |
321 | cmap f = liftMatrix (mapVector f) | 321 | cmap f = liftMatrix (mapVector f) |
@@ -339,11 +339,9 @@ instance (Container Vector a) => Container Matrix a where | |||
339 | ---------------------------------------------------- | 339 | ---------------------------------------------------- |
340 | 340 | ||
341 | -- | Matrix product and related functions | 341 | -- | Matrix product and related functions |
342 | class Element e => Product e where | 342 | class (Num e, Element e) => Product e where |
343 | -- | matrix product | 343 | -- | matrix product |
344 | multiply :: Matrix e -> Matrix e -> Matrix e | 344 | multiply :: Matrix e -> Matrix e -> Matrix e |
345 | -- | (unconjugated) dot product | ||
346 | udot :: Vector e -> Vector e -> e | ||
347 | -- | sum of absolute value of elements (differs in complex case from @norm1@) | 345 | -- | sum of absolute value of elements (differs in complex case from @norm1@) |
348 | absSum :: Vector e -> RealOf e | 346 | absSum :: Vector e -> RealOf e |
349 | -- | sum of absolute value of elements | 347 | -- | sum of absolute value of elements |
@@ -354,36 +352,57 @@ class Element e => Product e where | |||
354 | normInf :: Vector e -> RealOf e | 352 | normInf :: Vector e -> RealOf e |
355 | 353 | ||
356 | instance Product Float where | 354 | instance Product Float where |
357 | norm2 = toScalarF Norm2 | 355 | norm2 = emptyVal (toScalarF Norm2) |
358 | absSum = toScalarF AbsSum | 356 | absSum = emptyVal (toScalarF AbsSum) |
359 | udot = dotF | 357 | norm1 = emptyVal (toScalarF AbsSum) |
360 | norm1 = toScalarF AbsSum | 358 | normInf = emptyVal (maxElement . vectorMapF Abs) |
361 | normInf = maxElement . vectorMapF Abs | 359 | multiply = emptyMul multiplyF |
362 | multiply = multiplyF | ||
363 | 360 | ||
364 | instance Product Double where | 361 | instance Product Double where |
365 | norm2 = toScalarR Norm2 | 362 | norm2 = emptyVal (toScalarR Norm2) |
366 | absSum = toScalarR AbsSum | 363 | absSum = emptyVal (toScalarR AbsSum) |
367 | udot = dotR | 364 | norm1 = emptyVal (toScalarR AbsSum) |
368 | norm1 = toScalarR AbsSum | 365 | normInf = emptyVal (maxElement . vectorMapR Abs) |
369 | normInf = maxElement . vectorMapR Abs | 366 | multiply = emptyMul multiplyR |
370 | multiply = multiplyR | ||
371 | 367 | ||
372 | instance Product (Complex Float) where | 368 | instance Product (Complex Float) where |
373 | norm2 = toScalarQ Norm2 | 369 | norm2 = emptyVal (toScalarQ Norm2) |
374 | absSum = toScalarQ AbsSum | 370 | absSum = emptyVal (toScalarQ AbsSum) |
375 | udot = dotQ | 371 | norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapQ Abs) |
376 | norm1 = sumElements . fst . fromComplex . vectorMapQ Abs | 372 | normInf = emptyVal (maxElement . fst . fromComplex . vectorMapQ Abs) |
377 | normInf = maxElement . fst . fromComplex . vectorMapQ Abs | 373 | multiply = emptyMul multiplyQ |
378 | multiply = multiplyQ | ||
379 | 374 | ||
380 | instance Product (Complex Double) where | 375 | instance Product (Complex Double) where |
381 | norm2 = toScalarC Norm2 | 376 | norm2 = emptyVal (toScalarC Norm2) |
382 | absSum = toScalarC AbsSum | 377 | absSum = emptyVal (toScalarC AbsSum) |
383 | udot = dotC | 378 | norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapC Abs) |
384 | norm1 = sumElements . fst . fromComplex . vectorMapC Abs | 379 | normInf = emptyVal (maxElement . fst . fromComplex . vectorMapC Abs) |
385 | normInf = maxElement . fst . fromComplex . vectorMapC Abs | 380 | multiply = emptyMul multiplyC |
386 | multiply = multiplyC | 381 | |
382 | emptyMul m a b | ||
383 | | x1 == 0 && x2 == 0 || r == 0 || c == 0 = konst' 0 (r,c) | ||
384 | | otherwise = m a b | ||
385 | where | ||
386 | r = rows a | ||
387 | x1 = cols a | ||
388 | x2 = rows b | ||
389 | c = cols b | ||
390 | |||
391 | emptyVal f v = | ||
392 | if dim v > 0 | ||
393 | then f v | ||
394 | else 0 | ||
395 | |||
396 | |||
397 | -- FIXME remove unused C wrappers | ||
398 | -- | (unconjugated) dot product | ||
399 | udot :: Product e => Vector e -> Vector e -> e | ||
400 | udot u v | ||
401 | | dim u == dim v = val (asRow u `multiply` asColumn v) | ||
402 | | otherwise = error $ "different dimensions "++show (dim u)++" and "++show (dim v)++" in dot product" | ||
403 | where | ||
404 | val m | dim u > 0 = m@@>(0,0) | ||
405 | | otherwise = 0 | ||
387 | 406 | ||
388 | ---------------------------------------------------------- | 407 | ---------------------------------------------------------- |
389 | 408 | ||