diff options
author | Alberto Ruiz <aruiz@um.es> | 2014-05-06 08:50:50 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2014-05-06 08:50:50 +0200 |
commit | c9914d694d3b86ece46fa0c76e0466c6cd394d14 (patch) | |
tree | 7fa1c5a95b204912f5d560c843ae6045ee8d2780 /lib/Numeric/ContainerBoot.hs | |
parent | 4078cf44c98b42960be27843782f6983bb66017f (diff) |
extend conformability to empty arrays
Diffstat (limited to 'lib/Numeric/ContainerBoot.hs')
-rw-r--r-- | lib/Numeric/ContainerBoot.hs | 62 |
1 files changed, 35 insertions, 27 deletions
diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs index 6445e04..ea4262c 100644 --- a/lib/Numeric/ContainerBoot.hs +++ b/lib/Numeric/ContainerBoot.hs | |||
@@ -45,7 +45,7 @@ import Numeric.Conversion | |||
45 | import Data.Packed.Internal | 45 | import Data.Packed.Internal |
46 | import Numeric.GSL.Vector | 46 | import Numeric.GSL.Vector |
47 | import Data.Complex | 47 | import Data.Complex |
48 | import Control.Monad(ap) | 48 | import Control.Applicative((<*>)) |
49 | 49 | ||
50 | import Numeric.LinearAlgebra.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ) | 50 | import Numeric.LinearAlgebra.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ) |
51 | 51 | ||
@@ -206,10 +206,10 @@ instance Container Vector Float where | |||
206 | conj = id | 206 | conj = id |
207 | cmap = mapVector | 207 | cmap = mapVector |
208 | atIndex = (@>) | 208 | atIndex = (@>) |
209 | minIndex = round . toScalarF MinIdx | 209 | minIndex = emptyErrorV "minIndex" (round . toScalarF MinIdx) |
210 | maxIndex = round . toScalarF MaxIdx | 210 | maxIndex = emptyErrorV "maxIndex" (round . toScalarF MaxIdx) |
211 | minElement = toScalarF Min | 211 | minElement = emptyErrorV "minElement" (toScalarF Min) |
212 | maxElement = toScalarF Max | 212 | maxElement = emptyErrorV "maxElement" (toScalarF Max) |
213 | sumElements = sumF | 213 | sumElements = sumF |
214 | prodElements = prodF | 214 | prodElements = prodF |
215 | step = stepF | 215 | step = stepF |
@@ -234,10 +234,10 @@ instance Container Vector Double where | |||
234 | conj = id | 234 | conj = id |
235 | cmap = mapVector | 235 | cmap = mapVector |
236 | atIndex = (@>) | 236 | atIndex = (@>) |
237 | minIndex = round . toScalarR MinIdx | 237 | minIndex = emptyErrorV "minIndex" (round . toScalarR MinIdx) |
238 | maxIndex = round . toScalarR MaxIdx | 238 | maxIndex = emptyErrorV "maxIndex" (round . toScalarR MaxIdx) |
239 | minElement = toScalarR Min | 239 | minElement = emptyErrorV "minElement" (toScalarR Min) |
240 | maxElement = toScalarR Max | 240 | maxElement = emptyErrorV "maxElement" (toScalarR Max) |
241 | sumElements = sumR | 241 | sumElements = sumR |
242 | prodElements = prodR | 242 | prodElements = prodR |
243 | step = stepD | 243 | step = stepD |
@@ -262,10 +262,10 @@ instance Container Vector (Complex Double) where | |||
262 | conj = conjugateC | 262 | conj = conjugateC |
263 | cmap = mapVector | 263 | cmap = mapVector |
264 | atIndex = (@>) | 264 | atIndex = (@>) |
265 | minIndex = minIndex . fst . fromComplex . (zipVectorWith (*) `ap` mapVector conjugate) | 265 | minIndex = emptyErrorV "minIndex" (minIndex . fst . fromComplex . (mul <*> conj)) |
266 | maxIndex = maxIndex . fst . fromComplex . (zipVectorWith (*) `ap` mapVector conjugate) | 266 | maxIndex = emptyErrorV "maxIndex" (maxIndex . fst . fromComplex . (mul <*> conj)) |
267 | minElement = ap (@>) minIndex | 267 | minElement = emptyErrorV "minElement" (atIndex <*> minIndex) |
268 | maxElement = ap (@>) maxIndex | 268 | maxElement = emptyErrorV "maxElement" (atIndex <*> maxIndex) |
269 | sumElements = sumC | 269 | sumElements = sumC |
270 | prodElements = prodC | 270 | prodElements = prodC |
271 | step = undefined -- cannot match | 271 | step = undefined -- cannot match |
@@ -290,10 +290,10 @@ instance Container Vector (Complex Float) where | |||
290 | conj = conjugateQ | 290 | conj = conjugateQ |
291 | cmap = mapVector | 291 | cmap = mapVector |
292 | atIndex = (@>) | 292 | atIndex = (@>) |
293 | minIndex = minIndex . fst . fromComplex . (zipVectorWith (*) `ap` mapVector conjugate) | 293 | minIndex = emptyErrorV "minIndex" (minIndex . fst . fromComplex . (mul <*> conj)) |
294 | maxIndex = maxIndex . fst . fromComplex . (zipVectorWith (*) `ap` mapVector conjugate) | 294 | maxIndex = emptyErrorV "maxIndex" (maxIndex . fst . fromComplex . (mul <*> conj)) |
295 | minElement = ap (@>) minIndex | 295 | minElement = emptyErrorV "minElement" (atIndex <*> minIndex) |
296 | maxElement = ap (@>) maxIndex | 296 | maxElement = emptyErrorV "maxElement" (atIndex <*> maxIndex) |
297 | sumElements = sumQ | 297 | sumElements = sumQ |
298 | prodElements = prodQ | 298 | prodElements = prodQ |
299 | step = undefined -- cannot match | 299 | step = undefined -- cannot match |
@@ -320,14 +320,12 @@ instance (Container Vector a) => Container Matrix a where | |||
320 | conj = liftMatrix conj | 320 | conj = liftMatrix conj |
321 | cmap f = liftMatrix (mapVector f) | 321 | cmap f = liftMatrix (mapVector f) |
322 | atIndex = (@@>) | 322 | atIndex = (@@>) |
323 | minIndex m = let (r,c) = (rows m,cols m) | 323 | minIndex = emptyErrorM "minIndex of Matrix" $ |
324 | i = (minIndex $ flatten m) | 324 | \m -> divMod (minIndex $ flatten m) (cols m) |
325 | in (i `div` c,i `mod` c) | 325 | maxIndex = emptyErrorM "maxIndex of Matrix" $ |
326 | maxIndex m = let (r,c) = (rows m,cols m) | 326 | \m -> divMod (maxIndex $ flatten m) (cols m) |
327 | i = (maxIndex $ flatten m) | 327 | minElement = emptyErrorM "minElement of Matrix" (atIndex <*> minIndex) |
328 | in (i `div` c,i `mod` c) | 328 | maxElement = emptyErrorM "maxElement of Matrix" (atIndex <*> maxIndex) |
329 | minElement = ap (@@>) minIndex | ||
330 | maxElement = ap (@@>) maxIndex | ||
331 | sumElements = sumElements . flatten | 329 | sumElements = sumElements . flatten |
332 | prodElements = prodElements . flatten | 330 | prodElements = prodElements . flatten |
333 | step = liftMatrix step | 331 | step = liftMatrix step |
@@ -336,6 +334,17 @@ instance (Container Vector a) => Container Matrix a where | |||
336 | accum = accumM | 334 | accum = accumM |
337 | cond = condM | 335 | cond = condM |
338 | 336 | ||
337 | |||
338 | emptyErrorV msg f v = | ||
339 | if dim v > 0 | ||
340 | then f v | ||
341 | else error $ msg ++ " of Vector with dim = 0" | ||
342 | |||
343 | emptyErrorM msg f m = | ||
344 | if rows m > 0 && cols m > 0 | ||
345 | then f m | ||
346 | else error $ msg++" "++shSize m | ||
347 | |||
339 | ---------------------------------------------------- | 348 | ---------------------------------------------------- |
340 | 349 | ||
341 | -- | Matrix product and related functions | 350 | -- | Matrix product and related functions |
@@ -393,7 +402,6 @@ emptyVal f v = | |||
393 | then f v | 402 | then f v |
394 | else 0 | 403 | else 0 |
395 | 404 | ||
396 | |||
397 | -- FIXME remove unused C wrappers | 405 | -- FIXME remove unused C wrappers |
398 | -- | (unconjugated) dot product | 406 | -- | (unconjugated) dot product |
399 | udot :: Product e => Vector e -> Vector e -> e | 407 | udot :: Product e => Vector e -> Vector e -> e |
@@ -592,7 +600,7 @@ accumM m0 f xs = ST.runSTMatrix $ do | |||
592 | 600 | ||
593 | ---------------------------------------------------------------------- | 601 | ---------------------------------------------------------------------- |
594 | 602 | ||
595 | condM a b l e t = reshape (cols a'') $ cond a' b' l' e' t' | 603 | condM a b l e t = matrixFromVector RowMajor (rows a'') (cols a'') $ cond a' b' l' e' t' |
596 | where | 604 | where |
597 | args@(a'':_) = conformMs [a,b,l,e,t] | 605 | args@(a'':_) = conformMs [a,b,l,e,t] |
598 | [a', b', l', e', t'] = map flatten args | 606 | [a', b', l', e', t'] = map flatten args |