summaryrefslogtreecommitdiff
path: root/src/Data/Primitive/Struct.hs
blob: 9175cc5ace36d4904fe1923621ef6d3b2f2d2c9c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
{-# LANGUAGE AllowAmbiguousTypes   #-}
{-# LANGUAGE CPP                   #-}
{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE UndecidableInstances  #-}
module Data.Primitive.Struct where

import Control.Monad.Primitive
import Data.Primitive.ByteArray
import Data.Primitive.ByteArray.Util
import Data.Primitive.Types
import Data.Tagged
import Data.Typeable
import Foreign.ForeignPtr
import Foreign.Ptr
import Foreign.Storable
import GHC.TypeLits

-- | A Record Field
--
--   Type Parameters:
--      @tag@   -   Type constructor, which identifies the struct, typically unpopulated.
--      @typ@   -   The type of the value stored at this field
--      @n@     -   The byte-offset as a type-level natural within the struct (location of field's value).
--
--   Data Parameter: The 'Offset' corresponding to @n@.
--
--   To associate a field name with a field, see the 'IsField' type class.
newtype Field tag typ n = Field (Offset n)

-- | Structs in the garbage collected heap modifiable in either IO or ST
--
--   Type Parameters:
--      @m@     -   Monad, typically or IO,ST, but any instance of 'Control.Monad.Primitive.PrimMonad' will do
--      @base@  -   The byte-address of the struct, within the 'structArray'
--      @tag@   -   Type constructor identifying the struct, helps for inferring 'IsField' instances
data Struct m base tag = Struct
    { structOffset :: !(Offset base)
    -- ^ 'Offset' corresponding to the @base@ type parameter
    , structArray  :: !(MutableByteArray (PrimState m))
    -- ^ mutuable array where values are stored
    }

newStruct :: forall tag m. (KnownNat (SizeOf tag), PrimMonad m) => m (Struct m 0 tag)
newStruct = Struct (Offset 0) <$> newPinnedByteArray (fromIntegral sz)
 where
    sz = natVal (Proxy :: Proxy (SizeOf tag))

newtype Nested tag subtag n = Nested (Offset n)

class IsStruct m p where
    type BaseOffset p :: Nat
    type NestedStruct m p (offset::Nat) subtag

    setField :: ( Prim a
#if __GLASGOW_HASKELL__ >= 802
           , IsMultipleOf ((BaseOffset p) + k) (SizeOf a)
#endif
           ) => p tag -> Field tag a k -> a -> m ()

    getField :: ( Prim a
#if __GLASGOW_HASKELL__ >= 802
           , IsMultipleOf ((BaseOffset p) + k) (SizeOf a)
#endif
           ) => p tag -> Field tag a k -> m a

    nestedField :: p tag -> Field tag subtag k -> proxy m -> NestedStruct m p k subtag

class IsField (lbl::Symbol) tag where
    type FieldOffset lbl tag :: Nat
    type FieldType lbl tag
    field :: p tag -> Field tag (FieldType lbl tag) (FieldOffset lbl tag)

set :: forall lbl tag m p. (IsField lbl tag, IsStruct m p, Prim (FieldType lbl tag),
#if __GLASGOW_HASKELL__ >= 802
              IsMultipleOf (BaseOffset p + FieldOffset lbl tag) (SizeOf (FieldType lbl tag))
#endif
              ) =>
             p tag -> FieldType lbl tag -> m ()
set p a = setField p (field @lbl p) a
{-# INLINE set #-}

get :: forall lbl tag m p. (IsField lbl tag, IsStruct m p, Prim (FieldType lbl tag),
#if __GLASGOW_HASKELL__ >= 802
              IsMultipleOf (BaseOffset p + FieldOffset lbl tag) (SizeOf (FieldType lbl tag))
#endif
              ) =>
             p tag -> m (FieldType lbl tag)
get p = getField p (field @lbl p)
{-# INLINE get #-}


modify :: forall lbl tag m p.
    ( Monad m
    , IsField lbl tag
    , IsStruct m p
    , Prim (FieldType lbl tag)
#if __GLASGOW_HASKELL__ >= 802
    , IsMultipleOf (BaseOffset p + FieldOffset lbl tag) (SizeOf (FieldType lbl tag))
#endif
    ) => p tag -> (FieldType lbl tag -> FieldType lbl tag) -> m ()
modify p f = get @lbl p >>= set @lbl p . f

nested :: forall lbl m p tag. (IsField lbl tag, IsStruct m p) =>
                p tag
                -> NestedStruct m p (FieldOffset lbl tag) (FieldType lbl tag)
nested p = nestedField p (field @lbl p) (Proxy @m)

instance PrimMonad m => IsStruct m (Struct m base) where
    type BaseOffset (Struct m base) = base
    type NestedStruct m (Struct m base) j t = Struct m (base + j) t
    setField (Struct o c) (Field field) value = writeAtByte c (o +. field) value
    getField (Struct o c) (Field field) = readAtByte c (o +. field)
    nestedField (Struct base ary) (Field offset) _ = Struct (base +. offset) ary

instance IsStruct IO Ptr where
    type BaseOffset Ptr = 0
    type NestedStruct IO Ptr j t = Ptr t
    setField ptr (Field (Offset o)) value = poke (ptr `plusPtr` o) $ PrimStorable value
    getField ptr (Field (Offset o)) = getPrimStorable <$> peek (ptr `plusPtr` o)
    nestedField ptr (Field (Offset offset)) _ = castPtr (plusPtr ptr offset)

withPointer :: Struct IO base tag -> (Ptr tag -> IO x) -> IO x
withPointer (Struct (Offset off) ary) f = do
#if !MIN_VERSION_primitive(0,7,0)
    x <- f (ptr (mutableByteArrayContents ary) `plusPtr` off)
#else
    x <- f ((mutableByteArrayContents ary) `plusPtr` off)
#endif
    seq ary $ return x

data ForeignStruct tag = ForeignStruct
    { fsPtr    :: !(ForeignPtr tag)
    , fsOffset :: !Int
    }

instance IsStruct IO ForeignStruct where
    type BaseOffset ForeignStruct = TypeError (Text "ForeignStruct has no type-level offset information.")
    type NestedStruct IO ForeignStruct j t = ForeignStruct t
    setField (ForeignStruct fptr base) (Field (Offset o)) val = withForeignPtr fptr $ \ptr -> do
        poke (castPtr $ ptr `plusPtr` o `plusPtr` base) $ PrimStorable val
    getField (ForeignStruct fptr base) (Field (Offset o)) = withForeignPtr fptr $ \ptr -> do
        getPrimStorable <$> peek (castPtr $ ptr `plusPtr` o `plusPtr` base)
    nestedField (ForeignStruct fptr base) (Field (Offset o)) _ = ForeignStruct (castForeignPtr fptr) (base + o)