diff options
author | Alberto Ruiz <aruiz@um.es> | 2014-06-04 14:14:31 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2014-06-04 14:14:31 +0200 |
commit | 0476c58d0b9da4fdcbbcb05ea055f6d14097e116 (patch) | |
tree | 65cc2d7a0388820153038518554e27f67e359cf9 /packages/base/src/Numeric/LinearAlgebra/Real.hs | |
parent | 9a17969ad0ea9f940db6201a37b9aed19ad605df (diff) |
operations with nonexpanded constant and diagonal matrices
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Real.hs')
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Real.hs | 198 |
1 files changed, 148 insertions, 50 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs index 424e766..2ff69c7 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Real.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs | |||
@@ -27,7 +27,7 @@ Experimental interface for real arrays with statically checked dimensions. | |||
27 | 27 | ||
28 | module Numeric.LinearAlgebra.Real( | 28 | module Numeric.LinearAlgebra.Real( |
29 | -- * Vector | 29 | -- * Vector |
30 | R, | 30 | R, C, |
31 | vec2, vec3, vec4, (&), (#), | 31 | vec2, vec3, vec4, (&), (#), |
32 | vect, | 32 | vect, |
33 | linspace, range, dim, | 33 | linspace, range, dim, |
@@ -35,27 +35,30 @@ module Numeric.LinearAlgebra.Real( | |||
35 | L, Sq, | 35 | L, Sq, |
36 | row, col, (¦),(——), | 36 | row, col, (¦),(——), |
37 | unrow, uncol, | 37 | unrow, uncol, |
38 | Sized(..), | 38 | |
39 | eye, | 39 | eye, |
40 | diagR, diag, Diag(..), | 40 | diagR, diag, |
41 | blockAt, | 41 | blockAt, |
42 | mat, | 42 | mat, |
43 | -- * Products | 43 | -- * Products |
44 | (<>),(#>),(<·>), | 44 | (<>),(#>),(<·>), |
45 | -- * Linear Systems | 45 | -- * Linear Systems |
46 | linSolve, -- (<\>), | 46 | linSolve, (<\>), |
47 | -- * Factorizations | ||
48 | svd, svdTall, svdFlat, eig, | ||
47 | -- * Pretty printing | 49 | -- * Pretty printing |
48 | Disp(..), | 50 | Disp(..), |
49 | -- * Misc | 51 | -- * Misc |
50 | C, | ||
51 | withVector, withMatrix, | 52 | withVector, withMatrix, |
53 | Sized(..), Diag(..), Sym, sym, -- Her, her, | ||
52 | module Numeric.HMatrix | 54 | module Numeric.HMatrix |
53 | ) where | 55 | ) where |
54 | 56 | ||
55 | 57 | ||
56 | import GHC.TypeLits | 58 | import GHC.TypeLits |
57 | import Numeric.HMatrix hiding ( | 59 | import Numeric.HMatrix hiding ( |
58 | (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vect,mat,linspace,(<\>),fromList,takeDiag) | 60 | (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vect,mat,linspace, |
61 | (<\>),fromList,takeDiag,svd,eig) | ||
59 | import qualified Numeric.HMatrix as LA | 62 | import qualified Numeric.HMatrix as LA |
60 | import Data.Proxy(Proxy) | 63 | import Data.Proxy(Proxy) |
61 | import Numeric.LinearAlgebra.Static | 64 | import Numeric.LinearAlgebra.Static |
@@ -122,8 +125,8 @@ dim = mkR (scalar d) | |||
122 | -------------------------------------------------------------------------------- | 125 | -------------------------------------------------------------------------------- |
123 | 126 | ||
124 | newtype L m n = L (Dim m (Dim n (Matrix ℝ))) | 127 | newtype L m n = L (Dim m (Dim n (Matrix ℝ))) |
125 | deriving (Num,Fractional) | ||
126 | 128 | ||
129 | -- newtype CL m n = CL (Dim m (Dim n (Matrix ℂ))) | ||
127 | 130 | ||
128 | ud2 :: L m n -> Matrix ℝ | 131 | ud2 :: L m n -> Matrix ℝ |
129 | ud2 (L (Dim (Dim x))) = x | 132 | ud2 (L (Dim (Dim x))) = x |
@@ -137,27 +140,22 @@ mkL x = L (Dim (Dim x)) | |||
137 | 140 | ||
138 | instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) | 141 | instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) |
139 | where | 142 | where |
143 | show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n' | ||
140 | show (ud2 -> x) | 144 | show (ud2 -> x) |
141 | | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n' | 145 | | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n' |
142 | | isDiag = printf "(diag %s %s :: L %d %d)" (show z) shy m' n' | ||
143 | | otherwise = "(mat"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")" | 146 | | otherwise = "(mat"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")" |
144 | where | 147 | where |
145 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | 148 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int |
146 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | 149 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int |
147 | isDiag = rows x == 1 && m' > 1 | ||
148 | v = flatten x | ||
149 | z = v!0 | ||
150 | y = subVector 1 (size v-1) v | ||
151 | shy = drop 9 (show y) | ||
152 | 150 | ||
153 | -------------------------------------------------------------------------------- | 151 | -------------------------------------------------------------------------------- |
154 | 152 | ||
155 | instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ) | 153 | instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ) |
156 | where | 154 | where |
157 | konst x = mkR (LA.scalar x) | 155 | konst x = mkR (LA.scalar x) |
158 | extract = ud1 | 156 | unwrap = ud1 |
159 | fromList = vect | 157 | fromList = vect |
160 | expand (extract -> v) | 158 | extract (unwrap -> v) |
161 | | singleV v = LA.konst (v!0) d | 159 | | singleV v = LA.konst (v!0) d |
162 | | otherwise = v | 160 | | otherwise = v |
163 | where | 161 | where |
@@ -167,23 +165,25 @@ instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ) | |||
167 | instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ) | 165 | instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ) |
168 | where | 166 | where |
169 | konst x = mkL (LA.scalar x) | 167 | konst x = mkL (LA.scalar x) |
170 | extract = ud2 | 168 | unwrap = ud2 |
171 | fromList = mat | 169 | fromList = mat |
172 | expand (extract -> a) | 170 | extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n' |
171 | extract (unwrap -> a) | ||
173 | | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') | 172 | | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') |
174 | | rows a == 1 && m'>1 = diagRect x y m' n' | ||
175 | | otherwise = a | 173 | | otherwise = a |
176 | where | 174 | where |
177 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 175 | m' = fromIntegral . natVal $ (undefined :: Proxy m) |
178 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | 176 | n' = fromIntegral . natVal $ (undefined :: Proxy n) |
179 | v = flatten a | ||
180 | x = v!0 | ||
181 | y = subVector 1 (size v -1) v | ||
182 | 177 | ||
183 | -------------------------------------------------------------------------------- | 178 | -------------------------------------------------------------------------------- |
184 | 179 | ||
185 | diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n | 180 | diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n |
186 | diagR x v = mkL (asRow (vjoin [scalar x, expand v])) | 181 | diagR x v = mkL (asRow (vjoin [scalar x, ev, zeros])) |
182 | where | ||
183 | ev = extract v | ||
184 | zeros = LA.konst x (max 0 ((min m' n') - size ev)) | ||
185 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
186 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
187 | 187 | ||
188 | diag :: KnownNat n => R n -> Sq n | 188 | diag :: KnownNat n => R n -> Sq n |
189 | diag = diagR 0 | 189 | diag = diagR 0 |
@@ -221,21 +221,14 @@ class Disp t | |||
221 | instance (KnownNat m, KnownNat n) => Disp (L m n) | 221 | instance (KnownNat m, KnownNat n) => Disp (L m n) |
222 | where | 222 | where |
223 | disp n x = do | 223 | disp n x = do |
224 | let a = expand x | 224 | let a = extract x |
225 | let su = LA.dispf n a | 225 | let su = LA.dispf n a |
226 | printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su) | 226 | printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su) |
227 | 227 | ||
228 | {- | ||
229 | disp n (ud2 -> a) = do | ||
230 | if rows a == 1 && cols a == 1 | ||
231 | then putStrLn $ "Const " ++ (last . words . LA.dispf n $ a) | ||
232 | else putStr "Dim " >> LA.disp n a | ||
233 | -} | ||
234 | |||
235 | instance KnownNat n => Disp (R n) | 228 | instance KnownNat n => Disp (R n) |
236 | where | 229 | where |
237 | disp n v = do | 230 | disp n v = do |
238 | let su = LA.dispf n (asRow $ expand v) | 231 | let su = LA.dispf n (asRow $ extract v) |
239 | putStr "R " >> putStr (tail . dropWhile (/='x') $ su) | 232 | putStr "R " >> putStr (tail . dropWhile (/='x') $ su) |
240 | 233 | ||
241 | -------------------------------------------------------------------------------- | 234 | -------------------------------------------------------------------------------- |
@@ -256,7 +249,7 @@ uncol = unrow . tr | |||
256 | 249 | ||
257 | infixl 2 —— | 250 | infixl 2 —— |
258 | (——) :: (KnownNat r1, KnownNat r2, KnownNat c) => L r1 c -> L r2 c -> L (r1+r2) c | 251 | (——) :: (KnownNat r1, KnownNat r2, KnownNat c) => L r1 c -> L r2 c -> L (r1+r2) c |
259 | a —— b = mkL (expand a LA.—— expand b) | 252 | a —— b = mkL (extract a LA.—— extract b) |
260 | 253 | ||
261 | 254 | ||
262 | infixl 3 ¦ | 255 | infixl 3 ¦ |
@@ -264,35 +257,61 @@ infixl 3 ¦ | |||
264 | a ¦ b = tr (tr a —— tr b) | 257 | a ¦ b = tr (tr a —— tr b) |
265 | 258 | ||
266 | 259 | ||
267 | type Sq n = L n n | 260 | type Sq n = L n n |
261 | --type CSq n = CL n n | ||
268 | 262 | ||
269 | type GL = (KnownNat n, KnownNat m) => L m n | 263 | type GL = (KnownNat n, KnownNat m) => L m n |
270 | type GSq = KnownNat n => Sq n | 264 | type GSq = KnownNat n => Sq n |
271 | 265 | ||
272 | isDiag0 :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (Vector ℝ) | 266 | isKonst :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ,(Int,Int)) |
273 | isDiag0 (extract -> x) | 267 | isKonst (unwrap -> x) |
274 | | rows x == 1 && m' > 1 && z == 0 = Just y | 268 | | singleM x = Just (x `atIndex` (0,0), (m',n')) |
269 | | otherwise = Nothing | ||
270 | where | ||
271 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
272 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
273 | |||
274 | |||
275 | |||
276 | isDiag :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ, Vector ℝ, (Int,Int)) | ||
277 | isDiag (unwrap -> x) | ||
278 | | singleM x = Nothing | ||
279 | | rows x == 1 && m' > 1 || cols x == 1 && n' > 1 = Just (z,yz,(m',n')) | ||
275 | | otherwise = Nothing | 280 | | otherwise = Nothing |
276 | where | 281 | where |
277 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | 282 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int |
283 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
278 | v = flatten x | 284 | v = flatten x |
279 | z = v!0 | 285 | z = v!0 |
280 | y = subVector 1 (size v-1) v | 286 | y = subVector 1 (size v-1) v |
287 | ny = size y | ||
288 | zeros = LA.konst 0 (max 0 (min m' n' - ny)) | ||
289 | yz = vjoin [y,zeros] | ||
281 | 290 | ||
282 | 291 | ||
283 | infixr 8 <> | 292 | infixr 8 <> |
284 | (<>) :: (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n | 293 | (<>) :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n |
285 | a <> b = mkL (expand a LA.<> expand b) | 294 | |
295 | (isKonst -> Just (a,(_,k))) <> (isKonst -> Just (b,_)) = konst (a * b * fromIntegral k) | ||
296 | |||
297 | (isDiag -> Just (0,a,_)) <> (isDiag -> Just (0,b,_)) = diagR 0 (mkR v :: R k) | ||
298 | where | ||
299 | v = a' * b' | ||
300 | n = min (size a) (size b) | ||
301 | a' = subVector 0 n a | ||
302 | b' = subVector 0 n b | ||
303 | |||
304 | (isDiag -> Just (0,a,_)) <> (extract -> b) = mkL (asColumn a * takeRows (size a) b) | ||
305 | |||
306 | (extract -> a) <> (isDiag -> Just (0,b,_)) = mkL (takeColumns (size b) a * asRow b) | ||
307 | |||
308 | a <> b = mkL (extract a LA.<> extract b) | ||
286 | 309 | ||
287 | infixr 8 #> | 310 | infixr 8 #> |
288 | (#>) :: (KnownNat m, KnownNat n) => L m n -> R n -> R m | 311 | (#>) :: (KnownNat m, KnownNat n) => L m n -> R n -> R m |
289 | (isDiag0 -> Just w) #> v = mkR (w' * v') | 312 | (isDiag -> Just (0, w, _)) #> v = mkR (w * subVector 0 (size w) (extract v)) |
290 | where | 313 | m #> v = mkR (extract m LA.#> extract v) |
291 | v' = expand v | ||
292 | w' = subVector 0 (max 0 (size w - size v')) (vjoin [w , z]) | ||
293 | z = LA.konst 0 (max 0 (size v' - size w)) | ||
294 | 314 | ||
295 | m #> v = mkR (expand m LA.#> expand v) | ||
296 | 315 | ||
297 | infixr 8 <·> | 316 | infixr 8 <·> |
298 | (<·>) :: R n -> R n -> ℝ | 317 | (<·>) :: R n -> R n -> ℝ |
@@ -306,6 +325,36 @@ instance Transposable (L m n) (L n m) | |||
306 | tr (ud2 -> a) = mkL (tr a) | 325 | tr (ud2 -> a) = mkL (tr a) |
307 | 326 | ||
308 | -------------------------------------------------------------------------------- | 327 | -------------------------------------------------------------------------------- |
328 | |||
329 | adaptDiag f a@(isDiag -> Just _) b | isFull b = f (mkL (extract a)) b | ||
330 | adaptDiag f a b@(isDiag -> Just _) | isFull a = f a (mkL (extract b)) | ||
331 | adaptDiag f a b = f a b | ||
332 | |||
333 | isFull m = isDiag m == Nothing && not (singleM (unwrap m)) | ||
334 | |||
335 | |||
336 | lift1L f (L v) = L (f v) | ||
337 | lift2L f (L a) (L b) = L (f a b) | ||
338 | lift2LD f = adaptDiag (lift2L f) | ||
339 | |||
340 | |||
341 | instance (KnownNat n, KnownNat m) => Num (L n m) | ||
342 | where | ||
343 | (+) = lift2LD (+) | ||
344 | (*) = lift2LD (*) | ||
345 | (-) = lift2LD (-) | ||
346 | abs = lift1L abs | ||
347 | signum = lift1L signum | ||
348 | negate = lift1L negate | ||
349 | fromInteger = L . Dim . Dim . fromInteger | ||
350 | |||
351 | instance (KnownNat n, KnownNat m) => Fractional (L n m) | ||
352 | where | ||
353 | fromRational = L . Dim . Dim . fromRational | ||
354 | (/) = lift2LD (/) | ||
355 | |||
356 | -------------------------------------------------------------------------------- | ||
357 | |||
309 | {- | 358 | {- |
310 | class Minim (n :: Nat) (m :: Nat) | 359 | class Minim (n :: Nat) (m :: Nat) |
311 | where | 360 | where |
@@ -333,24 +382,73 @@ class Diag m d | m -> d | |||
333 | 382 | ||
334 | instance forall n . (KnownNat n) => Diag (L n n) (R n) | 383 | instance forall n . (KnownNat n) => Diag (L n n) (R n) |
335 | where | 384 | where |
336 | takeDiag m = mkR (LA.takeDiag (expand m)) | 385 | takeDiag m = mkR (LA.takeDiag (extract m)) |
337 | 386 | ||
338 | 387 | ||
339 | instance forall m n . (KnownNat m, KnownNat n, m <= n+1) => Diag (L m n) (R m) | 388 | instance forall m n . (KnownNat m, KnownNat n, m <= n+1) => Diag (L m n) (R m) |
340 | where | 389 | where |
341 | takeDiag m = mkR (LA.takeDiag (expand m)) | 390 | takeDiag m = mkR (LA.takeDiag (extract m)) |
342 | 391 | ||
343 | 392 | ||
344 | instance forall m n . (KnownNat m, KnownNat n, n <= m+1) => Diag (L m n) (R n) | 393 | instance forall m n . (KnownNat m, KnownNat n, n <= m+1) => Diag (L m n) (R n) |
345 | where | 394 | where |
346 | takeDiag m = mkR (LA.takeDiag (expand m)) | 395 | takeDiag m = mkR (LA.takeDiag (extract m)) |
396 | |||
397 | |||
398 | -------------------------------------------------------------------------------- | ||
399 | |||
400 | linSolve :: (KnownNat m, KnownNat n) => L m m -> L m n -> L m n | ||
401 | linSolve (extract -> a) (extract -> b) = mkL (LA.linearSolve a b) | ||
402 | |||
403 | (<\>) :: (KnownNat m, KnownNat n, KnownNat r) => L m n -> L m r -> L n r | ||
404 | (extract -> a) <\> (extract -> b) = mkL (a LA.<\> b) | ||
405 | |||
406 | svd :: (KnownNat m, KnownNat n) => L m n -> (L m m, R n, L n n) | ||
407 | svd (extract -> m) = (mkL u, mkR s', mkL v) | ||
408 | where | ||
409 | (u,s,v) = LA.svd m | ||
410 | s' = vjoin [s, z] | ||
411 | z = LA.konst 0 (max 0 (cols m - size s)) | ||
412 | |||
413 | |||
414 | svdTall :: (KnownNat m, KnownNat n, n <= m) => L m n -> (L m n, R n, L n n) | ||
415 | svdTall (extract -> m) = (mkL u, mkR s, mkL v) | ||
416 | where | ||
417 | (u,s,v) = LA.thinSVD m | ||
347 | 418 | ||
348 | 419 | ||
420 | svdFlat :: (KnownNat m, KnownNat n, m <= n) => L m n -> (L m m, R m, L m n) | ||
421 | svdFlat (extract -> m) = (mkL u, mkR s, mkL v) | ||
422 | where | ||
423 | (u,s,v) = LA.thinSVD m | ||
424 | |||
349 | -------------------------------------------------------------------------------- | 425 | -------------------------------------------------------------------------------- |
350 | 426 | ||
351 | linSolve :: L m m -> L m n -> L m n | 427 | class Eig m r | m -> r |
352 | linSolve (ud2 -> a) (ud2 -> b) = mkL (LA.linearSolve a b) | 428 | where |
429 | eig :: m -> r | ||
430 | |||
431 | newtype Sym n = Sym (Sq n) | ||
432 | |||
433 | --newtype Her n = Her (CSq n) | ||
434 | |||
435 | sym :: KnownNat n => Sq n -> Sym n | ||
436 | sym m = Sym $ (m + tr m)/2 | ||
437 | |||
438 | --her :: KnownNat n => CSq n -> Her n | ||
439 | --her = undefined -- Her $ (m + tr m)/2 | ||
353 | 440 | ||
441 | |||
442 | instance KnownNat n => Eig (Sym n) (R n, Sq n) | ||
443 | where | ||
444 | eig (Sym (extract -> m)) = (mkR l, mkL v) | ||
445 | where | ||
446 | (l,v) = eigSH m | ||
447 | |||
448 | instance KnownNat n => Eig (Sq n) (C n) | ||
449 | where | ||
450 | eig (extract -> m) = C . Dim . eigenvalues $ m | ||
451 | |||
354 | -------------------------------------------------------------------------------- | 452 | -------------------------------------------------------------------------------- |
355 | 453 | ||
356 | withVector | 454 | withVector |
@@ -383,7 +481,7 @@ withMatrix a f = | |||
383 | test :: (Bool, IO ()) | 481 | test :: (Bool, IO ()) |
384 | test = (ok,info) | 482 | test = (ok,info) |
385 | where | 483 | where |
386 | ok = expand (eye :: Sq 5) == ident 5 | 484 | ok = extract (eye :: Sq 5) == ident 5 |
387 | && ud2 (mTm sm :: Sq 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..] | 485 | && ud2 (mTm sm :: Sq 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..] |
388 | && ud2 (tm :: L 3 5) == LA.mat 5 [1..15] | 486 | && ud2 (tm :: L 3 5) == LA.mat 5 [1..15] |
389 | && thingS == thingD | 487 | && thingS == thingD |