diff options
Diffstat (limited to 'packages/base/src/Internal')
-rw-r--r-- | packages/base/src/Internal/Devel.hs | 89 | ||||
-rw-r--r-- | packages/base/src/Internal/LAPACK.hs | 54 | ||||
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 72 | ||||
-rw-r--r-- | packages/base/src/Internal/Sparse.hs | 4 | ||||
-rw-r--r-- | packages/base/src/Internal/Util.hs | 6 | ||||
-rw-r--r-- | packages/base/src/Internal/Vector.hs | 12 | ||||
-rw-r--r-- | packages/base/src/Internal/Vectorized.hs | 38 |
7 files changed, 153 insertions, 122 deletions
diff --git a/packages/base/src/Internal/Devel.hs b/packages/base/src/Internal/Devel.hs index b8e04ef..4be0afd 100644 --- a/packages/base/src/Internal/Devel.hs +++ b/packages/base/src/Internal/Devel.hs | |||
@@ -1,4 +1,5 @@ | |||
1 | {-# LANGUAGE TypeOperators #-} | 1 | {-# LANGUAGE TypeOperators #-} |
2 | {-# LANGUAGE TypeFamilies #-} | ||
2 | 3 | ||
3 | -- | | 4 | -- | |
4 | -- Module : Internal.Devel | 5 | -- Module : Internal.Devel |
@@ -16,68 +17,14 @@ import Foreign.C.Types ( CInt ) | |||
16 | --import Foreign.Storable.Complex () | 17 | --import Foreign.Storable.Complex () |
17 | import Foreign.Ptr(Ptr) | 18 | import Foreign.Ptr(Ptr) |
18 | import Control.Exception as E ( SomeException, catch ) | 19 | import Control.Exception as E ( SomeException, catch ) |
19 | 20 | import Internal.Vector(Vector,avec,arrvec) | |
21 | import Foreign.Storable(Storable) | ||
20 | 22 | ||
21 | -- | postfix function application (@flip ($)@) | 23 | -- | postfix function application (@flip ($)@) |
22 | (//) :: x -> (x -> y) -> y | 24 | (//) :: x -> (x -> y) -> y |
23 | infixl 0 // | 25 | infixl 0 // |
24 | (//) = flip ($) | 26 | (//) = flip ($) |
25 | 27 | ||
26 | -- hmm.. | ||
27 | ww2 w1 o1 w2 o2 f = w1 o1 $ w2 o2 . f | ||
28 | ww3 w1 o1 w2 o2 w3 o3 f = w1 o1 $ ww2 w2 o2 w3 o3 . f | ||
29 | ww4 w1 o1 w2 o2 w3 o3 w4 o4 f = w1 o1 $ ww3 w2 o2 w3 o3 w4 o4 . f | ||
30 | ww5 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 f = w1 o1 $ ww4 w2 o2 w3 o3 w4 o4 w5 o5 . f | ||
31 | ww6 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 f = w1 o1 $ ww5 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 . f | ||
32 | ww7 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 f = w1 o1 $ ww6 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 . f | ||
33 | ww8 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 f = w1 o1 $ ww7 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 . f | ||
34 | ww9 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 f = w1 o1 $ ww8 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 . f | ||
35 | ww10 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 f = w1 o1 $ ww9 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 . f | ||
36 | |||
37 | type Adapt f t r = t -> ((f -> r) -> IO()) -> IO() | ||
38 | |||
39 | type Adapt1 f t1 = Adapt f t1 (IO CInt) -> t1 -> String -> IO() | ||
40 | type Adapt2 f t1 r1 t2 = Adapt f t1 r1 -> t1 -> Adapt1 r1 t2 | ||
41 | type Adapt3 f t1 r1 t2 r2 t3 = Adapt f t1 r1 -> t1 -> Adapt2 r1 t2 r2 t3 | ||
42 | type Adapt4 f t1 r1 t2 r2 t3 r3 t4 = Adapt f t1 r1 -> t1 -> Adapt3 r1 t2 r2 t3 r3 t4 | ||
43 | type Adapt5 f t1 r1 t2 r2 t3 r3 t4 r4 t5 = Adapt f t1 r1 -> t1 -> Adapt4 r1 t2 r2 t3 r3 t4 r4 t5 | ||
44 | type Adapt6 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 = Adapt f t1 r1 -> t1 -> Adapt5 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 | ||
45 | type Adapt7 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 = Adapt f t1 r1 -> t1 -> Adapt6 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 | ||
46 | type Adapt8 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 = Adapt f t1 r1 -> t1 -> Adapt7 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 | ||
47 | type Adapt9 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 = Adapt f t1 r1 -> t1 -> Adapt8 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 | ||
48 | type Adapt10 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 r9 t10 = Adapt f t1 r1 -> t1 -> Adapt9 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 r9 t10 | ||
49 | |||
50 | app1 :: f -> Adapt1 f t1 | ||
51 | app2 :: f -> Adapt2 f t1 r1 t2 | ||
52 | app3 :: f -> Adapt3 f t1 r1 t2 r2 t3 | ||
53 | app4 :: f -> Adapt4 f t1 r1 t2 r2 t3 r3 t4 | ||
54 | app5 :: f -> Adapt5 f t1 r1 t2 r2 t3 r3 t4 r4 t5 | ||
55 | app6 :: f -> Adapt6 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 | ||
56 | app7 :: f -> Adapt7 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 | ||
57 | app8 :: f -> Adapt8 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 | ||
58 | app9 :: f -> Adapt9 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 | ||
59 | app10 :: f -> Adapt10 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 r9 t10 | ||
60 | |||
61 | app1 f w1 o1 s = w1 o1 $ \a1 -> f // a1 // check s | ||
62 | app2 f w1 o1 w2 o2 s = ww2 w1 o1 w2 o2 $ \a1 a2 -> f // a1 // a2 // check s | ||
63 | app3 f w1 o1 w2 o2 w3 o3 s = ww3 w1 o1 w2 o2 w3 o3 $ | ||
64 | \a1 a2 a3 -> f // a1 // a2 // a3 // check s | ||
65 | app4 f w1 o1 w2 o2 w3 o3 w4 o4 s = ww4 w1 o1 w2 o2 w3 o3 w4 o4 $ | ||
66 | \a1 a2 a3 a4 -> f // a1 // a2 // a3 // a4 // check s | ||
67 | app5 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 s = ww5 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 $ | ||
68 | \a1 a2 a3 a4 a5 -> f // a1 // a2 // a3 // a4 // a5 // check s | ||
69 | app6 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 s = ww6 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 $ | ||
70 | \a1 a2 a3 a4 a5 a6 -> f // a1 // a2 // a3 // a4 // a5 // a6 // check s | ||
71 | app7 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 s = ww7 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 $ | ||
72 | \a1 a2 a3 a4 a5 a6 a7 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // check s | ||
73 | app8 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 s = ww8 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 $ | ||
74 | \a1 a2 a3 a4 a5 a6 a7 a8 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // a8 // check s | ||
75 | app9 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 s = ww9 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 $ | ||
76 | \a1 a2 a3 a4 a5 a6 a7 a8 a9 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // a8 // a9 // check s | ||
77 | app10 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 s = ww10 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 $ | ||
78 | \a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // a8 // a9 // a10 // check s | ||
79 | |||
80 | |||
81 | 28 | ||
82 | -- GSL error codes are <= 1024 | 29 | -- GSL error codes are <= 1024 |
83 | -- | error codes for the auxiliary functions required by the wrappers | 30 | -- | error codes for the auxiliary functions required by the wrappers |
@@ -104,6 +51,11 @@ check msg f = do | |||
104 | when (err/=0) $ error (msg++": "++errorCode err) | 51 | when (err/=0) $ error (msg++": "++errorCode err) |
105 | return () | 52 | return () |
106 | 53 | ||
54 | |||
55 | -- | postfix error code check | ||
56 | infixl 0 #| | ||
57 | (#|) = flip check | ||
58 | |||
107 | -- | Error capture and conversion to Maybe | 59 | -- | Error capture and conversion to Maybe |
108 | mbCatch :: IO x -> IO (Maybe x) | 60 | mbCatch :: IO x -> IO (Maybe x) |
109 | mbCatch act = E.catch (Just `fmap` act) f | 61 | mbCatch act = E.catch (Just `fmap` act) f |
@@ -124,4 +76,27 @@ type (:>) t r = CV t r | |||
124 | type (::>) t r = OM t r | 76 | type (::>) t r = OM t r |
125 | type (..>) t r = CM t r | 77 | type (..>) t r = CM t r |
126 | 78 | ||
79 | class TransArray c | ||
80 | where | ||
81 | type Trans c b | ||
82 | type TransRaw c b | ||
83 | type Elem c | ||
84 | apply :: (Trans c b) -> c -> b | ||
85 | applyRaw :: (TransRaw c b) -> c -> b | ||
86 | applyArray :: (Ptr CInt -> Ptr (Elem c) -> b) -> c -> b | ||
87 | infixl 1 `apply`, `applyRaw`, `applyArray` | ||
88 | |||
89 | instance Storable t => TransArray (Vector t) | ||
90 | where | ||
91 | type Trans (Vector t) b = CInt -> Ptr t -> b | ||
92 | type TransRaw (Vector t) b = CInt -> Ptr t -> b | ||
93 | type Elem (Vector t) = t | ||
94 | apply = avec | ||
95 | {-# INLINE apply #-} | ||
96 | applyRaw = avec | ||
97 | {-# INLINE applyRaw #-} | ||
98 | applyArray = arrvec | ||
99 | {-# INLINE applyArray #-} | ||
100 | |||
101 | |||
127 | 102 | ||
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index 8df568d..3a9abbb 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs | |||
@@ -17,7 +17,7 @@ module Internal.LAPACK where | |||
17 | 17 | ||
18 | import Internal.Devel | 18 | import Internal.Devel |
19 | import Internal.Vector | 19 | import Internal.Vector |
20 | import Internal.Matrix | 20 | import Internal.Matrix hiding ((#)) |
21 | import Internal.Conversion | 21 | import Internal.Conversion |
22 | import Internal.Element | 22 | import Internal.Element |
23 | import Foreign.Ptr(nullPtr) | 23 | import Foreign.Ptr(nullPtr) |
@@ -27,6 +27,16 @@ import System.IO.Unsafe(unsafePerformIO) | |||
27 | 27 | ||
28 | ----------------------------------------------------------------------------------- | 28 | ----------------------------------------------------------------------------------- |
29 | 29 | ||
30 | infixl 1 # | ||
31 | a # b = applyRaw a b | ||
32 | {-# INLINE (#) #-} | ||
33 | |||
34 | infixl 1 #! | ||
35 | a #! b = apply a b | ||
36 | {-# INLINE (#!) #-} | ||
37 | |||
38 | ----------------------------------------------------------------------------------- | ||
39 | |||
30 | type TMMM t = t ..> t ..> t ..> Ok | 40 | type TMMM t = t ..> t ..> t ..> Ok |
31 | 41 | ||
32 | type F = Float | 42 | type F = Float |
@@ -49,7 +59,7 @@ multiplyAux f st a b = unsafePerformIO $ do | |||
49 | when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ | 59 | when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ |
50 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) | 60 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) |
51 | s <- createMatrix ColumnMajor (rows a) (cols b) | 61 | s <- createMatrix ColumnMajor (rows a) (cols b) |
52 | app3 (f (isT a) (isT b)) mat (tt a) mat (tt b) mat s st | 62 | f (isT a) (isT b) # (tt a) # (tt b) # s #| st |
53 | return s | 63 | return s |
54 | 64 | ||
55 | -- | Matrix product based on BLAS's /dgemm/. | 65 | -- | Matrix product based on BLAS's /dgemm/. |
@@ -73,7 +83,7 @@ multiplyI m a b = unsafePerformIO $ do | |||
73 | when (cols a /= rows b) $ error $ | 83 | when (cols a /= rows b) $ error $ |
74 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b | 84 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b |
75 | s <- createMatrix ColumnMajor (rows a) (cols b) | 85 | s <- createMatrix ColumnMajor (rows a) (cols b) |
76 | app3 (c_multiplyI m) omat a omat b omat s "c_multiplyI" | 86 | c_multiplyI m #! a #! b #! s #|"c_multiplyI" |
77 | return s | 87 | return s |
78 | 88 | ||
79 | multiplyL :: Z -> Matrix Z -> Matrix Z -> Matrix Z | 89 | multiplyL :: Z -> Matrix Z -> Matrix Z -> Matrix Z |
@@ -81,7 +91,7 @@ multiplyL m a b = unsafePerformIO $ do | |||
81 | when (cols a /= rows b) $ error $ | 91 | when (cols a /= rows b) $ error $ |
82 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b | 92 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b |
83 | s <- createMatrix ColumnMajor (rows a) (cols b) | 93 | s <- createMatrix ColumnMajor (rows a) (cols b) |
84 | app3 (c_multiplyL m) omat a omat b omat s "c_multiplyL" | 94 | c_multiplyL m #! a #! b #! s #|"c_multiplyL" |
85 | return s | 95 | return s |
86 | 96 | ||
87 | ----------------------------------------------------------------------------- | 97 | ----------------------------------------------------------------------------- |
@@ -113,7 +123,7 @@ svdAux f st x = unsafePerformIO $ do | |||
113 | u <- createMatrix ColumnMajor r r | 123 | u <- createMatrix ColumnMajor r r |
114 | s <- createVector (min r c) | 124 | s <- createVector (min r c) |
115 | v <- createMatrix ColumnMajor c c | 125 | v <- createMatrix ColumnMajor c c |
116 | app4 f mat x mat u vec s mat v st | 126 | f # x # u # s # v #| st |
117 | return (u,s,v) | 127 | return (u,s,v) |
118 | where r = rows x | 128 | where r = rows x |
119 | c = cols x | 129 | c = cols x |
@@ -139,7 +149,7 @@ thinSVDAux f st x = unsafePerformIO $ do | |||
139 | u <- createMatrix ColumnMajor r q | 149 | u <- createMatrix ColumnMajor r q |
140 | s <- createVector q | 150 | s <- createVector q |
141 | v <- createMatrix ColumnMajor q c | 151 | v <- createMatrix ColumnMajor q c |
142 | app4 f mat x mat u vec s mat v st | 152 | f # x # u # s # v #| st |
143 | return (u,s,v) | 153 | return (u,s,v) |
144 | where r = rows x | 154 | where r = rows x |
145 | c = cols x | 155 | c = cols x |
@@ -164,7 +174,7 @@ svCd = svAux zgesdd "svCd" . fmat | |||
164 | 174 | ||
165 | svAux f st x = unsafePerformIO $ do | 175 | svAux f st x = unsafePerformIO $ do |
166 | s <- createVector q | 176 | s <- createVector q |
167 | app2 g mat x vec s st | 177 | g # x # s #| st |
168 | return s | 178 | return s |
169 | where r = rows x | 179 | where r = rows x |
170 | c = cols x | 180 | c = cols x |
@@ -183,7 +193,7 @@ rightSVC = rightSVAux zgesvd "rightSVC" . fmat | |||
183 | rightSVAux f st x = unsafePerformIO $ do | 193 | rightSVAux f st x = unsafePerformIO $ do |
184 | s <- createVector q | 194 | s <- createVector q |
185 | v <- createMatrix ColumnMajor c c | 195 | v <- createMatrix ColumnMajor c c |
186 | app3 g mat x vec s mat v st | 196 | g # x # s # v #| st |
187 | return (s,v) | 197 | return (s,v) |
188 | where r = rows x | 198 | where r = rows x |
189 | c = cols x | 199 | c = cols x |
@@ -202,7 +212,7 @@ leftSVC = leftSVAux zgesvd "leftSVC" . fmat | |||
202 | leftSVAux f st x = unsafePerformIO $ do | 212 | leftSVAux f st x = unsafePerformIO $ do |
203 | u <- createMatrix ColumnMajor r r | 213 | u <- createMatrix ColumnMajor r r |
204 | s <- createVector q | 214 | s <- createVector q |
205 | app3 g mat x mat u vec s st | 215 | g # x # u # s #| st |
206 | return (u,s) | 216 | return (u,s) |
207 | where r = rows x | 217 | where r = rows x |
208 | c = cols x | 218 | c = cols x |
@@ -219,7 +229,7 @@ foreign import ccall unsafe "eig_l_H" zheev :: CInt -> C ..> R :> C ..> Ok | |||
219 | eigAux f st m = unsafePerformIO $ do | 229 | eigAux f st m = unsafePerformIO $ do |
220 | l <- createVector r | 230 | l <- createVector r |
221 | v <- createMatrix ColumnMajor r r | 231 | v <- createMatrix ColumnMajor r r |
222 | app3 g mat m vec l mat v st | 232 | g # m # l # v #| st |
223 | return (l,v) | 233 | return (l,v) |
224 | where r = rows m | 234 | where r = rows m |
225 | g ra ca pa = f ra ca pa 0 0 nullPtr | 235 | g ra ca pa = f ra ca pa 0 0 nullPtr |
@@ -232,7 +242,7 @@ eigC = eigAux zgeev "eigC" . fmat | |||
232 | 242 | ||
233 | eigOnlyAux f st m = unsafePerformIO $ do | 243 | eigOnlyAux f st m = unsafePerformIO $ do |
234 | l <- createVector r | 244 | l <- createVector r |
235 | app2 g mat m vec l st | 245 | g # m # l #| st |
236 | return l | 246 | return l |
237 | where r = rows m | 247 | where r = rows m |
238 | g ra ca pa nl pl = f ra ca pa 0 0 nullPtr nl pl 0 0 nullPtr | 248 | g ra ca pa nl pl = f ra ca pa 0 0 nullPtr nl pl 0 0 nullPtr |
@@ -255,7 +265,7 @@ eigRaux :: Matrix Double -> (Vector (Complex Double), Matrix Double) | |||
255 | eigRaux m = unsafePerformIO $ do | 265 | eigRaux m = unsafePerformIO $ do |
256 | l <- createVector r | 266 | l <- createVector r |
257 | v <- createMatrix ColumnMajor r r | 267 | v <- createMatrix ColumnMajor r r |
258 | app3 g mat m vec l mat v "eigR" | 268 | g # m # l # v #| "eigR" |
259 | return (l,v) | 269 | return (l,v) |
260 | where r = rows m | 270 | where r = rows m |
261 | g ra ca pa = dgeev ra ca pa 0 0 nullPtr | 271 | g ra ca pa = dgeev ra ca pa 0 0 nullPtr |
@@ -282,7 +292,7 @@ eigOnlyR = fixeig1 . eigOnlyAux dgeev "eigOnlyR" . fmat | |||
282 | eigSHAux f st m = unsafePerformIO $ do | 292 | eigSHAux f st m = unsafePerformIO $ do |
283 | l <- createVector r | 293 | l <- createVector r |
284 | v <- createMatrix ColumnMajor r r | 294 | v <- createMatrix ColumnMajor r r |
285 | app3 f mat m vec l mat v st | 295 | f # m # l # v #| st |
286 | return (l,v) | 296 | return (l,v) |
287 | where r = rows m | 297 | where r = rows m |
288 | 298 | ||
@@ -332,7 +342,7 @@ foreign import ccall unsafe "cholSolveC_l" zpotrs :: TMMM C | |||
332 | linearSolveSQAux g f st a b | 342 | linearSolveSQAux g f st a b |
333 | | n1==n2 && n1==r = unsafePerformIO . g $ do | 343 | | n1==n2 && n1==r = unsafePerformIO . g $ do |
334 | s <- createMatrix ColumnMajor r c | 344 | s <- createMatrix ColumnMajor r c |
335 | app3 f mat a mat b mat s st | 345 | f # a # b # s #| st |
336 | return s | 346 | return s |
337 | | otherwise = error $ st ++ " of nonsquare matrix" | 347 | | otherwise = error $ st ++ " of nonsquare matrix" |
338 | where n1 = rows a | 348 | where n1 = rows a |
@@ -371,7 +381,7 @@ foreign import ccall unsafe "linearSolveSVDC_l" zgelss :: Double -> TMMM C | |||
371 | 381 | ||
372 | linearSolveAux f st a b = unsafePerformIO $ do | 382 | linearSolveAux f st a b = unsafePerformIO $ do |
373 | r <- createMatrix ColumnMajor (max m n) nrhs | 383 | r <- createMatrix ColumnMajor (max m n) nrhs |
374 | app3 f mat a mat b mat r st | 384 | f # a # b # r #| st |
375 | return r | 385 | return r |
376 | where m = rows a | 386 | where m = rows a |
377 | n = cols a | 387 | n = cols a |
@@ -412,7 +422,7 @@ foreign import ccall unsafe "chol_l_S" dpotrf :: TMM R | |||
412 | 422 | ||
413 | cholAux f st a = do | 423 | cholAux f st a = do |
414 | r <- createMatrix ColumnMajor n n | 424 | r <- createMatrix ColumnMajor n n |
415 | app2 f mat a mat r st | 425 | f # a # r #| st |
416 | return r | 426 | return r |
417 | where n = rows a | 427 | where n = rows a |
418 | 428 | ||
@@ -450,7 +460,7 @@ qrC = qrAux zgeqr2 "qrC" . fmat | |||
450 | qrAux f st a = unsafePerformIO $ do | 460 | qrAux f st a = unsafePerformIO $ do |
451 | r <- createMatrix ColumnMajor m n | 461 | r <- createMatrix ColumnMajor m n |
452 | tau <- createVector mn | 462 | tau <- createVector mn |
453 | app3 f mat a vec tau mat r st | 463 | f # a # tau # r #| st |
454 | return (r,tau) | 464 | return (r,tau) |
455 | where | 465 | where |
456 | m = rows a | 466 | m = rows a |
@@ -469,7 +479,7 @@ qrgrC = qrgrAux zungqr "qrgrC" | |||
469 | 479 | ||
470 | qrgrAux f st n (a, tau) = unsafePerformIO $ do | 480 | qrgrAux f st n (a, tau) = unsafePerformIO $ do |
471 | res <- createMatrix ColumnMajor (rows a) n | 481 | res <- createMatrix ColumnMajor (rows a) n |
472 | app3 f mat (fmat a) vec (subVector 0 n tau') mat res st | 482 | f # (fmat a) # (subVector 0 n tau') # res #| st |
473 | return res | 483 | return res |
474 | where | 484 | where |
475 | tau' = vjoin [tau, constantD 0 n] | 485 | tau' = vjoin [tau, constantD 0 n] |
@@ -489,7 +499,7 @@ hessC = hessAux zgehrd "hessC" . fmat | |||
489 | hessAux f st a = unsafePerformIO $ do | 499 | hessAux f st a = unsafePerformIO $ do |
490 | r <- createMatrix ColumnMajor m n | 500 | r <- createMatrix ColumnMajor m n |
491 | tau <- createVector (mn-1) | 501 | tau <- createVector (mn-1) |
492 | app3 f mat a vec tau mat r st | 502 | f # a # tau # r #| st |
493 | return (r,tau) | 503 | return (r,tau) |
494 | where m = rows a | 504 | where m = rows a |
495 | n = cols a | 505 | n = cols a |
@@ -510,7 +520,7 @@ schurC = schurAux zgees "schurC" . fmat | |||
510 | schurAux f st a = unsafePerformIO $ do | 520 | schurAux f st a = unsafePerformIO $ do |
511 | u <- createMatrix ColumnMajor n n | 521 | u <- createMatrix ColumnMajor n n |
512 | s <- createMatrix ColumnMajor n n | 522 | s <- createMatrix ColumnMajor n n |
513 | app3 f mat a mat u mat s st | 523 | f # a # u # s #| st |
514 | return (u,s) | 524 | return (u,s) |
515 | where n = rows a | 525 | where n = rows a |
516 | 526 | ||
@@ -529,7 +539,7 @@ luC = luAux zgetrf "luC" . fmat | |||
529 | luAux f st a = unsafePerformIO $ do | 539 | luAux f st a = unsafePerformIO $ do |
530 | lu <- createMatrix ColumnMajor n m | 540 | lu <- createMatrix ColumnMajor n m |
531 | piv <- createVector (min n m) | 541 | piv <- createVector (min n m) |
532 | app3 f mat a vec piv mat lu st | 542 | f # a # piv # lu #| st |
533 | return (lu, map (pred.round) (toList piv)) | 543 | return (lu, map (pred.round) (toList piv)) |
534 | where n = rows a | 544 | where n = rows a |
535 | m = cols a | 545 | m = cols a |
@@ -552,7 +562,7 @@ lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv (fmat b) | |||
552 | lusAux f st a piv b | 562 | lusAux f st a piv b |
553 | | n1==n2 && n2==n =unsafePerformIO $ do | 563 | | n1==n2 && n2==n =unsafePerformIO $ do |
554 | x <- createMatrix ColumnMajor n m | 564 | x <- createMatrix ColumnMajor n m |
555 | app4 f mat a vec piv' mat b mat x st | 565 | f # a # piv' # b # x #| st |
556 | return x | 566 | return x |
557 | | otherwise = error $ st ++ " on LU factorization of nonsquare matrix" | 567 | | otherwise = error $ st ++ " on LU factorization of nonsquare matrix" |
558 | where n1 = rows a | 568 | where n1 = rows a |
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 8f8c219..db0a609 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -3,6 +3,8 @@ | |||
3 | {-# LANGUAGE FlexibleInstances #-} | 3 | {-# LANGUAGE FlexibleInstances #-} |
4 | {-# LANGUAGE BangPatterns #-} | 4 | {-# LANGUAGE BangPatterns #-} |
5 | {-# LANGUAGE TypeOperators #-} | 5 | {-# LANGUAGE TypeOperators #-} |
6 | {-# LANGUAGE TypeFamilies #-} | ||
7 | |||
6 | 8 | ||
7 | -- | | 9 | -- | |
8 | -- Module : Internal.Matrix | 10 | -- Module : Internal.Matrix |
@@ -18,7 +20,7 @@ module Internal.Matrix where | |||
18 | 20 | ||
19 | import Internal.Vector | 21 | import Internal.Vector |
20 | import Internal.Devel | 22 | import Internal.Devel |
21 | import Internal.Vectorized | 23 | import Internal.Vectorized hiding ((#)) |
22 | import Foreign.Marshal.Alloc ( free ) | 24 | import Foreign.Marshal.Alloc ( free ) |
23 | import Foreign.Marshal.Array(newArray) | 25 | import Foreign.Marshal.Array(newArray) |
24 | import Foreign.Ptr ( Ptr ) | 26 | import Foreign.Ptr ( Ptr ) |
@@ -79,8 +81,6 @@ data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int | |||
79 | -- RowMajor: preferred by C, fdat may require a transposition | 81 | -- RowMajor: preferred by C, fdat may require a transposition |
80 | -- ColumnMajor: preferred by LAPACK, cdat may require a transposition | 82 | -- ColumnMajor: preferred by LAPACK, cdat may require a transposition |
81 | 83 | ||
82 | --cdat = xdat | ||
83 | --fdat = xdat | ||
84 | 84 | ||
85 | rows :: Matrix t -> Int | 85 | rows :: Matrix t -> Int |
86 | rows = irows | 86 | rows = irows |
@@ -129,6 +129,48 @@ omat a f = | |||
129 | g (fi (rows a)) (fi (cols a)) (stepRow a) (stepCol a) p | 129 | g (fi (rows a)) (fi (cols a)) (stepRow a) (stepCol a) p |
130 | f m | 130 | f m |
131 | 131 | ||
132 | -------------------------------------------------------------------------------- | ||
133 | |||
134 | {-# INLINE amatr #-} | ||
135 | amatr :: Storable a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b | ||
136 | amatr f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c)) | ||
137 | where | ||
138 | r = fromIntegral (rows x) | ||
139 | c = fromIntegral (cols x) | ||
140 | |||
141 | {-# INLINE amat #-} | ||
142 | amat :: Storable a => (CInt -> CInt -> CInt -> CInt -> Ptr a -> b) -> Matrix a -> b | ||
143 | amat f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c sr sc)) | ||
144 | where | ||
145 | r = fromIntegral (rows x) | ||
146 | c = fromIntegral (cols x) | ||
147 | sr = stepRow x | ||
148 | sc = stepCol x | ||
149 | |||
150 | {-# INLINE arrmat #-} | ||
151 | arrmat :: Storable a => (Ptr CInt -> Ptr a -> b) -> Matrix a -> b | ||
152 | arrmat f x = inlinePerformIO (unsafeWith s (\p -> unsafeWith (xdat x) (return . f p))) | ||
153 | where | ||
154 | s = fromList [fi (rows x), fi (cols x), stepRow x, stepCol x] | ||
155 | |||
156 | |||
157 | instance Storable t => TransArray (Matrix t) | ||
158 | where | ||
159 | type Elem (Matrix t) = t | ||
160 | type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b | ||
161 | type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b | ||
162 | apply = amat | ||
163 | {-# INLINE apply #-} | ||
164 | applyRaw = amatr | ||
165 | {-# INLINE applyRaw #-} | ||
166 | applyArray = arrmat | ||
167 | {-# INLINE applyArray #-} | ||
168 | |||
169 | infixl 1 # | ||
170 | a # b = apply a b | ||
171 | {-# INLINE (#) #-} | ||
172 | |||
173 | -------------------------------------------------------------------------------- | ||
132 | 174 | ||
133 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. | 175 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. |
134 | 176 | ||
@@ -139,12 +181,6 @@ fromList [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0] | |||
139 | flatten :: Element t => Matrix t -> Vector t | 181 | flatten :: Element t => Matrix t -> Vector t |
140 | flatten = xdat . cmat | 182 | flatten = xdat . cmat |
141 | 183 | ||
142 | {- | ||
143 | type Mt t s = Int -> Int -> Ptr t -> s | ||
144 | |||
145 | infixr 6 ::> | ||
146 | type t ::> s = Mt t s | ||
147 | -} | ||
148 | 184 | ||
149 | -- | the inverse of 'Data.Packed.Matrix.fromLists' | 185 | -- | the inverse of 'Data.Packed.Matrix.fromLists' |
150 | toLists :: (Element t) => Matrix t -> [[t]] | 186 | toLists :: (Element t) => Matrix t -> [[t]] |
@@ -445,7 +481,7 @@ extractAux f m moder vr modec vc = do | |||
445 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr | 481 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr |
446 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc | 482 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc |
447 | r <- createMatrix RowMajor nr nc | 483 | r <- createMatrix RowMajor nr nc |
448 | app4 (f moder modec) vec vr vec vc omat m omat r "extractAux" | 484 | f moder modec # vr # vc # m # r #|"extract" |
449 | return r | 485 | return r |
450 | 486 | ||
451 | type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) | 487 | type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) |
@@ -459,7 +495,7 @@ foreign import ccall unsafe "extractL" c_extractL :: Extr Z | |||
459 | 495 | ||
460 | --------------------------------------------------------------- | 496 | --------------------------------------------------------------- |
461 | 497 | ||
462 | setRectAux f i j m r = app2 (f (fi i) (fi j)) omat m omat r "setRect" | 498 | setRectAux f i j m r = f (fi i) (fi j) # m # r #|"setRect" |
463 | 499 | ||
464 | type SetRect x = I -> I -> x ::> x::> Ok | 500 | type SetRect x = I -> I -> x ::> x::> Ok |
465 | 501 | ||
@@ -474,7 +510,7 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z | |||
474 | 510 | ||
475 | sortG f v = unsafePerformIO $ do | 511 | sortG f v = unsafePerformIO $ do |
476 | r <- createVector (dim v) | 512 | r <- createVector (dim v) |
477 | app2 f vec v vec r "sortG" | 513 | f # v # r #|"sortG" |
478 | return r | 514 | return r |
479 | 515 | ||
480 | sortIdxD = sortG c_sort_indexD | 516 | sortIdxD = sortG c_sort_indexD |
@@ -501,7 +537,7 @@ foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok | |||
501 | 537 | ||
502 | compareG f u v = unsafePerformIO $ do | 538 | compareG f u v = unsafePerformIO $ do |
503 | r <- createVector (dim v) | 539 | r <- createVector (dim v) |
504 | app3 f vec u vec v vec r "compareG" | 540 | f # u # v # r #|"compareG" |
505 | return r | 541 | return r |
506 | 542 | ||
507 | compareD = compareG c_compareD | 543 | compareD = compareG c_compareD |
@@ -518,7 +554,7 @@ foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok | |||
518 | 554 | ||
519 | selectG f c u v w = unsafePerformIO $ do | 555 | selectG f c u v w = unsafePerformIO $ do |
520 | r <- createVector (dim v) | 556 | r <- createVector (dim v) |
521 | app5 f vec c vec u vec v vec w vec r "selectG" | 557 | f # c # u # v # w # r #|"selectG" |
522 | return r | 558 | return r |
523 | 559 | ||
524 | selectD = selectG c_selectD | 560 | selectD = selectG c_selectD |
@@ -541,7 +577,7 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z | |||
541 | 577 | ||
542 | remapG f i j m = unsafePerformIO $ do | 578 | remapG f i j m = unsafePerformIO $ do |
543 | r <- createMatrix RowMajor (rows i) (cols i) | 579 | r <- createMatrix RowMajor (rows i) (cols i) |
544 | app4 f omat i omat j omat m omat r "remapG" | 580 | f # i # j # m # r #|"remapG" |
545 | return r | 581 | return r |
546 | 582 | ||
547 | remapD = remapG c_remapD | 583 | remapD = remapG c_remapD |
@@ -564,7 +600,7 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z | |||
564 | 600 | ||
565 | rowOpAux f c x i1 i2 j1 j2 m = do | 601 | rowOpAux f c x i1 i2 j1 j2 m = do |
566 | px <- newArray [x] | 602 | px <- newArray [x] |
567 | app1 (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) omat m "rowOp" | 603 | f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2) # m #|"rowOp" |
568 | free px | 604 | free px |
569 | 605 | ||
570 | type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok | 606 | type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok |
@@ -580,7 +616,7 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z | |||
580 | 616 | ||
581 | -------------------------------------------------------------------------------- | 617 | -------------------------------------------------------------------------------- |
582 | 618 | ||
583 | gemmg f u v m1 m2 m3 = app5 f vec u vec v omat m1 omat m2 omat m3 "gemmg" | 619 | gemmg f u v m1 m2 m3 = f # u # v # m1 # m2 # m3 #|"gemmg" |
584 | 620 | ||
585 | type Tgemm x = x :> I :> x ::> x ::> x ::> Ok | 621 | type Tgemm x = x :> I :> x ::> x ::> x ::> Ok |
586 | 622 | ||
@@ -608,7 +644,7 @@ saveMatrix | |||
608 | saveMatrix name format m = do | 644 | saveMatrix name format m = do |
609 | cname <- newCString name | 645 | cname <- newCString name |
610 | cformat <- newCString format | 646 | cformat <- newCString format |
611 | app1 (c_saveMatrix cname cformat) mat m "saveMatrix" | 647 | c_saveMatrix cname cformat `applyRaw` m #|"saveMatrix" |
612 | free cname | 648 | free cname |
613 | free cformat | 649 | free cformat |
614 | return () | 650 | return () |
diff --git a/packages/base/src/Internal/Sparse.hs b/packages/base/src/Internal/Sparse.hs index b365c15..eb4ee1b 100644 --- a/packages/base/src/Internal/Sparse.hs +++ b/packages/base/src/Internal/Sparse.hs | |||
@@ -145,13 +145,13 @@ gmXv :: GMatrix -> Vector Double -> Vector Double | |||
145 | gmXv SparseR { gmCSR = CSR{..}, .. } v = unsafePerformIO $ do | 145 | gmXv SparseR { gmCSR = CSR{..}, .. } v = unsafePerformIO $ do |
146 | dim v /= nCols ~!~ printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) | 146 | dim v /= nCols ~!~ printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) |
147 | r <- createVector nRows | 147 | r <- createVector nRows |
148 | app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv" | 148 | c_smXv # csrVals # csrCols # csrRows # v # r #|"CSRXv" |
149 | return r | 149 | return r |
150 | 150 | ||
151 | gmXv SparseC { gmCSC = CSC{..}, .. } v = unsafePerformIO $ do | 151 | gmXv SparseC { gmCSC = CSC{..}, .. } v = unsafePerformIO $ do |
152 | dim v /= nCols ~!~ printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) | 152 | dim v /= nCols ~!~ printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) |
153 | r <- createVector nRows | 153 | r <- createVector nRows |
154 | app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv" | 154 | c_smTXv # cscVals # cscRows # cscCols # v # r #|"CSCXv" |
155 | return r | 155 | return r |
156 | 156 | ||
157 | gmXv Diag{..} v | 157 | gmXv Diag{..} v |
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs index 079663d..924ca4c 100644 --- a/packages/base/src/Internal/Util.hs +++ b/packages/base/src/Internal/Util.hs | |||
@@ -31,7 +31,7 @@ module Internal.Util( | |||
31 | diagl, | 31 | diagl, |
32 | row, | 32 | row, |
33 | col, | 33 | col, |
34 | (&), (¦), (|||), (——), (===), (#), | 34 | (&), (¦), (|||), (——), (===), |
35 | (?), (¿), | 35 | (?), (¿), |
36 | Indexable(..), size, | 36 | Indexable(..), size, |
37 | Numeric, | 37 | Numeric, |
@@ -185,10 +185,6 @@ infixl 2 —— | |||
185 | (——) = (===) | 185 | (——) = (===) |
186 | 186 | ||
187 | 187 | ||
188 | (#) :: Matrix Double -> Matrix Double -> Matrix Double | ||
189 | infixl 2 # | ||
190 | a # b = fromBlocks [[a],[b]] | ||
191 | |||
192 | -- | create a single row real matrix from a list | 188 | -- | create a single row real matrix from a list |
193 | -- | 189 | -- |
194 | -- >>> row [2,3,1,8] | 190 | -- >>> row [2,3,1,8] |
diff --git a/packages/base/src/Internal/Vector.hs b/packages/base/src/Internal/Vector.hs index 0e9161d..e5ac440 100644 --- a/packages/base/src/Internal/Vector.hs +++ b/packages/base/src/Internal/Vector.hs | |||
@@ -14,7 +14,7 @@ module Internal.Vector( | |||
14 | I,Z,R,C, | 14 | I,Z,R,C, |
15 | fi,ti, | 15 | fi,ti, |
16 | Vector, fromList, unsafeToForeignPtr, unsafeFromForeignPtr, unsafeWith, | 16 | Vector, fromList, unsafeToForeignPtr, unsafeFromForeignPtr, unsafeWith, |
17 | createVector, vec, | 17 | createVector, vec, avec, arrvec, inlinePerformIO, |
18 | toList, dim, (@>), at', (|>), | 18 | toList, dim, (@>), at', (|>), |
19 | vjoin, subVector, takesV, idxs, | 19 | vjoin, subVector, takesV, idxs, |
20 | buildVector, | 20 | buildVector, |
@@ -75,6 +75,16 @@ vec x f = unsafeWith x $ \p -> do | |||
75 | f v | 75 | f v |
76 | {-# INLINE vec #-} | 76 | {-# INLINE vec #-} |
77 | 77 | ||
78 | {-# INLINE avec #-} | ||
79 | avec :: Storable a => (CInt -> Ptr a -> b) -> Vector a -> b | ||
80 | avec f v = inlinePerformIO (unsafeWith v (return . f (fromIntegral (Vector.length v)))) | ||
81 | infixl 1 `avec` | ||
82 | |||
83 | {-# INLINE arrvec #-} | ||
84 | arrvec :: Storable a => (Ptr CInt -> Ptr a -> b) -> Vector a -> b | ||
85 | arrvec f v = inlinePerformIO (unsafeWith (idxs [1,dim v]) (\p -> unsafeWith v (return . f p))) | ||
86 | |||
87 | |||
78 | 88 | ||
79 | -- allocates memory for a new vector | 89 | -- allocates memory for a new vector |
80 | createVector :: Storable a => Int -> IO (Vector a) | 90 | createVector :: Storable a => Int -> IO (Vector a) |
diff --git a/packages/base/src/Internal/Vectorized.hs b/packages/base/src/Internal/Vectorized.hs index 5c89ac9..03bcf90 100644 --- a/packages/base/src/Internal/Vectorized.hs +++ b/packages/base/src/Internal/Vectorized.hs | |||
@@ -1,4 +1,5 @@ | |||
1 | {-# LANGUAGE TypeOperators #-} | 1 | {-# LANGUAGE TypeOperators #-} |
2 | {-# LANGUAGE TypeFamilies #-} | ||
2 | 3 | ||
3 | ----------------------------------------------------------------------------- | 4 | ----------------------------------------------------------------------------- |
4 | -- | | 5 | -- | |
@@ -26,7 +27,9 @@ import Foreign.C.String | |||
26 | import System.IO.Unsafe(unsafePerformIO) | 27 | import System.IO.Unsafe(unsafePerformIO) |
27 | import Control.Monad(when) | 28 | import Control.Monad(when) |
28 | 29 | ||
29 | 30 | infixl 1 # | |
31 | a # b = applyRaw a b | ||
32 | {-# INLINE (#) #-} | ||
30 | 33 | ||
31 | fromei x = fromIntegral (fromEnum x) :: CInt | 34 | fromei x = fromIntegral (fromEnum x) :: CInt |
32 | 35 | ||
@@ -100,7 +103,7 @@ sumL m = sumg (c_sumL m) | |||
100 | 103 | ||
101 | sumg f x = unsafePerformIO $ do | 104 | sumg f x = unsafePerformIO $ do |
102 | r <- createVector 1 | 105 | r <- createVector 1 |
103 | app2 f vec x vec r "sum" | 106 | f # x # r #| "sum" |
104 | return $ r @> 0 | 107 | return $ r @> 0 |
105 | 108 | ||
106 | type TVV t = t :> t :> Ok | 109 | type TVV t = t :> t :> Ok |
@@ -128,14 +131,15 @@ prodQ = prodg c_prodQ | |||
128 | prodC :: Vector (Complex Double) -> Complex Double | 131 | prodC :: Vector (Complex Double) -> Complex Double |
129 | prodC = prodg c_prodC | 132 | prodC = prodg c_prodC |
130 | 133 | ||
131 | 134 | prodI :: I-> Vector I -> I | |
132 | prodI = prodg . c_prodI | 135 | prodI = prodg . c_prodI |
133 | 136 | ||
137 | prodL :: Z-> Vector Z -> Z | ||
134 | prodL = prodg . c_prodL | 138 | prodL = prodg . c_prodL |
135 | 139 | ||
136 | prodg f x = unsafePerformIO $ do | 140 | prodg f x = unsafePerformIO $ do |
137 | r <- createVector 1 | 141 | r <- createVector 1 |
138 | app2 f vec x vec r "prod" | 142 | f # x # r #| "prod" |
139 | return $ r @> 0 | 143 | return $ r @> 0 |
140 | 144 | ||
141 | 145 | ||
@@ -150,24 +154,24 @@ foreign import ccall unsafe "prodL" c_prodL :: Z -> TVV Z | |||
150 | 154 | ||
151 | toScalarAux fun code v = unsafePerformIO $ do | 155 | toScalarAux fun code v = unsafePerformIO $ do |
152 | r <- createVector 1 | 156 | r <- createVector 1 |
153 | app2 (fun (fromei code)) vec v vec r "toScalarAux" | 157 | fun (fromei code) # v # r #|"toScalarAux" |
154 | return (r @> 0) | 158 | return (r @> 0) |
155 | 159 | ||
156 | vectorMapAux fun code v = unsafePerformIO $ do | 160 | vectorMapAux fun code v = unsafePerformIO $ do |
157 | r <- createVector (dim v) | 161 | r <- createVector (dim v) |
158 | app2 (fun (fromei code)) vec v vec r "vectorMapAux" | 162 | fun (fromei code) # v # r #|"vectorMapAux" |
159 | return r | 163 | return r |
160 | 164 | ||
161 | vectorMapValAux fun code val v = unsafePerformIO $ do | 165 | vectorMapValAux fun code val v = unsafePerformIO $ do |
162 | r <- createVector (dim v) | 166 | r <- createVector (dim v) |
163 | pval <- newArray [val] | 167 | pval <- newArray [val] |
164 | app2 (fun (fromei code) pval) vec v vec r "vectorMapValAux" | 168 | fun (fromei code) pval # v # r #|"vectorMapValAux" |
165 | free pval | 169 | free pval |
166 | return r | 170 | return r |
167 | 171 | ||
168 | vectorZipAux fun code u v = unsafePerformIO $ do | 172 | vectorZipAux fun code u v = unsafePerformIO $ do |
169 | r <- createVector (dim u) | 173 | r <- createVector (dim u) |
170 | app3 (fun (fromei code)) vec u vec v vec r "vectorZipAux" | 174 | fun (fromei code) # u # v # r #|"vectorZipAux" |
171 | return r | 175 | return r |
172 | 176 | ||
173 | --------------------------------------------------------------------- | 177 | --------------------------------------------------------------------- |
@@ -364,7 +368,7 @@ randomVector :: Seed | |||
364 | -> Vector Double | 368 | -> Vector Double |
365 | randomVector seed dist n = unsafePerformIO $ do | 369 | randomVector seed dist n = unsafePerformIO $ do |
366 | r <- createVector n | 370 | r <- createVector n |
367 | app1 (c_random_vector (fi seed) ((fi.fromEnum) dist)) vec r "randomVector" | 371 | c_random_vector (fi seed) ((fi.fromEnum) dist) # r #|"randomVector" |
368 | return r | 372 | return r |
369 | 373 | ||
370 | foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> Double :> Ok | 374 | foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> Double :> Ok |
@@ -373,7 +377,7 @@ foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> D | |||
373 | 377 | ||
374 | roundVector v = unsafePerformIO $ do | 378 | roundVector v = unsafePerformIO $ do |
375 | r <- createVector (dim v) | 379 | r <- createVector (dim v) |
376 | app2 c_round_vector vec v vec r "roundVector" | 380 | c_round_vector # v # r #|"roundVector" |
377 | return r | 381 | return r |
378 | 382 | ||
379 | foreign import ccall unsafe "round_vector" c_round_vector :: TVV Double | 383 | foreign import ccall unsafe "round_vector" c_round_vector :: TVV Double |
@@ -387,7 +391,7 @@ foreign import ccall unsafe "round_vector" c_round_vector :: TVV Double | |||
387 | range :: Int -> Vector I | 391 | range :: Int -> Vector I |
388 | range n = unsafePerformIO $ do | 392 | range n = unsafePerformIO $ do |
389 | r <- createVector n | 393 | r <- createVector n |
390 | app1 c_range_vector vec r "range" | 394 | c_range_vector # r #|"range" |
391 | return r | 395 | return r |
392 | 396 | ||
393 | foreign import ccall unsafe "range_vector" c_range_vector :: CInt :> Ok | 397 | foreign import ccall unsafe "range_vector" c_range_vector :: CInt :> Ok |
@@ -427,7 +431,7 @@ long2intV = tog c_long2int | |||
427 | 431 | ||
428 | tog f v = unsafePerformIO $ do | 432 | tog f v = unsafePerformIO $ do |
429 | r <- createVector (dim v) | 433 | r <- createVector (dim v) |
430 | app2 f vec v vec r "tog" | 434 | f # v # r #|"tog" |
431 | return r | 435 | return r |
432 | 436 | ||
433 | foreign import ccall unsafe "float2double" c_float2double :: Float :> Double :> Ok | 437 | foreign import ccall unsafe "float2double" c_float2double :: Float :> Double :> Ok |
@@ -446,7 +450,7 @@ foreign import ccall unsafe "long2int" c_long2int :: Z :> I :> Ok | |||
446 | 450 | ||
447 | stepg f v = unsafePerformIO $ do | 451 | stepg f v = unsafePerformIO $ do |
448 | r <- createVector (dim v) | 452 | r <- createVector (dim v) |
449 | app2 f vec v vec r "step" | 453 | f # v # r #|"step" |
450 | return r | 454 | return r |
451 | 455 | ||
452 | stepD :: Vector Double -> Vector Double | 456 | stepD :: Vector Double -> Vector Double |
@@ -471,7 +475,7 @@ foreign import ccall unsafe "stepL" c_stepL :: TVV Z | |||
471 | 475 | ||
472 | conjugateAux fun x = unsafePerformIO $ do | 476 | conjugateAux fun x = unsafePerformIO $ do |
473 | v <- createVector (dim x) | 477 | v <- createVector (dim x) |
474 | app2 fun vec x vec v "conjugateAux" | 478 | fun # x # v #|"conjugateAux" |
475 | return v | 479 | return v |
476 | 480 | ||
477 | conjugateQ :: Vector (Complex Float) -> Vector (Complex Float) | 481 | conjugateQ :: Vector (Complex Float) -> Vector (Complex Float) |
@@ -489,7 +493,7 @@ cloneVector v = do | |||
489 | let n = dim v | 493 | let n = dim v |
490 | r <- createVector n | 494 | r <- createVector n |
491 | let f _ s _ d = copyArray d s n >> return 0 | 495 | let f _ s _ d = copyArray d s n >> return 0 |
492 | app2 f vec v vec r "cloneVector" | 496 | f # v # r #|"cloneVector" |
493 | return r | 497 | return r |
494 | 498 | ||
495 | -------------------------------------------------------------------------------- | 499 | -------------------------------------------------------------------------------- |
@@ -497,7 +501,7 @@ cloneVector v = do | |||
497 | constantAux fun x n = unsafePerformIO $ do | 501 | constantAux fun x n = unsafePerformIO $ do |
498 | v <- createVector n | 502 | v <- createVector n |
499 | px <- newArray [x] | 503 | px <- newArray [x] |
500 | app1 (fun px) vec v "constantAux" | 504 | fun px # v #|"constantAux" |
501 | free px | 505 | free px |
502 | return v | 506 | return v |
503 | 507 | ||