summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2014-05-22 11:49:23 +0200
committerAlberto Ruiz <aruiz@um.es>2014-05-22 11:49:23 +0200
commit5158a1717f1d4caee25669a0781602fe64787302 (patch)
tree2846767b931e3a3429e0f2c5aed709c6800c94cc
parent3916d70b9d170633c6122cb3c46000f0b3f32018 (diff)
initial support for sparse matrix
-rw-r--r--packages/base/hmatrix.cabal1
-rw-r--r--packages/base/src/C/lapack-aux.h2
-rw-r--r--packages/base/src/C/vector-aux.c26
-rw-r--r--packages/base/src/Data/Packed/Internal/Numeric.hs23
-rw-r--r--packages/base/src/Data/Packed/Internal/Sparse.hs191
-rw-r--r--packages/base/src/Numeric/Container.hs37
-rw-r--r--packages/base/src/Numeric/LinearAlgebra.hs14
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Data.hs7
8 files changed, 278 insertions, 23 deletions
diff --git a/packages/base/hmatrix.cabal b/packages/base/hmatrix.cabal
index 1948897..06c1a3c 100644
--- a/packages/base/hmatrix.cabal
+++ b/packages/base/hmatrix.cabal
@@ -48,6 +48,7 @@ library
48 Numeric.LinearAlgebra.Data 48 Numeric.LinearAlgebra.Data
49 49
50 Numeric.LinearAlgebra.Compat 50 Numeric.LinearAlgebra.Compat
51 Data.Packed.Internal.Sparse
51 52
52 other-modules: Data.Packed.Internal, 53 other-modules: Data.Packed.Internal,
53 Data.Packed.Internal.Common, 54 Data.Packed.Internal.Common,
diff --git a/packages/base/src/C/lapack-aux.h b/packages/base/src/C/lapack-aux.h
index a3f1899..c95a2a3 100644
--- a/packages/base/src/C/lapack-aux.h
+++ b/packages/base/src/C/lapack-aux.h
@@ -36,6 +36,7 @@ typedef short ftnlen;
36 36
37/********************************************************/ 37/********************************************************/
38 38
39#define IVEC(A) int A##n, int*A##p
39#define FVEC(A) int A##n, float*A##p 40#define FVEC(A) int A##n, float*A##p
40#define DVEC(A) int A##n, double*A##p 41#define DVEC(A) int A##n, double*A##p
41#define QVEC(A) int A##n, complex*A##p 42#define QVEC(A) int A##n, complex*A##p
@@ -47,6 +48,7 @@ typedef short ftnlen;
47#define CMAT(A) int A##r, int A##c, doublecomplex* A##p 48#define CMAT(A) int A##r, int A##c, doublecomplex* A##p
48#define PMAT(A) int A##r, int A##c, void* A##p, int A##s 49#define PMAT(A) int A##r, int A##c, void* A##p, int A##s
49 50
51#define KIVEC(A) int A##n, const int*A##p
50#define KFVEC(A) int A##n, const float*A##p 52#define KFVEC(A) int A##n, const float*A##p
51#define KDVEC(A) int A##n, const double*A##p 53#define KDVEC(A) int A##n, const double*A##p
52#define KQVEC(A) int A##n, const complex*A##p 54#define KQVEC(A) int A##n, const complex*A##p
diff --git a/packages/base/src/C/vector-aux.c b/packages/base/src/C/vector-aux.c
index 5b9c171..53b56aa 100644
--- a/packages/base/src/C/vector-aux.c
+++ b/packages/base/src/C/vector-aux.c
@@ -744,3 +744,29 @@ int random_vector(int seed, int code, DVEC(r)) {
744 } 744 }
745} 745}
746 746
747////////////////////////////////////////////////////////////////////////////////
748
749int smXv(KDVEC(vals),KIVEC(cols),KIVEC(rows),KDVEC(x),DVEC(r)) {
750 int r, c;
751 for (r = 0; r < rowsn - 1; r++) {
752 rp[r] = 0;
753 for (c = rowsp[r]; c < rowsp[r+1]; c++) {
754 rp[r] += valsp[c-1] * xp[colsp[c-1]-1];
755 }
756 }
757 OK
758}
759
760int smTXv(KDVEC(vals),KIVEC(cols),KIVEC(rows),KDVEC(x),DVEC(r)) {
761 int r,c;
762 for (c = 0; c < rn; c++) {
763 rp[c] = 0;
764 }
765 for (r = 0; r < rowsn - 1; r++) {
766 for (c = rowsp[r]; c < rowsp[r+1]; c++) {
767 rp[colsp[c-1]-1] += valsp[c-1] * xp[r];
768 }
769 }
770 OK
771}
772
diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Data/Packed/Internal/Numeric.hs
index 81a8083..3528e96 100644
--- a/packages/base/src/Data/Packed/Internal/Numeric.hs
+++ b/packages/base/src/Data/Packed/Internal/Numeric.hs
@@ -20,6 +20,7 @@ module Data.Packed.Internal.Numeric (
20 ident, diag, ctrans, 20 ident, diag, ctrans,
21 -- * Generic operations 21 -- * Generic operations
22 Container(..), 22 Container(..),
23 Transposable(..), Linear(..), Testable(..),
23 -- * Matrix product and related functions 24 -- * Matrix product and related functions
24 Product(..), udot, 25 Product(..), udot,
25 mXm,mXv,vXm, 26 mXm,mXv,vXm,
@@ -605,3 +606,25 @@ condV f a b l e t = f a' b' l' e' t'
605 where 606 where
606 [a', b', l', e', t'] = conformVs [a,b,l,e,t] 607 [a', b', l', e', t'] = conformVs [a,b,l,e,t]
607 608
609--------------------------------------------------------------------------------
610
611class Transposable t
612 where
613 tr :: t -> t
614
615
616class Linear t v
617 where
618 scalarL :: t -> v
619 addL :: v -> v -> v
620 scaleL :: t -> v -> v
621
622
623class Testable t
624 where
625 checkT :: t -> (Bool, IO())
626 ioCheckT :: t -> IO (Bool, IO())
627 ioCheckT = return . checkT
628
629--------------------------------------------------------------------------------
630
diff --git a/packages/base/src/Data/Packed/Internal/Sparse.hs b/packages/base/src/Data/Packed/Internal/Sparse.hs
new file mode 100644
index 0000000..544c913
--- /dev/null
+++ b/packages/base/src/Data/Packed/Internal/Sparse.hs
@@ -0,0 +1,191 @@
1{-# LANGUAGE RecordWildCards #-}
2{-# LANGUAGE MultiParamTypeClasses #-}
3{-# LANGUAGE FlexibleInstances #-}
4
5
6
7module Data.Packed.Internal.Sparse(
8 SMatrix(..),
9 mkCSR, mkDiag,
10 AssocMatrix,
11 toDense,
12 smXv
13)where
14
15import Numeric.Container
16import qualified Data.Vector.Storable as V
17import Data.Function(on)
18import Control.Arrow((***))
19import Control.Monad(when)
20import Data.List(groupBy, sort)
21import Foreign.C.Types(CInt(..))
22import Numeric.LinearAlgebra.Devel
23import System.IO.Unsafe(unsafePerformIO)
24import Foreign(Ptr)
25import Text.Printf(printf)
26
27infixl 0 ~!~
28c ~!~ msg = when c (error msg)
29
30type AssocMatrix = [((Int,Int),Double)]
31
32data SMatrix
33 = CSR
34 { csrVals :: Vector Double
35 , csrCols :: Vector CInt
36 , csrRows :: Vector CInt
37 , nRows :: Int
38 , nCols :: Int
39 }
40 | CSC
41 { cscVals :: Vector Double
42 , cscRows :: Vector CInt
43 , cscCols :: Vector CInt
44 , nRows :: Int
45 , nCols :: Int
46 }
47 | Diag
48 { diagVals :: Vector Double
49 , nRows :: Int
50 , nCols :: Int
51 }
52-- | Banded
53 deriving Show
54
55mkCSR :: AssocMatrix -> SMatrix
56mkCSR sm' = CSR{..}
57 where
58 sm = sort sm'
59 rws = map ((fromList *** fromList)
60 . unzip
61 . map ((succ.fi.snd) *** id)
62 )
63 . groupBy ((==) `on` (fst.fst))
64 $ sm
65 rszs = map (fi . dim . fst) rws
66 csrRows = fromList (scanl (+) 1 rszs)
67 csrVals = vjoin (map snd rws)
68 csrCols = vjoin (map fst rws)
69 nRows = dim csrRows - 1
70 nCols = fromIntegral (V.maximum csrCols)
71
72
73mkDiagR r c v
74 | dim v <= min r c = Diag{..}
75 | otherwise = error $ printf "mkDiagR: incorrect sizes (%d,%d) [%d]" r c (dim v)
76 where
77 nRows = r
78 nCols = c
79 diagVals = v
80
81mkDiag v = mkDiagR (dim v) (dim v) v
82
83
84type IV t = CInt -> Ptr CInt -> t
85type V t = CInt -> Ptr Double -> t
86type SMxV = V (IV (IV (V (V (IO CInt)))))
87
88smXv :: SMatrix -> Vector Double -> Vector Double
89smXv CSR{..} v = unsafePerformIO $ do
90 dim v /= nCols ~!~ printf "smXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v)
91 r <- createVector nRows
92 app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv"
93 return r
94
95smXv CSC{..} v = unsafePerformIO $ do
96 dim v /= nCols ~!~ printf "smXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v)
97 r <- createVector nRows
98 app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv"
99 return r
100
101smXv Diag{..} v
102 | dim v == nCols
103 = vjoin [ subVector 0 (dim diagVals) v `mul` diagVals
104 , konst 0 (nRows - dim diagVals) ]
105 | otherwise = error $ printf "smXv (Diag): incorrect sizes: (%d,%d) [%d] x %d"
106 nRows nCols (dim diagVals) (dim v)
107
108
109instance Contraction SMatrix (Vector Double) (Vector Double)
110 where
111 contraction = smXv
112
113--------------------------------------------------------------------------------
114
115foreign import ccall unsafe "smXv"
116 c_smXv :: SMxV
117
118foreign import ccall unsafe "smTXv"
119 c_smTXv :: SMxV
120
121--------------------------------------------------------------------------------
122
123toDense :: AssocMatrix -> Matrix Double
124toDense asm = assoc (r+1,c+1) 0 asm
125 where
126 (r,c) = (maximum *** maximum) . unzip . map fst $ asm
127
128
129
130instance Transposable (SMatrix)
131 where
132 tr (CSR vs cs rs n m) = CSC vs cs rs m n
133 tr (CSC vs rs cs n m) = CSR vs rs cs m n
134 tr (Diag v n m) = Diag v m n
135
136instance Transposable (Matrix Double)
137 where
138 tr = trans
139
140
141
142--------------------------------------------------------------------------------
143
144instance Testable SMatrix
145 where
146 checkT _ = (ok,info)
147 where
148 sma = convo2 20 3
149 x1 = vect [1..20]
150 x2 = vect [1..40]
151 sm = mkCSR sma
152
153 s1 = sm ◇ x1
154 d1 = toDense sma ◇ x1
155
156 s2 = tr sm ◇ x2
157 d2 = tr (toDense sma) ◇ x2
158
159 sdia = mkDiagR 40 20 (vect [1..10])
160 s3 = sdia ◇ x1
161 s4 = tr sdia ◇ x2
162 ddia = diagRect 0 (vect [1..10]) 40 20
163 d3 = ddia ◇ x1
164 d4 = tr ddia ◇ x2
165
166 info = do
167 print sm
168 disp (toDense sma)
169 print s1; print d1
170 print s2; print d2
171 print s3; print d3
172 print s4; print d4
173
174 ok = s1==d1
175 && s2==d2
176 && s3==d3
177 && s4==d4
178
179 disp = putStr . dispf 2
180
181 vect = fromList :: [Double] -> Vector Double
182
183 convomat :: Int -> Int -> AssocMatrix
184 convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]]
185
186 convo2 :: Int -> Int -> AssocMatrix
187 convo2 n k = m1 ++ m2
188 where
189 m1 = convomat n k
190 m2 = map (((+n) *** id) *** id) m1
191
diff --git a/packages/base/src/Numeric/Container.hs b/packages/base/src/Numeric/Container.hs
index 264a619..0633640 100644
--- a/packages/base/src/Numeric/Container.hs
+++ b/packages/base/src/Numeric/Container.hs
@@ -32,11 +32,11 @@ module Numeric.Container (
32 diag, ident, 32 diag, ident,
33 ctrans, 33 ctrans,
34 -- * Generic operations 34 -- * Generic operations
35 Container(..), 35 Container(..), Transposable(..), Linear(..),
36 -- * Matrix product 36 -- * Matrix product
37 Product(..), udot, dot, (◇), 37 Product(..), udot, dot, (◇),
38 Mul(..), 38 Mul(..),
39 Contraction(..), 39 Contraction(..),(<.>),
40 optimiseMult, 40 optimiseMult,
41 mXm,mXv,vXm,LSDiv(..), 41 mXm,mXv,vXm,LSDiv(..),
42 outer, kronecker, 42 outer, kronecker,
@@ -55,7 +55,9 @@ module Numeric.Container (
55 IndexOf, 55 IndexOf,
56 module Data.Complex, 56 module Data.Complex,
57 -- * IO 57 -- * IO
58 module Data.Packed.IO 58 module Data.Packed.IO,
59 -- * Misc
60 Testable(..)
59) where 61) where
60 62
61import Data.Packed hiding (stepD, stepF, condD, condF, conjugateC, conjugateQ) 63import Data.Packed hiding (stepD, stepF, condD, condF, conjugateC, conjugateQ)
@@ -87,10 +89,9 @@ linspace n (a,b) = addConstant a $ scale s $ fromList $ map fromIntegral [0 .. n
87 89
88-------------------------------------------------------- 90--------------------------------------------------------
89 91
90class Contraction a b c | a b -> c 92{- | Matrix product, matrix - vector product, and dot product (equivalent to 'contraction')
91 where 93
92 infixl 7 <.> 94(This operator can also be written using the unicode symbol ◇ (25c7).)
93 {- | Matrix product, matrix - vector product, and dot product
94 95
95Examples: 96Examples:
96 97
@@ -129,22 +130,28 @@ For complex vectors the first argument is conjugated:
129 130
130>>> fromList [1,i,1-i] <.> complex a 131>>> fromList [1,i,1-i] <.> complex a
131fromList [10.0 :+ 4.0,12.0 :+ 4.0,14.0 :+ 4.0,16.0 :+ 4.0] 132fromList [10.0 :+ 4.0,12.0 :+ 4.0,14.0 :+ 4.0,16.0 :+ 4.0]
132
133-} 133-}
134 (<.>) :: a -> b -> c 134infixl 7 <.>
135(<.>) :: Contraction a b c => a -> b -> c
136(<.>) = contraction
135 137
136 138
139class Contraction a b c | a b -> c
140 where
141 -- | Matrix product, matrix - vector product, and dot product
142 contraction :: a -> b -> c
143
137instance (Product t, Container Vector t) => Contraction (Vector t) (Vector t) t where 144instance (Product t, Container Vector t) => Contraction (Vector t) (Vector t) t where
138 u <.> v = conj u `udot` v 145 u `contraction` v = conj u `udot` v
139 146
140instance Product t => Contraction (Matrix t) (Vector t) (Vector t) where 147instance Product t => Contraction (Matrix t) (Vector t) (Vector t) where
141 (<.>) = mXv 148 contraction = mXv
142 149
143instance (Container Vector t, Product t) => Contraction (Vector t) (Matrix t) (Vector t) where 150instance (Container Vector t, Product t) => Contraction (Vector t) (Matrix t) (Vector t) where
144 (<.>) v m = (conj v) `vXm` m 151 contraction v m = (conj v) `vXm` m
145 152
146instance Product t => Contraction (Matrix t) (Matrix t) (Matrix t) where 153instance Product t => Contraction (Matrix t) (Matrix t) (Matrix t) where
147 (<.>) = mXm 154 contraction = mXm
148 155
149 156
150-------------------------------------------------------------------------------- 157--------------------------------------------------------------------------------
@@ -229,10 +236,10 @@ instance Container Matrix e => Build (Int,Int) (e -> e -> e) Matrix e
229 236
230-------------------------------------------------------------------------------- 237--------------------------------------------------------------------------------
231 238
232-- | alternative unicode symbol (25c7) for the contraction operator '(\<.\>)' 239-- | alternative unicode symbol (25c7) for 'contraction'
233(◇) :: Contraction a b c => a -> b -> c 240(◇) :: Contraction a b c => a -> b -> c
234infixl 7 ◇ 241infixl 7 ◇
235(◇) = (<.>) 242(◇) = contraction
236 243
237-- | dot product: @cdot u v = 'udot' ('conj' u) v@ 244-- | dot product: @cdot u v = 'udot' ('conj' u) v@
238dot :: (Container Vector t, Product t) => Vector t -> Vector t -> t 245dot :: (Container Vector t, Product t) => Vector t -> Vector t -> t
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs
index 96bf29f..549ebd0 100644
--- a/packages/base/src/Numeric/LinearAlgebra.hs
+++ b/packages/base/src/Numeric/LinearAlgebra.hs
@@ -37,11 +37,11 @@ module Numeric.LinearAlgebra (
37 37
38 -- * Matrix product 38 -- * Matrix product
39 (<.>), 39 (<.>),
40 40
41 -- | This operator can also be written using the unicode symbol ◇ (25c7). 41 -- | The overloaded multiplication operator may need type annotations to remove
42 -- ambiguity. In those cases we can also use the specific functions 'mXm', 'mXv', and 'dot'.
42 -- 43 --
43 44 -- The matrix x matrix product is also implemented in the "Data.Monoid" instance, where
44 -- | The matrix x matrix product is also implemented in the "Data.Monoid" instance, where
45 -- single-element matrices (created from numeric literals or using 'scalar') 45 -- single-element matrices (created from numeric literals or using 'scalar')
46 -- are used for scaling. 46 -- are used for scaling.
47 -- 47 --
@@ -52,6 +52,7 @@ module Numeric.LinearAlgebra (
52 -- , 4.0, 10.0, 0.0 ] 52 -- , 4.0, 10.0, 0.0 ]
53 -- 53 --
54 -- 'mconcat' uses 'optimiseMult' to get the optimal association order. 54 -- 'mconcat' uses 'optimiseMult' to get the optimal association order.
55
55 56
56 -- * Other products 57 -- * Other products
57 outer, kronecker, cross, 58 outer, kronecker, cross,
@@ -125,7 +126,7 @@ module Numeric.LinearAlgebra (
125 RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample, 126 RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample,
126 127
127 -- * Misc 128 -- * Misc
128 meanCov, peps, relativeError, haussholder, optimiseMult, udot, Seed, (◇) 129 meanCov, peps, relativeError, haussholder, optimiseMult, dot, udot, mXm, mXv, smXv, (<>), (◇), Seed, checkT
129) where 130) where
130 131
131import Numeric.LinearAlgebra.Data 132import Numeric.LinearAlgebra.Data
@@ -136,6 +137,5 @@ import Numeric.Container
136import Numeric.LinearAlgebra.Algorithms 137import Numeric.LinearAlgebra.Algorithms
137import Numeric.LinearAlgebra.Util 138import Numeric.LinearAlgebra.Util
138import Numeric.LinearAlgebra.Random 139import Numeric.LinearAlgebra.Random
139 140import Data.Packed.Internal.Sparse(smXv)
140
141 141
diff --git a/packages/base/src/Numeric/LinearAlgebra/Data.hs b/packages/base/src/Numeric/LinearAlgebra/Data.hs
index 7e8af03..49bc1c0 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Data.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Data.hs
@@ -48,6 +48,10 @@ module Numeric.LinearAlgebra.Data(
48 latexFormat, 48 latexFormat,
49 dispf, disps, dispcf, format, 49 dispf, disps, dispcf, format,
50 50
51 -- * Sparse
52 SMatrix, AssocMatrix, mkCSR, toDense,
53 mkDiag,
54
51-- * Conversion 55-- * Conversion
52 Convert(..), 56 Convert(..),
53 57
@@ -56,7 +60,7 @@ module Numeric.LinearAlgebra.Data(
56 rows, cols, 60 rows, cols,
57 separable, 61 separable,
58 62
59 module Data.Complex 63 module Data.Complex,
60 64
61) where 65) where
62 66
@@ -65,4 +69,5 @@ import Data.Packed.Matrix
65import Numeric.Container 69import Numeric.Container
66import Numeric.LinearAlgebra.Util 70import Numeric.LinearAlgebra.Util
67import Data.Complex 71import Data.Complex
72import Data.Packed.Internal.Sparse
68 73