diff options
Diffstat (limited to 'packages/base/src/Internal/Container.hs')
-rw-r--r-- | packages/base/src/Internal/Container.hs | 96 |
1 files changed, 93 insertions, 3 deletions
diff --git a/packages/base/src/Internal/Container.hs b/packages/base/src/Internal/Container.hs index 41b8214..0f2e7d5 100644 --- a/packages/base/src/Internal/Container.hs +++ b/packages/base/src/Internal/Container.hs | |||
@@ -4,6 +4,8 @@ | |||
4 | {-# LANGUAGE MultiParamTypeClasses #-} | 4 | {-# LANGUAGE MultiParamTypeClasses #-} |
5 | {-# LANGUAGE FunctionalDependencies #-} | 5 | {-# LANGUAGE FunctionalDependencies #-} |
6 | {-# LANGUAGE UndecidableInstances #-} | 6 | {-# LANGUAGE UndecidableInstances #-} |
7 | {-# LANGUAGE PatternSynonyms #-} | ||
8 | {-# LANGUAGE ScopedTypeVariables #-} | ||
7 | 9 | ||
8 | {-# OPTIONS_GHC -fno-warn-simplifiable-class-constraints #-} | 10 | {-# OPTIONS_GHC -fno-warn-simplifiable-class-constraints #-} |
9 | 11 | ||
@@ -30,8 +32,15 @@ module Internal.Container where | |||
30 | import Internal.Vector | 32 | import Internal.Vector |
31 | import Internal.Matrix | 33 | import Internal.Matrix |
32 | import Internal.Element | 34 | import Internal.Element |
35 | import Internal.Extract(requires,pattern BAD_SIZE) | ||
33 | import Internal.Numeric | 36 | import Internal.Numeric |
34 | import Internal.Algorithms(Field,linearSolveSVD,Herm,mTm) | 37 | import Internal.Algorithms(Field,linearSolveSVD,Herm,mTm) |
38 | import Control.Monad(when) | ||
39 | import Data.Function | ||
40 | import Data.Int | ||
41 | import Foreign.Ptr | ||
42 | import Foreign.Storable | ||
43 | import Foreign.Marshal.Array | ||
35 | #if MIN_VERSION_base(4,11,0) | 44 | #if MIN_VERSION_base(4,11,0) |
36 | import Prelude hiding ((<>)) | 45 | import Prelude hiding ((<>)) |
37 | #endif | 46 | #endif |
@@ -227,7 +236,7 @@ meanCov x = (med,cov) where | |||
227 | 236 | ||
228 | -------------------------------------------------------------------------------- | 237 | -------------------------------------------------------------------------------- |
229 | 238 | ||
230 | sortVector :: (Ord t, Element t) => Vector t -> Vector t | 239 | sortVector :: (Ord t, Storable t) => Vector t -> Vector t |
231 | sortVector = sortV | 240 | sortVector = sortV |
232 | 241 | ||
233 | {- | | 242 | {- | |
@@ -248,7 +257,7 @@ sortVector = sortV | |||
248 | -2.20 0.11 -1.58 -0.01 0.19 -0.29 1.04 1.06 -2.09 -0.75 | 257 | -2.20 0.11 -1.58 -0.01 0.19 -0.29 1.04 1.06 -2.09 -0.75 |
249 | 258 | ||
250 | -} | 259 | -} |
251 | sortIndex :: (Ord t, Element t) => Vector t -> Vector I | 260 | sortIndex :: (Ord t, Storable t) => Vector t -> Vector I |
252 | sortIndex = sortI | 261 | sortIndex = sortI |
253 | 262 | ||
254 | ccompare :: (Ord t, Container c t) => c t -> c t -> c I | 263 | ccompare :: (Ord t, Container c t) => c t -> c t -> c I |
@@ -296,10 +305,91 @@ The indexes are autoconformable. | |||
296 | , 10, 16, 22 ] | 305 | , 10, 16, 22 ] |
297 | 306 | ||
298 | -} | 307 | -} |
299 | remap :: Element t => Matrix I -> Matrix I -> Matrix t -> Matrix t | 308 | remap :: Storable t => Matrix I -> Matrix I -> Matrix t -> Matrix t |
300 | remap i j m | 309 | remap i j m |
301 | | minElement i >= 0 && maxElement i < fromIntegral (rows m) && | 310 | | minElement i >= 0 && maxElement i < fromIntegral (rows m) && |
302 | minElement j >= 0 && maxElement j < fromIntegral (cols m) = remapM i' j' m | 311 | minElement j >= 0 && maxElement j < fromIntegral (cols m) = remapM i' j' m |
303 | | otherwise = error $ "out of range index in remap" | 312 | | otherwise = error $ "out of range index in remap" |
304 | where | 313 | where |
305 | [i',j'] = conformMs [i,j] | 314 | [i',j'] = conformMs [i,j] |
315 | |||
316 | sortI :: (Storable a, Ord a) => Vector a -> Vector Int32 | ||
317 | sortI = sortG sort_index | ||
318 | |||
319 | type C_Compare a = Ptr a -> Ptr a -> IO Int32 | ||
320 | |||
321 | foreign import ccall "wrapper" wrapCompare :: C_Compare a -> IO (FunPtr (C_Compare a)) | ||
322 | |||
323 | foreign import ccall "qsort" | ||
324 | c_qsort :: Ptr a -- ^ base | ||
325 | -> Word -- ^ nmemb | ||
326 | -> Word -- ^ size | ||
327 | -> FunPtr (C_Compare a) -- ^ compar | ||
328 | -> IO () | ||
329 | |||
330 | sizeOfElem :: forall a. Storable a => Ptr a -> Int | ||
331 | sizeOfElem _ = sizeOf (undefined :: a) | ||
332 | |||
333 | sort_index :: (Storable a, Ord a) => | ||
334 | Int32 -> Ptr a | ||
335 | -> Int32 -> Ptr Int32 | ||
336 | -> IO Int32 | ||
337 | sort_index vn vp rn rp = do | ||
338 | requires (vn == rn) BAD_SIZE $ do | ||
339 | comp <- wrapCompare $ \ap bp -> do | ||
340 | a <- peekElemOff vp . fromIntegral =<< peek (ap :: Ptr Int32) | ||
341 | b <- peekElemOff vp . fromIntegral =<< peek bp | ||
342 | return $ case compare a b of | ||
343 | LT -> -1 | ||
344 | GT -> 1 | ||
345 | EQ -> 0 | ||
346 | sequence_ [ pokeElemOff rp (fromIntegral i) i | i <- [0 .. rn-1] ] | ||
347 | c_qsort rp (fromIntegral rn) 4 comp | ||
348 | freeHaskellFunPtr comp | ||
349 | return 0 | ||
350 | |||
351 | sortV :: (Storable a, Ord a) => Vector a -> Vector a | ||
352 | sortV = sortG sortStorable | ||
353 | |||
354 | sortStorable :: (Storable a, Ord a) => | ||
355 | Int32 -> Ptr a | ||
356 | -> Int32 -> Ptr a | ||
357 | -> IO Int32 | ||
358 | sortStorable vn vp rn rp = do | ||
359 | requires (vn == rn) BAD_SIZE $ do | ||
360 | copyArray rp vp (fromIntegral vn * sizeOfElem vp) | ||
361 | comp <- wrapCompare $ \ap bp -> do | ||
362 | a <- peek ap | ||
363 | b <- peek bp | ||
364 | return $ case compare a b of | ||
365 | LT -> -1 | ||
366 | GT -> 1 | ||
367 | EQ -> 0 | ||
368 | c_qsort rp (fromIntegral rn) (fromIntegral $ sizeOfElem rp) comp | ||
369 | freeHaskellFunPtr comp | ||
370 | return 0 | ||
371 | |||
372 | remapM :: Storable a => Matrix Int32 -> Matrix Int32 -> Matrix a -> Matrix a | ||
373 | remapM = remapG remapStorable | ||
374 | |||
375 | remapStorable :: Storable a => | ||
376 | Int32 -> Int32 -> Int32 -> Int32 -> Ptr Int32 -- i | ||
377 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Int32 -- j | ||
378 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr a -- m | ||
379 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr a -- r | ||
380 | -> IO Int32 | ||
381 | remapStorable ir ic iXr iXc ip | ||
382 | jr jc jXr jXc jp | ||
383 | mr mc mXr mXc mp | ||
384 | rr rc rXr rXc rp = do | ||
385 | requires (ir==jr && ic==jc && ir==rr && ic==rc) BAD_SIZE $ do | ||
386 | ($ 0) $ fix $ \aloop a -> when (a<rr) $ do | ||
387 | ($ 0) $ fix $ \bloop b -> when (b<rc) $ do | ||
388 | iab <- peekElemOff ip (fromIntegral $ iXr*a + iXc*b) | ||
389 | jab <- peekElemOff jp (fromIntegral $ jXr*a + jXc*b) | ||
390 | when (0 <= iab && iab < mr && 0 <= jab && jab < mc) $ | ||
391 | pokeElemOff rp (fromIntegral $ rXr*a + rXc*b) | ||
392 | =<< peekElemOff mp (fromIntegral $ mXr*iab + mXc*jab) | ||
393 | bloop (succ b) | ||
394 | aloop (succ a) | ||
395 | return 0 | ||