summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2014-06-04 18:49:50 +0200
committerAlberto Ruiz <aruiz@um.es>2014-06-04 18:49:50 +0200
commit2addcfb5db6721b9520e8be9942278dfc17b7021 (patch)
tree6fd765a21adad6b219153fe4009395c55630056e /packages/base/src/Numeric
parent0476c58d0b9da4fdcbbcb05ea055f6d14097e116 (diff)
complex instances
Diffstat (limited to 'packages/base/src/Numeric')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Real.hs105
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Static.hs6
2 files changed, 84 insertions, 27 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs
index 2ff69c7..d03ca6e 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Real.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs
@@ -45,7 +45,7 @@ module Numeric.LinearAlgebra.Real(
45 -- * Linear Systems 45 -- * Linear Systems
46 linSolve, (<\>), 46 linSolve, (<\>),
47 -- * Factorizations 47 -- * Factorizations
48 svd, svdTall, svdFlat, eig, 48 svd, svdTall, svdFlat, Eigen(..),
49 -- * Pretty printing 49 -- * Pretty printing
50 Disp(..), 50 Disp(..),
51 -- * Misc 51 -- * Misc
@@ -58,8 +58,9 @@ module Numeric.LinearAlgebra.Real(
58import GHC.TypeLits 58import GHC.TypeLits
59import Numeric.HMatrix hiding ( 59import Numeric.HMatrix hiding (
60 (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vect,mat,linspace, 60 (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vect,mat,linspace,
61 (<\>),fromList,takeDiag,svd,eig) 61 (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH')
62import qualified Numeric.HMatrix as LA 62import qualified Numeric.HMatrix as LA
63import Data.Packed.Internal(mbCatch)
63import Data.Proxy(Proxy) 64import Data.Proxy(Proxy)
64import Numeric.LinearAlgebra.Static 65import Numeric.LinearAlgebra.Static
65import Text.Printf 66import Text.Printf
@@ -80,6 +81,8 @@ ud1 (R (Dim v)) = v
80mkR :: Vector ℝ -> R n 81mkR :: Vector ℝ -> R n
81mkR = R . Dim 82mkR = R . Dim
82 83
84mkC :: Vector ℂ -> C n
85mkC = C . Dim
83 86
84infixl 4 & 87infixl 4 &
85(&) :: forall n . KnownNat n 88(&) :: forall n . KnownNat n
@@ -126,17 +129,17 @@ dim = mkR (scalar d)
126 129
127newtype L m n = L (Dim m (Dim n (Matrix ℝ))) 130newtype L m n = L (Dim m (Dim n (Matrix ℝ)))
128 131
129-- newtype CL m n = CL (Dim m (Dim n (Matrix ℂ))) 132newtype M m n = M (Dim m (Dim n (Matrix ℂ)))
130 133
131ud2 :: L m n -> Matrix ℝ 134ud2 :: L m n -> Matrix ℝ
132ud2 (L (Dim (Dim x))) = x 135ud2 (L (Dim (Dim x))) = x
133 136
134 137
135
136
137mkL :: Matrix ℝ -> L m n 138mkL :: Matrix ℝ -> L m n
138mkL x = L (Dim (Dim x)) 139mkL x = L (Dim (Dim x))
139 140
141mkM :: Matrix ℂ -> M m n
142mkM x = M (Dim (Dim x))
140 143
141instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) 144instance forall m n . (KnownNat m, KnownNat n) => Show (L m n)
142 where 145 where
@@ -150,6 +153,18 @@ instance forall m n . (KnownNat m, KnownNat n) => Show (L m n)
150 153
151-------------------------------------------------------------------------------- 154--------------------------------------------------------------------------------
152 155
156instance forall n. KnownNat n => Sized ℂ (C n) (Vector ℂ)
157 where
158 konst x = mkC (LA.scalar x)
159 unwrap (C (Dim v)) = v
160 fromList xs = C (gvect "C" xs)
161 extract (unwrap -> v)
162 | singleV v = LA.konst (v!0) d
163 | otherwise = v
164 where
165 d = fromIntegral . natVal $ (undefined :: Proxy n)
166
167
153instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ) 168instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ)
154 where 169 where
155 konst x = mkR (LA.scalar x) 170 konst x = mkR (LA.scalar x)
@@ -162,11 +177,12 @@ instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ)
162 d = fromIntegral . natVal $ (undefined :: Proxy n) 177 d = fromIntegral . natVal $ (undefined :: Proxy n)
163 178
164 179
180
165instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ) 181instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ)
166 where 182 where
167 konst x = mkL (LA.scalar x) 183 konst x = mkL (LA.scalar x)
168 unwrap = ud2
169 fromList = mat 184 fromList = mat
185 unwrap = ud2
170 extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n' 186 extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n'
171 extract (unwrap -> a) 187 extract (unwrap -> a)
172 | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') 188 | singleM a = LA.konst (a `atIndex` (0,0)) (m',n')
@@ -175,6 +191,20 @@ instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ)
175 m' = fromIntegral . natVal $ (undefined :: Proxy m) 191 m' = fromIntegral . natVal $ (undefined :: Proxy m)
176 n' = fromIntegral . natVal $ (undefined :: Proxy n) 192 n' = fromIntegral . natVal $ (undefined :: Proxy n)
177 193
194
195instance forall m n . (KnownNat m, KnownNat n) => Sized ℂ (M m n) (Matrix ℂ)
196 where
197 konst x = mkM (LA.scalar x)
198 fromList xs = M (gmat "M" xs)
199 unwrap (M (Dim (Dim m))) = m
200 extract (isDiagC -> Just (z,y,(m',n'))) = diagRect z y m' n'
201 extract (unwrap -> a)
202 | singleM a = LA.konst (a `atIndex` (0,0)) (m',n')
203 | otherwise = a
204 where
205 m' = fromIntegral . natVal $ (undefined :: Proxy m)
206 n' = fromIntegral . natVal $ (undefined :: Proxy n)
207
178-------------------------------------------------------------------------------- 208--------------------------------------------------------------------------------
179 209
180diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n 210diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n
@@ -225,26 +255,41 @@ instance (KnownNat m, KnownNat n) => Disp (L m n)
225 let su = LA.dispf n a 255 let su = LA.dispf n a
226 printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su) 256 printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su)
227 257
258instance (KnownNat m, KnownNat n) => Disp (M m n)
259 where
260 disp n x = do
261 let a = extract x
262 let su = LA.dispcf n a
263 printf "M %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su)
264
265
228instance KnownNat n => Disp (R n) 266instance KnownNat n => Disp (R n)
229 where 267 where
230 disp n v = do 268 disp n v = do
231 let su = LA.dispf n (asRow $ extract v) 269 let su = LA.dispf n (asRow $ extract v)
232 putStr "R " >> putStr (tail . dropWhile (/='x') $ su) 270 putStr "R " >> putStr (tail . dropWhile (/='x') $ su)
233 271
272instance KnownNat n => Disp (C n)
273 where
274 disp n v = do
275 let su = LA.dispcf n (asRow $ extract v)
276 putStr "C " >> putStr (tail . dropWhile (/='x') $ su)
277
278
234-------------------------------------------------------------------------------- 279--------------------------------------------------------------------------------
235 280
236 281
237row :: R n -> L 1 n 282row :: R n -> L 1 n
238row = mkL . asRow . ud1 283row = mkL . asRow . ud1
239 284
240col :: R n -> L n 1 285--col :: R n -> L n 1
241col = tr . row 286col v = tr . row $ v
242 287
243unrow :: L 1 n -> R n 288unrow :: L 1 n -> R n
244unrow = mkR . head . toRows . ud2 289unrow = mkR . head . toRows . ud2
245 290
246uncol :: L n 1 -> R n 291--uncol :: L n 1 -> R n
247uncol = unrow . tr 292uncol v = unrow . tr $ v
248 293
249 294
250infixl 2 —— 295infixl 2 ——
@@ -253,7 +298,7 @@ a —— b = mkL (extract a LA.—— extract b)
253 298
254 299
255infixl 3 ¦ 300infixl 3 ¦
256(¦) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2) 301-- (¦) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2)
257a ¦ b = tr (tr a —— tr b) 302a ¦ b = tr (tr a —— tr b)
258 303
259 304
@@ -274,7 +319,14 @@ isKonst (unwrap -> x)
274 319
275 320
276isDiag :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ, Vector ℝ, (Int,Int)) 321isDiag :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ, Vector ℝ, (Int,Int))
277isDiag (unwrap -> x) 322isDiag (L x) = isDiagg x
323
324isDiagC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (ℂ, Vector ℂ, (Int,Int))
325isDiagC (M x) = isDiagg x
326
327
328isDiagg :: forall m n t . (Numeric t, KnownNat m, KnownNat n) => GM m n t -> Maybe (t, Vector t, (Int,Int))
329isDiagg (Dim (Dim x))
278 | singleM x = Nothing 330 | singleM x = Nothing
279 | rows x == 1 && m' > 1 || cols x == 1 && n' > 1 = Just (z,yz,(m',n')) 331 | rows x == 1 && m' > 1 || cols x == 1 && n' > 1 = Just (z,yz,(m',n'))
280 | otherwise = Nothing 332 | otherwise = Nothing
@@ -282,7 +334,7 @@ isDiag (unwrap -> x)
282 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int 334 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
283 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int 335 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
284 v = flatten x 336 v = flatten x
285 z = v!0 337 z = v `atIndex` 0
286 y = subVector 1 (size v-1) v 338 y = subVector 1 (size v-1) v
287 ny = size y 339 ny = size y
288 zeros = LA.konst 0 (max 0 (min m' n' - ny)) 340 zeros = LA.konst 0 (max 0 (min m' n' - ny))
@@ -320,9 +372,10 @@ infixr 8 <·>
320 | otherwise = udot u v 372 | otherwise = udot u v
321 373
322 374
323instance Transposable (L m n) (L n m) 375instance (KnownNat n, KnownNat m) => Transposable (L m n) (L n m)
324 where 376 where
325 tr (ud2 -> a) = mkL (tr a) 377 tr a@(isDiag -> Just _) = mkL (extract a)
378 tr (extract -> a) = mkL (tr a)
326 379
327-------------------------------------------------------------------------------- 380--------------------------------------------------------------------------------
328 381
@@ -424,11 +477,12 @@ svdFlat (extract -> m) = (mkL u, mkR s, mkL v)
424 477
425-------------------------------------------------------------------------------- 478--------------------------------------------------------------------------------
426 479
427class Eig m r | m -> r 480class Eigen m l v | m -> l, m -> v
428 where 481 where
429 eig :: m -> r 482 eigensystem :: m -> (l,v)
483 eigenvalues :: m -> l
430 484
431newtype Sym n = Sym (Sq n) 485newtype Sym n = Sym (Sq n) deriving Show
432 486
433--newtype Her n = Her (CSq n) 487--newtype Her n = Her (CSq n)
434 488
@@ -438,16 +492,19 @@ sym m = Sym $ (m + tr m)/2
438--her :: KnownNat n => CSq n -> Her n 492--her :: KnownNat n => CSq n -> Her n
439--her = undefined -- Her $ (m + tr m)/2 493--her = undefined -- Her $ (m + tr m)/2
440 494
441 495instance KnownNat n => Eigen (Sym n) (R n) (L n n)
442instance KnownNat n => Eig (Sym n) (R n, Sq n)
443 where 496 where
444 eig (Sym (extract -> m)) = (mkR l, mkL v) 497 eigenvalues (Sym (extract -> m)) = mkR . LA.eigenvaluesSH' $ m
498 eigensystem (Sym (extract -> m)) = (mkR l, mkL v)
445 where 499 where
446 (l,v) = eigSH m 500 (l,v) = LA.eigSH' m
447 501
448instance KnownNat n => Eig (Sq n) (C n) 502instance KnownNat n => Eigen (Sq n) (C n) (M n n)
449 where 503 where
450 eig (extract -> m) = C . Dim . eigenvalues $ m 504 eigenvalues (extract -> m) = mkC . LA.eigenvalues $ m
505 eigensystem (extract -> m) = (mkC l, mkM v)
506 where
507 (l,v) = LA.eig m
451 508
452-------------------------------------------------------------------------------- 509--------------------------------------------------------------------------------
453 510
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs
index 5caf6f8..6acd9a3 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Static.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs
@@ -27,7 +27,7 @@ module Numeric.LinearAlgebra.Static(
27 lift1F, lift2F, 27 lift1F, lift2F,
28 vconcat, gvec2, gvec3, gvec4, gvect, gmat, 28 vconcat, gvec2, gvec3, gvec4, gvect, gmat,
29 Sized(..), 29 Sized(..),
30 singleV, singleM 30 singleV, singleM,GM
31) where 31) where
32 32
33 33
@@ -105,7 +105,7 @@ ud (Dim v) = v
105mkV :: forall (n :: Nat) t . t -> Dim n t 105mkV :: forall (n :: Nat) t . t -> Dim n t
106mkV = Dim 106mkV = Dim
107 107
108type M m n t = Dim m (Dim n (Matrix t)) 108type GM m n t = Dim m (Dim n (Matrix t))
109 109
110--ud2 :: Dim m (Dim n (Matrix t)) -> Matrix t 110--ud2 :: Dim m (Dim n (Matrix t)) -> Matrix t
111--ud2 (Dim (Dim m)) = m 111--ud2 (Dim (Dim m)) = m
@@ -166,7 +166,7 @@ gvect st xs'
166 abort info = error $ st++" "++show d++" can't be created from elements "++info 166 abort info = error $ st++" "++show d++" can't be created from elements "++info
167 167
168 168
169gmat :: forall m n t . (Show t, KnownNat m, KnownNat n, Numeric t) => String -> [t] -> M m n t 169gmat :: forall m n t . (Show t, KnownNat m, KnownNat n, Numeric t) => String -> [t] -> GM m n t
170gmat st xs' 170gmat st xs'
171 | ok = mkM x 171 | ok = mkM x
172 | not (null rest) && null (tail rest) = abort (show xs') 172 | not (null rest) && null (tail rest) = abort (show xs')