summaryrefslogtreecommitdiff
path: root/lib/Data/Packed/Internal.hs
blob: c8ad8d70adbe98d9d8d8a783d478c4ca3fd838cc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
-----------------------------------------------------------------------------
-- |
-- Module      :  Data.Packed.Internal
-- Copyright   :  (c) Alberto Ruiz 2007
-- License     :  GPL-style
--
-- Maintainer  :  Alberto Ruiz <aruiz@um.es>
-- Stability   :  provisional
-- Portability :  portable (uses FFI)
--
-- Fundamental types
--
-----------------------------------------------------------------------------

module Data.Packed.Internal where

import Foreign
import Complex
import Control.Monad(when)
import Debug.Trace

debug x = trace (show x) x

-- | 1D array
data Vector t = V { dim  :: Int
                  , fptr :: ForeignPtr t
                  , ptr  :: Ptr t
                  }

data TransMode = NoTrans | Trans | ConjTrans

-- | 2D array
data Matrix t = M { rows     :: Int
                  , cols     :: Int
                  , mat      :: Vector t
                  , trMode   :: TransMode
                  , isCOrder :: Bool
                  }

data IdxTp = Covariant | Contravariant

-- | multidimensional array
data Tensor t = T { rank   :: Int
                  , dims   :: [Int]
                  , idxNm  :: [String]
                  , idxTp  :: [IdxTp]
                  , ten    :: Vector t
                  }

----------------------------------------------------------------------
instance (Storable a, RealFloat a) => Storable (Complex a) where    --
    alignment x = alignment (realPart x)                            --
    sizeOf x    = 2 * sizeOf (realPart x)                           --
    peek p = do                                                     --
        [re,im] <- peekArray 2 (castPtr p)                          --
        return (re :+ im)                                           --
    poke p (a :+ b) = pokeArray (castPtr p) [a,b]                   --
----------------------------------------------------------------------


-- f // vec a // vec b // vec res // check "vector add" [a,b]

(//) :: x -> (x -> y) -> y
infixl 0 //
(//) = flip ($)

vec :: Vector a -> (Int -> Ptr b -> t) -> t
vec v f = f (dim v) (castPtr $ ptr v)

check msg ls f = do
    err <- f
    when (err/=0) (error msg)
    mapM_ (touchForeignPtr . fptr) ls
    return ()

createVector :: Storable a => Int -> IO (Vector a)
createVector n = do
    when (n <= 0) $ error ("trying to createVector of dim "++show n)
    fp <- mallocForeignPtrArray n
    let p = unsafeForeignPtrToPtr fp
    return $ V n fp p

fromList :: Storable a => [a] -> Vector a
fromList l = unsafePerformIO $ do
    v <- createVector (length l)
    let f _ p = pokeArray p l >> return 0
    f // vec v // check "fromList" []
    return v

toList :: Storable a => Vector a -> [a]
toList v = unsafePerformIO $ peekArray (dim v) (ptr v)

at' :: Storable a => Vector a -> Int -> a
at' v n = unsafePerformIO $ peekElemOff (ptr v) n

at :: Storable a => Vector a -> Int -> a
at v n | n >= 0 && n < dim v = at' v n
       | otherwise          = error "vector index out of range"

constant :: Storable a => Int -> a -> Vector a
constant n x = unsafePerformIO $ do
    v <- createVector n
    let f k p | k == n    = return 0
              | otherwise = pokeElemOff p k x >> f (k+1) p
    const (f 0) // vec v // check "constant" []
    return v

instance (Show a, Storable a) => (Show (Vector a)) where
    show v = "fromList " ++ show (toList v)

instance (Show a, Storable a) => (Show (Matrix a)) where
    show m = "reshape "++show (cols m) ++ " $ fromList " ++ show (toList (mat m))

reshape :: Storable a => Int -> Vector a -> Matrix a
reshape n v = M { rows = dim v `div` n
                , cols = n
                , mat  = v
                , trMode = NoTrans
                , isCOrder = True
                }