summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric/LinearAlgebra/Real.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Real.hs')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Real.hs231
1 files changed, 1 insertions, 230 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs
index aa48687..97c462e 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Real.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs
@@ -62,28 +62,13 @@ import Numeric.HMatrix hiding (
62import qualified Numeric.HMatrix as LA 62import qualified Numeric.HMatrix as LA
63import Data.Proxy(Proxy) 63import Data.Proxy(Proxy)
64import Numeric.LinearAlgebra.Static 64import Numeric.LinearAlgebra.Static
65import Text.Printf
66import Control.Arrow((***)) 65import Control.Arrow((***))
67 66
68 67
69๐‘– :: Sized โ„‚ s c => s 68๐‘– :: Sized โ„‚ s c => s
70๐‘– = konst i_C 69๐‘– = konst iC
71 70
72instance forall n . KnownNat n => Show (R n)
73 where
74 show (ud1 -> v)
75 | singleV v = "("++show (v!0)++" :: R "++show d++")"
76 | otherwise = "(vector"++ drop 8 (show v)++" :: R "++show d++")"
77 where
78 d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
79 71
80instance forall n . KnownNat n => Show (C n)
81 where
82 show (C (Dim v))
83 | singleV v = "("++show (v!0)++" :: C "++show d++")"
84 | otherwise = "(vector"++ drop 8 (show v)++" :: C "++show d++")"
85 where
86 d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
87 72
88 73
89 74
@@ -91,12 +76,6 @@ ud1 :: R n -> Vector โ„
91ud1 (R (Dim v)) = v 76ud1 (R (Dim v)) = v
92 77
93 78
94mkR :: Vector โ„ -> R n
95mkR = R . Dim
96
97mkC :: Vector โ„‚ -> C n
98mkC = C . Dim
99
100infixl 4 & 79infixl 4 &
101(&) :: forall n . KnownNat n 80(&) :: forall n . KnownNat n
102 => R n -> โ„ -> R (n+1) 81 => R n -> โ„ -> R (n+1)
@@ -143,95 +122,12 @@ dim = mkR (scalar d)
143 122
144-------------------------------------------------------------------------------- 123--------------------------------------------------------------------------------
145 124
146newtype L m n = L (Dim m (Dim n (Matrix โ„)))
147
148newtype M m n = M (Dim m (Dim n (Matrix โ„‚)))
149 125
150ud2 :: L m n -> Matrix โ„ 126ud2 :: L m n -> Matrix โ„
151ud2 (L (Dim (Dim x))) = x 127ud2 (L (Dim (Dim x))) = x
152 128
153 129
154mkL :: Matrix โ„ -> L m n
155mkL x = L (Dim (Dim x))
156
157mkM :: Matrix โ„‚ -> M m n
158mkM x = M (Dim (Dim x))
159
160instance forall m n . (KnownNat m, KnownNat n) => Show (L m n)
161 where
162 show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n'
163 show (ud2 -> x)
164 | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n'
165 | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")"
166 where
167 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
168 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
169
170instance forall m n . (KnownNat m, KnownNat n) => Show (M m n)
171 where
172 show (isDiagC -> Just (z,y,(m',n'))) = printf "(diag %s %s :: M %d %d)" (show z) (drop 9 $ show y) m' n'
173 show (M (Dim (Dim x)))
174 | singleM x = printf "(%s :: M %d %d)" (show (x `atIndex` (0,0))) m' n'
175 | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: M "++show m'++" "++show n'++")"
176 where
177 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
178 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
179
180
181-------------------------------------------------------------------------------- 130--------------------------------------------------------------------------------
182
183instance forall n. KnownNat n => Sized โ„‚ (C n) (Vector โ„‚)
184 where
185 konst x = mkC (LA.scalar x)
186 unwrap (C (Dim v)) = v
187 fromList xs = C (gvect "C" xs)
188 extract (unwrap -> v)
189 | singleV v = LA.konst (v!0) d
190 | otherwise = v
191 where
192 d = fromIntegral . natVal $ (undefined :: Proxy n)
193
194
195instance forall n. KnownNat n => Sized โ„ (R n) (Vector โ„)
196 where
197 konst x = mkR (LA.scalar x)
198 unwrap = ud1
199 fromList xs = R (gvect "R" xs)
200 extract (unwrap -> v)
201 | singleV v = LA.konst (v!0) d
202 | otherwise = v
203 where
204 d = fromIntegral . natVal $ (undefined :: Proxy n)
205
206
207
208instance forall m n . (KnownNat m, KnownNat n) => Sized โ„ (L m n) (Matrix โ„)
209 where
210 konst x = mkL (LA.scalar x)
211 fromList xs = L (gmat "L" xs)
212 unwrap = ud2
213 extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n'
214 extract (unwrap -> a)
215 | singleM a = LA.konst (a `atIndex` (0,0)) (m',n')
216 | otherwise = a
217 where
218 m' = fromIntegral . natVal $ (undefined :: Proxy m)
219 n' = fromIntegral . natVal $ (undefined :: Proxy n)
220
221
222instance forall m n . (KnownNat m, KnownNat n) => Sized โ„‚ (M m n) (Matrix โ„‚)
223 where
224 konst x = mkM (LA.scalar x)
225 fromList xs = M (gmat "M" xs)
226 unwrap (M (Dim (Dim m))) = m
227 extract (isDiagC -> Just (z,y,(m',n'))) = diagRect z y m' n'
228 extract (unwrap -> a)
229 | singleM a = LA.konst (a `atIndex` (0,0)) (m',n')
230 | otherwise = a
231 where
232 m' = fromIntegral . natVal $ (undefined :: Proxy m)
233 n' = fromIntegral . natVal $ (undefined :: Proxy n)
234
235-------------------------------------------------------------------------------- 131--------------------------------------------------------------------------------
236 132
237diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => โ„ -> R k -> L m n 133diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => โ„ -> R k -> L m n
@@ -269,41 +165,6 @@ blockAt x r c a = mkL res
269 165
270-------------------------------------------------------------------------------- 166--------------------------------------------------------------------------------
271 167
272class Disp t
273 where
274 disp :: Int -> t -> IO ()
275
276
277instance (KnownNat m, KnownNat n) => Disp (L m n)
278 where
279 disp n x = do
280 let a = extract x
281 let su = LA.dispf n a
282 printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su)
283
284instance (KnownNat m, KnownNat n) => Disp (M m n)
285 where
286 disp n x = do
287 let a = extract x
288 let su = LA.dispcf n a
289 printf "M %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su)
290
291
292instance KnownNat n => Disp (R n)
293 where
294 disp n v = do
295 let su = LA.dispf n (asRow $ extract v)
296 putStr "R " >> putStr (tail . dropWhile (/='x') $ su)
297
298instance KnownNat n => Disp (C n)
299 where
300 disp n v = do
301 let su = LA.dispcf n (asRow $ extract v)
302 putStr "C " >> putStr (tail . dropWhile (/='x') $ su)
303
304
305--------------------------------------------------------------------------------
306
307 168
308row :: R n -> L 1 n 169row :: R n -> L 1 n
309row = mkL . asRow . ud1 170row = mkL . asRow . ud1
@@ -344,28 +205,6 @@ isKonst (unwrap -> x)
344 205
345 206
346 207
347isDiag :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (โ„, Vector โ„, (Int,Int))
348isDiag (L x) = isDiagg x
349
350isDiagC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (โ„‚, Vector โ„‚, (Int,Int))
351isDiagC (M x) = isDiagg x
352
353
354isDiagg :: forall m n t . (Numeric t, KnownNat m, KnownNat n) => GM m n t -> Maybe (t, Vector t, (Int,Int))
355isDiagg (Dim (Dim x))
356 | singleM x = Nothing
357 | rows x == 1 && m' > 1 || cols x == 1 && n' > 1 = Just (z,yz,(m',n'))
358 | otherwise = Nothing
359 where
360 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
361 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
362 v = flatten x
363 z = v `atIndex` 0
364 y = subVector 1 (size v-1) v
365 ny = size y
366 zeros = LA.konst 0 (max 0 (min m' n' - ny))
367 yz = vjoin [y,zeros]
368
369 208
370infixr 8 <> 209infixr 8 <>
371(<>) :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n 210(<>) :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n
@@ -397,74 +236,6 @@ infixr 8 <ยท>
397 | singleV u || singleV v = sumElements (u * v) 236 | singleV u || singleV v = sumElements (u * v)
398 | otherwise = udot u v 237 | otherwise = udot u v
399 238
400
401instance (KnownNat n, KnownNat m) => Transposable (L m n) (L n m)
402 where
403 tr a@(isDiag -> Just _) = mkL (extract a)
404 tr (extract -> a) = mkL (tr a)
405
406instance (KnownNat n, KnownNat m) => Transposable (M m n) (M n m)
407 where
408 tr a@(isDiagC -> Just _) = mkM (extract a)
409 tr (extract -> a) = mkM (tr a)
410
411
412--------------------------------------------------------------------------------
413
414adaptDiag f a@(isDiag -> Just _) b | isFull b = f (mkL (extract a)) b
415adaptDiag f a b@(isDiag -> Just _) | isFull a = f a (mkL (extract b))
416adaptDiag f a b = f a b
417
418isFull m = isDiag m == Nothing && not (singleM (unwrap m))
419
420
421lift1L f (L v) = L (f v)
422lift2L f (L a) (L b) = L (f a b)
423lift2LD f = adaptDiag (lift2L f)
424
425
426instance (KnownNat n, KnownNat m) => Num (L n m)
427 where
428 (+) = lift2LD (+)
429 (*) = lift2LD (*)
430 (-) = lift2LD (-)
431 abs = lift1L abs
432 signum = lift1L signum
433 negate = lift1L negate
434 fromInteger = L . Dim . Dim . fromInteger
435
436instance (KnownNat n, KnownNat m) => Fractional (L n m)
437 where
438 fromRational = L . Dim . Dim . fromRational
439 (/) = lift2LD (/)
440
441--------------------------------------------------------------------------------
442
443adaptDiagC f a@(isDiagC -> Just _) b | isFullC b = f (mkM (extract a)) b
444adaptDiagC f a b@(isDiagC -> Just _) | isFullC a = f a (mkM (extract b))
445adaptDiagC f a b = f a b
446
447isFullC m = isDiagC m == Nothing && not (singleM (unwrap m))
448
449lift1M f (M v) = M (f v)
450lift2M f (M a) (M b) = M (f a b)
451lift2MD f = adaptDiagC (lift2M f)
452
453instance (KnownNat n, KnownNat m) => Num (M n m)
454 where
455 (+) = lift2MD (+)
456 (*) = lift2MD (*)
457 (-) = lift2MD (-)
458 abs = lift1M abs
459 signum = lift1M signum
460 negate = lift1M negate
461 fromInteger = M . Dim . Dim . fromInteger
462
463instance (KnownNat n, KnownNat m) => Fractional (M n m)
464 where
465 fromRational = M . Dim . Dim . fromRational
466 (/) = lift2MD (/)
467
468-------------------------------------------------------------------------------- 239--------------------------------------------------------------------------------
469 240
470{- 241{-