diff options
Diffstat (limited to 'lib/Data')
-rw-r--r-- | lib/Data/Packed/Internal.hs | 120 | ||||
-rw-r--r-- | lib/Data/Packed/Matrix.hs | 1 | ||||
-rw-r--r-- | lib/Data/Packed/Tensor.hs | 1 | ||||
-rw-r--r-- | lib/Data/Packed/Vector.hs | 1 |
4 files changed, 123 insertions, 0 deletions
diff --git a/lib/Data/Packed/Internal.hs b/lib/Data/Packed/Internal.hs new file mode 100644 index 0000000..c8ad8d7 --- /dev/null +++ b/lib/Data/Packed/Internal.hs | |||
@@ -0,0 +1,120 @@ | |||
1 | ----------------------------------------------------------------------------- | ||
2 | -- | | ||
3 | -- Module : Data.Packed.Internal | ||
4 | -- Copyright : (c) Alberto Ruiz 2007 | ||
5 | -- License : GPL-style | ||
6 | -- | ||
7 | -- Maintainer : Alberto Ruiz <aruiz@um.es> | ||
8 | -- Stability : provisional | ||
9 | -- Portability : portable (uses FFI) | ||
10 | -- | ||
11 | -- Fundamental types | ||
12 | -- | ||
13 | ----------------------------------------------------------------------------- | ||
14 | |||
15 | module Data.Packed.Internal where | ||
16 | |||
17 | import Foreign | ||
18 | import Complex | ||
19 | import Control.Monad(when) | ||
20 | import Debug.Trace | ||
21 | |||
22 | debug x = trace (show x) x | ||
23 | |||
24 | -- | 1D array | ||
25 | data Vector t = V { dim :: Int | ||
26 | , fptr :: ForeignPtr t | ||
27 | , ptr :: Ptr t | ||
28 | } | ||
29 | |||
30 | data TransMode = NoTrans | Trans | ConjTrans | ||
31 | |||
32 | -- | 2D array | ||
33 | data Matrix t = M { rows :: Int | ||
34 | , cols :: Int | ||
35 | , mat :: Vector t | ||
36 | , trMode :: TransMode | ||
37 | , isCOrder :: Bool | ||
38 | } | ||
39 | |||
40 | data IdxTp = Covariant | Contravariant | ||
41 | |||
42 | -- | multidimensional array | ||
43 | data Tensor t = T { rank :: Int | ||
44 | , dims :: [Int] | ||
45 | , idxNm :: [String] | ||
46 | , idxTp :: [IdxTp] | ||
47 | , ten :: Vector t | ||
48 | } | ||
49 | |||
50 | ---------------------------------------------------------------------- | ||
51 | instance (Storable a, RealFloat a) => Storable (Complex a) where -- | ||
52 | alignment x = alignment (realPart x) -- | ||
53 | sizeOf x = 2 * sizeOf (realPart x) -- | ||
54 | peek p = do -- | ||
55 | [re,im] <- peekArray 2 (castPtr p) -- | ||
56 | return (re :+ im) -- | ||
57 | poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- | ||
58 | ---------------------------------------------------------------------- | ||
59 | |||
60 | |||
61 | -- f // vec a // vec b // vec res // check "vector add" [a,b] | ||
62 | |||
63 | (//) :: x -> (x -> y) -> y | ||
64 | infixl 0 // | ||
65 | (//) = flip ($) | ||
66 | |||
67 | vec :: Vector a -> (Int -> Ptr b -> t) -> t | ||
68 | vec v f = f (dim v) (castPtr $ ptr v) | ||
69 | |||
70 | check msg ls f = do | ||
71 | err <- f | ||
72 | when (err/=0) (error msg) | ||
73 | mapM_ (touchForeignPtr . fptr) ls | ||
74 | return () | ||
75 | |||
76 | createVector :: Storable a => Int -> IO (Vector a) | ||
77 | createVector n = do | ||
78 | when (n <= 0) $ error ("trying to createVector of dim "++show n) | ||
79 | fp <- mallocForeignPtrArray n | ||
80 | let p = unsafeForeignPtrToPtr fp | ||
81 | return $ V n fp p | ||
82 | |||
83 | fromList :: Storable a => [a] -> Vector a | ||
84 | fromList l = unsafePerformIO $ do | ||
85 | v <- createVector (length l) | ||
86 | let f _ p = pokeArray p l >> return 0 | ||
87 | f // vec v // check "fromList" [] | ||
88 | return v | ||
89 | |||
90 | toList :: Storable a => Vector a -> [a] | ||
91 | toList v = unsafePerformIO $ peekArray (dim v) (ptr v) | ||
92 | |||
93 | at' :: Storable a => Vector a -> Int -> a | ||
94 | at' v n = unsafePerformIO $ peekElemOff (ptr v) n | ||
95 | |||
96 | at :: Storable a => Vector a -> Int -> a | ||
97 | at v n | n >= 0 && n < dim v = at' v n | ||
98 | | otherwise = error "vector index out of range" | ||
99 | |||
100 | constant :: Storable a => Int -> a -> Vector a | ||
101 | constant n x = unsafePerformIO $ do | ||
102 | v <- createVector n | ||
103 | let f k p | k == n = return 0 | ||
104 | | otherwise = pokeElemOff p k x >> f (k+1) p | ||
105 | const (f 0) // vec v // check "constant" [] | ||
106 | return v | ||
107 | |||
108 | instance (Show a, Storable a) => (Show (Vector a)) where | ||
109 | show v = "fromList " ++ show (toList v) | ||
110 | |||
111 | instance (Show a, Storable a) => (Show (Matrix a)) where | ||
112 | show m = "reshape "++show (cols m) ++ " $ fromList " ++ show (toList (mat m)) | ||
113 | |||
114 | reshape :: Storable a => Int -> Vector a -> Matrix a | ||
115 | reshape n v = M { rows = dim v `div` n | ||
116 | , cols = n | ||
117 | , mat = v | ||
118 | , trMode = NoTrans | ||
119 | , isCOrder = True | ||
120 | } | ||
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs new file mode 100644 index 0000000..8d1c8b6 --- /dev/null +++ b/lib/Data/Packed/Matrix.hs | |||
@@ -0,0 +1 @@ | |||
diff --git a/lib/Data/Packed/Tensor.hs b/lib/Data/Packed/Tensor.hs new file mode 100644 index 0000000..8d1c8b6 --- /dev/null +++ b/lib/Data/Packed/Tensor.hs | |||
@@ -0,0 +1 @@ | |||
diff --git a/lib/Data/Packed/Vector.hs b/lib/Data/Packed/Vector.hs new file mode 100644 index 0000000..8d1c8b6 --- /dev/null +++ b/lib/Data/Packed/Vector.hs | |||
@@ -0,0 +1 @@ | |||