summaryrefslogtreecommitdiff
path: root/lib/Data/Packed/Internal/Vector.hs
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2007-06-04 19:10:28 +0000
committerAlberto Ruiz <aruiz@um.es>2007-06-04 19:10:28 +0000
commit7430630fa0504296b796223e01cbd417b88650ef (patch)
treec338dea8b82867a4c161fcee5817ed2ca27c7258 /lib/Data/Packed/Internal/Vector.hs
parent0a9817cc481fb09f1962eb2c272125e56a123814 (diff)
separation of Internal
Diffstat (limited to 'lib/Data/Packed/Internal/Vector.hs')
-rw-r--r--lib/Data/Packed/Internal/Vector.hs164
1 files changed, 164 insertions, 0 deletions
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs
new file mode 100644
index 0000000..7dcefeb
--- /dev/null
+++ b/lib/Data/Packed/Internal/Vector.hs
@@ -0,0 +1,164 @@
1{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-}
2-----------------------------------------------------------------------------
3-- |
4-- Module : Data.Packed.Internal.Vector
5-- Copyright : (c) Alberto Ruiz 2007
6-- License : GPL-style
7--
8-- Maintainer : Alberto Ruiz <aruiz@um.es>
9-- Stability : provisional
10-- Portability : portable (uses FFI)
11--
12-- Fundamental types
13--
14-----------------------------------------------------------------------------
15
16module Data.Packed.Internal.Vector where
17
18import Foreign
19import Complex
20import Control.Monad(when)
21import Debug.Trace
22import Data.List(transpose,intersperse)
23import Data.Typeable
24import Data.Maybe(fromJust)
25
26debug x = trace (show x) x
27
28----------------------------------------------------------------------
29instance (Storable a, RealFloat a) => Storable (Complex a) where --
30 alignment x = alignment (realPart x) --
31 sizeOf x = 2 * sizeOf (realPart x) --
32 peek p = do --
33 [re,im] <- peekArray 2 (castPtr p) --
34 return (re :+ im) --
35 poke p (a :+ b) = pokeArray (castPtr p) [a,b] --
36----------------------------------------------------------------------
37
38(//) :: x -> (x -> y) -> y
39infixl 0 //
40(//) = flip ($)
41
42check msg ls f = do
43 err <- f
44 when (err/=0) (error msg)
45 mapM_ (touchForeignPtr . fptr) ls
46 return ()
47
48class (Storable a, Typeable a) => Field a where
49instance (Storable a, Typeable a) => Field a where
50
51isReal w x = typeOf (undefined :: Double) == typeOf (w x)
52isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x)
53baseOf v = (v `at` 0)
54
55scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b
56scast = fromJust . cast
57
58
59
60----------------------------------------------------------------------
61
62data Vector t = V { dim :: Int
63 , fptr :: ForeignPtr t
64 , ptr :: Ptr t
65 } deriving Typeable
66
67type Vc t s = Int -> Ptr t -> s
68infixr 5 :>
69type t :> s = Vc t s
70
71vec :: Vector t -> (Vc t s) -> s
72vec v f = f (dim v) (ptr v)
73
74createVector :: Storable a => Int -> IO (Vector a)
75createVector n = do
76 when (n <= 0) $ error ("trying to createVector of dim "++show n)
77 fp <- mallocForeignPtrArray n
78 let p = unsafeForeignPtrToPtr fp
79 --putStrLn ("\n---------> V"++show n)
80 return $ V n fp p
81
82fromList :: Storable a => [a] -> Vector a
83fromList l = unsafePerformIO $ do
84 v <- createVector (length l)
85 let f _ p = pokeArray p l >> return 0
86 f // vec v // check "fromList" []
87 return v
88
89toList :: Storable a => Vector a -> [a]
90toList v = unsafePerformIO $ peekArray (dim v) (ptr v)
91
92n # l = if length l == n then fromList l else error "# with wrong size"
93
94at' :: Storable a => Vector a -> Int -> a
95at' v n = unsafePerformIO $ peekElemOff (ptr v) n
96
97at :: Storable a => Vector a -> Int -> a
98at v n | n >= 0 && n < dim v = at' v n
99 | otherwise = error "vector index out of range"
100
101instance (Show a, Storable a) => (Show (Vector a)) where
102 show v = (show (dim v))++" # " ++ show (toList v)
103
104-- | creates a Vector taking a number of consecutive toList from another Vector
105subVector :: Storable t => Int -- ^ index of the starting element
106 -> Int -- ^ number of toList to extract
107 -> Vector t -- ^ source
108 -> Vector t -- ^ result
109subVector k l (v@V {dim=n, ptr=p, fptr=fp})
110 | k<0 || k >= n || k+l > n || l < 0 = error "subVector out of range"
111 | otherwise = unsafePerformIO $ do
112 r <- createVector l
113 let f = copyArray (ptr r) (advancePtr p k) l >> return 0
114 f // check "subVector" [v]
115 return r
116
117subVector' k l (v@V {dim=n, ptr=p, fptr=fp})
118 | k<0 || k >= n || k+l > n || l < 0 = error "subVector out of range"
119 | otherwise = v {dim=l, ptr=advancePtr p k}
120
121
122{-
123-- | creates a new Vector by joining a list of Vectors
124join :: Field t => [Vector t] -> Vector t
125join [] = error "joining an empty list"
126join as = unsafePerformIO $ do
127 let tot = sum (map size as)
128 p <- mallocForeignPtrArray tot
129 withForeignPtr p $ \p ->
130 joiner as tot p
131 return (V tot p)
132 where joiner [] _ _ = return ()
133 joiner (V n b : cs) _ p = do
134 withForeignPtr b $ \b' -> copyArray p b' n
135 joiner cs 0 (advancePtr p n)
136-}
137
138
139constantG n x = fromList (replicate n x)
140
141constantR :: Int -> Double -> Vector Double
142constantR = constantAux cconstantR
143
144constantC :: Int -> Complex Double -> Vector (Complex Double)
145constantC = constantAux cconstantC
146
147constantAux fun n x = unsafePerformIO $ do
148 v <- createVector n
149 px <- newArray [x]
150 fun px // vec v // check "constantAux" []
151 free px
152 return v
153
154foreign import ccall safe "aux.h constantR"
155 cconstantR :: Ptr Double -> Double :> IO Int
156
157foreign import ccall safe "aux.h constantC"
158 cconstantC :: Ptr (Complex Double) -> Complex Double :> IO Int
159
160constant :: Field a => Int -> a -> Vector a
161constant n x | isReal id x = scast $ constantR n (scast x)
162 | isComp id x = scast $ constantC n (scast x)
163 | otherwise = constantG n x
164