diff options
author | Alberto Ruiz <aruiz@um.es> | 2007-09-09 15:45:06 +0000 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2007-09-09 15:45:06 +0000 |
commit | 631a32fbdc0d61f647d3217da86bcb1d552e5e5a (patch) | |
tree | 2b81e4e6b9d7ae646787bc3ff8684b2d759b24c5 /lib | |
parent | 34380f2b5d7b048a4d68197f16a8db0e53742030 (diff) |
simplified (but wrong)
Diffstat (limited to 'lib')
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 131 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Tensor.hs | 2 | ||||
-rw-r--r-- | lib/Data/Packed/Matrix.hs | 4 | ||||
-rw-r--r-- | lib/GSL.hs | 2 | ||||
-rw-r--r-- | lib/GSL/Compat.hs | 3 | ||||
-rw-r--r-- | lib/GSL/Matrix.hs | 42 | ||||
-rw-r--r-- | lib/LAPACK.hs | 73 |
7 files changed, 188 insertions, 69 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 48652f3..6ba2d06 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -27,6 +27,10 @@ import Data.Maybe(fromJust) | |||
27 | 27 | ||
28 | ---------------------------------------------------------------- | 28 | ---------------------------------------------------------------- |
29 | 29 | ||
30 | -- the condition Storable a => Field a means that we can only put | ||
31 | -- in Field types that are in Storable, and therefore Storable a | ||
32 | -- is not required in signatures if we have a Field a. | ||
33 | |||
30 | class Storable a => Field a where | 34 | class Storable a => Field a where |
31 | constant :: a -> Int -> Vector a | 35 | constant :: a -> Int -> Vector a |
32 | transdata :: Int -> Vector a -> Int -> Vector a | 36 | transdata :: Int -> Vector a -> Int -> Vector a |
@@ -36,7 +40,6 @@ class Storable a => Field a where | |||
36 | -> Matrix a -> Matrix a | 40 | -> Matrix a -> Matrix a |
37 | diag :: Vector a -> Matrix a | 41 | diag :: Vector a -> Matrix a |
38 | 42 | ||
39 | |||
40 | instance Field Double where | 43 | instance Field Double where |
41 | constant = constantR | 44 | constant = constantR |
42 | transdata = transdataR | 45 | transdata = transdataR |
@@ -78,12 +81,40 @@ foreign import ccall safe "aux.h transC" | |||
78 | 81 | ||
79 | transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d | 82 | transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d |
80 | 83 | ||
84 | {- Design considerations for the Matrix Type | ||
85 | ----------------------------------------- | ||
86 | |||
87 | - we must easily handle both row major and column major order, | ||
88 | for bindings to LAPACK and GSL/C | ||
89 | |||
90 | - we'd like to simplify redundant matrix transposes: | ||
91 | - Some of them arise from the order requirements of some functions | ||
92 | - some functions (matrix product) admit transposed arguments | ||
81 | 93 | ||
94 | - maybe we don't really need this kind of simplification: | ||
95 | - more complex code | ||
96 | - some computational overhead | ||
97 | - only appreciable gain in code with a lot of redundant transpositions | ||
98 | and cheap matrix computations | ||
82 | 99 | ||
100 | - we could carry both the matrix and its (lazily computed) transpose. | ||
101 | This may save some transpositions, but it is necessary to keep track of the | ||
102 | data which is actually computed to be used by functions like the matrix product | ||
103 | which admit both orders. Therefore, maybe it is better to have something like | ||
104 | viewC and viewF, which may actually perform a transpose if required. | ||
83 | 105 | ||
106 | - but if we need the transposed data and it is not in the structure, we must make | ||
107 | sure that we touch the same foreignptr that is used in the computation. Access | ||
108 | to such pointer cannot be made by creating a new vector. | ||
109 | |||
110 | -} | ||
84 | 111 | ||
85 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) | 112 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) |
86 | 113 | ||
114 | {- | ||
115 | |||
116 | |||
117 | |||
87 | data Matrix t = M { rows :: Int | 118 | data Matrix t = M { rows :: Int |
88 | , cols :: Int | 119 | , cols :: Int |
89 | , dat :: Vector t | 120 | , dat :: Vector t |
@@ -91,28 +122,26 @@ data Matrix t = M { rows :: Int | |||
91 | , isTrans :: Bool | 122 | , isTrans :: Bool |
92 | , order :: MatrixOrder | 123 | , order :: MatrixOrder |
93 | } -- deriving Typeable | 124 | } -- deriving Typeable |
125 | -} | ||
94 | 126 | ||
127 | data Matrix t = MC { rows :: Int, cols :: Int, dat :: Vector t } -- row major order | ||
128 | | MF { rows :: Int, cols :: Int, dat :: Vector t } -- column major order | ||
95 | 129 | ||
96 | data NMat t = MC { rws, cls :: Int, dtc :: Vector t} | 130 | -- transposition just changes the data order |
97 | | MF { rws, cls :: Int, dtf :: Vector t} | 131 | trans :: Matrix t -> Matrix t |
98 | | Tr (NMat t) | 132 | trans MC {rows = r, cols = c, dat = d} = MF {rows = c, cols = r, dat = d} |
99 | 133 | trans MF {rows = r, cols = c, dat = d} = MC {rows = c, cols = r, dat = d} | |
100 | ntrans (Tr m) = m | ||
101 | ntrans m = Tr m | ||
102 | 134 | ||
103 | viewC m@MC{} = m | 135 | viewC m@MC{} = m |
104 | viewF m@MF{} = m | 136 | viewC MF {rows = r, cols = c, dat = d} = MC {rows = r, cols = c, dat = transdata r d c} |
105 | 137 | ||
106 | fortran m = order m == ColumnMajor | 138 | viewF m@MF{} = m |
139 | viewF MC {rows = r, cols = c, dat = d} = MF {rows = r, cols = c, dat = transdata c d r} | ||
107 | 140 | ||
108 | cdat m = if fortran m `xor` isTrans m then tdat m else dat m | 141 | --fortran m = order m == ColumnMajor |
109 | fdat m = if fortran m `xor` isTrans m then dat m else tdat m | ||
110 | 142 | ||
111 | trans :: Matrix t -> Matrix t | 143 | cdat m = dat (viewC m) |
112 | trans m = m { rows = cols m | 144 | fdat m = dat (viewF m) |
113 | , cols = rows m | ||
114 | , isTrans = not (isTrans m) | ||
115 | } | ||
116 | 145 | ||
117 | type Mt t s = Int -> Int -> Ptr t -> s | 146 | type Mt t s = Int -> Int -> Ptr t -> s |
118 | -- not yet admitted by my haddock version | 147 | -- not yet admitted by my haddock version |
@@ -120,11 +149,14 @@ type Mt t s = Int -> Int -> Ptr t -> s | |||
120 | -- type t ::> s = Mt t s | 149 | -- type t ::> s = Mt t s |
121 | 150 | ||
122 | mat d m f = f (rows m) (cols m) (ptr (d m)) | 151 | mat d m f = f (rows m) (cols m) (ptr (d m)) |
152 | --mat m f = f (rows m) (cols m) (ptr (dat m)) | ||
153 | --matC m f = f (rows m) (cols m) (ptr (cdat m)) | ||
154 | |||
123 | 155 | ||
124 | toLists :: (Storable t) => Matrix t -> [[t]] | 156 | --toLists :: (Storable t) => Matrix t -> [[t]] |
125 | toLists m = partit (cols m) . toList . cdat $ m | 157 | toLists m = partit (cols m) . toList . cdat $ m |
126 | 158 | ||
127 | instance (Show a, Storable a) => (Show (Matrix a)) where | 159 | instance (Show a, Field a) => (Show (Matrix a)) where |
128 | show m = (sizes++) . dsp . map (map show) . toLists $ m | 160 | show m = (sizes++) . dsp . map (map show) . toLists $ m |
129 | where sizes = "("++show (rows m)++"><"++show (cols m)++")\n" | 161 | where sizes = "("++show (rows m)++"><"++show (cols m)++")\n" |
130 | 162 | ||
@@ -136,6 +168,7 @@ dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unw | |||
136 | pad n str = replicate (n - length str) ' ' ++ str | 168 | pad n str = replicate (n - length str) ' ' ++ str |
137 | unwords' = concat . intersperse ", " | 169 | unwords' = concat . intersperse ", " |
138 | 170 | ||
171 | {- | ||
139 | matrixFromVector RowMajor c v = | 172 | matrixFromVector RowMajor c v = |
140 | M { rows = r | 173 | M { rows = r |
141 | , cols = c | 174 | , cols = c |
@@ -147,8 +180,6 @@ matrixFromVector RowMajor c v = | |||
147 | r | m==0 = d | 180 | r | m==0 = d |
148 | | otherwise = error "matrixFromVector" | 181 | | otherwise = error "matrixFromVector" |
149 | 182 | ||
150 | -- r = dim v `div` c -- TODO check mod=0 | ||
151 | |||
152 | matrixFromVector ColumnMajor c v = | 183 | matrixFromVector ColumnMajor c v = |
153 | M { rows = r | 184 | M { rows = r |
154 | , cols = c | 185 | , cols = c |
@@ -160,6 +191,23 @@ matrixFromVector ColumnMajor c v = | |||
160 | r | m==0 = d | 191 | r | m==0 = d |
161 | | otherwise = error "matrixFromVector" | 192 | | otherwise = error "matrixFromVector" |
162 | 193 | ||
194 | -} | ||
195 | |||
196 | matrixFromVector RowMajor c v = MC { rows = r, cols = c, dat = v} | ||
197 | where (d,m) = dim v `divMod` c | ||
198 | r | m==0 = d | ||
199 | | otherwise = error "matrixFromVector" | ||
200 | |||
201 | matrixFromVector ColumnMajor c v = MF { rows = r, cols = c, dat = v} | ||
202 | where (d,m) = dim v `divMod` c | ||
203 | r | m==0 = d | ||
204 | | otherwise = error "matrixFromVector" | ||
205 | |||
206 | |||
207 | |||
208 | |||
209 | |||
210 | |||
163 | createMatrix order r c = do | 211 | createMatrix order r c = do |
164 | p <- createVector (r*c) | 212 | p <- createVector (r*c) |
165 | return (matrixFromVector order c p) | 213 | return (matrixFromVector order c p) |
@@ -178,10 +226,10 @@ reshape c v = matrixFromVector RowMajor c v | |||
178 | 226 | ||
179 | singleton x = reshape 1 (fromList [x]) | 227 | singleton x = reshape 1 (fromList [x]) |
180 | 228 | ||
181 | liftMatrix :: (Field a, Field b) => (Vector a -> Vector b) -> Matrix a -> Matrix b | 229 | --liftMatrix :: (Field a, Field b) => (Vector a -> Vector b) -> Matrix a -> Matrix b |
182 | liftMatrix f m = reshape (cols m) (f (cdat m)) | 230 | liftMatrix f m = reshape (cols m) (f (cdat m)) |
183 | 231 | ||
184 | liftMatrix2 :: (Field t) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | 232 | --liftMatrix2 :: (Field t) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t |
185 | liftMatrix2 f m1 m2 | compat m1 m2 = reshape (cols m1) (f (cdat m1) (cdat m2)) | 233 | liftMatrix2 f m1 m2 | compat m1 m2 = reshape (cols m1) (f (cdat m1) (cdat m2)) |
186 | | otherwise = error "nonconformant matrices in liftMatrix2" | 234 | | otherwise = error "nonconformant matrices in liftMatrix2" |
187 | ------------------------------------------------------------------ | 235 | ------------------------------------------------------------------ |
@@ -203,6 +251,7 @@ multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multipl | |||
203 | 251 | ||
204 | ------------------------------------------------------------------ | 252 | ------------------------------------------------------------------ |
205 | 253 | ||
254 | {- | ||
206 | gmatC m f | fortran m = | 255 | gmatC m f | fortran m = |
207 | if (isTrans m) | 256 | if (isTrans m) |
208 | then f 0 (rows m) (cols m) (ptr (dat m)) | 257 | then f 0 (rows m) (cols m) (ptr (dat m)) |
@@ -211,7 +260,11 @@ gmatC m f | fortran m = | |||
211 | if isTrans m | 260 | if isTrans m |
212 | then f 1 (cols m) (rows m) (ptr (dat m)) | 261 | then f 1 (cols m) (rows m) (ptr (dat m)) |
213 | else f 0 (rows m) (cols m) (ptr (dat m)) | 262 | else f 0 (rows m) (cols m) (ptr (dat m)) |
263 | -} | ||
214 | 264 | ||
265 | gmatC MF {rows = r, cols = c, dat = d} f = f 1 c r (ptr d) | ||
266 | gmatC MC {rows = r, cols = c, dat = d} f = f 0 r c (ptr d) | ||
267 | {-# INLINE gmatC #-} | ||
215 | 268 | ||
216 | multiplyAux fun order a b = unsafePerformIO $ do | 269 | multiplyAux fun order a b = unsafePerformIO $ do |
217 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ | 270 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ |
@@ -219,6 +272,7 @@ multiplyAux fun order a b = unsafePerformIO $ do | |||
219 | r <- createMatrix order (rows a) (cols b) | 272 | r <- createMatrix order (rows a) (cols b) |
220 | fun // gmatC a // gmatC b // mat dat r // check "multiplyAux" [dat a, dat b] | 273 | fun // gmatC a // gmatC b // mat dat r // check "multiplyAux" [dat a, dat b] |
221 | return r | 274 | return r |
275 | {-# INLINE multiplyAux #-} | ||
222 | 276 | ||
223 | foreign import ccall safe "aux.h multiplyR" | 277 | foreign import ccall safe "aux.h multiplyR" |
224 | cmultiplyR :: Int -> Int -> Int -> Ptr Double | 278 | cmultiplyR :: Int -> Int -> Int -> Ptr Double |
@@ -234,13 +288,15 @@ foreign import ccall safe "aux.h multiplyC" | |||
234 | 288 | ||
235 | multiply :: (Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a | 289 | multiply :: (Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a |
236 | multiply RowMajor a b = multiplyD RowMajor a b | 290 | multiply RowMajor a b = multiplyD RowMajor a b |
237 | multiply ColumnMajor a b = m {rows = cols m, cols = rows m, order = ColumnMajor} | 291 | multiply ColumnMajor a b = MF {rows = c, cols = r, dat = d} |
238 | where m = multiplyD RowMajor (trans b) (trans a) | 292 | where MC {rows = r, cols = c, dat = d } = multiplyD RowMajor (trans b) (trans a) |
239 | 293 | ||
240 | 294 | ||
241 | multiplyR = multiplyAux cmultiplyR | 295 | multiplyR = multiplyAux cmultiplyR' |
242 | multiplyC = multiplyAux cmultiplyC | 296 | multiplyC = multiplyAux cmultiplyC |
243 | 297 | ||
298 | cmultiplyR' p1 p2 p3 p4 q1 q2 q3 q4 r1 r2 r3 = {-# SCC "mulR" #-} cmultiplyR p1 p2 p3 p4 q1 q2 q3 q4 r1 r2 r3 | ||
299 | |||
244 | ---------------------------------------------------------------------- | 300 | ---------------------------------------------------------------------- |
245 | 301 | ||
246 | -- | extraction of a submatrix of a real matrix | 302 | -- | extraction of a submatrix of a real matrix |
@@ -249,7 +305,7 @@ subMatrixR :: (Int,Int) -- ^ (r0,c0) starting position | |||
249 | -> Matrix Double -> Matrix Double | 305 | -> Matrix Double -> Matrix Double |
250 | subMatrixR (r0,c0) (rt,ct) x = unsafePerformIO $ do | 306 | subMatrixR (r0,c0) (rt,ct) x = unsafePerformIO $ do |
251 | r <- createMatrix RowMajor rt ct | 307 | r <- createMatrix RowMajor rt ct |
252 | c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // mat cdat x // mat cdat r // check "subMatrixR" [dat r] | 308 | c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // mat cdat x // mat dat r // check "subMatrixR" [dat r] |
253 | return r | 309 | return r |
254 | foreign import ccall "aux.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM | 310 | foreign import ccall "aux.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM |
255 | 311 | ||
@@ -278,8 +334,8 @@ subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0 | |||
278 | 334 | ||
279 | diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do | 335 | diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do |
280 | m <- createMatrix RowMajor n n | 336 | m <- createMatrix RowMajor n n |
281 | fun // vec v // mat dat m // check msg [dat m] | 337 | fun // vec v // mat cdat m // check msg [dat m] |
282 | return m {tdat = dat m} | 338 | return m -- {tdat = dat m} |
283 | 339 | ||
284 | -- | diagonal matrix from a real vector | 340 | -- | diagonal matrix from a real vector |
285 | diagR :: Vector Double -> Matrix Double | 341 | diagR :: Vector Double -> Matrix Double |
@@ -305,13 +361,13 @@ diagG v = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1.. | |||
305 | | otherwise = 0 | 361 | | otherwise = 0 |
306 | 362 | ||
307 | -- | creates a Matrix from a list of vectors | 363 | -- | creates a Matrix from a list of vectors |
308 | fromRows :: Field t => [Vector t] -> Matrix t | 364 | --fromRows :: Field t => [Vector t] -> Matrix t |
309 | fromRows vs = case common dim vs of | 365 | fromRows vs = case common dim vs of |
310 | Nothing -> error "fromRows applied to [] or to vectors with different sizes" | 366 | Nothing -> error "fromRows applied to [] or to vectors with different sizes" |
311 | Just c -> reshape c (join vs) | 367 | Just c -> reshape c (join vs) |
312 | 368 | ||
313 | -- | extracts the rows of a matrix as a list of vectors | 369 | -- | extracts the rows of a matrix as a list of vectors |
314 | toRows :: Storable t => Matrix t -> [Vector t] | 370 | --toRows :: Storable t => Matrix t -> [Vector t] |
315 | toRows m = toRows' 0 where | 371 | toRows m = toRows' 0 where |
316 | v = cdat m | 372 | v = cdat m |
317 | r = rows m | 373 | r = rows m |
@@ -324,16 +380,25 @@ fromColumns :: Field t => [Vector t] -> Matrix t | |||
324 | fromColumns m = trans . fromRows $ m | 380 | fromColumns m = trans . fromRows $ m |
325 | 381 | ||
326 | -- | Creates a list of vectors from the columns of a matrix | 382 | -- | Creates a list of vectors from the columns of a matrix |
327 | toColumns :: Storable t => Matrix t -> [Vector t] | 383 | toColumns :: Field t => Matrix t -> [Vector t] |
328 | toColumns m = toRows . trans $ m | 384 | toColumns m = toRows . trans $ m |
329 | 385 | ||
330 | 386 | ||
331 | -- | Reads a matrix position. | 387 | -- | Reads a matrix position. |
332 | (@@>) :: Storable t => Matrix t -> (Int,Int) -> t | 388 | (@@>) :: Storable t => Matrix t -> (Int,Int) -> t |
333 | infixl 9 @@> | 389 | infixl 9 @@> |
334 | m@M {rows = r, cols = c} @@> (i,j) | 390 | --m@M {rows = r, cols = c} @@> (i,j) |
391 | -- | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" | ||
392 | -- | otherwise = cdat m `at` (i*c+j) | ||
393 | |||
394 | MC {rows = r, cols = c, dat = v} @@> (i,j) | ||
395 | | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" | ||
396 | | otherwise = v `at` (i*c+j) | ||
397 | |||
398 | MF {rows = r, cols = c, dat = v} @@> (i,j) | ||
335 | | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" | 399 | | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" |
336 | | otherwise = cdat m `at` (i*c+j) | 400 | | otherwise = v `at` (j*r+i) |
401 | |||
337 | 402 | ||
338 | ------------------------------------------------------------------ | 403 | ------------------------------------------------------------------ |
339 | 404 | ||
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs index 6876685..dea1636 100644 --- a/lib/Data/Packed/Internal/Tensor.hs +++ b/lib/Data/Packed/Internal/Tensor.hs | |||
@@ -92,7 +92,7 @@ tensor dssig vec = T d v `withIdx` seqind where | |||
92 | tensorFromVector :: IdxType -> Vector t -> Tensor t | 92 | tensorFromVector :: IdxType -> Vector t -> Tensor t |
93 | tensorFromVector tp v = T {dims = [IdxDesc (dim v) tp "1"], ten = v} | 93 | tensorFromVector tp v = T {dims = [IdxDesc (dim v) tp "1"], ten = v} |
94 | 94 | ||
95 | tensorFromMatrix :: IdxType -> IdxType -> Matrix t -> Tensor t | 95 | tensorFromMatrix :: Field t => IdxType -> IdxType -> Matrix t -> Tensor t |
96 | tensorFromMatrix tpr tpc m = T {dims = [IdxDesc (rows m) tpr "1",IdxDesc (cols m) tpc "2"] | 96 | tensorFromMatrix tpr tpc m = T {dims = [IdxDesc (rows m) tpr "1",IdxDesc (cols m) tpc "2"] |
97 | , ten = cdat m} | 97 | , ten = cdat m} |
98 | 98 | ||
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index 2e8cb3d..45aaaba 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs | |||
@@ -77,7 +77,7 @@ diagRect s r c | |||
77 | | r > c = joinVert [diag s , zeros (r-c,c)] | 77 | | r > c = joinVert [diag s , zeros (r-c,c)] |
78 | where zeros (r,c) = reshape c $ constant 0 (r*c) | 78 | where zeros (r,c) = reshape c $ constant 0 (r*c) |
79 | 79 | ||
80 | takeDiag :: (Storable t) => Matrix t -> Vector t | 80 | takeDiag :: (Field t) => Matrix t -> Vector t |
81 | takeDiag m = fromList [cdat m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] | 81 | takeDiag m = fromList [cdat m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] |
82 | 82 | ||
83 | ident :: (Num t, Field t) => Int -> Matrix t | 83 | ident :: (Num t, Field t) => Int -> Matrix t |
@@ -119,7 +119,7 @@ dropColumns n mat = subMatrix (0,n) (rows mat, cols mat - n) mat | |||
119 | @\> flatten ('ident' 3) | 119 | @\> flatten ('ident' 3) |
120 | 9 # [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]@ | 120 | 9 # [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]@ |
121 | -} | 121 | -} |
122 | flatten :: Matrix t -> Vector t | 122 | flatten :: Field t => Matrix t -> Vector t |
123 | flatten = cdat | 123 | flatten = cdat |
124 | 124 | ||
125 | -- | Creates a 'Matrix' from a list of lists (considered as rows). | 125 | -- | Creates a 'Matrix' from a list of lists (considered as rows). |
@@ -21,7 +21,7 @@ module LinearAlgebra.Algorithms, | |||
21 | module LAPACK, | 21 | module LAPACK, |
22 | module GSL.Integration, | 22 | module GSL.Integration, |
23 | module GSL.Differentiation, | 23 | module GSL.Differentiation, |
24 | module GSL.Special, | 24 | --module GSL.Special, |
25 | module GSL.Fourier, | 25 | module GSL.Fourier, |
26 | module GSL.Polynomials, | 26 | module GSL.Polynomials, |
27 | module GSL.Minimization, | 27 | module GSL.Minimization, |
diff --git a/lib/GSL/Compat.hs b/lib/GSL/Compat.hs index 809a1f5..1d6f7b9 100644 --- a/lib/GSL/Compat.hs +++ b/lib/GSL/Compat.hs | |||
@@ -38,7 +38,7 @@ adaptScalar f1 f2 f3 x y | |||
38 | | dim y == 1 = f3 x (y@>0) | 38 | | dim y == 1 = f3 x (y@>0) |
39 | | otherwise = f2 x y | 39 | | otherwise = f2 x y |
40 | 40 | ||
41 | liftMatrix2' :: (Field t) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | 41 | liftMatrix2' :: (Field t, Field a, Field b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t |
42 | liftMatrix2' f m1 m2 | compat' m1 m2 = reshape (max (cols m1) (cols m2)) (f (cdat m1) (cdat m2)) | 42 | liftMatrix2' f m1 m2 | compat' m1 m2 = reshape (max (cols m1) (cols m2)) (f (cdat m1) (cdat m2)) |
43 | | otherwise = error "nonconformant matrices in liftMatrix2'" | 43 | | otherwise = error "nonconformant matrices in liftMatrix2'" |
44 | 44 | ||
@@ -63,6 +63,7 @@ instance (Eq a, Field a) => Eq (Matrix a) where | |||
63 | 63 | ||
64 | instance (Field a, Linear Vector a) => Num (Matrix a) where | 64 | instance (Field a, Linear Vector a) => Num (Matrix a) where |
65 | (+) = liftMatrix2' (+) | 65 | (+) = liftMatrix2' (+) |
66 | (-) = liftMatrix2' (-) | ||
66 | negate = liftMatrix negate | 67 | negate = liftMatrix negate |
67 | (*) = liftMatrix2' (*) | 68 | (*) = liftMatrix2' (*) |
68 | signum = liftMatrix signum | 69 | signum = liftMatrix signum |
diff --git a/lib/GSL/Matrix.hs b/lib/GSL/Matrix.hs index 26c5e2a..15710df 100644 --- a/lib/GSL/Matrix.hs +++ b/lib/GSL/Matrix.hs | |||
@@ -46,13 +46,14 @@ import Foreign.C.String | |||
46 | 46 | ||
47 | -} | 47 | -} |
48 | eigSg :: Matrix Double -> (Vector Double, Matrix Double) | 48 | eigSg :: Matrix Double -> (Vector Double, Matrix Double) |
49 | eigSg (m@M {rows = r}) | 49 | eigSg m |
50 | | r == 1 = (fromList [cdat m `at` 0], singleton 1) | 50 | | r == 1 = (fromList [cdat m `at` 0], singleton 1) |
51 | | otherwise = unsafePerformIO $ do | 51 | | otherwise = unsafePerformIO $ do |
52 | l <- createVector r | 52 | l <- createVector r |
53 | v <- createMatrix RowMajor r r | 53 | v <- createMatrix RowMajor r r |
54 | c_eigS // mat cdat m // vec l // mat dat v // check "eigSg" [cdat m] | 54 | c_eigS // mat cdat m // vec l // mat dat v // check "eigSg" [cdat m] |
55 | return (l,v) | 55 | return (l,v) |
56 | where r = rows m | ||
56 | foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM | 57 | foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM |
57 | 58 | ||
58 | ------------------------------------------------------------------ | 59 | ------------------------------------------------------------------ |
@@ -76,13 +77,14 @@ foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM | |||
76 | 77 | ||
77 | -} | 78 | -} |
78 | eigHg :: Matrix (Complex Double)-> (Vector Double, Matrix (Complex Double)) | 79 | eigHg :: Matrix (Complex Double)-> (Vector Double, Matrix (Complex Double)) |
79 | eigHg (m@M {rows = r}) | 80 | eigHg m |
80 | | r == 1 = (fromList [realPart $ cdat m `at` 0], singleton 1) | 81 | | r == 1 = (fromList [realPart $ cdat m `at` 0], singleton 1) |
81 | | otherwise = unsafePerformIO $ do | 82 | | otherwise = unsafePerformIO $ do |
82 | l <- createVector r | 83 | l <- createVector r |
83 | v <- createMatrix RowMajor r r | 84 | v <- createMatrix RowMajor r r |
84 | c_eigH // mat cdat m // vec l // mat dat v // check "eigHg" [cdat m] | 85 | c_eigH // mat cdat m // vec l // mat dat v // check "eigHg" [cdat m] |
85 | return (l,v) | 86 | return (l,v) |
87 | where r = rows m | ||
86 | foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM | 88 | foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM |
87 | 89 | ||
88 | 90 | ||
@@ -108,16 +110,18 @@ foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM | |||
108 | 110 | ||
109 | -} | 111 | -} |
110 | svdg :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) | 112 | svdg :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) |
111 | svdg x@M {rows = r, cols = c} = if r>=c | 113 | svdg x = if rows x >= cols x |
112 | then svd' x | 114 | then svd' x |
113 | else (v, s, u) where (u,s,v) = svd' (trans x) | 115 | else (v, s, u) where (u,s,v) = svd' (trans x) |
114 | 116 | ||
115 | svd' x@M {rows = r, cols = c} = unsafePerformIO $ do | 117 | svd' x = unsafePerformIO $ do |
116 | u <- createMatrix RowMajor r c | 118 | u <- createMatrix RowMajor r c |
117 | s <- createVector c | 119 | s <- createVector c |
118 | v <- createMatrix RowMajor c c | 120 | v <- createMatrix RowMajor c c |
119 | c_svd // mat cdat x // mat dat u // vec s // mat dat v // check "svdg" [cdat x] | 121 | c_svd // mat cdat x // mat dat u // vec s // mat dat v // check "svdg" [cdat x] |
120 | return (u,s,v) | 122 | return (u,s,v) |
123 | where r = rows x | ||
124 | c = cols x | ||
121 | foreign import ccall "gsl-aux.h svd" c_svd :: TMMVM | 125 | foreign import ccall "gsl-aux.h svd" c_svd :: TMMVM |
122 | 126 | ||
123 | {- | QR decomposition of a real matrix using /gsl_linalg_QR_decomp/ and /gsl_linalg_QR_unpack/. | 127 | {- | QR decomposition of a real matrix using /gsl_linalg_QR_decomp/ and /gsl_linalg_QR_unpack/. |
@@ -138,11 +142,13 @@ foreign import ccall "gsl-aux.h svd" c_svd :: TMMVM | |||
138 | 142 | ||
139 | -} | 143 | -} |
140 | qr :: Matrix Double -> (Matrix Double, Matrix Double) | 144 | qr :: Matrix Double -> (Matrix Double, Matrix Double) |
141 | qr x@M {rows = r, cols = c} = unsafePerformIO $ do | 145 | qr x = unsafePerformIO $ do |
142 | q <- createMatrix RowMajor r r | 146 | q <- createMatrix RowMajor r r |
143 | rot <- createMatrix RowMajor r c | 147 | rot <- createMatrix RowMajor r c |
144 | c_qr // mat cdat x // mat dat q // mat dat rot // check "qr" [cdat x] | 148 | c_qr // mat cdat x // mat dat q // mat dat rot // check "qr" [cdat x] |
145 | return (q,rot) | 149 | return (q,rot) |
150 | where r = rows x | ||
151 | c = cols x | ||
146 | foreign import ccall "gsl-aux.h QR" c_qr :: TMMM | 152 | foreign import ccall "gsl-aux.h QR" c_qr :: TMMM |
147 | 153 | ||
148 | {- | Cholesky decomposition of a symmetric positive definite real matrix using /gsl_linalg_cholesky_decomp/. | 154 | {- | Cholesky decomposition of a symmetric positive definite real matrix using /gsl_linalg_cholesky_decomp/. |
@@ -159,11 +165,11 @@ foreign import ccall "gsl-aux.h QR" c_qr :: TMMM | |||
159 | 165 | ||
160 | -} | 166 | -} |
161 | chol :: Matrix Double -> Matrix Double | 167 | chol :: Matrix Double -> Matrix Double |
162 | --chol x@(M r _ p) = createM [p] "chol" r r $ m c_chol x | 168 | chol x = unsafePerformIO $ do |
163 | chol x@M {rows = r} = unsafePerformIO $ do | ||
164 | res <- createMatrix RowMajor r r | 169 | res <- createMatrix RowMajor r r |
165 | c_chol // mat cdat x // mat dat res // check "chol" [cdat x] | 170 | c_chol // mat cdat x // mat dat res // check "chol" [cdat x] |
166 | return res | 171 | return res |
172 | where r = rows x | ||
167 | foreign import ccall "gsl-aux.h chol" c_chol :: TMM | 173 | foreign import ccall "gsl-aux.h chol" c_chol :: TMM |
168 | 174 | ||
169 | -------------------------------------------------------- | 175 | -------------------------------------------------------- |
@@ -171,43 +177,53 @@ foreign import ccall "gsl-aux.h chol" c_chol :: TMM | |||
171 | {- -| efficient multiplication by the inverse of a matrix (for real matrices) | 177 | {- -| efficient multiplication by the inverse of a matrix (for real matrices) |
172 | -} | 178 | -} |
173 | luSolveR :: Matrix Double -> Matrix Double -> Matrix Double | 179 | luSolveR :: Matrix Double -> Matrix Double -> Matrix Double |
174 | luSolveR a@(M {rows = n1, cols = n2}) b@(M {rows = r, cols = c}) | 180 | luSolveR a b |
175 | | n1==n2 && n1==r = unsafePerformIO $ do | 181 | | n1==n2 && n1==r = unsafePerformIO $ do |
176 | s <- createMatrix RowMajor r c | 182 | s <- createMatrix RowMajor r c |
177 | c_luSolveR // mat cdat a // mat cdat b // mat dat s // check "luSolveR" [cdat a, cdat b] | 183 | c_luSolveR // mat cdat a // mat cdat b // mat dat s // check "luSolveR" [cdat a, cdat b] |
178 | return s | 184 | return s |
179 | | otherwise = error "luSolveR of nonsquare matrix" | 185 | | otherwise = error "luSolveR of nonsquare matrix" |
180 | 186 | where n1 = rows a | |
187 | n2 = cols a | ||
188 | r = rows b | ||
189 | c = cols b | ||
181 | foreign import ccall "gsl-aux.h luSolveR" c_luSolveR :: TMMM | 190 | foreign import ccall "gsl-aux.h luSolveR" c_luSolveR :: TMMM |
182 | 191 | ||
183 | {- -| efficient multiplication by the inverse of a matrix (for complex matrices). | 192 | {- -| efficient multiplication by the inverse of a matrix (for complex matrices). |
184 | -} | 193 | -} |
185 | luSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | 194 | luSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) |
186 | luSolveC a@(M {rows = n1, cols = n2}) b@(M {rows = r, cols = c}) | 195 | luSolveC a b |
187 | | n1==n2 && n1==r = unsafePerformIO $ do | 196 | | n1==n2 && n1==r = unsafePerformIO $ do |
188 | s <- createMatrix RowMajor r c | 197 | s <- createMatrix RowMajor r c |
189 | c_luSolveC // mat cdat a // mat cdat b // mat dat s // check "luSolveC" [cdat a, cdat b] | 198 | c_luSolveC // mat cdat a // mat cdat b // mat dat s // check "luSolveC" [cdat a, cdat b] |
190 | return s | 199 | return s |
191 | | otherwise = error "luSolveC of nonsquare matrix" | 200 | | otherwise = error "luSolveC of nonsquare matrix" |
192 | 201 | where n1 = rows a | |
202 | n2 = cols a | ||
203 | r = rows b | ||
204 | c = cols b | ||
193 | foreign import ccall "gsl-aux.h luSolveC" c_luSolveC :: TCMCMCM | 205 | foreign import ccall "gsl-aux.h luSolveC" c_luSolveC :: TCMCMCM |
194 | 206 | ||
195 | {- | lu decomposition of real matrix (packed as a vector including l, u, the permutation and sign) | 207 | {- | lu decomposition of real matrix (packed as a vector including l, u, the permutation and sign) |
196 | -} | 208 | -} |
197 | luRaux :: Matrix Double -> Vector Double | 209 | luRaux :: Matrix Double -> Vector Double |
198 | luRaux x@M {rows = r, cols = c} = unsafePerformIO $ do | 210 | luRaux x = unsafePerformIO $ do |
199 | res <- createVector (r*r+r+1) | 211 | res <- createVector (r*r+r+1) |
200 | c_luRaux // mat cdat x // vec res // check "luRaux" [cdat x] | 212 | c_luRaux // mat cdat x // vec res // check "luRaux" [cdat x] |
201 | return res | 213 | return res |
214 | where r = rows x | ||
215 | c = cols x | ||
202 | foreign import ccall "gsl-aux.h luRaux" c_luRaux :: TMV | 216 | foreign import ccall "gsl-aux.h luRaux" c_luRaux :: TMV |
203 | 217 | ||
204 | {- | lu decomposition of complex matrix (packed as a vector including l, u, the permutation and sign) | 218 | {- | lu decomposition of complex matrix (packed as a vector including l, u, the permutation and sign) |
205 | -} | 219 | -} |
206 | luCaux :: Matrix (Complex Double) -> Vector (Complex Double) | 220 | luCaux :: Matrix (Complex Double) -> Vector (Complex Double) |
207 | luCaux x@M {rows = r, cols = c} = unsafePerformIO $ do | 221 | luCaux x = unsafePerformIO $ do |
208 | res <- createVector (r*r+r+1) | 222 | res <- createVector (r*r+r+1) |
209 | c_luCaux // mat cdat x // vec res // check "luCaux" [cdat x] | 223 | c_luCaux // mat cdat x // vec res // check "luCaux" [cdat x] |
210 | return res | 224 | return res |
225 | where r = rows x | ||
226 | c = cols x | ||
211 | foreign import ccall "gsl-aux.h luCaux" c_luCaux :: TCMCV | 227 | foreign import ccall "gsl-aux.h luCaux" c_luCaux :: TCMCV |
212 | 228 | ||
213 | {- | The LU decomposition of a square matrix. Is based on /gsl_linalg_LU_decomp/ and /gsl_linalg_complex_LU_decomp/ as described in <http://www.gnu.org/software/gsl/manual/gsl-ref_13.html#SEC223>. | 229 | {- | The LU decomposition of a square matrix. Is based on /gsl_linalg_LU_decomp/ and /gsl_linalg_complex_LU_decomp/ as described in <http://www.gnu.org/software/gsl/manual/gsl-ref_13.html#SEC223>. |
diff --git a/lib/LAPACK.hs b/lib/LAPACK.hs index b0008b1..ba72681 100644 --- a/lib/LAPACK.hs +++ b/lib/LAPACK.hs | |||
@@ -37,16 +37,19 @@ foreign import ccall "LAPACK/lapack-aux.h svd_l_R" dgesvd :: TMMVM | |||
37 | -- | 37 | -- |
38 | -- @(u,s,v)=svdR m@ so that @m=u \<\> s \<\> 'trans' v@. | 38 | -- @(u,s,v)=svdR m@ so that @m=u \<\> s \<\> 'trans' v@. |
39 | svdR :: Matrix Double -> (Matrix Double, Matrix Double, Matrix Double) | 39 | svdR :: Matrix Double -> (Matrix Double, Matrix Double, Matrix Double) |
40 | svdR x@M {rows = r, cols = c} = (u, diagRect s r c, v) where (u,s,v) = svdR' x | 40 | svdR x = (u, diagRect s r c, v) where (u,s,v) = svdR' x |
41 | r = rows x | ||
42 | c = cols x | ||
41 | 43 | ||
42 | svdR' :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) | 44 | svdR' :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) |
43 | svdR' x@M {rows = r, cols = c} = unsafePerformIO $ do | 45 | svdR' x = unsafePerformIO $ do |
44 | u <- createMatrix ColumnMajor r r | 46 | u <- createMatrix ColumnMajor r r |
45 | s <- createVector (min r c) | 47 | s <- createVector (min r c) |
46 | v <- createMatrix ColumnMajor c c | 48 | v <- createMatrix ColumnMajor c c |
47 | dgesvd // mat fdat x // mat dat u // vec s // mat dat v // check "svdR" [fdat x] | 49 | dgesvd // mat fdat x // mat dat u // vec s // mat dat v // check "svdR" [fdat x] |
48 | return (u,s,trans v) | 50 | return (u,s,trans v) |
49 | 51 | where r = rows x | |
52 | c = cols x | ||
50 | ----------------------------------------------------------------------------- | 53 | ----------------------------------------------------------------------------- |
51 | foreign import ccall "LAPACK/lapack-aux.h svd_l_Rdd" dgesdd :: TMMVM | 54 | foreign import ccall "LAPACK/lapack-aux.h svd_l_Rdd" dgesdd :: TMMVM |
52 | 55 | ||
@@ -54,15 +57,19 @@ foreign import ccall "LAPACK/lapack-aux.h svd_l_Rdd" dgesdd :: TMMVM | |||
54 | -- | 57 | -- |
55 | -- @(u,s,v)=svdRdd m@ so that @m=u \<\> s \<\> 'trans' v@. | 58 | -- @(u,s,v)=svdRdd m@ so that @m=u \<\> s \<\> 'trans' v@. |
56 | svdRdd :: Matrix Double -> (Matrix Double, Matrix Double , Matrix Double) | 59 | svdRdd :: Matrix Double -> (Matrix Double, Matrix Double , Matrix Double) |
57 | svdRdd x@M {rows = r, cols = c} = (u, diagRect s r c, v) where (u,s,v) = svdRdd' x | 60 | svdRdd x = (u, diagRect s r c, v) where (u,s,v) = svdRdd' x |
61 | r = rows x | ||
62 | c = cols x | ||
58 | 63 | ||
59 | svdRdd' :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) | 64 | svdRdd' :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) |
60 | svdRdd' x@M {rows = r, cols = c} = unsafePerformIO $ do | 65 | svdRdd' x = unsafePerformIO $ do |
61 | u <- createMatrix ColumnMajor r r | 66 | u <- createMatrix ColumnMajor r r |
62 | s <- createVector (min r c) | 67 | s <- createVector (min r c) |
63 | v <- createMatrix ColumnMajor c c | 68 | v <- createMatrix ColumnMajor c c |
64 | dgesdd // mat fdat x // mat dat u // vec s // mat dat v // check "svdRdd" [fdat x] | 69 | dgesdd // mat fdat x // mat dat u // vec s // mat dat v // check "svdRdd" [fdat x] |
65 | return (u,s,trans v) | 70 | return (u,s,trans v) |
71 | where r = rows x | ||
72 | c = cols x | ||
66 | 73 | ||
67 | ----------------------------------------------------------------------------- | 74 | ----------------------------------------------------------------------------- |
68 | foreign import ccall "LAPACK/lapack-aux.h svd_l_C" zgesvd :: TCMCMVCM | 75 | foreign import ccall "LAPACK/lapack-aux.h svd_l_C" zgesvd :: TCMCMVCM |
@@ -71,15 +78,20 @@ foreign import ccall "LAPACK/lapack-aux.h svd_l_C" zgesvd :: TCMCMVCM | |||
71 | -- | 78 | -- |
72 | -- @(u,s,v)=svdC m@ so that @m=u \<\> s \<\> 'trans' v@. | 79 | -- @(u,s,v)=svdC m@ so that @m=u \<\> s \<\> 'trans' v@. |
73 | svdC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix Double, Matrix (Complex Double)) | 80 | svdC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix Double, Matrix (Complex Double)) |
74 | svdC x@M {rows = r, cols = c} = (u, diagRect s r c, v) where (u,s,v) = svdC' x | 81 | svdC x = (u, diagRect s r c, v) where (u,s,v) = svdC' x |
82 | r = rows x | ||
83 | c = cols x | ||
75 | 84 | ||
76 | svdC' :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector Double, Matrix (Complex Double)) | 85 | svdC' :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector Double, Matrix (Complex Double)) |
77 | svdC' x@M {rows = r, cols = c} = unsafePerformIO $ do | 86 | svdC' x = unsafePerformIO $ do |
78 | u <- createMatrix ColumnMajor r r | 87 | u <- createMatrix ColumnMajor r r |
79 | s <- createVector (min r c) | 88 | s <- createVector (min r c) |
80 | v <- createMatrix ColumnMajor c c | 89 | v <- createMatrix ColumnMajor c c |
81 | zgesvd // mat fdat x // mat dat u // vec s // mat dat v // check "svdC" [fdat x] | 90 | zgesvd // mat fdat x // mat dat u // vec s // mat dat v // check "svdC" [fdat x] |
82 | return (u,s,trans v) | 91 | return (u,s,trans v) |
92 | where r = rows x | ||
93 | c = cols x | ||
94 | |||
83 | 95 | ||
84 | ----------------------------------------------------------------------------- | 96 | ----------------------------------------------------------------------------- |
85 | foreign import ccall "LAPACK/lapack-aux.h eig_l_C" zgeev :: TCMCMCVCM | 97 | foreign import ccall "LAPACK/lapack-aux.h eig_l_C" zgeev :: TCMCMCVCM |
@@ -91,7 +103,7 @@ foreign import ccall "LAPACK/lapack-aux.h eig_l_C" zgeev :: TCMCMCVCM | |||
91 | -- The eigenvectors are the columns of v. | 103 | -- The eigenvectors are the columns of v. |
92 | -- The eigenvalues are not sorted. | 104 | -- The eigenvalues are not sorted. |
93 | eigC :: Matrix (Complex Double) -> (Vector (Complex Double), Matrix (Complex Double)) | 105 | eigC :: Matrix (Complex Double) -> (Vector (Complex Double), Matrix (Complex Double)) |
94 | eigC (m@M {rows = r}) | 106 | eigC m |
95 | | r == 1 = (fromList [cdat m `at` 0], singleton 1) | 107 | | r == 1 = (fromList [cdat m `at` 0], singleton 1) |
96 | | otherwise = unsafePerformIO $ do | 108 | | otherwise = unsafePerformIO $ do |
97 | l <- createVector r | 109 | l <- createVector r |
@@ -99,6 +111,7 @@ eigC (m@M {rows = r}) | |||
99 | dummy <- createMatrix ColumnMajor 1 1 | 111 | dummy <- createMatrix ColumnMajor 1 1 |
100 | zgeev // mat fdat m // mat dat dummy // vec l // mat dat v // check "eigC" [fdat m] | 112 | zgeev // mat fdat m // mat dat dummy // vec l // mat dat v // check "eigC" [fdat m] |
101 | return (l,v) | 113 | return (l,v) |
114 | where r = rows m | ||
102 | 115 | ||
103 | ----------------------------------------------------------------------------- | 116 | ----------------------------------------------------------------------------- |
104 | foreign import ccall "LAPACK/lapack-aux.h eig_l_R" dgeev :: TMMCVM | 117 | foreign import ccall "LAPACK/lapack-aux.h eig_l_R" dgeev :: TMMCVM |
@@ -110,14 +123,15 @@ foreign import ccall "LAPACK/lapack-aux.h eig_l_R" dgeev :: TMMCVM | |||
110 | -- The eigenvectors are the columns of v. | 123 | -- The eigenvectors are the columns of v. |
111 | -- The eigenvalues are not sorted. | 124 | -- The eigenvalues are not sorted. |
112 | eigR :: Matrix Double -> (Vector (Complex Double), Matrix (Complex Double)) | 125 | eigR :: Matrix Double -> (Vector (Complex Double), Matrix (Complex Double)) |
113 | eigR (m@M {rows = r}) = (s', v'') | 126 | eigR m = (s', v'') |
114 | where (s,v) = eigRaux m | 127 | where (s,v) = eigRaux m |
115 | s' = toComplex (subVector 0 r (asReal s), subVector r r (asReal s)) | 128 | s' = toComplex (subVector 0 r (asReal s), subVector r r (asReal s)) |
116 | v' = toRows $ trans v | 129 | v' = toRows $ trans v |
117 | v'' = fromColumns $ fixeig (toList s') v' | 130 | v'' = fromColumns $ fixeig (toList s') v' |
131 | r = rows m | ||
118 | 132 | ||
119 | eigRaux :: Matrix Double -> (Vector (Complex Double), Matrix Double) | 133 | eigRaux :: Matrix Double -> (Vector (Complex Double), Matrix Double) |
120 | eigRaux (m@M {rows = r}) | 134 | eigRaux m |
121 | | r == 1 = (fromList [(cdat m `at` 0):+0], singleton 1) | 135 | | r == 1 = (fromList [(cdat m `at` 0):+0], singleton 1) |
122 | | otherwise = unsafePerformIO $ do | 136 | | otherwise = unsafePerformIO $ do |
123 | l <- createVector r | 137 | l <- createVector r |
@@ -125,6 +139,7 @@ eigRaux (m@M {rows = r}) | |||
125 | dummy <- createMatrix ColumnMajor 1 1 | 139 | dummy <- createMatrix ColumnMajor 1 1 |
126 | dgeev // mat fdat m // mat dat dummy // vec l // mat dat v // check "eigR" [fdat m] | 140 | dgeev // mat fdat m // mat dat dummy // vec l // mat dat v // check "eigR" [fdat m] |
127 | return (l,v) | 141 | return (l,v) |
142 | where r = rows m | ||
128 | 143 | ||
129 | fixeig [] _ = [] | 144 | fixeig [] _ = [] |
130 | fixeig [r] [v] = [comp v] | 145 | fixeig [r] [v] = [comp v] |
@@ -148,13 +163,14 @@ eigS m = (s', fliprl v) | |||
148 | where (s,v) = eigS' m | 163 | where (s,v) = eigS' m |
149 | s' = fromList . reverse . toList $ s | 164 | s' = fromList . reverse . toList $ s |
150 | 165 | ||
151 | eigS' (m@M {rows = r}) | 166 | eigS' m |
152 | | r == 1 = (fromList [cdat m `at` 0], singleton 1) | 167 | | r == 1 = (fromList [cdat m `at` 0], singleton 1) |
153 | | otherwise = unsafePerformIO $ do | 168 | | otherwise = unsafePerformIO $ do |
154 | l <- createVector r | 169 | l <- createVector r |
155 | v <- createMatrix ColumnMajor r r | 170 | v <- createMatrix ColumnMajor r r |
156 | dsyev // mat fdat m // vec l // mat dat v // check "eigS" [fdat m] | 171 | dsyev // mat fdat m // vec l // mat dat v // check "eigS" [fdat m] |
157 | return (l,v) | 172 | return (l,v) |
173 | where r = rows m | ||
158 | 174 | ||
159 | ----------------------------------------------------------------------------- | 175 | ----------------------------------------------------------------------------- |
160 | foreign import ccall "LAPACK/lapack-aux.h eig_l_H" zheev :: TCMVCM | 176 | foreign import ccall "LAPACK/lapack-aux.h eig_l_H" zheev :: TCMVCM |
@@ -170,37 +186,46 @@ eigH m = (s', fliprl v) | |||
170 | where (s,v) = eigH' m | 186 | where (s,v) = eigH' m |
171 | s' = fromList . reverse . toList $ s | 187 | s' = fromList . reverse . toList $ s |
172 | 188 | ||
173 | eigH' (m@M {rows = r}) | 189 | eigH' m |
174 | | r == 1 = (fromList [realPart (cdat m `at` 0)], singleton 1) | 190 | | r == 1 = (fromList [realPart (cdat m `at` 0)], singleton 1) |
175 | | otherwise = unsafePerformIO $ do | 191 | | otherwise = unsafePerformIO $ do |
176 | l <- createVector r | 192 | l <- createVector r |
177 | v <- createMatrix ColumnMajor r r | 193 | v <- createMatrix ColumnMajor r r |
178 | zheev // mat fdat m // vec l // mat dat v // check "eigH" [fdat m] | 194 | zheev // mat fdat m // vec l // mat dat v // check "eigH" [fdat m] |
179 | return (l,v) | 195 | return (l,v) |
196 | where r = rows m | ||
180 | 197 | ||
181 | ----------------------------------------------------------------------------- | 198 | ----------------------------------------------------------------------------- |
182 | foreign import ccall "LAPACK/lapack-aux.h linearSolveR_l" dgesv :: TMMM | 199 | foreign import ccall "LAPACK/lapack-aux.h linearSolveR_l" dgesv :: TMMM |
183 | 200 | ||
184 | -- | Wrapper for LAPACK's /dgesv/, which solves a general real linear system (for several right-hand sides) internally using the lu decomposition. | 201 | -- | Wrapper for LAPACK's /dgesv/, which solves a general real linear system (for several right-hand sides) internally using the lu decomposition. |
185 | linearSolveR :: Matrix Double -> Matrix Double -> Matrix Double | 202 | linearSolveR :: Matrix Double -> Matrix Double -> Matrix Double |
186 | linearSolveR a@(M {rows = n1, cols = n2}) b@(M {rows = r, cols = c}) | 203 | linearSolveR a b |
187 | | n1==n2 && n1==r = unsafePerformIO $ do | 204 | | n1==n2 && n1==r = unsafePerformIO $ do |
188 | s <- createMatrix ColumnMajor r c | 205 | s <- createMatrix ColumnMajor r c |
189 | dgesv // mat fdat a // mat fdat b // mat dat s // check "linearSolveR" [fdat a, fdat b] | 206 | dgesv // mat fdat a // mat fdat b // mat dat s // check "linearSolveR" [fdat a, fdat b] |
190 | return s | 207 | return s |
191 | | otherwise = error "linearSolveR of nonsquare matrix" | 208 | | otherwise = error "linearSolveR of nonsquare matrix" |
209 | where n1 = rows a | ||
210 | n2 = cols a | ||
211 | r = rows b | ||
212 | c = cols b | ||
192 | 213 | ||
193 | ----------------------------------------------------------------------------- | 214 | ----------------------------------------------------------------------------- |
194 | foreign import ccall "LAPACK/lapack-aux.h linearSolveC_l" zgesv :: TCMCMCM | 215 | foreign import ccall "LAPACK/lapack-aux.h linearSolveC_l" zgesv :: TCMCMCM |
195 | 216 | ||
196 | -- | Wrapper for LAPACK's /zgesv/, which solves a general complex linear system (for several right-hand sides) internally using the lu decomposition. | 217 | -- | Wrapper for LAPACK's /zgesv/, which solves a general complex linear system (for several right-hand sides) internally using the lu decomposition. |
197 | linearSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | 218 | linearSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) |
198 | linearSolveC a@(M {rows = n1, cols = n2}) b@(M {rows = r, cols = c}) | 219 | linearSolveC a b |
199 | | n1==n2 && n1==r = unsafePerformIO $ do | 220 | | n1==n2 && n1==r = unsafePerformIO $ do |
200 | s <- createMatrix ColumnMajor r c | 221 | s <- createMatrix ColumnMajor r c |
201 | zgesv // mat fdat a // mat fdat b // mat dat s // check "linearSolveC" [fdat a, fdat b] | 222 | zgesv // mat fdat a // mat fdat b // mat dat s // check "linearSolveC" [fdat a, fdat b] |
202 | return s | 223 | return s |
203 | | otherwise = error "linearSolveC of nonsquare matrix" | 224 | | otherwise = error "linearSolveC of nonsquare matrix" |
225 | where n1 = rows a | ||
226 | n2 = cols a | ||
227 | r = rows b | ||
228 | c = cols b | ||
204 | 229 | ||
205 | ----------------------------------------------------------------------------------- | 230 | ----------------------------------------------------------------------------------- |
206 | foreign import ccall "LAPACK/lapack-aux.h linearSolveLSR_l" dgels :: TMMM | 231 | foreign import ccall "LAPACK/lapack-aux.h linearSolveLSR_l" dgels :: TMMM |
@@ -209,10 +234,13 @@ foreign import ccall "LAPACK/lapack-aux.h linearSolveLSR_l" dgels :: TMMM | |||
209 | linearSolveLSR :: Matrix Double -> Matrix Double -> Matrix Double | 234 | linearSolveLSR :: Matrix Double -> Matrix Double -> Matrix Double |
210 | linearSolveLSR a b = subMatrix (0,0) (cols a, cols b) $ linearSolveLSR_l a b | 235 | linearSolveLSR a b = subMatrix (0,0) (cols a, cols b) $ linearSolveLSR_l a b |
211 | 236 | ||
212 | linearSolveLSR_l a@(M {rows = m, cols = n}) b@(M {cols = nrhs}) = unsafePerformIO $ do | 237 | linearSolveLSR_l a b = unsafePerformIO $ do |
213 | r <- createMatrix ColumnMajor (max m n) nrhs | 238 | r <- createMatrix ColumnMajor (max m n) nrhs |
214 | dgels // mat fdat a // mat fdat b // mat dat r // check "linearSolveLSR" [fdat a, fdat b] | 239 | dgels // mat fdat a // mat fdat b // mat dat r // check "linearSolveLSR" [fdat a, fdat b] |
215 | return r | 240 | return r |
241 | where m = rows a | ||
242 | n = cols a | ||
243 | nrhs = cols b | ||
216 | 244 | ||
217 | ----------------------------------------------------------------------------------- | 245 | ----------------------------------------------------------------------------------- |
218 | foreign import ccall "LAPACK/lapack-aux.h linearSolveLSC_l" zgels :: TCMCMCM | 246 | foreign import ccall "LAPACK/lapack-aux.h linearSolveLSC_l" zgels :: TCMCMCM |
@@ -221,10 +249,13 @@ foreign import ccall "LAPACK/lapack-aux.h linearSolveLSC_l" zgels :: TCMCMCM | |||
221 | linearSolveLSC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | 249 | linearSolveLSC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) |
222 | linearSolveLSC a b = subMatrix (0,0) (cols a, cols b) $ linearSolveLSC_l a b | 250 | linearSolveLSC a b = subMatrix (0,0) (cols a, cols b) $ linearSolveLSC_l a b |
223 | 251 | ||
224 | linearSolveLSC_l a@(M {rows = m, cols = n}) b@(M {cols = nrhs}) = unsafePerformIO $ do | 252 | linearSolveLSC_l a b = unsafePerformIO $ do |
225 | r <- createMatrix ColumnMajor (max m n) nrhs | 253 | r <- createMatrix ColumnMajor (max m n) nrhs |
226 | zgels // mat fdat a // mat fdat b // mat dat r // check "linearSolveLSC" [fdat a, fdat b] | 254 | zgels // mat fdat a // mat fdat b // mat dat r // check "linearSolveLSC" [fdat a, fdat b] |
227 | return r | 255 | return r |
256 | where m = rows a | ||
257 | n = cols a | ||
258 | nrhs = cols b | ||
228 | 259 | ||
229 | ----------------------------------------------------------------------------------- | 260 | ----------------------------------------------------------------------------------- |
230 | foreign import ccall "LAPACK/lapack-aux.h linearSolveSVDR_l" dgelss :: Double -> TMMM | 261 | foreign import ccall "LAPACK/lapack-aux.h linearSolveSVDR_l" dgelss :: Double -> TMMM |
@@ -237,10 +268,13 @@ linearSolveSVDR :: Maybe Double -- ^ rcond | |||
237 | linearSolveSVDR (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $ linearSolveSVDR_l rcond a b | 268 | linearSolveSVDR (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $ linearSolveSVDR_l rcond a b |
238 | linearSolveSVDR Nothing a b = linearSolveSVDR (Just (-1)) a b | 269 | linearSolveSVDR Nothing a b = linearSolveSVDR (Just (-1)) a b |
239 | 270 | ||
240 | linearSolveSVDR_l rcond a@(M {rows = m, cols = n}) b@(M {cols = nrhs}) = unsafePerformIO $ do | 271 | linearSolveSVDR_l rcond a b = unsafePerformIO $ do |
241 | r <- createMatrix ColumnMajor (max m n) nrhs | 272 | r <- createMatrix ColumnMajor (max m n) nrhs |
242 | dgelss rcond // mat fdat a // mat fdat b // mat dat r // check "linearSolveSVDR" [fdat a, fdat b] | 273 | dgelss rcond // mat fdat a // mat fdat b // mat dat r // check "linearSolveSVDR" [fdat a, fdat b] |
243 | return r | 274 | return r |
275 | where m = rows a | ||
276 | n = cols a | ||
277 | nrhs = cols b | ||
244 | 278 | ||
245 | ----------------------------------------------------------------------------------- | 279 | ----------------------------------------------------------------------------------- |
246 | foreign import ccall "LAPACK/lapack-aux.h linearSolveSVDC_l" zgelss :: Double -> TCMCMCM | 280 | foreign import ccall "LAPACK/lapack-aux.h linearSolveSVDC_l" zgelss :: Double -> TCMCMCM |
@@ -253,8 +287,11 @@ linearSolveSVDC :: Maybe Double -- ^ rcond | |||
253 | linearSolveSVDC (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $ linearSolveSVDC_l rcond a b | 287 | linearSolveSVDC (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $ linearSolveSVDC_l rcond a b |
254 | linearSolveSVDC Nothing a b = linearSolveSVDC (Just (-1)) a b | 288 | linearSolveSVDC Nothing a b = linearSolveSVDC (Just (-1)) a b |
255 | 289 | ||
256 | linearSolveSVDC_l rcond a@(M {rows = m, cols = n}) b@(M {cols = nrhs}) = unsafePerformIO $ do | 290 | linearSolveSVDC_l rcond a b = unsafePerformIO $ do |
257 | r <- createMatrix ColumnMajor (max m n) nrhs | 291 | r <- createMatrix ColumnMajor (max m n) nrhs |
258 | zgelss rcond // mat fdat a // mat fdat b // mat dat r // check "linearSolveSVDC" [fdat a, fdat b] | 292 | zgelss rcond // mat fdat a // mat fdat b // mat dat r // check "linearSolveSVDC" [fdat a, fdat b] |
259 | return r | 293 | return r |
294 | where m = rows a | ||
295 | n = cols a | ||
296 | nrhs = cols b | ||
260 | 297 | ||