diff options
author | Joe Crayne <joe@jerkface.net> | 2019-08-09 19:27:40 -0400 |
---|---|---|
committer | Joe Crayne <joe@jerkface.net> | 2019-08-09 19:28:26 -0400 |
commit | 45c2c85c3fbf3173df685bf3143af50e8569e40d (patch) | |
tree | 00d4cb8e657a5b3e65b98af752fd635d9d1873ca /packages/base/src/Internal/Specialized.hs | |
parent | d844a145f2e8808c9f75cd99c673d5f5c8960bf2 (diff) |
Internal.Specialized
Diffstat (limited to 'packages/base/src/Internal/Specialized.hs')
-rw-r--r-- | packages/base/src/Internal/Specialized.hs | 89 |
1 files changed, 89 insertions, 0 deletions
diff --git a/packages/base/src/Internal/Specialized.hs b/packages/base/src/Internal/Specialized.hs new file mode 100644 index 0000000..ff89024 --- /dev/null +++ b/packages/base/src/Internal/Specialized.hs | |||
@@ -0,0 +1,89 @@ | |||
1 | {-# LANGUAGE BangPatterns #-} | ||
2 | {-# LANGUAGE ConstraintKinds #-} | ||
3 | {-# LANGUAGE TypeOperators #-} | ||
4 | {-# LANGUAGE TypeFamilies #-} | ||
5 | module Internal.Specialized where | ||
6 | |||
7 | import Control.Monad | ||
8 | import Data.Complex | ||
9 | import Data.Int | ||
10 | import Data.Typeable | ||
11 | import Foreign.Marshal.Alloc(free,malloc) | ||
12 | import Foreign.Marshal.Array(newArray,copyArray) | ||
13 | import Foreign.Ptr | ||
14 | import Foreign.Storable | ||
15 | import Foreign.C.Types (CInt(..)) | ||
16 | import Data.Vector.Storable (Vector) | ||
17 | import System.IO.Unsafe | ||
18 | |||
19 | import Internal.Vector (createVector) | ||
20 | import Internal.Devel | ||
21 | |||
22 | type Element t = (Storable t, Typeable t) | ||
23 | |||
24 | data Specialized a | ||
25 | = SpFloat !(a :~: Float) | ||
26 | | SpDouble !(a :~: Double) | ||
27 | | SpCFloat !(a :~: Complex Float) | ||
28 | | SpCDouble !(a :~: Complex Double) | ||
29 | | SpInt32 !(a :~: Int32) | ||
30 | | SpInt64 !(a :~: Int64) | ||
31 | |||
32 | specialize :: Typeable a => a -> Maybe (Specialized a) | ||
33 | specialize x = foldr1 mplus | ||
34 | [ SpDouble <$> cst x | ||
35 | , SpInt64 <$> cst x | ||
36 | , SpFloat <$> cst x | ||
37 | , SpInt32 <$> cst x | ||
38 | , SpCDouble <$> cst x | ||
39 | , SpCFloat <$> cst x | ||
40 | -- , fmap (\(CInt y) -> SpInt32 y) <$> cast x | ||
41 | ] | ||
42 | where | ||
43 | cst :: (Typeable a, Typeable b) => a -> Maybe (a :~: b) | ||
44 | cst _ = eqT | ||
45 | |||
46 | -- | Supported matrix elements. | ||
47 | constantD :: Typeable a => a -> Int -> Vector a | ||
48 | constantD x = case specialize x of | ||
49 | Nothing -> error "constantD" | ||
50 | Just (SpDouble Refl) -> constantAux cconstantR x | ||
51 | Just (SpInt64 Refl) -> constantAux cconstantL x | ||
52 | Just (SpFloat Refl) -> constantAux cconstantF x | ||
53 | Just (SpInt32 Refl) -> constantAux cconstantI x | ||
54 | Just (SpCDouble Refl) -> constantAux cconstantC x | ||
55 | Just (SpCFloat Refl) -> constantAux cconstantQ x | ||
56 | |||
57 | {- | ||
58 | extractR :: Typeable a => MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) | ||
59 | setRect :: Typeable a => Int -> Int -> Matrix a -> Matrix a -> IO () | ||
60 | sortI :: (Typeable a , Ord a ) => Vector a -> Vector CInt | ||
61 | sortV :: (Typeable a , Ord a ) => Vector a -> Vector a | ||
62 | compareV :: (Typeable a , Ord a ) => Vector a -> Vector a -> Vector CInt | ||
63 | selectV :: Typeable a => Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a | ||
64 | remapM :: Typeable a => Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a | ||
65 | rowOp :: Typeable a => Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () | ||
66 | gemm :: Typeable a => Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () | ||
67 | reorderV :: Typeable a => Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation | ||
68 | -} | ||
69 | |||
70 | |||
71 | ( extractR , setRect , sortI , sortV , compareV , selectV , remapM , rowOp , gemm , reorderV ) | ||
72 | = error "todo Element" | ||
73 | |||
74 | constantAux :: (Storable a1, Storable a) | ||
75 | => (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a | ||
76 | constantAux fun x n = unsafePerformIO $ do | ||
77 | v <- createVector n | ||
78 | px <- newArray [x] | ||
79 | (applyRaw v id) (fun px) #|"constantAux" | ||
80 | free px | ||
81 | return v | ||
82 | |||
83 | type TConst t = Ptr t -> CInt -> Ptr t -> IO CInt | ||
84 | foreign import ccall unsafe "constantF" cconstantF :: TConst Float | ||
85 | foreign import ccall unsafe "constantR" cconstantR :: TConst Double | ||
86 | foreign import ccall unsafe "constantQ" cconstantQ :: TConst (Complex Float) | ||
87 | foreign import ccall unsafe "constantC" cconstantC :: TConst (Complex Double) | ||
88 | foreign import ccall unsafe "constantI" cconstantI :: TConst Int32 | ||
89 | foreign import ccall unsafe "constantL" cconstantL :: TConst Int64 | ||