summaryrefslogtreecommitdiff
path: root/src/Data/Primitive/Struct.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Primitive/Struct.hs')
-rw-r--r--src/Data/Primitive/Struct.hs112
1 files changed, 112 insertions, 0 deletions
diff --git a/src/Data/Primitive/Struct.hs b/src/Data/Primitive/Struct.hs
new file mode 100644
index 0000000..154b750
--- /dev/null
+++ b/src/Data/Primitive/Struct.hs
@@ -0,0 +1,112 @@
1{-# LANGUAGE AllowAmbiguousTypes #-}
2{-# LANGUAGE CPP #-}
3{-# LANGUAGE DataKinds #-}
4{-# LANGUAGE FlexibleContexts #-}
5{-# LANGUAGE FlexibleInstances #-}
6{-# LANGUAGE MultiParamTypeClasses #-}
7{-# LANGUAGE ScopedTypeVariables #-}
8{-# LANGUAGE TypeApplications #-}
9{-# LANGUAGE TypeFamilies #-}
10{-# LANGUAGE TypeOperators #-}
11module Data.Primitive.Struct where
12
13import Control.Monad.Primitive
14import Data.Primitive.ByteArray
15import Data.Primitive.ByteArray.Util
16import Data.Primitive.Types
17import Data.Tagged
18import Data.Typeable
19import Foreign.Ptr
20import Foreign.Storable
21import GHC.TypeLits
22
23newtype Field tag typ n = Field (Offset n)
24
25data Struct m base tag = Struct
26 { structOffset :: !(Offset base)
27 , structArray :: !(MutableByteArray (PrimState m))
28 }
29
30newStruct :: forall tag m. (KnownNat (SizeOf tag), PrimMonad m) => m (Struct m 0 tag)
31newStruct = Struct (Offset 0) <$> newPinnedByteArray (fromIntegral sz)
32 where
33 sz = natVal (Proxy :: Proxy (SizeOf tag))
34
35newtype Nested tag subtag n = Nested (Offset n)
36
37class IsStruct m p where
38 type BaseOffset p :: Nat
39 type NestedStruct m p (offset::Nat) subtag
40
41 setField :: ( Prim a
42#if __GLASGOW_HASKELL__ >= 802
43 , IsMultipleOf ((BaseOffset p) + k) (SizeOf a)
44#endif
45 ) => p tag -> Field tag a k -> a -> m ()
46
47 getField :: ( Prim a
48#if __GLASGOW_HASKELL__ >= 802
49 , IsMultipleOf ((BaseOffset p) + k) (SizeOf a)
50#endif
51 ) => p tag -> Field tag a k -> m a
52
53 nestedField :: p tag -> Field tag subtag k -> proxy m -> NestedStruct m p k subtag
54
55class IsField (lbl::Symbol) tag where
56 type FieldOffset lbl tag :: Nat
57 type FieldType lbl tag
58 field :: p tag -> Field tag (FieldType lbl tag) (FieldOffset lbl tag)
59
60set :: forall lbl tag m p. (IsField lbl tag, IsStruct m p, Prim (FieldType lbl tag),
61#if __GLASGOW_HASKELL__ >= 802
62 IsMultipleOf (BaseOffset p + FieldOffset lbl tag) (SizeOf (FieldType lbl tag))
63#endif
64 ) =>
65 p tag -> FieldType lbl tag -> m ()
66set p a = setField p (field @lbl p) a
67{-# INLINE set #-}
68
69get :: forall lbl tag m p. (IsField lbl tag, IsStruct m p, Prim (FieldType lbl tag),
70#if __GLASGOW_HASKELL__ >= 802
71 IsMultipleOf (BaseOffset p + FieldOffset lbl tag) (SizeOf (FieldType lbl tag))
72#endif
73 ) =>
74 p tag -> m (FieldType lbl tag)
75get p = getField p (field @lbl p)
76{-# INLINE get #-}
77
78
79modify :: forall lbl tag m p.
80 ( Monad m
81 , IsField lbl tag
82 , IsStruct m p
83 , Prim (FieldType lbl tag)
84#if __GLASGOW_HASKELL__ >= 802
85 , IsMultipleOf (BaseOffset p + FieldOffset lbl tag) (SizeOf (FieldType lbl tag))
86#endif
87 ) => p tag -> (FieldType lbl tag -> FieldType lbl tag) -> m ()
88modify p f = get @lbl p >>= set @lbl p . f
89
90nested :: forall lbl m p tag. (IsField lbl tag, IsStruct m p) =>
91 p tag
92 -> NestedStruct m p (FieldOffset lbl tag) (FieldType lbl tag)
93nested p = nestedField p (field @lbl p) (Proxy @m)
94
95instance PrimMonad m => IsStruct m (Struct m base) where
96 type BaseOffset (Struct m base) = base
97 type NestedStruct m (Struct m base) j t = Struct m (base + j) t
98 setField (Struct o c) (Field field) value = writeAtByte c (o +. field) value
99 getField (Struct o c) (Field field) = readAtByte c (o +. field)
100 nestedField (Struct base ary) (Field offset) _ = Struct (base +. offset) ary
101
102instance IsStruct IO Ptr where
103 type BaseOffset Ptr = 0
104 type NestedStruct IO Ptr j t = Ptr t
105 setField ptr (Field (Offset o)) value = poke (ptr `plusPtr` o) $ PrimStorable value
106 getField ptr (Field (Offset o)) = getPrimStorable <$> peek (ptr `plusPtr` o)
107 nestedField ptr (Field (Offset offset)) _ = castPtr (plusPtr ptr offset)
108
109withPointer :: Struct IO base tag -> (Ptr tag -> IO x) -> IO x
110withPointer (Struct (Offset off) ary) f = do
111 x <- f (ptr (mutableByteArrayContents ary) `plusPtr` off)
112 seq ary $ return x