diff options
author | Alberto Ruiz <aruiz@um.es> | 2007-06-04 19:10:28 +0000 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2007-06-04 19:10:28 +0000 |
commit | 7430630fa0504296b796223e01cbd417b88650ef (patch) | |
tree | c338dea8b82867a4c161fcee5817ed2ca27c7258 /lib/Data/Packed/Internal/Vector.hs | |
parent | 0a9817cc481fb09f1962eb2c272125e56a123814 (diff) |
separation of Internal
Diffstat (limited to 'lib/Data/Packed/Internal/Vector.hs')
-rw-r--r-- | lib/Data/Packed/Internal/Vector.hs | 164 |
1 files changed, 164 insertions, 0 deletions
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs new file mode 100644 index 0000000..7dcefeb --- /dev/null +++ b/lib/Data/Packed/Internal/Vector.hs | |||
@@ -0,0 +1,164 @@ | |||
1 | {-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} | ||
2 | ----------------------------------------------------------------------------- | ||
3 | -- | | ||
4 | -- Module : Data.Packed.Internal.Vector | ||
5 | -- Copyright : (c) Alberto Ruiz 2007 | ||
6 | -- License : GPL-style | ||
7 | -- | ||
8 | -- Maintainer : Alberto Ruiz <aruiz@um.es> | ||
9 | -- Stability : provisional | ||
10 | -- Portability : portable (uses FFI) | ||
11 | -- | ||
12 | -- Fundamental types | ||
13 | -- | ||
14 | ----------------------------------------------------------------------------- | ||
15 | |||
16 | module Data.Packed.Internal.Vector where | ||
17 | |||
18 | import Foreign | ||
19 | import Complex | ||
20 | import Control.Monad(when) | ||
21 | import Debug.Trace | ||
22 | import Data.List(transpose,intersperse) | ||
23 | import Data.Typeable | ||
24 | import Data.Maybe(fromJust) | ||
25 | |||
26 | debug x = trace (show x) x | ||
27 | |||
28 | ---------------------------------------------------------------------- | ||
29 | instance (Storable a, RealFloat a) => Storable (Complex a) where -- | ||
30 | alignment x = alignment (realPart x) -- | ||
31 | sizeOf x = 2 * sizeOf (realPart x) -- | ||
32 | peek p = do -- | ||
33 | [re,im] <- peekArray 2 (castPtr p) -- | ||
34 | return (re :+ im) -- | ||
35 | poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- | ||
36 | ---------------------------------------------------------------------- | ||
37 | |||
38 | (//) :: x -> (x -> y) -> y | ||
39 | infixl 0 // | ||
40 | (//) = flip ($) | ||
41 | |||
42 | check msg ls f = do | ||
43 | err <- f | ||
44 | when (err/=0) (error msg) | ||
45 | mapM_ (touchForeignPtr . fptr) ls | ||
46 | return () | ||
47 | |||
48 | class (Storable a, Typeable a) => Field a where | ||
49 | instance (Storable a, Typeable a) => Field a where | ||
50 | |||
51 | isReal w x = typeOf (undefined :: Double) == typeOf (w x) | ||
52 | isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x) | ||
53 | baseOf v = (v `at` 0) | ||
54 | |||
55 | scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b | ||
56 | scast = fromJust . cast | ||
57 | |||
58 | |||
59 | |||
60 | ---------------------------------------------------------------------- | ||
61 | |||
62 | data Vector t = V { dim :: Int | ||
63 | , fptr :: ForeignPtr t | ||
64 | , ptr :: Ptr t | ||
65 | } deriving Typeable | ||
66 | |||
67 | type Vc t s = Int -> Ptr t -> s | ||
68 | infixr 5 :> | ||
69 | type t :> s = Vc t s | ||
70 | |||
71 | vec :: Vector t -> (Vc t s) -> s | ||
72 | vec v f = f (dim v) (ptr v) | ||
73 | |||
74 | createVector :: Storable a => Int -> IO (Vector a) | ||
75 | createVector n = do | ||
76 | when (n <= 0) $ error ("trying to createVector of dim "++show n) | ||
77 | fp <- mallocForeignPtrArray n | ||
78 | let p = unsafeForeignPtrToPtr fp | ||
79 | --putStrLn ("\n---------> V"++show n) | ||
80 | return $ V n fp p | ||
81 | |||
82 | fromList :: Storable a => [a] -> Vector a | ||
83 | fromList l = unsafePerformIO $ do | ||
84 | v <- createVector (length l) | ||
85 | let f _ p = pokeArray p l >> return 0 | ||
86 | f // vec v // check "fromList" [] | ||
87 | return v | ||
88 | |||
89 | toList :: Storable a => Vector a -> [a] | ||
90 | toList v = unsafePerformIO $ peekArray (dim v) (ptr v) | ||
91 | |||
92 | n # l = if length l == n then fromList l else error "# with wrong size" | ||
93 | |||
94 | at' :: Storable a => Vector a -> Int -> a | ||
95 | at' v n = unsafePerformIO $ peekElemOff (ptr v) n | ||
96 | |||
97 | at :: Storable a => Vector a -> Int -> a | ||
98 | at v n | n >= 0 && n < dim v = at' v n | ||
99 | | otherwise = error "vector index out of range" | ||
100 | |||
101 | instance (Show a, Storable a) => (Show (Vector a)) where | ||
102 | show v = (show (dim v))++" # " ++ show (toList v) | ||
103 | |||
104 | -- | creates a Vector taking a number of consecutive toList from another Vector | ||
105 | subVector :: Storable t => Int -- ^ index of the starting element | ||
106 | -> Int -- ^ number of toList to extract | ||
107 | -> Vector t -- ^ source | ||
108 | -> Vector t -- ^ result | ||
109 | subVector k l (v@V {dim=n, ptr=p, fptr=fp}) | ||
110 | | k<0 || k >= n || k+l > n || l < 0 = error "subVector out of range" | ||
111 | | otherwise = unsafePerformIO $ do | ||
112 | r <- createVector l | ||
113 | let f = copyArray (ptr r) (advancePtr p k) l >> return 0 | ||
114 | f // check "subVector" [v] | ||
115 | return r | ||
116 | |||
117 | subVector' k l (v@V {dim=n, ptr=p, fptr=fp}) | ||
118 | | k<0 || k >= n || k+l > n || l < 0 = error "subVector out of range" | ||
119 | | otherwise = v {dim=l, ptr=advancePtr p k} | ||
120 | |||
121 | |||
122 | {- | ||
123 | -- | creates a new Vector by joining a list of Vectors | ||
124 | join :: Field t => [Vector t] -> Vector t | ||
125 | join [] = error "joining an empty list" | ||
126 | join as = unsafePerformIO $ do | ||
127 | let tot = sum (map size as) | ||
128 | p <- mallocForeignPtrArray tot | ||
129 | withForeignPtr p $ \p -> | ||
130 | joiner as tot p | ||
131 | return (V tot p) | ||
132 | where joiner [] _ _ = return () | ||
133 | joiner (V n b : cs) _ p = do | ||
134 | withForeignPtr b $ \b' -> copyArray p b' n | ||
135 | joiner cs 0 (advancePtr p n) | ||
136 | -} | ||
137 | |||
138 | |||
139 | constantG n x = fromList (replicate n x) | ||
140 | |||
141 | constantR :: Int -> Double -> Vector Double | ||
142 | constantR = constantAux cconstantR | ||
143 | |||
144 | constantC :: Int -> Complex Double -> Vector (Complex Double) | ||
145 | constantC = constantAux cconstantC | ||
146 | |||
147 | constantAux fun n x = unsafePerformIO $ do | ||
148 | v <- createVector n | ||
149 | px <- newArray [x] | ||
150 | fun px // vec v // check "constantAux" [] | ||
151 | free px | ||
152 | return v | ||
153 | |||
154 | foreign import ccall safe "aux.h constantR" | ||
155 | cconstantR :: Ptr Double -> Double :> IO Int | ||
156 | |||
157 | foreign import ccall safe "aux.h constantC" | ||
158 | cconstantC :: Ptr (Complex Double) -> Complex Double :> IO Int | ||
159 | |||
160 | constant :: Field a => Int -> a -> Vector a | ||
161 | constant n x | isReal id x = scast $ constantR n (scast x) | ||
162 | | isComp id x = scast $ constantC n (scast x) | ||
163 | | otherwise = constantG n x | ||
164 | |||