summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Container.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal/Container.hs')
-rw-r--r--packages/base/src/Internal/Container.hs96
1 files changed, 93 insertions, 3 deletions
diff --git a/packages/base/src/Internal/Container.hs b/packages/base/src/Internal/Container.hs
index 41b8214..0f2e7d5 100644
--- a/packages/base/src/Internal/Container.hs
+++ b/packages/base/src/Internal/Container.hs
@@ -4,6 +4,8 @@
4{-# LANGUAGE MultiParamTypeClasses #-} 4{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE FunctionalDependencies #-} 5{-# LANGUAGE FunctionalDependencies #-}
6{-# LANGUAGE UndecidableInstances #-} 6{-# LANGUAGE UndecidableInstances #-}
7{-# LANGUAGE PatternSynonyms #-}
8{-# LANGUAGE ScopedTypeVariables #-}
7 9
8{-# OPTIONS_GHC -fno-warn-simplifiable-class-constraints #-} 10{-# OPTIONS_GHC -fno-warn-simplifiable-class-constraints #-}
9 11
@@ -30,8 +32,15 @@ module Internal.Container where
30import Internal.Vector 32import Internal.Vector
31import Internal.Matrix 33import Internal.Matrix
32import Internal.Element 34import Internal.Element
35import Internal.Extract(requires,pattern BAD_SIZE)
33import Internal.Numeric 36import Internal.Numeric
34import Internal.Algorithms(Field,linearSolveSVD,Herm,mTm) 37import Internal.Algorithms(Field,linearSolveSVD,Herm,mTm)
38import Control.Monad(when)
39import Data.Function
40import Data.Int
41import Foreign.Ptr
42import Foreign.Storable
43import Foreign.Marshal.Array
35#if MIN_VERSION_base(4,11,0) 44#if MIN_VERSION_base(4,11,0)
36import Prelude hiding ((<>)) 45import Prelude hiding ((<>))
37#endif 46#endif
@@ -227,7 +236,7 @@ meanCov x = (med,cov) where
227 236
228-------------------------------------------------------------------------------- 237--------------------------------------------------------------------------------
229 238
230sortVector :: (Ord t, Element t) => Vector t -> Vector t 239sortVector :: (Ord t, Storable t) => Vector t -> Vector t
231sortVector = sortV 240sortVector = sortV
232 241
233{- | 242{- |
@@ -248,7 +257,7 @@ sortVector = sortV
248-2.20 0.11 -1.58 -0.01 0.19 -0.29 1.04 1.06 -2.09 -0.75 257-2.20 0.11 -1.58 -0.01 0.19 -0.29 1.04 1.06 -2.09 -0.75
249 258
250-} 259-}
251sortIndex :: (Ord t, Element t) => Vector t -> Vector I 260sortIndex :: (Ord t, Storable t) => Vector t -> Vector I
252sortIndex = sortI 261sortIndex = sortI
253 262
254ccompare :: (Ord t, Container c t) => c t -> c t -> c I 263ccompare :: (Ord t, Container c t) => c t -> c t -> c I
@@ -296,10 +305,91 @@ The indexes are autoconformable.
296 , 10, 16, 22 ] 305 , 10, 16, 22 ]
297 306
298-} 307-}
299remap :: Element t => Matrix I -> Matrix I -> Matrix t -> Matrix t 308remap :: Storable t => Matrix I -> Matrix I -> Matrix t -> Matrix t
300remap i j m 309remap i j m
301 | minElement i >= 0 && maxElement i < fromIntegral (rows m) && 310 | minElement i >= 0 && maxElement i < fromIntegral (rows m) &&
302 minElement j >= 0 && maxElement j < fromIntegral (cols m) = remapM i' j' m 311 minElement j >= 0 && maxElement j < fromIntegral (cols m) = remapM i' j' m
303 | otherwise = error $ "out of range index in remap" 312 | otherwise = error $ "out of range index in remap"
304 where 313 where
305 [i',j'] = conformMs [i,j] 314 [i',j'] = conformMs [i,j]
315
316sortI :: (Storable a, Ord a) => Vector a -> Vector Int32
317sortI = sortG sort_index
318
319type C_Compare a = Ptr a -> Ptr a -> IO Int32
320
321foreign import ccall "wrapper" wrapCompare :: C_Compare a -> IO (FunPtr (C_Compare a))
322
323foreign import ccall "qsort"
324 c_qsort :: Ptr a -- ^ base
325 -> Word -- ^ nmemb
326 -> Word -- ^ size
327 -> FunPtr (C_Compare a) -- ^ compar
328 -> IO ()
329
330sizeOfElem :: forall a. Storable a => Ptr a -> Int
331sizeOfElem _ = sizeOf (undefined :: a)
332
333sort_index :: (Storable a, Ord a) =>
334 Int32 -> Ptr a
335 -> Int32 -> Ptr Int32
336 -> IO Int32
337sort_index vn vp rn rp = do
338 requires (vn == rn) BAD_SIZE $ do
339 comp <- wrapCompare $ \ap bp -> do
340 a <- peekElemOff vp . fromIntegral =<< peek (ap :: Ptr Int32)
341 b <- peekElemOff vp . fromIntegral =<< peek bp
342 return $ case compare a b of
343 LT -> -1
344 GT -> 1
345 EQ -> 0
346 sequence_ [ pokeElemOff rp (fromIntegral i) i | i <- [0 .. rn-1] ]
347 c_qsort rp (fromIntegral rn) 4 comp
348 freeHaskellFunPtr comp
349 return 0
350
351sortV :: (Storable a, Ord a) => Vector a -> Vector a
352sortV = sortG sortStorable
353
354sortStorable :: (Storable a, Ord a) =>
355 Int32 -> Ptr a
356 -> Int32 -> Ptr a
357 -> IO Int32
358sortStorable vn vp rn rp = do
359 requires (vn == rn) BAD_SIZE $ do
360 copyArray rp vp (fromIntegral vn * sizeOfElem vp)
361 comp <- wrapCompare $ \ap bp -> do
362 a <- peek ap
363 b <- peek bp
364 return $ case compare a b of
365 LT -> -1
366 GT -> 1
367 EQ -> 0
368 c_qsort rp (fromIntegral rn) (fromIntegral $ sizeOfElem rp) comp
369 freeHaskellFunPtr comp
370 return 0
371
372remapM :: Storable a => Matrix Int32 -> Matrix Int32 -> Matrix a -> Matrix a
373remapM = remapG remapStorable
374
375remapStorable :: Storable a =>
376 Int32 -> Int32 -> Int32 -> Int32 -> Ptr Int32 -- i
377 -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Int32 -- j
378 -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr a -- m
379 -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr a -- r
380 -> IO Int32
381remapStorable ir ic iXr iXc ip
382 jr jc jXr jXc jp
383 mr mc mXr mXc mp
384 rr rc rXr rXc rp = do
385 requires (ir==jr && ic==jc && ir==rr && ic==rc) BAD_SIZE $ do
386 ($ 0) $ fix $ \aloop a -> when (a<rr) $ do
387 ($ 0) $ fix $ \bloop b -> when (b<rc) $ do
388 iab <- peekElemOff ip (fromIntegral $ iXr*a + iXc*b)
389 jab <- peekElemOff jp (fromIntegral $ jXr*a + jXc*b)
390 when (0 <= iab && iab < mr && 0 <= jab && jab < mc) $
391 pokeElemOff rp (fromIntegral $ rXr*a + rXc*b)
392 =<< peekElemOff mp (fromIntegral $ mXr*iab + mXc*jab)
393 bloop (succ b)
394 aloop (succ a)
395 return 0