diff options
author | Alberto Ruiz <aruiz@um.es> | 2014-05-04 21:08:51 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2014-05-04 21:08:51 +0200 |
commit | 4078cf44c98b42960be27843782f6983bb66017f (patch) | |
tree | bee20c3c811a98247aab99738991ab4b2bcc2312 /lib/Numeric | |
parent | ae104ebd5891c84f9c8b4a40501fefdeeb1280c4 (diff) |
allow empty arrays
Diffstat (limited to 'lib/Numeric')
-rw-r--r-- | lib/Numeric/Container.hs | 2 | ||||
-rw-r--r-- | lib/Numeric/ContainerBoot.hs | 77 | ||||
-rw-r--r-- | lib/Numeric/GSL/Vector.hs | 3 |
3 files changed, 51 insertions, 31 deletions
diff --git a/lib/Numeric/Container.hs b/lib/Numeric/Container.hs index a71fdfe..b145a26 100644 --- a/lib/Numeric/Container.hs +++ b/lib/Numeric/Container.hs | |||
@@ -36,7 +36,7 @@ module Numeric.Container ( | |||
36 | -- * Generic operations | 36 | -- * Generic operations |
37 | Container(..), | 37 | Container(..), |
38 | -- * Matrix product | 38 | -- * Matrix product |
39 | Product(..), | 39 | Product(..), udot, |
40 | Mul(..), | 40 | Mul(..), |
41 | Contraction(..), mmul, | 41 | Contraction(..), mmul, |
42 | optimiseMult, | 42 | optimiseMult, |
diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs index a333489..6445e04 100644 --- a/lib/Numeric/ContainerBoot.hs +++ b/lib/Numeric/ContainerBoot.hs | |||
@@ -25,7 +25,7 @@ module Numeric.ContainerBoot ( | |||
25 | -- * Generic operations | 25 | -- * Generic operations |
26 | Container(..), | 26 | Container(..), |
27 | -- * Matrix product and related functions | 27 | -- * Matrix product and related functions |
28 | Product(..), | 28 | Product(..), udot, |
29 | mXm,mXv,vXm, | 29 | mXm,mXv,vXm, |
30 | outer, kronecker, | 30 | outer, kronecker, |
31 | -- * Element conversion | 31 | -- * Element conversion |
@@ -315,7 +315,7 @@ instance (Container Vector a) => Container Matrix a where | |||
315 | equal a b = cols a == cols b && flatten a `equal` flatten b | 315 | equal a b = cols a == cols b && flatten a `equal` flatten b |
316 | arctan2 = liftMatrix2 arctan2 | 316 | arctan2 = liftMatrix2 arctan2 |
317 | scalar x = (1><1) [x] | 317 | scalar x = (1><1) [x] |
318 | konst' v (r,c) = reshape c (konst' v (r*c)) | 318 | konst' v (r,c) = matrixFromVector RowMajor r c (konst' v (r*c)) |
319 | build' = buildM | 319 | build' = buildM |
320 | conj = liftMatrix conj | 320 | conj = liftMatrix conj |
321 | cmap f = liftMatrix (mapVector f) | 321 | cmap f = liftMatrix (mapVector f) |
@@ -339,11 +339,9 @@ instance (Container Vector a) => Container Matrix a where | |||
339 | ---------------------------------------------------- | 339 | ---------------------------------------------------- |
340 | 340 | ||
341 | -- | Matrix product and related functions | 341 | -- | Matrix product and related functions |
342 | class Element e => Product e where | 342 | class (Num e, Element e) => Product e where |
343 | -- | matrix product | 343 | -- | matrix product |
344 | multiply :: Matrix e -> Matrix e -> Matrix e | 344 | multiply :: Matrix e -> Matrix e -> Matrix e |
345 | -- | (unconjugated) dot product | ||
346 | udot :: Vector e -> Vector e -> e | ||
347 | -- | sum of absolute value of elements (differs in complex case from @norm1@) | 345 | -- | sum of absolute value of elements (differs in complex case from @norm1@) |
348 | absSum :: Vector e -> RealOf e | 346 | absSum :: Vector e -> RealOf e |
349 | -- | sum of absolute value of elements | 347 | -- | sum of absolute value of elements |
@@ -354,36 +352,57 @@ class Element e => Product e where | |||
354 | normInf :: Vector e -> RealOf e | 352 | normInf :: Vector e -> RealOf e |
355 | 353 | ||
356 | instance Product Float where | 354 | instance Product Float where |
357 | norm2 = toScalarF Norm2 | 355 | norm2 = emptyVal (toScalarF Norm2) |
358 | absSum = toScalarF AbsSum | 356 | absSum = emptyVal (toScalarF AbsSum) |
359 | udot = dotF | 357 | norm1 = emptyVal (toScalarF AbsSum) |
360 | norm1 = toScalarF AbsSum | 358 | normInf = emptyVal (maxElement . vectorMapF Abs) |
361 | normInf = maxElement . vectorMapF Abs | 359 | multiply = emptyMul multiplyF |
362 | multiply = multiplyF | ||
363 | 360 | ||
364 | instance Product Double where | 361 | instance Product Double where |
365 | norm2 = toScalarR Norm2 | 362 | norm2 = emptyVal (toScalarR Norm2) |
366 | absSum = toScalarR AbsSum | 363 | absSum = emptyVal (toScalarR AbsSum) |
367 | udot = dotR | 364 | norm1 = emptyVal (toScalarR AbsSum) |
368 | norm1 = toScalarR AbsSum | 365 | normInf = emptyVal (maxElement . vectorMapR Abs) |
369 | normInf = maxElement . vectorMapR Abs | 366 | multiply = emptyMul multiplyR |
370 | multiply = multiplyR | ||
371 | 367 | ||
372 | instance Product (Complex Float) where | 368 | instance Product (Complex Float) where |
373 | norm2 = toScalarQ Norm2 | 369 | norm2 = emptyVal (toScalarQ Norm2) |
374 | absSum = toScalarQ AbsSum | 370 | absSum = emptyVal (toScalarQ AbsSum) |
375 | udot = dotQ | 371 | norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapQ Abs) |
376 | norm1 = sumElements . fst . fromComplex . vectorMapQ Abs | 372 | normInf = emptyVal (maxElement . fst . fromComplex . vectorMapQ Abs) |
377 | normInf = maxElement . fst . fromComplex . vectorMapQ Abs | 373 | multiply = emptyMul multiplyQ |
378 | multiply = multiplyQ | ||
379 | 374 | ||
380 | instance Product (Complex Double) where | 375 | instance Product (Complex Double) where |
381 | norm2 = toScalarC Norm2 | 376 | norm2 = emptyVal (toScalarC Norm2) |
382 | absSum = toScalarC AbsSum | 377 | absSum = emptyVal (toScalarC AbsSum) |
383 | udot = dotC | 378 | norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapC Abs) |
384 | norm1 = sumElements . fst . fromComplex . vectorMapC Abs | 379 | normInf = emptyVal (maxElement . fst . fromComplex . vectorMapC Abs) |
385 | normInf = maxElement . fst . fromComplex . vectorMapC Abs | 380 | multiply = emptyMul multiplyC |
386 | multiply = multiplyC | 381 | |
382 | emptyMul m a b | ||
383 | | x1 == 0 && x2 == 0 || r == 0 || c == 0 = konst' 0 (r,c) | ||
384 | | otherwise = m a b | ||
385 | where | ||
386 | r = rows a | ||
387 | x1 = cols a | ||
388 | x2 = rows b | ||
389 | c = cols b | ||
390 | |||
391 | emptyVal f v = | ||
392 | if dim v > 0 | ||
393 | then f v | ||
394 | else 0 | ||
395 | |||
396 | |||
397 | -- FIXME remove unused C wrappers | ||
398 | -- | (unconjugated) dot product | ||
399 | udot :: Product e => Vector e -> Vector e -> e | ||
400 | udot u v | ||
401 | | dim u == dim v = val (asRow u `multiply` asColumn v) | ||
402 | | otherwise = error $ "different dimensions "++show (dim u)++" and "++show (dim v)++" in dot product" | ||
403 | where | ||
404 | val m | dim u > 0 = m@@>(0,0) | ||
405 | | otherwise = 0 | ||
387 | 406 | ||
388 | ---------------------------------------------------------- | 407 | ---------------------------------------------------------- |
389 | 408 | ||
diff --git a/lib/Numeric/GSL/Vector.hs b/lib/Numeric/GSL/Vector.hs index db34041..6204b8e 100644 --- a/lib/Numeric/GSL/Vector.hs +++ b/lib/Numeric/GSL/Vector.hs | |||
@@ -33,6 +33,7 @@ import Foreign.Marshal.Array(newArray) | |||
33 | import Foreign.Ptr(Ptr) | 33 | import Foreign.Ptr(Ptr) |
34 | import Foreign.C.Types | 34 | import Foreign.C.Types |
35 | import System.IO.Unsafe(unsafePerformIO) | 35 | import System.IO.Unsafe(unsafePerformIO) |
36 | import Control.Monad(when) | ||
36 | 37 | ||
37 | fromei x = fromIntegral (fromEnum x) :: CInt | 38 | fromei x = fromIntegral (fromEnum x) :: CInt |
38 | 39 | ||
@@ -201,7 +202,7 @@ vectorMapValAux fun code val v = unsafePerformIO $ do | |||
201 | 202 | ||
202 | vectorZipAux fun code u v = unsafePerformIO $ do | 203 | vectorZipAux fun code u v = unsafePerformIO $ do |
203 | r <- createVector (dim u) | 204 | r <- createVector (dim u) |
204 | app3 (fun (fromei code)) vec u vec v vec r "vectorZipAux" | 205 | when (dim u > 0) $ app3 (fun (fromei code)) vec u vec v vec r "vectorZipAux" |
205 | return r | 206 | return r |
206 | 207 | ||
207 | --------------------------------------------------------------------- | 208 | --------------------------------------------------------------------- |