diff options
Diffstat (limited to 'packages/base/src/Internal/Extract.hs')
-rw-r--r-- | packages/base/src/Internal/Extract.hs | 145 |
1 files changed, 145 insertions, 0 deletions
diff --git a/packages/base/src/Internal/Extract.hs b/packages/base/src/Internal/Extract.hs new file mode 100644 index 0000000..84ee20f --- /dev/null +++ b/packages/base/src/Internal/Extract.hs | |||
@@ -0,0 +1,145 @@ | |||
1 | {-# LANGUAGE BangPatterns #-} | ||
2 | {-# LANGUAGE NondecreasingIndentation #-} | ||
3 | {-# LANGUAGE PatternSynonyms #-} | ||
4 | {-# LANGUAGE UnboxedTuples #-} | ||
5 | module Internal.Extract where | ||
6 | import Control.Monad | ||
7 | import Data.Complex | ||
8 | import Data.Function | ||
9 | import Data.Int | ||
10 | import Foreign.Ptr | ||
11 | import Foreign.Storable | ||
12 | |||
13 | type ConstPtr a = Ptr a | ||
14 | pattern ConstPtr a = a | ||
15 | |||
16 | extractStorable :: Storable t => | ||
17 | Int32 -- int modei | ||
18 | -> Int32 -- int modej | ||
19 | -> Int32 -- / KIVEC(i) | ||
20 | -> ConstPtr Int32 -- \ | ||
21 | -> Int32 -- / KIVEC(j) | ||
22 | -> ConstPtr Int32 -- \ | ||
23 | -> Int32 -- / | ||
24 | -> Int32 -- / | ||
25 | -> Int32 -- { KO##T##MAT(m) | ||
26 | -> Int32 -- \ | ||
27 | -> ConstPtr t -- \ | ||
28 | -> Int32 -- / | ||
29 | -> Int32 -- / | ||
30 | -> Int32 -- { O##T##MAT(r) | ||
31 | -> Int32 -- \ | ||
32 | -> Ptr t -- \ | ||
33 | -> IO Int32 | ||
34 | extractStorable modei | ||
35 | modej | ||
36 | in_ (ConstPtr ip) | ||
37 | jn (ConstPtr jp) | ||
38 | mr mc mXr mXc (ConstPtr mp) | ||
39 | rr rc rXr rXc rp = do | ||
40 | -- int i,j,si,sj,ni,nj; | ||
41 | ni <- if modei/=0 then return in_ | ||
42 | else fmap succ $ (-) <$> peekElemOff ip 1 <*> peekElemOff ip 0 | ||
43 | nj <- if modej/=0 then return jn | ||
44 | else fmap succ $ (-) <$> peekElemOff jp 1 <*> peekElemOff jp 0 | ||
45 | ($ 0) $ fix $ \iloop i -> when (i<ni) $ do | ||
46 | si <- if modei/=0 then peekElemOff ip (fromIntegral i) | ||
47 | else (+ i) <$> peek ip | ||
48 | ($ 0) $ fix $ \jloop j -> when (j<nj) $ do | ||
49 | sj <- if modej/=0 then peekElemOff jp (fromIntegral j) | ||
50 | else (+ j) <$> peek jp | ||
51 | pokeElemOff rp (fromIntegral $ i*rXr + j*rXc) | ||
52 | =<< peekElemOff mp (fromIntegral $ si*mXr + sj*mXc) | ||
53 | jloop $! succ j | ||
54 | iloop $! succ i | ||
55 | return 0 | ||
56 | |||
57 | {-# SPECIALIZE extractStorable :: | ||
58 | Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32 | ||
59 | -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr Double | ||
60 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Double | ||
61 | -> IO Int32 #-} | ||
62 | |||
63 | {-# SPECIALIZE extractStorable :: | ||
64 | Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32 | ||
65 | -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr Float | ||
66 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Float | ||
67 | -> IO Int32 #-} | ||
68 | |||
69 | {-# SPECIALIZE extractStorable :: | ||
70 | Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32 | ||
71 | -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr (Complex Double) | ||
72 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr (Complex Double) | ||
73 | -> IO Int32 #-} | ||
74 | |||
75 | {-# SPECIALIZE extractStorable :: | ||
76 | Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32 | ||
77 | -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr (Complex Float) | ||
78 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr (Complex Float) | ||
79 | -> IO Int32 #-} | ||
80 | |||
81 | {-# SPECIALIZE extractStorable :: | ||
82 | Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32 | ||
83 | -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr Int32 | ||
84 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Int32 | ||
85 | -> IO Int32 #-} | ||
86 | |||
87 | {-# SPECIALIZE extractStorable :: | ||
88 | Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32 | ||
89 | -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr Int64 | ||
90 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Int64 | ||
91 | -> IO Int32 #-} | ||
92 | |||
93 | {- | ||
94 | type Reorder x = CV Int32 (CV Int32 (CV Int32 (CV x (CV x (IO Int32))))) | ||
95 | |||
96 | foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double | ||
97 | foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float | ||
98 | foreign import ccall unsafe "reorderI" c_reorderI :: Reorder Int32 | ||
99 | foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double) | ||
100 | foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float) | ||
101 | foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z | ||
102 | -} | ||
103 | |||
104 | -- #define ERROR(CODE) MACRO(return CODE;) | ||
105 | -- #define REQUIRES(COND, CODE) MACRO(if(!(COND)) {ERROR(CODE);}) | ||
106 | |||
107 | requires :: Monad m => Bool -> Int32 -> m Int32 -> m Int32 | ||
108 | requires cond code go = | ||
109 | if cond then go | ||
110 | else return code | ||
111 | |||
112 | pattern BAD_SIZE = 2000 | ||
113 | |||
114 | reorderStorable :: Storable a => | ||
115 | Int32 -> Ptr Int32 -- k | ||
116 | -> Int32 -> ConstPtr Int32 -- strides | ||
117 | -> Int32 -> ConstPtr Int32 -- dims | ||
118 | -> Int32 -> ConstPtr a -- v | ||
119 | -> Int32 -> Ptr a -- r | ||
120 | -> IO Int32 | ||
121 | reorderStorable kn kp stridesn stridesp dimsn dimsp vn vp rn rp = do | ||
122 | requires (kn == stridesn && stridesn == dimsn) BAD_SIZE $ do | ||
123 | let ijlloop !i !j l fin = do | ||
124 | pokeElemOff kp (fromIntegral l) 0 | ||
125 | dimspl <- peekElemOff dimsp (fromIntegral l) | ||
126 | stridespl <- peekElemOff stridesp (fromIntegral l) | ||
127 | if (l<kn) then ijlloop (i * dimspl) (j + stridespl*(dimspl - 1)) (l + 1) fin | ||
128 | else fin i j | ||
129 | ijlloop 1 0 0 $ \i j -> do | ||
130 | requires (i <= vn && j < rn) BAD_SIZE $ do | ||
131 | (\go -> go 0 0) $ fix $ \ijloop i j -> do | ||
132 | pokeElemOff rp (fromIntegral i) =<< peekElemOff vp (fromIntegral j) | ||
133 | (\go -> go (kn - 1) j) $ fix $ \lloop l !j -> do | ||
134 | kpl <- succ <$> peekElemOff kp (fromIntegral l) | ||
135 | pokeElemOff kp (fromIntegral l) kpl | ||
136 | dimspl <- peekElemOff dimsp (fromIntegral l) | ||
137 | if (kpl < dimspl) | ||
138 | then do | ||
139 | stridespl <- peekElemOff stridesp (fromIntegral l) | ||
140 | ijloop (succ i) (j + stridespl) | ||
141 | else do | ||
142 | if l == 0 then return 0 else do | ||
143 | pokeElemOff kp (fromIntegral l) 0 | ||
144 | stridespl <- peekElemOff stridesp (fromIntegral l) | ||
145 | lloop (pred l) (j - stridespl*(dimspl-1)) | ||