summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric/LinearAlgebra/Static.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Static.hs')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Static.hs334
1 files changed, 281 insertions, 53 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs
index 6acd9a3..2647f20 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Static.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs
@@ -21,14 +21,7 @@ Stability : provisional
21 21
22-} 22-}
23 23
24module Numeric.LinearAlgebra.Static( 24module Numeric.LinearAlgebra.Static where
25 Dim(..),
26 R(..), C(..),
27 lift1F, lift2F,
28 vconcat, gvec2, gvec3, gvec4, gvect, gmat,
29 Sized(..),
30 singleV, singleM,GM
31) where
32 25
33 26
34import GHC.TypeLits 27import GHC.TypeLits
@@ -37,17 +30,9 @@ import Data.Packed as D
37import Data.Packed.ST 30import Data.Packed.ST
38import Data.Proxy(Proxy) 31import Data.Proxy(Proxy)
39import Foreign.Storable(Storable) 32import Foreign.Storable(Storable)
33import Text.Printf
40 34
41 35--------------------------------------------------------------------------------
42
43newtype R n = R (Dim n (Vector ℝ))
44 deriving (Num,Fractional)
45
46
47newtype C n = C (Dim n (Vector ℂ))
48 deriving (Num,Fractional)
49
50
51 36
52newtype Dim (n :: Nat) t = Dim t 37newtype Dim (n :: Nat) t = Dim t
53 deriving Show 38 deriving Show
@@ -64,36 +49,28 @@ lift2F f (Dim u) (Dim v) = Dim (f u v)
64 49
65-------------------------------------------------------------------------------- 50--------------------------------------------------------------------------------
66 51
67instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t)) 52newtype R n = R (Dim n (Vector ℝ))
68 where 53 deriving (Num,Fractional)
69 (+) = lift2F (+)
70 (*) = lift2F (*)
71 (-) = lift2F (-)
72 abs = lift1F abs
73 signum = lift1F signum
74 negate = lift1F negate
75 fromInteger x = Dim (fromInteger x)
76 54
77instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim n (Vector t)) 55newtype C n = C (Dim n (Vector ℂ))
78 where 56 deriving (Num,Fractional)
79 fromRational x = Dim (fromRational x)
80 (/) = lift2F (/)
81 57
58newtype L m n = L (Dim m (Dim n (Matrix ℝ)))
82 59
83instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t))) 60newtype M m n = M (Dim m (Dim n (Matrix ℂ)))
84 where
85 (+) = (lift2F . lift2F) (+)
86 (*) = (lift2F . lift2F) (*)
87 (-) = (lift2F . lift2F) (-)
88 abs = (lift1F . lift1F) abs
89 signum = (lift1F . lift1F) signum
90 negate = (lift1F . lift1F) negate
91 fromInteger x = Dim (Dim (fromInteger x))
92 61
93instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim m (Dim n (Matrix t))) 62
94 where 63mkR :: Vector ℝ -> R n
95 fromRational x = Dim (Dim (fromRational x)) 64mkR = R . Dim
96 (/) = (lift2F.lift2F) (/) 65
66mkC :: Vector ℂ -> C n
67mkC = C . Dim
68
69mkL :: Matrix ℝ -> L m n
70mkL x = L (Dim (Dim x))
71
72mkM :: Matrix ℂ -> M m n
73mkM x = M (Dim (Dim x))
97 74
98-------------------------------------------------------------------------------- 75--------------------------------------------------------------------------------
99 76
@@ -105,14 +82,6 @@ ud (Dim v) = v
105mkV :: forall (n :: Nat) t . t -> Dim n t 82mkV :: forall (n :: Nat) t . t -> Dim n t
106mkV = Dim 83mkV = Dim
107 84
108type GM m n t = Dim m (Dim n (Matrix t))
109
110--ud2 :: Dim m (Dim n (Matrix t)) -> Matrix t
111--ud2 (Dim (Dim m)) = m
112
113mkM :: forall (m :: Nat) (n :: Nat) t . t -> Dim m (Dim n t)
114mkM = Dim . Dim
115
116 85
117vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t) 86vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t)
118 => V n t -> V m t -> V (n+m) t 87 => V n t -> V m t -> V (n+m) t
@@ -166,9 +135,14 @@ gvect st xs'
166 abort info = error $ st++" "++show d++" can't be created from elements "++info 135 abort info = error $ st++" "++show d++" can't be created from elements "++info
167 136
168 137
138--------------------------------------------------------------------------------
139
140type GM m n t = Dim m (Dim n (Matrix t))
141
142
169gmat :: forall m n t . (Show t, KnownNat m, KnownNat n, Numeric t) => String -> [t] -> GM m n t 143gmat :: forall m n t . (Show t, KnownNat m, KnownNat n, Numeric t) => String -> [t] -> GM m n t
170gmat st xs' 144gmat st xs'
171 | ok = mkM x 145 | ok = Dim (Dim x)
172 | not (null rest) && null (tail rest) = abort (show xs') 146 | not (null rest) && null (tail rest) = abort (show xs')
173 | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]") 147 | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]")
174 | otherwise = abort (show xs) 148 | otherwise = abort (show xs)
@@ -181,6 +155,7 @@ gmat st xs'
181 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int 155 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
182 abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info 156 abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info
183 157
158--------------------------------------------------------------------------------
184 159
185class Num t => Sized t s d | s -> t, s -> d 160class Num t => Sized t s d | s -> t, s -> d
186 where 161 where
@@ -192,3 +167,256 @@ class Num t => Sized t s d | s -> t, s -> d
192singleV v = size v == 1 167singleV v = size v == 1
193singleM m = rows m == 1 && cols m == 1 168singleM m = rows m == 1 && cols m == 1
194 169
170
171instance forall n. KnownNat n => Sized ℂ (C n) (Vector ℂ)
172 where
173 konst x = mkC (LA.scalar x)
174 unwrap (C (Dim v)) = v
175 fromList xs = C (gvect "C" xs)
176 extract (unwrap -> v)
177 | singleV v = LA.konst (v!0) d
178 | otherwise = v
179 where
180 d = fromIntegral . natVal $ (undefined :: Proxy n)
181
182
183instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ)
184 where
185 konst x = mkR (LA.scalar x)
186 unwrap (R (Dim v)) = v
187 fromList xs = R (gvect "R" xs)
188 extract (unwrap -> v)
189 | singleV v = LA.konst (v!0) d
190 | otherwise = v
191 where
192 d = fromIntegral . natVal $ (undefined :: Proxy n)
193
194
195
196instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ)
197 where
198 konst x = mkL (LA.scalar x)
199 fromList xs = L (gmat "L" xs)
200 unwrap (L (Dim (Dim m))) = m
201 extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n'
202 extract (unwrap -> a)
203 | singleM a = LA.konst (a `atIndex` (0,0)) (m',n')
204 | otherwise = a
205 where
206 m' = fromIntegral . natVal $ (undefined :: Proxy m)
207 n' = fromIntegral . natVal $ (undefined :: Proxy n)
208
209
210instance forall m n . (KnownNat m, KnownNat n) => Sized ℂ (M m n) (Matrix ℂ)
211 where
212 konst x = mkM (LA.scalar x)
213 fromList xs = M (gmat "M" xs)
214 unwrap (M (Dim (Dim m))) = m
215 extract (isDiagC -> Just (z,y,(m',n'))) = diagRect z y m' n'
216 extract (unwrap -> a)
217 | singleM a = LA.konst (a `atIndex` (0,0)) (m',n')
218 | otherwise = a
219 where
220 m' = fromIntegral . natVal $ (undefined :: Proxy m)
221 n' = fromIntegral . natVal $ (undefined :: Proxy n)
222
223--------------------------------------------------------------------------------
224
225instance (KnownNat n, KnownNat m) => Transposable (L m n) (L n m)
226 where
227 tr a@(isDiag -> Just _) = mkL (extract a)
228 tr (extract -> a) = mkL (tr a)
229
230instance (KnownNat n, KnownNat m) => Transposable (M m n) (M n m)
231 where
232 tr a@(isDiagC -> Just _) = mkM (extract a)
233 tr (extract -> a) = mkM (tr a)
234
235--------------------------------------------------------------------------------
236
237isDiag :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ, Vector ℝ, (Int,Int))
238isDiag (L x) = isDiagg x
239
240isDiagC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (ℂ, Vector ℂ, (Int,Int))
241isDiagC (M x) = isDiagg x
242
243
244isDiagg :: forall m n t . (Numeric t, KnownNat m, KnownNat n) => GM m n t -> Maybe (t, Vector t, (Int,Int))
245isDiagg (Dim (Dim x))
246 | singleM x = Nothing
247 | rows x == 1 && m' > 1 || cols x == 1 && n' > 1 = Just (z,yz,(m',n'))
248 | otherwise = Nothing
249 where
250 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
251 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
252 v = flatten x
253 z = v `atIndex` 0
254 y = subVector 1 (size v-1) v
255 ny = size y
256 zeros = LA.konst 0 (max 0 (min m' n' - ny))
257 yz = vjoin [y,zeros]
258
259--------------------------------------------------------------------------------
260
261instance forall n . KnownNat n => Show (R n)
262 where
263 show (R (Dim v))
264 | singleV v = "("++show (v!0)++" :: R "++show d++")"
265 | otherwise = "(vector"++ drop 8 (show v)++" :: R "++show d++")"
266 where
267 d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
268
269instance forall n . KnownNat n => Show (C n)
270 where
271 show (C (Dim v))
272 | singleV v = "("++show (v!0)++" :: C "++show d++")"
273 | otherwise = "(vector"++ drop 8 (show v)++" :: C "++show d++")"
274 where
275 d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
276
277instance forall m n . (KnownNat m, KnownNat n) => Show (L m n)
278 where
279 show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n'
280 show (L (Dim (Dim x)))
281 | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n'
282 | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")"
283 where
284 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
285 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
286
287instance forall m n . (KnownNat m, KnownNat n) => Show (M m n)
288 where
289 show (isDiagC -> Just (z,y,(m',n'))) = printf "(diag %s %s :: M %d %d)" (show z) (drop 9 $ show y) m' n'
290 show (M (Dim (Dim x)))
291 | singleM x = printf "(%s :: M %d %d)" (show (x `atIndex` (0,0))) m' n'
292 | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: M "++show m'++" "++show n'++")"
293 where
294 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
295 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
296
297--------------------------------------------------------------------------------
298
299instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t))
300 where
301 (+) = lift2F (+)
302 (*) = lift2F (*)
303 (-) = lift2F (-)
304 abs = lift1F abs
305 signum = lift1F signum
306 negate = lift1F negate
307 fromInteger x = Dim (fromInteger x)
308
309instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim n (Vector t))
310 where
311 fromRational x = Dim (fromRational x)
312 (/) = lift2F (/)
313
314
315instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t)))
316 where
317 (+) = (lift2F . lift2F) (+)
318 (*) = (lift2F . lift2F) (*)
319 (-) = (lift2F . lift2F) (-)
320 abs = (lift1F . lift1F) abs
321 signum = (lift1F . lift1F) signum
322 negate = (lift1F . lift1F) negate
323 fromInteger x = Dim (Dim (fromInteger x))
324
325instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim m (Dim n (Matrix t)))
326 where
327 fromRational x = Dim (Dim (fromRational x))
328 (/) = (lift2F.lift2F) (/)
329
330--------------------------------------------------------------------------------
331
332
333adaptDiag f a@(isDiag -> Just _) b | isFull b = f (mkL (extract a)) b
334adaptDiag f a b@(isDiag -> Just _) | isFull a = f a (mkL (extract b))
335adaptDiag f a b = f a b
336
337isFull m = isDiag m == Nothing && not (singleM (unwrap m))
338
339
340lift1L f (L v) = L (f v)
341lift2L f (L a) (L b) = L (f a b)
342lift2LD f = adaptDiag (lift2L f)
343
344
345instance (KnownNat n, KnownNat m) => Num (L n m)
346 where
347 (+) = lift2LD (+)
348 (*) = lift2LD (*)
349 (-) = lift2LD (-)
350 abs = lift1L abs
351 signum = lift1L signum
352 negate = lift1L negate
353 fromInteger = L . Dim . Dim . fromInteger
354
355instance (KnownNat n, KnownNat m) => Fractional (L n m)
356 where
357 fromRational = L . Dim . Dim . fromRational
358 (/) = lift2LD (/)
359
360--------------------------------------------------------------------------------
361
362adaptDiagC f a@(isDiagC -> Just _) b | isFullC b = f (mkM (extract a)) b
363adaptDiagC f a b@(isDiagC -> Just _) | isFullC a = f a (mkM (extract b))
364adaptDiagC f a b = f a b
365
366isFullC m = isDiagC m == Nothing && not (singleM (unwrap m))
367
368lift1M f (M v) = M (f v)
369lift2M f (M a) (M b) = M (f a b)
370lift2MD f = adaptDiagC (lift2M f)
371
372instance (KnownNat n, KnownNat m) => Num (M n m)
373 where
374 (+) = lift2MD (+)
375 (*) = lift2MD (*)
376 (-) = lift2MD (-)
377 abs = lift1M abs
378 signum = lift1M signum
379 negate = lift1M negate
380 fromInteger = M . Dim . Dim . fromInteger
381
382instance (KnownNat n, KnownNat m) => Fractional (M n m)
383 where
384 fromRational = M . Dim . Dim . fromRational
385 (/) = lift2MD (/)
386
387--------------------------------------------------------------------------------
388
389
390class Disp t
391 where
392 disp :: Int -> t -> IO ()
393
394
395instance (KnownNat m, KnownNat n) => Disp (L m n)
396 where
397 disp n x = do
398 let a = extract x
399 let su = LA.dispf n a
400 printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su)
401
402instance (KnownNat m, KnownNat n) => Disp (M m n)
403 where
404 disp n x = do
405 let a = extract x
406 let su = LA.dispcf n a
407 printf "M %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su)
408
409
410instance KnownNat n => Disp (R n)
411 where
412 disp n v = do
413 let su = LA.dispf n (asRow $ extract v)
414 putStr "R " >> putStr (tail . dropWhile (/='x') $ su)
415
416instance KnownNat n => Disp (C n)
417 where
418 disp n v = do
419 let su = LA.dispcf n (asRow $ extract v)
420 putStr "C " >> putStr (tail . dropWhile (/='x') $ su)
421
422--------------------------------------------------------------------------------