summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric/LinearAlgebra
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2014-06-04 14:14:31 +0200
committerAlberto Ruiz <aruiz@um.es>2014-06-04 14:14:31 +0200
commit0476c58d0b9da4fdcbbcb05ea055f6d14097e116 (patch)
tree65cc2d7a0388820153038518554e27f67e359cf9 /packages/base/src/Numeric/LinearAlgebra
parent9a17969ad0ea9f940db6201a37b9aed19ad605df (diff)
operations with nonexpanded constant and diagonal matrices
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Real.hs198
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Static.hs19
2 files changed, 158 insertions, 59 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs
index 424e766..2ff69c7 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Real.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs
@@ -27,7 +27,7 @@ Experimental interface for real arrays with statically checked dimensions.
27 27
28module Numeric.LinearAlgebra.Real( 28module Numeric.LinearAlgebra.Real(
29 -- * Vector 29 -- * Vector
30 R, 30 R, C,
31 vec2, vec3, vec4, (&), (#), 31 vec2, vec3, vec4, (&), (#),
32 vect, 32 vect,
33 linspace, range, dim, 33 linspace, range, dim,
@@ -35,27 +35,30 @@ module Numeric.LinearAlgebra.Real(
35 L, Sq, 35 L, Sq,
36 row, col, (¦),(——), 36 row, col, (¦),(——),
37 unrow, uncol, 37 unrow, uncol,
38 Sized(..), 38
39 eye, 39 eye,
40 diagR, diag, Diag(..), 40 diagR, diag,
41 blockAt, 41 blockAt,
42 mat, 42 mat,
43 -- * Products 43 -- * Products
44 (<>),(#>),(<·>), 44 (<>),(#>),(<·>),
45 -- * Linear Systems 45 -- * Linear Systems
46 linSolve, -- (<\>), 46 linSolve, (<\>),
47 -- * Factorizations
48 svd, svdTall, svdFlat, eig,
47 -- * Pretty printing 49 -- * Pretty printing
48 Disp(..), 50 Disp(..),
49 -- * Misc 51 -- * Misc
50 C,
51 withVector, withMatrix, 52 withVector, withMatrix,
53 Sized(..), Diag(..), Sym, sym, -- Her, her,
52 module Numeric.HMatrix 54 module Numeric.HMatrix
53) where 55) where
54 56
55 57
56import GHC.TypeLits 58import GHC.TypeLits
57import Numeric.HMatrix hiding ( 59import Numeric.HMatrix hiding (
58 (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vect,mat,linspace,(<\>),fromList,takeDiag) 60 (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vect,mat,linspace,
61 (<\>),fromList,takeDiag,svd,eig)
59import qualified Numeric.HMatrix as LA 62import qualified Numeric.HMatrix as LA
60import Data.Proxy(Proxy) 63import Data.Proxy(Proxy)
61import Numeric.LinearAlgebra.Static 64import Numeric.LinearAlgebra.Static
@@ -122,8 +125,8 @@ dim = mkR (scalar d)
122-------------------------------------------------------------------------------- 125--------------------------------------------------------------------------------
123 126
124newtype L m n = L (Dim m (Dim n (Matrix ℝ))) 127newtype L m n = L (Dim m (Dim n (Matrix ℝ)))
125 deriving (Num,Fractional)
126 128
129-- newtype CL m n = CL (Dim m (Dim n (Matrix ℂ)))
127 130
128ud2 :: L m n -> Matrix ℝ 131ud2 :: L m n -> Matrix ℝ
129ud2 (L (Dim (Dim x))) = x 132ud2 (L (Dim (Dim x))) = x
@@ -137,27 +140,22 @@ mkL x = L (Dim (Dim x))
137 140
138instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) 141instance forall m n . (KnownNat m, KnownNat n) => Show (L m n)
139 where 142 where
143 show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n'
140 show (ud2 -> x) 144 show (ud2 -> x)
141 | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n' 145 | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n'
142 | isDiag = printf "(diag %s %s :: L %d %d)" (show z) shy m' n'
143 | otherwise = "(mat"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")" 146 | otherwise = "(mat"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")"
144 where 147 where
145 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int 148 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
146 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int 149 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
147 isDiag = rows x == 1 && m' > 1
148 v = flatten x
149 z = v!0
150 y = subVector 1 (size v-1) v
151 shy = drop 9 (show y)
152 150
153-------------------------------------------------------------------------------- 151--------------------------------------------------------------------------------
154 152
155instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ) 153instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ)
156 where 154 where
157 konst x = mkR (LA.scalar x) 155 konst x = mkR (LA.scalar x)
158 extract = ud1 156 unwrap = ud1
159 fromList = vect 157 fromList = vect
160 expand (extract -> v) 158 extract (unwrap -> v)
161 | singleV v = LA.konst (v!0) d 159 | singleV v = LA.konst (v!0) d
162 | otherwise = v 160 | otherwise = v
163 where 161 where
@@ -167,23 +165,25 @@ instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ)
167instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ) 165instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ)
168 where 166 where
169 konst x = mkL (LA.scalar x) 167 konst x = mkL (LA.scalar x)
170 extract = ud2 168 unwrap = ud2
171 fromList = mat 169 fromList = mat
172 expand (extract -> a) 170 extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n'
171 extract (unwrap -> a)
173 | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') 172 | singleM a = LA.konst (a `atIndex` (0,0)) (m',n')
174 | rows a == 1 && m'>1 = diagRect x y m' n'
175 | otherwise = a 173 | otherwise = a
176 where 174 where
177 m' = fromIntegral . natVal $ (undefined :: Proxy m) 175 m' = fromIntegral . natVal $ (undefined :: Proxy m)
178 n' = fromIntegral . natVal $ (undefined :: Proxy n) 176 n' = fromIntegral . natVal $ (undefined :: Proxy n)
179 v = flatten a
180 x = v!0
181 y = subVector 1 (size v -1) v
182 177
183-------------------------------------------------------------------------------- 178--------------------------------------------------------------------------------
184 179
185diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n 180diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n
186diagR x v = mkL (asRow (vjoin [scalar x, expand v])) 181diagR x v = mkL (asRow (vjoin [scalar x, ev, zeros]))
182 where
183 ev = extract v
184 zeros = LA.konst x (max 0 ((min m' n') - size ev))
185 m' = fromIntegral . natVal $ (undefined :: Proxy m)
186 n' = fromIntegral . natVal $ (undefined :: Proxy n)
187 187
188diag :: KnownNat n => R n -> Sq n 188diag :: KnownNat n => R n -> Sq n
189diag = diagR 0 189diag = diagR 0
@@ -221,21 +221,14 @@ class Disp t
221instance (KnownNat m, KnownNat n) => Disp (L m n) 221instance (KnownNat m, KnownNat n) => Disp (L m n)
222 where 222 where
223 disp n x = do 223 disp n x = do
224 let a = expand x 224 let a = extract x
225 let su = LA.dispf n a 225 let su = LA.dispf n a
226 printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su) 226 printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su)
227 227
228{-
229 disp n (ud2 -> a) = do
230 if rows a == 1 && cols a == 1
231 then putStrLn $ "Const " ++ (last . words . LA.dispf n $ a)
232 else putStr "Dim " >> LA.disp n a
233-}
234
235instance KnownNat n => Disp (R n) 228instance KnownNat n => Disp (R n)
236 where 229 where
237 disp n v = do 230 disp n v = do
238 let su = LA.dispf n (asRow $ expand v) 231 let su = LA.dispf n (asRow $ extract v)
239 putStr "R " >> putStr (tail . dropWhile (/='x') $ su) 232 putStr "R " >> putStr (tail . dropWhile (/='x') $ su)
240 233
241-------------------------------------------------------------------------------- 234--------------------------------------------------------------------------------
@@ -256,7 +249,7 @@ uncol = unrow . tr
256 249
257infixl 2 —— 250infixl 2 ——
258(——) :: (KnownNat r1, KnownNat r2, KnownNat c) => L r1 c -> L r2 c -> L (r1+r2) c 251(——) :: (KnownNat r1, KnownNat r2, KnownNat c) => L r1 c -> L r2 c -> L (r1+r2) c
259a —— b = mkL (expand a LA.—— expand b) 252a —— b = mkL (extract a LA.—— extract b)
260 253
261 254
262infixl 3 ¦ 255infixl 3 ¦
@@ -264,35 +257,61 @@ infixl 3 ¦
264a ¦ b = tr (tr a —— tr b) 257a ¦ b = tr (tr a —— tr b)
265 258
266 259
267type Sq n = L n n 260type Sq n = L n n
261--type CSq n = CL n n
268 262
269type GL = (KnownNat n, KnownNat m) => L m n 263type GL = (KnownNat n, KnownNat m) => L m n
270type GSq = KnownNat n => Sq n 264type GSq = KnownNat n => Sq n
271 265
272isDiag0 :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (Vector ℝ) 266isKonst :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ,(Int,Int))
273isDiag0 (extract -> x) 267isKonst (unwrap -> x)
274 | rows x == 1 && m' > 1 && z == 0 = Just y 268 | singleM x = Just (x `atIndex` (0,0), (m',n'))
269 | otherwise = Nothing
270 where
271 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
272 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
273
274
275
276isDiag :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ, Vector ℝ, (Int,Int))
277isDiag (unwrap -> x)
278 | singleM x = Nothing
279 | rows x == 1 && m' > 1 || cols x == 1 && n' > 1 = Just (z,yz,(m',n'))
275 | otherwise = Nothing 280 | otherwise = Nothing
276 where 281 where
277 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int 282 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
283 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
278 v = flatten x 284 v = flatten x
279 z = v!0 285 z = v!0
280 y = subVector 1 (size v-1) v 286 y = subVector 1 (size v-1) v
287 ny = size y
288 zeros = LA.konst 0 (max 0 (min m' n' - ny))
289 yz = vjoin [y,zeros]
281 290
282 291
283infixr 8 <> 292infixr 8 <>
284(<>) :: (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n 293(<>) :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n
285a <> b = mkL (expand a LA.<> expand b) 294
295(isKonst -> Just (a,(_,k))) <> (isKonst -> Just (b,_)) = konst (a * b * fromIntegral k)
296
297(isDiag -> Just (0,a,_)) <> (isDiag -> Just (0,b,_)) = diagR 0 (mkR v :: R k)
298 where
299 v = a' * b'
300 n = min (size a) (size b)
301 a' = subVector 0 n a
302 b' = subVector 0 n b
303
304(isDiag -> Just (0,a,_)) <> (extract -> b) = mkL (asColumn a * takeRows (size a) b)
305
306(extract -> a) <> (isDiag -> Just (0,b,_)) = mkL (takeColumns (size b) a * asRow b)
307
308a <> b = mkL (extract a LA.<> extract b)
286 309
287infixr 8 #> 310infixr 8 #>
288(#>) :: (KnownNat m, KnownNat n) => L m n -> R n -> R m 311(#>) :: (KnownNat m, KnownNat n) => L m n -> R n -> R m
289(isDiag0 -> Just w) #> v = mkR (w' * v') 312(isDiag -> Just (0, w, _)) #> v = mkR (w * subVector 0 (size w) (extract v))
290 where 313m #> v = mkR (extract m LA.#> extract v)
291 v' = expand v
292 w' = subVector 0 (max 0 (size w - size v')) (vjoin [w , z])
293 z = LA.konst 0 (max 0 (size v' - size w))
294 314
295m #> v = mkR (expand m LA.#> expand v)
296 315
297infixr 8 <·> 316infixr 8 <·>
298(<·>) :: R n -> R n -> ℝ 317(<·>) :: R n -> R n -> ℝ
@@ -306,6 +325,36 @@ instance Transposable (L m n) (L n m)
306 tr (ud2 -> a) = mkL (tr a) 325 tr (ud2 -> a) = mkL (tr a)
307 326
308-------------------------------------------------------------------------------- 327--------------------------------------------------------------------------------
328
329adaptDiag f a@(isDiag -> Just _) b | isFull b = f (mkL (extract a)) b
330adaptDiag f a b@(isDiag -> Just _) | isFull a = f a (mkL (extract b))
331adaptDiag f a b = f a b
332
333isFull m = isDiag m == Nothing && not (singleM (unwrap m))
334
335
336lift1L f (L v) = L (f v)
337lift2L f (L a) (L b) = L (f a b)
338lift2LD f = adaptDiag (lift2L f)
339
340
341instance (KnownNat n, KnownNat m) => Num (L n m)
342 where
343 (+) = lift2LD (+)
344 (*) = lift2LD (*)
345 (-) = lift2LD (-)
346 abs = lift1L abs
347 signum = lift1L signum
348 negate = lift1L negate
349 fromInteger = L . Dim . Dim . fromInteger
350
351instance (KnownNat n, KnownNat m) => Fractional (L n m)
352 where
353 fromRational = L . Dim . Dim . fromRational
354 (/) = lift2LD (/)
355
356--------------------------------------------------------------------------------
357
309{- 358{-
310class Minim (n :: Nat) (m :: Nat) 359class Minim (n :: Nat) (m :: Nat)
311 where 360 where
@@ -333,24 +382,73 @@ class Diag m d | m -> d
333 382
334instance forall n . (KnownNat n) => Diag (L n n) (R n) 383instance forall n . (KnownNat n) => Diag (L n n) (R n)
335 where 384 where
336 takeDiag m = mkR (LA.takeDiag (expand m)) 385 takeDiag m = mkR (LA.takeDiag (extract m))
337 386
338 387
339instance forall m n . (KnownNat m, KnownNat n, m <= n+1) => Diag (L m n) (R m) 388instance forall m n . (KnownNat m, KnownNat n, m <= n+1) => Diag (L m n) (R m)
340 where 389 where
341 takeDiag m = mkR (LA.takeDiag (expand m)) 390 takeDiag m = mkR (LA.takeDiag (extract m))
342 391
343 392
344instance forall m n . (KnownNat m, KnownNat n, n <= m+1) => Diag (L m n) (R n) 393instance forall m n . (KnownNat m, KnownNat n, n <= m+1) => Diag (L m n) (R n)
345 where 394 where
346 takeDiag m = mkR (LA.takeDiag (expand m)) 395 takeDiag m = mkR (LA.takeDiag (extract m))
396
397
398--------------------------------------------------------------------------------
399
400linSolve :: (KnownNat m, KnownNat n) => L m m -> L m n -> L m n
401linSolve (extract -> a) (extract -> b) = mkL (LA.linearSolve a b)
402
403(<\>) :: (KnownNat m, KnownNat n, KnownNat r) => L m n -> L m r -> L n r
404(extract -> a) <\> (extract -> b) = mkL (a LA.<\> b)
405
406svd :: (KnownNat m, KnownNat n) => L m n -> (L m m, R n, L n n)
407svd (extract -> m) = (mkL u, mkR s', mkL v)
408 where
409 (u,s,v) = LA.svd m
410 s' = vjoin [s, z]
411 z = LA.konst 0 (max 0 (cols m - size s))
412
413
414svdTall :: (KnownNat m, KnownNat n, n <= m) => L m n -> (L m n, R n, L n n)
415svdTall (extract -> m) = (mkL u, mkR s, mkL v)
416 where
417 (u,s,v) = LA.thinSVD m
347 418
348 419
420svdFlat :: (KnownNat m, KnownNat n, m <= n) => L m n -> (L m m, R m, L m n)
421svdFlat (extract -> m) = (mkL u, mkR s, mkL v)
422 where
423 (u,s,v) = LA.thinSVD m
424
349-------------------------------------------------------------------------------- 425--------------------------------------------------------------------------------
350 426
351linSolve :: L m m -> L m n -> L m n 427class Eig m r | m -> r
352linSolve (ud2 -> a) (ud2 -> b) = mkL (LA.linearSolve a b) 428 where
429 eig :: m -> r
430
431newtype Sym n = Sym (Sq n)
432
433--newtype Her n = Her (CSq n)
434
435sym :: KnownNat n => Sq n -> Sym n
436sym m = Sym $ (m + tr m)/2
437
438--her :: KnownNat n => CSq n -> Her n
439--her = undefined -- Her $ (m + tr m)/2
353 440
441
442instance KnownNat n => Eig (Sym n) (R n, Sq n)
443 where
444 eig (Sym (extract -> m)) = (mkR l, mkL v)
445 where
446 (l,v) = eigSH m
447
448instance KnownNat n => Eig (Sq n) (C n)
449 where
450 eig (extract -> m) = C . Dim . eigenvalues $ m
451
354-------------------------------------------------------------------------------- 452--------------------------------------------------------------------------------
355 453
356withVector 454withVector
@@ -383,7 +481,7 @@ withMatrix a f =
383test :: (Bool, IO ()) 481test :: (Bool, IO ())
384test = (ok,info) 482test = (ok,info)
385 where 483 where
386 ok = expand (eye :: Sq 5) == ident 5 484 ok = extract (eye :: Sq 5) == ident 5
387 && ud2 (mTm sm :: Sq 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..] 485 && ud2 (mTm sm :: Sq 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..]
388 && ud2 (tm :: L 3 5) == LA.mat 5 [1..15] 486 && ud2 (tm :: L 3 5) == LA.mat 5 [1..15]
389 && thingS == thingD 487 && thingS == thingD
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs
index f9e935d..5caf6f8 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Static.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs
@@ -74,6 +74,12 @@ instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t))
74 negate = lift1F negate 74 negate = lift1F negate
75 fromInteger x = Dim (fromInteger x) 75 fromInteger x = Dim (fromInteger x)
76 76
77instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim n (Vector t))
78 where
79 fromRational x = Dim (fromRational x)
80 (/) = lift2F (/)
81
82
77instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t))) 83instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t)))
78 where 84 where
79 (+) = (lift2F . lift2F) (+) 85 (+) = (lift2F . lift2F) (+)
@@ -84,11 +90,6 @@ instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t)))
84 negate = (lift1F . lift1F) negate 90 negate = (lift1F . lift1F) negate
85 fromInteger x = Dim (Dim (fromInteger x)) 91 fromInteger x = Dim (Dim (fromInteger x))
86 92
87instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim n (Vector t))
88 where
89 fromRational x = Dim (fromRational x)
90 (/) = lift2F (/)
91
92instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim m (Dim n (Matrix t))) 93instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim m (Dim n (Matrix t)))
93 where 94 where
94 fromRational x = Dim (Dim (fromRational x)) 95 fromRational x = Dim (Dim (fromRational x))
@@ -106,8 +107,8 @@ mkV = Dim
106 107
107type M m n t = Dim m (Dim n (Matrix t)) 108type M m n t = Dim m (Dim n (Matrix t))
108 109
109ud2 :: Dim m (Dim n (Matrix t)) -> Matrix t 110--ud2 :: Dim m (Dim n (Matrix t)) -> Matrix t
110ud2 (Dim (Dim m)) = m 111--ud2 (Dim (Dim m)) = m
111 112
112mkM :: forall (m :: Nat) (n :: Nat) t . t -> Dim m (Dim n t) 113mkM :: forall (m :: Nat) (n :: Nat) t . t -> Dim m (Dim n t)
113mkM = Dim . Dim 114mkM = Dim . Dim
@@ -184,9 +185,9 @@ gmat st xs'
184class Num t => Sized t s d | s -> t, s -> d 185class Num t => Sized t s d | s -> t, s -> d
185 where 186 where
186 konst :: t -> s 187 konst :: t -> s
187 extract :: s -> d 188 unwrap :: s -> d
188 fromList :: [t] -> s 189 fromList :: [t] -> s
189 expand :: s -> d 190 extract :: s -> d
190 191
191singleV v = size v == 1 192singleV v = size v == 1
192singleM m = rows m == 1 && cols m == 1 193singleM m = rows m == 1 && cols m == 1