summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoe Crayne <joe@jerkface.net>2019-08-09 19:27:40 -0400
committerJoe Crayne <joe@jerkface.net>2019-08-09 19:28:26 -0400
commit45c2c85c3fbf3173df685bf3143af50e8569e40d (patch)
tree00d4cb8e657a5b3e65b98af752fd635d9d1873ca
parentd844a145f2e8808c9f75cd99c673d5f5c8960bf2 (diff)
Internal.Specialized
-rw-r--r--packages/base/src/Internal/Specialized.hs89
1 files changed, 89 insertions, 0 deletions
diff --git a/packages/base/src/Internal/Specialized.hs b/packages/base/src/Internal/Specialized.hs
new file mode 100644
index 0000000..ff89024
--- /dev/null
+++ b/packages/base/src/Internal/Specialized.hs
@@ -0,0 +1,89 @@
1{-# LANGUAGE BangPatterns #-}
2{-# LANGUAGE ConstraintKinds #-}
3{-# LANGUAGE TypeOperators #-}
4{-# LANGUAGE TypeFamilies #-}
5module Internal.Specialized where
6
7import Control.Monad
8import Data.Complex
9import Data.Int
10import Data.Typeable
11import Foreign.Marshal.Alloc(free,malloc)
12import Foreign.Marshal.Array(newArray,copyArray)
13import Foreign.Ptr
14import Foreign.Storable
15import Foreign.C.Types (CInt(..))
16import Data.Vector.Storable (Vector)
17import System.IO.Unsafe
18
19import Internal.Vector (createVector)
20import Internal.Devel
21
22type Element t = (Storable t, Typeable t)
23
24data Specialized a
25 = SpFloat !(a :~: Float)
26 | SpDouble !(a :~: Double)
27 | SpCFloat !(a :~: Complex Float)
28 | SpCDouble !(a :~: Complex Double)
29 | SpInt32 !(a :~: Int32)
30 | SpInt64 !(a :~: Int64)
31
32specialize :: Typeable a => a -> Maybe (Specialized a)
33specialize x = foldr1 mplus
34 [ SpDouble <$> cst x
35 , SpInt64 <$> cst x
36 , SpFloat <$> cst x
37 , SpInt32 <$> cst x
38 , SpCDouble <$> cst x
39 , SpCFloat <$> cst x
40 -- , fmap (\(CInt y) -> SpInt32 y) <$> cast x
41 ]
42 where
43 cst :: (Typeable a, Typeable b) => a -> Maybe (a :~: b)
44 cst _ = eqT
45
46-- | Supported matrix elements.
47constantD :: Typeable a => a -> Int -> Vector a
48constantD x = case specialize x of
49 Nothing -> error "constantD"
50 Just (SpDouble Refl) -> constantAux cconstantR x
51 Just (SpInt64 Refl) -> constantAux cconstantL x
52 Just (SpFloat Refl) -> constantAux cconstantF x
53 Just (SpInt32 Refl) -> constantAux cconstantI x
54 Just (SpCDouble Refl) -> constantAux cconstantC x
55 Just (SpCFloat Refl) -> constantAux cconstantQ x
56
57{-
58extractR :: Typeable a => MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a)
59setRect :: Typeable a => Int -> Int -> Matrix a -> Matrix a -> IO ()
60sortI :: (Typeable a , Ord a ) => Vector a -> Vector CInt
61sortV :: (Typeable a , Ord a ) => Vector a -> Vector a
62compareV :: (Typeable a , Ord a ) => Vector a -> Vector a -> Vector CInt
63selectV :: Typeable a => Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a
64remapM :: Typeable a => Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
65rowOp :: Typeable a => Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
66gemm :: Typeable a => Vector a -> Matrix a -> Matrix a -> Matrix a -> IO ()
67reorderV :: Typeable a => Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation
68-}
69
70
71( extractR , setRect , sortI , sortV , compareV , selectV , remapM , rowOp , gemm , reorderV )
72 = error "todo Element"
73
74constantAux :: (Storable a1, Storable a)
75 => (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a
76constantAux fun x n = unsafePerformIO $ do
77 v <- createVector n
78 px <- newArray [x]
79 (applyRaw v id) (fun px) #|"constantAux"
80 free px
81 return v
82
83type TConst t = Ptr t -> CInt -> Ptr t -> IO CInt
84foreign import ccall unsafe "constantF" cconstantF :: TConst Float
85foreign import ccall unsafe "constantR" cconstantR :: TConst Double
86foreign import ccall unsafe "constantQ" cconstantQ :: TConst (Complex Float)
87foreign import ccall unsafe "constantC" cconstantC :: TConst (Complex Double)
88foreign import ccall unsafe "constantI" cconstantI :: TConst Int32
89foreign import ccall unsafe "constantL" cconstantL :: TConst Int64