diff options
Diffstat (limited to 'src/Data/Primitive/Struct.hs')
-rw-r--r-- | src/Data/Primitive/Struct.hs | 112 |
1 files changed, 112 insertions, 0 deletions
diff --git a/src/Data/Primitive/Struct.hs b/src/Data/Primitive/Struct.hs new file mode 100644 index 0000000..154b750 --- /dev/null +++ b/src/Data/Primitive/Struct.hs | |||
@@ -0,0 +1,112 @@ | |||
1 | {-# LANGUAGE AllowAmbiguousTypes #-} | ||
2 | {-# LANGUAGE CPP #-} | ||
3 | {-# LANGUAGE DataKinds #-} | ||
4 | {-# LANGUAGE FlexibleContexts #-} | ||
5 | {-# LANGUAGE FlexibleInstances #-} | ||
6 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
7 | {-# LANGUAGE ScopedTypeVariables #-} | ||
8 | {-# LANGUAGE TypeApplications #-} | ||
9 | {-# LANGUAGE TypeFamilies #-} | ||
10 | {-# LANGUAGE TypeOperators #-} | ||
11 | module Data.Primitive.Struct where | ||
12 | |||
13 | import Control.Monad.Primitive | ||
14 | import Data.Primitive.ByteArray | ||
15 | import Data.Primitive.ByteArray.Util | ||
16 | import Data.Primitive.Types | ||
17 | import Data.Tagged | ||
18 | import Data.Typeable | ||
19 | import Foreign.Ptr | ||
20 | import Foreign.Storable | ||
21 | import GHC.TypeLits | ||
22 | |||
23 | newtype Field tag typ n = Field (Offset n) | ||
24 | |||
25 | data Struct m base tag = Struct | ||
26 | { structOffset :: !(Offset base) | ||
27 | , structArray :: !(MutableByteArray (PrimState m)) | ||
28 | } | ||
29 | |||
30 | newStruct :: forall tag m. (KnownNat (SizeOf tag), PrimMonad m) => m (Struct m 0 tag) | ||
31 | newStruct = Struct (Offset 0) <$> newPinnedByteArray (fromIntegral sz) | ||
32 | where | ||
33 | sz = natVal (Proxy :: Proxy (SizeOf tag)) | ||
34 | |||
35 | newtype Nested tag subtag n = Nested (Offset n) | ||
36 | |||
37 | class IsStruct m p where | ||
38 | type BaseOffset p :: Nat | ||
39 | type NestedStruct m p (offset::Nat) subtag | ||
40 | |||
41 | setField :: ( Prim a | ||
42 | #if __GLASGOW_HASKELL__ >= 802 | ||
43 | , IsMultipleOf ((BaseOffset p) + k) (SizeOf a) | ||
44 | #endif | ||
45 | ) => p tag -> Field tag a k -> a -> m () | ||
46 | |||
47 | getField :: ( Prim a | ||
48 | #if __GLASGOW_HASKELL__ >= 802 | ||
49 | , IsMultipleOf ((BaseOffset p) + k) (SizeOf a) | ||
50 | #endif | ||
51 | ) => p tag -> Field tag a k -> m a | ||
52 | |||
53 | nestedField :: p tag -> Field tag subtag k -> proxy m -> NestedStruct m p k subtag | ||
54 | |||
55 | class IsField (lbl::Symbol) tag where | ||
56 | type FieldOffset lbl tag :: Nat | ||
57 | type FieldType lbl tag | ||
58 | field :: p tag -> Field tag (FieldType lbl tag) (FieldOffset lbl tag) | ||
59 | |||
60 | set :: forall lbl tag m p. (IsField lbl tag, IsStruct m p, Prim (FieldType lbl tag), | ||
61 | #if __GLASGOW_HASKELL__ >= 802 | ||
62 | IsMultipleOf (BaseOffset p + FieldOffset lbl tag) (SizeOf (FieldType lbl tag)) | ||
63 | #endif | ||
64 | ) => | ||
65 | p tag -> FieldType lbl tag -> m () | ||
66 | set p a = setField p (field @lbl p) a | ||
67 | {-# INLINE set #-} | ||
68 | |||
69 | get :: forall lbl tag m p. (IsField lbl tag, IsStruct m p, Prim (FieldType lbl tag), | ||
70 | #if __GLASGOW_HASKELL__ >= 802 | ||
71 | IsMultipleOf (BaseOffset p + FieldOffset lbl tag) (SizeOf (FieldType lbl tag)) | ||
72 | #endif | ||
73 | ) => | ||
74 | p tag -> m (FieldType lbl tag) | ||
75 | get p = getField p (field @lbl p) | ||
76 | {-# INLINE get #-} | ||
77 | |||
78 | |||
79 | modify :: forall lbl tag m p. | ||
80 | ( Monad m | ||
81 | , IsField lbl tag | ||
82 | , IsStruct m p | ||
83 | , Prim (FieldType lbl tag) | ||
84 | #if __GLASGOW_HASKELL__ >= 802 | ||
85 | , IsMultipleOf (BaseOffset p + FieldOffset lbl tag) (SizeOf (FieldType lbl tag)) | ||
86 | #endif | ||
87 | ) => p tag -> (FieldType lbl tag -> FieldType lbl tag) -> m () | ||
88 | modify p f = get @lbl p >>= set @lbl p . f | ||
89 | |||
90 | nested :: forall lbl m p tag. (IsField lbl tag, IsStruct m p) => | ||
91 | p tag | ||
92 | -> NestedStruct m p (FieldOffset lbl tag) (FieldType lbl tag) | ||
93 | nested p = nestedField p (field @lbl p) (Proxy @m) | ||
94 | |||
95 | instance PrimMonad m => IsStruct m (Struct m base) where | ||
96 | type BaseOffset (Struct m base) = base | ||
97 | type NestedStruct m (Struct m base) j t = Struct m (base + j) t | ||
98 | setField (Struct o c) (Field field) value = writeAtByte c (o +. field) value | ||
99 | getField (Struct o c) (Field field) = readAtByte c (o +. field) | ||
100 | nestedField (Struct base ary) (Field offset) _ = Struct (base +. offset) ary | ||
101 | |||
102 | instance IsStruct IO Ptr where | ||
103 | type BaseOffset Ptr = 0 | ||
104 | type NestedStruct IO Ptr j t = Ptr t | ||
105 | setField ptr (Field (Offset o)) value = poke (ptr `plusPtr` o) $ PrimStorable value | ||
106 | getField ptr (Field (Offset o)) = getPrimStorable <$> peek (ptr `plusPtr` o) | ||
107 | nestedField ptr (Field (Offset offset)) _ = castPtr (plusPtr ptr offset) | ||
108 | |||
109 | withPointer :: Struct IO base tag -> (Ptr tag -> IO x) -> IO x | ||
110 | withPointer (Struct (Offset off) ary) f = do | ||
111 | x <- f (ptr (mutableByteArrayContents ary) `plusPtr` off) | ||
112 | seq ary $ return x | ||