summaryrefslogtreecommitdiff
path: root/lib/Numeric
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Numeric')
-rw-r--r--lib/Numeric/Container.hs2
-rw-r--r--lib/Numeric/ContainerBoot.hs77
-rw-r--r--lib/Numeric/GSL/Vector.hs3
3 files changed, 51 insertions, 31 deletions
diff --git a/lib/Numeric/Container.hs b/lib/Numeric/Container.hs
index a71fdfe..b145a26 100644
--- a/lib/Numeric/Container.hs
+++ b/lib/Numeric/Container.hs
@@ -36,7 +36,7 @@ module Numeric.Container (
36 -- * Generic operations 36 -- * Generic operations
37 Container(..), 37 Container(..),
38 -- * Matrix product 38 -- * Matrix product
39 Product(..), 39 Product(..), udot,
40 Mul(..), 40 Mul(..),
41 Contraction(..), mmul, 41 Contraction(..), mmul,
42 optimiseMult, 42 optimiseMult,
diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs
index a333489..6445e04 100644
--- a/lib/Numeric/ContainerBoot.hs
+++ b/lib/Numeric/ContainerBoot.hs
@@ -25,7 +25,7 @@ module Numeric.ContainerBoot (
25 -- * Generic operations 25 -- * Generic operations
26 Container(..), 26 Container(..),
27 -- * Matrix product and related functions 27 -- * Matrix product and related functions
28 Product(..), 28 Product(..), udot,
29 mXm,mXv,vXm, 29 mXm,mXv,vXm,
30 outer, kronecker, 30 outer, kronecker,
31 -- * Element conversion 31 -- * Element conversion
@@ -315,7 +315,7 @@ instance (Container Vector a) => Container Matrix a where
315 equal a b = cols a == cols b && flatten a `equal` flatten b 315 equal a b = cols a == cols b && flatten a `equal` flatten b
316 arctan2 = liftMatrix2 arctan2 316 arctan2 = liftMatrix2 arctan2
317 scalar x = (1><1) [x] 317 scalar x = (1><1) [x]
318 konst' v (r,c) = reshape c (konst' v (r*c)) 318 konst' v (r,c) = matrixFromVector RowMajor r c (konst' v (r*c))
319 build' = buildM 319 build' = buildM
320 conj = liftMatrix conj 320 conj = liftMatrix conj
321 cmap f = liftMatrix (mapVector f) 321 cmap f = liftMatrix (mapVector f)
@@ -339,11 +339,9 @@ instance (Container Vector a) => Container Matrix a where
339---------------------------------------------------- 339----------------------------------------------------
340 340
341-- | Matrix product and related functions 341-- | Matrix product and related functions
342class Element e => Product e where 342class (Num e, Element e) => Product e where
343 -- | matrix product 343 -- | matrix product
344 multiply :: Matrix e -> Matrix e -> Matrix e 344 multiply :: Matrix e -> Matrix e -> Matrix e
345 -- | (unconjugated) dot product
346 udot :: Vector e -> Vector e -> e
347 -- | sum of absolute value of elements (differs in complex case from @norm1@) 345 -- | sum of absolute value of elements (differs in complex case from @norm1@)
348 absSum :: Vector e -> RealOf e 346 absSum :: Vector e -> RealOf e
349 -- | sum of absolute value of elements 347 -- | sum of absolute value of elements
@@ -354,36 +352,57 @@ class Element e => Product e where
354 normInf :: Vector e -> RealOf e 352 normInf :: Vector e -> RealOf e
355 353
356instance Product Float where 354instance Product Float where
357 norm2 = toScalarF Norm2 355 norm2 = emptyVal (toScalarF Norm2)
358 absSum = toScalarF AbsSum 356 absSum = emptyVal (toScalarF AbsSum)
359 udot = dotF 357 norm1 = emptyVal (toScalarF AbsSum)
360 norm1 = toScalarF AbsSum 358 normInf = emptyVal (maxElement . vectorMapF Abs)
361 normInf = maxElement . vectorMapF Abs 359 multiply = emptyMul multiplyF
362 multiply = multiplyF
363 360
364instance Product Double where 361instance Product Double where
365 norm2 = toScalarR Norm2 362 norm2 = emptyVal (toScalarR Norm2)
366 absSum = toScalarR AbsSum 363 absSum = emptyVal (toScalarR AbsSum)
367 udot = dotR 364 norm1 = emptyVal (toScalarR AbsSum)
368 norm1 = toScalarR AbsSum 365 normInf = emptyVal (maxElement . vectorMapR Abs)
369 normInf = maxElement . vectorMapR Abs 366 multiply = emptyMul multiplyR
370 multiply = multiplyR
371 367
372instance Product (Complex Float) where 368instance Product (Complex Float) where
373 norm2 = toScalarQ Norm2 369 norm2 = emptyVal (toScalarQ Norm2)
374 absSum = toScalarQ AbsSum 370 absSum = emptyVal (toScalarQ AbsSum)
375 udot = dotQ 371 norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapQ Abs)
376 norm1 = sumElements . fst . fromComplex . vectorMapQ Abs 372 normInf = emptyVal (maxElement . fst . fromComplex . vectorMapQ Abs)
377 normInf = maxElement . fst . fromComplex . vectorMapQ Abs 373 multiply = emptyMul multiplyQ
378 multiply = multiplyQ
379 374
380instance Product (Complex Double) where 375instance Product (Complex Double) where
381 norm2 = toScalarC Norm2 376 norm2 = emptyVal (toScalarC Norm2)
382 absSum = toScalarC AbsSum 377 absSum = emptyVal (toScalarC AbsSum)
383 udot = dotC 378 norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapC Abs)
384 norm1 = sumElements . fst . fromComplex . vectorMapC Abs 379 normInf = emptyVal (maxElement . fst . fromComplex . vectorMapC Abs)
385 normInf = maxElement . fst . fromComplex . vectorMapC Abs 380 multiply = emptyMul multiplyC
386 multiply = multiplyC 381
382emptyMul m a b
383 | x1 == 0 && x2 == 0 || r == 0 || c == 0 = konst' 0 (r,c)
384 | otherwise = m a b
385 where
386 r = rows a
387 x1 = cols a
388 x2 = rows b
389 c = cols b
390
391emptyVal f v =
392 if dim v > 0
393 then f v
394 else 0
395
396
397-- FIXME remove unused C wrappers
398-- | (unconjugated) dot product
399udot :: Product e => Vector e -> Vector e -> e
400udot u v
401 | dim u == dim v = val (asRow u `multiply` asColumn v)
402 | otherwise = error $ "different dimensions "++show (dim u)++" and "++show (dim v)++" in dot product"
403 where
404 val m | dim u > 0 = m@@>(0,0)
405 | otherwise = 0
387 406
388---------------------------------------------------------- 407----------------------------------------------------------
389 408
diff --git a/lib/Numeric/GSL/Vector.hs b/lib/Numeric/GSL/Vector.hs
index db34041..6204b8e 100644
--- a/lib/Numeric/GSL/Vector.hs
+++ b/lib/Numeric/GSL/Vector.hs
@@ -33,6 +33,7 @@ import Foreign.Marshal.Array(newArray)
33import Foreign.Ptr(Ptr) 33import Foreign.Ptr(Ptr)
34import Foreign.C.Types 34import Foreign.C.Types
35import System.IO.Unsafe(unsafePerformIO) 35import System.IO.Unsafe(unsafePerformIO)
36import Control.Monad(when)
36 37
37fromei x = fromIntegral (fromEnum x) :: CInt 38fromei x = fromIntegral (fromEnum x) :: CInt
38 39
@@ -201,7 +202,7 @@ vectorMapValAux fun code val v = unsafePerformIO $ do
201 202
202vectorZipAux fun code u v = unsafePerformIO $ do 203vectorZipAux fun code u v = unsafePerformIO $ do
203 r <- createVector (dim u) 204 r <- createVector (dim u)
204 app3 (fun (fromei code)) vec u vec v vec r "vectorZipAux" 205 when (dim u > 0) $ app3 (fun (fromei code)) vec u vec v vec r "vectorZipAux"
205 return r 206 return r
206 207
207--------------------------------------------------------------------- 208---------------------------------------------------------------------