summaryrefslogtreecommitdiff
path: root/packages/sparse/src/Numeric
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2014-05-27 20:24:12 +0200
committerAlberto Ruiz <aruiz@um.es>2014-05-27 20:24:12 +0200
commit3c1c5e59e3d699f3e17519f19d47f7dab2403879 (patch)
treea749e0a3fb515ad1a904ce7387fbd3afd2ee0ed3 /packages/sparse/src/Numeric
parent53559833d2166010eed754027484fb8d5525e710 (diff)
initial interface to MKL sparse solver
Diffstat (limited to 'packages/sparse/src/Numeric')
-rw-r--r--packages/sparse/src/Numeric/LinearAlgebra/Sparse.hs32
-rw-r--r--packages/sparse/src/Numeric/LinearAlgebra/sparse.c65
2 files changed, 97 insertions, 0 deletions
diff --git a/packages/sparse/src/Numeric/LinearAlgebra/Sparse.hs b/packages/sparse/src/Numeric/LinearAlgebra/Sparse.hs
new file mode 100644
index 0000000..ccf28b7
--- /dev/null
+++ b/packages/sparse/src/Numeric/LinearAlgebra/Sparse.hs
@@ -0,0 +1,32 @@
1{-# LANGUAGE ForeignFunctionInterface #-}
2{-# LANGUAGE RecordWildCards #-}
3
4
5
6module Numeric.LinearAlgebra.Sparse (
7 dss
8) where
9
10import Foreign.C.Types(CInt(..))
11import Data.Packed.Development
12import System.IO.Unsafe(unsafePerformIO)
13import Foreign(Ptr)
14import Numeric.HMatrix
15import Text.Printf
16import Numeric.LinearAlgebra.Util((~!~))
17
18
19type IV t = CInt -> Ptr CInt -> t
20type V t = CInt -> Ptr Double -> t
21type SMxV = V (IV (IV (V (V (IO CInt)))))
22
23dss :: CSR -> Vector Double -> Vector Double
24dss CSR{..} b = unsafePerformIO $ do
25 size b /= csrNRows ~!~ printf "dss: incorrect sizes: (%d,%d) x %d" csrNRows csrNCols (size b)
26 r <- createVector csrNCols
27 app5 c_dss vec csrVals vec csrCols vec csrRows vec b vec r "dss"
28 return r
29
30foreign import ccall unsafe "dss"
31 c_dss :: SMxV
32
diff --git a/packages/sparse/src/Numeric/LinearAlgebra/sparse.c b/packages/sparse/src/Numeric/LinearAlgebra/sparse.c
new file mode 100644
index 0000000..b1e257a
--- /dev/null
+++ b/packages/sparse/src/Numeric/LinearAlgebra/sparse.c
@@ -0,0 +1,65 @@
1
2#include <stdio.h>
3#include <stdlib.h>
4#include <math.h>
5
6#include "mkl_dss.h"
7#include "mkl_types.h"
8#include "mkl_spblas.h"
9
10#define KIVEC(A) int A##n, const int*A##p
11#define KDVEC(A) int A##n, const double*A##p
12#define DVEC(A) int A##n, double*A##p
13#define OK return 0;
14
15
16void check_error(int error)
17{
18 if(error != MKL_DSS_SUCCESS) {
19 printf ("Solver returned error code %d\n", error);
20 exit (1);
21 }
22}
23
24int dss(KDVEC(vals),KIVEC(cols),KIVEC(rows),KDVEC(x),DVEC(r)) {
25 MKL_INT nRows = rowsn-1, nCols = rn, nNonZeros = valsn, nRhs = 1;
26 MKL_INT *rowIndex = (MKL_INT*) rowsp;
27 MKL_INT *columns = (MKL_INT*) colsp;
28 double *values = (double*) valsp;
29 _DOUBLE_PRECISION_t *rhs = (_DOUBLE_PRECISION_t*) xp;
30// _DOUBLE_PRECISION_t *obtrhs = (_DOUBLE_PRECISION_t*) malloc((nCols)*sizeof(_DOUBLE_PRECISION_t));
31 _DOUBLE_PRECISION_t *solValues = (_DOUBLE_PRECISION_t*) rp;
32
33 _MKL_DSS_HANDLE_t handle;
34 _INTEGER_t error;
35// _CHARACTER_t *uplo;
36 MKL_INT opt;
37
38 opt = MKL_DSS_DEFAULTS;
39 error = dss_create(handle, opt);
40 check_error(error);
41
42 opt = MKL_DSS_NON_SYMMETRIC;
43 error = dss_define_structure(handle, opt, rowIndex, nRows, nCols, columns, nNonZeros);
44 check_error(error);
45
46 opt = MKL_DSS_DEFAULTS;
47 error = dss_reorder(handle, opt, 0);
48 check_error(error);
49
50 opt = MKL_DSS_INDEFINITE;
51 error = dss_factor_real(handle, opt, values);
52 check_error(error);
53
54 int j;
55 for (j = 0; j < nCols; j++) {
56 solValues[j] = 0.0;
57 }
58
59 // Solve system
60 opt = MKL_DSS_REFINEMENT_ON;
61 error = dss_solve_real(handle, opt, rhs, nRhs, solValues);
62 check_error(error);
63
64 OK
65}