diff options
author | Alberto Ruiz <aruiz@um.es> | 2014-05-27 20:24:12 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2014-05-27 20:24:12 +0200 |
commit | 3c1c5e59e3d699f3e17519f19d47f7dab2403879 (patch) | |
tree | a749e0a3fb515ad1a904ce7387fbd3afd2ee0ed3 /packages/sparse/src/Numeric | |
parent | 53559833d2166010eed754027484fb8d5525e710 (diff) |
initial interface to MKL sparse solver
Diffstat (limited to 'packages/sparse/src/Numeric')
-rw-r--r-- | packages/sparse/src/Numeric/LinearAlgebra/Sparse.hs | 32 | ||||
-rw-r--r-- | packages/sparse/src/Numeric/LinearAlgebra/sparse.c | 65 |
2 files changed, 97 insertions, 0 deletions
diff --git a/packages/sparse/src/Numeric/LinearAlgebra/Sparse.hs b/packages/sparse/src/Numeric/LinearAlgebra/Sparse.hs new file mode 100644 index 0000000..ccf28b7 --- /dev/null +++ b/packages/sparse/src/Numeric/LinearAlgebra/Sparse.hs | |||
@@ -0,0 +1,32 @@ | |||
1 | {-# LANGUAGE ForeignFunctionInterface #-} | ||
2 | {-# LANGUAGE RecordWildCards #-} | ||
3 | |||
4 | |||
5 | |||
6 | module Numeric.LinearAlgebra.Sparse ( | ||
7 | dss | ||
8 | ) where | ||
9 | |||
10 | import Foreign.C.Types(CInt(..)) | ||
11 | import Data.Packed.Development | ||
12 | import System.IO.Unsafe(unsafePerformIO) | ||
13 | import Foreign(Ptr) | ||
14 | import Numeric.HMatrix | ||
15 | import Text.Printf | ||
16 | import Numeric.LinearAlgebra.Util((~!~)) | ||
17 | |||
18 | |||
19 | type IV t = CInt -> Ptr CInt -> t | ||
20 | type V t = CInt -> Ptr Double -> t | ||
21 | type SMxV = V (IV (IV (V (V (IO CInt))))) | ||
22 | |||
23 | dss :: CSR -> Vector Double -> Vector Double | ||
24 | dss CSR{..} b = unsafePerformIO $ do | ||
25 | size b /= csrNRows ~!~ printf "dss: incorrect sizes: (%d,%d) x %d" csrNRows csrNCols (size b) | ||
26 | r <- createVector csrNCols | ||
27 | app5 c_dss vec csrVals vec csrCols vec csrRows vec b vec r "dss" | ||
28 | return r | ||
29 | |||
30 | foreign import ccall unsafe "dss" | ||
31 | c_dss :: SMxV | ||
32 | |||
diff --git a/packages/sparse/src/Numeric/LinearAlgebra/sparse.c b/packages/sparse/src/Numeric/LinearAlgebra/sparse.c new file mode 100644 index 0000000..b1e257a --- /dev/null +++ b/packages/sparse/src/Numeric/LinearAlgebra/sparse.c | |||
@@ -0,0 +1,65 @@ | |||
1 | |||
2 | #include <stdio.h> | ||
3 | #include <stdlib.h> | ||
4 | #include <math.h> | ||
5 | |||
6 | #include "mkl_dss.h" | ||
7 | #include "mkl_types.h" | ||
8 | #include "mkl_spblas.h" | ||
9 | |||
10 | #define KIVEC(A) int A##n, const int*A##p | ||
11 | #define KDVEC(A) int A##n, const double*A##p | ||
12 | #define DVEC(A) int A##n, double*A##p | ||
13 | #define OK return 0; | ||
14 | |||
15 | |||
16 | void check_error(int error) | ||
17 | { | ||
18 | if(error != MKL_DSS_SUCCESS) { | ||
19 | printf ("Solver returned error code %d\n", error); | ||
20 | exit (1); | ||
21 | } | ||
22 | } | ||
23 | |||
24 | int dss(KDVEC(vals),KIVEC(cols),KIVEC(rows),KDVEC(x),DVEC(r)) { | ||
25 | MKL_INT nRows = rowsn-1, nCols = rn, nNonZeros = valsn, nRhs = 1; | ||
26 | MKL_INT *rowIndex = (MKL_INT*) rowsp; | ||
27 | MKL_INT *columns = (MKL_INT*) colsp; | ||
28 | double *values = (double*) valsp; | ||
29 | _DOUBLE_PRECISION_t *rhs = (_DOUBLE_PRECISION_t*) xp; | ||
30 | // _DOUBLE_PRECISION_t *obtrhs = (_DOUBLE_PRECISION_t*) malloc((nCols)*sizeof(_DOUBLE_PRECISION_t)); | ||
31 | _DOUBLE_PRECISION_t *solValues = (_DOUBLE_PRECISION_t*) rp; | ||
32 | |||
33 | _MKL_DSS_HANDLE_t handle; | ||
34 | _INTEGER_t error; | ||
35 | // _CHARACTER_t *uplo; | ||
36 | MKL_INT opt; | ||
37 | |||
38 | opt = MKL_DSS_DEFAULTS; | ||
39 | error = dss_create(handle, opt); | ||
40 | check_error(error); | ||
41 | |||
42 | opt = MKL_DSS_NON_SYMMETRIC; | ||
43 | error = dss_define_structure(handle, opt, rowIndex, nRows, nCols, columns, nNonZeros); | ||
44 | check_error(error); | ||
45 | |||
46 | opt = MKL_DSS_DEFAULTS; | ||
47 | error = dss_reorder(handle, opt, 0); | ||
48 | check_error(error); | ||
49 | |||
50 | opt = MKL_DSS_INDEFINITE; | ||
51 | error = dss_factor_real(handle, opt, values); | ||
52 | check_error(error); | ||
53 | |||
54 | int j; | ||
55 | for (j = 0; j < nCols; j++) { | ||
56 | solValues[j] = 0.0; | ||
57 | } | ||
58 | |||
59 | // Solve system | ||
60 | opt = MKL_DSS_REFINEMENT_ON; | ||
61 | error = dss_solve_real(handle, opt, rhs, nRhs, solValues); | ||
62 | check_error(error); | ||
63 | |||
64 | OK | ||
65 | } | ||