summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Extract.hs
blob: 84ee20faf0d7ab68195250ba29b871c0034c942a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
{-# LANGUAGE BangPatterns             #-}
{-# LANGUAGE NondecreasingIndentation #-}
{-# LANGUAGE PatternSynonyms          #-}
{-# LANGUAGE UnboxedTuples            #-}
module Internal.Extract where 
import Control.Monad
import Data.Complex
import Data.Function
import Data.Int
import Foreign.Ptr
import Foreign.Storable

type ConstPtr a = Ptr a
pattern ConstPtr a = a

extractStorable :: Storable t =>
                Int32 -- int modei
                -> Int32 -- int modej
                -> Int32          -- / KIVEC(i)
                -> ConstPtr Int32 -- \
                -> Int32          -- / KIVEC(j)
                -> ConstPtr Int32 -- \
                -> Int32      --   /
                -> Int32      --  /
                -> Int32      -- {  KO##T##MAT(m)
                -> Int32      --  \
                -> ConstPtr t --   \
                -> Int32 --   /
                -> Int32 --  /
                -> Int32 -- {  O##T##MAT(r)
                -> Int32 --  \
                -> Ptr t --   \
                -> IO Int32
extractStorable modei
                modej
                in_ (ConstPtr ip)
                jn (ConstPtr jp)
                mr mc mXr mXc (ConstPtr mp)
                rr rc rXr rXc rp = do
    -- int i,j,si,sj,ni,nj;
    ni <- if modei/=0 then return in_
                      else fmap succ $ (-) <$> peekElemOff ip 1 <*> peekElemOff ip 0
    nj <- if modej/=0 then return jn
                      else fmap succ $ (-) <$> peekElemOff jp 1 <*> peekElemOff jp 0
    ($ 0) $ fix $ \iloop i -> when (i<ni) $ do
            si <- if modei/=0 then peekElemOff ip (fromIntegral i)
                              else (+ i) <$> peek ip
            ($ 0) $ fix $ \jloop j -> when (j<nj) $ do
                    sj <- if modej/=0 then peekElemOff jp (fromIntegral j)
                                      else (+ j) <$> peek jp
                    pokeElemOff rp (fromIntegral $ i*rXr + j*rXc)
                        =<< peekElemOff mp (fromIntegral $ si*mXr + sj*mXc)
                    jloop $! succ j
            iloop $! succ i
    return 0

{-# SPECIALIZE extractStorable ::
                Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32
                -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr Double
                -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Double
                -> IO Int32 #-}

{-# SPECIALIZE extractStorable ::
                Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32
                -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr Float
                -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Float
                -> IO Int32 #-}

{-# SPECIALIZE extractStorable ::
                Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32
                -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr (Complex Double)
                -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr (Complex Double)
                -> IO Int32 #-}

{-# SPECIALIZE extractStorable ::
                Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32
                -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr (Complex Float)
                -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr (Complex Float)
                -> IO Int32 #-}

{-# SPECIALIZE extractStorable ::
                Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32
                -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr Int32
                -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Int32
                -> IO Int32 #-}

{-# SPECIALIZE extractStorable ::
                Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32
                -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr Int64
                -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Int64
                -> IO Int32 #-}

{-
type Reorder x = CV Int32 (CV Int32 (CV Int32 (CV x (CV x (IO Int32)))))

foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double
foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float
foreign import ccall unsafe "reorderI" c_reorderI :: Reorder Int32
foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double)
foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float)
foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z
-}

-- #define ERROR(CODE) MACRO(return CODE;)
-- #define REQUIRES(COND, CODE) MACRO(if(!(COND)) {ERROR(CODE);})

requires :: Monad m => Bool -> Int32 -> m Int32 -> m Int32
requires cond code go =
    if cond then go
            else return code

pattern BAD_SIZE = 2000

reorderStorable :: Storable a =>
              Int32 -> Ptr Int32 -- k
              -> Int32 -> ConstPtr Int32 -- strides
              -> Int32 -> ConstPtr Int32 -- dims
              -> Int32 -> ConstPtr a -- v
              -> Int32 -> Ptr a -- r
              -> IO Int32
reorderStorable kn kp stridesn stridesp dimsn dimsp vn vp rn rp = do
    requires (kn == stridesn && stridesn == dimsn) BAD_SIZE $ do
    let ijlloop !i !j l fin = do
            pokeElemOff kp (fromIntegral l) 0
            dimspl <- peekElemOff dimsp (fromIntegral l)
            stridespl <- peekElemOff stridesp (fromIntegral l)
            if (l<kn) then ijlloop (i * dimspl) (j + stridespl*(dimspl - 1)) (l + 1) fin
                      else fin i j
    ijlloop 1 0 0 $ \i j -> do
    requires (i <= vn && j < rn) BAD_SIZE $ do
    (\go -> go 0 0) $ fix $ \ijloop i j -> do
        pokeElemOff rp (fromIntegral i) =<< peekElemOff vp (fromIntegral j)
        (\go -> go (kn - 1) j) $ fix $ \lloop l !j -> do
            kpl <- succ <$> peekElemOff kp (fromIntegral l)
            pokeElemOff kp (fromIntegral l) kpl
            dimspl <- peekElemOff dimsp (fromIntegral l)
            if (kpl < dimspl)
                then do
                    stridespl <- peekElemOff stridesp (fromIntegral l)
                    ijloop (succ i) (j + stridespl)
                else do
                    if l == 0 then return 0 else do
                        pokeElemOff kp (fromIntegral l) 0
                        stridespl <- peekElemOff stridesp (fromIntegral l)
                        lloop (pred l) (j - stridespl*(dimspl-1))