summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Specialized.hs
blob: ff89024ae2c054b6844e838990de590a65548bab (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
module Internal.Specialized where

import Control.Monad
import Data.Complex
import Data.Int
import Data.Typeable
import Foreign.Marshal.Alloc(free,malloc)
import Foreign.Marshal.Array(newArray,copyArray)
import Foreign.Ptr
import Foreign.Storable
import Foreign.C.Types (CInt(..))
import Data.Vector.Storable (Vector)
import System.IO.Unsafe

import Internal.Vector (createVector)
import Internal.Devel

type Element t = (Storable t, Typeable t)

data Specialized a
    = SpFloat   !(a :~: Float)
    | SpDouble  !(a :~: Double)
    | SpCFloat  !(a :~: Complex Float)
    | SpCDouble !(a :~: Complex Double)
    | SpInt32   !(a :~: Int32)
    | SpInt64   !(a :~: Int64)

specialize :: Typeable a => a -> Maybe (Specialized a)
specialize x = foldr1 mplus
    [ SpDouble  <$> cst x
    , SpInt64   <$> cst x
    , SpFloat   <$> cst x
    , SpInt32   <$> cst x
    , SpCDouble <$> cst x
    , SpCFloat  <$> cst x
    -- , fmap (\(CInt y) -> SpInt32 y) <$> cast x
    ]
 where
    cst :: (Typeable a, Typeable b) => a -> Maybe (a :~: b)
    cst _ = eqT

-- | Supported matrix elements.
constantD  :: Typeable a => a -> Int -> Vector a
constantD x = case specialize x of
    Nothing -> error "constantD"
    Just (SpDouble  Refl) -> constantAux cconstantR x
    Just (SpInt64   Refl) -> constantAux cconstantL x
    Just (SpFloat   Refl) -> constantAux cconstantF x
    Just (SpInt32   Refl) -> constantAux cconstantI x
    Just (SpCDouble Refl) -> constantAux cconstantC x
    Just (SpCFloat  Refl) -> constantAux cconstantQ x

{-
extractR :: Typeable a => MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a)
setRect  :: Typeable a => Int -> Int -> Matrix a -> Matrix a -> IO ()
sortI    :: (Typeable a , Ord a ) => Vector a -> Vector CInt
sortV    :: (Typeable a , Ord a ) => Vector a -> Vector a
compareV :: (Typeable a , Ord a ) => Vector a -> Vector a -> Vector CInt
selectV  :: Typeable a => Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a
remapM   :: Typeable a => Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
rowOp    :: Typeable a => Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
gemm     :: Typeable a => Vector a -> Matrix a -> Matrix a -> Matrix a -> IO ()
reorderV :: Typeable a => Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation
-}


( extractR , setRect , sortI , sortV , compareV , selectV , remapM , rowOp , gemm , reorderV )
    = error "todo Element"

constantAux :: (Storable a1, Storable a)
            => (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a
constantAux fun x n = unsafePerformIO $ do
    v <- createVector n
    px <- newArray [x]
    (applyRaw v id) (fun px) #|"constantAux"
    free px
    return v

type TConst t = Ptr t -> CInt -> Ptr t -> IO CInt
foreign import ccall unsafe "constantF" cconstantF :: TConst Float
foreign import ccall unsafe "constantR" cconstantR :: TConst Double
foreign import ccall unsafe "constantQ" cconstantQ :: TConst (Complex Float)
foreign import ccall unsafe "constantC" cconstantC :: TConst (Complex Double)
foreign import ccall unsafe "constantI" cconstantI :: TConst Int32
foreign import ccall unsafe "constantL" cconstantL :: TConst Int64