diff options
author | maxc01 <xingchen92@gmail.com> | 2015-10-07 13:48:26 +0800 |
---|---|---|
committer | maxc01 <xingchen92@gmail.com> | 2015-10-07 13:48:26 +0800 |
commit | a61af756ddca4544de5e4969edc73131f4fccdd1 (patch) | |
tree | 2ac1755695a42d3964208e0029e74d446f5c3bd8 | |
parent | 0840304af1564fa86a6006d648450372f301a6c8 (diff) | |
parent | c84a485f148063f6d0c23f016fe348ec94fb6b19 (diff) |
Merge pull request #1 from albertoruiz/master
sync from albetoruiz/hmatrix
104 files changed, 9059 insertions, 4325 deletions
@@ -55,6 +55,32 @@ using this method. | |||
55 | 55 | ||
56 | [winpack]: https://github.com/downloads/AlbertoRuiz/hmatrix/gsl-lapack-windows.zip | 56 | [winpack]: https://github.com/downloads/AlbertoRuiz/hmatrix/gsl-lapack-windows.zip |
57 | 57 | ||
58 | ### Alternative Windows build | ||
59 | |||
60 | 1) | ||
61 | |||
62 | > cabal update | ||
63 | |||
64 | 2) Download and unzip somewhere OpenBLAS http://www.openblas.net/ | ||
65 | |||
66 | 3) In a normal Windows cmd: | ||
67 | |||
68 | > cabal install --flags=openblas --extra-lib-dirs=C:\...\OpenBLAS\lib --extra-include-dir=C:\...\OpenBLAS\include | ||
69 | |||
70 | ### Stack-based Windows build | ||
71 | |||
72 | Similar should be build under other OSes, like Linux and OSX. | ||
73 | |||
74 | 1) | ||
75 | |||
76 | > stack setup | ||
77 | |||
78 | 2) Download and unzip somewhere OpenBLAS http://www.openblas.net/ | ||
79 | |||
80 | 3) Example in a normal Windows cmd for building hmatrix base lib: | ||
81 | |||
82 | > stack install hmatrix --flag hmatrix:openblas --extra-lib-dirs=C:\...\OpenBLAS\lib --extra-include-dir=C:\...\OpenBLAS\include | ||
83 | |||
58 | ## Tests ############################################### | 84 | ## Tests ############################################### |
59 | 85 | ||
60 | After installation we can verify that the library works as expected: | 86 | After installation we can verify that the library works as expected: |
@@ -5,9 +5,9 @@ A purely functional interface to linear algebra and other numerical algorithms, | |||
5 | 5 | ||
6 | This package includes matrix decompositions (eigensystems, singular values, Cholesky, QR, etc.), linear solvers, numeric integration, root finding, etc. | 6 | This package includes matrix decompositions (eigensystems, singular values, Cholesky, QR, etc.), linear solvers, numeric integration, root finding, etc. |
7 | 7 | ||
8 | Version 0.16 (june 2014) has [new features][changes]. | 8 | - [What's new][changes] in version 0.17 (July 2015) |
9 | 9 | ||
10 | - [Code examples (in construction)][examples] | 10 | - [Code examples][examples] |
11 | 11 | ||
12 | - Source code and documentation (Hackage) | 12 | - Source code and documentation (Hackage) |
13 | - linear algebra: [hmatrix](http://hackage.haskell.org/package/hmatrix) | 13 | - linear algebra: [hmatrix](http://hackage.haskell.org/package/hmatrix) |
diff --git a/examples/bool.hs b/examples/bool.hs index 679b8bf..ee85523 100644 --- a/examples/bool.hs +++ b/examples/bool.hs | |||
@@ -1,17 +1,25 @@ | |||
1 | -- vectorized boolean operations defined in terms of step or cond | 1 | -- vectorized boolean operations defined in terms of step or cond |
2 | 2 | ||
3 | {-# LANGUAGE FlexibleContexts #-} | ||
4 | |||
3 | import Numeric.LinearAlgebra | 5 | import Numeric.LinearAlgebra |
4 | 6 | ||
5 | infix 4 .==., ./=., .<., .<=., .>=., .>. | 7 | infix 4 .==., ./=., .<., .<=., .>=., .>. |
6 | infixr 3 .&&. | 8 | infixr 3 .&&. |
7 | infixr 2 .||. | 9 | infixr 2 .||. |
8 | 10 | ||
9 | a .<. b = step (b-a) | 11 | -- specialized for Int result |
10 | a .<=. b = cond a b 1 1 0 | 12 | cond' |
11 | a .==. b = cond a b 0 1 0 | 13 | :: (Element t, Ord t, Container c I, Container c t) |
12 | a ./=. b = cond a b 1 0 1 | 14 | => c t -> c t -> c I -> c I -> c I -> c I |
13 | a .>=. b = cond a b 0 1 1 | 15 | cond' = cond |
14 | a .>. b = step (a-b) | 16 | |
17 | a .<. b = cond' a b 1 0 0 | ||
18 | a .<=. b = cond' a b 1 1 0 | ||
19 | a .==. b = cond' a b 0 1 0 | ||
20 | a ./=. b = cond' a b 1 0 1 | ||
21 | a .>=. b = cond' a b 0 1 1 | ||
22 | a .>. b = cond' a b 0 0 1 | ||
15 | 23 | ||
16 | a .&&. b = step (a*b) | 24 | a .&&. b = step (a*b) |
17 | a .||. b = step (a+b) | 25 | a .||. b = step (a+b) |
@@ -29,26 +37,22 @@ maxEvery a b = cond a b b b a | |||
29 | 37 | ||
30 | clip a b x = cond y b y y b where y = cond x a a x x | 38 | clip a b x = cond y b y y b where y = cond x a a x x |
31 | 39 | ||
32 | disp = putStr . dispf 3 | 40 | eye n = ident n :: Matrix R |
33 | |||
34 | eye n = ident n :: Matrix Double | ||
35 | row = asRow . fromList :: [Double] -> Matrix Double | ||
36 | col = asColumn . fromList :: [Double] -> Matrix Double | ||
37 | 41 | ||
38 | m = (3><4) [1..] :: Matrix Double | 42 | m = (3><4) [1..] :: Matrix R |
39 | 43 | ||
40 | p = row [0,0,1,1] | 44 | p = fromList [0,0,1,1] :: Vector I |
41 | q = row [0,1,0,1] | 45 | q = fromList [0,1,0,1] :: Vector I |
42 | 46 | ||
43 | main = do | 47 | main = do |
44 | print $ find (>6) m | 48 | print $ find (>6) m |
45 | disp $ assoc (6,8) 7 $ zip (find (/=0) (eye 5)) [10..] | 49 | disp 3 $ assoc (6,8) 7 $ zip (find (/=0) (eye 5)) [10..] |
46 | disp $ accum (eye 5) (+) [((0,2),3), ((3,1),7), ((1,1),1)] | 50 | disp 3 $ accum (eye 5) (+) [((0,2),3), ((3,1),7), ((1,1),1)] |
47 | disp $ m .>=. 10 .||. m .<. 4 | 51 | print $ m .>=. 10 .||. m .<. 4 |
48 | (disp . fromColumns . map flatten) [p, q, p.&&.q, p .||.q, p `xor` q, p `equiv` q, p `imp` q] | 52 | (print . fromColumns) [p, q, p.&&.q, p .||.q, p `xor` q, p `equiv` q, p `imp` q] |
49 | print $ taut $ (p `imp` q ) `equiv` (no q `imp` no p) | 53 | print $ taut $ (p `imp` q ) `equiv` (no q `imp` no p) |
50 | print $ taut $ (xor p q) `equiv` (p .&&. no q .||. no p .&&. q) | 54 | print $ taut $ (xor p q) `equiv` (p .&&. no q .||. no p .&&. q) |
51 | disp $ clip 3 8 m | 55 | disp 3 $ clip 3 8 m |
52 | disp $ col [1..7] .<=. row [1..5] | 56 | print $ col [1..7] .<=. row [1..5] |
53 | disp $ cond (col [1..3]) (row [1..4]) m 50 (3*m) | 57 | print $ cond (col [1..3]) (row [1..4]) m 50 (3*m) |
54 | 58 | ||
diff --git a/examples/bool.ipynb b/examples/bool.ipynb new file mode 100644 index 0000000..abceeb4 --- /dev/null +++ b/examples/bool.ipynb | |||
@@ -0,0 +1,1152 @@ | |||
1 | { | ||
2 | "cells": [ | ||
3 | { | ||
4 | "cell_type": "markdown", | ||
5 | "metadata": {}, | ||
6 | "source": [ | ||
7 | "# vectorized boolean operations" | ||
8 | ] | ||
9 | }, | ||
10 | { | ||
11 | "cell_type": "code", | ||
12 | "execution_count": 1, | ||
13 | "metadata": { | ||
14 | "collapsed": true | ||
15 | }, | ||
16 | "outputs": [], | ||
17 | "source": [ | ||
18 | "import Numeric.LinearAlgebra\n", | ||
19 | ":ext FlexibleContexts" | ||
20 | ] | ||
21 | }, | ||
22 | { | ||
23 | "cell_type": "markdown", | ||
24 | "metadata": {}, | ||
25 | "source": [ | ||
26 | "## pretty printing" | ||
27 | ] | ||
28 | }, | ||
29 | { | ||
30 | "cell_type": "code", | ||
31 | "execution_count": 2, | ||
32 | "metadata": { | ||
33 | "collapsed": false, | ||
34 | "scrolled": true | ||
35 | }, | ||
36 | "outputs": [], | ||
37 | "source": [ | ||
38 | "import IHaskell.Display\n", | ||
39 | ":ext FlexibleInstances\n", | ||
40 | "\n", | ||
41 | "dec = 3\n", | ||
42 | "\n", | ||
43 | "dispBool = (\"\\n\"++) . format \" \" f\n", | ||
44 | " where\n", | ||
45 | " f 1 = \"\\\\top\"\n", | ||
46 | " f 0 = \"\\\\cdot\"\n", | ||
47 | "\n", | ||
48 | "instance IHaskellDisplay (Matrix I) where\n", | ||
49 | " display m = return $ Display [html (\"<p>$$\"++(latexFormat \"bmatrix\" . dispBool) m++\"$$</p>\")]\n", | ||
50 | "\n", | ||
51 | "instance IHaskellDisplay (Matrix C) where\n", | ||
52 | " display m = return $ Display [html (\"<p>$$\"++(latexFormat \"bmatrix\" . dispcf dec) m++\"$$</p>\")]\n", | ||
53 | "\n", | ||
54 | "instance IHaskellDisplay (Matrix R) where\n", | ||
55 | " display m = return $ Display [html (\"<p>$$\"++ (latexFormat \"bmatrix\" . dispf dec) m++\"$$</p>\")]" | ||
56 | ] | ||
57 | }, | ||
58 | { | ||
59 | "cell_type": "markdown", | ||
60 | "metadata": {}, | ||
61 | "source": [ | ||
62 | "## definitions" | ||
63 | ] | ||
64 | }, | ||
65 | { | ||
66 | "cell_type": "markdown", | ||
67 | "metadata": {}, | ||
68 | "source": [ | ||
69 | "vectorized operators defined in terms of `step` and `cond`" | ||
70 | ] | ||
71 | }, | ||
72 | { | ||
73 | "cell_type": "code", | ||
74 | "execution_count": 3, | ||
75 | "metadata": { | ||
76 | "collapsed": false | ||
77 | }, | ||
78 | "outputs": [], | ||
79 | "source": [ | ||
80 | "-- specialized for Int result\n", | ||
81 | "cond'\n", | ||
82 | " :: (Element t, Ord t, Container c I, Container c t)\n", | ||
83 | " => c t -> c t -> c I -> c I -> c I -> c I\n", | ||
84 | "cond' = cond" | ||
85 | ] | ||
86 | }, | ||
87 | { | ||
88 | "cell_type": "code", | ||
89 | "execution_count": 4, | ||
90 | "metadata": { | ||
91 | "collapsed": false | ||
92 | }, | ||
93 | "outputs": [], | ||
94 | "source": [ | ||
95 | "infix 4 .==., ./=., .<., .<=., .>=., .>.\n", | ||
96 | "infixr 3 .&&.\n", | ||
97 | "infixr 2 .||.\n", | ||
98 | "\n", | ||
99 | "a .<. b = cond' a b 1 0 0\n", | ||
100 | "a .<=. b = cond' a b 1 1 0\n", | ||
101 | "a .==. b = cond' a b 0 1 0\n", | ||
102 | "a ./=. b = cond' a b 1 0 1\n", | ||
103 | "a .>=. b = cond' a b 0 1 1\n", | ||
104 | "a .>. b = cond' a b 0 0 1\n", | ||
105 | "\n", | ||
106 | "a .&&. b = step (a*b)\n", | ||
107 | "a .||. b = step (a+b)\n", | ||
108 | "no a = 1-a\n", | ||
109 | "xor a b = a ./=. b\n", | ||
110 | "equiv a b = a .==. b\n", | ||
111 | "imp a b = no a .||. b" | ||
112 | ] | ||
113 | }, | ||
114 | { | ||
115 | "cell_type": "markdown", | ||
116 | "metadata": {}, | ||
117 | "source": [ | ||
118 | "other useful functions" | ||
119 | ] | ||
120 | }, | ||
121 | { | ||
122 | "cell_type": "code", | ||
123 | "execution_count": 5, | ||
124 | "metadata": { | ||
125 | "collapsed": true | ||
126 | }, | ||
127 | "outputs": [], | ||
128 | "source": [ | ||
129 | "taut x = minElement x == 1\n", | ||
130 | "\n", | ||
131 | "minEvery a b = cond a b a a b\n", | ||
132 | "maxEvery a b = cond a b b b a\n", | ||
133 | "\n", | ||
134 | "eye n = ident n :: Matrix R\n", | ||
135 | "\n", | ||
136 | "clip a b x = cond y b y y b\n", | ||
137 | " where\n", | ||
138 | " y = cond x a a x x" | ||
139 | ] | ||
140 | }, | ||
141 | { | ||
142 | "cell_type": "markdown", | ||
143 | "metadata": {}, | ||
144 | "source": [ | ||
145 | "## examples" | ||
146 | ] | ||
147 | }, | ||
148 | { | ||
149 | "cell_type": "code", | ||
150 | "execution_count": 6, | ||
151 | "metadata": { | ||
152 | "collapsed": false | ||
153 | }, | ||
154 | "outputs": [ | ||
155 | { | ||
156 | "data": { | ||
157 | "text/html": [ | ||
158 | "<style>/*\n", | ||
159 | "Custom IHaskell CSS.\n", | ||
160 | "*/\n", | ||
161 | "\n", | ||
162 | "/* Styles used for the Hoogle display in the pager */\n", | ||
163 | ".hoogle-doc {\n", | ||
164 | " display: block;\n", | ||
165 | " padding-bottom: 1.3em;\n", | ||
166 | " padding-left: 0.4em;\n", | ||
167 | "}\n", | ||
168 | ".hoogle-code {\n", | ||
169 | " display: block;\n", | ||
170 | " font-family: monospace;\n", | ||
171 | " white-space: pre;\n", | ||
172 | "}\n", | ||
173 | ".hoogle-text {\n", | ||
174 | " display: block;\n", | ||
175 | "}\n", | ||
176 | ".hoogle-name {\n", | ||
177 | " color: green;\n", | ||
178 | " font-weight: bold;\n", | ||
179 | "}\n", | ||
180 | ".hoogle-head {\n", | ||
181 | " font-weight: bold;\n", | ||
182 | "}\n", | ||
183 | ".hoogle-sub {\n", | ||
184 | " display: block;\n", | ||
185 | " margin-left: 0.4em;\n", | ||
186 | "}\n", | ||
187 | ".hoogle-package {\n", | ||
188 | " font-weight: bold;\n", | ||
189 | " font-style: italic;\n", | ||
190 | "}\n", | ||
191 | ".hoogle-module {\n", | ||
192 | " font-weight: bold;\n", | ||
193 | "}\n", | ||
194 | ".hoogle-class {\n", | ||
195 | " font-weight: bold;\n", | ||
196 | "}\n", | ||
197 | "\n", | ||
198 | "/* Styles used for basic displays */\n", | ||
199 | ".get-type {\n", | ||
200 | " color: green;\n", | ||
201 | " font-weight: bold;\n", | ||
202 | " font-family: monospace;\n", | ||
203 | " display: block;\n", | ||
204 | " white-space: pre-wrap;\n", | ||
205 | "}\n", | ||
206 | "\n", | ||
207 | ".show-type {\n", | ||
208 | " color: green;\n", | ||
209 | " font-weight: bold;\n", | ||
210 | " font-family: monospace;\n", | ||
211 | " margin-left: 1em;\n", | ||
212 | "}\n", | ||
213 | "\n", | ||
214 | ".mono {\n", | ||
215 | " font-family: monospace;\n", | ||
216 | " display: block;\n", | ||
217 | "}\n", | ||
218 | "\n", | ||
219 | ".err-msg {\n", | ||
220 | " color: red;\n", | ||
221 | " font-style: italic;\n", | ||
222 | " font-family: monospace;\n", | ||
223 | " white-space: pre;\n", | ||
224 | " display: block;\n", | ||
225 | "}\n", | ||
226 | "\n", | ||
227 | "#unshowable {\n", | ||
228 | " color: red;\n", | ||
229 | " font-weight: bold;\n", | ||
230 | "}\n", | ||
231 | "\n", | ||
232 | ".err-msg.in.collapse {\n", | ||
233 | " padding-top: 0.7em;\n", | ||
234 | "}\n", | ||
235 | "\n", | ||
236 | "/* Code that will get highlighted before it is highlighted */\n", | ||
237 | ".highlight-code {\n", | ||
238 | " white-space: pre;\n", | ||
239 | " font-family: monospace;\n", | ||
240 | "}\n", | ||
241 | "\n", | ||
242 | "/* Hlint styles */\n", | ||
243 | ".suggestion-warning { \n", | ||
244 | " font-weight: bold;\n", | ||
245 | " color: rgb(200, 130, 0);\n", | ||
246 | "}\n", | ||
247 | ".suggestion-error { \n", | ||
248 | " font-weight: bold;\n", | ||
249 | " color: red;\n", | ||
250 | "}\n", | ||
251 | ".suggestion-name {\n", | ||
252 | " font-weight: bold;\n", | ||
253 | "}\n", | ||
254 | "</style><p>$$\\begin{bmatrix}\n", | ||
255 | "\\top & \\top & \\top & \\top & \\top\n", | ||
256 | "\\\\\n", | ||
257 | "\\cdot & \\top & \\top & \\top & \\top\n", | ||
258 | "\\\\\n", | ||
259 | "\\cdot & \\cdot & \\top & \\top & \\top\n", | ||
260 | "\\\\\n", | ||
261 | "\\cdot & \\cdot & \\cdot & \\top & \\top\n", | ||
262 | "\\\\\n", | ||
263 | "\\cdot & \\cdot & \\cdot & \\cdot & \\top\n", | ||
264 | "\\\\\n", | ||
265 | "\\cdot & \\cdot & \\cdot & \\cdot & \\cdot\n", | ||
266 | "\\\\\n", | ||
267 | "\\cdot & \\cdot & \\cdot & \\cdot & \\cdot\n", | ||
268 | "\\end{bmatrix}$$</p>" | ||
269 | ] | ||
270 | }, | ||
271 | "metadata": {}, | ||
272 | "output_type": "display_data" | ||
273 | } | ||
274 | ], | ||
275 | "source": [ | ||
276 | "col [1..7] .<=. row [1..5]" | ||
277 | ] | ||
278 | }, | ||
279 | { | ||
280 | "cell_type": "code", | ||
281 | "execution_count": 7, | ||
282 | "metadata": { | ||
283 | "collapsed": true | ||
284 | }, | ||
285 | "outputs": [], | ||
286 | "source": [ | ||
287 | "m = (3><4) [1..] :: Matrix R" | ||
288 | ] | ||
289 | }, | ||
290 | { | ||
291 | "cell_type": "code", | ||
292 | "execution_count": 8, | ||
293 | "metadata": { | ||
294 | "collapsed": false | ||
295 | }, | ||
296 | "outputs": [ | ||
297 | { | ||
298 | "data": { | ||
299 | "text/html": [ | ||
300 | "<style>/*\n", | ||
301 | "Custom IHaskell CSS.\n", | ||
302 | "*/\n", | ||
303 | "\n", | ||
304 | "/* Styles used for the Hoogle display in the pager */\n", | ||
305 | ".hoogle-doc {\n", | ||
306 | " display: block;\n", | ||
307 | " padding-bottom: 1.3em;\n", | ||
308 | " padding-left: 0.4em;\n", | ||
309 | "}\n", | ||
310 | ".hoogle-code {\n", | ||
311 | " display: block;\n", | ||
312 | " font-family: monospace;\n", | ||
313 | " white-space: pre;\n", | ||
314 | "}\n", | ||
315 | ".hoogle-text {\n", | ||
316 | " display: block;\n", | ||
317 | "}\n", | ||
318 | ".hoogle-name {\n", | ||
319 | " color: green;\n", | ||
320 | " font-weight: bold;\n", | ||
321 | "}\n", | ||
322 | ".hoogle-head {\n", | ||
323 | " font-weight: bold;\n", | ||
324 | "}\n", | ||
325 | ".hoogle-sub {\n", | ||
326 | " display: block;\n", | ||
327 | " margin-left: 0.4em;\n", | ||
328 | "}\n", | ||
329 | ".hoogle-package {\n", | ||
330 | " font-weight: bold;\n", | ||
331 | " font-style: italic;\n", | ||
332 | "}\n", | ||
333 | ".hoogle-module {\n", | ||
334 | " font-weight: bold;\n", | ||
335 | "}\n", | ||
336 | ".hoogle-class {\n", | ||
337 | " font-weight: bold;\n", | ||
338 | "}\n", | ||
339 | "\n", | ||
340 | "/* Styles used for basic displays */\n", | ||
341 | ".get-type {\n", | ||
342 | " color: green;\n", | ||
343 | " font-weight: bold;\n", | ||
344 | " font-family: monospace;\n", | ||
345 | " display: block;\n", | ||
346 | " white-space: pre-wrap;\n", | ||
347 | "}\n", | ||
348 | "\n", | ||
349 | ".show-type {\n", | ||
350 | " color: green;\n", | ||
351 | " font-weight: bold;\n", | ||
352 | " font-family: monospace;\n", | ||
353 | " margin-left: 1em;\n", | ||
354 | "}\n", | ||
355 | "\n", | ||
356 | ".mono {\n", | ||
357 | " font-family: monospace;\n", | ||
358 | " display: block;\n", | ||
359 | "}\n", | ||
360 | "\n", | ||
361 | ".err-msg {\n", | ||
362 | " color: red;\n", | ||
363 | " font-style: italic;\n", | ||
364 | " font-family: monospace;\n", | ||
365 | " white-space: pre;\n", | ||
366 | " display: block;\n", | ||
367 | "}\n", | ||
368 | "\n", | ||
369 | "#unshowable {\n", | ||
370 | " color: red;\n", | ||
371 | " font-weight: bold;\n", | ||
372 | "}\n", | ||
373 | "\n", | ||
374 | ".err-msg.in.collapse {\n", | ||
375 | " padding-top: 0.7em;\n", | ||
376 | "}\n", | ||
377 | "\n", | ||
378 | "/* Code that will get highlighted before it is highlighted */\n", | ||
379 | ".highlight-code {\n", | ||
380 | " white-space: pre;\n", | ||
381 | " font-family: monospace;\n", | ||
382 | "}\n", | ||
383 | "\n", | ||
384 | "/* Hlint styles */\n", | ||
385 | ".suggestion-warning { \n", | ||
386 | " font-weight: bold;\n", | ||
387 | " color: rgb(200, 130, 0);\n", | ||
388 | "}\n", | ||
389 | ".suggestion-error { \n", | ||
390 | " font-weight: bold;\n", | ||
391 | " color: red;\n", | ||
392 | "}\n", | ||
393 | ".suggestion-name {\n", | ||
394 | " font-weight: bold;\n", | ||
395 | "}\n", | ||
396 | "</style><p>$$\\begin{bmatrix}\n", | ||
397 | "1 & 2 & 3 & 4\n", | ||
398 | "\\\\\n", | ||
399 | "5 & 6 & 7 & 8\n", | ||
400 | "\\\\\n", | ||
401 | "9 & 10 & 11 & 12\n", | ||
402 | "\\end{bmatrix}$$</p>" | ||
403 | ] | ||
404 | }, | ||
405 | "metadata": {}, | ||
406 | "output_type": "display_data" | ||
407 | } | ||
408 | ], | ||
409 | "source": [ | ||
410 | "m" | ||
411 | ] | ||
412 | }, | ||
413 | { | ||
414 | "cell_type": "code", | ||
415 | "execution_count": 9, | ||
416 | "metadata": { | ||
417 | "collapsed": false | ||
418 | }, | ||
419 | "outputs": [ | ||
420 | { | ||
421 | "data": { | ||
422 | "text/html": [ | ||
423 | "<style>/*\n", | ||
424 | "Custom IHaskell CSS.\n", | ||
425 | "*/\n", | ||
426 | "\n", | ||
427 | "/* Styles used for the Hoogle display in the pager */\n", | ||
428 | ".hoogle-doc {\n", | ||
429 | " display: block;\n", | ||
430 | " padding-bottom: 1.3em;\n", | ||
431 | " padding-left: 0.4em;\n", | ||
432 | "}\n", | ||
433 | ".hoogle-code {\n", | ||
434 | " display: block;\n", | ||
435 | " font-family: monospace;\n", | ||
436 | " white-space: pre;\n", | ||
437 | "}\n", | ||
438 | ".hoogle-text {\n", | ||
439 | " display: block;\n", | ||
440 | "}\n", | ||
441 | ".hoogle-name {\n", | ||
442 | " color: green;\n", | ||
443 | " font-weight: bold;\n", | ||
444 | "}\n", | ||
445 | ".hoogle-head {\n", | ||
446 | " font-weight: bold;\n", | ||
447 | "}\n", | ||
448 | ".hoogle-sub {\n", | ||
449 | " display: block;\n", | ||
450 | " margin-left: 0.4em;\n", | ||
451 | "}\n", | ||
452 | ".hoogle-package {\n", | ||
453 | " font-weight: bold;\n", | ||
454 | " font-style: italic;\n", | ||
455 | "}\n", | ||
456 | ".hoogle-module {\n", | ||
457 | " font-weight: bold;\n", | ||
458 | "}\n", | ||
459 | ".hoogle-class {\n", | ||
460 | " font-weight: bold;\n", | ||
461 | "}\n", | ||
462 | "\n", | ||
463 | "/* Styles used for basic displays */\n", | ||
464 | ".get-type {\n", | ||
465 | " color: green;\n", | ||
466 | " font-weight: bold;\n", | ||
467 | " font-family: monospace;\n", | ||
468 | " display: block;\n", | ||
469 | " white-space: pre-wrap;\n", | ||
470 | "}\n", | ||
471 | "\n", | ||
472 | ".show-type {\n", | ||
473 | " color: green;\n", | ||
474 | " font-weight: bold;\n", | ||
475 | " font-family: monospace;\n", | ||
476 | " margin-left: 1em;\n", | ||
477 | "}\n", | ||
478 | "\n", | ||
479 | ".mono {\n", | ||
480 | " font-family: monospace;\n", | ||
481 | " display: block;\n", | ||
482 | "}\n", | ||
483 | "\n", | ||
484 | ".err-msg {\n", | ||
485 | " color: red;\n", | ||
486 | " font-style: italic;\n", | ||
487 | " font-family: monospace;\n", | ||
488 | " white-space: pre;\n", | ||
489 | " display: block;\n", | ||
490 | "}\n", | ||
491 | "\n", | ||
492 | "#unshowable {\n", | ||
493 | " color: red;\n", | ||
494 | " font-weight: bold;\n", | ||
495 | "}\n", | ||
496 | "\n", | ||
497 | ".err-msg.in.collapse {\n", | ||
498 | " padding-top: 0.7em;\n", | ||
499 | "}\n", | ||
500 | "\n", | ||
501 | "/* Code that will get highlighted before it is highlighted */\n", | ||
502 | ".highlight-code {\n", | ||
503 | " white-space: pre;\n", | ||
504 | " font-family: monospace;\n", | ||
505 | "}\n", | ||
506 | "\n", | ||
507 | "/* Hlint styles */\n", | ||
508 | ".suggestion-warning { \n", | ||
509 | " font-weight: bold;\n", | ||
510 | " color: rgb(200, 130, 0);\n", | ||
511 | "}\n", | ||
512 | ".suggestion-error { \n", | ||
513 | " font-weight: bold;\n", | ||
514 | " color: red;\n", | ||
515 | "}\n", | ||
516 | ".suggestion-name {\n", | ||
517 | " font-weight: bold;\n", | ||
518 | "}\n", | ||
519 | "</style><p>$$\\begin{bmatrix}\n", | ||
520 | "3 & 3 & 3 & 4\n", | ||
521 | "\\\\\n", | ||
522 | "5 & 6 & 7 & 8\n", | ||
523 | "\\\\\n", | ||
524 | "8 & 8 & 8 & 8\n", | ||
525 | "\\end{bmatrix}$$</p>" | ||
526 | ] | ||
527 | }, | ||
528 | "metadata": {}, | ||
529 | "output_type": "display_data" | ||
530 | } | ||
531 | ], | ||
532 | "source": [ | ||
533 | "clip 3 8 m" | ||
534 | ] | ||
535 | }, | ||
536 | { | ||
537 | "cell_type": "code", | ||
538 | "execution_count": 10, | ||
539 | "metadata": { | ||
540 | "collapsed": false | ||
541 | }, | ||
542 | "outputs": [ | ||
543 | { | ||
544 | "data": { | ||
545 | "text/plain": [ | ||
546 | "[(1,2),(1,3),(2,0),(2,1),(2,2),(2,3)]" | ||
547 | ] | ||
548 | }, | ||
549 | "metadata": {}, | ||
550 | "output_type": "display_data" | ||
551 | } | ||
552 | ], | ||
553 | "source": [ | ||
554 | "find (>6) m" | ||
555 | ] | ||
556 | }, | ||
557 | { | ||
558 | "cell_type": "code", | ||
559 | "execution_count": 11, | ||
560 | "metadata": { | ||
561 | "collapsed": false | ||
562 | }, | ||
563 | "outputs": [ | ||
564 | { | ||
565 | "data": { | ||
566 | "text/html": [ | ||
567 | "<style>/*\n", | ||
568 | "Custom IHaskell CSS.\n", | ||
569 | "*/\n", | ||
570 | "\n", | ||
571 | "/* Styles used for the Hoogle display in the pager */\n", | ||
572 | ".hoogle-doc {\n", | ||
573 | " display: block;\n", | ||
574 | " padding-bottom: 1.3em;\n", | ||
575 | " padding-left: 0.4em;\n", | ||
576 | "}\n", | ||
577 | ".hoogle-code {\n", | ||
578 | " display: block;\n", | ||
579 | " font-family: monospace;\n", | ||
580 | " white-space: pre;\n", | ||
581 | "}\n", | ||
582 | ".hoogle-text {\n", | ||
583 | " display: block;\n", | ||
584 | "}\n", | ||
585 | ".hoogle-name {\n", | ||
586 | " color: green;\n", | ||
587 | " font-weight: bold;\n", | ||
588 | "}\n", | ||
589 | ".hoogle-head {\n", | ||
590 | " font-weight: bold;\n", | ||
591 | "}\n", | ||
592 | ".hoogle-sub {\n", | ||
593 | " display: block;\n", | ||
594 | " margin-left: 0.4em;\n", | ||
595 | "}\n", | ||
596 | ".hoogle-package {\n", | ||
597 | " font-weight: bold;\n", | ||
598 | " font-style: italic;\n", | ||
599 | "}\n", | ||
600 | ".hoogle-module {\n", | ||
601 | " font-weight: bold;\n", | ||
602 | "}\n", | ||
603 | ".hoogle-class {\n", | ||
604 | " font-weight: bold;\n", | ||
605 | "}\n", | ||
606 | "\n", | ||
607 | "/* Styles used for basic displays */\n", | ||
608 | ".get-type {\n", | ||
609 | " color: green;\n", | ||
610 | " font-weight: bold;\n", | ||
611 | " font-family: monospace;\n", | ||
612 | " display: block;\n", | ||
613 | " white-space: pre-wrap;\n", | ||
614 | "}\n", | ||
615 | "\n", | ||
616 | ".show-type {\n", | ||
617 | " color: green;\n", | ||
618 | " font-weight: bold;\n", | ||
619 | " font-family: monospace;\n", | ||
620 | " margin-left: 1em;\n", | ||
621 | "}\n", | ||
622 | "\n", | ||
623 | ".mono {\n", | ||
624 | " font-family: monospace;\n", | ||
625 | " display: block;\n", | ||
626 | "}\n", | ||
627 | "\n", | ||
628 | ".err-msg {\n", | ||
629 | " color: red;\n", | ||
630 | " font-style: italic;\n", | ||
631 | " font-family: monospace;\n", | ||
632 | " white-space: pre;\n", | ||
633 | " display: block;\n", | ||
634 | "}\n", | ||
635 | "\n", | ||
636 | "#unshowable {\n", | ||
637 | " color: red;\n", | ||
638 | " font-weight: bold;\n", | ||
639 | "}\n", | ||
640 | "\n", | ||
641 | ".err-msg.in.collapse {\n", | ||
642 | " padding-top: 0.7em;\n", | ||
643 | "}\n", | ||
644 | "\n", | ||
645 | "/* Code that will get highlighted before it is highlighted */\n", | ||
646 | ".highlight-code {\n", | ||
647 | " white-space: pre;\n", | ||
648 | " font-family: monospace;\n", | ||
649 | "}\n", | ||
650 | "\n", | ||
651 | "/* Hlint styles */\n", | ||
652 | ".suggestion-warning { \n", | ||
653 | " font-weight: bold;\n", | ||
654 | " color: rgb(200, 130, 0);\n", | ||
655 | "}\n", | ||
656 | ".suggestion-error { \n", | ||
657 | " font-weight: bold;\n", | ||
658 | " color: red;\n", | ||
659 | "}\n", | ||
660 | ".suggestion-name {\n", | ||
661 | " font-weight: bold;\n", | ||
662 | "}\n", | ||
663 | "</style><p>$$\\begin{bmatrix}\n", | ||
664 | "\\top & \\top & \\top & \\cdot\n", | ||
665 | "\\\\\n", | ||
666 | "\\cdot & \\cdot & \\cdot & \\cdot\n", | ||
667 | "\\\\\n", | ||
668 | "\\cdot & \\top & \\top & \\top\n", | ||
669 | "\\end{bmatrix}$$</p>" | ||
670 | ] | ||
671 | }, | ||
672 | "metadata": {}, | ||
673 | "output_type": "display_data" | ||
674 | } | ||
675 | ], | ||
676 | "source": [ | ||
677 | "(m .>=. 10) .||. (m .<. 4)" | ||
678 | ] | ||
679 | }, | ||
680 | { | ||
681 | "cell_type": "code", | ||
682 | "execution_count": 12, | ||
683 | "metadata": { | ||
684 | "collapsed": false | ||
685 | }, | ||
686 | "outputs": [ | ||
687 | { | ||
688 | "data": { | ||
689 | "text/html": [ | ||
690 | "<style>/*\n", | ||
691 | "Custom IHaskell CSS.\n", | ||
692 | "*/\n", | ||
693 | "\n", | ||
694 | "/* Styles used for the Hoogle display in the pager */\n", | ||
695 | ".hoogle-doc {\n", | ||
696 | " display: block;\n", | ||
697 | " padding-bottom: 1.3em;\n", | ||
698 | " padding-left: 0.4em;\n", | ||
699 | "}\n", | ||
700 | ".hoogle-code {\n", | ||
701 | " display: block;\n", | ||
702 | " font-family: monospace;\n", | ||
703 | " white-space: pre;\n", | ||
704 | "}\n", | ||
705 | ".hoogle-text {\n", | ||
706 | " display: block;\n", | ||
707 | "}\n", | ||
708 | ".hoogle-name {\n", | ||
709 | " color: green;\n", | ||
710 | " font-weight: bold;\n", | ||
711 | "}\n", | ||
712 | ".hoogle-head {\n", | ||
713 | " font-weight: bold;\n", | ||
714 | "}\n", | ||
715 | ".hoogle-sub {\n", | ||
716 | " display: block;\n", | ||
717 | " margin-left: 0.4em;\n", | ||
718 | "}\n", | ||
719 | ".hoogle-package {\n", | ||
720 | " font-weight: bold;\n", | ||
721 | " font-style: italic;\n", | ||
722 | "}\n", | ||
723 | ".hoogle-module {\n", | ||
724 | " font-weight: bold;\n", | ||
725 | "}\n", | ||
726 | ".hoogle-class {\n", | ||
727 | " font-weight: bold;\n", | ||
728 | "}\n", | ||
729 | "\n", | ||
730 | "/* Styles used for basic displays */\n", | ||
731 | ".get-type {\n", | ||
732 | " color: green;\n", | ||
733 | " font-weight: bold;\n", | ||
734 | " font-family: monospace;\n", | ||
735 | " display: block;\n", | ||
736 | " white-space: pre-wrap;\n", | ||
737 | "}\n", | ||
738 | "\n", | ||
739 | ".show-type {\n", | ||
740 | " color: green;\n", | ||
741 | " font-weight: bold;\n", | ||
742 | " font-family: monospace;\n", | ||
743 | " margin-left: 1em;\n", | ||
744 | "}\n", | ||
745 | "\n", | ||
746 | ".mono {\n", | ||
747 | " font-family: monospace;\n", | ||
748 | " display: block;\n", | ||
749 | "}\n", | ||
750 | "\n", | ||
751 | ".err-msg {\n", | ||
752 | " color: red;\n", | ||
753 | " font-style: italic;\n", | ||
754 | " font-family: monospace;\n", | ||
755 | " white-space: pre;\n", | ||
756 | " display: block;\n", | ||
757 | "}\n", | ||
758 | "\n", | ||
759 | "#unshowable {\n", | ||
760 | " color: red;\n", | ||
761 | " font-weight: bold;\n", | ||
762 | "}\n", | ||
763 | "\n", | ||
764 | ".err-msg.in.collapse {\n", | ||
765 | " padding-top: 0.7em;\n", | ||
766 | "}\n", | ||
767 | "\n", | ||
768 | "/* Code that will get highlighted before it is highlighted */\n", | ||
769 | ".highlight-code {\n", | ||
770 | " white-space: pre;\n", | ||
771 | " font-family: monospace;\n", | ||
772 | "}\n", | ||
773 | "\n", | ||
774 | "/* Hlint styles */\n", | ||
775 | ".suggestion-warning { \n", | ||
776 | " font-weight: bold;\n", | ||
777 | " color: rgb(200, 130, 0);\n", | ||
778 | "}\n", | ||
779 | ".suggestion-error { \n", | ||
780 | " font-weight: bold;\n", | ||
781 | " color: red;\n", | ||
782 | "}\n", | ||
783 | ".suggestion-name {\n", | ||
784 | " font-weight: bold;\n", | ||
785 | "}\n", | ||
786 | "</style><p>$$\\begin{bmatrix}\n", | ||
787 | "50 & 2 & 3 & 4\n", | ||
788 | "\\\\\n", | ||
789 | "15 & 50 & 7 & 8\n", | ||
790 | "\\\\\n", | ||
791 | "27 & 30 & 50 & 12\n", | ||
792 | "\\end{bmatrix}$$</p>" | ||
793 | ] | ||
794 | }, | ||
795 | "metadata": {}, | ||
796 | "output_type": "display_data" | ||
797 | } | ||
798 | ], | ||
799 | "source": [ | ||
800 | "cond (col [1..3]) (row [1..4]) m 50 (3*m)" | ||
801 | ] | ||
802 | }, | ||
803 | { | ||
804 | "cell_type": "code", | ||
805 | "execution_count": 13, | ||
806 | "metadata": { | ||
807 | "collapsed": false | ||
808 | }, | ||
809 | "outputs": [ | ||
810 | { | ||
811 | "data": { | ||
812 | "text/plain": [ | ||
813 | "(6><8)\n", | ||
814 | " [ 10, 7, 7, 7, 7, 7, 7, 7\n", | ||
815 | " , 7, 11, 7, 7, 7, 7, 7, 7\n", | ||
816 | " , 7, 7, 12, 7, 7, 7, 7, 7\n", | ||
817 | " , 7, 7, 7, 13, 7, 7, 7, 7\n", | ||
818 | " , 7, 7, 7, 7, 14, 7, 7, 7\n", | ||
819 | " , 7, 7, 7, 7, 7, 7, 7, 7 ]" | ||
820 | ] | ||
821 | }, | ||
822 | "metadata": {}, | ||
823 | "output_type": "display_data" | ||
824 | } | ||
825 | ], | ||
826 | "source": [ | ||
827 | "assoc (6,8) 7 $ zip (find (/=0) (eye 5)) [10..] :: Matrix Z" | ||
828 | ] | ||
829 | }, | ||
830 | { | ||
831 | "cell_type": "code", | ||
832 | "execution_count": 14, | ||
833 | "metadata": { | ||
834 | "collapsed": false | ||
835 | }, | ||
836 | "outputs": [ | ||
837 | { | ||
838 | "data": { | ||
839 | "text/html": [ | ||
840 | "<style>/*\n", | ||
841 | "Custom IHaskell CSS.\n", | ||
842 | "*/\n", | ||
843 | "\n", | ||
844 | "/* Styles used for the Hoogle display in the pager */\n", | ||
845 | ".hoogle-doc {\n", | ||
846 | " display: block;\n", | ||
847 | " padding-bottom: 1.3em;\n", | ||
848 | " padding-left: 0.4em;\n", | ||
849 | "}\n", | ||
850 | ".hoogle-code {\n", | ||
851 | " display: block;\n", | ||
852 | " font-family: monospace;\n", | ||
853 | " white-space: pre;\n", | ||
854 | "}\n", | ||
855 | ".hoogle-text {\n", | ||
856 | " display: block;\n", | ||
857 | "}\n", | ||
858 | ".hoogle-name {\n", | ||
859 | " color: green;\n", | ||
860 | " font-weight: bold;\n", | ||
861 | "}\n", | ||
862 | ".hoogle-head {\n", | ||
863 | " font-weight: bold;\n", | ||
864 | "}\n", | ||
865 | ".hoogle-sub {\n", | ||
866 | " display: block;\n", | ||
867 | " margin-left: 0.4em;\n", | ||
868 | "}\n", | ||
869 | ".hoogle-package {\n", | ||
870 | " font-weight: bold;\n", | ||
871 | " font-style: italic;\n", | ||
872 | "}\n", | ||
873 | ".hoogle-module {\n", | ||
874 | " font-weight: bold;\n", | ||
875 | "}\n", | ||
876 | ".hoogle-class {\n", | ||
877 | " font-weight: bold;\n", | ||
878 | "}\n", | ||
879 | "\n", | ||
880 | "/* Styles used for basic displays */\n", | ||
881 | ".get-type {\n", | ||
882 | " color: green;\n", | ||
883 | " font-weight: bold;\n", | ||
884 | " font-family: monospace;\n", | ||
885 | " display: block;\n", | ||
886 | " white-space: pre-wrap;\n", | ||
887 | "}\n", | ||
888 | "\n", | ||
889 | ".show-type {\n", | ||
890 | " color: green;\n", | ||
891 | " font-weight: bold;\n", | ||
892 | " font-family: monospace;\n", | ||
893 | " margin-left: 1em;\n", | ||
894 | "}\n", | ||
895 | "\n", | ||
896 | ".mono {\n", | ||
897 | " font-family: monospace;\n", | ||
898 | " display: block;\n", | ||
899 | "}\n", | ||
900 | "\n", | ||
901 | ".err-msg {\n", | ||
902 | " color: red;\n", | ||
903 | " font-style: italic;\n", | ||
904 | " font-family: monospace;\n", | ||
905 | " white-space: pre;\n", | ||
906 | " display: block;\n", | ||
907 | "}\n", | ||
908 | "\n", | ||
909 | "#unshowable {\n", | ||
910 | " color: red;\n", | ||
911 | " font-weight: bold;\n", | ||
912 | "}\n", | ||
913 | "\n", | ||
914 | ".err-msg.in.collapse {\n", | ||
915 | " padding-top: 0.7em;\n", | ||
916 | "}\n", | ||
917 | "\n", | ||
918 | "/* Code that will get highlighted before it is highlighted */\n", | ||
919 | ".highlight-code {\n", | ||
920 | " white-space: pre;\n", | ||
921 | " font-family: monospace;\n", | ||
922 | "}\n", | ||
923 | "\n", | ||
924 | "/* Hlint styles */\n", | ||
925 | ".suggestion-warning { \n", | ||
926 | " font-weight: bold;\n", | ||
927 | " color: rgb(200, 130, 0);\n", | ||
928 | "}\n", | ||
929 | ".suggestion-error { \n", | ||
930 | " font-weight: bold;\n", | ||
931 | " color: red;\n", | ||
932 | "}\n", | ||
933 | ".suggestion-name {\n", | ||
934 | " font-weight: bold;\n", | ||
935 | "}\n", | ||
936 | "</style><p>$$\\begin{bmatrix}\n", | ||
937 | "1 & 0 & 3 & 0 & 0\n", | ||
938 | "\\\\\n", | ||
939 | "0 & 2 & 0 & 0 & 0\n", | ||
940 | "\\\\\n", | ||
941 | "0 & 0 & 1 & 0 & 0\n", | ||
942 | "\\\\\n", | ||
943 | "0 & 7 & 0 & 1 & 0\n", | ||
944 | "\\\\\n", | ||
945 | "0 & 0 & 0 & 0 & 1\n", | ||
946 | "\\end{bmatrix}$$</p>" | ||
947 | ] | ||
948 | }, | ||
949 | "metadata": {}, | ||
950 | "output_type": "display_data" | ||
951 | } | ||
952 | ], | ||
953 | "source": [ | ||
954 | "accum (eye 5) (+) [((0,2),3), ((3,1),7), ((1,1),1)]" | ||
955 | ] | ||
956 | }, | ||
957 | { | ||
958 | "cell_type": "code", | ||
959 | "execution_count": 15, | ||
960 | "metadata": { | ||
961 | "collapsed": true | ||
962 | }, | ||
963 | "outputs": [], | ||
964 | "source": [ | ||
965 | "p = fromList [0,0,1,1] :: Vector I\n", | ||
966 | "q = fromList [0,1,0,1] :: Vector I" | ||
967 | ] | ||
968 | }, | ||
969 | { | ||
970 | "cell_type": "code", | ||
971 | "execution_count": 16, | ||
972 | "metadata": { | ||
973 | "collapsed": false | ||
974 | }, | ||
975 | "outputs": [ | ||
976 | { | ||
977 | "data": { | ||
978 | "text/html": [ | ||
979 | "<style>/*\n", | ||
980 | "Custom IHaskell CSS.\n", | ||
981 | "*/\n", | ||
982 | "\n", | ||
983 | "/* Styles used for the Hoogle display in the pager */\n", | ||
984 | ".hoogle-doc {\n", | ||
985 | " display: block;\n", | ||
986 | " padding-bottom: 1.3em;\n", | ||
987 | " padding-left: 0.4em;\n", | ||
988 | "}\n", | ||
989 | ".hoogle-code {\n", | ||
990 | " display: block;\n", | ||
991 | " font-family: monospace;\n", | ||
992 | " white-space: pre;\n", | ||
993 | "}\n", | ||
994 | ".hoogle-text {\n", | ||
995 | " display: block;\n", | ||
996 | "}\n", | ||
997 | ".hoogle-name {\n", | ||
998 | " color: green;\n", | ||
999 | " font-weight: bold;\n", | ||
1000 | "}\n", | ||
1001 | ".hoogle-head {\n", | ||
1002 | " font-weight: bold;\n", | ||
1003 | "}\n", | ||
1004 | ".hoogle-sub {\n", | ||
1005 | " display: block;\n", | ||
1006 | " margin-left: 0.4em;\n", | ||
1007 | "}\n", | ||
1008 | ".hoogle-package {\n", | ||
1009 | " font-weight: bold;\n", | ||
1010 | " font-style: italic;\n", | ||
1011 | "}\n", | ||
1012 | ".hoogle-module {\n", | ||
1013 | " font-weight: bold;\n", | ||
1014 | "}\n", | ||
1015 | ".hoogle-class {\n", | ||
1016 | " font-weight: bold;\n", | ||
1017 | "}\n", | ||
1018 | "\n", | ||
1019 | "/* Styles used for basic displays */\n", | ||
1020 | ".get-type {\n", | ||
1021 | " color: green;\n", | ||
1022 | " font-weight: bold;\n", | ||
1023 | " font-family: monospace;\n", | ||
1024 | " display: block;\n", | ||
1025 | " white-space: pre-wrap;\n", | ||
1026 | "}\n", | ||
1027 | "\n", | ||
1028 | ".show-type {\n", | ||
1029 | " color: green;\n", | ||
1030 | " font-weight: bold;\n", | ||
1031 | " font-family: monospace;\n", | ||
1032 | " margin-left: 1em;\n", | ||
1033 | "}\n", | ||
1034 | "\n", | ||
1035 | ".mono {\n", | ||
1036 | " font-family: monospace;\n", | ||
1037 | " display: block;\n", | ||
1038 | "}\n", | ||
1039 | "\n", | ||
1040 | ".err-msg {\n", | ||
1041 | " color: red;\n", | ||
1042 | " font-style: italic;\n", | ||
1043 | " font-family: monospace;\n", | ||
1044 | " white-space: pre;\n", | ||
1045 | " display: block;\n", | ||
1046 | "}\n", | ||
1047 | "\n", | ||
1048 | "#unshowable {\n", | ||
1049 | " color: red;\n", | ||
1050 | " font-weight: bold;\n", | ||
1051 | "}\n", | ||
1052 | "\n", | ||
1053 | ".err-msg.in.collapse {\n", | ||
1054 | " padding-top: 0.7em;\n", | ||
1055 | "}\n", | ||
1056 | "\n", | ||
1057 | "/* Code that will get highlighted before it is highlighted */\n", | ||
1058 | ".highlight-code {\n", | ||
1059 | " white-space: pre;\n", | ||
1060 | " font-family: monospace;\n", | ||
1061 | "}\n", | ||
1062 | "\n", | ||
1063 | "/* Hlint styles */\n", | ||
1064 | ".suggestion-warning { \n", | ||
1065 | " font-weight: bold;\n", | ||
1066 | " color: rgb(200, 130, 0);\n", | ||
1067 | "}\n", | ||
1068 | ".suggestion-error { \n", | ||
1069 | " font-weight: bold;\n", | ||
1070 | " color: red;\n", | ||
1071 | "}\n", | ||
1072 | ".suggestion-name {\n", | ||
1073 | " font-weight: bold;\n", | ||
1074 | "}\n", | ||
1075 | "</style><p>$$\\begin{bmatrix}\n", | ||
1076 | "\\cdot & \\cdot & \\cdot & \\cdot & \\cdot & \\top & \\top\n", | ||
1077 | "\\\\\n", | ||
1078 | "\\cdot & \\top & \\cdot & \\top & \\top & \\cdot & \\top\n", | ||
1079 | "\\\\\n", | ||
1080 | "\\top & \\cdot & \\cdot & \\top & \\top & \\cdot & \\cdot\n", | ||
1081 | "\\\\\n", | ||
1082 | "\\top & \\top & \\top & \\top & \\cdot & \\top & \\top\n", | ||
1083 | "\\end{bmatrix}$$</p>" | ||
1084 | ] | ||
1085 | }, | ||
1086 | "metadata": {}, | ||
1087 | "output_type": "display_data" | ||
1088 | } | ||
1089 | ], | ||
1090 | "source": [ | ||
1091 | "fromColumns [p, q, p.&&.q, p .||.q, p `xor` q, p `equiv` q, p `imp` q]" | ||
1092 | ] | ||
1093 | }, | ||
1094 | { | ||
1095 | "cell_type": "code", | ||
1096 | "execution_count": 17, | ||
1097 | "metadata": { | ||
1098 | "collapsed": false | ||
1099 | }, | ||
1100 | "outputs": [ | ||
1101 | { | ||
1102 | "data": { | ||
1103 | "text/plain": [ | ||
1104 | "True" | ||
1105 | ] | ||
1106 | }, | ||
1107 | "metadata": {}, | ||
1108 | "output_type": "display_data" | ||
1109 | } | ||
1110 | ], | ||
1111 | "source": [ | ||
1112 | "taut $ (p `imp` q ) `equiv` (no q `imp` no p)" | ||
1113 | ] | ||
1114 | }, | ||
1115 | { | ||
1116 | "cell_type": "code", | ||
1117 | "execution_count": 18, | ||
1118 | "metadata": { | ||
1119 | "collapsed": false | ||
1120 | }, | ||
1121 | "outputs": [ | ||
1122 | { | ||
1123 | "data": { | ||
1124 | "text/plain": [ | ||
1125 | "False" | ||
1126 | ] | ||
1127 | }, | ||
1128 | "metadata": {}, | ||
1129 | "output_type": "display_data" | ||
1130 | } | ||
1131 | ], | ||
1132 | "source": [ | ||
1133 | "taut $ xor p q `equiv` (p .&&. no q .||. no p .&&. q)" | ||
1134 | ] | ||
1135 | } | ||
1136 | ], | ||
1137 | "metadata": { | ||
1138 | "kernelspec": { | ||
1139 | "display_name": "Haskell", | ||
1140 | "language": "haskell", | ||
1141 | "name": "haskell" | ||
1142 | }, | ||
1143 | "language_info": { | ||
1144 | "codemirror_mode": "ihaskell", | ||
1145 | "file_extension": ".hs", | ||
1146 | "name": "haskell", | ||
1147 | "version": "7.10.1" | ||
1148 | } | ||
1149 | }, | ||
1150 | "nbformat": 4, | ||
1151 | "nbformat_minor": 0 | ||
1152 | } | ||
diff --git a/examples/devel/ej1/functions.c b/examples/devel/ej1/functions.c deleted file mode 100644 index 02a4cdd..0000000 --- a/examples/devel/ej1/functions.c +++ /dev/null | |||
@@ -1,35 +0,0 @@ | |||
1 | /* assuming row order */ | ||
2 | |||
3 | typedef struct { double r, i; } doublecomplex; | ||
4 | |||
5 | #define DVEC(A) int A##n, double*A##p | ||
6 | #define CVEC(A) int A##n, doublecomplex*A##p | ||
7 | #define DMAT(A) int A##r, int A##c, double*A##p | ||
8 | #define CMAT(A) int A##r, int A##c, doublecomplex*A##p | ||
9 | |||
10 | #define AT(M,row,col) (M##p[(row)*M##c + (col)]) | ||
11 | |||
12 | /*-----------------------------------------------------*/ | ||
13 | |||
14 | int c_scale_vector(double s, DVEC(x), DVEC(y)) { | ||
15 | int k; | ||
16 | for (k=0; k<=yn; k++) { | ||
17 | yp[k] = s*xp[k]; | ||
18 | } | ||
19 | return 0; | ||
20 | } | ||
21 | |||
22 | /*-----------------------------------------------------*/ | ||
23 | |||
24 | int c_diag(DMAT(m),DVEC(y),DMAT(z)) { | ||
25 | int i,j; | ||
26 | for (j=0; j<yn; j++) { | ||
27 | yp[j] = AT(m,j,j); | ||
28 | } | ||
29 | for (i=0; i<mr; i++) { | ||
30 | for (j=0; j<mc; j++) { | ||
31 | AT(z,i,j) = i==j?yp[i]:0; | ||
32 | } | ||
33 | } | ||
34 | return 0; | ||
35 | } | ||
diff --git a/examples/devel/ej1/wrappers.hs b/examples/devel/ej1/wrappers.hs deleted file mode 100644 index a88f74b..0000000 --- a/examples/devel/ej1/wrappers.hs +++ /dev/null | |||
@@ -1,44 +0,0 @@ | |||
1 | {-# LANGUAGE ForeignFunctionInterface #-} | ||
2 | |||
3 | -- $ ghc -O2 --make wrappers.hs functions.c | ||
4 | |||
5 | import Numeric.LinearAlgebra | ||
6 | import Data.Packed.Development | ||
7 | import Foreign(Ptr,unsafePerformIO) | ||
8 | import Foreign.C.Types(CInt) | ||
9 | |||
10 | ----------------------------------------------------- | ||
11 | |||
12 | main = do | ||
13 | print $ myScale 3.0 (fromList [1..10]) | ||
14 | print $ myDiag $ (3><5) [1..] | ||
15 | |||
16 | ----------------------------------------------------- | ||
17 | |||
18 | foreign import ccall unsafe "c_scale_vector" | ||
19 | cScaleVector :: Double -- scale | ||
20 | -> CInt -> Ptr Double -- argument | ||
21 | -> CInt -> Ptr Double -- result | ||
22 | -> IO CInt -- exit code | ||
23 | |||
24 | myScale s x = unsafePerformIO $ do | ||
25 | y <- createVector (dim x) | ||
26 | app2 (cScaleVector s) vec x vec y "cScaleVector" | ||
27 | return y | ||
28 | |||
29 | ----------------------------------------------------- | ||
30 | -- forcing row order | ||
31 | |||
32 | foreign import ccall unsafe "c_diag" | ||
33 | cDiag :: CInt -> CInt -> Ptr Double -- argument | ||
34 | -> CInt -> Ptr Double -- result1 | ||
35 | -> CInt -> CInt -> Ptr Double -- result2 | ||
36 | -> IO CInt -- exit code | ||
37 | |||
38 | myDiag m = unsafePerformIO $ do | ||
39 | y <- createVector (min r c) | ||
40 | z <- createMatrix RowMajor r c | ||
41 | app3 cDiag mat (cmat m) vec y mat z "cDiag" | ||
42 | return (y,z) | ||
43 | where r = rows m | ||
44 | c = cols m | ||
diff --git a/examples/devel/ej2/functions.c b/examples/devel/ej2/functions.c deleted file mode 100644 index 4dcd377..0000000 --- a/examples/devel/ej2/functions.c +++ /dev/null | |||
@@ -1,24 +0,0 @@ | |||
1 | /* general element order */ | ||
2 | |||
3 | typedef struct { double r, i; } doublecomplex; | ||
4 | |||
5 | #define DVEC(A) int A##n, double*A##p | ||
6 | #define CVEC(A) int A##n, doublecomplex*A##p | ||
7 | #define DMAT(A) int A##r, int A##c, double*A##p | ||
8 | #define CMAT(A) int A##r, int A##c, doublecomplex*A##p | ||
9 | |||
10 | #define AT(M,r,c) (M##p[(r)*sr+(c)*sc]) | ||
11 | |||
12 | int c_diag(int ro, DMAT(m),DVEC(y),DMAT(z)) { | ||
13 | int i,j,sr,sc; | ||
14 | if (ro==1) { sr = mc; sc = 1;} else { sr = 1; sc = mr;} | ||
15 | for (j=0; j<yn; j++) { | ||
16 | yp[j] = AT(m,j,j); | ||
17 | } | ||
18 | for (i=0; i<mr; i++) { | ||
19 | for (j=0; j<mc; j++) { | ||
20 | AT(z,i,j) = i==j?yp[i]:0; | ||
21 | } | ||
22 | } | ||
23 | return 0; | ||
24 | } | ||
diff --git a/examples/devel/ej2/wrappers.hs b/examples/devel/ej2/wrappers.hs deleted file mode 100644 index 1c02a24..0000000 --- a/examples/devel/ej2/wrappers.hs +++ /dev/null | |||
@@ -1,32 +0,0 @@ | |||
1 | {-# LANGUAGE ForeignFunctionInterface #-} | ||
2 | |||
3 | -- $ ghc -O2 --make wrappers.hs functions.c | ||
4 | |||
5 | import Numeric.LinearAlgebra | ||
6 | import Data.Packed.Development | ||
7 | import Foreign(Ptr,unsafePerformIO) | ||
8 | import Foreign.C.Types(CInt) | ||
9 | |||
10 | ----------------------------------------------------- | ||
11 | |||
12 | main = do | ||
13 | print $ myDiag $ (3><5) [1..] | ||
14 | |||
15 | ----------------------------------------------------- | ||
16 | -- arbitrary data order | ||
17 | |||
18 | foreign import ccall unsafe "c_diag" | ||
19 | cDiag :: CInt -- matrix order | ||
20 | -> CInt -> CInt -> Ptr Double -- argument | ||
21 | -> CInt -> Ptr Double -- result1 | ||
22 | -> CInt -> CInt -> Ptr Double -- result2 | ||
23 | -> IO CInt -- exit code | ||
24 | |||
25 | myDiag m = unsafePerformIO $ do | ||
26 | y <- createVector (min r c) | ||
27 | z <- createMatrix (orderOf m) r c | ||
28 | app3 (cDiag o) mat m vec y mat z "cDiag" | ||
29 | return (y,z) | ||
30 | where r = rows m | ||
31 | c = cols m | ||
32 | o = if orderOf m == RowMajor then 1 else 0 | ||
diff --git a/examples/devel/example/functions.c b/examples/devel/example/functions.c new file mode 100644 index 0000000..67d3270 --- /dev/null +++ b/examples/devel/example/functions.c | |||
@@ -0,0 +1,22 @@ | |||
1 | |||
2 | typedef struct { double r, i; } doublecomplex; | ||
3 | |||
4 | #define VEC(T,A) int A##n, T* A##p | ||
5 | #define MAT(T,A) int A##r, int A##c, int A##Xr, int A##Xc, T* A##p | ||
6 | |||
7 | #define AT(m,i,j) (m##p[(i)*m##Xr + (j)*m##Xc]) | ||
8 | #define TRAV(m,i,j) int i,j; for (i=0;i<m##r;i++) for (j=0;j<m##c;j++) | ||
9 | |||
10 | |||
11 | int c_diag(MAT(double,m), VEC(double,y), MAT(double,z)) { | ||
12 | int k; | ||
13 | for (k=0; k<yn; k++) { | ||
14 | yp[k] = AT(m,k,k); | ||
15 | } | ||
16 | { TRAV(z,i,j) { | ||
17 | AT(z,i,j) = i==j?yp[i]:0; | ||
18 | } | ||
19 | } | ||
20 | return 0; | ||
21 | } | ||
22 | |||
diff --git a/examples/devel/example/wrappers.hs b/examples/devel/example/wrappers.hs new file mode 100644 index 0000000..f4e0f0b --- /dev/null +++ b/examples/devel/example/wrappers.hs | |||
@@ -0,0 +1,45 @@ | |||
1 | {-# LANGUAGE ForeignFunctionInterface #-} | ||
2 | {-# LANGUAGE TypeOperators #-} | ||
3 | {-# LANGUAGE GADTs #-} | ||
4 | |||
5 | {- | ||
6 | $ ghc -O2 wrappers.hs functions.c | ||
7 | $ ./wrappers | ||
8 | -} | ||
9 | |||
10 | import Numeric.LinearAlgebra | ||
11 | import Numeric.LinearAlgebra.Devel | ||
12 | import System.IO.Unsafe(unsafePerformIO) | ||
13 | import Foreign.C.Types(CInt(..)) | ||
14 | import Foreign.Ptr(Ptr) | ||
15 | |||
16 | |||
17 | infixl 1 # | ||
18 | a # b = apply a b | ||
19 | {-# INLINE (#) #-} | ||
20 | |||
21 | infixr 5 :>, ::> | ||
22 | type (:>) t r = CInt -> Ptr t -> r | ||
23 | type (::>) t r = CInt -> CInt -> CInt -> CInt -> Ptr t -> r | ||
24 | type Ok = IO CInt | ||
25 | |||
26 | ----------------------------------------------------- | ||
27 | |||
28 | x = (3><5) [1..] | ||
29 | |||
30 | main = do | ||
31 | print x | ||
32 | print $ myDiag x | ||
33 | print $ myDiag (tr x) | ||
34 | |||
35 | ----------------------------------------------------- | ||
36 | foreign import ccall unsafe "c_diag" cDiag :: Double ::> Double :> Double ::> Ok | ||
37 | |||
38 | myDiag m = unsafePerformIO $ do | ||
39 | y <- createVector (min r c) | ||
40 | z <- createMatrix RowMajor r c | ||
41 | cDiag # m # y # z #| "cDiag" | ||
42 | return (y,z) | ||
43 | where | ||
44 | (r,c) = size m | ||
45 | |||
diff --git a/examples/error.hs b/examples/error.hs index 5efae7c..77467df 100644 --- a/examples/error.hs +++ b/examples/error.hs | |||
@@ -8,6 +8,7 @@ test x = catch | |||
8 | (print x) | 8 | (print x) |
9 | (\e -> putStrLn $ "captured ["++ show (e :: SomeException) ++"]") | 9 | (\e -> putStrLn $ "captured ["++ show (e :: SomeException) ++"]") |
10 | 10 | ||
11 | |||
11 | main = do | 12 | main = do |
12 | setErrorHandlerOff | 13 | setErrorHandlerOff |
13 | 14 | ||
@@ -15,7 +16,8 @@ main = do | |||
15 | test $ 5 + (fst.exp_e) 1000 | 16 | test $ 5 + (fst.exp_e) 1000 |
16 | test $ bessel_zero_Jnu_e (-0.3) 2 | 17 | test $ bessel_zero_Jnu_e (-0.3) 2 |
17 | 18 | ||
18 | test $ (linearSolve 0 4 :: Matrix Double) | 19 | test $ (inv 0 :: Matrix Double) |
19 | test $ (linearSolve 5 (sqrt (-1)) :: Matrix Double) | 20 | test $ (linearSolveLS 5 (sqrt (-1)) :: Matrix Double) |
21 | |||
22 | putStrLn "Bye" | ||
20 | 23 | ||
21 | putStrLn "Bye" \ No newline at end of file | ||
diff --git a/examples/inplace.hs b/examples/inplace.hs index 574aa44..19f9bc9 100644 --- a/examples/inplace.hs +++ b/examples/inplace.hs | |||
@@ -1,7 +1,9 @@ | |||
1 | -- some tests of the interface for pure | 1 | -- some tests of the interface for pure |
2 | -- computations with inplace updates | 2 | -- computations with inplace updates |
3 | 3 | ||
4 | import Numeric.LinearAlgebra.HMatrix | 4 | {-# LANGUAGE FlexibleContexts #-} |
5 | |||
6 | import Numeric.LinearAlgebra | ||
5 | import Numeric.LinearAlgebra.Devel | 7 | import Numeric.LinearAlgebra.Devel |
6 | 8 | ||
7 | import Data.Array.Unboxed | 9 | import Data.Array.Unboxed |
diff --git a/examples/kalman.hs b/examples/kalman.hs index 7fbe3d2..9756aa0 100644 --- a/examples/kalman.hs +++ b/examples/kalman.hs | |||
@@ -1,17 +1,15 @@ | |||
1 | import Numeric.LinearAlgebra | 1 | import Numeric.LinearAlgebra |
2 | import Graphics.Plot | 2 | import Graphics.Plot |
3 | 3 | ||
4 | vector l = fromList l :: Vector Double | 4 | f = fromLists |
5 | matrix ls = fromLists ls :: Matrix Double | 5 | [[1,0,0,0], |
6 | diagl = diag . vector | 6 | [1,1,0,0], |
7 | [0,0,1,0], | ||
8 | [0,0,0,1]] | ||
7 | 9 | ||
8 | f = matrix [[1,0,0,0], | 10 | h = fromLists |
9 | [1,1,0,0], | 11 | [[0,-1,1,0], |
10 | [0,0,1,0], | 12 | [0,-1,0,1]] |
11 | [0,0,0,1]] | ||
12 | |||
13 | h = matrix [[0,-1,1,0], | ||
14 | [0,-1,0,1]] | ||
15 | 13 | ||
16 | q = diagl [1,1,0,0] | 14 | q = diagl [1,1,0,0] |
17 | 15 | ||
@@ -25,13 +23,13 @@ type Measurement = Vector Double | |||
25 | 23 | ||
26 | kalman :: System -> State -> Measurement -> State | 24 | kalman :: System -> State -> Measurement -> State |
27 | kalman (System f h q r) (State x p) z = State x' p' where | 25 | kalman (System f h q r) (State x p) z = State x' p' where |
28 | px = f <> x -- prediction | 26 | px = f #> x -- prediction |
29 | pq = f <> p <> trans f + q -- its covariance | 27 | pq = f <> p <> tr f + q -- its covariance |
30 | y = z - h <> px -- residue | 28 | y = z - h #> px -- residue |
31 | cy = h <> pq <> trans h + r -- its covariance | 29 | cy = h <> pq <> tr h + r -- its covariance |
32 | k = pq <> trans h <> inv cy -- kalman gain | 30 | k = pq <> tr h <> inv cy -- kalman gain |
33 | x' = px + k <> y -- new state | 31 | x' = px + k #> y -- new state |
34 | p' = (ident (dim x) - k <> h) <> pq -- its covariance | 32 | p' = (ident (size x) - k <> h) <> pq -- its covariance |
35 | 33 | ||
36 | sys = System f h q r | 34 | sys = System f h q r |
37 | 35 | ||
@@ -49,3 +47,4 @@ main = do | |||
49 | print $ fromRows $ take 10 (map sX xs) | 47 | print $ fromRows $ take 10 (map sX xs) |
50 | mapM_ (print . sP) $ take 10 xs | 48 | mapM_ (print . sP) $ take 10 xs |
51 | mplot (evolution 20 (xs,des)) | 49 | mplot (evolution 20 (xs,des)) |
50 | |||
diff --git a/examples/lie.hs b/examples/lie.hs index db21ea8..4933df6 100644 --- a/examples/lie.hs +++ b/examples/lie.hs | |||
@@ -1,8 +1,8 @@ | |||
1 | -- The magic of Lie Algebra | 1 | -- The magic of Lie Algebra |
2 | 2 | ||
3 | import Numeric.LinearAlgebra | 3 | {-# LANGUAGE FlexibleContexts #-} |
4 | 4 | ||
5 | disp = putStrLn . dispf 5 | 5 | import Numeric.LinearAlgebra |
6 | 6 | ||
7 | rot1 :: Double -> Matrix Double | 7 | rot1 :: Double -> Matrix Double |
8 | rot1 a = (3><3) | 8 | rot1 a = (3><3) |
@@ -58,8 +58,8 @@ main = do | |||
58 | exact = rot3 a <> rot1 b <> rot2 c | 58 | exact = rot3 a <> rot1 b <> rot2 c |
59 | lie = scalar a * g3 |+| scalar b * g1 |+| scalar c * g2 | 59 | lie = scalar a * g3 |+| scalar b * g1 |+| scalar c * g2 |
60 | putStrLn "position in the tangent space:" | 60 | putStrLn "position in the tangent space:" |
61 | disp lie | 61 | disp 5 lie |
62 | putStrLn "exponential map back to the group (2 terms):" | 62 | putStrLn "exponential map back to the group (2 terms):" |
63 | disp (expm lie) | 63 | disp 5 (expm lie) |
64 | putStrLn "exact position:" | 64 | putStrLn "exact position:" |
65 | disp exact | 65 | disp 5 exact |
diff --git a/examples/minimize.hs b/examples/minimize.hs index 19b2cb3..c27afc2 100644 --- a/examples/minimize.hs +++ b/examples/minimize.hs | |||
@@ -20,7 +20,7 @@ partialDerivative n f v = fst (derivCentral 0.01 g (v!!n)) where | |||
20 | g x = f (concat [a,x:b]) | 20 | g x = f (concat [a,x:b]) |
21 | (a,_:b) = splitAt n v | 21 | (a,_:b) = splitAt n v |
22 | 22 | ||
23 | disp = putStrLn . format " " (printf "%.3f") | 23 | disp' = putStrLn . format " " (printf "%.3f") |
24 | 24 | ||
25 | allMethods :: (Enum a, Bounded a) => [a] | 25 | allMethods :: (Enum a, Bounded a) => [a] |
26 | allMethods = [minBound .. maxBound] | 26 | allMethods = [minBound .. maxBound] |
@@ -29,22 +29,23 @@ test method = do | |||
29 | print method | 29 | print method |
30 | let (s,p) = minimize method 1E-2 30 [1,1] f [5,7] | 30 | let (s,p) = minimize method 1E-2 30 [1,1] f [5,7] |
31 | print s | 31 | print s |
32 | disp p | 32 | disp' p |
33 | 33 | ||
34 | testD method = do | 34 | testD method = do |
35 | print method | 35 | print method |
36 | let (s,p) = minimizeD method 1E-3 30 1E-2 1E-4 f df [5,7] | 36 | let (s,p) = minimizeD method 1E-3 30 1E-2 1E-4 f df [5,7] |
37 | print s | 37 | print s |
38 | disp p | 38 | disp' p |
39 | 39 | ||
40 | testD' method = do | 40 | testD' method = do |
41 | putStrLn $ show method ++ " with estimated gradient" | 41 | putStrLn $ show method ++ " with estimated gradient" |
42 | let (s,p) = minimizeD method 1E-3 30 1E-2 1E-4 f (gradient f) [5,7] | 42 | let (s,p) = minimizeD method 1E-3 30 1E-2 1E-4 f (gradient f) [5,7] |
43 | print s | 43 | print s |
44 | disp p | 44 | disp' p |
45 | 45 | ||
46 | main = do | 46 | main = do |
47 | mapM_ test [NMSimplex, NMSimplex2] | 47 | mapM_ test [NMSimplex, NMSimplex2] |
48 | mapM_ testD allMethods | 48 | mapM_ testD allMethods |
49 | testD' ConjugateFR | 49 | testD' ConjugateFR |
50 | mplot $ drop 3 . toColumns . snd $ minimizeS f [5,7] | 50 | mplot $ drop 3 . toColumns . snd $ minimizeS f [5,7] |
51 | |||
diff --git a/examples/monadic.hs b/examples/monadic.hs index 7c6e0f4..cf8aacc 100644 --- a/examples/monadic.hs +++ b/examples/monadic.hs | |||
@@ -1,35 +1,37 @@ | |||
1 | -- monadic computations | 1 | -- monadic computations |
2 | -- (contributed by Vivian McPhail) | 2 | -- (contributed by Vivian McPhail) |
3 | 3 | ||
4 | {-# LANGUAGE FlexibleContexts #-} | ||
5 | |||
4 | import Numeric.LinearAlgebra | 6 | import Numeric.LinearAlgebra |
7 | import Numeric.LinearAlgebra.Devel | ||
5 | import Control.Monad.State.Strict | 8 | import Control.Monad.State.Strict |
6 | import Control.Monad.Maybe | 9 | import Control.Monad.Trans.Maybe |
7 | import Foreign.Storable(Storable) | 10 | import Foreign.Storable(Storable) |
8 | import System.Random(randomIO) | 11 | import System.Random(randomIO) |
9 | 12 | ||
10 | ------------------------------------------- | 13 | ------------------------------------------- |
11 | 14 | ||
12 | -- an instance of MonadIO, a monad transformer | 15 | -- an instance of MonadIO, a monad transformer |
13 | type VectorMonadT = StateT Int IO | 16 | type VectorMonadT = StateT I IO |
14 | 17 | ||
15 | test1 :: Vector Int -> IO (Vector Int) | 18 | test1 :: Vector I -> IO (Vector I) |
16 | test1 = mapVectorM $ \x -> do | 19 | test1 = mapVectorM $ \x -> do |
17 | putStr $ (show x) ++ " " | 20 | putStr $ (show x) ++ " " |
18 | return (x + 1) | 21 | return (x + 1) |
19 | 22 | ||
20 | -- we can have an arbitrary monad AND do IO | 23 | -- we can have an arbitrary monad AND do IO |
21 | addInitialM :: Vector Int -> VectorMonadT () | 24 | addInitialM :: Vector I -> VectorMonadT () |
22 | addInitialM = mapVectorM_ $ \x -> do | 25 | addInitialM = mapVectorM_ $ \x -> do |
23 | i <- get | 26 | i <- get |
24 | liftIO $ putStr $ (show $ x + i) ++ " " | 27 | liftIO $ putStr $ (show $ x + i) ++ " " |
25 | put $ x + i | 28 | put $ x + i |
26 | 29 | ||
27 | -- sum the values of the even indiced elements | 30 | -- sum the values of the even indiced elements |
28 | sumEvens :: Vector Int -> Int | 31 | sumEvens :: Vector I -> I |
29 | sumEvens = foldVectorWithIndex (\x a b -> if x `mod` 2 == 0 then a + b else b) 0 | 32 | sumEvens = foldVectorWithIndex (\x a b -> if x `mod` 2 == 0 then a + b else b) 0 |
30 | 33 | ||
31 | -- sum and print running total of evens | 34 | -- sum and print running total of evens |
32 | sumEvensAndPrint :: Vector Int -> VectorMonadT () | ||
33 | sumEvensAndPrint = mapVectorWithIndexM_ $ \ i x -> do | 35 | sumEvensAndPrint = mapVectorWithIndexM_ $ \ i x -> do |
34 | when (i `mod` 2 == 0) $ do | 36 | when (i `mod` 2 == 0) $ do |
35 | v <- get | 37 | v <- get |
@@ -38,7 +40,7 @@ sumEvensAndPrint = mapVectorWithIndexM_ $ \ i x -> do | |||
38 | liftIO $ putStr $ (show v') ++ " " | 40 | liftIO $ putStr $ (show v') ++ " " |
39 | 41 | ||
40 | 42 | ||
41 | indexPlusSum :: Vector Int -> VectorMonadT () | 43 | --indexPlusSum :: Vector I -> VectorMonadT () |
42 | indexPlusSum v' = do | 44 | indexPlusSum v' = do |
43 | let f i x = do | 45 | let f i x = do |
44 | s <- get | 46 | s <- get |
@@ -63,7 +65,7 @@ monoStep d = do | |||
63 | 65 | ||
64 | isMonotoneIncreasing :: Vector Double -> Bool | 66 | isMonotoneIncreasing :: Vector Double -> Bool |
65 | isMonotoneIncreasing v = | 67 | isMonotoneIncreasing v = |
66 | let res = evalState (runMaybeT $ (mapVectorM_ monoStep v)) (v @> 0) | 68 | let res = evalState (runMaybeT $ (mapVectorM_ monoStep v)) (v ! 0) |
67 | in case res of | 69 | in case res of |
68 | Nothing -> False | 70 | Nothing -> False |
69 | Just _ -> True | 71 | Just _ -> True |
@@ -72,8 +74,8 @@ isMonotoneIncreasing v = | |||
72 | ------------------------------------------- | 74 | ------------------------------------------- |
73 | 75 | ||
74 | -- | apply a test to successive elements of a vector, evaluates to true iff test passes for all pairs | 76 | -- | apply a test to successive elements of a vector, evaluates to true iff test passes for all pairs |
75 | successive_ :: Storable a => (a -> a -> Bool) -> Vector a -> Bool | 77 | successive_ :: (Container Vector a, Indexable (Vector a) a) => (a -> a -> Bool) -> Vector a -> Bool |
76 | successive_ t v = maybe False (\_ -> True) $ evalState (runMaybeT (mapVectorM_ step (subVector 1 (dim v - 1) v))) (v @> 0) | 78 | successive_ t v = maybe False (\_ -> True) $ evalState (runMaybeT (mapVectorM_ step (subVector 1 (size v - 1) v))) (v ! 0) |
77 | where step e = do | 79 | where step e = do |
78 | ep <- lift $ get | 80 | ep <- lift $ get |
79 | if t e ep | 81 | if t e ep |
@@ -81,8 +83,10 @@ successive_ t v = maybe False (\_ -> True) $ evalState (runMaybeT (mapVectorM_ s | |||
81 | else (fail "successive_ test failed") | 83 | else (fail "successive_ test failed") |
82 | 84 | ||
83 | -- | operate on successive elements of a vector and return the resulting vector, whose length 1 less than that of the input | 85 | -- | operate on successive elements of a vector and return the resulting vector, whose length 1 less than that of the input |
84 | successive :: (Storable a, Storable b) => (a -> a -> b) -> Vector a -> Vector b | 86 | successive |
85 | successive f v = evalState (mapVectorM step (subVector 1 (dim v - 1) v)) (v @> 0) | 87 | :: (Storable b, Container Vector s, Indexable (Vector s) s) |
88 | => (s -> s -> b) -> Vector s -> Vector b | ||
89 | successive f v = evalState (mapVectorM step (subVector 1 (size v - 1) v)) (v ! 0) | ||
86 | where step e = do | 90 | where step e = do |
87 | ep <- get | 91 | ep <- get |
88 | put e | 92 | put e |
@@ -90,7 +94,7 @@ successive f v = evalState (mapVectorM step (subVector 1 (dim v - 1) v)) (v @> 0 | |||
90 | 94 | ||
91 | ------------------------------------------- | 95 | ------------------------------------------- |
92 | 96 | ||
93 | v :: Vector Int | 97 | v :: Vector I |
94 | v = 10 |> [0..] | 98 | v = 10 |> [0..] |
95 | 99 | ||
96 | w = fromList ([1..10]++[10,9..1]) :: Vector Double | 100 | w = fromList ([1..10]++[10,9..1]) :: Vector Double |
@@ -116,3 +120,4 @@ main = do | |||
116 | print $ successive_ (>) v | 120 | print $ successive_ (>) v |
117 | print $ successive_ (>) w | 121 | print $ successive_ (>) w |
118 | print $ successive (+) v | 122 | print $ successive (+) v |
123 | |||
diff --git a/examples/multiply.hs b/examples/multiply.hs index 572961c..be8fa73 100644 --- a/examples/multiply.hs +++ b/examples/multiply.hs | |||
@@ -22,10 +22,10 @@ instance Container Vector t => Scaling t (Vector t) (Vector t) where | |||
22 | instance Container Vector t => Scaling (Vector t) t (Vector t) where | 22 | instance Container Vector t => Scaling (Vector t) t (Vector t) where |
23 | (⋅) = flip scale | 23 | (⋅) = flip scale |
24 | 24 | ||
25 | instance Container Vector t => Scaling t (Matrix t) (Matrix t) where | 25 | instance (Num t, Container Vector t) => Scaling t (Matrix t) (Matrix t) where |
26 | (⋅) = scale | 26 | (⋅) = scale |
27 | 27 | ||
28 | instance Container Vector t => Scaling (Matrix t) t (Matrix t) where | 28 | instance (Num t, Container Vector t) => Scaling (Matrix t) t (Matrix t) where |
29 | (⋅) = flip scale | 29 | (⋅) = flip scale |
30 | 30 | ||
31 | 31 | ||
@@ -42,14 +42,14 @@ class Mul a b c | a b -> c, a c -> b, b c -> a where | |||
42 | instance Product t => Mul (Vector t) (Vector t) t where | 42 | instance Product t => Mul (Vector t) (Vector t) t where |
43 | (×) = udot | 43 | (×) = udot |
44 | 44 | ||
45 | instance Product t => Mul (Matrix t) (Vector t) (Vector t) where | 45 | instance (Numeric t, Product t) => Mul (Matrix t) (Vector t) (Vector t) where |
46 | (×) = mXv | 46 | (×) = (#>) |
47 | 47 | ||
48 | instance Product t => Mul (Vector t) (Matrix t) (Vector t) where | 48 | instance (Numeric t, Product t) => Mul (Vector t) (Matrix t) (Vector t) where |
49 | (×) = vXm | 49 | (×) = (<#) |
50 | 50 | ||
51 | instance Product t => Mul (Matrix t) (Matrix t) (Matrix t) where | 51 | instance (Numeric t, Product t) => Mul (Matrix t) (Matrix t) (Matrix t) where |
52 | (×) = mXm | 52 | (×) = (<>) |
53 | 53 | ||
54 | 54 | ||
55 | --instance Scaling a b c => Contraction a b c where | 55 | --instance Scaling a b c => Contraction a b c where |
@@ -92,9 +92,9 @@ u = fromList [3,0,5] | |||
92 | w = konst 1 (2,3) :: Matrix Double | 92 | w = konst 1 (2,3) :: Matrix Double |
93 | 93 | ||
94 | main = do | 94 | main = do |
95 | print $ (scale s v <> m) `udot` v | 95 | print $ (scale s v <# m) `udot` v |
96 | print $ scale s v `udot` (m <> v) | 96 | print $ scale s v `udot` (m #> v) |
97 | print $ s * ((v <> m) `udot` v) | 97 | print $ s * ((v <# m) `udot` v) |
98 | print $ s ⋅ v × m × v | 98 | print $ s ⋅ v × m × v |
99 | print a | 99 | print a |
100 | -- print (b == c) | 100 | -- print (b == c) |
diff --git a/examples/ode.hs b/examples/ode.hs index dc6e0ec..4cf1673 100644 --- a/examples/ode.hs +++ b/examples/ode.hs | |||
@@ -43,7 +43,7 @@ vanderpol' mu = do | |||
43 | jac t (toList->[x,v]) = (2><2) [ 0 , 1 | 43 | jac t (toList->[x,v]) = (2><2) [ 0 , 1 |
44 | , -1-2*x*v*mu, mu*(1-x**2) ] | 44 | , -1-2*x*v*mu, mu*(1-x**2) ] |
45 | ts = linspace 1000 (0,50) | 45 | ts = linspace 1000 (0,50) |
46 | hi = (ts@>1 - ts@>0)/100 | 46 | hi = (ts!1 - ts!0)/100 |
47 | sol = toColumns $ odeSolveV (MSBDF jac) hi 1E-8 1E-8 (xdot mu) (fromList [1,0]) ts | 47 | sol = toColumns $ odeSolveV (MSBDF jac) hi 1E-8 1E-8 (xdot mu) (fromList [1,0]) ts |
48 | mplot sol | 48 | mplot sol |
49 | 49 | ||
diff --git a/examples/pca1.hs b/examples/pca1.hs index a11eba9..ad2214d 100644 --- a/examples/pca1.hs +++ b/examples/pca1.hs | |||
@@ -8,27 +8,25 @@ import Control.Monad(when) | |||
8 | type Vec = Vector Double | 8 | type Vec = Vector Double |
9 | type Mat = Matrix Double | 9 | type Mat = Matrix Double |
10 | 10 | ||
11 | 11 | {- | |
12 | -- Vector with the mean value of the columns of a matrix | 12 | -- Vector with the mean value of the columns of a matrix |
13 | mean a = constant (recip . fromIntegral . rows $ a) (rows a) <> a | 13 | mean a = constant (recip . fromIntegral . rows $ a) (rows a) <> a |
14 | 14 | ||
15 | -- covariance matrix of a list of observations stored as rows | 15 | -- covariance matrix of a list of observations stored as rows |
16 | cov x = (trans xc <> xc) / fromIntegral (rows x - 1) | 16 | cov x = (trans xc <> xc) / fromIntegral (rows x - 1) |
17 | where xc = x - asRow (mean x) | 17 | where xc = x - asRow (mean x) |
18 | -} | ||
18 | 19 | ||
19 | 20 | ||
20 | -- creates the compression and decompression functions from the desired number of components | 21 | -- creates the compression and decompression functions from the desired number of components |
21 | pca :: Int -> Mat -> (Vec -> Vec , Vec -> Vec) | 22 | pca :: Int -> Mat -> (Vec -> Vec , Vec -> Vec) |
22 | pca n dataSet = (encode,decode) | 23 | pca n dataSet = (encode,decode) |
23 | where | 24 | where |
24 | encode x = vp <> (x - m) | 25 | encode x = vp #> (x - m) |
25 | decode x = x <> vp + m | 26 | decode x = x <# vp + m |
26 | m = mean dataSet | 27 | (m,c) = meanCov dataSet |
27 | c = cov dataSet | 28 | (_,v) = eigSH (trustSym c) |
28 | (_,v) = eigSH' c | 29 | vp = tr $ takeColumns n v |
29 | vp = takeRows n (trans v) | ||
30 | |||
31 | norm = pnorm PNorm2 | ||
32 | 30 | ||
33 | main = do | 31 | main = do |
34 | ok <- doesFileExist ("mnist.txt") | 32 | ok <- doesFileExist ("mnist.txt") |
@@ -43,4 +41,4 @@ main = do | |||
43 | let (pe,pd) = pca 10 xs | 41 | let (pe,pd) = pca 10 xs |
44 | let y = pe x | 42 | let y = pe x |
45 | print y -- compressed version | 43 | print y -- compressed version |
46 | print $ norm (x - pd y) / norm x --reconstruction quality | 44 | print $ norm_2 (x - pd y) / norm_2 x --reconstruction quality |
diff --git a/examples/pca2.hs b/examples/pca2.hs index e7ea95f..892d382 100644 --- a/examples/pca2.hs +++ b/examples/pca2.hs | |||
@@ -1,5 +1,7 @@ | |||
1 | -- Improved PCA, including illustrative graphics | 1 | -- Improved PCA, including illustrative graphics |
2 | 2 | ||
3 | {-# LANGUAGE FlexibleContexts #-} | ||
4 | |||
3 | import Numeric.LinearAlgebra | 5 | import Numeric.LinearAlgebra |
4 | import Graphics.Plot | 6 | import Graphics.Plot |
5 | import System.Directory(doesFileExist) | 7 | import System.Directory(doesFileExist) |
@@ -10,27 +12,27 @@ type Vec = Vector Double | |||
10 | type Mat = Matrix Double | 12 | type Mat = Matrix Double |
11 | 13 | ||
12 | -- Vector with the mean value of the columns of a matrix | 14 | -- Vector with the mean value of the columns of a matrix |
13 | mean a = constant (recip . fromIntegral . rows $ a) (rows a) <> a | 15 | mean a = konst (recip . fromIntegral . rows $ a) (rows a) <# a |
14 | 16 | ||
15 | -- covariance matrix of a list of observations stored as rows | 17 | -- covariance matrix of a list of observations stored as rows |
16 | cov x = (trans xc <> xc) / fromIntegral (rows x - 1) | 18 | cov x = (mTm xc) -- / fromIntegral (rows x - 1) |
17 | where xc = x - asRow (mean x) | 19 | where xc = x - asRow (mean x) |
18 | 20 | ||
19 | 21 | ||
20 | type Stat = (Vec, [Double], Mat) | 22 | type Stat = (Vec, [Double], Mat) |
21 | -- 1st and 2nd order statistics of a dataset (mean, eigenvalues and eigenvectors of cov) | 23 | -- 1st and 2nd order statistics of a dataset (mean, eigenvalues and eigenvectors of cov) |
22 | stat :: Mat -> Stat | 24 | stat :: Mat -> Stat |
23 | stat x = (m, toList s, trans v) where | 25 | stat x = (m, toList s, tr v) where |
24 | m = mean x | 26 | m = mean x |
25 | (s,v) = eigSH' (cov x) | 27 | (s,v) = eigSH (cov x) |
26 | 28 | ||
27 | -- creates the compression and decompression functions from the desired reconstruction | 29 | -- creates the compression and decompression functions from the desired reconstruction |
28 | -- quality and the statistics of a data set | 30 | -- quality and the statistics of a data set |
29 | pca :: Double -> Stat -> (Vec -> Vec , Vec -> Vec) | 31 | pca :: Double -> Stat -> (Vec -> Vec , Vec -> Vec) |
30 | pca prec (m,s,v) = (encode,decode) | 32 | pca prec (m,s,v) = (encode,decode) |
31 | where | 33 | where |
32 | encode x = vp <> (x - m) | 34 | encode x = vp #> (x - m) |
33 | decode x = x <> vp + m | 35 | decode x = x <# vp + m |
34 | vp = takeRows n v | 36 | vp = takeRows n v |
35 | n = 1 + (length $ fst $ span (< (prec'*sum s)) $ cumSum s) | 37 | n = 1 + (length $ fst $ span (< (prec'*sum s)) $ cumSum s) |
36 | cumSum = tail . scanl (+) 0.0 | 38 | cumSum = tail . scanl (+) 0.0 |
@@ -46,7 +48,7 @@ test :: Stat -> Double -> Vec -> IO () | |||
46 | test st prec x = do | 48 | test st prec x = do |
47 | let (pe,pd) = pca prec st | 49 | let (pe,pd) = pca prec st |
48 | let y = pe x | 50 | let y = pe x |
49 | print $ dim y | 51 | print $ size y |
50 | shdigit (pd y) | 52 | shdigit (pd y) |
51 | 53 | ||
52 | main = do | 54 | main = do |
@@ -63,3 +65,4 @@ main = do | |||
63 | let st = stat xs | 65 | let st = stat xs |
64 | test st 0.90 x | 66 | test st 0.90 x |
65 | test st 0.50 x | 67 | test st 0.50 x |
68 | |||
diff --git a/examples/pinv.hs b/examples/pinv.hs index 7de50b8..6f093b4 100644 --- a/examples/pinv.hs +++ b/examples/pinv.hs | |||
@@ -1,20 +1,19 @@ | |||
1 | import Numeric.LinearAlgebra | 1 | import Numeric.LinearAlgebra |
2 | import Graphics.Plot | ||
3 | import Text.Printf(printf) | 2 | import Text.Printf(printf) |
4 | 3 | ||
5 | expand :: Int -> Vector Double -> Matrix Double | 4 | expand :: Int -> Vector R -> Matrix R |
6 | expand n x = fromColumns $ map (x^) [0 .. n] | 5 | expand n x = fromColumns $ map (x^) [0 .. n] |
7 | 6 | ||
8 | polynomialModel :: Vector Double -> Vector Double -> Int | 7 | polynomialModel :: Vector R -> Vector R -> Int |
9 | -> (Vector Double -> Vector Double) | 8 | -> (Vector R -> Vector R) |
10 | polynomialModel x y n = f where | 9 | polynomialModel x y n = f where |
11 | f z = expand n z <> ws | 10 | f z = expand n z #> ws |
12 | ws = expand n x <\> y | 11 | ws = expand n x <\> y |
13 | 12 | ||
14 | main = do | 13 | main = do |
15 | [x,y] <- (toColumns . readMatrix) `fmap` readFile "data.txt" | 14 | [x,y] <- toColumns <$> loadMatrix "data.txt" |
16 | let pol = polynomialModel x y | 15 | let pol = polynomialModel x y |
17 | let view = [x, y, pol 1 x, pol 2 x, pol 3 x] | 16 | let view = [x, y, pol 1 x, pol 2 x, pol 3 x] |
18 | putStrLn $ " x y p 1 p 2 p 3" | 17 | putStrLn $ " x y p 1 p 2 p 3" |
19 | putStrLn $ format " " (printf "%.2f") $ fromColumns view | 18 | putStrLn $ format " " (printf "%.2f") $ fromColumns view |
20 | mplot view | 19 | |
diff --git a/examples/pinv.ipynb b/examples/pinv.ipynb new file mode 100644 index 0000000..532b8d0 --- /dev/null +++ b/examples/pinv.ipynb | |||
@@ -0,0 +1,722 @@ | |||
1 | { | ||
2 | "cells": [ | ||
3 | { | ||
4 | "cell_type": "code", | ||
5 | "execution_count": 1, | ||
6 | "metadata": { | ||
7 | "collapsed": true | ||
8 | }, | ||
9 | "outputs": [], | ||
10 | "source": [ | ||
11 | "import Numeric.LinearAlgebra" | ||
12 | ] | ||
13 | }, | ||
14 | { | ||
15 | "cell_type": "code", | ||
16 | "execution_count": 2, | ||
17 | "metadata": { | ||
18 | "collapsed": true | ||
19 | }, | ||
20 | "outputs": [], | ||
21 | "source": [ | ||
22 | "import IHaskell.Display" | ||
23 | ] | ||
24 | }, | ||
25 | { | ||
26 | "cell_type": "code", | ||
27 | "execution_count": 3, | ||
28 | "metadata": { | ||
29 | "collapsed": false | ||
30 | }, | ||
31 | "outputs": [], | ||
32 | "source": [ | ||
33 | ":ext FlexibleInstances\n", | ||
34 | "\n", | ||
35 | "dec = 3\n", | ||
36 | "\n", | ||
37 | "instance IHaskellDisplay (Matrix C) where\n", | ||
38 | " display m = return $ Display [html (\"<p>$$\"++(latexFormat \"bmatrix\" . dispcf dec) m++\"$$</p>\")]\n", | ||
39 | "\n", | ||
40 | "instance IHaskellDisplay (Matrix R) where\n", | ||
41 | " display m = return $ Display [html (\"<p>$$\"++ (latexFormat \"bmatrix\" . dispf dec) m++\"$$</p>\")]" | ||
42 | ] | ||
43 | }, | ||
44 | { | ||
45 | "cell_type": "code", | ||
46 | "execution_count": 4, | ||
47 | "metadata": { | ||
48 | "collapsed": true | ||
49 | }, | ||
50 | "outputs": [], | ||
51 | "source": [ | ||
52 | "import Graphics.SVG\n", | ||
53 | "data RawSVG = RawSVG String\n", | ||
54 | "instance IHaskellDisplay RawSVG where\n", | ||
55 | " display (RawSVG s) = return $ Display [html $ \"<div style=\\\"width:600px\\\">\"++ s ++ \"</div>\"]\n", | ||
56 | "\n", | ||
57 | "lplot = RawSVG . hPlot" | ||
58 | ] | ||
59 | }, | ||
60 | { | ||
61 | "cell_type": "markdown", | ||
62 | "metadata": {}, | ||
63 | "source": [ | ||
64 | "# least squares polynomial model" | ||
65 | ] | ||
66 | }, | ||
67 | { | ||
68 | "cell_type": "code", | ||
69 | "execution_count": 5, | ||
70 | "metadata": { | ||
71 | "collapsed": false | ||
72 | }, | ||
73 | "outputs": [], | ||
74 | "source": [ | ||
75 | "expand :: Int -> Vector R -> Matrix R\n", | ||
76 | "expand n x = fromColumns $ map (x^) [0 .. n]\n", | ||
77 | "\n", | ||
78 | "polynomialModel :: Vector R -> Vector R -> Int -> (Vector R -> Vector R)\n", | ||
79 | "polynomialModel x y n = f\n", | ||
80 | " where\n", | ||
81 | " f z = expand n z #> ws\n", | ||
82 | " ws = expand n x <\\> y" | ||
83 | ] | ||
84 | }, | ||
85 | { | ||
86 | "cell_type": "code", | ||
87 | "execution_count": 6, | ||
88 | "metadata": { | ||
89 | "collapsed": true | ||
90 | }, | ||
91 | "outputs": [], | ||
92 | "source": [ | ||
93 | "[x,y] <- toColumns <$> loadMatrix \"data.txt\"" | ||
94 | ] | ||
95 | }, | ||
96 | { | ||
97 | "cell_type": "code", | ||
98 | "execution_count": 7, | ||
99 | "metadata": { | ||
100 | "collapsed": false | ||
101 | }, | ||
102 | "outputs": [ | ||
103 | { | ||
104 | "data": { | ||
105 | "text/plain": [ | ||
106 | "[0.9,2.1,3.1,4.0,4.9,6.1,7.0,7.9,9.1,10.2]" | ||
107 | ] | ||
108 | }, | ||
109 | "metadata": {}, | ||
110 | "output_type": "display_data" | ||
111 | }, | ||
112 | { | ||
113 | "data": { | ||
114 | "text/plain": [ | ||
115 | "[1.1,3.9,9.2,51.8,25.3,35.7,49.4,3.6,81.5,99.5]" | ||
116 | ] | ||
117 | }, | ||
118 | "metadata": {}, | ||
119 | "output_type": "display_data" | ||
120 | } | ||
121 | ], | ||
122 | "source": [ | ||
123 | "x\n", | ||
124 | "y" | ||
125 | ] | ||
126 | }, | ||
127 | { | ||
128 | "cell_type": "code", | ||
129 | "execution_count": 8, | ||
130 | "metadata": { | ||
131 | "collapsed": false | ||
132 | }, | ||
133 | "outputs": [ | ||
134 | { | ||
135 | "data": { | ||
136 | "text/html": [ | ||
137 | "<style>/*\n", | ||
138 | "Custom IHaskell CSS.\n", | ||
139 | "*/\n", | ||
140 | "\n", | ||
141 | "/* Styles used for the Hoogle display in the pager */\n", | ||
142 | ".hoogle-doc {\n", | ||
143 | " display: block;\n", | ||
144 | " padding-bottom: 1.3em;\n", | ||
145 | " padding-left: 0.4em;\n", | ||
146 | "}\n", | ||
147 | ".hoogle-code {\n", | ||
148 | " display: block;\n", | ||
149 | " font-family: monospace;\n", | ||
150 | " white-space: pre;\n", | ||
151 | "}\n", | ||
152 | ".hoogle-text {\n", | ||
153 | " display: block;\n", | ||
154 | "}\n", | ||
155 | ".hoogle-name {\n", | ||
156 | " color: green;\n", | ||
157 | " font-weight: bold;\n", | ||
158 | "}\n", | ||
159 | ".hoogle-head {\n", | ||
160 | " font-weight: bold;\n", | ||
161 | "}\n", | ||
162 | ".hoogle-sub {\n", | ||
163 | " display: block;\n", | ||
164 | " margin-left: 0.4em;\n", | ||
165 | "}\n", | ||
166 | ".hoogle-package {\n", | ||
167 | " font-weight: bold;\n", | ||
168 | " font-style: italic;\n", | ||
169 | "}\n", | ||
170 | ".hoogle-module {\n", | ||
171 | " font-weight: bold;\n", | ||
172 | "}\n", | ||
173 | ".hoogle-class {\n", | ||
174 | " font-weight: bold;\n", | ||
175 | "}\n", | ||
176 | "\n", | ||
177 | "/* Styles used for basic displays */\n", | ||
178 | ".get-type {\n", | ||
179 | " color: green;\n", | ||
180 | " font-weight: bold;\n", | ||
181 | " font-family: monospace;\n", | ||
182 | " display: block;\n", | ||
183 | " white-space: pre-wrap;\n", | ||
184 | "}\n", | ||
185 | "\n", | ||
186 | ".show-type {\n", | ||
187 | " color: green;\n", | ||
188 | " font-weight: bold;\n", | ||
189 | " font-family: monospace;\n", | ||
190 | " margin-left: 1em;\n", | ||
191 | "}\n", | ||
192 | "\n", | ||
193 | ".mono {\n", | ||
194 | " font-family: monospace;\n", | ||
195 | " display: block;\n", | ||
196 | "}\n", | ||
197 | "\n", | ||
198 | ".err-msg {\n", | ||
199 | " color: red;\n", | ||
200 | " font-style: italic;\n", | ||
201 | " font-family: monospace;\n", | ||
202 | " white-space: pre;\n", | ||
203 | " display: block;\n", | ||
204 | "}\n", | ||
205 | "\n", | ||
206 | "#unshowable {\n", | ||
207 | " color: red;\n", | ||
208 | " font-weight: bold;\n", | ||
209 | "}\n", | ||
210 | "\n", | ||
211 | ".err-msg.in.collapse {\n", | ||
212 | " padding-top: 0.7em;\n", | ||
213 | "}\n", | ||
214 | "\n", | ||
215 | "/* Code that will get highlighted before it is highlighted */\n", | ||
216 | ".highlight-code {\n", | ||
217 | " white-space: pre;\n", | ||
218 | " font-family: monospace;\n", | ||
219 | "}\n", | ||
220 | "\n", | ||
221 | "/* Hlint styles */\n", | ||
222 | ".suggestion-warning { \n", | ||
223 | " font-weight: bold;\n", | ||
224 | " color: rgb(200, 130, 0);\n", | ||
225 | "}\n", | ||
226 | ".suggestion-error { \n", | ||
227 | " font-weight: bold;\n", | ||
228 | " color: red;\n", | ||
229 | "}\n", | ||
230 | ".suggestion-name {\n", | ||
231 | " font-weight: bold;\n", | ||
232 | "}\n", | ||
233 | "</style><p>$$\\begin{bmatrix}\n", | ||
234 | "1.000 & 0.900 & 0.810\n", | ||
235 | "\\\\\n", | ||
236 | "1.000 & 2.100 & 4.410\n", | ||
237 | "\\\\\n", | ||
238 | "1.000 & 3.100 & 9.610\n", | ||
239 | "\\\\\n", | ||
240 | "1.000 & 4.000 & 16.000\n", | ||
241 | "\\\\\n", | ||
242 | "1.000 & 4.900 & 24.010\n", | ||
243 | "\\\\\n", | ||
244 | "1.000 & 6.100 & 37.210\n", | ||
245 | "\\\\\n", | ||
246 | "1.000 & 7.000 & 49.000\n", | ||
247 | "\\\\\n", | ||
248 | "1.000 & 7.900 & 62.410\n", | ||
249 | "\\\\\n", | ||
250 | "1.000 & 9.100 & 82.810\n", | ||
251 | "\\\\\n", | ||
252 | "1.000 & 10.200 & 104.040\n", | ||
253 | "\\end{bmatrix}$$</p>" | ||
254 | ] | ||
255 | }, | ||
256 | "metadata": {}, | ||
257 | "output_type": "display_data" | ||
258 | } | ||
259 | ], | ||
260 | "source": [ | ||
261 | "expand 2 x" | ||
262 | ] | ||
263 | }, | ||
264 | { | ||
265 | "cell_type": "code", | ||
266 | "execution_count": 9, | ||
267 | "metadata": { | ||
268 | "collapsed": true | ||
269 | }, | ||
270 | "outputs": [], | ||
271 | "source": [ | ||
272 | "pol = polynomialModel x y\n", | ||
273 | "view = [x, y, pol 1 x, pol 2 x, pol 3 x]" | ||
274 | ] | ||
275 | }, | ||
276 | { | ||
277 | "cell_type": "code", | ||
278 | "execution_count": 10, | ||
279 | "metadata": { | ||
280 | "collapsed": false | ||
281 | }, | ||
282 | "outputs": [ | ||
283 | { | ||
284 | "data": { | ||
285 | "text/plain": [ | ||
286 | " x y p 1 p 2 p 3" | ||
287 | ] | ||
288 | }, | ||
289 | "metadata": {}, | ||
290 | "output_type": "display_data" | ||
291 | }, | ||
292 | { | ||
293 | "data": { | ||
294 | "text/plain": [ | ||
295 | " 0.90 1.10 -3.41 7.70 -6.99\n", | ||
296 | " 2.10 3.90 6.83 9.80 15.97\n", | ||
297 | " 3.10 9.20 15.36 13.39 25.09\n", | ||
298 | " 4.00 51.80 23.04 18.05 28.22\n", | ||
299 | " 4.90 25.30 30.72 24.05 28.86\n", | ||
300 | " 6.10 35.70 40.96 34.16 29.68\n", | ||
301 | " 7.00 49.40 48.64 43.31 33.17\n", | ||
302 | " 7.90 3.60 56.32 53.82 41.60\n", | ||
303 | " 9.10 81.50 66.57 69.92 64.39\n", | ||
304 | "10.20 99.50 75.95 86.80 101.01" | ||
305 | ] | ||
306 | }, | ||
307 | "metadata": {}, | ||
308 | "output_type": "display_data" | ||
309 | } | ||
310 | ], | ||
311 | "source": [ | ||
312 | "import Text.Printf\n", | ||
313 | "\n", | ||
314 | "putStrLn \" x y p 1 p 2 p 3\"\n", | ||
315 | "putStrLn $ format \" \" (printf \"%.2f\") $ fromColumns view" | ||
316 | ] | ||
317 | }, | ||
318 | { | ||
319 | "cell_type": "code", | ||
320 | "execution_count": 11, | ||
321 | "metadata": { | ||
322 | "collapsed": false | ||
323 | }, | ||
324 | "outputs": [], | ||
325 | "source": [ | ||
326 | "t = linspace 100 (minElement x, maxElement x)" | ||
327 | ] | ||
328 | }, | ||
329 | { | ||
330 | "cell_type": "code", | ||
331 | "execution_count": 12, | ||
332 | "metadata": { | ||
333 | "collapsed": false | ||
334 | }, | ||
335 | "outputs": [ | ||
336 | { | ||
337 | "data": { | ||
338 | "text/html": [ | ||
339 | "<style>/*\n", | ||
340 | "Custom IHaskell CSS.\n", | ||
341 | "*/\n", | ||
342 | "\n", | ||
343 | "/* Styles used for the Hoogle display in the pager */\n", | ||
344 | ".hoogle-doc {\n", | ||
345 | " display: block;\n", | ||
346 | " padding-bottom: 1.3em;\n", | ||
347 | " padding-left: 0.4em;\n", | ||
348 | "}\n", | ||
349 | ".hoogle-code {\n", | ||
350 | " display: block;\n", | ||
351 | " font-family: monospace;\n", | ||
352 | " white-space: pre;\n", | ||
353 | "}\n", | ||
354 | ".hoogle-text {\n", | ||
355 | " display: block;\n", | ||
356 | "}\n", | ||
357 | ".hoogle-name {\n", | ||
358 | " color: green;\n", | ||
359 | " font-weight: bold;\n", | ||
360 | "}\n", | ||
361 | ".hoogle-head {\n", | ||
362 | " font-weight: bold;\n", | ||
363 | "}\n", | ||
364 | ".hoogle-sub {\n", | ||
365 | " display: block;\n", | ||
366 | " margin-left: 0.4em;\n", | ||
367 | "}\n", | ||
368 | ".hoogle-package {\n", | ||
369 | " font-weight: bold;\n", | ||
370 | " font-style: italic;\n", | ||
371 | "}\n", | ||
372 | ".hoogle-module {\n", | ||
373 | " font-weight: bold;\n", | ||
374 | "}\n", | ||
375 | ".hoogle-class {\n", | ||
376 | " font-weight: bold;\n", | ||
377 | "}\n", | ||
378 | "\n", | ||
379 | "/* Styles used for basic displays */\n", | ||
380 | ".get-type {\n", | ||
381 | " color: green;\n", | ||
382 | " font-weight: bold;\n", | ||
383 | " font-family: monospace;\n", | ||
384 | " display: block;\n", | ||
385 | " white-space: pre-wrap;\n", | ||
386 | "}\n", | ||
387 | "\n", | ||
388 | ".show-type {\n", | ||
389 | " color: green;\n", | ||
390 | " font-weight: bold;\n", | ||
391 | " font-family: monospace;\n", | ||
392 | " margin-left: 1em;\n", | ||
393 | "}\n", | ||
394 | "\n", | ||
395 | ".mono {\n", | ||
396 | " font-family: monospace;\n", | ||
397 | " display: block;\n", | ||
398 | "}\n", | ||
399 | "\n", | ||
400 | ".err-msg {\n", | ||
401 | " color: red;\n", | ||
402 | " font-style: italic;\n", | ||
403 | " font-family: monospace;\n", | ||
404 | " white-space: pre;\n", | ||
405 | " display: block;\n", | ||
406 | "}\n", | ||
407 | "\n", | ||
408 | "#unshowable {\n", | ||
409 | " color: red;\n", | ||
410 | " font-weight: bold;\n", | ||
411 | "}\n", | ||
412 | "\n", | ||
413 | ".err-msg.in.collapse {\n", | ||
414 | " padding-top: 0.7em;\n", | ||
415 | "}\n", | ||
416 | "\n", | ||
417 | "/* Code that will get highlighted before it is highlighted */\n", | ||
418 | ".highlight-code {\n", | ||
419 | " white-space: pre;\n", | ||
420 | " font-family: monospace;\n", | ||
421 | "}\n", | ||
422 | "\n", | ||
423 | "/* Hlint styles */\n", | ||
424 | ".suggestion-warning { \n", | ||
425 | " font-weight: bold;\n", | ||
426 | " color: rgb(200, 130, 0);\n", | ||
427 | "}\n", | ||
428 | ".suggestion-error { \n", | ||
429 | " font-weight: bold;\n", | ||
430 | " color: red;\n", | ||
431 | "}\n", | ||
432 | ".suggestion-name {\n", | ||
433 | " font-weight: bold;\n", | ||
434 | "}\n", | ||
435 | "</style><div style=\"width:600px\"><?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", | ||
436 | "<svg xmlns='http://www.w3.org/2000/svg' version='1.1' viewBox='0 0 600 400' >\n", | ||
437 | "\n", | ||
438 | "<g style='text-anchor:middle'>\n", | ||
439 | "<text x='300.000' y='30.000' style='font-size:14.0px'> polynomial models </text>\n", | ||
440 | "\n", | ||
441 | "</g>\n", | ||
442 | "\n", | ||
443 | "<g style='fill:white; stroke:none; stroke-width:1.0'>\n", | ||
444 | "<rect x='70.000' y='50.000' width='490.000' height='300.000' />\n", | ||
445 | "\n", | ||
446 | "</g>\n", | ||
447 | "\n", | ||
448 | "<g style='stroke:black;stroke-width:0.1;stroke-dasharray:2'>\n", | ||
449 | "<path d = 'M 144.961 350.000 144.961 50.000 ' />\n", | ||
450 | "<path d = 'M 240.758 350.000 240.758 50.000 ' />\n", | ||
451 | "<path d = 'M 336.554 350.000 336.554 50.000 ' />\n", | ||
452 | "<path d = 'M 432.351 350.000 432.351 50.000 ' />\n", | ||
453 | "<path d = 'M 528.148 350.000 528.148 50.000 ' />\n", | ||
454 | "\n", | ||
455 | "<path d = 'M 70.000 317.304 560.000 317.304 ' />\n", | ||
456 | "<path d = 'M 70.000 279.118 560.000 279.118 ' />\n", | ||
457 | "<path d = 'M 70.000 240.932 560.000 240.932 ' />\n", | ||
458 | "<path d = 'M 70.000 202.745 560.000 202.745 ' />\n", | ||
459 | "<path d = 'M 70.000 164.559 560.000 164.559 ' />\n", | ||
460 | "<path d = 'M 70.000 126.373 560.000 126.373 ' />\n", | ||
461 | "<path d = 'M 70.000 88.186 560.000 88.186 ' />\n", | ||
462 | "<path d = 'M 70.000 50.000 560.000 50.000 ' />\n", | ||
463 | "\n", | ||
464 | "\n", | ||
465 | "</g>\n", | ||
466 | "\n", | ||
467 | "\n", | ||
468 | "<defs> <clipPath id='clip7000050000490000300000'>\n", | ||
469 | "<rect x='70.000' y='50.000' width='490.000' height='300.000' />\n", | ||
470 | "</clipPath> </defs>\n", | ||
471 | "<g clip-path='url(#clip7000050000490000300000)'>\n", | ||
472 | "<g style='fill:none; stroke:none; stroke-width:1.0'>\n", | ||
473 | "<path d = 'M 92.273 277.018 149.751 271.672 197.649 261.552 240.758 180.215 283.866 230.812 341.344 210.955 384.453 184.798 427.561 272.244 485.039 123.509 537.727 89.141 ' />\n", | ||
474 | "\n", | ||
475 | "</g>\n", | ||
476 | "\n", | ||
477 | "<g style='fill:red; stroke:none; stroke-width:1.0'>\n", | ||
478 | "<circle cx='92.273' cy='277.018' r='3.000' />\n", | ||
479 | "<circle cx='149.751' cy='271.672' r='3.000' />\n", | ||
480 | "<circle cx='197.649' cy='261.552' r='3.000' />\n", | ||
481 | "<circle cx='240.758' cy='180.215' r='3.000' />\n", | ||
482 | "<circle cx='283.866' cy='230.812' r='3.000' />\n", | ||
483 | "<circle cx='341.344' cy='210.955' r='3.000' />\n", | ||
484 | "<circle cx='384.453' cy='184.798' r='3.000' />\n", | ||
485 | "<circle cx='427.561' cy='272.244' r='3.000' />\n", | ||
486 | "<circle cx='485.039' cy='123.509' r='3.000' />\n", | ||
487 | "<circle cx='537.727' cy='89.141' r='3.000' />\n", | ||
488 | "\n", | ||
489 | "\n", | ||
490 | "</g>\n", | ||
491 | "\n", | ||
492 | "\n", | ||
493 | "<g style='fill:none; stroke:blue; stroke-width:1.0'>\n", | ||
494 | "<path d = 'M 92.273 285.630 96.772 284.099 101.272 282.569 105.771 281.038 110.271 279.508 114.770 277.977 119.270 276.447 123.770 274.916 128.269 273.385 132.769 271.855 137.268 270.324 141.768 268.794 146.267 267.263 150.767 265.732 155.266 264.202 159.766 262.671 164.265 261.141 168.765 259.610 173.264 258.079 177.764 256.549 182.264 255.018 186.763 253.488 191.263 251.957 195.762 250.426 200.262 248.896 204.761 247.365 209.261 245.835 213.760 244.304 218.260 242.774 222.759 241.243 227.259 239.712 231.758 238.182 236.258 236.651 240.758 235.121 245.257 233.590 249.757 232.059 254.256 230.529 258.756 228.998 263.255 227.468 267.755 225.937 272.254 224.406 276.754 222.876 281.253 221.345 285.753 219.815 290.253 218.284 294.752 216.753 299.252 215.223 303.751 213.692 308.251 212.162 312.750 210.631 317.250 209.100 321.749 207.570 326.249 206.039 330.748 204.509 335.248 202.978 339.747 201.448 344.247 199.917 348.747 198.386 353.246 196.856 357.746 195.325 362.245 193.795 366.745 192.264 371.244 190.733 375.744 189.203 380.243 187.672 384.743 186.142 389.242 184.611 393.742 183.080 398.242 181.550 402.741 180.019 407.241 178.489 411.740 176.958 416.240 175.427 420.739 173.897 425.239 172.366 429.738 170.836 434.238 169.305 438.737 167.775 443.237 166.244 447.736 164.713 452.236 163.183 456.736 161.652 461.235 160.122 465.735 158.591 470.234 157.060 474.734 155.530 479.233 153.999 483.733 152.469 488.232 150.938 492.732 149.407 497.231 147.877 501.731 146.346 506.230 144.816 510.730 143.285 515.230 141.754 519.729 140.224 524.229 138.693 528.728 137.163 533.228 135.632 537.727 134.102 ' />\n", | ||
495 | "\n", | ||
496 | "</g>\n", | ||
497 | "\n", | ||
498 | "<g style='fill:none; stroke:none; stroke-width:1.0'>\n", | ||
499 | "\n", | ||
500 | "\n", | ||
501 | "</g>\n", | ||
502 | "\n", | ||
503 | "\n", | ||
504 | "<g style='fill:none; stroke:green; stroke-width:1.0'>\n", | ||
505 | "<path d = 'M 92.273 264.421 96.772 264.272 101.272 264.094 105.771 263.889 110.271 263.655 114.770 263.393 119.270 263.103 123.770 262.786 128.269 262.440 132.769 262.066 137.268 261.663 141.768 261.233 146.267 260.775 150.767 260.288 155.266 259.774 159.766 259.231 164.265 258.661 168.765 258.062 173.264 257.435 177.764 256.780 182.264 256.097 186.763 255.386 191.263 254.647 195.762 253.880 200.262 253.084 204.761 252.261 209.261 251.409 213.760 250.530 218.260 249.622 222.759 248.686 227.259 247.722 231.758 246.730 236.258 245.710 240.758 244.662 245.257 243.586 249.757 242.482 254.256 241.349 258.756 240.189 263.255 239.000 267.755 237.784 272.254 236.539 276.754 235.266 281.253 233.966 285.753 232.637 290.253 231.280 294.752 229.894 299.252 228.481 303.751 227.040 308.251 225.571 312.750 224.073 317.250 222.548 321.749 220.994 326.249 219.412 330.748 217.802 335.248 216.165 339.747 214.499 344.247 212.805 348.747 211.082 353.246 209.332 357.746 207.554 362.245 205.747 366.745 203.913 371.244 202.050 375.744 200.160 380.243 198.241 384.743 196.294 389.242 194.319 393.742 192.316 398.242 190.285 402.741 188.226 407.241 186.139 411.740 184.024 416.240 181.880 420.739 179.709 425.239 177.509 429.738 175.281 434.238 173.026 438.737 170.742 443.237 168.430 447.736 166.090 452.236 163.722 456.736 161.326 461.235 158.902 465.735 156.449 470.234 153.969 474.734 151.460 479.233 148.924 483.733 146.359 488.232 143.766 492.732 141.145 497.231 138.496 501.731 135.819 506.230 133.114 510.730 130.381 515.230 127.620 519.729 124.831 524.229 122.013 528.728 119.168 533.228 116.294 537.727 113.392 ' />\n", | ||
506 | "\n", | ||
507 | "</g>\n", | ||
508 | "\n", | ||
509 | "<g style='fill:none; stroke:none; stroke-width:1.0'>\n", | ||
510 | "\n", | ||
511 | "\n", | ||
512 | "</g>\n", | ||
513 | "\n", | ||
514 | "\n", | ||
515 | "<g style='fill:none; stroke:brown; stroke-width:1.0'>\n", | ||
516 | "<path d = 'M 92.273 292.464 96.772 287.916 101.272 283.574 105.771 279.435 110.271 275.493 114.770 271.744 119.270 268.182 123.770 264.803 128.269 261.602 132.769 258.574 137.268 255.714 141.768 253.018 146.267 250.480 150.767 248.096 155.266 245.861 159.766 243.770 164.265 241.818 168.765 240.000 173.264 238.312 177.764 236.749 182.264 235.305 186.763 233.976 191.263 232.757 195.762 231.644 200.262 230.630 204.761 229.713 209.261 228.886 213.760 228.145 218.260 227.485 222.759 226.901 227.259 226.389 231.758 225.943 236.258 225.558 240.758 225.230 245.257 224.955 249.757 224.726 254.256 224.540 258.756 224.391 263.255 224.274 267.755 224.185 272.254 224.120 276.754 224.072 281.253 224.037 285.753 224.010 290.253 223.987 294.752 223.963 299.252 223.933 303.751 223.891 308.251 223.833 312.750 223.755 317.250 223.651 321.749 223.517 326.249 223.347 330.748 223.137 335.248 222.882 339.747 222.578 344.247 222.218 348.747 221.799 353.246 221.316 357.746 220.763 362.245 220.137 366.745 219.431 371.244 218.642 375.744 217.764 380.243 216.792 384.743 215.723 389.242 214.550 393.742 213.269 398.242 211.875 402.741 210.364 407.241 208.730 411.740 206.969 416.240 205.075 420.739 203.044 425.239 200.872 429.738 198.552 434.238 196.081 438.737 193.454 443.237 190.665 447.736 187.710 452.236 184.584 456.736 181.283 461.235 177.800 465.735 174.132 470.234 170.274 474.734 166.220 479.233 161.966 483.733 157.507 488.232 152.839 492.732 147.956 497.231 142.853 501.731 137.526 506.230 131.970 510.730 126.180 515.230 120.151 519.729 113.879 524.229 107.358 528.728 100.583 533.228 93.550 537.727 86.254 ' />\n", | ||
517 | "\n", | ||
518 | "</g>\n", | ||
519 | "\n", | ||
520 | "<g style='fill:none; stroke:none; stroke-width:1.0'>\n", | ||
521 | "\n", | ||
522 | "\n", | ||
523 | "</g>\n", | ||
524 | "\n", | ||
525 | "\n", | ||
526 | "<g style='fill:none; stroke:gray; stroke-width:1.0'>\n", | ||
527 | "<path d = 'M 92.273 277.015 96.772 108.390 101.272 6.475 105.771 -44.591 110.271 -58.019 114.770 -44.689 119.270 -13.449 123.770 28.619 128.269 75.953 132.769 124.297 137.268 170.498 141.768 212.329 146.267 248.336 150.767 277.694 155.266 300.082 159.766 315.578 164.265 324.561 168.765 327.634 173.264 325.547 177.764 319.145 182.264 309.311 186.763 296.934 191.263 282.868 195.762 267.912 200.262 252.787 204.761 238.125 209.261 224.460 213.760 212.222 218.260 201.737 222.759 193.228 227.259 186.824 231.758 182.561 236.258 180.396 240.758 180.214 245.257 181.837 249.757 185.041 254.256 189.561 258.756 195.107 263.255 201.373 267.755 208.048 272.254 214.828 276.754 221.419 281.253 227.554 285.753 232.993 290.253 237.531 294.752 241.006 299.252 243.296 303.751 244.329 308.251 244.075 312.750 242.556 317.250 239.835 321.749 236.021 326.249 231.260 330.748 225.734 335.248 219.654 339.747 213.254 344.247 206.785 348.747 200.506 353.246 194.677 357.746 189.553 362.245 185.374 366.745 182.357 371.244 180.689 375.744 180.519 380.243 181.953 384.743 185.046 389.242 189.798 393.742 196.149 398.242 203.976 402.741 213.095 407.241 223.253 411.740 234.137 416.240 245.374 420.739 256.537 425.239 267.148 429.738 276.695 434.238 284.634 438.737 290.410 443.237 293.466 447.736 293.268 452.236 289.317 456.736 281.178 461.235 268.503 465.735 251.055 470.234 228.741 474.734 201.640 479.233 170.040 483.733 134.469 488.232 95.733 492.732 54.957 497.231 13.618 501.731 -26.407 506.230 -62.807 510.730 -92.788 515.230 -113.039 519.729 -119.686 524.229 -108.253 528.728 -73.622 533.228 -9.996 537.727 89.136 ' />\n", | ||
528 | "\n", | ||
529 | "</g>\n", | ||
530 | "\n", | ||
531 | "<g style='fill:none; stroke:none; stroke-width:1.0'>\n", | ||
532 | "\n", | ||
533 | "\n", | ||
534 | "</g>\n", | ||
535 | "\n", | ||
536 | "\n", | ||
537 | "\n", | ||
538 | "\n", | ||
539 | "</g>\n", | ||
540 | "\n", | ||
541 | "<g style='fill:none; stroke:black; stroke-width:1.5'>\n", | ||
542 | "<rect x='70.000' y='50.000' width='490.000' height='300.000' />\n", | ||
543 | "\n", | ||
544 | "</g>\n", | ||
545 | "\n", | ||
546 | "<g style='text-anchor:middle'>\n", | ||
547 | "<text x='144.961' y='366.000' style='font-size:12.0px'> 2 </text>\n", | ||
548 | "<text x='240.758' y='366.000' style='font-size:12.0px'> 4 </text>\n", | ||
549 | "<text x='336.554' y='366.000' style='font-size:12.0px'> 6 </text>\n", | ||
550 | "<text x='432.351' y='366.000' style='font-size:12.0px'> 8 </text>\n", | ||
551 | "<text x='528.148' y='366.000' style='font-size:12.0px'> 10 </text>\n", | ||
552 | "\n", | ||
553 | "<text x='315.000' y='382.000' style='font-size:12.0px'> </text>\n", | ||
554 | "\n", | ||
555 | "</g>\n", | ||
556 | "\n", | ||
557 | "<g style='text-anchor:end'>\n", | ||
558 | "<text x='62.000' y='320.304' style='font-size:12.0px'> -20 </text>\n", | ||
559 | "<text x='62.000' y='282.118' style='font-size:12.0px'> 0 </text>\n", | ||
560 | "<text x='62.000' y='243.932' style='font-size:12.0px'> 20 </text>\n", | ||
561 | "<text x='62.000' y='205.745' style='font-size:12.0px'> 40 </text>\n", | ||
562 | "<text x='62.000' y='167.559' style='font-size:12.0px'> 60 </text>\n", | ||
563 | "<text x='62.000' y='129.373' style='font-size:12.0px'> 80 </text>\n", | ||
564 | "<text x='62.000' y='91.186' style='font-size:12.0px'> 100 </text>\n", | ||
565 | "<text x='62.000' y='53.000' style='font-size:12.0px'> 120 </text>\n", | ||
566 | "\n", | ||
567 | "</g>\n", | ||
568 | "\n", | ||
569 | "<g style='text-anchor:middle'>\n", | ||
570 | "<g transform='matrix(1,0,0,1,30.0,199.99999999999997)'>\n", | ||
571 | "<g transform='rotate(-90.0)'>\n", | ||
572 | "<g transform='matrix(1,0,0,1,-30.0,-199.99999999999997)'>\n", | ||
573 | "<text x='30.000' y='200.000' style='font-size:12.0px'> </text>\n", | ||
574 | "\n", | ||
575 | "</g>\n", | ||
576 | "\n", | ||
577 | "\n", | ||
578 | "</g>\n", | ||
579 | "\n", | ||
580 | "\n", | ||
581 | "</g>\n", | ||
582 | "\n", | ||
583 | "\n", | ||
584 | "</g>\n", | ||
585 | "\n", | ||
586 | "\n", | ||
587 | "<g style='stroke:black; stroke-width:1'>\n", | ||
588 | "<path d = 'M 144.961 350.000 144.961 344.000 ' />\n", | ||
589 | "<path d = 'M 240.758 350.000 240.758 344.000 ' />\n", | ||
590 | "<path d = 'M 336.554 350.000 336.554 344.000 ' />\n", | ||
591 | "<path d = 'M 432.351 350.000 432.351 344.000 ' />\n", | ||
592 | "<path d = 'M 528.148 350.000 528.148 344.000 ' />\n", | ||
593 | "\n", | ||
594 | "<path d = 'M 70.000 317.304 79.800 317.304 ' />\n", | ||
595 | "<path d = 'M 70.000 279.118 79.800 279.118 ' />\n", | ||
596 | "<path d = 'M 70.000 240.932 79.800 240.932 ' />\n", | ||
597 | "<path d = 'M 70.000 202.745 79.800 202.745 ' />\n", | ||
598 | "<path d = 'M 70.000 164.559 79.800 164.559 ' />\n", | ||
599 | "<path d = 'M 70.000 126.373 79.800 126.373 ' />\n", | ||
600 | "<path d = 'M 70.000 88.186 79.800 88.186 ' />\n", | ||
601 | "<path d = 'M 70.000 50.000 79.800 50.000 ' />\n", | ||
602 | "\n", | ||
603 | "\n", | ||
604 | "</g>\n", | ||
605 | "\n", | ||
606 | "\n", | ||
607 | "\n", | ||
608 | "<g style='fill:#fcfcff;stroke:gray'>\n", | ||
609 | "<rect x='89.500' y='44.600' width='113.600' height='122.400' />\n", | ||
610 | "\n", | ||
611 | "</g>\n", | ||
612 | "\n", | ||
613 | "<defs> <clipPath id='clip995004460030800122400'>\n", | ||
614 | "<rect x='99.500' y='44.600' width='30.800' height='122.400' />\n", | ||
615 | "</clipPath> </defs>\n", | ||
616 | "<g clip-path='url(#clip995004460030800122400)'>\n", | ||
617 | "<g style='fill:none; stroke:none; stroke-width:1.0'>\n", | ||
618 | "<path d = 'M 94.500 65.000 114.900 65.000 135.300 65.000 ' />\n", | ||
619 | "\n", | ||
620 | "</g>\n", | ||
621 | "\n", | ||
622 | "<g style='fill:red; stroke:none; stroke-width:1.0'>\n", | ||
623 | "<circle cx='94.500' cy='65.000' r='3.000' />\n", | ||
624 | "<circle cx='114.900' cy='65.000' r='3.000' />\n", | ||
625 | "<circle cx='135.300' cy='65.000' r='3.000' />\n", | ||
626 | "\n", | ||
627 | "\n", | ||
628 | "</g>\n", | ||
629 | "\n", | ||
630 | "\n", | ||
631 | "<g style='fill:none; stroke:blue; stroke-width:1.0'>\n", | ||
632 | "<path d = 'M 94.500 85.400 114.900 85.400 135.300 85.400 ' />\n", | ||
633 | "\n", | ||
634 | "</g>\n", | ||
635 | "\n", | ||
636 | "<g style='fill:none; stroke:none; stroke-width:1.0'>\n", | ||
637 | "\n", | ||
638 | "\n", | ||
639 | "</g>\n", | ||
640 | "\n", | ||
641 | "\n", | ||
642 | "<g style='fill:none; stroke:green; stroke-width:1.0'>\n", | ||
643 | "<path d = 'M 94.500 105.800 114.900 105.800 135.300 105.800 ' />\n", | ||
644 | "\n", | ||
645 | "</g>\n", | ||
646 | "\n", | ||
647 | "<g style='fill:none; stroke:none; stroke-width:1.0'>\n", | ||
648 | "\n", | ||
649 | "\n", | ||
650 | "</g>\n", | ||
651 | "\n", | ||
652 | "\n", | ||
653 | "<g style='fill:none; stroke:brown; stroke-width:1.0'>\n", | ||
654 | "<path d = 'M 94.500 126.200 114.900 126.200 135.300 126.200 ' />\n", | ||
655 | "\n", | ||
656 | "</g>\n", | ||
657 | "\n", | ||
658 | "<g style='fill:none; stroke:none; stroke-width:1.0'>\n", | ||
659 | "\n", | ||
660 | "\n", | ||
661 | "</g>\n", | ||
662 | "\n", | ||
663 | "\n", | ||
664 | "<g style='fill:none; stroke:gray; stroke-width:1.0'>\n", | ||
665 | "<path d = 'M 94.500 146.600 114.900 146.600 135.300 146.600 ' />\n", | ||
666 | "\n", | ||
667 | "</g>\n", | ||
668 | "\n", | ||
669 | "<g style='fill:none; stroke:none; stroke-width:1.0'>\n", | ||
670 | "\n", | ||
671 | "\n", | ||
672 | "</g>\n", | ||
673 | "\n", | ||
674 | "\n", | ||
675 | "\n", | ||
676 | "</g>\n", | ||
677 | "\n", | ||
678 | "<text x='140.300' y='68.000' style='font-size:12.0px'> data </text>\n", | ||
679 | "<text x='140.300' y='88.400' style='font-size:12.0px'> degree 1 </text>\n", | ||
680 | "<text x='140.300' y='108.800' style='font-size:12.0px'> degree 2 </text>\n", | ||
681 | "<text x='140.300' y='129.200' style='font-size:12.0px'> degree 3 </text>\n", | ||
682 | "<text x='140.300' y='149.600' style='font-size:12.0px'> degree 9 </text>\n", | ||
683 | "\n", | ||
684 | "\n", | ||
685 | "\n", | ||
686 | "\n", | ||
687 | "</svg>\n", | ||
688 | "</div>" | ||
689 | ] | ||
690 | }, | ||
691 | "metadata": {}, | ||
692 | "output_type": "display_data" | ||
693 | } | ||
694 | ], | ||
695 | "source": [ | ||
696 | "lplot\n", | ||
697 | " [ plotMark x y \"none\" 1 circles \"red\" 3 \"data\"\n", | ||
698 | " , plot t (pol 1 t) \"blue\" 1 \"degree 1\"\n", | ||
699 | " , plot t (pol 2 t) \"green\" 1 \"degree 2\"\n", | ||
700 | " , plot t (pol 3 t) \"brown\" 1 \"degree 3\"\n", | ||
701 | " , plot t (pol 9 t) \"gray\" 1 \"degree 9\"\n", | ||
702 | " , MarginX 0.05, Title \"polynomial models\", LegendPos 0.05 0.95, MaxY 120\n", | ||
703 | " ] " | ||
704 | ] | ||
705 | } | ||
706 | ], | ||
707 | "metadata": { | ||
708 | "kernelspec": { | ||
709 | "display_name": "Haskell", | ||
710 | "language": "haskell", | ||
711 | "name": "haskell" | ||
712 | }, | ||
713 | "language_info": { | ||
714 | "codemirror_mode": "ihaskell", | ||
715 | "file_extension": ".hs", | ||
716 | "name": "haskell", | ||
717 | "version": "7.10.1" | ||
718 | } | ||
719 | }, | ||
720 | "nbformat": 4, | ||
721 | "nbformat_minor": 0 | ||
722 | } | ||
diff --git a/examples/plot.hs b/examples/plot.hs index f950aa5..90643ed 100644 --- a/examples/plot.hs +++ b/examples/plot.hs | |||
@@ -16,5 +16,5 @@ cumdist x = 0.5 * (1+ erf (x/sqrt 2)) | |||
16 | main = do | 16 | main = do |
17 | let x = linspace 1000 (-4,4) | 17 | let x = linspace 1000 (-4,4) |
18 | mplot [f x] | 18 | mplot [f x] |
19 | mplot [x, mapVector cumdist x, mapVector gaussianPDF x] | 19 | mplot [x, cmap cumdist x, cmap gaussianPDF x] |
20 | mesh (sombrero 40) \ No newline at end of file | 20 | mesh (sombrero 40) |
diff --git a/examples/repmat.ipynb b/examples/repmat.ipynb new file mode 100644 index 0000000..afa9706 --- /dev/null +++ b/examples/repmat.ipynb | |||
@@ -0,0 +1,138 @@ | |||
1 | { | ||
2 | "cells": [ | ||
3 | { | ||
4 | "cell_type": "markdown", | ||
5 | "metadata": {}, | ||
6 | "source": [ | ||
7 | "# repmat" | ||
8 | ] | ||
9 | }, | ||
10 | { | ||
11 | "cell_type": "markdown", | ||
12 | "metadata": {}, | ||
13 | "source": [ | ||
14 | "An alternative implementation of `repmat` using the new in-place tools." | ||
15 | ] | ||
16 | }, | ||
17 | { | ||
18 | "cell_type": "code", | ||
19 | "execution_count": 1, | ||
20 | "metadata": { | ||
21 | "collapsed": false | ||
22 | }, | ||
23 | "outputs": [], | ||
24 | "source": [ | ||
25 | ":ext FlexibleContexts\n", | ||
26 | "\n", | ||
27 | "import Numeric.LinearAlgebra\n", | ||
28 | "import Numeric.LinearAlgebra.Devel" | ||
29 | ] | ||
30 | }, | ||
31 | { | ||
32 | "cell_type": "code", | ||
33 | "execution_count": 2, | ||
34 | "metadata": { | ||
35 | "collapsed": true | ||
36 | }, | ||
37 | "outputs": [], | ||
38 | "source": [ | ||
39 | "m = (3><4)[1..] :: Matrix Z" | ||
40 | ] | ||
41 | }, | ||
42 | { | ||
43 | "cell_type": "code", | ||
44 | "execution_count": 3, | ||
45 | "metadata": { | ||
46 | "collapsed": false | ||
47 | }, | ||
48 | "outputs": [ | ||
49 | { | ||
50 | "data": { | ||
51 | "text/plain": [ | ||
52 | "(3><4)\n", | ||
53 | " [ 1, 2, 3, 4\n", | ||
54 | " , 5, 6, 7, 8\n", | ||
55 | " , 9, 10, 11, 12 ]" | ||
56 | ] | ||
57 | }, | ||
58 | "metadata": {}, | ||
59 | "output_type": "display_data" | ||
60 | } | ||
61 | ], | ||
62 | "source": [ | ||
63 | "m" | ||
64 | ] | ||
65 | }, | ||
66 | { | ||
67 | "cell_type": "code", | ||
68 | "execution_count": 4, | ||
69 | "metadata": { | ||
70 | "collapsed": true | ||
71 | }, | ||
72 | "outputs": [], | ||
73 | "source": [ | ||
74 | "import Control.Monad.ST" | ||
75 | ] | ||
76 | }, | ||
77 | { | ||
78 | "cell_type": "code", | ||
79 | "execution_count": 5, | ||
80 | "metadata": { | ||
81 | "collapsed": false | ||
82 | }, | ||
83 | "outputs": [], | ||
84 | "source": [ | ||
85 | "rpmt m i j = runST $ do\n", | ||
86 | " x <- newUndefinedMatrix RowMajor dr dc\n", | ||
87 | " sequence_ [ setMatrix x a b m | a <- [0,r..dr], b <-[0,c..dc] ]\n", | ||
88 | " unsafeFreezeMatrix x\n", | ||
89 | " where\n", | ||
90 | " (r,c) = size m\n", | ||
91 | " dr = i*r\n", | ||
92 | " dc = j*c" | ||
93 | ] | ||
94 | }, | ||
95 | { | ||
96 | "cell_type": "code", | ||
97 | "execution_count": 6, | ||
98 | "metadata": { | ||
99 | "collapsed": false | ||
100 | }, | ||
101 | "outputs": [ | ||
102 | { | ||
103 | "data": { | ||
104 | "text/plain": [ | ||
105 | "(6><12)\n", | ||
106 | " [ 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4\n", | ||
107 | " , 5, 6, 7, 8, 5, 6, 7, 8, 5, 6, 7, 8\n", | ||
108 | " , 9, 10, 11, 12, 9, 10, 11, 12, 9, 10, 11, 12\n", | ||
109 | " , 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4\n", | ||
110 | " , 5, 6, 7, 8, 5, 6, 7, 8, 5, 6, 7, 8\n", | ||
111 | " , 9, 10, 11, 12, 9, 10, 11, 12, 9, 10, 11, 12 ]" | ||
112 | ] | ||
113 | }, | ||
114 | "metadata": {}, | ||
115 | "output_type": "display_data" | ||
116 | } | ||
117 | ], | ||
118 | "source": [ | ||
119 | "rpmt m 2 3" | ||
120 | ] | ||
121 | } | ||
122 | ], | ||
123 | "metadata": { | ||
124 | "kernelspec": { | ||
125 | "display_name": "Haskell", | ||
126 | "language": "haskell", | ||
127 | "name": "haskell" | ||
128 | }, | ||
129 | "language_info": { | ||
130 | "codemirror_mode": "ihaskell", | ||
131 | "file_extension": ".hs", | ||
132 | "name": "haskell", | ||
133 | "version": "7.10.1" | ||
134 | } | ||
135 | }, | ||
136 | "nbformat": 4, | ||
137 | "nbformat_minor": 0 | ||
138 | } | ||
diff --git a/examples/root.hs b/examples/root.hs index 8546ff5..fa6e77a 100644 --- a/examples/root.hs +++ b/examples/root.hs | |||
@@ -9,7 +9,7 @@ test method = do | |||
9 | print method | 9 | print method |
10 | let (s,p) = root method 1E-7 30 (rosenbrock 1 10) [-10,-5] | 10 | let (s,p) = root method 1E-7 30 (rosenbrock 1 10) [-10,-5] |
11 | print s -- solution | 11 | print s -- solution |
12 | disp p -- evolution of the algorithm | 12 | disp' p -- evolution of the algorithm |
13 | 13 | ||
14 | jacobian a b [x,y] = [ [-a , 0] | 14 | jacobian a b [x,y] = [ [-a , 0] |
15 | , [-2*b*x, b] ] | 15 | , [-2*b*x, b] ] |
@@ -18,9 +18,9 @@ testJ method = do | |||
18 | print method | 18 | print method |
19 | let (s,p) = rootJ method 1E-7 30 (rosenbrock 1 10) (jacobian 1 10) [-10,-5] | 19 | let (s,p) = rootJ method 1E-7 30 (rosenbrock 1 10) (jacobian 1 10) [-10,-5] |
20 | print s | 20 | print s |
21 | disp p | 21 | disp' p |
22 | 22 | ||
23 | disp = putStrLn . format " " (printf "%.3f") | 23 | disp' = putStrLn . format " " (printf "%.3f") |
24 | 24 | ||
25 | main = do | 25 | main = do |
26 | test Hybrids | 26 | test Hybrids |
diff --git a/examples/vector.hs b/examples/vector.hs deleted file mode 100644 index f531cbd..0000000 --- a/examples/vector.hs +++ /dev/null | |||
@@ -1,31 +0,0 @@ | |||
1 | -- conversion to/from Data.Vector.Storable | ||
2 | -- from Roman Leshchinskiy "vector" package | ||
3 | -- | ||
4 | -- In the future Data.Packed.Vector will be replaced by Data.Vector.Storable | ||
5 | |||
6 | ------------------------------------------- | ||
7 | |||
8 | import Numeric.LinearAlgebra as H | ||
9 | import Data.Packed.Development(unsafeFromForeignPtr, unsafeToForeignPtr) | ||
10 | import Foreign.Storable | ||
11 | import qualified Data.Vector.Storable as V | ||
12 | |||
13 | fromVector :: Storable t => V.Vector t -> H.Vector t | ||
14 | fromVector v = unsafeFromForeignPtr p i n where | ||
15 | (p,i,n) = V.unsafeToForeignPtr v | ||
16 | |||
17 | toVector :: Storable t => H.Vector t -> V.Vector t | ||
18 | toVector v = V.unsafeFromForeignPtr p i n where | ||
19 | (p,i,n) = unsafeToForeignPtr v | ||
20 | |||
21 | ------------------------------------------- | ||
22 | |||
23 | v = V.slice 5 10 (V.fromList [1 .. 10::Double] V.++ V.replicate 10 7) | ||
24 | |||
25 | w = subVector 2 3 (linspace 5 (0,1)) :: Vector Double | ||
26 | |||
27 | main = do | ||
28 | print v | ||
29 | print $ fromVector v | ||
30 | print w | ||
31 | print $ toVector w | ||
diff --git a/packages/Makefile b/packages/Makefile index e9d8586..b00d71f 100644 --- a/packages/Makefile +++ b/packages/Makefile | |||
@@ -1,22 +1,26 @@ | |||
1 | pkgs=base gsl special glpk tests ../../hTensor ../../easyVision/packages/base | 1 | pkgs=base gsl special glpk tests ../../hTensor ../../easyVision/packages/tools ../../easyVision/packages/base |
2 | |||
3 | mkl=--extra-include-dirs=$(MKL) --extra-lib-dirs=$(MKL) | ||
4 | |||
5 | cabalcmd = \ | ||
6 | for p in $(1); do \ | ||
7 | if [ -e $$p ]; then \ | ||
8 | cd $$p; cabal $(2) ; cd -; \ | ||
9 | fi; \ | ||
10 | done; \ | ||
11 | cd sparse; \ | ||
12 | cabal $(3) $(2); cd -; | ||
13 | |||
2 | 14 | ||
3 | all: | 15 | all: |
4 | for p in $(pkgs); do \ | 16 | $(call cabalcmd, $(pkgs), install --force-reinstall --enable-documentation, $(mkl)) |
5 | if [ -e $$p ]; then \ | ||
6 | cd $$p; cabal install --force-reinstall --enable-documentation ; cd -; \ | ||
7 | fi; \ | ||
8 | done | ||
9 | cd sparse; \ | ||
10 | cabal install --extra-include-dirs=$(MKL) --extra-lib-dirs=$(MKL) \ | ||
11 | --force-reinstall --enable-documentation ; cd -; | ||
12 | 17 | ||
13 | fast: | 18 | fast: |
14 | for p in $(pkgs); do \ | 19 | $(call cabalcmd, $(pkgs), install --force-reinstall, $(mkl)) |
15 | if [ -e $$p ]; then \ | 20 | |
16 | cd $$p; cabal install --force-reinstall ; cd -; \ | 21 | clean: |
17 | fi; \ | 22 | $(call cabalcmd, $(pkgs), clean) |
18 | done | 23 | |
19 | cd sparse; \ | 24 | prof: |
20 | cabal install --extra-include-dirs=$(MKL) --extra-lib-dirs=$(MKL) \ | 25 | $(call cabalcmd, $(pkgs), install --force-reinstall --enable-library-profiling, $(mkl)) |
21 | --force-reinstall; cd -; | ||
22 | 26 | ||
diff --git a/packages/base/CHANGELOG b/packages/base/CHANGELOG index c137285..581d2ac 100644 --- a/packages/base/CHANGELOG +++ b/packages/base/CHANGELOG | |||
@@ -1,3 +1,33 @@ | |||
1 | 0.17.0.0 | ||
2 | -------- | ||
3 | |||
4 | * eigSH, chol, and other functions that work with Hermitian or symmetric matrices | ||
5 | now take a special "Herm" argument that can be created by means of "sym" | ||
6 | or "mTm". The unchecked versions of those functions have been removed and we | ||
7 | use "trustSym" to create the Herm type when the matrix is known to be Hermitian/symmetric. | ||
8 | |||
9 | * Improved matrix extraction (??) and rectangular matrix slices without data copy | ||
10 | |||
11 | * Basic support of Int32 and Int64 elements | ||
12 | |||
13 | * remap, more general cond, sortIndex | ||
14 | |||
15 | * Experimental support of type safe modular arithmetic, including linear | ||
16 | system solver and LU factorization | ||
17 | |||
18 | * Elementary row operations and inplace matrix slice products in the ST monad | ||
19 | |||
20 | * Improved development tools. | ||
21 | |||
22 | * Old compatibility modules removed, simpler organization of internal modules | ||
23 | |||
24 | * unitary, pairwiseD2, tr' | ||
25 | |||
26 | * ldlPacked, ldlSolve for indefinite symmetric systems (apparently not faster | ||
27 | than the general solver based on the LU) | ||
28 | |||
29 | * LU, LDL, and QR types for these compact decompositions. | ||
30 | |||
1 | 0.16.1.0 | 31 | 0.16.1.0 |
2 | -------- | 32 | -------- |
3 | 33 | ||
diff --git a/packages/base/THANKS.md b/packages/base/THANKS.md index a4188eb..f29775a 100644 --- a/packages/base/THANKS.md +++ b/packages/base/THANKS.md | |||
@@ -190,3 +190,17 @@ module reorganization, monadic mapVectorM, and many other improvements. | |||
190 | 190 | ||
191 | - Thomas M. DuBuisson fixed a C include file. | 191 | - Thomas M. DuBuisson fixed a C include file. |
192 | 192 | ||
193 | - Matt Peddie wrote the interfaces to the interpolation and simulated annealing modules. | ||
194 | |||
195 | - "maxc01" solved uninstallability in FreeBSD and improved urandom | ||
196 | |||
197 | - "ntfrgl" added {take,drop}Last{Rows,Columns} and odeSolveVWith with generalized step control function | ||
198 | and fixed link errors related to mod/mod_l. | ||
199 | |||
200 | - "cruegge" discovered a bug in the conjugate gradient solver for sparse symmetric systems. | ||
201 | |||
202 | - Ilan Godik and Douglas McClean helped with Windows support. | ||
203 | |||
204 | - Vassil Keremidchiev fixed the cabal options for OpenBlas, fixed several installation | ||
205 | issues, and added support for stack-based build. | ||
206 | |||
diff --git a/packages/base/hmatrix.cabal b/packages/base/hmatrix.cabal index 3895dc1..7bed0d3 100644 --- a/packages/base/hmatrix.cabal +++ b/packages/base/hmatrix.cabal | |||
@@ -1,5 +1,5 @@ | |||
1 | Name: hmatrix | 1 | Name: hmatrix |
2 | Version: 0.16.1.5 | 2 | Version: 0.17.0.1 |
3 | License: BSD3 | 3 | License: BSD3 |
4 | License-file: LICENSE | 4 | License-file: LICENSE |
5 | Author: Alberto Ruiz | 5 | Author: Alberto Ruiz |
@@ -7,17 +7,11 @@ Maintainer: Alberto Ruiz | |||
7 | Stability: provisional | 7 | Stability: provisional |
8 | Homepage: https://github.com/albertoruiz/hmatrix | 8 | Homepage: https://github.com/albertoruiz/hmatrix |
9 | Synopsis: Numeric Linear Algebra | 9 | Synopsis: Numeric Linear Algebra |
10 | Description: Linear algebra based on BLAS and LAPACK. | 10 | Description: Linear systems, matrix decompositions, and other numerical computations based on BLAS and LAPACK. |
11 | . | 11 | . |
12 | The package is organized as follows: | 12 | Standard interface: "Numeric.LinearAlgebra". |
13 | . | 13 | . |
14 | ["Numeric.LinearAlgebra.HMatrix"] Starting point and recommended import module for most applications. | 14 | Safer interface with statically checked dimensions: "Numeric.LinearAlgebra.Static". |
15 | . | ||
16 | ["Numeric.LinearAlgebra.Static"] Experimental alternative interface. | ||
17 | . | ||
18 | ["Numeric.LinearAlgebra.Devel"] Tools for extending the library. | ||
19 | . | ||
20 | (Other modules are exposed with hidden documentation for backwards compatibility.) | ||
21 | . | 15 | . |
22 | Code examples: <http://dis.um.es/~alberto/hmatrix/hmatrix.html> | 16 | Code examples: <http://dis.um.es/~alberto/hmatrix/hmatrix.html> |
23 | 17 | ||
@@ -30,16 +24,16 @@ build-type: Simple | |||
30 | 24 | ||
31 | extra-source-files: THANKS.md CHANGELOG | 25 | extra-source-files: THANKS.md CHANGELOG |
32 | 26 | ||
33 | extra-source-files: src/C/lapack-aux.h | 27 | extra-source-files: src/Internal/C/lapack-aux.h |
34 | 28 | ||
35 | flag openblas | 29 | flag openblas |
36 | description: Link with OpenBLAS (https://github.com/xianyi/OpenBLAS) optimized libraries. | 30 | description: Link with OpenBLAS (https://github.com/xianyi/OpenBLAS) optimized libraries. |
37 | default: False | 31 | default: False |
38 | manual: True | 32 | manual: True |
39 | 33 | ||
40 | library | 34 | library |
41 | 35 | ||
42 | Build-Depends: base >= 4 && < 5, | 36 | Build-Depends: base >= 4.8 && < 5, |
43 | binary, | 37 | binary, |
44 | array, | 38 | array, |
45 | deepseq, | 39 | deepseq, |
@@ -51,47 +45,38 @@ library | |||
51 | 45 | ||
52 | hs-source-dirs: src | 46 | hs-source-dirs: src |
53 | 47 | ||
54 | exposed-modules: Data.Packed, | 48 | exposed-modules: Numeric.LinearAlgebra |
55 | Data.Packed.Vector, | ||
56 | Data.Packed.Matrix, | ||
57 | Data.Packed.Foreign, | ||
58 | Data.Packed.ST, | ||
59 | Data.Packed.Development, | ||
60 | |||
61 | Numeric.LinearAlgebra | ||
62 | Numeric.LinearAlgebra.LAPACK | ||
63 | Numeric.LinearAlgebra.Algorithms | ||
64 | Numeric.Container | ||
65 | Numeric.LinearAlgebra.Util | ||
66 | |||
67 | Numeric.LinearAlgebra.Devel | 49 | Numeric.LinearAlgebra.Devel |
68 | Numeric.LinearAlgebra.Data | 50 | Numeric.LinearAlgebra.Data |
69 | Numeric.LinearAlgebra.HMatrix | 51 | Numeric.LinearAlgebra.HMatrix |
70 | Numeric.LinearAlgebra.Static | 52 | Numeric.LinearAlgebra.Static |
71 | |||
72 | |||
73 | 53 | ||
74 | other-modules: Data.Packed.Internal, | 54 | other-modules: Internal.Vector |
75 | Data.Packed.Internal.Common | 55 | Internal.Devel |
76 | Data.Packed.Internal.Signatures | 56 | Internal.Vectorized |
77 | Data.Packed.Internal.Vector | 57 | Internal.Matrix |
78 | Data.Packed.Internal.Matrix | 58 | Internal.Foreign |
79 | Data.Packed.IO | 59 | Internal.ST |
80 | Numeric.Chain | 60 | Internal.IO |
81 | Numeric.Vectorized | 61 | Internal.Element |
62 | Internal.Conversion | ||
63 | Internal.LAPACK | ||
64 | Internal.Numeric | ||
65 | Internal.Algorithms | ||
66 | Internal.Random | ||
67 | Internal.Container | ||
68 | Internal.Sparse | ||
69 | Internal.Convolution | ||
70 | Internal.Chain | ||
82 | Numeric.Vector | 71 | Numeric.Vector |
72 | Internal.CG | ||
83 | Numeric.Matrix | 73 | Numeric.Matrix |
84 | Data.Packed.Internal.Numeric | 74 | Internal.Util |
85 | Data.Packed.Numeric | 75 | Internal.Modular |
86 | Numeric.LinearAlgebra.Util.Convolution | 76 | Internal.Static |
87 | Numeric.LinearAlgebra.Util.CG | ||
88 | Numeric.LinearAlgebra.Random | ||
89 | Numeric.Conversion | ||
90 | Numeric.Sparse | ||
91 | Numeric.LinearAlgebra.Static.Internal | ||
92 | 77 | ||
93 | C-sources: src/C/lapack-aux.c | 78 | C-sources: src/Internal/C/lapack-aux.c |
94 | src/C/vector-aux.c | 79 | src/Internal/C/vector-aux.c |
95 | 80 | ||
96 | 81 | ||
97 | extensions: ForeignFunctionInterface, | 82 | extensions: ForeignFunctionInterface, |
@@ -100,18 +85,24 @@ library | |||
100 | ghc-options: -Wall | 85 | ghc-options: -Wall |
101 | -fno-warn-missing-signatures | 86 | -fno-warn-missing-signatures |
102 | -fno-warn-orphans | 87 | -fno-warn-orphans |
88 | -fprof-auto | ||
103 | 89 | ||
104 | cc-options: -O4 -msse2 -Wall | 90 | cc-options: -O4 -Wall |
105 | 91 | ||
106 | cpp-options: -DBINARY | 92 | if arch(x86_64) |
93 | cc-options: -msse2 | ||
94 | if arch(i386) | ||
95 | cc-options: -msse2 | ||
107 | 96 | ||
108 | if flag(openblas) | 97 | cpp-options: -DBINARY |
109 | extra-lib-dirs: /usr/lib/openblas/lib | ||
110 | extra-libraries: openblas | ||
111 | else | ||
112 | extra-libraries: blas lapack | ||
113 | 98 | ||
114 | if os(OSX) | 99 | if os(OSX) |
100 | if flag(openblas) | ||
101 | extra-lib-dirs: /opt/local/lib/openblas/lib | ||
102 | extra-libraries: openblas | ||
103 | else | ||
104 | extra-libraries: blas lapack | ||
105 | |||
115 | extra-lib-dirs: /opt/local/lib/ | 106 | extra-lib-dirs: /opt/local/lib/ |
116 | include-dirs: /opt/local/include/ | 107 | include-dirs: /opt/local/include/ |
117 | extra-lib-dirs: /usr/local/lib/ | 108 | extra-lib-dirs: /usr/local/lib/ |
@@ -121,14 +112,29 @@ library | |||
121 | frameworks: Accelerate | 112 | frameworks: Accelerate |
122 | 113 | ||
123 | if os(freebsd) | 114 | if os(freebsd) |
115 | if flag(openblas) | ||
116 | extra-lib-dirs: /usr/local/lib/openblas/lib | ||
117 | extra-libraries: openblas | ||
118 | else | ||
119 | extra-libraries: blas lapack | ||
120 | |||
124 | extra-lib-dirs: /usr/local/lib | 121 | extra-lib-dirs: /usr/local/lib |
125 | include-dirs: /usr/local/include | 122 | include-dirs: /usr/local/include |
126 | extra-libraries: blas lapack gfortran | 123 | extra-libraries: gfortran |
127 | 124 | ||
128 | if os(windows) | 125 | if os(windows) |
129 | extra-libraries: blas lapack | 126 | if flag(openblas) |
127 | extra-libraries: libopenblas | ||
128 | else | ||
129 | extra-libraries: blas lapack | ||
130 | 130 | ||
131 | if os(linux) | 131 | if os(linux) |
132 | if flag(openblas) | ||
133 | extra-lib-dirs: /usr/lib/openblas/lib | ||
134 | extra-libraries: openblas | ||
135 | else | ||
136 | extra-libraries: blas lapack | ||
137 | |||
132 | if arch(x86_64) | 138 | if arch(x86_64) |
133 | cc-options: -fPIC | 139 | cc-options: -fPIC |
134 | 140 | ||
diff --git a/packages/base/src/Data/Packed.hs b/packages/base/src/Data/Packed.hs deleted file mode 100644 index 129bd22..0000000 --- a/packages/base/src/Data/Packed.hs +++ /dev/null | |||
@@ -1,26 +0,0 @@ | |||
1 | ----------------------------------------------------------------------------- | ||
2 | {- | | ||
3 | Module : Data.Packed | ||
4 | Copyright : (c) Alberto Ruiz 2006-2014 | ||
5 | License : BSD3 | ||
6 | Maintainer : Alberto Ruiz | ||
7 | Stability : provisional | ||
8 | |||
9 | Types for dense 'Vector' and 'Matrix' of 'Storable' elements. | ||
10 | |||
11 | -} | ||
12 | ----------------------------------------------------------------------------- | ||
13 | {-# OPTIONS_HADDOCK hide #-} | ||
14 | |||
15 | module Data.Packed ( | ||
16 | -- * Vector | ||
17 | -- | ||
18 | -- | Vectors are @Data.Vector.Storable.Vector@ from the \"vector\" package. | ||
19 | module Data.Packed.Vector, | ||
20 | -- * Matrix | ||
21 | module Data.Packed.Matrix, | ||
22 | ) where | ||
23 | |||
24 | import Data.Packed.Vector | ||
25 | import Data.Packed.Matrix | ||
26 | |||
diff --git a/packages/base/src/Data/Packed/Development.hs b/packages/base/src/Data/Packed/Development.hs deleted file mode 100644 index 72eb16b..0000000 --- a/packages/base/src/Data/Packed/Development.hs +++ /dev/null | |||
@@ -1,32 +0,0 @@ | |||
1 | |||
2 | ----------------------------------------------------------------------------- | ||
3 | -- | | ||
4 | -- Module : Data.Packed.Development | ||
5 | -- Copyright : (c) Alberto Ruiz 2009 | ||
6 | -- License : BSD3 | ||
7 | -- Maintainer : Alberto Ruiz | ||
8 | -- Stability : provisional | ||
9 | -- Portability : portable | ||
10 | -- | ||
11 | -- The library can be easily extended with additional foreign functions | ||
12 | -- using the tools in this module. Illustrative usage examples can be found | ||
13 | -- in the @examples\/devel@ folder included in the package. | ||
14 | -- | ||
15 | ----------------------------------------------------------------------------- | ||
16 | {-# OPTIONS_HADDOCK hide #-} | ||
17 | |||
18 | module Data.Packed.Development ( | ||
19 | createVector, createMatrix, | ||
20 | vec, mat, | ||
21 | app1, app2, app3, app4, | ||
22 | app5, app6, app7, app8, app9, app10, | ||
23 | MatrixOrder(..), orderOf, cmat, fmat, | ||
24 | matrixFromVector, | ||
25 | unsafeFromForeignPtr, | ||
26 | unsafeToForeignPtr, | ||
27 | check, (//), | ||
28 | at', atM', fi | ||
29 | ) where | ||
30 | |||
31 | import Data.Packed.Internal | ||
32 | |||
diff --git a/packages/base/src/Data/Packed/Internal.hs b/packages/base/src/Data/Packed/Internal.hs deleted file mode 100644 index 59a72fc..0000000 --- a/packages/base/src/Data/Packed/Internal.hs +++ /dev/null | |||
@@ -1,24 +0,0 @@ | |||
1 | ----------------------------------------------------------------------------- | ||
2 | -- | | ||
3 | -- Module : Data.Packed.Internal | ||
4 | -- Copyright : (c) Alberto Ruiz 2007 | ||
5 | -- License : BSD3 | ||
6 | -- Maintainer : Alberto Ruiz | ||
7 | -- Stability : provisional | ||
8 | -- | ||
9 | -- Reexports all internal modules | ||
10 | -- | ||
11 | ----------------------------------------------------------------------------- | ||
12 | -- #hide | ||
13 | |||
14 | module Data.Packed.Internal ( | ||
15 | module Data.Packed.Internal.Common, | ||
16 | module Data.Packed.Internal.Signatures, | ||
17 | module Data.Packed.Internal.Vector, | ||
18 | module Data.Packed.Internal.Matrix, | ||
19 | ) where | ||
20 | |||
21 | import Data.Packed.Internal.Common | ||
22 | import Data.Packed.Internal.Signatures | ||
23 | import Data.Packed.Internal.Vector | ||
24 | import Data.Packed.Internal.Matrix | ||
diff --git a/packages/base/src/Data/Packed/Internal/Common.hs b/packages/base/src/Data/Packed/Internal/Common.hs deleted file mode 100644 index 615bbdf..0000000 --- a/packages/base/src/Data/Packed/Internal/Common.hs +++ /dev/null | |||
@@ -1,160 +0,0 @@ | |||
1 | {-# LANGUAGE CPP #-} | ||
2 | -- | | ||
3 | -- Module : Data.Packed.Internal.Common | ||
4 | -- Copyright : (c) Alberto Ruiz 2007 | ||
5 | -- License : BSD3 | ||
6 | -- Maintainer : Alberto Ruiz | ||
7 | -- Stability : provisional | ||
8 | -- | ||
9 | -- | ||
10 | -- Development utilities. | ||
11 | -- | ||
12 | |||
13 | |||
14 | module Data.Packed.Internal.Common( | ||
15 | Adapt, | ||
16 | app1, app2, app3, app4, | ||
17 | app5, app6, app7, app8, app9, app10, | ||
18 | (//), check, mbCatch, | ||
19 | splitEvery, common, compatdim, | ||
20 | fi, | ||
21 | table, | ||
22 | finit | ||
23 | ) where | ||
24 | |||
25 | import Control.Monad(when) | ||
26 | import Foreign.C.Types | ||
27 | import Foreign.Storable.Complex() | ||
28 | import Data.List(transpose,intersperse) | ||
29 | import Control.Exception as E | ||
30 | |||
31 | -- | @splitEvery 3 [1..9] == [[1,2,3],[4,5,6],[7,8,9]]@ | ||
32 | splitEvery :: Int -> [a] -> [[a]] | ||
33 | splitEvery _ [] = [] | ||
34 | splitEvery k l = take k l : splitEvery k (drop k l) | ||
35 | |||
36 | -- | obtains the common value of a property of a list | ||
37 | common :: (Eq a) => (b->a) -> [b] -> Maybe a | ||
38 | common f = commonval . map f where | ||
39 | commonval :: (Eq a) => [a] -> Maybe a | ||
40 | commonval [] = Nothing | ||
41 | commonval [a] = Just a | ||
42 | commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing | ||
43 | |||
44 | -- | common value with \"adaptable\" 1 | ||
45 | compatdim :: [Int] -> Maybe Int | ||
46 | compatdim [] = Nothing | ||
47 | compatdim [a] = Just a | ||
48 | compatdim (a:b:xs) | ||
49 | | a==b = compatdim (b:xs) | ||
50 | | a==1 = compatdim (b:xs) | ||
51 | | b==1 = compatdim (a:xs) | ||
52 | | otherwise = Nothing | ||
53 | |||
54 | -- | Formatting tool | ||
55 | table :: String -> [[String]] -> String | ||
56 | table sep as = unlines . map unwords' $ transpose mtp where | ||
57 | mt = transpose as | ||
58 | longs = map (maximum . map length) mt | ||
59 | mtp = zipWith (\a b -> map (pad a) b) longs mt | ||
60 | pad n str = replicate (n - length str) ' ' ++ str | ||
61 | unwords' = concat . intersperse sep | ||
62 | |||
63 | -- | postfix function application (@flip ($)@) | ||
64 | (//) :: x -> (x -> y) -> y | ||
65 | infixl 0 // | ||
66 | (//) = flip ($) | ||
67 | |||
68 | -- | specialized fromIntegral | ||
69 | fi :: Int -> CInt | ||
70 | fi = fromIntegral | ||
71 | |||
72 | -- hmm.. | ||
73 | ww2 w1 o1 w2 o2 f = w1 o1 $ w2 o2 . f | ||
74 | ww3 w1 o1 w2 o2 w3 o3 f = w1 o1 $ ww2 w2 o2 w3 o3 . f | ||
75 | ww4 w1 o1 w2 o2 w3 o3 w4 o4 f = w1 o1 $ ww3 w2 o2 w3 o3 w4 o4 . f | ||
76 | ww5 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 f = w1 o1 $ ww4 w2 o2 w3 o3 w4 o4 w5 o5 . f | ||
77 | ww6 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 f = w1 o1 $ ww5 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 . f | ||
78 | ww7 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 f = w1 o1 $ ww6 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 . f | ||
79 | ww8 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 f = w1 o1 $ ww7 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 . f | ||
80 | ww9 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 f = w1 o1 $ ww8 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 . f | ||
81 | ww10 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 f = w1 o1 $ ww9 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 . f | ||
82 | |||
83 | type Adapt f t r = t -> ((f -> r) -> IO()) -> IO() | ||
84 | |||
85 | type Adapt1 f t1 = Adapt f t1 (IO CInt) -> t1 -> String -> IO() | ||
86 | type Adapt2 f t1 r1 t2 = Adapt f t1 r1 -> t1 -> Adapt1 r1 t2 | ||
87 | type Adapt3 f t1 r1 t2 r2 t3 = Adapt f t1 r1 -> t1 -> Adapt2 r1 t2 r2 t3 | ||
88 | type Adapt4 f t1 r1 t2 r2 t3 r3 t4 = Adapt f t1 r1 -> t1 -> Adapt3 r1 t2 r2 t3 r3 t4 | ||
89 | type Adapt5 f t1 r1 t2 r2 t3 r3 t4 r4 t5 = Adapt f t1 r1 -> t1 -> Adapt4 r1 t2 r2 t3 r3 t4 r4 t5 | ||
90 | type Adapt6 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 = Adapt f t1 r1 -> t1 -> Adapt5 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 | ||
91 | type Adapt7 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 = Adapt f t1 r1 -> t1 -> Adapt6 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 | ||
92 | type Adapt8 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 = Adapt f t1 r1 -> t1 -> Adapt7 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 | ||
93 | type Adapt9 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 = Adapt f t1 r1 -> t1 -> Adapt8 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 | ||
94 | type Adapt10 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 r9 t10 = Adapt f t1 r1 -> t1 -> Adapt9 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 r9 t10 | ||
95 | |||
96 | app1 :: f -> Adapt1 f t1 | ||
97 | app2 :: f -> Adapt2 f t1 r1 t2 | ||
98 | app3 :: f -> Adapt3 f t1 r1 t2 r2 t3 | ||
99 | app4 :: f -> Adapt4 f t1 r1 t2 r2 t3 r3 t4 | ||
100 | app5 :: f -> Adapt5 f t1 r1 t2 r2 t3 r3 t4 r4 t5 | ||
101 | app6 :: f -> Adapt6 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 | ||
102 | app7 :: f -> Adapt7 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 | ||
103 | app8 :: f -> Adapt8 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 | ||
104 | app9 :: f -> Adapt9 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 | ||
105 | app10 :: f -> Adapt10 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 r9 t10 | ||
106 | |||
107 | app1 f w1 o1 s = w1 o1 $ \a1 -> f // a1 // check s | ||
108 | app2 f w1 o1 w2 o2 s = ww2 w1 o1 w2 o2 $ \a1 a2 -> f // a1 // a2 // check s | ||
109 | app3 f w1 o1 w2 o2 w3 o3 s = ww3 w1 o1 w2 o2 w3 o3 $ | ||
110 | \a1 a2 a3 -> f // a1 // a2 // a3 // check s | ||
111 | app4 f w1 o1 w2 o2 w3 o3 w4 o4 s = ww4 w1 o1 w2 o2 w3 o3 w4 o4 $ | ||
112 | \a1 a2 a3 a4 -> f // a1 // a2 // a3 // a4 // check s | ||
113 | app5 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 s = ww5 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 $ | ||
114 | \a1 a2 a3 a4 a5 -> f // a1 // a2 // a3 // a4 // a5 // check s | ||
115 | app6 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 s = ww6 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 $ | ||
116 | \a1 a2 a3 a4 a5 a6 -> f // a1 // a2 // a3 // a4 // a5 // a6 // check s | ||
117 | app7 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 s = ww7 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 $ | ||
118 | \a1 a2 a3 a4 a5 a6 a7 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // check s | ||
119 | app8 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 s = ww8 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 $ | ||
120 | \a1 a2 a3 a4 a5 a6 a7 a8 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // a8 // check s | ||
121 | app9 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 s = ww9 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 $ | ||
122 | \a1 a2 a3 a4 a5 a6 a7 a8 a9 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // a8 // a9 // check s | ||
123 | app10 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 s = ww10 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 $ | ||
124 | \a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // a8 // a9 // a10 // check s | ||
125 | |||
126 | |||
127 | |||
128 | -- GSL error codes are <= 1024 | ||
129 | -- | error codes for the auxiliary functions required by the wrappers | ||
130 | errorCode :: CInt -> String | ||
131 | errorCode 2000 = "bad size" | ||
132 | errorCode 2001 = "bad function code" | ||
133 | errorCode 2002 = "memory problem" | ||
134 | errorCode 2003 = "bad file" | ||
135 | errorCode 2004 = "singular" | ||
136 | errorCode 2005 = "didn't converge" | ||
137 | errorCode 2006 = "the input matrix is not positive definite" | ||
138 | errorCode 2007 = "not yet supported in this OS" | ||
139 | errorCode n = "code "++show n | ||
140 | |||
141 | |||
142 | -- | clear the fpu | ||
143 | foreign import ccall unsafe "asm_finit" finit :: IO () | ||
144 | |||
145 | -- | check the error code | ||
146 | check :: String -> IO CInt -> IO () | ||
147 | check msg f = do | ||
148 | #if FINIT | ||
149 | finit | ||
150 | #endif | ||
151 | err <- f | ||
152 | when (err/=0) $ error (msg++": "++errorCode err) | ||
153 | return () | ||
154 | |||
155 | -- | Error capture and conversion to Maybe | ||
156 | mbCatch :: IO x -> IO (Maybe x) | ||
157 | mbCatch act = E.catch (Just `fmap` act) f | ||
158 | where f :: SomeException -> IO (Maybe x) | ||
159 | f _ = return Nothing | ||
160 | |||
diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs deleted file mode 100644 index 150b978..0000000 --- a/packages/base/src/Data/Packed/Internal/Matrix.hs +++ /dev/null | |||
@@ -1,423 +0,0 @@ | |||
1 | {-# LANGUAGE ForeignFunctionInterface #-} | ||
2 | {-# LANGUAGE FlexibleContexts #-} | ||
3 | {-# LANGUAGE FlexibleInstances #-} | ||
4 | {-# LANGUAGE BangPatterns #-} | ||
5 | |||
6 | -- | | ||
7 | -- Module : Data.Packed.Internal.Matrix | ||
8 | -- Copyright : (c) Alberto Ruiz 2007 | ||
9 | -- License : BSD3 | ||
10 | -- Maintainer : Alberto Ruiz | ||
11 | -- Stability : provisional | ||
12 | -- | ||
13 | -- Internal matrix representation | ||
14 | -- | ||
15 | |||
16 | module Data.Packed.Internal.Matrix( | ||
17 | Matrix(..), rows, cols, cdat, fdat, | ||
18 | MatrixOrder(..), orderOf, | ||
19 | createMatrix, mat, | ||
20 | cmat, fmat, | ||
21 | toLists, flatten, reshape, | ||
22 | Element(..), | ||
23 | trans, | ||
24 | fromRows, toRows, fromColumns, toColumns, | ||
25 | matrixFromVector, | ||
26 | subMatrix, | ||
27 | liftMatrix, liftMatrix2, | ||
28 | (@@>), atM', | ||
29 | singleton, | ||
30 | emptyM, | ||
31 | size, shSize, conformVs, conformMs, conformVTo, conformMTo | ||
32 | ) where | ||
33 | |||
34 | import Data.Packed.Internal.Common | ||
35 | import Data.Packed.Internal.Signatures | ||
36 | import Data.Packed.Internal.Vector | ||
37 | |||
38 | import Foreign.Marshal.Alloc(alloca, free) | ||
39 | import Foreign.Marshal.Array(newArray) | ||
40 | import Foreign.Ptr(Ptr, castPtr) | ||
41 | import Foreign.Storable(Storable, peekElemOff, pokeElemOff, poke, sizeOf) | ||
42 | import Data.Complex(Complex) | ||
43 | import Foreign.C.Types | ||
44 | import System.IO.Unsafe(unsafePerformIO) | ||
45 | import Control.DeepSeq | ||
46 | |||
47 | ----------------------------------------------------------------- | ||
48 | |||
49 | {- Design considerations for the Matrix Type | ||
50 | ----------------------------------------- | ||
51 | |||
52 | - we must easily handle both row major and column major order, | ||
53 | for bindings to LAPACK and GSL/C | ||
54 | |||
55 | - we'd like to simplify redundant matrix transposes: | ||
56 | - Some of them arise from the order requirements of some functions | ||
57 | - some functions (matrix product) admit transposed arguments | ||
58 | |||
59 | - maybe we don't really need this kind of simplification: | ||
60 | - more complex code | ||
61 | - some computational overhead | ||
62 | - only appreciable gain in code with a lot of redundant transpositions | ||
63 | and cheap matrix computations | ||
64 | |||
65 | - we could carry both the matrix and its (lazily computed) transpose. | ||
66 | This may save some transpositions, but it is necessary to keep track of the | ||
67 | data which is actually computed to be used by functions like the matrix product | ||
68 | which admit both orders. | ||
69 | |||
70 | - but if we need the transposed data and it is not in the structure, we must make | ||
71 | sure that we touch the same foreignptr that is used in the computation. | ||
72 | |||
73 | - a reasonable solution is using two constructors for a matrix. Transposition just | ||
74 | "flips" the constructor. Actual data transposition is not done if followed by a | ||
75 | matrix product or another transpose. | ||
76 | |||
77 | -} | ||
78 | |||
79 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) | ||
80 | |||
81 | transOrder RowMajor = ColumnMajor | ||
82 | transOrder ColumnMajor = RowMajor | ||
83 | {- | Matrix representation suitable for BLAS\/LAPACK computations. | ||
84 | |||
85 | The elements are stored in a continuous memory array. | ||
86 | |||
87 | -} | ||
88 | |||
89 | data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int | ||
90 | , icols :: {-# UNPACK #-} !Int | ||
91 | , xdat :: {-# UNPACK #-} !(Vector t) | ||
92 | , order :: !MatrixOrder } | ||
93 | -- RowMajor: preferred by C, fdat may require a transposition | ||
94 | -- ColumnMajor: preferred by LAPACK, cdat may require a transposition | ||
95 | |||
96 | cdat = xdat | ||
97 | fdat = xdat | ||
98 | |||
99 | rows :: Matrix t -> Int | ||
100 | rows = irows | ||
101 | |||
102 | cols :: Matrix t -> Int | ||
103 | cols = icols | ||
104 | |||
105 | orderOf :: Matrix t -> MatrixOrder | ||
106 | orderOf = order | ||
107 | |||
108 | |||
109 | -- | Matrix transpose. | ||
110 | trans :: Matrix t -> Matrix t | ||
111 | trans Matrix {irows = r, icols = c, xdat = d, order = o } = Matrix { irows = c, icols = r, xdat = d, order = transOrder o} | ||
112 | |||
113 | cmat :: (Element t) => Matrix t -> Matrix t | ||
114 | cmat m@Matrix{order = RowMajor} = m | ||
115 | cmat Matrix {irows = r, icols = c, xdat = d, order = ColumnMajor } = Matrix { irows = r, icols = c, xdat = transdata r d c, order = RowMajor} | ||
116 | |||
117 | fmat :: (Element t) => Matrix t -> Matrix t | ||
118 | fmat m@Matrix{order = ColumnMajor} = m | ||
119 | fmat Matrix {irows = r, icols = c, xdat = d, order = RowMajor } = Matrix { irows = r, icols = c, xdat = transdata c d r, order = ColumnMajor} | ||
120 | |||
121 | -- C-Haskell matrix adapter | ||
122 | -- mat :: Adapt (CInt -> CInt -> Ptr t -> r) (Matrix t) r | ||
123 | |||
124 | mat :: (Storable t) => Matrix t -> (((CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b | ||
125 | mat a f = | ||
126 | unsafeWith (xdat a) $ \p -> do | ||
127 | let m g = do | ||
128 | g (fi (rows a)) (fi (cols a)) p | ||
129 | f m | ||
130 | |||
131 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. | ||
132 | |||
133 | >>> flatten (ident 3) | ||
134 | fromList [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0] | ||
135 | |||
136 | -} | ||
137 | flatten :: Element t => Matrix t -> Vector t | ||
138 | flatten = xdat . cmat | ||
139 | |||
140 | {- | ||
141 | type Mt t s = Int -> Int -> Ptr t -> s | ||
142 | |||
143 | infixr 6 ::> | ||
144 | type t ::> s = Mt t s | ||
145 | -} | ||
146 | |||
147 | -- | the inverse of 'Data.Packed.Matrix.fromLists' | ||
148 | toLists :: (Element t) => Matrix t -> [[t]] | ||
149 | toLists m = splitEvery (cols m) . toList . flatten $ m | ||
150 | |||
151 | -- | Create a matrix from a list of vectors. | ||
152 | -- All vectors must have the same dimension, | ||
153 | -- or dimension 1, which is are automatically expanded. | ||
154 | fromRows :: Element t => [Vector t] -> Matrix t | ||
155 | fromRows [] = emptyM 0 0 | ||
156 | fromRows vs = case compatdim (map dim vs) of | ||
157 | Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs) | ||
158 | Just 0 -> emptyM r 0 | ||
159 | Just c -> matrixFromVector RowMajor r c . vjoin . map (adapt c) $ vs | ||
160 | where | ||
161 | r = length vs | ||
162 | adapt c v | ||
163 | | c == 0 = fromList[] | ||
164 | | dim v == c = v | ||
165 | | otherwise = constantD (v@>0) c | ||
166 | |||
167 | -- | extracts the rows of a matrix as a list of vectors | ||
168 | toRows :: Element t => Matrix t -> [Vector t] | ||
169 | toRows m | ||
170 | | c == 0 = replicate r (fromList[]) | ||
171 | | otherwise = toRows' 0 | ||
172 | where | ||
173 | v = flatten m | ||
174 | r = rows m | ||
175 | c = cols m | ||
176 | toRows' k | k == r*c = [] | ||
177 | | otherwise = subVector k c v : toRows' (k+c) | ||
178 | |||
179 | -- | Creates a matrix from a list of vectors, as columns | ||
180 | fromColumns :: Element t => [Vector t] -> Matrix t | ||
181 | fromColumns m = trans . fromRows $ m | ||
182 | |||
183 | -- | Creates a list of vectors from the columns of a matrix | ||
184 | toColumns :: Element t => Matrix t -> [Vector t] | ||
185 | toColumns m = toRows . trans $ m | ||
186 | |||
187 | -- | Reads a matrix position. | ||
188 | (@@>) :: Storable t => Matrix t -> (Int,Int) -> t | ||
189 | infixl 9 @@> | ||
190 | m@Matrix {irows = r, icols = c} @@> (i,j) | ||
191 | | safe = if i<0 || i>=r || j<0 || j>=c | ||
192 | then error "matrix indexing out of range" | ||
193 | else atM' m i j | ||
194 | | otherwise = atM' m i j | ||
195 | {-# INLINE (@@>) #-} | ||
196 | |||
197 | -- Unsafe matrix access without range checking | ||
198 | atM' Matrix {icols = c, xdat = v, order = RowMajor} i j = v `at'` (i*c+j) | ||
199 | atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i) | ||
200 | {-# INLINE atM' #-} | ||
201 | |||
202 | ------------------------------------------------------------------ | ||
203 | |||
204 | matrixFromVector o r c v | ||
205 | | r * c == dim v = m | ||
206 | | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m | ||
207 | where | ||
208 | m = Matrix { irows = r, icols = c, xdat = v, order = o } | ||
209 | |||
210 | -- allocates memory for a new matrix | ||
211 | createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) | ||
212 | createMatrix ord r c = do | ||
213 | p <- createVector (r*c) | ||
214 | return (matrixFromVector ord r c p) | ||
215 | |||
216 | {- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = trans . reshape r@ | ||
217 | where r is the desired number of rows.) | ||
218 | |||
219 | >>> reshape 4 (fromList [1..12]) | ||
220 | (3><4) | ||
221 | [ 1.0, 2.0, 3.0, 4.0 | ||
222 | , 5.0, 6.0, 7.0, 8.0 | ||
223 | , 9.0, 10.0, 11.0, 12.0 ] | ||
224 | |||
225 | -} | ||
226 | reshape :: Storable t => Int -> Vector t -> Matrix t | ||
227 | reshape 0 v = matrixFromVector RowMajor 0 0 v | ||
228 | reshape c v = matrixFromVector RowMajor (dim v `div` c) c v | ||
229 | |||
230 | singleton x = reshape 1 (fromList [x]) | ||
231 | |||
232 | -- | application of a vector function on the flattened matrix elements | ||
233 | liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b | ||
234 | liftMatrix f Matrix { irows = r, icols = c, xdat = d, order = o } = matrixFromVector o r c (f d) | ||
235 | |||
236 | -- | application of a vector function on the flattened matrices elements | ||
237 | liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | ||
238 | liftMatrix2 f m1 m2 | ||
239 | | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" | ||
240 | | otherwise = case orderOf m1 of | ||
241 | RowMajor -> matrixFromVector RowMajor (rows m1) (cols m1) (f (xdat m1) (flatten m2)) | ||
242 | ColumnMajor -> matrixFromVector ColumnMajor (rows m1) (cols m1) (f (xdat m1) ((xdat.fmat) m2)) | ||
243 | |||
244 | |||
245 | compat :: Matrix a -> Matrix b -> Bool | ||
246 | compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2 | ||
247 | |||
248 | ------------------------------------------------------------------ | ||
249 | |||
250 | {- | Supported matrix elements. | ||
251 | |||
252 | This class provides optimized internal | ||
253 | operations for selected element types. | ||
254 | It provides unoptimised defaults for any 'Storable' type, | ||
255 | so you can create instances simply as: | ||
256 | |||
257 | >instance Element Foo | ||
258 | -} | ||
259 | class (Storable a) => Element a where | ||
260 | subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position | ||
261 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | ||
262 | -> Matrix a -> Matrix a | ||
263 | subMatrixD = subMatrix' | ||
264 | transdata :: Int -> Vector a -> Int -> Vector a | ||
265 | transdata = transdataP -- transdata' | ||
266 | constantD :: a -> Int -> Vector a | ||
267 | constantD = constantP -- constant' | ||
268 | |||
269 | |||
270 | instance Element Float where | ||
271 | transdata = transdataAux ctransF | ||
272 | constantD = constantAux cconstantF | ||
273 | |||
274 | instance Element Double where | ||
275 | transdata = transdataAux ctransR | ||
276 | constantD = constantAux cconstantR | ||
277 | |||
278 | instance Element (Complex Float) where | ||
279 | transdata = transdataAux ctransQ | ||
280 | constantD = constantAux cconstantQ | ||
281 | |||
282 | instance Element (Complex Double) where | ||
283 | transdata = transdataAux ctransC | ||
284 | constantD = constantAux cconstantC | ||
285 | |||
286 | ------------------------------------------------------------------- | ||
287 | |||
288 | transdataAux fun c1 d c2 = | ||
289 | if noneed | ||
290 | then d | ||
291 | else unsafePerformIO $ do | ||
292 | v <- createVector (dim d) | ||
293 | unsafeWith d $ \pd -> | ||
294 | unsafeWith v $ \pv -> | ||
295 | fun (fi r1) (fi c1) pd (fi r2) (fi c2) pv // check "transdataAux" | ||
296 | return v | ||
297 | where r1 = dim d `div` c1 | ||
298 | r2 = dim d `div` c2 | ||
299 | noneed = dim d == 0 || r1 == 1 || c1 == 1 | ||
300 | |||
301 | transdataP :: Storable a => Int -> Vector a -> Int -> Vector a | ||
302 | transdataP c1 d c2 = | ||
303 | if noneed | ||
304 | then d | ||
305 | else unsafePerformIO $ do | ||
306 | v <- createVector (dim d) | ||
307 | unsafeWith d $ \pd -> | ||
308 | unsafeWith v $ \pv -> | ||
309 | ctransP (fi r1) (fi c1) (castPtr pd) (fi sz) (fi r2) (fi c2) (castPtr pv) (fi sz) // check "transdataP" | ||
310 | return v | ||
311 | where r1 = dim d `div` c1 | ||
312 | r2 = dim d `div` c2 | ||
313 | sz = sizeOf (d @> 0) | ||
314 | noneed = dim d == 0 || r1 == 1 || c1 == 1 | ||
315 | |||
316 | foreign import ccall unsafe "transF" ctransF :: TFMFM | ||
317 | foreign import ccall unsafe "transR" ctransR :: TMM | ||
318 | foreign import ccall unsafe "transQ" ctransQ :: TQMQM | ||
319 | foreign import ccall unsafe "transC" ctransC :: TCMCM | ||
320 | foreign import ccall unsafe "transP" ctransP :: CInt -> CInt -> Ptr () -> CInt -> CInt -> CInt -> Ptr () -> CInt -> IO CInt | ||
321 | |||
322 | ---------------------------------------------------------------------- | ||
323 | |||
324 | constantAux fun x n = unsafePerformIO $ do | ||
325 | v <- createVector n | ||
326 | px <- newArray [x] | ||
327 | app1 (fun px) vec v "constantAux" | ||
328 | free px | ||
329 | return v | ||
330 | |||
331 | foreign import ccall unsafe "constantF" cconstantF :: Ptr Float -> TF | ||
332 | |||
333 | foreign import ccall unsafe "constantR" cconstantR :: Ptr Double -> TV | ||
334 | |||
335 | foreign import ccall unsafe "constantQ" cconstantQ :: Ptr (Complex Float) -> TQV | ||
336 | |||
337 | foreign import ccall unsafe "constantC" cconstantC :: Ptr (Complex Double) -> TCV | ||
338 | |||
339 | constantP :: Storable a => a -> Int -> Vector a | ||
340 | constantP a n = unsafePerformIO $ do | ||
341 | let sz = sizeOf a | ||
342 | v <- createVector n | ||
343 | unsafeWith v $ \p -> do | ||
344 | alloca $ \k -> do | ||
345 | poke k a | ||
346 | cconstantP (castPtr k) (fi n) (castPtr p) (fi sz) // check "constantP" | ||
347 | return v | ||
348 | foreign import ccall unsafe "constantP" cconstantP :: Ptr () -> CInt -> Ptr () -> CInt -> IO CInt | ||
349 | |||
350 | ---------------------------------------------------------------------- | ||
351 | |||
352 | -- | Extracts a submatrix from a matrix. | ||
353 | subMatrix :: Element a | ||
354 | => (Int,Int) -- ^ (r0,c0) starting position | ||
355 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | ||
356 | -> Matrix a -- ^ input matrix | ||
357 | -> Matrix a -- ^ result | ||
358 | subMatrix (r0,c0) (rt,ct) m | ||
359 | | 0 <= r0 && 0 <= rt && r0+rt <= (rows m) && | ||
360 | 0 <= c0 && 0 <= ct && c0+ct <= (cols m) = subMatrixD (r0,c0) (rt,ct) m | ||
361 | | otherwise = error $ "wrong subMatrix "++ | ||
362 | show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) | ||
363 | |||
364 | subMatrix'' (r0,c0) (rt,ct) c v = unsafePerformIO $ do | ||
365 | w <- createVector (rt*ct) | ||
366 | unsafeWith v $ \p -> | ||
367 | unsafeWith w $ \q -> do | ||
368 | let go (-1) _ = return () | ||
369 | go !i (-1) = go (i-1) (ct-1) | ||
370 | go !i !j = do x <- peekElemOff p ((i+r0)*c+j+c0) | ||
371 | pokeElemOff q (i*ct+j) x | ||
372 | go i (j-1) | ||
373 | go (rt-1) (ct-1) | ||
374 | return w | ||
375 | |||
376 | subMatrix' (r0,c0) (rt,ct) (Matrix { icols = c, xdat = v, order = RowMajor}) = Matrix rt ct (subMatrix'' (r0,c0) (rt,ct) c v) RowMajor | ||
377 | subMatrix' (r0,c0) (rt,ct) m = trans $ subMatrix' (c0,r0) (ct,rt) (trans m) | ||
378 | |||
379 | -------------------------------------------------------------------------- | ||
380 | |||
381 | maxZ xs = if minimum xs == 0 then 0 else maximum xs | ||
382 | |||
383 | conformMs ms = map (conformMTo (r,c)) ms | ||
384 | where | ||
385 | r = maxZ (map rows ms) | ||
386 | c = maxZ (map cols ms) | ||
387 | |||
388 | |||
389 | conformVs vs = map (conformVTo n) vs | ||
390 | where | ||
391 | n = maxZ (map dim vs) | ||
392 | |||
393 | conformMTo (r,c) m | ||
394 | | size m == (r,c) = m | ||
395 | | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) | ||
396 | | size m == (r,1) = repCols c m | ||
397 | | size m == (1,c) = repRows r m | ||
398 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")" | ||
399 | |||
400 | conformVTo n v | ||
401 | | dim v == n = v | ||
402 | | dim v == 1 = constantD (v@>0) n | ||
403 | | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n | ||
404 | |||
405 | repRows n x = fromRows (replicate n (flatten x)) | ||
406 | repCols n x = fromColumns (replicate n (flatten x)) | ||
407 | |||
408 | size m = (rows m, cols m) | ||
409 | |||
410 | shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" | ||
411 | |||
412 | emptyM r c = matrixFromVector RowMajor r c (fromList[]) | ||
413 | |||
414 | ---------------------------------------------------------------------- | ||
415 | |||
416 | instance (Storable t, NFData t) => NFData (Matrix t) | ||
417 | where | ||
418 | rnf m | d > 0 = rnf (v @> 0) | ||
419 | | otherwise = () | ||
420 | where | ||
421 | d = dim v | ||
422 | v = xdat m | ||
423 | |||
diff --git a/packages/base/src/Data/Packed/Internal/Signatures.hs b/packages/base/src/Data/Packed/Internal/Signatures.hs deleted file mode 100644 index acc3070..0000000 --- a/packages/base/src/Data/Packed/Internal/Signatures.hs +++ /dev/null | |||
@@ -1,70 +0,0 @@ | |||
1 | -- | | ||
2 | -- Module : Data.Packed.Internal.Signatures | ||
3 | -- Copyright : (c) Alberto Ruiz 2009 | ||
4 | -- License : BSD3 | ||
5 | -- Maintainer : Alberto Ruiz | ||
6 | -- Stability : provisional | ||
7 | -- | ||
8 | -- Signatures of the C functions. | ||
9 | -- | ||
10 | |||
11 | |||
12 | module Data.Packed.Internal.Signatures where | ||
13 | |||
14 | import Foreign.Ptr(Ptr) | ||
15 | import Data.Complex(Complex) | ||
16 | import Foreign.C.Types(CInt) | ||
17 | |||
18 | type PF = Ptr Float -- | ||
19 | type PD = Ptr Double -- | ||
20 | type PQ = Ptr (Complex Float) -- | ||
21 | type PC = Ptr (Complex Double) -- | ||
22 | type TF = CInt -> PF -> IO CInt -- | ||
23 | type TFF = CInt -> PF -> TF -- | ||
24 | type TFV = CInt -> PF -> TV -- | ||
25 | type TVF = CInt -> PD -> TF -- | ||
26 | type TFFF = CInt -> PF -> TFF -- | ||
27 | type TV = CInt -> PD -> IO CInt -- | ||
28 | type TVV = CInt -> PD -> TV -- | ||
29 | type TVVV = CInt -> PD -> TVV -- | ||
30 | type TFM = CInt -> CInt -> PF -> IO CInt -- | ||
31 | type TFMFM = CInt -> CInt -> PF -> TFM -- | ||
32 | type TFMFMFM = CInt -> CInt -> PF -> TFMFM -- | ||
33 | type TM = CInt -> CInt -> PD -> IO CInt -- | ||
34 | type TMM = CInt -> CInt -> PD -> TM -- | ||
35 | type TVMM = CInt -> PD -> TMM -- | ||
36 | type TMVMM = CInt -> CInt -> PD -> TVMM -- | ||
37 | type TMMM = CInt -> CInt -> PD -> TMM -- | ||
38 | type TVM = CInt -> PD -> TM -- | ||
39 | type TVVM = CInt -> PD -> TVM -- | ||
40 | type TMV = CInt -> CInt -> PD -> TV -- | ||
41 | type TMMV = CInt -> CInt -> PD -> TMV -- | ||
42 | type TMVM = CInt -> CInt -> PD -> TVM -- | ||
43 | type TMMVM = CInt -> CInt -> PD -> TMVM -- | ||
44 | type TCM = CInt -> CInt -> PC -> IO CInt -- | ||
45 | type TCVCM = CInt -> PC -> TCM -- | ||
46 | type TCMCVCM = CInt -> CInt -> PC -> TCVCM -- | ||
47 | type TMCMCVCM = CInt -> CInt -> PD -> TCMCVCM -- | ||
48 | type TCMCMCVCM = CInt -> CInt -> PC -> TCMCVCM -- | ||
49 | type TCMCM = CInt -> CInt -> PC -> TCM -- | ||
50 | type TVCM = CInt -> PD -> TCM -- | ||
51 | type TCMVCM = CInt -> CInt -> PC -> TVCM -- | ||
52 | type TCMCMVCM = CInt -> CInt -> PC -> TCMVCM -- | ||
53 | type TCMCMCM = CInt -> CInt -> PC -> TCMCM -- | ||
54 | type TCV = CInt -> PC -> IO CInt -- | ||
55 | type TCVCV = CInt -> PC -> TCV -- | ||
56 | type TCVCVCV = CInt -> PC -> TCVCV -- | ||
57 | type TCVV = CInt -> PC -> TV -- | ||
58 | type TQV = CInt -> PQ -> IO CInt -- | ||
59 | type TQVQV = CInt -> PQ -> TQV -- | ||
60 | type TQVQVQV = CInt -> PQ -> TQVQV -- | ||
61 | type TQVF = CInt -> PQ -> TF -- | ||
62 | type TQM = CInt -> CInt -> PQ -> IO CInt -- | ||
63 | type TQMQM = CInt -> CInt -> PQ -> TQM -- | ||
64 | type TQMQMQM = CInt -> CInt -> PQ -> TQMQM -- | ||
65 | type TCMCV = CInt -> CInt -> PC -> TCV -- | ||
66 | type TVCV = CInt -> PD -> TCV -- | ||
67 | type TCVM = CInt -> PC -> TM -- | ||
68 | type TMCVM = CInt -> CInt -> PD -> TCVM -- | ||
69 | type TMMCVM = CInt -> CInt -> PD -> TMCVM -- | ||
70 | |||
diff --git a/packages/base/src/Data/Packed/Vector.hs b/packages/base/src/Data/Packed/Vector.hs deleted file mode 100644 index 2104f52..0000000 --- a/packages/base/src/Data/Packed/Vector.hs +++ /dev/null | |||
@@ -1,125 +0,0 @@ | |||
1 | {-# LANGUAGE FlexibleContexts #-} | ||
2 | {-# LANGUAGE CPP #-} | ||
3 | ----------------------------------------------------------------------------- | ||
4 | -- | | ||
5 | -- Module : Data.Packed.Vector | ||
6 | -- Copyright : (c) Alberto Ruiz 2007-10 | ||
7 | -- License : BSD3 | ||
8 | -- Maintainer : Alberto Ruiz | ||
9 | -- Stability : provisional | ||
10 | -- | ||
11 | -- 1D arrays suitable for numeric computations using external libraries. | ||
12 | -- | ||
13 | -- This module provides basic functions for manipulation of structure. | ||
14 | -- | ||
15 | ----------------------------------------------------------------------------- | ||
16 | {-# OPTIONS_HADDOCK hide #-} | ||
17 | |||
18 | module Data.Packed.Vector ( | ||
19 | Vector, | ||
20 | fromList, (|>), toList, buildVector, | ||
21 | dim, (@>), | ||
22 | subVector, takesV, vjoin, join, | ||
23 | mapVector, mapVectorWithIndex, zipVector, zipVectorWith, unzipVector, unzipVectorWith, | ||
24 | mapVectorM, mapVectorM_, mapVectorWithIndexM, mapVectorWithIndexM_, | ||
25 | foldLoop, foldVector, foldVectorG, foldVectorWithIndex, | ||
26 | toByteString, fromByteString | ||
27 | ) where | ||
28 | |||
29 | import Data.Packed.Internal.Vector | ||
30 | import Foreign.Storable | ||
31 | |||
32 | ------------------------------------------------------------------- | ||
33 | |||
34 | #ifdef BINARY | ||
35 | |||
36 | import Data.Binary | ||
37 | import Control.Monad(replicateM) | ||
38 | |||
39 | import Data.ByteString.Internal as BS | ||
40 | import Foreign.ForeignPtr(castForeignPtr) | ||
41 | import Data.Vector.Storable.Internal(updPtr) | ||
42 | import Foreign.Ptr(plusPtr) | ||
43 | |||
44 | |||
45 | -- a 64K cache, with a Double taking 13 bytes in Bytestring, | ||
46 | -- implies a chunk size of 5041 | ||
47 | chunk :: Int | ||
48 | chunk = 5000 | ||
49 | |||
50 | chunks :: Int -> [Int] | ||
51 | chunks d = let c = d `div` chunk | ||
52 | m = d `mod` chunk | ||
53 | in if m /= 0 then reverse (m:(replicate c chunk)) else (replicate c chunk) | ||
54 | |||
55 | putVector v = mapM_ put $! toList v | ||
56 | |||
57 | getVector d = do | ||
58 | xs <- replicateM d get | ||
59 | return $! fromList xs | ||
60 | |||
61 | -------------------------------------------------------------------------------- | ||
62 | |||
63 | toByteString :: Storable t => Vector t -> ByteString | ||
64 | toByteString v = BS.PS (castForeignPtr fp) (sz*o) (sz * dim v) | ||
65 | where | ||
66 | (fp,o,_n) = unsafeToForeignPtr v | ||
67 | sz = sizeOf (v@>0) | ||
68 | |||
69 | |||
70 | fromByteString :: Storable t => ByteString -> Vector t | ||
71 | fromByteString (BS.PS fp o n) = r | ||
72 | where | ||
73 | r = unsafeFromForeignPtr (castForeignPtr (updPtr (`plusPtr` o) fp)) 0 n' | ||
74 | n' = n `div` sz | ||
75 | sz = sizeOf (r@>0) | ||
76 | |||
77 | -------------------------------------------------------------------------------- | ||
78 | |||
79 | instance (Binary a, Storable a) => Binary (Vector a) where | ||
80 | |||
81 | put v = do | ||
82 | let d = dim v | ||
83 | put d | ||
84 | mapM_ putVector $! takesV (chunks d) v | ||
85 | |||
86 | -- put = put . v2bs | ||
87 | |||
88 | get = do | ||
89 | d <- get | ||
90 | vs <- mapM getVector $ chunks d | ||
91 | return $! vjoin vs | ||
92 | |||
93 | -- get = fmap bs2v get | ||
94 | |||
95 | #endif | ||
96 | |||
97 | |||
98 | ------------------------------------------------------------------- | ||
99 | |||
100 | {- | creates a Vector of the specified length using the supplied function to | ||
101 | to map the index to the value at that index. | ||
102 | |||
103 | @> buildVector 4 fromIntegral | ||
104 | 4 |> [0.0,1.0,2.0,3.0]@ | ||
105 | |||
106 | -} | ||
107 | buildVector :: Storable a => Int -> (Int -> a) -> Vector a | ||
108 | buildVector len f = | ||
109 | fromList $ map f [0 .. (len - 1)] | ||
110 | |||
111 | |||
112 | -- | zip for Vectors | ||
113 | zipVector :: (Storable a, Storable b, Storable (a,b)) => Vector a -> Vector b -> Vector (a,b) | ||
114 | zipVector = zipVectorWith (,) | ||
115 | |||
116 | -- | unzip for Vectors | ||
117 | unzipVector :: (Storable a, Storable b, Storable (a,b)) => Vector (a,b) -> (Vector a,Vector b) | ||
118 | unzipVector = unzipVectorWith id | ||
119 | |||
120 | ------------------------------------------------------------------- | ||
121 | |||
122 | {-# DEPRECATED join "use vjoin or Data.Vector.concat" #-} | ||
123 | join :: Storable t => [Vector t] -> Vector t | ||
124 | join = vjoin | ||
125 | |||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Algorithms.hs b/packages/base/src/Internal/Algorithms.hs index 02ac6a0..c4f1a60 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Algorithms.hs +++ b/packages/base/src/Internal/Algorithms.hs | |||
@@ -6,7 +6,7 @@ | |||
6 | 6 | ||
7 | ----------------------------------------------------------------------------- | 7 | ----------------------------------------------------------------------------- |
8 | {- | | 8 | {- | |
9 | Module : Numeric.LinearAlgebra.Algorithms | 9 | Module : Internal.Algorithms |
10 | Copyright : (c) Alberto Ruiz 2006-14 | 10 | Copyright : (c) Alberto Ruiz 2006-14 |
11 | License : BSD3 | 11 | License : BSD3 |
12 | Maintainer : Alberto Ruiz | 12 | Maintainer : Alberto Ruiz |
@@ -18,86 +18,36 @@ Specific functions for particular base types can also be explicitly | |||
18 | imported from "Numeric.LinearAlgebra.LAPACK". | 18 | imported from "Numeric.LinearAlgebra.LAPACK". |
19 | 19 | ||
20 | -} | 20 | -} |
21 | {-# OPTIONS_HADDOCK hide #-} | ||
22 | ----------------------------------------------------------------------------- | 21 | ----------------------------------------------------------------------------- |
23 | 22 | ||
24 | module Numeric.LinearAlgebra.Algorithms ( | 23 | module Internal.Algorithms where |
25 | -- * Supported types | ||
26 | Field(), | ||
27 | -- * Linear Systems | ||
28 | linearSolve, | ||
29 | mbLinearSolve, | ||
30 | luSolve, | ||
31 | cholSolve, | ||
32 | linearSolveLS, | ||
33 | linearSolveSVD, | ||
34 | inv, pinv, pinvTol, | ||
35 | det, invlndet, | ||
36 | rank, rcond, | ||
37 | -- * Matrix factorizations | ||
38 | -- ** Singular value decomposition | ||
39 | svd, | ||
40 | fullSVD, | ||
41 | thinSVD, | ||
42 | compactSVD, | ||
43 | singularValues, | ||
44 | leftSV, rightSV, | ||
45 | -- ** Eigensystems | ||
46 | eig, eigSH, eigSH', | ||
47 | eigenvalues, eigenvaluesSH, eigenvaluesSH', | ||
48 | geigSH', | ||
49 | -- ** QR | ||
50 | qr, rq, qrRaw, qrgr, | ||
51 | -- ** Cholesky | ||
52 | chol, cholSH, mbCholSH, | ||
53 | -- ** Hessenberg | ||
54 | hess, | ||
55 | -- ** Schur | ||
56 | schur, | ||
57 | -- ** LU | ||
58 | lu, luPacked, | ||
59 | -- * Matrix functions | ||
60 | expm, | ||
61 | sqrtm, | ||
62 | matFunc, | ||
63 | -- * Nullspace | ||
64 | nullspacePrec, | ||
65 | nullVector, | ||
66 | nullspaceSVD, | ||
67 | orthSVD, | ||
68 | orth, | ||
69 | -- * Norms | ||
70 | Normed(..), NormType(..), | ||
71 | relativeError', relativeError, | ||
72 | -- * Misc | ||
73 | eps, peps, i, | ||
74 | -- * Util | ||
75 | haussholder, | ||
76 | unpackQR, unpackHess, | ||
77 | ranksv | ||
78 | ) where | ||
79 | |||
80 | |||
81 | import Data.Packed | ||
82 | import Numeric.LinearAlgebra.LAPACK as LAPACK | ||
83 | import Data.List(foldl1') | ||
84 | import Data.Array | ||
85 | import Data.Packed.Internal.Numeric | ||
86 | import Data.Packed.Internal(shSize) | ||
87 | 24 | ||
25 | import Internal.Vector | ||
26 | import Internal.Matrix | ||
27 | import Internal.Element | ||
28 | import Internal.Conversion | ||
29 | import Internal.LAPACK as LAPACK | ||
30 | import Internal.Numeric | ||
31 | import Data.List(foldl1') | ||
32 | import qualified Data.Array as A | ||
33 | import Internal.ST | ||
34 | import Internal.Vectorized(range) | ||
35 | import Control.DeepSeq | ||
88 | 36 | ||
89 | {- | Generic linear algebra functions for double precision real and complex matrices. | 37 | {- | Generic linear algebra functions for double precision real and complex matrices. |
90 | 38 | ||
91 | (Single precision data can be converted using 'single' and 'double'). | 39 | (Single precision data can be converted using 'single' and 'double'). |
92 | 40 | ||
93 | -} | 41 | -} |
94 | class (Product t, | 42 | class (Numeric t, |
95 | Convert t, | 43 | Convert t, |
96 | Container Vector t, | ||
97 | Container Matrix t, | ||
98 | Normed Matrix t, | 44 | Normed Matrix t, |
99 | Normed Vector t, | 45 | Normed Vector t, |
100 | Floating t, | 46 | Floating t, |
47 | Linear t Vector, | ||
48 | Linear t Matrix, | ||
49 | Additive (Vector t), | ||
50 | Additive (Matrix t), | ||
101 | RealOf t ~ Double) => Field t where | 51 | RealOf t ~ Double) => Field t where |
102 | svd' :: Matrix t -> (Matrix t, Vector Double, Matrix t) | 52 | svd' :: Matrix t -> (Matrix t, Vector Double, Matrix t) |
103 | thinSVD' :: Matrix t -> (Matrix t, Vector Double, Matrix t) | 53 | thinSVD' :: Matrix t -> (Matrix t, Vector Double, Matrix t) |
@@ -107,6 +57,8 @@ class (Product t, | |||
107 | mbLinearSolve' :: Matrix t -> Matrix t -> Maybe (Matrix t) | 57 | mbLinearSolve' :: Matrix t -> Matrix t -> Maybe (Matrix t) |
108 | linearSolve' :: Matrix t -> Matrix t -> Matrix t | 58 | linearSolve' :: Matrix t -> Matrix t -> Matrix t |
109 | cholSolve' :: Matrix t -> Matrix t -> Matrix t | 59 | cholSolve' :: Matrix t -> Matrix t -> Matrix t |
60 | ldlPacked' :: Matrix t -> (Matrix t, [Int]) | ||
61 | ldlSolve' :: (Matrix t, [Int]) -> Matrix t -> Matrix t | ||
110 | linearSolveSVD' :: Matrix t -> Matrix t -> Matrix t | 62 | linearSolveSVD' :: Matrix t -> Matrix t -> Matrix t |
111 | linearSolveLS' :: Matrix t -> Matrix t -> Matrix t | 63 | linearSolveLS' :: Matrix t -> Matrix t -> Matrix t |
112 | eig' :: Matrix t -> (Vector (Complex Double), Matrix (Complex Double)) | 64 | eig' :: Matrix t -> (Vector (Complex Double), Matrix (Complex Double)) |
@@ -142,6 +94,8 @@ instance Field Double where | |||
142 | qrgr' = qrgrR | 94 | qrgr' = qrgrR |
143 | hess' = unpackHess hessR | 95 | hess' = unpackHess hessR |
144 | schur' = schurR | 96 | schur' = schurR |
97 | ldlPacked' = ldlR | ||
98 | ldlSolve'= uncurry ldlsR | ||
145 | 99 | ||
146 | instance Field (Complex Double) where | 100 | instance Field (Complex Double) where |
147 | #ifdef NOZGESDD | 101 | #ifdef NOZGESDD |
@@ -169,6 +123,8 @@ instance Field (Complex Double) where | |||
169 | qrgr' = qrgrC | 123 | qrgr' = qrgrC |
170 | hess' = unpackHess hessC | 124 | hess' = unpackHess hessC |
171 | schur' = schurC | 125 | schur' = schurC |
126 | ldlPacked' = ldlC | ||
127 | ldlSolve' = uncurry ldlsC | ||
172 | 128 | ||
173 | -------------------------------------------------------------- | 129 | -------------------------------------------------------------- |
174 | 130 | ||
@@ -228,7 +184,9 @@ fromList [35.18264833189422,1.4769076999800903,1.089145439970417e-15] | |||
228 | 184 | ||
229 | -} | 185 | -} |
230 | svd :: Field t => Matrix t -> (Matrix t, Vector Double, Matrix t) | 186 | svd :: Field t => Matrix t -> (Matrix t, Vector Double, Matrix t) |
231 | svd = {-# SCC "svd" #-} svd' | 187 | svd = {-# SCC "svd" #-} g . svd' |
188 | where | ||
189 | g (u,s,v) = (u,s,tr v) | ||
232 | 190 | ||
233 | {- | A version of 'svd' which returns only the @min (rows m) (cols m)@ singular vectors of @m@. | 191 | {- | A version of 'svd' which returns only the @min (rows m) (cols m)@ singular vectors of @m@. |
234 | 192 | ||
@@ -272,7 +230,10 @@ fromList [35.18264833189422,1.4769076999800903,1.089145439970417e-15] | |||
272 | 230 | ||
273 | -} | 231 | -} |
274 | thinSVD :: Field t => Matrix t -> (Matrix t, Vector Double, Matrix t) | 232 | thinSVD :: Field t => Matrix t -> (Matrix t, Vector Double, Matrix t) |
275 | thinSVD = {-# SCC "thinSVD" #-} thinSVD' | 233 | thinSVD = {-# SCC "thinSVD" #-} g . thinSVD' |
234 | where | ||
235 | g (u,s,v) = (u,s,tr v) | ||
236 | |||
276 | 237 | ||
277 | -- | Singular values only. | 238 | -- | Singular values only. |
278 | singularValues :: Field t => Matrix t -> Vector Double | 239 | singularValues :: Field t => Matrix t -> Vector Double |
@@ -350,25 +311,38 @@ leftSV m | vertical m = let (u,s,_) = svd m in (u,s) | |||
350 | 311 | ||
351 | -------------------------------------------------------------- | 312 | -------------------------------------------------------------- |
352 | 313 | ||
314 | -- | LU decomposition of a matrix in a compact format. | ||
315 | data LU t = LU (Matrix t) [Int] deriving Show | ||
316 | |||
317 | instance (NFData t, Numeric t) => NFData (LU t) | ||
318 | where | ||
319 | rnf (LU m _) = rnf m | ||
320 | |||
353 | -- | Obtains the LU decomposition of a matrix in a compact data structure suitable for 'luSolve'. | 321 | -- | Obtains the LU decomposition of a matrix in a compact data structure suitable for 'luSolve'. |
354 | luPacked :: Field t => Matrix t -> (Matrix t, [Int]) | 322 | luPacked :: Field t => Matrix t -> LU t |
355 | luPacked = {-# SCC "luPacked" #-} luPacked' | 323 | luPacked x = {-# SCC "luPacked" #-} LU m p |
324 | where | ||
325 | (m,p) = luPacked' x | ||
356 | 326 | ||
357 | -- | Solution of a linear system (for several right hand sides) from the precomputed LU factorization obtained by 'luPacked'. | 327 | -- | Solution of a linear system (for several right hand sides) from the precomputed LU factorization obtained by 'luPacked'. |
358 | luSolve :: Field t => (Matrix t, [Int]) -> Matrix t -> Matrix t | 328 | luSolve :: Field t => LU t -> Matrix t -> Matrix t |
359 | luSolve = {-# SCC "luSolve" #-} luSolve' | 329 | luSolve (LU m p) = {-# SCC "luSolve" #-} luSolve' (m,p) |
360 | 330 | ||
361 | -- | Solve a linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition. For underconstrained or overconstrained systems use 'linearSolveLS' or 'linearSolveSVD'. | 331 | -- | Solve a linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition. For underconstrained or overconstrained systems use 'linearSolveLS' or 'linearSolveSVD'. |
362 | -- It is similar to 'luSolve' . 'luPacked', but @linearSolve@ raises an error if called on a singular system. | 332 | -- It is similar to 'luSolve' . 'luPacked', but @linearSolve@ raises an error if called on a singular system. |
363 | linearSolve :: Field t => Matrix t -> Matrix t -> Matrix t | 333 | linearSolve :: Field t => Matrix t -> Matrix t -> Matrix t |
364 | linearSolve = {-# SCC "linearSolve" #-} linearSolve' | 334 | linearSolve = {-# SCC "linearSolve" #-} linearSolve' |
365 | 335 | ||
366 | -- | Solve a linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition, returning Nothing for a singular system. For underconstrained or overconstrained systems use 'linearSolveLS' or 'linearSolveSVD'. | 336 | -- | Solve a linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition, returning Nothing for a singular system. For underconstrained or overconstrained systems use 'linearSolveLS' or 'linearSolveSVD'. |
367 | mbLinearSolve :: Field t => Matrix t -> Matrix t -> Maybe (Matrix t) | 337 | mbLinearSolve :: Field t => Matrix t -> Matrix t -> Maybe (Matrix t) |
368 | mbLinearSolve = {-# SCC "linearSolve" #-} mbLinearSolve' | 338 | mbLinearSolve = {-# SCC "linearSolve" #-} mbLinearSolve' |
369 | 339 | ||
370 | -- | Solve a symmetric or Hermitian positive definite linear system using a precomputed Cholesky decomposition obtained by 'chol'. | 340 | -- | Solve a symmetric or Hermitian positive definite linear system using a precomputed Cholesky decomposition obtained by 'chol'. |
371 | cholSolve :: Field t => Matrix t -> Matrix t -> Matrix t | 341 | cholSolve |
342 | :: Field t | ||
343 | => Matrix t -- ^ Cholesky decomposition of the coefficient matrix | ||
344 | -> Matrix t -- ^ right hand sides | ||
345 | -> Matrix t -- ^ solution | ||
372 | cholSolve = {-# SCC "cholSolve" #-} cholSolve' | 346 | cholSolve = {-# SCC "cholSolve" #-} cholSolve' |
373 | 347 | ||
374 | -- | Minimum norm solution of a general linear least squares problem Ax=B using the SVD. Admits rank-deficient systems but it is slower than 'linearSolveLS'. The effective rank of A is determined by treating as zero those singular valures which are less than 'eps' times the largest singular value. | 348 | -- | Minimum norm solution of a general linear least squares problem Ax=B using the SVD. Admits rank-deficient systems but it is slower than 'linearSolveLS'. The effective rank of A is determined by treating as zero those singular valures which are less than 'eps' times the largest singular value. |
@@ -380,6 +354,31 @@ linearSolveSVD = {-# SCC "linearSolveSVD" #-} linearSolveSVD' | |||
380 | linearSolveLS :: Field t => Matrix t -> Matrix t -> Matrix t | 354 | linearSolveLS :: Field t => Matrix t -> Matrix t -> Matrix t |
381 | linearSolveLS = {-# SCC "linearSolveLS" #-} linearSolveLS' | 355 | linearSolveLS = {-# SCC "linearSolveLS" #-} linearSolveLS' |
382 | 356 | ||
357 | -------------------------------------------------------------------------------- | ||
358 | |||
359 | -- | LDL decomposition of a complex Hermitian or real symmetric matrix in a compact format. | ||
360 | data LDL t = LDL (Matrix t) [Int] deriving Show | ||
361 | |||
362 | instance (NFData t, Numeric t) => NFData (LDL t) | ||
363 | where | ||
364 | rnf (LDL m _) = rnf m | ||
365 | |||
366 | -- | Similar to 'ldlPacked', without checking that the input matrix is hermitian or symmetric. It works with the lower triangular part. | ||
367 | ldlPackedSH :: Field t => Matrix t -> LDL t | ||
368 | ldlPackedSH x = {-# SCC "ldlPacked" #-} LDL m p | ||
369 | where | ||
370 | (m,p) = ldlPacked' x | ||
371 | |||
372 | -- | Obtains the LDL decomposition of a matrix in a compact data structure suitable for 'ldlSolve'. | ||
373 | ldlPacked :: Field t => Herm t -> LDL t | ||
374 | ldlPacked (Herm m) = ldlPackedSH m | ||
375 | |||
376 | -- | Solution of a linear system (for several right hand sides) from a precomputed LDL factorization obtained by 'ldlPacked'. | ||
377 | -- | ||
378 | -- Note: this can be slower than the general solver based on the LU decomposition. | ||
379 | ldlSolve :: Field t => LDL t -> Matrix t -> Matrix t | ||
380 | ldlSolve (LDL m p) = {-# SCC "ldlSolve" #-} ldlSolve' (m,p) | ||
381 | |||
383 | -------------------------------------------------------------- | 382 | -------------------------------------------------------------- |
384 | 383 | ||
385 | {- | Eigenvalues (not ordered) and eigenvectors (as columns) of a general square matrix. | 384 | {- | Eigenvalues (not ordered) and eigenvectors (as columns) of a general square matrix. |
@@ -456,28 +455,39 @@ fromList [11.344814282762075,0.17091518882717918,-0.5157294715892575] | |||
456 | 3.000 5.000 6.000 | 455 | 3.000 5.000 6.000 |
457 | 456 | ||
458 | -} | 457 | -} |
459 | eigSH :: Field t => Matrix t -> (Vector Double, Matrix t) | 458 | eigSH :: Field t => Herm t -> (Vector Double, Matrix t) |
460 | eigSH m | exactHermitian m = eigSH' m | 459 | eigSH (Herm m) = eigSH' m |
461 | | otherwise = error "eigSH requires complex hermitian or real symmetric matrix" | ||
462 | 460 | ||
463 | -- | Eigenvalues (in descending order) of a complex hermitian or real symmetric matrix. | 461 | -- | Eigenvalues (in descending order) of a complex hermitian or real symmetric matrix. |
464 | eigenvaluesSH :: Field t => Matrix t -> Vector Double | 462 | eigenvaluesSH :: Field t => Herm t -> Vector Double |
465 | eigenvaluesSH m | exactHermitian m = eigenvaluesSH' m | 463 | eigenvaluesSH (Herm m) = eigenvaluesSH' m |
466 | | otherwise = error "eigenvaluesSH requires complex hermitian or real symmetric matrix" | ||
467 | 464 | ||
468 | -------------------------------------------------------------- | 465 | -------------------------------------------------------------- |
469 | 466 | ||
467 | -- | QR decomposition of a matrix in compact form. (The orthogonal matrix is not explicitly formed.) | ||
468 | data QR t = QR (Matrix t) (Vector t) | ||
469 | |||
470 | instance (NFData t, Numeric t) => NFData (QR t) | ||
471 | where | ||
472 | rnf (QR m _) = rnf m | ||
473 | |||
474 | |||
470 | -- | QR factorization. | 475 | -- | QR factorization. |
471 | -- | 476 | -- |
472 | -- If @(q,r) = qr m@ then @m == q \<> r@, where q is unitary and r is upper triangular. | 477 | -- If @(q,r) = qr m@ then @m == q \<> r@, where q is unitary and r is upper triangular. |
473 | qr :: Field t => Matrix t -> (Matrix t, Matrix t) | 478 | qr :: Field t => Matrix t -> (Matrix t, Matrix t) |
474 | qr = {-# SCC "qr" #-} unpackQR . qr' | 479 | qr = {-# SCC "qr" #-} unpackQR . qr' |
475 | 480 | ||
476 | qrRaw m = qr' m | 481 | -- | Compute the QR decomposition of a matrix in compact form. |
482 | qrRaw :: Field t => Matrix t -> QR t | ||
483 | qrRaw m = QR x v | ||
484 | where | ||
485 | (x,v) = qr' m | ||
477 | 486 | ||
478 | {- | generate a matrix with k orthogonal columns from the output of qrRaw | 487 | -- | generate a matrix with k orthogonal columns from the compact QR decomposition obtained by 'qrRaw'. |
479 | -} | 488 | -- |
480 | qrgr n (a,t) | 489 | qrgr :: Field t => Int -> QR t -> Matrix t |
490 | qrgr n (QR a t) | ||
481 | | dim t > min (cols a) (rows a) || n < 0 || n > dim t = error "qrgr expects k <= min(rows,cols)" | 491 | | dim t > min (cols a) (rows a) || n < 0 || n > dim t = error "qrgr expects k <= min(rows,cols)" |
482 | | otherwise = qrgr' n (a,t) | 492 | | otherwise = qrgr' n (a,t) |
483 | 493 | ||
@@ -494,14 +504,14 @@ rq m = {-# SCC "rq" #-} (r,q) where | |||
494 | 504 | ||
495 | -- | Hessenberg factorization. | 505 | -- | Hessenberg factorization. |
496 | -- | 506 | -- |
497 | -- If @(p,h) = hess m@ then @m == p \<> h \<> ctrans p@, where p is unitary | 507 | -- If @(p,h) = hess m@ then @m == p \<> h \<> tr p@, where p is unitary |
498 | -- and h is in upper Hessenberg form (it has zero entries below the first subdiagonal). | 508 | -- and h is in upper Hessenberg form (it has zero entries below the first subdiagonal). |
499 | hess :: Field t => Matrix t -> (Matrix t, Matrix t) | 509 | hess :: Field t => Matrix t -> (Matrix t, Matrix t) |
500 | hess = hess' | 510 | hess = hess' |
501 | 511 | ||
502 | -- | Schur factorization. | 512 | -- | Schur factorization. |
503 | -- | 513 | -- |
504 | -- If @(u,s) = schur m@ then @m == u \<> s \<> ctrans u@, where u is unitary | 514 | -- If @(u,s) = schur m@ then @m == u \<> s \<> tr u@, where u is unitary |
505 | -- and s is a Shur matrix. A complex Schur matrix is upper triangular. A real Schur matrix is | 515 | -- and s is a Shur matrix. A complex Schur matrix is upper triangular. A real Schur matrix is |
506 | -- upper triangular in 2x2 blocks. | 516 | -- upper triangular in 2x2 blocks. |
507 | -- | 517 | -- |
@@ -517,14 +527,18 @@ mbCholSH = {-# SCC "mbCholSH" #-} mbCholSH' | |||
517 | 527 | ||
518 | -- | Similar to 'chol', without checking that the input matrix is hermitian or symmetric. It works with the upper triangular part. | 528 | -- | Similar to 'chol', without checking that the input matrix is hermitian or symmetric. It works with the upper triangular part. |
519 | cholSH :: Field t => Matrix t -> Matrix t | 529 | cholSH :: Field t => Matrix t -> Matrix t |
520 | cholSH = {-# SCC "cholSH" #-} cholSH' | 530 | cholSH = cholSH' |
521 | 531 | ||
522 | -- | Cholesky factorization of a positive definite hermitian or symmetric matrix. | 532 | -- | Cholesky factorization of a positive definite hermitian or symmetric matrix. |
523 | -- | 533 | -- |
524 | -- If @c = chol m@ then @c@ is upper triangular and @m == ctrans c \<> c@. | 534 | -- If @c = chol m@ then @c@ is upper triangular and @m == tr c \<> c@. |
525 | chol :: Field t => Matrix t -> Matrix t | 535 | chol :: Field t => Herm t -> Matrix t |
526 | chol m | exactHermitian m = cholSH m | 536 | chol (Herm m) = {-# SCC "chol" #-} cholSH' m |
527 | | otherwise = error "chol requires positive definite complex hermitian or real symmetric matrix" | 537 | |
538 | -- | Similar to 'chol', but instead of an error (e.g., caused by a matrix not positive definite) it returns 'Nothing'. | ||
539 | mbChol :: Field t => Herm t -> Maybe (Matrix t) | ||
540 | mbChol (Herm m) = {-# SCC "mbChol" #-} mbCholSH' m | ||
541 | |||
528 | 542 | ||
529 | 543 | ||
530 | -- | Joint computation of inverse and logarithm of determinant of a square matrix. | 544 | -- | Joint computation of inverse and logarithm of determinant of a square matrix. |
@@ -534,7 +548,7 @@ invlndet :: Field t | |||
534 | invlndet m | square m = (im,(ladm,sdm)) | 548 | invlndet m | square m = (im,(ladm,sdm)) |
535 | | otherwise = error $ "invlndet of nonsquare "++ shSize m ++ " matrix" | 549 | | otherwise = error $ "invlndet of nonsquare "++ shSize m ++ " matrix" |
536 | where | 550 | where |
537 | lp@(lup,perm) = luPacked m | 551 | lp@(LU lup perm) = luPacked m |
538 | s = signlp (rows m) perm | 552 | s = signlp (rows m) perm |
539 | dg = toList $ takeDiag $ lup | 553 | dg = toList $ takeDiag $ lup |
540 | ladm = sum $ map (log.abs) dg | 554 | ladm = sum $ map (log.abs) dg |
@@ -546,8 +560,9 @@ invlndet m | square m = (im,(ladm,sdm)) | |||
546 | det :: Field t => Matrix t -> t | 560 | det :: Field t => Matrix t -> t |
547 | det m | square m = {-# SCC "det" #-} s * (product $ toList $ takeDiag $ lup) | 561 | det m | square m = {-# SCC "det" #-} s * (product $ toList $ takeDiag $ lup) |
548 | | otherwise = error $ "det of nonsquare "++ shSize m ++ " matrix" | 562 | | otherwise = error $ "det of nonsquare "++ shSize m ++ " matrix" |
549 | where (lup,perm) = luPacked m | 563 | where |
550 | s = signlp (rows m) perm | 564 | LU lup perm = luPacked m |
565 | s = signlp (rows m) perm | ||
551 | 566 | ||
552 | -- | Explicit LU factorization of a general matrix. | 567 | -- | Explicit LU factorization of a general matrix. |
553 | -- | 568 | -- |
@@ -587,7 +602,7 @@ m = (3><3) [ 1, 0, 0 | |||
587 | -} | 602 | -} |
588 | 603 | ||
589 | pinvTol :: Field t => Double -> Matrix t -> Matrix t | 604 | pinvTol :: Field t => Double -> Matrix t -> Matrix t |
590 | pinvTol t m = conj v' `mXm` diag s' `mXm` ctrans u' where | 605 | pinvTol t m = v' `mXm` diag s' `mXm` ctrans u' where |
591 | (u,s,v) = thinSVD m | 606 | (u,s,v) = thinSVD m |
592 | sl@(g:_) = toList s | 607 | sl@(g:_) = toList s |
593 | s' = real . fromList . map rec $ sl | 608 | s' = real . fromList . map rec $ sl |
@@ -628,11 +643,6 @@ eps = 2.22044604925031e-16 | |||
628 | peps :: RealFloat x => x | 643 | peps :: RealFloat x => x |
629 | peps = x where x = 2.0 ** fromIntegral (1 - floatDigits x) | 644 | peps = x where x = 2.0 ** fromIntegral (1 - floatDigits x) |
630 | 645 | ||
631 | |||
632 | -- | The imaginary unit: @i = 0.0 :+ 1.0@ | ||
633 | i :: Complex Double | ||
634 | i = 0:+1 | ||
635 | |||
636 | ----------------------------------------------------------------------- | 646 | ----------------------------------------------------------------------- |
637 | 647 | ||
638 | -- | The nullspace of a matrix from its precomputed SVD decomposition. | 648 | -- | The nullspace of a matrix from its precomputed SVD decomposition. |
@@ -649,7 +659,7 @@ nullspaceSVD hint a (s,v) = vs where | |||
649 | k = case hint of | 659 | k = case hint of |
650 | Right t -> t | 660 | Right t -> t |
651 | _ -> rankSVD tol a s | 661 | _ -> rankSVD tol a s |
652 | vs = conj (dropColumns k v) | 662 | vs = dropColumns k v |
653 | 663 | ||
654 | 664 | ||
655 | -- | The nullspace of a matrix. See also 'nullspaceSVD'. | 665 | -- | The nullspace of a matrix. See also 'nullspaceSVD'. |
@@ -752,7 +762,7 @@ diagonalize m = if rank v == n | |||
752 | else Nothing | 762 | else Nothing |
753 | where n = rows m | 763 | where n = rows m |
754 | (l,v) = if exactHermitian m | 764 | (l,v) = if exactHermitian m |
755 | then let (l',v') = eigSH m in (real l', v') | 765 | then let (l',v') = eigSH (trustSym m) in (real l', v') |
756 | else eig m | 766 | else eig m |
757 | 767 | ||
758 | -- | Generic matrix functions for diagonalizable matrices. For instance: | 768 | -- | Generic matrix functions for diagonalizable matrices. For instance: |
@@ -846,19 +856,32 @@ signlp r vals = foldl f 1 (zip [0..r-1] vals) | |||
846 | where f s (a,b) | a /= b = -s | 856 | where f s (a,b) | a /= b = -s |
847 | | otherwise = s | 857 | | otherwise = s |
848 | 858 | ||
849 | swap (arr,s) (a,b) | a /= b = (arr // [(a, arr!b),(b,arr!a)],-s) | 859 | fixPerm r vals = (fromColumns $ A.elems res, sign) |
850 | | otherwise = (arr,s) | 860 | where |
851 | 861 | v = [0..r-1] | |
852 | fixPerm r vals = (fromColumns $ elems res, sign) | 862 | t = toColumns (ident r) |
853 | where v = [0..r-1] | 863 | (res,sign) = foldl swap (A.listArray (0,r-1) t, 1) (zip v vals) |
854 | s = toColumns (ident r) | 864 | swap (arr,s) (a,b) |
855 | (res,sign) = foldl swap (listArray (0,r-1) s, 1) (zip v vals) | 865 | | a /= b = (arr A.// [(a, arr A.! b),(b,arr A.! a)],-s) |
866 | | otherwise = (arr,s) | ||
867 | |||
868 | fixPerm' :: [Int] -> Vector I | ||
869 | fixPerm' s = res $ mutable f s0 | ||
870 | where | ||
871 | s0 = reshape 1 (range (length s)) | ||
872 | res = flatten . fst | ||
873 | swap m i j = rowOper (SWAP i j AllCols) m | ||
874 | f :: (Num t, Element t) => (Int, Int) -> STMatrix s t -> ST s () -- needed because of TypeFamilies | ||
875 | f _ p = sequence_ $ zipWith (swap p) [0..] s | ||
856 | 876 | ||
857 | triang r c h v = (r><c) [el s t | s<-[0..r-1], t<-[0..c-1]] | 877 | triang r c h v = (r><c) [el s t | s<-[0..r-1], t<-[0..c-1]] |
858 | where el p q = if q-p>=h then v else 1 - v | 878 | where el p q = if q-p>=h then v else 1 - v |
859 | 879 | ||
860 | luFact (l_u,perm) | r <= c = (l ,u ,p, s) | 880 | -- | Compute the explicit LU decomposition from the compact one obtained by 'luPacked'. |
861 | | otherwise = (l',u',p, s) | 881 | luFact :: Numeric t => LU t -> (Matrix t, Matrix t, Matrix t, t) |
882 | luFact (LU l_u perm) | ||
883 | | r <= c = (l ,u ,p, s) | ||
884 | | otherwise = (l',u',p, s) | ||
862 | where | 885 | where |
863 | r = rows l_u | 886 | r = rows l_u |
864 | c = cols l_u | 887 | c = cols l_u |
@@ -935,10 +958,9 @@ relativeError' x y = dig (norm (x `sub` y) / norm x) | |||
935 | dig r = round $ -logBase 10 (realToFrac r :: Double) | 958 | dig r = round $ -logBase 10 (realToFrac r :: Double) |
936 | 959 | ||
937 | 960 | ||
938 | relativeError :: (Normed c t, Num (c t)) => NormType -> c t -> c t -> Double | 961 | relativeError :: Num a => (a -> Double) -> a -> a -> Double |
939 | relativeError t a b = realToFrac r | 962 | relativeError norm a b = r |
940 | where | 963 | where |
941 | norm = pnorm t | ||
942 | na = norm a | 964 | na = norm a |
943 | nb = norm b | 965 | nb = norm b |
944 | nab = norm (a-b) | 966 | nab = norm (a-b) |
@@ -952,7 +974,13 @@ relativeError t a b = realToFrac r | |||
952 | ---------------------------------------------------------------------- | 974 | ---------------------------------------------------------------------- |
953 | 975 | ||
954 | -- | Generalized symmetric positive definite eigensystem Av = lBv, | 976 | -- | Generalized symmetric positive definite eigensystem Av = lBv, |
955 | -- for A and B symmetric, B positive definite (conditions not checked). | 977 | -- for A and B symmetric, B positive definite. |
978 | geigSH :: Field t | ||
979 | => Herm t -- ^ A | ||
980 | -> Herm t -- ^ B | ||
981 | -> (Vector Double, Matrix t) | ||
982 | geigSH (Herm a) (Herm b) = geigSH' a b | ||
983 | |||
956 | geigSH' :: Field t | 984 | geigSH' :: Field t |
957 | => Matrix t -- ^ A | 985 | => Matrix t -- ^ A |
958 | -> Matrix t -- ^ B | 986 | -> Matrix t -- ^ B |
@@ -966,3 +994,37 @@ geigSH' a b = (l,v') | |||
966 | v' = iu <> v | 994 | v' = iu <> v |
967 | (<>) = mXm | 995 | (<>) = mXm |
968 | 996 | ||
997 | -------------------------------------------------------------------------------- | ||
998 | |||
999 | -- | A matrix that, by construction, it is known to be complex Hermitian or real symmetric. | ||
1000 | -- | ||
1001 | -- It can be created using 'sym', 'mTm', or 'trustSym', and the matrix can be extracted using 'unSym'. | ||
1002 | newtype Herm t = Herm (Matrix t) deriving Show | ||
1003 | |||
1004 | instance (NFData t, Numeric t) => NFData (Herm t) | ||
1005 | where | ||
1006 | rnf (Herm m) = rnf m | ||
1007 | |||
1008 | -- | Extract the general matrix from a 'Herm' structure, forgetting its symmetric or Hermitian property. | ||
1009 | unSym :: Herm t -> Matrix t | ||
1010 | unSym (Herm x) = x | ||
1011 | |||
1012 | -- | Compute the complex Hermitian or real symmetric part of a square matrix (@(x + tr x)/2@). | ||
1013 | sym :: Field t => Matrix t -> Herm t | ||
1014 | sym x = Herm (scale 0.5 (tr x `add` x)) | ||
1015 | |||
1016 | -- | Compute the contraction @tr x <> x@ of a general matrix. | ||
1017 | mTm :: Numeric t => Matrix t -> Herm t | ||
1018 | mTm x = Herm (tr x `mXm` x) | ||
1019 | |||
1020 | instance Field t => Linear t Herm where | ||
1021 | scale x (Herm m) = Herm (scale x m) | ||
1022 | |||
1023 | instance Field t => Additive (Herm t) where | ||
1024 | add (Herm a) (Herm b) = Herm (a `add` b) | ||
1025 | |||
1026 | -- | At your own risk, declare that a matrix is complex Hermitian or real symmetric | ||
1027 | -- for usage in 'chol', 'eigSH', etc. Only a triangular part of the matrix will be used. | ||
1028 | trustSym :: Matrix t -> Herm t | ||
1029 | trustSym x = (Herm x) | ||
1030 | |||
diff --git a/packages/base/src/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c index e5e45ef..ff7ad92 100644 --- a/packages/base/src/C/lapack-aux.c +++ b/packages/base/src/Internal/C/lapack-aux.c | |||
@@ -3,6 +3,14 @@ | |||
3 | #include <string.h> | 3 | #include <string.h> |
4 | #include <math.h> | 4 | #include <math.h> |
5 | #include <time.h> | 5 | #include <time.h> |
6 | #include <inttypes.h> | ||
7 | #include <complex.h> | ||
8 | |||
9 | typedef double complex TCD; | ||
10 | typedef float complex TCF; | ||
11 | |||
12 | #undef complex | ||
13 | |||
6 | #include "lapack-aux.h" | 14 | #include "lapack-aux.h" |
7 | 15 | ||
8 | #define MACRO(B) do {B} while (0) | 16 | #define MACRO(B) do {B} while (0) |
@@ -30,6 +38,9 @@ | |||
30 | // #define OK return 0; | 38 | // #define OK return 0; |
31 | // #endif | 39 | // #endif |
32 | 40 | ||
41 | |||
42 | #define INFOMAT(M) printf("%dx%d %d:%d\n",M##r,M##c,M##Xr,M##Xc); | ||
43 | |||
33 | #define TRACEMAT(M) {int q; printf(" %d x %d: ",M##r,M##c); \ | 44 | #define TRACEMAT(M) {int q; printf(" %d x %d: ",M##r,M##c); \ |
34 | for(q=0;q<M##r*M##c;q++) printf("%.1f ",M##p[q]); printf("\n");} | 45 | for(q=0;q<M##r*M##c;q++) printf("%.1f ",M##p[q]); printf("\n");} |
35 | 46 | ||
@@ -44,7 +55,7 @@ | |||
44 | #define NODEFPOS 2006 | 55 | #define NODEFPOS 2006 |
45 | #define NOSPRTD 2007 | 56 | #define NOSPRTD 2007 |
46 | 57 | ||
47 | //--------------------------------------- | 58 | //////////////////////////////////////////////////////////////////////////////// |
48 | void asm_finit() { | 59 | void asm_finit() { |
49 | #ifdef i386 | 60 | #ifdef i386 |
50 | 61 | ||
@@ -66,8 +77,6 @@ void asm_finit() { | |||
66 | #endif | 77 | #endif |
67 | } | 78 | } |
68 | 79 | ||
69 | //--------------------------------------- | ||
70 | |||
71 | #if NANDEBUG | 80 | #if NANDEBUG |
72 | 81 | ||
73 | #define CHECKNANR(M,msg) \ | 82 | #define CHECKNANR(M,msg) \ |
@@ -97,16 +106,16 @@ for(k=0; k<(M##r * M##c); k++) { \ | |||
97 | #define CHECKNANR(M,msg) | 106 | #define CHECKNANR(M,msg) |
98 | #endif | 107 | #endif |
99 | 108 | ||
100 | //--------------------------------------- | ||
101 | 109 | ||
102 | //////////////////// real svd //////////////////////////////////// | 110 | //////////////////////////////////////////////////////////////////////////////// |
111 | //////////////////// real svd /////////////////////////////////////////////////// | ||
103 | 112 | ||
104 | /* Subroutine */ int dgesvd_(char *jobu, char *jobvt, integer *m, integer *n, | 113 | int dgesvd_(char *jobu, char *jobvt, integer *m, integer *n, |
105 | doublereal *a, integer *lda, doublereal *s, doublereal *u, integer * | 114 | doublereal *a, integer *lda, doublereal *s, doublereal *u, integer * |
106 | ldu, doublereal *vt, integer *ldvt, doublereal *work, integer *lwork, | 115 | ldu, doublereal *vt, integer *ldvt, doublereal *work, integer *lwork, |
107 | integer *info); | 116 | integer *info); |
108 | 117 | ||
109 | int svd_l_R(KDMAT(a),DMAT(u), DVEC(s),DMAT(v)) { | 118 | int svd_l_R(ODMAT(a),ODMAT(u), DVEC(s),ODMAT(v)) { |
110 | integer m = ar; | 119 | integer m = ar; |
111 | integer n = ac; | 120 | integer n = ac; |
112 | integer q = MIN(m,n); | 121 | integer q = MIN(m,n); |
@@ -132,15 +141,12 @@ int svd_l_R(KDMAT(a),DMAT(u), DVEC(s),DMAT(v)) { | |||
132 | } | 141 | } |
133 | } | 142 | } |
134 | DEBUGMSG("svd_l_R"); | 143 | DEBUGMSG("svd_l_R"); |
135 | double *B = (double*)malloc(m*n*sizeof(double)); | ||
136 | CHECK(!B,MEM); | ||
137 | memcpy(B,ap,m*n*sizeof(double)); | ||
138 | integer lwork = -1; | 144 | integer lwork = -1; |
139 | integer res; | 145 | integer res; |
140 | // ask for optimal lwork | 146 | // ask for optimal lwork |
141 | double ans; | 147 | double ans; |
142 | dgesvd_ (jobu,jobvt, | 148 | dgesvd_ (jobu,jobvt, |
143 | &m,&n,B,&m, | 149 | &m,&n,ap,&m, |
144 | sp, | 150 | sp, |
145 | up,&m, | 151 | up,&m, |
146 | vp,&ldvt, | 152 | vp,&ldvt, |
@@ -150,7 +156,7 @@ int svd_l_R(KDMAT(a),DMAT(u), DVEC(s),DMAT(v)) { | |||
150 | double * work = (double*)malloc(lwork*sizeof(double)); | 156 | double * work = (double*)malloc(lwork*sizeof(double)); |
151 | CHECK(!work,MEM); | 157 | CHECK(!work,MEM); |
152 | dgesvd_ (jobu,jobvt, | 158 | dgesvd_ (jobu,jobvt, |
153 | &m,&n,B,&m, | 159 | &m,&n,ap,&m, |
154 | sp, | 160 | sp, |
155 | up,&m, | 161 | up,&m, |
156 | vp,&ldvt, | 162 | vp,&ldvt, |
@@ -158,18 +164,17 @@ int svd_l_R(KDMAT(a),DMAT(u), DVEC(s),DMAT(v)) { | |||
158 | &res); | 164 | &res); |
159 | CHECK(res,res); | 165 | CHECK(res,res); |
160 | free(work); | 166 | free(work); |
161 | free(B); | ||
162 | OK | 167 | OK |
163 | } | 168 | } |
164 | 169 | ||
165 | // (alternative version) | 170 | // (alternative version) |
166 | 171 | ||
167 | /* Subroutine */ int dgesdd_(char *jobz, integer *m, integer *n, doublereal * | 172 | int dgesdd_(char *jobz, integer *m, integer *n, doublereal * |
168 | a, integer *lda, doublereal *s, doublereal *u, integer *ldu, | 173 | a, integer *lda, doublereal *s, doublereal *u, integer *ldu, |
169 | doublereal *vt, integer *ldvt, doublereal *work, integer *lwork, | 174 | doublereal *vt, integer *ldvt, doublereal *work, integer *lwork, |
170 | integer *iwork, integer *info); | 175 | integer *iwork, integer *info); |
171 | 176 | ||
172 | int svd_l_Rdd(KDMAT(a),DMAT(u), DVEC(s),DMAT(v)) { | 177 | int svd_l_Rdd(ODMAT(a),ODMAT(u), DVEC(s),ODMAT(v)) { |
173 | integer m = ar; | 178 | integer m = ar; |
174 | integer n = ac; | 179 | integer n = ac; |
175 | integer q = MIN(m,n); | 180 | integer q = MIN(m,n); |
@@ -189,37 +194,31 @@ int svd_l_Rdd(KDMAT(a),DMAT(u), DVEC(s),DMAT(v)) { | |||
189 | } | 194 | } |
190 | } | 195 | } |
191 | DEBUGMSG("svd_l_Rdd"); | 196 | DEBUGMSG("svd_l_Rdd"); |
192 | double *B = (double*)malloc(m*n*sizeof(double)); | ||
193 | CHECK(!B,MEM); | ||
194 | memcpy(B,ap,m*n*sizeof(double)); | ||
195 | integer* iwk = (integer*) malloc(8*q*sizeof(integer)); | 197 | integer* iwk = (integer*) malloc(8*q*sizeof(integer)); |
196 | CHECK(!iwk,MEM); | 198 | CHECK(!iwk,MEM); |
197 | integer lwk = -1; | 199 | integer lwk = -1; |
198 | integer res; | 200 | integer res; |
199 | // ask for optimal lwk | 201 | // ask for optimal lwk |
200 | double ans; | 202 | double ans; |
201 | dgesdd_ (jobz,&m,&n,B,&m,sp,up,&m,vp,&ldvt,&ans,&lwk,iwk,&res); | 203 | dgesdd_ (jobz,&m,&n,ap,&m,sp,up,&m,vp,&ldvt,&ans,&lwk,iwk,&res); |
202 | lwk = ans; | 204 | lwk = ans; |
203 | double * workv = (double*)malloc(lwk*sizeof(double)); | 205 | double * workv = (double*)malloc(lwk*sizeof(double)); |
204 | CHECK(!workv,MEM); | 206 | CHECK(!workv,MEM); |
205 | dgesdd_ (jobz,&m,&n,B,&m,sp,up,&m,vp,&ldvt,workv,&lwk,iwk,&res); | 207 | dgesdd_ (jobz,&m,&n,ap,&m,sp,up,&m,vp,&ldvt,workv,&lwk,iwk,&res); |
206 | CHECK(res,res); | 208 | CHECK(res,res); |
207 | free(iwk); | 209 | free(iwk); |
208 | free(workv); | 210 | free(workv); |
209 | free(B); | ||
210 | OK | 211 | OK |
211 | } | 212 | } |
212 | 213 | ||
213 | //////////////////// complex svd //////////////////////////////////// | 214 | //////////////////// complex svd //////////////////////////////////// |
214 | 215 | ||
215 | // not in clapack.h | ||
216 | |||
217 | int zgesvd_(char *jobu, char *jobvt, integer *m, integer *n, | 216 | int zgesvd_(char *jobu, char *jobvt, integer *m, integer *n, |
218 | doublecomplex *a, integer *lda, doublereal *s, doublecomplex *u, | 217 | doublecomplex *a, integer *lda, doublereal *s, doublecomplex *u, |
219 | integer *ldu, doublecomplex *vt, integer *ldvt, doublecomplex *work, | 218 | integer *ldu, doublecomplex *vt, integer *ldvt, doublecomplex *work, |
220 | integer *lwork, doublereal *rwork, integer *info); | 219 | integer *lwork, doublereal *rwork, integer *info); |
221 | 220 | ||
222 | int svd_l_C(KCMAT(a),CMAT(u), DVEC(s),CMAT(v)) { | 221 | int svd_l_C(OCMAT(a),OCMAT(u), DVEC(s),OCMAT(v)) { |
223 | integer m = ar; | 222 | integer m = ar; |
224 | integer n = ac; | 223 | integer n = ac; |
225 | integer q = MIN(m,n); | 224 | integer q = MIN(m,n); |
@@ -244,9 +243,6 @@ int svd_l_C(KCMAT(a),CMAT(u), DVEC(s),CMAT(v)) { | |||
244 | ldvt = q; | 243 | ldvt = q; |
245 | } | 244 | } |
246 | }DEBUGMSG("svd_l_C"); | 245 | }DEBUGMSG("svd_l_C"); |
247 | doublecomplex *B = (doublecomplex*)malloc(m*n*sizeof(doublecomplex)); | ||
248 | CHECK(!B,MEM); | ||
249 | memcpy(B,ap,m*n*sizeof(doublecomplex)); | ||
250 | 246 | ||
251 | double *rwork = (double*) malloc(5*q*sizeof(double)); | 247 | double *rwork = (double*) malloc(5*q*sizeof(double)); |
252 | CHECK(!rwork,MEM); | 248 | CHECK(!rwork,MEM); |
@@ -255,7 +251,7 @@ int svd_l_C(KCMAT(a),CMAT(u), DVEC(s),CMAT(v)) { | |||
255 | // ask for optimal lwork | 251 | // ask for optimal lwork |
256 | doublecomplex ans; | 252 | doublecomplex ans; |
257 | zgesvd_ (jobu,jobvt, | 253 | zgesvd_ (jobu,jobvt, |
258 | &m,&n,B,&m, | 254 | &m,&n,ap,&m, |
259 | sp, | 255 | sp, |
260 | up,&m, | 256 | up,&m, |
261 | vp,&ldvt, | 257 | vp,&ldvt, |
@@ -266,7 +262,7 @@ int svd_l_C(KCMAT(a),CMAT(u), DVEC(s),CMAT(v)) { | |||
266 | doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); | 262 | doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); |
267 | CHECK(!work,MEM); | 263 | CHECK(!work,MEM); |
268 | zgesvd_ (jobu,jobvt, | 264 | zgesvd_ (jobu,jobvt, |
269 | &m,&n,B,&m, | 265 | &m,&n,ap,&m, |
270 | sp, | 266 | sp, |
271 | up,&m, | 267 | up,&m, |
272 | vp,&ldvt, | 268 | vp,&ldvt, |
@@ -276,7 +272,6 @@ int svd_l_C(KCMAT(a),CMAT(u), DVEC(s),CMAT(v)) { | |||
276 | CHECK(res,res); | 272 | CHECK(res,res); |
277 | free(work); | 273 | free(work); |
278 | free(rwork); | 274 | free(rwork); |
279 | free(B); | ||
280 | OK | 275 | OK |
281 | } | 276 | } |
282 | 277 | ||
@@ -285,8 +280,7 @@ int zgesdd_ (char *jobz, integer *m, integer *n, | |||
285 | integer *ldu, doublecomplex *vt, integer *ldvt, doublecomplex *work, | 280 | integer *ldu, doublecomplex *vt, integer *ldvt, doublecomplex *work, |
286 | integer *lwork, doublereal *rwork, integer* iwork, integer *info); | 281 | integer *lwork, doublereal *rwork, integer* iwork, integer *info); |
287 | 282 | ||
288 | int svd_l_Cdd(KCMAT(a),CMAT(u), DVEC(s),CMAT(v)) { | 283 | int svd_l_Cdd(OCMAT(a),OCMAT(u), DVEC(s),OCMAT(v)) { |
289 | //printf("entro\n"); | ||
290 | integer m = ar; | 284 | integer m = ar; |
291 | integer n = ac; | 285 | integer n = ac; |
292 | integer q = MIN(m,n); | 286 | integer q = MIN(m,n); |
@@ -306,9 +300,6 @@ int svd_l_Cdd(KCMAT(a),CMAT(u), DVEC(s),CMAT(v)) { | |||
306 | } | 300 | } |
307 | } | 301 | } |
308 | DEBUGMSG("svd_l_Cdd"); | 302 | DEBUGMSG("svd_l_Cdd"); |
309 | doublecomplex *B = (doublecomplex*)malloc(m*n*sizeof(doublecomplex)); | ||
310 | CHECK(!B,MEM); | ||
311 | memcpy(B,ap,m*n*sizeof(doublecomplex)); | ||
312 | integer* iwk = (integer*) malloc(8*q*sizeof(integer)); | 303 | integer* iwk = (integer*) malloc(8*q*sizeof(integer)); |
313 | CHECK(!iwk,MEM); | 304 | CHECK(!iwk,MEM); |
314 | int lrwk; | 305 | int lrwk; |
@@ -319,34 +310,30 @@ int svd_l_Cdd(KCMAT(a),CMAT(u), DVEC(s),CMAT(v)) { | |||
319 | } | 310 | } |
320 | double *rwk = (double*)malloc(lrwk*sizeof(double));; | 311 | double *rwk = (double*)malloc(lrwk*sizeof(double));; |
321 | CHECK(!rwk,MEM); | 312 | CHECK(!rwk,MEM); |
322 | //printf("%s %ld %d\n",jobz,q,lrwk); | ||
323 | integer lwk = -1; | 313 | integer lwk = -1; |
324 | integer res; | 314 | integer res; |
325 | // ask for optimal lwk | 315 | // ask for optimal lwk |
326 | doublecomplex ans; | 316 | doublecomplex ans; |
327 | zgesdd_ (jobz,&m,&n,B,&m,sp,up,&m,vp,&ldvt,&ans,&lwk,rwk,iwk,&res); | 317 | zgesdd_ (jobz,&m,&n,ap,&m,sp,up,&m,vp,&ldvt,&ans,&lwk,rwk,iwk,&res); |
328 | lwk = ans.r; | 318 | lwk = ans.r; |
329 | //printf("lwk = %ld\n",lwk); | ||
330 | doublecomplex * workv = (doublecomplex*)malloc(lwk*sizeof(doublecomplex)); | 319 | doublecomplex * workv = (doublecomplex*)malloc(lwk*sizeof(doublecomplex)); |
331 | CHECK(!workv,MEM); | 320 | CHECK(!workv,MEM); |
332 | zgesdd_ (jobz,&m,&n,B,&m,sp,up,&m,vp,&ldvt,workv,&lwk,rwk,iwk,&res); | 321 | zgesdd_ (jobz,&m,&n,ap,&m,sp,up,&m,vp,&ldvt,workv,&lwk,rwk,iwk,&res); |
333 | //printf("res = %ld\n",res); | ||
334 | CHECK(res,res); | 322 | CHECK(res,res); |
335 | free(workv); // printf("freed workv\n"); | 323 | free(workv); |
336 | free(rwk); // printf("freed rwk\n"); | 324 | free(rwk); |
337 | free(iwk); // printf("freed iwk\n"); | 325 | free(iwk); |
338 | free(B); // printf("freed B, salgo\n"); | ||
339 | OK | 326 | OK |
340 | } | 327 | } |
341 | 328 | ||
342 | //////////////////// general complex eigensystem //////////// | 329 | //////////////////// general complex eigensystem //////////// |
343 | 330 | ||
344 | /* Subroutine */ int zgeev_(char *jobvl, char *jobvr, integer *n, | 331 | int zgeev_(char *jobvl, char *jobvr, integer *n, |
345 | doublecomplex *a, integer *lda, doublecomplex *w, doublecomplex *vl, | 332 | doublecomplex *a, integer *lda, doublecomplex *w, doublecomplex *vl, |
346 | integer *ldvl, doublecomplex *vr, integer *ldvr, doublecomplex *work, | 333 | integer *ldvl, doublecomplex *vr, integer *ldvr, doublecomplex *work, |
347 | integer *lwork, doublereal *rwork, integer *info); | 334 | integer *lwork, doublereal *rwork, integer *info); |
348 | 335 | ||
349 | int eig_l_C(KCMAT(a), CMAT(u), CVEC(s),CMAT(v)) { | 336 | int eig_l_C(OCMAT(a), OCMAT(u), CVEC(s),OCMAT(v)) { |
350 | integer n = ar; | 337 | integer n = ar; |
351 | REQUIRES(ac==n && sn==n, BAD_SIZE); | 338 | REQUIRES(ac==n && sn==n, BAD_SIZE); |
352 | REQUIRES(up==NULL || (ur==n && uc==n), BAD_SIZE); | 339 | REQUIRES(up==NULL || (ur==n && uc==n), BAD_SIZE); |
@@ -354,18 +341,14 @@ int eig_l_C(KCMAT(a), CMAT(u), CVEC(s),CMAT(v)) { | |||
354 | REQUIRES(vp==NULL || (vr==n && vc==n), BAD_SIZE); | 341 | REQUIRES(vp==NULL || (vr==n && vc==n), BAD_SIZE); |
355 | char jobvr = vp==NULL?'N':'V'; | 342 | char jobvr = vp==NULL?'N':'V'; |
356 | DEBUGMSG("eig_l_C"); | 343 | DEBUGMSG("eig_l_C"); |
357 | doublecomplex *B = (doublecomplex*)malloc(n*n*sizeof(doublecomplex)); | ||
358 | CHECK(!B,MEM); | ||
359 | memcpy(B,ap,n*n*sizeof(doublecomplex)); | ||
360 | double *rwork = (double*) malloc(2*n*sizeof(double)); | 344 | double *rwork = (double*) malloc(2*n*sizeof(double)); |
361 | CHECK(!rwork,MEM); | 345 | CHECK(!rwork,MEM); |
362 | integer lwork = -1; | 346 | integer lwork = -1; |
363 | integer res; | 347 | integer res; |
364 | // ask for optimal lwork | 348 | // ask for optimal lwork |
365 | doublecomplex ans; | 349 | doublecomplex ans; |
366 | //printf("ask zgeev\n"); | ||
367 | zgeev_ (&jobvl,&jobvr, | 350 | zgeev_ (&jobvl,&jobvr, |
368 | &n,B,&n, | 351 | &n,ap,&n, |
369 | sp, | 352 | sp, |
370 | up,&n, | 353 | up,&n, |
371 | vp,&n, | 354 | vp,&n, |
@@ -373,12 +356,10 @@ int eig_l_C(KCMAT(a), CMAT(u), CVEC(s),CMAT(v)) { | |||
373 | rwork, | 356 | rwork, |
374 | &res); | 357 | &res); |
375 | lwork = ceil(ans.r); | 358 | lwork = ceil(ans.r); |
376 | //printf("ans = %d\n",lwork); | ||
377 | doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); | 359 | doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); |
378 | CHECK(!work,MEM); | 360 | CHECK(!work,MEM); |
379 | //printf("zgeev\n"); | ||
380 | zgeev_ (&jobvl,&jobvr, | 361 | zgeev_ (&jobvl,&jobvr, |
381 | &n,B,&n, | 362 | &n,ap,&n, |
382 | sp, | 363 | sp, |
383 | up,&n, | 364 | up,&n, |
384 | vp,&n, | 365 | vp,&n, |
@@ -388,7 +369,6 @@ int eig_l_C(KCMAT(a), CMAT(u), CVEC(s),CMAT(v)) { | |||
388 | CHECK(res,res); | 369 | CHECK(res,res); |
389 | free(work); | 370 | free(work); |
390 | free(rwork); | 371 | free(rwork); |
391 | free(B); | ||
392 | OK | 372 | OK |
393 | } | 373 | } |
394 | 374 | ||
@@ -396,12 +376,12 @@ int eig_l_C(KCMAT(a), CMAT(u), CVEC(s),CMAT(v)) { | |||
396 | 376 | ||
397 | //////////////////// general real eigensystem //////////// | 377 | //////////////////// general real eigensystem //////////// |
398 | 378 | ||
399 | /* Subroutine */ int dgeev_(char *jobvl, char *jobvr, integer *n, doublereal * | 379 | int dgeev_(char *jobvl, char *jobvr, integer *n, doublereal * |
400 | a, integer *lda, doublereal *wr, doublereal *wi, doublereal *vl, | 380 | a, integer *lda, doublereal *wr, doublereal *wi, doublereal *vl, |
401 | integer *ldvl, doublereal *vr, integer *ldvr, doublereal *work, | 381 | integer *ldvl, doublereal *vr, integer *ldvr, doublereal *work, |
402 | integer *lwork, integer *info); | 382 | integer *lwork, integer *info); |
403 | 383 | ||
404 | int eig_l_R(KDMAT(a),DMAT(u), CVEC(s),DMAT(v)) { | 384 | int eig_l_R(ODMAT(a),ODMAT(u), CVEC(s),ODMAT(v)) { |
405 | integer n = ar; | 385 | integer n = ar; |
406 | REQUIRES(ac==n && sn==n, BAD_SIZE); | 386 | REQUIRES(ac==n && sn==n, BAD_SIZE); |
407 | REQUIRES(up==NULL || (ur==n && uc==n), BAD_SIZE); | 387 | REQUIRES(up==NULL || (ur==n && uc==n), BAD_SIZE); |
@@ -409,28 +389,22 @@ int eig_l_R(KDMAT(a),DMAT(u), CVEC(s),DMAT(v)) { | |||
409 | REQUIRES(vp==NULL || (vr==n && vc==n), BAD_SIZE); | 389 | REQUIRES(vp==NULL || (vr==n && vc==n), BAD_SIZE); |
410 | char jobvr = vp==NULL?'N':'V'; | 390 | char jobvr = vp==NULL?'N':'V'; |
411 | DEBUGMSG("eig_l_R"); | 391 | DEBUGMSG("eig_l_R"); |
412 | double *B = (double*)malloc(n*n*sizeof(double)); | ||
413 | CHECK(!B,MEM); | ||
414 | memcpy(B,ap,n*n*sizeof(double)); | ||
415 | integer lwork = -1; | 392 | integer lwork = -1; |
416 | integer res; | 393 | integer res; |
417 | // ask for optimal lwork | 394 | // ask for optimal lwork |
418 | double ans; | 395 | double ans; |
419 | //printf("ask dgeev\n"); | ||
420 | dgeev_ (&jobvl,&jobvr, | 396 | dgeev_ (&jobvl,&jobvr, |
421 | &n,B,&n, | 397 | &n,ap,&n, |
422 | (double*)sp, (double*)sp+n, | 398 | (double*)sp, (double*)sp+n, |
423 | up,&n, | 399 | up,&n, |
424 | vp,&n, | 400 | vp,&n, |
425 | &ans, &lwork, | 401 | &ans, &lwork, |
426 | &res); | 402 | &res); |
427 | lwork = ceil(ans); | 403 | lwork = ceil(ans); |
428 | //printf("ans = %d\n",lwork); | ||
429 | double * work = (double*)malloc(lwork*sizeof(double)); | 404 | double * work = (double*)malloc(lwork*sizeof(double)); |
430 | CHECK(!work,MEM); | 405 | CHECK(!work,MEM); |
431 | //printf("dgeev\n"); | ||
432 | dgeev_ (&jobvl,&jobvr, | 406 | dgeev_ (&jobvl,&jobvr, |
433 | &n,B,&n, | 407 | &n,ap,&n, |
434 | (double*)sp, (double*)sp+n, | 408 | (double*)sp, (double*)sp+n, |
435 | up,&n, | 409 | up,&n, |
436 | vp,&n, | 410 | vp,&n, |
@@ -438,37 +412,32 @@ int eig_l_R(KDMAT(a),DMAT(u), CVEC(s),DMAT(v)) { | |||
438 | &res); | 412 | &res); |
439 | CHECK(res,res); | 413 | CHECK(res,res); |
440 | free(work); | 414 | free(work); |
441 | free(B); | ||
442 | OK | 415 | OK |
443 | } | 416 | } |
444 | 417 | ||
445 | 418 | ||
446 | //////////////////// symmetric real eigensystem //////////// | 419 | //////////////////// symmetric real eigensystem //////////// |
447 | 420 | ||
448 | /* Subroutine */ int dsyev_(char *jobz, char *uplo, integer *n, doublereal *a, | 421 | int dsyev_(char *jobz, char *uplo, integer *n, doublereal *a, |
449 | integer *lda, doublereal *w, doublereal *work, integer *lwork, | 422 | integer *lda, doublereal *w, doublereal *work, integer *lwork, |
450 | integer *info); | 423 | integer *info); |
451 | 424 | ||
452 | int eig_l_S(int wantV,KDMAT(a),DVEC(s),DMAT(v)) { | 425 | int eig_l_S(int wantV,DVEC(s),ODMAT(v)) { |
453 | integer n = ar; | 426 | integer n = sn; |
454 | REQUIRES(ac==n && sn==n, BAD_SIZE); | ||
455 | REQUIRES(vr==n && vc==n, BAD_SIZE); | 427 | REQUIRES(vr==n && vc==n, BAD_SIZE); |
456 | char jobz = wantV?'V':'N'; | 428 | char jobz = wantV?'V':'N'; |
457 | DEBUGMSG("eig_l_S"); | 429 | DEBUGMSG("eig_l_S"); |
458 | memcpy(vp,ap,n*n*sizeof(double)); | ||
459 | integer lwork = -1; | 430 | integer lwork = -1; |
460 | char uplo = 'U'; | 431 | char uplo = 'U'; |
461 | integer res; | 432 | integer res; |
462 | // ask for optimal lwork | 433 | // ask for optimal lwork |
463 | double ans; | 434 | double ans; |
464 | //printf("ask dsyev\n"); | ||
465 | dsyev_ (&jobz,&uplo, | 435 | dsyev_ (&jobz,&uplo, |
466 | &n,vp,&n, | 436 | &n,vp,&n, |
467 | sp, | 437 | sp, |
468 | &ans, &lwork, | 438 | &ans, &lwork, |
469 | &res); | 439 | &res); |
470 | lwork = ceil(ans); | 440 | lwork = ceil(ans); |
471 | //printf("ans = %d\n",lwork); | ||
472 | double * work = (double*)malloc(lwork*sizeof(double)); | 441 | double * work = (double*)malloc(lwork*sizeof(double)); |
473 | CHECK(!work,MEM); | 442 | CHECK(!work,MEM); |
474 | dsyev_ (&jobz,&uplo, | 443 | dsyev_ (&jobz,&uplo, |
@@ -483,17 +452,15 @@ int eig_l_S(int wantV,KDMAT(a),DVEC(s),DMAT(v)) { | |||
483 | 452 | ||
484 | //////////////////// hermitian complex eigensystem //////////// | 453 | //////////////////// hermitian complex eigensystem //////////// |
485 | 454 | ||
486 | /* Subroutine */ int zheev_(char *jobz, char *uplo, integer *n, doublecomplex | 455 | int zheev_(char *jobz, char *uplo, integer *n, doublecomplex |
487 | *a, integer *lda, doublereal *w, doublecomplex *work, integer *lwork, | 456 | *a, integer *lda, doublereal *w, doublecomplex *work, integer *lwork, |
488 | doublereal *rwork, integer *info); | 457 | doublereal *rwork, integer *info); |
489 | 458 | ||
490 | int eig_l_H(int wantV,KCMAT(a),DVEC(s),CMAT(v)) { | 459 | int eig_l_H(int wantV,DVEC(s),OCMAT(v)) { |
491 | integer n = ar; | 460 | integer n = sn; |
492 | REQUIRES(ac==n && sn==n, BAD_SIZE); | ||
493 | REQUIRES(vr==n && vc==n, BAD_SIZE); | 461 | REQUIRES(vr==n && vc==n, BAD_SIZE); |
494 | char jobz = wantV?'V':'N'; | 462 | char jobz = wantV?'V':'N'; |
495 | DEBUGMSG("eig_l_H"); | 463 | DEBUGMSG("eig_l_H"); |
496 | memcpy(vp,ap,2*n*n*sizeof(double)); | ||
497 | double *rwork = (double*) malloc((3*n-2)*sizeof(double)); | 464 | double *rwork = (double*) malloc((3*n-2)*sizeof(double)); |
498 | CHECK(!rwork,MEM); | 465 | CHECK(!rwork,MEM); |
499 | integer lwork = -1; | 466 | integer lwork = -1; |
@@ -501,7 +468,6 @@ int eig_l_H(int wantV,KCMAT(a),DVEC(s),CMAT(v)) { | |||
501 | integer res; | 468 | integer res; |
502 | // ask for optimal lwork | 469 | // ask for optimal lwork |
503 | doublecomplex ans; | 470 | doublecomplex ans; |
504 | //printf("ask zheev\n"); | ||
505 | zheev_ (&jobz,&uplo, | 471 | zheev_ (&jobz,&uplo, |
506 | &n,vp,&n, | 472 | &n,vp,&n, |
507 | sp, | 473 | sp, |
@@ -509,7 +475,6 @@ int eig_l_H(int wantV,KCMAT(a),DVEC(s),CMAT(v)) { | |||
509 | rwork, | 475 | rwork, |
510 | &res); | 476 | &res); |
511 | lwork = ceil(ans.r); | 477 | lwork = ceil(ans.r); |
512 | //printf("ans = %d\n",lwork); | ||
513 | doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); | 478 | doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); |
514 | CHECK(!work,MEM); | 479 | CHECK(!work,MEM); |
515 | zheev_ (&jobz,&uplo, | 480 | zheev_ (&jobz,&uplo, |
@@ -526,80 +491,72 @@ int eig_l_H(int wantV,KCMAT(a),DVEC(s),CMAT(v)) { | |||
526 | 491 | ||
527 | //////////////////// general real linear system //////////// | 492 | //////////////////// general real linear system //////////// |
528 | 493 | ||
529 | /* Subroutine */ int dgesv_(integer *n, integer *nrhs, doublereal *a, integer | 494 | int dgesv_(integer *n, integer *nrhs, doublereal *a, integer |
530 | *lda, integer *ipiv, doublereal *b, integer *ldb, integer *info); | 495 | *lda, integer *ipiv, doublereal *b, integer *ldb, integer *info); |
531 | 496 | ||
532 | int linearSolveR_l(KDMAT(a),KDMAT(b),DMAT(x)) { | 497 | int linearSolveR_l(ODMAT(a),ODMAT(b)) { |
533 | integer n = ar; | 498 | integer n = ar; |
534 | integer nhrs = bc; | 499 | integer nhrs = bc; |
535 | REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); | 500 | REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); |
536 | DEBUGMSG("linearSolveR_l"); | 501 | DEBUGMSG("linearSolveR_l"); |
537 | double*AC = (double*)malloc(n*n*sizeof(double)); | ||
538 | memcpy(AC,ap,n*n*sizeof(double)); | ||
539 | memcpy(xp,bp,n*nhrs*sizeof(double)); | ||
540 | integer * ipiv = (integer*)malloc(n*sizeof(integer)); | 502 | integer * ipiv = (integer*)malloc(n*sizeof(integer)); |
541 | integer res; | 503 | integer res; |
542 | dgesv_ (&n,&nhrs, | 504 | dgesv_ (&n,&nhrs, |
543 | AC, &n, | 505 | ap, &n, |
544 | ipiv, | 506 | ipiv, |
545 | xp, &n, | 507 | bp, &n, |
546 | &res); | 508 | &res); |
547 | if(res>0) { | 509 | if(res>0) { |
548 | return SINGULAR; | 510 | return SINGULAR; |
549 | } | 511 | } |
550 | CHECK(res,res); | 512 | CHECK(res,res); |
551 | free(ipiv); | 513 | free(ipiv); |
552 | free(AC); | ||
553 | OK | 514 | OK |
554 | } | 515 | } |
555 | 516 | ||
556 | //////////////////// general complex linear system //////////// | 517 | //////////////////// general complex linear system //////////// |
557 | 518 | ||
558 | /* Subroutine */ int zgesv_(integer *n, integer *nrhs, doublecomplex *a, | 519 | int zgesv_(integer *n, integer *nrhs, doublecomplex *a, |
559 | integer *lda, integer *ipiv, doublecomplex *b, integer *ldb, integer * | 520 | integer *lda, integer *ipiv, doublecomplex *b, integer *ldb, integer * |
560 | info); | 521 | info); |
561 | 522 | ||
562 | int linearSolveC_l(KCMAT(a),KCMAT(b),CMAT(x)) { | 523 | int linearSolveC_l(OCMAT(a),OCMAT(b)) { |
563 | integer n = ar; | 524 | integer n = ar; |
564 | integer nhrs = bc; | 525 | integer nhrs = bc; |
565 | REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); | 526 | REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); |
566 | DEBUGMSG("linearSolveC_l"); | 527 | DEBUGMSG("linearSolveC_l"); |
567 | doublecomplex*AC = (doublecomplex*)malloc(n*n*sizeof(doublecomplex)); | ||
568 | memcpy(AC,ap,n*n*sizeof(doublecomplex)); | ||
569 | memcpy(xp,bp,n*nhrs*sizeof(doublecomplex)); | ||
570 | integer * ipiv = (integer*)malloc(n*sizeof(integer)); | 528 | integer * ipiv = (integer*)malloc(n*sizeof(integer)); |
571 | integer res; | 529 | integer res; |
572 | zgesv_ (&n,&nhrs, | 530 | zgesv_ (&n,&nhrs, |
573 | AC, &n, | 531 | ap, &n, |
574 | ipiv, | 532 | ipiv, |
575 | xp, &n, | 533 | bp, &n, |
576 | &res); | 534 | &res); |
577 | if(res>0) { | 535 | if(res>0) { |
578 | return SINGULAR; | 536 | return SINGULAR; |
579 | } | 537 | } |
580 | CHECK(res,res); | 538 | CHECK(res,res); |
581 | free(ipiv); | 539 | free(ipiv); |
582 | free(AC); | ||
583 | OK | 540 | OK |
584 | } | 541 | } |
585 | 542 | ||
586 | //////// symmetric positive definite real linear system using Cholesky //////////// | 543 | //////// symmetric positive definite real linear system using Cholesky //////////// |
587 | 544 | ||
588 | /* Subroutine */ int dpotrs_(char *uplo, integer *n, integer *nrhs, | 545 | int dpotrs_(char *uplo, integer *n, integer *nrhs, |
589 | doublereal *a, integer *lda, doublereal *b, integer *ldb, integer * | 546 | doublereal *a, integer *lda, doublereal *b, integer *ldb, integer * |
590 | info); | 547 | info); |
591 | 548 | ||
592 | int cholSolveR_l(KDMAT(a),KDMAT(b),DMAT(x)) { | 549 | int cholSolveR_l(KODMAT(a),ODMAT(b)) { |
593 | integer n = ar; | 550 | integer n = ar; |
551 | integer lda = aXc; | ||
594 | integer nhrs = bc; | 552 | integer nhrs = bc; |
595 | REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); | 553 | REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); |
596 | DEBUGMSG("cholSolveR_l"); | 554 | DEBUGMSG("cholSolveR_l"); |
597 | memcpy(xp,bp,n*nhrs*sizeof(double)); | ||
598 | integer res; | 555 | integer res; |
599 | dpotrs_ ("U", | 556 | dpotrs_ ("U", |
600 | &n,&nhrs, | 557 | &n,&nhrs, |
601 | (double*)ap, &n, | 558 | (double*)ap, &lda, |
602 | xp, &n, | 559 | bp, &n, |
603 | &res); | 560 | &res); |
604 | CHECK(res,res); | 561 | CHECK(res,res); |
605 | OK | 562 | OK |
@@ -607,21 +564,21 @@ int cholSolveR_l(KDMAT(a),KDMAT(b),DMAT(x)) { | |||
607 | 564 | ||
608 | //////// Hermitian positive definite real linear system using Cholesky //////////// | 565 | //////// Hermitian positive definite real linear system using Cholesky //////////// |
609 | 566 | ||
610 | /* Subroutine */ int zpotrs_(char *uplo, integer *n, integer *nrhs, | 567 | int zpotrs_(char *uplo, integer *n, integer *nrhs, |
611 | doublecomplex *a, integer *lda, doublecomplex *b, integer *ldb, | 568 | doublecomplex *a, integer *lda, doublecomplex *b, integer *ldb, |
612 | integer *info); | 569 | integer *info); |
613 | 570 | ||
614 | int cholSolveC_l(KCMAT(a),KCMAT(b),CMAT(x)) { | 571 | int cholSolveC_l(KOCMAT(a),OCMAT(b)) { |
615 | integer n = ar; | 572 | integer n = ar; |
573 | integer lda = aXc; | ||
616 | integer nhrs = bc; | 574 | integer nhrs = bc; |
617 | REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); | 575 | REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); |
618 | DEBUGMSG("cholSolveC_l"); | 576 | DEBUGMSG("cholSolveC_l"); |
619 | memcpy(xp,bp,n*nhrs*sizeof(doublecomplex)); | ||
620 | integer res; | 577 | integer res; |
621 | zpotrs_ ("U", | 578 | zpotrs_ ("U", |
622 | &n,&nhrs, | 579 | &n,&nhrs, |
623 | (doublecomplex*)ap, &n, | 580 | (doublecomplex*)ap, &lda, |
624 | xp, &n, | 581 | bp, &n, |
625 | &res); | 582 | &res); |
626 | CHECK(res,res); | 583 | CHECK(res,res); |
627 | OK | 584 | OK |
@@ -629,41 +586,30 @@ int cholSolveC_l(KCMAT(a),KCMAT(b),CMAT(x)) { | |||
629 | 586 | ||
630 | //////////////////// least squares real linear system //////////// | 587 | //////////////////// least squares real linear system //////////// |
631 | 588 | ||
632 | /* Subroutine */ int dgels_(char *trans, integer *m, integer *n, integer * | 589 | int dgels_(char *trans, integer *m, integer *n, integer * |
633 | nrhs, doublereal *a, integer *lda, doublereal *b, integer *ldb, | 590 | nrhs, doublereal *a, integer *lda, doublereal *b, integer *ldb, |
634 | doublereal *work, integer *lwork, integer *info); | 591 | doublereal *work, integer *lwork, integer *info); |
635 | 592 | ||
636 | int linearSolveLSR_l(KDMAT(a),KDMAT(b),DMAT(x)) { | 593 | int linearSolveLSR_l(ODMAT(a),ODMAT(b)) { |
637 | integer m = ar; | 594 | integer m = ar; |
638 | integer n = ac; | 595 | integer n = ac; |
639 | integer nrhs = bc; | 596 | integer nrhs = bc; |
640 | integer ldb = xr; | 597 | integer ldb = bXc; |
641 | REQUIRES(m>=1 && n>=1 && ar==br && xr==MAX(m,n) && xc == bc, BAD_SIZE); | 598 | REQUIRES(m>=1 && n>=1 && br==MAX(m,n), BAD_SIZE); |
642 | DEBUGMSG("linearSolveLSR_l"); | 599 | DEBUGMSG("linearSolveLSR_l"); |
643 | double*AC = (double*)malloc(m*n*sizeof(double)); | ||
644 | memcpy(AC,ap,m*n*sizeof(double)); | ||
645 | if (m>=n) { | ||
646 | memcpy(xp,bp,m*nrhs*sizeof(double)); | ||
647 | } else { | ||
648 | int k; | ||
649 | for(k = 0; k<nrhs; k++) { | ||
650 | memcpy(xp+ldb*k,bp+m*k,m*sizeof(double)); | ||
651 | } | ||
652 | } | ||
653 | integer res; | 600 | integer res; |
654 | integer lwork = -1; | 601 | integer lwork = -1; |
655 | double ans; | 602 | double ans; |
656 | dgels_ ("N",&m,&n,&nrhs, | 603 | dgels_ ("N",&m,&n,&nrhs, |
657 | AC,&m, | 604 | ap,&m, |
658 | xp,&ldb, | 605 | bp,&ldb, |
659 | &ans,&lwork, | 606 | &ans,&lwork, |
660 | &res); | 607 | &res); |
661 | lwork = ceil(ans); | 608 | lwork = ceil(ans); |
662 | //printf("ans = %d\n",lwork); | ||
663 | double * work = (double*)malloc(lwork*sizeof(double)); | 609 | double * work = (double*)malloc(lwork*sizeof(double)); |
664 | dgels_ ("N",&m,&n,&nrhs, | 610 | dgels_ ("N",&m,&n,&nrhs, |
665 | AC,&m, | 611 | ap,&m, |
666 | xp,&ldb, | 612 | bp,&ldb, |
667 | work,&lwork, | 613 | work,&lwork, |
668 | &res); | 614 | &res); |
669 | if(res>0) { | 615 | if(res>0) { |
@@ -671,47 +617,35 @@ int linearSolveLSR_l(KDMAT(a),KDMAT(b),DMAT(x)) { | |||
671 | } | 617 | } |
672 | CHECK(res,res); | 618 | CHECK(res,res); |
673 | free(work); | 619 | free(work); |
674 | free(AC); | ||
675 | OK | 620 | OK |
676 | } | 621 | } |
677 | 622 | ||
678 | //////////////////// least squares complex linear system //////////// | 623 | //////////////////// least squares complex linear system //////////// |
679 | 624 | ||
680 | /* Subroutine */ int zgels_(char *trans, integer *m, integer *n, integer * | 625 | int zgels_(char *trans, integer *m, integer *n, integer * |
681 | nrhs, doublecomplex *a, integer *lda, doublecomplex *b, integer *ldb, | 626 | nrhs, doublecomplex *a, integer *lda, doublecomplex *b, integer *ldb, |
682 | doublecomplex *work, integer *lwork, integer *info); | 627 | doublecomplex *work, integer *lwork, integer *info); |
683 | 628 | ||
684 | int linearSolveLSC_l(KCMAT(a),KCMAT(b),CMAT(x)) { | 629 | int linearSolveLSC_l(OCMAT(a),OCMAT(b)) { |
685 | integer m = ar; | 630 | integer m = ar; |
686 | integer n = ac; | 631 | integer n = ac; |
687 | integer nrhs = bc; | 632 | integer nrhs = bc; |
688 | integer ldb = xr; | 633 | integer ldb = bXc; |
689 | REQUIRES(m>=1 && n>=1 && ar==br && xr==MAX(m,n) && xc == bc, BAD_SIZE); | 634 | REQUIRES(m>=1 && n>=1 && br==MAX(m,n), BAD_SIZE); |
690 | DEBUGMSG("linearSolveLSC_l"); | 635 | DEBUGMSG("linearSolveLSC_l"); |
691 | doublecomplex*AC = (doublecomplex*)malloc(m*n*sizeof(doublecomplex)); | ||
692 | memcpy(AC,ap,m*n*sizeof(doublecomplex)); | ||
693 | if (m>=n) { | ||
694 | memcpy(xp,bp,m*nrhs*sizeof(doublecomplex)); | ||
695 | } else { | ||
696 | int k; | ||
697 | for(k = 0; k<nrhs; k++) { | ||
698 | memcpy(xp+ldb*k,bp+m*k,m*sizeof(doublecomplex)); | ||
699 | } | ||
700 | } | ||
701 | integer res; | 636 | integer res; |
702 | integer lwork = -1; | 637 | integer lwork = -1; |
703 | doublecomplex ans; | 638 | doublecomplex ans; |
704 | zgels_ ("N",&m,&n,&nrhs, | 639 | zgels_ ("N",&m,&n,&nrhs, |
705 | AC,&m, | 640 | ap,&m, |
706 | xp,&ldb, | 641 | bp,&ldb, |
707 | &ans,&lwork, | 642 | &ans,&lwork, |
708 | &res); | 643 | &res); |
709 | lwork = ceil(ans.r); | 644 | lwork = ceil(ans.r); |
710 | //printf("ans = %d\n",lwork); | ||
711 | doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); | 645 | doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); |
712 | zgels_ ("N",&m,&n,&nrhs, | 646 | zgels_ ("N",&m,&n,&nrhs, |
713 | AC,&m, | 647 | ap,&m, |
714 | xp,&ldb, | 648 | bp,&ldb, |
715 | work,&lwork, | 649 | work,&lwork, |
716 | &res); | 650 | &res); |
717 | if(res>0) { | 651 | if(res>0) { |
@@ -719,52 +653,40 @@ int linearSolveLSC_l(KCMAT(a),KCMAT(b),CMAT(x)) { | |||
719 | } | 653 | } |
720 | CHECK(res,res); | 654 | CHECK(res,res); |
721 | free(work); | 655 | free(work); |
722 | free(AC); | ||
723 | OK | 656 | OK |
724 | } | 657 | } |
725 | 658 | ||
726 | //////////////////// least squares real linear system using SVD //////////// | 659 | //////////////////// least squares real linear system using SVD //////////// |
727 | 660 | ||
728 | /* Subroutine */ int dgelss_(integer *m, integer *n, integer *nrhs, | 661 | int dgelss_(integer *m, integer *n, integer *nrhs, |
729 | doublereal *a, integer *lda, doublereal *b, integer *ldb, doublereal * | 662 | doublereal *a, integer *lda, doublereal *b, integer *ldb, doublereal * |
730 | s, doublereal *rcond, integer *rank, doublereal *work, integer *lwork, | 663 | s, doublereal *rcond, integer *rank, doublereal *work, integer *lwork, |
731 | integer *info); | 664 | integer *info); |
732 | 665 | ||
733 | int linearSolveSVDR_l(double rcond,KDMAT(a),KDMAT(b),DMAT(x)) { | 666 | int linearSolveSVDR_l(double rcond,ODMAT(a),ODMAT(b)) { |
734 | integer m = ar; | 667 | integer m = ar; |
735 | integer n = ac; | 668 | integer n = ac; |
736 | integer nrhs = bc; | 669 | integer nrhs = bc; |
737 | integer ldb = xr; | 670 | integer ldb = bXc; |
738 | REQUIRES(m>=1 && n>=1 && ar==br && xr==MAX(m,n) && xc == bc, BAD_SIZE); | 671 | REQUIRES(m>=1 && n>=1 && br==MAX(m,n), BAD_SIZE); |
739 | DEBUGMSG("linearSolveSVDR_l"); | 672 | DEBUGMSG("linearSolveSVDR_l"); |
740 | double*AC = (double*)malloc(m*n*sizeof(double)); | ||
741 | double*S = (double*)malloc(MIN(m,n)*sizeof(double)); | 673 | double*S = (double*)malloc(MIN(m,n)*sizeof(double)); |
742 | memcpy(AC,ap,m*n*sizeof(double)); | ||
743 | if (m>=n) { | ||
744 | memcpy(xp,bp,m*nrhs*sizeof(double)); | ||
745 | } else { | ||
746 | int k; | ||
747 | for(k = 0; k<nrhs; k++) { | ||
748 | memcpy(xp+ldb*k,bp+m*k,m*sizeof(double)); | ||
749 | } | ||
750 | } | ||
751 | integer res; | 674 | integer res; |
752 | integer lwork = -1; | 675 | integer lwork = -1; |
753 | integer rank; | 676 | integer rank; |
754 | double ans; | 677 | double ans; |
755 | dgelss_ (&m,&n,&nrhs, | 678 | dgelss_ (&m,&n,&nrhs, |
756 | AC,&m, | 679 | ap,&m, |
757 | xp,&ldb, | 680 | bp,&ldb, |
758 | S, | 681 | S, |
759 | &rcond,&rank, | 682 | &rcond,&rank, |
760 | &ans,&lwork, | 683 | &ans,&lwork, |
761 | &res); | 684 | &res); |
762 | lwork = ceil(ans); | 685 | lwork = ceil(ans); |
763 | //printf("ans = %d\n",lwork); | ||
764 | double * work = (double*)malloc(lwork*sizeof(double)); | 686 | double * work = (double*)malloc(lwork*sizeof(double)); |
765 | dgelss_ (&m,&n,&nrhs, | 687 | dgelss_ (&m,&n,&nrhs, |
766 | AC,&m, | 688 | ap,&m, |
767 | xp,&ldb, | 689 | bp,&ldb, |
768 | S, | 690 | S, |
769 | &rcond,&rank, | 691 | &rcond,&rank, |
770 | work,&lwork, | 692 | work,&lwork, |
@@ -775,57 +697,43 @@ int linearSolveSVDR_l(double rcond,KDMAT(a),KDMAT(b),DMAT(x)) { | |||
775 | CHECK(res,res); | 697 | CHECK(res,res); |
776 | free(work); | 698 | free(work); |
777 | free(S); | 699 | free(S); |
778 | free(AC); | ||
779 | OK | 700 | OK |
780 | } | 701 | } |
781 | 702 | ||
782 | //////////////////// least squares complex linear system using SVD //////////// | 703 | //////////////////// least squares complex linear system using SVD //////////// |
783 | 704 | ||
784 | // not in clapack.h | ||
785 | |||
786 | int zgelss_(integer *m, integer *n, integer *nhrs, | 705 | int zgelss_(integer *m, integer *n, integer *nhrs, |
787 | doublecomplex *a, integer *lda, doublecomplex *b, integer *ldb, doublereal *s, | 706 | doublecomplex *a, integer *lda, doublecomplex *b, integer *ldb, doublereal *s, |
788 | doublereal *rcond, integer* rank, | 707 | doublereal *rcond, integer* rank, |
789 | doublecomplex *work, integer* lwork, doublereal* rwork, | 708 | doublecomplex *work, integer* lwork, doublereal* rwork, |
790 | integer *info); | 709 | integer *info); |
791 | 710 | ||
792 | int linearSolveSVDC_l(double rcond, KCMAT(a),KCMAT(b),CMAT(x)) { | 711 | int linearSolveSVDC_l(double rcond, OCMAT(a),OCMAT(b)) { |
793 | integer m = ar; | 712 | integer m = ar; |
794 | integer n = ac; | 713 | integer n = ac; |
795 | integer nrhs = bc; | 714 | integer nrhs = bc; |
796 | integer ldb = xr; | 715 | integer ldb = bXc; |
797 | REQUIRES(m>=1 && n>=1 && ar==br && xr==MAX(m,n) && xc == bc, BAD_SIZE); | 716 | REQUIRES(m>=1 && n>=1 && br==MAX(m,n), BAD_SIZE); |
798 | DEBUGMSG("linearSolveSVDC_l"); | 717 | DEBUGMSG("linearSolveSVDC_l"); |
799 | doublecomplex*AC = (doublecomplex*)malloc(m*n*sizeof(doublecomplex)); | ||
800 | double*S = (double*)malloc(MIN(m,n)*sizeof(double)); | 718 | double*S = (double*)malloc(MIN(m,n)*sizeof(double)); |
801 | double*RWORK = (double*)malloc(5*MIN(m,n)*sizeof(double)); | 719 | double*RWORK = (double*)malloc(5*MIN(m,n)*sizeof(double)); |
802 | memcpy(AC,ap,m*n*sizeof(doublecomplex)); | ||
803 | if (m>=n) { | ||
804 | memcpy(xp,bp,m*nrhs*sizeof(doublecomplex)); | ||
805 | } else { | ||
806 | int k; | ||
807 | for(k = 0; k<nrhs; k++) { | ||
808 | memcpy(xp+ldb*k,bp+m*k,m*sizeof(doublecomplex)); | ||
809 | } | ||
810 | } | ||
811 | integer res; | 720 | integer res; |
812 | integer lwork = -1; | 721 | integer lwork = -1; |
813 | integer rank; | 722 | integer rank; |
814 | doublecomplex ans; | 723 | doublecomplex ans; |
815 | zgelss_ (&m,&n,&nrhs, | 724 | zgelss_ (&m,&n,&nrhs, |
816 | AC,&m, | 725 | ap,&m, |
817 | xp,&ldb, | 726 | bp,&ldb, |
818 | S, | 727 | S, |
819 | &rcond,&rank, | 728 | &rcond,&rank, |
820 | &ans,&lwork, | 729 | &ans,&lwork, |
821 | RWORK, | 730 | RWORK, |
822 | &res); | 731 | &res); |
823 | lwork = ceil(ans.r); | 732 | lwork = ceil(ans.r); |
824 | //printf("ans = %d\n",lwork); | ||
825 | doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); | 733 | doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); |
826 | zgelss_ (&m,&n,&nrhs, | 734 | zgelss_ (&m,&n,&nrhs, |
827 | AC,&m, | 735 | ap,&m, |
828 | xp,&ldb, | 736 | bp,&ldb, |
829 | S, | 737 | S, |
830 | &rcond,&rank, | 738 | &rcond,&rank, |
831 | work,&lwork, | 739 | work,&lwork, |
@@ -838,20 +746,17 @@ int linearSolveSVDC_l(double rcond, KCMAT(a),KCMAT(b),CMAT(x)) { | |||
838 | free(work); | 746 | free(work); |
839 | free(RWORK); | 747 | free(RWORK); |
840 | free(S); | 748 | free(S); |
841 | free(AC); | ||
842 | OK | 749 | OK |
843 | } | 750 | } |
844 | 751 | ||
845 | //////////////////// Cholesky factorization ///////////////////////// | 752 | //////////////////// Cholesky factorization ///////////////////////// |
846 | 753 | ||
847 | /* Subroutine */ int zpotrf_(char *uplo, integer *n, doublecomplex *a, | 754 | int zpotrf_(char *uplo, integer *n, doublecomplex *a, integer *lda, integer *info); |
848 | integer *lda, integer *info); | ||
849 | 755 | ||
850 | int chol_l_H(KCMAT(a),CMAT(l)) { | 756 | int chol_l_H(OCMAT(l)) { |
851 | integer n = ar; | 757 | integer n = lr; |
852 | REQUIRES(n>=1 && ac == n && lr==n && lc==n,BAD_SIZE); | 758 | REQUIRES(n>=1 && lc == n,BAD_SIZE); |
853 | DEBUGMSG("chol_l_H"); | 759 | DEBUGMSG("chol_l_H"); |
854 | memcpy(lp,ap,n*n*sizeof(doublecomplex)); | ||
855 | char uplo = 'U'; | 760 | char uplo = 'U'; |
856 | integer res; | 761 | integer res; |
857 | zpotrf_ (&uplo,&n,lp,&n,&res); | 762 | zpotrf_ (&uplo,&n,lp,&n,&res); |
@@ -859,32 +764,30 @@ int chol_l_H(KCMAT(a),CMAT(l)) { | |||
859 | CHECK(res,res); | 764 | CHECK(res,res); |
860 | doublecomplex zero = {0.,0.}; | 765 | doublecomplex zero = {0.,0.}; |
861 | int r,c; | 766 | int r,c; |
862 | for (r=0; r<lr-1; r++) { | 767 | for (r=0; r<lr; r++) { |
863 | for(c=r+1; c<lc; c++) { | 768 | for(c=0; c<r; c++) { |
864 | lp[r*lc+c] = zero; | 769 | AT(l,r,c) = zero; |
865 | } | 770 | } |
866 | } | 771 | } |
867 | OK | 772 | OK |
868 | } | 773 | } |
869 | 774 | ||
870 | 775 | ||
871 | /* Subroutine */ int dpotrf_(char *uplo, integer *n, doublereal *a, integer * | 776 | int dpotrf_(char *uplo, integer *n, doublereal *a, integer * lda, integer *info); |
872 | lda, integer *info); | ||
873 | 777 | ||
874 | int chol_l_S(KDMAT(a),DMAT(l)) { | 778 | int chol_l_S(ODMAT(l)) { |
875 | integer n = ar; | 779 | integer n = lr; |
876 | REQUIRES(n>=1 && ac == n && lr==n && lc==n,BAD_SIZE); | 780 | REQUIRES(n>=1 && lc == n,BAD_SIZE); |
877 | DEBUGMSG("chol_l_S"); | 781 | DEBUGMSG("chol_l_S"); |
878 | memcpy(lp,ap,n*n*sizeof(double)); | ||
879 | char uplo = 'U'; | 782 | char uplo = 'U'; |
880 | integer res; | 783 | integer res; |
881 | dpotrf_ (&uplo,&n,lp,&n,&res); | 784 | dpotrf_ (&uplo,&n,lp,&n,&res); |
882 | CHECK(res>0,NODEFPOS); | 785 | CHECK(res>0,NODEFPOS); |
883 | CHECK(res,res); | 786 | CHECK(res,res); |
884 | int r,c; | 787 | int r,c; |
885 | for (r=0; r<lr-1; r++) { | 788 | for (r=0; r<lr; r++) { |
886 | for(c=r+1; c<lc; c++) { | 789 | for(c=0; c<r; c++) { |
887 | lp[r*lc+c] = 0.; | 790 | AT(l,r,c) = 0.; |
888 | } | 791 | } |
889 | } | 792 | } |
890 | OK | 793 | OK |
@@ -892,18 +795,17 @@ int chol_l_S(KDMAT(a),DMAT(l)) { | |||
892 | 795 | ||
893 | //////////////////// QR factorization ///////////////////////// | 796 | //////////////////// QR factorization ///////////////////////// |
894 | 797 | ||
895 | /* Subroutine */ int dgeqr2_(integer *m, integer *n, doublereal *a, integer * | 798 | int dgeqr2_(integer *m, integer *n, doublereal *a, integer * |
896 | lda, doublereal *tau, doublereal *work, integer *info); | 799 | lda, doublereal *tau, doublereal *work, integer *info); |
897 | 800 | ||
898 | int qr_l_R(KDMAT(a), DVEC(tau), DMAT(r)) { | 801 | int qr_l_R(DVEC(tau), ODMAT(r)) { |
899 | integer m = ar; | 802 | integer m = rr; |
900 | integer n = ac; | 803 | integer n = rc; |
901 | integer mn = MIN(m,n); | 804 | integer mn = MIN(m,n); |
902 | REQUIRES(m>=1 && n >=1 && rr== m && rc == n && taun == mn, BAD_SIZE); | 805 | REQUIRES(m>=1 && n >=1 && taun == mn, BAD_SIZE); |
903 | DEBUGMSG("qr_l_R"); | 806 | DEBUGMSG("qr_l_R"); |
904 | double *WORK = (double*)malloc(n*sizeof(double)); | 807 | double *WORK = (double*)malloc(n*sizeof(double)); |
905 | CHECK(!WORK,MEM); | 808 | CHECK(!WORK,MEM); |
906 | memcpy(rp,ap,m*n*sizeof(double)); | ||
907 | integer res; | 809 | integer res; |
908 | dgeqr2_ (&m,&n,rp,&m,taup,WORK,&res); | 810 | dgeqr2_ (&m,&n,rp,&m,taup,WORK,&res); |
909 | CHECK(res,res); | 811 | CHECK(res,res); |
@@ -911,18 +813,17 @@ int qr_l_R(KDMAT(a), DVEC(tau), DMAT(r)) { | |||
911 | OK | 813 | OK |
912 | } | 814 | } |
913 | 815 | ||
914 | /* Subroutine */ int zgeqr2_(integer *m, integer *n, doublecomplex *a, | 816 | int zgeqr2_(integer *m, integer *n, doublecomplex *a, |
915 | integer *lda, doublecomplex *tau, doublecomplex *work, integer *info); | 817 | integer *lda, doublecomplex *tau, doublecomplex *work, integer *info); |
916 | 818 | ||
917 | int qr_l_C(KCMAT(a), CVEC(tau), CMAT(r)) { | 819 | int qr_l_C(CVEC(tau), OCMAT(r)) { |
918 | integer m = ar; | 820 | integer m = rr; |
919 | integer n = ac; | 821 | integer n = rc; |
920 | integer mn = MIN(m,n); | 822 | integer mn = MIN(m,n); |
921 | REQUIRES(m>=1 && n >=1 && rr== m && rc == n && taun == mn, BAD_SIZE); | 823 | REQUIRES(m>=1 && n >=1 && taun == mn, BAD_SIZE); |
922 | DEBUGMSG("qr_l_C"); | 824 | DEBUGMSG("qr_l_C"); |
923 | doublecomplex *WORK = (doublecomplex*)malloc(n*sizeof(doublecomplex)); | 825 | doublecomplex *WORK = (doublecomplex*)malloc(n*sizeof(doublecomplex)); |
924 | CHECK(!WORK,MEM); | 826 | CHECK(!WORK,MEM); |
925 | memcpy(rp,ap,m*n*sizeof(doublecomplex)); | ||
926 | integer res; | 827 | integer res; |
927 | zgeqr2_ (&m,&n,rp,&m,taup,WORK,&res); | 828 | zgeqr2_ (&m,&n,rp,&m,taup,WORK,&res); |
928 | CHECK(res,res); | 829 | CHECK(res,res); |
@@ -930,19 +831,18 @@ int qr_l_C(KCMAT(a), CVEC(tau), CMAT(r)) { | |||
930 | OK | 831 | OK |
931 | } | 832 | } |
932 | 833 | ||
933 | /* Subroutine */ int dorgqr_(integer *m, integer *n, integer *k, doublereal * | 834 | int dorgqr_(integer *m, integer *n, integer *k, doublereal * |
934 | a, integer *lda, doublereal *tau, doublereal *work, integer *lwork, | 835 | a, integer *lda, doublereal *tau, doublereal *work, integer *lwork, |
935 | integer *info); | 836 | integer *info); |
936 | 837 | ||
937 | int c_dorgqr(KDMAT(a), KDVEC(tau), DMAT(r)) { | 838 | int c_dorgqr(KDVEC(tau), ODMAT(r)) { |
938 | integer m = ar; | 839 | integer m = rr; |
939 | integer n = MIN(ac,ar); | 840 | integer n = MIN(rc,rr); |
940 | integer k = taun; | 841 | integer k = taun; |
941 | DEBUGMSG("c_dorgqr"); | 842 | DEBUGMSG("c_dorgqr"); |
942 | integer lwork = 8*n; // FIXME | 843 | integer lwork = 8*n; // FIXME |
943 | double *WORK = (double*)malloc(lwork*sizeof(double)); | 844 | double *WORK = (double*)malloc(lwork*sizeof(double)); |
944 | CHECK(!WORK,MEM); | 845 | CHECK(!WORK,MEM); |
945 | memcpy(rp,ap,m*k*sizeof(double)); | ||
946 | integer res; | 846 | integer res; |
947 | dorgqr_ (&m,&n,&k,rp,&m,(double*)taup,WORK,&lwork,&res); | 847 | dorgqr_ (&m,&n,&k,rp,&m,(double*)taup,WORK,&lwork,&res); |
948 | CHECK(res,res); | 848 | CHECK(res,res); |
@@ -950,19 +850,18 @@ int c_dorgqr(KDMAT(a), KDVEC(tau), DMAT(r)) { | |||
950 | OK | 850 | OK |
951 | } | 851 | } |
952 | 852 | ||
953 | /* Subroutine */ int zungqr_(integer *m, integer *n, integer *k, | 853 | int zungqr_(integer *m, integer *n, integer *k, |
954 | doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex * | 854 | doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex * |
955 | work, integer *lwork, integer *info); | 855 | work, integer *lwork, integer *info); |
956 | 856 | ||
957 | int c_zungqr(KCMAT(a), KCVEC(tau), CMAT(r)) { | 857 | int c_zungqr(KCVEC(tau), OCMAT(r)) { |
958 | integer m = ar; | 858 | integer m = rr; |
959 | integer n = MIN(ac,ar); | 859 | integer n = MIN(rc,rr); |
960 | integer k = taun; | 860 | integer k = taun; |
961 | DEBUGMSG("z_ungqr"); | 861 | DEBUGMSG("z_ungqr"); |
962 | integer lwork = 8*n; // FIXME | 862 | integer lwork = 8*n; // FIXME |
963 | doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); | 863 | doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); |
964 | CHECK(!WORK,MEM); | 864 | CHECK(!WORK,MEM); |
965 | memcpy(rp,ap,m*k*sizeof(doublecomplex)); | ||
966 | integer res; | 865 | integer res; |
967 | zungqr_ (&m,&n,&k,rp,&m,(doublecomplex*)taup,WORK,&lwork,&res); | 866 | zungqr_ (&m,&n,&k,rp,&m,(doublecomplex*)taup,WORK,&lwork,&res); |
968 | CHECK(res,res); | 867 | CHECK(res,res); |
@@ -973,20 +872,19 @@ int c_zungqr(KCMAT(a), KCVEC(tau), CMAT(r)) { | |||
973 | 872 | ||
974 | //////////////////// Hessenberg factorization ///////////////////////// | 873 | //////////////////// Hessenberg factorization ///////////////////////// |
975 | 874 | ||
976 | /* Subroutine */ int dgehrd_(integer *n, integer *ilo, integer *ihi, | 875 | int dgehrd_(integer *n, integer *ilo, integer *ihi, |
977 | doublereal *a, integer *lda, doublereal *tau, doublereal *work, | 876 | doublereal *a, integer *lda, doublereal *tau, doublereal *work, |
978 | integer *lwork, integer *info); | 877 | integer *lwork, integer *info); |
979 | 878 | ||
980 | int hess_l_R(KDMAT(a), DVEC(tau), DMAT(r)) { | 879 | int hess_l_R(DVEC(tau), ODMAT(r)) { |
981 | integer m = ar; | 880 | integer m = rr; |
982 | integer n = ac; | 881 | integer n = rc; |
983 | integer mn = MIN(m,n); | 882 | integer mn = MIN(m,n); |
984 | REQUIRES(m>=1 && n == m && rr== m && rc == n && taun == mn-1, BAD_SIZE); | 883 | REQUIRES(m>=1 && n == m && taun == mn-1, BAD_SIZE); |
985 | DEBUGMSG("hess_l_R"); | 884 | DEBUGMSG("hess_l_R"); |
986 | integer lwork = 5*n; // fixme | 885 | integer lwork = 5*n; // FIXME |
987 | double *WORK = (double*)malloc(lwork*sizeof(double)); | 886 | double *WORK = (double*)malloc(lwork*sizeof(double)); |
988 | CHECK(!WORK,MEM); | 887 | CHECK(!WORK,MEM); |
989 | memcpy(rp,ap,m*n*sizeof(double)); | ||
990 | integer res; | 888 | integer res; |
991 | integer one = 1; | 889 | integer one = 1; |
992 | dgehrd_ (&n,&one,&n,rp,&n,taup,WORK,&lwork,&res); | 890 | dgehrd_ (&n,&one,&n,rp,&n,taup,WORK,&lwork,&res); |
@@ -996,20 +894,19 @@ int hess_l_R(KDMAT(a), DVEC(tau), DMAT(r)) { | |||
996 | } | 894 | } |
997 | 895 | ||
998 | 896 | ||
999 | /* Subroutine */ int zgehrd_(integer *n, integer *ilo, integer *ihi, | 897 | int zgehrd_(integer *n, integer *ilo, integer *ihi, |
1000 | doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex * | 898 | doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex * |
1001 | work, integer *lwork, integer *info); | 899 | work, integer *lwork, integer *info); |
1002 | 900 | ||
1003 | int hess_l_C(KCMAT(a), CVEC(tau), CMAT(r)) { | 901 | int hess_l_C(CVEC(tau), OCMAT(r)) { |
1004 | integer m = ar; | 902 | integer m = rr; |
1005 | integer n = ac; | 903 | integer n = rc; |
1006 | integer mn = MIN(m,n); | 904 | integer mn = MIN(m,n); |
1007 | REQUIRES(m>=1 && n == m && rr== m && rc == n && taun == mn-1, BAD_SIZE); | 905 | REQUIRES(m>=1 && n == m && taun == mn-1, BAD_SIZE); |
1008 | DEBUGMSG("hess_l_C"); | 906 | DEBUGMSG("hess_l_C"); |
1009 | integer lwork = 5*n; // fixme | 907 | integer lwork = 5*n; // FIXME |
1010 | doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); | 908 | doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); |
1011 | CHECK(!WORK,MEM); | 909 | CHECK(!WORK,MEM); |
1012 | memcpy(rp,ap,m*n*sizeof(doublecomplex)); | ||
1013 | integer res; | 910 | integer res; |
1014 | integer one = 1; | 911 | integer one = 1; |
1015 | zgehrd_ (&n,&one,&n,rp,&n,taup,WORK,&lwork,&res); | 912 | zgehrd_ (&n,&one,&n,rp,&n,taup,WORK,&lwork,&res); |
@@ -1020,23 +917,17 @@ int hess_l_C(KCMAT(a), CVEC(tau), CMAT(r)) { | |||
1020 | 917 | ||
1021 | //////////////////// Schur factorization ///////////////////////// | 918 | //////////////////// Schur factorization ///////////////////////// |
1022 | 919 | ||
1023 | /* Subroutine */ int dgees_(char *jobvs, char *sort, L_fp select, integer *n, | 920 | int dgees_(char *jobvs, char *sort, L_fp select, integer *n, |
1024 | doublereal *a, integer *lda, integer *sdim, doublereal *wr, | 921 | doublereal *a, integer *lda, integer *sdim, doublereal *wr, |
1025 | doublereal *wi, doublereal *vs, integer *ldvs, doublereal *work, | 922 | doublereal *wi, doublereal *vs, integer *ldvs, doublereal *work, |
1026 | integer *lwork, logical *bwork, integer *info); | 923 | integer *lwork, logical *bwork, integer *info); |
1027 | 924 | ||
1028 | int schur_l_R(KDMAT(a), DMAT(u), DMAT(s)) { | 925 | int schur_l_R(ODMAT(u), ODMAT(s)) { |
1029 | integer m = ar; | 926 | integer m = sr; |
1030 | integer n = ac; | 927 | integer n = sc; |
1031 | REQUIRES(m>=1 && n==m && ur==n && uc==n && sr==n && sc==n, BAD_SIZE); | 928 | REQUIRES(m>=1 && n==m && ur==n && uc==n, BAD_SIZE); |
1032 | DEBUGMSG("schur_l_R"); | 929 | DEBUGMSG("schur_l_R"); |
1033 | //int k; | 930 | integer lwork = 6*n; // FIXME |
1034 | //printf("---------------------------\n"); | ||
1035 | //printf("%p: ",ap); for(k=0;k<n*n;k++) printf("%f ",ap[k]); printf("\n"); | ||
1036 | //printf("%p: ",up); for(k=0;k<n*n;k++) printf("%f ",up[k]); printf("\n"); | ||
1037 | //printf("%p: ",sp); for(k=0;k<n*n;k++) printf("%f ",sp[k]); printf("\n"); | ||
1038 | memcpy(sp,ap,n*n*sizeof(double)); | ||
1039 | integer lwork = 6*n; // fixme | ||
1040 | double *WORK = (double*)malloc(lwork*sizeof(double)); | 931 | double *WORK = (double*)malloc(lwork*sizeof(double)); |
1041 | double *WR = (double*)malloc(n*sizeof(double)); | 932 | double *WR = (double*)malloc(n*sizeof(double)); |
1042 | double *WI = (double*)malloc(n*sizeof(double)); | 933 | double *WI = (double*)malloc(n*sizeof(double)); |
@@ -1045,9 +936,6 @@ int schur_l_R(KDMAT(a), DMAT(u), DMAT(s)) { | |||
1045 | integer res; | 936 | integer res; |
1046 | integer sdim; | 937 | integer sdim; |
1047 | dgees_ ("V","N",NULL,&n,sp,&n,&sdim,WR,WI,up,&n,WORK,&lwork,BWORK,&res); | 938 | dgees_ ("V","N",NULL,&n,sp,&n,&sdim,WR,WI,up,&n,WORK,&lwork,BWORK,&res); |
1048 | //printf("%p: ",ap); for(k=0;k<n*n;k++) printf("%f ",ap[k]); printf("\n"); | ||
1049 | //printf("%p: ",up); for(k=0;k<n*n;k++) printf("%f ",up[k]); printf("\n"); | ||
1050 | //printf("%p: ",sp); for(k=0;k<n*n;k++) printf("%f ",sp[k]); printf("\n"); | ||
1051 | if(res>0) { | 939 | if(res>0) { |
1052 | return NOCONVER; | 940 | return NOCONVER; |
1053 | } | 941 | } |
@@ -1060,18 +948,17 @@ int schur_l_R(KDMAT(a), DMAT(u), DMAT(s)) { | |||
1060 | } | 948 | } |
1061 | 949 | ||
1062 | 950 | ||
1063 | /* Subroutine */ int zgees_(char *jobvs, char *sort, L_fp select, integer *n, | 951 | int zgees_(char *jobvs, char *sort, L_fp select, integer *n, |
1064 | doublecomplex *a, integer *lda, integer *sdim, doublecomplex *w, | 952 | doublecomplex *a, integer *lda, integer *sdim, doublecomplex *w, |
1065 | doublecomplex *vs, integer *ldvs, doublecomplex *work, integer *lwork, | 953 | doublecomplex *vs, integer *ldvs, doublecomplex *work, integer *lwork, |
1066 | doublereal *rwork, logical *bwork, integer *info); | 954 | doublereal *rwork, logical *bwork, integer *info); |
1067 | 955 | ||
1068 | int schur_l_C(KCMAT(a), CMAT(u), CMAT(s)) { | 956 | int schur_l_C(OCMAT(u), OCMAT(s)) { |
1069 | integer m = ar; | 957 | integer m = sr; |
1070 | integer n = ac; | 958 | integer n = sc; |
1071 | REQUIRES(m>=1 && n==m && ur==n && uc==n && sr==n && sc==n, BAD_SIZE); | 959 | REQUIRES(m>=1 && n==m && ur==n && uc==n, BAD_SIZE); |
1072 | DEBUGMSG("schur_l_C"); | 960 | DEBUGMSG("schur_l_C"); |
1073 | memcpy(sp,ap,n*n*sizeof(doublecomplex)); | 961 | integer lwork = 6*n; // FIXME |
1074 | integer lwork = 6*n; // fixme | ||
1075 | doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); | 962 | doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); |
1076 | doublecomplex *W = (doublecomplex*)malloc(n*sizeof(doublecomplex)); | 963 | doublecomplex *W = (doublecomplex*)malloc(n*sizeof(doublecomplex)); |
1077 | // W not really required in this call | 964 | // W not really required in this call |
@@ -1094,21 +981,20 @@ int schur_l_C(KCMAT(a), CMAT(u), CMAT(s)) { | |||
1094 | 981 | ||
1095 | //////////////////// LU factorization ///////////////////////// | 982 | //////////////////// LU factorization ///////////////////////// |
1096 | 983 | ||
1097 | /* Subroutine */ int dgetrf_(integer *m, integer *n, doublereal *a, integer * | 984 | int dgetrf_(integer *m, integer *n, doublereal *a, integer * |
1098 | lda, integer *ipiv, integer *info); | 985 | lda, integer *ipiv, integer *info); |
1099 | 986 | ||
1100 | int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)) { | 987 | int lu_l_R(DVEC(ipiv), ODMAT(r)) { |
1101 | integer m = ar; | 988 | integer m = rr; |
1102 | integer n = ac; | 989 | integer n = rc; |
1103 | integer mn = MIN(m,n); | 990 | integer mn = MIN(m,n); |
1104 | REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE); | 991 | REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE); |
1105 | DEBUGMSG("lu_l_R"); | 992 | DEBUGMSG("lu_l_R"); |
1106 | integer* auxipiv = (integer*)malloc(mn*sizeof(integer)); | 993 | integer* auxipiv = (integer*)malloc(mn*sizeof(integer)); |
1107 | memcpy(rp,ap,m*n*sizeof(double)); | ||
1108 | integer res; | 994 | integer res; |
1109 | dgetrf_ (&m,&n,rp,&m,auxipiv,&res); | 995 | dgetrf_ (&m,&n,rp,&m,auxipiv,&res); |
1110 | if(res>0) { | 996 | if(res>0) { |
1111 | res = 0; // fixme | 997 | res = 0; // FIXME |
1112 | } | 998 | } |
1113 | CHECK(res,res); | 999 | CHECK(res,res); |
1114 | int k; | 1000 | int k; |
@@ -1120,21 +1006,20 @@ int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)) { | |||
1120 | } | 1006 | } |
1121 | 1007 | ||
1122 | 1008 | ||
1123 | /* Subroutine */ int zgetrf_(integer *m, integer *n, doublecomplex *a, | 1009 | int zgetrf_(integer *m, integer *n, doublecomplex *a, |
1124 | integer *lda, integer *ipiv, integer *info); | 1010 | integer *lda, integer *ipiv, integer *info); |
1125 | 1011 | ||
1126 | int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)) { | 1012 | int lu_l_C(DVEC(ipiv), OCMAT(r)) { |
1127 | integer m = ar; | 1013 | integer m = rr; |
1128 | integer n = ac; | 1014 | integer n = rc; |
1129 | integer mn = MIN(m,n); | 1015 | integer mn = MIN(m,n); |
1130 | REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE); | 1016 | REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE); |
1131 | DEBUGMSG("lu_l_C"); | 1017 | DEBUGMSG("lu_l_C"); |
1132 | integer* auxipiv = (integer*)malloc(mn*sizeof(integer)); | 1018 | integer* auxipiv = (integer*)malloc(mn*sizeof(integer)); |
1133 | memcpy(rp,ap,m*n*sizeof(doublecomplex)); | ||
1134 | integer res; | 1019 | integer res; |
1135 | zgetrf_ (&m,&n,rp,&m,auxipiv,&res); | 1020 | zgetrf_ (&m,&n,rp,&m,auxipiv,&res); |
1136 | if(res>0) { | 1021 | if(res>0) { |
1137 | res = 0; // fixme | 1022 | res = 0; // FIXME |
1138 | } | 1023 | } |
1139 | CHECK(res,res); | 1024 | CHECK(res,res); |
1140 | int k; | 1025 | int k; |
@@ -1148,13 +1033,14 @@ int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)) { | |||
1148 | 1033 | ||
1149 | //////////////////// LU substitution ///////////////////////// | 1034 | //////////////////// LU substitution ///////////////////////// |
1150 | 1035 | ||
1151 | /* Subroutine */ int dgetrs_(char *trans, integer *n, integer *nrhs, | 1036 | int dgetrs_(char *trans, integer *n, integer *nrhs, |
1152 | doublereal *a, integer *lda, integer *ipiv, doublereal *b, integer * | 1037 | doublereal *a, integer *lda, integer *ipiv, doublereal *b, integer * |
1153 | ldb, integer *info); | 1038 | ldb, integer *info); |
1154 | 1039 | ||
1155 | int luS_l_R(KDMAT(a), KDVEC(ipiv), KDMAT(b), DMAT(x)) { | 1040 | int luS_l_R(KODMAT(a), KDVEC(ipiv), ODMAT(b)) { |
1156 | integer m = ar; | 1041 | integer m = ar; |
1157 | integer n = ac; | 1042 | integer n = ac; |
1043 | integer lda = aXc; | ||
1158 | integer mrhs = br; | 1044 | integer mrhs = br; |
1159 | integer nrhs = bc; | 1045 | integer nrhs = bc; |
1160 | 1046 | ||
@@ -1165,21 +1051,21 @@ int luS_l_R(KDMAT(a), KDVEC(ipiv), KDMAT(b), DMAT(x)) { | |||
1165 | auxipiv[k] = (integer)ipivp[k]; | 1051 | auxipiv[k] = (integer)ipivp[k]; |
1166 | } | 1052 | } |
1167 | integer res; | 1053 | integer res; |
1168 | memcpy(xp,bp,mrhs*nrhs*sizeof(double)); | 1054 | dgetrs_ ("N",&n,&nrhs,(/*no const (!?)*/ double*)ap,&lda,auxipiv,bp,&mrhs,&res); |
1169 | dgetrs_ ("N",&n,&nrhs,(/*no const (!?)*/ double*)ap,&m,auxipiv,xp,&mrhs,&res); | ||
1170 | CHECK(res,res); | 1055 | CHECK(res,res); |
1171 | free(auxipiv); | 1056 | free(auxipiv); |
1172 | OK | 1057 | OK |
1173 | } | 1058 | } |
1174 | 1059 | ||
1175 | 1060 | ||
1176 | /* Subroutine */ int zgetrs_(char *trans, integer *n, integer *nrhs, | 1061 | int zgetrs_(char *trans, integer *n, integer *nrhs, |
1177 | doublecomplex *a, integer *lda, integer *ipiv, doublecomplex *b, | 1062 | doublecomplex *a, integer *lda, integer *ipiv, doublecomplex *b, |
1178 | integer *ldb, integer *info); | 1063 | integer *ldb, integer *info); |
1179 | 1064 | ||
1180 | int luS_l_C(KCMAT(a), KDVEC(ipiv), KCMAT(b), CMAT(x)) { | 1065 | int luS_l_C(KOCMAT(a), KDVEC(ipiv), OCMAT(b)) { |
1181 | integer m = ar; | 1066 | integer m = ar; |
1182 | integer n = ac; | 1067 | integer n = ac; |
1068 | integer lda = aXc; | ||
1183 | integer mrhs = br; | 1069 | integer mrhs = br; |
1184 | integer nrhs = bc; | 1070 | integer nrhs = bc; |
1185 | 1071 | ||
@@ -1190,30 +1076,135 @@ int luS_l_C(KCMAT(a), KDVEC(ipiv), KCMAT(b), CMAT(x)) { | |||
1190 | auxipiv[k] = (integer)ipivp[k]; | 1076 | auxipiv[k] = (integer)ipivp[k]; |
1191 | } | 1077 | } |
1192 | integer res; | 1078 | integer res; |
1193 | memcpy(xp,bp,mrhs*nrhs*sizeof(doublecomplex)); | 1079 | zgetrs_ ("N",&n,&nrhs,(doublecomplex*)ap,&lda,auxipiv,bp,&mrhs,&res); |
1194 | zgetrs_ ("N",&n,&nrhs,(doublecomplex*)ap,&m,auxipiv,xp,&mrhs,&res); | 1080 | CHECK(res,res); |
1081 | free(auxipiv); | ||
1082 | OK | ||
1083 | } | ||
1084 | |||
1085 | |||
1086 | //////////////////// LDL factorization ///////////////////////// | ||
1087 | |||
1088 | int dsytrf_(char *uplo, integer *n, doublereal *a, integer *lda, integer *ipiv, | ||
1089 | doublereal *work, integer *lwork, integer *info); | ||
1090 | |||
1091 | int ldl_R(DVEC(ipiv), ODMAT(r)) { | ||
1092 | integer n = rr; | ||
1093 | REQUIRES(n>=1 && rc==n && ipivn == n, BAD_SIZE); | ||
1094 | DEBUGMSG("ldl_R"); | ||
1095 | integer* auxipiv = (integer*)malloc(n*sizeof(integer)); | ||
1096 | integer res; | ||
1097 | integer lda = rXc; | ||
1098 | integer lwork = -1; | ||
1099 | doublereal ans; | ||
1100 | dsytrf_ ("L",&n,rp,&lda,auxipiv,&ans,&lwork,&res); | ||
1101 | lwork = ceil(ans); | ||
1102 | doublereal* work = (doublereal*)malloc(lwork*sizeof(doublereal)); | ||
1103 | dsytrf_ ("L",&n,rp,&lda,auxipiv,work,&lwork,&res); | ||
1104 | CHECK(res,res); | ||
1105 | int k; | ||
1106 | for (k=0; k<n; k++) { | ||
1107 | ipivp[k] = auxipiv[k]; | ||
1108 | } | ||
1109 | free(auxipiv); | ||
1110 | free(work); | ||
1111 | OK | ||
1112 | } | ||
1113 | |||
1114 | |||
1115 | int zhetrf_(char *uplo, integer *n, doublecomplex *a, integer *lda, integer *ipiv, | ||
1116 | doublecomplex *work, integer *lwork, integer *info); | ||
1117 | |||
1118 | int ldl_C(DVEC(ipiv), OCMAT(r)) { | ||
1119 | integer n = rr; | ||
1120 | REQUIRES(n>=1 && rc==n && ipivn == n, BAD_SIZE); | ||
1121 | DEBUGMSG("ldl_R"); | ||
1122 | integer* auxipiv = (integer*)malloc(n*sizeof(integer)); | ||
1123 | integer res; | ||
1124 | integer lda = rXc; | ||
1125 | integer lwork = -1; | ||
1126 | doublecomplex ans; | ||
1127 | zhetrf_ ("L",&n,rp,&lda,auxipiv,&ans,&lwork,&res); | ||
1128 | lwork = ceil(ans.r); | ||
1129 | doublecomplex* work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); | ||
1130 | zhetrf_ ("L",&n,rp,&lda,auxipiv,work,&lwork,&res); | ||
1195 | CHECK(res,res); | 1131 | CHECK(res,res); |
1132 | int k; | ||
1133 | for (k=0; k<n; k++) { | ||
1134 | ipivp[k] = auxipiv[k]; | ||
1135 | } | ||
1196 | free(auxipiv); | 1136 | free(auxipiv); |
1137 | free(work); | ||
1197 | OK | 1138 | OK |
1139 | |||
1198 | } | 1140 | } |
1199 | 1141 | ||
1142 | //////////////////// LDL solve ///////////////////////// | ||
1143 | |||
1144 | int dsytrs_(char *uplo, integer *n, integer *nrhs, doublereal *a, integer *lda, | ||
1145 | integer *ipiv, doublereal *b, integer *ldb, integer *info); | ||
1146 | |||
1147 | int ldl_S_R(KODMAT(a), KDVEC(ipiv), ODMAT(b)) { | ||
1148 | integer m = ar; | ||
1149 | integer n = ac; | ||
1150 | integer lda = aXc; | ||
1151 | integer mrhs = br; | ||
1152 | integer nrhs = bc; | ||
1153 | |||
1154 | REQUIRES(m==n && m==mrhs && m==ipivn,BAD_SIZE); | ||
1155 | integer* auxipiv = (integer*)malloc(n*sizeof(integer)); | ||
1156 | int k; | ||
1157 | for (k=0; k<n; k++) { | ||
1158 | auxipiv[k] = (integer)ipivp[k]; | ||
1159 | } | ||
1160 | integer res; | ||
1161 | dsytrs_ ("L",&n,&nrhs,(/*no const (!?)*/ double*)ap,&lda,auxipiv,bp,&mrhs,&res); | ||
1162 | CHECK(res,res); | ||
1163 | free(auxipiv); | ||
1164 | OK | ||
1165 | } | ||
1166 | |||
1167 | |||
1168 | int zhetrs_(char *uplo, integer *n, integer *nrhs, doublecomplex *a, integer *lda, | ||
1169 | integer *ipiv, doublecomplex *b, integer *ldb, integer *info); | ||
1170 | |||
1171 | int ldl_S_C(KOCMAT(a), KDVEC(ipiv), OCMAT(b)) { | ||
1172 | integer m = ar; | ||
1173 | integer n = ac; | ||
1174 | integer lda = aXc; | ||
1175 | integer mrhs = br; | ||
1176 | integer nrhs = bc; | ||
1177 | |||
1178 | REQUIRES(m==n && m==mrhs && m==ipivn,BAD_SIZE); | ||
1179 | integer* auxipiv = (integer*)malloc(n*sizeof(integer)); | ||
1180 | int k; | ||
1181 | for (k=0; k<n; k++) { | ||
1182 | auxipiv[k] = (integer)ipivp[k]; | ||
1183 | } | ||
1184 | integer res; | ||
1185 | zhetrs_ ("L",&n,&nrhs,(doublecomplex*)ap,&lda,auxipiv,bp,&mrhs,&res); | ||
1186 | CHECK(res,res); | ||
1187 | free(auxipiv); | ||
1188 | OK | ||
1189 | } | ||
1190 | |||
1191 | |||
1200 | //////////////////// Matrix Product ///////////////////////// | 1192 | //////////////////// Matrix Product ///////////////////////// |
1201 | 1193 | ||
1202 | void dgemm_(char *, char *, integer *, integer *, integer *, | 1194 | void dgemm_(char *, char *, integer *, integer *, integer *, |
1203 | double *, const double *, integer *, const double *, | 1195 | double *, const double *, integer *, const double *, |
1204 | integer *, double *, double *, integer *); | 1196 | integer *, double *, double *, integer *); |
1205 | 1197 | ||
1206 | int multiplyR(int ta, int tb, KDMAT(a),KDMAT(b),DMAT(r)) { | 1198 | int multiplyR(int ta, int tb, KODMAT(a),KODMAT(b),ODMAT(r)) { |
1207 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
1208 | DEBUGMSG("dgemm_"); | 1199 | DEBUGMSG("dgemm_"); |
1209 | CHECKNANR(a,"NaN multR Input\n") | 1200 | CHECKNANR(a,"NaN multR Input\n") |
1210 | CHECKNANR(b,"NaN multR Input\n") | 1201 | CHECKNANR(b,"NaN multR Input\n") |
1211 | integer m = ta?ac:ar; | 1202 | integer m = ta?ac:ar; |
1212 | integer n = tb?br:bc; | 1203 | integer n = tb?br:bc; |
1213 | integer k = ta?ar:ac; | 1204 | integer k = ta?ar:ac; |
1214 | integer lda = ar; | 1205 | integer lda = aXc; |
1215 | integer ldb = br; | 1206 | integer ldb = bXc; |
1216 | integer ldc = rr; | 1207 | integer ldc = rXc; |
1217 | double alpha = 1; | 1208 | double alpha = 1; |
1218 | double beta = 0; | 1209 | double beta = 0; |
1219 | dgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc); | 1210 | dgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc); |
@@ -1225,17 +1216,16 @@ void zgemm_(char *, char *, integer *, integer *, integer *, | |||
1225 | doublecomplex *, const doublecomplex *, integer *, const doublecomplex *, | 1216 | doublecomplex *, const doublecomplex *, integer *, const doublecomplex *, |
1226 | integer *, doublecomplex *, doublecomplex *, integer *); | 1217 | integer *, doublecomplex *, doublecomplex *, integer *); |
1227 | 1218 | ||
1228 | int multiplyC(int ta, int tb, KCMAT(a),KCMAT(b),CMAT(r)) { | 1219 | int multiplyC(int ta, int tb, KOCMAT(a),KOCMAT(b),OCMAT(r)) { |
1229 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
1230 | DEBUGMSG("zgemm_"); | 1220 | DEBUGMSG("zgemm_"); |
1231 | CHECKNANC(a,"NaN multC Input\n") | 1221 | CHECKNANC(a,"NaN multC Input\n") |
1232 | CHECKNANC(b,"NaN multC Input\n") | 1222 | CHECKNANC(b,"NaN multC Input\n") |
1233 | integer m = ta?ac:ar; | 1223 | integer m = ta?ac:ar; |
1234 | integer n = tb?br:bc; | 1224 | integer n = tb?br:bc; |
1235 | integer k = ta?ar:ac; | 1225 | integer k = ta?ar:ac; |
1236 | integer lda = ar; | 1226 | integer lda = aXc; |
1237 | integer ldb = br; | 1227 | integer ldb = bXc; |
1238 | integer ldc = rr; | 1228 | integer ldc = rXc; |
1239 | doublecomplex alpha = {1,0}; | 1229 | doublecomplex alpha = {1,0}; |
1240 | doublecomplex beta = {0,0}; | 1230 | doublecomplex beta = {0,0}; |
1241 | zgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha, | 1231 | zgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha, |
@@ -1250,15 +1240,14 @@ void sgemm_(char *, char *, integer *, integer *, integer *, | |||
1250 | float *, const float *, integer *, const float *, | 1240 | float *, const float *, integer *, const float *, |
1251 | integer *, float *, float *, integer *); | 1241 | integer *, float *, float *, integer *); |
1252 | 1242 | ||
1253 | int multiplyF(int ta, int tb, KFMAT(a),KFMAT(b),FMAT(r)) { | 1243 | int multiplyF(int ta, int tb, KOFMAT(a),KOFMAT(b),OFMAT(r)) { |
1254 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
1255 | DEBUGMSG("sgemm_"); | 1244 | DEBUGMSG("sgemm_"); |
1256 | integer m = ta?ac:ar; | 1245 | integer m = ta?ac:ar; |
1257 | integer n = tb?br:bc; | 1246 | integer n = tb?br:bc; |
1258 | integer k = ta?ar:ac; | 1247 | integer k = ta?ar:ac; |
1259 | integer lda = ar; | 1248 | integer lda = aXc; |
1260 | integer ldb = br; | 1249 | integer ldb = bXc; |
1261 | integer ldc = rr; | 1250 | integer ldc = rXc; |
1262 | float alpha = 1; | 1251 | float alpha = 1; |
1263 | float beta = 0; | 1252 | float beta = 0; |
1264 | sgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc); | 1253 | sgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc); |
@@ -1269,15 +1258,14 @@ void cgemm_(char *, char *, integer *, integer *, integer *, | |||
1269 | complex *, const complex *, integer *, const complex *, | 1258 | complex *, const complex *, integer *, const complex *, |
1270 | integer *, complex *, complex *, integer *); | 1259 | integer *, complex *, complex *, integer *); |
1271 | 1260 | ||
1272 | int multiplyQ(int ta, int tb, KQMAT(a),KQMAT(b),QMAT(r)) { | 1261 | int multiplyQ(int ta, int tb, KOQMAT(a),KOQMAT(b),OQMAT(r)) { |
1273 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
1274 | DEBUGMSG("cgemm_"); | 1262 | DEBUGMSG("cgemm_"); |
1275 | integer m = ta?ac:ar; | 1263 | integer m = ta?ac:ar; |
1276 | integer n = tb?br:bc; | 1264 | integer n = tb?br:bc; |
1277 | integer k = ta?ar:ac; | 1265 | integer k = ta?ar:ac; |
1278 | integer lda = ar; | 1266 | integer lda = aXc; |
1279 | integer ldb = br; | 1267 | integer ldb = bXc; |
1280 | integer ldc = rr; | 1268 | integer ldc = rXc; |
1281 | complex alpha = {1,0}; | 1269 | complex alpha = {1,0}; |
1282 | complex beta = {0,0}; | 1270 | complex beta = {0,0}; |
1283 | cgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha, | 1271 | cgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha, |
@@ -1287,203 +1275,270 @@ int multiplyQ(int ta, int tb, KQMAT(a),KQMAT(b),QMAT(r)) { | |||
1287 | OK | 1275 | OK |
1288 | } | 1276 | } |
1289 | 1277 | ||
1290 | //////////////////// transpose ///////////////////////// | ||
1291 | 1278 | ||
1292 | int transF(KFMAT(x),FMAT(t)) { | 1279 | #define MULT_IMP_VER(OP) \ |
1293 | REQUIRES(xr==tc && xc==tr,BAD_SIZE); | 1280 | { TRAV(r,i,j) { \ |
1294 | DEBUGMSG("transF"); | 1281 | int k; \ |
1295 | int i,j; | 1282 | AT(r,i,j) = 0; \ |
1296 | for (i=0; i<tr; i++) { | 1283 | for (k=0;k<ac;k++) { \ |
1297 | for (j=0; j<tc; j++) { | 1284 | OP \ |
1298 | tp[i*tc+j] = xp[j*xc+i]; | 1285 | } \ |
1299 | } | 1286 | } \ |
1300 | } | 1287 | } |
1301 | OK | ||
1302 | } | ||
1303 | 1288 | ||
1304 | int transR(KDMAT(x),DMAT(t)) { | 1289 | #define MULT_IMP(M) { \ |
1305 | REQUIRES(xr==tc && xc==tr,BAD_SIZE); | 1290 | if (m==1) { \ |
1306 | DEBUGMSG("transR"); | 1291 | MULT_IMP_VER( AT(r,i,j) += AT(a,i,k) * AT(b,k,j); ) \ |
1307 | int i,j; | 1292 | } else { \ |
1308 | for (i=0; i<tr; i++) { | 1293 | MULT_IMP_VER( AT(r,i,j) = M(AT(r,i,j) + M(AT(a,i,k) * AT(b,k,j), m) , m) ; ) \ |
1309 | for (j=0; j<tc; j++) { | 1294 | } OK } |
1310 | tp[i*tc+j] = xp[j*xc+i]; | 1295 | |
1311 | } | 1296 | int multiplyI(int m, KOIMAT(a), KOIMAT(b), OIMAT(r)) MULT_IMP(mod) |
1312 | } | 1297 | int multiplyL(int64_t m, KOLMAT(a), KOLMAT(b), OLMAT(r)) MULT_IMP(mod_l) |
1313 | OK | 1298 | |
1299 | /////////////////////////////// inplace row ops //////////////////////////////// | ||
1300 | |||
1301 | #define AXPY_IMP { \ | ||
1302 | int j; \ | ||
1303 | for(j=j1; j<=j2; j++) { \ | ||
1304 | AT(r,i2,j) += a*AT(r,i1,j); \ | ||
1305 | } OK } | ||
1306 | |||
1307 | #define AXPY_MOD_IMP(M) { \ | ||
1308 | int j; \ | ||
1309 | for(j=j1; j<=j2; j++) { \ | ||
1310 | AT(r,i2,j) = M(AT(r,i2,j) + M(a*AT(r,i1,j), m) , m); \ | ||
1311 | } OK } | ||
1312 | |||
1313 | |||
1314 | #define SCAL_IMP { \ | ||
1315 | int i,j; \ | ||
1316 | for(i=i1; i<=i2; i++) { \ | ||
1317 | for(j=j1; j<=j2; j++) { \ | ||
1318 | AT(r,i,j) = a*AT(r,i,j); \ | ||
1319 | } \ | ||
1320 | } OK } | ||
1321 | |||
1322 | #define SCAL_MOD_IMP(M) { \ | ||
1323 | int i,j; \ | ||
1324 | for(i=i1; i<=i2; i++) { \ | ||
1325 | for(j=j1; j<=j2; j++) { \ | ||
1326 | AT(r,i,j) = M(a*AT(r,i,j) , m); \ | ||
1327 | } \ | ||
1328 | } OK } | ||
1329 | |||
1330 | |||
1331 | #define SWAP_IMP(T) { \ | ||
1332 | T aux; \ | ||
1333 | int k; \ | ||
1334 | if (i1 != i2) { \ | ||
1335 | for (k=j1; k<=j2; k++) { \ | ||
1336 | aux = AT(r,i1,k); \ | ||
1337 | AT(r,i1,k) = AT(r,i2,k); \ | ||
1338 | AT(r,i2,k) = aux; \ | ||
1339 | } \ | ||
1340 | } OK } | ||
1341 | |||
1342 | |||
1343 | #define ROWOP_IMP(T) { \ | ||
1344 | T a = *pa; \ | ||
1345 | switch(code) { \ | ||
1346 | case 0: AXPY_IMP \ | ||
1347 | case 1: SCAL_IMP \ | ||
1348 | case 2: SWAP_IMP(T) \ | ||
1349 | default: ERROR(BAD_CODE); \ | ||
1350 | } \ | ||
1314 | } | 1351 | } |
1315 | 1352 | ||
1316 | int transQ(KQMAT(x),QMAT(t)) { | 1353 | #define ROWOP_MOD_IMP(T,M) { \ |
1317 | REQUIRES(xr==tc && xc==tr,BAD_SIZE); | 1354 | T a = *pa; \ |
1318 | DEBUGMSG("transQ"); | 1355 | switch(code) { \ |
1319 | int i,j; | 1356 | case 0: AXPY_MOD_IMP(M) \ |
1320 | for (i=0; i<tr; i++) { | 1357 | case 1: SCAL_MOD_IMP(M) \ |
1321 | for (j=0; j<tc; j++) { | 1358 | case 2: SWAP_IMP(T) \ |
1322 | tp[i*tc+j] = xp[j*xc+i]; | 1359 | default: ERROR(BAD_CODE); \ |
1323 | } | 1360 | } \ |
1324 | } | ||
1325 | OK | ||
1326 | } | 1361 | } |
1327 | 1362 | ||
1328 | int transC(KCMAT(x),CMAT(t)) { | ||
1329 | REQUIRES(xr==tc && xc==tr,BAD_SIZE); | ||
1330 | DEBUGMSG("transC"); | ||
1331 | int i,j; | ||
1332 | for (i=0; i<tr; i++) { | ||
1333 | for (j=0; j<tc; j++) { | ||
1334 | tp[i*tc+j] = xp[j*xc+i]; | ||
1335 | } | ||
1336 | } | ||
1337 | OK | ||
1338 | } | ||
1339 | 1363 | ||
1340 | int transP(KPMAT(x), PMAT(t)) { | 1364 | #define ROWOP(T) int rowop_##T(int code, T* pa, int i1, int i2, int j1, int j2, MATG(T,r)) ROWOP_IMP(T) |
1341 | REQUIRES(xr==tc && xc==tr,BAD_SIZE); | 1365 | |
1342 | REQUIRES(xs==ts,NOCONVER); | 1366 | #define ROWOP_MOD(T,M) int rowop_mod_##T(T m, int code, T* pa, int i1, int i2, int j1, int j2, MATG(T,r)) ROWOP_MOD_IMP(T,M) |
1343 | DEBUGMSG("transP"); | 1367 | |
1344 | int i,j; | 1368 | ROWOP(double) |
1345 | for (i=0; i<tr; i++) { | 1369 | ROWOP(float) |
1346 | for (j=0; j<tc; j++) { | 1370 | ROWOP(TCD) |
1347 | memcpy(tp+(i*tc+j)*xs,xp +(j*xc+i)*xs,xs); | 1371 | ROWOP(TCF) |
1372 | ROWOP(int32_t) | ||
1373 | ROWOP(int64_t) | ||
1374 | ROWOP_MOD(int32_t,mod) | ||
1375 | ROWOP_MOD(int64_t,mod_l) | ||
1376 | |||
1377 | /////////////////////////////// inplace GEMM //////////////////////////////// | ||
1378 | |||
1379 | #define GEMM(T) int gemm_##T(VECG(T,c),MATG(T,a),MATG(T,b),MATG(T,r)) { \ | ||
1380 | T a = cp[0], b = cp[1]; \ | ||
1381 | T t; \ | ||
1382 | int k; \ | ||
1383 | { TRAV(r,i,j) { \ | ||
1384 | t = 0; \ | ||
1385 | for(k=0; k<ac; k++) { \ | ||
1386 | t += AT(a,i,k) * AT(b,k,j); \ | ||
1387 | } \ | ||
1388 | AT(r,i,j) = b*AT(r,i,j) + a*t; \ | ||
1389 | } \ | ||
1390 | } OK } | ||
1391 | |||
1392 | |||
1393 | GEMM(double) | ||
1394 | GEMM(float) | ||
1395 | GEMM(TCD) | ||
1396 | GEMM(TCF) | ||
1397 | GEMM(int32_t) | ||
1398 | GEMM(int64_t) | ||
1399 | |||
1400 | #define GEMM_MOD(T,M) int gemm_mod_##T(T m, VECG(T,c),MATG(T,a),MATG(T,b),MATG(T,r)) { \ | ||
1401 | T a = cp[0], b = cp[1]; \ | ||
1402 | int k; \ | ||
1403 | T t; \ | ||
1404 | { TRAV(r,i,j) { \ | ||
1405 | t = 0; \ | ||
1406 | for(k=0; k<ac; k++) { \ | ||
1407 | t = M(t+M(AT(a,i,k) * AT(b,k,j))); \ | ||
1408 | } \ | ||
1409 | AT(r,i,j) = M(M(b*AT(r,i,j)) + M(a*t)); \ | ||
1410 | } \ | ||
1411 | } OK } | ||
1412 | |||
1413 | |||
1414 | #define MOD32(X) mod(X,m) | ||
1415 | #define MOD64(X) mod_l(X,m) | ||
1416 | |||
1417 | GEMM_MOD(int32_t,MOD32) | ||
1418 | GEMM_MOD(int64_t,MOD64) | ||
1419 | |||
1420 | ////////////////// sparse matrix-product /////////////////////////////////////// | ||
1421 | |||
1422 | |||
1423 | int smXv(KDVEC(vals),KIVEC(cols),KIVEC(rows),KDVEC(x),DVEC(r)) { | ||
1424 | int r, c; | ||
1425 | for (r = 0; r < rowsn - 1; r++) { | ||
1426 | rp[r] = 0; | ||
1427 | for (c = rowsp[r]; c < rowsp[r+1]; c++) { | ||
1428 | rp[r] += valsp[c-1] * xp[colsp[c-1]-1]; | ||
1348 | } | 1429 | } |
1349 | } | 1430 | } |
1350 | OK | 1431 | OK |
1351 | } | 1432 | } |
1352 | 1433 | ||
1353 | //////////////////// constant ///////////////////////// | 1434 | int smTXv(KDVEC(vals),KIVEC(cols),KIVEC(rows),KDVEC(x),DVEC(r)) { |
1354 | 1435 | int r,c; | |
1355 | int constantF(float * pval, FVEC(r)) { | 1436 | for (c = 0; c < rn; c++) { |
1356 | DEBUGMSG("constantF") | 1437 | rp[c] = 0; |
1357 | int k; | ||
1358 | double val = *pval; | ||
1359 | for(k=0;k<rn;k++) { | ||
1360 | rp[k]=val; | ||
1361 | } | ||
1362 | OK | ||
1363 | } | ||
1364 | |||
1365 | int constantR(double * pval, DVEC(r)) { | ||
1366 | DEBUGMSG("constantR") | ||
1367 | int k; | ||
1368 | double val = *pval; | ||
1369 | for(k=0;k<rn;k++) { | ||
1370 | rp[k]=val; | ||
1371 | } | 1438 | } |
1372 | OK | 1439 | for (r = 0; r < rowsn - 1; r++) { |
1373 | } | 1440 | for (c = rowsp[r]; c < rowsp[r+1]; c++) { |
1374 | 1441 | rp[colsp[c-1]-1] += valsp[c-1] * xp[r]; | |
1375 | int constantQ(complex* pval, QVEC(r)) { | 1442 | } |
1376 | DEBUGMSG("constantQ") | ||
1377 | int k; | ||
1378 | complex val = *pval; | ||
1379 | for(k=0;k<rn;k++) { | ||
1380 | rp[k]=val; | ||
1381 | } | 1443 | } |
1382 | OK | 1444 | OK |
1383 | } | 1445 | } |
1384 | 1446 | ||
1385 | int constantC(doublecomplex* pval, CVEC(r)) { | ||
1386 | DEBUGMSG("constantC") | ||
1387 | int k; | ||
1388 | doublecomplex val = *pval; | ||
1389 | for(k=0;k<rn;k++) { | ||
1390 | rp[k]=val; | ||
1391 | } | ||
1392 | OK | ||
1393 | } | ||
1394 | 1447 | ||
1395 | int constantP(void* pval, PVEC(r)) { | 1448 | //////////////////////// extract ///////////////////////////////// |
1396 | DEBUGMSG("constantP") | 1449 | |
1397 | int k; | 1450 | #define EXTRACT_IMP { \ |
1398 | for(k=0;k<rn;k++) { | 1451 | int i,j,si,sj,ni,nj; \ |
1399 | memcpy(rp+k*rs,pval,rs); | 1452 | ni = modei ? in : ip[1]-ip[0]+1; \ |
1400 | } | 1453 | nj = modej ? jn : jp[1]-jp[0]+1; \ |
1454 | \ | ||
1455 | for (i=0; i<ni; i++) { \ | ||
1456 | si = modei ? ip[i] : i+ip[0]; \ | ||
1457 | \ | ||
1458 | for (j=0; j<nj; j++) { \ | ||
1459 | sj = modej ? jp[j] : j+jp[0]; \ | ||
1460 | \ | ||
1461 | AT(r,i,j) = AT(m,si,sj); \ | ||
1462 | } \ | ||
1463 | } OK } | ||
1464 | |||
1465 | #define EXTRACT(T) int extract##T(int modei, int modej, KIVEC(i), KIVEC(j), KO##T##MAT(m), O##T##MAT(r)) EXTRACT_IMP | ||
1466 | |||
1467 | EXTRACT(D) | ||
1468 | EXTRACT(F) | ||
1469 | EXTRACT(C) | ||
1470 | EXTRACT(Q) | ||
1471 | EXTRACT(I) | ||
1472 | EXTRACT(L) | ||
1473 | |||
1474 | //////////////////////// setRect ///////////////////////////////// | ||
1475 | |||
1476 | #define SETRECT(T) \ | ||
1477 | int setRect##T(int i, int j, KO##T##MAT(m), O##T##MAT(r)) { \ | ||
1478 | { TRAV(m,a,b) { \ | ||
1479 | int x = a+i, y = b+j; \ | ||
1480 | if(x>=0 && x<rr && y>=0 && y<rc) { \ | ||
1481 | AT(r,x,y) = AT(m,a,b); \ | ||
1482 | } \ | ||
1483 | } \ | ||
1484 | } OK } | ||
1485 | |||
1486 | SETRECT(D) | ||
1487 | SETRECT(F) | ||
1488 | SETRECT(C) | ||
1489 | SETRECT(Q) | ||
1490 | SETRECT(I) | ||
1491 | SETRECT(L) | ||
1492 | |||
1493 | //////////////////////// remap ///////////////////////////////// | ||
1494 | |||
1495 | #define REMAP_IMP \ | ||
1496 | REQUIRES(ir==jr && ic==jc && ir==rr && ic==rc ,BAD_SIZE); \ | ||
1497 | { TRAV(r,a,b) { AT(r,a,b) = AT(m,AT(i,a,b),AT(j,a,b)); } \ | ||
1498 | } \ | ||
1401 | OK | 1499 | OK |
1402 | } | ||
1403 | |||
1404 | //////////////////// float-double conversion ///////////////////////// | ||
1405 | 1500 | ||
1406 | int float2double(FVEC(x),DVEC(y)) { | 1501 | int remapD(KOIMAT(i), KOIMAT(j), KODMAT(m), ODMAT(r)) { |
1407 | DEBUGMSG("float2double") | 1502 | REMAP_IMP |
1408 | int k; | ||
1409 | for(k=0;k<xn;k++) { | ||
1410 | yp[k]=xp[k]; | ||
1411 | } | ||
1412 | OK | ||
1413 | } | 1503 | } |
1414 | 1504 | ||
1415 | int double2float(DVEC(x),FVEC(y)) { | 1505 | int remapF(KOIMAT(i), KOIMAT(j), KOFMAT(m), OFMAT(r)) { |
1416 | DEBUGMSG("double2float") | 1506 | REMAP_IMP |
1417 | int k; | ||
1418 | for(k=0;k<xn;k++) { | ||
1419 | yp[k]=xp[k]; | ||
1420 | } | ||
1421 | OK | ||
1422 | } | 1507 | } |
1423 | 1508 | ||
1424 | //////////////////// conjugate ///////////////////////// | 1509 | int remapI(KOIMAT(i), KOIMAT(j), KOIMAT(m), OIMAT(r)) { |
1425 | 1510 | REMAP_IMP | |
1426 | int conjugateQ(KQVEC(x),QVEC(t)) { | ||
1427 | REQUIRES(xn==tn,BAD_SIZE); | ||
1428 | DEBUGMSG("conjugateQ"); | ||
1429 | int k; | ||
1430 | for(k=0;k<xn;k++) { | ||
1431 | tp[k].r = xp[k].r; | ||
1432 | tp[k].i = -xp[k].i; | ||
1433 | } | ||
1434 | OK | ||
1435 | } | 1511 | } |
1436 | 1512 | ||
1437 | int conjugateC(KCVEC(x),CVEC(t)) { | 1513 | int remapL(KOIMAT(i), KOIMAT(j), KOLMAT(m), OLMAT(r)) { |
1438 | REQUIRES(xn==tn,BAD_SIZE); | 1514 | REMAP_IMP |
1439 | DEBUGMSG("conjugateC"); | ||
1440 | int k; | ||
1441 | for(k=0;k<xn;k++) { | ||
1442 | tp[k].r = xp[k].r; | ||
1443 | tp[k].i = -xp[k].i; | ||
1444 | } | ||
1445 | OK | ||
1446 | } | ||
1447 | |||
1448 | //////////////////// step ///////////////////////// | ||
1449 | |||
1450 | int stepF(FVEC(x),FVEC(y)) { | ||
1451 | DEBUGMSG("stepF") | ||
1452 | int k; | ||
1453 | for(k=0;k<xn;k++) { | ||
1454 | yp[k]=xp[k]>0; | ||
1455 | } | ||
1456 | OK | ||
1457 | } | 1515 | } |
1458 | 1516 | ||
1459 | int stepD(DVEC(x),DVEC(y)) { | 1517 | int remapC(KOIMAT(i), KOIMAT(j), KOCMAT(m), OCMAT(r)) { |
1460 | DEBUGMSG("stepD") | 1518 | REMAP_IMP |
1461 | int k; | ||
1462 | for(k=0;k<xn;k++) { | ||
1463 | yp[k]=xp[k]>0; | ||
1464 | } | ||
1465 | OK | ||
1466 | } | 1519 | } |
1467 | 1520 | ||
1468 | //////////////////// cond ///////////////////////// | 1521 | int remapQ(KOIMAT(i), KOIMAT(j), KOQMAT(m), OQMAT(r)) { |
1469 | 1522 | REMAP_IMP | |
1470 | int condF(FVEC(x),FVEC(y),FVEC(lt),FVEC(eq),FVEC(gt),FVEC(r)) { | ||
1471 | REQUIRES(xn==yn && xn==ltn && xn==eqn && xn==gtn && xn==rn ,BAD_SIZE); | ||
1472 | DEBUGMSG("condF") | ||
1473 | int k; | ||
1474 | for(k=0;k<xn;k++) { | ||
1475 | rp[k] = xp[k]<yp[k]?ltp[k]:(xp[k]>yp[k]?gtp[k]:eqp[k]); | ||
1476 | } | ||
1477 | OK | ||
1478 | } | 1523 | } |
1479 | 1524 | ||
1480 | int condD(DVEC(x),DVEC(y),DVEC(lt),DVEC(eq),DVEC(gt),DVEC(r)) { | 1525 | //////////////////////////////////////////////////////////////////////////////// |
1481 | REQUIRES(xn==yn && xn==ltn && xn==eqn && xn==gtn && xn==rn ,BAD_SIZE); | 1526 | |
1482 | DEBUGMSG("condD") | 1527 | int saveMatrix(char * file, char * format, KODMAT(a)){ |
1483 | int k; | 1528 | FILE * fp; |
1484 | for(k=0;k<xn;k++) { | 1529 | fp = fopen (file, "w"); |
1485 | rp[k] = xp[k]<yp[k]?ltp[k]:(xp[k]>yp[k]?gtp[k]:eqp[k]); | 1530 | int r, c; |
1531 | for (r=0;r<ar; r++) { | ||
1532 | for (c=0; c<ac; c++) { | ||
1533 | fprintf(fp,format,AT(a,r,c)); | ||
1534 | if (c<ac-1) { | ||
1535 | fprintf(fp," "); | ||
1536 | } else { | ||
1537 | fprintf(fp,"\n"); | ||
1538 | } | ||
1539 | } | ||
1486 | } | 1540 | } |
1541 | fclose(fp); | ||
1487 | OK | 1542 | OK |
1488 | } | 1543 | } |
1489 | 1544 | ||
diff --git a/packages/base/src/C/lapack-aux.h b/packages/base/src/Internal/C/lapack-aux.h index c95a2a3..7a6fcbf 100644 --- a/packages/base/src/C/lapack-aux.h +++ b/packages/base/src/Internal/C/lapack-aux.h | |||
@@ -37,11 +37,15 @@ typedef short ftnlen; | |||
37 | /********************************************************/ | 37 | /********************************************************/ |
38 | 38 | ||
39 | #define IVEC(A) int A##n, int*A##p | 39 | #define IVEC(A) int A##n, int*A##p |
40 | #define LVEC(A) int A##n, int64_t*A##p | ||
40 | #define FVEC(A) int A##n, float*A##p | 41 | #define FVEC(A) int A##n, float*A##p |
41 | #define DVEC(A) int A##n, double*A##p | 42 | #define DVEC(A) int A##n, double*A##p |
42 | #define QVEC(A) int A##n, complex*A##p | 43 | #define QVEC(A) int A##n, complex*A##p |
43 | #define CVEC(A) int A##n, doublecomplex*A##p | 44 | #define CVEC(A) int A##n, doublecomplex*A##p |
44 | #define PVEC(A) int A##n, void* A##p, int A##s | 45 | #define PVEC(A) int A##n, void* A##p, int A##s |
46 | |||
47 | #define IMAT(A) int A##r, int A##c, int* A##p | ||
48 | #define LMAT(A) int A##r, int A##c, int64_t* A##p | ||
45 | #define FMAT(A) int A##r, int A##c, float* A##p | 49 | #define FMAT(A) int A##r, int A##c, float* A##p |
46 | #define DMAT(A) int A##r, int A##c, double* A##p | 50 | #define DMAT(A) int A##r, int A##c, double* A##p |
47 | #define QMAT(A) int A##r, int A##c, complex* A##p | 51 | #define QMAT(A) int A##r, int A##c, complex* A##p |
@@ -49,14 +53,59 @@ typedef short ftnlen; | |||
49 | #define PMAT(A) int A##r, int A##c, void* A##p, int A##s | 53 | #define PMAT(A) int A##r, int A##c, void* A##p, int A##s |
50 | 54 | ||
51 | #define KIVEC(A) int A##n, const int*A##p | 55 | #define KIVEC(A) int A##n, const int*A##p |
56 | #define KLVEC(A) int A##n, const int64_t*A##p | ||
52 | #define KFVEC(A) int A##n, const float*A##p | 57 | #define KFVEC(A) int A##n, const float*A##p |
53 | #define KDVEC(A) int A##n, const double*A##p | 58 | #define KDVEC(A) int A##n, const double*A##p |
54 | #define KQVEC(A) int A##n, const complex*A##p | 59 | #define KQVEC(A) int A##n, const complex*A##p |
55 | #define KCVEC(A) int A##n, const doublecomplex*A##p | 60 | #define KCVEC(A) int A##n, const doublecomplex*A##p |
56 | #define KPVEC(A) int A##n, const void* A##p, int A##s | 61 | #define KPVEC(A) int A##n, const void* A##p, int A##s |
62 | |||
63 | #define KIMAT(A) int A##r, int A##c, const int* A##p | ||
64 | #define KLMAT(A) int A##r, int A##c, const int64_t* A##p | ||
57 | #define KFMAT(A) int A##r, int A##c, const float* A##p | 65 | #define KFMAT(A) int A##r, int A##c, const float* A##p |
58 | #define KDMAT(A) int A##r, int A##c, const double* A##p | 66 | #define KDMAT(A) int A##r, int A##c, const double* A##p |
59 | #define KQMAT(A) int A##r, int A##c, const complex* A##p | 67 | #define KQMAT(A) int A##r, int A##c, const complex* A##p |
60 | #define KCMAT(A) int A##r, int A##c, const doublecomplex* A##p | 68 | #define KCMAT(A) int A##r, int A##c, const doublecomplex* A##p |
61 | #define KPMAT(A) int A##r, int A##c, const void* A##p, int A##s | 69 | #define KPMAT(A) int A##r, int A##c, const void* A##p, int A##s |
62 | 70 | ||
71 | #define VECG(T,A) int A##n, T* A##p | ||
72 | #define MATG(T,A) int A##r, int A##c, int A##Xr, int A##Xc, T* A##p | ||
73 | |||
74 | #define OIMAT(A) MATG(int,A) | ||
75 | #define OLMAT(A) MATG(int64_t,A) | ||
76 | #define OFMAT(A) MATG(float,A) | ||
77 | #define ODMAT(A) MATG(double,A) | ||
78 | #define OQMAT(A) MATG(complex,A) | ||
79 | #define OCMAT(A) MATG(doublecomplex,A) | ||
80 | |||
81 | #define KOIMAT(A) MATG(const int,A) | ||
82 | #define KOLMAT(A) MATG(const int64_t,A) | ||
83 | #define KOFMAT(A) MATG(const float,A) | ||
84 | #define KODMAT(A) MATG(const double,A) | ||
85 | #define KOQMAT(A) MATG(const complex,A) | ||
86 | #define KOCMAT(A) MATG(const doublecomplex,A) | ||
87 | |||
88 | #define AT(m,i,j) (m##p[(i)*m##Xr + (j)*m##Xc]) | ||
89 | #define TRAV(m,i,j) int i,j; for (i=0;i<m##r;i++) for (j=0;j<m##c;j++) | ||
90 | |||
91 | /********************************************************/ | ||
92 | |||
93 | static inline | ||
94 | int mod (int a, int b) { | ||
95 | int m = a % b; | ||
96 | if (b>0) { | ||
97 | return m >=0 ? m : m+b; | ||
98 | } else { | ||
99 | return m <=0 ? m : m+b; | ||
100 | } | ||
101 | } | ||
102 | |||
103 | static inline | ||
104 | int64_t mod_l (int64_t a, int64_t b) { | ||
105 | int64_t m = a % b; | ||
106 | if (b>0) { | ||
107 | return m >=0 ? m : m+b; | ||
108 | } else { | ||
109 | return m <=0 ? m : m+b; | ||
110 | } | ||
111 | } | ||
diff --git a/packages/base/src/C/vector-aux.c b/packages/base/src/Internal/C/vector-aux.c index abeba76..9dbf536 100644 --- a/packages/base/src/C/vector-aux.c +++ b/packages/base/src/Internal/C/vector-aux.c | |||
@@ -1,4 +1,5 @@ | |||
1 | #include <complex.h> | 1 | #include <complex.h> |
2 | #include <inttypes.h> | ||
2 | 3 | ||
3 | typedef double complex TCD; | 4 | typedef double complex TCD; |
4 | typedef float complex TCF; | 5 | typedef float complex TCF; |
@@ -46,7 +47,7 @@ int sumF(KFVEC(x),FVEC(r)) { | |||
46 | rp[0] = res; | 47 | rp[0] = res; |
47 | OK | 48 | OK |
48 | } | 49 | } |
49 | 50 | ||
50 | int sumR(KDVEC(x),DVEC(r)) { | 51 | int sumR(KDVEC(x),DVEC(r)) { |
51 | DEBUGMSG("sumR"); | 52 | DEBUGMSG("sumR"); |
52 | REQUIRES(rn==1,BAD_SIZE); | 53 | REQUIRES(rn==1,BAD_SIZE); |
@@ -57,6 +58,31 @@ int sumR(KDVEC(x),DVEC(r)) { | |||
57 | OK | 58 | OK |
58 | } | 59 | } |
59 | 60 | ||
61 | int sumI(int m, KIVEC(x),IVEC(r)) { | ||
62 | REQUIRES(rn==1,BAD_SIZE); | ||
63 | int i; | ||
64 | int res = 0; | ||
65 | if (m==1) { | ||
66 | for (i = 0; i < xn; i++) res += xp[i]; | ||
67 | } else { | ||
68 | for (i = 0; i < xn; i++) res = (res + xp[i]) % m; | ||
69 | } | ||
70 | rp[0] = res; | ||
71 | OK | ||
72 | } | ||
73 | |||
74 | int sumL(int64_t m, KLVEC(x),LVEC(r)) { | ||
75 | REQUIRES(rn==1,BAD_SIZE); | ||
76 | int i; | ||
77 | int res = 0; | ||
78 | if (m==1) { | ||
79 | for (i = 0; i < xn; i++) res += xp[i]; | ||
80 | } else { | ||
81 | for (i = 0; i < xn; i++) res = (res + xp[i]) % m; | ||
82 | } | ||
83 | rp[0] = res; | ||
84 | OK | ||
85 | } | ||
60 | 86 | ||
61 | int sumQ(KQVEC(x),QVEC(r)) { | 87 | int sumQ(KQVEC(x),QVEC(r)) { |
62 | DEBUGMSG("sumQ"); | 88 | DEBUGMSG("sumQ"); |
@@ -72,7 +98,7 @@ int sumQ(KQVEC(x),QVEC(r)) { | |||
72 | rp[0] = res; | 98 | rp[0] = res; |
73 | OK | 99 | OK |
74 | } | 100 | } |
75 | 101 | ||
76 | int sumC(KCVEC(x),CVEC(r)) { | 102 | int sumC(KCVEC(x),CVEC(r)) { |
77 | DEBUGMSG("sumC"); | 103 | DEBUGMSG("sumC"); |
78 | REQUIRES(rn==1,BAD_SIZE); | 104 | REQUIRES(rn==1,BAD_SIZE); |
@@ -98,7 +124,7 @@ int prodF(KFVEC(x),FVEC(r)) { | |||
98 | rp[0] = res; | 124 | rp[0] = res; |
99 | OK | 125 | OK |
100 | } | 126 | } |
101 | 127 | ||
102 | int prodR(KDVEC(x),DVEC(r)) { | 128 | int prodR(KDVEC(x),DVEC(r)) { |
103 | DEBUGMSG("prodR"); | 129 | DEBUGMSG("prodR"); |
104 | REQUIRES(rn==1,BAD_SIZE); | 130 | REQUIRES(rn==1,BAD_SIZE); |
@@ -109,6 +135,31 @@ int prodR(KDVEC(x),DVEC(r)) { | |||
109 | OK | 135 | OK |
110 | } | 136 | } |
111 | 137 | ||
138 | int prodI(int m, KIVEC(x),IVEC(r)) { | ||
139 | REQUIRES(rn==1,BAD_SIZE); | ||
140 | int i; | ||
141 | int res = 1; | ||
142 | if (m==1) { | ||
143 | for (i = 0; i < xn; i++) res *= xp[i]; | ||
144 | } else { | ||
145 | for (i = 0; i < xn; i++) res = (res * xp[i]) % m; | ||
146 | } | ||
147 | rp[0] = res; | ||
148 | OK | ||
149 | } | ||
150 | |||
151 | int prodL(int64_t m, KLVEC(x),LVEC(r)) { | ||
152 | REQUIRES(rn==1,BAD_SIZE); | ||
153 | int i; | ||
154 | int res = 1; | ||
155 | if (m==1) { | ||
156 | for (i = 0; i < xn; i++) res *= xp[i]; | ||
157 | } else { | ||
158 | for (i = 0; i < xn; i++) res = (res * xp[i]) % m; | ||
159 | } | ||
160 | rp[0] = res; | ||
161 | OK | ||
162 | } | ||
112 | 163 | ||
113 | int prodQ(KQVEC(x),QVEC(r)) { | 164 | int prodQ(KQVEC(x),QVEC(r)) { |
114 | DEBUGMSG("prodQ"); | 165 | DEBUGMSG("prodQ"); |
@@ -126,7 +177,7 @@ int prodQ(KQVEC(x),QVEC(r)) { | |||
126 | rp[0] = res; | 177 | rp[0] = res; |
127 | OK | 178 | OK |
128 | } | 179 | } |
129 | 180 | ||
130 | int prodC(KCVEC(x),CVEC(r)) { | 181 | int prodC(KCVEC(x),CVEC(r)) { |
131 | DEBUGMSG("prodC"); | 182 | DEBUGMSG("prodC"); |
132 | REQUIRES(rn==1,BAD_SIZE); | 183 | REQUIRES(rn==1,BAD_SIZE); |
@@ -144,7 +195,7 @@ int prodC(KCVEC(x),CVEC(r)) { | |||
144 | OK | 195 | OK |
145 | } | 196 | } |
146 | 197 | ||
147 | 198 | ||
148 | double dnrm2_(integer*, const double*, integer*); | 199 | double dnrm2_(integer*, const double*, integer*); |
149 | double dasum_(integer*, const double*, integer*); | 200 | double dasum_(integer*, const double*, integer*); |
150 | 201 | ||
@@ -170,7 +221,7 @@ double vector_min(KDVEC(x)) { | |||
170 | return r; | 221 | return r; |
171 | } | 222 | } |
172 | 223 | ||
173 | double vector_max_index(KDVEC(x)) { | 224 | int vector_max_index(KDVEC(x)) { |
174 | int k, r = 0; | 225 | int k, r = 0; |
175 | for (k = 1; k<xn; k++) { | 226 | for (k = 1; k<xn; k++) { |
176 | if(xp[k]>xp[r]) { | 227 | if(xp[k]>xp[r]) { |
@@ -180,7 +231,7 @@ double vector_max_index(KDVEC(x)) { | |||
180 | return r; | 231 | return r; |
181 | } | 232 | } |
182 | 233 | ||
183 | double vector_min_index(KDVEC(x)) { | 234 | int vector_min_index(KDVEC(x)) { |
184 | int k, r = 0; | 235 | int k, r = 0; |
185 | for (k = 1; k<xn; k++) { | 236 | for (k = 1; k<xn; k++) { |
186 | if(xp[k]<xp[r]) { | 237 | if(xp[k]<xp[r]) { |
@@ -189,8 +240,8 @@ double vector_min_index(KDVEC(x)) { | |||
189 | } | 240 | } |
190 | return r; | 241 | return r; |
191 | } | 242 | } |
192 | 243 | ||
193 | int toScalarR(int code, KDVEC(x), DVEC(r)) { | 244 | int toScalarR(int code, KDVEC(x), DVEC(r)) { |
194 | REQUIRES(rn==1,BAD_SIZE); | 245 | REQUIRES(rn==1,BAD_SIZE); |
195 | DEBUGMSG("toScalarR"); | 246 | DEBUGMSG("toScalarR"); |
196 | double res; | 247 | double res; |
@@ -235,7 +286,7 @@ float vector_min_f(KFVEC(x)) { | |||
235 | return r; | 286 | return r; |
236 | } | 287 | } |
237 | 288 | ||
238 | float vector_max_index_f(KFVEC(x)) { | 289 | int vector_max_index_f(KFVEC(x)) { |
239 | int k, r = 0; | 290 | int k, r = 0; |
240 | for (k = 1; k<xn; k++) { | 291 | for (k = 1; k<xn; k++) { |
241 | if(xp[k]>xp[r]) { | 292 | if(xp[k]>xp[r]) { |
@@ -245,7 +296,7 @@ float vector_max_index_f(KFVEC(x)) { | |||
245 | return r; | 296 | return r; |
246 | } | 297 | } |
247 | 298 | ||
248 | float vector_min_index_f(KFVEC(x)) { | 299 | int vector_min_index_f(KFVEC(x)) { |
249 | int k, r = 0; | 300 | int k, r = 0; |
250 | for (k = 1; k<xn; k++) { | 301 | for (k = 1; k<xn; k++) { |
251 | if(xp[k]<xp[r]) { | 302 | if(xp[k]<xp[r]) { |
@@ -256,7 +307,7 @@ float vector_min_index_f(KFVEC(x)) { | |||
256 | } | 307 | } |
257 | 308 | ||
258 | 309 | ||
259 | int toScalarF(int code, KFVEC(x), FVEC(r)) { | 310 | int toScalarF(int code, KFVEC(x), FVEC(r)) { |
260 | REQUIRES(rn==1,BAD_SIZE); | 311 | REQUIRES(rn==1,BAD_SIZE); |
261 | DEBUGMSG("toScalarF"); | 312 | DEBUGMSG("toScalarF"); |
262 | float res; | 313 | float res; |
@@ -275,10 +326,126 @@ int toScalarF(int code, KFVEC(x), FVEC(r)) { | |||
275 | OK | 326 | OK |
276 | } | 327 | } |
277 | 328 | ||
329 | int vector_max_i(KIVEC(x)) { | ||
330 | int r = xp[0]; | ||
331 | int k; | ||
332 | for (k = 1; k<xn; k++) { | ||
333 | if(xp[k]>r) { | ||
334 | r = xp[k]; | ||
335 | } | ||
336 | } | ||
337 | return r; | ||
338 | } | ||
339 | |||
340 | int vector_min_i(KIVEC(x)) { | ||
341 | int r = xp[0]; | ||
342 | int k; | ||
343 | for (k = 1; k<xn; k++) { | ||
344 | if(xp[k]<r) { | ||
345 | r = xp[k]; | ||
346 | } | ||
347 | } | ||
348 | return r; | ||
349 | } | ||
350 | |||
351 | int vector_max_index_i(KIVEC(x)) { | ||
352 | int k, r = 0; | ||
353 | for (k = 1; k<xn; k++) { | ||
354 | if(xp[k]>xp[r]) { | ||
355 | r = k; | ||
356 | } | ||
357 | } | ||
358 | return r; | ||
359 | } | ||
360 | |||
361 | int vector_min_index_i(KIVEC(x)) { | ||
362 | int k, r = 0; | ||
363 | for (k = 1; k<xn; k++) { | ||
364 | if(xp[k]<xp[r]) { | ||
365 | r = k; | ||
366 | } | ||
367 | } | ||
368 | return r; | ||
369 | } | ||
370 | |||
371 | |||
372 | int toScalarI(int code, KIVEC(x), IVEC(r)) { | ||
373 | REQUIRES(rn==1,BAD_SIZE); | ||
374 | int res; | ||
375 | switch(code) { | ||
376 | case 2: { res = vector_max_index_i(V(x)); break; } | ||
377 | case 3: { res = vector_max_i(V(x)); break; } | ||
378 | case 4: { res = vector_min_index_i(V(x)); break; } | ||
379 | case 5: { res = vector_min_i(V(x)); break; } | ||
380 | default: ERROR(BAD_CODE); | ||
381 | } | ||
382 | rp[0] = res; | ||
383 | OK | ||
384 | } | ||
385 | |||
386 | |||
387 | int64_t vector_max_l(KLVEC(x)) { | ||
388 | int64_t r = xp[0]; | ||
389 | int k; | ||
390 | for (k = 1; k<xn; k++) { | ||
391 | if(xp[k]>r) { | ||
392 | r = xp[k]; | ||
393 | } | ||
394 | } | ||
395 | return r; | ||
396 | } | ||
397 | |||
398 | int64_t vector_min_l(KLVEC(x)) { | ||
399 | int64_t r = xp[0]; | ||
400 | int k; | ||
401 | for (k = 1; k<xn; k++) { | ||
402 | if(xp[k]<r) { | ||
403 | r = xp[k]; | ||
404 | } | ||
405 | } | ||
406 | return r; | ||
407 | } | ||
408 | |||
409 | int vector_max_index_l(KLVEC(x)) { | ||
410 | int k, r = 0; | ||
411 | for (k = 1; k<xn; k++) { | ||
412 | if(xp[k]>xp[r]) { | ||
413 | r = k; | ||
414 | } | ||
415 | } | ||
416 | return r; | ||
417 | } | ||
418 | |||
419 | int vector_min_index_l(KLVEC(x)) { | ||
420 | int k, r = 0; | ||
421 | for (k = 1; k<xn; k++) { | ||
422 | if(xp[k]<xp[r]) { | ||
423 | r = k; | ||
424 | } | ||
425 | } | ||
426 | return r; | ||
427 | } | ||
428 | |||
429 | |||
430 | int toScalarL(int code, KLVEC(x), LVEC(r)) { | ||
431 | REQUIRES(rn==1,BAD_SIZE); | ||
432 | int64_t res; | ||
433 | switch(code) { | ||
434 | case 2: { res = vector_max_index_l(V(x)); break; } | ||
435 | case 3: { res = vector_max_l(V(x)); break; } | ||
436 | case 4: { res = vector_min_index_l(V(x)); break; } | ||
437 | case 5: { res = vector_min_l(V(x)); break; } | ||
438 | default: ERROR(BAD_CODE); | ||
439 | } | ||
440 | rp[0] = res; | ||
441 | OK | ||
442 | } | ||
443 | |||
444 | |||
278 | double dznrm2_(integer*, const doublecomplex*, integer*); | 445 | double dznrm2_(integer*, const doublecomplex*, integer*); |
279 | double dzasum_(integer*, const doublecomplex*, integer*); | 446 | double dzasum_(integer*, const doublecomplex*, integer*); |
280 | 447 | ||
281 | int toScalarC(int code, KCVEC(x), DVEC(r)) { | 448 | int toScalarC(int code, KCVEC(x), DVEC(r)) { |
282 | REQUIRES(rn==1,BAD_SIZE); | 449 | REQUIRES(rn==1,BAD_SIZE); |
283 | DEBUGMSG("toScalarC"); | 450 | DEBUGMSG("toScalarC"); |
284 | double res; | 451 | double res; |
@@ -297,7 +464,7 @@ int toScalarC(int code, KCVEC(x), DVEC(r)) { | |||
297 | double scnrm2_(integer*, const complex*, integer*); | 464 | double scnrm2_(integer*, const complex*, integer*); |
298 | double scasum_(integer*, const complex*, integer*); | 465 | double scasum_(integer*, const complex*, integer*); |
299 | 466 | ||
300 | int toScalarQ(int code, KQVEC(x), FVEC(r)) { | 467 | int toScalarQ(int code, KQVEC(x), FVEC(r)) { |
301 | REQUIRES(rn==1,BAD_SIZE); | 468 | REQUIRES(rn==1,BAD_SIZE); |
302 | DEBUGMSG("toScalarQ"); | 469 | DEBUGMSG("toScalarQ"); |
303 | float res; | 470 | float res; |
@@ -389,6 +556,29 @@ int mapF(int code, KFVEC(x), FVEC(r)) { | |||
389 | } | 556 | } |
390 | 557 | ||
391 | 558 | ||
559 | int mapI(int code, KIVEC(x), IVEC(r)) { | ||
560 | int k; | ||
561 | REQUIRES(xn == rn,BAD_SIZE); | ||
562 | switch (code) { | ||
563 | OP(3,abs) | ||
564 | OP(15,sign) | ||
565 | default: ERROR(BAD_CODE); | ||
566 | } | ||
567 | } | ||
568 | |||
569 | |||
570 | int mapL(int code, KLVEC(x), LVEC(r)) { | ||
571 | int k; | ||
572 | REQUIRES(xn == rn,BAD_SIZE); | ||
573 | switch (code) { | ||
574 | OP(3,abs) | ||
575 | OP(15,sign) | ||
576 | default: ERROR(BAD_CODE); | ||
577 | } | ||
578 | } | ||
579 | |||
580 | |||
581 | |||
392 | inline double abs_complex(doublecomplex z) { | 582 | inline double abs_complex(doublecomplex z) { |
393 | return sqrt(z.r*z.r + z.i*z.i); | 583 | return sqrt(z.r*z.r + z.i*z.i); |
394 | } | 584 | } |
@@ -526,6 +716,38 @@ int mapValF(int code, float* pval, KFVEC(x), FVEC(r)) { | |||
526 | } | 716 | } |
527 | } | 717 | } |
528 | 718 | ||
719 | int mapValI(int code, int* pval, KIVEC(x), IVEC(r)) { | ||
720 | int k; | ||
721 | int val = *pval; | ||
722 | REQUIRES(xn == rn,BAD_SIZE); | ||
723 | DEBUGMSG("mapValI"); | ||
724 | switch (code) { | ||
725 | OPV(0,val*xp[k]) | ||
726 | OPV(1,val/xp[k]) | ||
727 | OPV(2,val+xp[k]) | ||
728 | OPV(3,val-xp[k]) | ||
729 | OPV(6,mod(val,xp[k])) | ||
730 | OPV(7,mod(xp[k],val)) | ||
731 | default: ERROR(BAD_CODE); | ||
732 | } | ||
733 | } | ||
734 | |||
735 | int mapValL(int code, int64_t* pval, KLVEC(x), LVEC(r)) { | ||
736 | int k; | ||
737 | int64_t val = *pval; | ||
738 | REQUIRES(xn == rn,BAD_SIZE); | ||
739 | DEBUGMSG("mapValL"); | ||
740 | switch (code) { | ||
741 | OPV(0,val*xp[k]) | ||
742 | OPV(1,val/xp[k]) | ||
743 | OPV(2,val+xp[k]) | ||
744 | OPV(3,val-xp[k]) | ||
745 | OPV(6,mod_l(val,xp[k])) | ||
746 | OPV(7,mod_l(xp[k],val)) | ||
747 | default: ERROR(BAD_CODE); | ||
748 | } | ||
749 | } | ||
750 | |||
529 | 751 | ||
530 | 752 | ||
531 | inline doublecomplex complex_add(doublecomplex a, doublecomplex b) { | 753 | inline doublecomplex complex_add(doublecomplex a, doublecomplex b) { |
@@ -608,6 +830,33 @@ REQUIRES(an == bn && an == rn, BAD_SIZE); | |||
608 | } | 830 | } |
609 | 831 | ||
610 | 832 | ||
833 | int zipI(int code, KIVEC(a), KIVEC(b), IVEC(r)) { | ||
834 | REQUIRES(an == bn && an == rn, BAD_SIZE); | ||
835 | int k; | ||
836 | switch(code) { | ||
837 | OPZO(0,"zipI Add",+) | ||
838 | OPZO(1,"zipI Sub",-) | ||
839 | OPZO(2,"zipI Mul",*) | ||
840 | OPZO(3,"zipI Div",/) | ||
841 | OPZO(6,"zipI Mod",%) | ||
842 | default: ERROR(BAD_CODE); | ||
843 | } | ||
844 | } | ||
845 | |||
846 | |||
847 | int zipL(int code, KLVEC(a), KLVEC(b), LVEC(r)) { | ||
848 | REQUIRES(an == bn && an == rn, BAD_SIZE); | ||
849 | int k; | ||
850 | switch(code) { | ||
851 | OPZO(0,"zipI Add",+) | ||
852 | OPZO(1,"zipI Sub",-) | ||
853 | OPZO(2,"zipI Mul",*) | ||
854 | OPZO(3,"zipI Div",/) | ||
855 | OPZO(6,"zipI Mod",%) | ||
856 | default: ERROR(BAD_CODE); | ||
857 | } | ||
858 | } | ||
859 | |||
611 | 860 | ||
612 | #define OPZOb(C,msg,O) case C: {DEBUGMSG(msg) for(k=0;k<an;k++) r2p[k] = a2p[k] O b2p[k]; OK } | 861 | #define OPZOb(C,msg,O) case C: {DEBUGMSG(msg) for(k=0;k<an;k++) r2p[k] = a2p[k] O b2p[k]; OK } |
613 | #define OPZEb(C,msg,E) case C: {DEBUGMSG(msg) for(k=0;k<an;k++) r2p[k] = E(a2p[k],b2p[k]); OK } | 862 | #define OPZEb(C,msg,E) case C: {DEBUGMSG(msg) for(k=0;k<an;k++) r2p[k] = E(a2p[k],b2p[k]); OK } |
@@ -679,24 +928,6 @@ int vectorScan(char * file, int* n, double**pp){ | |||
679 | *pp = p; | 928 | *pp = p; |
680 | fclose(fp); | 929 | fclose(fp); |
681 | OK | 930 | OK |
682 | } | ||
683 | |||
684 | int saveMatrix(char * file, char * format, KDMAT(a)){ | ||
685 | FILE * fp; | ||
686 | fp = fopen (file, "w"); | ||
687 | int r, c; | ||
688 | for (r=0;r<ar; r++) { | ||
689 | for (c=0; c<ac; c++) { | ||
690 | fprintf(fp,format,ap[r*ac+c]); | ||
691 | if (c<ac-1) { | ||
692 | fprintf(fp," "); | ||
693 | } else { | ||
694 | fprintf(fp,"\n"); | ||
695 | } | ||
696 | } | ||
697 | } | ||
698 | fclose(fp); | ||
699 | OK | ||
700 | } | 931 | } |
701 | 932 | ||
702 | //////////////////////////////////////////////////////////////////////////////// | 933 | //////////////////////////////////////////////////////////////////////////////// |
@@ -708,7 +939,12 @@ int saveMatrix(char * file, char * format, KDMAT(a)){ | |||
708 | See: http://www.evanjones.ca/random-thread-safe.html | 939 | See: http://www.evanjones.ca/random-thread-safe.html |
709 | */ | 940 | */ |
710 | #pragma message "randomVector is not thread-safe in OSX and FreeBSD" | 941 | #pragma message "randomVector is not thread-safe in OSX and FreeBSD" |
942 | #endif | ||
711 | 943 | ||
944 | #if defined (__APPLE__) || (__FreeBSD__) || defined(_WIN32) || defined(WIN32) | ||
945 | /* Windows use thread-safe random | ||
946 | See: http://stackoverflow.com/questions/143108/is-windows-rand-s-thread-safe | ||
947 | */ | ||
712 | inline double urandom() { | 948 | inline double urandom() { |
713 | /* the probalility of matching will be theoretically p^3(in fact, it is not) | 949 | /* the probalility of matching will be theoretically p^3(in fact, it is not) |
714 | p is matching probalility of random(). | 950 | p is matching probalility of random(). |
@@ -754,7 +990,7 @@ int random_vector(unsigned int seed, int code, DVEC(r)) { | |||
754 | double V1,V2,S; | 990 | double V1,V2,S; |
755 | 991 | ||
756 | srandom(seed); | 992 | srandom(seed); |
757 | 993 | ||
758 | int k; | 994 | int k; |
759 | switch (code) { | 995 | switch (code) { |
760 | case 0: { // uniform | 996 | case 0: { // uniform |
@@ -816,7 +1052,7 @@ int random_vector(unsigned int seed, int code, DVEC(r)) { | |||
816 | char random_state[128]; | 1052 | char random_state[128]; |
817 | memset(&buffer, 0, sizeof(struct random_data)); | 1053 | memset(&buffer, 0, sizeof(struct random_data)); |
818 | memset(random_state, 0, sizeof(random_state)); | 1054 | memset(random_state, 0, sizeof(random_state)); |
819 | 1055 | ||
820 | initstate_r(seed,random_state,sizeof(random_state),&buffer); | 1056 | initstate_r(seed,random_state,sizeof(random_state),&buffer); |
821 | // setstate_r(random_state,&buffer); | 1057 | // setstate_r(random_state,&buffer); |
822 | // srandom_r(seed,&buffer); | 1058 | // srandom_r(seed,&buffer); |
@@ -847,43 +1083,115 @@ int random_vector(unsigned int seed, int code, DVEC(r)) { | |||
847 | 1083 | ||
848 | //////////////////////////////////////////////////////////////////////////////// | 1084 | //////////////////////////////////////////////////////////////////////////////// |
849 | 1085 | ||
850 | int smXv(KDVEC(vals),KIVEC(cols),KIVEC(rows),KDVEC(x),DVEC(r)) { | 1086 | int |
851 | int r, c; | 1087 | compare_doubles (const void *a, const void *b) { |
852 | for (r = 0; r < rowsn - 1; r++) { | 1088 | return *(double*)a > *(double*)b; |
853 | rp[r] = 0; | 1089 | } |
854 | for (c = rowsp[r]; c < rowsp[r+1]; c++) { | 1090 | |
855 | rp[r] += valsp[c-1] * xp[colsp[c-1]-1]; | 1091 | int sort_valuesD(KDVEC(v),DVEC(r)) { |
856 | } | 1092 | memcpy(rp,vp,vn*sizeof(double)); |
857 | } | 1093 | qsort(rp,rn,sizeof(double),compare_doubles); |
858 | OK | 1094 | OK |
859 | } | 1095 | } |
860 | 1096 | ||
861 | int smTXv(KDVEC(vals),KIVEC(cols),KIVEC(rows),KDVEC(x),DVEC(r)) { | 1097 | int |
862 | int r,c; | 1098 | compare_floats (const void *a, const void *b) { |
863 | for (c = 0; c < rn; c++) { | 1099 | return *(float*)a > *(float*)b; |
864 | rp[c] = 0; | 1100 | } |
865 | } | 1101 | |
866 | for (r = 0; r < rowsn - 1; r++) { | 1102 | int sort_valuesF(KFVEC(v),FVEC(r)) { |
867 | for (c = rowsp[r]; c < rowsp[r+1]; c++) { | 1103 | memcpy(rp,vp,vn*sizeof(float)); |
868 | rp[colsp[c-1]-1] += valsp[c-1] * xp[r]; | 1104 | qsort(rp,rn,sizeof(float),compare_floats); |
869 | } | ||
870 | } | ||
871 | OK | 1105 | OK |
872 | } | 1106 | } |
873 | 1107 | ||
874 | //////////////////////////////////////////////////////////////////////////////// | 1108 | int |
1109 | compare_ints(const void *a, const void *b) { | ||
1110 | return *(int*)a > *(int*)b; | ||
1111 | } | ||
1112 | |||
1113 | int sort_valuesI(KIVEC(v),IVEC(r)) { | ||
1114 | memcpy(rp,vp,vn*sizeof(int)); | ||
1115 | qsort(rp,rn,sizeof(int),compare_ints); | ||
1116 | OK | ||
1117 | } | ||
875 | 1118 | ||
876 | int | 1119 | int |
877 | compare_doubles (const void *a, const void *b) { | 1120 | compare_longs(const void *a, const void *b) { |
878 | return *(double*)a > *(double*)b; | 1121 | return *(int64_t*)a > *(int64_t*)b; |
879 | } | 1122 | } |
880 | 1123 | ||
881 | int sort_values(KDVEC(v),DVEC(r)) { | 1124 | int sort_valuesL(KLVEC(v),LVEC(r)) { |
882 | memcpy(rp,vp,vn*sizeof(double)); | 1125 | memcpy(rp,vp,vn*sizeof(int64_t)); |
883 | qsort(rp,rn,sizeof(double),compare_doubles); | 1126 | qsort(rp,rn,sizeof(int64_t),compare_ints); |
884 | OK | 1127 | OK |
885 | } | 1128 | } |
886 | 1129 | ||
1130 | |||
1131 | //////////////////////////////////////// | ||
1132 | |||
1133 | |||
1134 | #define SORTIDX_IMP(T,C) \ | ||
1135 | T* x = (T*)malloc(sizeof(T)*vn); \ | ||
1136 | int k; \ | ||
1137 | for (k=0;k<vn;k++) { \ | ||
1138 | x[k].pos = k; \ | ||
1139 | x[k].val = vp[k]; \ | ||
1140 | } \ | ||
1141 | \ | ||
1142 | qsort(x,vn,sizeof(T),C); \ | ||
1143 | \ | ||
1144 | for (k=0;k<vn;k++) { \ | ||
1145 | rp[k] = x[k].pos; \ | ||
1146 | } \ | ||
1147 | free(x); \ | ||
1148 | OK | ||
1149 | |||
1150 | |||
1151 | typedef struct DI { int pos; double val;} DI; | ||
1152 | |||
1153 | int compare_doubles_i (const void *a, const void *b) { | ||
1154 | return ((DI*)a)->val > ((DI*)b)->val; | ||
1155 | } | ||
1156 | |||
1157 | int sort_indexD(KDVEC(v),IVEC(r)) { | ||
1158 | SORTIDX_IMP(DI,compare_doubles_i) | ||
1159 | } | ||
1160 | |||
1161 | |||
1162 | typedef struct FI { int pos; float val;} FI; | ||
1163 | |||
1164 | int compare_floats_i (const void *a, const void *b) { | ||
1165 | return ((FI*)a)->val > ((FI*)b)->val; | ||
1166 | } | ||
1167 | |||
1168 | int sort_indexF(KFVEC(v),IVEC(r)) { | ||
1169 | SORTIDX_IMP(FI,compare_floats_i) | ||
1170 | } | ||
1171 | |||
1172 | |||
1173 | typedef struct II { int pos; int val;} II; | ||
1174 | |||
1175 | int compare_ints_i (const void *a, const void *b) { | ||
1176 | return ((II*)a)->val > ((II*)b)->val; | ||
1177 | } | ||
1178 | |||
1179 | int sort_indexI(KIVEC(v),IVEC(r)) { | ||
1180 | SORTIDX_IMP(II,compare_ints_i) | ||
1181 | } | ||
1182 | |||
1183 | |||
1184 | typedef struct LI { int pos; int64_t val;} LI; | ||
1185 | |||
1186 | int compare_longs_i (const void *a, const void *b) { | ||
1187 | return ((II*)a)->val > ((II*)b)->val; | ||
1188 | } | ||
1189 | |||
1190 | int sort_indexL(KLVEC(v),LVEC(r)) { | ||
1191 | SORTIDX_IMP(II,compare_longs_i) | ||
1192 | } | ||
1193 | |||
1194 | |||
887 | //////////////////////////////////////////////////////////////////////////////// | 1195 | //////////////////////////////////////////////////////////////////////////////// |
888 | 1196 | ||
889 | int round_vector(KDVEC(v),DVEC(r)) { | 1197 | int round_vector(KDVEC(v),DVEC(r)) { |
@@ -894,3 +1202,285 @@ int round_vector(KDVEC(v),DVEC(r)) { | |||
894 | OK | 1202 | OK |
895 | } | 1203 | } |
896 | 1204 | ||
1205 | //////////////////////////////////////////////////////////////////////////////// | ||
1206 | |||
1207 | int round_vector_i(KDVEC(v),IVEC(r)) { | ||
1208 | int k; | ||
1209 | for(k=0; k<vn; k++) { | ||
1210 | rp[k] = round(vp[k]); | ||
1211 | } | ||
1212 | OK | ||
1213 | } | ||
1214 | |||
1215 | |||
1216 | int mod_vector(int m, KIVEC(v), IVEC(r)) { | ||
1217 | int k; | ||
1218 | for(k=0; k<vn; k++) { | ||
1219 | rp[k] = mod(vp[k],m); | ||
1220 | } | ||
1221 | OK | ||
1222 | } | ||
1223 | |||
1224 | int div_vector(int m, KIVEC(v), IVEC(r)) { | ||
1225 | int k; | ||
1226 | for(k=0; k<vn; k++) { | ||
1227 | rp[k] = vp[k] / m; | ||
1228 | } | ||
1229 | OK | ||
1230 | } | ||
1231 | |||
1232 | int range_vector(IVEC(r)) { | ||
1233 | int k; | ||
1234 | for(k=0; k<rn; k++) { | ||
1235 | rp[k] = k; | ||
1236 | } | ||
1237 | OK | ||
1238 | } | ||
1239 | |||
1240 | /////////////////////////// | ||
1241 | |||
1242 | |||
1243 | int round_vector_l(KDVEC(v),LVEC(r)) { | ||
1244 | int k; | ||
1245 | for(k=0; k<vn; k++) { | ||
1246 | rp[k] = round(vp[k]); | ||
1247 | } | ||
1248 | OK | ||
1249 | } | ||
1250 | |||
1251 | |||
1252 | int mod_vector_l(int64_t m, KLVEC(v), LVEC(r)) { | ||
1253 | int k; | ||
1254 | for(k=0; k<vn; k++) { | ||
1255 | rp[k] = mod_l(vp[k],m); | ||
1256 | } | ||
1257 | OK | ||
1258 | } | ||
1259 | |||
1260 | int div_vector_l(int64_t m, KLVEC(v), LVEC(r)) { | ||
1261 | int k; | ||
1262 | for(k=0; k<vn; k++) { | ||
1263 | rp[k] = vp[k] / m; | ||
1264 | } | ||
1265 | OK | ||
1266 | } | ||
1267 | |||
1268 | int range_vector_l(LVEC(r)) { | ||
1269 | int k; | ||
1270 | for(k=0; k<rn; k++) { | ||
1271 | rp[k] = k; | ||
1272 | } | ||
1273 | OK | ||
1274 | } | ||
1275 | |||
1276 | |||
1277 | |||
1278 | //////////////////// constant ///////////////////////// | ||
1279 | |||
1280 | int constantF(float * pval, FVEC(r)) { | ||
1281 | DEBUGMSG("constantF") | ||
1282 | int k; | ||
1283 | double val = *pval; | ||
1284 | for(k=0;k<rn;k++) { | ||
1285 | rp[k]=val; | ||
1286 | } | ||
1287 | OK | ||
1288 | } | ||
1289 | |||
1290 | int constantR(double * pval, DVEC(r)) { | ||
1291 | DEBUGMSG("constantR") | ||
1292 | int k; | ||
1293 | double val = *pval; | ||
1294 | for(k=0;k<rn;k++) { | ||
1295 | rp[k]=val; | ||
1296 | } | ||
1297 | OK | ||
1298 | } | ||
1299 | |||
1300 | int constantQ(complex* pval, QVEC(r)) { | ||
1301 | DEBUGMSG("constantQ") | ||
1302 | int k; | ||
1303 | complex val = *pval; | ||
1304 | for(k=0;k<rn;k++) { | ||
1305 | rp[k]=val; | ||
1306 | } | ||
1307 | OK | ||
1308 | } | ||
1309 | |||
1310 | int constantC(doublecomplex* pval, CVEC(r)) { | ||
1311 | DEBUGMSG("constantC") | ||
1312 | int k; | ||
1313 | doublecomplex val = *pval; | ||
1314 | for(k=0;k<rn;k++) { | ||
1315 | rp[k]=val; | ||
1316 | } | ||
1317 | OK | ||
1318 | } | ||
1319 | |||
1320 | |||
1321 | |||
1322 | int constantI(int * pval, IVEC(r)) { | ||
1323 | DEBUGMSG("constantI") | ||
1324 | int k; | ||
1325 | int val = *pval; | ||
1326 | for(k=0;k<rn;k++) { | ||
1327 | rp[k]=val; | ||
1328 | } | ||
1329 | OK | ||
1330 | } | ||
1331 | |||
1332 | |||
1333 | |||
1334 | int constantL(int64_t * pval, LVEC(r)) { | ||
1335 | DEBUGMSG("constantL") | ||
1336 | int k; | ||
1337 | int64_t val = *pval; | ||
1338 | for(k=0;k<rn;k++) { | ||
1339 | rp[k]=val; | ||
1340 | } | ||
1341 | OK | ||
1342 | } | ||
1343 | |||
1344 | |||
1345 | //////////////////// type conversions ///////////////////////// | ||
1346 | |||
1347 | #define CONVERT_IMP { \ | ||
1348 | int k; \ | ||
1349 | for(k=0;k<xn;k++) { \ | ||
1350 | yp[k]=xp[k]; \ | ||
1351 | } \ | ||
1352 | OK } | ||
1353 | |||
1354 | int float2double(FVEC(x),DVEC(y)) CONVERT_IMP | ||
1355 | |||
1356 | int float2int(KFVEC(x),IVEC(y)) CONVERT_IMP | ||
1357 | |||
1358 | int double2float(DVEC(x),FVEC(y)) CONVERT_IMP | ||
1359 | |||
1360 | int double2int(KDVEC(x),IVEC(y)) CONVERT_IMP | ||
1361 | |||
1362 | int double2long(KDVEC(x),LVEC(y)) CONVERT_IMP | ||
1363 | |||
1364 | int int2float(KIVEC(x),FVEC(y)) CONVERT_IMP | ||
1365 | |||
1366 | int int2double(KIVEC(x),DVEC(y)) CONVERT_IMP | ||
1367 | |||
1368 | int int2long(KIVEC(x),LVEC(y)) CONVERT_IMP | ||
1369 | |||
1370 | int long2int(KLVEC(x),IVEC(y)) CONVERT_IMP | ||
1371 | |||
1372 | int long2double(KLVEC(x),DVEC(y)) CONVERT_IMP | ||
1373 | |||
1374 | |||
1375 | //////////////////// conjugate ///////////////////////// | ||
1376 | |||
1377 | int conjugateQ(KQVEC(x),QVEC(t)) { | ||
1378 | REQUIRES(xn==tn,BAD_SIZE); | ||
1379 | DEBUGMSG("conjugateQ"); | ||
1380 | int k; | ||
1381 | for(k=0;k<xn;k++) { | ||
1382 | tp[k].r = xp[k].r; | ||
1383 | tp[k].i = -xp[k].i; | ||
1384 | } | ||
1385 | OK | ||
1386 | } | ||
1387 | |||
1388 | int conjugateC(KCVEC(x),CVEC(t)) { | ||
1389 | REQUIRES(xn==tn,BAD_SIZE); | ||
1390 | DEBUGMSG("conjugateC"); | ||
1391 | int k; | ||
1392 | for(k=0;k<xn;k++) { | ||
1393 | tp[k].r = xp[k].r; | ||
1394 | tp[k].i = -xp[k].i; | ||
1395 | } | ||
1396 | OK | ||
1397 | } | ||
1398 | |||
1399 | //////////////////// step ///////////////////////// | ||
1400 | |||
1401 | #define STEP_IMP \ | ||
1402 | int k; \ | ||
1403 | for(k=0;k<xn;k++) { \ | ||
1404 | yp[k]=xp[k]>0; \ | ||
1405 | } \ | ||
1406 | OK | ||
1407 | |||
1408 | int stepF(KFVEC(x),FVEC(y)) { | ||
1409 | STEP_IMP | ||
1410 | } | ||
1411 | |||
1412 | int stepD(KDVEC(x),DVEC(y)) { | ||
1413 | STEP_IMP | ||
1414 | } | ||
1415 | |||
1416 | int stepI(KIVEC(x),IVEC(y)) { | ||
1417 | STEP_IMP | ||
1418 | } | ||
1419 | |||
1420 | int stepL(KLVEC(x),LVEC(y)) { | ||
1421 | STEP_IMP | ||
1422 | } | ||
1423 | |||
1424 | |||
1425 | //////////////////// cond ///////////////////////// | ||
1426 | |||
1427 | #define COMPARE_IMP \ | ||
1428 | REQUIRES(xn==yn && xn==rn ,BAD_SIZE); \ | ||
1429 | int k; \ | ||
1430 | for(k=0;k<xn;k++) { \ | ||
1431 | rp[k] = xp[k]<yp[k]?-1:(xp[k]>yp[k]?1:0); \ | ||
1432 | } \ | ||
1433 | OK | ||
1434 | |||
1435 | |||
1436 | int compareF(KFVEC(x),KFVEC(y),IVEC(r)) { | ||
1437 | COMPARE_IMP | ||
1438 | } | ||
1439 | |||
1440 | int compareD(KDVEC(x),KDVEC(y),IVEC(r)) { | ||
1441 | COMPARE_IMP | ||
1442 | } | ||
1443 | |||
1444 | int compareI(KIVEC(x),KIVEC(y),IVEC(r)) { | ||
1445 | COMPARE_IMP | ||
1446 | } | ||
1447 | |||
1448 | int compareL(KLVEC(x),KLVEC(y),IVEC(r)) { | ||
1449 | COMPARE_IMP | ||
1450 | } | ||
1451 | |||
1452 | |||
1453 | |||
1454 | #define CHOOSE_IMP \ | ||
1455 | REQUIRES(condn==ltn && ltn==eqn && ltn==gtn && ltn==rn ,BAD_SIZE); \ | ||
1456 | int k; \ | ||
1457 | for(k=0;k<condn;k++) { \ | ||
1458 | rp[k] = condp[k]<0?ltp[k]:(condp[k]>0?gtp[k]:eqp[k]); \ | ||
1459 | } \ | ||
1460 | OK | ||
1461 | |||
1462 | int chooseF(KIVEC(cond),KFVEC(lt),KFVEC(eq),KFVEC(gt),FVEC(r)) { | ||
1463 | CHOOSE_IMP | ||
1464 | } | ||
1465 | |||
1466 | int chooseD(KIVEC(cond),KDVEC(lt),KDVEC(eq),KDVEC(gt),DVEC(r)) { | ||
1467 | CHOOSE_IMP | ||
1468 | } | ||
1469 | |||
1470 | int chooseI(KIVEC(cond),KIVEC(lt),KIVEC(eq),KIVEC(gt),IVEC(r)) { | ||
1471 | CHOOSE_IMP | ||
1472 | } | ||
1473 | |||
1474 | int chooseL(KIVEC(cond),KLVEC(lt),KLVEC(eq),KLVEC(gt),LVEC(r)) { | ||
1475 | CHOOSE_IMP | ||
1476 | } | ||
1477 | |||
1478 | |||
1479 | int chooseC(KIVEC(cond),KCVEC(lt),KCVEC(eq),KCVEC(gt),CVEC(r)) { | ||
1480 | CHOOSE_IMP | ||
1481 | } | ||
1482 | |||
1483 | int chooseQ(KIVEC(cond),KQVEC(lt),KQVEC(eq),KQVEC(gt),QVEC(r)) { | ||
1484 | CHOOSE_IMP | ||
1485 | } | ||
1486 | |||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Internal/CG.hs index b82c74f..cc10ad8 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs +++ b/packages/base/src/Internal/CG.hs | |||
@@ -1,15 +1,20 @@ | |||
1 | {-# LANGUAGE FlexibleContexts, FlexibleInstances #-} | 1 | {-# LANGUAGE FlexibleContexts, FlexibleInstances #-} |
2 | {-# LANGUAGE RecordWildCards #-} | 2 | {-# LANGUAGE RecordWildCards #-} |
3 | 3 | ||
4 | module Numeric.LinearAlgebra.Util.CG( | 4 | module Internal.CG( |
5 | cgSolve, cgSolve', | 5 | cgSolve, cgSolve', |
6 | CGState(..), R, V | 6 | CGState(..), R, V |
7 | ) where | 7 | ) where |
8 | 8 | ||
9 | import Data.Packed.Numeric | 9 | import Internal.Vector |
10 | import Numeric.Sparse | 10 | import Internal.Matrix |
11 | import Internal.Numeric | ||
12 | import Internal.Element | ||
13 | import Internal.IO | ||
14 | import Internal.Container | ||
15 | import Internal.Sparse | ||
11 | import Numeric.Vector() | 16 | import Numeric.Vector() |
12 | import Numeric.LinearAlgebra.Algorithms(linearSolveLS, relativeError, NormType(..)) | 17 | import Internal.Algorithms(linearSolveLS, linearSolve, relativeError, pnorm, NormType(..)) |
13 | import Control.Arrow((***)) | 18 | import Control.Arrow((***)) |
14 | 19 | ||
15 | {- | 20 | {- |
@@ -24,15 +29,14 @@ infix 0 /// | |||
24 | v /// b = debugMat b 2 asRow v | 29 | v /// b = debugMat b 2 asRow v |
25 | -} | 30 | -} |
26 | 31 | ||
27 | type R = Double | ||
28 | type V = Vector R | 32 | type V = Vector R |
29 | 33 | ||
30 | data CGState = CGState | 34 | data CGState = CGState |
31 | { cgp :: V -- ^ conjugate gradient | 35 | { cgp :: Vector R -- ^ conjugate gradient |
32 | , cgr :: V -- ^ residual | 36 | , cgr :: Vector R -- ^ residual |
33 | , cgr2 :: R -- ^ squared norm of residual | 37 | , cgr2 :: R -- ^ squared norm of residual |
34 | , cgx :: V -- ^ current solution | 38 | , cgx :: Vector R -- ^ current solution |
35 | , cgdx :: R -- ^ normalized size of correction | 39 | , cgdx :: R -- ^ normalized size of correction |
36 | } | 40 | } |
37 | 41 | ||
38 | cg :: Bool -> (V -> V) -> (V -> V) -> CGState -> CGState | 42 | cg :: Bool -> (V -> V) -> (V -> V) -> CGState -> CGState |
@@ -41,13 +45,13 @@ cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx | |||
41 | ap1 = a p | 45 | ap1 = a p |
42 | ap | sym = ap1 | 46 | ap | sym = ap1 |
43 | | otherwise = at ap1 | 47 | | otherwise = at ap1 |
44 | pap | sym = p <·> ap1 | 48 | pap | sym = p <.> ap1 |
45 | | otherwise = norm2 ap1 ** 2 | 49 | | otherwise = norm2 ap1 ** 2 |
46 | alpha = r2 / pap | 50 | alpha = r2 / pap |
47 | dx = scale alpha p | 51 | dx = scale alpha p |
48 | x' = x + dx | 52 | x' = x + dx |
49 | r' = r - scale alpha ap | 53 | r' = r - scale alpha ap |
50 | r'2 = r' <·> r' | 54 | r'2 = r' <.> r' |
51 | beta = r'2 / r2 | 55 | beta = r'2 / r2 |
52 | p' = r' + scale beta p | 56 | p' = r' + scale beta p |
53 | 57 | ||
@@ -55,25 +59,26 @@ cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx | |||
55 | 59 | ||
56 | conjugrad | 60 | conjugrad |
57 | :: Bool -> GMatrix -> V -> V -> R -> R -> [CGState] | 61 | :: Bool -> GMatrix -> V -> V -> R -> R -> [CGState] |
58 | conjugrad sym a b = solveG (tr a !#>) (a !#>) (cg sym) b | 62 | conjugrad sym a b = solveG sym (tr a !#>) (a !#>) (cg sym) b |
59 | 63 | ||
60 | solveG | 64 | solveG |
61 | :: (V -> V) -> (V -> V) | 65 | :: Bool |
66 | -> (V -> V) -> (V -> V) | ||
62 | -> ((V -> V) -> (V -> V) -> CGState -> CGState) | 67 | -> ((V -> V) -> (V -> V) -> CGState -> CGState) |
63 | -> V | 68 | -> V |
64 | -> V | 69 | -> V |
65 | -> R -> R | 70 | -> R -> R |
66 | -> [CGState] | 71 | -> [CGState] |
67 | solveG mat ma meth rawb x0' ϵb ϵx | 72 | solveG sym mat ma meth rawb x0' ϵb ϵx |
68 | = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1 | 73 | = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1 |
69 | where | 74 | where |
70 | a = mat . ma | 75 | a = if sym then ma else mat . ma |
71 | b = mat rawb | 76 | b = if sym then rawb else mat rawb |
72 | x0 = if x0' == 0 then konst 0 (dim b) else x0' | 77 | x0 = if x0' == 0 then konst 0 (dim b) else x0' |
73 | r0 = b - a x0 | 78 | r0 = b - a x0 |
74 | r20 = r0 <·> r0 | 79 | r20 = r0 <.> r0 |
75 | p0 = r0 | 80 | p0 = r0 |
76 | nb2 = b <·> b | 81 | nb2 = b <.> b |
77 | ok CGState {..} | 82 | ok CGState {..} |
78 | = cgr2 <nb2*ϵb**2 | 83 | = cgr2 <nb2*ϵb**2 |
79 | || cgdx < ϵx | 84 | || cgdx < ϵx |
@@ -84,23 +89,25 @@ takeUntil q xs = a++ take 1 b | |||
84 | where | 89 | where |
85 | (a,b) = break q xs | 90 | (a,b) = break q xs |
86 | 91 | ||
92 | -- | Solve a sparse linear system using the conjugate gradient method with default parameters. | ||
87 | cgSolve | 93 | cgSolve |
88 | :: Bool -- ^ is symmetric | 94 | :: Bool -- ^ is symmetric |
89 | -> GMatrix -- ^ coefficient matrix | 95 | -> GMatrix -- ^ coefficient matrix |
90 | -> Vector Double -- ^ right-hand side | 96 | -> Vector R -- ^ right-hand side |
91 | -> Vector Double -- ^ solution | 97 | -> Vector R -- ^ solution |
92 | cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0 | 98 | cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0 |
93 | where | 99 | where |
94 | n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double)) | 100 | n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double)) |
95 | 101 | ||
102 | -- | Solve a sparse linear system using the conjugate gradient method with default parameters. | ||
96 | cgSolve' | 103 | cgSolve' |
97 | :: Bool -- ^ symmetric | 104 | :: Bool -- ^ symmetric |
98 | -> R -- ^ relative tolerance for the residual (e.g. 1E-4) | 105 | -> R -- ^ relative tolerance for the residual (e.g. 1E-4) |
99 | -> R -- ^ relative tolerance for δx (e.g. 1E-3) | 106 | -> R -- ^ relative tolerance for δx (e.g. 1E-3) |
100 | -> Int -- ^ maximum number of iterations | 107 | -> Int -- ^ maximum number of iterations |
101 | -> GMatrix -- ^ coefficient matrix | 108 | -> GMatrix -- ^ coefficient matrix |
102 | -> V -- ^ initial solution | 109 | -> Vector R -- ^ initial solution |
103 | -> V -- ^ right-hand side | 110 | -> Vector R -- ^ right-hand side |
104 | -> [CGState] -- ^ solution | 111 | -> [CGState] -- ^ solution |
105 | cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es | 112 | cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es |
106 | 113 | ||
@@ -134,6 +141,11 @@ instance Testable GMatrix | |||
134 | s5 = cgSolve False sm v | 141 | s5 = cgSolve False sm v |
135 | d5 = denseSolve dm v | 142 | d5 = denseSolve dm v |
136 | 143 | ||
144 | symassoc = [((0,0),1.0),((1,1),2.0),((0,1),0.5),((1,0),0.5)] | ||
145 | b = vect [3,4] | ||
146 | d6 = flatten $ linearSolve (toDense symassoc) (asColumn b) | ||
147 | s6 = cgSolve True (mkSparse symassoc) b | ||
148 | |||
137 | info = do | 149 | info = do |
138 | print sm | 150 | print sm |
139 | disp (toDense sma) | 151 | disp (toDense sma) |
@@ -142,13 +154,16 @@ instance Testable GMatrix | |||
142 | print s3; print d3 | 154 | print s3; print d3 |
143 | print s4; print d4 | 155 | print s4; print d4 |
144 | print s5; print d5 | 156 | print s5; print d5 |
145 | print $ relativeError Infinity s5 d5 | 157 | print $ relativeError (pnorm Infinity) s5 d5 |
158 | print s6; print d6 | ||
159 | print $ relativeError (pnorm Infinity) s6 d6 | ||
146 | 160 | ||
147 | ok = s1==d1 | 161 | ok = s1==d1 |
148 | && s2==d2 | 162 | && s2==d2 |
149 | && s3==d3 | 163 | && s3==d3 |
150 | && s4==d4 | 164 | && s4==d4 |
151 | && relativeError Infinity s5 d5 < 1E-10 | 165 | && relativeError (pnorm Infinity) s5 d5 < 1E-10 |
166 | && relativeError (pnorm Infinity) s6 d6 < 1E-10 | ||
152 | 167 | ||
153 | disp = putStr . dispf 2 | 168 | disp = putStr . dispf 2 |
154 | 169 | ||
diff --git a/packages/base/src/Numeric/Chain.hs b/packages/base/src/Internal/Chain.hs index 443bd28..f87eb02 100644 --- a/packages/base/src/Numeric/Chain.hs +++ b/packages/base/src/Internal/Chain.hs | |||
@@ -1,6 +1,8 @@ | |||
1 | {-# LANGUAGE FlexibleContexts #-} | ||
2 | |||
1 | ----------------------------------------------------------------------------- | 3 | ----------------------------------------------------------------------------- |
2 | -- | | 4 | -- | |
3 | -- Module : Numeric.Chain | 5 | -- Module : Internal.Chain |
4 | -- Copyright : (c) Vivian McPhail 2010 | 6 | -- Copyright : (c) Vivian McPhail 2010 |
5 | -- License : BSD3 | 7 | -- License : BSD3 |
6 | -- | 8 | -- |
@@ -14,14 +16,14 @@ | |||
14 | 16 | ||
15 | {-# LANGUAGE FlexibleContexts #-} | 17 | {-# LANGUAGE FlexibleContexts #-} |
16 | 18 | ||
17 | module Numeric.Chain ( | 19 | module Internal.Chain ( |
18 | optimiseMult, | 20 | optimiseMult, |
19 | ) where | 21 | ) where |
20 | 22 | ||
21 | import Data.Maybe | 23 | import Data.Maybe |
22 | 24 | ||
23 | import Data.Packed.Matrix | 25 | import Internal.Matrix |
24 | import Data.Packed.Internal.Numeric | 26 | import Internal.Numeric |
25 | 27 | ||
26 | import qualified Data.Array.IArray as A | 28 | import qualified Data.Array.IArray as A |
27 | 29 | ||
diff --git a/packages/base/src/Data/Packed/Numeric.hs b/packages/base/src/Internal/Container.hs index 6027f43..b08f892 100644 --- a/packages/base/src/Data/Packed/Numeric.hs +++ b/packages/base/src/Internal/Container.hs | |||
@@ -6,7 +6,7 @@ | |||
6 | 6 | ||
7 | ----------------------------------------------------------------------------- | 7 | ----------------------------------------------------------------------------- |
8 | -- | | 8 | -- | |
9 | -- Module : Data.Packed.Numeric | 9 | -- Module : Internal.Container |
10 | -- Copyright : (c) Alberto Ruiz 2010-14 | 10 | -- Copyright : (c) Alberto Ruiz 2010-14 |
11 | -- License : BSD3 | 11 | -- License : BSD3 |
12 | -- Maintainer : Alberto Ruiz | 12 | -- Maintainer : Alberto Ruiz |
@@ -21,59 +21,14 @@ | |||
21 | -- numeric Haskell classes provided by "Numeric.LinearAlgebra". | 21 | -- numeric Haskell classes provided by "Numeric.LinearAlgebra". |
22 | -- | 22 | -- |
23 | ----------------------------------------------------------------------------- | 23 | ----------------------------------------------------------------------------- |
24 | {-# OPTIONS_HADDOCK hide #-} | 24 | |
25 | 25 | module Internal.Container where | |
26 | module Data.Packed.Numeric ( | 26 | |
27 | -- * Basic functions | 27 | import Internal.Vector |
28 | module Data.Packed, | 28 | import Internal.Matrix |
29 | Konst(..), Build(..), | 29 | import Internal.Element |
30 | linspace, | 30 | import Internal.Numeric |
31 | diag, ident, | 31 | import Internal.Algorithms(Field,linearSolveSVD) |
32 | ctrans, | ||
33 | -- * Generic operations | ||
34 | Container(..), Numeric, | ||
35 | -- add, mul, sub, divide, equal, scaleRecip, addConstant, | ||
36 | scalar, conj, scale, arctan2, cmap, | ||
37 | atIndex, minIndex, maxIndex, minElement, maxElement, | ||
38 | sumElements, prodElements, | ||
39 | step, cond, find, assoc, accum, | ||
40 | Transposable(..), Linear(..), | ||
41 | -- * Matrix product | ||
42 | Product(..), udot, dot, (<·>), (#>), app, | ||
43 | Mul(..), | ||
44 | (<.>), | ||
45 | optimiseMult, | ||
46 | mXm,mXv,vXm,LSDiv,(<\>), | ||
47 | outer, kronecker, | ||
48 | -- * Random numbers | ||
49 | RandDist(..), | ||
50 | randomVector, | ||
51 | gaussianSample, | ||
52 | uniformSample, | ||
53 | meanCov, | ||
54 | -- * sorting | ||
55 | sortVector, | ||
56 | -- * Element conversion | ||
57 | Convert(..), | ||
58 | Complexable(), | ||
59 | RealElement(), | ||
60 | RealOf, ComplexOf, SingleOf, DoubleOf, | ||
61 | roundVector, | ||
62 | IndexOf, | ||
63 | module Data.Complex, | ||
64 | -- * IO | ||
65 | module Data.Packed.IO, | ||
66 | -- * Misc | ||
67 | Testable(..) | ||
68 | ) where | ||
69 | |||
70 | import Data.Packed | ||
71 | import Data.Packed.Internal.Numeric | ||
72 | import Data.Complex | ||
73 | import Numeric.LinearAlgebra.Algorithms(Field,linearSolveSVD) | ||
74 | import Data.Monoid(Monoid(mconcat)) | ||
75 | import Data.Packed.IO | ||
76 | import Numeric.LinearAlgebra.Random | ||
77 | 32 | ||
78 | ------------------------------------------------------------------ | 33 | ------------------------------------------------------------------ |
79 | 34 | ||
@@ -89,7 +44,7 @@ Logarithmic spacing can be defined as follows: | |||
89 | 44 | ||
90 | @logspace n (a,b) = 10 ** linspace n (a,b)@ | 45 | @logspace n (a,b) = 10 ** linspace n (a,b)@ |
91 | -} | 46 | -} |
92 | linspace :: (Container Vector e) => Int -> (e, e) -> Vector e | 47 | linspace :: (Fractional e, Container Vector e) => Int -> (e, e) -> Vector e |
93 | linspace 0 _ = fromList[] | 48 | linspace 0 _ = fromList[] |
94 | linspace 1 (a,b) = fromList[(a+b)/2] | 49 | linspace 1 (a,b) = fromList[(a+b)/2] |
95 | linspace n (a,b) = addConstant a $ scale s $ fromList $ map fromIntegral [0 .. n-1] | 50 | linspace n (a,b) = addConstant a $ scale s $ fromList $ map fromIntegral [0 .. n-1] |
@@ -97,31 +52,26 @@ linspace n (a,b) = addConstant a $ scale s $ fromList $ map fromIntegral [0 .. n | |||
97 | 52 | ||
98 | -------------------------------------------------------------------------------- | 53 | -------------------------------------------------------------------------------- |
99 | 54 | ||
100 | infixl 7 <.> | 55 | infixr 8 <.> |
101 | -- | An infix synonym for 'dot' | 56 | {- | An infix synonym for 'dot' |
102 | (<.>) :: Numeric t => Vector t -> Vector t -> t | ||
103 | (<.>) = dot | ||
104 | 57 | ||
58 | >>> vector [1,2,3,4] <.> vector [-2,0,1,1] | ||
59 | 5.0 | ||
105 | 60 | ||
106 | infixr 8 <·>, #> | 61 | >>> let 𝑖 = 0:+1 :: C |
62 | >>> fromList [1+𝑖,1] <.> fromList [1,1+𝑖] | ||
63 | 2.0 :+ 0.0 | ||
107 | 64 | ||
108 | {- | infix synonym for 'dot' | 65 | -} |
109 | 66 | ||
110 | >>> vector [1,2,3,4] <·> vector [-2,0,1,1] | 67 | (<.>) :: Numeric t => Vector t -> Vector t -> t |
111 | 5.0 | 68 | (<.>) = dot |
112 | 69 | ||
113 | >>> let 𝑖 = 0:+1 :: ℂ | ||
114 | >>> fromList [1+𝑖,1] <·> fromList [1,1+𝑖] | ||
115 | 2.0 :+ 0.0 | ||
116 | 70 | ||
117 | (the dot symbol "·" is obtained by Alt-Gr .) | ||
118 | 71 | ||
119 | -} | ||
120 | (<·>) :: Numeric t => Vector t -> Vector t -> t | ||
121 | (<·>) = dot | ||
122 | 72 | ||
123 | 73 | ||
124 | {- | infix synonym for 'app' | 74 | {- | dense matrix-vector product |
125 | 75 | ||
126 | >>> let m = (2><3) [1..] | 76 | >>> let m = (2><3) [1..] |
127 | >>> m | 77 | >>> m |
@@ -135,6 +85,7 @@ infixr 8 <·>, #> | |||
135 | fromList [140.0,320.0] | 85 | fromList [140.0,320.0] |
136 | 86 | ||
137 | -} | 87 | -} |
88 | infixr 8 #> | ||
138 | (#>) :: Numeric t => Matrix t -> Vector t -> Vector t | 89 | (#>) :: Numeric t => Matrix t -> Vector t -> Vector t |
139 | (#>) = mXv | 90 | (#>) = mXv |
140 | 91 | ||
@@ -142,6 +93,11 @@ fromList [140.0,320.0] | |||
142 | app :: Numeric t => Matrix t -> Vector t -> Vector t | 93 | app :: Numeric t => Matrix t -> Vector t -> Vector t |
143 | app = (#>) | 94 | app = (#>) |
144 | 95 | ||
96 | infixl 8 <# | ||
97 | -- | dense vector-matrix product | ||
98 | (<#) :: Numeric t => Vector t -> Matrix t -> Vector t | ||
99 | (<#) = vXm | ||
100 | |||
145 | -------------------------------------------------------------------------------- | 101 | -------------------------------------------------------------------------------- |
146 | 102 | ||
147 | class Mul a b c | a b -> c where | 103 | class Mul a b c | a b -> c where |
@@ -201,29 +157,6 @@ instance LSDiv Matrix | |||
201 | 157 | ||
202 | -------------------------------------------------------------------------------- | 158 | -------------------------------------------------------------------------------- |
203 | 159 | ||
204 | class Konst e d c | d -> c, c -> d | ||
205 | where | ||
206 | -- | | ||
207 | -- >>> konst 7 3 :: Vector Float | ||
208 | -- fromList [7.0,7.0,7.0] | ||
209 | -- | ||
210 | -- >>> konst i (3::Int,4::Int) | ||
211 | -- (3><4) | ||
212 | -- [ 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0 | ||
213 | -- , 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0 | ||
214 | -- , 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0 ] | ||
215 | -- | ||
216 | konst :: e -> d -> c e | ||
217 | |||
218 | instance Container Vector e => Konst e Int Vector | ||
219 | where | ||
220 | konst = konst' | ||
221 | |||
222 | instance Container Vector e => Konst e (Int,Int) Matrix | ||
223 | where | ||
224 | konst = konst' | ||
225 | |||
226 | -------------------------------------------------------------------------------- | ||
227 | 160 | ||
228 | class Build d f c e | d -> c, c -> d, f -> e, f -> d, f -> c, c e -> f, d e -> f | 161 | class Build d f c e | d -> c, c -> d, f -> e, f -> d, f -> c, c e -> f, d e -> f |
229 | where | 162 | where |
@@ -284,16 +217,81 @@ meanCov x = (med,cov) where | |||
284 | 217 | ||
285 | -------------------------------------------------------------------------------- | 218 | -------------------------------------------------------------------------------- |
286 | 219 | ||
287 | class ( Container Vector t | 220 | sortVector :: (Ord t, Element t) => Vector t -> Vector t |
288 | , Container Matrix t | 221 | sortVector = sortV |
289 | , Konst t Int Vector | ||
290 | , Konst t (Int,Int) Matrix | ||
291 | , Product t | ||
292 | ) => Numeric t | ||
293 | 222 | ||
294 | instance Numeric Double | 223 | {- | |
295 | instance Numeric (Complex Double) | ||
296 | instance Numeric Float | ||
297 | instance Numeric (Complex Float) | ||
298 | 224 | ||
225 | >>> m <- randn 4 10 | ||
226 | >>> disp 2 m | ||
227 | 4x10 | ||
228 | -0.31 0.41 0.43 -0.19 -0.17 -0.23 -0.17 -1.04 -0.07 -1.24 | ||
229 | 0.26 0.19 0.14 0.83 -1.54 -0.09 0.37 -0.63 0.71 -0.50 | ||
230 | -0.11 -0.10 -1.29 -1.40 -1.04 -0.89 -0.68 0.35 -1.46 1.86 | ||
231 | 1.04 -0.29 0.19 -0.75 -2.20 -0.01 1.06 0.11 -2.09 -1.58 | ||
232 | |||
233 | >>> disp 2 $ m ?? (All, Pos $ sortIndex (m!1)) | ||
234 | 4x10 | ||
235 | -0.17 -1.04 -1.24 -0.23 0.43 0.41 -0.31 -0.17 -0.07 -0.19 | ||
236 | -1.54 -0.63 -0.50 -0.09 0.14 0.19 0.26 0.37 0.71 0.83 | ||
237 | -1.04 0.35 1.86 -0.89 -1.29 -0.10 -0.11 -0.68 -1.46 -1.40 | ||
238 | -2.20 0.11 -1.58 -0.01 0.19 -0.29 1.04 1.06 -2.09 -0.75 | ||
239 | |||
240 | -} | ||
241 | sortIndex :: (Ord t, Element t) => Vector t -> Vector I | ||
242 | sortIndex = sortI | ||
243 | |||
244 | ccompare :: (Ord t, Container c t) => c t -> c t -> c I | ||
245 | ccompare = ccompare' | ||
246 | |||
247 | cselect :: (Container c t) => c I -> c t -> c t -> c t -> c t | ||
248 | cselect = cselect' | ||
249 | |||
250 | {- | Extract elements from positions given in matrices of rows and columns. | ||
251 | |||
252 | >>> r | ||
253 | (3><3) | ||
254 | [ 1, 1, 1 | ||
255 | , 1, 2, 2 | ||
256 | , 1, 2, 3 ] | ||
257 | >>> c | ||
258 | (3><3) | ||
259 | [ 0, 1, 5 | ||
260 | , 2, 2, 1 | ||
261 | , 4, 4, 1 ] | ||
262 | >>> m | ||
263 | (4><6) | ||
264 | [ 0, 1, 2, 3, 4, 5 | ||
265 | , 6, 7, 8, 9, 10, 11 | ||
266 | , 12, 13, 14, 15, 16, 17 | ||
267 | , 18, 19, 20, 21, 22, 23 ] | ||
268 | |||
269 | >>> remap r c m | ||
270 | (3><3) | ||
271 | [ 6, 7, 11 | ||
272 | , 8, 14, 13 | ||
273 | , 10, 16, 19 ] | ||
274 | |||
275 | The indexes are autoconformable. | ||
276 | |||
277 | >>> c' | ||
278 | (3><1) | ||
279 | [ 1 | ||
280 | , 2 | ||
281 | , 4 ] | ||
282 | >>> remap r c' m | ||
283 | (3><3) | ||
284 | [ 7, 7, 7 | ||
285 | , 8, 14, 14 | ||
286 | , 10, 16, 22 ] | ||
287 | |||
288 | -} | ||
289 | remap :: Element t => Matrix I -> Matrix I -> Matrix t -> Matrix t | ||
290 | remap i j m | ||
291 | | minElement i >= 0 && maxElement i < fromIntegral (rows m) && | ||
292 | minElement j >= 0 && maxElement j < fromIntegral (cols m) = remapM i' j' m | ||
293 | | otherwise = error $ "out of range index in remap" | ||
294 | where | ||
295 | [i',j'] = conformMs [i,j] | ||
296 | |||
299 | 297 | ||
diff --git a/packages/base/src/Numeric/Conversion.hs b/packages/base/src/Internal/Conversion.hs index a1f9385..4541ec4 100644 --- a/packages/base/src/Numeric/Conversion.hs +++ b/packages/base/src/Internal/Conversion.hs | |||
@@ -16,16 +16,16 @@ | |||
16 | -- Conversion routines | 16 | -- Conversion routines |
17 | -- | 17 | -- |
18 | ----------------------------------------------------------------------------- | 18 | ----------------------------------------------------------------------------- |
19 | {-# OPTIONS_HADDOCK hide #-} | ||
20 | 19 | ||
21 | 20 | ||
22 | module Numeric.Conversion ( | 21 | module Internal.Conversion ( |
23 | Complexable(..), RealElement, | 22 | Complexable(..), RealElement, |
24 | module Data.Complex | 23 | module Data.Complex |
25 | ) where | 24 | ) where |
26 | 25 | ||
27 | import Data.Packed.Internal.Vector | 26 | import Internal.Vector |
28 | import Data.Packed.Internal.Matrix | 27 | import Internal.Matrix |
28 | import Internal.Vectorized | ||
29 | import Data.Complex | 29 | import Data.Complex |
30 | import Control.Arrow((***)) | 30 | import Control.Arrow((***)) |
31 | 31 | ||
@@ -44,10 +44,13 @@ instance Precision (Complex Float) (Complex Double) where | |||
44 | double2FloatG = asComplex . double2FloatV . asReal | 44 | double2FloatG = asComplex . double2FloatV . asReal |
45 | float2DoubleG = asComplex . float2DoubleV . asReal | 45 | float2DoubleG = asComplex . float2DoubleV . asReal |
46 | 46 | ||
47 | instance Precision I Z where | ||
48 | double2FloatG = long2intV | ||
49 | float2DoubleG = int2longV | ||
50 | |||
51 | |||
47 | -- | Supported real types | 52 | -- | Supported real types |
48 | class (Element t, Element (Complex t), RealFloat t | 53 | class (Element t, Element (Complex t), RealFloat t) |
49 | -- , RealOf t ~ t, RealOf (Complex t) ~ t | ||
50 | ) | ||
51 | => RealElement t | 54 | => RealElement t |
52 | 55 | ||
53 | instance RealElement Double | 56 | instance RealElement Double |
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/Convolution.hs b/packages/base/src/Internal/Convolution.hs index c9e75de..384fdf8 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Util/Convolution.hs +++ b/packages/base/src/Internal/Convolution.hs | |||
@@ -1,7 +1,7 @@ | |||
1 | {-# LANGUAGE FlexibleContexts #-} | 1 | {-# LANGUAGE FlexibleContexts #-} |
2 | ----------------------------------------------------------------------------- | 2 | ----------------------------------------------------------------------------- |
3 | {- | | 3 | {- | |
4 | Module : Numeric.LinearAlgebra.Util.Convolution | 4 | Module : Internal.Convolution |
5 | Copyright : (c) Alberto Ruiz 2012 | 5 | Copyright : (c) Alberto Ruiz 2012 |
6 | License : BSD3 | 6 | License : BSD3 |
7 | Maintainer : Alberto Ruiz | 7 | Maintainer : Alberto Ruiz |
@@ -11,13 +11,18 @@ Stability : provisional | |||
11 | ----------------------------------------------------------------------------- | 11 | ----------------------------------------------------------------------------- |
12 | {-# OPTIONS_HADDOCK hide #-} | 12 | {-# OPTIONS_HADDOCK hide #-} |
13 | 13 | ||
14 | module Numeric.LinearAlgebra.Util.Convolution( | 14 | module Internal.Convolution( |
15 | corr, conv, corrMin, | 15 | corr, conv, corrMin, |
16 | corr2, conv2, separable | 16 | corr2, conv2, separable |
17 | ) where | 17 | ) where |
18 | 18 | ||
19 | import qualified Data.Vector.Storable as SV | 19 | import qualified Data.Vector.Storable as SV |
20 | import Data.Packed.Numeric | 20 | import Internal.Vector |
21 | import Internal.Matrix | ||
22 | import Internal.Numeric | ||
23 | import Internal.Element | ||
24 | import Internal.Conversion | ||
25 | import Internal.Container | ||
21 | 26 | ||
22 | 27 | ||
23 | vectSS :: Element t => Int -> Vector t -> Matrix t | 28 | vectSS :: Element t => Int -> Vector t -> Matrix t |
diff --git a/packages/base/src/Internal/Devel.hs b/packages/base/src/Internal/Devel.hs new file mode 100644 index 0000000..92b5604 --- /dev/null +++ b/packages/base/src/Internal/Devel.hs | |||
@@ -0,0 +1,95 @@ | |||
1 | {-# LANGUAGE TypeOperators #-} | ||
2 | {-# LANGUAGE TypeFamilies #-} | ||
3 | |||
4 | -- | | ||
5 | -- Module : Internal.Devel | ||
6 | -- Copyright : (c) Alberto Ruiz 2007-15 | ||
7 | -- License : BSD3 | ||
8 | -- Maintainer : Alberto Ruiz | ||
9 | -- Stability : provisional | ||
10 | -- | ||
11 | |||
12 | module Internal.Devel where | ||
13 | |||
14 | |||
15 | import Control.Monad ( when ) | ||
16 | import Foreign.C.Types ( CInt ) | ||
17 | --import Foreign.Storable.Complex () | ||
18 | import Foreign.Ptr(Ptr) | ||
19 | import Control.Exception as E ( SomeException, catch ) | ||
20 | import Internal.Vector(Vector,avec) | ||
21 | import Foreign.Storable(Storable) | ||
22 | |||
23 | -- | postfix function application (@flip ($)@) | ||
24 | (//) :: x -> (x -> y) -> y | ||
25 | infixl 0 // | ||
26 | (//) = flip ($) | ||
27 | |||
28 | |||
29 | -- GSL error codes are <= 1024 | ||
30 | -- | error codes for the auxiliary functions required by the wrappers | ||
31 | errorCode :: CInt -> String | ||
32 | errorCode 2000 = "bad size" | ||
33 | errorCode 2001 = "bad function code" | ||
34 | errorCode 2002 = "memory problem" | ||
35 | errorCode 2003 = "bad file" | ||
36 | errorCode 2004 = "singular" | ||
37 | errorCode 2005 = "didn't converge" | ||
38 | errorCode 2006 = "the input matrix is not positive definite" | ||
39 | errorCode 2007 = "not yet supported in this OS" | ||
40 | errorCode n = "code "++show n | ||
41 | |||
42 | |||
43 | -- | clear the fpu | ||
44 | foreign import ccall unsafe "asm_finit" finit :: IO () | ||
45 | |||
46 | -- | check the error code | ||
47 | check :: String -> IO CInt -> IO () | ||
48 | check msg f = do | ||
49 | -- finit | ||
50 | err <- f | ||
51 | when (err/=0) $ error (msg++": "++errorCode err) | ||
52 | return () | ||
53 | |||
54 | |||
55 | -- | postfix error code check | ||
56 | infixl 0 #| | ||
57 | (#|) = flip check | ||
58 | |||
59 | -- | Error capture and conversion to Maybe | ||
60 | mbCatch :: IO x -> IO (Maybe x) | ||
61 | mbCatch act = E.catch (Just `fmap` act) f | ||
62 | where f :: SomeException -> IO (Maybe x) | ||
63 | f _ = return Nothing | ||
64 | |||
65 | -------------------------------------------------------------------------------- | ||
66 | |||
67 | type CM b r = CInt -> CInt -> Ptr b -> r | ||
68 | type CV b r = CInt -> Ptr b -> r | ||
69 | type OM b r = CInt -> CInt -> CInt -> CInt -> Ptr b -> r | ||
70 | |||
71 | type CIdxs r = CV CInt r | ||
72 | type Ok = IO CInt | ||
73 | |||
74 | infixr 5 :>, ::>, ..> | ||
75 | type (:>) t r = CV t r | ||
76 | type (::>) t r = OM t r | ||
77 | type (..>) t r = CM t r | ||
78 | |||
79 | class TransArray c | ||
80 | where | ||
81 | type Trans c b | ||
82 | type TransRaw c b | ||
83 | apply :: (Trans c b) -> c -> b | ||
84 | applyRaw :: (TransRaw c b) -> c -> b | ||
85 | infixl 1 `apply`, `applyRaw` | ||
86 | |||
87 | instance Storable t => TransArray (Vector t) | ||
88 | where | ||
89 | type Trans (Vector t) b = CInt -> Ptr t -> b | ||
90 | type TransRaw (Vector t) b = CInt -> Ptr t -> b | ||
91 | apply = avec | ||
92 | {-# INLINE apply #-} | ||
93 | applyRaw = avec | ||
94 | {-# INLINE applyRaw #-} | ||
95 | |||
diff --git a/packages/base/src/Data/Packed/Matrix.hs b/packages/base/src/Internal/Element.hs index 70b9232..a459678 100644 --- a/packages/base/src/Data/Packed/Matrix.hs +++ b/packages/base/src/Internal/Element.hs | |||
@@ -18,35 +18,19 @@ | |||
18 | -- This module provides basic functions for manipulation of structure. | 18 | -- This module provides basic functions for manipulation of structure. |
19 | 19 | ||
20 | ----------------------------------------------------------------------------- | 20 | ----------------------------------------------------------------------------- |
21 | {-# OPTIONS_HADDOCK hide #-} | ||
22 | |||
23 | module Data.Packed.Matrix ( | ||
24 | Matrix, | ||
25 | Element, | ||
26 | rows,cols, | ||
27 | (><), | ||
28 | trans, | ||
29 | reshape, flatten, | ||
30 | fromLists, toLists, buildMatrix, | ||
31 | (@@>), | ||
32 | asRow, asColumn, | ||
33 | fromRows, toRows, fromColumns, toColumns, | ||
34 | fromBlocks, diagBlock, toBlocks, toBlocksEvery, | ||
35 | repmat, | ||
36 | flipud, fliprl, | ||
37 | subMatrix, takeRows, dropRows, takeColumns, dropColumns, | ||
38 | extractRows, extractColumns, | ||
39 | diagRect, takeDiag, | ||
40 | mapMatrix, mapMatrixWithIndex, mapMatrixWithIndexM, mapMatrixWithIndexM_, | ||
41 | liftMatrix, liftMatrix2, liftMatrix2Auto,fromArray2D | ||
42 | ) where | ||
43 | |||
44 | import Data.Packed.Internal | ||
45 | import qualified Data.Packed.ST as ST | ||
46 | import Data.Array | ||
47 | 21 | ||
22 | module Internal.Element where | ||
23 | |||
24 | import Internal.Vector | ||
25 | import Internal.Matrix | ||
26 | import Internal.Vectorized | ||
27 | import qualified Internal.ST as ST | ||
28 | import Data.Array | ||
29 | import Text.Printf | ||
48 | import Data.List(transpose,intersperse) | 30 | import Data.List(transpose,intersperse) |
31 | import Data.List.Split(chunksOf) | ||
49 | import Foreign.Storable(Storable) | 32 | import Foreign.Storable(Storable) |
33 | import System.IO.Unsafe(unsafePerformIO) | ||
50 | import Control.Monad(liftM) | 34 | import Control.Monad(liftM) |
51 | 35 | ||
52 | ------------------------------------------------------------------- | 36 | ------------------------------------------------------------------- |
@@ -95,7 +79,126 @@ instance (Element a, Read a) => Read (Matrix a) where | |||
95 | breakAt c l = (a++[c],tail b) where | 79 | breakAt c l = (a++[c],tail b) where |
96 | (a,b) = break (==c) l | 80 | (a,b) = break (==c) l |
97 | 81 | ||
98 | ------------------------------------------------------------------ | 82 | -------------------------------------------------------------------------------- |
83 | -- | Specification of indexes for the operator '??'. | ||
84 | data Extractor | ||
85 | = All | ||
86 | | Range Int Int Int | ||
87 | | Pos (Vector I) | ||
88 | | PosCyc (Vector I) | ||
89 | | Take Int | ||
90 | | TakeLast Int | ||
91 | | Drop Int | ||
92 | | DropLast Int | ||
93 | deriving Show | ||
94 | |||
95 | ppext All = ":" | ||
96 | ppext (Range a 1 c) = printf "%d:%d" a c | ||
97 | ppext (Range a b c) = printf "%d:%d:%d" a b c | ||
98 | ppext (Pos v) = show (toList v) | ||
99 | ppext (PosCyc v) = "Cyclic"++show (toList v) | ||
100 | ppext (Take n) = printf "Take %d" n | ||
101 | ppext (Drop n) = printf "Drop %d" n | ||
102 | ppext (TakeLast n) = printf "TakeLast %d" n | ||
103 | ppext (DropLast n) = printf "DropLast %d" n | ||
104 | |||
105 | {- | General matrix slicing. | ||
106 | |||
107 | >>> m | ||
108 | (4><5) | ||
109 | [ 0, 1, 2, 3, 4 | ||
110 | , 5, 6, 7, 8, 9 | ||
111 | , 10, 11, 12, 13, 14 | ||
112 | , 15, 16, 17, 18, 19 ] | ||
113 | |||
114 | >>> m ?? (Take 3, DropLast 2) | ||
115 | (3><3) | ||
116 | [ 0, 1, 2 | ||
117 | , 5, 6, 7 | ||
118 | , 10, 11, 12 ] | ||
119 | |||
120 | >>> m ?? (Pos (idxs[2,1]), All) | ||
121 | (2><5) | ||
122 | [ 10, 11, 12, 13, 14 | ||
123 | , 5, 6, 7, 8, 9 ] | ||
124 | |||
125 | >>> m ?? (PosCyc (idxs[-7,80]), Range 4 (-2) 0) | ||
126 | (2><3) | ||
127 | [ 9, 7, 5 | ||
128 | , 4, 2, 0 ] | ||
129 | |||
130 | -} | ||
131 | infixl 9 ?? | ||
132 | (??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t | ||
133 | |||
134 | minEl = toScalarI Min | ||
135 | maxEl = toScalarI Max | ||
136 | cmodi = vectorMapValI ModVS | ||
137 | |||
138 | extractError m (e1,e2)= error $ printf "can't extract (%s,%s) from matrix %dx%d" (ppext e1::String) (ppext e2::String) (rows m) (cols m) | ||
139 | |||
140 | m ?? (Range a s b,e) | s /= 1 = m ?? (Pos (idxs [a,a+s .. b]), e) | ||
141 | m ?? (e,Range a s b) | s /= 1 = m ?? (e, Pos (idxs [a,a+s .. b])) | ||
142 | |||
143 | m ?? e@(Range a _ b,_) | a < 0 || b >= rows m = extractError m e | ||
144 | m ?? e@(_,Range a _ b) | a < 0 || b >= cols m = extractError m e | ||
145 | |||
146 | m ?? e@(Pos vs,_) | dim vs>0 && (minEl vs < 0 || maxEl vs >= fi (rows m)) = extractError m e | ||
147 | m ?? e@(_,Pos vs) | dim vs>0 && (minEl vs < 0 || maxEl vs >= fi (cols m)) = extractError m e | ||
148 | |||
149 | m ?? (All,All) = m | ||
150 | |||
151 | m ?? (Range a _ b,e) | a > b = m ?? (Take 0,e) | ||
152 | m ?? (e,Range a _ b) | a > b = m ?? (e,Take 0) | ||
153 | |||
154 | m ?? (Take n,e) | ||
155 | | n <= 0 = (0><cols m) [] ?? (All,e) | ||
156 | | n >= rows m = m ?? (All,e) | ||
157 | |||
158 | m ?? (e,Take n) | ||
159 | | n <= 0 = (rows m><0) [] ?? (e,All) | ||
160 | | n >= cols m = m ?? (e,All) | ||
161 | |||
162 | m ?? (Drop n,e) | ||
163 | | n <= 0 = m ?? (All,e) | ||
164 | | n >= rows m = (0><cols m) [] ?? (All,e) | ||
165 | |||
166 | m ?? (e,Drop n) | ||
167 | | n <= 0 = m ?? (e,All) | ||
168 | | n >= cols m = (rows m><0) [] ?? (e,All) | ||
169 | |||
170 | m ?? (TakeLast n, e) = m ?? (Drop (rows m - n), e) | ||
171 | m ?? (e, TakeLast n) = m ?? (e, Drop (cols m - n)) | ||
172 | |||
173 | m ?? (DropLast n, e) = m ?? (Take (rows m - n), e) | ||
174 | m ?? (e, DropLast n) = m ?? (e, Take (cols m - n)) | ||
175 | |||
176 | m ?? (er,ec) = unsafePerformIO $ extractR (orderOf m) m moder rs modec cs | ||
177 | where | ||
178 | (moder,rs) = mkExt (rows m) er | ||
179 | (modec,cs) = mkExt (cols m) ec | ||
180 | ran a b = (0, idxs [a,b]) | ||
181 | pos ks = (1, ks) | ||
182 | mkExt _ (Pos ks) = pos ks | ||
183 | mkExt n (PosCyc ks) | ||
184 | | n == 0 = mkExt n (Take 0) | ||
185 | | otherwise = pos (cmodi (fi n) ks) | ||
186 | mkExt _ (Range mn _ mx) = ran mn mx | ||
187 | mkExt _ (Take k) = ran 0 (k-1) | ||
188 | mkExt n (Drop k) = ran k (n-1) | ||
189 | mkExt n _ = ran 0 (n-1) -- All | ||
190 | |||
191 | -------------------------------------------------------------------------------- | ||
192 | |||
193 | -- | obtains the common value of a property of a list | ||
194 | common :: (Eq a) => (b->a) -> [b] -> Maybe a | ||
195 | common f = commonval . map f | ||
196 | where | ||
197 | commonval :: (Eq a) => [a] -> Maybe a | ||
198 | commonval [] = Nothing | ||
199 | commonval [a] = Just a | ||
200 | commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing | ||
201 | |||
99 | 202 | ||
100 | -- | creates a matrix from a vertical list of matrices | 203 | -- | creates a matrix from a vertical list of matrices |
101 | joinVert :: Element t => [Matrix t] -> Matrix t | 204 | joinVert :: Element t => [Matrix t] -> Matrix t |
@@ -141,7 +244,7 @@ adaptBlocks ms = ms' where | |||
141 | rs = map (compatdim . map rows) ms | 244 | rs = map (compatdim . map rows) ms |
142 | cs = map (compatdim . map cols) (transpose ms) | 245 | cs = map (compatdim . map cols) (transpose ms) |
143 | szs = sequence [rs,cs] | 246 | szs = sequence [rs,cs] |
144 | ms' = splitEvery bc $ zipWith g szs (concat ms) | 247 | ms' = chunksOf bc $ zipWith g szs (concat ms) |
145 | 248 | ||
146 | g [Just nr,Just nc] m | 249 | g [Just nr,Just nc] m |
147 | | nr == r && nc == c = m | 250 | | nr == r && nc == c = m |
@@ -218,13 +321,13 @@ diagRect z v r c = ST.runSTMatrix $ do | |||
218 | 321 | ||
219 | -- | extracts the diagonal from a rectangular matrix | 322 | -- | extracts the diagonal from a rectangular matrix |
220 | takeDiag :: (Element t) => Matrix t -> Vector t | 323 | takeDiag :: (Element t) => Matrix t -> Vector t |
221 | takeDiag m = fromList [flatten m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] | 324 | takeDiag m = fromList [flatten m @> (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] |
222 | 325 | ||
223 | ------------------------------------------------------------ | 326 | ------------------------------------------------------------ |
224 | 327 | ||
225 | {- | create a general matrix | 328 | {- | Create a matrix from a list of elements |
226 | 329 | ||
227 | >>> (2><3) [2, 4, 7+2*𝑖, -3, 11, 0] | 330 | >>> (2><3) [2, 4, 7+2*iC, -3, 11, 0] |
228 | (2><3) | 331 | (2><3) |
229 | [ 2.0 :+ 0.0, 4.0 :+ 0.0, 7.0 :+ 2.0 | 332 | [ 2.0 :+ 0.0, 4.0 :+ 0.0, 7.0 :+ 2.0 |
230 | , (-3.0) :+ (-0.0), 11.0 :+ 0.0, 0.0 :+ 0.0 ] | 333 | , (-3.0) :+ (-0.0), 11.0 :+ 0.0, 0.0 :+ 0.0 ] |
@@ -250,19 +353,34 @@ r >< c = f where | |||
250 | 353 | ||
251 | ---------------------------------------------------------------- | 354 | ---------------------------------------------------------------- |
252 | 355 | ||
253 | -- | Creates a matrix with the first n rows of another matrix | ||
254 | takeRows :: Element t => Int -> Matrix t -> Matrix t | 356 | takeRows :: Element t => Int -> Matrix t -> Matrix t |
255 | takeRows n mt = subMatrix (0,0) (n, cols mt) mt | 357 | takeRows n mt = subMatrix (0,0) (n, cols mt) mt |
256 | -- | Creates a copy of a matrix without the first n rows | 358 | |
359 | -- | Creates a matrix with the last n rows of another matrix | ||
360 | takeLastRows :: Element t => Int -> Matrix t -> Matrix t | ||
361 | takeLastRows n mt = subMatrix (rows mt - n, 0) (n, cols mt) mt | ||
362 | |||
257 | dropRows :: Element t => Int -> Matrix t -> Matrix t | 363 | dropRows :: Element t => Int -> Matrix t -> Matrix t |
258 | dropRows n mt = subMatrix (n,0) (rows mt - n, cols mt) mt | 364 | dropRows n mt = subMatrix (n,0) (rows mt - n, cols mt) mt |
259 | -- |Creates a matrix with the first n columns of another matrix | 365 | |
366 | -- | Creates a copy of a matrix without the last n rows | ||
367 | dropLastRows :: Element t => Int -> Matrix t -> Matrix t | ||
368 | dropLastRows n mt = subMatrix (0,0) (rows mt - n, cols mt) mt | ||
369 | |||
260 | takeColumns :: Element t => Int -> Matrix t -> Matrix t | 370 | takeColumns :: Element t => Int -> Matrix t -> Matrix t |
261 | takeColumns n mt = subMatrix (0,0) (rows mt, n) mt | 371 | takeColumns n mt = subMatrix (0,0) (rows mt, n) mt |
262 | -- | Creates a copy of a matrix without the first n columns | 372 | |
373 | -- |Creates a matrix with the last n columns of another matrix | ||
374 | takeLastColumns :: Element t => Int -> Matrix t -> Matrix t | ||
375 | takeLastColumns n mt = subMatrix (0, cols mt - n) (rows mt, n) mt | ||
376 | |||
263 | dropColumns :: Element t => Int -> Matrix t -> Matrix t | 377 | dropColumns :: Element t => Int -> Matrix t -> Matrix t |
264 | dropColumns n mt = subMatrix (0,n) (rows mt, cols mt - n) mt | 378 | dropColumns n mt = subMatrix (0,n) (rows mt, cols mt - n) mt |
265 | 379 | ||
380 | -- | Creates a copy of a matrix without the last n columns | ||
381 | dropLastColumns :: Element t => Int -> Matrix t -> Matrix t | ||
382 | dropLastColumns n mt = subMatrix (0,0) (rows mt, cols mt - n) mt | ||
383 | |||
266 | ---------------------------------------------------------------- | 384 | ---------------------------------------------------------------- |
267 | 385 | ||
268 | {- | Creates a 'Matrix' from a list of lists (considered as rows). | 386 | {- | Creates a 'Matrix' from a list of lists (considered as rows). |
@@ -331,24 +449,11 @@ fromArray2D m = (r><c) (elems m) | |||
331 | 449 | ||
332 | -- | rearranges the rows of a matrix according to the order given in a list of integers. | 450 | -- | rearranges the rows of a matrix according to the order given in a list of integers. |
333 | extractRows :: Element t => [Int] -> Matrix t -> Matrix t | 451 | extractRows :: Element t => [Int] -> Matrix t -> Matrix t |
334 | extractRows [] m = emptyM 0 (cols m) | 452 | extractRows l m = m ?? (Pos (idxs l), All) |
335 | extractRows l m = fromRows $ extract (toRows m) l | ||
336 | where | ||
337 | extract l' is = [l'!!i | i<- map verify is] | ||
338 | verify k | ||
339 | | k >= 0 && k < rows m = k | ||
340 | | otherwise = error $ "can't extract row " | ||
341 | ++show k++" in list " ++ show l ++ " from matrix " ++ shSize m | ||
342 | 453 | ||
343 | -- | rearranges the rows of a matrix according to the order given in a list of integers. | 454 | -- | rearranges the rows of a matrix according to the order given in a list of integers. |
344 | extractColumns :: Element t => [Int] -> Matrix t -> Matrix t | 455 | extractColumns :: Element t => [Int] -> Matrix t -> Matrix t |
345 | extractColumns l m = trans . extractRows (map verify l) . trans $ m | 456 | extractColumns l m = m ?? (All, Pos (idxs l)) |
346 | where | ||
347 | verify k | ||
348 | | k >= 0 && k < cols m = k | ||
349 | | otherwise = error $ "can't extract column " | ||
350 | ++show k++" in list " ++ show l ++ " from matrix " ++ shSize m | ||
351 | |||
352 | 457 | ||
353 | 458 | ||
354 | {- | creates matrix by repetition of a matrix a given number of rows and columns | 459 | {- | creates matrix by repetition of a matrix a given number of rows and columns |
@@ -386,9 +491,13 @@ liftMatrix2Auto f m1 m2 | |||
386 | -- FIXME do not flatten if equal order | 491 | -- FIXME do not flatten if equal order |
387 | lM f m1 m2 = matrixFromVector | 492 | lM f m1 m2 = matrixFromVector |
388 | RowMajor | 493 | RowMajor |
389 | (max (rows m1) (rows m2)) | 494 | (max' (rows m1) (rows m2)) |
390 | (max (cols m1) (cols m2)) | 495 | (max' (cols m1) (cols m2)) |
391 | (f (flatten m1) (flatten m2)) | 496 | (f (flatten m1) (flatten m2)) |
497 | where | ||
498 | max' 1 b = b | ||
499 | max' a 1 = a | ||
500 | max' a b = max a b | ||
392 | 501 | ||
393 | compat' :: Matrix a -> Matrix b -> Bool | 502 | compat' :: Matrix a -> Matrix b -> Bool |
394 | compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 | 503 | compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 |
@@ -490,5 +599,6 @@ mapMatrixWithIndex g m = reshape c . mapVectorWithIndex (mk c g) . flatten $ m | |||
490 | where | 599 | where |
491 | c = cols m | 600 | c = cols m |
492 | 601 | ||
493 | mapMatrix :: (Storable a, Storable b) => (a -> b) -> Matrix a -> Matrix b | 602 | mapMatrix :: (Element a, Element b) => (a -> b) -> Matrix a -> Matrix b |
494 | mapMatrix f = liftMatrix (mapVector f) | 603 | mapMatrix f = liftMatrix (mapVector f) |
604 | |||
diff --git a/packages/base/src/Data/Packed/Foreign.hs b/packages/base/src/Internal/Foreign.hs index 1ec3694..ea071a4 100644 --- a/packages/base/src/Data/Packed/Foreign.hs +++ b/packages/base/src/Internal/Foreign.hs | |||
@@ -6,17 +6,19 @@ | |||
6 | -- @ glUniformMatrix4fv 0 1 (fromIntegral gl_TRUE) \`appMatrix\` perspective 0.01 100 (pi\/2) (4\/3) | 6 | -- @ glUniformMatrix4fv 0 1 (fromIntegral gl_TRUE) \`appMatrix\` perspective 0.01 100 (pi\/2) (4\/3) |
7 | -- @ | 7 | -- @ |
8 | -- | 8 | -- |
9 | {-# OPTIONS_HADDOCK hide #-} | 9 | |
10 | module Data.Packed.Foreign | 10 | module Internal.Foreign |
11 | ( app | 11 | ( app |
12 | , appVector, appVectorLen | 12 | , appVector, appVectorLen |
13 | , appMatrix, appMatrixLen, appMatrixRaw, appMatrixRawLen | 13 | , appMatrix, appMatrixLen, appMatrixRaw, appMatrixRawLen |
14 | , unsafeMatrixToVector, unsafeMatrixToForeignPtr | 14 | , unsafeMatrixToVector, unsafeMatrixToForeignPtr |
15 | ) where | 15 | ) where |
16 | import Data.Packed.Internal | 16 | |
17 | import Foreign.C.Types(CInt) | ||
18 | import Internal.Vector | ||
19 | import Internal.Matrix | ||
17 | import qualified Data.Vector.Storable as S | 20 | import qualified Data.Vector.Storable as S |
18 | import Foreign (Ptr, ForeignPtr, Storable) | 21 | import Foreign (Ptr, ForeignPtr, Storable) |
19 | import Foreign.C.Types (CInt) | ||
20 | import GHC.Base (IO(..), realWorld#) | 22 | import GHC.Base (IO(..), realWorld#) |
21 | 23 | ||
22 | {-# INLINE unsafeInlinePerformIO #-} | 24 | {-# INLINE unsafeInlinePerformIO #-} |
diff --git a/packages/base/src/Data/Packed/IO.hs b/packages/base/src/Internal/IO.hs index 85f1b37..a899cfd 100644 --- a/packages/base/src/Data/Packed/IO.hs +++ b/packages/base/src/Internal/IO.hs | |||
@@ -1,6 +1,6 @@ | |||
1 | ----------------------------------------------------------------------------- | 1 | ----------------------------------------------------------------------------- |
2 | -- | | 2 | -- | |
3 | -- Module : Data.Packed.IO | 3 | -- Module : Internal.IO |
4 | -- Copyright : (c) Alberto Ruiz 2010 | 4 | -- Copyright : (c) Alberto Ruiz 2010 |
5 | -- License : BSD3 | 5 | -- License : BSD3 |
6 | -- | 6 | -- |
@@ -10,20 +10,32 @@ | |||
10 | -- Display, formatting and IO functions for numeric 'Vector' and 'Matrix' | 10 | -- Display, formatting and IO functions for numeric 'Vector' and 'Matrix' |
11 | -- | 11 | -- |
12 | ----------------------------------------------------------------------------- | 12 | ----------------------------------------------------------------------------- |
13 | {-# OPTIONS_HADDOCK hide #-} | ||
14 | 13 | ||
15 | module Data.Packed.IO ( | 14 | module Internal.IO ( |
16 | dispf, disps, dispcf, vecdisp, latexFormat, format, | 15 | dispf, disps, dispcf, vecdisp, latexFormat, format, |
17 | readMatrix, fromArray2D, loadMatrix, loadMatrix', saveMatrix | 16 | loadMatrix, loadMatrix', saveMatrix |
18 | ) where | 17 | ) where |
19 | 18 | ||
20 | import Data.Packed | 19 | import Internal.Devel |
20 | import Internal.Vector | ||
21 | import Internal.Matrix | ||
22 | import Internal.Vectorized | ||
21 | import Text.Printf(printf) | 23 | import Text.Printf(printf) |
22 | import Data.List(intersperse) | 24 | import Data.List(intersperse,transpose) |
23 | import Data.Complex | 25 | import Data.Complex |
24 | import Numeric.Vectorized(vectorScan,saveMatrix) | 26 | |
25 | import Control.Applicative((<$>)) | 27 | |
26 | import Data.Packed.Internal | 28 | -- | Formatting tool |
29 | table :: String -> [[String]] -> String | ||
30 | table sep as = unlines . map unwords' $ transpose mtp | ||
31 | where | ||
32 | mt = transpose as | ||
33 | longs = map (maximum . map length) mt | ||
34 | mtp = zipWith (\a b -> map (pad a) b) longs mt | ||
35 | pad n str = replicate (n - length str) ' ' ++ str | ||
36 | unwords' = concat . intersperse sep | ||
37 | |||
38 | |||
27 | 39 | ||
28 | {- | Creates a string from a matrix given a separator and a function to show each entry. Using | 40 | {- | Creates a string from a matrix given a separator and a function to show each entry. Using |
29 | this function the user can easily define any desired display function: | 41 | this function the user can easily define any desired display function: |
@@ -137,12 +149,6 @@ dispcf d m = sdims m ++ "\n" ++ format " " (showComplex d) m | |||
137 | 149 | ||
138 | -------------------------------------------------------------------- | 150 | -------------------------------------------------------------------- |
139 | 151 | ||
140 | -- | reads a matrix from a string containing a table of numbers. | ||
141 | readMatrix :: String -> Matrix Double | ||
142 | readMatrix = fromLists . map (map read). map words . filter (not.null) . lines | ||
143 | |||
144 | -------------------------------------------------------------------------------- | ||
145 | |||
146 | apparentCols :: FilePath -> IO Int | 152 | apparentCols :: FilePath -> IO Int |
147 | apparentCols s = f . dropWhile null . map words . lines <$> readFile s | 153 | apparentCols s = f . dropWhile null . map words . lines <$> readFile s |
148 | where | 154 | where |
diff --git a/packages/base/src/Numeric/LinearAlgebra/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index e088fdc..c2c140b 100644 --- a/packages/base/src/Numeric/LinearAlgebra/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs | |||
@@ -1,3 +1,6 @@ | |||
1 | {-# LANGUAGE TypeOperators #-} | ||
2 | {-# LANGUAGE ViewPatterns #-} | ||
3 | |||
1 | ----------------------------------------------------------------------------- | 4 | ----------------------------------------------------------------------------- |
2 | -- | | 5 | -- | |
3 | -- Module : Numeric.LinearAlgebra.LAPACK | 6 | -- Module : Numeric.LinearAlgebra.LAPACK |
@@ -9,44 +12,15 @@ | |||
9 | -- Functional interface to selected LAPACK functions (<http://www.netlib.org/lapack>). | 12 | -- Functional interface to selected LAPACK functions (<http://www.netlib.org/lapack>). |
10 | -- | 13 | -- |
11 | ----------------------------------------------------------------------------- | 14 | ----------------------------------------------------------------------------- |
12 | {-# OPTIONS_HADDOCK hide #-} | ||
13 | |||
14 | |||
15 | module Numeric.LinearAlgebra.LAPACK ( | ||
16 | -- * Matrix product | ||
17 | multiplyR, multiplyC, multiplyF, multiplyQ, | ||
18 | -- * Linear systems | ||
19 | linearSolveR, linearSolveC, | ||
20 | mbLinearSolveR, mbLinearSolveC, | ||
21 | lusR, lusC, | ||
22 | cholSolveR, cholSolveC, | ||
23 | linearSolveLSR, linearSolveLSC, | ||
24 | linearSolveSVDR, linearSolveSVDC, | ||
25 | -- * SVD | ||
26 | svR, svRd, svC, svCd, | ||
27 | svdR, svdRd, svdC, svdCd, | ||
28 | thinSVDR, thinSVDRd, thinSVDC, thinSVDCd, | ||
29 | rightSVR, rightSVC, leftSVR, leftSVC, | ||
30 | -- * Eigensystems | ||
31 | eigR, eigC, eigS, eigS', eigH, eigH', | ||
32 | eigOnlyR, eigOnlyC, eigOnlyS, eigOnlyH, | ||
33 | -- * LU | ||
34 | luR, luC, | ||
35 | -- * Cholesky | ||
36 | cholS, cholH, mbCholS, mbCholH, | ||
37 | -- * QR | ||
38 | qrR, qrC, qrgrR, qrgrC, | ||
39 | -- * Hessenberg | ||
40 | hessR, hessC, | ||
41 | -- * Schur | ||
42 | schurR, schurC | ||
43 | ) where | ||
44 | |||
45 | import Data.Packed.Development | ||
46 | import Data.Packed | ||
47 | import Data.Packed.Internal | ||
48 | import Numeric.Conversion | ||
49 | 15 | ||
16 | |||
17 | module Internal.LAPACK where | ||
18 | |||
19 | import Internal.Devel | ||
20 | import Internal.Vector | ||
21 | import Internal.Matrix hiding ((#)) | ||
22 | import Internal.Conversion | ||
23 | import Internal.Element | ||
50 | import Foreign.Ptr(nullPtr) | 24 | import Foreign.Ptr(nullPtr) |
51 | import Foreign.C.Types | 25 | import Foreign.C.Types |
52 | import Control.Monad(when) | 26 | import Control.Monad(when) |
@@ -54,22 +28,35 @@ import System.IO.Unsafe(unsafePerformIO) | |||
54 | 28 | ||
55 | ----------------------------------------------------------------------------------- | 29 | ----------------------------------------------------------------------------------- |
56 | 30 | ||
57 | foreign import ccall unsafe "multiplyR" dgemmc :: CInt -> CInt -> TMMM | 31 | infixl 1 # |
58 | foreign import ccall unsafe "multiplyC" zgemmc :: CInt -> CInt -> TCMCMCM | 32 | a # b = apply a b |
59 | foreign import ccall unsafe "multiplyF" sgemmc :: CInt -> CInt -> TFMFMFM | 33 | {-# INLINE (#) #-} |
60 | foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TQMQMQM | 34 | |
35 | ----------------------------------------------------------------------------------- | ||
36 | |||
37 | type TMMM t = t ::> t ::> t ::> Ok | ||
38 | |||
39 | type F = Float | ||
40 | type Q = Complex Float | ||
41 | |||
42 | foreign import ccall unsafe "multiplyR" dgemmc :: CInt -> CInt -> TMMM R | ||
43 | foreign import ccall unsafe "multiplyC" zgemmc :: CInt -> CInt -> TMMM C | ||
44 | foreign import ccall unsafe "multiplyF" sgemmc :: CInt -> CInt -> TMMM F | ||
45 | foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TMMM Q | ||
46 | foreign import ccall unsafe "multiplyI" c_multiplyI :: I -> TMMM I | ||
47 | foreign import ccall unsafe "multiplyL" c_multiplyL :: Z -> TMMM Z | ||
61 | 48 | ||
62 | isT Matrix{order = ColumnMajor} = 0 | 49 | isT (rowOrder -> False) = 0 |
63 | isT Matrix{order = RowMajor} = 1 | 50 | isT _ = 1 |
64 | 51 | ||
65 | tt x@Matrix{order = ColumnMajor} = x | 52 | tt x@(rowOrder -> False) = x |
66 | tt x@Matrix{order = RowMajor} = trans x | 53 | tt x = trans x |
67 | 54 | ||
68 | multiplyAux f st a b = unsafePerformIO $ do | 55 | multiplyAux f st a b = unsafePerformIO $ do |
69 | when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ | 56 | when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ |
70 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) | 57 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) |
71 | s <- createMatrix ColumnMajor (rows a) (cols b) | 58 | s <- createMatrix ColumnMajor (rows a) (cols b) |
72 | app3 (f (isT a) (isT b)) mat (tt a) mat (tt b) mat s st | 59 | f (isT a) (isT b) # (tt a) # (tt b) # s #| st |
73 | return s | 60 | return s |
74 | 61 | ||
75 | -- | Matrix product based on BLAS's /dgemm/. | 62 | -- | Matrix product based on BLAS's /dgemm/. |
@@ -88,178 +75,213 @@ multiplyF a b = multiplyAux sgemmc "sgemmc" a b | |||
88 | multiplyQ :: Matrix (Complex Float) -> Matrix (Complex Float) -> Matrix (Complex Float) | 75 | multiplyQ :: Matrix (Complex Float) -> Matrix (Complex Float) -> Matrix (Complex Float) |
89 | multiplyQ a b = multiplyAux cgemmc "cgemmc" a b | 76 | multiplyQ a b = multiplyAux cgemmc "cgemmc" a b |
90 | 77 | ||
78 | multiplyI :: I -> Matrix CInt -> Matrix CInt -> Matrix CInt | ||
79 | multiplyI m a b = unsafePerformIO $ do | ||
80 | when (cols a /= rows b) $ error $ | ||
81 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b | ||
82 | s <- createMatrix ColumnMajor (rows a) (cols b) | ||
83 | c_multiplyI m # a # b # s #|"c_multiplyI" | ||
84 | return s | ||
85 | |||
86 | multiplyL :: Z -> Matrix Z -> Matrix Z -> Matrix Z | ||
87 | multiplyL m a b = unsafePerformIO $ do | ||
88 | when (cols a /= rows b) $ error $ | ||
89 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b | ||
90 | s <- createMatrix ColumnMajor (rows a) (cols b) | ||
91 | c_multiplyL m # a # b # s #|"c_multiplyL" | ||
92 | return s | ||
93 | |||
91 | ----------------------------------------------------------------------------- | 94 | ----------------------------------------------------------------------------- |
92 | foreign import ccall unsafe "svd_l_R" dgesvd :: TMMVM | 95 | |
93 | foreign import ccall unsafe "svd_l_C" zgesvd :: TCMCMVCM | 96 | type TSVD t = t ::> t ::> R :> t ::> Ok |
94 | foreign import ccall unsafe "svd_l_Rdd" dgesdd :: TMMVM | 97 | |
95 | foreign import ccall unsafe "svd_l_Cdd" zgesdd :: TCMCMVCM | 98 | foreign import ccall unsafe "svd_l_R" dgesvd :: TSVD R |
99 | foreign import ccall unsafe "svd_l_C" zgesvd :: TSVD C | ||
100 | foreign import ccall unsafe "svd_l_Rdd" dgesdd :: TSVD R | ||
101 | foreign import ccall unsafe "svd_l_Cdd" zgesdd :: TSVD C | ||
96 | 102 | ||
97 | -- | Full SVD of a real matrix using LAPACK's /dgesvd/. | 103 | -- | Full SVD of a real matrix using LAPACK's /dgesvd/. |
98 | svdR :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) | 104 | svdR :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) |
99 | svdR = svdAux dgesvd "svdR" . fmat | 105 | svdR = svdAux dgesvd "svdR" |
100 | 106 | ||
101 | -- | Full SVD of a real matrix using LAPACK's /dgesdd/. | 107 | -- | Full SVD of a real matrix using LAPACK's /dgesdd/. |
102 | svdRd :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) | 108 | svdRd :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) |
103 | svdRd = svdAux dgesdd "svdRdd" . fmat | 109 | svdRd = svdAux dgesdd "svdRdd" |
104 | 110 | ||
105 | -- | Full SVD of a complex matrix using LAPACK's /zgesvd/. | 111 | -- | Full SVD of a complex matrix using LAPACK's /zgesvd/. |
106 | svdC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector Double, Matrix (Complex Double)) | 112 | svdC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector Double, Matrix (Complex Double)) |
107 | svdC = svdAux zgesvd "svdC" . fmat | 113 | svdC = svdAux zgesvd "svdC" |
108 | 114 | ||
109 | -- | Full SVD of a complex matrix using LAPACK's /zgesdd/. | 115 | -- | Full SVD of a complex matrix using LAPACK's /zgesdd/. |
110 | svdCd :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector Double, Matrix (Complex Double)) | 116 | svdCd :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector Double, Matrix (Complex Double)) |
111 | svdCd = svdAux zgesdd "svdCdd" . fmat | 117 | svdCd = svdAux zgesdd "svdCdd" |
112 | 118 | ||
113 | svdAux f st x = unsafePerformIO $ do | 119 | svdAux f st x = unsafePerformIO $ do |
120 | a <- copy ColumnMajor x | ||
114 | u <- createMatrix ColumnMajor r r | 121 | u <- createMatrix ColumnMajor r r |
115 | s <- createVector (min r c) | 122 | s <- createVector (min r c) |
116 | v <- createMatrix ColumnMajor c c | 123 | v <- createMatrix ColumnMajor c c |
117 | app4 f mat x mat u vec s mat v st | 124 | f # a # u # s # v #| st |
118 | return (u,s,trans v) | 125 | return (u,s,v) |
119 | where r = rows x | 126 | where |
120 | c = cols x | 127 | r = rows x |
128 | c = cols x | ||
121 | 129 | ||
122 | 130 | ||
123 | -- | Thin SVD of a real matrix, using LAPACK's /dgesvd/ with jobu == jobvt == \'S\'. | 131 | -- | Thin SVD of a real matrix, using LAPACK's /dgesvd/ with jobu == jobvt == \'S\'. |
124 | thinSVDR :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) | 132 | thinSVDR :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) |
125 | thinSVDR = thinSVDAux dgesvd "thinSVDR" . fmat | 133 | thinSVDR = thinSVDAux dgesvd "thinSVDR" |
126 | 134 | ||
127 | -- | Thin SVD of a complex matrix, using LAPACK's /zgesvd/ with jobu == jobvt == \'S\'. | 135 | -- | Thin SVD of a complex matrix, using LAPACK's /zgesvd/ with jobu == jobvt == \'S\'. |
128 | thinSVDC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector Double, Matrix (Complex Double)) | 136 | thinSVDC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector Double, Matrix (Complex Double)) |
129 | thinSVDC = thinSVDAux zgesvd "thinSVDC" . fmat | 137 | thinSVDC = thinSVDAux zgesvd "thinSVDC" |
130 | 138 | ||
131 | -- | Thin SVD of a real matrix, using LAPACK's /dgesdd/ with jobz == \'S\'. | 139 | -- | Thin SVD of a real matrix, using LAPACK's /dgesdd/ with jobz == \'S\'. |
132 | thinSVDRd :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) | 140 | thinSVDRd :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) |
133 | thinSVDRd = thinSVDAux dgesdd "thinSVDRdd" . fmat | 141 | thinSVDRd = thinSVDAux dgesdd "thinSVDRdd" |
134 | 142 | ||
135 | -- | Thin SVD of a complex matrix, using LAPACK's /zgesdd/ with jobz == \'S\'. | 143 | -- | Thin SVD of a complex matrix, using LAPACK's /zgesdd/ with jobz == \'S\'. |
136 | thinSVDCd :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector Double, Matrix (Complex Double)) | 144 | thinSVDCd :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector Double, Matrix (Complex Double)) |
137 | thinSVDCd = thinSVDAux zgesdd "thinSVDCdd" . fmat | 145 | thinSVDCd = thinSVDAux zgesdd "thinSVDCdd" |
138 | 146 | ||
139 | thinSVDAux f st x = unsafePerformIO $ do | 147 | thinSVDAux f st x = unsafePerformIO $ do |
148 | a <- copy ColumnMajor x | ||
140 | u <- createMatrix ColumnMajor r q | 149 | u <- createMatrix ColumnMajor r q |
141 | s <- createVector q | 150 | s <- createVector q |
142 | v <- createMatrix ColumnMajor q c | 151 | v <- createMatrix ColumnMajor q c |
143 | app4 f mat x mat u vec s mat v st | 152 | f # a # u # s # v #| st |
144 | return (u,s,trans v) | 153 | return (u,s,v) |
145 | where r = rows x | 154 | where |
146 | c = cols x | 155 | r = rows x |
147 | q = min r c | 156 | c = cols x |
157 | q = min r c | ||
148 | 158 | ||
149 | 159 | ||
150 | -- | Singular values of a real matrix, using LAPACK's /dgesvd/ with jobu == jobvt == \'N\'. | 160 | -- | Singular values of a real matrix, using LAPACK's /dgesvd/ with jobu == jobvt == \'N\'. |
151 | svR :: Matrix Double -> Vector Double | 161 | svR :: Matrix Double -> Vector Double |
152 | svR = svAux dgesvd "svR" . fmat | 162 | svR = svAux dgesvd "svR" |
153 | 163 | ||
154 | -- | Singular values of a complex matrix, using LAPACK's /zgesvd/ with jobu == jobvt == \'N\'. | 164 | -- | Singular values of a complex matrix, using LAPACK's /zgesvd/ with jobu == jobvt == \'N\'. |
155 | svC :: Matrix (Complex Double) -> Vector Double | 165 | svC :: Matrix (Complex Double) -> Vector Double |
156 | svC = svAux zgesvd "svC" . fmat | 166 | svC = svAux zgesvd "svC" |
157 | 167 | ||
158 | -- | Singular values of a real matrix, using LAPACK's /dgesdd/ with jobz == \'N\'. | 168 | -- | Singular values of a real matrix, using LAPACK's /dgesdd/ with jobz == \'N\'. |
159 | svRd :: Matrix Double -> Vector Double | 169 | svRd :: Matrix Double -> Vector Double |
160 | svRd = svAux dgesdd "svRd" . fmat | 170 | svRd = svAux dgesdd "svRd" |
161 | 171 | ||
162 | -- | Singular values of a complex matrix, using LAPACK's /zgesdd/ with jobz == \'N\'. | 172 | -- | Singular values of a complex matrix, using LAPACK's /zgesdd/ with jobz == \'N\'. |
163 | svCd :: Matrix (Complex Double) -> Vector Double | 173 | svCd :: Matrix (Complex Double) -> Vector Double |
164 | svCd = svAux zgesdd "svCd" . fmat | 174 | svCd = svAux zgesdd "svCd" |
165 | 175 | ||
166 | svAux f st x = unsafePerformIO $ do | 176 | svAux f st x = unsafePerformIO $ do |
177 | a <- copy ColumnMajor x | ||
167 | s <- createVector q | 178 | s <- createVector q |
168 | app2 g mat x vec s st | 179 | g # a # s #| st |
169 | return s | 180 | return s |
170 | where r = rows x | 181 | where |
171 | c = cols x | 182 | r = rows x |
172 | q = min r c | 183 | c = cols x |
173 | g ra ca pa nb pb = f ra ca pa 0 0 nullPtr nb pb 0 0 nullPtr | 184 | q = min r c |
185 | g ra ca xra xca pa nb pb = f ra ca xra xca pa 0 0 0 0 nullPtr nb pb 0 0 0 0 nullPtr | ||
174 | 186 | ||
175 | 187 | ||
176 | -- | Singular values and all right singular vectors of a real matrix, using LAPACK's /dgesvd/ with jobu == \'N\' and jobvt == \'A\'. | 188 | -- | Singular values and all right singular vectors of a real matrix, using LAPACK's /dgesvd/ with jobu == \'N\' and jobvt == \'A\'. |
177 | rightSVR :: Matrix Double -> (Vector Double, Matrix Double) | 189 | rightSVR :: Matrix Double -> (Vector Double, Matrix Double) |
178 | rightSVR = rightSVAux dgesvd "rightSVR" . fmat | 190 | rightSVR = rightSVAux dgesvd "rightSVR" |
179 | 191 | ||
180 | -- | Singular values and all right singular vectors of a complex matrix, using LAPACK's /zgesvd/ with jobu == \'N\' and jobvt == \'A\'. | 192 | -- | Singular values and all right singular vectors of a complex matrix, using LAPACK's /zgesvd/ with jobu == \'N\' and jobvt == \'A\'. |
181 | rightSVC :: Matrix (Complex Double) -> (Vector Double, Matrix (Complex Double)) | 193 | rightSVC :: Matrix (Complex Double) -> (Vector Double, Matrix (Complex Double)) |
182 | rightSVC = rightSVAux zgesvd "rightSVC" . fmat | 194 | rightSVC = rightSVAux zgesvd "rightSVC" |
183 | 195 | ||
184 | rightSVAux f st x = unsafePerformIO $ do | 196 | rightSVAux f st x = unsafePerformIO $ do |
197 | a <- copy ColumnMajor x | ||
185 | s <- createVector q | 198 | s <- createVector q |
186 | v <- createMatrix ColumnMajor c c | 199 | v <- createMatrix ColumnMajor c c |
187 | app3 g mat x vec s mat v st | 200 | g # a # s # v #| st |
188 | return (s,trans v) | 201 | return (s,v) |
189 | where r = rows x | 202 | where |
190 | c = cols x | 203 | r = rows x |
191 | q = min r c | 204 | c = cols x |
192 | g ra ca pa = f ra ca pa 0 0 nullPtr | 205 | q = min r c |
206 | g ra ca xra xca pa = f ra ca xra xca pa 0 0 0 0 nullPtr | ||
193 | 207 | ||
194 | 208 | ||
195 | -- | Singular values and all left singular vectors of a real matrix, using LAPACK's /dgesvd/ with jobu == \'A\' and jobvt == \'N\'. | 209 | -- | Singular values and all left singular vectors of a real matrix, using LAPACK's /dgesvd/ with jobu == \'A\' and jobvt == \'N\'. |
196 | leftSVR :: Matrix Double -> (Matrix Double, Vector Double) | 210 | leftSVR :: Matrix Double -> (Matrix Double, Vector Double) |
197 | leftSVR = leftSVAux dgesvd "leftSVR" . fmat | 211 | leftSVR = leftSVAux dgesvd "leftSVR" |
198 | 212 | ||
199 | -- | Singular values and all left singular vectors of a complex matrix, using LAPACK's /zgesvd/ with jobu == \'A\' and jobvt == \'N\'. | 213 | -- | Singular values and all left singular vectors of a complex matrix, using LAPACK's /zgesvd/ with jobu == \'A\' and jobvt == \'N\'. |
200 | leftSVC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector Double) | 214 | leftSVC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector Double) |
201 | leftSVC = leftSVAux zgesvd "leftSVC" . fmat | 215 | leftSVC = leftSVAux zgesvd "leftSVC" |
202 | 216 | ||
203 | leftSVAux f st x = unsafePerformIO $ do | 217 | leftSVAux f st x = unsafePerformIO $ do |
218 | a <- copy ColumnMajor x | ||
204 | u <- createMatrix ColumnMajor r r | 219 | u <- createMatrix ColumnMajor r r |
205 | s <- createVector q | 220 | s <- createVector q |
206 | app3 g mat x mat u vec s st | 221 | g # a # u # s #| st |
207 | return (u,s) | 222 | return (u,s) |
208 | where r = rows x | 223 | where |
209 | c = cols x | 224 | r = rows x |
210 | q = min r c | 225 | c = cols x |
211 | g ra ca pa ru cu pu nb pb = f ra ca pa ru cu pu nb pb 0 0 nullPtr | 226 | q = min r c |
227 | g ra ca xra xca pa ru cu xru xcu pu nb pb = f ra ca xra xca pa ru cu xru xcu pu nb pb 0 0 0 0 nullPtr | ||
212 | 228 | ||
213 | ----------------------------------------------------------------------------- | 229 | ----------------------------------------------------------------------------- |
214 | 230 | ||
215 | foreign import ccall unsafe "eig_l_R" dgeev :: TMMCVM | 231 | foreign import ccall unsafe "eig_l_R" dgeev :: R ::> R ::> C :> R ::> Ok |
216 | foreign import ccall unsafe "eig_l_C" zgeev :: TCMCMCVCM | 232 | foreign import ccall unsafe "eig_l_C" zgeev :: C ::> C ::> C :> C ::> Ok |
217 | foreign import ccall unsafe "eig_l_S" dsyev :: CInt -> TMVM | 233 | foreign import ccall unsafe "eig_l_S" dsyev :: CInt -> R :> R ::> Ok |
218 | foreign import ccall unsafe "eig_l_H" zheev :: CInt -> TCMVCM | 234 | foreign import ccall unsafe "eig_l_H" zheev :: CInt -> R :> C ::> Ok |
219 | 235 | ||
220 | eigAux f st m = unsafePerformIO $ do | 236 | eigAux f st m = unsafePerformIO $ do |
221 | l <- createVector r | 237 | a <- copy ColumnMajor m |
222 | v <- createMatrix ColumnMajor r r | 238 | l <- createVector r |
223 | app3 g mat m vec l mat v st | 239 | v <- createMatrix ColumnMajor r r |
224 | return (l,v) | 240 | g # a # l # v #| st |
225 | where r = rows m | 241 | return (l,v) |
226 | g ra ca pa = f ra ca pa 0 0 nullPtr | 242 | where |
243 | r = rows m | ||
244 | g ra ca xra xca pa = f ra ca xra xca pa 0 0 0 0 nullPtr | ||
227 | 245 | ||
228 | 246 | ||
229 | -- | Eigenvalues and right eigenvectors of a general complex matrix, using LAPACK's /zgeev/. | 247 | -- | Eigenvalues and right eigenvectors of a general complex matrix, using LAPACK's /zgeev/. |
230 | -- The eigenvectors are the columns of v. The eigenvalues are not sorted. | 248 | -- The eigenvectors are the columns of v. The eigenvalues are not sorted. |
231 | eigC :: Matrix (Complex Double) -> (Vector (Complex Double), Matrix (Complex Double)) | 249 | eigC :: Matrix (Complex Double) -> (Vector (Complex Double), Matrix (Complex Double)) |
232 | eigC = eigAux zgeev "eigC" . fmat | 250 | eigC = eigAux zgeev "eigC" |
233 | 251 | ||
234 | eigOnlyAux f st m = unsafePerformIO $ do | 252 | eigOnlyAux f st m = unsafePerformIO $ do |
235 | l <- createVector r | 253 | a <- copy ColumnMajor m |
236 | app2 g mat m vec l st | 254 | l <- createVector r |
237 | return l | 255 | g # a # l #| st |
238 | where r = rows m | 256 | return l |
239 | g ra ca pa nl pl = f ra ca pa 0 0 nullPtr nl pl 0 0 nullPtr | 257 | where |
258 | r = rows m | ||
259 | g ra ca xra xca pa nl pl = f ra ca xra xca pa 0 0 0 0 nullPtr nl pl 0 0 0 0 nullPtr | ||
240 | 260 | ||
241 | -- | Eigenvalues of a general complex matrix, using LAPACK's /zgeev/ with jobz == \'N\'. | 261 | -- | Eigenvalues of a general complex matrix, using LAPACK's /zgeev/ with jobz == \'N\'. |
242 | -- The eigenvalues are not sorted. | 262 | -- The eigenvalues are not sorted. |
243 | eigOnlyC :: Matrix (Complex Double) -> Vector (Complex Double) | 263 | eigOnlyC :: Matrix (Complex Double) -> Vector (Complex Double) |
244 | eigOnlyC = eigOnlyAux zgeev "eigOnlyC" . fmat | 264 | eigOnlyC = eigOnlyAux zgeev "eigOnlyC" |
245 | 265 | ||
246 | -- | Eigenvalues and right eigenvectors of a general real matrix, using LAPACK's /dgeev/. | 266 | -- | Eigenvalues and right eigenvectors of a general real matrix, using LAPACK's /dgeev/. |
247 | -- The eigenvectors are the columns of v. The eigenvalues are not sorted. | 267 | -- The eigenvectors are the columns of v. The eigenvalues are not sorted. |
248 | eigR :: Matrix Double -> (Vector (Complex Double), Matrix (Complex Double)) | 268 | eigR :: Matrix Double -> (Vector (Complex Double), Matrix (Complex Double)) |
249 | eigR m = (s', v'') | 269 | eigR m = (s', v'') |
250 | where (s,v) = eigRaux (fmat m) | 270 | where (s,v) = eigRaux m |
251 | s' = fixeig1 s | 271 | s' = fixeig1 s |
252 | v' = toRows $ trans v | 272 | v' = toRows $ trans v |
253 | v'' = fromColumns $ fixeig (toList s') v' | 273 | v'' = fromColumns $ fixeig (toList s') v' |
254 | 274 | ||
255 | eigRaux :: Matrix Double -> (Vector (Complex Double), Matrix Double) | 275 | eigRaux :: Matrix Double -> (Vector (Complex Double), Matrix Double) |
256 | eigRaux m = unsafePerformIO $ do | 276 | eigRaux m = unsafePerformIO $ do |
257 | l <- createVector r | 277 | a <- copy ColumnMajor m |
258 | v <- createMatrix ColumnMajor r r | 278 | l <- createVector r |
259 | app3 g mat m vec l mat v "eigR" | 279 | v <- createMatrix ColumnMajor r r |
260 | return (l,v) | 280 | g # a # l # v #| "eigR" |
261 | where r = rows m | 281 | return (l,v) |
262 | g ra ca pa = dgeev ra ca pa 0 0 nullPtr | 282 | where |
283 | r = rows m | ||
284 | g ra ca xra xca pa = dgeev ra ca xra xca pa 0 0 0 0 nullPtr | ||
263 | 285 | ||
264 | fixeig1 s = toComplex' (subVector 0 r (asReal s), subVector r r (asReal s)) | 286 | fixeig1 s = toComplex' (subVector 0 r (asReal s), subVector r r (asReal s)) |
265 | where r = dim s | 287 | where r = dim s |
@@ -275,118 +297,141 @@ fixeig _ _ = error "fixeig with impossible inputs" | |||
275 | -- | Eigenvalues of a general real matrix, using LAPACK's /dgeev/ with jobz == \'N\'. | 297 | -- | Eigenvalues of a general real matrix, using LAPACK's /dgeev/ with jobz == \'N\'. |
276 | -- The eigenvalues are not sorted. | 298 | -- The eigenvalues are not sorted. |
277 | eigOnlyR :: Matrix Double -> Vector (Complex Double) | 299 | eigOnlyR :: Matrix Double -> Vector (Complex Double) |
278 | eigOnlyR = fixeig1 . eigOnlyAux dgeev "eigOnlyR" . fmat | 300 | eigOnlyR = fixeig1 . eigOnlyAux dgeev "eigOnlyR" |
279 | 301 | ||
280 | 302 | ||
281 | ----------------------------------------------------------------------------- | 303 | ----------------------------------------------------------------------------- |
282 | 304 | ||
283 | eigSHAux f st m = unsafePerformIO $ do | 305 | eigSHAux f st m = unsafePerformIO $ do |
284 | l <- createVector r | 306 | l <- createVector r |
285 | v <- createMatrix ColumnMajor r r | 307 | v <- copy ColumnMajor m |
286 | app3 f mat m vec l mat v st | 308 | f # l # v #| st |
287 | return (l,v) | 309 | return (l,v) |
288 | where r = rows m | 310 | where |
311 | r = rows m | ||
289 | 312 | ||
290 | -- | Eigenvalues and right eigenvectors of a symmetric real matrix, using LAPACK's /dsyev/. | 313 | -- | Eigenvalues and right eigenvectors of a symmetric real matrix, using LAPACK's /dsyev/. |
291 | -- The eigenvectors are the columns of v. | 314 | -- The eigenvectors are the columns of v. |
292 | -- The eigenvalues are sorted in descending order (use 'eigS'' for ascending order). | 315 | -- The eigenvalues are sorted in descending order (use 'eigS'' for ascending order). |
293 | eigS :: Matrix Double -> (Vector Double, Matrix Double) | 316 | eigS :: Matrix Double -> (Vector Double, Matrix Double) |
294 | eigS m = (s', fliprl v) | 317 | eigS m = (s', fliprl v) |
295 | where (s,v) = eigS' (fmat m) | 318 | where (s,v) = eigS' m |
296 | s' = fromList . reverse . toList $ s | 319 | s' = fromList . reverse . toList $ s |
297 | 320 | ||
298 | -- | 'eigS' in ascending order | 321 | -- | 'eigS' in ascending order |
299 | eigS' :: Matrix Double -> (Vector Double, Matrix Double) | 322 | eigS' :: Matrix Double -> (Vector Double, Matrix Double) |
300 | eigS' = eigSHAux (dsyev 1) "eigS'" . fmat | 323 | eigS' = eigSHAux (dsyev 1) "eigS'" |
301 | 324 | ||
302 | -- | Eigenvalues and right eigenvectors of a hermitian complex matrix, using LAPACK's /zheev/. | 325 | -- | Eigenvalues and right eigenvectors of a hermitian complex matrix, using LAPACK's /zheev/. |
303 | -- The eigenvectors are the columns of v. | 326 | -- The eigenvectors are the columns of v. |
304 | -- The eigenvalues are sorted in descending order (use 'eigH'' for ascending order). | 327 | -- The eigenvalues are sorted in descending order (use 'eigH'' for ascending order). |
305 | eigH :: Matrix (Complex Double) -> (Vector Double, Matrix (Complex Double)) | 328 | eigH :: Matrix (Complex Double) -> (Vector Double, Matrix (Complex Double)) |
306 | eigH m = (s', fliprl v) | 329 | eigH m = (s', fliprl v) |
307 | where (s,v) = eigH' (fmat m) | 330 | where |
308 | s' = fromList . reverse . toList $ s | 331 | (s,v) = eigH' m |
332 | s' = fromList . reverse . toList $ s | ||
309 | 333 | ||
310 | -- | 'eigH' in ascending order | 334 | -- | 'eigH' in ascending order |
311 | eigH' :: Matrix (Complex Double) -> (Vector Double, Matrix (Complex Double)) | 335 | eigH' :: Matrix (Complex Double) -> (Vector Double, Matrix (Complex Double)) |
312 | eigH' = eigSHAux (zheev 1) "eigH'" . fmat | 336 | eigH' = eigSHAux (zheev 1) "eigH'" |
313 | 337 | ||
314 | 338 | ||
315 | -- | Eigenvalues of a symmetric real matrix, using LAPACK's /dsyev/ with jobz == \'N\'. | 339 | -- | Eigenvalues of a symmetric real matrix, using LAPACK's /dsyev/ with jobz == \'N\'. |
316 | -- The eigenvalues are sorted in descending order. | 340 | -- The eigenvalues are sorted in descending order. |
317 | eigOnlyS :: Matrix Double -> Vector Double | 341 | eigOnlyS :: Matrix Double -> Vector Double |
318 | eigOnlyS = vrev . fst. eigSHAux (dsyev 0) "eigS'" . fmat | 342 | eigOnlyS = vrev . fst. eigSHAux (dsyev 0) "eigS'" |
319 | 343 | ||
320 | -- | Eigenvalues of a hermitian complex matrix, using LAPACK's /zheev/ with jobz == \'N\'. | 344 | -- | Eigenvalues of a hermitian complex matrix, using LAPACK's /zheev/ with jobz == \'N\'. |
321 | -- The eigenvalues are sorted in descending order. | 345 | -- The eigenvalues are sorted in descending order. |
322 | eigOnlyH :: Matrix (Complex Double) -> Vector Double | 346 | eigOnlyH :: Matrix (Complex Double) -> Vector Double |
323 | eigOnlyH = vrev . fst. eigSHAux (zheev 0) "eigH'" . fmat | 347 | eigOnlyH = vrev . fst. eigSHAux (zheev 0) "eigH'" |
324 | 348 | ||
325 | vrev = flatten . flipud . reshape 1 | 349 | vrev = flatten . flipud . reshape 1 |
326 | 350 | ||
327 | ----------------------------------------------------------------------------- | 351 | ----------------------------------------------------------------------------- |
328 | foreign import ccall unsafe "linearSolveR_l" dgesv :: TMMM | 352 | foreign import ccall unsafe "linearSolveR_l" dgesv :: R ::> R ::> Ok |
329 | foreign import ccall unsafe "linearSolveC_l" zgesv :: TCMCMCM | 353 | foreign import ccall unsafe "linearSolveC_l" zgesv :: C ::> C ::> Ok |
330 | foreign import ccall unsafe "cholSolveR_l" dpotrs :: TMMM | ||
331 | foreign import ccall unsafe "cholSolveC_l" zpotrs :: TCMCMCM | ||
332 | 354 | ||
333 | linearSolveSQAux g f st a b | 355 | linearSolveSQAux g f st a b |
334 | | n1==n2 && n1==r = unsafePerformIO . g $ do | 356 | | n1==n2 && n1==r = unsafePerformIO . g $ do |
335 | s <- createMatrix ColumnMajor r c | 357 | a' <- copy ColumnMajor a |
336 | app3 f mat a mat b mat s st | 358 | s <- copy ColumnMajor b |
359 | f # a' # s #| st | ||
337 | return s | 360 | return s |
338 | | otherwise = error $ st ++ " of nonsquare matrix" | 361 | | otherwise = error $ st ++ " of nonsquare matrix" |
339 | where n1 = rows a | 362 | where |
340 | n2 = cols a | 363 | n1 = rows a |
341 | r = rows b | 364 | n2 = cols a |
342 | c = cols b | 365 | r = rows b |
343 | 366 | ||
344 | -- | Solve a real linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition, based on LAPACK's /dgesv/. For underconstrained or overconstrained systems use 'linearSolveLSR' or 'linearSolveSVDR'. See also 'lusR'. | 367 | -- | Solve a real linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition, based on LAPACK's /dgesv/. For underconstrained or overconstrained systems use 'linearSolveLSR' or 'linearSolveSVDR'. See also 'lusR'. |
345 | linearSolveR :: Matrix Double -> Matrix Double -> Matrix Double | 368 | linearSolveR :: Matrix Double -> Matrix Double -> Matrix Double |
346 | linearSolveR a b = linearSolveSQAux id dgesv "linearSolveR" (fmat a) (fmat b) | 369 | linearSolveR a b = linearSolveSQAux id dgesv "linearSolveR" a b |
347 | 370 | ||
348 | mbLinearSolveR :: Matrix Double -> Matrix Double -> Maybe (Matrix Double) | 371 | mbLinearSolveR :: Matrix Double -> Matrix Double -> Maybe (Matrix Double) |
349 | mbLinearSolveR a b = linearSolveSQAux mbCatch dgesv "linearSolveR" (fmat a) (fmat b) | 372 | mbLinearSolveR a b = linearSolveSQAux mbCatch dgesv "linearSolveR" a b |
350 | 373 | ||
351 | 374 | ||
352 | -- | Solve a complex linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition, based on LAPACK's /zgesv/. For underconstrained or overconstrained systems use 'linearSolveLSC' or 'linearSolveSVDC'. See also 'lusC'. | 375 | -- | Solve a complex linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition, based on LAPACK's /zgesv/. For underconstrained or overconstrained systems use 'linearSolveLSC' or 'linearSolveSVDC'. See also 'lusC'. |
353 | linearSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | 376 | linearSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) |
354 | linearSolveC a b = linearSolveSQAux id zgesv "linearSolveC" (fmat a) (fmat b) | 377 | linearSolveC a b = linearSolveSQAux id zgesv "linearSolveC" a b |
355 | 378 | ||
356 | mbLinearSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Maybe (Matrix (Complex Double)) | 379 | mbLinearSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Maybe (Matrix (Complex Double)) |
357 | mbLinearSolveC a b = linearSolveSQAux mbCatch zgesv "linearSolveC" (fmat a) (fmat b) | 380 | mbLinearSolveC a b = linearSolveSQAux mbCatch zgesv "linearSolveC" a b |
381 | |||
382 | -------------------------------------------------------------------------------- | ||
383 | foreign import ccall unsafe "cholSolveR_l" dpotrs :: R ::> R ::> Ok | ||
384 | foreign import ccall unsafe "cholSolveC_l" zpotrs :: C ::> C ::> Ok | ||
385 | |||
386 | |||
387 | linearSolveSQAux2 g f st a b | ||
388 | | n1==n2 && n1==r = unsafePerformIO . g $ do | ||
389 | s <- copy ColumnMajor b | ||
390 | f # a # s #| st | ||
391 | return s | ||
392 | | otherwise = error $ st ++ " of nonsquare matrix" | ||
393 | where | ||
394 | n1 = rows a | ||
395 | n2 = cols a | ||
396 | r = rows b | ||
358 | 397 | ||
359 | -- | Solves a symmetric positive definite system of linear equations using a precomputed Cholesky factorization obtained by 'cholS'. | 398 | -- | Solves a symmetric positive definite system of linear equations using a precomputed Cholesky factorization obtained by 'cholS'. |
360 | cholSolveR :: Matrix Double -> Matrix Double -> Matrix Double | 399 | cholSolveR :: Matrix Double -> Matrix Double -> Matrix Double |
361 | cholSolveR a b = linearSolveSQAux id dpotrs "cholSolveR" (fmat a) (fmat b) | 400 | cholSolveR a b = linearSolveSQAux2 id dpotrs "cholSolveR" (fmat a) b |
362 | 401 | ||
363 | -- | Solves a Hermitian positive definite system of linear equations using a precomputed Cholesky factorization obtained by 'cholH'. | 402 | -- | Solves a Hermitian positive definite system of linear equations using a precomputed Cholesky factorization obtained by 'cholH'. |
364 | cholSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | 403 | cholSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) |
365 | cholSolveC a b = linearSolveSQAux id zpotrs "cholSolveC" (fmat a) (fmat b) | 404 | cholSolveC a b = linearSolveSQAux2 id zpotrs "cholSolveC" (fmat a) b |
366 | 405 | ||
367 | ----------------------------------------------------------------------------------- | 406 | ----------------------------------------------------------------------------------- |
368 | foreign import ccall unsafe "linearSolveLSR_l" dgels :: TMMM | 407 | |
369 | foreign import ccall unsafe "linearSolveLSC_l" zgels :: TCMCMCM | 408 | foreign import ccall unsafe "linearSolveLSR_l" dgels :: R ::> R ::> Ok |
370 | foreign import ccall unsafe "linearSolveSVDR_l" dgelss :: Double -> TMMM | 409 | foreign import ccall unsafe "linearSolveLSC_l" zgels :: C ::> C ::> Ok |
371 | foreign import ccall unsafe "linearSolveSVDC_l" zgelss :: Double -> TCMCMCM | 410 | foreign import ccall unsafe "linearSolveSVDR_l" dgelss :: Double -> R ::> R ::> Ok |
372 | 411 | foreign import ccall unsafe "linearSolveSVDC_l" zgelss :: Double -> C ::> C ::> Ok | |
373 | linearSolveAux f st a b = unsafePerformIO $ do | 412 | |
374 | r <- createMatrix ColumnMajor (max m n) nrhs | 413 | linearSolveAux f st a b |
375 | app3 f mat a mat b mat r st | 414 | | m == rows b = unsafePerformIO $ do |
376 | return r | 415 | a' <- copy ColumnMajor a |
377 | where m = rows a | 416 | r <- createMatrix ColumnMajor (max m n) nrhs |
378 | n = cols a | 417 | setRect 0 0 b r |
379 | nrhs = cols b | 418 | f # a' # r #| st |
419 | return r | ||
420 | | otherwise = error $ "different number of rows in linearSolve ("++st++")" | ||
421 | where | ||
422 | m = rows a | ||
423 | n = cols a | ||
424 | nrhs = cols b | ||
380 | 425 | ||
381 | -- | Least squared error solution of an overconstrained real linear system, or the minimum norm solution of an underconstrained system, using LAPACK's /dgels/. For rank-deficient systems use 'linearSolveSVDR'. | 426 | -- | Least squared error solution of an overconstrained real linear system, or the minimum norm solution of an underconstrained system, using LAPACK's /dgels/. For rank-deficient systems use 'linearSolveSVDR'. |
382 | linearSolveLSR :: Matrix Double -> Matrix Double -> Matrix Double | 427 | linearSolveLSR :: Matrix Double -> Matrix Double -> Matrix Double |
383 | linearSolveLSR a b = subMatrix (0,0) (cols a, cols b) $ | 428 | linearSolveLSR a b = subMatrix (0,0) (cols a, cols b) $ |
384 | linearSolveAux dgels "linearSolverLSR" (fmat a) (fmat b) | 429 | linearSolveAux dgels "linearSolverLSR" a b |
385 | 430 | ||
386 | -- | Least squared error solution of an overconstrained complex linear system, or the minimum norm solution of an underconstrained system, using LAPACK's /zgels/. For rank-deficient systems use 'linearSolveSVDC'. | 431 | -- | Least squared error solution of an overconstrained complex linear system, or the minimum norm solution of an underconstrained system, using LAPACK's /zgels/. For rank-deficient systems use 'linearSolveSVDC'. |
387 | linearSolveLSC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | 432 | linearSolveLSC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) |
388 | linearSolveLSC a b = subMatrix (0,0) (cols a, cols b) $ | 433 | linearSolveLSC a b = subMatrix (0,0) (cols a, cols b) $ |
389 | linearSolveAux zgels "linearSolveLSC" (fmat a) (fmat b) | 434 | linearSolveAux zgels "linearSolveLSC" a b |
390 | 435 | ||
391 | -- | Minimum norm solution of a general real linear least squares problem Ax=B using the SVD, based on LAPACK's /dgelss/. Admits rank-deficient systems but it is slower than 'linearSolveLSR'. The effective rank of A is determined by treating as zero those singular valures which are less than rcond times the largest singular value. If rcond == Nothing machine precision is used. | 436 | -- | Minimum norm solution of a general real linear least squares problem Ax=B using the SVD, based on LAPACK's /dgelss/. Admits rank-deficient systems but it is slower than 'linearSolveLSR'. The effective rank of A is determined by treating as zero those singular valures which are less than rcond times the largest singular value. If rcond == Nothing machine precision is used. |
392 | linearSolveSVDR :: Maybe Double -- ^ rcond | 437 | linearSolveSVDR :: Maybe Double -- ^ rcond |
@@ -394,8 +439,8 @@ linearSolveSVDR :: Maybe Double -- ^ rcond | |||
394 | -> Matrix Double -- ^ right hand sides (as columns) | 439 | -> Matrix Double -- ^ right hand sides (as columns) |
395 | -> Matrix Double -- ^ solution vectors (as columns) | 440 | -> Matrix Double -- ^ solution vectors (as columns) |
396 | linearSolveSVDR (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $ | 441 | linearSolveSVDR (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $ |
397 | linearSolveAux (dgelss rcond) "linearSolveSVDR" (fmat a) (fmat b) | 442 | linearSolveAux (dgelss rcond) "linearSolveSVDR" a b |
398 | linearSolveSVDR Nothing a b = linearSolveSVDR (Just (-1)) (fmat a) (fmat b) | 443 | linearSolveSVDR Nothing a b = linearSolveSVDR (Just (-1)) a b |
399 | 444 | ||
400 | -- | Minimum norm solution of a general complex linear least squares problem Ax=B using the SVD, based on LAPACK's /zgelss/. Admits rank-deficient systems but it is slower than 'linearSolveLSC'. The effective rank of A is determined by treating as zero those singular valures which are less than rcond times the largest singular value. If rcond == Nothing machine precision is used. | 445 | -- | Minimum norm solution of a general complex linear least squares problem Ax=B using the SVD, based on LAPACK's /zgelss/. Admits rank-deficient systems but it is slower than 'linearSolveLSC'. The effective rank of A is determined by treating as zero those singular valures which are less than rcond times the largest singular value. If rcond == Nothing machine precision is used. |
401 | linearSolveSVDC :: Maybe Double -- ^ rcond | 446 | linearSolveSVDC :: Maybe Double -- ^ rcond |
@@ -403,59 +448,62 @@ linearSolveSVDC :: Maybe Double -- ^ rcond | |||
403 | -> Matrix (Complex Double) -- ^ right hand sides (as columns) | 448 | -> Matrix (Complex Double) -- ^ right hand sides (as columns) |
404 | -> Matrix (Complex Double) -- ^ solution vectors (as columns) | 449 | -> Matrix (Complex Double) -- ^ solution vectors (as columns) |
405 | linearSolveSVDC (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $ | 450 | linearSolveSVDC (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $ |
406 | linearSolveAux (zgelss rcond) "linearSolveSVDC" (fmat a) (fmat b) | 451 | linearSolveAux (zgelss rcond) "linearSolveSVDC" a b |
407 | linearSolveSVDC Nothing a b = linearSolveSVDC (Just (-1)) (fmat a) (fmat b) | 452 | linearSolveSVDC Nothing a b = linearSolveSVDC (Just (-1)) a b |
408 | 453 | ||
409 | ----------------------------------------------------------------------------------- | 454 | ----------------------------------------------------------------------------------- |
410 | foreign import ccall unsafe "chol_l_H" zpotrf :: TCMCM | 455 | |
411 | foreign import ccall unsafe "chol_l_S" dpotrf :: TMM | 456 | foreign import ccall unsafe "chol_l_H" zpotrf :: C ::> Ok |
457 | foreign import ccall unsafe "chol_l_S" dpotrf :: R ::> Ok | ||
412 | 458 | ||
413 | cholAux f st a = do | 459 | cholAux f st a = do |
414 | r <- createMatrix ColumnMajor n n | 460 | r <- copy ColumnMajor a |
415 | app2 f mat a mat r st | 461 | f # r #| st |
416 | return r | 462 | return r |
417 | where n = rows a | ||
418 | 463 | ||
419 | -- | Cholesky factorization of a complex Hermitian positive definite matrix, using LAPACK's /zpotrf/. | 464 | -- | Cholesky factorization of a complex Hermitian positive definite matrix, using LAPACK's /zpotrf/. |
420 | cholH :: Matrix (Complex Double) -> Matrix (Complex Double) | 465 | cholH :: Matrix (Complex Double) -> Matrix (Complex Double) |
421 | cholH = unsafePerformIO . cholAux zpotrf "cholH" . fmat | 466 | cholH = unsafePerformIO . cholAux zpotrf "cholH" |
422 | 467 | ||
423 | -- | Cholesky factorization of a real symmetric positive definite matrix, using LAPACK's /dpotrf/. | 468 | -- | Cholesky factorization of a real symmetric positive definite matrix, using LAPACK's /dpotrf/. |
424 | cholS :: Matrix Double -> Matrix Double | 469 | cholS :: Matrix Double -> Matrix Double |
425 | cholS = unsafePerformIO . cholAux dpotrf "cholS" . fmat | 470 | cholS = unsafePerformIO . cholAux dpotrf "cholS" |
426 | 471 | ||
427 | -- | Cholesky factorization of a complex Hermitian positive definite matrix, using LAPACK's /zpotrf/ ('Maybe' version). | 472 | -- | Cholesky factorization of a complex Hermitian positive definite matrix, using LAPACK's /zpotrf/ ('Maybe' version). |
428 | mbCholH :: Matrix (Complex Double) -> Maybe (Matrix (Complex Double)) | 473 | mbCholH :: Matrix (Complex Double) -> Maybe (Matrix (Complex Double)) |
429 | mbCholH = unsafePerformIO . mbCatch . cholAux zpotrf "cholH" . fmat | 474 | mbCholH = unsafePerformIO . mbCatch . cholAux zpotrf "cholH" |
430 | 475 | ||
431 | -- | Cholesky factorization of a real symmetric positive definite matrix, using LAPACK's /dpotrf/ ('Maybe' version). | 476 | -- | Cholesky factorization of a real symmetric positive definite matrix, using LAPACK's /dpotrf/ ('Maybe' version). |
432 | mbCholS :: Matrix Double -> Maybe (Matrix Double) | 477 | mbCholS :: Matrix Double -> Maybe (Matrix Double) |
433 | mbCholS = unsafePerformIO . mbCatch . cholAux dpotrf "cholS" . fmat | 478 | mbCholS = unsafePerformIO . mbCatch . cholAux dpotrf "cholS" |
434 | 479 | ||
435 | ----------------------------------------------------------------------------------- | 480 | ----------------------------------------------------------------------------------- |
436 | foreign import ccall unsafe "qr_l_R" dgeqr2 :: TMVM | 481 | |
437 | foreign import ccall unsafe "qr_l_C" zgeqr2 :: TCMCVCM | 482 | type TMVM t = t ::> t :> t ::> Ok |
483 | |||
484 | foreign import ccall unsafe "qr_l_R" dgeqr2 :: R :> R ::> Ok | ||
485 | foreign import ccall unsafe "qr_l_C" zgeqr2 :: C :> C ::> Ok | ||
438 | 486 | ||
439 | -- | QR factorization of a real matrix, using LAPACK's /dgeqr2/. | 487 | -- | QR factorization of a real matrix, using LAPACK's /dgeqr2/. |
440 | qrR :: Matrix Double -> (Matrix Double, Vector Double) | 488 | qrR :: Matrix Double -> (Matrix Double, Vector Double) |
441 | qrR = qrAux dgeqr2 "qrR" . fmat | 489 | qrR = qrAux dgeqr2 "qrR" |
442 | 490 | ||
443 | -- | QR factorization of a complex matrix, using LAPACK's /zgeqr2/. | 491 | -- | QR factorization of a complex matrix, using LAPACK's /zgeqr2/. |
444 | qrC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) | 492 | qrC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) |
445 | qrC = qrAux zgeqr2 "qrC" . fmat | 493 | qrC = qrAux zgeqr2 "qrC" |
446 | 494 | ||
447 | qrAux f st a = unsafePerformIO $ do | 495 | qrAux f st a = unsafePerformIO $ do |
448 | r <- createMatrix ColumnMajor m n | 496 | r <- copy ColumnMajor a |
449 | tau <- createVector mn | 497 | tau <- createVector mn |
450 | app3 f mat a vec tau mat r st | 498 | f # tau # r #| st |
451 | return (r,tau) | 499 | return (r,tau) |
452 | where | 500 | where |
453 | m = rows a | 501 | m = rows a |
454 | n = cols a | 502 | n = cols a |
455 | mn = min m n | 503 | mn = min m n |
456 | 504 | ||
457 | foreign import ccall unsafe "c_dorgqr" dorgqr :: TMVM | 505 | foreign import ccall unsafe "c_dorgqr" dorgqr :: R :> R ::> Ok |
458 | foreign import ccall unsafe "c_zungqr" zungqr :: TCMCVCM | 506 | foreign import ccall unsafe "c_zungqr" zungqr :: C :> C ::> Ok |
459 | 507 | ||
460 | -- | build rotation from reflectors | 508 | -- | build rotation from reflectors |
461 | qrgrR :: Int -> (Matrix Double, Vector Double) -> Matrix Double | 509 | qrgrR :: Int -> (Matrix Double, Vector Double) -> Matrix Double |
@@ -465,96 +513,128 @@ qrgrC :: Int -> (Matrix (Complex Double), Vector (Complex Double)) -> Matrix (Co | |||
465 | qrgrC = qrgrAux zungqr "qrgrC" | 513 | qrgrC = qrgrAux zungqr "qrgrC" |
466 | 514 | ||
467 | qrgrAux f st n (a, tau) = unsafePerformIO $ do | 515 | qrgrAux f st n (a, tau) = unsafePerformIO $ do |
468 | res <- createMatrix ColumnMajor (rows a) n | 516 | res <- copy ColumnMajor (subMatrix (0,0) (rows a,n) a) |
469 | app3 f mat (fmat a) vec (subVector 0 n tau') mat res st | 517 | f # (subVector 0 n tau') # res #| st |
470 | return res | 518 | return res |
471 | where | 519 | where |
472 | tau' = vjoin [tau, constantD 0 n] | 520 | tau' = vjoin [tau, constantD 0 n] |
473 | 521 | ||
474 | ----------------------------------------------------------------------------------- | 522 | ----------------------------------------------------------------------------------- |
475 | foreign import ccall unsafe "hess_l_R" dgehrd :: TMVM | 523 | foreign import ccall unsafe "hess_l_R" dgehrd :: R :> R ::> Ok |
476 | foreign import ccall unsafe "hess_l_C" zgehrd :: TCMCVCM | 524 | foreign import ccall unsafe "hess_l_C" zgehrd :: C :> C ::> Ok |
477 | 525 | ||
478 | -- | Hessenberg factorization of a square real matrix, using LAPACK's /dgehrd/. | 526 | -- | Hessenberg factorization of a square real matrix, using LAPACK's /dgehrd/. |
479 | hessR :: Matrix Double -> (Matrix Double, Vector Double) | 527 | hessR :: Matrix Double -> (Matrix Double, Vector Double) |
480 | hessR = hessAux dgehrd "hessR" . fmat | 528 | hessR = hessAux dgehrd "hessR" |
481 | 529 | ||
482 | -- | Hessenberg factorization of a square complex matrix, using LAPACK's /zgehrd/. | 530 | -- | Hessenberg factorization of a square complex matrix, using LAPACK's /zgehrd/. |
483 | hessC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) | 531 | hessC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) |
484 | hessC = hessAux zgehrd "hessC" . fmat | 532 | hessC = hessAux zgehrd "hessC" |
485 | 533 | ||
486 | hessAux f st a = unsafePerformIO $ do | 534 | hessAux f st a = unsafePerformIO $ do |
487 | r <- createMatrix ColumnMajor m n | 535 | r <- copy ColumnMajor a |
488 | tau <- createVector (mn-1) | 536 | tau <- createVector (mn-1) |
489 | app3 f mat a vec tau mat r st | 537 | f # tau # r #| st |
490 | return (r,tau) | 538 | return (r,tau) |
491 | where m = rows a | 539 | where |
492 | n = cols a | 540 | m = rows a |
493 | mn = min m n | 541 | n = cols a |
542 | mn = min m n | ||
494 | 543 | ||
495 | ----------------------------------------------------------------------------------- | 544 | ----------------------------------------------------------------------------------- |
496 | foreign import ccall unsafe "schur_l_R" dgees :: TMMM | 545 | foreign import ccall unsafe "schur_l_R" dgees :: R ::> R ::> Ok |
497 | foreign import ccall unsafe "schur_l_C" zgees :: TCMCMCM | 546 | foreign import ccall unsafe "schur_l_C" zgees :: C ::> C ::> Ok |
498 | 547 | ||
499 | -- | Schur factorization of a square real matrix, using LAPACK's /dgees/. | 548 | -- | Schur factorization of a square real matrix, using LAPACK's /dgees/. |
500 | schurR :: Matrix Double -> (Matrix Double, Matrix Double) | 549 | schurR :: Matrix Double -> (Matrix Double, Matrix Double) |
501 | schurR = schurAux dgees "schurR" . fmat | 550 | schurR = schurAux dgees "schurR" |
502 | 551 | ||
503 | -- | Schur factorization of a square complex matrix, using LAPACK's /zgees/. | 552 | -- | Schur factorization of a square complex matrix, using LAPACK's /zgees/. |
504 | schurC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix (Complex Double)) | 553 | schurC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix (Complex Double)) |
505 | schurC = schurAux zgees "schurC" . fmat | 554 | schurC = schurAux zgees "schurC" |
506 | 555 | ||
507 | schurAux f st a = unsafePerformIO $ do | 556 | schurAux f st a = unsafePerformIO $ do |
508 | u <- createMatrix ColumnMajor n n | 557 | u <- createMatrix ColumnMajor n n |
509 | s <- createMatrix ColumnMajor n n | 558 | s <- copy ColumnMajor a |
510 | app3 f mat a mat u mat s st | 559 | f # u # s #| st |
511 | return (u,s) | 560 | return (u,s) |
512 | where n = rows a | 561 | where |
562 | n = rows a | ||
513 | 563 | ||
514 | ----------------------------------------------------------------------------------- | 564 | ----------------------------------------------------------------------------------- |
515 | foreign import ccall unsafe "lu_l_R" dgetrf :: TMVM | 565 | foreign import ccall unsafe "lu_l_R" dgetrf :: R :> R ::> Ok |
516 | foreign import ccall unsafe "lu_l_C" zgetrf :: TCMVCM | 566 | foreign import ccall unsafe "lu_l_C" zgetrf :: R :> C ::> Ok |
517 | 567 | ||
518 | -- | LU factorization of a general real matrix, using LAPACK's /dgetrf/. | 568 | -- | LU factorization of a general real matrix, using LAPACK's /dgetrf/. |
519 | luR :: Matrix Double -> (Matrix Double, [Int]) | 569 | luR :: Matrix Double -> (Matrix Double, [Int]) |
520 | luR = luAux dgetrf "luR" . fmat | 570 | luR = luAux dgetrf "luR" |
521 | 571 | ||
522 | -- | LU factorization of a general complex matrix, using LAPACK's /zgetrf/. | 572 | -- | LU factorization of a general complex matrix, using LAPACK's /zgetrf/. |
523 | luC :: Matrix (Complex Double) -> (Matrix (Complex Double), [Int]) | 573 | luC :: Matrix (Complex Double) -> (Matrix (Complex Double), [Int]) |
524 | luC = luAux zgetrf "luC" . fmat | 574 | luC = luAux zgetrf "luC" |
525 | 575 | ||
526 | luAux f st a = unsafePerformIO $ do | 576 | luAux f st a = unsafePerformIO $ do |
527 | lu <- createMatrix ColumnMajor n m | 577 | lu <- copy ColumnMajor a |
528 | piv <- createVector (min n m) | 578 | piv <- createVector (min n m) |
529 | app3 f mat a vec piv mat lu st | 579 | f # piv # lu #| st |
530 | return (lu, map (pred.round) (toList piv)) | 580 | return (lu, map (pred.round) (toList piv)) |
531 | where n = rows a | 581 | where |
532 | m = cols a | 582 | n = rows a |
583 | m = cols a | ||
533 | 584 | ||
534 | ----------------------------------------------------------------------------------- | 585 | ----------------------------------------------------------------------------------- |
535 | type TW a = CInt -> PD -> a | ||
536 | type TQ a = CInt -> CInt -> PC -> a | ||
537 | 586 | ||
538 | foreign import ccall unsafe "luS_l_R" dgetrs :: TMVMM | 587 | foreign import ccall unsafe "luS_l_R" dgetrs :: R ::> R :> R ::> Ok |
539 | foreign import ccall unsafe "luS_l_C" zgetrs :: TQ (TW (TQ (TQ (IO CInt)))) | 588 | foreign import ccall unsafe "luS_l_C" zgetrs :: C ::> R :> C ::> Ok |
540 | 589 | ||
541 | -- | Solve a real linear system from a precomputed LU decomposition ('luR'), using LAPACK's /dgetrs/. | 590 | -- | Solve a real linear system from a precomputed LU decomposition ('luR'), using LAPACK's /dgetrs/. |
542 | lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double | 591 | lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double |
543 | lusR a piv b = lusAux dgetrs "lusR" (fmat a) piv (fmat b) | 592 | lusR a piv b = lusAux dgetrs "lusR" (fmat a) piv b |
544 | 593 | ||
545 | -- | Solve a real linear system from a precomputed LU decomposition ('luC'), using LAPACK's /zgetrs/. | 594 | -- | Solve a complex linear system from a precomputed LU decomposition ('luC'), using LAPACK's /zgetrs/. |
546 | lusC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double) | 595 | lusC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double) |
547 | lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv (fmat b) | 596 | lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv b |
548 | 597 | ||
549 | lusAux f st a piv b | 598 | lusAux f st a piv b |
550 | | n1==n2 && n2==n =unsafePerformIO $ do | 599 | | n1==n2 && n2==n =unsafePerformIO $ do |
551 | x <- createMatrix ColumnMajor n m | 600 | x <- copy ColumnMajor b |
552 | app4 f mat a vec piv' mat b mat x st | 601 | f # a # piv' # x #| st |
553 | return x | 602 | return x |
554 | | otherwise = error $ st ++ " on LU factorization of nonsquare matrix" | 603 | | otherwise = error st |
555 | where n1 = rows a | 604 | where |
556 | n2 = cols a | 605 | n1 = rows a |
557 | n = rows b | 606 | n2 = cols a |
558 | m = cols b | 607 | n = rows b |
559 | piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double | 608 | piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double |
609 | |||
610 | ----------------------------------------------------------------------------------- | ||
611 | foreign import ccall unsafe "ldl_R" dsytrf :: R :> R ::> Ok | ||
612 | foreign import ccall unsafe "ldl_C" zhetrf :: R :> C ::> Ok | ||
613 | |||
614 | -- | LDL factorization of a symmetric real matrix, using LAPACK's /dsytrf/. | ||
615 | ldlR :: Matrix Double -> (Matrix Double, [Int]) | ||
616 | ldlR = ldlAux dsytrf "ldlR" | ||
617 | |||
618 | -- | LDL factorization of a hermitian complex matrix, using LAPACK's /zhetrf/. | ||
619 | ldlC :: Matrix (Complex Double) -> (Matrix (Complex Double), [Int]) | ||
620 | ldlC = ldlAux zhetrf "ldlC" | ||
621 | |||
622 | ldlAux f st a = unsafePerformIO $ do | ||
623 | ldl <- copy ColumnMajor a | ||
624 | piv <- createVector (rows a) | ||
625 | f # piv # ldl #| st | ||
626 | return (ldl, map (pred.round) (toList piv)) | ||
627 | |||
628 | ----------------------------------------------------------------------------------- | ||
629 | |||
630 | foreign import ccall unsafe "ldl_S_R" dsytrs :: R ::> R :> R ::> Ok | ||
631 | foreign import ccall unsafe "ldl_S_C" zsytrs :: C ::> R :> C ::> Ok | ||
632 | |||
633 | -- | Solve a real linear system from a precomputed LDL decomposition ('ldlR'), using LAPACK's /dsytrs/. | ||
634 | ldlsR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double | ||
635 | ldlsR a piv b = lusAux dsytrs "ldlsR" (fmat a) piv b | ||
636 | |||
637 | -- | Solve a complex linear system from a precomputed LDL decomposition ('ldlC'), using LAPACK's /zsytrs/. | ||
638 | ldlsC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double) | ||
639 | ldlsC a piv b = lusAux zsytrs "ldlsC" (fmat a) piv b | ||
560 | 640 | ||
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs new file mode 100644 index 0000000..3082e8d --- /dev/null +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -0,0 +1,598 @@ | |||
1 | {-# LANGUAGE ForeignFunctionInterface #-} | ||
2 | {-# LANGUAGE FlexibleContexts #-} | ||
3 | {-# LANGUAGE FlexibleInstances #-} | ||
4 | {-# LANGUAGE BangPatterns #-} | ||
5 | {-# LANGUAGE TypeOperators #-} | ||
6 | {-# LANGUAGE TypeFamilies #-} | ||
7 | {-# LANGUAGE ViewPatterns #-} | ||
8 | |||
9 | |||
10 | |||
11 | -- | | ||
12 | -- Module : Internal.Matrix | ||
13 | -- Copyright : (c) Alberto Ruiz 2007-15 | ||
14 | -- License : BSD3 | ||
15 | -- Maintainer : Alberto Ruiz | ||
16 | -- Stability : provisional | ||
17 | -- | ||
18 | -- Internal matrix representation | ||
19 | -- | ||
20 | |||
21 | module Internal.Matrix where | ||
22 | |||
23 | import Internal.Vector | ||
24 | import Internal.Devel | ||
25 | import Internal.Vectorized hiding ((#)) | ||
26 | import Foreign.Marshal.Alloc ( free ) | ||
27 | import Foreign.Marshal.Array(newArray) | ||
28 | import Foreign.Ptr ( Ptr ) | ||
29 | import Foreign.Storable ( Storable ) | ||
30 | import Data.Complex ( Complex ) | ||
31 | import Foreign.C.Types ( CInt(..) ) | ||
32 | import Foreign.C.String ( CString, newCString ) | ||
33 | import System.IO.Unsafe ( unsafePerformIO ) | ||
34 | import Control.DeepSeq ( NFData(..) ) | ||
35 | import Text.Printf | ||
36 | |||
37 | ----------------------------------------------------------------- | ||
38 | |||
39 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) | ||
40 | |||
41 | -- | Matrix representation suitable for BLAS\/LAPACK computations. | ||
42 | |||
43 | data Matrix t = Matrix | ||
44 | { irows :: {-# UNPACK #-} !Int | ||
45 | , icols :: {-# UNPACK #-} !Int | ||
46 | , xRow :: {-# UNPACK #-} !Int | ||
47 | , xCol :: {-# UNPACK #-} !Int | ||
48 | , xdat :: {-# UNPACK #-} !(Vector t) | ||
49 | } | ||
50 | |||
51 | |||
52 | rows :: Matrix t -> Int | ||
53 | rows = irows | ||
54 | {-# INLINE rows #-} | ||
55 | |||
56 | cols :: Matrix t -> Int | ||
57 | cols = icols | ||
58 | {-# INLINE cols #-} | ||
59 | |||
60 | size m = (irows m, icols m) | ||
61 | {-# INLINE size #-} | ||
62 | |||
63 | rowOrder m = xCol m == 1 || cols m == 1 | ||
64 | {-# INLINE rowOrder #-} | ||
65 | |||
66 | colOrder m = xRow m == 1 || rows m == 1 | ||
67 | {-# INLINE colOrder #-} | ||
68 | |||
69 | is1d (size->(r,c)) = r==1 || c==1 | ||
70 | {-# INLINE is1d #-} | ||
71 | |||
72 | -- data is not contiguous | ||
73 | isSlice m@(size->(r,c)) = r*c < dim (xdat m) | ||
74 | {-# INLINE isSlice #-} | ||
75 | |||
76 | orderOf :: Matrix t -> MatrixOrder | ||
77 | orderOf m = if rowOrder m then RowMajor else ColumnMajor | ||
78 | |||
79 | |||
80 | showInternal :: Storable t => Matrix t -> IO () | ||
81 | showInternal m = printf "%dx%d %s %s %d:%d (%d)\n" r c slc ord xr xc dv | ||
82 | where | ||
83 | r = rows m | ||
84 | c = cols m | ||
85 | xr = xRow m | ||
86 | xc = xCol m | ||
87 | slc = if isSlice m then "slice" else "full" | ||
88 | ord = if is1d m then "1d" else if rowOrder m then "rows" else "cols" | ||
89 | dv = dim (xdat m) | ||
90 | |||
91 | -------------------------------------------------------------------------------- | ||
92 | |||
93 | -- | Matrix transpose. | ||
94 | trans :: Matrix t -> Matrix t | ||
95 | trans m@Matrix { irows = r, icols = c, xRow = xr, xCol = xc } = | ||
96 | m { irows = c, icols = r, xRow = xc, xCol = xr } | ||
97 | |||
98 | |||
99 | cmat :: (Element t) => Matrix t -> Matrix t | ||
100 | cmat m | ||
101 | | rowOrder m = m | ||
102 | | otherwise = extractAll RowMajor m | ||
103 | |||
104 | |||
105 | fmat :: (Element t) => Matrix t -> Matrix t | ||
106 | fmat m | ||
107 | | colOrder m = m | ||
108 | | otherwise = extractAll ColumnMajor m | ||
109 | |||
110 | |||
111 | -- C-Haskell matrix adapters | ||
112 | {-# INLINE amatr #-} | ||
113 | amatr :: Storable a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b | ||
114 | amatr f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c)) | ||
115 | where | ||
116 | r = fi (rows x) | ||
117 | c = fi (cols x) | ||
118 | |||
119 | {-# INLINE amat #-} | ||
120 | amat :: Storable a => (CInt -> CInt -> CInt -> CInt -> Ptr a -> b) -> Matrix a -> b | ||
121 | amat f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c sr sc)) | ||
122 | where | ||
123 | r = fi (rows x) | ||
124 | c = fi (cols x) | ||
125 | sr = fi (xRow x) | ||
126 | sc = fi (xCol x) | ||
127 | |||
128 | |||
129 | instance Storable t => TransArray (Matrix t) | ||
130 | where | ||
131 | type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b | ||
132 | type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b | ||
133 | apply = amat | ||
134 | {-# INLINE apply #-} | ||
135 | applyRaw = amatr | ||
136 | {-# INLINE applyRaw #-} | ||
137 | |||
138 | infixl 1 # | ||
139 | a # b = apply a b | ||
140 | {-# INLINE (#) #-} | ||
141 | |||
142 | -------------------------------------------------------------------------------- | ||
143 | |||
144 | copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) | ||
145 | |||
146 | extractAll ord m = unsafePerformIO (copy ord m) | ||
147 | |||
148 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. | ||
149 | |||
150 | >>> flatten (ident 3) | ||
151 | fromList [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0] | ||
152 | |||
153 | -} | ||
154 | flatten :: Element t => Matrix t -> Vector t | ||
155 | flatten m | ||
156 | | isSlice m || not (rowOrder m) = xdat (extractAll RowMajor m) | ||
157 | | otherwise = xdat m | ||
158 | |||
159 | |||
160 | -- | the inverse of 'Data.Packed.Matrix.fromLists' | ||
161 | toLists :: (Element t) => Matrix t -> [[t]] | ||
162 | toLists = map toList . toRows | ||
163 | |||
164 | |||
165 | |||
166 | -- | common value with \"adaptable\" 1 | ||
167 | compatdim :: [Int] -> Maybe Int | ||
168 | compatdim [] = Nothing | ||
169 | compatdim [a] = Just a | ||
170 | compatdim (a:b:xs) | ||
171 | | a==b = compatdim (b:xs) | ||
172 | | a==1 = compatdim (b:xs) | ||
173 | | b==1 = compatdim (a:xs) | ||
174 | | otherwise = Nothing | ||
175 | |||
176 | |||
177 | |||
178 | |||
179 | -- | Create a matrix from a list of vectors. | ||
180 | -- All vectors must have the same dimension, | ||
181 | -- or dimension 1, which is are automatically expanded. | ||
182 | fromRows :: Element t => [Vector t] -> Matrix t | ||
183 | fromRows [] = emptyM 0 0 | ||
184 | fromRows vs = case compatdim (map dim vs) of | ||
185 | Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs) | ||
186 | Just 0 -> emptyM r 0 | ||
187 | Just c -> matrixFromVector RowMajor r c . vjoin . map (adapt c) $ vs | ||
188 | where | ||
189 | r = length vs | ||
190 | adapt c v | ||
191 | | c == 0 = fromList[] | ||
192 | | dim v == c = v | ||
193 | | otherwise = constantD (v@>0) c | ||
194 | |||
195 | -- | extracts the rows of a matrix as a list of vectors | ||
196 | toRows :: Element t => Matrix t -> [Vector t] | ||
197 | toRows m | ||
198 | | rowOrder m = map sub rowRange | ||
199 | | otherwise = map ext rowRange | ||
200 | where | ||
201 | rowRange = [0..rows m-1] | ||
202 | sub k = subVector (k*xRow m) (cols m) (xdat m) | ||
203 | ext k = xdat $ unsafePerformIO $ extractR RowMajor m 1 (idxs[k]) 0 (idxs[0,cols m-1]) | ||
204 | |||
205 | |||
206 | -- | Creates a matrix from a list of vectors, as columns | ||
207 | fromColumns :: Element t => [Vector t] -> Matrix t | ||
208 | fromColumns m = trans . fromRows $ m | ||
209 | |||
210 | -- | Creates a list of vectors from the columns of a matrix | ||
211 | toColumns :: Element t => Matrix t -> [Vector t] | ||
212 | toColumns m = toRows . trans $ m | ||
213 | |||
214 | -- | Reads a matrix position. | ||
215 | (@@>) :: Storable t => Matrix t -> (Int,Int) -> t | ||
216 | infixl 9 @@> | ||
217 | m@Matrix {irows = r, icols = c} @@> (i,j) | ||
218 | | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" | ||
219 | | otherwise = atM' m i j | ||
220 | {-# INLINE (@@>) #-} | ||
221 | |||
222 | -- Unsafe matrix access without range checking | ||
223 | atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m)) | ||
224 | {-# INLINE atM' #-} | ||
225 | |||
226 | ------------------------------------------------------------------ | ||
227 | |||
228 | matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } | ||
229 | matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } | ||
230 | matrixFromVector o r c v | ||
231 | | r * c == dim v = m | ||
232 | | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m | ||
233 | where | ||
234 | m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = c, xCol = 1 } | ||
235 | | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1, xCol = r } | ||
236 | |||
237 | -- allocates memory for a new matrix | ||
238 | createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) | ||
239 | createMatrix ord r c = do | ||
240 | p <- createVector (r*c) | ||
241 | return (matrixFromVector ord r c p) | ||
242 | |||
243 | {- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = tr' . reshape r@ | ||
244 | where r is the desired number of rows.) | ||
245 | |||
246 | >>> reshape 4 (fromList [1..12]) | ||
247 | (3><4) | ||
248 | [ 1.0, 2.0, 3.0, 4.0 | ||
249 | , 5.0, 6.0, 7.0, 8.0 | ||
250 | , 9.0, 10.0, 11.0, 12.0 ] | ||
251 | |||
252 | -} | ||
253 | reshape :: Storable t => Int -> Vector t -> Matrix t | ||
254 | reshape 0 v = matrixFromVector RowMajor 0 0 v | ||
255 | reshape c v = matrixFromVector RowMajor (dim v `div` c) c v | ||
256 | |||
257 | |||
258 | -- | application of a vector function on the flattened matrix elements | ||
259 | liftMatrix :: (Element a, Element b) => (Vector a -> Vector b) -> Matrix a -> Matrix b | ||
260 | liftMatrix f m@Matrix { irows = r, icols = c, xdat = d} | ||
261 | | isSlice m = matrixFromVector RowMajor r c (f (flatten m)) | ||
262 | | otherwise = matrixFromVector (orderOf m) r c (f d) | ||
263 | |||
264 | -- | application of a vector function on the flattened matrices elements | ||
265 | liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | ||
266 | liftMatrix2 f m1@(size->(r,c)) m2 | ||
267 | | (r,c)/=size m2 = error "nonconformant matrices in liftMatrix2" | ||
268 | | rowOrder m1 = matrixFromVector RowMajor r c (f (flatten m1) (flatten m2)) | ||
269 | | otherwise = matrixFromVector ColumnMajor r c (f (flatten (trans m1)) (flatten (trans m2))) | ||
270 | |||
271 | ------------------------------------------------------------------ | ||
272 | |||
273 | -- | Supported matrix elements. | ||
274 | class (Storable a) => Element a where | ||
275 | constantD :: a -> Int -> Vector a | ||
276 | extractR :: MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) | ||
277 | setRect :: Int -> Int -> Matrix a -> Matrix a -> IO () | ||
278 | sortI :: Ord a => Vector a -> Vector CInt | ||
279 | sortV :: Ord a => Vector a -> Vector a | ||
280 | compareV :: Ord a => Vector a -> Vector a -> Vector CInt | ||
281 | selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a | ||
282 | remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a | ||
283 | rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () | ||
284 | gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () | ||
285 | |||
286 | |||
287 | instance Element Float where | ||
288 | constantD = constantAux cconstantF | ||
289 | extractR = extractAux c_extractF | ||
290 | setRect = setRectAux c_setRectF | ||
291 | sortI = sortIdxF | ||
292 | sortV = sortValF | ||
293 | compareV = compareF | ||
294 | selectV = selectF | ||
295 | remapM = remapF | ||
296 | rowOp = rowOpAux c_rowOpF | ||
297 | gemm = gemmg c_gemmF | ||
298 | |||
299 | instance Element Double where | ||
300 | constantD = constantAux cconstantR | ||
301 | extractR = extractAux c_extractD | ||
302 | setRect = setRectAux c_setRectD | ||
303 | sortI = sortIdxD | ||
304 | sortV = sortValD | ||
305 | compareV = compareD | ||
306 | selectV = selectD | ||
307 | remapM = remapD | ||
308 | rowOp = rowOpAux c_rowOpD | ||
309 | gemm = gemmg c_gemmD | ||
310 | |||
311 | instance Element (Complex Float) where | ||
312 | constantD = constantAux cconstantQ | ||
313 | extractR = extractAux c_extractQ | ||
314 | setRect = setRectAux c_setRectQ | ||
315 | sortI = undefined | ||
316 | sortV = undefined | ||
317 | compareV = undefined | ||
318 | selectV = selectQ | ||
319 | remapM = remapQ | ||
320 | rowOp = rowOpAux c_rowOpQ | ||
321 | gemm = gemmg c_gemmQ | ||
322 | |||
323 | instance Element (Complex Double) where | ||
324 | constantD = constantAux cconstantC | ||
325 | extractR = extractAux c_extractC | ||
326 | setRect = setRectAux c_setRectC | ||
327 | sortI = undefined | ||
328 | sortV = undefined | ||
329 | compareV = undefined | ||
330 | selectV = selectC | ||
331 | remapM = remapC | ||
332 | rowOp = rowOpAux c_rowOpC | ||
333 | gemm = gemmg c_gemmC | ||
334 | |||
335 | instance Element (CInt) where | ||
336 | constantD = constantAux cconstantI | ||
337 | extractR = extractAux c_extractI | ||
338 | setRect = setRectAux c_setRectI | ||
339 | sortI = sortIdxI | ||
340 | sortV = sortValI | ||
341 | compareV = compareI | ||
342 | selectV = selectI | ||
343 | remapM = remapI | ||
344 | rowOp = rowOpAux c_rowOpI | ||
345 | gemm = gemmg c_gemmI | ||
346 | |||
347 | instance Element Z where | ||
348 | constantD = constantAux cconstantL | ||
349 | extractR = extractAux c_extractL | ||
350 | setRect = setRectAux c_setRectL | ||
351 | sortI = sortIdxL | ||
352 | sortV = sortValL | ||
353 | compareV = compareL | ||
354 | selectV = selectL | ||
355 | remapM = remapL | ||
356 | rowOp = rowOpAux c_rowOpL | ||
357 | gemm = gemmg c_gemmL | ||
358 | |||
359 | ------------------------------------------------------------------- | ||
360 | |||
361 | -- | reference to a rectangular slice of a matrix (no data copy) | ||
362 | subMatrix :: Element a | ||
363 | => (Int,Int) -- ^ (r0,c0) starting position | ||
364 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | ||
365 | -> Matrix a -- ^ input matrix | ||
366 | -> Matrix a -- ^ result | ||
367 | subMatrix (r0,c0) (rt,ct) m | ||
368 | | rt <= 0 || ct <= 0 = matrixFromVector RowMajor (max 0 rt) (max 0 ct) (fromList []) | ||
369 | | 0 <= r0 && 0 <= rt && r0+rt <= rows m && | ||
370 | 0 <= c0 && 0 <= ct && c0+ct <= cols m = res | ||
371 | | otherwise = error $ "wrong subMatrix "++show ((r0,c0),(rt,ct))++" of "++shSize m | ||
372 | where | ||
373 | p = r0 * xRow m + c0 * xCol m | ||
374 | tot | rowOrder m = ct + (rt-1) * xRow m | ||
375 | | otherwise = rt + (ct-1) * xCol m | ||
376 | res = m { irows = rt, icols = ct, xdat = subVector p tot (xdat m) } | ||
377 | |||
378 | -------------------------------------------------------------------------- | ||
379 | |||
380 | maxZ xs = if minimum xs == 0 then 0 else maximum xs | ||
381 | |||
382 | conformMs ms = map (conformMTo (r,c)) ms | ||
383 | where | ||
384 | r = maxZ (map rows ms) | ||
385 | c = maxZ (map cols ms) | ||
386 | |||
387 | |||
388 | conformVs vs = map (conformVTo n) vs | ||
389 | where | ||
390 | n = maxZ (map dim vs) | ||
391 | |||
392 | conformMTo (r,c) m | ||
393 | | size m == (r,c) = m | ||
394 | | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) | ||
395 | | size m == (r,1) = repCols c m | ||
396 | | size m == (1,c) = repRows r m | ||
397 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c) | ||
398 | |||
399 | conformVTo n v | ||
400 | | dim v == n = v | ||
401 | | dim v == 1 = constantD (v@>0) n | ||
402 | | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n | ||
403 | |||
404 | repRows n x = fromRows (replicate n (flatten x)) | ||
405 | repCols n x = fromColumns (replicate n (flatten x)) | ||
406 | |||
407 | shSize = shDim . size | ||
408 | |||
409 | shDim (r,c) = "(" ++ show r ++"x"++ show c ++")" | ||
410 | |||
411 | emptyM r c = matrixFromVector RowMajor r c (fromList[]) | ||
412 | |||
413 | ---------------------------------------------------------------------- | ||
414 | |||
415 | instance (Storable t, NFData t) => NFData (Matrix t) | ||
416 | where | ||
417 | rnf m | d > 0 = rnf (v @> 0) | ||
418 | | otherwise = () | ||
419 | where | ||
420 | d = dim v | ||
421 | v = xdat m | ||
422 | |||
423 | --------------------------------------------------------------- | ||
424 | |||
425 | extractAux f ord m moder vr modec vc = do | ||
426 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr | ||
427 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc | ||
428 | r <- createMatrix ord nr nc | ||
429 | f moder modec # vr # vc # m # r #|"extract" | ||
430 | return r | ||
431 | |||
432 | type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) | ||
433 | |||
434 | foreign import ccall unsafe "extractD" c_extractD :: Extr Double | ||
435 | foreign import ccall unsafe "extractF" c_extractF :: Extr Float | ||
436 | foreign import ccall unsafe "extractC" c_extractC :: Extr (Complex Double) | ||
437 | foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float) | ||
438 | foreign import ccall unsafe "extractI" c_extractI :: Extr CInt | ||
439 | foreign import ccall unsafe "extractL" c_extractL :: Extr Z | ||
440 | |||
441 | --------------------------------------------------------------- | ||
442 | |||
443 | setRectAux f i j m r = f (fi i) (fi j) # m # r #|"setRect" | ||
444 | |||
445 | type SetRect x = I -> I -> x ::> x::> Ok | ||
446 | |||
447 | foreign import ccall unsafe "setRectD" c_setRectD :: SetRect Double | ||
448 | foreign import ccall unsafe "setRectF" c_setRectF :: SetRect Float | ||
449 | foreign import ccall unsafe "setRectC" c_setRectC :: SetRect (Complex Double) | ||
450 | foreign import ccall unsafe "setRectQ" c_setRectQ :: SetRect (Complex Float) | ||
451 | foreign import ccall unsafe "setRectI" c_setRectI :: SetRect I | ||
452 | foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z | ||
453 | |||
454 | -------------------------------------------------------------------------------- | ||
455 | |||
456 | sortG f v = unsafePerformIO $ do | ||
457 | r <- createVector (dim v) | ||
458 | f # v # r #|"sortG" | ||
459 | return r | ||
460 | |||
461 | sortIdxD = sortG c_sort_indexD | ||
462 | sortIdxF = sortG c_sort_indexF | ||
463 | sortIdxI = sortG c_sort_indexI | ||
464 | sortIdxL = sortG c_sort_indexL | ||
465 | |||
466 | sortValD = sortG c_sort_valD | ||
467 | sortValF = sortG c_sort_valF | ||
468 | sortValI = sortG c_sort_valI | ||
469 | sortValL = sortG c_sort_valL | ||
470 | |||
471 | foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) | ||
472 | foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV CInt (IO CInt)) | ||
473 | foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV CInt (CV CInt (IO CInt)) | ||
474 | foreign import ccall unsafe "sort_indexL" c_sort_indexL :: Z :> I :> Ok | ||
475 | |||
476 | foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO CInt)) | ||
477 | foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO CInt)) | ||
478 | foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV CInt (CV CInt (IO CInt)) | ||
479 | foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok | ||
480 | |||
481 | -------------------------------------------------------------------------------- | ||
482 | |||
483 | compareG f u v = unsafePerformIO $ do | ||
484 | r <- createVector (dim v) | ||
485 | f # u # v # r #|"compareG" | ||
486 | return r | ||
487 | |||
488 | compareD = compareG c_compareD | ||
489 | compareF = compareG c_compareF | ||
490 | compareI = compareG c_compareI | ||
491 | compareL = compareG c_compareL | ||
492 | |||
493 | foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) | ||
494 | foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV CInt (IO CInt))) | ||
495 | foreign import ccall unsafe "compareI" c_compareI :: CV CInt (CV CInt (CV CInt (IO CInt))) | ||
496 | foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok | ||
497 | |||
498 | -------------------------------------------------------------------------------- | ||
499 | |||
500 | selectG f c u v w = unsafePerformIO $ do | ||
501 | r <- createVector (dim v) | ||
502 | f # c # u # v # w # r #|"selectG" | ||
503 | return r | ||
504 | |||
505 | selectD = selectG c_selectD | ||
506 | selectF = selectG c_selectF | ||
507 | selectI = selectG c_selectI | ||
508 | selectL = selectG c_selectL | ||
509 | selectC = selectG c_selectC | ||
510 | selectQ = selectG c_selectQ | ||
511 | |||
512 | type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) | ||
513 | |||
514 | foreign import ccall unsafe "chooseD" c_selectD :: Sel Double | ||
515 | foreign import ccall unsafe "chooseF" c_selectF :: Sel Float | ||
516 | foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt | ||
517 | foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double) | ||
518 | foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float) | ||
519 | foreign import ccall unsafe "chooseL" c_selectL :: Sel Z | ||
520 | |||
521 | --------------------------------------------------------------------------- | ||
522 | |||
523 | remapG f i j m = unsafePerformIO $ do | ||
524 | r <- createMatrix RowMajor (rows i) (cols i) | ||
525 | f # i # j # m # r #|"remapG" | ||
526 | return r | ||
527 | |||
528 | remapD = remapG c_remapD | ||
529 | remapF = remapG c_remapF | ||
530 | remapI = remapG c_remapI | ||
531 | remapL = remapG c_remapL | ||
532 | remapC = remapG c_remapC | ||
533 | remapQ = remapG c_remapQ | ||
534 | |||
535 | type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) | ||
536 | |||
537 | foreign import ccall unsafe "remapD" c_remapD :: Rem Double | ||
538 | foreign import ccall unsafe "remapF" c_remapF :: Rem Float | ||
539 | foreign import ccall unsafe "remapI" c_remapI :: Rem CInt | ||
540 | foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double) | ||
541 | foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float) | ||
542 | foreign import ccall unsafe "remapL" c_remapL :: Rem Z | ||
543 | |||
544 | -------------------------------------------------------------------------------- | ||
545 | |||
546 | rowOpAux f c x i1 i2 j1 j2 m = do | ||
547 | px <- newArray [x] | ||
548 | f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2) # m #|"rowOp" | ||
549 | free px | ||
550 | |||
551 | type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok | ||
552 | |||
553 | foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R | ||
554 | foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float | ||
555 | foreign import ccall unsafe "rowop_TCD" c_rowOpC :: RowOp C | ||
556 | foreign import ccall unsafe "rowop_TCF" c_rowOpQ :: RowOp (Complex Float) | ||
557 | foreign import ccall unsafe "rowop_int32_t" c_rowOpI :: RowOp I | ||
558 | foreign import ccall unsafe "rowop_int64_t" c_rowOpL :: RowOp Z | ||
559 | foreign import ccall unsafe "rowop_mod_int32_t" c_rowOpMI :: I -> RowOp I | ||
560 | foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z | ||
561 | |||
562 | -------------------------------------------------------------------------------- | ||
563 | |||
564 | gemmg f v m1 m2 m3 = f # v # m1 # m2 # m3 #|"gemmg" | ||
565 | |||
566 | type Tgemm x = x :> x ::> x ::> x ::> Ok | ||
567 | |||
568 | foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R | ||
569 | foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float | ||
570 | foreign import ccall unsafe "gemm_TCD" c_gemmC :: Tgemm C | ||
571 | foreign import ccall unsafe "gemm_TCF" c_gemmQ :: Tgemm (Complex Float) | ||
572 | foreign import ccall unsafe "gemm_int32_t" c_gemmI :: Tgemm I | ||
573 | foreign import ccall unsafe "gemm_int64_t" c_gemmL :: Tgemm Z | ||
574 | foreign import ccall unsafe "gemm_mod_int32_t" c_gemmMI :: I -> Tgemm I | ||
575 | foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z | ||
576 | |||
577 | -------------------------------------------------------------------------------- | ||
578 | |||
579 | foreign import ccall unsafe "saveMatrix" c_saveMatrix | ||
580 | :: CString -> CString -> Double ::> Ok | ||
581 | |||
582 | {- | save a matrix as a 2D ASCII table | ||
583 | -} | ||
584 | saveMatrix | ||
585 | :: FilePath | ||
586 | -> String -- ^ \"printf\" format (e.g. \"%.2f\", \"%g\", etc.) | ||
587 | -> Matrix Double | ||
588 | -> IO () | ||
589 | saveMatrix name format m = do | ||
590 | cname <- newCString name | ||
591 | cformat <- newCString format | ||
592 | c_saveMatrix cname cformat # m #|"saveMatrix" | ||
593 | free cname | ||
594 | free cformat | ||
595 | return () | ||
596 | |||
597 | -------------------------------------------------------------------------------- | ||
598 | |||
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs new file mode 100644 index 0000000..239c742 --- /dev/null +++ b/packages/base/src/Internal/Modular.hs | |||
@@ -0,0 +1,469 @@ | |||
1 | {-# LANGUAGE DataKinds #-} | ||
2 | {-# LANGUAGE KindSignatures #-} | ||
3 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} | ||
4 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
5 | {-# LANGUAGE FunctionalDependencies #-} | ||
6 | {-# LANGUAGE FlexibleContexts #-} | ||
7 | {-# LANGUAGE ScopedTypeVariables #-} | ||
8 | {-# LANGUAGE Rank2Types #-} | ||
9 | {-# LANGUAGE FlexibleInstances #-} | ||
10 | {-# LANGUAGE UndecidableInstances #-} | ||
11 | {-# LANGUAGE GADTs #-} | ||
12 | {-# LANGUAGE TypeFamilies #-} | ||
13 | {-# LANGUAGE TypeOperators #-} | ||
14 | |||
15 | {- | | ||
16 | Module : Internal.Modular | ||
17 | Copyright : (c) Alberto Ruiz 2015 | ||
18 | License : BSD3 | ||
19 | Stability : experimental | ||
20 | |||
21 | Proof of concept of statically checked modular arithmetic. | ||
22 | |||
23 | -} | ||
24 | |||
25 | module Internal.Modular( | ||
26 | Mod, type (./.) | ||
27 | ) where | ||
28 | |||
29 | import Internal.Vector | ||
30 | import Internal.Matrix hiding (size) | ||
31 | import Internal.Numeric | ||
32 | import Internal.Element | ||
33 | import Internal.Container | ||
34 | import Internal.Vectorized (prodI,sumI,prodL,sumL) | ||
35 | import Internal.LAPACK (multiplyI, multiplyL) | ||
36 | import Internal.Algorithms(luFact,LU(..)) | ||
37 | import Internal.Util(Normed(..),Indexable(..), | ||
38 | gaussElim, gaussElim_1, gaussElim_2, | ||
39 | luST, luSolve', luPacked', magnit, invershur) | ||
40 | import Internal.ST(mutable) | ||
41 | import GHC.TypeLits | ||
42 | import Data.Proxy(Proxy) | ||
43 | import Foreign.ForeignPtr(castForeignPtr) | ||
44 | import Foreign.Storable | ||
45 | import Data.Ratio | ||
46 | import Data.Complex | ||
47 | import Control.DeepSeq ( NFData(..) ) | ||
48 | |||
49 | |||
50 | |||
51 | -- | Wrapper with a phantom integer for statically checked modular arithmetic. | ||
52 | newtype Mod (n :: Nat) t = Mod {unMod:: t} | ||
53 | deriving (Storable) | ||
54 | |||
55 | instance (NFData t) => NFData (Mod n t) | ||
56 | where | ||
57 | rnf (Mod x) = rnf x | ||
58 | |||
59 | infixr 5 ./. | ||
60 | type (./.) x n = Mod n x | ||
61 | |||
62 | instance (Integral t, Enum t, KnownNat m) => Enum (Mod m t) | ||
63 | where | ||
64 | toEnum = l0 (\m x -> fromIntegral $ x `mod` (fromIntegral m)) | ||
65 | fromEnum = fromIntegral . unMod | ||
66 | |||
67 | instance (Eq t, KnownNat m) => Eq (Mod m t) | ||
68 | where | ||
69 | a == b = (unMod a) == (unMod b) | ||
70 | |||
71 | instance (Ord t, KnownNat m) => Ord (Mod m t) | ||
72 | where | ||
73 | compare a b = compare (unMod a) (unMod b) | ||
74 | |||
75 | instance (Real t, KnownNat m, Integral (Mod m t)) => Real (Mod m t) | ||
76 | where | ||
77 | toRational x = toInteger x % 1 | ||
78 | |||
79 | instance (Integral t, KnownNat m, Num (Mod m t)) => Integral (Mod m t) | ||
80 | where | ||
81 | toInteger = toInteger . unMod | ||
82 | quotRem a b = (Mod q, Mod r) | ||
83 | where | ||
84 | (q,r) = quotRem (unMod a) (unMod b) | ||
85 | |||
86 | -- | this instance is only valid for prime m | ||
87 | instance (Show (Mod m t), Num (Mod m t), Eq t, KnownNat m) => Fractional (Mod m t) | ||
88 | where | ||
89 | recip x | ||
90 | | x*r == 1 = r | ||
91 | | otherwise = error $ show x ++" does not have a multiplicative inverse mod "++show m' | ||
92 | where | ||
93 | r = x^(m'-2 :: Integer) | ||
94 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
95 | fromRational x = fromInteger (numerator x) / fromInteger (denominator x) | ||
96 | |||
97 | l2 :: forall m a b c. (Num c, KnownNat m) => (c -> a -> b -> c) -> Mod m a -> Mod m b -> Mod m c | ||
98 | l2 f (Mod u) (Mod v) = Mod (f m' u v) | ||
99 | where | ||
100 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
101 | |||
102 | l1 :: forall m a b . (Num b, KnownNat m) => (b -> a -> b) -> Mod m a -> Mod m b | ||
103 | l1 f (Mod u) = Mod (f m' u) | ||
104 | where | ||
105 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
106 | |||
107 | l0 :: forall m a b . (Num b, KnownNat m) => (b -> a -> b) -> a -> Mod m b | ||
108 | l0 f u = Mod (f m' u) | ||
109 | where | ||
110 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
111 | |||
112 | |||
113 | instance Show t => Show (Mod n t) | ||
114 | where | ||
115 | show = show . unMod | ||
116 | |||
117 | instance forall n t . (Integral t, KnownNat n) => Num (Mod n t) | ||
118 | where | ||
119 | (+) = l2 (\m a b -> (a + b) `mod` (fromIntegral m)) | ||
120 | (*) = l2 (\m a b -> (a * b) `mod` (fromIntegral m)) | ||
121 | (-) = l2 (\m a b -> (a - b) `mod` (fromIntegral m)) | ||
122 | abs = l1 (const abs) | ||
123 | signum = l1 (const signum) | ||
124 | fromInteger = l0 (\m x -> fromInteger x `mod` (fromIntegral m)) | ||
125 | |||
126 | |||
127 | instance KnownNat m => Element (Mod m I) | ||
128 | where | ||
129 | constantD x n = i2f (constantD (unMod x) n) | ||
130 | extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js | ||
131 | setRect i j m x = setRect i j (f2iM m) (f2iM x) | ||
132 | sortI = sortI . f2i | ||
133 | sortV = i2f . sortV . f2i | ||
134 | compareV u v = compareV (f2i u) (f2i v) | ||
135 | selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) | ||
136 | remapM i j m = i2fM (remap i j (f2iM m)) | ||
137 | rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x) | ||
138 | where | ||
139 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
140 | gemm u a b c = gemmg (c_gemmMI m') (f2i u) (f2iM a) (f2iM b) (f2iM c) | ||
141 | where | ||
142 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
143 | |||
144 | instance KnownNat m => Element (Mod m Z) | ||
145 | where | ||
146 | constantD x n = i2f (constantD (unMod x) n) | ||
147 | extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js | ||
148 | setRect i j m x = setRect i j (f2iM m) (f2iM x) | ||
149 | sortI = sortI . f2i | ||
150 | sortV = i2f . sortV . f2i | ||
151 | compareV u v = compareV (f2i u) (f2i v) | ||
152 | selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) | ||
153 | remapM i j m = i2fM (remap i j (f2iM m)) | ||
154 | rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x) | ||
155 | where | ||
156 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
157 | gemm u a b c = gemmg (c_gemmML m') (f2i u) (f2iM a) (f2iM b) (f2iM c) | ||
158 | where | ||
159 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
160 | |||
161 | |||
162 | instance forall m . KnownNat m => CTrans (Mod m I) | ||
163 | instance forall m . KnownNat m => CTrans (Mod m Z) | ||
164 | |||
165 | |||
166 | instance forall m . KnownNat m => Container Vector (Mod m I) | ||
167 | where | ||
168 | conj' = id | ||
169 | size' = dim | ||
170 | scale' s x = vmod (scale (unMod s) (f2i x)) | ||
171 | addConstant c x = vmod (addConstant (unMod c) (f2i x)) | ||
172 | add' a b = vmod (add' (f2i a) (f2i b)) | ||
173 | sub a b = vmod (sub (f2i a) (f2i b)) | ||
174 | mul a b = vmod (mul (f2i a) (f2i b)) | ||
175 | equal u v = equal (f2i u) (f2i v) | ||
176 | scalar' x = fromList [x] | ||
177 | konst' x = i2f . konst (unMod x) | ||
178 | build' n f = build n (fromIntegral . f) | ||
179 | cmap' = mapVector | ||
180 | atIndex' x k = fromIntegral (atIndex (f2i x) k) | ||
181 | minIndex' = minIndex . f2i | ||
182 | maxIndex' = maxIndex . f2i | ||
183 | minElement' = Mod . minElement . f2i | ||
184 | maxElement' = Mod . maxElement . f2i | ||
185 | sumElements' = fromIntegral . sumI m' . f2i | ||
186 | where | ||
187 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
188 | prodElements' = fromIntegral . prodI m' . f2i | ||
189 | where | ||
190 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
191 | step' = i2f . step . f2i | ||
192 | find' = findV | ||
193 | assoc' = assocV | ||
194 | accum' = accumV | ||
195 | ccompare' a b = ccompare (f2i a) (f2i b) | ||
196 | cselect' c l e g = i2f $ cselect c (f2i l) (f2i e) (f2i g) | ||
197 | scaleRecip s x = scale' s (cmap recip x) | ||
198 | divide x y = mul x (cmap recip y) | ||
199 | arctan2' = undefined | ||
200 | cmod' m = vmod . cmod' (unMod m) . f2i | ||
201 | fromInt' = vmod | ||
202 | toInt' = f2i | ||
203 | fromZ' = vmod . fromZ' | ||
204 | toZ' = toZ' . f2i | ||
205 | |||
206 | instance forall m . KnownNat m => Container Vector (Mod m Z) | ||
207 | where | ||
208 | conj' = id | ||
209 | size' = dim | ||
210 | scale' s x = vmod (scale (unMod s) (f2i x)) | ||
211 | addConstant c x = vmod (addConstant (unMod c) (f2i x)) | ||
212 | add' a b = vmod (add' (f2i a) (f2i b)) | ||
213 | sub a b = vmod (sub (f2i a) (f2i b)) | ||
214 | mul a b = vmod (mul (f2i a) (f2i b)) | ||
215 | equal u v = equal (f2i u) (f2i v) | ||
216 | scalar' x = fromList [x] | ||
217 | konst' x = i2f . konst (unMod x) | ||
218 | build' n f = build n (fromIntegral . f) | ||
219 | cmap' = mapVector | ||
220 | atIndex' x k = fromIntegral (atIndex (f2i x) k) | ||
221 | minIndex' = minIndex . f2i | ||
222 | maxIndex' = maxIndex . f2i | ||
223 | minElement' = Mod . minElement . f2i | ||
224 | maxElement' = Mod . maxElement . f2i | ||
225 | sumElements' = fromIntegral . sumL m' . f2i | ||
226 | where | ||
227 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
228 | prodElements' = fromIntegral . prodL m' . f2i | ||
229 | where | ||
230 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
231 | step' = i2f . step . f2i | ||
232 | find' = findV | ||
233 | assoc' = assocV | ||
234 | accum' = accumV | ||
235 | ccompare' a b = ccompare (f2i a) (f2i b) | ||
236 | cselect' c l e g = i2f $ cselect c (f2i l) (f2i e) (f2i g) | ||
237 | scaleRecip s x = scale' s (cmap recip x) | ||
238 | divide x y = mul x (cmap recip y) | ||
239 | arctan2' = undefined | ||
240 | cmod' m = vmod . cmod' (unMod m) . f2i | ||
241 | fromInt' = vmod . fromInt' | ||
242 | toInt' = toInt . f2i | ||
243 | fromZ' = vmod | ||
244 | toZ' = f2i | ||
245 | |||
246 | |||
247 | instance (Storable t, Indexable (Vector t) t) => Indexable (Vector (Mod m t)) (Mod m t) | ||
248 | where | ||
249 | (!) = (@>) | ||
250 | |||
251 | type instance RealOf (Mod n I) = I | ||
252 | type instance RealOf (Mod n Z) = Z | ||
253 | |||
254 | instance KnownNat m => Product (Mod m I) where | ||
255 | norm2 = undefined | ||
256 | absSum = undefined | ||
257 | norm1 = undefined | ||
258 | normInf = undefined | ||
259 | multiply = lift2m (multiplyI m') | ||
260 | where | ||
261 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
262 | |||
263 | instance KnownNat m => Product (Mod m Z) where | ||
264 | norm2 = undefined | ||
265 | absSum = undefined | ||
266 | norm1 = undefined | ||
267 | normInf = undefined | ||
268 | multiply = lift2m (multiplyL m') | ||
269 | where | ||
270 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
271 | |||
272 | instance KnownNat m => Normed (Vector (Mod m I)) | ||
273 | where | ||
274 | norm_0 = norm_0 . toInt | ||
275 | norm_1 = norm_1 . toInt | ||
276 | norm_2 = norm_2 . toInt | ||
277 | norm_Inf = norm_Inf . toInt | ||
278 | |||
279 | instance KnownNat m => Normed (Vector (Mod m Z)) | ||
280 | where | ||
281 | norm_0 = norm_0 . toZ | ||
282 | norm_1 = norm_1 . toZ | ||
283 | norm_2 = norm_2 . toZ | ||
284 | norm_Inf = norm_Inf . toZ | ||
285 | |||
286 | |||
287 | instance KnownNat m => Numeric (Mod m I) | ||
288 | instance KnownNat m => Numeric (Mod m Z) | ||
289 | |||
290 | i2f :: Storable t => Vector t -> Vector (Mod n t) | ||
291 | i2f v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) | ||
292 | where (fp,i,n) = unsafeToForeignPtr v | ||
293 | |||
294 | f2i :: Storable t => Vector (Mod n t) -> Vector t | ||
295 | f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) | ||
296 | where (fp,i,n) = unsafeToForeignPtr v | ||
297 | |||
298 | f2iM :: (Element t, Element (Mod n t)) => Matrix (Mod n t) -> Matrix t | ||
299 | f2iM m = m { xdat = f2i (xdat m) } | ||
300 | |||
301 | i2fM :: (Element t, Element (Mod n t)) => Matrix t -> Matrix (Mod n t) | ||
302 | i2fM m = m { xdat = i2f (xdat m) } | ||
303 | |||
304 | vmod :: forall m t. (KnownNat m, Storable t, Integral t, Numeric t) => Vector t -> Vector (Mod m t) | ||
305 | vmod = i2f . cmod' m' | ||
306 | where | ||
307 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
308 | |||
309 | lift1 f a = vmod (f (f2i a)) | ||
310 | lift2 f a b = vmod (f (f2i a) (f2i b)) | ||
311 | |||
312 | lift2m f a b = liftMatrix vmod (f (f2iM a) (f2iM b)) | ||
313 | |||
314 | instance forall m . KnownNat m => Num (Vector (Mod m I)) | ||
315 | where | ||
316 | (+) = lift2 (+) | ||
317 | (*) = lift2 (*) | ||
318 | (-) = lift2 (-) | ||
319 | abs = lift1 abs | ||
320 | signum = lift1 signum | ||
321 | negate = lift1 negate | ||
322 | fromInteger x = fromInt (fromInteger x) | ||
323 | |||
324 | instance forall m . KnownNat m => Num (Vector (Mod m Z)) | ||
325 | where | ||
326 | (+) = lift2 (+) | ||
327 | (*) = lift2 (*) | ||
328 | (-) = lift2 (-) | ||
329 | abs = lift1 abs | ||
330 | signum = lift1 signum | ||
331 | negate = lift1 negate | ||
332 | fromInteger x = fromZ (fromInteger x) | ||
333 | |||
334 | -------------------------------------------------------------------------------- | ||
335 | |||
336 | instance (KnownNat m) => Testable (Matrix (Mod m I)) | ||
337 | where | ||
338 | checkT _ = test | ||
339 | |||
340 | test = (ok, info) | ||
341 | where | ||
342 | v = fromList [3,-5,75] :: Vector (Mod 11 I) | ||
343 | m = (3><3) [1..] :: Matrix (Mod 11 I) | ||
344 | |||
345 | a = (3><3) [1,2 , 3 | ||
346 | ,4,5 , 6 | ||
347 | ,0,10,-3] :: Matrix I | ||
348 | |||
349 | b = (3><2) [0..] :: Matrix I | ||
350 | |||
351 | am = fromInt a :: Matrix (Mod 13 I) | ||
352 | bm = fromInt b :: Matrix (Mod 13 I) | ||
353 | ad = fromInt a :: Matrix Double | ||
354 | bd = fromInt b :: Matrix Double | ||
355 | |||
356 | g = (3><3) (repeat (40000)) :: Matrix I | ||
357 | gm = fromInt g :: Matrix (Mod 100000 I) | ||
358 | |||
359 | lg = (3><3) (repeat (3*10^(9::Int))) :: Matrix Z | ||
360 | lgm = fromZ lg :: Matrix (Mod 10000000000 Z) | ||
361 | |||
362 | gen n = diagRect 1 (konst 5 n) n n :: Numeric t => Matrix t | ||
363 | |||
364 | rgen n = gen n :: Matrix R | ||
365 | cgen n = complex (rgen n) + fliprl (complex (rgen n)) * scalar (0:+1) :: Matrix C | ||
366 | sgen n = single (cgen n) | ||
367 | |||
368 | checkGen x = norm_Inf $ flatten $ invg x <> x - ident (rows x) | ||
369 | |||
370 | invg t = gaussElim t (ident (rows t)) | ||
371 | |||
372 | checkLU okf t = norm_Inf $ flatten (l <> u <> p - t) | ||
373 | where | ||
374 | (l,u,p,_) = luFact (LU x' p') | ||
375 | where | ||
376 | (x',p') = mutable (luST okf) t | ||
377 | |||
378 | checkSolve aa = norm_Inf $ flatten (aa <> x - bb) | ||
379 | where | ||
380 | bb = flipud aa | ||
381 | x = luSolve' (luPacked' aa) bb | ||
382 | |||
383 | tmm = diagRect 1 (fromList [2..6]) 5 5 :: Matrix (Mod 19 I) | ||
384 | |||
385 | info = do | ||
386 | print v | ||
387 | print m | ||
388 | print (tr m) | ||
389 | print $ v+v | ||
390 | print $ m+m | ||
391 | print $ m <> m | ||
392 | print $ m #> v | ||
393 | |||
394 | print $ am <> gaussElim am bm - bm | ||
395 | print $ ad <> gaussElim ad bd - bd | ||
396 | |||
397 | print g | ||
398 | print $ g <> g | ||
399 | print gm | ||
400 | print $ gm <> gm | ||
401 | |||
402 | print lg | ||
403 | print $ lg <> lg | ||
404 | print lgm | ||
405 | print $ lgm <> lgm | ||
406 | |||
407 | putStrLn "checkGen" | ||
408 | print (checkGen (gen 5 :: Matrix R)) | ||
409 | print (checkGen (gen 5 :: Matrix Float)) | ||
410 | print (checkGen (cgen 5 :: Matrix C)) | ||
411 | print (checkGen (sgen 5 :: Matrix (Complex Float))) | ||
412 | print (invg (gen 5) :: Matrix (Mod 7 I)) | ||
413 | print (invg (gen 5) :: Matrix (Mod 7 Z)) | ||
414 | |||
415 | print $ mutable (luST (const True)) (gen 5 :: Matrix R) | ||
416 | print $ mutable (luST (const True)) (gen 5 :: Matrix (Mod 11 Z)) | ||
417 | |||
418 | putStrLn "checkLU" | ||
419 | print $ checkLU (magnit 0) (gen 5 :: Matrix R) | ||
420 | print $ checkLU (magnit 0) (gen 5 :: Matrix Float) | ||
421 | print $ checkLU (magnit 0) (cgen 5 :: Matrix C) | ||
422 | print $ checkLU (magnit 0) (sgen 5 :: Matrix (Complex Float)) | ||
423 | print $ checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 I)) | ||
424 | print $ checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 Z)) | ||
425 | |||
426 | putStrLn "checkSolve" | ||
427 | print $ checkSolve (gen 5 :: Matrix R) | ||
428 | print $ checkSolve (gen 5 :: Matrix Float) | ||
429 | print $ checkSolve (cgen 5 :: Matrix C) | ||
430 | print $ checkSolve (sgen 5 :: Matrix (Complex Float)) | ||
431 | print $ checkSolve (gen 5 :: Matrix (Mod 7 I)) | ||
432 | print $ checkSolve (gen 5 :: Matrix (Mod 7 Z)) | ||
433 | |||
434 | putStrLn "luSolve'" | ||
435 | print $ luSolve' (luPacked' tmm) (ident (rows tmm)) | ||
436 | print $ invershur tmm | ||
437 | |||
438 | |||
439 | ok = and | ||
440 | [ toInt (m #> v) == cmod 11 (toInt m #> toInt v ) | ||
441 | , am <> gaussElim_1 am bm == bm | ||
442 | , am <> gaussElim_2 am bm == bm | ||
443 | , am <> gaussElim am bm == bm | ||
444 | , (checkGen (gen 5 :: Matrix R)) < 1E-15 | ||
445 | , (checkGen (gen 5 :: Matrix Float)) < 2E-7 | ||
446 | , (checkGen (cgen 5 :: Matrix C)) < 1E-15 | ||
447 | , (checkGen (sgen 5 :: Matrix (Complex Float))) < 3E-7 | ||
448 | , (checkGen (gen 5 :: Matrix (Mod 7 I))) == 0 | ||
449 | , (checkGen (gen 5 :: Matrix (Mod 7 Z))) == 0 | ||
450 | , (checkLU (magnit 1E-10) (gen 5 :: Matrix R)) < 2E-15 | ||
451 | , (checkLU (magnit 1E-5) (gen 5 :: Matrix Float)) < 1E-6 | ||
452 | , (checkLU (magnit 1E-10) (cgen 5 :: Matrix C)) < 5E-15 | ||
453 | , (checkLU (magnit 1E-5) (sgen 5 :: Matrix (Complex Float))) < 1E-6 | ||
454 | , (checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 I))) == 0 | ||
455 | , (checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 Z))) == 0 | ||
456 | , checkSolve (gen 5 :: Matrix R) < 2E-15 | ||
457 | , checkSolve (gen 5 :: Matrix Float) < 1E-6 | ||
458 | , checkSolve (cgen 5 :: Matrix C) < 4E-15 | ||
459 | , checkSolve (sgen 5 :: Matrix (Complex Float)) < 1E-6 | ||
460 | , checkSolve (gen 5 :: Matrix (Mod 7 I)) == 0 | ||
461 | , checkSolve (gen 5 :: Matrix (Mod 7 Z)) == 0 | ||
462 | , prodElements (konst (9:: Mod 10 I) (12::Int)) == product (replicate 12 (9:: Mod 10 I)) | ||
463 | , gm <> gm == konst 0 (3,3) | ||
464 | , lgm <> lgm == konst 0 (3,3) | ||
465 | , invershur tmm == luSolve' (luPacked' tmm) (ident (rows tmm)) | ||
466 | , luSolve' (luPacked' (tr $ ident 5 :: Matrix (I ./. 2))) (ident 5) == ident 5 | ||
467 | ] | ||
468 | |||
469 | |||
diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Internal/Numeric.hs index 257ad73..e8c7440 100644 --- a/packages/base/src/Data/Packed/Internal/Numeric.hs +++ b/packages/base/src/Internal/Numeric.hs | |||
@@ -16,44 +16,18 @@ | |||
16 | -- | 16 | -- |
17 | ----------------------------------------------------------------------------- | 17 | ----------------------------------------------------------------------------- |
18 | 18 | ||
19 | module Data.Packed.Internal.Numeric ( | 19 | module Internal.Numeric where |
20 | -- * Basic functions | ||
21 | ident, diag, ctrans, | ||
22 | -- * Generic operations | ||
23 | Container(..), | ||
24 | scalar, conj, scale, arctan2, cmap, | ||
25 | atIndex, minIndex, maxIndex, minElement, maxElement, | ||
26 | sumElements, prodElements, | ||
27 | step, cond, find, assoc, accum, | ||
28 | Transposable(..), Linear(..), Testable(..), | ||
29 | -- * Matrix product and related functions | ||
30 | Product(..), udot, | ||
31 | mXm,mXv,vXm, | ||
32 | outer, kronecker, | ||
33 | -- * sorting | ||
34 | sortVector, | ||
35 | -- * Element conversion | ||
36 | Convert(..), | ||
37 | Complexable(), | ||
38 | RealElement(), | ||
39 | roundVector, | ||
40 | RealOf, ComplexOf, SingleOf, DoubleOf, | ||
41 | IndexOf, | ||
42 | module Data.Complex | ||
43 | ) where | ||
44 | |||
45 | import Data.Packed | ||
46 | import Data.Packed.ST as ST | ||
47 | import Numeric.Conversion | ||
48 | import Data.Packed.Development | ||
49 | import Numeric.Vectorized | ||
50 | import Data.Complex | ||
51 | import Control.Applicative((<*>)) | ||
52 | |||
53 | import Numeric.LinearAlgebra.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ) | ||
54 | import Data.Packed.Internal | ||
55 | 20 | ||
56 | ------------------------------------------------------------------- | 21 | import Internal.Vector |
22 | import Internal.Matrix | ||
23 | import Internal.Element | ||
24 | import Internal.ST as ST | ||
25 | import Internal.Conversion | ||
26 | import Internal.Vectorized | ||
27 | import Internal.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ,multiplyI,multiplyL) | ||
28 | import Data.List.Split(chunksOf) | ||
29 | |||
30 | -------------------------------------------------------------------------------- | ||
57 | 31 | ||
58 | type family IndexOf (c :: * -> *) | 32 | type family IndexOf (c :: * -> *) |
59 | 33 | ||
@@ -65,30 +39,21 @@ type family ArgOf (c :: * -> *) a | |||
65 | type instance ArgOf Vector a = a -> a | 39 | type instance ArgOf Vector a = a -> a |
66 | type instance ArgOf Matrix a = a -> a -> a | 40 | type instance ArgOf Matrix a = a -> a -> a |
67 | 41 | ||
68 | ------------------------------------------------------------------- | 42 | -------------------------------------------------------------------------------- |
69 | 43 | ||
70 | -- | Basic element-by-element functions for numeric containers | 44 | -- | Basic element-by-element functions for numeric containers |
71 | class (Complexable c, Fractional e, Element e) => Container c e | 45 | class Element e => Container c e |
72 | where | 46 | where |
47 | conj' :: c e -> c e | ||
73 | size' :: c e -> IndexOf c | 48 | size' :: c e -> IndexOf c |
74 | scalar' :: e -> c e | 49 | scalar' :: e -> c e |
75 | conj' :: c e -> c e | ||
76 | scale' :: e -> c e -> c e | 50 | scale' :: e -> c e -> c e |
77 | -- | scale the element by element reciprocal of the object: | ||
78 | -- | ||
79 | -- @scaleRecip 2 (fromList [5,i]) == 2 |> [0.4 :+ 0.0,0.0 :+ (-2.0)]@ | ||
80 | scaleRecip :: e -> c e -> c e | ||
81 | addConstant :: e -> c e -> c e | 51 | addConstant :: e -> c e -> c e |
82 | add :: c e -> c e -> c e | 52 | add' :: c e -> c e -> c e |
83 | sub :: c e -> c e -> c e | 53 | sub :: c e -> c e -> c e |
84 | -- | element by element multiplication | 54 | -- | element by element multiplication |
85 | mul :: c e -> c e -> c e | 55 | mul :: c e -> c e -> c e |
86 | -- | element by element division | ||
87 | divide :: c e -> c e -> c e | ||
88 | equal :: c e -> c e -> Bool | 56 | equal :: c e -> c e -> Bool |
89 | -- | ||
90 | -- element by element inverse tangent | ||
91 | arctan2' :: c e -> c e -> c e | ||
92 | cmap' :: (Element b) => (e -> b) -> c e -> c b | 57 | cmap' :: (Element b) => (e -> b) -> c e -> c b |
93 | konst' :: e -> IndexOf c -> c e | 58 | konst' :: e -> IndexOf c -> c e |
94 | build' :: IndexOf c -> (ArgOf c e) -> c e | 59 | build' :: IndexOf c -> (ArgOf c e) -> c e |
@@ -99,14 +64,9 @@ class (Complexable c, Fractional e, Element e) => Container c e | |||
99 | maxElement' :: c e -> e | 64 | maxElement' :: c e -> e |
100 | sumElements' :: c e -> e | 65 | sumElements' :: c e -> e |
101 | prodElements' :: c e -> e | 66 | prodElements' :: c e -> e |
102 | step' :: RealElement e => c e -> c e | 67 | step' :: Ord e => c e -> c e |
103 | cond' :: RealElement e | 68 | ccompare' :: Ord e => c e -> c e -> c I |
104 | => c e -- ^ a | 69 | cselect' :: c I -> c e -> c e -> c e -> c e |
105 | -> c e -- ^ b | ||
106 | -> c e -- ^ l | ||
107 | -> c e -- ^ e | ||
108 | -> c e -- ^ g | ||
109 | -> c e -- ^ result | ||
110 | find' :: (e -> Bool) -> c e -> [IndexOf c] | 70 | find' :: (e -> Bool) -> c e -> [IndexOf c] |
111 | assoc' :: IndexOf c -- ^ size | 71 | assoc' :: IndexOf c -- ^ size |
112 | -> e -- ^ default value | 72 | -> e -- ^ default value |
@@ -117,24 +77,115 @@ class (Complexable c, Fractional e, Element e) => Container c e | |||
117 | -> [(IndexOf c, e)] -- ^ association list | 77 | -> [(IndexOf c, e)] -- ^ association list |
118 | -> c e -- ^ result | 78 | -> c e -- ^ result |
119 | 79 | ||
80 | -- | scale the element by element reciprocal of the object: | ||
81 | -- | ||
82 | -- @scaleRecip 2 (fromList [5,i]) == 2 |> [0.4 :+ 0.0,0.0 :+ (-2.0)]@ | ||
83 | scaleRecip :: Fractional e => e -> c e -> c e | ||
84 | -- | element by element division | ||
85 | divide :: Fractional e => c e -> c e -> c e | ||
86 | -- | ||
87 | -- element by element inverse tangent | ||
88 | arctan2' :: Fractional e => c e -> c e -> c e | ||
89 | cmod' :: Integral e => e -> c e -> c e | ||
90 | fromInt' :: c I -> c e | ||
91 | toInt' :: c e -> c I | ||
92 | fromZ' :: c Z -> c e | ||
93 | toZ' :: c e -> c Z | ||
94 | |||
120 | -------------------------------------------------------------------------- | 95 | -------------------------------------------------------------------------- |
121 | 96 | ||
97 | instance Container Vector I | ||
98 | where | ||
99 | conj' = id | ||
100 | size' = dim | ||
101 | scale' = vectorMapValI Scale | ||
102 | addConstant = vectorMapValI AddConstant | ||
103 | add' = vectorZipI Add | ||
104 | sub = vectorZipI Sub | ||
105 | mul = vectorZipI Mul | ||
106 | equal u v = dim u == dim v && maxElement' (vectorMapI Abs (sub u v)) == 0 | ||
107 | scalar' x = fromList [x] | ||
108 | konst' = constantD | ||
109 | build' = buildV | ||
110 | cmap' = mapVector | ||
111 | atIndex' = (@>) | ||
112 | minIndex' = emptyErrorV "minIndex" (fromIntegral . toScalarI MinIdx) | ||
113 | maxIndex' = emptyErrorV "maxIndex" (fromIntegral . toScalarI MaxIdx) | ||
114 | minElement' = emptyErrorV "minElement" (toScalarI Min) | ||
115 | maxElement' = emptyErrorV "maxElement" (toScalarI Max) | ||
116 | sumElements' = sumI 1 | ||
117 | prodElements' = prodI 1 | ||
118 | step' = stepI | ||
119 | find' = findV | ||
120 | assoc' = assocV | ||
121 | accum' = accumV | ||
122 | ccompare' = compareCV compareV | ||
123 | cselect' = selectCV selectV | ||
124 | scaleRecip = undefined -- cannot match | ||
125 | divide = undefined | ||
126 | arctan2' = undefined | ||
127 | cmod' m x | ||
128 | | m /= 0 = vectorMapValI ModVS m x | ||
129 | | otherwise = error $ "cmod 0 on vector of size "++(show $ dim x) | ||
130 | fromInt' = id | ||
131 | toInt' = id | ||
132 | fromZ' = long2intV | ||
133 | toZ' = int2longV | ||
134 | |||
135 | |||
136 | instance Container Vector Z | ||
137 | where | ||
138 | conj' = id | ||
139 | size' = dim | ||
140 | scale' = vectorMapValL Scale | ||
141 | addConstant = vectorMapValL AddConstant | ||
142 | add' = vectorZipL Add | ||
143 | sub = vectorZipL Sub | ||
144 | mul = vectorZipL Mul | ||
145 | equal u v = dim u == dim v && maxElement' (vectorMapL Abs (sub u v)) == 0 | ||
146 | scalar' x = fromList [x] | ||
147 | konst' = constantD | ||
148 | build' = buildV | ||
149 | cmap' = mapVector | ||
150 | atIndex' = (@>) | ||
151 | minIndex' = emptyErrorV "minIndex" (fromIntegral . toScalarL MinIdx) | ||
152 | maxIndex' = emptyErrorV "maxIndex" (fromIntegral . toScalarL MaxIdx) | ||
153 | minElement' = emptyErrorV "minElement" (toScalarL Min) | ||
154 | maxElement' = emptyErrorV "maxElement" (toScalarL Max) | ||
155 | sumElements' = sumL 1 | ||
156 | prodElements' = prodL 1 | ||
157 | step' = stepL | ||
158 | find' = findV | ||
159 | assoc' = assocV | ||
160 | accum' = accumV | ||
161 | ccompare' = compareCV compareV | ||
162 | cselect' = selectCV selectV | ||
163 | scaleRecip = undefined -- cannot match | ||
164 | divide = undefined | ||
165 | arctan2' = undefined | ||
166 | cmod' m x | ||
167 | | m /= 0 = vectorMapValL ModVS m x | ||
168 | | otherwise = error $ "cmod 0 on vector of size "++(show $ dim x) | ||
169 | fromInt' = int2longV | ||
170 | toInt' = long2intV | ||
171 | fromZ' = id | ||
172 | toZ' = id | ||
173 | |||
174 | |||
175 | |||
122 | instance Container Vector Float | 176 | instance Container Vector Float |
123 | where | 177 | where |
178 | conj' = id | ||
124 | size' = dim | 179 | size' = dim |
125 | scale' = vectorMapValF Scale | 180 | scale' = vectorMapValF Scale |
126 | scaleRecip = vectorMapValF Recip | ||
127 | addConstant = vectorMapValF AddConstant | 181 | addConstant = vectorMapValF AddConstant |
128 | add = vectorZipF Add | 182 | add' = vectorZipF Add |
129 | sub = vectorZipF Sub | 183 | sub = vectorZipF Sub |
130 | mul = vectorZipF Mul | 184 | mul = vectorZipF Mul |
131 | divide = vectorZipF Div | ||
132 | equal u v = dim u == dim v && maxElement (vectorMapF Abs (sub u v)) == 0.0 | 185 | equal u v = dim u == dim v && maxElement (vectorMapF Abs (sub u v)) == 0.0 |
133 | arctan2' = vectorZipF ATan2 | ||
134 | scalar' x = fromList [x] | 186 | scalar' x = fromList [x] |
135 | konst' = constantD | 187 | konst' = constantD |
136 | build' = buildV | 188 | build' = buildV |
137 | conj' = id | ||
138 | cmap' = mapVector | 189 | cmap' = mapVector |
139 | atIndex' = (@>) | 190 | atIndex' = (@>) |
140 | minIndex' = emptyErrorV "minIndex" (round . toScalarF MinIdx) | 191 | minIndex' = emptyErrorV "minIndex" (round . toScalarF MinIdx) |
@@ -147,24 +198,31 @@ instance Container Vector Float | |||
147 | find' = findV | 198 | find' = findV |
148 | assoc' = assocV | 199 | assoc' = assocV |
149 | accum' = accumV | 200 | accum' = accumV |
150 | cond' = condV condF | 201 | ccompare' = compareCV compareV |
202 | cselect' = selectCV selectV | ||
203 | scaleRecip = vectorMapValF Recip | ||
204 | divide = vectorZipF Div | ||
205 | arctan2' = vectorZipF ATan2 | ||
206 | cmod' = undefined | ||
207 | fromInt' = int2floatV | ||
208 | toInt' = float2IntV | ||
209 | fromZ' = (single :: Vector R-> Vector Float) . fromZ' | ||
210 | toZ' = toZ' . double | ||
211 | |||
151 | 212 | ||
152 | instance Container Vector Double | 213 | instance Container Vector Double |
153 | where | 214 | where |
215 | conj' = id | ||
154 | size' = dim | 216 | size' = dim |
155 | scale' = vectorMapValR Scale | 217 | scale' = vectorMapValR Scale |
156 | scaleRecip = vectorMapValR Recip | ||
157 | addConstant = vectorMapValR AddConstant | 218 | addConstant = vectorMapValR AddConstant |
158 | add = vectorZipR Add | 219 | add' = vectorZipR Add |
159 | sub = vectorZipR Sub | 220 | sub = vectorZipR Sub |
160 | mul = vectorZipR Mul | 221 | mul = vectorZipR Mul |
161 | divide = vectorZipR Div | ||
162 | equal u v = dim u == dim v && maxElement (vectorMapR Abs (sub u v)) == 0.0 | 222 | equal u v = dim u == dim v && maxElement (vectorMapR Abs (sub u v)) == 0.0 |
163 | arctan2' = vectorZipR ATan2 | ||
164 | scalar' x = fromList [x] | 223 | scalar' x = fromList [x] |
165 | konst' = constantD | 224 | konst' = constantD |
166 | build' = buildV | 225 | build' = buildV |
167 | conj' = id | ||
168 | cmap' = mapVector | 226 | cmap' = mapVector |
169 | atIndex' = (@>) | 227 | atIndex' = (@>) |
170 | minIndex' = emptyErrorV "minIndex" (round . toScalarR MinIdx) | 228 | minIndex' = emptyErrorV "minIndex" (round . toScalarR MinIdx) |
@@ -177,24 +235,31 @@ instance Container Vector Double | |||
177 | find' = findV | 235 | find' = findV |
178 | assoc' = assocV | 236 | assoc' = assocV |
179 | accum' = accumV | 237 | accum' = accumV |
180 | cond' = condV condD | 238 | ccompare' = compareCV compareV |
239 | cselect' = selectCV selectV | ||
240 | scaleRecip = vectorMapValR Recip | ||
241 | divide = vectorZipR Div | ||
242 | arctan2' = vectorZipR ATan2 | ||
243 | cmod' = undefined | ||
244 | fromInt' = int2DoubleV | ||
245 | toInt' = double2IntV | ||
246 | fromZ' = long2DoubleV | ||
247 | toZ' = double2longV | ||
248 | |||
181 | 249 | ||
182 | instance Container Vector (Complex Double) | 250 | instance Container Vector (Complex Double) |
183 | where | 251 | where |
252 | conj' = conjugateC | ||
184 | size' = dim | 253 | size' = dim |
185 | scale' = vectorMapValC Scale | 254 | scale' = vectorMapValC Scale |
186 | scaleRecip = vectorMapValC Recip | ||
187 | addConstant = vectorMapValC AddConstant | 255 | addConstant = vectorMapValC AddConstant |
188 | add = vectorZipC Add | 256 | add' = vectorZipC Add |
189 | sub = vectorZipC Sub | 257 | sub = vectorZipC Sub |
190 | mul = vectorZipC Mul | 258 | mul = vectorZipC Mul |
191 | divide = vectorZipC Div | ||
192 | equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 | 259 | equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 |
193 | arctan2' = vectorZipC ATan2 | ||
194 | scalar' x = fromList [x] | 260 | scalar' x = fromList [x] |
195 | konst' = constantD | 261 | konst' = constantD |
196 | build' = buildV | 262 | build' = buildV |
197 | conj' = conjugateC | ||
198 | cmap' = mapVector | 263 | cmap' = mapVector |
199 | atIndex' = (@>) | 264 | atIndex' = (@>) |
200 | minIndex' = emptyErrorV "minIndex" (minIndex' . fst . fromComplex . (mul <*> conj')) | 265 | minIndex' = emptyErrorV "minIndex" (minIndex' . fst . fromComplex . (mul <*> conj')) |
@@ -207,24 +272,30 @@ instance Container Vector (Complex Double) | |||
207 | find' = findV | 272 | find' = findV |
208 | assoc' = assocV | 273 | assoc' = assocV |
209 | accum' = accumV | 274 | accum' = accumV |
210 | cond' = undefined -- cannot match | 275 | ccompare' = undefined -- cannot match |
276 | cselect' = selectCV selectV | ||
277 | scaleRecip = vectorMapValC Recip | ||
278 | divide = vectorZipC Div | ||
279 | arctan2' = vectorZipC ATan2 | ||
280 | cmod' = undefined | ||
281 | fromInt' = complex . int2DoubleV | ||
282 | toInt' = toInt' . fst . fromComplex | ||
283 | fromZ' = complex . long2DoubleV | ||
284 | toZ' = toZ' . fst . fromComplex | ||
211 | 285 | ||
212 | instance Container Vector (Complex Float) | 286 | instance Container Vector (Complex Float) |
213 | where | 287 | where |
288 | conj' = conjugateQ | ||
214 | size' = dim | 289 | size' = dim |
215 | scale' = vectorMapValQ Scale | 290 | scale' = vectorMapValQ Scale |
216 | scaleRecip = vectorMapValQ Recip | ||
217 | addConstant = vectorMapValQ AddConstant | 291 | addConstant = vectorMapValQ AddConstant |
218 | add = vectorZipQ Add | 292 | add' = vectorZipQ Add |
219 | sub = vectorZipQ Sub | 293 | sub = vectorZipQ Sub |
220 | mul = vectorZipQ Mul | 294 | mul = vectorZipQ Mul |
221 | divide = vectorZipQ Div | ||
222 | equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 | 295 | equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 |
223 | arctan2' = vectorZipQ ATan2 | ||
224 | scalar' x = fromList [x] | 296 | scalar' x = fromList [x] |
225 | konst' = constantD | 297 | konst' = constantD |
226 | build' = buildV | 298 | build' = buildV |
227 | conj' = conjugateQ | ||
228 | cmap' = mapVector | 299 | cmap' = mapVector |
229 | atIndex' = (@>) | 300 | atIndex' = (@>) |
230 | minIndex' = emptyErrorV "minIndex" (minIndex' . fst . fromComplex . (mul <*> conj')) | 301 | minIndex' = emptyErrorV "minIndex" (minIndex' . fst . fromComplex . (mul <*> conj')) |
@@ -237,26 +308,32 @@ instance Container Vector (Complex Float) | |||
237 | find' = findV | 308 | find' = findV |
238 | assoc' = assocV | 309 | assoc' = assocV |
239 | accum' = accumV | 310 | accum' = accumV |
240 | cond' = undefined -- cannot match | 311 | ccompare' = undefined -- cannot match |
312 | cselect' = selectCV selectV | ||
313 | scaleRecip = vectorMapValQ Recip | ||
314 | divide = vectorZipQ Div | ||
315 | arctan2' = vectorZipQ ATan2 | ||
316 | cmod' = undefined | ||
317 | fromInt' = complex . int2floatV | ||
318 | toInt' = toInt' . fst . fromComplex | ||
319 | fromZ' = complex . single . long2DoubleV | ||
320 | toZ' = toZ' . double . fst . fromComplex | ||
241 | 321 | ||
242 | --------------------------------------------------------------- | 322 | --------------------------------------------------------------- |
243 | 323 | ||
244 | instance (Fractional a, Element a, Container Vector a) => Container Matrix a | 324 | instance (Num a, Element a, Container Vector a) => Container Matrix a |
245 | where | 325 | where |
326 | conj' = liftMatrix conj' | ||
246 | size' = size | 327 | size' = size |
247 | scale' x = liftMatrix (scale' x) | 328 | scale' x = liftMatrix (scale' x) |
248 | scaleRecip x = liftMatrix (scaleRecip x) | ||
249 | addConstant x = liftMatrix (addConstant x) | 329 | addConstant x = liftMatrix (addConstant x) |
250 | add = liftMatrix2 add | 330 | add' = liftMatrix2 add' |
251 | sub = liftMatrix2 sub | 331 | sub = liftMatrix2 sub |
252 | mul = liftMatrix2 mul | 332 | mul = liftMatrix2 mul |
253 | divide = liftMatrix2 divide | ||
254 | equal a b = cols a == cols b && flatten a `equal` flatten b | 333 | equal a b = cols a == cols b && flatten a `equal` flatten b |
255 | arctan2' = liftMatrix2 arctan2' | ||
256 | scalar' x = (1><1) [x] | 334 | scalar' x = (1><1) [x] |
257 | konst' v (r,c) = matrixFromVector RowMajor r c (konst' v (r*c)) | 335 | konst' v (r,c) = matrixFromVector RowMajor r c (konst' v (r*c)) |
258 | build' = buildM | 336 | build' = buildM |
259 | conj' = liftMatrix conj' | ||
260 | cmap' f = liftMatrix (mapVector f) | 337 | cmap' f = liftMatrix (mapVector f) |
261 | atIndex' = (@@>) | 338 | atIndex' = (@@>) |
262 | minIndex' = emptyErrorM "minIndex of Matrix" $ | 339 | minIndex' = emptyErrorM "minIndex of Matrix" $ |
@@ -265,19 +342,30 @@ instance (Fractional a, Element a, Container Vector a) => Container Matrix a | |||
265 | \m -> divMod (maxIndex' $ flatten m) (cols m) | 342 | \m -> divMod (maxIndex' $ flatten m) (cols m) |
266 | minElement' = emptyErrorM "minElement of Matrix" (atIndex' <*> minIndex') | 343 | minElement' = emptyErrorM "minElement of Matrix" (atIndex' <*> minIndex') |
267 | maxElement' = emptyErrorM "maxElement of Matrix" (atIndex' <*> maxIndex') | 344 | maxElement' = emptyErrorM "maxElement of Matrix" (atIndex' <*> maxIndex') |
268 | sumElements' = sumElements . flatten | 345 | sumElements' = sumElements' . flatten |
269 | prodElements' = prodElements . flatten | 346 | prodElements' = prodElements' . flatten |
270 | step' = liftMatrix step | 347 | step' = liftMatrix step' |
271 | find' = findM | 348 | find' = findM |
272 | assoc' = assocM | 349 | assoc' = assocM |
273 | accum' = accumM | 350 | accum' = accumM |
274 | cond' = condM | 351 | ccompare' = compareM |
352 | cselect' = selectM | ||
353 | scaleRecip x = liftMatrix (scaleRecip x) | ||
354 | divide = liftMatrix2 divide | ||
355 | arctan2' = liftMatrix2 arctan2' | ||
356 | cmod' m x | ||
357 | | m /= 0 = liftMatrix (cmod' m) x | ||
358 | | otherwise = error $ "cmod 0 on matrix "++shSize x | ||
359 | fromInt' = liftMatrix fromInt' | ||
360 | toInt' = liftMatrix toInt' | ||
361 | fromZ' = liftMatrix fromZ' | ||
362 | toZ' = liftMatrix toZ' | ||
275 | 363 | ||
276 | 364 | ||
277 | emptyErrorV msg f v = | 365 | emptyErrorV msg f v = |
278 | if dim v > 0 | 366 | if dim v > 0 |
279 | then f v | 367 | then f v |
280 | else error $ msg ++ " of Vector with dim = 0" | 368 | else error $ msg ++ " of empty Vector" |
281 | 369 | ||
282 | emptyErrorM msg f m = | 370 | emptyErrorM msg f m = |
283 | if rows m > 0 && cols m > 0 | 371 | if rows m > 0 && cols m > 0 |
@@ -299,18 +387,47 @@ scalar = scalar' | |||
299 | conj :: Container c e => c e -> c e | 387 | conj :: Container c e => c e -> c e |
300 | conj = conj' | 388 | conj = conj' |
301 | 389 | ||
302 | -- | multiplication by scalar | ||
303 | scale :: Container c e => e -> c e -> c e | ||
304 | scale = scale' | ||
305 | 390 | ||
306 | arctan2 :: Container c e => c e -> c e -> c e | 391 | arctan2 :: (Fractional e, Container c e) => c e -> c e -> c e |
307 | arctan2 = arctan2' | 392 | arctan2 = arctan2' |
308 | 393 | ||
394 | -- | 'mod' for integer arrays | ||
395 | -- | ||
396 | -- >>> cmod 3 (range 5) | ||
397 | -- fromList [0,1,2,0,1] | ||
398 | cmod :: (Integral e, Container c e) => e -> c e -> c e | ||
399 | cmod = cmod' | ||
400 | |||
401 | -- | | ||
402 | -- >>>fromInt ((2><2) [0..3]) :: Matrix (Complex Double) | ||
403 | -- (2><2) | ||
404 | -- [ 0.0 :+ 0.0, 1.0 :+ 0.0 | ||
405 | -- , 2.0 :+ 0.0, 3.0 :+ 0.0 ] | ||
406 | -- | ||
407 | fromInt :: (Container c e) => c I -> c e | ||
408 | fromInt = fromInt' | ||
409 | |||
410 | toInt :: (Container c e) => c e -> c I | ||
411 | toInt = toInt' | ||
412 | |||
413 | fromZ :: (Container c e) => c Z -> c e | ||
414 | fromZ = fromZ' | ||
415 | |||
416 | toZ :: (Container c e) => c e -> c Z | ||
417 | toZ = toZ' | ||
418 | |||
309 | -- | like 'fmap' (cannot implement instance Functor because of Element class constraint) | 419 | -- | like 'fmap' (cannot implement instance Functor because of Element class constraint) |
310 | cmap :: (Element b, Container c e) => (e -> b) -> c e -> c b | 420 | cmap :: (Element b, Container c e) => (e -> b) -> c e -> c b |
311 | cmap = cmap' | 421 | cmap = cmap' |
312 | 422 | ||
313 | -- | indexing function | 423 | -- | generic indexing function |
424 | -- | ||
425 | -- >>> vector [1,2,3] `atIndex` 1 | ||
426 | -- 2.0 | ||
427 | -- | ||
428 | -- >>> matrix 3 [0..8] `atIndex` (2,0) | ||
429 | -- 6.0 | ||
430 | -- | ||
314 | atIndex :: Container c e => c e -> IndexOf c -> e | 431 | atIndex :: Container c e => c e -> IndexOf c -> e |
315 | atIndex = atIndex' | 432 | atIndex = atIndex' |
316 | 433 | ||
@@ -345,7 +462,7 @@ prodElements = prodElements' | |||
345 | -- 5 |> [0.0,0.0,0.0,1.0,1.0] | 462 | -- 5 |> [0.0,0.0,0.0,1.0,1.0] |
346 | -- | 463 | -- |
347 | step | 464 | step |
348 | :: (RealElement e, Container c e) | 465 | :: (Ord e, Container c e) |
349 | => c e | 466 | => c e |
350 | -> c e | 467 | -> c e |
351 | step = step' | 468 | step = step' |
@@ -361,15 +478,17 @@ step = step' | |||
361 | -- , 0.0, 100.0, 7.0, 8.0 | 478 | -- , 0.0, 100.0, 7.0, 8.0 |
362 | -- , 0.0, 0.0, 100.0, 12.0 ] | 479 | -- , 0.0, 0.0, 100.0, 12.0 ] |
363 | -- | 480 | -- |
481 | -- >>> let chop x = cond (abs x) 1E-6 0 0 x | ||
482 | -- | ||
364 | cond | 483 | cond |
365 | :: (RealElement e, Container c e) | 484 | :: (Ord e, Container c e, Container c x) |
366 | => c e -- ^ a | 485 | => c e -- ^ a |
367 | -> c e -- ^ b | 486 | -> c e -- ^ b |
368 | -> c e -- ^ l | 487 | -> c x -- ^ l |
369 | -> c e -- ^ e | 488 | -> c x -- ^ e |
370 | -> c e -- ^ g | 489 | -> c x -- ^ g |
371 | -> c e -- ^ result | 490 | -> c x -- ^ result |
372 | cond = cond' | 491 | cond a b l e g = cselect' (ccompare' a b) l e g |
373 | 492 | ||
374 | 493 | ||
375 | -- | Find index of elements which satisfy a predicate | 494 | -- | Find index of elements which satisfy a predicate |
@@ -427,6 +546,52 @@ accum | |||
427 | -> c e -- ^ result | 546 | -> c e -- ^ result |
428 | accum = accum' | 547 | accum = accum' |
429 | 548 | ||
549 | -------------------------------------------------------------------------------- | ||
550 | |||
551 | class Konst e d c | d -> c, c -> d | ||
552 | where | ||
553 | -- | | ||
554 | -- >>> konst 7 3 :: Vector Float | ||
555 | -- fromList [7.0,7.0,7.0] | ||
556 | -- | ||
557 | -- >>> konst i (3::Int,4::Int) | ||
558 | -- (3><4) | ||
559 | -- [ 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0 | ||
560 | -- , 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0 | ||
561 | -- , 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0 ] | ||
562 | -- | ||
563 | konst :: e -> d -> c e | ||
564 | |||
565 | instance Container Vector e => Konst e Int Vector | ||
566 | where | ||
567 | konst = konst' | ||
568 | |||
569 | instance (Num e, Container Vector e) => Konst e (Int,Int) Matrix | ||
570 | where | ||
571 | konst = konst' | ||
572 | |||
573 | -------------------------------------------------------------------------------- | ||
574 | |||
575 | class ( Container Vector t | ||
576 | , Container Matrix t | ||
577 | , Konst t Int Vector | ||
578 | , Konst t (Int,Int) Matrix | ||
579 | , CTrans t | ||
580 | , Product t | ||
581 | , Additive (Vector t) | ||
582 | , Additive (Matrix t) | ||
583 | , Linear t Vector | ||
584 | , Linear t Matrix | ||
585 | ) => Numeric t | ||
586 | |||
587 | instance Numeric Double | ||
588 | instance Numeric (Complex Double) | ||
589 | instance Numeric Float | ||
590 | instance Numeric (Complex Float) | ||
591 | instance Numeric I | ||
592 | instance Numeric Z | ||
593 | |||
594 | -------------------------------------------------------------------------------- | ||
430 | 595 | ||
431 | -------------------------------------------------------------------------------- | 596 | -------------------------------------------------------------------------------- |
432 | 597 | ||
@@ -439,7 +604,7 @@ class (Num e, Element e) => Product e where | |||
439 | -- | sum of absolute value of elements | 604 | -- | sum of absolute value of elements |
440 | norm1 :: Vector e -> RealOf e | 605 | norm1 :: Vector e -> RealOf e |
441 | -- | euclidean norm | 606 | -- | euclidean norm |
442 | norm2 :: Vector e -> RealOf e | 607 | norm2 :: Floating e => Vector e -> RealOf e |
443 | -- | element of maximum magnitude | 608 | -- | element of maximum magnitude |
444 | normInf :: Vector e -> RealOf e | 609 | normInf :: Vector e -> RealOf e |
445 | 610 | ||
@@ -471,6 +636,21 @@ instance Product (Complex Double) where | |||
471 | normInf = emptyVal (maxElement . fst . fromComplex . vectorMapC Abs) | 636 | normInf = emptyVal (maxElement . fst . fromComplex . vectorMapC Abs) |
472 | multiply = emptyMul multiplyC | 637 | multiply = emptyMul multiplyC |
473 | 638 | ||
639 | instance Product I where | ||
640 | norm2 = undefined | ||
641 | absSum = emptyVal (sumElements . vectorMapI Abs) | ||
642 | norm1 = absSum | ||
643 | normInf = emptyVal (maxElement . vectorMapI Abs) | ||
644 | multiply = emptyMul (multiplyI 1) | ||
645 | |||
646 | instance Product Z where | ||
647 | norm2 = undefined | ||
648 | absSum = emptyVal (sumElements . vectorMapL Abs) | ||
649 | norm1 = absSum | ||
650 | normInf = emptyVal (maxElement . vectorMapL Abs) | ||
651 | multiply = emptyMul (multiplyL 1) | ||
652 | |||
653 | |||
474 | emptyMul m a b | 654 | emptyMul m a b |
475 | | x1 == 0 && x2 == 0 || r == 0 || c == 0 = konst' 0 (r,c) | 655 | | x1 == 0 && x2 == 0 || r == 0 || c == 0 = konst' 0 (r,c) |
476 | | otherwise = m a b | 656 | | otherwise = m a b |
@@ -546,7 +726,7 @@ m2=(4><3) | |||
546 | -} | 726 | -} |
547 | kronecker :: (Product t) => Matrix t -> Matrix t -> Matrix t | 727 | kronecker :: (Product t) => Matrix t -> Matrix t -> Matrix t |
548 | kronecker a b = fromBlocks | 728 | kronecker a b = fromBlocks |
549 | . splitEvery (cols a) | 729 | . chunksOf (cols a) |
550 | . map (reshape (cols b)) | 730 | . map (reshape (cols b)) |
551 | . toRows | 731 | . toRows |
552 | $ flatten a `outer` flatten b | 732 | $ flatten a `outer` flatten b |
@@ -555,12 +735,12 @@ kronecker a b = fromBlocks | |||
555 | 735 | ||
556 | 736 | ||
557 | class Convert t where | 737 | class Convert t where |
558 | real :: Container c t => c (RealOf t) -> c t | 738 | real :: Complexable c => c (RealOf t) -> c t |
559 | complex :: Container c t => c t -> c (ComplexOf t) | 739 | complex :: Complexable c => c t -> c (ComplexOf t) |
560 | single :: Container c t => c t -> c (SingleOf t) | 740 | single :: Complexable c => c t -> c (SingleOf t) |
561 | double :: Container c t => c t -> c (DoubleOf t) | 741 | double :: Complexable c => c t -> c (DoubleOf t) |
562 | toComplex :: (Container c t, RealElement t) => (c t, c t) -> c (Complex t) | 742 | toComplex :: (Complexable c, RealElement t) => (c t, c t) -> c (Complex t) |
563 | fromComplex :: (Container c t, RealElement t) => c (Complex t) -> (c t, c t) | 743 | fromComplex :: (Complexable c, RealElement t) => c (Complex t) -> (c t, c t) |
564 | 744 | ||
565 | 745 | ||
566 | instance Convert Double where | 746 | instance Convert Double where |
@@ -605,6 +785,9 @@ type instance RealOf (Complex Double) = Double | |||
605 | type instance RealOf Float = Float | 785 | type instance RealOf Float = Float |
606 | type instance RealOf (Complex Float) = Float | 786 | type instance RealOf (Complex Float) = Float |
607 | 787 | ||
788 | type instance RealOf I = I | ||
789 | type instance RealOf Z = Z | ||
790 | |||
608 | type family ComplexOf x | 791 | type family ComplexOf x |
609 | 792 | ||
610 | type instance ComplexOf Double = Complex Double | 793 | type instance ComplexOf Double = Complex Double |
@@ -642,9 +825,6 @@ buildV n f = fromList [f k | k <- ks] | |||
642 | where ks = map fromIntegral [0 .. (n-1)] | 825 | where ks = map fromIntegral [0 .. (n-1)] |
643 | 826 | ||
644 | -------------------------------------------------------- | 827 | -------------------------------------------------------- |
645 | -- | conjugate transpose | ||
646 | ctrans :: (Container Vector e, Element e) => Matrix e -> Matrix e | ||
647 | ctrans = liftMatrix conj' . trans | ||
648 | 828 | ||
649 | -- | Creates a square matrix with a given diagonal. | 829 | -- | Creates a square matrix with a given diagonal. |
650 | diag :: (Num a, Element a) => Vector a -> Matrix a | 830 | diag :: (Num a, Element a) => Vector a -> Matrix a |
@@ -683,31 +863,80 @@ accumM m0 f xs = ST.runSTMatrix $ do | |||
683 | 863 | ||
684 | ---------------------------------------------------------------------- | 864 | ---------------------------------------------------------------------- |
685 | 865 | ||
686 | condM a b l e t = matrixFromVector RowMajor (rows a'') (cols a'') $ cond a' b' l' e' t' | 866 | compareM a b = matrixFromVector RowMajor (rows a'') (cols a'') $ ccompare' a' b' |
867 | where | ||
868 | args@(a'':_) = conformMs [a,b] | ||
869 | [a', b'] = map flatten args | ||
870 | |||
871 | compareCV f a b = f a' b' | ||
872 | where | ||
873 | [a', b'] = conformVs [a,b] | ||
874 | |||
875 | selectM c l e t = matrixFromVector RowMajor (rows a'') (cols a'') $ cselect' (toInt c') l' e' t' | ||
687 | where | 876 | where |
688 | args@(a'':_) = conformMs [a,b,l,e,t] | 877 | args@(a'':_) = conformMs [fromInt c,l,e,t] |
689 | [a', b', l', e', t'] = map flatten args | 878 | [c', l', e', t'] = map flatten args |
690 | 879 | ||
691 | condV f a b l e t = f a' b' l' e' t' | 880 | selectCV f c l e t = f (toInt c') l' e' t' |
692 | where | 881 | where |
693 | [a', b', l', e', t'] = conformVs [a,b,l,e,t] | 882 | [c', l', e', t'] = conformVs [fromInt c,l,e,t] |
694 | 883 | ||
695 | -------------------------------------------------------------------------------- | 884 | -------------------------------------------------------------------------------- |
696 | 885 | ||
886 | class CTrans t | ||
887 | where | ||
888 | ctrans :: Matrix t -> Matrix t | ||
889 | ctrans = trans | ||
890 | |||
891 | instance CTrans Float | ||
892 | instance CTrans R | ||
893 | instance CTrans I | ||
894 | instance CTrans Z | ||
895 | |||
896 | instance CTrans C | ||
897 | where | ||
898 | ctrans = conj . trans | ||
899 | |||
900 | instance CTrans (Complex Float) | ||
901 | where | ||
902 | ctrans = conj . trans | ||
903 | |||
697 | class Transposable m mt | m -> mt, mt -> m | 904 | class Transposable m mt | m -> mt, mt -> m |
698 | where | 905 | where |
699 | -- | (conjugate) transpose | 906 | -- | conjugate transpose |
700 | tr :: m -> mt | 907 | tr :: m -> mt |
908 | -- | transpose | ||
909 | tr' :: m -> mt | ||
910 | |||
911 | instance (CTrans t, Container Vector t) => Transposable (Matrix t) (Matrix t) | ||
912 | where | ||
913 | tr = ctrans | ||
914 | tr' = trans | ||
915 | |||
916 | class Additive c | ||
917 | where | ||
918 | add :: c -> c -> c | ||
919 | |||
920 | class Linear t c | ||
921 | where | ||
922 | scale :: t -> c t -> c t | ||
923 | |||
924 | |||
925 | instance Container Vector t => Linear t Vector | ||
926 | where | ||
927 | scale = scale' | ||
928 | |||
929 | instance Container Matrix t => Linear t Matrix | ||
930 | where | ||
931 | scale = scale' | ||
701 | 932 | ||
702 | instance (Container Vector t) => Transposable (Matrix t) (Matrix t) | 933 | instance Container Vector t => Additive (Vector t) |
703 | where | 934 | where |
704 | tr = ctrans | 935 | add = add' |
705 | 936 | ||
706 | class Linear t v | 937 | instance Container Matrix t => Additive (Matrix t) |
707 | where | 938 | where |
708 | scalarL :: t -> v | 939 | add = add' |
709 | addL :: v -> v -> v | ||
710 | scaleL :: t -> v -> v | ||
711 | 940 | ||
712 | 941 | ||
713 | class Testable t | 942 | class Testable t |
diff --git a/packages/base/src/Numeric/LinearAlgebra/Random.hs b/packages/base/src/Internal/Random.hs index b66988e..8c792eb 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Random.hs +++ b/packages/base/src/Internal/Random.hs | |||
@@ -10,7 +10,7 @@ | |||
10 | -- | 10 | -- |
11 | ----------------------------------------------------------------------------- | 11 | ----------------------------------------------------------------------------- |
12 | 12 | ||
13 | module Numeric.LinearAlgebra.Random ( | 13 | module Internal.Random ( |
14 | Seed, | 14 | Seed, |
15 | RandDist(..), | 15 | RandDist(..), |
16 | randomVector, | 16 | randomVector, |
@@ -19,13 +19,13 @@ module Numeric.LinearAlgebra.Random ( | |||
19 | rand, randn | 19 | rand, randn |
20 | ) where | 20 | ) where |
21 | 21 | ||
22 | import Numeric.Vectorized | 22 | import Internal.Vectorized |
23 | import Data.Packed | 23 | import Internal.Vector |
24 | import Data.Packed.Internal.Numeric | 24 | import Internal.Matrix |
25 | import Numeric.LinearAlgebra.Algorithms | 25 | import Internal.Numeric |
26 | import Internal.Algorithms | ||
26 | import System.Random(randomIO) | 27 | import System.Random(randomIO) |
27 | 28 | ||
28 | |||
29 | -- | Obtains a matrix whose rows are pseudorandom samples from a multivariate | 29 | -- | Obtains a matrix whose rows are pseudorandom samples from a multivariate |
30 | -- Gaussian distribution. | 30 | -- Gaussian distribution. |
31 | gaussianSample :: Seed | 31 | gaussianSample :: Seed |
diff --git a/packages/base/src/Data/Packed/ST.hs b/packages/base/src/Internal/ST.hs index 5c45c7b..544c9e4 100644 --- a/packages/base/src/Data/Packed/ST.hs +++ b/packages/base/src/Internal/ST.hs | |||
@@ -1,28 +1,29 @@ | |||
1 | {-# LANGUAGE CPP #-} | ||
2 | {-# LANGUAGE TypeOperators #-} | ||
3 | {-# LANGUAGE Rank2Types #-} | 1 | {-# LANGUAGE Rank2Types #-} |
4 | {-# LANGUAGE BangPatterns #-} | 2 | {-# LANGUAGE BangPatterns #-} |
3 | {-# LANGUAGE ViewPatterns #-} | ||
4 | |||
5 | ----------------------------------------------------------------------------- | 5 | ----------------------------------------------------------------------------- |
6 | -- | | 6 | -- | |
7 | -- Module : Data.Packed.ST | 7 | -- Module : Internal.ST |
8 | -- Copyright : (c) Alberto Ruiz 2008 | 8 | -- Copyright : (c) Alberto Ruiz 2008 |
9 | -- License : BSD3 | 9 | -- License : BSD3 |
10 | -- Maintainer : Alberto Ruiz | 10 | -- Maintainer : Alberto Ruiz |
11 | -- Stability : provisional | 11 | -- Stability : provisional |
12 | -- | 12 | -- |
13 | -- In-place manipulation inside the ST monad. | 13 | -- In-place manipulation inside the ST monad. |
14 | -- See examples/inplace.hs in the distribution. | 14 | -- See @examples/inplace.hs@ in the repository. |
15 | -- | 15 | -- |
16 | ----------------------------------------------------------------------------- | 16 | ----------------------------------------------------------------------------- |
17 | {-# OPTIONS_HADDOCK hide #-} | ||
18 | 17 | ||
19 | module Data.Packed.ST ( | 18 | module Internal.ST ( |
19 | ST, runST, | ||
20 | -- * Mutable Vectors | 20 | -- * Mutable Vectors |
21 | STVector, newVector, thawVector, freezeVector, runSTVector, | 21 | STVector, newVector, thawVector, freezeVector, runSTVector, |
22 | readVector, writeVector, modifyVector, liftSTVector, | 22 | readVector, writeVector, modifyVector, liftSTVector, |
23 | -- * Mutable Matrices | 23 | -- * Mutable Matrices |
24 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, | 24 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, |
25 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, | 25 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, |
26 | mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), gemmm, Slice(..), | ||
26 | -- * Unsafe functions | 27 | -- * Unsafe functions |
27 | newUndefinedVector, | 28 | newUndefinedVector, |
28 | unsafeReadVector, unsafeWriteVector, | 29 | unsafeReadVector, unsafeWriteVector, |
@@ -32,16 +33,12 @@ module Data.Packed.ST ( | |||
32 | unsafeThawMatrix, unsafeFreezeMatrix | 33 | unsafeThawMatrix, unsafeFreezeMatrix |
33 | ) where | 34 | ) where |
34 | 35 | ||
35 | import Data.Packed.Internal | 36 | import Internal.Vector |
36 | 37 | import Internal.Matrix | |
38 | import Internal.Vectorized | ||
37 | import Control.Monad.ST(ST, runST) | 39 | import Control.Monad.ST(ST, runST) |
38 | import Foreign.Storable(Storable, peekElemOff, pokeElemOff) | 40 | import Foreign.Storable(Storable, peekElemOff, pokeElemOff) |
39 | |||
40 | #if MIN_VERSION_base(4,4,0) | ||
41 | import Control.Monad.ST.Unsafe(unsafeIOToST) | 41 | import Control.Monad.ST.Unsafe(unsafeIOToST) |
42 | #else | ||
43 | import Control.Monad.ST(unsafeIOToST) | ||
44 | #endif | ||
45 | 42 | ||
46 | {-# INLINE ioReadV #-} | 43 | {-# INLINE ioReadV #-} |
47 | ioReadV :: Storable t => Vector t -> Int -> IO t | 44 | ioReadV :: Storable t => Vector t -> Int -> IO t |
@@ -74,13 +71,13 @@ unsafeWriteVector (STVector x) k = unsafeIOToST . ioWriteV x k | |||
74 | modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s () | 71 | modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s () |
75 | modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k | 72 | modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k |
76 | 73 | ||
77 | liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s1 t -> ST s2 a | 74 | liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s t -> ST s a |
78 | liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x | 75 | liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x |
79 | 76 | ||
80 | freezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) | 77 | freezeVector :: (Storable t) => STVector s t -> ST s (Vector t) |
81 | freezeVector v = liftSTVector id v | 78 | freezeVector v = liftSTVector id v |
82 | 79 | ||
83 | unsafeFreezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) | 80 | unsafeFreezeVector :: (Storable t) => STVector s t -> ST s (Vector t) |
84 | unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x | 81 | unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x |
85 | 82 | ||
86 | {-# INLINE safeIndexV #-} | 83 | {-# INLINE safeIndexV #-} |
@@ -112,17 +109,17 @@ newVector x n = do | |||
112 | 109 | ||
113 | {-# INLINE ioReadM #-} | 110 | {-# INLINE ioReadM #-} |
114 | ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t | 111 | ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t |
115 | ioReadM (Matrix _ nc cv RowMajor) r c = ioReadV cv (r*nc+c) | 112 | ioReadM m r c = ioReadV (xdat m) (r * xRow m + c * xCol m) |
116 | ioReadM (Matrix nr _ fv ColumnMajor) r c = ioReadV fv (c*nr+r) | 113 | |
117 | 114 | ||
118 | {-# INLINE ioWriteM #-} | 115 | {-# INLINE ioWriteM #-} |
119 | ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO () | 116 | ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO () |
120 | ioWriteM (Matrix _ nc cv RowMajor) r c val = ioWriteV cv (r*nc+c) val | 117 | ioWriteM m r c val = ioWriteV (xdat m) (r * xRow m + c * xCol m) val |
121 | ioWriteM (Matrix nr _ fv ColumnMajor) r c val = ioWriteV fv (c*nr+r) val | 118 | |
122 | 119 | ||
123 | newtype STMatrix s t = STMatrix (Matrix t) | 120 | newtype STMatrix s t = STMatrix (Matrix t) |
124 | 121 | ||
125 | thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) | 122 | thawMatrix :: Element t => Matrix t -> ST s (STMatrix s t) |
126 | thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix | 123 | thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix |
127 | 124 | ||
128 | unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) | 125 | unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) |
@@ -143,16 +140,17 @@ unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c | |||
143 | modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () | 140 | modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () |
144 | modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c | 141 | modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c |
145 | 142 | ||
146 | liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s1 t -> ST s2 a | 143 | liftSTMatrix :: (Element t) => (Matrix t -> a) -> STMatrix s t -> ST s a |
147 | liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x | 144 | liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x |
148 | 145 | ||
149 | unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) | 146 | unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) |
150 | unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x | 147 | unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x |
151 | 148 | ||
152 | freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) | 149 | |
150 | freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t) | ||
153 | freezeMatrix m = liftSTMatrix id m | 151 | freezeMatrix m = liftSTMatrix id m |
154 | 152 | ||
155 | cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o) | 153 | cloneMatrix m = copy (orderOf m) m |
156 | 154 | ||
157 | {-# INLINE safeIndexM #-} | 155 | {-# INLINE safeIndexM #-} |
158 | safeIndexM f (STMatrix m) r c | 156 | safeIndexM f (STMatrix m) r c |
@@ -169,6 +167,9 @@ readMatrix = safeIndexM unsafeReadMatrix | |||
169 | writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () | 167 | writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () |
170 | writeMatrix = safeIndexM unsafeWriteMatrix | 168 | writeMatrix = safeIndexM unsafeWriteMatrix |
171 | 169 | ||
170 | setMatrix :: Element t => STMatrix s t -> Int -> Int -> Matrix t -> ST s () | ||
171 | setMatrix (STMatrix x) i j m = unsafeIOToST $ setRect i j m x | ||
172 | |||
172 | newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t) | 173 | newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t) |
173 | newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c | 174 | newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c |
174 | 175 | ||
@@ -176,3 +177,73 @@ newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c | |||
176 | newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t) | 177 | newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t) |
177 | newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c) | 178 | newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c) |
178 | 179 | ||
180 | -------------------------------------------------------------------------------- | ||
181 | |||
182 | data ColRange = AllCols | ||
183 | | ColRange Int Int | ||
184 | | Col Int | ||
185 | | FromCol Int | ||
186 | |||
187 | getColRange c AllCols = (0,c-1) | ||
188 | getColRange c (ColRange a b) = (a `mod` c, b `mod` c) | ||
189 | getColRange c (Col a) = (a `mod` c, a `mod` c) | ||
190 | getColRange c (FromCol a) = (a `mod` c, c-1) | ||
191 | |||
192 | data RowRange = AllRows | ||
193 | | RowRange Int Int | ||
194 | | Row Int | ||
195 | | FromRow Int | ||
196 | |||
197 | getRowRange r AllRows = (0,r-1) | ||
198 | getRowRange r (RowRange a b) = (a `mod` r, b `mod` r) | ||
199 | getRowRange r (Row a) = (a `mod` r, a `mod` r) | ||
200 | getRowRange r (FromRow a) = (a `mod` r, r-1) | ||
201 | |||
202 | data RowOper t = AXPY t Int Int ColRange | ||
203 | | SCAL t RowRange ColRange | ||
204 | | SWAP Int Int ColRange | ||
205 | |||
206 | rowOper :: (Num t, Element t) => RowOper t -> STMatrix s t -> ST s () | ||
207 | |||
208 | rowOper (AXPY x i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 0 x i1' i2' j1 j2 m | ||
209 | where | ||
210 | (j1,j2) = getColRange (cols m) r | ||
211 | i1' = i1 `mod` (rows m) | ||
212 | i2' = i2 `mod` (rows m) | ||
213 | |||
214 | rowOper (SCAL x rr rc) (STMatrix m) = unsafeIOToST $ rowOp 1 x i1 i2 j1 j2 m | ||
215 | where | ||
216 | (i1,i2) = getRowRange (rows m) rr | ||
217 | (j1,j2) = getColRange (cols m) rc | ||
218 | |||
219 | rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m | ||
220 | where | ||
221 | (j1,j2) = getColRange (cols m) r | ||
222 | i1' = i1 `mod` (rows m) | ||
223 | i2' = i2 `mod` (rows m) | ||
224 | |||
225 | |||
226 | extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) | ||
227 | where | ||
228 | (i1,i2) = getRowRange (rows m) rr | ||
229 | (j1,j2) = getColRange (cols m) rc | ||
230 | |||
231 | -- | r0 c0 height width | ||
232 | data Slice s t = Slice (STMatrix s t) Int Int Int Int | ||
233 | |||
234 | slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m | ||
235 | |||
236 | gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () | ||
237 | gemmm beta (slice->r) alpha (slice->a) (slice->b) = res | ||
238 | where | ||
239 | res = unsafeIOToST (gemm v a b r) | ||
240 | v = fromList [alpha,beta] | ||
241 | |||
242 | |||
243 | mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) | ||
244 | mutable f a = runST $ do | ||
245 | x <- thawMatrix a | ||
246 | info <- f (rows a, cols a) x | ||
247 | r <- unsafeFreezeMatrix x | ||
248 | return (r,info) | ||
249 | |||
diff --git a/packages/base/src/Numeric/Sparse.hs b/packages/base/src/Internal/Sparse.hs index f1516ec..1604e7e 100644 --- a/packages/base/src/Numeric/Sparse.hs +++ b/packages/base/src/Internal/Sparse.hs | |||
@@ -2,7 +2,7 @@ | |||
2 | {-# LANGUAGE MultiParamTypeClasses #-} | 2 | {-# LANGUAGE MultiParamTypeClasses #-} |
3 | {-# LANGUAGE FlexibleInstances #-} | 3 | {-# LANGUAGE FlexibleInstances #-} |
4 | 4 | ||
5 | module Numeric.Sparse( | 5 | module Internal.Sparse( |
6 | GMatrix(..), CSR(..), mkCSR, fromCSR, | 6 | GMatrix(..), CSR(..), mkCSR, fromCSR, |
7 | mkSparse, mkDiagR, mkDense, | 7 | mkSparse, mkDiagR, mkDense, |
8 | AssocMatrix, | 8 | AssocMatrix, |
@@ -10,7 +10,9 @@ module Numeric.Sparse( | |||
10 | gmXv, (!#>) | 10 | gmXv, (!#>) |
11 | )where | 11 | )where |
12 | 12 | ||
13 | import Data.Packed.Numeric | 13 | import Internal.Vector |
14 | import Internal.Matrix | ||
15 | import Internal.Numeric | ||
14 | import qualified Data.Vector.Storable as V | 16 | import qualified Data.Vector.Storable as V |
15 | import Data.Function(on) | 17 | import Data.Function(on) |
16 | import Control.Arrow((***)) | 18 | import Control.Arrow((***)) |
@@ -18,7 +20,7 @@ import Control.Monad(when) | |||
18 | import Data.List(groupBy, sort) | 20 | import Data.List(groupBy, sort) |
19 | import Foreign.C.Types(CInt(..)) | 21 | import Foreign.C.Types(CInt(..)) |
20 | 22 | ||
21 | import Data.Packed.Development | 23 | import Internal.Devel |
22 | import System.IO.Unsafe(unsafePerformIO) | 24 | import System.IO.Unsafe(unsafePerformIO) |
23 | import Foreign(Ptr) | 25 | import Foreign(Ptr) |
24 | import Text.Printf(printf) | 26 | import Text.Printf(printf) |
@@ -142,13 +144,13 @@ gmXv :: GMatrix -> Vector Double -> Vector Double | |||
142 | gmXv SparseR { gmCSR = CSR{..}, .. } v = unsafePerformIO $ do | 144 | gmXv SparseR { gmCSR = CSR{..}, .. } v = unsafePerformIO $ do |
143 | dim v /= nCols ~!~ printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) | 145 | dim v /= nCols ~!~ printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) |
144 | r <- createVector nRows | 146 | r <- createVector nRows |
145 | app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv" | 147 | c_smXv # csrVals # csrCols # csrRows # v # r #|"CSRXv" |
146 | return r | 148 | return r |
147 | 149 | ||
148 | gmXv SparseC { gmCSC = CSC{..}, .. } v = unsafePerformIO $ do | 150 | gmXv SparseC { gmCSC = CSC{..}, .. } v = unsafePerformIO $ do |
149 | dim v /= nCols ~!~ printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) | 151 | dim v /= nCols ~!~ printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) |
150 | r <- createVector nRows | 152 | r <- createVector nRows |
151 | app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv" | 153 | c_smTXv # cscVals # cscRows # cscCols # v # r #|"CSCXv" |
152 | return r | 154 | return r |
153 | 155 | ||
154 | gmXv Diag{..} v | 156 | gmXv Diag{..} v |
@@ -195,10 +197,12 @@ toDense asm = assoc (r+1,c+1) 0 asm | |||
195 | instance Transposable CSR CSC | 197 | instance Transposable CSR CSC |
196 | where | 198 | where |
197 | tr (CSR vs cs rs n m) = CSC vs cs rs m n | 199 | tr (CSR vs cs rs n m) = CSC vs cs rs m n |
200 | tr' = tr | ||
198 | 201 | ||
199 | instance Transposable CSC CSR | 202 | instance Transposable CSC CSR |
200 | where | 203 | where |
201 | tr (CSC vs rs cs n m) = CSR vs rs cs m n | 204 | tr (CSC vs rs cs n m) = CSR vs rs cs m n |
205 | tr' = tr | ||
202 | 206 | ||
203 | instance Transposable GMatrix GMatrix | 207 | instance Transposable GMatrix GMatrix |
204 | where | 208 | where |
@@ -206,5 +210,5 @@ instance Transposable GMatrix GMatrix | |||
206 | tr (SparseC s n m) = SparseR (tr s) m n | 210 | tr (SparseC s n m) = SparseR (tr s) m n |
207 | tr (Diag v n m) = Diag v m n | 211 | tr (Diag v n m) = Diag v m n |
208 | tr (Dense a n m) = Dense (tr a) m n | 212 | tr (Dense a n m) = Dense (tr a) m n |
209 | 213 | tr' = tr | |
210 | 214 | ||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs b/packages/base/src/Internal/Static.hs index ec02cf6..0068313 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs +++ b/packages/base/src/Internal/Static.hs | |||
@@ -13,27 +13,30 @@ | |||
13 | {-# LANGUAGE ViewPatterns #-} | 13 | {-# LANGUAGE ViewPatterns #-} |
14 | 14 | ||
15 | {- | | 15 | {- | |
16 | Module : Numeric.LinearAlgebra.Static.Internal | 16 | Module : Internal.Static |
17 | Copyright : (c) Alberto Ruiz 2006-14 | 17 | Copyright : (c) Alberto Ruiz 2006-14 |
18 | License : BSD3 | 18 | License : BSD3 |
19 | Stability : provisional | 19 | Stability : provisional |
20 | 20 | ||
21 | -} | 21 | -} |
22 | 22 | ||
23 | module Numeric.LinearAlgebra.Static.Internal where | 23 | module Internal.Static where |
24 | 24 | ||
25 | 25 | ||
26 | import GHC.TypeLits | 26 | import GHC.TypeLits |
27 | import qualified Numeric.LinearAlgebra.HMatrix as LA | 27 | import qualified Numeric.LinearAlgebra as LA |
28 | import Numeric.LinearAlgebra.HMatrix hiding (konst,size) | 28 | import Numeric.LinearAlgebra hiding (konst,size,R,C) |
29 | import Data.Packed as D | 29 | import Internal.Vector as D hiding (R,C) |
30 | import Data.Packed.ST | 30 | import Internal.ST |
31 | import Data.Proxy(Proxy) | 31 | import Data.Proxy(Proxy) |
32 | import Foreign.Storable(Storable) | 32 | import Foreign.Storable(Storable) |
33 | import Text.Printf | 33 | import Text.Printf |
34 | 34 | ||
35 | -------------------------------------------------------------------------------- | 35 | -------------------------------------------------------------------------------- |
36 | 36 | ||
37 | type ℝ = Double | ||
38 | type ℂ = Complex Double | ||
39 | |||
37 | newtype Dim (n :: Nat) t = Dim t | 40 | newtype Dim (n :: Nat) t = Dim t |
38 | deriving Show | 41 | deriving Show |
39 | 42 | ||
@@ -244,11 +247,14 @@ instance (KnownNat n, KnownNat m) => Transposable (L m n) (L n m) | |||
244 | where | 247 | where |
245 | tr a@(isDiag -> Just _) = mkL (extract a) | 248 | tr a@(isDiag -> Just _) = mkL (extract a) |
246 | tr (extract -> a) = mkL (tr a) | 249 | tr (extract -> a) = mkL (tr a) |
250 | tr' = tr | ||
247 | 251 | ||
248 | instance (KnownNat n, KnownNat m) => Transposable (M m n) (M n m) | 252 | instance (KnownNat n, KnownNat m) => Transposable (M m n) (M n m) |
249 | where | 253 | where |
250 | tr a@(isDiagC -> Just _) = mkM (extract a) | 254 | tr a@(isDiagC -> Just _) = mkM (extract a) |
251 | tr (extract -> a) = mkM (tr a) | 255 | tr (extract -> a) = mkM (tr a) |
256 | tr' a@(isDiagC -> Just _) = mkM (extract a) | ||
257 | tr' (extract -> a) = mkM (tr' a) | ||
252 | 258 | ||
253 | -------------------------------------------------------------------------------- | 259 | -------------------------------------------------------------------------------- |
254 | 260 | ||
@@ -322,12 +328,12 @@ instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t)) | |||
322 | negate = lift1F negate | 328 | negate = lift1F negate |
323 | fromInteger x = Dim (fromInteger x) | 329 | fromInteger x = Dim (fromInteger x) |
324 | 330 | ||
325 | instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim n (Vector t)) | 331 | instance (Num (Vector t), Num (Matrix t), Fractional t, Numeric t) => Fractional (Dim n (Vector t)) |
326 | where | 332 | where |
327 | fromRational x = Dim (fromRational x) | 333 | fromRational x = Dim (fromRational x) |
328 | (/) = lift2F (/) | 334 | (/) = lift2F (/) |
329 | 335 | ||
330 | instance (Floating (Vector t), Numeric t) => Floating (Dim n (Vector t)) where | 336 | instance (Fractional t, Floating (Vector t), Numeric t) => Floating (Dim n (Vector t)) where |
331 | sin = lift1F sin | 337 | sin = lift1F sin |
332 | cos = lift1F cos | 338 | cos = lift1F cos |
333 | tan = lift1F tan | 339 | tan = lift1F tan |
@@ -357,12 +363,12 @@ instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t))) | |||
357 | negate = (lift1F . lift1F) negate | 363 | negate = (lift1F . lift1F) negate |
358 | fromInteger x = Dim (Dim (fromInteger x)) | 364 | fromInteger x = Dim (Dim (fromInteger x)) |
359 | 365 | ||
360 | instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim m (Dim n (Matrix t))) | 366 | instance (Num (Vector t), Num (Matrix t), Fractional t, Numeric t) => Fractional (Dim m (Dim n (Matrix t))) |
361 | where | 367 | where |
362 | fromRational x = Dim (Dim (fromRational x)) | 368 | fromRational x = Dim (Dim (fromRational x)) |
363 | (/) = (lift2F.lift2F) (/) | 369 | (/) = (lift2F.lift2F) (/) |
364 | 370 | ||
365 | instance (Num (Vector t), Floating (Matrix t), Numeric t) => Floating (Dim m (Dim n (Matrix t))) where | 371 | instance (Num (Vector t), Floating (Matrix t), Fractional t, Numeric t) => Floating (Dim m (Dim n (Matrix t))) where |
366 | sin = (lift1F . lift1F) sin | 372 | sin = (lift1F . lift1F) sin |
367 | cos = (lift1F . lift1F) cos | 373 | cos = (lift1F . lift1F) cos |
368 | tan = (lift1F . lift1F) tan | 374 | tan = (lift1F . lift1F) tan |
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs new file mode 100644 index 0000000..cf42961 --- /dev/null +++ b/packages/base/src/Internal/Util.hs | |||
@@ -0,0 +1,896 @@ | |||
1 | {-# LANGUAGE FlexibleContexts #-} | ||
2 | {-# LANGUAGE FlexibleInstances #-} | ||
3 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
4 | {-# LANGUAGE FunctionalDependencies #-} | ||
5 | {-# LANGUAGE ViewPatterns #-} | ||
6 | |||
7 | |||
8 | ----------------------------------------------------------------------------- | ||
9 | {- | | ||
10 | Module : Internal.Util | ||
11 | Copyright : (c) Alberto Ruiz 2013 | ||
12 | License : BSD3 | ||
13 | Maintainer : Alberto Ruiz | ||
14 | Stability : provisional | ||
15 | |||
16 | -} | ||
17 | ----------------------------------------------------------------------------- | ||
18 | |||
19 | module Internal.Util( | ||
20 | |||
21 | -- * Convenience functions | ||
22 | vector, matrix, | ||
23 | disp, | ||
24 | formatSparse, | ||
25 | approxInt, | ||
26 | dispDots, | ||
27 | dispBlanks, | ||
28 | formatShort, | ||
29 | dispShort, | ||
30 | zeros, ones, | ||
31 | diagl, | ||
32 | row, | ||
33 | col, | ||
34 | (&), (¦), (|||), (——), (===), | ||
35 | (?), (¿), | ||
36 | Indexable(..), size, | ||
37 | Numeric, | ||
38 | rand, randn, | ||
39 | cross, | ||
40 | norm, | ||
41 | ℕ,ℤ,ℝ,ℂ,iC, | ||
42 | Normed(..), norm_Frob, norm_nuclear, | ||
43 | magnit, | ||
44 | unitary, | ||
45 | mt, | ||
46 | (~!~), | ||
47 | pairwiseD2, | ||
48 | rowOuters, | ||
49 | null1, | ||
50 | null1sym, | ||
51 | -- * Convolution | ||
52 | -- ** 1D | ||
53 | corr, conv, corrMin, | ||
54 | -- ** 2D | ||
55 | corr2, conv2, separable, | ||
56 | block2x2,block3x3,view1,unView1,foldMatrix, | ||
57 | gaussElim_1, gaussElim_2, gaussElim, | ||
58 | luST, luSolve', luSolve'', luPacked', luPacked'', | ||
59 | invershur | ||
60 | ) where | ||
61 | |||
62 | import Internal.Vector | ||
63 | import Internal.Matrix hiding (size) | ||
64 | import Internal.Numeric | ||
65 | import Internal.Element | ||
66 | import Internal.Container | ||
67 | import Internal.Vectorized | ||
68 | import Internal.IO | ||
69 | import Internal.Algorithms hiding (Normed,linearSolve',luSolve', luPacked') | ||
70 | import Numeric.Matrix() | ||
71 | import Numeric.Vector() | ||
72 | import Internal.Random | ||
73 | import Internal.Convolution | ||
74 | import Control.Monad(when,forM_) | ||
75 | import Text.Printf | ||
76 | import Data.List.Split(splitOn) | ||
77 | import Data.List(intercalate,sortBy,foldl') | ||
78 | import Control.Arrow((&&&),(***)) | ||
79 | import Data.Complex | ||
80 | import Data.Function(on) | ||
81 | import Internal.ST | ||
82 | |||
83 | type ℝ = Double | ||
84 | type ℕ = Int | ||
85 | type ℤ = Int | ||
86 | type ℂ = Complex Double | ||
87 | |||
88 | -- | imaginary unit | ||
89 | iC :: C | ||
90 | iC = 0:+1 | ||
91 | |||
92 | {- | Create a real vector. | ||
93 | |||
94 | >>> vector [1..5] | ||
95 | fromList [1.0,2.0,3.0,4.0,5.0] | ||
96 | |||
97 | -} | ||
98 | vector :: [R] -> Vector R | ||
99 | vector = fromList | ||
100 | |||
101 | {- | Create a real matrix. | ||
102 | |||
103 | >>> matrix 5 [1..15] | ||
104 | (3><5) | ||
105 | [ 1.0, 2.0, 3.0, 4.0, 5.0 | ||
106 | , 6.0, 7.0, 8.0, 9.0, 10.0 | ||
107 | , 11.0, 12.0, 13.0, 14.0, 15.0 ] | ||
108 | |||
109 | -} | ||
110 | matrix | ||
111 | :: Int -- ^ number of columns | ||
112 | -> [R] -- ^ elements in row order | ||
113 | -> Matrix R | ||
114 | matrix c = reshape c . fromList | ||
115 | |||
116 | |||
117 | {- | print a real matrix with given number of digits after the decimal point | ||
118 | |||
119 | >>> disp 5 $ ident 2 / 3 | ||
120 | 2x2 | ||
121 | 0.33333 0.00000 | ||
122 | 0.00000 0.33333 | ||
123 | |||
124 | -} | ||
125 | disp :: Int -> Matrix Double -> IO () | ||
126 | |||
127 | disp n = putStr . dispf n | ||
128 | |||
129 | |||
130 | {- | create a real diagonal matrix from a list | ||
131 | |||
132 | >>> diagl [1,2,3] | ||
133 | (3><3) | ||
134 | [ 1.0, 0.0, 0.0 | ||
135 | , 0.0, 2.0, 0.0 | ||
136 | , 0.0, 0.0, 3.0 ] | ||
137 | |||
138 | -} | ||
139 | diagl :: [Double] -> Matrix Double | ||
140 | diagl = diag . fromList | ||
141 | |||
142 | -- | a real matrix of zeros | ||
143 | zeros :: Int -- ^ rows | ||
144 | -> Int -- ^ columns | ||
145 | -> Matrix Double | ||
146 | zeros r c = konst 0 (r,c) | ||
147 | |||
148 | -- | a real matrix of ones | ||
149 | ones :: Int -- ^ rows | ||
150 | -> Int -- ^ columns | ||
151 | -> Matrix Double | ||
152 | ones r c = konst 1 (r,c) | ||
153 | |||
154 | -- | concatenation of real vectors | ||
155 | infixl 3 & | ||
156 | (&) :: Vector Double -> Vector Double -> Vector Double | ||
157 | a & b = vjoin [a,b] | ||
158 | |||
159 | {- | horizontal concatenation | ||
160 | |||
161 | >>> ident 3 ||| konst 7 (3,4) | ||
162 | (3><7) | ||
163 | [ 1.0, 0.0, 0.0, 7.0, 7.0, 7.0, 7.0 | ||
164 | , 0.0, 1.0, 0.0, 7.0, 7.0, 7.0, 7.0 | ||
165 | , 0.0, 0.0, 1.0, 7.0, 7.0, 7.0, 7.0 ] | ||
166 | |||
167 | -} | ||
168 | infixl 3 ||| | ||
169 | (|||) :: Element t => Matrix t -> Matrix t -> Matrix t | ||
170 | a ||| b = fromBlocks [[a,b]] | ||
171 | |||
172 | -- | a synonym for ('|||') (unicode 0x00a6, broken bar) | ||
173 | infixl 3 ¦ | ||
174 | (¦) :: Matrix Double -> Matrix Double -> Matrix Double | ||
175 | (¦) = (|||) | ||
176 | |||
177 | |||
178 | -- | vertical concatenation | ||
179 | -- | ||
180 | (===) :: Element t => Matrix t -> Matrix t -> Matrix t | ||
181 | infixl 2 === | ||
182 | a === b = fromBlocks [[a],[b]] | ||
183 | |||
184 | -- | a synonym for ('===') (unicode 0x2014, em dash) | ||
185 | (——) :: Matrix Double -> Matrix Double -> Matrix Double | ||
186 | infixl 2 —— | ||
187 | (——) = (===) | ||
188 | |||
189 | |||
190 | -- | create a single row real matrix from a list | ||
191 | -- | ||
192 | -- >>> row [2,3,1,8] | ||
193 | -- (1><4) | ||
194 | -- [ 2.0, 3.0, 1.0, 8.0 ] | ||
195 | -- | ||
196 | row :: [Double] -> Matrix Double | ||
197 | row = asRow . fromList | ||
198 | |||
199 | -- | create a single column real matrix from a list | ||
200 | -- | ||
201 | -- >>> col [7,-2,4] | ||
202 | -- (3><1) | ||
203 | -- [ 7.0 | ||
204 | -- , -2.0 | ||
205 | -- , 4.0 ] | ||
206 | -- | ||
207 | col :: [Double] -> Matrix Double | ||
208 | col = asColumn . fromList | ||
209 | |||
210 | {- | extract rows | ||
211 | |||
212 | >>> (20><4) [1..] ? [2,1,1] | ||
213 | (3><4) | ||
214 | [ 9.0, 10.0, 11.0, 12.0 | ||
215 | , 5.0, 6.0, 7.0, 8.0 | ||
216 | , 5.0, 6.0, 7.0, 8.0 ] | ||
217 | |||
218 | -} | ||
219 | infixl 9 ? | ||
220 | (?) :: Element t => Matrix t -> [Int] -> Matrix t | ||
221 | (?) = flip extractRows | ||
222 | |||
223 | {- | extract columns | ||
224 | |||
225 | (unicode 0x00bf, inverted question mark, Alt-Gr ?) | ||
226 | |||
227 | >>> (3><4) [1..] ¿ [3,0] | ||
228 | (3><2) | ||
229 | [ 4.0, 1.0 | ||
230 | , 8.0, 5.0 | ||
231 | , 12.0, 9.0 ] | ||
232 | |||
233 | -} | ||
234 | infixl 9 ¿ | ||
235 | (¿) :: Element t => Matrix t -> [Int] -> Matrix t | ||
236 | (¿)= flip extractColumns | ||
237 | |||
238 | |||
239 | cross :: Product t => Vector t -> Vector t -> Vector t | ||
240 | -- ^ cross product (for three-element vectors) | ||
241 | cross x y | dim x == 3 && dim y == 3 = fromList [z1,z2,z3] | ||
242 | | otherwise = error $ "the cross product requires 3-element vectors (sizes given: " | ||
243 | ++show (dim x)++" and "++show (dim y)++")" | ||
244 | where | ||
245 | [x1,x2,x3] = toList x | ||
246 | [y1,y2,y3] = toList y | ||
247 | z1 = x2*y3-x3*y2 | ||
248 | z2 = x3*y1-x1*y3 | ||
249 | z3 = x1*y2-x2*y1 | ||
250 | |||
251 | {-# SPECIALIZE cross :: Vector Double -> Vector Double -> Vector Double #-} | ||
252 | {-# SPECIALIZE cross :: Vector (Complex Double) -> Vector (Complex Double) -> Vector (Complex Double) #-} | ||
253 | |||
254 | norm :: Vector Double -> Double | ||
255 | -- ^ 2-norm of real vector | ||
256 | norm = pnorm PNorm2 | ||
257 | |||
258 | class Normed a | ||
259 | where | ||
260 | norm_0 :: a -> R | ||
261 | norm_1 :: a -> R | ||
262 | norm_2 :: a -> R | ||
263 | norm_Inf :: a -> R | ||
264 | |||
265 | |||
266 | instance Normed (Vector R) | ||
267 | where | ||
268 | norm_0 v = sumElements (step (abs v - scalar (eps*normInf v))) | ||
269 | norm_1 = pnorm PNorm1 | ||
270 | norm_2 = pnorm PNorm2 | ||
271 | norm_Inf = pnorm Infinity | ||
272 | |||
273 | instance Normed (Vector C) | ||
274 | where | ||
275 | norm_0 v = sumElements (step (fst (fromComplex (abs v)) - scalar (eps*normInf v))) | ||
276 | norm_1 = pnorm PNorm1 | ||
277 | norm_2 = pnorm PNorm2 | ||
278 | norm_Inf = pnorm Infinity | ||
279 | |||
280 | instance Normed (Matrix R) | ||
281 | where | ||
282 | norm_0 = norm_0 . flatten | ||
283 | norm_1 = pnorm PNorm1 | ||
284 | norm_2 = pnorm PNorm2 | ||
285 | norm_Inf = pnorm Infinity | ||
286 | |||
287 | instance Normed (Matrix C) | ||
288 | where | ||
289 | norm_0 = norm_0 . flatten | ||
290 | norm_1 = pnorm PNorm1 | ||
291 | norm_2 = pnorm PNorm2 | ||
292 | norm_Inf = pnorm Infinity | ||
293 | |||
294 | instance Normed (Vector I) | ||
295 | where | ||
296 | norm_0 = fromIntegral . sumElements . step . abs | ||
297 | norm_1 = fromIntegral . norm1 | ||
298 | norm_2 v = sqrt . fromIntegral $ dot v v | ||
299 | norm_Inf = fromIntegral . normInf | ||
300 | |||
301 | instance Normed (Vector Z) | ||
302 | where | ||
303 | norm_0 = fromIntegral . sumElements . step . abs | ||
304 | norm_1 = fromIntegral . norm1 | ||
305 | norm_2 v = sqrt . fromIntegral $ dot v v | ||
306 | norm_Inf = fromIntegral . normInf | ||
307 | |||
308 | instance Normed (Vector Float) | ||
309 | where | ||
310 | norm_0 = norm_0 . double | ||
311 | norm_1 = norm_1 . double | ||
312 | norm_2 = norm_2 . double | ||
313 | norm_Inf = norm_Inf . double | ||
314 | |||
315 | instance Normed (Vector (Complex Float)) | ||
316 | where | ||
317 | norm_0 = norm_0 . double | ||
318 | norm_1 = norm_1 . double | ||
319 | norm_2 = norm_2 . double | ||
320 | norm_Inf = norm_Inf . double | ||
321 | |||
322 | |||
323 | norm_Frob :: (Normed (Vector t), Element t) => Matrix t -> R | ||
324 | norm_Frob = norm_2 . flatten | ||
325 | |||
326 | norm_nuclear :: Field t => Matrix t -> R | ||
327 | norm_nuclear = sumElements . singularValues | ||
328 | |||
329 | {- | Check if the absolute value or complex magnitude is greater than a given threshold | ||
330 | |||
331 | >>> magnit 1E-6 (1E-12 :: R) | ||
332 | False | ||
333 | >>> magnit 1E-6 (3+iC :: C) | ||
334 | True | ||
335 | >>> magnit 0 (3 :: I ./. 5) | ||
336 | True | ||
337 | |||
338 | -} | ||
339 | magnit :: (Element t, Normed (Vector t)) => R -> t -> Bool | ||
340 | magnit e x = norm_1 (fromList [x]) > e | ||
341 | |||
342 | |||
343 | -- | Obtains a vector in the same direction with 2-norm=1 | ||
344 | unitary :: Vector Double -> Vector Double | ||
345 | unitary v = v / scalar (norm v) | ||
346 | |||
347 | |||
348 | -- | trans . inv | ||
349 | mt :: Matrix Double -> Matrix Double | ||
350 | mt = trans . inv | ||
351 | |||
352 | -------------------------------------------------------------------------------- | ||
353 | {- | | ||
354 | |||
355 | >>> size $ vector [1..10] | ||
356 | 10 | ||
357 | >>> size $ (2><5)[1..10::Double] | ||
358 | (2,5) | ||
359 | |||
360 | -} | ||
361 | size :: Container c t => c t -> IndexOf c | ||
362 | size = size' | ||
363 | |||
364 | {- | Alternative indexing function. | ||
365 | |||
366 | >>> vector [1..10] ! 3 | ||
367 | 4.0 | ||
368 | |||
369 | On a matrix it gets the k-th row as a vector: | ||
370 | |||
371 | >>> matrix 5 [1..15] ! 1 | ||
372 | fromList [6.0,7.0,8.0,9.0,10.0] | ||
373 | |||
374 | >>> matrix 5 [1..15] ! 1 ! 3 | ||
375 | 9.0 | ||
376 | |||
377 | -} | ||
378 | class Indexable c t | c -> t , t -> c | ||
379 | where | ||
380 | infixl 9 ! | ||
381 | (!) :: c -> Int -> t | ||
382 | |||
383 | instance Indexable (Vector Double) Double | ||
384 | where | ||
385 | (!) = (@>) | ||
386 | |||
387 | instance Indexable (Vector Float) Float | ||
388 | where | ||
389 | (!) = (@>) | ||
390 | |||
391 | instance Indexable (Vector I) I | ||
392 | where | ||
393 | (!) = (@>) | ||
394 | |||
395 | instance Indexable (Vector Z) Z | ||
396 | where | ||
397 | (!) = (@>) | ||
398 | |||
399 | instance Indexable (Vector (Complex Double)) (Complex Double) | ||
400 | where | ||
401 | (!) = (@>) | ||
402 | |||
403 | instance Indexable (Vector (Complex Float)) (Complex Float) | ||
404 | where | ||
405 | (!) = (@>) | ||
406 | |||
407 | instance Element t => Indexable (Matrix t) (Vector t) | ||
408 | where | ||
409 | m!j = subVector (j*c) c (flatten m) | ||
410 | where | ||
411 | c = cols m | ||
412 | |||
413 | -------------------------------------------------------------------------------- | ||
414 | |||
415 | -- | Matrix of pairwise squared distances of row vectors | ||
416 | -- (using the matrix product trick in blog.smola.org) | ||
417 | pairwiseD2 :: Matrix Double -> Matrix Double -> Matrix Double | ||
418 | pairwiseD2 x y | ok = x2 `outer` oy + ox `outer` y2 - 2* x <> trans y | ||
419 | | otherwise = error $ "pairwiseD2 with different number of columns: " | ||
420 | ++ show (size x) ++ ", " ++ show (size y) | ||
421 | where | ||
422 | ox = one (rows x) | ||
423 | oy = one (rows y) | ||
424 | oc = one (cols x) | ||
425 | one k = konst 1 k | ||
426 | x2 = x * x <> oc | ||
427 | y2 = y * y <> oc | ||
428 | ok = cols x == cols y | ||
429 | |||
430 | -------------------------------------------------------------------------------- | ||
431 | |||
432 | {- | outer products of rows | ||
433 | |||
434 | >>> a | ||
435 | (3><2) | ||
436 | [ 1.0, 2.0 | ||
437 | , 10.0, 20.0 | ||
438 | , 100.0, 200.0 ] | ||
439 | >>> b | ||
440 | (3><3) | ||
441 | [ 1.0, 2.0, 3.0 | ||
442 | , 4.0, 5.0, 6.0 | ||
443 | , 7.0, 8.0, 9.0 ] | ||
444 | |||
445 | >>> rowOuters a (b ||| 1) | ||
446 | (3><8) | ||
447 | [ 1.0, 2.0, 3.0, 1.0, 2.0, 4.0, 6.0, 2.0 | ||
448 | , 40.0, 50.0, 60.0, 10.0, 80.0, 100.0, 120.0, 20.0 | ||
449 | , 700.0, 800.0, 900.0, 100.0, 1400.0, 1600.0, 1800.0, 200.0 ] | ||
450 | |||
451 | -} | ||
452 | rowOuters :: Matrix Double -> Matrix Double -> Matrix Double | ||
453 | rowOuters a b = a' * b' | ||
454 | where | ||
455 | a' = kronecker a (ones 1 (cols b)) | ||
456 | b' = kronecker (ones 1 (cols a)) b | ||
457 | |||
458 | -------------------------------------------------------------------------------- | ||
459 | |||
460 | -- | solution of overconstrained homogeneous linear system | ||
461 | null1 :: Matrix R -> Vector R | ||
462 | null1 = last . toColumns . snd . rightSV | ||
463 | |||
464 | -- | solution of overconstrained homogeneous symmetric linear system | ||
465 | null1sym :: Herm R -> Vector R | ||
466 | null1sym = last . toColumns . snd . eigSH | ||
467 | |||
468 | -------------------------------------------------------------------------------- | ||
469 | |||
470 | infixl 0 ~!~ | ||
471 | c ~!~ msg = when c (error msg) | ||
472 | |||
473 | -------------------------------------------------------------------------------- | ||
474 | |||
475 | formatSparse :: String -> String -> String -> Int -> Matrix Double -> String | ||
476 | |||
477 | formatSparse zeroI _zeroF sep _ (approxInt -> Just m) = format sep f m | ||
478 | where | ||
479 | f 0 = zeroI | ||
480 | f x = printf "%.0f" x | ||
481 | |||
482 | formatSparse zeroI zeroF sep n m = format sep f m | ||
483 | where | ||
484 | f x | abs (x::Double) < 2*peps = zeroI++zeroF | ||
485 | | abs (fromIntegral (round x::Int) - x) / abs x < 2*peps | ||
486 | = printf ("%.0f."++replicate n ' ') x | ||
487 | | otherwise = printf ("%."++show n++"f") x | ||
488 | |||
489 | approxInt m | ||
490 | | norm_Inf (v - vi) < 2*peps * norm_Inf v = Just (reshape (cols m) vi) | ||
491 | | otherwise = Nothing | ||
492 | where | ||
493 | v = flatten m | ||
494 | vi = roundVector v | ||
495 | |||
496 | dispDots n = putStr . formatSparse "." (replicate n ' ') " " n | ||
497 | |||
498 | dispBlanks n = putStr . formatSparse "" "" " " n | ||
499 | |||
500 | formatShort sep fmt maxr maxc m = auxm4 | ||
501 | where | ||
502 | (rm,cm) = size m | ||
503 | (r1,r2,r3) | ||
504 | | rm <= maxr = (rm,0,0) | ||
505 | | otherwise = (maxr-3,rm-maxr+1,2) | ||
506 | (c1,c2,c3) | ||
507 | | cm <= maxc = (cm,0,0) | ||
508 | | otherwise = (maxc-3,cm-maxc+1,2) | ||
509 | [ [a,_,b] | ||
510 | ,[_,_,_] | ||
511 | ,[c,_,d]] = toBlocks [r1,r2,r3] | ||
512 | [c1,c2,c3] m | ||
513 | auxm = fromBlocks [[a,b],[c,d]] | ||
514 | auxm2 | ||
515 | | cm > maxc = format "|" fmt auxm | ||
516 | | otherwise = format sep fmt auxm | ||
517 | auxm3 | ||
518 | | cm > maxc = map (f . splitOn "|") (lines auxm2) | ||
519 | | otherwise = (lines auxm2) | ||
520 | f items = intercalate sep (take (maxc-3) items) ++ " .. " ++ | ||
521 | intercalate sep (drop (maxc-3) items) | ||
522 | auxm4 | ||
523 | | rm > maxr = unlines (take (maxr-3) auxm3 ++ vsep : drop (maxr-3) auxm3) | ||
524 | | otherwise = unlines auxm3 | ||
525 | vsep = map g (head auxm3) | ||
526 | g '.' = ':' | ||
527 | g _ = ' ' | ||
528 | |||
529 | |||
530 | dispShort :: Int -> Int -> Int -> Matrix Double -> IO () | ||
531 | dispShort maxr maxc dec m = | ||
532 | printf "%dx%d\n%s" (rows m) (cols m) (formatShort " " fmt maxr maxc m) | ||
533 | where | ||
534 | fmt = printf ("%."++show dec ++"f") | ||
535 | |||
536 | -------------------------------------------------------------------------------- | ||
537 | |||
538 | -- matrix views | ||
539 | |||
540 | block2x2 r c m = [[m11,m12],[m21,m22]] | ||
541 | where | ||
542 | m11 = m ?? (Take r, Take c) | ||
543 | m12 = m ?? (Take r, Drop c) | ||
544 | m21 = m ?? (Drop r, Take c) | ||
545 | m22 = m ?? (Drop r, Drop c) | ||
546 | |||
547 | block3x3 r nr c nc m = [[m ?? (er !! i, ec !! j) | j <- [0..2] ] | i <- [0..2] ] | ||
548 | where | ||
549 | er = [ Range 0 1 (r-1), Range r 1 (r+nr-1), Drop (nr+r) ] | ||
550 | ec = [ Range 0 1 (c-1), Range c 1 (c+nc-1), Drop (nc+c) ] | ||
551 | |||
552 | view1 :: Numeric t => Matrix t -> Maybe (View1 t) | ||
553 | view1 m | ||
554 | | rows m > 0 && cols m > 0 = Just (e, flatten m12, flatten m21 , m22) | ||
555 | | otherwise = Nothing | ||
556 | where | ||
557 | [[m11,m12],[m21,m22]] = block2x2 1 1 m | ||
558 | e = m11 `atIndex` (0, 0) | ||
559 | |||
560 | unView1 :: Numeric t => View1 t -> Matrix t | ||
561 | unView1 (e,r,c,m) = fromBlocks [[scalar e, asRow r],[asColumn c, m]] | ||
562 | |||
563 | type View1 t = (t, Vector t, Vector t, Matrix t) | ||
564 | |||
565 | foldMatrix :: Numeric t => (Matrix t -> Matrix t) -> (View1 t -> View1 t) -> (Matrix t -> Matrix t) | ||
566 | foldMatrix g f ( (f <$>) . view1 . g -> Just (e,r,c,m)) = unView1 (e, r, c, foldMatrix g f m) | ||
567 | foldMatrix _ _ m = m | ||
568 | |||
569 | |||
570 | swapMax k m | ||
571 | | rows m > 0 && j>0 = (j, m ?? (Pos (idxs swapped), All)) | ||
572 | | otherwise = (0,m) | ||
573 | where | ||
574 | j = maxIndex $ abs (tr m ! k) | ||
575 | swapped = j:[1..j-1] ++ 0:[j+1..rows m-1] | ||
576 | |||
577 | down g a = foldMatrix g f a | ||
578 | where | ||
579 | f (e,r,c,m) | ||
580 | | e /= 0 = (1, r', 0, m - outer c r') | ||
581 | | otherwise = error "singular!" | ||
582 | where | ||
583 | r' = r / scalar e | ||
584 | |||
585 | |||
586 | -- | generic reference implementation of gaussian elimination | ||
587 | -- | ||
588 | -- @a <> gaussElim a b = b@ | ||
589 | -- | ||
590 | gaussElim_2 | ||
591 | :: (Eq t, Fractional t, Num (Vector t), Numeric t) | ||
592 | => Matrix t -> Matrix t -> Matrix t | ||
593 | |||
594 | gaussElim_2 a b = flipudrl r | ||
595 | where | ||
596 | flipudrl = flipud . fliprl | ||
597 | splitColsAt n = (takeColumns n &&& dropColumns n) | ||
598 | go f x y = splitColsAt (cols a) (down f $ x ||| y) | ||
599 | (a1,b1) = go (snd . swapMax 0) a b | ||
600 | ( _, r) = go id (flipudrl $ a1) (flipudrl $ b1) | ||
601 | |||
602 | -------------------------------------------------------------------------------- | ||
603 | |||
604 | gaussElim_1 | ||
605 | :: (Fractional t, Num (Vector t), Ord t, Indexable (Vector t) t, Numeric t) | ||
606 | => Matrix t -> Matrix t -> Matrix t | ||
607 | |||
608 | gaussElim_1 x y = dropColumns (rows x) (flipud $ fromRows s2) | ||
609 | where | ||
610 | rs = toRows $ x ||| y | ||
611 | s1 = fromRows $ pivotDown (rows x) 0 rs -- interesting | ||
612 | s2 = pivotUp (rows x-1) (toRows $ flipud s1) | ||
613 | |||
614 | pivotDown t n xs | ||
615 | | t == n = [] | ||
616 | | otherwise = y : pivotDown t (n+1) ys | ||
617 | where | ||
618 | y:ys = redu (pivot n xs) | ||
619 | |||
620 | pivot k = (const k &&& id) | ||
621 | . sortBy (flip compare `on` (abs. (!k))) | ||
622 | |||
623 | redu (k,x:zs) | ||
624 | | p == 0 = error "gauss: singular!" -- FIXME | ||
625 | | otherwise = u : map f zs | ||
626 | where | ||
627 | p = x!k | ||
628 | u = scale (recip (x!k)) x | ||
629 | f z = z - scale (z!k) u | ||
630 | redu (_,[]) = [] | ||
631 | |||
632 | |||
633 | pivotUp n xs | ||
634 | | n == -1 = [] | ||
635 | | otherwise = y : pivotUp (n-1) ys | ||
636 | where | ||
637 | y:ys = redu' (n,xs) | ||
638 | |||
639 | redu' (k,x:zs) = u : map f zs | ||
640 | where | ||
641 | u = x | ||
642 | f z = z - scale (z!k) u | ||
643 | redu' (_,[]) = [] | ||
644 | |||
645 | -------------------------------------------------------------------------------- | ||
646 | |||
647 | gaussElim a b = dropColumns (rows a) $ fst $ mutable gaussST (a ||| b) | ||
648 | |||
649 | gaussST (r,_) x = do | ||
650 | let n = r-1 | ||
651 | axpy m a i j = rowOper (AXPY a i j AllCols) m | ||
652 | swap m i j = rowOper (SWAP i j AllCols) m | ||
653 | scal m a i = rowOper (SCAL a (Row i) AllCols) m | ||
654 | forM_ [0..n] $ \i -> do | ||
655 | c <- maxIndex . abs . flatten <$> extractMatrix x (FromRow i) (Col i) | ||
656 | swap x i (i+c) | ||
657 | a <- readMatrix x i i | ||
658 | when (a == 0) $ error "singular!" | ||
659 | scal x (recip a) i | ||
660 | forM_ [i+1..n] $ \j -> do | ||
661 | b <- readMatrix x j i | ||
662 | axpy x (-b) i j | ||
663 | forM_ [n,n-1..1] $ \i -> do | ||
664 | forM_ [i-1,i-2..0] $ \j -> do | ||
665 | b <- readMatrix x j i | ||
666 | axpy x (-b) i j | ||
667 | |||
668 | |||
669 | |||
670 | luST ok (r,_) x = do | ||
671 | let axpy m a i j = rowOper (AXPY a i j (FromCol (i+1))) m | ||
672 | swap m i j = rowOper (SWAP i j AllCols) m | ||
673 | p <- newUndefinedVector r | ||
674 | forM_ [0..r-1] $ \i -> do | ||
675 | k <- maxIndex . abs . flatten <$> extractMatrix x (FromRow i) (Col i) | ||
676 | writeVector p i (k+i) | ||
677 | swap x i (i+k) | ||
678 | a <- readMatrix x i i | ||
679 | when (ok a) $ do | ||
680 | forM_ [i+1..r-1] $ \j -> do | ||
681 | b <- (/a) <$> readMatrix x j i | ||
682 | axpy x (-b) i j | ||
683 | writeMatrix x j i b | ||
684 | v <- unsafeFreezeVector p | ||
685 | return (toList v) | ||
686 | |||
687 | {- | Experimental implementation of 'luPacked' | ||
688 | for any Fractional element type, including 'Mod' n 'I' and 'Mod' n 'Z'. | ||
689 | |||
690 | >>> let m = ident 5 + (5><5) [0..] :: Matrix (Z ./. 17) | ||
691 | (5><5) | ||
692 | [ 1, 1, 2, 3, 4 | ||
693 | , 5, 7, 7, 8, 9 | ||
694 | , 10, 11, 13, 13, 14 | ||
695 | , 15, 16, 0, 2, 2 | ||
696 | , 3, 4, 5, 6, 8 ] | ||
697 | |||
698 | >>> let (l,u,p,s) = luFact $ luPacked' m | ||
699 | >>> l | ||
700 | (5><5) | ||
701 | [ 1, 0, 0, 0, 0 | ||
702 | , 6, 1, 0, 0, 0 | ||
703 | , 12, 7, 1, 0, 0 | ||
704 | , 7, 10, 7, 1, 0 | ||
705 | , 8, 2, 6, 11, 1 ] | ||
706 | >>> u | ||
707 | (5><5) | ||
708 | [ 15, 16, 0, 2, 2 | ||
709 | , 0, 13, 7, 13, 14 | ||
710 | , 0, 0, 15, 0, 11 | ||
711 | , 0, 0, 0, 15, 15 | ||
712 | , 0, 0, 0, 0, 1 ] | ||
713 | |||
714 | -} | ||
715 | luPacked' x = LU m p | ||
716 | where | ||
717 | (m,p) = mutable (luST (magnit 0)) x | ||
718 | |||
719 | -------------------------------------------------------------------------------- | ||
720 | |||
721 | scalS a (Slice x r0 c0 nr nc) = rowOper (SCAL a (RowRange r0 (r0+nr-1)) (ColRange c0 (c0+nc-1))) x | ||
722 | |||
723 | view x k r = do | ||
724 | d <- readMatrix x k k | ||
725 | let rr = r-1-k | ||
726 | o = if k < r-1 then 1 else 0 | ||
727 | s = Slice x (k+1) (k+1) rr rr | ||
728 | u = Slice x k (k+1) o rr | ||
729 | l = Slice x (k+1) k rr o | ||
730 | return (d,u,l,s) | ||
731 | |||
732 | withVec r f = \s x -> do | ||
733 | p <- newUndefinedVector r | ||
734 | _ <- f s x p | ||
735 | v <- unsafeFreezeVector p | ||
736 | return v | ||
737 | |||
738 | |||
739 | luPacked'' m = (id *** toList) (mutable (withVec (rows m) lu2) m) | ||
740 | where | ||
741 | lu2 (r,_) x p = do | ||
742 | forM_ [0..r-1] $ \k -> do | ||
743 | pivot x p k | ||
744 | (d,u,l,s) <- view x k r | ||
745 | when (magnit 0 d) $ do | ||
746 | scalS (recip d) l | ||
747 | gemmm 1 s (-1) l u | ||
748 | |||
749 | pivot x p k = do | ||
750 | j <- maxIndex . abs . flatten <$> extractMatrix x (FromRow k) (Col k) | ||
751 | writeVector p k (j+k) | ||
752 | swap k (k+j) | ||
753 | where | ||
754 | swap i j = rowOper (SWAP i j AllCols) x | ||
755 | |||
756 | -------------------------------------------------------------------------------- | ||
757 | |||
758 | rowRange m = [0..rows m -1] | ||
759 | |||
760 | at k = Pos (idxs[k]) | ||
761 | |||
762 | backSust' lup rhs = foldl' f (rhs?[]) (reverse ls) | ||
763 | where | ||
764 | ls = [ (d k , u k , b k) | k <- rowRange lup ] | ||
765 | where | ||
766 | d k = lup ?? (at k, at k) | ||
767 | u k = lup ?? (at k, Drop (k+1)) | ||
768 | b k = rhs ?? (at k, All) | ||
769 | |||
770 | f x (d,u,b) = (b - u<>x) / d | ||
771 | === | ||
772 | x | ||
773 | |||
774 | |||
775 | forwSust' lup rhs = foldl' f (rhs?[]) ls | ||
776 | where | ||
777 | ls = [ (l k , b k) | k <- rowRange lup ] | ||
778 | where | ||
779 | l k = lup ?? (at k, Take k) | ||
780 | b k = rhs ?? (at k, All) | ||
781 | |||
782 | f x (l,b) = x | ||
783 | === | ||
784 | (b - l<>x) | ||
785 | |||
786 | |||
787 | luSolve'' (LU lup p) b = backSust' lup (forwSust' lup pb) | ||
788 | where | ||
789 | pb = b ?? (Pos (fixPerm' p), All) | ||
790 | |||
791 | -------------------------------------------------------------------------------- | ||
792 | |||
793 | forwSust lup rhs = fst $ mutable f rhs | ||
794 | where | ||
795 | f (r,c) x = do | ||
796 | l <- unsafeThawMatrix lup | ||
797 | let go k = gemmm 1 (Slice x k 0 1 c) (-1) (Slice l k 0 1 k) (Slice x 0 0 k c) | ||
798 | mapM_ go [0..r-1] | ||
799 | |||
800 | |||
801 | backSust lup rhs = fst $ mutable f rhs | ||
802 | where | ||
803 | f (r,c) m = do | ||
804 | l <- unsafeThawMatrix lup | ||
805 | let d k = recip (lup `atIndex` (k,k)) | ||
806 | u k = Slice l k (k+1) 1 (r-1-k) | ||
807 | b k = Slice m k 0 1 c | ||
808 | x k = Slice m (k+1) 0 (r-1-k) c | ||
809 | scal k = rowOper (SCAL (d k) (Row k) AllCols) m | ||
810 | |||
811 | go k = gemmm 1 (b k) (-1) (u k) (x k) >> scal k | ||
812 | mapM_ go [r-1,r-2..0] | ||
813 | |||
814 | |||
815 | {- | Experimental implementation of 'luSolve' for any Fractional element type, including 'Mod' n 'I' and 'Mod' n 'Z'. | ||
816 | |||
817 | >>> let a = (2><2) [1,2,3,5] :: Matrix (Z ./. 13) | ||
818 | (2><2) | ||
819 | [ 1, 2 | ||
820 | , 3, 5 ] | ||
821 | >>> b | ||
822 | (2><3) | ||
823 | [ 5, 1, 3 | ||
824 | , 8, 6, 3 ] | ||
825 | |||
826 | >>> luSolve' (luPacked' a) b | ||
827 | (2><3) | ||
828 | [ 4, 7, 4 | ||
829 | , 7, 10, 6 ] | ||
830 | |||
831 | -} | ||
832 | luSolve' (LU lup p) b = backSust lup (forwSust lup pb) | ||
833 | where | ||
834 | pb = b ?? (Pos (fixPerm' p), All) | ||
835 | |||
836 | |||
837 | -------------------------------------------------------------------------------- | ||
838 | |||
839 | data MatrixView t b | ||
840 | = Elem t | ||
841 | | Block b b b b | ||
842 | deriving Show | ||
843 | |||
844 | |||
845 | viewBlock' r c m | ||
846 | | (rt,ct) == (1,1) = Elem (atM' m 0 0) | ||
847 | | otherwise = Block m11 m12 m21 m22 | ||
848 | where | ||
849 | (rt,ct) = size m | ||
850 | m11 = subm (0,0) (r,c) m | ||
851 | m12 = subm (0,c) (r,ct-c) m | ||
852 | m21 = subm (r,0) (rt-r,c) m | ||
853 | m22 = subm (r,c) (rt-r,ct-c) m | ||
854 | subm = subMatrix | ||
855 | |||
856 | viewBlock m = viewBlock' n n m | ||
857 | where | ||
858 | n = rows m `div` 2 | ||
859 | |||
860 | invershur (viewBlock -> Block a b c d) = fromBlocks [[a',b'],[c',d']] | ||
861 | where | ||
862 | r1 = invershur a | ||
863 | r2 = c <> r1 | ||
864 | r3 = r1 <> b | ||
865 | r4 = c <> r3 | ||
866 | r5 = r4-d | ||
867 | r6 = invershur r5 | ||
868 | b' = r3 <> r6 | ||
869 | c' = r6 <> r2 | ||
870 | r7 = r3 <> c' | ||
871 | a' = r1-r7 | ||
872 | d' = -r6 | ||
873 | |||
874 | invershur x = recip x | ||
875 | |||
876 | -------------------------------------------------------------------------------- | ||
877 | |||
878 | instance Testable (Matrix I) where | ||
879 | checkT _ = test | ||
880 | |||
881 | test :: (Bool, IO()) | ||
882 | test = (and ok, return ()) | ||
883 | where | ||
884 | m = (3><4) [1..12] :: Matrix I | ||
885 | r = (2><3) [1,2,3,4,3,2] | ||
886 | c = (3><2) [0,4,4,1,2,3] | ||
887 | p = (9><10) [0..89] :: Matrix I | ||
888 | ep = (2><3) [10,24,32,44,31,23] | ||
889 | md = fromInt m :: Matrix Double | ||
890 | ok = [ tr m <> m == toInt (tr md <> md) | ||
891 | , m <> tr m == toInt (md <> tr md) | ||
892 | , m ?? (Take 2, Take 3) == remap (asColumn (range 2)) (asRow (range 3)) m | ||
893 | , remap r (tr c) p == ep | ||
894 | , tr p ?? (PosCyc (idxs[-5,13]), Pos (idxs[3,7,1])) == (2><3) [35,75,15,33,73,13] | ||
895 | ] | ||
896 | |||
diff --git a/packages/base/src/Data/Packed/Internal/Vector.hs b/packages/base/src/Internal/Vector.hs index d0bc143..de4e670 100644 --- a/packages/base/src/Data/Packed/Internal/Vector.hs +++ b/packages/base/src/Internal/Vector.hs | |||
@@ -1,60 +1,64 @@ | |||
1 | {-# LANGUAGE MagicHash, CPP, UnboxedTuples, BangPatterns, FlexibleContexts #-} | 1 | {-# LANGUAGE MagicHash, CPP, UnboxedTuples, BangPatterns, FlexibleContexts #-} |
2 | {-# LANGUAGE TypeSynonymInstances #-} | ||
3 | |||
4 | |||
2 | -- | | 5 | -- | |
3 | -- Module : Data.Packed.Internal.Vector | 6 | -- Module : Internal.Vector |
4 | -- Copyright : (c) Alberto Ruiz 2007 | 7 | -- Copyright : (c) Alberto Ruiz 2007-15 |
5 | -- License : BSD3 | 8 | -- License : BSD3 |
6 | -- Maintainer : Alberto Ruiz | 9 | -- Maintainer : Alberto Ruiz |
7 | -- Stability : provisional | 10 | -- Stability : provisional |
8 | -- | 11 | -- |
9 | -- Vector implementation | ||
10 | -- | ||
11 | -------------------------------------------------------------------------------- | ||
12 | 12 | ||
13 | module Data.Packed.Internal.Vector ( | 13 | module Internal.Vector( |
14 | Vector, dim, | 14 | I,Z,R,C, |
15 | fromList, toList, (|>), | 15 | fi,ti, |
16 | vjoin, (@>), safe, at, at', subVector, takesV, | 16 | Vector, fromList, unsafeToForeignPtr, unsafeFromForeignPtr, unsafeWith, |
17 | mapVector, mapVectorWithIndex, zipVectorWith, unzipVectorWith, | 17 | createVector, avec, inlinePerformIO, |
18 | mapVectorM, mapVectorM_, mapVectorWithIndexM, mapVectorWithIndexM_, | 18 | toList, dim, (@>), at', (|>), |
19 | foldVector, foldVectorG, foldLoop, foldVectorWithIndex, | 19 | vjoin, subVector, takesV, idxs, |
20 | createVector, vec, | 20 | buildVector, |
21 | asComplex, asReal, float2DoubleV, double2FloatV, | 21 | asReal, asComplex, |
22 | stepF, stepD, condF, condD, | 22 | toByteString,fromByteString, |
23 | conjugateQ, conjugateC, | 23 | zipVector, unzipVector, zipVectorWith, unzipVectorWith, |
24 | cloneVector, | 24 | foldVector, foldVectorG, foldVectorWithIndex, foldLoop, |
25 | unsafeToForeignPtr, | 25 | mapVector, mapVectorM, mapVectorM_, |
26 | unsafeFromForeignPtr, | 26 | mapVectorWithIndex, mapVectorWithIndexM, mapVectorWithIndexM_ |
27 | unsafeWith | ||
28 | ) where | 27 | ) where |
29 | 28 | ||
30 | import Data.Packed.Internal.Common | 29 | import Foreign.Marshal.Array |
31 | import Data.Packed.Internal.Signatures | 30 | import Foreign.ForeignPtr |
32 | import Foreign.Marshal.Array(peekArray, copyArray, advancePtr) | 31 | import Foreign.Ptr |
33 | import Foreign.ForeignPtr(ForeignPtr, castForeignPtr) | 32 | import Foreign.Storable |
34 | import Foreign.Ptr(Ptr) | 33 | import Foreign.C.Types(CInt) |
35 | import Foreign.Storable(Storable, peekElemOff, pokeElemOff, sizeOf) | 34 | import Data.Int(Int64) |
36 | import Foreign.C.Types | ||
37 | import Data.Complex | 35 | import Data.Complex |
38 | import Control.Monad(when) | ||
39 | import System.IO.Unsafe(unsafePerformIO) | 36 | import System.IO.Unsafe(unsafePerformIO) |
37 | import GHC.ForeignPtr(mallocPlainForeignPtrBytes) | ||
38 | import GHC.Base(realWorld#, IO(IO), when) | ||
39 | import qualified Data.Vector.Storable as Vector | ||
40 | import Data.Vector.Storable(Vector, fromList, unsafeToForeignPtr, unsafeFromForeignPtr, unsafeWith) | ||
40 | 41 | ||
41 | #if __GLASGOW_HASKELL__ >= 605 | 42 | #ifdef BINARY |
42 | import GHC.ForeignPtr (mallocPlainForeignPtrBytes) | 43 | import Data.Binary |
43 | #else | 44 | import Control.Monad(replicateM) |
44 | import Foreign.ForeignPtr (mallocForeignPtrBytes) | 45 | import qualified Data.ByteString.Internal as BS |
46 | import Data.Vector.Storable.Internal(updPtr) | ||
45 | #endif | 47 | #endif |
46 | 48 | ||
47 | import GHC.Base | 49 | type I = CInt |
48 | #if __GLASGOW_HASKELL__ < 612 | 50 | type Z = Int64 |
49 | import GHC.IOBase hiding (liftIO) | 51 | type R = Double |
50 | #endif | 52 | type C = Complex Double |
51 | 53 | ||
52 | import qualified Data.Vector.Storable as Vector | 54 | |
53 | import Data.Vector.Storable(Vector, | 55 | -- | specialized fromIntegral |
54 | fromList, | 56 | fi :: Int -> CInt |
55 | unsafeToForeignPtr, | 57 | fi = fromIntegral |
56 | unsafeFromForeignPtr, | 58 | |
57 | unsafeWith) | 59 | -- | specialized fromIntegral |
60 | ti :: CInt -> Int | ||
61 | ti = fromIntegral | ||
58 | 62 | ||
59 | 63 | ||
60 | -- | Number of elements | 64 | -- | Number of elements |
@@ -63,14 +67,10 @@ dim = Vector.length | |||
63 | 67 | ||
64 | 68 | ||
65 | -- C-Haskell vector adapter | 69 | -- C-Haskell vector adapter |
66 | -- vec :: Adapt (CInt -> Ptr t -> r) (Vector t) r | 70 | {-# INLINE avec #-} |
67 | vec :: (Storable t) => Vector t -> (((CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b | 71 | avec :: Storable a => (CInt -> Ptr a -> b) -> Vector a -> b |
68 | vec x f = unsafeWith x $ \p -> do | 72 | avec f v = inlinePerformIO (unsafeWith v (return . f (fromIntegral (Vector.length v)))) |
69 | let v g = do | 73 | infixl 1 `avec` |
70 | g (fi $ dim x) p | ||
71 | f v | ||
72 | {-# INLINE vec #-} | ||
73 | |||
74 | 74 | ||
75 | -- allocates memory for a new vector | 75 | -- allocates memory for a new vector |
76 | createVector :: Storable a => Int -> IO (Vector a) | 76 | createVector :: Storable a => Int -> IO (Vector a) |
@@ -85,11 +85,7 @@ createVector n = do | |||
85 | -- | 85 | -- |
86 | doMalloc :: Storable b => b -> IO (ForeignPtr b) | 86 | doMalloc :: Storable b => b -> IO (ForeignPtr b) |
87 | doMalloc dummy = do | 87 | doMalloc dummy = do |
88 | #if __GLASGOW_HASKELL__ >= 605 | ||
89 | mallocPlainForeignPtrBytes (n * sizeOf dummy) | 88 | mallocPlainForeignPtrBytes (n * sizeOf dummy) |
90 | #else | ||
91 | mallocForeignPtrBytes (n * sizeOf dummy) | ||
92 | #endif | ||
93 | 89 | ||
94 | {- | creates a Vector from a list: | 90 | {- | creates a Vector from a list: |
95 | 91 | ||
@@ -105,7 +101,7 @@ inlinePerformIO :: IO a -> a | |||
105 | inlinePerformIO (IO m) = case m realWorld# of (# _, r #) -> r | 101 | inlinePerformIO (IO m) = case m realWorld# of (# _, r #) -> r |
106 | {-# INLINE inlinePerformIO #-} | 102 | {-# INLINE inlinePerformIO #-} |
107 | 103 | ||
108 | {- | extracts the Vector elements to a list | 104 | {- extracts the Vector elements to a list |
109 | 105 | ||
110 | >>> toList (linspace 5 (1,10)) | 106 | >>> toList (linspace 5 (1,10)) |
111 | [1.0,3.25,5.5,7.75,10.0] | 107 | [1.0,3.25,5.5,7.75,10.0] |
@@ -115,7 +111,7 @@ toList :: Storable a => Vector a -> [a] | |||
115 | toList v = safeRead v $ peekArray (dim v) | 111 | toList v = safeRead v $ peekArray (dim v) |
116 | 112 | ||
117 | {- | Create a vector from a list of elements and explicit dimension. The input | 113 | {- | Create a vector from a list of elements and explicit dimension. The input |
118 | list is explicitly truncated if it is too long, so it may safely | 114 | list is truncated if it is too long, so it may safely |
119 | be used, for instance, with infinite lists. | 115 | be used, for instance, with infinite lists. |
120 | 116 | ||
121 | >>> 5 |> [1..] | 117 | >>> 5 |> [1..] |
@@ -124,36 +120,16 @@ fromList [1.0,2.0,3.0,4.0,5.0] | |||
124 | -} | 120 | -} |
125 | (|>) :: (Storable a) => Int -> [a] -> Vector a | 121 | (|>) :: (Storable a) => Int -> [a] -> Vector a |
126 | infixl 9 |> | 122 | infixl 9 |> |
127 | n |> l = if length l' == n | 123 | n |> l |
128 | then fromList l' | 124 | | length l' == n = fromList l' |
129 | else error "list too short for |>" | 125 | | otherwise = error "list too short for |>" |
130 | where l' = take n l | 126 | where |
131 | 127 | l' = take n l | |
132 | 128 | ||
133 | -- | access to Vector elements without range checking | ||
134 | at' :: Storable a => Vector a -> Int -> a | ||
135 | at' v n = safeRead v $ flip peekElemOff n | ||
136 | {-# INLINE at' #-} | ||
137 | 129 | ||
138 | -- | 130 | -- | Create a vector of indexes, useful for matrix extraction using '(??)' |
139 | -- turn off bounds checking with -funsafe at configure time. | 131 | idxs :: [Int] -> Vector I |
140 | -- ghc will optimise away the salways true case at compile time. | 132 | idxs js = fromList (map fromIntegral js) :: Vector I |
141 | -- | ||
142 | #if defined(UNSAFE) | ||
143 | safe :: Bool | ||
144 | safe = False | ||
145 | #else | ||
146 | safe = True | ||
147 | #endif | ||
148 | |||
149 | -- | access to Vector elements with range checking. | ||
150 | at :: Storable a => Vector a -> Int -> a | ||
151 | at v n | ||
152 | | safe = if n >= 0 && n < dim v | ||
153 | then at' v n | ||
154 | else error "vector index out of range" | ||
155 | | otherwise = at' v n | ||
156 | {-# INLINE at #-} | ||
157 | 133 | ||
158 | {- | takes a number of consecutive elements from a Vector | 134 | {- | takes a number of consecutive elements from a Vector |
159 | 135 | ||
@@ -168,6 +144,8 @@ subVector :: Storable t => Int -- ^ index of the starting element | |||
168 | subVector = Vector.slice | 144 | subVector = Vector.slice |
169 | 145 | ||
170 | 146 | ||
147 | |||
148 | |||
171 | {- | Reads a vector position: | 149 | {- | Reads a vector position: |
172 | 150 | ||
173 | >>> fromList [0..9] @> 7 | 151 | >>> fromList [0..9] @> 7 |
@@ -176,8 +154,15 @@ subVector = Vector.slice | |||
176 | -} | 154 | -} |
177 | (@>) :: Storable t => Vector t -> Int -> t | 155 | (@>) :: Storable t => Vector t -> Int -> t |
178 | infixl 9 @> | 156 | infixl 9 @> |
179 | (@>) = at | 157 | v @> n |
158 | | n >= 0 && n < dim v = at' v n | ||
159 | | otherwise = error "vector index out of range" | ||
160 | {-# INLINE (@>) #-} | ||
180 | 161 | ||
162 | -- | access to Vector elements without range checking | ||
163 | at' :: Storable a => Vector a -> Int -> a | ||
164 | at' v n = safeRead v $ flip peekElemOff n | ||
165 | {-# INLINE at' #-} | ||
181 | 166 | ||
182 | {- | concatenate a list of vectors | 167 | {- | concatenate a list of vectors |
183 | 168 | ||
@@ -226,84 +211,8 @@ asComplex :: (RealFloat a, Storable a) => Vector a -> Vector (Complex a) | |||
226 | asComplex v = unsafeFromForeignPtr (castForeignPtr fp) (i `div` 2) (n `div` 2) | 211 | asComplex v = unsafeFromForeignPtr (castForeignPtr fp) (i `div` 2) (n `div` 2) |
227 | where (fp,i,n) = unsafeToForeignPtr v | 212 | where (fp,i,n) = unsafeToForeignPtr v |
228 | 213 | ||
229 | --------------------------------------------------------------- | ||
230 | |||
231 | float2DoubleV :: Vector Float -> Vector Double | ||
232 | float2DoubleV v = unsafePerformIO $ do | ||
233 | r <- createVector (dim v) | ||
234 | app2 c_float2double vec v vec r "float2double" | ||
235 | return r | ||
236 | |||
237 | double2FloatV :: Vector Double -> Vector Float | ||
238 | double2FloatV v = unsafePerformIO $ do | ||
239 | r <- createVector (dim v) | ||
240 | app2 c_double2float vec v vec r "double2float2" | ||
241 | return r | ||
242 | |||
243 | |||
244 | foreign import ccall unsafe "float2double" c_float2double:: TFV | ||
245 | foreign import ccall unsafe "double2float" c_double2float:: TVF | ||
246 | |||
247 | --------------------------------------------------------------- | ||
248 | |||
249 | stepF :: Vector Float -> Vector Float | ||
250 | stepF v = unsafePerformIO $ do | ||
251 | r <- createVector (dim v) | ||
252 | app2 c_stepF vec v vec r "stepF" | ||
253 | return r | ||
254 | |||
255 | stepD :: Vector Double -> Vector Double | ||
256 | stepD v = unsafePerformIO $ do | ||
257 | r <- createVector (dim v) | ||
258 | app2 c_stepD vec v vec r "stepD" | ||
259 | return r | ||
260 | |||
261 | foreign import ccall unsafe "stepF" c_stepF :: TFF | ||
262 | foreign import ccall unsafe "stepD" c_stepD :: TVV | ||
263 | |||
264 | --------------------------------------------------------------- | ||
265 | |||
266 | condF :: Vector Float -> Vector Float -> Vector Float -> Vector Float -> Vector Float -> Vector Float | ||
267 | condF x y l e g = unsafePerformIO $ do | ||
268 | r <- createVector (dim x) | ||
269 | app6 c_condF vec x vec y vec l vec e vec g vec r "condF" | ||
270 | return r | ||
271 | |||
272 | condD :: Vector Double -> Vector Double -> Vector Double -> Vector Double -> Vector Double -> Vector Double | ||
273 | condD x y l e g = unsafePerformIO $ do | ||
274 | r <- createVector (dim x) | ||
275 | app6 c_condD vec x vec y vec l vec e vec g vec r "condD" | ||
276 | return r | ||
277 | |||
278 | foreign import ccall unsafe "condF" c_condF :: CInt -> PF -> CInt -> PF -> CInt -> PF -> TFFF | ||
279 | foreign import ccall unsafe "condD" c_condD :: CInt -> PD -> CInt -> PD -> CInt -> PD -> TVVV | ||
280 | |||
281 | -------------------------------------------------------------------------------- | ||
282 | |||
283 | conjugateAux fun x = unsafePerformIO $ do | ||
284 | v <- createVector (dim x) | ||
285 | app2 fun vec x vec v "conjugateAux" | ||
286 | return v | ||
287 | |||
288 | conjugateQ :: Vector (Complex Float) -> Vector (Complex Float) | ||
289 | conjugateQ = conjugateAux c_conjugateQ | ||
290 | foreign import ccall unsafe "conjugateQ" c_conjugateQ :: TQVQV | ||
291 | |||
292 | conjugateC :: Vector (Complex Double) -> Vector (Complex Double) | ||
293 | conjugateC = conjugateAux c_conjugateC | ||
294 | foreign import ccall unsafe "conjugateC" c_conjugateC :: TCVCV | ||
295 | |||
296 | -------------------------------------------------------------------------------- | 214 | -------------------------------------------------------------------------------- |
297 | 215 | ||
298 | cloneVector :: Storable t => Vector t -> IO (Vector t) | ||
299 | cloneVector v = do | ||
300 | let n = dim v | ||
301 | r <- createVector n | ||
302 | let f _ s _ d = copyArray d s n >> return 0 | ||
303 | app2 f vec v vec r "cloneVector" | ||
304 | return r | ||
305 | |||
306 | ------------------------------------------------------------------ | ||
307 | 216 | ||
308 | -- | map on Vectors | 217 | -- | map on Vectors |
309 | mapVector :: (Storable a, Storable b) => (a-> b) -> Vector a -> Vector b | 218 | mapVector :: (Storable a, Storable b) => (a-> b) -> Vector a -> Vector b |
@@ -381,7 +290,7 @@ foldLoop f s0 d = go (d - 1) s0 | |||
381 | go !j !s = go (j - 1) (f j s) | 290 | go !j !s = go (j - 1) (f j s) |
382 | 291 | ||
383 | foldVectorG f s0 v = foldLoop g s0 (dim v) | 292 | foldVectorG f s0 v = foldLoop g s0 (dim v) |
384 | where g !k !s = f k (at' v) s | 293 | where g !k !s = f k (safeRead v . flip peekElemOff) s |
385 | {-# INLINE g #-} -- Thanks to Ryan Ingram (http://permalink.gmane.org/gmane.comp.lang.haskell.cafe/46479) | 294 | {-# INLINE g #-} -- Thanks to Ryan Ingram (http://permalink.gmane.org/gmane.comp.lang.haskell.cafe/46479) |
386 | {-# INLINE foldVectorG #-} | 295 | {-# INLINE foldVectorG #-} |
387 | 296 | ||
@@ -468,4 +377,85 @@ mapVectorWithIndex f v = unsafePerformIO $ do | |||
468 | return w | 377 | return w |
469 | {-# INLINE mapVectorWithIndex #-} | 378 | {-# INLINE mapVectorWithIndex #-} |
470 | 379 | ||
380 | -------------------------------------------------------------------------------- | ||
381 | |||
382 | |||
383 | #ifdef BINARY | ||
384 | |||
385 | -- a 64K cache, with a Double taking 13 bytes in Bytestring, | ||
386 | -- implies a chunk size of 5041 | ||
387 | chunk :: Int | ||
388 | chunk = 5000 | ||
389 | |||
390 | chunks :: Int -> [Int] | ||
391 | chunks d = let c = d `div` chunk | ||
392 | m = d `mod` chunk | ||
393 | in if m /= 0 then reverse (m:(replicate c chunk)) else (replicate c chunk) | ||
394 | |||
395 | putVector v = mapM_ put $! toList v | ||
396 | |||
397 | getVector d = do | ||
398 | xs <- replicateM d get | ||
399 | return $! fromList xs | ||
400 | |||
401 | -------------------------------------------------------------------------------- | ||
402 | |||
403 | toByteString :: Storable t => Vector t -> BS.ByteString | ||
404 | toByteString v = BS.PS (castForeignPtr fp) (sz*o) (sz * dim v) | ||
405 | where | ||
406 | (fp,o,_n) = unsafeToForeignPtr v | ||
407 | sz = sizeOf (v@>0) | ||
408 | |||
409 | |||
410 | fromByteString :: Storable t => BS.ByteString -> Vector t | ||
411 | fromByteString (BS.PS fp o n) = r | ||
412 | where | ||
413 | r = unsafeFromForeignPtr (castForeignPtr (updPtr (`plusPtr` o) fp)) 0 n' | ||
414 | n' = n `div` sz | ||
415 | sz = sizeOf (r@>0) | ||
416 | |||
417 | -------------------------------------------------------------------------------- | ||
418 | |||
419 | instance (Binary a, Storable a) => Binary (Vector a) where | ||
420 | |||
421 | put v = do | ||
422 | let d = dim v | ||
423 | put d | ||
424 | mapM_ putVector $! takesV (chunks d) v | ||
425 | |||
426 | -- put = put . v2bs | ||
427 | |||
428 | get = do | ||
429 | d <- get | ||
430 | vs <- mapM getVector $ chunks d | ||
431 | return $! vjoin vs | ||
432 | |||
433 | -- get = fmap bs2v get | ||
434 | |||
435 | #endif | ||
436 | |||
437 | |||
438 | ------------------------------------------------------------------- | ||
439 | |||
440 | {- | creates a Vector of the specified length using the supplied function to | ||
441 | to map the index to the value at that index. | ||
442 | |||
443 | @> buildVector 4 fromIntegral | ||
444 | 4 |> [0.0,1.0,2.0,3.0]@ | ||
445 | |||
446 | -} | ||
447 | buildVector :: Storable a => Int -> (Int -> a) -> Vector a | ||
448 | buildVector len f = | ||
449 | fromList $ map f [0 .. (len - 1)] | ||
450 | |||
451 | |||
452 | -- | zip for Vectors | ||
453 | zipVector :: (Storable a, Storable b, Storable (a,b)) => Vector a -> Vector b -> Vector (a,b) | ||
454 | zipVector = zipVectorWith (,) | ||
455 | |||
456 | -- | unzip for Vectors | ||
457 | unzipVector :: (Storable a, Storable b, Storable (a,b)) => Vector (a,b) -> (Vector a,Vector b) | ||
458 | unzipVector = unzipVectorWith id | ||
459 | |||
460 | ------------------------------------------------------------------- | ||
471 | 461 | ||
diff --git a/packages/base/src/Internal/Vectorized.hs b/packages/base/src/Internal/Vectorized.hs new file mode 100644 index 0000000..03bcf90 --- /dev/null +++ b/packages/base/src/Internal/Vectorized.hs | |||
@@ -0,0 +1,518 @@ | |||
1 | {-# LANGUAGE TypeOperators #-} | ||
2 | {-# LANGUAGE TypeFamilies #-} | ||
3 | |||
4 | ----------------------------------------------------------------------------- | ||
5 | -- | | ||
6 | -- Module : Numeric.Vectorized | ||
7 | -- Copyright : (c) Alberto Ruiz 2007-15 | ||
8 | -- License : BSD3 | ||
9 | -- Maintainer : Alberto Ruiz | ||
10 | -- Stability : provisional | ||
11 | -- | ||
12 | -- Low level interface to vector operations. | ||
13 | -- | ||
14 | ----------------------------------------------------------------------------- | ||
15 | |||
16 | module Internal.Vectorized where | ||
17 | |||
18 | import Internal.Vector | ||
19 | import Internal.Devel | ||
20 | import Data.Complex | ||
21 | import Foreign.Marshal.Alloc(free,malloc) | ||
22 | import Foreign.Marshal.Array(newArray,copyArray) | ||
23 | import Foreign.Ptr(Ptr) | ||
24 | import Foreign.Storable(peek,Storable) | ||
25 | import Foreign.C.Types | ||
26 | import Foreign.C.String | ||
27 | import System.IO.Unsafe(unsafePerformIO) | ||
28 | import Control.Monad(when) | ||
29 | |||
30 | infixl 1 # | ||
31 | a # b = applyRaw a b | ||
32 | {-# INLINE (#) #-} | ||
33 | |||
34 | fromei x = fromIntegral (fromEnum x) :: CInt | ||
35 | |||
36 | data FunCodeV = Sin | ||
37 | | Cos | ||
38 | | Tan | ||
39 | | Abs | ||
40 | | ASin | ||
41 | | ACos | ||
42 | | ATan | ||
43 | | Sinh | ||
44 | | Cosh | ||
45 | | Tanh | ||
46 | | ASinh | ||
47 | | ACosh | ||
48 | | ATanh | ||
49 | | Exp | ||
50 | | Log | ||
51 | | Sign | ||
52 | | Sqrt | ||
53 | deriving Enum | ||
54 | |||
55 | data FunCodeSV = Scale | ||
56 | | Recip | ||
57 | | AddConstant | ||
58 | | Negate | ||
59 | | PowSV | ||
60 | | PowVS | ||
61 | | ModSV | ||
62 | | ModVS | ||
63 | deriving Enum | ||
64 | |||
65 | data FunCodeVV = Add | ||
66 | | Sub | ||
67 | | Mul | ||
68 | | Div | ||
69 | | Pow | ||
70 | | ATan2 | ||
71 | | Mod | ||
72 | deriving Enum | ||
73 | |||
74 | data FunCodeS = Norm2 | ||
75 | | AbsSum | ||
76 | | MaxIdx | ||
77 | | Max | ||
78 | | MinIdx | ||
79 | | Min | ||
80 | deriving Enum | ||
81 | |||
82 | ------------------------------------------------------------------ | ||
83 | |||
84 | -- | sum of elements | ||
85 | sumF :: Vector Float -> Float | ||
86 | sumF = sumg c_sumF | ||
87 | |||
88 | -- | sum of elements | ||
89 | sumR :: Vector Double -> Double | ||
90 | sumR = sumg c_sumR | ||
91 | |||
92 | -- | sum of elements | ||
93 | sumQ :: Vector (Complex Float) -> Complex Float | ||
94 | sumQ = sumg c_sumQ | ||
95 | |||
96 | -- | sum of elements | ||
97 | sumC :: Vector (Complex Double) -> Complex Double | ||
98 | sumC = sumg c_sumC | ||
99 | |||
100 | sumI m = sumg (c_sumI m) | ||
101 | |||
102 | sumL m = sumg (c_sumL m) | ||
103 | |||
104 | sumg f x = unsafePerformIO $ do | ||
105 | r <- createVector 1 | ||
106 | f # x # r #| "sum" | ||
107 | return $ r @> 0 | ||
108 | |||
109 | type TVV t = t :> t :> Ok | ||
110 | |||
111 | foreign import ccall unsafe "sumF" c_sumF :: TVV Float | ||
112 | foreign import ccall unsafe "sumR" c_sumR :: TVV Double | ||
113 | foreign import ccall unsafe "sumQ" c_sumQ :: TVV (Complex Float) | ||
114 | foreign import ccall unsafe "sumC" c_sumC :: TVV (Complex Double) | ||
115 | foreign import ccall unsafe "sumI" c_sumI :: I -> TVV I | ||
116 | foreign import ccall unsafe "sumL" c_sumL :: Z -> TVV Z | ||
117 | |||
118 | -- | product of elements | ||
119 | prodF :: Vector Float -> Float | ||
120 | prodF = prodg c_prodF | ||
121 | |||
122 | -- | product of elements | ||
123 | prodR :: Vector Double -> Double | ||
124 | prodR = prodg c_prodR | ||
125 | |||
126 | -- | product of elements | ||
127 | prodQ :: Vector (Complex Float) -> Complex Float | ||
128 | prodQ = prodg c_prodQ | ||
129 | |||
130 | -- | product of elements | ||
131 | prodC :: Vector (Complex Double) -> Complex Double | ||
132 | prodC = prodg c_prodC | ||
133 | |||
134 | prodI :: I-> Vector I -> I | ||
135 | prodI = prodg . c_prodI | ||
136 | |||
137 | prodL :: Z-> Vector Z -> Z | ||
138 | prodL = prodg . c_prodL | ||
139 | |||
140 | prodg f x = unsafePerformIO $ do | ||
141 | r <- createVector 1 | ||
142 | f # x # r #| "prod" | ||
143 | return $ r @> 0 | ||
144 | |||
145 | |||
146 | foreign import ccall unsafe "prodF" c_prodF :: TVV Float | ||
147 | foreign import ccall unsafe "prodR" c_prodR :: TVV Double | ||
148 | foreign import ccall unsafe "prodQ" c_prodQ :: TVV (Complex Float) | ||
149 | foreign import ccall unsafe "prodC" c_prodC :: TVV (Complex Double) | ||
150 | foreign import ccall unsafe "prodI" c_prodI :: I -> TVV I | ||
151 | foreign import ccall unsafe "prodL" c_prodL :: Z -> TVV Z | ||
152 | |||
153 | ------------------------------------------------------------------ | ||
154 | |||
155 | toScalarAux fun code v = unsafePerformIO $ do | ||
156 | r <- createVector 1 | ||
157 | fun (fromei code) # v # r #|"toScalarAux" | ||
158 | return (r @> 0) | ||
159 | |||
160 | vectorMapAux fun code v = unsafePerformIO $ do | ||
161 | r <- createVector (dim v) | ||
162 | fun (fromei code) # v # r #|"vectorMapAux" | ||
163 | return r | ||
164 | |||
165 | vectorMapValAux fun code val v = unsafePerformIO $ do | ||
166 | r <- createVector (dim v) | ||
167 | pval <- newArray [val] | ||
168 | fun (fromei code) pval # v # r #|"vectorMapValAux" | ||
169 | free pval | ||
170 | return r | ||
171 | |||
172 | vectorZipAux fun code u v = unsafePerformIO $ do | ||
173 | r <- createVector (dim u) | ||
174 | fun (fromei code) # u # v # r #|"vectorZipAux" | ||
175 | return r | ||
176 | |||
177 | --------------------------------------------------------------------- | ||
178 | |||
179 | -- | obtains different functions of a vector: norm1, norm2, max, min, posmax, posmin, etc. | ||
180 | toScalarR :: FunCodeS -> Vector Double -> Double | ||
181 | toScalarR oper = toScalarAux c_toScalarR (fromei oper) | ||
182 | |||
183 | foreign import ccall unsafe "toScalarR" c_toScalarR :: CInt -> TVV Double | ||
184 | |||
185 | -- | obtains different functions of a vector: norm1, norm2, max, min, posmax, posmin, etc. | ||
186 | toScalarF :: FunCodeS -> Vector Float -> Float | ||
187 | toScalarF oper = toScalarAux c_toScalarF (fromei oper) | ||
188 | |||
189 | foreign import ccall unsafe "toScalarF" c_toScalarF :: CInt -> TVV Float | ||
190 | |||
191 | -- | obtains different functions of a vector: only norm1, norm2 | ||
192 | toScalarC :: FunCodeS -> Vector (Complex Double) -> Double | ||
193 | toScalarC oper = toScalarAux c_toScalarC (fromei oper) | ||
194 | |||
195 | foreign import ccall unsafe "toScalarC" c_toScalarC :: CInt -> Complex Double :> Double :> Ok | ||
196 | |||
197 | -- | obtains different functions of a vector: only norm1, norm2 | ||
198 | toScalarQ :: FunCodeS -> Vector (Complex Float) -> Float | ||
199 | toScalarQ oper = toScalarAux c_toScalarQ (fromei oper) | ||
200 | |||
201 | foreign import ccall unsafe "toScalarQ" c_toScalarQ :: CInt -> Complex Float :> Float :> Ok | ||
202 | |||
203 | -- | obtains different functions of a vector: norm1, norm2, max, min, posmax, posmin, etc. | ||
204 | toScalarI :: FunCodeS -> Vector CInt -> CInt | ||
205 | toScalarI oper = toScalarAux c_toScalarI (fromei oper) | ||
206 | |||
207 | foreign import ccall unsafe "toScalarI" c_toScalarI :: CInt -> TVV CInt | ||
208 | |||
209 | -- | obtains different functions of a vector: norm1, norm2, max, min, posmax, posmin, etc. | ||
210 | toScalarL :: FunCodeS -> Vector Z -> Z | ||
211 | toScalarL oper = toScalarAux c_toScalarL (fromei oper) | ||
212 | |||
213 | foreign import ccall unsafe "toScalarL" c_toScalarL :: CInt -> TVV Z | ||
214 | |||
215 | |||
216 | ------------------------------------------------------------------ | ||
217 | |||
218 | -- | map of real vectors with given function | ||
219 | vectorMapR :: FunCodeV -> Vector Double -> Vector Double | ||
220 | vectorMapR = vectorMapAux c_vectorMapR | ||
221 | |||
222 | foreign import ccall unsafe "mapR" c_vectorMapR :: CInt -> TVV Double | ||
223 | |||
224 | -- | map of complex vectors with given function | ||
225 | vectorMapC :: FunCodeV -> Vector (Complex Double) -> Vector (Complex Double) | ||
226 | vectorMapC oper = vectorMapAux c_vectorMapC (fromei oper) | ||
227 | |||
228 | foreign import ccall unsafe "mapC" c_vectorMapC :: CInt -> TVV (Complex Double) | ||
229 | |||
230 | -- | map of real vectors with given function | ||
231 | vectorMapF :: FunCodeV -> Vector Float -> Vector Float | ||
232 | vectorMapF = vectorMapAux c_vectorMapF | ||
233 | |||
234 | foreign import ccall unsafe "mapF" c_vectorMapF :: CInt -> TVV Float | ||
235 | |||
236 | -- | map of real vectors with given function | ||
237 | vectorMapQ :: FunCodeV -> Vector (Complex Float) -> Vector (Complex Float) | ||
238 | vectorMapQ = vectorMapAux c_vectorMapQ | ||
239 | |||
240 | foreign import ccall unsafe "mapQ" c_vectorMapQ :: CInt -> TVV (Complex Float) | ||
241 | |||
242 | -- | map of real vectors with given function | ||
243 | vectorMapI :: FunCodeV -> Vector CInt -> Vector CInt | ||
244 | vectorMapI = vectorMapAux c_vectorMapI | ||
245 | |||
246 | foreign import ccall unsafe "mapI" c_vectorMapI :: CInt -> TVV CInt | ||
247 | |||
248 | -- | map of real vectors with given function | ||
249 | vectorMapL :: FunCodeV -> Vector Z -> Vector Z | ||
250 | vectorMapL = vectorMapAux c_vectorMapL | ||
251 | |||
252 | foreign import ccall unsafe "mapL" c_vectorMapL :: CInt -> TVV Z | ||
253 | |||
254 | ------------------------------------------------------------------- | ||
255 | |||
256 | -- | map of real vectors with given function | ||
257 | vectorMapValR :: FunCodeSV -> Double -> Vector Double -> Vector Double | ||
258 | vectorMapValR oper = vectorMapValAux c_vectorMapValR (fromei oper) | ||
259 | |||
260 | foreign import ccall unsafe "mapValR" c_vectorMapValR :: CInt -> Ptr Double -> TVV Double | ||
261 | |||
262 | -- | map of complex vectors with given function | ||
263 | vectorMapValC :: FunCodeSV -> Complex Double -> Vector (Complex Double) -> Vector (Complex Double) | ||
264 | vectorMapValC = vectorMapValAux c_vectorMapValC | ||
265 | |||
266 | foreign import ccall unsafe "mapValC" c_vectorMapValC :: CInt -> Ptr (Complex Double) -> TVV (Complex Double) | ||
267 | |||
268 | -- | map of real vectors with given function | ||
269 | vectorMapValF :: FunCodeSV -> Float -> Vector Float -> Vector Float | ||
270 | vectorMapValF oper = vectorMapValAux c_vectorMapValF (fromei oper) | ||
271 | |||
272 | foreign import ccall unsafe "mapValF" c_vectorMapValF :: CInt -> Ptr Float -> TVV Float | ||
273 | |||
274 | -- | map of complex vectors with given function | ||
275 | vectorMapValQ :: FunCodeSV -> Complex Float -> Vector (Complex Float) -> Vector (Complex Float) | ||
276 | vectorMapValQ oper = vectorMapValAux c_vectorMapValQ (fromei oper) | ||
277 | |||
278 | foreign import ccall unsafe "mapValQ" c_vectorMapValQ :: CInt -> Ptr (Complex Float) -> TVV (Complex Float) | ||
279 | |||
280 | -- | map of real vectors with given function | ||
281 | vectorMapValI :: FunCodeSV -> CInt -> Vector CInt -> Vector CInt | ||
282 | vectorMapValI oper = vectorMapValAux c_vectorMapValI (fromei oper) | ||
283 | |||
284 | foreign import ccall unsafe "mapValI" c_vectorMapValI :: CInt -> Ptr CInt -> TVV CInt | ||
285 | |||
286 | -- | map of real vectors with given function | ||
287 | vectorMapValL :: FunCodeSV -> Z -> Vector Z -> Vector Z | ||
288 | vectorMapValL oper = vectorMapValAux c_vectorMapValL (fromei oper) | ||
289 | |||
290 | foreign import ccall unsafe "mapValL" c_vectorMapValL :: CInt -> Ptr Z -> TVV Z | ||
291 | |||
292 | |||
293 | ------------------------------------------------------------------- | ||
294 | |||
295 | type TVVV t = t :> t :> t :> Ok | ||
296 | |||
297 | -- | elementwise operation on real vectors | ||
298 | vectorZipR :: FunCodeVV -> Vector Double -> Vector Double -> Vector Double | ||
299 | vectorZipR = vectorZipAux c_vectorZipR | ||
300 | |||
301 | foreign import ccall unsafe "zipR" c_vectorZipR :: CInt -> TVVV Double | ||
302 | |||
303 | -- | elementwise operation on complex vectors | ||
304 | vectorZipC :: FunCodeVV -> Vector (Complex Double) -> Vector (Complex Double) -> Vector (Complex Double) | ||
305 | vectorZipC = vectorZipAux c_vectorZipC | ||
306 | |||
307 | foreign import ccall unsafe "zipC" c_vectorZipC :: CInt -> TVVV (Complex Double) | ||
308 | |||
309 | -- | elementwise operation on real vectors | ||
310 | vectorZipF :: FunCodeVV -> Vector Float -> Vector Float -> Vector Float | ||
311 | vectorZipF = vectorZipAux c_vectorZipF | ||
312 | |||
313 | foreign import ccall unsafe "zipF" c_vectorZipF :: CInt -> TVVV Float | ||
314 | |||
315 | -- | elementwise operation on complex vectors | ||
316 | vectorZipQ :: FunCodeVV -> Vector (Complex Float) -> Vector (Complex Float) -> Vector (Complex Float) | ||
317 | vectorZipQ = vectorZipAux c_vectorZipQ | ||
318 | |||
319 | foreign import ccall unsafe "zipQ" c_vectorZipQ :: CInt -> TVVV (Complex Float) | ||
320 | |||
321 | -- | elementwise operation on CInt vectors | ||
322 | vectorZipI :: FunCodeVV -> Vector CInt -> Vector CInt -> Vector CInt | ||
323 | vectorZipI = vectorZipAux c_vectorZipI | ||
324 | |||
325 | foreign import ccall unsafe "zipI" c_vectorZipI :: CInt -> TVVV CInt | ||
326 | |||
327 | -- | elementwise operation on CInt vectors | ||
328 | vectorZipL :: FunCodeVV -> Vector Z -> Vector Z -> Vector Z | ||
329 | vectorZipL = vectorZipAux c_vectorZipL | ||
330 | |||
331 | foreign import ccall unsafe "zipL" c_vectorZipL :: CInt -> TVVV Z | ||
332 | |||
333 | -------------------------------------------------------------------------------- | ||
334 | |||
335 | foreign import ccall unsafe "vectorScan" c_vectorScan | ||
336 | :: CString -> Ptr CInt -> Ptr (Ptr Double) -> IO CInt | ||
337 | |||
338 | vectorScan :: FilePath -> IO (Vector Double) | ||
339 | vectorScan s = do | ||
340 | pp <- malloc | ||
341 | pn <- malloc | ||
342 | cs <- newCString s | ||
343 | ok <- c_vectorScan cs pn pp | ||
344 | when (not (ok == 0)) $ | ||
345 | error ("vectorScan: file \"" ++ s ++"\" not found") | ||
346 | n <- fromIntegral <$> peek pn | ||
347 | p <- peek pp | ||
348 | v <- createVector n | ||
349 | free pn | ||
350 | free cs | ||
351 | unsafeWith v $ \pv -> copyArray pv p n | ||
352 | free p | ||
353 | free pp | ||
354 | return v | ||
355 | |||
356 | -------------------------------------------------------------------------------- | ||
357 | |||
358 | type Seed = Int | ||
359 | |||
360 | data RandDist = Uniform -- ^ uniform distribution in [0,1) | ||
361 | | Gaussian -- ^ normal distribution with mean zero and standard deviation one | ||
362 | deriving Enum | ||
363 | |||
364 | -- | Obtains a vector of pseudorandom elements (use randomIO to get a random seed). | ||
365 | randomVector :: Seed | ||
366 | -> RandDist -- ^ distribution | ||
367 | -> Int -- ^ vector size | ||
368 | -> Vector Double | ||
369 | randomVector seed dist n = unsafePerformIO $ do | ||
370 | r <- createVector n | ||
371 | c_random_vector (fi seed) ((fi.fromEnum) dist) # r #|"randomVector" | ||
372 | return r | ||
373 | |||
374 | foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> Double :> Ok | ||
375 | |||
376 | -------------------------------------------------------------------------------- | ||
377 | |||
378 | roundVector v = unsafePerformIO $ do | ||
379 | r <- createVector (dim v) | ||
380 | c_round_vector # v # r #|"roundVector" | ||
381 | return r | ||
382 | |||
383 | foreign import ccall unsafe "round_vector" c_round_vector :: TVV Double | ||
384 | |||
385 | -------------------------------------------------------------------------------- | ||
386 | |||
387 | -- | | ||
388 | -- >>> range 5 | ||
389 | -- fromList [0,1,2,3,4] | ||
390 | -- | ||
391 | range :: Int -> Vector I | ||
392 | range n = unsafePerformIO $ do | ||
393 | r <- createVector n | ||
394 | c_range_vector # r #|"range" | ||
395 | return r | ||
396 | |||
397 | foreign import ccall unsafe "range_vector" c_range_vector :: CInt :> Ok | ||
398 | |||
399 | |||
400 | float2DoubleV :: Vector Float -> Vector Double | ||
401 | float2DoubleV = tog c_float2double | ||
402 | |||
403 | double2FloatV :: Vector Double -> Vector Float | ||
404 | double2FloatV = tog c_double2float | ||
405 | |||
406 | double2IntV :: Vector Double -> Vector CInt | ||
407 | double2IntV = tog c_double2int | ||
408 | |||
409 | int2DoubleV :: Vector CInt -> Vector Double | ||
410 | int2DoubleV = tog c_int2double | ||
411 | |||
412 | double2longV :: Vector Double -> Vector Z | ||
413 | double2longV = tog c_double2long | ||
414 | |||
415 | long2DoubleV :: Vector Z -> Vector Double | ||
416 | long2DoubleV = tog c_long2double | ||
417 | |||
418 | |||
419 | float2IntV :: Vector Float -> Vector CInt | ||
420 | float2IntV = tog c_float2int | ||
421 | |||
422 | int2floatV :: Vector CInt -> Vector Float | ||
423 | int2floatV = tog c_int2float | ||
424 | |||
425 | int2longV :: Vector I -> Vector Z | ||
426 | int2longV = tog c_int2long | ||
427 | |||
428 | long2intV :: Vector Z -> Vector I | ||
429 | long2intV = tog c_long2int | ||
430 | |||
431 | |||
432 | tog f v = unsafePerformIO $ do | ||
433 | r <- createVector (dim v) | ||
434 | f # v # r #|"tog" | ||
435 | return r | ||
436 | |||
437 | foreign import ccall unsafe "float2double" c_float2double :: Float :> Double :> Ok | ||
438 | foreign import ccall unsafe "double2float" c_double2float :: Double :> Float :> Ok | ||
439 | foreign import ccall unsafe "int2double" c_int2double :: CInt :> Double :> Ok | ||
440 | foreign import ccall unsafe "double2int" c_double2int :: Double :> CInt :> Ok | ||
441 | foreign import ccall unsafe "long2double" c_long2double :: Z :> Double :> Ok | ||
442 | foreign import ccall unsafe "double2long" c_double2long :: Double :> Z :> Ok | ||
443 | foreign import ccall unsafe "int2float" c_int2float :: CInt :> Float :> Ok | ||
444 | foreign import ccall unsafe "float2int" c_float2int :: Float :> CInt :> Ok | ||
445 | foreign import ccall unsafe "int2long" c_int2long :: I :> Z :> Ok | ||
446 | foreign import ccall unsafe "long2int" c_long2int :: Z :> I :> Ok | ||
447 | |||
448 | |||
449 | --------------------------------------------------------------- | ||
450 | |||
451 | stepg f v = unsafePerformIO $ do | ||
452 | r <- createVector (dim v) | ||
453 | f # v # r #|"step" | ||
454 | return r | ||
455 | |||
456 | stepD :: Vector Double -> Vector Double | ||
457 | stepD = stepg c_stepD | ||
458 | |||
459 | stepF :: Vector Float -> Vector Float | ||
460 | stepF = stepg c_stepF | ||
461 | |||
462 | stepI :: Vector CInt -> Vector CInt | ||
463 | stepI = stepg c_stepI | ||
464 | |||
465 | stepL :: Vector Z -> Vector Z | ||
466 | stepL = stepg c_stepL | ||
467 | |||
468 | |||
469 | foreign import ccall unsafe "stepF" c_stepF :: TVV Float | ||
470 | foreign import ccall unsafe "stepD" c_stepD :: TVV Double | ||
471 | foreign import ccall unsafe "stepI" c_stepI :: TVV CInt | ||
472 | foreign import ccall unsafe "stepL" c_stepL :: TVV Z | ||
473 | |||
474 | -------------------------------------------------------------------------------- | ||
475 | |||
476 | conjugateAux fun x = unsafePerformIO $ do | ||
477 | v <- createVector (dim x) | ||
478 | fun # x # v #|"conjugateAux" | ||
479 | return v | ||
480 | |||
481 | conjugateQ :: Vector (Complex Float) -> Vector (Complex Float) | ||
482 | conjugateQ = conjugateAux c_conjugateQ | ||
483 | foreign import ccall unsafe "conjugateQ" c_conjugateQ :: TVV (Complex Float) | ||
484 | |||
485 | conjugateC :: Vector (Complex Double) -> Vector (Complex Double) | ||
486 | conjugateC = conjugateAux c_conjugateC | ||
487 | foreign import ccall unsafe "conjugateC" c_conjugateC :: TVV (Complex Double) | ||
488 | |||
489 | -------------------------------------------------------------------------------- | ||
490 | |||
491 | cloneVector :: Storable t => Vector t -> IO (Vector t) | ||
492 | cloneVector v = do | ||
493 | let n = dim v | ||
494 | r <- createVector n | ||
495 | let f _ s _ d = copyArray d s n >> return 0 | ||
496 | f # v # r #|"cloneVector" | ||
497 | return r | ||
498 | |||
499 | -------------------------------------------------------------------------------- | ||
500 | |||
501 | constantAux fun x n = unsafePerformIO $ do | ||
502 | v <- createVector n | ||
503 | px <- newArray [x] | ||
504 | fun px # v #|"constantAux" | ||
505 | free px | ||
506 | return v | ||
507 | |||
508 | type TConst t = Ptr t -> t :> Ok | ||
509 | |||
510 | foreign import ccall unsafe "constantF" cconstantF :: TConst Float | ||
511 | foreign import ccall unsafe "constantR" cconstantR :: TConst Double | ||
512 | foreign import ccall unsafe "constantQ" cconstantQ :: TConst (Complex Float) | ||
513 | foreign import ccall unsafe "constantC" cconstantC :: TConst (Complex Double) | ||
514 | foreign import ccall unsafe "constantI" cconstantI :: TConst CInt | ||
515 | foreign import ccall unsafe "constantL" cconstantL :: TConst Z | ||
516 | |||
517 | ---------------------------------------------------------------------- | ||
518 | |||
diff --git a/packages/base/src/Numeric/Container.hs b/packages/base/src/Numeric/Container.hs deleted file mode 100644 index f78bfb9..0000000 --- a/packages/base/src/Numeric/Container.hs +++ /dev/null | |||
@@ -1,49 +0,0 @@ | |||
1 | {-# OPTIONS_HADDOCK hide #-} | ||
2 | |||
3 | module Numeric.Container( | ||
4 | module Data.Packed, | ||
5 | constant, | ||
6 | linspace, | ||
7 | diag, | ||
8 | ident, | ||
9 | ctrans, | ||
10 | Container(scaleRecip, addConstant,add, sub, mul, divide, equal), | ||
11 | scalar, | ||
12 | conj, | ||
13 | scale, | ||
14 | arctan2, | ||
15 | cmap, | ||
16 | Konst(..), | ||
17 | Build(..), | ||
18 | atIndex, | ||
19 | minIndex, maxIndex, minElement, maxElement, | ||
20 | sumElements, prodElements, | ||
21 | step, cond, find, assoc, accum, | ||
22 | Element(..), | ||
23 | Product(..), dot, udot, | ||
24 | optimiseMult, | ||
25 | mXm, mXv, vXm, (<.>), | ||
26 | Mul(..), | ||
27 | LSDiv, (<\>), | ||
28 | outer, kronecker, | ||
29 | RandDist(..), | ||
30 | randomVector, gaussianSample, uniformSample, | ||
31 | meanCov, | ||
32 | Convert(..), | ||
33 | Complexable, | ||
34 | RealElement, | ||
35 | RealOf, ComplexOf, SingleOf, DoubleOf, IndexOf, | ||
36 | module Data.Complex, | ||
37 | dispf, disps, dispcf, vecdisp, latexFormat, format, | ||
38 | loadMatrix, saveMatrix, readMatrix | ||
39 | ) where | ||
40 | |||
41 | |||
42 | import Data.Packed.Numeric | ||
43 | import Data.Packed | ||
44 | import Data.Packed.Internal(constantD) | ||
45 | import Data.Complex | ||
46 | |||
47 | constant :: Element a => a -> Int -> Vector a | ||
48 | constant = constantD | ||
49 | |||
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs index ad315e4..6a9c33a 100644 --- a/packages/base/src/Numeric/LinearAlgebra.hs +++ b/packages/base/src/Numeric/LinearAlgebra.hs | |||
@@ -1,22 +1,255 @@ | |||
1 | -------------------------------------------------------------------------------- | 1 | {-# LANGUAGE FlexibleContexts #-} |
2 | |||
3 | ----------------------------------------------------------------------------- | ||
2 | {- | | 4 | {- | |
3 | Module : Numeric.LinearAlgebra | 5 | Module : Numeric.LinearAlgebra |
4 | Copyright : (c) Alberto Ruiz 2006-14 | 6 | Copyright : (c) Alberto Ruiz 2006-15 |
5 | License : BSD3 | 7 | License : BSD3 |
6 | Maintainer : Alberto Ruiz | 8 | Maintainer : Alberto Ruiz |
7 | Stability : provisional | 9 | Stability : provisional |
8 | 10 | ||
9 | -} | ||
10 | -------------------------------------------------------------------------------- | ||
11 | {-# OPTIONS_HADDOCK hide #-} | ||
12 | 11 | ||
12 | -} | ||
13 | ----------------------------------------------------------------------------- | ||
13 | module Numeric.LinearAlgebra ( | 14 | module Numeric.LinearAlgebra ( |
14 | module Numeric.Container, | 15 | |
15 | module Numeric.LinearAlgebra.Algorithms | 16 | -- * Basic types and data manipulation |
17 | -- | This package works with 2D ('Matrix') and 1D ('Vector') | ||
18 | -- arrays of real ('R') or complex ('C') double precision numbers. | ||
19 | -- Single precision and machine integers are also supported for | ||
20 | -- basic arithmetic and data manipulation. | ||
21 | module Numeric.LinearAlgebra.Data, | ||
22 | |||
23 | -- * Numeric classes | ||
24 | -- | | ||
25 | -- The standard numeric classes are defined elementwise: | ||
26 | -- | ||
27 | -- >>> vector [1,2,3] * vector [3,0,-2] | ||
28 | -- fromList [3.0,0.0,-6.0] | ||
29 | -- | ||
30 | -- >>> matrix 3 [1..9] * ident 3 | ||
31 | -- (3><3) | ||
32 | -- [ 1.0, 0.0, 0.0 | ||
33 | -- , 0.0, 5.0, 0.0 | ||
34 | -- , 0.0, 0.0, 9.0 ] | ||
35 | |||
36 | -- * Autoconformable dimensions | ||
37 | -- | | ||
38 | -- In most operations, single-element vectors and matrices | ||
39 | -- (created from numeric literals or using 'scalar'), and matrices | ||
40 | -- with just one row or column, automatically | ||
41 | -- expand to match the dimensions of the other operand: | ||
42 | -- | ||
43 | -- >>> 5 + 2*ident 3 :: Matrix Double | ||
44 | -- (3><3) | ||
45 | -- [ 7.0, 5.0, 5.0 | ||
46 | -- , 5.0, 7.0, 5.0 | ||
47 | -- , 5.0, 5.0, 7.0 ] | ||
48 | -- | ||
49 | -- >>> (4><3) [1..] + row [10,20,30] | ||
50 | -- (4><3) | ||
51 | -- [ 11.0, 22.0, 33.0 | ||
52 | -- , 14.0, 25.0, 36.0 | ||
53 | -- , 17.0, 28.0, 39.0 | ||
54 | -- , 20.0, 31.0, 42.0 ] | ||
55 | -- | ||
56 | |||
57 | -- * Products | ||
58 | -- ** Dot | ||
59 | dot, (<.>), | ||
60 | -- ** Matrix-vector | ||
61 | (#>), (<#), (!#>), | ||
62 | -- ** Matrix-matrix | ||
63 | (<>), | ||
64 | -- | The matrix product is also implemented in the "Data.Monoid" instance, where | ||
65 | -- single-element matrices (created from numeric literals or using 'scalar') | ||
66 | -- are used for scaling. | ||
67 | -- | ||
68 | -- >>> import Data.Monoid as M | ||
69 | -- >>> let m = matrix 3 [1..6] | ||
70 | -- >>> m M.<> 2 M.<> diagl[0.5,1,0] | ||
71 | -- (2><3) | ||
72 | -- [ 1.0, 4.0, 0.0 | ||
73 | -- , 4.0, 10.0, 0.0 ] | ||
74 | -- | ||
75 | -- 'mconcat' uses 'optimiseMult' to get the optimal association order. | ||
76 | |||
77 | |||
78 | -- ** Other | ||
79 | outer, kronecker, cross, | ||
80 | scale, add, | ||
81 | sumElements, prodElements, | ||
82 | |||
83 | -- * Linear systems | ||
84 | -- ** General | ||
85 | (<\>), | ||
86 | linearSolveLS, | ||
87 | linearSolveSVD, | ||
88 | -- ** Determined | ||
89 | linearSolve, | ||
90 | luSolve, luPacked, | ||
91 | luSolve', luPacked', | ||
92 | -- ** Symmetric indefinite | ||
93 | ldlSolve, ldlPacked, | ||
94 | -- ** Positive definite | ||
95 | cholSolve, | ||
96 | -- ** Sparse | ||
97 | cgSolve, | ||
98 | cgSolve', | ||
99 | |||
100 | -- * Inverse and pseudoinverse | ||
101 | inv, pinv, pinvTol, | ||
102 | |||
103 | -- * Determinant and rank | ||
104 | rcond, rank, | ||
105 | det, invlndet, | ||
106 | |||
107 | -- * Norms | ||
108 | Normed(..), | ||
109 | norm_Frob, norm_nuclear, | ||
110 | |||
111 | -- * Nullspace and range | ||
112 | orth, | ||
113 | nullspace, null1, null1sym, | ||
114 | |||
115 | -- * Singular value decomposition | ||
116 | svd, | ||
117 | thinSVD, | ||
118 | compactSVD, | ||
119 | singularValues, | ||
120 | leftSV, rightSV, | ||
121 | |||
122 | -- * Eigendecomposition | ||
123 | eig, eigSH, | ||
124 | eigenvalues, eigenvaluesSH, | ||
125 | geigSH, | ||
126 | |||
127 | -- * QR | ||
128 | qr, rq, qrRaw, qrgr, | ||
129 | |||
130 | -- * Cholesky | ||
131 | chol, mbChol, | ||
132 | |||
133 | -- * LU | ||
134 | lu, luFact, | ||
135 | |||
136 | -- * Hessenberg | ||
137 | hess, | ||
138 | |||
139 | -- * Schur | ||
140 | schur, | ||
141 | |||
142 | -- * Matrix functions | ||
143 | expm, | ||
144 | sqrtm, | ||
145 | matFunc, | ||
146 | |||
147 | -- * Correlation and convolution | ||
148 | corr, conv, corrMin, corr2, conv2, | ||
149 | |||
150 | -- * Random arrays | ||
151 | |||
152 | Seed, RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample, | ||
153 | |||
154 | -- * Misc | ||
155 | meanCov, rowOuters, pairwiseD2, unitary, peps, relativeError, magnit, | ||
156 | haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv, | ||
157 | iC, sym, mTm, trustSym, unSym, | ||
158 | -- * Auxiliary classes | ||
159 | Element, Container, Product, Numeric, LSDiv, Herm, | ||
160 | Complexable, RealElement, | ||
161 | RealOf, ComplexOf, SingleOf, DoubleOf, | ||
162 | IndexOf, | ||
163 | Field, Linear(), Additive(), | ||
164 | Transposable, | ||
165 | LU(..), | ||
166 | LDL(..), | ||
167 | QR(..), | ||
168 | CGState(..), | ||
169 | Testable(..) | ||
16 | ) where | 170 | ) where |
17 | 171 | ||
18 | import Numeric.Container | 172 | import Numeric.LinearAlgebra.Data |
19 | import Numeric.LinearAlgebra.Algorithms | 173 | |
20 | import Numeric.Matrix() | 174 | import Numeric.Matrix() |
21 | import Numeric.Vector() | 175 | import Numeric.Vector() |
176 | import Internal.Matrix | ||
177 | import Internal.Container hiding ((<>)) | ||
178 | import Internal.Numeric hiding (mul) | ||
179 | import Internal.Algorithms hiding (linearSolve,Normed,orth,luPacked',linearSolve',luSolve',ldlPacked') | ||
180 | import qualified Internal.Algorithms as A | ||
181 | import Internal.Util | ||
182 | import Internal.Random | ||
183 | import Internal.Sparse((!#>)) | ||
184 | import Internal.CG | ||
185 | import Internal.Conversion | ||
186 | |||
187 | {- | dense matrix product | ||
188 | |||
189 | >>> let a = (3><5) [1..] | ||
190 | >>> a | ||
191 | (3><5) | ||
192 | [ 1.0, 2.0, 3.0, 4.0, 5.0 | ||
193 | , 6.0, 7.0, 8.0, 9.0, 10.0 | ||
194 | , 11.0, 12.0, 13.0, 14.0, 15.0 ] | ||
195 | |||
196 | >>> let b = (5><2) [1,3, 0,2, -1,5, 7,7, 6,0] | ||
197 | >>> b | ||
198 | (5><2) | ||
199 | [ 1.0, 3.0 | ||
200 | , 0.0, 2.0 | ||
201 | , -1.0, 5.0 | ||
202 | , 7.0, 7.0 | ||
203 | , 6.0, 0.0 ] | ||
204 | |||
205 | >>> a <> b | ||
206 | (3><2) | ||
207 | [ 56.0, 50.0 | ||
208 | , 121.0, 135.0 | ||
209 | , 186.0, 220.0 ] | ||
210 | |||
211 | -} | ||
212 | (<>) :: Numeric t => Matrix t -> Matrix t -> Matrix t | ||
213 | (<>) = mXm | ||
214 | infixr 8 <> | ||
215 | |||
216 | |||
217 | {- | Solve a linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition, returning Nothing for a singular system. For underconstrained or overconstrained systems use 'linearSolveLS' or 'linearSolveSVD'. | ||
218 | |||
219 | @ | ||
220 | a = (2><2) | ||
221 | [ 1.0, 2.0 | ||
222 | , 3.0, 5.0 ] | ||
223 | @ | ||
224 | |||
225 | @ | ||
226 | b = (2><3) | ||
227 | [ 6.0, 1.0, 10.0 | ||
228 | , 15.0, 3.0, 26.0 ] | ||
229 | @ | ||
230 | |||
231 | >>> linearSolve a b | ||
232 | Just (2><3) | ||
233 | [ -1.4802973661668753e-15, 0.9999999999999997, 1.999999999999997 | ||
234 | , 3.000000000000001, 1.6653345369377348e-16, 4.000000000000002 ] | ||
235 | |||
236 | >>> let Just x = it | ||
237 | >>> disp 5 x | ||
238 | 2x3 | ||
239 | -0.00000 1.00000 2.00000 | ||
240 | 3.00000 0.00000 4.00000 | ||
241 | |||
242 | >>> a <> x | ||
243 | (2><3) | ||
244 | [ 6.0, 1.0, 10.0 | ||
245 | , 15.0, 3.0, 26.0 ] | ||
246 | |||
247 | -} | ||
248 | linearSolve m b = A.mbLinearSolve m b | ||
249 | |||
250 | -- | return an orthonormal basis of the null space of a matrix. See also 'nullspaceSVD'. | ||
251 | nullspace m = nullspaceSVD (Left (1*eps)) m (rightSV m) | ||
252 | |||
253 | -- | return an orthonormal basis of the range space of a matrix. See also 'orthSVD'. | ||
254 | orth m = orthSVD (Left (1*eps)) m (leftSV m) | ||
22 | 255 | ||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Data.hs b/packages/base/src/Numeric/LinearAlgebra/Data.hs index 6dea407..a389aac 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Data.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Data.hs | |||
@@ -1,83 +1,121 @@ | |||
1 | {-# LANGUAGE TypeOperators #-} | ||
2 | |||
1 | -------------------------------------------------------------------------------- | 3 | -------------------------------------------------------------------------------- |
2 | {- | | 4 | {- | |
3 | Module : Numeric.LinearAlgebra.Data | 5 | Module : Numeric.LinearAlgebra.Data |
4 | Copyright : (c) Alberto Ruiz 2014 | 6 | Copyright : (c) Alberto Ruiz 2015 |
5 | License : BSD3 | 7 | License : BSD3 |
6 | Maintainer : Alberto Ruiz | 8 | Maintainer : Alberto Ruiz |
7 | Stability : provisional | 9 | Stability : provisional |
8 | 10 | ||
9 | Basic data processing. | 11 | This module provides functions for creation and manipulation of vectors and matrices, IO, and other utilities. |
10 | 12 | ||
11 | -} | 13 | -} |
12 | -------------------------------------------------------------------------------- | 14 | -------------------------------------------------------------------------------- |
13 | 15 | ||
14 | module Numeric.LinearAlgebra.Data( | 16 | module Numeric.LinearAlgebra.Data( |
15 | 17 | ||
18 | -- * Elements | ||
19 | R,C,I,Z,type(./.), | ||
20 | |||
16 | -- * Vector | 21 | -- * Vector |
17 | -- | 1D arrays are storable vectors from the vector package. | 22 | {- | 1D arrays are storable vectors directly reexported from the vector package. |
18 | 23 | -} | |
19 | vector, (|>), | 24 | |
25 | fromList, toList, (|>), vector, range, idxs, | ||
20 | 26 | ||
21 | -- * Matrix | 27 | -- * Matrix |
22 | 28 | ||
23 | matrix, (><), tr, | 29 | {- | The main data type of hmatrix is a 2D dense array defined on top of |
24 | 30 | a storable vector. The internal representation is suitable for direct | |
31 | interface with standard numeric libraries. | ||
32 | -} | ||
33 | |||
34 | (><), matrix, tr, tr', | ||
35 | |||
36 | -- * Dimensions | ||
37 | |||
38 | size, rows, cols, | ||
39 | |||
40 | -- * Conversion from\/to lists | ||
41 | |||
42 | fromLists, toLists, | ||
43 | row, col, | ||
44 | |||
45 | -- * Conversions vector\/matrix | ||
46 | |||
47 | flatten, reshape, asRow, asColumn, | ||
48 | fromRows, toRows, fromColumns, toColumns, | ||
49 | |||
25 | -- * Indexing | 50 | -- * Indexing |
26 | 51 | ||
27 | size, | 52 | atIndex, |
28 | Indexable(..), | 53 | Indexable(..), |
29 | 54 | ||
30 | -- * Construction | 55 | -- * Construction |
31 | scalar, Konst(..), Build(..), assoc, accum, linspace, -- ones, zeros, | 56 | scalar, Konst(..), Build(..), assoc, accum, linspace, -- ones, zeros, |
32 | 57 | ||
33 | -- * Diagonal | 58 | -- * Diagonal |
34 | ident, diag, diagl, diagRect, takeDiag, | 59 | ident, diag, diagl, diagRect, takeDiag, |
35 | 60 | ||
36 | -- * Data manipulation | 61 | -- * Vector extraction |
37 | fromList, toList, subVector, takesV, vjoin, | 62 | subVector, takesV, vjoin, |
38 | flatten, reshape, asRow, asColumn, row, col, | 63 | |
39 | fromRows, toRows, fromColumns, toColumns, fromLists, toLists, fromArray2D, | 64 | -- * Matrix extraction |
40 | takeRows, dropRows, takeColumns, dropColumns, subMatrix, (?), (¿), fliprl, flipud, | 65 | Extractor(..), (??), |
41 | 66 | ||
67 | (?), (¿), fliprl, flipud, | ||
68 | |||
69 | subMatrix, takeRows, dropRows, takeColumns, dropColumns, | ||
70 | |||
71 | remap, | ||
72 | |||
42 | -- * Block matrix | 73 | -- * Block matrix |
43 | fromBlocks, (|||), (===), diagBlock, repmat, toBlocks, toBlocksEvery, | 74 | fromBlocks, (|||), (===), diagBlock, repmat, toBlocks, toBlocksEvery, |
44 | 75 | ||
45 | -- * Mapping functions | 76 | -- * Mapping functions |
46 | conj, cmap, step, cond, | 77 | conj, cmap, cmod, |
47 | 78 | ||
79 | step, cond, | ||
80 | |||
48 | -- * Find elements | 81 | -- * Find elements |
49 | find, maxIndex, minIndex, maxElement, minElement, atIndex, | 82 | find, maxIndex, minIndex, maxElement, minElement, |
50 | sortVector, | 83 | sortVector, sortIndex, |
51 | 84 | ||
52 | -- * Sparse | 85 | -- * Sparse |
53 | AssocMatrix, toDense, | 86 | AssocMatrix, toDense, |
54 | mkSparse, mkDiagR, mkDense, | 87 | mkSparse, mkDiagR, mkDense, |
55 | 88 | ||
56 | -- * IO | 89 | -- * IO |
57 | disp, | 90 | disp, |
58 | loadMatrix, loadMatrix', saveMatrix, | 91 | loadMatrix, loadMatrix', saveMatrix, |
59 | latexFormat, | 92 | latexFormat, |
60 | dispf, disps, dispcf, format, | 93 | dispf, disps, dispcf, format, |
61 | dispDots, dispBlanks, dispShort, | 94 | dispDots, dispBlanks, dispShort, |
62 | -- * Conversion | 95 | -- * Element conversion |
63 | Convert(..), | 96 | Convert(..), |
64 | roundVector, | 97 | roundVector, |
98 | fromInt,toInt,fromZ,toZ, | ||
65 | -- * Misc | 99 | -- * Misc |
66 | arctan2, | 100 | arctan2, |
67 | rows, cols, | ||
68 | separable, | 101 | separable, |
69 | (¦),(——), | 102 | fromArray2D, |
70 | module Data.Complex, | 103 | module Data.Complex, |
71 | 104 | Mod, | |
72 | Vector, Matrix, GMatrix, nRows, nCols | 105 | Vector, Matrix, GMatrix, nRows, nCols |
73 | 106 | ||
74 | ) where | 107 | ) where |
75 | 108 | ||
76 | import Data.Packed.Vector | 109 | import Internal.Vector |
77 | import Data.Packed.Matrix | 110 | import Internal.Vectorized |
78 | import Data.Packed.Numeric | 111 | import Internal.Matrix hiding (size) |
79 | import Numeric.LinearAlgebra.Util hiding ((&),(#)) | 112 | import Internal.Element |
113 | import Internal.IO | ||
114 | import Internal.Numeric | ||
115 | import Internal.Container | ||
116 | import Internal.Util hiding ((&)) | ||
80 | import Data.Complex | 117 | import Data.Complex |
81 | import Numeric.Sparse | 118 | import Internal.Sparse |
119 | import Internal.Modular | ||
82 | 120 | ||
83 | 121 | ||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Devel.hs b/packages/base/src/Numeric/LinearAlgebra/Devel.hs index 55894e0..57a68e7 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Devel.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Devel.hs | |||
@@ -17,16 +17,23 @@ module Numeric.LinearAlgebra.Devel( | |||
17 | -- | 17 | -- |
18 | -- @ glUniformMatrix4fv 0 1 (fromIntegral gl_TRUE) \`appMatrix\` perspective 0.01 100 (pi\/2) (4\/3) | 18 | -- @ glUniformMatrix4fv 0 1 (fromIntegral gl_TRUE) \`appMatrix\` perspective 0.01 100 (pi\/2) (4\/3) |
19 | -- @ | 19 | -- @ |
20 | module Data.Packed.Foreign, | 20 | module Internal.Foreign, |
21 | 21 | ||
22 | -- * FFI tools | 22 | -- * FFI tools |
23 | -- | Illustrative usage examples can be found | 23 | -- | See @examples/devel@ in the repository. |
24 | -- in the @examples\/devel@ folder included in the package. | 24 | |
25 | module Data.Packed.Development, | 25 | createVector, createMatrix, |
26 | TransArray(..), | ||
27 | MatrixOrder(..), orderOf, cmat, fmat, | ||
28 | matrixFromVector, | ||
29 | unsafeFromForeignPtr, | ||
30 | unsafeToForeignPtr, | ||
31 | check, (//), (#|), | ||
32 | at', atM', fi, ti, | ||
26 | 33 | ||
27 | -- * ST | 34 | -- * ST |
28 | -- | In-place manipulation inside the ST monad. | 35 | -- | In-place manipulation inside the ST monad. |
29 | -- See examples\/inplace.hs in the distribution. | 36 | -- See @examples/inplace.hs@ in the repository. |
30 | 37 | ||
31 | -- ** Mutable Vectors | 38 | -- ** Mutable Vectors |
32 | STVector, newVector, thawVector, freezeVector, runSTVector, | 39 | STVector, newVector, thawVector, freezeVector, runSTVector, |
@@ -34,6 +41,7 @@ module Numeric.LinearAlgebra.Devel( | |||
34 | -- ** Mutable Matrices | 41 | -- ** Mutable Matrices |
35 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, | 42 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, |
36 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, | 43 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, |
44 | mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), gemmm, Slice(..), | ||
37 | -- ** Unsafe functions | 45 | -- ** Unsafe functions |
38 | newUndefinedVector, | 46 | newUndefinedVector, |
39 | unsafeReadVector, unsafeWriteVector, | 47 | unsafeReadVector, unsafeWriteVector, |
@@ -54,13 +62,15 @@ module Numeric.LinearAlgebra.Devel( | |||
54 | GMatrix(..), | 62 | GMatrix(..), |
55 | 63 | ||
56 | -- * Misc | 64 | -- * Misc |
57 | toByteString, fromByteString | 65 | toByteString, fromByteString, showInternal |
58 | 66 | ||
59 | ) where | 67 | ) where |
60 | 68 | ||
61 | import Data.Packed.Foreign | 69 | import Internal.Foreign |
62 | import Data.Packed.Development | 70 | import Internal.Devel |
63 | import Data.Packed.ST | 71 | import Internal.ST |
64 | import Data.Packed | 72 | import Internal.Vector |
65 | import Numeric.Sparse | 73 | import Internal.Matrix |
74 | import Internal.Element | ||
75 | import Internal.Sparse | ||
66 | 76 | ||
diff --git a/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs b/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs index 677f9ee..5ce529c 100644 --- a/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs +++ b/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs | |||
@@ -1,4 +1,4 @@ | |||
1 | ----------------------------------------------------------------------------- | 1 | -------------------------------------------------------------------------------- |
2 | {- | | 2 | {- | |
3 | Module : Numeric.LinearAlgebra.HMatrix | 3 | Module : Numeric.LinearAlgebra.HMatrix |
4 | Copyright : (c) Alberto Ruiz 2006-14 | 4 | Copyright : (c) Alberto Ruiz 2006-14 |
@@ -6,230 +6,25 @@ License : BSD3 | |||
6 | Maintainer : Alberto Ruiz | 6 | Maintainer : Alberto Ruiz |
7 | Stability : provisional | 7 | Stability : provisional |
8 | 8 | ||
9 | -} | 9 | compatibility with previous version, to be removed |
10 | ----------------------------------------------------------------------------- | ||
11 | module Numeric.LinearAlgebra.HMatrix ( | ||
12 | |||
13 | -- * Basic types and data processing | ||
14 | module Numeric.LinearAlgebra.Data, | ||
15 | |||
16 | -- * Arithmetic and numeric classes | ||
17 | -- | | ||
18 | -- The standard numeric classes are defined elementwise: | ||
19 | -- | ||
20 | -- >>> vector [1,2,3] * vector [3,0,-2] | ||
21 | -- fromList [3.0,0.0,-6.0] | ||
22 | -- | ||
23 | -- >>> matrix 3 [1..9] * ident 3 | ||
24 | -- (3><3) | ||
25 | -- [ 1.0, 0.0, 0.0 | ||
26 | -- , 0.0, 5.0, 0.0 | ||
27 | -- , 0.0, 0.0, 9.0 ] | ||
28 | -- | ||
29 | -- In arithmetic operations single-element vectors and matrices | ||
30 | -- (created from numeric literals or using 'scalar') automatically | ||
31 | -- expand to match the dimensions of the other operand: | ||
32 | -- | ||
33 | -- >>> 5 + 2*ident 3 :: Matrix Double | ||
34 | -- (3><3) | ||
35 | -- [ 7.0, 5.0, 5.0 | ||
36 | -- , 5.0, 7.0, 5.0 | ||
37 | -- , 5.0, 5.0, 7.0 ] | ||
38 | -- | ||
39 | -- >>> matrix 3 [1..9] + matrix 1 [10,20,30] | ||
40 | -- (3><3) | ||
41 | -- [ 11.0, 12.0, 13.0 | ||
42 | -- , 24.0, 25.0, 26.0 | ||
43 | -- , 37.0, 38.0, 39.0 ] | ||
44 | -- | ||
45 | |||
46 | -- * Products | ||
47 | -- ** dot | ||
48 | dot, (<·>), | ||
49 | -- ** matrix-vector | ||
50 | app, (#>), (!#>), | ||
51 | -- ** matrix-matrix | ||
52 | mul, (<>), | ||
53 | -- | The matrix product is also implemented in the "Data.Monoid" instance, where | ||
54 | -- single-element matrices (created from numeric literals or using 'scalar') | ||
55 | -- are used for scaling. | ||
56 | -- | ||
57 | -- >>> import Data.Monoid as M | ||
58 | -- >>> let m = matrix 3 [1..6] | ||
59 | -- >>> m M.<> 2 M.<> diagl[0.5,1,0] | ||
60 | -- (2><3) | ||
61 | -- [ 1.0, 4.0, 0.0 | ||
62 | -- , 4.0, 10.0, 0.0 ] | ||
63 | -- | ||
64 | -- 'mconcat' uses 'optimiseMult' to get the optimal association order. | ||
65 | |||
66 | |||
67 | -- ** other | ||
68 | outer, kronecker, cross, | ||
69 | scale, | ||
70 | sumElements, prodElements, | ||
71 | |||
72 | -- * Linear Systems | ||
73 | (<\>), | ||
74 | linearSolve, | ||
75 | linearSolveLS, | ||
76 | linearSolveSVD, | ||
77 | luSolve, | ||
78 | cholSolve, | ||
79 | cgSolve, | ||
80 | cgSolve', | ||
81 | |||
82 | -- * Inverse and pseudoinverse | ||
83 | inv, pinv, pinvTol, | ||
84 | |||
85 | -- * Determinant and rank | ||
86 | rcond, rank, | ||
87 | det, invlndet, | ||
88 | |||
89 | -- * Norms | ||
90 | Normed(..), | ||
91 | norm_Frob, norm_nuclear, | ||
92 | |||
93 | -- * Nullspace and range | ||
94 | orth, | ||
95 | nullspace, null1, null1sym, | ||
96 | |||
97 | -- * SVD | ||
98 | svd, | ||
99 | thinSVD, | ||
100 | compactSVD, | ||
101 | singularValues, | ||
102 | leftSV, rightSV, | ||
103 | |||
104 | -- * Eigensystems | ||
105 | eig, eigSH, eigSH', | ||
106 | eigenvalues, eigenvaluesSH, eigenvaluesSH', | ||
107 | geigSH', | ||
108 | |||
109 | -- * QR | ||
110 | qr, rq, qrRaw, qrgr, | ||
111 | |||
112 | -- * Cholesky | ||
113 | chol, cholSH, mbCholSH, | ||
114 | |||
115 | -- * Hessenberg | ||
116 | hess, | ||
117 | |||
118 | -- * Schur | ||
119 | schur, | ||
120 | |||
121 | -- * LU | ||
122 | lu, luPacked, | ||
123 | |||
124 | -- * Matrix functions | ||
125 | expm, | ||
126 | sqrtm, | ||
127 | matFunc, | ||
128 | |||
129 | -- * Correlation and convolution | ||
130 | corr, conv, corrMin, corr2, conv2, | ||
131 | |||
132 | -- * Random arrays | ||
133 | |||
134 | Seed, RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample, | ||
135 | |||
136 | -- * Misc | ||
137 | meanCov, rowOuters, peps, relativeError, haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv, | ||
138 | ℝ,ℂ,iC, | ||
139 | -- * Auxiliary classes | ||
140 | Element, Container, Product, Numeric, LSDiv, | ||
141 | Complexable, RealElement, | ||
142 | RealOf, ComplexOf, SingleOf, DoubleOf, | ||
143 | IndexOf, | ||
144 | Field, | ||
145 | -- Normed, | ||
146 | Transposable, | ||
147 | CGState(..), | ||
148 | Testable(..) | ||
149 | ) where | ||
150 | |||
151 | import Numeric.LinearAlgebra.Data | ||
152 | |||
153 | import Numeric.Matrix() | ||
154 | import Numeric.Vector() | ||
155 | import Data.Packed.Numeric hiding ((<>), mul) | ||
156 | import Numeric.LinearAlgebra.Algorithms hiding (linearSolve,Normed,orth) | ||
157 | import qualified Numeric.LinearAlgebra.Algorithms as A | ||
158 | import Numeric.LinearAlgebra.Util | ||
159 | import Numeric.LinearAlgebra.Random | ||
160 | import Numeric.Sparse((!#>)) | ||
161 | import Numeric.LinearAlgebra.Util.CG | ||
162 | |||
163 | {- | infix synonym of 'mul' | ||
164 | |||
165 | >>> let a = (3><5) [1..] | ||
166 | >>> a | ||
167 | (3><5) | ||
168 | [ 1.0, 2.0, 3.0, 4.0, 5.0 | ||
169 | , 6.0, 7.0, 8.0, 9.0, 10.0 | ||
170 | , 11.0, 12.0, 13.0, 14.0, 15.0 ] | ||
171 | |||
172 | >>> let b = (5><2) [1,3, 0,2, -1,5, 7,7, 6,0] | ||
173 | >>> b | ||
174 | (5><2) | ||
175 | [ 1.0, 3.0 | ||
176 | , 0.0, 2.0 | ||
177 | , -1.0, 5.0 | ||
178 | , 7.0, 7.0 | ||
179 | , 6.0, 0.0 ] | ||
180 | |||
181 | >>> a <> b | ||
182 | (3><2) | ||
183 | [ 56.0, 50.0 | ||
184 | , 121.0, 135.0 | ||
185 | , 186.0, 220.0 ] | ||
186 | 10 | ||
187 | -} | 11 | -} |
188 | (<>) :: Numeric t => Matrix t -> Matrix t -> Matrix t | 12 | -------------------------------------------------------------------------------- |
189 | (<>) = mXm | ||
190 | infixr 8 <> | ||
191 | |||
192 | -- | dense matrix product | ||
193 | mul :: Numeric t => Matrix t -> Matrix t -> Matrix t | ||
194 | mul = mXm | ||
195 | |||
196 | |||
197 | {- | Solve a linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition, returning Nothing for a singular system. For underconstrained or overconstrained systems use 'linearSolveLS' or 'linearSolveSVD'. | ||
198 | |||
199 | @ | ||
200 | a = (2><2) | ||
201 | [ 1.0, 2.0 | ||
202 | , 3.0, 5.0 ] | ||
203 | @ | ||
204 | 13 | ||
205 | @ | 14 | module Numeric.LinearAlgebra.HMatrix ( |
206 | b = (2><3) | 15 | module Numeric.LinearAlgebra, |
207 | [ 6.0, 1.0, 10.0 | 16 | (¦),(——),ℝ,ℂ,(<·>),app,mul, cholSH, mbCholSH, eigSH', eigenvaluesSH', geigSH' |
208 | , 15.0, 3.0, 26.0 ] | 17 | ) where |
209 | @ | ||
210 | |||
211 | >>> linearSolve a b | ||
212 | Just (2><3) | ||
213 | [ -1.4802973661668753e-15, 0.9999999999999997, 1.999999999999997 | ||
214 | , 3.000000000000001, 1.6653345369377348e-16, 4.000000000000002 ] | ||
215 | |||
216 | >>> let Just x = it | ||
217 | >>> disp 5 x | ||
218 | 2x3 | ||
219 | -0.00000 1.00000 2.00000 | ||
220 | 3.00000 0.00000 4.00000 | ||
221 | 18 | ||
222 | >>> a <> x | 19 | import Numeric.LinearAlgebra |
223 | (2><3) | 20 | import Internal.Util |
224 | [ 6.0, 1.0, 10.0 | 21 | import Internal.Algorithms(cholSH, mbCholSH, eigSH', eigenvaluesSH', geigSH') |
225 | , 15.0, 3.0, 26.0 ] | ||
226 | 22 | ||
227 | -} | 23 | infixr 8 <·> |
228 | linearSolve m b = A.mbLinearSolve m b | 24 | (<·>) :: Numeric t => Vector t -> Vector t -> t |
25 | (<·>) = dot | ||
229 | 26 | ||
230 | -- | return an orthonormal basis of the null space of a matrix. See also 'nullspaceSVD'. | 27 | app m v = m #> v |
231 | nullspace m = nullspaceSVD (Left (1*eps)) m (rightSV m) | ||
232 | 28 | ||
233 | -- | return an orthonormal basis of the range space of a matrix. See also 'orthSVD'. | 29 | mul a b = a <> b |
234 | orth m = orthSVD (Left (1*eps)) m (leftSV m) | ||
235 | 30 | ||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs index 3398e6a..843c727 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs | |||
@@ -1,5 +1,3 @@ | |||
1 | #if __GLASGOW_HASKELL__ >= 708 | ||
2 | |||
3 | {-# LANGUAGE DataKinds #-} | 1 | {-# LANGUAGE DataKinds #-} |
4 | {-# LANGUAGE KindSignatures #-} | 2 | {-# LANGUAGE KindSignatures #-} |
5 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} | 3 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} |
@@ -13,7 +11,6 @@ | |||
13 | {-# LANGUAGE TypeOperators #-} | 11 | {-# LANGUAGE TypeOperators #-} |
14 | {-# LANGUAGE ViewPatterns #-} | 12 | {-# LANGUAGE ViewPatterns #-} |
15 | {-# LANGUAGE GADTs #-} | 13 | {-# LANGUAGE GADTs #-} |
16 | {-# LANGUAGE OverlappingInstances #-} | ||
17 | {-# LANGUAGE TypeFamilies #-} | 14 | {-# LANGUAGE TypeFamilies #-} |
18 | 15 | ||
19 | 16 | ||
@@ -25,19 +22,19 @@ Stability : experimental | |||
25 | 22 | ||
26 | Experimental interface with statically checked dimensions. | 23 | Experimental interface with statically checked dimensions. |
27 | 24 | ||
28 | This module is under active development and the interface is subject to changes. | 25 | See code examples at http://dis.um.es/~alberto/hmatrix/static.html. |
29 | 26 | ||
30 | -} | 27 | -} |
31 | 28 | ||
32 | module Numeric.LinearAlgebra.Static( | 29 | module Numeric.LinearAlgebra.Static( |
33 | -- * Vector | 30 | -- * Vector |
34 | ℝ, R, | 31 | ℝ, R, |
35 | vec2, vec3, vec4, (&), (#), split, headTail, | 32 | vec2, vec3, vec4, (&), (#), split, headTail, |
36 | vector, | 33 | vector, |
37 | linspace, range, dim, | 34 | linspace, range, dim, |
38 | -- * Matrix | 35 | -- * Matrix |
39 | L, Sq, build, | 36 | L, Sq, build, |
40 | row, col, (¦),(——), splitRows, splitCols, | 37 | row, col, (|||),(===), splitRows, splitCols, |
41 | unrow, uncol, | 38 | unrow, uncol, |
42 | tr, | 39 | tr, |
43 | eye, | 40 | eye, |
@@ -47,7 +44,7 @@ module Numeric.LinearAlgebra.Static( | |||
47 | -- * Complex | 44 | -- * Complex |
48 | C, M, Her, her, 𝑖, | 45 | C, M, Her, her, 𝑖, |
49 | -- * Products | 46 | -- * Products |
50 | (<>),(#>),(<·>), | 47 | (<>),(#>),(<.>), |
51 | -- * Linear Systems | 48 | -- * Linear Systems |
52 | linSolve, (<\>), | 49 | linSolve, (<\>), |
53 | -- * Factorizations | 50 | -- * Factorizations |
@@ -58,26 +55,22 @@ module Numeric.LinearAlgebra.Static( | |||
58 | Disp(..), Domain(..), | 55 | Disp(..), Domain(..), |
59 | withVector, withMatrix, | 56 | withVector, withMatrix, |
60 | toRows, toColumns, | 57 | toRows, toColumns, |
61 | Sized(..), Diag(..), Sym, sym, mTm, unSym | 58 | Sized(..), Diag(..), Sym, sym, mTm, unSym, (<·>) |
62 | ) where | 59 | ) where |
63 | 60 | ||
64 | 61 | ||
65 | import GHC.TypeLits | 62 | import GHC.TypeLits |
66 | import Numeric.LinearAlgebra.HMatrix hiding ( | 63 | import Numeric.LinearAlgebra hiding ( |
67 | (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——), | 64 | (<>),(#>),(<.>),Konst(..),diag, disp,(===),(|||), |
68 | row,col,vector,matrix,linspace,toRows,toColumns, | 65 | row,col,vector,matrix,linspace,toRows,toColumns, |
69 | (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH', | 66 | (<\>),fromList,takeDiag,svd,eig,eigSH, |
70 | eigenvalues,eigenvaluesSH,eigenvaluesSH',build, | 67 | eigenvalues,eigenvaluesSH,build, |
71 | qr,size,app,mul,dot,chol) | 68 | qr,size,dot,chol,range,R,C,sym,mTm,unSym) |
72 | import qualified Numeric.LinearAlgebra.HMatrix as LA | 69 | import qualified Numeric.LinearAlgebra as LA |
73 | import Data.Proxy(Proxy) | 70 | import Data.Proxy(Proxy) |
74 | import Numeric.LinearAlgebra.Static.Internal | 71 | import Internal.Static |
75 | import Control.Arrow((***)) | 72 | import Control.Arrow((***)) |
76 | 73 | ||
77 | |||
78 | |||
79 | |||
80 | |||
81 | ud1 :: R n -> Vector ℝ | 74 | ud1 :: R n -> Vector ℝ |
82 | ud1 (R (Dim v)) = v | 75 | ud1 (R (Dim v)) = v |
83 | 76 | ||
@@ -171,21 +164,22 @@ unrow = mkR . head . LA.toRows . ud2 | |||
171 | uncol v = unrow . tr $ v | 164 | uncol v = unrow . tr $ v |
172 | 165 | ||
173 | 166 | ||
174 | infixl 2 —— | 167 | infixl 2 === |
175 | (——) :: (KnownNat r1, KnownNat r2, KnownNat c) => L r1 c -> L r2 c -> L (r1+r2) c | 168 | (===) :: (KnownNat r1, KnownNat r2, KnownNat c) => L r1 c -> L r2 c -> L (r1+r2) c |
176 | a —— b = mkL (extract a LA.—— extract b) | 169 | a === b = mkL (extract a LA.=== extract b) |
177 | 170 | ||
178 | 171 | ||
179 | infixl 3 ¦ | 172 | infixl 3 ||| |
180 | -- (¦) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2) | 173 | -- (|||) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2) |
181 | a ¦ b = tr (tr a —— tr b) | 174 | a ||| b = tr (tr a === tr b) |
182 | 175 | ||
183 | 176 | ||
184 | type Sq n = L n n | 177 | type Sq n = L n n |
185 | --type CSq n = CL n n | 178 | --type CSq n = CL n n |
186 | 179 | ||
187 | type GL = forall n m. (KnownNat n, KnownNat m) => L m n | 180 | |
188 | type GSq = forall n. KnownNat n => Sq n | 181 | type GL = forall n m . (KnownNat n, KnownNat m) => L m n |
182 | type GSq = forall n . KnownNat n => Sq n | ||
189 | 183 | ||
190 | isKonst :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ,(Int,Int)) | 184 | isKonst :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ,(Int,Int)) |
191 | isKonst s@(unwrap -> x) | 185 | isKonst s@(unwrap -> x) |
@@ -213,6 +207,10 @@ infixr 8 <·> | |||
213 | (<·>) :: R n -> R n -> ℝ | 207 | (<·>) :: R n -> R n -> ℝ |
214 | (<·>) = dotR | 208 | (<·>) = dotR |
215 | 209 | ||
210 | infixr 8 <.> | ||
211 | (<.>) :: R n -> R n -> ℝ | ||
212 | (<.>) = dotR | ||
213 | |||
216 | -------------------------------------------------------------------------------- | 214 | -------------------------------------------------------------------------------- |
217 | 215 | ||
218 | class Diag m d | m -> d | 216 | class Diag m d | m -> d |
@@ -294,10 +292,10 @@ her m = Her $ (m + LA.tr m)/2 | |||
294 | 292 | ||
295 | instance KnownNat n => Eigen (Sym n) (R n) (L n n) | 293 | instance KnownNat n => Eigen (Sym n) (R n) (L n n) |
296 | where | 294 | where |
297 | eigenvalues (Sym (extract -> m)) = mkR . LA.eigenvaluesSH' $ m | 295 | eigenvalues (Sym (extract -> m)) = mkR . LA.eigenvaluesSH . LA.trustSym $ m |
298 | eigensystem (Sym (extract -> m)) = (mkR l, mkL v) | 296 | eigensystem (Sym (extract -> m)) = (mkR l, mkL v) |
299 | where | 297 | where |
300 | (l,v) = LA.eigSH' m | 298 | (l,v) = LA.eigSH . LA.trustSym $ m |
301 | 299 | ||
302 | instance KnownNat n => Eigen (Sq n) (C n) (M n n) | 300 | instance KnownNat n => Eigen (Sq n) (C n) (M n n) |
303 | where | 301 | where |
@@ -307,7 +305,7 @@ instance KnownNat n => Eigen (Sq n) (C n) (M n n) | |||
307 | (l,v) = LA.eig m | 305 | (l,v) = LA.eig m |
308 | 306 | ||
309 | chol :: KnownNat n => Sym n -> Sq n | 307 | chol :: KnownNat n => Sym n -> Sq n |
310 | chol (extract . unSym -> m) = mkL $ LA.cholSH m | 308 | chol (extract . unSym -> m) = mkL $ LA.chol $ LA.trustSym m |
311 | 309 | ||
312 | -------------------------------------------------------------------------------- | 310 | -------------------------------------------------------------------------------- |
313 | 311 | ||
@@ -502,7 +500,7 @@ appC m v = mkC (extract m LA.#> extract v) | |||
502 | dotC :: KnownNat n => C n -> C n -> ℂ | 500 | dotC :: KnownNat n => C n -> C n -> ℂ |
503 | dotC (unwrap -> u) (unwrap -> v) | 501 | dotC (unwrap -> u) (unwrap -> v) |
504 | | singleV u || singleV v = sumElements (conj u * v) | 502 | | singleV u || singleV v = sumElements (conj u * v) |
505 | | otherwise = u LA.<·> v | 503 | | otherwise = u LA.<.> v |
506 | 504 | ||
507 | 505 | ||
508 | crossC :: C 3 -> C 3 -> C 3 | 506 | crossC :: C 3 -> C 3 -> C 3 |
@@ -590,12 +588,12 @@ test = (ok,info) | |||
590 | where | 588 | where |
591 | q = tm :: L 10 3 | 589 | q = tm :: L 10 3 |
592 | 590 | ||
593 | thingD = vjoin [ud1 u, 1] LA.<·> tr m LA.#> m LA.#> ud1 v | 591 | thingD = vjoin [ud1 u, 1] LA.<.> tr m LA.#> m LA.#> ud1 v |
594 | where | 592 | where |
595 | m = LA.matrix 3 [1..30] | 593 | m = LA.matrix 3 [1..30] |
596 | 594 | ||
597 | precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <·> konst 2 #> v | 595 | precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <·> konst 2 #> v |
598 | precD = 1 + 2 * vjoin[ud1 u, 6] LA.<·> LA.konst 2 (LA.size (ud1 u) +1, LA.size (ud1 v)) LA.#> ud1 v | 596 | precD = 1 + 2 * vjoin[ud1 u, 6] LA.<.> LA.konst 2 (LA.size (ud1 u) +1, LA.size (ud1 v)) LA.#> ud1 v |
599 | 597 | ||
600 | 598 | ||
601 | splittest | 599 | splittest |
@@ -618,23 +616,3 @@ instance (KnownNat n', KnownNat m') => Testable (L n' m') | |||
618 | where | 616 | where |
619 | checkT _ = test | 617 | checkT _ = test |
620 | 618 | ||
621 | #else | ||
622 | |||
623 | {- | | ||
624 | Module : Numeric.LinearAlgebra.Static | ||
625 | Copyright : (c) Alberto Ruiz 2014 | ||
626 | License : BSD3 | ||
627 | Stability : experimental | ||
628 | |||
629 | Experimental interface with statically checked dimensions. | ||
630 | |||
631 | This module requires GHC >= 7.8 | ||
632 | |||
633 | -} | ||
634 | |||
635 | module Numeric.LinearAlgebra.Static | ||
636 | {-# WARNING "This module requires GHC >= 7.8" #-} | ||
637 | where | ||
638 | |||
639 | #endif | ||
640 | |||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util.hs b/packages/base/src/Numeric/LinearAlgebra/Util.hs deleted file mode 100644 index 043aa21..0000000 --- a/packages/base/src/Numeric/LinearAlgebra/Util.hs +++ /dev/null | |||
@@ -1,505 +0,0 @@ | |||
1 | {-# LANGUAGE FlexibleContexts #-} | ||
2 | {-# LANGUAGE FlexibleInstances #-} | ||
3 | {-# LANGUAGE TypeFamilies #-} | ||
4 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
5 | {-# LANGUAGE FunctionalDependencies #-} | ||
6 | {-# LANGUAGE ViewPatterns #-} | ||
7 | |||
8 | |||
9 | ----------------------------------------------------------------------------- | ||
10 | {- | | ||
11 | Module : Numeric.LinearAlgebra.Util | ||
12 | Copyright : (c) Alberto Ruiz 2013 | ||
13 | License : BSD3 | ||
14 | Maintainer : Alberto Ruiz | ||
15 | Stability : provisional | ||
16 | |||
17 | -} | ||
18 | ----------------------------------------------------------------------------- | ||
19 | {-# OPTIONS_HADDOCK hide #-} | ||
20 | |||
21 | module Numeric.LinearAlgebra.Util( | ||
22 | |||
23 | -- * Convenience functions | ||
24 | vector, matrix, | ||
25 | disp, | ||
26 | formatSparse, | ||
27 | approxInt, | ||
28 | dispDots, | ||
29 | dispBlanks, | ||
30 | formatShort, | ||
31 | dispShort, | ||
32 | zeros, ones, | ||
33 | diagl, | ||
34 | row, | ||
35 | col, | ||
36 | (&), (¦), (|||), (——), (===), (#), | ||
37 | (?), (¿), | ||
38 | Indexable(..), size, | ||
39 | Numeric, | ||
40 | rand, randn, | ||
41 | cross, | ||
42 | norm, | ||
43 | ℕ,ℤ,ℝ,ℂ,iC, | ||
44 | Normed(..), norm_Frob, norm_nuclear, | ||
45 | unitary, | ||
46 | mt, | ||
47 | (~!~), | ||
48 | pairwiseD2, | ||
49 | rowOuters, | ||
50 | null1, | ||
51 | null1sym, | ||
52 | -- * Convolution | ||
53 | -- ** 1D | ||
54 | corr, conv, corrMin, | ||
55 | -- ** 2D | ||
56 | corr2, conv2, separable, | ||
57 | -- * Tools for the Kronecker product | ||
58 | -- | ||
59 | -- | (see A. Fusiello, A matter of notation: Several uses of the Kronecker product in | ||
60 | -- 3d computer vision, Pattern Recognition Letters 28 (15) (2007) 2127-2132) | ||
61 | |||
62 | -- | ||
63 | -- | @`vec` (a \<> x \<> b) == ('trans' b ` 'kronecker' ` a) \<> 'vec' x@ | ||
64 | vec, | ||
65 | vech, | ||
66 | dup, | ||
67 | vtrans | ||
68 | ) where | ||
69 | |||
70 | import Data.Packed.Numeric | ||
71 | import Numeric.LinearAlgebra.Algorithms hiding (i,Normed) | ||
72 | --import qualified Numeric.LinearAlgebra.Algorithms as A | ||
73 | import Numeric.Matrix() | ||
74 | import Numeric.Vector() | ||
75 | import Numeric.LinearAlgebra.Random | ||
76 | import Numeric.LinearAlgebra.Util.Convolution | ||
77 | import Control.Monad(when) | ||
78 | import Text.Printf | ||
79 | import Data.List.Split(splitOn) | ||
80 | import Data.List(intercalate) | ||
81 | |||
82 | type ℝ = Double | ||
83 | type ℕ = Int | ||
84 | type ℤ = Int | ||
85 | type ℂ = Complex Double | ||
86 | |||
87 | -- | imaginary unit | ||
88 | iC :: ℂ | ||
89 | iC = 0:+1 | ||
90 | |||
91 | {- | create a real vector | ||
92 | |||
93 | >>> vector [1..5] | ||
94 | fromList [1.0,2.0,3.0,4.0,5.0] | ||
95 | |||
96 | -} | ||
97 | vector :: [ℝ] -> Vector ℝ | ||
98 | vector = fromList | ||
99 | |||
100 | {- | create a real matrix | ||
101 | |||
102 | >>> matrix 5 [1..15] | ||
103 | (3><5) | ||
104 | [ 1.0, 2.0, 3.0, 4.0, 5.0 | ||
105 | , 6.0, 7.0, 8.0, 9.0, 10.0 | ||
106 | , 11.0, 12.0, 13.0, 14.0, 15.0 ] | ||
107 | |||
108 | -} | ||
109 | matrix | ||
110 | :: Int -- ^ columns | ||
111 | -> [ℝ] -- ^ elements | ||
112 | -> Matrix ℝ | ||
113 | matrix c = reshape c . fromList | ||
114 | |||
115 | |||
116 | {- | print a real matrix with given number of digits after the decimal point | ||
117 | |||
118 | >>> disp 5 $ ident 2 / 3 | ||
119 | 2x2 | ||
120 | 0.33333 0.00000 | ||
121 | 0.00000 0.33333 | ||
122 | |||
123 | -} | ||
124 | disp :: Int -> Matrix Double -> IO () | ||
125 | |||
126 | disp n = putStr . dispf n | ||
127 | |||
128 | |||
129 | {- | create a real diagonal matrix from a list | ||
130 | |||
131 | >>> diagl [1,2,3] | ||
132 | (3><3) | ||
133 | [ 1.0, 0.0, 0.0 | ||
134 | , 0.0, 2.0, 0.0 | ||
135 | , 0.0, 0.0, 3.0 ] | ||
136 | |||
137 | -} | ||
138 | diagl :: [Double] -> Matrix Double | ||
139 | diagl = diag . fromList | ||
140 | |||
141 | -- | a real matrix of zeros | ||
142 | zeros :: Int -- ^ rows | ||
143 | -> Int -- ^ columns | ||
144 | -> Matrix Double | ||
145 | zeros r c = konst 0 (r,c) | ||
146 | |||
147 | -- | a real matrix of ones | ||
148 | ones :: Int -- ^ rows | ||
149 | -> Int -- ^ columns | ||
150 | -> Matrix Double | ||
151 | ones r c = konst 1 (r,c) | ||
152 | |||
153 | -- | concatenation of real vectors | ||
154 | infixl 3 & | ||
155 | (&) :: Vector Double -> Vector Double -> Vector Double | ||
156 | a & b = vjoin [a,b] | ||
157 | |||
158 | {- | horizontal concatenation of real matrices | ||
159 | |||
160 | >>> ident 3 ||| konst 7 (3,4) | ||
161 | (3><7) | ||
162 | [ 1.0, 0.0, 0.0, 7.0, 7.0, 7.0, 7.0 | ||
163 | , 0.0, 1.0, 0.0, 7.0, 7.0, 7.0, 7.0 | ||
164 | , 0.0, 0.0, 1.0, 7.0, 7.0, 7.0, 7.0 ] | ||
165 | |||
166 | -} | ||
167 | infixl 3 ||| | ||
168 | (|||) :: Matrix Double -> Matrix Double -> Matrix Double | ||
169 | a ||| b = fromBlocks [[a,b]] | ||
170 | |||
171 | -- | a synonym for ('|||') (unicode 0x00a6, broken bar) | ||
172 | infixl 3 ¦ | ||
173 | (¦) :: Matrix Double -> Matrix Double -> Matrix Double | ||
174 | (¦) = (|||) | ||
175 | |||
176 | |||
177 | -- | vertical concatenation of real matrices | ||
178 | -- | ||
179 | (===) :: Matrix Double -> Matrix Double -> Matrix Double | ||
180 | infixl 2 === | ||
181 | a === b = fromBlocks [[a],[b]] | ||
182 | |||
183 | -- | a synonym for ('===') (unicode 0x2014, em dash) | ||
184 | (——) :: Matrix Double -> Matrix Double -> Matrix Double | ||
185 | infixl 2 —— | ||
186 | (——) = (===) | ||
187 | |||
188 | |||
189 | (#) :: Matrix Double -> Matrix Double -> Matrix Double | ||
190 | infixl 2 # | ||
191 | a # b = fromBlocks [[a],[b]] | ||
192 | |||
193 | -- | create a single row real matrix from a list | ||
194 | row :: [Double] -> Matrix Double | ||
195 | row = asRow . fromList | ||
196 | |||
197 | -- | create a single column real matrix from a list | ||
198 | col :: [Double] -> Matrix Double | ||
199 | col = asColumn . fromList | ||
200 | |||
201 | {- | extract rows | ||
202 | |||
203 | >>> (20><4) [1..] ? [2,1,1] | ||
204 | (3><4) | ||
205 | [ 9.0, 10.0, 11.0, 12.0 | ||
206 | , 5.0, 6.0, 7.0, 8.0 | ||
207 | , 5.0, 6.0, 7.0, 8.0 ] | ||
208 | |||
209 | -} | ||
210 | infixl 9 ? | ||
211 | (?) :: Element t => Matrix t -> [Int] -> Matrix t | ||
212 | (?) = flip extractRows | ||
213 | |||
214 | {- | extract columns | ||
215 | |||
216 | (unicode 0x00bf, inverted question mark, Alt-Gr ?) | ||
217 | |||
218 | >>> (3><4) [1..] ¿ [3,0] | ||
219 | (3><2) | ||
220 | [ 4.0, 1.0 | ||
221 | , 8.0, 5.0 | ||
222 | , 12.0, 9.0 ] | ||
223 | |||
224 | -} | ||
225 | infixl 9 ¿ | ||
226 | (¿) :: Element t => Matrix t -> [Int] -> Matrix t | ||
227 | (¿)= flip extractColumns | ||
228 | |||
229 | |||
230 | cross :: Vector Double -> Vector Double -> Vector Double | ||
231 | -- ^ cross product (for three-element real vectors) | ||
232 | cross x y | dim x == 3 && dim y == 3 = fromList [z1,z2,z3] | ||
233 | | otherwise = error $ "cross ("++show x++") ("++show y++")" | ||
234 | where | ||
235 | [x1,x2,x3] = toList x | ||
236 | [y1,y2,y3] = toList y | ||
237 | z1 = x2*y3-x3*y2 | ||
238 | z2 = x3*y1-x1*y3 | ||
239 | z3 = x1*y2-x2*y1 | ||
240 | |||
241 | norm :: Vector Double -> Double | ||
242 | -- ^ 2-norm of real vector | ||
243 | norm = pnorm PNorm2 | ||
244 | |||
245 | class Normed a | ||
246 | where | ||
247 | norm_0 :: a -> ℝ | ||
248 | norm_1 :: a -> ℝ | ||
249 | norm_2 :: a -> ℝ | ||
250 | norm_Inf :: a -> ℝ | ||
251 | |||
252 | |||
253 | instance Normed (Vector ℝ) | ||
254 | where | ||
255 | norm_0 v = sumElements (step (abs v - scalar (eps*normInf v))) | ||
256 | norm_1 = pnorm PNorm1 | ||
257 | norm_2 = pnorm PNorm2 | ||
258 | norm_Inf = pnorm Infinity | ||
259 | |||
260 | instance Normed (Vector ℂ) | ||
261 | where | ||
262 | norm_0 v = sumElements (step (fst (fromComplex (abs v)) - scalar (eps*normInf v))) | ||
263 | norm_1 = pnorm PNorm1 | ||
264 | norm_2 = pnorm PNorm2 | ||
265 | norm_Inf = pnorm Infinity | ||
266 | |||
267 | instance Normed (Matrix ℝ) | ||
268 | where | ||
269 | norm_0 = norm_0 . flatten | ||
270 | norm_1 = pnorm PNorm1 | ||
271 | norm_2 = pnorm PNorm2 | ||
272 | norm_Inf = pnorm Infinity | ||
273 | |||
274 | instance Normed (Matrix ℂ) | ||
275 | where | ||
276 | norm_0 = norm_0 . flatten | ||
277 | norm_1 = pnorm PNorm1 | ||
278 | norm_2 = pnorm PNorm2 | ||
279 | norm_Inf = pnorm Infinity | ||
280 | |||
281 | |||
282 | norm_Frob :: (Normed (Vector t), Element t) => Matrix t -> ℝ | ||
283 | norm_Frob = norm_2 . flatten | ||
284 | |||
285 | norm_nuclear :: Field t => Matrix t -> ℝ | ||
286 | norm_nuclear = sumElements . singularValues | ||
287 | |||
288 | |||
289 | -- | Obtains a vector in the same direction with 2-norm=1 | ||
290 | unitary :: Vector Double -> Vector Double | ||
291 | unitary v = v / scalar (norm v) | ||
292 | |||
293 | |||
294 | -- | trans . inv | ||
295 | mt :: Matrix Double -> Matrix Double | ||
296 | mt = trans . inv | ||
297 | |||
298 | -------------------------------------------------------------------------------- | ||
299 | {- | | ||
300 | |||
301 | >>> size $ fromList[1..10::Double] | ||
302 | 10 | ||
303 | >>> size $ (2><5)[1..10::Double] | ||
304 | (2,5) | ||
305 | |||
306 | -} | ||
307 | size :: Container c t => c t -> IndexOf c | ||
308 | size = size' | ||
309 | |||
310 | {- | | ||
311 | |||
312 | >>> vect [1..10] ! 3 | ||
313 | 4.0 | ||
314 | |||
315 | >>> mat 5 [1..15] ! 1 | ||
316 | fromList [6.0,7.0,8.0,9.0,10.0] | ||
317 | |||
318 | >>> mat 5 [1..15] ! 1 ! 3 | ||
319 | 9.0 | ||
320 | |||
321 | -} | ||
322 | class Indexable c t | c -> t , t -> c | ||
323 | where | ||
324 | infixl 9 ! | ||
325 | (!) :: c -> Int -> t | ||
326 | |||
327 | instance Indexable (Vector Double) Double | ||
328 | where | ||
329 | (!) = (@>) | ||
330 | |||
331 | instance Indexable (Vector Float) Float | ||
332 | where | ||
333 | (!) = (@>) | ||
334 | |||
335 | instance Indexable (Vector (Complex Double)) (Complex Double) | ||
336 | where | ||
337 | (!) = (@>) | ||
338 | |||
339 | instance Indexable (Vector (Complex Float)) (Complex Float) | ||
340 | where | ||
341 | (!) = (@>) | ||
342 | |||
343 | instance Element t => Indexable (Matrix t) (Vector t) | ||
344 | where | ||
345 | m!j = subVector (j*c) c (flatten m) | ||
346 | where | ||
347 | c = cols m | ||
348 | |||
349 | -------------------------------------------------------------------------------- | ||
350 | |||
351 | -- | Matrix of pairwise squared distances of row vectors | ||
352 | -- (using the matrix product trick in blog.smola.org) | ||
353 | pairwiseD2 :: Matrix Double -> Matrix Double -> Matrix Double | ||
354 | pairwiseD2 x y | ok = x2 `outer` oy + ox `outer` y2 - 2* x <> trans y | ||
355 | | otherwise = error $ "pairwiseD2 with different number of columns: " | ||
356 | ++ show (size x) ++ ", " ++ show (size y) | ||
357 | where | ||
358 | ox = one (rows x) | ||
359 | oy = one (rows y) | ||
360 | oc = one (cols x) | ||
361 | one k = konst 1 k | ||
362 | x2 = x * x <> oc | ||
363 | y2 = y * y <> oc | ||
364 | ok = cols x == cols y | ||
365 | |||
366 | -------------------------------------------------------------------------------- | ||
367 | |||
368 | {- | outer products of rows | ||
369 | |||
370 | >>> a | ||
371 | (3><2) | ||
372 | [ 1.0, 2.0 | ||
373 | , 10.0, 20.0 | ||
374 | , 100.0, 200.0 ] | ||
375 | >>> b | ||
376 | (3><3) | ||
377 | [ 1.0, 2.0, 3.0 | ||
378 | , 4.0, 5.0, 6.0 | ||
379 | , 7.0, 8.0, 9.0 ] | ||
380 | |||
381 | >>> rowOuters a (b ||| 1) | ||
382 | (3><8) | ||
383 | [ 1.0, 2.0, 3.0, 1.0, 2.0, 4.0, 6.0, 2.0 | ||
384 | , 40.0, 50.0, 60.0, 10.0, 80.0, 100.0, 120.0, 20.0 | ||
385 | , 700.0, 800.0, 900.0, 100.0, 1400.0, 1600.0, 1800.0, 200.0 ] | ||
386 | |||
387 | -} | ||
388 | rowOuters :: Matrix Double -> Matrix Double -> Matrix Double | ||
389 | rowOuters a b = a' * b' | ||
390 | where | ||
391 | a' = kronecker a (ones 1 (cols b)) | ||
392 | b' = kronecker (ones 1 (cols a)) b | ||
393 | |||
394 | -------------------------------------------------------------------------------- | ||
395 | |||
396 | -- | solution of overconstrained homogeneous linear system | ||
397 | null1 :: Matrix Double -> Vector Double | ||
398 | null1 = last . toColumns . snd . rightSV | ||
399 | |||
400 | -- | solution of overconstrained homogeneous symmetric linear system | ||
401 | null1sym :: Matrix Double -> Vector Double | ||
402 | null1sym = last . toColumns . snd . eigSH' | ||
403 | |||
404 | -------------------------------------------------------------------------------- | ||
405 | |||
406 | vec :: Element t => Matrix t -> Vector t | ||
407 | -- ^ stacking of columns | ||
408 | vec = flatten . trans | ||
409 | |||
410 | |||
411 | vech :: Element t => Matrix t -> Vector t | ||
412 | -- ^ half-vectorization (of the lower triangular part) | ||
413 | vech m = vjoin . zipWith f [0..] . toColumns $ m | ||
414 | where | ||
415 | f k v = subVector k (dim v - k) v | ||
416 | |||
417 | |||
418 | dup :: (Num t, Num (Vector t), Element t) => Int -> Matrix t | ||
419 | -- ^ duplication matrix (@'dup' k \<> 'vech' m == 'vec' m@, for symmetric m of 'dim' k) | ||
420 | dup k = trans $ fromRows $ map f es | ||
421 | where | ||
422 | rs = zip [0..] (toRows (ident (k^(2::Int)))) | ||
423 | es = [(i,j) | j <- [0..k-1], i <- [0..k-1], i>=j ] | ||
424 | f (i,j) | i == j = g (k*j + i) | ||
425 | | otherwise = g (k*j + i) + g (k*i + j) | ||
426 | g j = v | ||
427 | where | ||
428 | Just v = lookup j rs | ||
429 | |||
430 | |||
431 | vtrans :: Element t => Int -> Matrix t -> Matrix t | ||
432 | -- ^ generalized \"vector\" transposition: @'vtrans' 1 == 'trans'@, and @'vtrans' ('rows' m) m == 'asColumn' ('vec' m)@ | ||
433 | vtrans p m | r == 0 = fromBlocks . map (map asColumn . takesV (replicate q p)) . toColumns $ m | ||
434 | | otherwise = error $ "vtrans " ++ show p ++ " of matrix with " ++ show (rows m) ++ " rows" | ||
435 | where | ||
436 | (q,r) = divMod (rows m) p | ||
437 | |||
438 | -------------------------------------------------------------------------------- | ||
439 | |||
440 | infixl 0 ~!~ | ||
441 | c ~!~ msg = when c (error msg) | ||
442 | |||
443 | -------------------------------------------------------------------------------- | ||
444 | |||
445 | formatSparse :: String -> String -> String -> Int -> Matrix Double -> String | ||
446 | |||
447 | formatSparse zeroI _zeroF sep _ (approxInt -> Just m) = format sep f m | ||
448 | where | ||
449 | f 0 = zeroI | ||
450 | f x = printf "%.0f" x | ||
451 | |||
452 | formatSparse zeroI zeroF sep n m = format sep f m | ||
453 | where | ||
454 | f x | abs (x::Double) < 2*peps = zeroI++zeroF | ||
455 | | abs (fromIntegral (round x::Int) - x) / abs x < 2*peps | ||
456 | = printf ("%.0f."++replicate n ' ') x | ||
457 | | otherwise = printf ("%."++show n++"f") x | ||
458 | |||
459 | approxInt m | ||
460 | | norm_Inf (v - vi) < 2*peps * norm_Inf v = Just (reshape (cols m) vi) | ||
461 | | otherwise = Nothing | ||
462 | where | ||
463 | v = flatten m | ||
464 | vi = roundVector v | ||
465 | |||
466 | dispDots n = putStr . formatSparse "." (replicate n ' ') " " n | ||
467 | |||
468 | dispBlanks n = putStr . formatSparse "" "" " " n | ||
469 | |||
470 | formatShort sep fmt maxr maxc m = auxm4 | ||
471 | where | ||
472 | (rm,cm) = size m | ||
473 | (r1,r2,r3) | ||
474 | | rm <= maxr = (rm,0,0) | ||
475 | | otherwise = (maxr-3,rm-maxr+1,2) | ||
476 | (c1,c2,c3) | ||
477 | | cm <= maxc = (cm,0,0) | ||
478 | | otherwise = (maxc-3,cm-maxc+1,2) | ||
479 | [ [a,_,b] | ||
480 | ,[_,_,_] | ||
481 | ,[c,_,d]] = toBlocks [r1,r2,r3] | ||
482 | [c1,c2,c3] m | ||
483 | auxm = fromBlocks [[a,b],[c,d]] | ||
484 | auxm2 | ||
485 | | cm > maxc = format "|" fmt auxm | ||
486 | | otherwise = format sep fmt auxm | ||
487 | auxm3 | ||
488 | | cm > maxc = map (f . splitOn "|") (lines auxm2) | ||
489 | | otherwise = (lines auxm2) | ||
490 | f items = intercalate sep (take (maxc-3) items) ++ " .. " ++ | ||
491 | intercalate sep (drop (maxc-3) items) | ||
492 | auxm4 | ||
493 | | rm > maxr = unlines (take (maxr-3) auxm3 ++ vsep : drop (maxr-3) auxm3) | ||
494 | | otherwise = unlines auxm3 | ||
495 | vsep = map g (head auxm3) | ||
496 | g '.' = ':' | ||
497 | g _ = ' ' | ||
498 | |||
499 | |||
500 | dispShort :: Int -> Int -> Int -> Matrix Double -> IO () | ||
501 | dispShort maxr maxc dec m = | ||
502 | printf "%dx%d\n%s" (rows m) (cols m) (formatShort " " fmt maxr maxc m) | ||
503 | where | ||
504 | fmt = printf ("%."++show dec ++"f") | ||
505 | |||
diff --git a/packages/base/src/Numeric/Matrix.hs b/packages/base/src/Numeric/Matrix.hs index a9022c6..5400f26 100644 --- a/packages/base/src/Numeric/Matrix.hs +++ b/packages/base/src/Numeric/Matrix.hs | |||
@@ -26,18 +26,20 @@ module Numeric.Matrix ( | |||
26 | 26 | ||
27 | ------------------------------------------------------------------- | 27 | ------------------------------------------------------------------- |
28 | 28 | ||
29 | import Data.Packed | 29 | import Internal.Vector |
30 | import Data.Packed.Internal.Numeric | 30 | import Internal.Matrix |
31 | import Internal.Element | ||
32 | import Internal.Numeric | ||
31 | import qualified Data.Monoid as M | 33 | import qualified Data.Monoid as M |
32 | import Data.List(partition) | 34 | import Data.List(partition) |
33 | import Numeric.Chain | 35 | import Internal.Chain |
34 | 36 | ||
35 | ------------------------------------------------------------------- | 37 | ------------------------------------------------------------------- |
36 | 38 | ||
37 | instance Container Matrix a => Eq (Matrix a) where | 39 | instance Container Matrix a => Eq (Matrix a) where |
38 | (==) = equal | 40 | (==) = equal |
39 | 41 | ||
40 | instance (Container Matrix a, Num (Vector a)) => Num (Matrix a) where | 42 | instance (Container Matrix a, Num a, Num (Vector a)) => Num (Matrix a) where |
41 | (+) = liftMatrix2Auto (+) | 43 | (+) = liftMatrix2Auto (+) |
42 | (-) = liftMatrix2Auto (-) | 44 | (-) = liftMatrix2Auto (-) |
43 | negate = liftMatrix negate | 45 | negate = liftMatrix negate |
@@ -48,7 +50,7 @@ instance (Container Matrix a, Num (Vector a)) => Num (Matrix a) where | |||
48 | 50 | ||
49 | --------------------------------------------------- | 51 | --------------------------------------------------- |
50 | 52 | ||
51 | instance (Container Vector a, Fractional (Vector a), Num (Matrix a)) => Fractional (Matrix a) where | 53 | instance (Container Vector a, Fractional a, Fractional (Vector a), Num (Matrix a)) => Fractional (Matrix a) where |
52 | fromRational n = (1><1) [fromRational n] | 54 | fromRational n = (1><1) [fromRational n] |
53 | (/) = liftMatrix2Auto (/) | 55 | (/) = liftMatrix2Auto (/) |
54 | 56 | ||
diff --git a/packages/base/src/Numeric/Vector.hs b/packages/base/src/Numeric/Vector.hs index 28b453f..017196c 100644 --- a/packages/base/src/Numeric/Vector.hs +++ b/packages/base/src/Numeric/Vector.hs | |||
@@ -19,9 +19,10 @@ | |||
19 | 19 | ||
20 | module Numeric.Vector () where | 20 | module Numeric.Vector () where |
21 | 21 | ||
22 | import Numeric.Vectorized | 22 | import Internal.Vectorized |
23 | import Data.Packed.Vector | 23 | import Internal.Vector |
24 | import Data.Packed.Internal.Numeric | 24 | import Internal.Numeric |
25 | import Internal.Conversion | ||
25 | 26 | ||
26 | ------------------------------------------------------------------- | 27 | ------------------------------------------------------------------- |
27 | 28 | ||
@@ -32,6 +33,22 @@ adaptScalar f1 f2 f3 x y | |||
32 | 33 | ||
33 | ------------------------------------------------------------------ | 34 | ------------------------------------------------------------------ |
34 | 35 | ||
36 | instance Num (Vector I) where | ||
37 | (+) = adaptScalar addConstant add (flip addConstant) | ||
38 | negate = scale (-1) | ||
39 | (*) = adaptScalar scale mul (flip scale) | ||
40 | signum = vectorMapI Sign | ||
41 | abs = vectorMapI Abs | ||
42 | fromInteger = fromList . return . fromInteger | ||
43 | |||
44 | instance Num (Vector Z) where | ||
45 | (+) = adaptScalar addConstant add (flip addConstant) | ||
46 | negate = scale (-1) | ||
47 | (*) = adaptScalar scale mul (flip scale) | ||
48 | signum = vectorMapL Sign | ||
49 | abs = vectorMapL Abs | ||
50 | fromInteger = fromList . return . fromInteger | ||
51 | |||
35 | instance Num (Vector Float) where | 52 | instance Num (Vector Float) where |
36 | (+) = adaptScalar addConstant add (flip addConstant) | 53 | (+) = adaptScalar addConstant add (flip addConstant) |
37 | negate = scale (-1) | 54 | negate = scale (-1) |
@@ -66,7 +83,7 @@ instance Num (Vector (Complex Float)) where | |||
66 | 83 | ||
67 | --------------------------------------------------- | 84 | --------------------------------------------------- |
68 | 85 | ||
69 | instance (Container Vector a, Num (Vector a)) => Fractional (Vector a) where | 86 | instance (Container Vector a, Num (Vector a), Fractional a) => Fractional (Vector a) where |
70 | fromRational n = fromList [fromRational n] | 87 | fromRational n = fromList [fromRational n] |
71 | (/) = adaptScalar f divide g where | 88 | (/) = adaptScalar f divide g where |
72 | r `f` v = scaleRecip r v | 89 | r `f` v = scaleRecip r v |
diff --git a/packages/base/src/Numeric/Vectorized.hs b/packages/base/src/Numeric/Vectorized.hs deleted file mode 100644 index 6f0d240..0000000 --- a/packages/base/src/Numeric/Vectorized.hs +++ /dev/null | |||
@@ -1,365 +0,0 @@ | |||
1 | ----------------------------------------------------------------------------- | ||
2 | -- | | ||
3 | -- Module : Numeric.Vectorized | ||
4 | -- Copyright : (c) Alberto Ruiz 2007-14 | ||
5 | -- License : BSD3 | ||
6 | -- Maintainer : Alberto Ruiz | ||
7 | -- Stability : provisional | ||
8 | -- | ||
9 | -- Low level interface to vector operations. | ||
10 | -- | ||
11 | ----------------------------------------------------------------------------- | ||
12 | |||
13 | module Numeric.Vectorized ( | ||
14 | sumF, sumR, sumQ, sumC, | ||
15 | prodF, prodR, prodQ, prodC, | ||
16 | FunCodeS(..), toScalarR, toScalarF, toScalarC, toScalarQ, | ||
17 | FunCodeV(..), vectorMapR, vectorMapC, vectorMapF, vectorMapQ, | ||
18 | FunCodeSV(..), vectorMapValR, vectorMapValC, vectorMapValF, vectorMapValQ, | ||
19 | FunCodeVV(..), vectorZipR, vectorZipC, vectorZipF, vectorZipQ, | ||
20 | vectorScan, saveMatrix, | ||
21 | Seed, RandDist(..), randomVector, | ||
22 | sortVector, roundVector | ||
23 | ) where | ||
24 | |||
25 | import Data.Packed.Internal.Common | ||
26 | import Data.Packed.Internal.Signatures | ||
27 | import Data.Packed.Internal.Vector | ||
28 | import Data.Packed.Internal.Matrix | ||
29 | |||
30 | import Data.Complex | ||
31 | import Foreign.Marshal.Alloc(free,malloc) | ||
32 | import Foreign.Marshal.Array(newArray,copyArray) | ||
33 | import Foreign.Ptr(Ptr) | ||
34 | import Foreign.Storable(peek) | ||
35 | import Foreign.C.Types | ||
36 | import Foreign.C.String | ||
37 | import System.IO.Unsafe(unsafePerformIO) | ||
38 | |||
39 | import Control.Monad(when) | ||
40 | import Control.Applicative((<$>)) | ||
41 | |||
42 | |||
43 | |||
44 | fromei x = fromIntegral (fromEnum x) :: CInt | ||
45 | |||
46 | data FunCodeV = Sin | ||
47 | | Cos | ||
48 | | Tan | ||
49 | | Abs | ||
50 | | ASin | ||
51 | | ACos | ||
52 | | ATan | ||
53 | | Sinh | ||
54 | | Cosh | ||
55 | | Tanh | ||
56 | | ASinh | ||
57 | | ACosh | ||
58 | | ATanh | ||
59 | | Exp | ||
60 | | Log | ||
61 | | Sign | ||
62 | | Sqrt | ||
63 | deriving Enum | ||
64 | |||
65 | data FunCodeSV = Scale | ||
66 | | Recip | ||
67 | | AddConstant | ||
68 | | Negate | ||
69 | | PowSV | ||
70 | | PowVS | ||
71 | deriving Enum | ||
72 | |||
73 | data FunCodeVV = Add | ||
74 | | Sub | ||
75 | | Mul | ||
76 | | Div | ||
77 | | Pow | ||
78 | | ATan2 | ||
79 | deriving Enum | ||
80 | |||
81 | data FunCodeS = Norm2 | ||
82 | | AbsSum | ||
83 | | MaxIdx | ||
84 | | Max | ||
85 | | MinIdx | ||
86 | | Min | ||
87 | deriving Enum | ||
88 | |||
89 | ------------------------------------------------------------------ | ||
90 | |||
91 | -- | sum of elements | ||
92 | sumF :: Vector Float -> Float | ||
93 | sumF x = unsafePerformIO $ do | ||
94 | r <- createVector 1 | ||
95 | app2 c_sumF vec x vec r "sumF" | ||
96 | return $ r @> 0 | ||
97 | |||
98 | -- | sum of elements | ||
99 | sumR :: Vector Double -> Double | ||
100 | sumR x = unsafePerformIO $ do | ||
101 | r <- createVector 1 | ||
102 | app2 c_sumR vec x vec r "sumR" | ||
103 | return $ r @> 0 | ||
104 | |||
105 | -- | sum of elements | ||
106 | sumQ :: Vector (Complex Float) -> Complex Float | ||
107 | sumQ x = unsafePerformIO $ do | ||
108 | r <- createVector 1 | ||
109 | app2 c_sumQ vec x vec r "sumQ" | ||
110 | return $ r @> 0 | ||
111 | |||
112 | -- | sum of elements | ||
113 | sumC :: Vector (Complex Double) -> Complex Double | ||
114 | sumC x = unsafePerformIO $ do | ||
115 | r <- createVector 1 | ||
116 | app2 c_sumC vec x vec r "sumC" | ||
117 | return $ r @> 0 | ||
118 | |||
119 | foreign import ccall unsafe "sumF" c_sumF :: TFF | ||
120 | foreign import ccall unsafe "sumR" c_sumR :: TVV | ||
121 | foreign import ccall unsafe "sumQ" c_sumQ :: TQVQV | ||
122 | foreign import ccall unsafe "sumC" c_sumC :: TCVCV | ||
123 | |||
124 | -- | product of elements | ||
125 | prodF :: Vector Float -> Float | ||
126 | prodF x = unsafePerformIO $ do | ||
127 | r <- createVector 1 | ||
128 | app2 c_prodF vec x vec r "prodF" | ||
129 | return $ r @> 0 | ||
130 | |||
131 | -- | product of elements | ||
132 | prodR :: Vector Double -> Double | ||
133 | prodR x = unsafePerformIO $ do | ||
134 | r <- createVector 1 | ||
135 | app2 c_prodR vec x vec r "prodR" | ||
136 | return $ r @> 0 | ||
137 | |||
138 | -- | product of elements | ||
139 | prodQ :: Vector (Complex Float) -> Complex Float | ||
140 | prodQ x = unsafePerformIO $ do | ||
141 | r <- createVector 1 | ||
142 | app2 c_prodQ vec x vec r "prodQ" | ||
143 | return $ r @> 0 | ||
144 | |||
145 | -- | product of elements | ||
146 | prodC :: Vector (Complex Double) -> Complex Double | ||
147 | prodC x = unsafePerformIO $ do | ||
148 | r <- createVector 1 | ||
149 | app2 c_prodC vec x vec r "prodC" | ||
150 | return $ r @> 0 | ||
151 | |||
152 | foreign import ccall unsafe "prodF" c_prodF :: TFF | ||
153 | foreign import ccall unsafe "prodR" c_prodR :: TVV | ||
154 | foreign import ccall unsafe "prodQ" c_prodQ :: TQVQV | ||
155 | foreign import ccall unsafe "prodC" c_prodC :: TCVCV | ||
156 | |||
157 | ------------------------------------------------------------------ | ||
158 | |||
159 | toScalarAux fun code v = unsafePerformIO $ do | ||
160 | r <- createVector 1 | ||
161 | app2 (fun (fromei code)) vec v vec r "toScalarAux" | ||
162 | return (r `at` 0) | ||
163 | |||
164 | vectorMapAux fun code v = unsafePerformIO $ do | ||
165 | r <- createVector (dim v) | ||
166 | app2 (fun (fromei code)) vec v vec r "vectorMapAux" | ||
167 | return r | ||
168 | |||
169 | vectorMapValAux fun code val v = unsafePerformIO $ do | ||
170 | r <- createVector (dim v) | ||
171 | pval <- newArray [val] | ||
172 | app2 (fun (fromei code) pval) vec v vec r "vectorMapValAux" | ||
173 | free pval | ||
174 | return r | ||
175 | |||
176 | vectorZipAux fun code u v = unsafePerformIO $ do | ||
177 | r <- createVector (dim u) | ||
178 | app3 (fun (fromei code)) vec u vec v vec r "vectorZipAux" | ||
179 | return r | ||
180 | |||
181 | --------------------------------------------------------------------- | ||
182 | |||
183 | -- | obtains different functions of a vector: norm1, norm2, max, min, posmax, posmin, etc. | ||
184 | toScalarR :: FunCodeS -> Vector Double -> Double | ||
185 | toScalarR oper = toScalarAux c_toScalarR (fromei oper) | ||
186 | |||
187 | foreign import ccall unsafe "toScalarR" c_toScalarR :: CInt -> TVV | ||
188 | |||
189 | -- | obtains different functions of a vector: norm1, norm2, max, min, posmax, posmin, etc. | ||
190 | toScalarF :: FunCodeS -> Vector Float -> Float | ||
191 | toScalarF oper = toScalarAux c_toScalarF (fromei oper) | ||
192 | |||
193 | foreign import ccall unsafe "toScalarF" c_toScalarF :: CInt -> TFF | ||
194 | |||
195 | -- | obtains different functions of a vector: only norm1, norm2 | ||
196 | toScalarC :: FunCodeS -> Vector (Complex Double) -> Double | ||
197 | toScalarC oper = toScalarAux c_toScalarC (fromei oper) | ||
198 | |||
199 | foreign import ccall unsafe "toScalarC" c_toScalarC :: CInt -> TCVV | ||
200 | |||
201 | -- | obtains different functions of a vector: only norm1, norm2 | ||
202 | toScalarQ :: FunCodeS -> Vector (Complex Float) -> Float | ||
203 | toScalarQ oper = toScalarAux c_toScalarQ (fromei oper) | ||
204 | |||
205 | foreign import ccall unsafe "toScalarQ" c_toScalarQ :: CInt -> TQVF | ||
206 | |||
207 | ------------------------------------------------------------------ | ||
208 | |||
209 | -- | map of real vectors with given function | ||
210 | vectorMapR :: FunCodeV -> Vector Double -> Vector Double | ||
211 | vectorMapR = vectorMapAux c_vectorMapR | ||
212 | |||
213 | foreign import ccall unsafe "mapR" c_vectorMapR :: CInt -> TVV | ||
214 | |||
215 | -- | map of complex vectors with given function | ||
216 | vectorMapC :: FunCodeV -> Vector (Complex Double) -> Vector (Complex Double) | ||
217 | vectorMapC oper = vectorMapAux c_vectorMapC (fromei oper) | ||
218 | |||
219 | foreign import ccall unsafe "mapC" c_vectorMapC :: CInt -> TCVCV | ||
220 | |||
221 | -- | map of real vectors with given function | ||
222 | vectorMapF :: FunCodeV -> Vector Float -> Vector Float | ||
223 | vectorMapF = vectorMapAux c_vectorMapF | ||
224 | |||
225 | foreign import ccall unsafe "mapF" c_vectorMapF :: CInt -> TFF | ||
226 | |||
227 | -- | map of real vectors with given function | ||
228 | vectorMapQ :: FunCodeV -> Vector (Complex Float) -> Vector (Complex Float) | ||
229 | vectorMapQ = vectorMapAux c_vectorMapQ | ||
230 | |||
231 | foreign import ccall unsafe "mapQ" c_vectorMapQ :: CInt -> TQVQV | ||
232 | |||
233 | ------------------------------------------------------------------- | ||
234 | |||
235 | -- | map of real vectors with given function | ||
236 | vectorMapValR :: FunCodeSV -> Double -> Vector Double -> Vector Double | ||
237 | vectorMapValR oper = vectorMapValAux c_vectorMapValR (fromei oper) | ||
238 | |||
239 | foreign import ccall unsafe "mapValR" c_vectorMapValR :: CInt -> Ptr Double -> TVV | ||
240 | |||
241 | -- | map of complex vectors with given function | ||
242 | vectorMapValC :: FunCodeSV -> Complex Double -> Vector (Complex Double) -> Vector (Complex Double) | ||
243 | vectorMapValC = vectorMapValAux c_vectorMapValC | ||
244 | |||
245 | foreign import ccall unsafe "mapValC" c_vectorMapValC :: CInt -> Ptr (Complex Double) -> TCVCV | ||
246 | |||
247 | -- | map of real vectors with given function | ||
248 | vectorMapValF :: FunCodeSV -> Float -> Vector Float -> Vector Float | ||
249 | vectorMapValF oper = vectorMapValAux c_vectorMapValF (fromei oper) | ||
250 | |||
251 | foreign import ccall unsafe "mapValF" c_vectorMapValF :: CInt -> Ptr Float -> TFF | ||
252 | |||
253 | -- | map of complex vectors with given function | ||
254 | vectorMapValQ :: FunCodeSV -> Complex Float -> Vector (Complex Float) -> Vector (Complex Float) | ||
255 | vectorMapValQ oper = vectorMapValAux c_vectorMapValQ (fromei oper) | ||
256 | |||
257 | foreign import ccall unsafe "mapValQ" c_vectorMapValQ :: CInt -> Ptr (Complex Float) -> TQVQV | ||
258 | |||
259 | ------------------------------------------------------------------- | ||
260 | |||
261 | -- | elementwise operation on real vectors | ||
262 | vectorZipR :: FunCodeVV -> Vector Double -> Vector Double -> Vector Double | ||
263 | vectorZipR = vectorZipAux c_vectorZipR | ||
264 | |||
265 | foreign import ccall unsafe "zipR" c_vectorZipR :: CInt -> TVVV | ||
266 | |||
267 | -- | elementwise operation on complex vectors | ||
268 | vectorZipC :: FunCodeVV -> Vector (Complex Double) -> Vector (Complex Double) -> Vector (Complex Double) | ||
269 | vectorZipC = vectorZipAux c_vectorZipC | ||
270 | |||
271 | foreign import ccall unsafe "zipC" c_vectorZipC :: CInt -> TCVCVCV | ||
272 | |||
273 | -- | elementwise operation on real vectors | ||
274 | vectorZipF :: FunCodeVV -> Vector Float -> Vector Float -> Vector Float | ||
275 | vectorZipF = vectorZipAux c_vectorZipF | ||
276 | |||
277 | foreign import ccall unsafe "zipF" c_vectorZipF :: CInt -> TFFF | ||
278 | |||
279 | -- | elementwise operation on complex vectors | ||
280 | vectorZipQ :: FunCodeVV -> Vector (Complex Float) -> Vector (Complex Float) -> Vector (Complex Float) | ||
281 | vectorZipQ = vectorZipAux c_vectorZipQ | ||
282 | |||
283 | foreign import ccall unsafe "zipQ" c_vectorZipQ :: CInt -> TQVQVQV | ||
284 | |||
285 | -------------------------------------------------------------------------------- | ||
286 | |||
287 | foreign import ccall unsafe "vectorScan" c_vectorScan | ||
288 | :: CString -> Ptr CInt -> Ptr (Ptr Double) -> IO CInt | ||
289 | |||
290 | vectorScan :: FilePath -> IO (Vector Double) | ||
291 | vectorScan s = do | ||
292 | pp <- malloc | ||
293 | pn <- malloc | ||
294 | cs <- newCString s | ||
295 | ok <- c_vectorScan cs pn pp | ||
296 | when (not (ok == 0)) $ | ||
297 | error ("vectorScan: file \"" ++ s ++"\" not found") | ||
298 | n <- fromIntegral <$> peek pn | ||
299 | p <- peek pp | ||
300 | v <- createVector n | ||
301 | free pn | ||
302 | free cs | ||
303 | unsafeWith v $ \pv -> copyArray pv p n | ||
304 | free p | ||
305 | free pp | ||
306 | return v | ||
307 | |||
308 | -------------------------------------------------------------------------------- | ||
309 | |||
310 | foreign import ccall unsafe "saveMatrix" c_saveMatrix | ||
311 | :: CString -> CString -> TM | ||
312 | |||
313 | {- | save a matrix as a 2D ASCII table | ||
314 | -} | ||
315 | saveMatrix | ||
316 | :: FilePath | ||
317 | -> String -- ^ \"printf\" format (e.g. \"%.2f\", \"%g\", etc.) | ||
318 | -> Matrix Double | ||
319 | -> IO () | ||
320 | saveMatrix name format m = do | ||
321 | cname <- newCString name | ||
322 | cformat <- newCString format | ||
323 | app1 (c_saveMatrix cname cformat) mat m "saveMatrix" | ||
324 | free cname | ||
325 | free cformat | ||
326 | return () | ||
327 | |||
328 | -------------------------------------------------------------------------------- | ||
329 | |||
330 | type Seed = Int | ||
331 | |||
332 | data RandDist = Uniform -- ^ uniform distribution in [0,1) | ||
333 | | Gaussian -- ^ normal distribution with mean zero and standard deviation one | ||
334 | deriving Enum | ||
335 | |||
336 | -- | Obtains a vector of pseudorandom elements (use randomIO to get a random seed). | ||
337 | randomVector :: Seed | ||
338 | -> RandDist -- ^ distribution | ||
339 | -> Int -- ^ vector size | ||
340 | -> Vector Double | ||
341 | randomVector seed dist n = unsafePerformIO $ do | ||
342 | r <- createVector n | ||
343 | app1 (c_random_vector (fi seed) ((fi.fromEnum) dist)) vec r "randomVector" | ||
344 | return r | ||
345 | |||
346 | foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> TV | ||
347 | |||
348 | -------------------------------------------------------------------------------- | ||
349 | |||
350 | sortVector v = unsafePerformIO $ do | ||
351 | r <- createVector (dim v) | ||
352 | app2 c_sort_values vec v vec r "sortVector" | ||
353 | return r | ||
354 | |||
355 | foreign import ccall unsafe "sort_values" c_sort_values :: TVV | ||
356 | |||
357 | -------------------------------------------------------------------------------- | ||
358 | |||
359 | roundVector v = unsafePerformIO $ do | ||
360 | r <- createVector (dim v) | ||
361 | app2 c_round_vector vec v vec r "roundVector" | ||
362 | return r | ||
363 | |||
364 | foreign import ccall unsafe "round_vector" c_round_vector :: TVV | ||
365 | |||
diff --git a/packages/base/stack.yaml b/packages/base/stack.yaml new file mode 100644 index 0000000..f4001c6 --- /dev/null +++ b/packages/base/stack.yaml | |||
@@ -0,0 +1,7 @@ | |||
1 | flags: | ||
2 | hmatrix: | ||
3 | openblas: false | ||
4 | packages: | ||
5 | - '.' | ||
6 | extra-deps: [] | ||
7 | resolver: lts-3.3 | ||
diff --git a/packages/glpk/hmatrix-glpk.cabal b/packages/glpk/hmatrix-glpk.cabal index 5a1b59c..8593e0a 100644 --- a/packages/glpk/hmatrix-glpk.cabal +++ b/packages/glpk/hmatrix-glpk.cabal | |||
@@ -1,5 +1,5 @@ | |||
1 | Name: hmatrix-glpk | 1 | Name: hmatrix-glpk |
2 | Version: 0.4.1.0 | 2 | Version: 0.5.0.0 |
3 | License: GPL | 3 | License: GPL |
4 | License-file: LICENSE | 4 | License-file: LICENSE |
5 | Author: Alberto Ruiz | 5 | Author: Alberto Ruiz |
@@ -23,7 +23,7 @@ extra-source-files: examples/simplex1.hs | |||
23 | examples/simplex5.hs | 23 | examples/simplex5.hs |
24 | 24 | ||
25 | library | 25 | library |
26 | Build-Depends: base <5, hmatrix >= 0.16, containers >= 0.5.4.0 | 26 | Build-Depends: base <5, hmatrix >= 0.17, containers |
27 | 27 | ||
28 | hs-source-dirs: src | 28 | hs-source-dirs: src |
29 | 29 | ||
diff --git a/packages/glpk/src/Numeric/LinearProgramming.hs b/packages/glpk/src/Numeric/LinearProgramming.hs index 6a0c47d..0a776fa 100644 --- a/packages/glpk/src/Numeric/LinearProgramming.hs +++ b/packages/glpk/src/Numeric/LinearProgramming.hs | |||
@@ -85,8 +85,8 @@ module Numeric.LinearProgramming( | |||
85 | Solution(..) | 85 | Solution(..) |
86 | ) where | 86 | ) where |
87 | 87 | ||
88 | import Data.Packed | 88 | import Numeric.LinearAlgebra.HMatrix |
89 | import Data.Packed.Development | 89 | import Numeric.LinearAlgebra.Devel hiding (Dense) |
90 | import Foreign(Ptr) | 90 | import Foreign(Ptr) |
91 | import System.IO.Unsafe(unsafePerformIO) | 91 | import System.IO.Unsafe(unsafePerformIO) |
92 | import Foreign.C.Types | 92 | import Foreign.C.Types |
@@ -180,16 +180,17 @@ exact opt constr@(General _) bnds = exact opt (sparseOfGeneral constr) bnds | |||
180 | 180 | ||
181 | adapt :: Optimization -> (Int, Double, [Double]) | 181 | adapt :: Optimization -> (Int, Double, [Double]) |
182 | adapt opt = case opt of | 182 | adapt opt = case opt of |
183 | Maximize x -> (size x, 1 ,x) | 183 | Maximize x -> (sz x, 1 ,x) |
184 | Minimize x -> (size x, -1, (map negate x)) | 184 | Minimize x -> (sz x, -1, (map negate x)) |
185 | where size x | null x = error "simplex: objective function with zero variables" | 185 | where |
186 | | otherwise = length x | 186 | sz x | null x = error "simplex: objective function with zero variables" |
187 | | otherwise = length x | ||
187 | 188 | ||
188 | extract :: Double -> Vector Double -> Solution | 189 | extract :: Double -> Vector Double -> Solution |
189 | extract sg sol = r where | 190 | extract sg sol = r where |
190 | z = sg * (sol@>1) | 191 | z = sg * (sol!1) |
191 | v = toList $ subVector 2 (dim sol -2) sol | 192 | v = toList $ subVector 2 (size sol -2) sol |
192 | r = case round(sol@>0)::Int of | 193 | r = case round(sol!0)::Int of |
193 | 1 -> Undefined | 194 | 1 -> Undefined |
194 | 2 -> Feasible (z,v) | 195 | 2 -> Feasible (z,v) |
195 | 3 -> Infeasible (z,v) | 196 | 3 -> Infeasible (z,v) |
@@ -261,7 +262,7 @@ mkConstrD n f b1 | ok = fromLists (ob ++ co) | |||
261 | ok = all (==n) ls | 262 | ok = all (==n) ls |
262 | den = fromLists cs | 263 | den = fromLists cs |
263 | ob = map (([0,0]++).return) f | 264 | ob = map (([0,0]++).return) f |
264 | co = [[fromIntegral i, fromIntegral j,den@@>(i-1,j-1)]| i<-[1 ..rows den], j<-[1 .. cols den]] | 265 | co = [[fromIntegral i, fromIntegral j,den `atIndex` (i-1,j-1)]| i<-[1 ..rows den], j<-[1 .. cols den]] |
265 | 266 | ||
266 | mkConstrS :: Int -> [Double] -> [Bound [(Double, Int)]] -> Matrix Double | 267 | mkConstrS :: Int -> [Double] -> [Bound [(Double, Int)]] -> Matrix Double |
267 | mkConstrS n objfun b1 = fromLists (ob ++ co) where | 268 | mkConstrS n objfun b1 = fromLists (ob ++ co) where |
@@ -274,6 +275,11 @@ mkConstrS n objfun b1 = fromLists (ob ++ co) where | |||
274 | 275 | ||
275 | ----------------------------------------------------- | 276 | ----------------------------------------------------- |
276 | 277 | ||
278 | (##) :: TransArray c => TransRaw c b -> c -> b | ||
279 | infixl 1 ## | ||
280 | a ## b = applyRaw a b | ||
281 | {-# INLINE (##) #-} | ||
282 | |||
277 | foreign import ccall unsafe "c_simplex_sparse" c_simplex_sparse | 283 | foreign import ccall unsafe "c_simplex_sparse" c_simplex_sparse |
278 | :: CInt -> CInt -- rows and cols | 284 | :: CInt -> CInt -- rows and cols |
279 | -> CInt -> CInt -> Ptr Double -- coeffs | 285 | -> CInt -> CInt -> Ptr Double -- coeffs |
@@ -284,7 +290,7 @@ foreign import ccall unsafe "c_simplex_sparse" c_simplex_sparse | |||
284 | simplexSparse :: Int -> Int -> Matrix Double -> Matrix Double -> Vector Double | 290 | simplexSparse :: Int -> Int -> Matrix Double -> Matrix Double -> Vector Double |
285 | simplexSparse m n c b = unsafePerformIO $ do | 291 | simplexSparse m n c b = unsafePerformIO $ do |
286 | s <- createVector (2+n) | 292 | s <- createVector (2+n) |
287 | app3 (c_simplex_sparse (fi m) (fi n)) mat (cmat c) mat (cmat b) vec s "c_simplex_sparse" | 293 | c_simplex_sparse (fi m) (fi n) ## (cmat c) ## (cmat b) ## s #|"c_simplex_sparse" |
288 | return s | 294 | return s |
289 | 295 | ||
290 | foreign import ccall unsafe "c_exact_sparse" c_exact_sparse | 296 | foreign import ccall unsafe "c_exact_sparse" c_exact_sparse |
@@ -297,7 +303,7 @@ foreign import ccall unsafe "c_exact_sparse" c_exact_sparse | |||
297 | exactSparse :: Int -> Int -> Matrix Double -> Matrix Double -> Vector Double | 303 | exactSparse :: Int -> Int -> Matrix Double -> Matrix Double -> Vector Double |
298 | exactSparse m n c b = unsafePerformIO $ do | 304 | exactSparse m n c b = unsafePerformIO $ do |
299 | s <- createVector (2+n) | 305 | s <- createVector (2+n) |
300 | app3 (c_exact_sparse (fi m) (fi n)) mat (cmat c) mat (cmat b) vec s "c_exact_sparse" | 306 | c_exact_sparse (fi m) (fi n) ## (cmat c) ## (cmat b) ## s #|"c_exact_sparse" |
301 | return s | 307 | return s |
302 | 308 | ||
303 | glpFR, glpLO, glpUP, glpDB, glpFX :: Double | 309 | glpFR, glpLO, glpUP, glpDB, glpFX :: Double |
diff --git a/packages/glpk/src/Numeric/LinearProgramming/L1.hs b/packages/glpk/src/Numeric/LinearProgramming/L1.hs index f55c721..d7f1258 100644 --- a/packages/glpk/src/Numeric/LinearProgramming/L1.hs +++ b/packages/glpk/src/Numeric/LinearProgramming/L1.hs | |||
@@ -14,7 +14,7 @@ module Numeric.LinearProgramming.L1 ( | |||
14 | l1SolveU, | 14 | l1SolveU, |
15 | ) where | 15 | ) where |
16 | 16 | ||
17 | import Numeric.LinearAlgebra | 17 | import Numeric.LinearAlgebra.HMatrix |
18 | import Numeric.LinearProgramming | 18 | import Numeric.LinearProgramming |
19 | 19 | ||
20 | -- | L_inf solution of overconstrained system Ax=b. | 20 | -- | L_inf solution of overconstrained system Ax=b. |
diff --git a/packages/gsl/CHANGELOG b/packages/gsl/CHANGELOG new file mode 100644 index 0000000..091dc0e --- /dev/null +++ b/packages/gsl/CHANGELOG | |||
@@ -0,0 +1,14 @@ | |||
1 | 0.17.0.0 | ||
2 | -------- | ||
3 | |||
4 | * Added interpolation modules | ||
5 | |||
6 | * Added simulated annealing module | ||
7 | |||
8 | * Added odeSolveVWith | ||
9 | |||
10 | 0.16.0.0 | ||
11 | -------- | ||
12 | |||
13 | * The modules Numeric.GSL.* have been moved from hmatrix to the new package hmatrix-gsl. | ||
14 | |||
diff --git a/packages/gsl/THANKS.md b/packages/gsl/THANKS.md new file mode 100644 index 0000000..9cb2584 --- /dev/null +++ b/packages/gsl/THANKS.md | |||
@@ -0,0 +1,3 @@ | |||
1 | |||
2 | See the THANKS file of the hmatrix package. | ||
3 | |||
diff --git a/packages/gsl/hmatrix-gsl.cabal b/packages/gsl/hmatrix-gsl.cabal index 2f6f51b..f288c64 100644 --- a/packages/gsl/hmatrix-gsl.cabal +++ b/packages/gsl/hmatrix-gsl.cabal | |||
@@ -1,5 +1,5 @@ | |||
1 | Name: hmatrix-gsl | 1 | Name: hmatrix-gsl |
2 | Version: 0.16.0.3 | 2 | Version: 0.17.0.0 |
3 | License: GPL | 3 | License: GPL |
4 | License-file: LICENSE | 4 | License-file: LICENSE |
5 | Author: Alberto Ruiz | 5 | Author: Alberto Ruiz |
@@ -25,7 +25,7 @@ flag onlygsl | |||
25 | 25 | ||
26 | library | 26 | library |
27 | 27 | ||
28 | Build-Depends: base<5, hmatrix>=0.16, array, vector, | 28 | Build-Depends: base<5, hmatrix>=0.17, array, vector, |
29 | process, random | 29 | process, random |
30 | 30 | ||
31 | 31 | ||
@@ -44,6 +44,7 @@ library | |||
44 | Numeric.GSL, | 44 | Numeric.GSL, |
45 | Numeric.GSL.LinearAlgebra, | 45 | Numeric.GSL.LinearAlgebra, |
46 | Numeric.GSL.Interpolation, | 46 | Numeric.GSL.Interpolation, |
47 | Numeric.GSL.SimulatedAnnealing, | ||
47 | Graphics.Plot | 48 | Graphics.Plot |
48 | other-modules: Numeric.GSL.Internal, | 49 | other-modules: Numeric.GSL.Internal, |
49 | Numeric.GSL.Vector, | 50 | Numeric.GSL.Vector, |
@@ -53,7 +54,12 @@ library | |||
53 | 54 | ||
54 | C-sources: src/Numeric/GSL/gsl-aux.c | 55 | C-sources: src/Numeric/GSL/gsl-aux.c |
55 | 56 | ||
56 | cc-options: -O4 -msse2 -Wall | 57 | cc-options: -O4 -Wall |
58 | |||
59 | if arch(x86_64) | ||
60 | cc-options: -msse2 | ||
61 | if arch(i386) | ||
62 | cc-options: -msse2 | ||
57 | 63 | ||
58 | ghc-options: -Wall -fno-warn-missing-signatures | 64 | ghc-options: -Wall -fno-warn-missing-signatures |
59 | -fno-warn-orphans | 65 | -fno-warn-orphans |
diff --git a/packages/gsl/src/Graphics/Plot.hs b/packages/gsl/src/Graphics/Plot.hs index 0ea41ac..d2ea192 100644 --- a/packages/gsl/src/Graphics/Plot.hs +++ b/packages/gsl/src/Graphics/Plot.hs | |||
@@ -27,13 +27,13 @@ module Graphics.Plot( | |||
27 | 27 | ||
28 | ) where | 28 | ) where |
29 | 29 | ||
30 | import Numeric.Container | 30 | import Numeric.LinearAlgebra.HMatrix |
31 | import Data.List(intersperse) | 31 | import Data.List(intersperse) |
32 | import System.Process (system) | 32 | import System.Process (system) |
33 | 33 | ||
34 | -- | From vectors x and y, it generates a pair of matrices to be used as x and y arguments for matrix functions. | 34 | -- | From vectors x and y, it generates a pair of matrices to be used as x and y arguments for matrix functions. |
35 | meshdom :: Vector Double -> Vector Double -> (Matrix Double , Matrix Double) | 35 | meshdom :: Vector Double -> Vector Double -> (Matrix Double , Matrix Double) |
36 | meshdom r1 r2 = (outer r1 (constant 1 (dim r2)), outer (constant 1 (dim r1)) r2) | 36 | meshdom r1 r2 = (outer r1 (konst 1 (size r2)), outer (konst 1 (size r1)) r2) |
37 | 37 | ||
38 | 38 | ||
39 | {- | Draws a 3D surface representation of a real matrix. | 39 | {- | Draws a 3D surface representation of a real matrix. |
diff --git a/packages/gsl/src/Numeric/GSL/Fitting.hs b/packages/gsl/src/Numeric/GSL/Fitting.hs index 0a92373..8eb93a7 100644 --- a/packages/gsl/src/Numeric/GSL/Fitting.hs +++ b/packages/gsl/src/Numeric/GSL/Fitting.hs | |||
@@ -1,3 +1,5 @@ | |||
1 | {-# LANGUAGE FlexibleContexts #-} | ||
2 | |||
1 | {- | | 3 | {- | |
2 | Module : Numeric.GSL.Fitting | 4 | Module : Numeric.GSL.Fitting |
3 | Copyright : (c) Alberto Ruiz 2010 | 5 | Copyright : (c) Alberto Ruiz 2010 |
@@ -50,7 +52,7 @@ module Numeric.GSL.Fitting ( | |||
50 | fitModelScaled, fitModel | 52 | fitModelScaled, fitModel |
51 | ) where | 53 | ) where |
52 | 54 | ||
53 | import Numeric.LinearAlgebra | 55 | import Numeric.LinearAlgebra.HMatrix |
54 | import Numeric.GSL.Internal | 56 | import Numeric.GSL.Internal |
55 | 57 | ||
56 | import Foreign.Ptr(FunPtr, freeHaskellFunPtr) | 58 | import Foreign.Ptr(FunPtr, freeHaskellFunPtr) |
@@ -80,13 +82,13 @@ nlFitting :: FittingMethod | |||
80 | nlFitting method epsabs epsrel maxit fun jac xinit = nlFitGen (fi (fromEnum method)) fun jac xinit epsabs epsrel maxit | 82 | nlFitting method epsabs epsrel maxit fun jac xinit = nlFitGen (fi (fromEnum method)) fun jac xinit epsabs epsrel maxit |
81 | 83 | ||
82 | nlFitGen m f jac xiv epsabs epsrel maxit = unsafePerformIO $ do | 84 | nlFitGen m f jac xiv epsabs epsrel maxit = unsafePerformIO $ do |
83 | let p = dim xiv | 85 | let p = size xiv |
84 | n = dim (f xiv) | 86 | n = size (f xiv) |
85 | fp <- mkVecVecfun (aux_vTov (checkdim1 n p . f)) | 87 | fp <- mkVecVecfun (aux_vTov (checkdim1 n p . f)) |
86 | jp <- mkVecMatfun (aux_vTom (checkdim2 n p . jac)) | 88 | jp <- mkVecMatfun (aux_vTom (checkdim2 n p . jac)) |
87 | rawpath <- createMatrix RowMajor maxit (2+p) | 89 | rawpath <- createMatrix RowMajor maxit (2+p) |
88 | app2 (c_nlfit m fp jp epsabs epsrel (fi maxit) (fi n)) vec xiv mat rawpath "c_nlfit" | 90 | c_nlfit m fp jp epsabs epsrel (fi maxit) (fi n) # xiv # rawpath #|"c_nlfit" |
89 | let it = round (rawpath @@> (maxit-1,0)) | 91 | let it = round (rawpath `atIndex` (maxit-1,0)) |
90 | path = takeRows it rawpath | 92 | path = takeRows it rawpath |
91 | [sol] = toRows $ dropRows (it-1) path | 93 | [sol] = toRows $ dropRows (it-1) path |
92 | freeHaskellFunPtr fp | 94 | freeHaskellFunPtr fp |
@@ -99,7 +101,7 @@ foreign import ccall safe "nlfit" | |||
99 | ------------------------------------------------------- | 101 | ------------------------------------------------------- |
100 | 102 | ||
101 | checkdim1 n _p v | 103 | checkdim1 n _p v |
102 | | dim v == n = v | 104 | | size v == n = v |
103 | | otherwise = error $ "Error: "++ show n | 105 | | otherwise = error $ "Error: "++ show n |
104 | ++ " components expected in the result of the function supplied to nlFitting" | 106 | ++ " components expected in the result of the function supplied to nlFitting" |
105 | 107 | ||
@@ -114,9 +116,9 @@ err (model,deriv) dat vsol = zip sol errs where | |||
114 | sol = toList vsol | 116 | sol = toList vsol |
115 | c = max 1 (chi/sqrt (fromIntegral dof)) | 117 | c = max 1 (chi/sqrt (fromIntegral dof)) |
116 | dof = length dat - (rows cov) | 118 | dof = length dat - (rows cov) |
117 | chi = norm2 (fromList $ cost (resMs model) dat sol) | 119 | chi = norm_2 (fromList $ cost (resMs model) dat sol) |
118 | js = fromLists $ jacobian (resDs deriv) dat sol | 120 | js = fromLists $ jacobian (resDs deriv) dat sol |
119 | cov = inv $ trans js <> js | 121 | cov = inv $ tr js <> js |
120 | errs = toList $ scalar c * sqrt (takeDiag cov) | 122 | errs = toList $ scalar c * sqrt (takeDiag cov) |
121 | 123 | ||
122 | 124 | ||
diff --git a/packages/gsl/src/Numeric/GSL/Fourier.hs b/packages/gsl/src/Numeric/GSL/Fourier.hs index 734325b..1c2c053 100644 --- a/packages/gsl/src/Numeric/GSL/Fourier.hs +++ b/packages/gsl/src/Numeric/GSL/Fourier.hs | |||
@@ -1,3 +1,5 @@ | |||
1 | {-# LANGUAGE TypeFamilies #-} | ||
2 | |||
1 | {- | | 3 | {- | |
2 | Module : Numeric.GSL.Fourier | 4 | Module : Numeric.GSL.Fourier |
3 | Copyright : (c) Alberto Ruiz 2006 | 5 | Copyright : (c) Alberto Ruiz 2006 |
@@ -16,15 +18,14 @@ module Numeric.GSL.Fourier ( | |||
16 | ifft | 18 | ifft |
17 | ) where | 19 | ) where |
18 | 20 | ||
19 | import Data.Packed | 21 | import Numeric.LinearAlgebra.HMatrix |
20 | import Numeric.GSL.Internal | 22 | import Numeric.GSL.Internal |
21 | import Data.Complex | ||
22 | import Foreign.C.Types | 23 | import Foreign.C.Types |
23 | import System.IO.Unsafe (unsafePerformIO) | 24 | import System.IO.Unsafe (unsafePerformIO) |
24 | 25 | ||
25 | genfft code v = unsafePerformIO $ do | 26 | genfft code v = unsafePerformIO $ do |
26 | r <- createVector (dim v) | 27 | r <- createVector (size v) |
27 | app2 (c_fft code) vec v vec r "fft" | 28 | c_fft code # v # r #|"fft" |
28 | return r | 29 | return r |
29 | 30 | ||
30 | foreign import ccall unsafe "gsl-aux.h fft" c_fft :: CInt -> TCV (TCV Res) | 31 | foreign import ccall unsafe "gsl-aux.h fft" c_fft :: CInt -> TCV (TCV Res) |
@@ -42,3 +43,4 @@ fft = genfft 0 | |||
42 | -- | The inverse of 'fft', using /gsl_fft_complex_inverse/. | 43 | -- | The inverse of 'fft', using /gsl_fft_complex_inverse/. |
43 | ifft :: Vector (Complex Double) -> Vector (Complex Double) | 44 | ifft :: Vector (Complex Double) -> Vector (Complex Double) |
44 | ifft = genfft 1 | 45 | ifft = genfft 1 |
46 | |||
diff --git a/packages/gsl/src/Numeric/GSL/IO.hs b/packages/gsl/src/Numeric/GSL/IO.hs index 0d6031a..936f6bf 100644 --- a/packages/gsl/src/Numeric/GSL/IO.hs +++ b/packages/gsl/src/Numeric/GSL/IO.hs | |||
@@ -14,7 +14,7 @@ module Numeric.GSL.IO ( | |||
14 | fileDimensions, loadMatrix, fromFile | 14 | fileDimensions, loadMatrix, fromFile |
15 | ) where | 15 | ) where |
16 | 16 | ||
17 | import Data.Packed | 17 | import Numeric.LinearAlgebra.HMatrix hiding(saveMatrix, loadMatrix) |
18 | import Numeric.GSL.Vector | 18 | import Numeric.GSL.Vector |
19 | import System.Process(readProcess) | 19 | import System.Process(readProcess) |
20 | 20 | ||
diff --git a/packages/gsl/src/Numeric/GSL/Internal.hs b/packages/gsl/src/Numeric/GSL/Internal.hs index a1c4e0c..dcd3bc4 100644 --- a/packages/gsl/src/Numeric/GSL/Internal.hs +++ b/packages/gsl/src/Numeric/GSL/Internal.hs | |||
@@ -22,21 +22,20 @@ module Numeric.GSL.Internal( | |||
22 | aux_vTom, | 22 | aux_vTom, |
23 | createV, | 23 | createV, |
24 | createMIO, | 24 | createMIO, |
25 | module Data.Packed.Development, | 25 | module Numeric.LinearAlgebra.Devel, |
26 | check, | 26 | check,(#),vec, ww2, |
27 | Res,TV,TM,TCV,TCM | 27 | Res,TV,TM,TCV,TCM |
28 | ) where | 28 | ) where |
29 | 29 | ||
30 | import Data.Packed | 30 | import Numeric.LinearAlgebra.HMatrix |
31 | import Data.Packed.Development hiding (check) | 31 | import Numeric.LinearAlgebra.Devel hiding (check) |
32 | import Data.Complex | ||
33 | 32 | ||
34 | import Foreign.Marshal.Array(copyArray) | 33 | import Foreign.Marshal.Array(copyArray) |
35 | import Foreign.Ptr(Ptr, FunPtr) | 34 | import Foreign.Ptr(Ptr, FunPtr) |
36 | import Foreign.C.Types | 35 | import Foreign.C.Types |
37 | import Foreign.C.String(peekCString) | 36 | import Foreign.C.String(peekCString) |
38 | import System.IO.Unsafe(unsafePerformIO) | 37 | import System.IO.Unsafe(unsafePerformIO) |
39 | import Data.Vector.Storable(unsafeWith) | 38 | import Data.Vector.Storable as V (unsafeWith,length) |
40 | import Control.Monad(when) | 39 | import Control.Monad(when) |
41 | 40 | ||
42 | iv :: (Vector Double -> Double) -> (CInt -> Ptr Double -> Double) | 41 | iv :: (Vector Double -> Double) -> (CInt -> Ptr Double -> Double) |
@@ -87,12 +86,12 @@ aux_vTom f n p rr cr r = g where | |||
87 | 86 | ||
88 | createV n fun msg = unsafePerformIO $ do | 87 | createV n fun msg = unsafePerformIO $ do |
89 | r <- createVector n | 88 | r <- createVector n |
90 | app1 fun vec r msg | 89 | fun # r #| msg |
91 | return r | 90 | return r |
92 | 91 | ||
93 | createMIO r c fun msg = do | 92 | createMIO r c fun msg = do |
94 | res <- createMatrix RowMajor r c | 93 | res <- createMatrix RowMajor r c |
95 | app1 fun mat res msg | 94 | fun # res #| msg |
96 | return res | 95 | return res |
97 | 96 | ||
98 | -------------------------------------------------------------------------------- | 97 | -------------------------------------------------------------------------------- |
@@ -124,3 +123,15 @@ type TCM x = CInt -> CInt -> PC -> x | |||
124 | type TVV = TV (TV Res) | 123 | type TVV = TV (TV Res) |
125 | type TVM = TV (TM Res) | 124 | type TVM = TV (TM Res) |
126 | 125 | ||
126 | ww2 w1 o1 w2 o2 f = w1 o1 $ \a1 -> w2 o2 $ \a2 -> f a1 a2 | ||
127 | |||
128 | vec x f = unsafeWith x $ \p -> do | ||
129 | let v g = do | ||
130 | g (fi $ V.length x) p | ||
131 | f v | ||
132 | {-# INLINE vec #-} | ||
133 | |||
134 | infixl 1 # | ||
135 | a # b = applyRaw a b | ||
136 | {-# INLINE (#) #-} | ||
137 | |||
diff --git a/packages/gsl/src/Numeric/GSL/Interpolation.hs b/packages/gsl/src/Numeric/GSL/Interpolation.hs index 4d72ee2..d060468 100644 --- a/packages/gsl/src/Numeric/GSL/Interpolation.hs +++ b/packages/gsl/src/Numeric/GSL/Interpolation.hs | |||
@@ -32,8 +32,7 @@ module Numeric.GSL.Interpolation ( | |||
32 | , evaluateIntegralV | 32 | , evaluateIntegralV |
33 | ) where | 33 | ) where |
34 | 34 | ||
35 | import Data.Packed.Vector(Vector, fromList, dim) | 35 | import Numeric.LinearAlgebra(Vector, fromList, size, Numeric) |
36 | import Data.Packed.Foreign(appVector) | ||
37 | import Foreign.C.Types | 36 | import Foreign.C.Types |
38 | import Foreign.Marshal.Alloc(alloca) | 37 | import Foreign.Marshal.Alloc(alloca) |
39 | import Foreign.Ptr(Ptr) | 38 | import Foreign.Ptr(Ptr) |
@@ -57,6 +56,9 @@ methodToInt CSplinePeriodic = 3 | |||
57 | methodToInt Akima = 4 | 56 | methodToInt Akima = 4 |
58 | methodToInt AkimaPeriodic = 5 | 57 | methodToInt AkimaPeriodic = 5 |
59 | 58 | ||
59 | dim :: Numeric t => Vector t -> Int | ||
60 | dim = size | ||
61 | |||
60 | applyCFun hsname cname fun mth xs ys x | 62 | applyCFun hsname cname fun mth xs ys x |
61 | | dim xs /= dim ys = error $ | 63 | | dim xs /= dim ys = error $ |
62 | "Error: Vectors of unequal sizes " ++ | 64 | "Error: Vectors of unequal sizes " ++ |
@@ -115,7 +117,7 @@ evaluate :: InterpolationMethod -- ^ What method to use to interpolate | |||
115 | -> Double -- ^ Point at which to evaluate the function | 117 | -> Double -- ^ Point at which to evaluate the function |
116 | -> Double -- ^ Interpolated result | 118 | -> Double -- ^ Interpolated result |
117 | evaluate mth pts = | 119 | evaluate mth pts = |
118 | applyCFun "evaluate" "spline_eval" c_spline_eval_deriv | 120 | applyCFun "evaluate" "spline_eval" c_spline_eval |
119 | mth (fromList xs) (fromList ys) | 121 | mth (fromList xs) (fromList ys) |
120 | where | 122 | where |
121 | (xs, ys) = unzip pts | 123 | (xs, ys) = unzip pts |
diff --git a/packages/gsl/src/Numeric/GSL/LinearAlgebra.hs b/packages/gsl/src/Numeric/GSL/LinearAlgebra.hs index 17e2258..6ffe306 100644 --- a/packages/gsl/src/Numeric/GSL/LinearAlgebra.hs +++ b/packages/gsl/src/Numeric/GSL/LinearAlgebra.hs | |||
@@ -15,7 +15,7 @@ module Numeric.GSL.LinearAlgebra ( | |||
15 | fileDimensions, loadMatrix, fromFile | 15 | fileDimensions, loadMatrix, fromFile |
16 | ) where | 16 | ) where |
17 | 17 | ||
18 | import Data.Packed | 18 | import Numeric.LinearAlgebra.HMatrix hiding (RandDist,randomVector,saveMatrix,loadMatrix) |
19 | import Numeric.GSL.Internal hiding (TV,TM,TCV,TCM) | 19 | import Numeric.GSL.Internal hiding (TV,TM,TCV,TCM) |
20 | 20 | ||
21 | import Foreign.Marshal.Alloc(free) | 21 | import Foreign.Marshal.Alloc(free) |
@@ -40,7 +40,7 @@ randomVector :: Int -- ^ seed | |||
40 | -> Vector Double | 40 | -> Vector Double |
41 | randomVector seed dist n = unsafePerformIO $ do | 41 | randomVector seed dist n = unsafePerformIO $ do |
42 | r <- createVector n | 42 | r <- createVector n |
43 | app1 (c_random_vector (fi seed) ((fi.fromEnum) dist)) vec r "randomVector" | 43 | c_random_vector (fi seed) ((fi.fromEnum) dist) # r #|"randomVector" |
44 | return r | 44 | return r |
45 | 45 | ||
46 | foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> TV | 46 | foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> TV |
@@ -56,7 +56,7 @@ saveMatrix filename fmt m = do | |||
56 | charname <- newCString filename | 56 | charname <- newCString filename |
57 | charfmt <- newCString fmt | 57 | charfmt <- newCString fmt |
58 | let o = if orderOf m == RowMajor then 1 else 0 | 58 | let o = if orderOf m == RowMajor then 1 else 0 |
59 | app1 (matrix_fprintf charname charfmt o) mat m "matrix_fprintf" | 59 | matrix_fprintf charname charfmt o # m #|"matrix_fprintf" |
60 | free charname | 60 | free charname |
61 | free charfmt | 61 | free charfmt |
62 | 62 | ||
@@ -69,7 +69,7 @@ fscanfVector :: FilePath -> Int -> IO (Vector Double) | |||
69 | fscanfVector filename n = do | 69 | fscanfVector filename n = do |
70 | charname <- newCString filename | 70 | charname <- newCString filename |
71 | res <- createVector n | 71 | res <- createVector n |
72 | app1 (gsl_vector_fscanf charname) vec res "gsl_vector_fscanf" | 72 | gsl_vector_fscanf charname # res #|"gsl_vector_fscanf" |
73 | free charname | 73 | free charname |
74 | return res | 74 | return res |
75 | 75 | ||
@@ -80,7 +80,7 @@ fprintfVector :: FilePath -> String -> Vector Double -> IO () | |||
80 | fprintfVector filename fmt v = do | 80 | fprintfVector filename fmt v = do |
81 | charname <- newCString filename | 81 | charname <- newCString filename |
82 | charfmt <- newCString fmt | 82 | charfmt <- newCString fmt |
83 | app1 (gsl_vector_fprintf charname charfmt) vec v "gsl_vector_fprintf" | 83 | gsl_vector_fprintf charname charfmt # v #|"gsl_vector_fprintf" |
84 | free charname | 84 | free charname |
85 | free charfmt | 85 | free charfmt |
86 | 86 | ||
@@ -91,7 +91,7 @@ freadVector :: FilePath -> Int -> IO (Vector Double) | |||
91 | freadVector filename n = do | 91 | freadVector filename n = do |
92 | charname <- newCString filename | 92 | charname <- newCString filename |
93 | res <- createVector n | 93 | res <- createVector n |
94 | app1 (gsl_vector_fread charname) vec res "gsl_vector_fread" | 94 | gsl_vector_fread charname # res #| "gsl_vector_fread" |
95 | free charname | 95 | free charname |
96 | return res | 96 | return res |
97 | 97 | ||
@@ -101,7 +101,7 @@ foreign import ccall unsafe "vector_fread" gsl_vector_fread:: Ptr CChar -> TV | |||
101 | fwriteVector :: FilePath -> Vector Double -> IO () | 101 | fwriteVector :: FilePath -> Vector Double -> IO () |
102 | fwriteVector filename v = do | 102 | fwriteVector filename v = do |
103 | charname <- newCString filename | 103 | charname <- newCString filename |
104 | app1 (gsl_vector_fwrite charname) vec v "gsl_vector_fwrite" | 104 | gsl_vector_fwrite charname # v #|"gsl_vector_fwrite" |
105 | free charname | 105 | free charname |
106 | 106 | ||
107 | foreign import ccall unsafe "vector_fwrite" gsl_vector_fwrite :: Ptr CChar -> TV | 107 | foreign import ccall unsafe "vector_fwrite" gsl_vector_fwrite :: Ptr CChar -> TV |
diff --git a/packages/gsl/src/Numeric/GSL/Minimization.hs b/packages/gsl/src/Numeric/GSL/Minimization.hs index 056d463..a0e5306 100644 --- a/packages/gsl/src/Numeric/GSL/Minimization.hs +++ b/packages/gsl/src/Numeric/GSL/Minimization.hs | |||
@@ -1,3 +1,6 @@ | |||
1 | {-# LANGUAGE FlexibleContexts #-} | ||
2 | |||
3 | |||
1 | {- | | 4 | {- | |
2 | Module : Numeric.GSL.Minimization | 5 | Module : Numeric.GSL.Minimization |
3 | Copyright : (c) Alberto Ruiz 2006-9 | 6 | Copyright : (c) Alberto Ruiz 2006-9 |
@@ -56,7 +59,7 @@ module Numeric.GSL.Minimization ( | |||
56 | ) where | 59 | ) where |
57 | 60 | ||
58 | 61 | ||
59 | import Data.Packed | 62 | import Numeric.LinearAlgebra.HMatrix hiding(step) |
60 | import Numeric.GSL.Internal | 63 | import Numeric.GSL.Internal |
61 | 64 | ||
62 | import Foreign.Ptr(Ptr, FunPtr, freeHaskellFunPtr) | 65 | import Foreign.Ptr(Ptr, FunPtr, freeHaskellFunPtr) |
@@ -99,7 +102,7 @@ uniMinimizeGen m f xmin xl xu epsrel maxit = unsafePerformIO $ do | |||
99 | rawpath <- createMIO maxit 4 | 102 | rawpath <- createMIO maxit 4 |
100 | (c_uniMinize m fp epsrel (fi maxit) xmin xl xu) | 103 | (c_uniMinize m fp epsrel (fi maxit) xmin xl xu) |
101 | "uniMinimize" | 104 | "uniMinimize" |
102 | let it = round (rawpath @@> (maxit-1,0)) | 105 | let it = round (rawpath `atIndex` (maxit-1,0)) |
103 | path = takeRows it rawpath | 106 | path = takeRows it rawpath |
104 | [sol] = toLists $ dropRows (it-1) path | 107 | [sol] = toLists $ dropRows (it-1) path |
105 | freeHaskellFunPtr fp | 108 | freeHaskellFunPtr fp |
@@ -134,16 +137,16 @@ minimizeV :: MinimizeMethod | |||
134 | minimize method eps maxit sz f xi = v2l $ minimizeV method eps maxit (fromList sz) (f.toList) (fromList xi) | 137 | minimize method eps maxit sz f xi = v2l $ minimizeV method eps maxit (fromList sz) (f.toList) (fromList xi) |
135 | where v2l (v,m) = (toList v, m) | 138 | where v2l (v,m) = (toList v, m) |
136 | 139 | ||
137 | ww2 w1 o1 w2 o2 f = w1 o1 $ \a1 -> w2 o2 $ \a2 -> f a1 a2 | 140 | |
138 | 141 | ||
139 | minimizeV method eps maxit szv f xiv = unsafePerformIO $ do | 142 | minimizeV method eps maxit szv f xiv = unsafePerformIO $ do |
140 | let n = dim xiv | 143 | let n = size xiv |
141 | fp <- mkVecfun (iv f) | 144 | fp <- mkVecfun (iv f) |
142 | rawpath <- ww2 vec xiv vec szv $ \xiv' szv' -> | 145 | rawpath <- ww2 vec xiv vec szv $ \xiv' szv' -> |
143 | createMIO maxit (n+3) | 146 | createMIO maxit (n+3) |
144 | (c_minimize (fi (fromEnum method)) fp eps (fi maxit) // xiv' // szv') | 147 | (c_minimize (fi (fromEnum method)) fp eps (fi maxit) // xiv' // szv') |
145 | "minimize" | 148 | "minimize" |
146 | let it = round (rawpath @@> (maxit-1,0)) | 149 | let it = round (rawpath `atIndex` (maxit-1,0)) |
147 | path = takeRows it rawpath | 150 | path = takeRows it rawpath |
148 | sol = flatten $ dropColumns 3 $ dropRows (it-1) path | 151 | sol = flatten $ dropColumns 3 $ dropRows (it-1) path |
149 | freeHaskellFunPtr fp | 152 | freeHaskellFunPtr fp |
@@ -191,7 +194,7 @@ minimizeD method eps maxit istep tol f df xi = v2l $ minimizeVD | |||
191 | 194 | ||
192 | 195 | ||
193 | minimizeVD method eps maxit istep tol f df xiv = unsafePerformIO $ do | 196 | minimizeVD method eps maxit istep tol f df xiv = unsafePerformIO $ do |
194 | let n = dim xiv | 197 | let n = size xiv |
195 | f' = f | 198 | f' = f |
196 | df' = (checkdim1 n . df) | 199 | df' = (checkdim1 n . df) |
197 | fp <- mkVecfun (iv f') | 200 | fp <- mkVecfun (iv f') |
@@ -200,7 +203,7 @@ minimizeVD method eps maxit istep tol f df xiv = unsafePerformIO $ do | |||
200 | createMIO maxit (n+2) | 203 | createMIO maxit (n+2) |
201 | (c_minimizeD (fi (fromEnum method)) fp dfp istep tol eps (fi maxit) // xiv') | 204 | (c_minimizeD (fi (fromEnum method)) fp dfp istep tol eps (fi maxit) // xiv') |
202 | "minimizeD" | 205 | "minimizeD" |
203 | let it = round (rawpath @@> (maxit-1,0)) | 206 | let it = round (rawpath `atIndex` (maxit-1,0)) |
204 | path = takeRows it rawpath | 207 | path = takeRows it rawpath |
205 | sol = flatten $ dropColumns 2 $ dropRows (it-1) path | 208 | sol = flatten $ dropColumns 2 $ dropRows (it-1) path |
206 | freeHaskellFunPtr fp | 209 | freeHaskellFunPtr fp |
@@ -217,6 +220,6 @@ foreign import ccall safe "gsl-aux.h minimizeD" | |||
217 | --------------------------------------------------------------------- | 220 | --------------------------------------------------------------------- |
218 | 221 | ||
219 | checkdim1 n v | 222 | checkdim1 n v |
220 | | dim v == n = v | 223 | | size v == n = v |
221 | | otherwise = error $ "Error: "++ show n | 224 | | otherwise = error $ "Error: "++ show n |
222 | ++ " components expected in the result of the gradient supplied to minimizeD" | 225 | ++ " components expected in the result of the gradient supplied to minimizeD" |
diff --git a/packages/gsl/src/Numeric/GSL/ODE.hs b/packages/gsl/src/Numeric/GSL/ODE.hs index 7549a65..9e52873 100644 --- a/packages/gsl/src/Numeric/GSL/ODE.hs +++ b/packages/gsl/src/Numeric/GSL/ODE.hs | |||
@@ -1,3 +1,6 @@ | |||
1 | {-# LANGUAGE FlexibleContexts #-} | ||
2 | |||
3 | |||
1 | {- | | 4 | {- | |
2 | Module : Numeric.GSL.ODE | 5 | Module : Numeric.GSL.ODE |
3 | Copyright : (c) Alberto Ruiz 2010 | 6 | Copyright : (c) Alberto Ruiz 2010 |
@@ -29,10 +32,10 @@ main = mplot (ts : toColumns sol) | |||
29 | ----------------------------------------------------------------------------- | 32 | ----------------------------------------------------------------------------- |
30 | 33 | ||
31 | module Numeric.GSL.ODE ( | 34 | module Numeric.GSL.ODE ( |
32 | odeSolve, odeSolveV, ODEMethod(..), Jacobian | 35 | odeSolve, odeSolveV, odeSolveVWith, ODEMethod(..), Jacobian, StepControl(..) |
33 | ) where | 36 | ) where |
34 | 37 | ||
35 | import Data.Packed | 38 | import Numeric.LinearAlgebra.HMatrix |
36 | import Numeric.GSL.Internal | 39 | import Numeric.GSL.Internal |
37 | 40 | ||
38 | import Foreign.Ptr(FunPtr, nullFunPtr, freeHaskellFunPtr) | 41 | import Foreign.Ptr(FunPtr, nullFunPtr, freeHaskellFunPtr) |
@@ -41,9 +44,10 @@ import System.IO.Unsafe(unsafePerformIO) | |||
41 | 44 | ||
42 | ------------------------------------------------------------------------- | 45 | ------------------------------------------------------------------------- |
43 | 46 | ||
44 | type TVV = TV (TV Res) | 47 | type TVV = TV (TV Res) |
45 | type TVM = TV (TM Res) | 48 | type TVM = TV (TM Res) |
46 | type TVVM = TV (TV (TM Res)) | 49 | type TVVM = TV (TV (TM Res)) |
50 | type TVVVM = TV (TV (TV (TM Res))) | ||
47 | 51 | ||
48 | type Jacobian = Double -> Vector Double -> Matrix Double | 52 | type Jacobian = Double -> Vector Double -> Matrix Double |
49 | 53 | ||
@@ -60,73 +64,105 @@ data ODEMethod = RK2 -- ^ Embedded Runge-Kutta (2, 3) method. | |||
60 | | MSAdams -- ^ A variable-coefficient linear multistep Adams method in Nordsieck form. This stepper uses explicit Adams-Bashforth (predictor) and implicit Adams-Moulton (corrector) methods in P(EC)^m functional iteration mode. Method order varies dynamically between 1 and 12. | 64 | | MSAdams -- ^ A variable-coefficient linear multistep Adams method in Nordsieck form. This stepper uses explicit Adams-Bashforth (predictor) and implicit Adams-Moulton (corrector) methods in P(EC)^m functional iteration mode. Method order varies dynamically between 1 and 12. |
61 | | MSBDF Jacobian -- ^ A variable-coefficient linear multistep backward differentiation formula (BDF) method in Nordsieck form. This stepper uses the explicit BDF formula as predictor and implicit BDF formula as corrector. A modified Newton iteration method is used to solve the system of non-linear equations. Method order varies dynamically between 1 and 5. The method is generally suitable for stiff problems. | 65 | | MSBDF Jacobian -- ^ A variable-coefficient linear multistep backward differentiation formula (BDF) method in Nordsieck form. This stepper uses the explicit BDF formula as predictor and implicit BDF formula as corrector. A modified Newton iteration method is used to solve the system of non-linear equations. Method order varies dynamically between 1 and 5. The method is generally suitable for stiff problems. |
62 | 66 | ||
67 | -- | Adaptive step-size control functions | ||
68 | data StepControl = X Double Double -- ^ abs. and rel. tolerance for x(t) | ||
69 | | X' Double Double -- ^ abs. and rel. tolerance for x'(t) | ||
70 | | XX' Double Double Double Double -- ^ include both via rel. tolerance scaling factors a_x, a_x' | ||
71 | | ScXX' Double Double Double Double (Vector Double) -- ^ scale abs. tolerance of x(t) components | ||
63 | 72 | ||
64 | -- | A version of 'odeSolveV' with reasonable default parameters and system of equations defined using lists. | 73 | -- | A version of 'odeSolveV' with reasonable default parameters and system of equations defined using lists. |
65 | odeSolve | 74 | odeSolve |
66 | :: (Double -> [Double] -> [Double]) -- ^ xdot(t,x) | 75 | :: (Double -> [Double] -> [Double]) -- ^ x'(t,x) |
67 | -> [Double] -- ^ initial conditions | 76 | -> [Double] -- ^ initial conditions |
68 | -> Vector Double -- ^ desired solution times | 77 | -> Vector Double -- ^ desired solution times |
69 | -> Matrix Double -- ^ solution | 78 | -> Matrix Double -- ^ solution |
70 | odeSolve xdot xi ts = odeSolveV RKf45 hi epsAbs epsRel (l2v xdot) (fromList xi) ts | 79 | odeSolve xdot xi ts = odeSolveV RKf45 hi epsAbs epsRel (l2v xdot) (fromList xi) ts |
71 | where hi = (ts@>1 - ts@>0)/100 | 80 | where hi = (ts!1 - ts!0)/100 |
72 | epsAbs = 1.49012e-08 | 81 | epsAbs = 1.49012e-08 |
73 | epsRel = 1.49012e-08 | 82 | epsRel = epsAbs |
74 | l2v f = \t -> fromList . f t . toList | 83 | l2v f = \t -> fromList . f t . toList |
75 | 84 | ||
76 | -- | Evolution of the system with adaptive step-size control. | 85 | -- | A version of 'odeSolveVWith' with reasonable default step control. |
77 | odeSolveV | 86 | odeSolveV |
78 | :: ODEMethod | 87 | :: ODEMethod |
79 | -> Double -- ^ initial step size | 88 | -> Double -- ^ initial step size |
80 | -> Double -- ^ absolute tolerance for the state vector | 89 | -> Double -- ^ absolute tolerance for the state vector |
81 | -> Double -- ^ relative tolerance for the state vector | 90 | -> Double -- ^ relative tolerance for the state vector |
82 | -> (Double -> Vector Double -> Vector Double) -- ^ xdot(t,x) | 91 | -> (Double -> Vector Double -> Vector Double) -- ^ x'(t,x) |
83 | -> Vector Double -- ^ initial conditions | 92 | -> Vector Double -- ^ initial conditions |
84 | -> Vector Double -- ^ desired solution times | 93 | -> Vector Double -- ^ desired solution times |
85 | -> Matrix Double -- ^ solution | 94 | -> Matrix Double -- ^ solution |
86 | odeSolveV RK2 = odeSolveV' 0 Nothing | 95 | odeSolveV meth hi epsAbs epsRel = odeSolveVWith meth (XX' epsAbs epsRel 1 1) hi |
87 | odeSolveV RK4 = odeSolveV' 1 Nothing | 96 | |
88 | odeSolveV RKf45 = odeSolveV' 2 Nothing | 97 | -- | Evolution of the system with adaptive step-size control. |
89 | odeSolveV RKck = odeSolveV' 3 Nothing | 98 | odeSolveVWith |
90 | odeSolveV RK8pd = odeSolveV' 4 Nothing | 99 | :: ODEMethod |
91 | odeSolveV (RK2imp jac) = odeSolveV' 5 (Just jac) | 100 | -> StepControl |
92 | odeSolveV (RK4imp jac) = odeSolveV' 6 (Just jac) | 101 | -> Double -- ^ initial step size |
93 | odeSolveV (BSimp jac) = odeSolveV' 7 (Just jac) | 102 | -> (Double -> Vector Double -> Vector Double) -- ^ x'(t,x) |
94 | odeSolveV (RK1imp jac) = odeSolveV' 8 (Just jac) | ||
95 | odeSolveV MSAdams = odeSolveV' 9 Nothing | ||
96 | odeSolveV (MSBDF jac) = odeSolveV' 10 (Just jac) | ||
97 | |||
98 | |||
99 | odeSolveV' | ||
100 | :: CInt | ||
101 | -> Maybe (Double -> Vector Double -> Matrix Double) -- ^ optional jacobian | ||
102 | -> Double -- ^ initial step size | ||
103 | -> Double -- ^ absolute tolerance for the state vector | ||
104 | -> Double -- ^ relative tolerance for the state vector | ||
105 | -> (Double -> Vector Double -> Vector Double) -- ^ xdot(t,x) | ||
106 | -> Vector Double -- ^ initial conditions | 103 | -> Vector Double -- ^ initial conditions |
107 | -> Vector Double -- ^ desired solution times | 104 | -> Vector Double -- ^ desired solution times |
108 | -> Matrix Double -- ^ solution | 105 | -> Matrix Double -- ^ solution |
109 | odeSolveV' method mbjac h epsAbs epsRel f xiv ts = unsafePerformIO $ do | 106 | odeSolveVWith method control = odeSolveVWith' m mbj c epsAbs epsRel aX aX' mbsc |
110 | let n = dim xiv | 107 | where (m, mbj) = case method of |
111 | fp <- mkDoubleVecVecfun (\t -> aux_vTov (checkdim1 n . f t)) | 108 | RK2 -> (0 , Nothing ) |
112 | jp <- case mbjac of | 109 | RK4 -> (1 , Nothing ) |
113 | Just jac -> mkDoubleVecMatfun (\t -> aux_vTom (checkdim2 n . jac t)) | 110 | RKf45 -> (2 , Nothing ) |
114 | Nothing -> return nullFunPtr | 111 | RKck -> (3 , Nothing ) |
115 | sol <- vec xiv $ \xiv' -> | 112 | RK8pd -> (4 , Nothing ) |
116 | vec (checkTimes ts) $ \ts' -> | 113 | RK2imp jac -> (5 , Just jac) |
117 | createMIO (dim ts) n | 114 | RK4imp jac -> (6 , Just jac) |
118 | (ode_c (method) h epsAbs epsRel fp jp // xiv' // ts' ) | 115 | BSimp jac -> (7 , Just jac) |
119 | "ode" | 116 | RK1imp jac -> (8 , Just jac) |
120 | freeHaskellFunPtr fp | 117 | MSAdams -> (9 , Nothing ) |
121 | return sol | 118 | MSBDF jac -> (10, Just jac) |
119 | (c, epsAbs, epsRel, aX, aX', mbsc) = case control of | ||
120 | X ea er -> (0, ea, er, 1 , 0 , Nothing) | ||
121 | X' ea er -> (0, ea, er, 0 , 1 , Nothing) | ||
122 | XX' ea er ax ax' -> (0, ea, er, ax, ax', Nothing) | ||
123 | ScXX' ea er ax ax' sc -> (1, ea, er, ax, ax', Just sc) | ||
124 | |||
125 | odeSolveVWith' | ||
126 | :: CInt -- ^ stepping function | ||
127 | -> Maybe (Double -> Vector Double -> Matrix Double) -- ^ optional jacobian | ||
128 | -> CInt -- ^ step-size control function | ||
129 | -> Double -- ^ absolute tolerance for step-size control | ||
130 | -> Double -- ^ relative tolerance for step-size control | ||
131 | -> Double -- ^ scaling factor for relative tolerance of x(t) | ||
132 | -> Double -- ^ scaling factor for relative tolerance of x'(t) | ||
133 | -> Maybe (Vector Double) -- ^ optional scaling for absolute error | ||
134 | -> Double -- ^ initial step size | ||
135 | -> (Double -> Vector Double -> Vector Double) -- ^ x'(t,x) | ||
136 | -> Vector Double -- ^ initial conditions | ||
137 | -> Vector Double -- ^ desired solution times | ||
138 | -> Matrix Double -- ^ solution | ||
139 | odeSolveVWith' method mbjac control epsAbs epsRel aX aX' mbsc h f xiv ts = | ||
140 | unsafePerformIO $ do | ||
141 | let n = size xiv | ||
142 | sc = case mbsc of | ||
143 | Just scv -> checkdim1 n scv | ||
144 | Nothing -> xiv | ||
145 | fp <- mkDoubleVecVecfun (\t -> aux_vTov (checkdim1 n . f t)) | ||
146 | jp <- case mbjac of | ||
147 | Just jac -> mkDoubleVecMatfun (\t -> aux_vTom (checkdim2 n . jac t)) | ||
148 | Nothing -> return nullFunPtr | ||
149 | sol <- vec sc $ \sc' -> vec xiv $ \xiv' -> | ||
150 | vec (checkTimes ts) $ \ts' -> createMIO (size ts) n | ||
151 | (ode_c method control h epsAbs epsRel aX aX' fp jp | ||
152 | // sc' // xiv' // ts' ) | ||
153 | "ode" | ||
154 | freeHaskellFunPtr fp | ||
155 | return sol | ||
122 | 156 | ||
123 | foreign import ccall safe "ode" | 157 | foreign import ccall safe "ode" |
124 | ode_c :: CInt -> Double -> Double -> Double -> FunPtr (Double -> TVV) -> FunPtr (Double -> TVM) -> TVVM | 158 | ode_c :: CInt -> CInt -> Double |
159 | -> Double -> Double -> Double -> Double | ||
160 | -> FunPtr (Double -> TVV) -> FunPtr (Double -> TVM) -> TVVVM | ||
125 | 161 | ||
126 | ------------------------------------------------------- | 162 | ------------------------------------------------------- |
127 | 163 | ||
128 | checkdim1 n v | 164 | checkdim1 n v |
129 | | dim v == n = v | 165 | | size v == n = v |
130 | | otherwise = error $ "Error: "++ show n | 166 | | otherwise = error $ "Error: "++ show n |
131 | ++ " components expected in the result of the function supplied to odeSolve" | 167 | ++ " components expected in the result of the function supplied to odeSolve" |
132 | 168 | ||
@@ -135,6 +171,6 @@ checkdim2 n m | |||
135 | | otherwise = error $ "Error: "++ show n ++ "x" ++ show n | 171 | | otherwise = error $ "Error: "++ show n ++ "x" ++ show n |
136 | ++ " Jacobian expected in odeSolve" | 172 | ++ " Jacobian expected in odeSolve" |
137 | 173 | ||
138 | checkTimes ts | dim ts > 1 && all (>0) (zipWith subtract ts' (tail ts')) = ts | 174 | checkTimes ts | size ts > 1 && all (>0) (zipWith subtract ts' (tail ts')) = ts |
139 | | otherwise = error "odeSolve requires increasing times" | 175 | | otherwise = error "odeSolve requires increasing times" |
140 | where ts' = toList ts | 176 | where ts' = toList ts |
diff --git a/packages/gsl/src/Numeric/GSL/Polynomials.hs b/packages/gsl/src/Numeric/GSL/Polynomials.hs index b1be85d..8890f8f 100644 --- a/packages/gsl/src/Numeric/GSL/Polynomials.hs +++ b/packages/gsl/src/Numeric/GSL/Polynomials.hs | |||
@@ -16,9 +16,8 @@ module Numeric.GSL.Polynomials ( | |||
16 | polySolve | 16 | polySolve |
17 | ) where | 17 | ) where |
18 | 18 | ||
19 | import Data.Packed | 19 | import Numeric.LinearAlgebra.HMatrix |
20 | import Numeric.GSL.Internal | 20 | import Numeric.GSL.Internal |
21 | import Data.Complex | ||
22 | import System.IO.Unsafe (unsafePerformIO) | 21 | import System.IO.Unsafe (unsafePerformIO) |
23 | 22 | ||
24 | #if __GLASGOW_HASKELL__ >= 704 | 23 | #if __GLASGOW_HASKELL__ >= 704 |
@@ -47,9 +46,9 @@ polySolve :: [Double] -> [Complex Double] | |||
47 | polySolve = toList . polySolve' . fromList | 46 | polySolve = toList . polySolve' . fromList |
48 | 47 | ||
49 | polySolve' :: Vector Double -> Vector (Complex Double) | 48 | polySolve' :: Vector Double -> Vector (Complex Double) |
50 | polySolve' v | dim v > 1 = unsafePerformIO $ do | 49 | polySolve' v | size v > 1 = unsafePerformIO $ do |
51 | r <- createVector (dim v-1) | 50 | r <- createVector (size v-1) |
52 | app2 c_polySolve vec v vec r "polySolve" | 51 | c_polySolve # v # r #| "polySolve" |
53 | return r | 52 | return r |
54 | | otherwise = error "polySolve on a polynomial of degree zero" | 53 | | otherwise = error "polySolve on a polynomial of degree zero" |
55 | 54 | ||
diff --git a/packages/gsl/src/Numeric/GSL/Random.hs b/packages/gsl/src/Numeric/GSL/Random.hs index f1f49e5..139c921 100644 --- a/packages/gsl/src/Numeric/GSL/Random.hs +++ b/packages/gsl/src/Numeric/GSL/Random.hs | |||
@@ -21,11 +21,13 @@ module Numeric.GSL.Random ( | |||
21 | ) where | 21 | ) where |
22 | 22 | ||
23 | import Numeric.GSL.Vector | 23 | import Numeric.GSL.Vector |
24 | import Numeric.LinearAlgebra(cholSH) | 24 | import Numeric.LinearAlgebra.HMatrix hiding ( |
25 | import Numeric.Container hiding ( | ||
26 | randomVector, | 25 | randomVector, |
27 | gaussianSample, | 26 | gaussianSample, |
28 | uniformSample | 27 | uniformSample, |
28 | Seed, | ||
29 | rand, | ||
30 | randn | ||
29 | ) | 31 | ) |
30 | import System.Random(randomIO) | 32 | import System.Random(randomIO) |
31 | 33 | ||
@@ -40,10 +42,10 @@ gaussianSample :: Seed | |||
40 | -> Matrix Double -- ^ covariance matrix | 42 | -> Matrix Double -- ^ covariance matrix |
41 | -> Matrix Double -- ^ result | 43 | -> Matrix Double -- ^ result |
42 | gaussianSample seed n med cov = m where | 44 | gaussianSample seed n med cov = m where |
43 | c = dim med | 45 | c = size med |
44 | meds = konst 1 n `outer` med | 46 | meds = konst 1 n `outer` med |
45 | rs = reshape c $ randomVector seed Gaussian (c * n) | 47 | rs = reshape c $ randomVector seed Gaussian (c * n) |
46 | m = rs `mXm` cholSH cov `add` meds | 48 | m = rs <> cholSH cov + meds |
47 | 49 | ||
48 | -- | Obtains a matrix whose rows are pseudorandom samples from a multivariate | 50 | -- | Obtains a matrix whose rows are pseudorandom samples from a multivariate |
49 | -- uniform distribution. | 51 | -- uniform distribution. |
@@ -55,10 +57,10 @@ uniformSample seed n rgs = m where | |||
55 | (as,bs) = unzip rgs | 57 | (as,bs) = unzip rgs |
56 | a = fromList as | 58 | a = fromList as |
57 | cs = zipWith subtract as bs | 59 | cs = zipWith subtract as bs |
58 | d = dim a | 60 | d = size a |
59 | dat = toRows $ reshape n $ randomVector seed Uniform (n*d) | 61 | dat = toRows $ reshape n $ randomVector seed Uniform (n*d) |
60 | am = konst 1 n `outer` a | 62 | am = konst 1 n `outer` a |
61 | m = fromColumns (zipWith scale cs dat) `add` am | 63 | m = fromColumns (zipWith scale cs dat) + am |
62 | 64 | ||
63 | -- | pseudorandom matrix with uniform elements between 0 and 1 | 65 | -- | pseudorandom matrix with uniform elements between 0 and 1 |
64 | randm :: RandDist | 66 | randm :: RandDist |
diff --git a/packages/gsl/src/Numeric/GSL/Root.hs b/packages/gsl/src/Numeric/GSL/Root.hs index b9f3b94..724f32f 100644 --- a/packages/gsl/src/Numeric/GSL/Root.hs +++ b/packages/gsl/src/Numeric/GSL/Root.hs | |||
@@ -1,3 +1,5 @@ | |||
1 | {-# LANGUAGE FlexibleContexts #-} | ||
2 | |||
1 | {- | | 3 | {- | |
2 | Module : Numeric.GSL.Root | 4 | Module : Numeric.GSL.Root |
3 | Copyright : (c) Alberto Ruiz 2009 | 5 | Copyright : (c) Alberto Ruiz 2009 |
@@ -39,7 +41,7 @@ module Numeric.GSL.Root ( | |||
39 | rootJ, RootMethodJ(..), | 41 | rootJ, RootMethodJ(..), |
40 | ) where | 42 | ) where |
41 | 43 | ||
42 | import Data.Packed | 44 | import Numeric.LinearAlgebra.HMatrix |
43 | import Numeric.GSL.Internal | 45 | import Numeric.GSL.Internal |
44 | import Foreign.Ptr(FunPtr, freeHaskellFunPtr) | 46 | import Foreign.Ptr(FunPtr, freeHaskellFunPtr) |
45 | import Foreign.C.Types | 47 | import Foreign.C.Types |
@@ -69,7 +71,7 @@ uniRootGen m f xl xu epsrel maxit = unsafePerformIO $ do | |||
69 | rawpath <- createMIO maxit 4 | 71 | rawpath <- createMIO maxit 4 |
70 | (c_root m fp epsrel (fi maxit) xl xu) | 72 | (c_root m fp epsrel (fi maxit) xl xu) |
71 | "root" | 73 | "root" |
72 | let it = round (rawpath @@> (maxit-1,0)) | 74 | let it = round (rawpath `atIndex` (maxit-1,0)) |
73 | path = takeRows it rawpath | 75 | path = takeRows it rawpath |
74 | [sol] = toLists $ dropRows (it-1) path | 76 | [sol] = toLists $ dropRows (it-1) path |
75 | freeHaskellFunPtr fp | 77 | freeHaskellFunPtr fp |
@@ -100,7 +102,7 @@ uniRootJGen m f df x epsrel maxit = unsafePerformIO $ do | |||
100 | rawpath <- createMIO maxit 2 | 102 | rawpath <- createMIO maxit 2 |
101 | (c_rootj m fp dfp epsrel (fi maxit) x) | 103 | (c_rootj m fp dfp epsrel (fi maxit) x) |
102 | "rootj" | 104 | "rootj" |
103 | let it = round (rawpath @@> (maxit-1,0)) | 105 | let it = round (rawpath `atIndex` (maxit-1,0)) |
104 | path = takeRows it rawpath | 106 | path = takeRows it rawpath |
105 | [sol] = toLists $ dropRows (it-1) path | 107 | [sol] = toLists $ dropRows (it-1) path |
106 | freeHaskellFunPtr fp | 108 | freeHaskellFunPtr fp |
@@ -132,13 +134,13 @@ root method epsabs maxit fun xinit = rootGen (fi (fromEnum method)) fun xinit ep | |||
132 | 134 | ||
133 | rootGen m f xi epsabs maxit = unsafePerformIO $ do | 135 | rootGen m f xi epsabs maxit = unsafePerformIO $ do |
134 | let xiv = fromList xi | 136 | let xiv = fromList xi |
135 | n = dim xiv | 137 | n = size xiv |
136 | fp <- mkVecVecfun (aux_vTov (checkdim1 n . fromList . f . toList)) | 138 | fp <- mkVecVecfun (aux_vTov (checkdim1 n . fromList . f . toList)) |
137 | rawpath <- vec xiv $ \xiv' -> | 139 | rawpath <- vec xiv $ \xiv' -> |
138 | createMIO maxit (2*n+1) | 140 | createMIO maxit (2*n+1) |
139 | (c_multiroot m fp epsabs (fi maxit) // xiv') | 141 | (c_multiroot m fp epsabs (fi maxit) // xiv') |
140 | "multiroot" | 142 | "multiroot" |
141 | let it = round (rawpath @@> (maxit-1,0)) | 143 | let it = round (rawpath `atIndex` (maxit-1,0)) |
142 | path = takeRows it rawpath | 144 | path = takeRows it rawpath |
143 | [sol] = toLists $ dropRows (it-1) path | 145 | [sol] = toLists $ dropRows (it-1) path |
144 | freeHaskellFunPtr fp | 146 | freeHaskellFunPtr fp |
@@ -169,14 +171,14 @@ rootJ method epsabs maxit fun jac xinit = rootJGen (fi (fromEnum method)) fun ja | |||
169 | 171 | ||
170 | rootJGen m f jac xi epsabs maxit = unsafePerformIO $ do | 172 | rootJGen m f jac xi epsabs maxit = unsafePerformIO $ do |
171 | let xiv = fromList xi | 173 | let xiv = fromList xi |
172 | n = dim xiv | 174 | n = size xiv |
173 | fp <- mkVecVecfun (aux_vTov (checkdim1 n . fromList . f . toList)) | 175 | fp <- mkVecVecfun (aux_vTov (checkdim1 n . fromList . f . toList)) |
174 | jp <- mkVecMatfun (aux_vTom (checkdim2 n . fromLists . jac . toList)) | 176 | jp <- mkVecMatfun (aux_vTom (checkdim2 n . fromLists . jac . toList)) |
175 | rawpath <- vec xiv $ \xiv' -> | 177 | rawpath <- vec xiv $ \xiv' -> |
176 | createMIO maxit (2*n+1) | 178 | createMIO maxit (2*n+1) |
177 | (c_multirootj m fp jp epsabs (fi maxit) // xiv') | 179 | (c_multirootj m fp jp epsabs (fi maxit) // xiv') |
178 | "multiroot" | 180 | "multiroot" |
179 | let it = round (rawpath @@> (maxit-1,0)) | 181 | let it = round (rawpath `atIndex` (maxit-1,0)) |
180 | path = takeRows it rawpath | 182 | path = takeRows it rawpath |
181 | [sol] = toLists $ dropRows (it-1) path | 183 | [sol] = toLists $ dropRows (it-1) path |
182 | freeHaskellFunPtr fp | 184 | freeHaskellFunPtr fp |
@@ -189,7 +191,7 @@ foreign import ccall safe "multirootj" | |||
189 | ------------------------------------------------------- | 191 | ------------------------------------------------------- |
190 | 192 | ||
191 | checkdim1 n v | 193 | checkdim1 n v |
192 | | dim v == n = v | 194 | | size v == n = v |
193 | | otherwise = error $ "Error: "++ show n | 195 | | otherwise = error $ "Error: "++ show n |
194 | ++ " components expected in the result of the function supplied to root" | 196 | ++ " components expected in the result of the function supplied to root" |
195 | 197 | ||
diff --git a/packages/gsl/src/Numeric/GSL/SimulatedAnnealing.hs b/packages/gsl/src/Numeric/GSL/SimulatedAnnealing.hs new file mode 100644 index 0000000..11b22d3 --- /dev/null +++ b/packages/gsl/src/Numeric/GSL/SimulatedAnnealing.hs | |||
@@ -0,0 +1,245 @@ | |||
1 | {- | | ||
2 | Module : Numeric.GSL.Interpolation | ||
3 | Copyright : (c) Matthew Peddie 2015 | ||
4 | License : GPL | ||
5 | Maintainer : Alberto Ruiz | ||
6 | Stability : provisional | ||
7 | |||
8 | Simulated annealing routines. | ||
9 | |||
10 | <https://www.gnu.org/software/gsl/manual/html_node/Simulated-Annealing.html#Simulated-Annealing> | ||
11 | |||
12 | Here is a translation of the simple example given in | ||
13 | <https://www.gnu.org/software/gsl/manual/html_node/Trivial-example.html#Trivial-example the GSL manual>: | ||
14 | |||
15 | > import Numeric.GSL.SimulatedAnnealing | ||
16 | > import Numeric.LinearAlgebra.HMatrix | ||
17 | > | ||
18 | > main = print $ simanSolve 0 1 exampleParams 15.5 exampleE exampleM exampleS (Just show) | ||
19 | > | ||
20 | > exampleParams = SimulatedAnnealingParams 200 1000 1.0 1.0 0.008 1.003 2.0e-6 | ||
21 | > | ||
22 | > exampleE x = exp (-(x - 1)**2) * sin (8 * x) | ||
23 | > | ||
24 | > exampleM x y = abs $ x - y | ||
25 | > | ||
26 | > exampleS rands stepSize current = (rands ! 0) * 2 * stepSize - stepSize + current | ||
27 | |||
28 | The manual states: | ||
29 | |||
30 | > The first example, in one dimensional Cartesian space, sets up an | ||
31 | > energy function which is a damped sine wave; this has many local | ||
32 | > minima, but only one global minimum, somewhere between 1.0 and | ||
33 | > 1.5. The initial guess given is 15.5, which is several local minima | ||
34 | > away from the global minimum. | ||
35 | |||
36 | This global minimum is around 1.36. | ||
37 | |||
38 | -} | ||
39 | {-# OPTIONS_GHC -Wall #-} | ||
40 | |||
41 | module Numeric.GSL.SimulatedAnnealing ( | ||
42 | -- * Searching for minima | ||
43 | simanSolve | ||
44 | -- * Configuring the annealing process | ||
45 | , SimulatedAnnealingParams(..) | ||
46 | ) where | ||
47 | |||
48 | import Numeric.GSL.Internal | ||
49 | import Numeric.LinearAlgebra.HMatrix hiding(step) | ||
50 | |||
51 | import Data.Vector.Storable(generateM) | ||
52 | import Foreign.Storable(Storable(..)) | ||
53 | import Foreign.Marshal.Utils(with) | ||
54 | import Foreign.Ptr(Ptr, FunPtr, nullFunPtr) | ||
55 | import Foreign.StablePtr(StablePtr, newStablePtr, deRefStablePtr, freeStablePtr) | ||
56 | import Foreign.C.Types | ||
57 | import System.IO.Unsafe(unsafePerformIO) | ||
58 | |||
59 | import System.IO (hFlush, stdout) | ||
60 | |||
61 | import Data.IORef (IORef, newIORef, writeIORef, readIORef, modifyIORef') | ||
62 | |||
63 | -- | 'SimulatedAnnealingParams' is a translation of the | ||
64 | -- @gsl_siman_params_t@ structure documented in | ||
65 | -- <https://www.gnu.org/software/gsl/manual/html_node/Simulated-Annealing-functions.html#Simulated-Annealing-functions the GSL manual>, | ||
66 | -- which controls the simulated annealing algorithm. | ||
67 | -- | ||
68 | -- The annealing process is parameterized by the Boltzmann | ||
69 | -- distribution and the /cooling schedule/. For more details, see | ||
70 | -- <https://www.gnu.org/software/gsl/manual/html_node/Simulated-Annealing-algorithm.html#Simulated-Annealing-algorithm the relevant section of the manual>. | ||
71 | data SimulatedAnnealingParams = SimulatedAnnealingParams { | ||
72 | n_tries :: CInt -- ^ The number of points to try for each step. | ||
73 | , iters_fixed_T :: CInt -- ^ The number of iterations at each temperature | ||
74 | , step_size :: Double -- ^ The maximum step size in the random walk | ||
75 | , boltzmann_k :: Double -- ^ Boltzmann distribution parameter | ||
76 | , cooling_t_initial :: Double -- ^ Initial temperature | ||
77 | , cooling_mu_t :: Double -- ^ Cooling rate parameter | ||
78 | , cooling_t_min :: Double -- ^ Final temperature | ||
79 | } deriving (Eq, Show, Read) | ||
80 | |||
81 | instance Storable SimulatedAnnealingParams where | ||
82 | sizeOf p = sizeOf (n_tries p) + | ||
83 | sizeOf (iters_fixed_T p) + | ||
84 | sizeOf (step_size p) + | ||
85 | sizeOf (boltzmann_k p) + | ||
86 | sizeOf (cooling_t_initial p) + | ||
87 | sizeOf (cooling_mu_t p) + | ||
88 | sizeOf (cooling_t_min p) | ||
89 | -- TODO(MP): is this safe? | ||
90 | alignment p = alignment (step_size p) | ||
91 | -- TODO(MP): Is there a more automatic way to write these? | ||
92 | peek ptr = SimulatedAnnealingParams <$> | ||
93 | peekByteOff ptr 0 <*> | ||
94 | peekByteOff ptr i <*> | ||
95 | peekByteOff ptr (2*i) <*> | ||
96 | peekByteOff ptr (2*i + d) <*> | ||
97 | peekByteOff ptr (2*i + 2*d) <*> | ||
98 | peekByteOff ptr (2*i + 3*d) <*> | ||
99 | peekByteOff ptr (2*i + 4*d) | ||
100 | where | ||
101 | i = sizeOf (0 :: CInt) | ||
102 | d = sizeOf (0 :: Double) | ||
103 | poke ptr sap = do | ||
104 | pokeByteOff ptr 0 (n_tries sap) | ||
105 | pokeByteOff ptr i (iters_fixed_T sap) | ||
106 | pokeByteOff ptr (2*i) (step_size sap) | ||
107 | pokeByteOff ptr (2*i + d) (boltzmann_k sap) | ||
108 | pokeByteOff ptr (2*i + 2*d) (cooling_t_initial sap) | ||
109 | pokeByteOff ptr (2*i + 3*d) (cooling_mu_t sap) | ||
110 | pokeByteOff ptr (2*i + 4*d) (cooling_t_min sap) | ||
111 | where | ||
112 | i = sizeOf (0 :: CInt) | ||
113 | d = sizeOf (0 :: Double) | ||
114 | |||
115 | -- We use a StablePtr to an IORef so that we can keep hold of | ||
116 | -- StablePtr values but mutate their contents. A simple 'StablePtr a' | ||
117 | -- won't work, since we'd have no way to write 'copyConfig'. | ||
118 | type P a = StablePtr (IORef a) | ||
119 | |||
120 | copyConfig :: P a -> P a -> IO () | ||
121 | copyConfig src' dest' = do | ||
122 | dest <- deRefStablePtr dest' | ||
123 | src <- deRefStablePtr src' | ||
124 | readIORef src >>= writeIORef dest | ||
125 | |||
126 | copyConstructConfig :: P a -> IO (P a) | ||
127 | copyConstructConfig x = do | ||
128 | conf <- deRefRead x | ||
129 | newconf <- newIORef conf | ||
130 | newStablePtr newconf | ||
131 | |||
132 | destroyConfig :: P a -> IO () | ||
133 | destroyConfig p = do | ||
134 | freeStablePtr p | ||
135 | |||
136 | deRefRead :: P a -> IO a | ||
137 | deRefRead p = deRefStablePtr p >>= readIORef | ||
138 | |||
139 | wrapEnergy :: (a -> Double) -> P a -> Double | ||
140 | wrapEnergy f p = unsafePerformIO $ f <$> deRefRead p | ||
141 | |||
142 | wrapMetric :: (a -> a -> Double) -> P a -> P a -> Double | ||
143 | wrapMetric f x y = unsafePerformIO $ f <$> deRefRead x <*> deRefRead y | ||
144 | |||
145 | wrapStep :: Int | ||
146 | -> (Vector Double -> Double -> a -> a) | ||
147 | -> GSLRNG | ||
148 | -> P a | ||
149 | -> Double | ||
150 | -> IO () | ||
151 | wrapStep nrand f (GSLRNG rng) confptr stepSize = do | ||
152 | v <- generateM nrand (\_ -> gslRngUniform rng) | ||
153 | conf <- deRefStablePtr confptr | ||
154 | modifyIORef' conf $ f v stepSize | ||
155 | |||
156 | wrapPrint :: (a -> String) -> P a -> IO () | ||
157 | wrapPrint pf ptr = deRefRead ptr >>= putStr . pf >> hFlush stdout | ||
158 | |||
159 | foreign import ccall safe "wrapper" | ||
160 | mkEnergyFun :: (P a -> Double) -> IO (FunPtr (P a -> Double)) | ||
161 | |||
162 | foreign import ccall safe "wrapper" | ||
163 | mkMetricFun :: (P a -> P a -> Double) -> IO (FunPtr (P a -> P a -> Double)) | ||
164 | |||
165 | foreign import ccall safe "wrapper" | ||
166 | mkStepFun :: (GSLRNG -> P a -> Double -> IO ()) | ||
167 | -> IO (FunPtr (GSLRNG -> P a -> Double -> IO ())) | ||
168 | |||
169 | foreign import ccall safe "wrapper" | ||
170 | mkCopyFun :: (P a -> P a -> IO ()) -> IO (FunPtr (P a -> P a -> IO ())) | ||
171 | |||
172 | foreign import ccall safe "wrapper" | ||
173 | mkCopyConstructorFun :: (P a -> IO (P a)) -> IO (FunPtr (P a -> IO (P a))) | ||
174 | |||
175 | foreign import ccall safe "wrapper" | ||
176 | mkDestructFun :: (P a -> IO ()) -> IO (FunPtr (P a -> IO ())) | ||
177 | |||
178 | newtype GSLRNG = GSLRNG (Ptr GSLRNG) | ||
179 | |||
180 | foreign import ccall safe "gsl_rng.h gsl_rng_uniform" | ||
181 | gslRngUniform :: Ptr GSLRNG -> IO Double | ||
182 | |||
183 | foreign import ccall safe "gsl-aux.h siman" | ||
184 | siman :: CInt -- ^ RNG seed (for repeatability) | ||
185 | -> Ptr SimulatedAnnealingParams -- ^ params | ||
186 | -> P a -- ^ Configuration | ||
187 | -> FunPtr (P a -> Double) -- ^ Energy functional | ||
188 | -> FunPtr (P a -> P a -> Double) -- ^ Metric definition | ||
189 | -> FunPtr (GSLRNG -> P a -> Double -> IO ()) -- ^ Step evaluation | ||
190 | -> FunPtr (P a -> P a -> IO ()) -- ^ Copy config | ||
191 | -> FunPtr (P a -> IO (P a)) -- ^ Copy constructor for config | ||
192 | -> FunPtr (P a -> IO ()) -- ^ Destructor for config | ||
193 | -> FunPtr (P a -> IO ()) -- ^ Print function | ||
194 | -> IO CInt | ||
195 | |||
196 | -- | | ||
197 | -- Calling | ||
198 | -- | ||
199 | -- > simanSolve seed nrand params x0 e m step print | ||
200 | -- | ||
201 | -- performs a simulated annealing search through a given space. So | ||
202 | -- that any configuration type may be used, the space is specified by | ||
203 | -- providing the functions @e@ (the energy functional) and @m@ (the | ||
204 | -- metric definition). @x0@ is the initial configuration of the | ||
205 | -- system. The simulated annealing steps are generated using the | ||
206 | -- user-provided function @step@, which should randomly construct a | ||
207 | -- new system configuration. | ||
208 | -- | ||
209 | -- If 'Nothing' is passed instead of a printing function, no | ||
210 | -- incremental output will be generated. Otherwise, the GSL-formatted | ||
211 | -- output, including the configuration description the user function | ||
212 | -- generates, will be printed to stdout. | ||
213 | -- | ||
214 | -- Each time the step function is called, it is supplied with a random | ||
215 | -- vector containing @nrand@ 'Double' values, uniformly distributed in | ||
216 | -- @[0, 1)@. It should use these values to generate its new | ||
217 | -- configuration. | ||
218 | simanSolve :: Int -- ^ Seed for the random number generator | ||
219 | -> Int -- ^ @nrand@, the number of random 'Double's the | ||
220 | -- step function requires | ||
221 | -> SimulatedAnnealingParams -- ^ Parameters to configure the solver | ||
222 | -> a -- ^ Initial configuration @x0@ | ||
223 | -> (a -> Double) -- ^ Energy functional @e@ | ||
224 | -> (a -> a -> Double) -- ^ Metric definition @m@ | ||
225 | -> (Vector Double -> Double -> a -> a) -- ^ Stepping function @step@ | ||
226 | -> Maybe (a -> String) -- ^ Optional printing function | ||
227 | -> a -- ^ Best configuration the solver has found | ||
228 | simanSolve seed nrand params conf e m step printfun = | ||
229 | unsafePerformIO $ with params $ \paramptr -> do | ||
230 | ewrap <- mkEnergyFun $ wrapEnergy e | ||
231 | mwrap <- mkMetricFun $ wrapMetric m | ||
232 | stepwrap <- mkStepFun $ wrapStep nrand step | ||
233 | confptr <- newIORef conf >>= newStablePtr | ||
234 | cpwrap <- mkCopyFun copyConfig | ||
235 | ccwrap <- mkCopyConstructorFun copyConstructConfig | ||
236 | dwrap <- mkDestructFun destroyConfig | ||
237 | pwrap <- case printfun of | ||
238 | Nothing -> return nullFunPtr | ||
239 | Just pf -> mkDestructFun $ wrapPrint pf | ||
240 | siman (fromIntegral seed) | ||
241 | paramptr confptr | ||
242 | ewrap mwrap stepwrap cpwrap ccwrap dwrap pwrap // check "siman" | ||
243 | result <- deRefRead confptr | ||
244 | freeStablePtr confptr | ||
245 | return result | ||
diff --git a/packages/gsl/src/Numeric/GSL/Vector.hs b/packages/gsl/src/Numeric/GSL/Vector.hs index af79f32..fb982c5 100644 --- a/packages/gsl/src/Numeric/GSL/Vector.hs +++ b/packages/gsl/src/Numeric/GSL/Vector.hs | |||
@@ -14,8 +14,7 @@ module Numeric.GSL.Vector ( | |||
14 | fwriteVector, freadVector, fprintfVector, fscanfVector | 14 | fwriteVector, freadVector, fprintfVector, fscanfVector |
15 | ) where | 15 | ) where |
16 | 16 | ||
17 | import Data.Packed | 17 | import Numeric.LinearAlgebra.HMatrix hiding(randomVector, saveMatrix) |
18 | import Numeric.LinearAlgebra(RandDist(..)) | ||
19 | import Numeric.GSL.Internal hiding (TV,TM,TCV,TCM) | 18 | import Numeric.GSL.Internal hiding (TV,TM,TCV,TCM) |
20 | 19 | ||
21 | import Foreign.Marshal.Alloc(free) | 20 | import Foreign.Marshal.Alloc(free) |
@@ -35,7 +34,7 @@ randomVector :: Int -- ^ seed | |||
35 | -> Vector Double | 34 | -> Vector Double |
36 | randomVector seed dist n = unsafePerformIO $ do | 35 | randomVector seed dist n = unsafePerformIO $ do |
37 | r <- createVector n | 36 | r <- createVector n |
38 | app1 (c_random_vector_GSL (fi seed) ((fi.fromEnum) dist)) vec r "randomVectorGSL" | 37 | c_random_vector_GSL (fi seed) ((fi.fromEnum) dist) # r #|"randomVectorGSL" |
39 | return r | 38 | return r |
40 | 39 | ||
41 | foreign import ccall unsafe "random_vector_GSL" c_random_vector_GSL :: CInt -> CInt -> TV | 40 | foreign import ccall unsafe "random_vector_GSL" c_random_vector_GSL :: CInt -> CInt -> TV |
@@ -51,7 +50,7 @@ saveMatrix filename fmt m = do | |||
51 | charname <- newCString filename | 50 | charname <- newCString filename |
52 | charfmt <- newCString fmt | 51 | charfmt <- newCString fmt |
53 | let o = if orderOf m == RowMajor then 1 else 0 | 52 | let o = if orderOf m == RowMajor then 1 else 0 |
54 | app1 (matrix_fprintf charname charfmt o) mat m "matrix_fprintf" | 53 | matrix_fprintf charname charfmt o # m #|"matrix_fprintf" |
55 | free charname | 54 | free charname |
56 | free charfmt | 55 | free charfmt |
57 | 56 | ||
@@ -64,7 +63,7 @@ fscanfVector :: FilePath -> Int -> IO (Vector Double) | |||
64 | fscanfVector filename n = do | 63 | fscanfVector filename n = do |
65 | charname <- newCString filename | 64 | charname <- newCString filename |
66 | res <- createVector n | 65 | res <- createVector n |
67 | app1 (gsl_vector_fscanf charname) vec res "gsl_vector_fscanf" | 66 | gsl_vector_fscanf charname # res #|"gsl_vector_fscanf" |
68 | free charname | 67 | free charname |
69 | return res | 68 | return res |
70 | 69 | ||
@@ -75,7 +74,7 @@ fprintfVector :: FilePath -> String -> Vector Double -> IO () | |||
75 | fprintfVector filename fmt v = do | 74 | fprintfVector filename fmt v = do |
76 | charname <- newCString filename | 75 | charname <- newCString filename |
77 | charfmt <- newCString fmt | 76 | charfmt <- newCString fmt |
78 | app1 (gsl_vector_fprintf charname charfmt) vec v "gsl_vector_fprintf" | 77 | gsl_vector_fprintf charname charfmt # v #|"gsl_vector_fprintf" |
79 | free charname | 78 | free charname |
80 | free charfmt | 79 | free charfmt |
81 | 80 | ||
@@ -86,7 +85,7 @@ freadVector :: FilePath -> Int -> IO (Vector Double) | |||
86 | freadVector filename n = do | 85 | freadVector filename n = do |
87 | charname <- newCString filename | 86 | charname <- newCString filename |
88 | res <- createVector n | 87 | res <- createVector n |
89 | app1 (gsl_vector_fread charname) vec res "gsl_vector_fread" | 88 | gsl_vector_fread charname # res #|"gsl_vector_fread" |
90 | free charname | 89 | free charname |
91 | return res | 90 | return res |
92 | 91 | ||
@@ -96,7 +95,7 @@ foreign import ccall unsafe "vector_fread" gsl_vector_fread:: Ptr CChar -> TV | |||
96 | fwriteVector :: FilePath -> Vector Double -> IO () | 95 | fwriteVector :: FilePath -> Vector Double -> IO () |
97 | fwriteVector filename v = do | 96 | fwriteVector filename v = do |
98 | charname <- newCString filename | 97 | charname <- newCString filename |
99 | app1 (gsl_vector_fwrite charname) vec v "gsl_vector_fwrite" | 98 | gsl_vector_fwrite charname # v #|"gsl_vector_fwrite" |
100 | free charname | 99 | free charname |
101 | 100 | ||
102 | foreign import ccall unsafe "vector_fwrite" gsl_vector_fwrite :: Ptr CChar -> TV | 101 | foreign import ccall unsafe "vector_fwrite" gsl_vector_fwrite :: Ptr CChar -> TV |
diff --git a/packages/gsl/src/Numeric/GSL/gsl-aux.c b/packages/gsl/src/Numeric/GSL/gsl-aux.c index e1b189c..1ca8199 100644 --- a/packages/gsl/src/Numeric/GSL/gsl-aux.c +++ b/packages/gsl/src/Numeric/GSL/gsl-aux.c | |||
@@ -36,6 +36,8 @@ | |||
36 | #include <gsl/gsl_roots.h> | 36 | #include <gsl/gsl_roots.h> |
37 | #include <gsl/gsl_spline.h> | 37 | #include <gsl/gsl_spline.h> |
38 | #include <gsl/gsl_multifit_nlin.h> | 38 | #include <gsl/gsl_multifit_nlin.h> |
39 | #include <gsl/gsl_siman.h> | ||
40 | |||
39 | #include <string.h> | 41 | #include <string.h> |
40 | #include <stdio.h> | 42 | #include <stdio.h> |
41 | 43 | ||
@@ -475,7 +477,30 @@ int uniMinimize(int method, double f(double), | |||
475 | OK | 477 | OK |
476 | } | 478 | } |
477 | 479 | ||
478 | 480 | int siman(int seed, | |
481 | gsl_siman_params_t *params, void *xp0, | ||
482 | double energy(void *), double metric(void *, void *), | ||
483 | void step(const gsl_rng *, void *, double), | ||
484 | void copy(void *, void *), void *copycons(void *), | ||
485 | void destroy(void *), void print(void *)) { | ||
486 | DEBUGMSG("siman"); | ||
487 | gsl_rng *gen = gsl_rng_alloc (gsl_rng_mt19937); | ||
488 | gsl_rng_set(gen, seed); | ||
489 | |||
490 | // The simulated annealing routine doesn't indicate with a return | ||
491 | // code how things went -- there's little notion of convergence for | ||
492 | // a randomized minimizer on a potentially non-convex problem, and I | ||
493 | // suppose it doesn't detect egregious failures like malloc errors | ||
494 | // in the copy-constructor. | ||
495 | gsl_siman_solve(gen, xp0, | ||
496 | energy, step, | ||
497 | metric, print, | ||
498 | copy, copycons, | ||
499 | destroy, 0, *params); | ||
500 | |||
501 | gsl_rng_free(gen); | ||
502 | OK | ||
503 | } | ||
479 | 504 | ||
480 | // this version returns info about intermediate steps | 505 | // this version returns info about intermediate steps |
481 | int minimize(int method, double f(int, double*), double tolsize, int maxit, | 506 | int minimize(int method, double f(int, double*), double tolsize, int maxit, |
diff --git a/packages/gsl/src/Numeric/GSL/gsl-ode.c b/packages/gsl/src/Numeric/GSL/gsl-ode.c index 3f2771b..a6bdb55 100644 --- a/packages/gsl/src/Numeric/GSL/gsl-ode.c +++ b/packages/gsl/src/Numeric/GSL/gsl-ode.c | |||
@@ -23,10 +23,11 @@ int odejac (double t, const double y[], double *dfdy, double dfdt[], void *param | |||
23 | } | 23 | } |
24 | 24 | ||
25 | 25 | ||
26 | int ode(int method, double h, double eps_abs, double eps_rel, | 26 | int ode(int method, int control, double h, |
27 | double eps_abs, double eps_rel, double a_y, double a_dydt, | ||
27 | int f(double, int, const double*, int, double*), | 28 | int f(double, int, const double*, int, double*), |
28 | int jac(double, int, const double*, int, int, double*), | 29 | int jac(double, int, const double*, int, int, double*), |
29 | KRVEC(xi), KRVEC(ts), RMAT(sol)) { | 30 | KRVEC(sc), KRVEC(xi), KRVEC(ts), RMAT(sol)) { |
30 | 31 | ||
31 | const gsl_odeiv_step_type * T; | 32 | const gsl_odeiv_step_type * T; |
32 | 33 | ||
@@ -46,8 +47,16 @@ int ode(int method, double h, double eps_abs, double eps_rel, | |||
46 | } | 47 | } |
47 | 48 | ||
48 | gsl_odeiv_step * s = gsl_odeiv_step_alloc (T, xin); | 49 | gsl_odeiv_step * s = gsl_odeiv_step_alloc (T, xin); |
49 | gsl_odeiv_control * c = gsl_odeiv_control_y_new (eps_abs, eps_rel); | ||
50 | gsl_odeiv_evolve * e = gsl_odeiv_evolve_alloc (xin); | 50 | gsl_odeiv_evolve * e = gsl_odeiv_evolve_alloc (xin); |
51 | gsl_odeiv_control * c; | ||
52 | |||
53 | switch(control) { | ||
54 | case 0: { c = gsl_odeiv_control_standard_new | ||
55 | (eps_abs, eps_rel, a_y, a_dydt); break; } | ||
56 | case 1: { c = gsl_odeiv_control_scaled_new | ||
57 | (eps_abs, eps_rel, a_y, a_dydt, scp, scn); break; } | ||
58 | default: ERROR(BAD_CODE); | ||
59 | } | ||
51 | 60 | ||
52 | Tode P; | 61 | Tode P; |
53 | P.f = f; | 62 | P.f = f; |
@@ -112,10 +121,11 @@ int odejac (double t, const double y[], double *dfdy, double dfdt[], void *param | |||
112 | } | 121 | } |
113 | 122 | ||
114 | 123 | ||
115 | int ode(int method, double h, double eps_abs, double eps_rel, | 124 | int ode(int method, int control, double h, |
125 | double eps_abs, double eps_rel, double a_y, double a_dydt, | ||
116 | int f(double, int, const double*, int, double*), | 126 | int f(double, int, const double*, int, double*), |
117 | int jac(double, int, const double*, int, int, double*), | 127 | int jac(double, int, const double*, int, int, double*), |
118 | KRVEC(xi), KRVEC(ts), RMAT(sol)) { | 128 | KRVEC(sc), KRVEC(xi), KRVEC(ts), RMAT(sol)) { |
119 | 129 | ||
120 | const gsl_odeiv2_step_type * T; | 130 | const gsl_odeiv2_step_type * T; |
121 | 131 | ||
@@ -141,8 +151,15 @@ int ode(int method, double h, double eps_abs, double eps_rel, | |||
141 | 151 | ||
142 | gsl_odeiv2_system sys = {odefunc, odejac, xin, &P}; | 152 | gsl_odeiv2_system sys = {odefunc, odejac, xin, &P}; |
143 | 153 | ||
144 | gsl_odeiv2_driver * d = | 154 | gsl_odeiv2_driver * d; |
145 | gsl_odeiv2_driver_alloc_y_new (&sys, T, h, eps_abs, eps_rel); | 155 | |
156 | switch(control) { | ||
157 | case 0: { d = gsl_odeiv2_driver_alloc_standard_new | ||
158 | (&sys, T, h, eps_abs, eps_rel, a_y, a_dydt); break; } | ||
159 | case 1: { d = gsl_odeiv2_driver_alloc_scaled_new | ||
160 | (&sys, T, h, eps_abs, eps_rel, a_y, a_dydt, scp); break; } | ||
161 | default: ERROR(BAD_CODE); | ||
162 | } | ||
146 | 163 | ||
147 | double t = tsp[0]; | 164 | double t = tsp[0]; |
148 | 165 | ||
diff --git a/packages/sparse/hmatrix-sparse.cabal b/packages/sparse/hmatrix-sparse.cabal index d048086..55eb424 100644 --- a/packages/sparse/hmatrix-sparse.cabal +++ b/packages/sparse/hmatrix-sparse.cabal | |||
@@ -29,7 +29,13 @@ library | |||
29 | 29 | ||
30 | c-sources: src/Numeric/LinearAlgebra/sparse.c | 30 | c-sources: src/Numeric/LinearAlgebra/sparse.c |
31 | 31 | ||
32 | cc-options: -O4 -msse2 -Wall | 32 | cc-options: -O4 -Wall |
33 | |||
34 | if arch(x86_64) | ||
35 | cc-options: -msse2 | ||
36 | |||
37 | if arch(i386) | ||
38 | cc-options: -msse2 | ||
33 | 39 | ||
34 | extra-libraries: mkl_intel mkl_sequential mkl_core | 40 | extra-libraries: mkl_intel mkl_sequential mkl_core |
35 | 41 | ||
diff --git a/packages/sparse/src/Numeric/LinearAlgebra/Sparse.hs b/packages/sparse/src/Numeric/LinearAlgebra/Sparse.hs index 8608394..b2ca7f0 100644 --- a/packages/sparse/src/Numeric/LinearAlgebra/Sparse.hs +++ b/packages/sparse/src/Numeric/LinearAlgebra/Sparse.hs | |||
@@ -13,8 +13,11 @@ import System.IO.Unsafe(unsafePerformIO) | |||
13 | import Foreign(Ptr) | 13 | import Foreign(Ptr) |
14 | import Numeric.LinearAlgebra.HMatrix | 14 | import Numeric.LinearAlgebra.HMatrix |
15 | import Text.Printf | 15 | import Text.Printf |
16 | import Numeric.LinearAlgebra.Util((~!~)) | 16 | import Control.Monad(when) |
17 | 17 | ||
18 | (???) :: Bool -> String -> IO () | ||
19 | infixl 0 ??? | ||
20 | c ??? msg = when c (error msg) | ||
18 | 21 | ||
19 | type IV t = CInt -> Ptr CInt -> t | 22 | type IV t = CInt -> Ptr CInt -> t |
20 | type V t = CInt -> Ptr Double -> t | 23 | type V t = CInt -> Ptr Double -> t |
@@ -22,9 +25,9 @@ type SMxV = V (IV (IV (V (V (IO CInt))))) | |||
22 | 25 | ||
23 | dss :: CSR -> Vector Double -> Vector Double | 26 | dss :: CSR -> Vector Double -> Vector Double |
24 | dss CSR{..} b = unsafePerformIO $ do | 27 | dss CSR{..} b = unsafePerformIO $ do |
25 | size b /= csrNRows ~!~ printf "dss: incorrect sizes: (%d,%d) x %d" csrNRows csrNCols (size b) | 28 | size b /= csrNRows ??? printf "dss: incorrect sizes: (%d,%d) x %d" csrNRows csrNCols (size b) |
26 | r <- createVector csrNCols | 29 | r <- createVector csrNCols |
27 | app5 c_dss vec csrVals vec csrCols vec csrRows vec b vec r "dss" | 30 | c_dss `apply` csrVals `apply` csrCols `apply` csrRows `apply` b `apply` r #|"dss" |
28 | return r | 31 | return r |
29 | 32 | ||
30 | foreign import ccall unsafe "dss" | 33 | foreign import ccall unsafe "dss" |
diff --git a/packages/special/hmatrix-special.cabal b/packages/special/hmatrix-special.cabal index 28b294b..3b122c8 100644 --- a/packages/special/hmatrix-special.cabal +++ b/packages/special/hmatrix-special.cabal | |||
@@ -1,5 +1,5 @@ | |||
1 | Name: hmatrix-special | 1 | Name: hmatrix-special |
2 | Version: 0.3.0.1 | 2 | Version: 0.4.0.0 |
3 | License: GPL | 3 | License: GPL |
4 | License-file: LICENSE | 4 | License-file: LICENSE |
5 | Author: Alberto Ruiz | 5 | Author: Alberto Ruiz |
@@ -27,7 +27,7 @@ flag safe-cheap | |||
27 | default: False | 27 | default: False |
28 | 28 | ||
29 | library | 29 | library |
30 | Build-Depends: base <5, hmatrix, hmatrix-gsl | 30 | Build-Depends: base <5, hmatrix>=0.17, hmatrix-gsl |
31 | 31 | ||
32 | Extensions: ForeignFunctionInterface, | 32 | Extensions: ForeignFunctionInterface, |
33 | CPP | 33 | CPP |
diff --git a/packages/special/lib/Numeric/GSL/Special/Internal.hsc b/packages/special/lib/Numeric/GSL/Special/Internal.hsc index e7c38e8..a9aab9b 100644 --- a/packages/special/lib/Numeric/GSL/Special/Internal.hsc +++ b/packages/special/lib/Numeric/GSL/Special/Internal.hsc | |||
@@ -33,7 +33,7 @@ import Foreign.Storable | |||
33 | import Foreign.Ptr | 33 | import Foreign.Ptr |
34 | import Foreign.Marshal | 34 | import Foreign.Marshal |
35 | import System.IO.Unsafe(unsafePerformIO) | 35 | import System.IO.Unsafe(unsafePerformIO) |
36 | import Data.Packed.Development(check,(//)) | 36 | import Numeric.LinearAlgebra.Devel(check,(//)) |
37 | import Foreign.C.Types | 37 | import Foreign.C.Types |
38 | 38 | ||
39 | data Precision = PrecDouble | PrecSingle | PrecApprox | 39 | data Precision = PrecDouble | PrecSingle | PrecApprox |
diff --git a/packages/tests/hmatrix-tests.cabal b/packages/tests/hmatrix-tests.cabal index 0514843..49e0640 100644 --- a/packages/tests/hmatrix-tests.cabal +++ b/packages/tests/hmatrix-tests.cabal | |||
@@ -1,5 +1,5 @@ | |||
1 | Name: hmatrix-tests | 1 | Name: hmatrix-tests |
2 | Version: 0.4.1.0 | 2 | Version: 0.5.0.0 |
3 | License: BSD3 | 3 | License: BSD3 |
4 | License-file: LICENSE | 4 | License-file: LICENSE |
5 | Author: Alberto Ruiz | 5 | Author: Alberto Ruiz |
@@ -26,11 +26,11 @@ flag gsl | |||
26 | 26 | ||
27 | library | 27 | library |
28 | 28 | ||
29 | Build-Depends: base >= 4 && < 5, | 29 | Build-Depends: base >= 4 && < 5, deepseq, |
30 | QuickCheck >= 2, HUnit, random, | 30 | QuickCheck >= 2, HUnit, random, |
31 | hmatrix >= 0.16 | 31 | hmatrix >= 0.17 |
32 | if flag(gsl) | 32 | if flag(gsl) |
33 | Build-Depends: hmatrix-gsl >= 0.16 | 33 | Build-Depends: hmatrix-gsl >= 0.17 |
34 | 34 | ||
35 | hs-source-dirs: src | 35 | hs-source-dirs: src |
36 | 36 | ||
diff --git a/packages/tests/src/Numeric/GSL/Tests.hs b/packages/tests/src/Numeric/GSL/Tests.hs index 9dff6f5..025427b 100644 --- a/packages/tests/src/Numeric/GSL/Tests.hs +++ b/packages/tests/src/Numeric/GSL/Tests.hs | |||
@@ -19,10 +19,11 @@ import System.Exit (exitFailure) | |||
19 | 19 | ||
20 | import Test.HUnit (runTestTT, failures, Test(..), errors) | 20 | import Test.HUnit (runTestTT, failures, Test(..), errors) |
21 | 21 | ||
22 | import Numeric.LinearAlgebra | 22 | import Numeric.LinearAlgebra.HMatrix |
23 | import Numeric.GSL | 23 | import Numeric.GSL |
24 | import Numeric.GSL.SimulatedAnnealing | ||
24 | import Numeric.LinearAlgebra.Tests (qCheck, utest) | 25 | import Numeric.LinearAlgebra.Tests (qCheck, utest) |
25 | import Numeric.LinearAlgebra.Tests.Properties ((|~|), (~~)) | 26 | import Numeric.LinearAlgebra.Tests.Properties ((|~|), (~~), (~=)) |
26 | 27 | ||
27 | --------------------------------------------------------------------- | 28 | --------------------------------------------------------------------- |
28 | 29 | ||
@@ -42,7 +43,7 @@ fittingTest = utest "levmar" (ok1 && ok2) | |||
42 | sol = fst $ fitModel 1E-4 1E-4 20 (expModel, expModelDer) dat [1,0,0] | 43 | sol = fst $ fitModel 1E-4 1E-4 20 (expModel, expModelDer) dat [1,0,0] |
43 | 44 | ||
44 | ok1 = and (zipWith f sols [5,0.1,1]) where f (x,d) r = abs (x-r)<2*d | 45 | ok1 = and (zipWith f sols [5,0.1,1]) where f (x,d) r = abs (x-r)<2*d |
45 | ok2 = norm2 (fromList (map fst sols) - fromList sol) < 1E-5 | 46 | ok2 = norm_2 (fromList (map fst sols) - fromList sol) < 1E-5 |
46 | 47 | ||
47 | --------------------------------------------------------------------- | 48 | --------------------------------------------------------------------- |
48 | 49 | ||
@@ -66,6 +67,59 @@ rootFindingTest = TestList [ utest "root Hybrids" (fst sol1 ~~ [1,1]) | |||
66 | jacobian a b [x,_y] = [ [-a , 0] | 67 | jacobian a b [x,_y] = [ [-a , 0] |
67 | , [-2*b*x, b] ] | 68 | , [-2*b*x, b] ] |
68 | 69 | ||
70 | -------------------------------------------------------------------- | ||
71 | |||
72 | interpolationTest = TestList [ | ||
73 | utest "interpolation evaluateV" (esol ~= ev) | ||
74 | , utest "interpolation evaluate" (esol ~= eval) | ||
75 | , utest "interpolation evaluateDerivativeV" (desol ~= dev) | ||
76 | , utest "interpolation evaluateDerivative" (desol ~= de) | ||
77 | , utest "interpolation evaluateDerivative2V" (d2esol ~= d2ev) | ||
78 | , utest "interpolation evaluateDerivative2" (d2esol ~= d2e) | ||
79 | , utest "interpolation evaluateIntegralV" (intesol ~= intev) | ||
80 | , utest "interpolation evaluateIntegral" (intesol ~= inte) | ||
81 | ] | ||
82 | where | ||
83 | xtest = 2.2 | ||
84 | applyVec f = f Akima xs ys xtest | ||
85 | applyList f = f Akima (zip xs' ys') xtest | ||
86 | |||
87 | esol = xtest**2 | ||
88 | ev = applyVec evaluateV | ||
89 | eval = applyList evaluate | ||
90 | |||
91 | desol = 2*xtest | ||
92 | dev = applyVec evaluateDerivativeV | ||
93 | de = applyList evaluateDerivative | ||
94 | |||
95 | d2esol = 2 | ||
96 | d2ev = applyVec evaluateDerivative2V | ||
97 | d2e = applyList evaluateDerivative2 | ||
98 | |||
99 | intesol = 1/3 * xtest**3 | ||
100 | intev = evaluateIntegralV Akima xs ys 0 xtest | ||
101 | inte = evaluateIntegral Akima (zip xs' ys') (0, xtest) | ||
102 | |||
103 | xs' = [-1..10] | ||
104 | ys' = map (**2) xs' | ||
105 | xs = vector xs' | ||
106 | ys = vector ys' | ||
107 | |||
108 | --------------------------------------------------------------------- | ||
109 | |||
110 | simanTest = TestList [ | ||
111 | -- We use a slightly more relaxed tolerance here because the | ||
112 | -- simulated annealer is randomized | ||
113 | utest "simulated annealing manual example" $ abs (result - 1.3631300) < 1e-6 | ||
114 | ] | ||
115 | where | ||
116 | -- This is the example from the GSL manual. | ||
117 | result = simanSolve 0 1 exampleParams 15.5 exampleE exampleM exampleS Nothing | ||
118 | exampleParams = SimulatedAnnealingParams 200 10000 1.0 1.0 0.008 1.003 2.0e-6 | ||
119 | exampleE x = exp (-(x - 1)**2) * sin (8 * x) | ||
120 | exampleM x y = abs $ x - y | ||
121 | exampleS rands stepSize current = (rands ! 0) * 2 * stepSize - stepSize + current | ||
122 | |||
69 | --------------------------------------------------------------------- | 123 | --------------------------------------------------------------------- |
70 | 124 | ||
71 | minimizationTest = TestList | 125 | minimizationTest = TestList |
@@ -123,6 +177,8 @@ runTests n = do | |||
123 | , odeTest | 177 | , odeTest |
124 | , rootFindingTest | 178 | , rootFindingTest |
125 | , minimizationTest | 179 | , minimizationTest |
180 | , interpolationTest | ||
181 | , simanTest | ||
126 | , utest "deriv" derivTest | 182 | , utest "deriv" derivTest |
127 | , utest "integrate" (abs (volSphere 2.5 - 4/3*pi*2.5**3) < 1E-8) | 183 | , utest "integrate" (abs (volSphere 2.5 - 4/3*pi*2.5**3) < 1E-8) |
128 | , utest "polySolve" (polySolveProp [1,2,3,4]) | 184 | , utest "polySolve" (polySolveProp [1,2,3,4]) |
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs index 713af79..4b631cf 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs | |||
@@ -4,6 +4,8 @@ | |||
4 | {-# LANGUAGE TypeFamilies #-} | 4 | {-# LANGUAGE TypeFamilies #-} |
5 | {-# LANGUAGE FlexibleContexts #-} | 5 | {-# LANGUAGE FlexibleContexts #-} |
6 | {-# LANGUAGE RankNTypes #-} | 6 | {-# LANGUAGE RankNTypes #-} |
7 | {-# LANGUAGE TypeOperators #-} | ||
8 | {-# LANGUAGE ViewPatterns #-} | ||
7 | 9 | ||
8 | ----------------------------------------------------------------------------- | 10 | ----------------------------------------------------------------------------- |
9 | {- | | 11 | {- | |
@@ -28,12 +30,9 @@ module Numeric.LinearAlgebra.Tests( | |||
28 | --, runBigTests | 30 | --, runBigTests |
29 | ) where | 31 | ) where |
30 | 32 | ||
31 | import Numeric.LinearAlgebra | 33 | import Numeric.LinearAlgebra hiding (unitary) |
32 | import Numeric.LinearAlgebra.HMatrix hiding ((<>),linearSolve) | 34 | import Numeric.LinearAlgebra.Devel |
33 | import Numeric.LinearAlgebra.Static(L) | 35 | import Numeric.LinearAlgebra.Static(L) |
34 | import Numeric.LinearAlgebra.Util(col,row) | ||
35 | import Data.Packed | ||
36 | import Numeric.LinearAlgebra.LAPACK | ||
37 | import Numeric.LinearAlgebra.Tests.Instances | 36 | import Numeric.LinearAlgebra.Tests.Instances |
38 | import Numeric.LinearAlgebra.Tests.Properties | 37 | import Numeric.LinearAlgebra.Tests.Properties |
39 | import Test.HUnit hiding ((~:),test,Testable,State) | 38 | import Test.HUnit hiding ((~:),test,Testable,State) |
@@ -44,15 +43,13 @@ import qualified Prelude | |||
44 | import System.CPUTime | 43 | import System.CPUTime |
45 | import System.Exit | 44 | import System.Exit |
46 | import Text.Printf | 45 | import Text.Printf |
47 | import Data.Packed.Development(unsafeFromForeignPtr,unsafeToForeignPtr) | 46 | import Numeric.LinearAlgebra.Devel(unsafeFromForeignPtr,unsafeToForeignPtr) |
48 | import Control.Arrow((***)) | 47 | import Control.Arrow((***)) |
49 | import Debug.Trace | 48 | import Debug.Trace |
50 | import Control.Monad(when) | 49 | import Control.Monad(when) |
51 | import Numeric.LinearAlgebra.Util hiding (ones,row,col) | ||
52 | import Control.Applicative | 50 | import Control.Applicative |
53 | import Control.Monad(ap) | 51 | import Control.Monad(ap) |
54 | 52 | import Control.DeepSeq ( NFData(..) ) | |
55 | import Data.Packed.ST | ||
56 | 53 | ||
57 | import Test.QuickCheck(Arbitrary,arbitrary,coarbitrary,choose,vector | 54 | import Test.QuickCheck(Arbitrary,arbitrary,coarbitrary,choose,vector |
58 | ,sized,classify,Testable,Property | 55 | ,sized,classify,Testable,Property |
@@ -81,7 +78,7 @@ detTest1 = det m == 26 | |||
81 | && det mc == 38 :+ (-3) | 78 | && det mc == 38 :+ (-3) |
82 | && det (feye 2) == -1 | 79 | && det (feye 2) == -1 |
83 | where | 80 | where |
84 | m = (3><3) | 81 | m = (3><3) |
85 | [ 1, 2, 3 | 82 | [ 1, 2, 3 |
86 | , 4, 5, 7 | 83 | , 4, 5, 7 |
87 | , 2, 8, 4 :: Double | 84 | , 2, 8, 4 :: Double |
@@ -89,7 +86,7 @@ detTest1 = det m == 26 | |||
89 | mc = (3><3) | 86 | mc = (3><3) |
90 | [ 1, 2, 3 | 87 | [ 1, 2, 3 |
91 | , 4, 5, 7 | 88 | , 4, 5, 7 |
92 | , 2, 8, i | 89 | , 2, 8, iC |
93 | ] | 90 | ] |
94 | 91 | ||
95 | detTest2 = inv1 |~| inv2 && [det1] ~~ [det2] | 92 | detTest2 = inv1 |~| inv2 && [det1] ~~ [det2] |
@@ -130,8 +127,8 @@ expmTest2 = expm nd2 :~15~: (2><2) | |||
130 | mbCholTest = utest "mbCholTest" (ok1 && ok2) where | 127 | mbCholTest = utest "mbCholTest" (ok1 && ok2) where |
131 | m1 = (2><2) [2,5,5,8 :: Double] | 128 | m1 = (2><2) [2,5,5,8 :: Double] |
132 | m2 = (2><2) [3,5,5,9 :: Complex Double] | 129 | m2 = (2><2) [3,5,5,9 :: Complex Double] |
133 | ok1 = mbCholSH m1 == Nothing | 130 | ok1 = mbChol (trustSym m1) == Nothing |
134 | ok2 = mbCholSH m2 == Just (chol m2) | 131 | ok2 = mbChol (trustSym m2) == Just (chol $ trustSym m2) |
135 | 132 | ||
136 | --------------------------------------------------------------------- | 133 | --------------------------------------------------------------------- |
137 | 134 | ||
@@ -140,7 +137,7 @@ randomTestGaussian = c :~1~: snd (meanCov dat) where | |||
140 | 2,4,0, | 137 | 2,4,0, |
141 | -2,2,1] | 138 | -2,2,1] |
142 | m = 3 |> [1,2,3] | 139 | m = 3 |> [1,2,3] |
143 | c = a <> trans a | 140 | c = a <> tr a |
144 | dat = gaussianSample 7 (10^6) m c | 141 | dat = gaussianSample 7 (10^6) m c |
145 | 142 | ||
146 | randomTestUniform = c :~1~: snd (meanCov dat) where | 143 | randomTestUniform = c :~1~: snd (meanCov dat) where |
@@ -174,54 +171,54 @@ offsetTest = y == y' where | |||
174 | 171 | ||
175 | normsVTest = TestList [ | 172 | normsVTest = TestList [ |
176 | utest "normv2CD" $ norm2PropC v | 173 | utest "normv2CD" $ norm2PropC v |
177 | , utest "normv2CF" $ norm2PropC (single v) | 174 | -- , utest "normv2CF" $ norm2PropC (single v) |
178 | #ifndef NONORMVTEST | 175 | #ifndef NONORMVTEST |
179 | , utest "normv2D" $ norm2PropR x | 176 | , utest "normv2D" $ norm2PropR x |
180 | , utest "normv2F" $ norm2PropR (single x) | 177 | -- , utest "normv2F" $ norm2PropR (single x) |
181 | #endif | 178 | #endif |
182 | , utest "normv1CD" $ norm1 v == 8 | 179 | , utest "normv1CD" $ norm_1 v == 8 |
183 | , utest "normv1CF" $ norm1 (single v) == 8 | 180 | -- , utest "normv1CF" $ norm_1 (single v) == 8 |
184 | , utest "normv1D" $ norm1 x == 6 | 181 | , utest "normv1D" $ norm_1 x == 6 |
185 | , utest "normv1F" $ norm1 (single x) == 6 | 182 | -- , utest "normv1F" $ norm_1 (single x) == 6 |
186 | 183 | ||
187 | , utest "normvInfCD" $ normInf v == 5 | 184 | , utest "normvInfCD" $ norm_Inf v == 5 |
188 | , utest "normvInfCF" $ normInf (single v) == 5 | 185 | -- , utest "normvInfCF" $ norm_Inf (single v) == 5 |
189 | , utest "normvInfD" $ normInf x == 3 | 186 | , utest "normvInfD" $ norm_Inf x == 3 |
190 | , utest "normvInfF" $ normInf (single x) == 3 | 187 | -- , utest "normvInfF" $ norm_Inf (single x) == 3 |
191 | 188 | ||
192 | ] where v = fromList [1,-2,3:+4] :: Vector (Complex Double) | 189 | ] where v = fromList [1,-2,3:+4] :: Vector (Complex Double) |
193 | x = fromList [1,2,-3] :: Vector Double | 190 | x = fromList [1,2,-3] :: Vector Double |
194 | #ifndef NONORMVTEST | 191 | #ifndef NONORMVTEST |
195 | norm2PropR a = norm2 a =~= sqrt (udot a a) | 192 | norm2PropR a = norm_2 a =~= sqrt (udot a a) |
196 | #endif | 193 | #endif |
197 | norm2PropC a = norm2 a =~= realPart (sqrt (a <.> a)) | 194 | norm2PropC a = norm_2 a =~= realPart (sqrt (a `dot` a)) |
198 | a =~= b = fromList [a] |~| fromList [b] | 195 | a =~= b = fromList [a] |~| fromList [b] |
199 | 196 | ||
200 | normsMTest = TestList [ | 197 | normsMTest = TestList [ |
201 | utest "norm2mCD" $ pnorm PNorm2 v =~= 8.86164970498005 | 198 | utest "norm2mCD" $ norm_2 v =~= 8.86164970498005 |
202 | , utest "norm2mCF" $ pnorm PNorm2 (single v) =~= 8.86164970498005 | 199 | -- , utest "norm2mCF" $ norm_2 (single v) =~= 8.86164970498005 |
203 | , utest "norm2mD" $ pnorm PNorm2 x =~= 5.96667765076216 | 200 | , utest "norm2mD" $ norm_2 x =~= 5.96667765076216 |
204 | , utest "norm2mF" $ pnorm PNorm2 (single x) =~= 5.96667765076216 | 201 | -- , utest "norm2mF" $ norm_2 (single x) =~= 5.96667765076216 |
205 | 202 | ||
206 | , utest "norm1mCD" $ pnorm PNorm1 v == 9 | 203 | , utest "norm1mCD" $ norm_1 v == 9 |
207 | , utest "norm1mCF" $ pnorm PNorm1 (single v) == 9 | 204 | -- , utest "norm1mCF" $ norm_1 (single v) == 9 |
208 | , utest "norm1mD" $ pnorm PNorm1 x == 7 | 205 | , utest "norm1mD" $ norm_1 x == 7 |
209 | , utest "norm1mF" $ pnorm PNorm1 (single x) == 7 | 206 | -- , utest "norm1mF" $ norm_1 (single x) == 7 |
210 | 207 | ||
211 | , utest "normmInfCD" $ pnorm Infinity v == 12 | 208 | , utest "normmInfCD" $ norm_Inf v == 12 |
212 | , utest "normmInfCF" $ pnorm Infinity (single v) == 12 | 209 | -- , utest "normmInfCF" $ norm_Inf (single v) == 12 |
213 | , utest "normmInfD" $ pnorm Infinity x == 8 | 210 | , utest "normmInfD" $ norm_Inf x == 8 |
214 | , utest "normmInfF" $ pnorm Infinity (single x) == 8 | 211 | -- , utest "normmInfF" $ norm_Inf (single x) == 8 |
215 | 212 | ||
216 | , utest "normmFroCD" $ pnorm Frobenius v =~= 8.88819441731559 | 213 | , utest "normmFroCD" $ norm_Frob v =~= 8.88819441731559 |
217 | , utest "normmFroCF" $ pnorm Frobenius (single v) =~~= 8.88819441731559 | 214 | -- , utest "normmFroCF" $ norm_Frob (single v) =~~= 8.88819441731559 |
218 | , utest "normmFroD" $ pnorm Frobenius x =~= 6.24499799839840 | 215 | , utest "normmFroD" $ norm_Frob x =~= 6.24499799839840 |
219 | , utest "normmFroF" $ pnorm Frobenius (single x) =~~= 6.24499799839840 | 216 | -- , utest "normmFroF" $ norm_Frob (single x) =~~= 6.24499799839840 |
220 | 217 | ||
221 | ] where v = (2><2) [1,-2*i,3:+4,7] :: Matrix (Complex Double) | 218 | ] where v = (2><2) [1,-2*iC,3:+4,7] :: Matrix (Complex Double) |
222 | x = (2><2) [1,2,-3,5] :: Matrix Double | 219 | x = (2><2) [1,2,-3,5] :: Matrix Double |
223 | a =~= b = fromList [a] :~10~: fromList [b] | 220 | a =~= b = fromList [a] :~10~: fromList [b] |
224 | a =~~= b = fromList [a] :~5~: fromList [b] | 221 | -- a =~~= b = fromList [a] :~5~: fromList [b] |
225 | 222 | ||
226 | --------------------------------------------------------------------- | 223 | --------------------------------------------------------------------- |
227 | 224 | ||
@@ -236,7 +233,7 @@ sumprodTest = TestList [ | |||
236 | , utest "prodD" $ prodProp v | 233 | , utest "prodD" $ prodProp v |
237 | , utest "prodF" $ prodProp (single v) | 234 | , utest "prodF" $ prodProp (single v) |
238 | ] where v = fromList [1,2,3] :: Vector Double | 235 | ] where v = fromList [1,2,3] :: Vector Double |
239 | z = fromList [1,2-i,3+i] | 236 | z = fromList [1,2-iC,3+iC] |
240 | prodProp x = prodElements x == product (toList x) | 237 | prodProp x = prodElements x == product (toList x) |
241 | 238 | ||
242 | --------------------------------------------------------------------- | 239 | --------------------------------------------------------------------- |
@@ -250,7 +247,7 @@ chainTest = utest "chain" $ foldl1' (<>) ms |~| optimiseMult ms where | |||
250 | 247 | ||
251 | --------------------------------------------------------------------- | 248 | --------------------------------------------------------------------- |
252 | 249 | ||
253 | conjuTest m = mapVector conjugate (flatten (trans m)) == flatten (ctrans m) | 250 | conjuTest m = cmap conjugate (flatten (conj (tr m))) == flatten (tr m) |
254 | 251 | ||
255 | --------------------------------------------------------------------- | 252 | --------------------------------------------------------------------- |
256 | 253 | ||
@@ -306,7 +303,7 @@ lift_maybe m = MaybeT $ do | |||
306 | 303 | ||
307 | -- apply a test to successive elements of a vector, evaluates to true iff test passes for all pairs | 304 | -- apply a test to successive elements of a vector, evaluates to true iff test passes for all pairs |
308 | --successive_ :: Storable a => (a -> a -> Bool) -> Vector a -> Bool | 305 | --successive_ :: Storable a => (a -> a -> Bool) -> Vector a -> Bool |
309 | successive_ t v = maybe False (\_ -> True) $ evalState (runMaybeT (mapVectorM_ stp (subVector 1 (dim v - 1) v))) (v @> 0) | 306 | successive_ t v = maybe False (\_ -> True) $ evalState (runMaybeT (mapVectorM_ stp (subVector 1 (size v - 1) v))) (v ! 0) |
310 | where stp e = do | 307 | where stp e = do |
311 | ep <- lift_maybe $ state_get | 308 | ep <- lift_maybe $ state_get |
312 | if t e ep | 309 | if t e ep |
@@ -315,7 +312,7 @@ successive_ t v = maybe False (\_ -> True) $ evalState (runMaybeT (mapVectorM_ s | |||
315 | 312 | ||
316 | -- operate on successive elements of a vector and return the resulting vector, whose length 1 less than that of the input | 313 | -- operate on successive elements of a vector and return the resulting vector, whose length 1 less than that of the input |
317 | --successive :: (Storable a, Storable b) => (a -> a -> b) -> Vector a -> Vector b | 314 | --successive :: (Storable a, Storable b) => (a -> a -> b) -> Vector a -> Vector b |
318 | successive f v = evalState (mapVectorM stp (subVector 1 (dim v - 1) v)) (v @> 0) | 315 | successive f v = evalState (mapVectorM stp (subVector 1 (size v - 1) v)) (v ! 0) |
319 | where stp e = do | 316 | where stp e = do |
320 | ep <- state_get | 317 | ep <- state_get |
321 | state_put e | 318 | state_put e |
@@ -362,7 +359,7 @@ accumTest = utest "accum" ok | |||
362 | ,0,1,7 | 359 | ,0,1,7 |
363 | ,0,0,4] | 360 | ,0,0,4] |
364 | && | 361 | && |
365 | toList (flatten x) == [1,0,0,0,1,0,0,0,1] | 362 | toList (flatten x) == [1,0,0,0,1,0,0,0,1] |
366 | 363 | ||
367 | -------------------------------------------------------------------------------- | 364 | -------------------------------------------------------------------------------- |
368 | 365 | ||
@@ -377,28 +374,19 @@ convolutionTest = utest "convolution" ok | |||
377 | 374 | ||
378 | -------------------------------------------------------------------------------- | 375 | -------------------------------------------------------------------------------- |
379 | 376 | ||
380 | kroneckerTest = utest "kronecker" ok | 377 | sparseTest = utest "sparse" (fst $ checkT (undefined :: GMatrix)) |
381 | where | ||
382 | a,x,b :: Matrix Double | ||
383 | a = (3><4) [1..] | ||
384 | x = (4><2) [3,5..] | ||
385 | b = (2><5) [0,5..] | ||
386 | v1 = vec (a <> x <> b) | ||
387 | v2 = (trans b `kronecker` a) <> vec x | ||
388 | s = trans b <> b | ||
389 | v3 = vec s | ||
390 | v4 = (dup 5 :: Matrix Double) <> vech s | ||
391 | ok = v1 == v2 && v3 == v4 | ||
392 | && vtrans 1 a == trans a | ||
393 | && vtrans (rows a) a == asColumn (vec a) | ||
394 | 378 | ||
395 | -------------------------------------------------------------------------------- | 379 | -------------------------------------------------------------------------------- |
396 | 380 | ||
397 | sparseTest = utest "sparse" (fst $ checkT (undefined :: GMatrix)) | 381 | staticTest = utest "static" (fst $ checkT (undefined :: L 3 5)) |
398 | 382 | ||
399 | -------------------------------------------------------------------------------- | 383 | -------------------------------------------------------------------------------- |
400 | 384 | ||
401 | staticTest = utest "static" (fst $ checkT (undefined :: L 3 5)) | 385 | intTest = utest "int ops" (fst $ checkT (undefined :: Matrix I)) |
386 | |||
387 | -------------------------------------------------------------------------------- | ||
388 | |||
389 | modularTest = utest "modular ops" (fst $ checkT (undefined :: Matrix (Mod 13 I))) | ||
402 | 390 | ||
403 | -------------------------------------------------------------------------------- | 391 | -------------------------------------------------------------------------------- |
404 | 392 | ||
@@ -414,6 +402,150 @@ indexProp g f x = a1 == g a2 && a2 == a3 && b1 == g b2 && b2 == b3 | |||
414 | 402 | ||
415 | -------------------------------------------------------------------------------- | 403 | -------------------------------------------------------------------------------- |
416 | 404 | ||
405 | sliceTest = utest "slice test" $ and | ||
406 | [ testSlice (chol . trustSym) (gen 5 :: Matrix R) | ||
407 | , testSlice (chol . trustSym) (gen 5 :: Matrix C) | ||
408 | , testSlice qr (rec :: Matrix R) | ||
409 | , testSlice qr (rec :: Matrix C) | ||
410 | , testSlice hess (agen 5 :: Matrix R) | ||
411 | , testSlice hess (agen 5 :: Matrix C) | ||
412 | , testSlice schur (agen 5 :: Matrix R) | ||
413 | , testSlice schur (agen 5 :: Matrix C) | ||
414 | , testSlice lu (agen 5 :: Matrix R) | ||
415 | , testSlice lu (agen 5 :: Matrix C) | ||
416 | , testSlice (luSolve (luPacked (agen 5 :: Matrix R))) (agen 5) | ||
417 | , testSlice (luSolve (luPacked (agen 5 :: Matrix C))) (agen 5) | ||
418 | , test_lus (agen 5 :: Matrix R) | ||
419 | , test_lus (agen 5 :: Matrix C) | ||
420 | |||
421 | , testSlice eig (agen 5 :: Matrix R) | ||
422 | , testSlice eig (agen 5 :: Matrix C) | ||
423 | , testSlice (eigSH . trustSym) (gen 5 :: Matrix R) | ||
424 | , testSlice (eigSH . trustSym) (gen 5 :: Matrix C) | ||
425 | , testSlice eigenvalues (agen 5 :: Matrix R) | ||
426 | , testSlice eigenvalues (agen 5 :: Matrix C) | ||
427 | , testSlice (eigenvaluesSH . trustSym) (gen 5 :: Matrix R) | ||
428 | , testSlice (eigenvaluesSH . trustSym) (gen 5 :: Matrix C) | ||
429 | |||
430 | , testSlice svd (rec :: Matrix R) | ||
431 | , testSlice thinSVD (rec :: Matrix R) | ||
432 | , testSlice compactSVD (rec :: Matrix R) | ||
433 | , testSlice leftSV (rec :: Matrix R) | ||
434 | , testSlice rightSV (rec :: Matrix R) | ||
435 | , testSlice singularValues (rec :: Matrix R) | ||
436 | |||
437 | , testSlice svd (rec :: Matrix C) | ||
438 | , testSlice thinSVD (rec :: Matrix C) | ||
439 | , testSlice compactSVD (rec :: Matrix C) | ||
440 | , testSlice leftSV (rec :: Matrix C) | ||
441 | , testSlice rightSV (rec :: Matrix C) | ||
442 | , testSlice singularValues (rec :: Matrix C) | ||
443 | |||
444 | , testSlice (linearSolve (agen 5:: Matrix R)) (agen 5) | ||
445 | , testSlice (flip linearSolve (agen 5:: Matrix R)) (agen 5) | ||
446 | |||
447 | , testSlice (linearSolve (agen 5:: Matrix C)) (agen 5) | ||
448 | , testSlice (flip linearSolve (agen 5:: Matrix C)) (agen 5) | ||
449 | |||
450 | , testSlice (linearSolveLS (ogen 5:: Matrix R)) (ogen 5) | ||
451 | , testSlice (flip linearSolveLS (ogen 5:: Matrix R)) (ogen 5) | ||
452 | |||
453 | , testSlice (linearSolveLS (ogen 5:: Matrix C)) (ogen 5) | ||
454 | , testSlice (flip linearSolveLS (ogen 5:: Matrix C)) (ogen 5) | ||
455 | |||
456 | , testSlice (linearSolveSVD (ogen 5:: Matrix R)) (ogen 5) | ||
457 | , testSlice (flip linearSolveSVD (ogen 5:: Matrix R)) (ogen 5) | ||
458 | |||
459 | , testSlice (linearSolveSVD (ogen 5:: Matrix C)) (ogen 5) | ||
460 | , testSlice (flip linearSolveSVD (ogen 5:: Matrix C)) (ogen 5) | ||
461 | |||
462 | , testSlice (linearSolveLS (ugen 5:: Matrix R)) (ugen 5) | ||
463 | , testSlice (flip linearSolveLS (ugen 5:: Matrix R)) (ugen 5) | ||
464 | |||
465 | , testSlice (linearSolveLS (ugen 5:: Matrix C)) (ugen 5) | ||
466 | , testSlice (flip linearSolveLS (ugen 5:: Matrix C)) (ugen 5) | ||
467 | |||
468 | , testSlice (linearSolveSVD (ugen 5:: Matrix R)) (ugen 5) | ||
469 | , testSlice (flip linearSolveSVD (ugen 5:: Matrix R)) (ugen 5) | ||
470 | |||
471 | , testSlice (linearSolveSVD (ugen 5:: Matrix C)) (ugen 5) | ||
472 | , testSlice (flip linearSolveSVD (ugen 5:: Matrix C)) (ugen 5) | ||
473 | |||
474 | , testSlice ((<>) (ogen 5:: Matrix R)) (gen 5) | ||
475 | , testSlice (flip (<>) (gen 5:: Matrix R)) (ogen 5) | ||
476 | , testSlice ((<>) (ogen 5:: Matrix C)) (gen 5) | ||
477 | , testSlice (flip (<>) (gen 5:: Matrix C)) (ogen 5) | ||
478 | , testSlice ((<>) (ogen 5:: Matrix Float)) (gen 5) | ||
479 | , testSlice (flip (<>) (gen 5:: Matrix Float)) (ogen 5) | ||
480 | , testSlice ((<>) (ogen 5:: Matrix (Complex Float))) (gen 5) | ||
481 | , testSlice (flip (<>) (gen 5:: Matrix (Complex Float))) (ogen 5) | ||
482 | , testSlice ((<>) (ogen 5:: Matrix I)) (gen 5) | ||
483 | , testSlice (flip (<>) (gen 5:: Matrix I)) (ogen 5) | ||
484 | , testSlice ((<>) (ogen 5:: Matrix Z)) (gen 5) | ||
485 | , testSlice (flip (<>) (gen 5:: Matrix Z)) (ogen 5) | ||
486 | |||
487 | , testSlice ((<>) (ogen 5:: Matrix (I ./. 7))) (gen 5) | ||
488 | , testSlice (flip (<>) (gen 5:: Matrix (I ./. 7))) (ogen 5) | ||
489 | , testSlice ((<>) (ogen 5:: Matrix (Z ./. 7))) (gen 5) | ||
490 | , testSlice (flip (<>) (gen 5:: Matrix (Z ./. 7))) (ogen 5) | ||
491 | |||
492 | , testSlice (flip cholSolve (agen 5:: Matrix R)) (chol $ trustSym $ gen 5) | ||
493 | , testSlice (flip cholSolve (agen 5:: Matrix C)) (chol $ trustSym $ gen 5) | ||
494 | , testSlice (cholSolve (chol $ trustSym $ gen 5:: Matrix R)) (agen 5) | ||
495 | , testSlice (cholSolve (chol $ trustSym $ gen 5:: Matrix C)) (agen 5) | ||
496 | |||
497 | , ok_qrgr (rec :: Matrix R) | ||
498 | , ok_qrgr (rec :: Matrix C) | ||
499 | , testSlice (test_qrgr 4 tau1) qrr1 | ||
500 | , testSlice (test_qrgr 4 tau2) qrr2 | ||
501 | ] | ||
502 | where | ||
503 | QR qrr1 tau1 = qrRaw (rec :: Matrix R) | ||
504 | QR qrr2 tau2 = qrRaw (rec :: Matrix C) | ||
505 | |||
506 | test_qrgr n t x = qrgr n (QR x t) | ||
507 | |||
508 | ok_qrgr x = simeq 1E-15 q q' | ||
509 | where | ||
510 | (q,_) = qr x | ||
511 | atau = qrRaw x | ||
512 | q' = qrgr (rows q) atau | ||
513 | |||
514 | simeq eps a b = not $ magnit eps (norm_1 $ flatten (a-b)) | ||
515 | |||
516 | test_lus m = testSlice f lup | ||
517 | where | ||
518 | f x = luSolve (LU x p) m | ||
519 | (LU lup p) = luPacked m | ||
520 | |||
521 | gen :: Numeric t => Int -> Matrix t | ||
522 | gen n = diagRect 1 (konst 5 n) n n | ||
523 | |||
524 | agen :: (Numeric t, Num (Vector t))=> Int -> Matrix t | ||
525 | agen n = gen n + fromInt ((n><n)[0..]) | ||
526 | |||
527 | ogen :: (Numeric t, Num (Vector t))=> Int -> Matrix t | ||
528 | ogen n = gen n === gen n | ||
529 | |||
530 | ugen :: (Numeric t, Num (Vector t))=> Int -> Matrix t | ||
531 | ugen n = takeRows 3 (gen n) | ||
532 | |||
533 | |||
534 | rec :: Numeric t => Matrix t | ||
535 | rec = subMatrix (0,0) (4,5) (gen 5) | ||
536 | |||
537 | testSlice f x@(size->sz@(r,c)) = all (==f x) (map f (g y1 ++ g y2)) | ||
538 | where | ||
539 | subm = subMatrix | ||
540 | g y = [ subm (a*r,b*c) sz y | a <-[0..2], b <- [0..2]] | ||
541 | h z = fromBlocks (replicate 3 (replicate 3 z)) | ||
542 | y1 = h x | ||
543 | y2 = (tr . h . tr) x | ||
544 | |||
545 | |||
546 | |||
547 | -------------------------------------------------------------------------------- | ||
548 | |||
417 | -- | All tests must pass with a maximum dimension of about 20 | 549 | -- | All tests must pass with a maximum dimension of about 20 |
418 | -- (some tests may fail with bigger sizes due to precision loss). | 550 | -- (some tests may fail with bigger sizes due to precision loss). |
419 | runTests :: Int -- ^ maximum dimension | 551 | runTests :: Int -- ^ maximum dimension |
@@ -435,11 +567,11 @@ runTests n = do | |||
435 | test (multProp1 10 . cConsist) | 567 | test (multProp1 10 . cConsist) |
436 | test (multProp2 10 . rConsist) | 568 | test (multProp2 10 . rConsist) |
437 | test (multProp2 10 . cConsist) | 569 | test (multProp2 10 . cConsist) |
438 | putStrLn "------ mult Float" | 570 | -- putStrLn "------ mult Float" |
439 | test (multProp1 6 . (single *** single) . rConsist) | 571 | -- test (multProp1 6 . (single *** single) . rConsist) |
440 | test (multProp1 6 . (single *** single) . cConsist) | 572 | -- test (multProp1 6 . (single *** single) . cConsist) |
441 | test (multProp2 6 . (single *** single) . rConsist) | 573 | -- test (multProp2 6 . (single *** single) . rConsist) |
442 | test (multProp2 6 . (single *** single) . cConsist) | 574 | -- test (multProp2 6 . (single *** single) . cConsist) |
443 | putStrLn "------ sub-trans" | 575 | putStrLn "------ sub-trans" |
444 | test (subProp . rM) | 576 | test (subProp . rM) |
445 | test (subProp . cM) | 577 | test (subProp . cM) |
@@ -455,9 +587,12 @@ runTests n = do | |||
455 | putStrLn "------ luSolve" | 587 | putStrLn "------ luSolve" |
456 | test (linearSolveProp (luSolve.luPacked) . rSqWC) | 588 | test (linearSolveProp (luSolve.luPacked) . rSqWC) |
457 | test (linearSolveProp (luSolve.luPacked) . cSqWC) | 589 | test (linearSolveProp (luSolve.luPacked) . cSqWC) |
590 | putStrLn "------ ldlSolve" | ||
591 | test (linearSolvePropH (ldlSolve.ldlPacked) . rSymWC) | ||
592 | test (linearSolvePropH (ldlSolve.ldlPacked) . cSymWC) | ||
458 | putStrLn "------ cholSolve" | 593 | putStrLn "------ cholSolve" |
459 | test (linearSolveProp (cholSolve.chol) . rPosDef) | 594 | test (linearSolveProp (cholSolve.chol.trustSym) . rPosDef) |
460 | test (linearSolveProp (cholSolve.chol) . cPosDef) | 595 | test (linearSolveProp (cholSolve.chol.trustSym) . cPosDef) |
461 | putStrLn "------ luSolveLS" | 596 | putStrLn "------ luSolveLS" |
462 | test (linearSolveProp linearSolveLS . rSqWC) | 597 | test (linearSolveProp linearSolveLS . rSqWC) |
463 | test (linearSolveProp linearSolveLS . cSqWC) | 598 | test (linearSolveProp linearSolveLS . cSqWC) |
@@ -472,16 +607,16 @@ runTests n = do | |||
472 | putStrLn "------ svd" | 607 | putStrLn "------ svd" |
473 | test (svdProp1 . rM) | 608 | test (svdProp1 . rM) |
474 | test (svdProp1 . cM) | 609 | test (svdProp1 . cM) |
475 | test (svdProp1a svdR) | 610 | test (svdProp1a svd . rM) |
476 | test (svdProp1a svdC) | 611 | test (svdProp1a svd . cM) |
477 | test (svdProp1a svdRd) | 612 | -- test (svdProp1a svdRd) |
478 | test (svdProp1b svdR) | 613 | test (svdProp1b svd . rM) |
479 | test (svdProp1b svdC) | 614 | test (svdProp1b svd . cM) |
480 | test (svdProp1b svdRd) | 615 | -- test (svdProp1b svdRd) |
481 | test (svdProp2 thinSVDR) | 616 | test (svdProp2 thinSVD . rM) |
482 | test (svdProp2 thinSVDC) | 617 | test (svdProp2 thinSVD . cM) |
483 | test (svdProp2 thinSVDRd) | 618 | -- test (svdProp2 thinSVDRd) |
484 | test (svdProp2 thinSVDCd) | 619 | -- test (svdProp2 thinSVDCd) |
485 | test (svdProp3 . rM) | 620 | test (svdProp3 . rM) |
486 | test (svdProp3 . cM) | 621 | test (svdProp3 . cM) |
487 | test (svdProp4 . rM) | 622 | test (svdProp4 . rM) |
@@ -492,12 +627,12 @@ runTests n = do | |||
492 | test (svdProp6b) | 627 | test (svdProp6b) |
493 | test (svdProp7 . rM) | 628 | test (svdProp7 . rM) |
494 | test (svdProp7 . cM) | 629 | test (svdProp7 . cM) |
495 | putStrLn "------ svdCd" | 630 | -- putStrLn "------ svdCd" |
496 | #ifdef NOZGESDD | 631 | #ifdef NOZGESDD |
497 | putStrLn "Omitted" | 632 | -- putStrLn "Omitted" |
498 | #else | 633 | #else |
499 | test (svdProp1a svdCd) | 634 | -- test (svdProp1a svdCd) |
500 | test (svdProp1b svdCd) | 635 | -- test (svdProp1b svdCd) |
501 | #endif | 636 | #endif |
502 | putStrLn "------ eig" | 637 | putStrLn "------ eig" |
503 | test (eigSHProp . rHer) | 638 | test (eigSHProp . rHer) |
@@ -515,10 +650,10 @@ runTests n = do | |||
515 | test (qrProp . rM) | 650 | test (qrProp . rM) |
516 | test (qrProp . cM) | 651 | test (qrProp . cM) |
517 | test (rqProp . rM) | 652 | test (rqProp . rM) |
518 | test (rqProp . cM) | 653 | -- test (rqProp . cM) |
519 | test (rqProp1 . cM) | 654 | test (rqProp1 . cM) |
520 | test (rqProp2 . cM) | 655 | test (rqProp2 . cM) |
521 | test (rqProp3 . cM) | 656 | -- test (rqProp3 . cM) |
522 | putStrLn "------ hess" | 657 | putStrLn "------ hess" |
523 | test (hessProp . rSq) | 658 | test (hessProp . rSq) |
524 | test (hessProp . cSq) | 659 | test (hessProp . cSq) |
@@ -528,8 +663,8 @@ runTests n = do | |||
528 | putStrLn "------ chol" | 663 | putStrLn "------ chol" |
529 | test (cholProp . rPosDef) | 664 | test (cholProp . rPosDef) |
530 | test (cholProp . cPosDef) | 665 | test (cholProp . cPosDef) |
531 | test (exactProp . rPosDef) | 666 | -- test (exactProp . rPosDef) |
532 | test (exactProp . cPosDef) | 667 | -- test (exactProp . cPosDef) |
533 | putStrLn "------ expm" | 668 | putStrLn "------ expm" |
534 | test (expmDiagProp . complex. rSqWC) | 669 | test (expmDiagProp . complex. rSqWC) |
535 | test (expmDiagProp . cSqWC) | 670 | test (expmDiagProp . cSqWC) |
@@ -539,12 +674,12 @@ runTests n = do | |||
539 | test (\u -> sin u ** 2 + cos u ** 2 |~| (1::RM)) | 674 | test (\u -> sin u ** 2 + cos u ** 2 |~| (1::RM)) |
540 | test (\u -> cos u * tan u |~| sin (u::RM)) | 675 | test (\u -> cos u * tan u |~| sin (u::RM)) |
541 | test $ (\u -> cos u * tan u |~| sin (u::CM)) . liftMatrix makeUnitary | 676 | test $ (\u -> cos u * tan u |~| sin (u::CM)) . liftMatrix makeUnitary |
542 | putStrLn "------ vector operations - Float" | 677 | -- putStrLn "------ vector operations - Float" |
543 | test (\u -> sin u ^ 2 + cos u ^ 2 |~~| (1::FM)) | 678 | -- test (\u -> sin u ^ 2 + cos u ^ 2 |~~| (1::FM)) |
544 | test $ (\u -> sin u ^ 2 + cos u ^ 2 |~~| (1::ZM)) . liftMatrix makeUnitary | 679 | -- test $ (\u -> sin u ^ 2 + cos u ^ 2 |~~| (1::ZM)) . liftMatrix makeUnitary |
545 | test (\u -> sin u ** 2 + cos u ** 2 |~~| (1::FM)) | 680 | -- test (\u -> sin u ** 2 + cos u ** 2 |~~| (1::FM)) |
546 | test (\u -> cos u * tan u |~~| sin (u::FM)) | 681 | -- test (\u -> cos u * tan u |~~| sin (u::FM)) |
547 | test $ (\u -> cos u * tan u |~~| sin (u::ZM)) . liftMatrix makeUnitary | 682 | -- test $ (\u -> cos u * tan u |~~| sin (u::ZM)) . liftMatrix makeUnitary |
548 | putStrLn "------ read . show" | 683 | putStrLn "------ read . show" |
549 | test (\m -> (m::RM) == read (show m)) | 684 | test (\m -> (m::RM) == read (show m)) |
550 | test (\m -> (m::CM) == read (show m)) | 685 | test (\m -> (m::CM) == read (show m)) |
@@ -562,8 +697,8 @@ runTests n = do | |||
562 | , utest "expm1" (expmTest1) | 697 | , utest "expm1" (expmTest1) |
563 | , utest "expm2" (expmTest2) | 698 | , utest "expm2" (expmTest2) |
564 | , utest "arith1" $ ((ones (100,100) * 5 + 2)/0.5 - 7)**2 |~| (49 :: RM) | 699 | , utest "arith1" $ ((ones (100,100) * 5 + 2)/0.5 - 7)**2 |~| (49 :: RM) |
565 | , utest "arith2" $ ((scalar (1+i) * ones (100,100) * 5 + 2)/0.5 - 7)**2 |~| ( scalar (140*i-51) :: CM) | 700 | , utest "arith2" $ ((scalar (1+iC) * ones (100,100) * 5 + 2)/0.5 - 7)**2 |~| ( scalar (140*iC-51) :: CM) |
566 | , utest "arith3" $ exp (scalar i * ones(10,10)*pi) + 1 |~| 0 | 701 | , utest "arith3" $ exp (scalar iC * ones(10,10)*pi) + 1 |~| 0 |
567 | , utest "<\\>" $ (3><2) [2,0,0,3,1,1::Double] <\> 3|>[4,9,5] |~| 2|>[2,3] | 702 | , utest "<\\>" $ (3><2) [2,0,0,3,1,1::Double] <\> 3|>[4,9,5] |~| 2|>[2,3] |
568 | -- , utest "gamma" (gamma 5 == 24.0) | 703 | -- , utest "gamma" (gamma 5 == 24.0) |
569 | -- , besselTest | 704 | -- , besselTest |
@@ -571,10 +706,10 @@ runTests n = do | |||
571 | , utest "randomGaussian" randomTestGaussian | 706 | , utest "randomGaussian" randomTestGaussian |
572 | , utest "randomUniform" randomTestUniform | 707 | , utest "randomUniform" randomTestUniform |
573 | , utest "buildVector/Matrix" $ | 708 | , utest "buildVector/Matrix" $ |
574 | complex (10 |> [0::Double ..]) == buildVector 10 fromIntegral | 709 | complex (10 |> [0::Double ..]) == build 10 id |
575 | && ident 5 == buildMatrix 5 5 (\(r,c) -> if r==c then 1::Double else 0) | 710 | && ident 5 == build (5,5) (\r c -> if r==c then 1::Double else 0) |
576 | , utest "rank" $ rank ((2><3)[1,0,0,1,5*eps,0]) == 1 | 711 | , utest "rank" $ rank ((2><3)[1,0,0,1,5*peps,0::Double]) == 1 |
577 | && rank ((2><3)[1,0,0,1,7*eps,0]) == 2 | 712 | && rank ((2><3)[1,0,0,1,7*peps,0::Double]) == 2 |
578 | , utest "block" $ fromBlocks [[ident 3,0],[0,ident 4]] == (ident 7 :: CM) | 713 | , utest "block" $ fromBlocks [[ident 3,0],[0,ident 4]] == (ident 7 :: CM) |
579 | , mbCholTest | 714 | , mbCholTest |
580 | , utest "offset" offsetTest | 715 | , utest "offset" offsetTest |
@@ -588,21 +723,23 @@ runTests n = do | |||
588 | , conformTest | 723 | , conformTest |
589 | , accumTest | 724 | , accumTest |
590 | , convolutionTest | 725 | , convolutionTest |
591 | , kroneckerTest | ||
592 | , sparseTest | 726 | , sparseTest |
593 | , staticTest | 727 | , staticTest |
728 | , intTest | ||
729 | , modularTest | ||
730 | , sliceTest | ||
594 | ] | 731 | ] |
595 | when (errors c + failures c > 0) exitFailure | 732 | when (errors c + failures c > 0) exitFailure |
596 | return () | 733 | return () |
597 | 734 | ||
598 | 735 | ||
599 | -- single precision approximate equality | 736 | -- single precision approximate equality |
600 | infixl 4 |~~| | 737 | -- infixl 4 |~~| |
601 | a |~~| b = a :~6~: b | 738 | -- a |~~| b = a :~6~: b |
602 | 739 | ||
603 | makeUnitary v | realPart n > 1 = v / scalar n | 740 | makeUnitary v | realPart n > 1 = v / scalar n |
604 | | otherwise = v | 741 | | otherwise = v |
605 | where n = sqrt (v <.> v) | 742 | where n = sqrt (v `dot` v) |
606 | 743 | ||
607 | -- -- | Some additional tests on big matrices. They take a few minutes. | 744 | -- -- | Some additional tests on big matrices. They take a few minutes. |
608 | -- runBigTests :: IO () | 745 | -- runBigTests :: IO () |
@@ -625,6 +762,8 @@ runBenchmarks = do | |||
625 | mkVecBench | 762 | mkVecBench |
626 | multBench | 763 | multBench |
627 | cholBench | 764 | cholBench |
765 | luBench | ||
766 | luBench_2 | ||
628 | svdBench | 767 | svdBench |
629 | eigBench | 768 | eigBench |
630 | putStrLn "" | 769 | putStrLn "" |
@@ -668,9 +807,9 @@ manyvec5 xs = sumElements $ fromRows $ map (\x -> vec3 x (x**2) (x**3)) xs | |||
668 | 807 | ||
669 | 808 | ||
670 | manyvec2 xs = sum $ map (\x -> sqrt(x^2 + (x**2)^2 +(x**3)^2)) xs | 809 | manyvec2 xs = sum $ map (\x -> sqrt(x^2 + (x**2)^2 +(x**3)^2)) xs |
671 | manyvec3 xs = sum $ map (pnorm PNorm2 . (\x -> fromList [x,x**2,x**3])) xs | 810 | manyvec3 xs = sum $ map (norm_2 . (\x -> fromList [x,x**2,x**3])) xs |
672 | 811 | ||
673 | manyvec4 xs = sum $ map (pnorm PNorm2 . (\x -> vec3 x (x**2) (x**3))) xs | 812 | manyvec4 xs = sum $ map (norm_2 . (\x -> vec3 x (x**2) (x**3))) xs |
674 | 813 | ||
675 | vec3 :: Double -> Double -> Double -> Vector Double | 814 | vec3 :: Double -> Double -> Double -> Vector Double |
676 | vec3 a b c = runSTVector $ do | 815 | vec3 a b c = runSTVector $ do |
@@ -695,11 +834,11 @@ mkVecBench = do | |||
695 | 834 | ||
696 | subBench = do | 835 | subBench = do |
697 | putStrLn "" | 836 | putStrLn "" |
698 | let g = foldl1' (.) (replicate (10^5) (\v -> subVector 1 (dim v -1) v)) | 837 | let g = foldl1' (.) (replicate (10^5) (\v -> subVector 1 (size v -1) v)) |
699 | time "0.1M subVector " (g (konst 1 (1+10^5) :: Vector Double) @> 0) | 838 | time "0.1M subVector " (g (konst 1 (1+10^5) :: Vector Double) ! 0) |
700 | let f = foldl1' (.) (replicate (10^5) (fromRows.toRows)) | 839 | let f = foldl1' (.) (replicate (10^5) (fromRows.toRows)) |
701 | time "subVector-join 3" (f (ident 3 :: Matrix Double) @@>(0,0)) | 840 | time "subVector-join 3" (f (ident 3 :: Matrix Double) `atIndex` (0,0)) |
702 | time "subVector-join 10" (f (ident 10 :: Matrix Double) @@>(0,0)) | 841 | time "subVector-join 10" (f (ident 10 :: Matrix Double) `atIndex` (0,0)) |
703 | 842 | ||
704 | -------------------------------- | 843 | -------------------------------- |
705 | 844 | ||
@@ -724,10 +863,10 @@ multBench = do | |||
724 | 863 | ||
725 | eigBench = do | 864 | eigBench = do |
726 | let m = reshape 1000 (randomVector 777 Uniform (1000*1000)) | 865 | let m = reshape 1000 (randomVector 777 Uniform (1000*1000)) |
727 | s = m + trans m | 866 | s = m + tr m |
728 | m `seq` s `seq` putStrLn "" | 867 | m `seq` s `seq` putStrLn "" |
729 | time "eigenvalues symmetric 1000x1000" (eigenvaluesSH' m) | 868 | time "eigenvalues symmetric 1000x1000" (eigenvaluesSH (trustSym m)) |
730 | time "eigenvectors symmetric 1000x1000" (snd $ eigSH' m) | 869 | time "eigenvectors symmetric 1000x1000" (snd $ eigSH (trustSym m)) |
731 | time "eigenvalues general 1000x1000" (eigenvalues m) | 870 | time "eigenvalues general 1000x1000" (eigenvalues m) |
732 | time "eigenvectors general 1000x1000" (snd $ eig m) | 871 | time "eigenvectors general 1000x1000" (snd $ eig m) |
733 | 872 | ||
@@ -736,7 +875,7 @@ eigBench = do | |||
736 | svdBench = do | 875 | svdBench = do |
737 | let a = reshape 500 (randomVector 777 Uniform (3000*500)) | 876 | let a = reshape 500 (randomVector 777 Uniform (3000*500)) |
738 | b = reshape 1000 (randomVector 777 Uniform (1000*1000)) | 877 | b = reshape 1000 (randomVector 777 Uniform (1000*1000)) |
739 | fv (_,_,v) = v@@>(0,0) | 878 | fv (_,_,v) = v `atIndex` (0,0) |
740 | a `seq` b `seq` putStrLn "" | 879 | a `seq` b `seq` putStrLn "" |
741 | time "singular values 3000x500" (singularValues a) | 880 | time "singular values 3000x500" (singularValues a) |
742 | time "thin svd 3000x500" (fv $ thinSVD a) | 881 | time "thin svd 3000x500" (fv $ thinSVD a) |
@@ -748,26 +887,28 @@ svdBench = do | |||
748 | 887 | ||
749 | solveBenchN n = do | 888 | solveBenchN n = do |
750 | let x = uniformSample 777 (2*n) (replicate n (-1,1)) | 889 | let x = uniformSample 777 (2*n) (replicate n (-1,1)) |
751 | a = trans x <> x | 890 | a = tr x <> x |
752 | b = asColumn $ randomVector 666 Uniform n | 891 | b = asColumn $ randomVector 666 Uniform n |
753 | a `seq` b `seq` putStrLn "" | 892 | a `seq` b `seq` putStrLn "" |
754 | time ("svd solve " ++ show n) (linearSolveSVD a b) | 893 | time ("svd solve " ++ show n) (linearSolveSVD a b) |
755 | time (" ls solve " ++ show n) (linearSolveLS a b) | 894 | time (" ls solve " ++ show n) (linearSolveLS a b) |
756 | time (" solve " ++ show n) (linearSolve a b) | 895 | time (" solve " ++ show n) (linearSolve a b) |
757 | time ("cholSolve " ++ show n) (cholSolve (chol a) b) | 896 | -- time (" LU solve " ++ show n) (luSolve (luPacked a) b) |
897 | time ("LDL solve " ++ show n) (ldlSolve (ldlPacked (trustSym a)) b) | ||
898 | time ("cholSolve " ++ show n) (cholSolve (chol $ trustSym a) b) | ||
758 | 899 | ||
759 | solveBench = do | 900 | solveBench = do |
760 | solveBenchN 500 | 901 | solveBenchN 500 |
761 | solveBenchN 1000 | 902 | solveBenchN 1000 |
762 | -- solveBenchN 1500 | 903 | solveBenchN 1500 |
763 | 904 | ||
764 | -------------------------------- | 905 | -------------------------------- |
765 | 906 | ||
766 | cholBenchN n = do | 907 | cholBenchN n = do |
767 | let x = uniformSample 777 (2*n) (replicate n (-1,1)) | 908 | let x = uniformSample 777 (2*n) (replicate n (-1,1)) |
768 | a = trans x <> x | 909 | a = tr x <> x |
769 | a `seq` putStr "" | 910 | a `seq` putStr "" |
770 | time ("chol " ++ show n) (chol a) | 911 | time ("chol " ++ show n) (chol $ trustSym a) |
771 | 912 | ||
772 | cholBench = do | 913 | cholBench = do |
773 | putStrLn "" | 914 | putStrLn "" |
@@ -776,3 +917,32 @@ cholBench = do | |||
776 | cholBenchN 300 | 917 | cholBenchN 300 |
777 | -- cholBenchN 150 | 918 | -- cholBenchN 150 |
778 | -- cholBenchN 50 | 919 | -- cholBenchN 50 |
920 | |||
921 | -------------------------------------------------------------------------------- | ||
922 | |||
923 | luBenchN f n x msg = do | ||
924 | let m = diagRect 1 (fromList (replicate n x)) n n | ||
925 | m `seq` putStr "" | ||
926 | time (msg ++ " "++ show n) (rnf $ f m) | ||
927 | |||
928 | luBench = do | ||
929 | putStrLn "" | ||
930 | luBenchN luPacked 1000 (5::R) "luPacked Double " | ||
931 | luBenchN luPacked' 1000 (5::R) "luPacked' Double " | ||
932 | luBenchN luPacked' 1000 (5::Mod 9973 I) "luPacked' I mod 9973" | ||
933 | luBenchN luPacked' 1000 (5::Mod 9973 Z) "luPacked' Z mod 9973" | ||
934 | |||
935 | luBenchN_2 f g n x msg = do | ||
936 | let m = diagRect 1 (fromList (replicate n x)) n n | ||
937 | b = flipud m | ||
938 | m `seq` b `seq` putStr "" | ||
939 | time (msg ++ " "++ show n) (f (g m) b) | ||
940 | |||
941 | luBench_2 = do | ||
942 | putStrLn "" | ||
943 | luBenchN_2 luSolve luPacked 500 (5::R) "luSolve .luPacked Double " | ||
944 | luBenchN_2 luSolve' luPacked' 500 (5::R) "luSolve'.luPacked' Double " | ||
945 | luBenchN_2 luSolve' luPacked' 500 (5::Mod 9973 I) "luSolve'.luPacked' I mod 9973" | ||
946 | luBenchN_2 luSolve' luPacked' 500 (5::Mod 9973 Z) "luSolve'.luPacked' Z mod 9973" | ||
947 | |||
948 | |||
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs index 53fc4d2..3d5441d 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs | |||
@@ -1,5 +1,4 @@ | |||
1 | {-# LANGUAGE FlexibleContexts, UndecidableInstances, CPP, FlexibleInstances #-} | 1 | {-# LANGUAGE FlexibleContexts, UndecidableInstances, FlexibleInstances #-} |
2 | {-# OPTIONS_GHC -fno-warn-unused-imports #-} | ||
3 | ----------------------------------------------------------------------------- | 2 | ----------------------------------------------------------------------------- |
4 | {- | | 3 | {- | |
5 | Module : Numeric.LinearAlgebra.Tests.Instances | 4 | Module : Numeric.LinearAlgebra.Tests.Instances |
@@ -15,9 +14,9 @@ Arbitrary instances for vectors, matrices. | |||
15 | module Numeric.LinearAlgebra.Tests.Instances( | 14 | module Numeric.LinearAlgebra.Tests.Instances( |
16 | Sq(..), rSq,cSq, | 15 | Sq(..), rSq,cSq, |
17 | Rot(..), rRot,cRot, | 16 | Rot(..), rRot,cRot, |
18 | Her(..), rHer,cHer, | 17 | rHer,cHer, |
19 | WC(..), rWC,cWC, | 18 | WC(..), rWC,cWC, |
20 | SqWC(..), rSqWC, cSqWC, | 19 | SqWC(..), rSqWC, cSqWC, rSymWC, cSymWC, |
21 | PosDef(..), rPosDef, cPosDef, | 20 | PosDef(..), rPosDef, cPosDef, |
22 | Consistent(..), rConsist, cConsist, | 21 | Consistent(..), rConsist, cConsist, |
23 | RM,CM, rM,cM, | 22 | RM,CM, rM,cM, |
@@ -26,15 +25,11 @@ module Numeric.LinearAlgebra.Tests.Instances( | |||
26 | 25 | ||
27 | import System.Random | 26 | import System.Random |
28 | 27 | ||
29 | import Numeric.LinearAlgebra | 28 | import Numeric.LinearAlgebra.HMatrix hiding (vector) |
30 | import Numeric.LinearAlgebra.Devel | ||
31 | import Numeric.Container | ||
32 | import Control.Monad(replicateM) | 29 | import Control.Monad(replicateM) |
33 | import Test.QuickCheck(Arbitrary,arbitrary,coarbitrary,choose,vector | 30 | import Test.QuickCheck(Arbitrary,arbitrary,choose,vector,sized,shrink) |
34 | ,sized,classify,Testable,Property | 31 | |
35 | ,quickCheckWith,maxSize,stdArgs,shrink) | ||
36 | 32 | ||
37 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
38 | shrinkListElementwise :: (Arbitrary a) => [a] -> [[a]] | 33 | shrinkListElementwise :: (Arbitrary a) => [a] -> [[a]] |
39 | shrinkListElementwise [] = [] | 34 | shrinkListElementwise [] = [] |
40 | shrinkListElementwise (x:xs) = [ y:xs | y <- shrink x ] | 35 | shrinkListElementwise (x:xs) = [ y:xs | y <- shrink x ] |
@@ -42,25 +37,6 @@ shrinkListElementwise (x:xs) = [ y:xs | y <- shrink x ] | |||
42 | 37 | ||
43 | shrinkPair :: (Arbitrary a, Arbitrary b) => (a,b) -> [(a,b)] | 38 | shrinkPair :: (Arbitrary a, Arbitrary b) => (a,b) -> [(a,b)] |
44 | shrinkPair (a,b) = [ (a,x) | x <- shrink b ] ++ [ (x,b) | x <- shrink a ] | 39 | shrinkPair (a,b) = [ (a,x) | x <- shrink b ] ++ [ (x,b) | x <- shrink a ] |
45 | #endif | ||
46 | |||
47 | #if MIN_VERSION_QuickCheck(2,1,1) | ||
48 | #else | ||
49 | instance (Arbitrary a, RealFloat a) => Arbitrary (Complex a) where | ||
50 | arbitrary = do | ||
51 | re <- arbitrary | ||
52 | im <- arbitrary | ||
53 | return (re :+ im) | ||
54 | |||
55 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
56 | shrink (re :+ im) = | ||
57 | [ u :+ v | (u,v) <- shrinkPair (re,im) ] | ||
58 | #else | ||
59 | -- this has been moved to the 'Coarbitrary' class in QuickCheck 2 | ||
60 | coarbitrary = undefined | ||
61 | #endif | ||
62 | |||
63 | #endif | ||
64 | 40 | ||
65 | chooseDim = sized $ \m -> choose (1,max 1 m) | 41 | chooseDim = sized $ \m -> choose (1,max 1 m) |
66 | 42 | ||
@@ -68,15 +44,9 @@ instance (Field a, Arbitrary a) => Arbitrary (Vector a) where | |||
68 | arbitrary = do m <- chooseDim | 44 | arbitrary = do m <- chooseDim |
69 | l <- vector m | 45 | l <- vector m |
70 | return $ fromList l | 46 | return $ fromList l |
71 | |||
72 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
73 | -- shrink any one of the components | 47 | -- shrink any one of the components |
74 | shrink = map fromList . shrinkListElementwise . toList | 48 | shrink = map fromList . shrinkListElementwise . toList |
75 | 49 | ||
76 | #else | ||
77 | coarbitrary = undefined | ||
78 | #endif | ||
79 | |||
80 | instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where | 50 | instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where |
81 | arbitrary = do | 51 | arbitrary = do |
82 | m <- chooseDim | 52 | m <- chooseDim |
@@ -84,16 +54,11 @@ instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where | |||
84 | l <- vector (m*n) | 54 | l <- vector (m*n) |
85 | return $ (m><n) l | 55 | return $ (m><n) l |
86 | 56 | ||
87 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
88 | -- shrink any one of the components | 57 | -- shrink any one of the components |
89 | shrink a = map (rows a >< cols a) | 58 | shrink a = map (rows a >< cols a) |
90 | . shrinkListElementwise | 59 | . shrinkListElementwise |
91 | . concat . toLists | 60 | . concat . toLists |
92 | $ a | 61 | $ a |
93 | #else | ||
94 | coarbitrary = undefined | ||
95 | #endif | ||
96 | |||
97 | 62 | ||
98 | -- a square matrix | 63 | -- a square matrix |
99 | newtype (Sq a) = Sq (Matrix a) deriving Show | 64 | newtype (Sq a) = Sq (Matrix a) deriving Show |
@@ -103,11 +68,7 @@ instance (Element a, Arbitrary a) => Arbitrary (Sq a) where | |||
103 | l <- vector (n*n) | 68 | l <- vector (n*n) |
104 | return $ Sq $ (n><n) l | 69 | return $ Sq $ (n><n) l |
105 | 70 | ||
106 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
107 | shrink (Sq a) = [ Sq b | b <- shrink a ] | 71 | shrink (Sq a) = [ Sq b | b <- shrink a ] |
108 | #else | ||
109 | coarbitrary = undefined | ||
110 | #endif | ||
111 | 72 | ||
112 | 73 | ||
113 | -- a unitary matrix | 74 | -- a unitary matrix |
@@ -118,24 +79,14 @@ instance (Field a, Arbitrary a) => Arbitrary (Rot a) where | |||
118 | let (q,_) = qr m | 79 | let (q,_) = qr m |
119 | return (Rot q) | 80 | return (Rot q) |
120 | 81 | ||
121 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
122 | #else | ||
123 | coarbitrary = undefined | ||
124 | #endif | ||
125 | |||
126 | 82 | ||
127 | -- a complex hermitian or real symmetric matrix | 83 | -- a complex hermitian or real symmetric matrix |
128 | newtype (Her a) = Her (Matrix a) deriving Show | 84 | instance (Field a, Arbitrary a, Num (Vector a)) => Arbitrary (Herm a) where |
129 | instance (Field a, Arbitrary a, Num (Vector a)) => Arbitrary (Her a) where | ||
130 | arbitrary = do | 85 | arbitrary = do |
131 | Sq m <- arbitrary | 86 | Sq m <- arbitrary |
132 | let m' = m/2 | 87 | let m' = m/2 |
133 | return $ Her (m' + ctrans m') | 88 | return $ sym m' |
134 | 89 | ||
135 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
136 | #else | ||
137 | coarbitrary = undefined | ||
138 | #endif | ||
139 | 90 | ||
140 | class (Field a, Arbitrary a, Element (RealOf a), Random (RealOf a)) => ArbitraryField a | 91 | class (Field a, Arbitrary a, Element (RealOf a), Random (RealOf a)) => ArbitraryField a |
141 | instance ArbitraryField Double | 92 | instance ArbitraryField Double |
@@ -144,7 +95,7 @@ instance ArbitraryField (Complex Double) | |||
144 | 95 | ||
145 | -- a well-conditioned general matrix (the singular values are between 1 and 100) | 96 | -- a well-conditioned general matrix (the singular values are between 1 and 100) |
146 | newtype (WC a) = WC (Matrix a) deriving Show | 97 | newtype (WC a) = WC (Matrix a) deriving Show |
147 | instance (ArbitraryField a) => Arbitrary (WC a) where | 98 | instance (Numeric a, ArbitraryField a) => Arbitrary (WC a) where |
148 | arbitrary = do | 99 | arbitrary = do |
149 | m <- arbitrary | 100 | m <- arbitrary |
150 | let (u,_,v) = svd m | 101 | let (u,_,v) = svd m |
@@ -153,48 +104,33 @@ instance (ArbitraryField a) => Arbitrary (WC a) where | |||
153 | n = min r c | 104 | n = min r c |
154 | sv' <- replicateM n (choose (1,100)) | 105 | sv' <- replicateM n (choose (1,100)) |
155 | let s = diagRect 0 (fromList sv') r c | 106 | let s = diagRect 0 (fromList sv') r c |
156 | return $ WC (u `mXm` real s `mXm` trans v) | 107 | return $ WC (u <> real s <> tr v) |
157 | |||
158 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
159 | #else | ||
160 | coarbitrary = undefined | ||
161 | #endif | ||
162 | 108 | ||
163 | 109 | ||
164 | -- a well-conditioned square matrix (the singular values are between 1 and 100) | 110 | -- a well-conditioned square matrix (the singular values are between 1 and 100) |
165 | newtype (SqWC a) = SqWC (Matrix a) deriving Show | 111 | newtype (SqWC a) = SqWC (Matrix a) deriving Show |
166 | instance (ArbitraryField a) => Arbitrary (SqWC a) where | 112 | instance (ArbitraryField a, Numeric a) => Arbitrary (SqWC a) where |
167 | arbitrary = do | 113 | arbitrary = do |
168 | Sq m <- arbitrary | 114 | Sq m <- arbitrary |
169 | let (u,_,v) = svd m | 115 | let (u,_,v) = svd m |
170 | n = rows m | 116 | n = rows m |
171 | sv' <- replicateM n (choose (1,100)) | 117 | sv' <- replicateM n (choose (1,100)) |
172 | let s = diag (fromList sv') | 118 | let s = diag (fromList sv') |
173 | return $ SqWC (u `mXm` real s `mXm` trans v) | 119 | return $ SqWC (u <> real s <> tr v) |
174 | |||
175 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
176 | #else | ||
177 | coarbitrary = undefined | ||
178 | #endif | ||
179 | 120 | ||
180 | 121 | ||
181 | -- a positive definite square matrix (the eigenvalues are between 0 and 100) | 122 | -- a positive definite square matrix (the eigenvalues are between 0 and 100) |
182 | newtype (PosDef a) = PosDef (Matrix a) deriving Show | 123 | newtype (PosDef a) = PosDef (Matrix a) deriving Show |
183 | instance (ArbitraryField a, Num (Vector a)) | 124 | instance (Numeric a, ArbitraryField a, Num (Vector a)) |
184 | => Arbitrary (PosDef a) where | 125 | => Arbitrary (PosDef a) where |
185 | arbitrary = do | 126 | arbitrary = do |
186 | Her m <- arbitrary | 127 | m <- arbitrary |
187 | let (_,v) = eigSH m | 128 | let (_,v) = eigSH m |
188 | n = rows m | 129 | n = rows (unSym m) |
189 | l <- replicateM n (choose (0,100)) | 130 | l <- replicateM n (choose (0,100)) |
190 | let s = diag (fromList l) | 131 | let s = diag (fromList l) |
191 | p = v `mXm` real s `mXm` ctrans v | 132 | p = v <> real s <> tr v |
192 | return $ PosDef (0.5 * p + 0.5 * ctrans p) | 133 | return $ PosDef (0.5 * p + 0.5 * tr p) |
193 | |||
194 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
195 | #else | ||
196 | coarbitrary = undefined | ||
197 | #endif | ||
198 | 134 | ||
199 | 135 | ||
200 | -- a pair of matrices that can be multiplied | 136 | -- a pair of matrices that can be multiplied |
@@ -208,11 +144,7 @@ instance (Field a, Arbitrary a) => Arbitrary (Consistent a) where | |||
208 | lb <- vector (k*m) | 144 | lb <- vector (k*m) |
209 | return $ Consistent ((n><k) la, (k><m) lb) | 145 | return $ Consistent ((n><k) la, (k><m) lb) |
210 | 146 | ||
211 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
212 | shrink (Consistent (x,y)) = [ Consistent (u,v) | (u,v) <- shrinkPair (x,y) ] | 147 | shrink (Consistent (x,y)) = [ Consistent (u,v) | (u,v) <- shrinkPair (x,y) ] |
213 | #else | ||
214 | coarbitrary = undefined | ||
215 | #endif | ||
216 | 148 | ||
217 | 149 | ||
218 | 150 | ||
@@ -228,8 +160,8 @@ fM m = m :: FM | |||
228 | zM m = m :: ZM | 160 | zM m = m :: ZM |
229 | 161 | ||
230 | 162 | ||
231 | rHer (Her m) = m :: RM | 163 | rHer m = unSym m :: RM |
232 | cHer (Her m) = m :: CM | 164 | cHer m = unSym m :: CM |
233 | 165 | ||
234 | rRot (Rot m) = m :: RM | 166 | rRot (Rot m) = m :: RM |
235 | cRot (Rot m) = m :: CM | 167 | cRot (Rot m) = m :: CM |
@@ -243,6 +175,9 @@ cWC (WC m) = m :: CM | |||
243 | rSqWC (SqWC m) = m :: RM | 175 | rSqWC (SqWC m) = m :: RM |
244 | cSqWC (SqWC m) = m :: CM | 176 | cSqWC (SqWC m) = m :: CM |
245 | 177 | ||
178 | rSymWC (SqWC m) = sym m :: Herm R | ||
179 | cSymWC (SqWC m) = sym m :: Herm C | ||
180 | |||
246 | rPosDef (PosDef m) = m :: RM | 181 | rPosDef (PosDef m) = m :: RM |
247 | cPosDef (PosDef m) = m :: CM | 182 | cPosDef (PosDef m) = m :: CM |
248 | 183 | ||
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs index a5c37f4..046644f 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs | |||
@@ -1,5 +1,4 @@ | |||
1 | {-# LANGUAGE CPP, FlexibleContexts #-} | 1 | {-# LANGUAGE FlexibleContexts #-} |
2 | {-# OPTIONS_GHC -fno-warn-unused-imports #-} | ||
3 | {-# LANGUAGE TypeFamilies #-} | 2 | {-# LANGUAGE TypeFamilies #-} |
4 | 3 | ||
5 | ----------------------------------------------------------------------------- | 4 | ----------------------------------------------------------------------------- |
@@ -15,7 +14,7 @@ Testing properties. | |||
15 | -} | 14 | -} |
16 | 15 | ||
17 | module Numeric.LinearAlgebra.Tests.Properties ( | 16 | module Numeric.LinearAlgebra.Tests.Properties ( |
18 | dist, (|~|), (~~), (~:), Aprox((:~)), | 17 | dist, (|~|), (~~), (~:), Aprox((:~)), (~=), |
19 | zeros, ones, | 18 | zeros, ones, |
20 | square, | 19 | square, |
21 | unitary, | 20 | unitary, |
@@ -29,7 +28,7 @@ module Numeric.LinearAlgebra.Tests.Properties ( | |||
29 | pinvProp, | 28 | pinvProp, |
30 | detProp, | 29 | detProp, |
31 | nullspaceProp, | 30 | nullspaceProp, |
32 | bugProp, | 31 | -- bugProp, |
33 | svdProp1, svdProp1a, svdProp1b, svdProp2, svdProp3, svdProp4, | 32 | svdProp1, svdProp1a, svdProp1b, svdProp2, svdProp3, svdProp4, |
34 | svdProp5a, svdProp5b, svdProp6a, svdProp6b, svdProp7, | 33 | svdProp5a, svdProp5b, svdProp6a, svdProp6b, svdProp7, |
35 | eigProp, eigSHProp, eigProp2, eigSHProp2, | 34 | eigProp, eigSHProp, eigProp2, eigSHProp2, |
@@ -40,23 +39,21 @@ module Numeric.LinearAlgebra.Tests.Properties ( | |||
40 | expmDiagProp, | 39 | expmDiagProp, |
41 | multProp1, multProp2, | 40 | multProp1, multProp2, |
42 | subProp, | 41 | subProp, |
43 | linearSolveProp, linearSolveProp2 | 42 | linearSolveProp, linearSolvePropH, linearSolveProp2 |
44 | ) where | 43 | ) where |
45 | 44 | ||
46 | import Numeric.Container | 45 | import Numeric.LinearAlgebra.HMatrix hiding (Testable,unitary) |
47 | import Numeric.LinearAlgebra --hiding (real,complex) | 46 | import Test.QuickCheck |
48 | import Numeric.LinearAlgebra.LAPACK | 47 | |
49 | import Debug.Trace | 48 | (~=) :: Double -> Double -> Bool |
50 | import Test.QuickCheck(Arbitrary,arbitrary,coarbitrary,choose,vector | 49 | a ~= b = abs (a - b) < 1e-10 |
51 | ,sized,classify,Testable,Property | ||
52 | ,quickCheckWith,maxSize,stdArgs,shrink) | ||
53 | 50 | ||
54 | trivial :: Testable a => Bool -> a -> Property | 51 | trivial :: Testable a => Bool -> a -> Property |
55 | trivial = (`classify` "trivial") | 52 | trivial = (`classify` "trivial") |
56 | 53 | ||
57 | -- relative error | 54 | -- relative error |
58 | dist :: (Normed c t, Num (c t)) => c t -> c t -> Double | 55 | dist :: (Num a, Normed a) => a -> a -> Double |
59 | dist = relativeError Infinity | 56 | dist = relativeError norm_Inf |
60 | 57 | ||
61 | infixl 4 |~| | 58 | infixl 4 |~| |
62 | a |~| b = a :~10~: b | 59 | a |~| b = a :~10~: b |
@@ -73,11 +70,11 @@ a :~n~: b = dist a b < 10^^(-n) | |||
73 | square m = rows m == cols m | 70 | square m = rows m == cols m |
74 | 71 | ||
75 | -- orthonormal columns | 72 | -- orthonormal columns |
76 | orthonormal m = ctrans m <> m |~| ident (cols m) | 73 | orthonormal m = tr m <> m |~| ident (cols m) |
77 | 74 | ||
78 | unitary m = square m && orthonormal m | 75 | unitary m = square m && orthonormal m |
79 | 76 | ||
80 | hermitian m = square m && m |~| ctrans m | 77 | hermitian m = square m && m |~| tr m |
81 | 78 | ||
82 | wellCond m = rcond m > 1/100 | 79 | wellCond m = rcond m > 1/100 |
83 | 80 | ||
@@ -85,12 +82,12 @@ positiveDefinite m = minimum (toList e) > 0 | |||
85 | where (e,_v) = eigSH m | 82 | where (e,_v) = eigSH m |
86 | 83 | ||
87 | upperTriang m = rows m == 1 || down == z | 84 | upperTriang m = rows m == 1 || down == z |
88 | where down = fromList $ concat $ zipWith drop [1..] (toLists (ctrans m)) | 85 | where down = fromList $ concat $ zipWith drop [1..] (toLists (tr m)) |
89 | z = konst 0 (dim down) | 86 | z = konst 0 (size down) |
90 | 87 | ||
91 | upperHessenberg m = rows m < 3 || down == z | 88 | upperHessenberg m = rows m < 3 || down == z |
92 | where down = fromList $ concat $ zipWith drop [2..] (toLists (ctrans m)) | 89 | where down = fromList $ concat $ zipWith drop [2..] (toLists (tr m)) |
93 | z = konst 0 (dim down) | 90 | z = konst 0 (size down) |
94 | 91 | ||
95 | zeros (r,c) = reshape c (konst 0 (r*c)) | 92 | zeros (r,c) = reshape c (konst 0 (r*c)) |
96 | 93 | ||
@@ -118,81 +115,94 @@ detProp m = s d1 |~| s d2 | |||
118 | s x = fromList [x] | 115 | s x = fromList [x] |
119 | 116 | ||
120 | nullspaceProp m = null nl `trivial` (null nl || m <> n |~| zeros (r,c) | 117 | nullspaceProp m = null nl `trivial` (null nl || m <> n |~| zeros (r,c) |
121 | && orthonormal (fromColumns nl)) | 118 | && orthonormal n) |
122 | where nl = nullspacePrec 1 m | 119 | where n = nullspaceSVD (Left (1*peps)) m (rightSV m) |
123 | n = fromColumns nl | 120 | nl = toColumns n |
124 | r = rows m | 121 | r = rows m |
125 | c = cols m - rank m | 122 | c = cols m - rank m |
126 | 123 | ||
127 | ------------------------------------------------------------------ | 124 | ------------------------------------------------------------------ |
128 | 125 | {- | |
129 | -- testcase for nonempty fpu stack | 126 | -- testcase for nonempty fpu stack |
130 | -- uncommenting unitary' signature eliminates the problem | 127 | -- uncommenting unitary' signature eliminates the problem |
131 | bugProp m = m |~| u <> real d <> trans v && unitary' u && unitary' v | 128 | bugProp m = m |~| u <> real d <> tr v && unitary' u && unitary' v |
132 | where (u,d,v) = fullSVD m | 129 | where (u,d,v) = svd m |
133 | -- unitary' :: (Num (Vector t), Field t) => Matrix t -> Bool | 130 | -- unitary' :: (Num (Vector t), Field t) => Matrix t -> Bool |
134 | unitary' a = unitary a | 131 | unitary' a = unitary a |
135 | 132 | -} | |
136 | ------------------------------------------------------------------ | 133 | ------------------------------------------------------------------ |
137 | 134 | ||
138 | -- fullSVD | 135 | -- fullSVD |
139 | svdProp1 m = m |~| u <> real d <> trans v && unitary u && unitary v | 136 | svdProp1 m = m |~| u <> real d <> tr v && unitary u && unitary v |
140 | where (u,d,v) = fullSVD m | 137 | where |
138 | (u,s,v) = svd m | ||
139 | d = diagRect 0 s (rows m) (cols m) | ||
141 | 140 | ||
142 | svdProp1a svdfun m = m |~| u <> real d <> trans v && unitary u && unitary v where | 141 | svdProp1a svdfun m = m |~| u <> real d <> tr v && unitary u && unitary v |
142 | where | ||
143 | (u,s,v) = svdfun m | 143 | (u,s,v) = svdfun m |
144 | d = diagRect 0 s (rows m) (cols m) | 144 | d = diagRect 0 s (rows m) (cols m) |
145 | 145 | ||
146 | svdProp1b svdfun m = unitary u && unitary v where | 146 | svdProp1b svdfun m = unitary u && unitary v |
147 | where | ||
147 | (u,_,v) = svdfun m | 148 | (u,_,v) = svdfun m |
148 | 149 | ||
149 | -- thinSVD | 150 | -- thinSVD |
150 | svdProp2 thinSVDfun m = m |~| u <> diag (real s) <> trans v && orthonormal u && orthonormal v && dim s == min (rows m) (cols m) | 151 | svdProp2 thinSVDfun m |
151 | where (u,s,v) = thinSVDfun m | 152 | = m |~| u <> diag (real s) <> tr v |
153 | && orthonormal u && orthonormal v | ||
154 | && size s == min (rows m) (cols m) | ||
155 | where | ||
156 | (u,s,v) = thinSVDfun m | ||
152 | 157 | ||
153 | -- compactSVD | 158 | -- compactSVD |
154 | svdProp3 m = (m |~| u <> real (diag s) <> trans v | 159 | svdProp3 m = (m |~| u <> real (diag s) <> tr v |
155 | && orthonormal u && orthonormal v) | 160 | && orthonormal u && orthonormal v) |
156 | where (u,s,v) = compactSVD m | 161 | where |
162 | (u,s,v) = compactSVD m | ||
157 | 163 | ||
158 | svdProp4 m' = m |~| u <> real (diag s) <> trans v | 164 | svdProp4 m' = m |~| u <> real (diag s) <> tr v |
159 | && orthonormal u && orthonormal v | 165 | && orthonormal u && orthonormal v |
160 | && (dim s == r || r == 0 && dim s == 1) | 166 | && (size s == r || r == 0 && size s == 1) |
161 | where (u,s,v) = compactSVD m | 167 | where |
162 | m = fromBlocks [[m'],[m']] | 168 | (u,s,v) = compactSVD m |
163 | r = rank m' | 169 | m = fromBlocks [[m'],[m']] |
164 | 170 | r = rank m' | |
165 | svdProp5a m = all (s1|~|) [s2,s3,s4,s5,s6] where | 171 | |
166 | s1 = svR m | 172 | svdProp5a m = all (s1|~|) [s3,s5] where |
167 | s2 = svRd m | 173 | s1 = singularValues (m :: Matrix Double) |
168 | (_,s3,_) = svdR m | 174 | -- s2 = svRd m |
169 | (_,s4,_) = svdRd m | 175 | (_,s3,_) = svd m |
170 | (_,s5,_) = thinSVDR m | 176 | -- (_,s4,_) = svdRd m |
171 | (_,s6,_) = thinSVDRd m | 177 | (_,s5,_) = thinSVD m |
172 | 178 | -- (_,s6,_) = thinSVDRd m | |
173 | svdProp5b m = all (s1|~|) [s2,s3,s4,s5,s6] where | 179 | |
174 | s1 = svC m | 180 | svdProp5b m = all (s1|~|) [s3,s5] where |
175 | s2 = svCd m | 181 | s1 = singularValues (m :: Matrix (Complex Double)) |
176 | (_,s3,_) = svdC m | 182 | -- s2 = svCd m |
177 | (_,s4,_) = svdCd m | 183 | (_,s3,_) = svd m |
178 | (_,s5,_) = thinSVDC m | 184 | -- (_,s4,_) = svdCd m |
179 | (_,s6,_) = thinSVDCd m | 185 | (_,s5,_) = thinSVD m |
186 | -- (_,s6,_) = thinSVDCd m | ||
180 | 187 | ||
181 | svdProp6a m = s |~| s' && v |~| v' && s |~| s'' && u |~| u' | 188 | svdProp6a m = s |~| s' && v |~| v' && s |~| s'' && u |~| u' |
182 | where (u,s,v) = svdR m | 189 | where |
183 | (s',v') = rightSVR m | 190 | (u,s,v) = svd (m :: Matrix Double) |
184 | (u',s'') = leftSVR m | 191 | (s',v') = rightSV m |
192 | (u',s'') = leftSV m | ||
185 | 193 | ||
186 | svdProp6b m = s |~| s' && v |~| v' && s |~| s'' && u |~| u' | 194 | svdProp6b m = s |~| s' && v |~| v' && s |~| s'' && u |~| u' |
187 | where (u,s,v) = svdC m | 195 | where |
188 | (s',v') = rightSVC m | 196 | (u,s,v) = svd (m :: Matrix (Complex Double)) |
189 | (u',s'') = leftSVC m | 197 | (s',v') = rightSV m |
198 | (u',s'') = leftSV m | ||
190 | 199 | ||
191 | svdProp7 m = s |~| s' && u |~| u' && v |~| v' && s |~| s''' | 200 | svdProp7 m = s |~| s' && u |~| u' && v |~| v' && s |~| s''' |
192 | where (u,s,v) = svd m | 201 | where |
193 | (s',v') = rightSV m | 202 | (u,s,v) = svd m |
194 | (u',_s'') = leftSV m | 203 | (s',v') = rightSV m |
195 | s''' = singularValues m | 204 | (u',_s'') = leftSV m |
205 | s''' = singularValues m | ||
196 | 206 | ||
197 | ------------------------------------------------------------------ | 207 | ------------------------------------------------------------------ |
198 | 208 | ||
@@ -201,12 +211,12 @@ eigProp m = complex m <> v |~| v <> diag s | |||
201 | 211 | ||
202 | eigSHProp m = m <> v |~| v <> real (diag s) | 212 | eigSHProp m = m <> v |~| v <> real (diag s) |
203 | && unitary v | 213 | && unitary v |
204 | && m |~| v <> real (diag s) <> ctrans v | 214 | && m |~| v <> real (diag s) <> tr v |
205 | where (s, v) = eigSH m | 215 | where (s, v) = eigSH' m |
206 | 216 | ||
207 | eigProp2 m = fst (eig m) |~| eigenvalues m | 217 | eigProp2 m = fst (eig m) |~| eigenvalues m |
208 | 218 | ||
209 | eigSHProp2 m = fst (eigSH m) |~| eigenvaluesSH m | 219 | eigSHProp2 m = fst (eigSH' m) |~| eigenvaluesSH' m |
210 | 220 | ||
211 | ------------------------------------------------------------------ | 221 | ------------------------------------------------------------------ |
212 | 222 | ||
@@ -226,22 +236,22 @@ rqProp3 m = upperTriang' r | |||
226 | where (r,_q) = rq m | 236 | where (r,_q) = rq m |
227 | 237 | ||
228 | upperTriang' r = upptr (rows r) (cols r) * r |~| r | 238 | upperTriang' r = upptr (rows r) (cols r) * r |~| r |
229 | where upptr f c = buildMatrix f c $ \(r',c') -> if r'-t > c' then 0 else 1 | 239 | where upptr f c = build (f,c) $ \r' c' -> if r'-t > c' then 0 else 1 |
230 | where t = f-c | 240 | where t = fromIntegral (f-c) |
231 | 241 | ||
232 | hessProp m = m |~| p <> h <> ctrans p && unitary p && upperHessenberg h | 242 | hessProp m = m |~| p <> h <> tr p && unitary p && upperHessenberg h |
233 | where (p,h) = hess m | 243 | where (p,h) = hess m |
234 | 244 | ||
235 | schurProp1 m = m |~| u <> s <> ctrans u && unitary u && upperTriang s | 245 | schurProp1 m = m |~| u <> s <> tr u && unitary u && upperTriang s |
236 | where (u,s) = schur m | 246 | where (u,s) = schur m |
237 | 247 | ||
238 | schurProp2 m = m |~| u <> s <> ctrans u && unitary u && upperHessenberg s -- fixme | 248 | schurProp2 m = m |~| u <> s <> tr u && unitary u && upperHessenberg s -- fixme |
239 | where (u,s) = schur m | 249 | where (u,s) = schur m |
240 | 250 | ||
241 | cholProp m = m |~| ctrans c <> c && upperTriang c | 251 | cholProp m = m |~| tr c <> c && upperTriang c |
242 | where c = chol m | 252 | where c = chol (trustSym m) |
243 | 253 | ||
244 | exactProp m = chol m == chol (m+0) | 254 | exactProp m = chol (trustSym m) == chol (trustSym (m+0)) |
245 | 255 | ||
246 | expmDiagProp m = expm (logm m) :~ 7 ~: complex m | 256 | expmDiagProp m = expm (logm m) :~ 7 ~: complex m |
247 | where logm = matFunc log | 257 | where logm = matFunc log |
@@ -252,14 +262,16 @@ mulH a b = fromLists [[ doth ai bj | bj <- toColumns b] | ai <- toRows a ] | |||
252 | 262 | ||
253 | multProp1 p (a,b) = (a <> b) :~p~: (mulH a b) | 263 | multProp1 p (a,b) = (a <> b) :~p~: (mulH a b) |
254 | 264 | ||
255 | multProp2 p (a,b) = (ctrans (a <> b)) :~p~: (ctrans b <> ctrans a) | 265 | multProp2 p (a,b) = (tr (a <> b)) :~p~: (tr b <> tr a) |
256 | 266 | ||
257 | linearSolveProp f m = f m m |~| ident (rows m) | 267 | linearSolveProp f m = f m m |~| ident (rows m) |
258 | 268 | ||
269 | linearSolvePropH f m = f m (unSym m) |~| ident (rows (unSym m)) | ||
270 | |||
259 | linearSolveProp2 f (a,x) = not wc `trivial` (not wc || a <> f a b |~| b) | 271 | linearSolveProp2 f (a,x) = not wc `trivial` (not wc || a <> f a b |~| b) |
260 | where q = min (rows a) (cols a) | 272 | where q = min (rows a) (cols a) |
261 | b = a <> x | 273 | b = a <> x |
262 | wc = rank a == q | 274 | wc = rank a == q |
263 | 275 | ||
264 | subProp m = m == (trans . fromColumns . toRows) m | 276 | subProp m = m == (conj . tr . fromColumns . toRows) m |
265 | 277 | ||
diff --git a/stack.yaml b/stack.yaml new file mode 100644 index 0000000..88394c7 --- /dev/null +++ b/stack.yaml | |||
@@ -0,0 +1,18 @@ | |||
1 | flags: | ||
2 | hmatrix-special: | ||
3 | safe-cheap: false | ||
4 | hmatrix-tests: | ||
5 | gsl: true | ||
6 | hmatrix: | ||
7 | openblas: false | ||
8 | hmatrix-gsl: | ||
9 | onlygsl: false | ||
10 | packages: | ||
11 | - packages\tests\ | ||
12 | - packages\special\ | ||
13 | - packages\sparse\ | ||
14 | - packages\gsl\ | ||
15 | - packages\glpk\ | ||
16 | - packages\base\ | ||
17 | extra-deps: [] | ||
18 | resolver: lts-3.3 | ||