diff options
Diffstat (limited to 'packages/base/src')
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Data.hs | 6 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Real.hs | 10 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Util.hs | 7 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | 2 | ||||
-rw-r--r-- | packages/base/src/Numeric/Sparse.hs | 92 |
5 files changed, 82 insertions, 35 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Data.hs b/packages/base/src/Numeric/LinearAlgebra/Data.hs index 3417a5e..33a2c9a 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Data.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Data.hs | |||
@@ -49,8 +49,8 @@ module Numeric.LinearAlgebra.Data( | |||
49 | find, maxIndex, minIndex, maxElement, minElement, atIndex, | 49 | find, maxIndex, minIndex, maxElement, minElement, atIndex, |
50 | 50 | ||
51 | -- * Sparse | 51 | -- * Sparse |
52 | GMatrix, AssocMatrix, mkSparse, toDense, | 52 | AssocMatrix, toDense, |
53 | mkDiagR, dense, | 53 | mkSparse, mkDiagR, mkDense, |
54 | 54 | ||
55 | -- * IO | 55 | -- * IO |
56 | disp, | 56 | disp, |
@@ -68,7 +68,7 @@ module Numeric.LinearAlgebra.Data( | |||
68 | 68 | ||
69 | module Data.Complex, | 69 | module Data.Complex, |
70 | 70 | ||
71 | Vector, Matrix | 71 | Vector, Matrix, GMatrix, CSR(..), mkCSR |
72 | 72 | ||
73 | ) where | 73 | ) where |
74 | 74 | ||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs index db15705..1e8b544 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Real.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs | |||
@@ -141,6 +141,16 @@ instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t))) | |||
141 | negate = (lift1F . lift1F) negate | 141 | negate = (lift1F . lift1F) negate |
142 | fromInteger x = Dim (Dim (fromInteger x)) | 142 | fromInteger x = Dim (Dim (fromInteger x)) |
143 | 143 | ||
144 | instance Fractional (Dim n (Vector Double)) | ||
145 | where | ||
146 | fromRational x = Dim (fromRational x) | ||
147 | (/) = lift2F (/) | ||
148 | |||
149 | instance Fractional (Dim m (Dim n (Matrix Double))) | ||
150 | where | ||
151 | fromRational x = Dim (Dim (fromRational x)) | ||
152 | (/) = (lift2F.lift2F) (/) | ||
153 | |||
144 | -------------------------------------------------------------------------------- | 154 | -------------------------------------------------------------------------------- |
145 | 155 | ||
146 | class Konst t | 156 | class Konst t |
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util.hs b/packages/base/src/Numeric/LinearAlgebra/Util.hs index 47b1090..aee21b8 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Util.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Util.hs | |||
@@ -37,6 +37,7 @@ module Numeric.LinearAlgebra.Util( | |||
37 | mnorm_1, mnorm_2, mnorm_0, mnorm_Inf, | 37 | mnorm_1, mnorm_2, mnorm_0, mnorm_Inf, |
38 | unitary, | 38 | unitary, |
39 | mt, | 39 | mt, |
40 | (~!~), | ||
40 | pairwiseD2, | 41 | pairwiseD2, |
41 | rowOuters, | 42 | rowOuters, |
42 | null1, | 43 | null1, |
@@ -65,6 +66,7 @@ import Numeric.Matrix() | |||
65 | import Numeric.Vector() | 66 | import Numeric.Vector() |
66 | import Numeric.LinearAlgebra.Random | 67 | import Numeric.LinearAlgebra.Random |
67 | import Numeric.LinearAlgebra.Util.Convolution | 68 | import Numeric.LinearAlgebra.Util.Convolution |
69 | import Control.Monad(when) | ||
68 | 70 | ||
69 | type ℝ = Double | 71 | type ℝ = Double |
70 | type ℕ = Int | 72 | type ℕ = Int |
@@ -385,3 +387,8 @@ vtrans p m | r == 0 = fromBlocks . map (map asColumn . takesV (replicate q p)) . | |||
385 | where | 387 | where |
386 | (q,r) = divMod (rows m) p | 388 | (q,r) = divMod (rows m) p |
387 | 389 | ||
390 | -------------------------------------------------------------------------------- | ||
391 | |||
392 | infixl 0 ~!~ | ||
393 | c ~!~ msg = when c (error msg) | ||
394 | |||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs index 50372f1..f821b57 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | |||
@@ -115,7 +115,7 @@ instance Testable GMatrix | |||
115 | sma = convo2 20 3 | 115 | sma = convo2 20 3 |
116 | x1 = vect [1..20] | 116 | x1 = vect [1..20] |
117 | x2 = vect [1..40] | 117 | x2 = vect [1..40] |
118 | sm = mkSparse sma | 118 | sm = (mkSparse . mkCSR) sma |
119 | dm = toDense sma | 119 | dm = toDense sma |
120 | 120 | ||
121 | s1 = sm !#> x1 | 121 | s1 = sm !#> x1 |
diff --git a/packages/base/src/Numeric/Sparse.hs b/packages/base/src/Numeric/Sparse.hs index 4d05bdc..3c19c93 100644 --- a/packages/base/src/Numeric/Sparse.hs +++ b/packages/base/src/Numeric/Sparse.hs | |||
@@ -3,8 +3,8 @@ | |||
3 | {-# LANGUAGE FlexibleInstances #-} | 3 | {-# LANGUAGE FlexibleInstances #-} |
4 | 4 | ||
5 | module Numeric.Sparse( | 5 | module Numeric.Sparse( |
6 | GMatrix(..), | 6 | GMatrix, CSR(..), mkCSR, |
7 | mkSparse, mkDiagR, dense, | 7 | mkSparse, mkDiagR, mkDense, |
8 | AssocMatrix, | 8 | AssocMatrix, |
9 | toDense, | 9 | toDense, |
10 | gmXv, (!#>) | 10 | gmXv, (!#>) |
@@ -28,18 +28,49 @@ c ~!~ msg = when c (error msg) | |||
28 | 28 | ||
29 | type AssocMatrix = [((Int,Int),Double)] | 29 | type AssocMatrix = [((Int,Int),Double)] |
30 | 30 | ||
31 | data CSR = CSR | ||
32 | { csrVals :: Vector Double | ||
33 | , csrCols :: Vector CInt | ||
34 | , csrRows :: Vector CInt | ||
35 | , csrNRows :: Int | ||
36 | , csrNCols :: Int | ||
37 | } deriving Show | ||
38 | |||
39 | data CSC = CSC | ||
40 | { cscVals :: Vector Double | ||
41 | , cscRows :: Vector CInt | ||
42 | , cscCols :: Vector CInt | ||
43 | , cscNRows :: Int | ||
44 | , cscNCols :: Int | ||
45 | } deriving Show | ||
46 | |||
47 | |||
48 | mkCSR :: AssocMatrix -> CSR | ||
49 | mkCSR sm' = CSR{..} | ||
50 | where | ||
51 | sm = sort sm' | ||
52 | rws = map ((fromList *** fromList) | ||
53 | . unzip | ||
54 | . map ((succ.fi.snd) *** id) | ||
55 | ) | ||
56 | . groupBy ((==) `on` (fst.fst)) | ||
57 | $ sm | ||
58 | rszs = map (fi . dim . fst) rws | ||
59 | csrRows = fromList (scanl (+) 1 rszs) | ||
60 | csrVals = vjoin (map snd rws) | ||
61 | csrCols = vjoin (map fst rws) | ||
62 | csrNRows = dim csrRows - 1 | ||
63 | csrNCols = fromIntegral (V.maximum csrCols) | ||
64 | |||
65 | |||
31 | data GMatrix | 66 | data GMatrix |
32 | = CSR | 67 | = SparseR |
33 | { csrVals :: Vector Double | 68 | { gmCSR :: CSR |
34 | , csrCols :: Vector CInt | ||
35 | , csrRows :: Vector CInt | ||
36 | , nRows :: Int | 69 | , nRows :: Int |
37 | , nCols :: Int | 70 | , nCols :: Int |
38 | } | 71 | } |
39 | | CSC | 72 | | SparseC |
40 | { cscVals :: Vector Double | 73 | { gmCSC :: CSC |
41 | , cscRows :: Vector CInt | ||
42 | , cscCols :: Vector CInt | ||
43 | , nRows :: Int | 74 | , nRows :: Int |
44 | , nCols :: Int | 75 | , nCols :: Int |
45 | } | 76 | } |
@@ -56,29 +87,21 @@ data GMatrix | |||
56 | -- | Banded | 87 | -- | Banded |
57 | deriving Show | 88 | deriving Show |
58 | 89 | ||
59 | dense :: Matrix Double -> GMatrix | 90 | |
60 | dense m = Dense{..} | 91 | mkDense :: Matrix Double -> GMatrix |
92 | mkDense m = Dense{..} | ||
61 | where | 93 | where |
62 | gmDense = m | 94 | gmDense = m |
63 | nRows = rows m | 95 | nRows = rows m |
64 | nCols = cols m | 96 | nCols = cols m |
65 | 97 | ||
66 | mkSparse :: AssocMatrix -> GMatrix | 98 | |
67 | mkSparse sm' = CSR{..} | 99 | mkSparse :: CSR -> GMatrix |
100 | mkSparse csr = SparseR {..} | ||
68 | where | 101 | where |
69 | sm = sort sm' | 102 | gmCSR @ CSR {..} = csr |
70 | rws = map ((fromList *** fromList) | 103 | nRows = csrNRows |
71 | . unzip | 104 | nCols = csrNCols |
72 | . map ((succ.fi.snd) *** id) | ||
73 | ) | ||
74 | . groupBy ((==) `on` (fst.fst)) | ||
75 | $ sm | ||
76 | rszs = map (fi . dim . fst) rws | ||
77 | csrRows = fromList (scanl (+) 1 rszs) | ||
78 | csrVals = vjoin (map snd rws) | ||
79 | csrCols = vjoin (map fst rws) | ||
80 | nRows = dim csrRows - 1 | ||
81 | nCols = fromIntegral (V.maximum csrCols) | ||
82 | 105 | ||
83 | 106 | ||
84 | mkDiagR r c v | 107 | mkDiagR r c v |
@@ -95,13 +118,13 @@ type V t = CInt -> Ptr Double -> t | |||
95 | type SMxV = V (IV (IV (V (V (IO CInt))))) | 118 | type SMxV = V (IV (IV (V (V (IO CInt))))) |
96 | 119 | ||
97 | gmXv :: GMatrix -> Vector Double -> Vector Double | 120 | gmXv :: GMatrix -> Vector Double -> Vector Double |
98 | gmXv CSR{..} v = unsafePerformIO $ do | 121 | gmXv SparseR { gmCSR = CSR{..}, .. } v = unsafePerformIO $ do |
99 | dim v /= nCols ~!~ printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) | 122 | dim v /= nCols ~!~ printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) |
100 | r <- createVector nRows | 123 | r <- createVector nRows |
101 | app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv" | 124 | app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv" |
102 | return r | 125 | return r |
103 | 126 | ||
104 | gmXv CSC{..} v = unsafePerformIO $ do | 127 | gmXv SparseC { gmCSC = CSC{..}, .. } v = unsafePerformIO $ do |
105 | dim v /= nCols ~!~ printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) | 128 | dim v /= nCols ~!~ printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) |
106 | r <- createVector nRows | 129 | r <- createVector nRows |
107 | app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv" | 130 | app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv" |
@@ -147,11 +170,18 @@ toDense asm = assoc (r+1,c+1) 0 asm | |||
147 | (r,c) = (maximum *** maximum) . unzip . map fst $ asm | 170 | (r,c) = (maximum *** maximum) . unzip . map fst $ asm |
148 | 171 | ||
149 | 172 | ||
150 | 173 | instance Transposable CSR CSC | |
151 | instance Transposable GMatrix GMatrix | ||
152 | where | 174 | where |
153 | tr (CSR vs cs rs n m) = CSC vs cs rs m n | 175 | tr (CSR vs cs rs n m) = CSC vs cs rs m n |
176 | |||
177 | instance Transposable CSC CSR | ||
178 | where | ||
154 | tr (CSC vs rs cs n m) = CSR vs rs cs m n | 179 | tr (CSC vs rs cs n m) = CSR vs rs cs m n |
180 | |||
181 | instance Transposable GMatrix GMatrix | ||
182 | where | ||
183 | tr (SparseR s n m) = SparseC (tr s) m n | ||
184 | tr (SparseC s n m) = SparseR (tr s) m n | ||
155 | tr (Diag v n m) = Diag v m n | 185 | tr (Diag v n m) = Diag v m n |
156 | tr (Dense a n m) = Dense (tr a) m n | 186 | tr (Dense a n m) = Dense (tr a) m n |
157 | 187 | ||