diff options
Diffstat (limited to 'lib/Numeric/ContainerBoot.hs')
-rw-r--r-- | lib/Numeric/ContainerBoot.hs | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs index 992a501..5a8e243 100644 --- a/lib/Numeric/ContainerBoot.hs +++ b/lib/Numeric/ContainerBoot.hs | |||
@@ -45,6 +45,7 @@ module Numeric.ContainerBoot ( | |||
45 | ) where | 45 | ) where |
46 | 46 | ||
47 | import Data.Packed | 47 | import Data.Packed |
48 | import Data.Packed.ST as ST | ||
48 | import Numeric.Conversion | 49 | import Numeric.Conversion |
49 | import Data.Packed.Internal | 50 | import Data.Packed.Internal |
50 | import Numeric.GSL.Vector | 51 | import Numeric.GSL.Vector |
@@ -120,6 +121,12 @@ class (Complexable c, Fractional e, Element e) => Container c e where | |||
120 | sumElements :: c e -> e | 121 | sumElements :: c e -> e |
121 | -- | the product of elements (faster than using @fold@) | 122 | -- | the product of elements (faster than using @fold@) |
122 | prodElements :: c e -> e | 123 | prodElements :: c e -> e |
124 | -- | map (if x_i>0 then 1.0 else 0.0) | ||
125 | step :: RealFloat e => c e -> c e | ||
126 | -- | find index of elements which satisfy a predicate | ||
127 | find :: (e -> Bool) -> c e -> [IndexOf c] | ||
128 | -- | create a structure from an association list | ||
129 | assoc :: IndexOf c -> e -> [(IndexOf c, e)] -> c e | ||
123 | 130 | ||
124 | -------------------------------------------------------------------------- | 131 | -------------------------------------------------------------------------- |
125 | 132 | ||
@@ -145,6 +152,9 @@ instance Container Vector Float where | |||
145 | maxElement = toScalarF Max | 152 | maxElement = toScalarF Max |
146 | sumElements = sumF | 153 | sumElements = sumF |
147 | prodElements = prodF | 154 | prodElements = prodF |
155 | step = stepF | ||
156 | find = findV | ||
157 | assoc = assocV | ||
148 | 158 | ||
149 | instance Container Vector Double where | 159 | instance Container Vector Double where |
150 | scale = vectorMapValR Scale | 160 | scale = vectorMapValR Scale |
@@ -168,6 +178,9 @@ instance Container Vector Double where | |||
168 | maxElement = toScalarR Max | 178 | maxElement = toScalarR Max |
169 | sumElements = sumR | 179 | sumElements = sumR |
170 | prodElements = prodR | 180 | prodElements = prodR |
181 | step = stepD | ||
182 | find = findV | ||
183 | assoc = assocV | ||
171 | 184 | ||
172 | instance Container Vector (Complex Double) where | 185 | instance Container Vector (Complex Double) where |
173 | scale = vectorMapValC Scale | 186 | scale = vectorMapValC Scale |
@@ -191,6 +204,9 @@ instance Container Vector (Complex Double) where | |||
191 | maxElement = ap (@>) maxIndex | 204 | maxElement = ap (@>) maxIndex |
192 | sumElements = sumC | 205 | sumElements = sumC |
193 | prodElements = prodC | 206 | prodElements = prodC |
207 | step = undefined -- cannot match | ||
208 | find = findV | ||
209 | assoc = assocV | ||
194 | 210 | ||
195 | instance Container Vector (Complex Float) where | 211 | instance Container Vector (Complex Float) where |
196 | scale = vectorMapValQ Scale | 212 | scale = vectorMapValQ Scale |
@@ -214,6 +230,9 @@ instance Container Vector (Complex Float) where | |||
214 | maxElement = ap (@>) maxIndex | 230 | maxElement = ap (@>) maxIndex |
215 | sumElements = sumQ | 231 | sumElements = sumQ |
216 | prodElements = prodQ | 232 | prodElements = prodQ |
233 | step = undefined -- cannot match | ||
234 | find = findV | ||
235 | assoc = assocV | ||
217 | 236 | ||
218 | --------------------------------------------------------------- | 237 | --------------------------------------------------------------- |
219 | 238 | ||
@@ -243,6 +262,9 @@ instance (Container Vector a) => Container Matrix a where | |||
243 | maxElement = ap (@@>) maxIndex | 262 | maxElement = ap (@@>) maxIndex |
244 | sumElements = sumElements . flatten | 263 | sumElements = sumElements . flatten |
245 | prodElements = prodElements . flatten | 264 | prodElements = prodElements . flatten |
265 | step = liftMatrix step | ||
266 | find = findM | ||
267 | assoc = assocM | ||
246 | 268 | ||
247 | ---------------------------------------------------- | 269 | ---------------------------------------------------- |
248 | 270 | ||
@@ -580,3 +602,21 @@ diag v = diagRect 0 v n n where n = dim v | |||
580 | -- | creates the identity matrix of given dimension | 602 | -- | creates the identity matrix of given dimension |
581 | ident :: (Num a, Element a) => Int -> Matrix a | 603 | ident :: (Num a, Element a) => Int -> Matrix a |
582 | ident n = diag (constantD 1 n) | 604 | ident n = diag (constantD 1 n) |
605 | |||
606 | -------------------------------------------------------- | ||
607 | |||
608 | findV p x = foldVectorWithIndex g [] x where | ||
609 | g k z l = if p z then k:l else l | ||
610 | |||
611 | findM p x = map ((`divMod` cols x)) $ findV p (flatten x) | ||
612 | |||
613 | assocV n z xs = ST.runSTVector $ do | ||
614 | v <- ST.newVector z n | ||
615 | mapM_ (\(k,x) -> ST.writeVector v k x) xs | ||
616 | return v | ||
617 | |||
618 | assocM (r,c) z xs = ST.runSTMatrix $ do | ||
619 | m <- ST.newMatrix z r c | ||
620 | mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs | ||
621 | return m | ||
622 | |||