summaryrefslogtreecommitdiff
path: root/packages/base/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src')
-rw-r--r--packages/base/src/Numeric/HMatrix.hs2
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Real.hs97
2 files changed, 77 insertions, 22 deletions
diff --git a/packages/base/src/Numeric/HMatrix.hs b/packages/base/src/Numeric/HMatrix.hs
index 7f27fd4..786fb6d 100644
--- a/packages/base/src/Numeric/HMatrix.hs
+++ b/packages/base/src/Numeric/HMatrix.hs
@@ -144,7 +144,7 @@ module Numeric.HMatrix (
144 Transposable, 144 Transposable,
145 CGState(..), 145 CGState(..),
146 Testable(..), 146 Testable(..),
147 โ„•,โ„ค,โ„,โ„‚, ๐‘–, i_C --โ„ 147 โ„•,โ„ค,โ„,โ„‚, i_C
148) where 148) where
149 149
150import Numeric.LinearAlgebra.Data 150import Numeric.LinearAlgebra.Data
diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs
index d03ca6e..0e54555 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Real.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs
@@ -32,10 +32,10 @@ module Numeric.LinearAlgebra.Real(
32 vect, 32 vect,
33 linspace, range, dim, 33 linspace, range, dim,
34 -- * Matrix 34 -- * Matrix
35 L, Sq, 35 L, Sq, M,
36 row, col, (ยฆ),(โ€”โ€”), 36 row, col, (ยฆ),(โ€”โ€”),
37 unrow, uncol, 37 unrow, uncol,
38 38
39 eye, 39 eye,
40 diagR, diag, 40 diagR, diag,
41 blockAt, 41 blockAt,
@@ -50,7 +50,7 @@ module Numeric.LinearAlgebra.Real(
50 Disp(..), 50 Disp(..),
51 -- * Misc 51 -- * Misc
52 withVector, withMatrix, 52 withVector, withMatrix,
53 Sized(..), Diag(..), Sym, sym, -- Her, her, 53 Sized(..), Diag(..), Sym, sym, Her, her, ๐‘–,
54 module Numeric.HMatrix 54 module Numeric.HMatrix
55) where 55) where
56 56
@@ -60,11 +60,13 @@ import Numeric.HMatrix hiding (
60 (<>),(#>),(<ยท>),Konst(..),diag, disp,(ยฆ),(โ€”โ€”),row,col,vect,mat,linspace, 60 (<>),(#>),(<ยท>),Konst(..),diag, disp,(ยฆ),(โ€”โ€”),row,col,vect,mat,linspace,
61 (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH') 61 (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH')
62import qualified Numeric.HMatrix as LA 62import qualified Numeric.HMatrix as LA
63import Data.Packed.Internal(mbCatch)
64import Data.Proxy(Proxy) 63import Data.Proxy(Proxy)
65import Numeric.LinearAlgebra.Static 64import Numeric.LinearAlgebra.Static
66import Text.Printf 65import Text.Printf
67 66
67๐‘– :: Sized โ„‚ s c => s
68๐‘– = konst i_C
69
68instance forall n . KnownNat n => Show (R n) 70instance forall n . KnownNat n => Show (R n)
69 where 71 where
70 show (ud1 -> v) 72 show (ud1 -> v)
@@ -73,6 +75,15 @@ instance forall n . KnownNat n => Show (R n)
73 where 75 where
74 d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int 76 d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
75 77
78instance forall n . KnownNat n => Show (C n)
79 where
80 show (C (Dim v))
81 | singleV v = "("++show (v!0)++" :: C "++show d++")"
82 | otherwise = "(fromList"++ drop 8 (show v)++" :: C "++show d++")"
83 where
84 d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
85
86
76 87
77ud1 :: R n -> Vector โ„ 88ud1 :: R n -> Vector โ„
78ud1 (R (Dim v)) = v 89ud1 (R (Dim v)) = v
@@ -144,19 +155,30 @@ mkM x = M (Dim (Dim x))
144instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) 155instance forall m n . (KnownNat m, KnownNat n) => Show (L m n)
145 where 156 where
146 show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n' 157 show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n'
147 show (ud2 -> x) 158 show (ud2 -> x)
148 | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n' 159 | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n'
149 | otherwise = "(mat"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")" 160 | otherwise = "(mat"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")"
150 where 161 where
151 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int 162 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
152 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int 163 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
153 164
165instance forall m n . (KnownNat m, KnownNat n) => Show (M m n)
166 where
167 show (isDiagC -> Just (z,y,(m',n'))) = printf "(diag %s %s :: M %d %d)" (show z) (drop 9 $ show y) m' n'
168 show (M (Dim (Dim x)))
169 | singleM x = printf "(%s :: M %d %d)" (show (x `atIndex` (0,0))) m' n'
170 | otherwise = "(fromList"++ dropWhile (/='\n') (show x)++" :: M "++show m'++" "++show n'++")"
171 where
172 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
173 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
174
175
154-------------------------------------------------------------------------------- 176--------------------------------------------------------------------------------
155 177
156instance forall n. KnownNat n => Sized โ„‚ (C n) (Vector โ„‚) 178instance forall n. KnownNat n => Sized โ„‚ (C n) (Vector โ„‚)
157 where 179 where
158 konst x = mkC (LA.scalar x) 180 konst x = mkC (LA.scalar x)
159 unwrap (C (Dim v)) = v 181 unwrap (C (Dim v)) = v
160 fromList xs = C (gvect "C" xs) 182 fromList xs = C (gvect "C" xs)
161 extract (unwrap -> v) 183 extract (unwrap -> v)
162 | singleV v = LA.konst (v!0) d 184 | singleV v = LA.konst (v!0) d
@@ -240,7 +262,7 @@ blockAt x r c a = mkL res
240 262
241mat :: forall m n . (KnownNat m, KnownNat n) => [โ„] -> L m n 263mat :: forall m n . (KnownNat m, KnownNat n) => [โ„] -> L m n
242mat xs = L (gmat "L" xs) 264mat xs = L (gmat "L" xs)
243 265
244-------------------------------------------------------------------------------- 266--------------------------------------------------------------------------------
245 267
246class Disp t 268class Disp t
@@ -315,7 +337,7 @@ isKonst (unwrap -> x)
315 where 337 where
316 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int 338 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
317 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int 339 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
318 340
319 341
320 342
321isDiag :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (โ„, Vector โ„, (Int,Int)) 343isDiag :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (โ„, Vector โ„, (Int,Int))
@@ -329,7 +351,7 @@ isDiagg :: forall m n t . (Numeric t, KnownNat m, KnownNat n) => GM m n t -> May
329isDiagg (Dim (Dim x)) 351isDiagg (Dim (Dim x))
330 | singleM x = Nothing 352 | singleM x = Nothing
331 | rows x == 1 && m' > 1 || cols x == 1 && n' > 1 = Just (z,yz,(m',n')) 353 | rows x == 1 && m' > 1 || cols x == 1 && n' > 1 = Just (z,yz,(m',n'))
332 | otherwise = Nothing 354 | otherwise = Nothing
333 where 355 where
334 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int 356 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
335 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int 357 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
@@ -377,6 +399,12 @@ instance (KnownNat n, KnownNat m) => Transposable (L m n) (L n m)
377 tr a@(isDiag -> Just _) = mkL (extract a) 399 tr a@(isDiag -> Just _) = mkL (extract a)
378 tr (extract -> a) = mkL (tr a) 400 tr (extract -> a) = mkL (tr a)
379 401
402instance (KnownNat n, KnownNat m) => Transposable (M m n) (M n m)
403 where
404 tr a@(isDiagC -> Just _) = mkM (extract a)
405 tr (extract -> a) = mkM (tr a)
406
407
380-------------------------------------------------------------------------------- 408--------------------------------------------------------------------------------
381 409
382adaptDiag f a@(isDiag -> Just _) b | isFull b = f (mkL (extract a)) b 410adaptDiag f a@(isDiag -> Just _) b | isFull b = f (mkL (extract a)) b
@@ -408,6 +436,33 @@ instance (KnownNat n, KnownNat m) => Fractional (L n m)
408 436
409-------------------------------------------------------------------------------- 437--------------------------------------------------------------------------------
410 438
439adaptDiagC f a@(isDiagC -> Just _) b | isFullC b = f (mkM (extract a)) b
440adaptDiagC f a b@(isDiagC -> Just _) | isFullC a = f a (mkM (extract b))
441adaptDiagC f a b = f a b
442
443isFullC m = isDiagC m == Nothing && not (singleM (unwrap m))
444
445lift1M f (M v) = M (f v)
446lift2M f (M a) (M b) = M (f a b)
447lift2MD f = adaptDiagC (lift2M f)
448
449instance (KnownNat n, KnownNat m) => Num (M n m)
450 where
451 (+) = lift2MD (+)
452 (*) = lift2MD (*)
453 (-) = lift2MD (-)
454 abs = lift1M abs
455 signum = lift1M signum
456 negate = lift1M negate
457 fromInteger = M . Dim . Dim . fromInteger
458
459instance (KnownNat n, KnownNat m) => Fractional (M n m)
460 where
461 fromRational = M . Dim . Dim . fromRational
462 (/) = lift2MD (/)
463
464--------------------------------------------------------------------------------
465
411{- 466{-
412class Minim (n :: Nat) (m :: Nat) 467class Minim (n :: Nat) (m :: Nat)
413 where 468 where
@@ -481,16 +536,16 @@ class Eigen m l v | m -> l, m -> v
481 where 536 where
482 eigensystem :: m -> (l,v) 537 eigensystem :: m -> (l,v)
483 eigenvalues :: m -> l 538 eigenvalues :: m -> l
484 539
485newtype Sym n = Sym (Sq n) deriving Show 540newtype Sym n = Sym (Sq n) deriving Show
486 541
487--newtype Her n = Her (CSq n) 542newtype Her n = Her (M n n)
488 543
489sym :: KnownNat n => Sq n -> Sym n 544sym :: KnownNat n => Sq n -> Sym n
490sym m = Sym $ (m + tr m)/2 545sym m = Sym $ (m + tr m)/2
491 546
492--her :: KnownNat n => CSq n -> Her n 547her :: KnownNat n => M n n -> Her n
493--her = undefined -- Her $ (m + tr m)/2 548her m = Her $ (m + tr m)/2
494 549
495instance KnownNat n => Eigen (Sym n) (R n) (L n n) 550instance KnownNat n => Eigen (Sym n) (R n) (L n n)
496 where 551 where
@@ -505,12 +560,12 @@ instance KnownNat n => Eigen (Sq n) (C n) (M n n)
505 eigensystem (extract -> m) = (mkC l, mkM v) 560 eigensystem (extract -> m) = (mkC l, mkM v)
506 where 561 where
507 (l,v) = LA.eig m 562 (l,v) = LA.eig m
508 563
509-------------------------------------------------------------------------------- 564--------------------------------------------------------------------------------
510 565
511withVector 566withVector
512 :: forall z 567 :: forall z
513 . Vector โ„ 568 . Vector โ„
514 -> (forall n . (KnownNat n) => R n -> z) 569 -> (forall n . (KnownNat n) => R n -> z)
515 -> z 570 -> z
516withVector v f = 571withVector v f =
@@ -521,16 +576,16 @@ withVector v f =
521 576
522withMatrix 577withMatrix
523 :: forall z 578 :: forall z
524 . Matrix โ„ 579 . Matrix โ„
525 -> (forall m n . (KnownNat m, KnownNat n) => L m n -> z) 580 -> (forall m n . (KnownNat m, KnownNat n) => L m n -> z)
526 -> z 581 -> z
527withMatrix a f = 582withMatrix a f =
528 case someNatVal $ fromIntegral $ rows a of 583 case someNatVal $ fromIntegral $ rows a of
529 Nothing -> error "static/dynamic mismatch" 584 Nothing -> error "static/dynamic mismatch"
530 Just (SomeNat (_ :: Proxy m)) -> 585 Just (SomeNat (_ :: Proxy m)) ->
531 case someNatVal $ fromIntegral $ cols a of 586 case someNatVal $ fromIntegral $ cols a of
532 Nothing -> error "static/dynamic mismatch" 587 Nothing -> error "static/dynamic mismatch"
533 Just (SomeNat (_ :: Proxy n)) -> 588 Just (SomeNat (_ :: Proxy n)) ->
534 f (mkL a :: L n m) 589 f (mkL a :: L n m)
535 590
536-------------------------------------------------------------------------------- 591--------------------------------------------------------------------------------
@@ -539,8 +594,8 @@ test :: (Bool, IO ())
539test = (ok,info) 594test = (ok,info)
540 where 595 where
541 ok = extract (eye :: Sq 5) == ident 5 596 ok = extract (eye :: Sq 5) == ident 5
542 && ud2 (mTm sm :: Sq 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..] 597 && unwrap (mTm sm :: Sq 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..]
543 && ud2 (tm :: L 3 5) == LA.mat 5 [1..15] 598 && unwrap (tm :: L 3 5) == LA.mat 5 [1..15]
544 && thingS == thingD 599 && thingS == thingD
545 && precS == precD 600 && precS == precD
546 && withVector (LA.vect [1..15]) sumV == sumElements (LA.fromList [1..15]) 601 && withVector (LA.vect [1..15]) sumV == sumElements (LA.fromList [1..15])
@@ -557,7 +612,7 @@ test = (ok,info)
557 print precS 612 print precS
558 print precD 613 print precD
559 print $ withVector (LA.vect [1..15]) sumV 614 print $ withVector (LA.vect [1..15]) sumV
560 615
561 sumV w = w <ยท> konst 1 616 sumV w = w <ยท> konst 1
562 617
563 u = vec2 3 5 618 u = vec2 3 5