summaryrefslogtreecommitdiff
path: root/lib/Numeric/ContainerBoot.hs
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Numeric/ContainerBoot.hs')
-rw-r--r--lib/Numeric/ContainerBoot.hs62
1 files changed, 35 insertions, 27 deletions
diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs
index 6445e04..ea4262c 100644
--- a/lib/Numeric/ContainerBoot.hs
+++ b/lib/Numeric/ContainerBoot.hs
@@ -45,7 +45,7 @@ import Numeric.Conversion
45import Data.Packed.Internal 45import Data.Packed.Internal
46import Numeric.GSL.Vector 46import Numeric.GSL.Vector
47import Data.Complex 47import Data.Complex
48import Control.Monad(ap) 48import Control.Applicative((<*>))
49 49
50import Numeric.LinearAlgebra.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ) 50import Numeric.LinearAlgebra.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ)
51 51
@@ -206,10 +206,10 @@ instance Container Vector Float where
206 conj = id 206 conj = id
207 cmap = mapVector 207 cmap = mapVector
208 atIndex = (@>) 208 atIndex = (@>)
209 minIndex = round . toScalarF MinIdx 209 minIndex = emptyErrorV "minIndex" (round . toScalarF MinIdx)
210 maxIndex = round . toScalarF MaxIdx 210 maxIndex = emptyErrorV "maxIndex" (round . toScalarF MaxIdx)
211 minElement = toScalarF Min 211 minElement = emptyErrorV "minElement" (toScalarF Min)
212 maxElement = toScalarF Max 212 maxElement = emptyErrorV "maxElement" (toScalarF Max)
213 sumElements = sumF 213 sumElements = sumF
214 prodElements = prodF 214 prodElements = prodF
215 step = stepF 215 step = stepF
@@ -234,10 +234,10 @@ instance Container Vector Double where
234 conj = id 234 conj = id
235 cmap = mapVector 235 cmap = mapVector
236 atIndex = (@>) 236 atIndex = (@>)
237 minIndex = round . toScalarR MinIdx 237 minIndex = emptyErrorV "minIndex" (round . toScalarR MinIdx)
238 maxIndex = round . toScalarR MaxIdx 238 maxIndex = emptyErrorV "maxIndex" (round . toScalarR MaxIdx)
239 minElement = toScalarR Min 239 minElement = emptyErrorV "minElement" (toScalarR Min)
240 maxElement = toScalarR Max 240 maxElement = emptyErrorV "maxElement" (toScalarR Max)
241 sumElements = sumR 241 sumElements = sumR
242 prodElements = prodR 242 prodElements = prodR
243 step = stepD 243 step = stepD
@@ -262,10 +262,10 @@ instance Container Vector (Complex Double) where
262 conj = conjugateC 262 conj = conjugateC
263 cmap = mapVector 263 cmap = mapVector
264 atIndex = (@>) 264 atIndex = (@>)
265 minIndex = minIndex . fst . fromComplex . (zipVectorWith (*) `ap` mapVector conjugate) 265 minIndex = emptyErrorV "minIndex" (minIndex . fst . fromComplex . (mul <*> conj))
266 maxIndex = maxIndex . fst . fromComplex . (zipVectorWith (*) `ap` mapVector conjugate) 266 maxIndex = emptyErrorV "maxIndex" (maxIndex . fst . fromComplex . (mul <*> conj))
267 minElement = ap (@>) minIndex 267 minElement = emptyErrorV "minElement" (atIndex <*> minIndex)
268 maxElement = ap (@>) maxIndex 268 maxElement = emptyErrorV "maxElement" (atIndex <*> maxIndex)
269 sumElements = sumC 269 sumElements = sumC
270 prodElements = prodC 270 prodElements = prodC
271 step = undefined -- cannot match 271 step = undefined -- cannot match
@@ -290,10 +290,10 @@ instance Container Vector (Complex Float) where
290 conj = conjugateQ 290 conj = conjugateQ
291 cmap = mapVector 291 cmap = mapVector
292 atIndex = (@>) 292 atIndex = (@>)
293 minIndex = minIndex . fst . fromComplex . (zipVectorWith (*) `ap` mapVector conjugate) 293 minIndex = emptyErrorV "minIndex" (minIndex . fst . fromComplex . (mul <*> conj))
294 maxIndex = maxIndex . fst . fromComplex . (zipVectorWith (*) `ap` mapVector conjugate) 294 maxIndex = emptyErrorV "maxIndex" (maxIndex . fst . fromComplex . (mul <*> conj))
295 minElement = ap (@>) minIndex 295 minElement = emptyErrorV "minElement" (atIndex <*> minIndex)
296 maxElement = ap (@>) maxIndex 296 maxElement = emptyErrorV "maxElement" (atIndex <*> maxIndex)
297 sumElements = sumQ 297 sumElements = sumQ
298 prodElements = prodQ 298 prodElements = prodQ
299 step = undefined -- cannot match 299 step = undefined -- cannot match
@@ -320,14 +320,12 @@ instance (Container Vector a) => Container Matrix a where
320 conj = liftMatrix conj 320 conj = liftMatrix conj
321 cmap f = liftMatrix (mapVector f) 321 cmap f = liftMatrix (mapVector f)
322 atIndex = (@@>) 322 atIndex = (@@>)
323 minIndex m = let (r,c) = (rows m,cols m) 323 minIndex = emptyErrorM "minIndex of Matrix" $
324 i = (minIndex $ flatten m) 324 \m -> divMod (minIndex $ flatten m) (cols m)
325 in (i `div` c,i `mod` c) 325 maxIndex = emptyErrorM "maxIndex of Matrix" $
326 maxIndex m = let (r,c) = (rows m,cols m) 326 \m -> divMod (maxIndex $ flatten m) (cols m)
327 i = (maxIndex $ flatten m) 327 minElement = emptyErrorM "minElement of Matrix" (atIndex <*> minIndex)
328 in (i `div` c,i `mod` c) 328 maxElement = emptyErrorM "maxElement of Matrix" (atIndex <*> maxIndex)
329 minElement = ap (@@>) minIndex
330 maxElement = ap (@@>) maxIndex
331 sumElements = sumElements . flatten 329 sumElements = sumElements . flatten
332 prodElements = prodElements . flatten 330 prodElements = prodElements . flatten
333 step = liftMatrix step 331 step = liftMatrix step
@@ -336,6 +334,17 @@ instance (Container Vector a) => Container Matrix a where
336 accum = accumM 334 accum = accumM
337 cond = condM 335 cond = condM
338 336
337
338emptyErrorV msg f v =
339 if dim v > 0
340 then f v
341 else error $ msg ++ " of Vector with dim = 0"
342
343emptyErrorM msg f m =
344 if rows m > 0 && cols m > 0
345 then f m
346 else error $ msg++" "++shSize m
347
339---------------------------------------------------- 348----------------------------------------------------
340 349
341-- | Matrix product and related functions 350-- | Matrix product and related functions
@@ -393,7 +402,6 @@ emptyVal f v =
393 then f v 402 then f v
394 else 0 403 else 0
395 404
396
397-- FIXME remove unused C wrappers 405-- FIXME remove unused C wrappers
398-- | (unconjugated) dot product 406-- | (unconjugated) dot product
399udot :: Product e => Vector e -> Vector e -> e 407udot :: Product e => Vector e -> Vector e -> e
@@ -592,7 +600,7 @@ accumM m0 f xs = ST.runSTMatrix $ do
592 600
593---------------------------------------------------------------------- 601----------------------------------------------------------------------
594 602
595condM a b l e t = reshape (cols a'') $ cond a' b' l' e' t' 603condM a b l e t = matrixFromVector RowMajor (rows a'') (cols a'') $ cond a' b' l' e' t'
596 where 604 where
597 args@(a'':_) = conformMs [a,b,l,e,t] 605 args@(a'':_) = conformMs [a,b,l,e,t]
598 [a', b', l', e', t'] = map flatten args 606 [a', b', l', e', t'] = map flatten args