summaryrefslogtreecommitdiff
path: root/src/Data/Primitive/Struct.hs
blob: 705e65dab3c6a97c486d80d4b57ddd3008d867f2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
{-# LANGUAGE AllowAmbiguousTypes   #-}
{-# LANGUAGE CPP                   #-}
{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE UndecidableInstances  #-}
module Data.Primitive.Struct where

import Control.Monad.Primitive
import Data.Primitive.ByteArray
import Data.Primitive.ByteArray.Util
import Data.Primitive.Types
import Data.Tagged
import Data.Typeable
import Foreign.ForeignPtr
import Foreign.Ptr
import Foreign.Storable
import GHC.TypeLits

newtype Field tag typ n = Field (Offset n)

data Struct m base tag = Struct
    { structOffset :: !(Offset base)
    , structArray  :: !(MutableByteArray (PrimState m))
    }

newStruct :: forall tag m. (KnownNat (SizeOf tag), PrimMonad m) => m (Struct m 0 tag)
newStruct = Struct (Offset 0) <$> newPinnedByteArray (fromIntegral sz)
 where
    sz = natVal (Proxy :: Proxy (SizeOf tag))

newtype Nested tag subtag n = Nested (Offset n)

class IsStruct m p where
    type BaseOffset p :: Nat
    type NestedStruct m p (offset::Nat) subtag

    setField :: ( Prim a
#if __GLASGOW_HASKELL__ >= 802
           , IsMultipleOf ((BaseOffset p) + k) (SizeOf a)
#endif
           ) => p tag -> Field tag a k -> a -> m ()

    getField :: ( Prim a
#if __GLASGOW_HASKELL__ >= 802
           , IsMultipleOf ((BaseOffset p) + k) (SizeOf a)
#endif
           ) => p tag -> Field tag a k -> m a

    nestedField :: p tag -> Field tag subtag k -> proxy m -> NestedStruct m p k subtag

class IsField (lbl::Symbol) tag where
    type FieldOffset lbl tag :: Nat
    type FieldType lbl tag
    field :: p tag -> Field tag (FieldType lbl tag) (FieldOffset lbl tag)

set :: forall lbl tag m p. (IsField lbl tag, IsStruct m p, Prim (FieldType lbl tag),
#if __GLASGOW_HASKELL__ >= 802
              IsMultipleOf (BaseOffset p + FieldOffset lbl tag) (SizeOf (FieldType lbl tag))
#endif
              ) =>
             p tag -> FieldType lbl tag -> m ()
set p a = setField p (field @lbl p) a
{-# INLINE set #-}

get :: forall lbl tag m p. (IsField lbl tag, IsStruct m p, Prim (FieldType lbl tag),
#if __GLASGOW_HASKELL__ >= 802
              IsMultipleOf (BaseOffset p + FieldOffset lbl tag) (SizeOf (FieldType lbl tag))
#endif
              ) =>
             p tag -> m (FieldType lbl tag)
get p = getField p (field @lbl p)
{-# INLINE get #-}


modify :: forall lbl tag m p.
    ( Monad m
    , IsField lbl tag
    , IsStruct m p
    , Prim (FieldType lbl tag)
#if __GLASGOW_HASKELL__ >= 802
    , IsMultipleOf (BaseOffset p + FieldOffset lbl tag) (SizeOf (FieldType lbl tag))
#endif
    ) => p tag -> (FieldType lbl tag -> FieldType lbl tag) -> m ()
modify p f = get @lbl p >>= set @lbl p . f

nested :: forall lbl m p tag. (IsField lbl tag, IsStruct m p) =>
                p tag
                -> NestedStruct m p (FieldOffset lbl tag) (FieldType lbl tag)
nested p = nestedField p (field @lbl p) (Proxy @m)

instance PrimMonad m => IsStruct m (Struct m base) where
    type BaseOffset (Struct m base) = base
    type NestedStruct m (Struct m base) j t = Struct m (base + j) t
    setField (Struct o c) (Field field) value = writeAtByte c (o +. field) value
    getField (Struct o c) (Field field) = readAtByte c (o +. field)
    nestedField (Struct base ary) (Field offset) _ = Struct (base +. offset) ary

instance IsStruct IO Ptr where
    type BaseOffset Ptr = 0
    type NestedStruct IO Ptr j t = Ptr t
    setField ptr (Field (Offset o)) value = poke (ptr `plusPtr` o) $ PrimStorable value
    getField ptr (Field (Offset o)) = getPrimStorable <$> peek (ptr `plusPtr` o)
    nestedField ptr (Field (Offset offset)) _ = castPtr (plusPtr ptr offset)

withPointer :: Struct IO base tag -> (Ptr tag -> IO x) -> IO x
withPointer (Struct (Offset off) ary) f = do
    x <- f (ptr (mutableByteArrayContents ary) `plusPtr` off)
    seq ary $ return x

data ForeignStruct tag = ForeignStruct
    { fsPtr    :: !(ForeignPtr tag)
    , fsOffset :: !Int
    }

instance IsStruct IO ForeignStruct where
    type BaseOffset ForeignStruct = TypeError (Text "ForeignStruct has no type-level offset information.")
    type NestedStruct IO ForeignStruct j t = ForeignStruct t
    setField (ForeignStruct fptr base) (Field (Offset o)) val = withForeignPtr fptr $ \ptr -> do
        poke (castPtr $ ptr `plusPtr` o `plusPtr` base) $ PrimStorable val
    getField (ForeignStruct fptr base) (Field (Offset o)) = withForeignPtr fptr $ \ptr -> do
        getPrimStorable <$> peek (castPtr $ ptr `plusPtr` o `plusPtr` base)
    nestedField (ForeignStruct fptr base) (Field (Offset o)) _ = ForeignStruct (castForeignPtr fptr) (base + o)