1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
|
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
module Internal.Specialized where
import Control.Monad
import Data.Complex
import Data.Int
import Data.Typeable
import Foreign.Marshal.Alloc(free,malloc)
import Foreign.Marshal.Array(newArray,copyArray)
import Foreign.Ptr
import Foreign.Storable
import Foreign.C.Types (CInt(..))
import Data.Vector.Storable (Vector)
import System.IO.Unsafe
import Internal.Vector (createVector)
import Internal.Devel
type Element t = (Storable t, Typeable t)
data Specialized a
= SpFloat !(a :~: Float)
| SpDouble !(a :~: Double)
| SpCFloat !(a :~: Complex Float)
| SpCDouble !(a :~: Complex Double)
| SpInt32 !(a :~: Int32)
| SpInt64 !(a :~: Int64)
specialize :: Typeable a => a -> Maybe (Specialized a)
specialize x = foldr1 mplus
[ SpDouble <$> cst x
, SpInt64 <$> cst x
, SpFloat <$> cst x
, SpInt32 <$> cst x
, SpCDouble <$> cst x
, SpCFloat <$> cst x
-- , fmap (\(CInt y) -> SpInt32 y) <$> cast x
]
where
cst :: (Typeable a, Typeable b) => a -> Maybe (a :~: b)
cst _ = eqT
-- | Supported matrix elements.
constantD :: Typeable a => a -> Int -> Vector a
constantD x = case specialize x of
Nothing -> error "constantD"
Just (SpDouble Refl) -> constantAux cconstantR x
Just (SpInt64 Refl) -> constantAux cconstantL x
Just (SpFloat Refl) -> constantAux cconstantF x
Just (SpInt32 Refl) -> constantAux cconstantI x
Just (SpCDouble Refl) -> constantAux cconstantC x
Just (SpCFloat Refl) -> constantAux cconstantQ x
{-
extractR :: Typeable a => MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a)
setRect :: Typeable a => Int -> Int -> Matrix a -> Matrix a -> IO ()
sortI :: (Typeable a , Ord a ) => Vector a -> Vector CInt
sortV :: (Typeable a , Ord a ) => Vector a -> Vector a
compareV :: (Typeable a , Ord a ) => Vector a -> Vector a -> Vector CInt
selectV :: Typeable a => Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a
remapM :: Typeable a => Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
rowOp :: Typeable a => Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
gemm :: Typeable a => Vector a -> Matrix a -> Matrix a -> Matrix a -> IO ()
reorderV :: Typeable a => Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation
-}
( extractR , setRect , sortI , sortV , compareV , selectV , remapM , rowOp , gemm , reorderV )
= error "todo Element"
constantAux :: (Storable a1, Storable a)
=> (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a
constantAux fun x n = unsafePerformIO $ do
v <- createVector n
px <- newArray [x]
(applyRaw v id) (fun px) #|"constantAux"
free px
return v
type TConst t = Ptr t -> CInt -> Ptr t -> IO CInt
foreign import ccall unsafe "constantF" cconstantF :: TConst Float
foreign import ccall unsafe "constantR" cconstantR :: TConst Double
foreign import ccall unsafe "constantQ" cconstantQ :: TConst (Complex Float)
foreign import ccall unsafe "constantC" cconstantC :: TConst (Complex Double)
foreign import ccall unsafe "constantI" cconstantI :: TConst Int32
foreign import ccall unsafe "constantL" cconstantL :: TConst Int64
|