summaryrefslogtreecommitdiff
path: root/lib/Numeric/ContainerBoot.hs
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Numeric/ContainerBoot.hs')
-rw-r--r--lib/Numeric/ContainerBoot.hs77
1 files changed, 48 insertions, 29 deletions
diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs
index a333489..6445e04 100644
--- a/lib/Numeric/ContainerBoot.hs
+++ b/lib/Numeric/ContainerBoot.hs
@@ -25,7 +25,7 @@ module Numeric.ContainerBoot (
25 -- * Generic operations 25 -- * Generic operations
26 Container(..), 26 Container(..),
27 -- * Matrix product and related functions 27 -- * Matrix product and related functions
28 Product(..), 28 Product(..), udot,
29 mXm,mXv,vXm, 29 mXm,mXv,vXm,
30 outer, kronecker, 30 outer, kronecker,
31 -- * Element conversion 31 -- * Element conversion
@@ -315,7 +315,7 @@ instance (Container Vector a) => Container Matrix a where
315 equal a b = cols a == cols b && flatten a `equal` flatten b 315 equal a b = cols a == cols b && flatten a `equal` flatten b
316 arctan2 = liftMatrix2 arctan2 316 arctan2 = liftMatrix2 arctan2
317 scalar x = (1><1) [x] 317 scalar x = (1><1) [x]
318 konst' v (r,c) = reshape c (konst' v (r*c)) 318 konst' v (r,c) = matrixFromVector RowMajor r c (konst' v (r*c))
319 build' = buildM 319 build' = buildM
320 conj = liftMatrix conj 320 conj = liftMatrix conj
321 cmap f = liftMatrix (mapVector f) 321 cmap f = liftMatrix (mapVector f)
@@ -339,11 +339,9 @@ instance (Container Vector a) => Container Matrix a where
339---------------------------------------------------- 339----------------------------------------------------
340 340
341-- | Matrix product and related functions 341-- | Matrix product and related functions
342class Element e => Product e where 342class (Num e, Element e) => Product e where
343 -- | matrix product 343 -- | matrix product
344 multiply :: Matrix e -> Matrix e -> Matrix e 344 multiply :: Matrix e -> Matrix e -> Matrix e
345 -- | (unconjugated) dot product
346 udot :: Vector e -> Vector e -> e
347 -- | sum of absolute value of elements (differs in complex case from @norm1@) 345 -- | sum of absolute value of elements (differs in complex case from @norm1@)
348 absSum :: Vector e -> RealOf e 346 absSum :: Vector e -> RealOf e
349 -- | sum of absolute value of elements 347 -- | sum of absolute value of elements
@@ -354,36 +352,57 @@ class Element e => Product e where
354 normInf :: Vector e -> RealOf e 352 normInf :: Vector e -> RealOf e
355 353
356instance Product Float where 354instance Product Float where
357 norm2 = toScalarF Norm2 355 norm2 = emptyVal (toScalarF Norm2)
358 absSum = toScalarF AbsSum 356 absSum = emptyVal (toScalarF AbsSum)
359 udot = dotF 357 norm1 = emptyVal (toScalarF AbsSum)
360 norm1 = toScalarF AbsSum 358 normInf = emptyVal (maxElement . vectorMapF Abs)
361 normInf = maxElement . vectorMapF Abs 359 multiply = emptyMul multiplyF
362 multiply = multiplyF
363 360
364instance Product Double where 361instance Product Double where
365 norm2 = toScalarR Norm2 362 norm2 = emptyVal (toScalarR Norm2)
366 absSum = toScalarR AbsSum 363 absSum = emptyVal (toScalarR AbsSum)
367 udot = dotR 364 norm1 = emptyVal (toScalarR AbsSum)
368 norm1 = toScalarR AbsSum 365 normInf = emptyVal (maxElement . vectorMapR Abs)
369 normInf = maxElement . vectorMapR Abs 366 multiply = emptyMul multiplyR
370 multiply = multiplyR
371 367
372instance Product (Complex Float) where 368instance Product (Complex Float) where
373 norm2 = toScalarQ Norm2 369 norm2 = emptyVal (toScalarQ Norm2)
374 absSum = toScalarQ AbsSum 370 absSum = emptyVal (toScalarQ AbsSum)
375 udot = dotQ 371 norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapQ Abs)
376 norm1 = sumElements . fst . fromComplex . vectorMapQ Abs 372 normInf = emptyVal (maxElement . fst . fromComplex . vectorMapQ Abs)
377 normInf = maxElement . fst . fromComplex . vectorMapQ Abs 373 multiply = emptyMul multiplyQ
378 multiply = multiplyQ
379 374
380instance Product (Complex Double) where 375instance Product (Complex Double) where
381 norm2 = toScalarC Norm2 376 norm2 = emptyVal (toScalarC Norm2)
382 absSum = toScalarC AbsSum 377 absSum = emptyVal (toScalarC AbsSum)
383 udot = dotC 378 norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapC Abs)
384 norm1 = sumElements . fst . fromComplex . vectorMapC Abs 379 normInf = emptyVal (maxElement . fst . fromComplex . vectorMapC Abs)
385 normInf = maxElement . fst . fromComplex . vectorMapC Abs 380 multiply = emptyMul multiplyC
386 multiply = multiplyC 381
382emptyMul m a b
383 | x1 == 0 && x2 == 0 || r == 0 || c == 0 = konst' 0 (r,c)
384 | otherwise = m a b
385 where
386 r = rows a
387 x1 = cols a
388 x2 = rows b
389 c = cols b
390
391emptyVal f v =
392 if dim v > 0
393 then f v
394 else 0
395
396
397-- FIXME remove unused C wrappers
398-- | (unconjugated) dot product
399udot :: Product e => Vector e -> Vector e -> e
400udot u v
401 | dim u == dim v = val (asRow u `multiply` asColumn v)
402 | otherwise = error $ "different dimensions "++show (dim u)++" and "++show (dim v)++" in dot product"
403 where
404 val m | dim u > 0 = m@@>(0,0)
405 | otherwise = 0
387 406
388---------------------------------------------------------- 407----------------------------------------------------------
389 408