summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric/LinearAlgebra/Real.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Real.hs')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Real.hs105
1 files changed, 81 insertions, 24 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs
index 2ff69c7..d03ca6e 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Real.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs
@@ -45,7 +45,7 @@ module Numeric.LinearAlgebra.Real(
45 -- * Linear Systems 45 -- * Linear Systems
46 linSolve, (<\>), 46 linSolve, (<\>),
47 -- * Factorizations 47 -- * Factorizations
48 svd, svdTall, svdFlat, eig, 48 svd, svdTall, svdFlat, Eigen(..),
49 -- * Pretty printing 49 -- * Pretty printing
50 Disp(..), 50 Disp(..),
51 -- * Misc 51 -- * Misc
@@ -58,8 +58,9 @@ module Numeric.LinearAlgebra.Real(
58import GHC.TypeLits 58import GHC.TypeLits
59import Numeric.HMatrix hiding ( 59import Numeric.HMatrix hiding (
60 (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vect,mat,linspace, 60 (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vect,mat,linspace,
61 (<\>),fromList,takeDiag,svd,eig) 61 (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH')
62import qualified Numeric.HMatrix as LA 62import qualified Numeric.HMatrix as LA
63import Data.Packed.Internal(mbCatch)
63import Data.Proxy(Proxy) 64import Data.Proxy(Proxy)
64import Numeric.LinearAlgebra.Static 65import Numeric.LinearAlgebra.Static
65import Text.Printf 66import Text.Printf
@@ -80,6 +81,8 @@ ud1 (R (Dim v)) = v
80mkR :: Vector ℝ -> R n 81mkR :: Vector ℝ -> R n
81mkR = R . Dim 82mkR = R . Dim
82 83
84mkC :: Vector ℂ -> C n
85mkC = C . Dim
83 86
84infixl 4 & 87infixl 4 &
85(&) :: forall n . KnownNat n 88(&) :: forall n . KnownNat n
@@ -126,17 +129,17 @@ dim = mkR (scalar d)
126 129
127newtype L m n = L (Dim m (Dim n (Matrix ℝ))) 130newtype L m n = L (Dim m (Dim n (Matrix ℝ)))
128 131
129-- newtype CL m n = CL (Dim m (Dim n (Matrix ℂ))) 132newtype M m n = M (Dim m (Dim n (Matrix ℂ)))
130 133
131ud2 :: L m n -> Matrix ℝ 134ud2 :: L m n -> Matrix ℝ
132ud2 (L (Dim (Dim x))) = x 135ud2 (L (Dim (Dim x))) = x
133 136
134 137
135
136
137mkL :: Matrix ℝ -> L m n 138mkL :: Matrix ℝ -> L m n
138mkL x = L (Dim (Dim x)) 139mkL x = L (Dim (Dim x))
139 140
141mkM :: Matrix ℂ -> M m n
142mkM x = M (Dim (Dim x))
140 143
141instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) 144instance forall m n . (KnownNat m, KnownNat n) => Show (L m n)
142 where 145 where
@@ -150,6 +153,18 @@ instance forall m n . (KnownNat m, KnownNat n) => Show (L m n)
150 153
151-------------------------------------------------------------------------------- 154--------------------------------------------------------------------------------
152 155
156instance forall n. KnownNat n => Sized ℂ (C n) (Vector ℂ)
157 where
158 konst x = mkC (LA.scalar x)
159 unwrap (C (Dim v)) = v
160 fromList xs = C (gvect "C" xs)
161 extract (unwrap -> v)
162 | singleV v = LA.konst (v!0) d
163 | otherwise = v
164 where
165 d = fromIntegral . natVal $ (undefined :: Proxy n)
166
167
153instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ) 168instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ)
154 where 169 where
155 konst x = mkR (LA.scalar x) 170 konst x = mkR (LA.scalar x)
@@ -162,11 +177,12 @@ instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ)
162 d = fromIntegral . natVal $ (undefined :: Proxy n) 177 d = fromIntegral . natVal $ (undefined :: Proxy n)
163 178
164 179
180
165instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ) 181instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ)
166 where 182 where
167 konst x = mkL (LA.scalar x) 183 konst x = mkL (LA.scalar x)
168 unwrap = ud2
169 fromList = mat 184 fromList = mat
185 unwrap = ud2
170 extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n' 186 extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n'
171 extract (unwrap -> a) 187 extract (unwrap -> a)
172 | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') 188 | singleM a = LA.konst (a `atIndex` (0,0)) (m',n')
@@ -175,6 +191,20 @@ instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ)
175 m' = fromIntegral . natVal $ (undefined :: Proxy m) 191 m' = fromIntegral . natVal $ (undefined :: Proxy m)
176 n' = fromIntegral . natVal $ (undefined :: Proxy n) 192 n' = fromIntegral . natVal $ (undefined :: Proxy n)
177 193
194
195instance forall m n . (KnownNat m, KnownNat n) => Sized ℂ (M m n) (Matrix ℂ)
196 where
197 konst x = mkM (LA.scalar x)
198 fromList xs = M (gmat "M" xs)
199 unwrap (M (Dim (Dim m))) = m
200 extract (isDiagC -> Just (z,y,(m',n'))) = diagRect z y m' n'
201 extract (unwrap -> a)
202 | singleM a = LA.konst (a `atIndex` (0,0)) (m',n')
203 | otherwise = a
204 where
205 m' = fromIntegral . natVal $ (undefined :: Proxy m)
206 n' = fromIntegral . natVal $ (undefined :: Proxy n)
207
178-------------------------------------------------------------------------------- 208--------------------------------------------------------------------------------
179 209
180diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n 210diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n
@@ -225,26 +255,41 @@ instance (KnownNat m, KnownNat n) => Disp (L m n)
225 let su = LA.dispf n a 255 let su = LA.dispf n a
226 printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su) 256 printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su)
227 257
258instance (KnownNat m, KnownNat n) => Disp (M m n)
259 where
260 disp n x = do
261 let a = extract x
262 let su = LA.dispcf n a
263 printf "M %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su)
264
265
228instance KnownNat n => Disp (R n) 266instance KnownNat n => Disp (R n)
229 where 267 where
230 disp n v = do 268 disp n v = do
231 let su = LA.dispf n (asRow $ extract v) 269 let su = LA.dispf n (asRow $ extract v)
232 putStr "R " >> putStr (tail . dropWhile (/='x') $ su) 270 putStr "R " >> putStr (tail . dropWhile (/='x') $ su)
233 271
272instance KnownNat n => Disp (C n)
273 where
274 disp n v = do
275 let su = LA.dispcf n (asRow $ extract v)
276 putStr "C " >> putStr (tail . dropWhile (/='x') $ su)
277
278
234-------------------------------------------------------------------------------- 279--------------------------------------------------------------------------------
235 280
236 281
237row :: R n -> L 1 n 282row :: R n -> L 1 n
238row = mkL . asRow . ud1 283row = mkL . asRow . ud1
239 284
240col :: R n -> L n 1 285--col :: R n -> L n 1
241col = tr . row 286col v = tr . row $ v
242 287
243unrow :: L 1 n -> R n 288unrow :: L 1 n -> R n
244unrow = mkR . head . toRows . ud2 289unrow = mkR . head . toRows . ud2
245 290
246uncol :: L n 1 -> R n 291--uncol :: L n 1 -> R n
247uncol = unrow . tr 292uncol v = unrow . tr $ v
248 293
249 294
250infixl 2 —— 295infixl 2 ——
@@ -253,7 +298,7 @@ a —— b = mkL (extract a LA.—— extract b)
253 298
254 299
255infixl 3 ¦ 300infixl 3 ¦
256(¦) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2) 301-- (¦) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2)
257a ¦ b = tr (tr a —— tr b) 302a ¦ b = tr (tr a —— tr b)
258 303
259 304
@@ -274,7 +319,14 @@ isKonst (unwrap -> x)
274 319
275 320
276isDiag :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ, Vector ℝ, (Int,Int)) 321isDiag :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ, Vector ℝ, (Int,Int))
277isDiag (unwrap -> x) 322isDiag (L x) = isDiagg x
323
324isDiagC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (ℂ, Vector ℂ, (Int,Int))
325isDiagC (M x) = isDiagg x
326
327
328isDiagg :: forall m n t . (Numeric t, KnownNat m, KnownNat n) => GM m n t -> Maybe (t, Vector t, (Int,Int))
329isDiagg (Dim (Dim x))
278 | singleM x = Nothing 330 | singleM x = Nothing
279 | rows x == 1 && m' > 1 || cols x == 1 && n' > 1 = Just (z,yz,(m',n')) 331 | rows x == 1 && m' > 1 || cols x == 1 && n' > 1 = Just (z,yz,(m',n'))
280 | otherwise = Nothing 332 | otherwise = Nothing
@@ -282,7 +334,7 @@ isDiag (unwrap -> x)
282 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int 334 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
283 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int 335 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
284 v = flatten x 336 v = flatten x
285 z = v!0 337 z = v `atIndex` 0
286 y = subVector 1 (size v-1) v 338 y = subVector 1 (size v-1) v
287 ny = size y 339 ny = size y
288 zeros = LA.konst 0 (max 0 (min m' n' - ny)) 340 zeros = LA.konst 0 (max 0 (min m' n' - ny))
@@ -320,9 +372,10 @@ infixr 8 <·>
320 | otherwise = udot u v 372 | otherwise = udot u v
321 373
322 374
323instance Transposable (L m n) (L n m) 375instance (KnownNat n, KnownNat m) => Transposable (L m n) (L n m)
324 where 376 where
325 tr (ud2 -> a) = mkL (tr a) 377 tr a@(isDiag -> Just _) = mkL (extract a)
378 tr (extract -> a) = mkL (tr a)
326 379
327-------------------------------------------------------------------------------- 380--------------------------------------------------------------------------------
328 381
@@ -424,11 +477,12 @@ svdFlat (extract -> m) = (mkL u, mkR s, mkL v)
424 477
425-------------------------------------------------------------------------------- 478--------------------------------------------------------------------------------
426 479
427class Eig m r | m -> r 480class Eigen m l v | m -> l, m -> v
428 where 481 where
429 eig :: m -> r 482 eigensystem :: m -> (l,v)
483 eigenvalues :: m -> l
430 484
431newtype Sym n = Sym (Sq n) 485newtype Sym n = Sym (Sq n) deriving Show
432 486
433--newtype Her n = Her (CSq n) 487--newtype Her n = Her (CSq n)
434 488
@@ -438,16 +492,19 @@ sym m = Sym $ (m + tr m)/2
438--her :: KnownNat n => CSq n -> Her n 492--her :: KnownNat n => CSq n -> Her n
439--her = undefined -- Her $ (m + tr m)/2 493--her = undefined -- Her $ (m + tr m)/2
440 494
441 495instance KnownNat n => Eigen (Sym n) (R n) (L n n)
442instance KnownNat n => Eig (Sym n) (R n, Sq n)
443 where 496 where
444 eig (Sym (extract -> m)) = (mkR l, mkL v) 497 eigenvalues (Sym (extract -> m)) = mkR . LA.eigenvaluesSH' $ m
498 eigensystem (Sym (extract -> m)) = (mkR l, mkL v)
445 where 499 where
446 (l,v) = eigSH m 500 (l,v) = LA.eigSH' m
447 501
448instance KnownNat n => Eig (Sq n) (C n) 502instance KnownNat n => Eigen (Sq n) (C n) (M n n)
449 where 503 where
450 eig (extract -> m) = C . Dim . eigenvalues $ m 504 eigenvalues (extract -> m) = mkC . LA.eigenvalues $ m
505 eigensystem (extract -> m) = (mkC l, mkM v)
506 where
507 (l,v) = LA.eig m
451 508
452-------------------------------------------------------------------------------- 509--------------------------------------------------------------------------------
453 510