summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Static.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal/Static.hs')
-rw-r--r--packages/base/src/Internal/Static.hs527
1 files changed, 527 insertions, 0 deletions
diff --git a/packages/base/src/Internal/Static.hs b/packages/base/src/Internal/Static.hs
new file mode 100644
index 0000000..0068313
--- /dev/null
+++ b/packages/base/src/Internal/Static.hs
@@ -0,0 +1,527 @@
1#if __GLASGOW_HASKELL__ >= 708
2
3{-# LANGUAGE DataKinds #-}
4{-# LANGUAGE KindSignatures #-}
5{-# LANGUAGE GeneralizedNewtypeDeriving #-}
6{-# LANGUAGE MultiParamTypeClasses #-}
7{-# LANGUAGE FunctionalDependencies #-}
8{-# LANGUAGE FlexibleContexts #-}
9{-# LANGUAGE ScopedTypeVariables #-}
10{-# LANGUAGE Rank2Types #-}
11{-# LANGUAGE FlexibleInstances #-}
12{-# LANGUAGE TypeOperators #-}
13{-# LANGUAGE ViewPatterns #-}
14
15{- |
16Module : Internal.Static
17Copyright : (c) Alberto Ruiz 2006-14
18License : BSD3
19Stability : provisional
20
21-}
22
23module Internal.Static where
24
25
26import GHC.TypeLits
27import qualified Numeric.LinearAlgebra as LA
28import Numeric.LinearAlgebra hiding (konst,size,R,C)
29import Internal.Vector as D hiding (R,C)
30import Internal.ST
31import Data.Proxy(Proxy)
32import Foreign.Storable(Storable)
33import Text.Printf
34
35--------------------------------------------------------------------------------
36
37type ℝ = Double
38type ℂ = Complex Double
39
40newtype Dim (n :: Nat) t = Dim t
41 deriving Show
42
43lift1F
44 :: (c t -> c t)
45 -> Dim n (c t) -> Dim n (c t)
46lift1F f (Dim v) = Dim (f v)
47
48lift2F
49 :: (c t -> c t -> c t)
50 -> Dim n (c t) -> Dim n (c t) -> Dim n (c t)
51lift2F f (Dim u) (Dim v) = Dim (f u v)
52
53--------------------------------------------------------------------------------
54
55newtype R n = R (Dim n (Vector ℝ))
56 deriving (Num,Fractional,Floating)
57
58newtype C n = C (Dim n (Vector ℂ))
59 deriving (Num,Fractional,Floating)
60
61newtype L m n = L (Dim m (Dim n (Matrix ℝ)))
62
63newtype M m n = M (Dim m (Dim n (Matrix ℂ)))
64
65
66mkR :: Vector ℝ -> R n
67mkR = R . Dim
68
69mkC :: Vector ℂ -> C n
70mkC = C . Dim
71
72mkL :: Matrix ℝ -> L m n
73mkL x = L (Dim (Dim x))
74
75mkM :: Matrix ℂ -> M m n
76mkM x = M (Dim (Dim x))
77
78--------------------------------------------------------------------------------
79
80type V n t = Dim n (Vector t)
81
82ud :: Dim n (Vector t) -> Vector t
83ud (Dim v) = v
84
85mkV :: forall (n :: Nat) t . t -> Dim n t
86mkV = Dim
87
88
89vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t)
90 => V n t -> V m t -> V (n+m) t
91(ud -> u) `vconcat` (ud -> v) = mkV (vjoin [u', v'])
92 where
93 du = fromIntegral . natVal $ (undefined :: Proxy n)
94 dv = fromIntegral . natVal $ (undefined :: Proxy m)
95 u' | du > 1 && LA.size u == 1 = LA.konst (u D.@> 0) du
96 | otherwise = u
97 v' | dv > 1 && LA.size v == 1 = LA.konst (v D.@> 0) dv
98 | otherwise = v
99
100
101gvec2 :: Storable t => t -> t -> V 2 t
102gvec2 a b = mkV $ runSTVector $ do
103 v <- newUndefinedVector 2
104 writeVector v 0 a
105 writeVector v 1 b
106 return v
107
108gvec3 :: Storable t => t -> t -> t -> V 3 t
109gvec3 a b c = mkV $ runSTVector $ do
110 v <- newUndefinedVector 3
111 writeVector v 0 a
112 writeVector v 1 b
113 writeVector v 2 c
114 return v
115
116
117gvec4 :: Storable t => t -> t -> t -> t -> V 4 t
118gvec4 a b c d = mkV $ runSTVector $ do
119 v <- newUndefinedVector 4
120 writeVector v 0 a
121 writeVector v 1 b
122 writeVector v 2 c
123 writeVector v 3 d
124 return v
125
126
127gvect :: forall n t . (Show t, KnownNat n, Numeric t) => String -> [t] -> V n t
128gvect st xs'
129 | ok = mkV v
130 | not (null rest) && null (tail rest) = abort (show xs')
131 | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]")
132 | otherwise = abort (show xs)
133 where
134 (xs,rest) = splitAt d xs'
135 ok = LA.size v == d && null rest
136 v = LA.fromList xs
137 d = fromIntegral . natVal $ (undefined :: Proxy n)
138 abort info = error $ st++" "++show d++" can't be created from elements "++info
139
140
141--------------------------------------------------------------------------------
142
143type GM m n t = Dim m (Dim n (Matrix t))
144
145
146gmat :: forall m n t . (Show t, KnownNat m, KnownNat n, Numeric t) => String -> [t] -> GM m n t
147gmat st xs'
148 | ok = Dim (Dim x)
149 | not (null rest) && null (tail rest) = abort (show xs')
150 | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]")
151 | otherwise = abort (show xs)
152 where
153 (xs,rest) = splitAt (m'*n') xs'
154 v = LA.fromList xs
155 x = reshape n' v
156 ok = null rest && ((n' == 0 && dim v == 0) || n'> 0 && (rem (LA.size v) n' == 0) && LA.size x == (m',n'))
157 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
158 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
159 abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info
160
161--------------------------------------------------------------------------------
162
163class Num t => Sized t s d | s -> t, s -> d
164 where
165 konst :: t -> s
166 unwrap :: s -> d t
167 fromList :: [t] -> s
168 extract :: s -> d t
169 create :: d t -> Maybe s
170 size :: s -> IndexOf d
171
172singleV v = LA.size v == 1
173singleM m = rows m == 1 && cols m == 1
174
175
176instance forall n. KnownNat n => Sized ℂ (C n) Vector
177 where
178 size _ = fromIntegral . natVal $ (undefined :: Proxy n)
179 konst x = mkC (LA.scalar x)
180 unwrap (C (Dim v)) = v
181 fromList xs = C (gvect "C" xs)
182 extract s@(unwrap -> v)
183 | singleV v = LA.konst (v!0) (size s)
184 | otherwise = v
185 create v
186 | LA.size v == size r = Just r
187 | otherwise = Nothing
188 where
189 r = mkC v :: C n
190
191
192instance forall n. KnownNat n => Sized ℝ (R n) Vector
193 where
194 size _ = fromIntegral . natVal $ (undefined :: Proxy n)
195 konst x = mkR (LA.scalar x)
196 unwrap (R (Dim v)) = v
197 fromList xs = R (gvect "R" xs)
198 extract s@(unwrap -> v)
199 | singleV v = LA.konst (v!0) (size s)
200 | otherwise = v
201 create v
202 | LA.size v == size r = Just r
203 | otherwise = Nothing
204 where
205 r = mkR v :: R n
206
207
208
209instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) Matrix
210 where
211 size _ = ((fromIntegral . natVal) (undefined :: Proxy m)
212 ,(fromIntegral . natVal) (undefined :: Proxy n))
213 konst x = mkL (LA.scalar x)
214 fromList xs = L (gmat "L" xs)
215 unwrap (L (Dim (Dim m))) = m
216 extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n'
217 extract s@(unwrap -> a)
218 | singleM a = LA.konst (a `atIndex` (0,0)) (size s)
219 | otherwise = a
220 create x
221 | LA.size x == size r = Just r
222 | otherwise = Nothing
223 where
224 r = mkL x :: L m n
225
226
227instance forall m n . (KnownNat m, KnownNat n) => Sized ℂ (M m n) Matrix
228 where
229 size _ = ((fromIntegral . natVal) (undefined :: Proxy m)
230 ,(fromIntegral . natVal) (undefined :: Proxy n))
231 konst x = mkM (LA.scalar x)
232 fromList xs = M (gmat "M" xs)
233 unwrap (M (Dim (Dim m))) = m
234 extract (isDiagC -> Just (z,y,(m',n'))) = diagRect z y m' n'
235 extract s@(unwrap -> a)
236 | singleM a = LA.konst (a `atIndex` (0,0)) (size s)
237 | otherwise = a
238 create x
239 | LA.size x == size r = Just r
240 | otherwise = Nothing
241 where
242 r = mkM x :: M m n
243
244--------------------------------------------------------------------------------
245
246instance (KnownNat n, KnownNat m) => Transposable (L m n) (L n m)
247 where
248 tr a@(isDiag -> Just _) = mkL (extract a)
249 tr (extract -> a) = mkL (tr a)
250 tr' = tr
251
252instance (KnownNat n, KnownNat m) => Transposable (M m n) (M n m)
253 where
254 tr a@(isDiagC -> Just _) = mkM (extract a)
255 tr (extract -> a) = mkM (tr a)
256 tr' a@(isDiagC -> Just _) = mkM (extract a)
257 tr' (extract -> a) = mkM (tr' a)
258
259--------------------------------------------------------------------------------
260
261isDiag :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ, Vector ℝ, (Int,Int))
262isDiag (L x) = isDiagg x
263
264isDiagC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (ℂ, Vector ℂ, (Int,Int))
265isDiagC (M x) = isDiagg x
266
267
268isDiagg :: forall m n t . (Numeric t, KnownNat m, KnownNat n) => GM m n t -> Maybe (t, Vector t, (Int,Int))
269isDiagg (Dim (Dim x))
270 | singleM x = Nothing
271 | rows x == 1 && m' > 1 || cols x == 1 && n' > 1 = Just (z,yz,(m',n'))
272 | otherwise = Nothing
273 where
274 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
275 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
276 v = flatten x
277 z = v `atIndex` 0
278 y = subVector 1 (LA.size v-1) v
279 ny = LA.size y
280 zeros = LA.konst 0 (max 0 (min m' n' - ny))
281 yz = vjoin [y,zeros]
282
283--------------------------------------------------------------------------------
284
285instance forall n . KnownNat n => Show (R n)
286 where
287 show s@(R (Dim v))
288 | singleV v = "("++show (v!0)++" :: R "++show d++")"
289 | otherwise = "(vector"++ drop 8 (show v)++" :: R "++show d++")"
290 where
291 d = size s
292
293instance forall n . KnownNat n => Show (C n)
294 where
295 show s@(C (Dim v))
296 | singleV v = "("++show (v!0)++" :: C "++show d++")"
297 | otherwise = "(vector"++ drop 8 (show v)++" :: C "++show d++")"
298 where
299 d = size s
300
301instance forall m n . (KnownNat m, KnownNat n) => Show (L m n)
302 where
303 show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n'
304 show s@(L (Dim (Dim x)))
305 | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n'
306 | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")"
307 where
308 (m',n') = size s
309
310instance forall m n . (KnownNat m, KnownNat n) => Show (M m n)
311 where
312 show (isDiagC -> Just (z,y,(m',n'))) = printf "(diag %s %s :: M %d %d)" (show z) (drop 9 $ show y) m' n'
313 show s@(M (Dim (Dim x)))
314 | singleM x = printf "(%s :: M %d %d)" (show (x `atIndex` (0,0))) m' n'
315 | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: M "++show m'++" "++show n'++")"
316 where
317 (m',n') = size s
318
319--------------------------------------------------------------------------------
320
321instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t))
322 where
323 (+) = lift2F (+)
324 (*) = lift2F (*)
325 (-) = lift2F (-)
326 abs = lift1F abs
327 signum = lift1F signum
328 negate = lift1F negate
329 fromInteger x = Dim (fromInteger x)
330
331instance (Num (Vector t), Num (Matrix t), Fractional t, Numeric t) => Fractional (Dim n (Vector t))
332 where
333 fromRational x = Dim (fromRational x)
334 (/) = lift2F (/)
335
336instance (Fractional t, Floating (Vector t), Numeric t) => Floating (Dim n (Vector t)) where
337 sin = lift1F sin
338 cos = lift1F cos
339 tan = lift1F tan
340 asin = lift1F asin
341 acos = lift1F acos
342 atan = lift1F atan
343 sinh = lift1F sinh
344 cosh = lift1F cosh
345 tanh = lift1F tanh
346 asinh = lift1F asinh
347 acosh = lift1F acosh
348 atanh = lift1F atanh
349 exp = lift1F exp
350 log = lift1F log
351 sqrt = lift1F sqrt
352 (**) = lift2F (**)
353 pi = Dim pi
354
355
356instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t)))
357 where
358 (+) = (lift2F . lift2F) (+)
359 (*) = (lift2F . lift2F) (*)
360 (-) = (lift2F . lift2F) (-)
361 abs = (lift1F . lift1F) abs
362 signum = (lift1F . lift1F) signum
363 negate = (lift1F . lift1F) negate
364 fromInteger x = Dim (Dim (fromInteger x))
365
366instance (Num (Vector t), Num (Matrix t), Fractional t, Numeric t) => Fractional (Dim m (Dim n (Matrix t)))
367 where
368 fromRational x = Dim (Dim (fromRational x))
369 (/) = (lift2F.lift2F) (/)
370
371instance (Num (Vector t), Floating (Matrix t), Fractional t, Numeric t) => Floating (Dim m (Dim n (Matrix t))) where
372 sin = (lift1F . lift1F) sin
373 cos = (lift1F . lift1F) cos
374 tan = (lift1F . lift1F) tan
375 asin = (lift1F . lift1F) asin
376 acos = (lift1F . lift1F) acos
377 atan = (lift1F . lift1F) atan
378 sinh = (lift1F . lift1F) sinh
379 cosh = (lift1F . lift1F) cosh
380 tanh = (lift1F . lift1F) tanh
381 asinh = (lift1F . lift1F) asinh
382 acosh = (lift1F . lift1F) acosh
383 atanh = (lift1F . lift1F) atanh
384 exp = (lift1F . lift1F) exp
385 log = (lift1F . lift1F) log
386 sqrt = (lift1F . lift1F) sqrt
387 (**) = (lift2F . lift2F) (**)
388 pi = Dim (Dim pi)
389
390--------------------------------------------------------------------------------
391
392
393adaptDiag f a@(isDiag -> Just _) b | isFull b = f (mkL (extract a)) b
394adaptDiag f a b@(isDiag -> Just _) | isFull a = f a (mkL (extract b))
395adaptDiag f a b = f a b
396
397isFull m = isDiag m == Nothing && not (singleM (unwrap m))
398
399
400lift1L f (L v) = L (f v)
401lift2L f (L a) (L b) = L (f a b)
402lift2LD f = adaptDiag (lift2L f)
403
404
405instance (KnownNat n, KnownNat m) => Num (L n m)
406 where
407 (+) = lift2LD (+)
408 (*) = lift2LD (*)
409 (-) = lift2LD (-)
410 abs = lift1L abs
411 signum = lift1L signum
412 negate = lift1L negate
413 fromInteger = L . Dim . Dim . fromInteger
414
415instance (KnownNat n, KnownNat m) => Fractional (L n m)
416 where
417 fromRational = L . Dim . Dim . fromRational
418 (/) = lift2LD (/)
419
420instance (KnownNat n, KnownNat m) => Floating (L n m) where
421 sin = lift1L sin
422 cos = lift1L cos
423 tan = lift1L tan
424 asin = lift1L asin
425 acos = lift1L acos
426 atan = lift1L atan
427 sinh = lift1L sinh
428 cosh = lift1L cosh
429 tanh = lift1L tanh
430 asinh = lift1L asinh
431 acosh = lift1L acosh
432 atanh = lift1L atanh
433 exp = lift1L exp
434 log = lift1L log
435 sqrt = lift1L sqrt
436 (**) = lift2LD (**)
437 pi = konst pi
438
439--------------------------------------------------------------------------------
440
441adaptDiagC f a@(isDiagC -> Just _) b | isFullC b = f (mkM (extract a)) b
442adaptDiagC f a b@(isDiagC -> Just _) | isFullC a = f a (mkM (extract b))
443adaptDiagC f a b = f a b
444
445isFullC m = isDiagC m == Nothing && not (singleM (unwrap m))
446
447lift1M f (M v) = M (f v)
448lift2M f (M a) (M b) = M (f a b)
449lift2MD f = adaptDiagC (lift2M f)
450
451instance (KnownNat n, KnownNat m) => Num (M n m)
452 where
453 (+) = lift2MD (+)
454 (*) = lift2MD (*)
455 (-) = lift2MD (-)
456 abs = lift1M abs
457 signum = lift1M signum
458 negate = lift1M negate
459 fromInteger = M . Dim . Dim . fromInteger
460
461instance (KnownNat n, KnownNat m) => Fractional (M n m)
462 where
463 fromRational = M . Dim . Dim . fromRational
464 (/) = lift2MD (/)
465
466instance (KnownNat n, KnownNat m) => Floating (M n m) where
467 sin = lift1M sin
468 cos = lift1M cos
469 tan = lift1M tan
470 asin = lift1M asin
471 acos = lift1M acos
472 atan = lift1M atan
473 sinh = lift1M sinh
474 cosh = lift1M cosh
475 tanh = lift1M tanh
476 asinh = lift1M asinh
477 acosh = lift1M acosh
478 atanh = lift1M atanh
479 exp = lift1M exp
480 log = lift1M log
481 sqrt = lift1M sqrt
482 (**) = lift2MD (**)
483 pi = M pi
484
485--------------------------------------------------------------------------------
486
487
488class Disp t
489 where
490 disp :: Int -> t -> IO ()
491
492
493instance (KnownNat m, KnownNat n) => Disp (L m n)
494 where
495 disp n x = do
496 let a = extract x
497 let su = LA.dispf n a
498 printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su)
499
500instance (KnownNat m, KnownNat n) => Disp (M m n)
501 where
502 disp n x = do
503 let a = extract x
504 let su = LA.dispcf n a
505 printf "M %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su)
506
507
508instance KnownNat n => Disp (R n)
509 where
510 disp n v = do
511 let su = LA.dispf n (asRow $ extract v)
512 putStr "R " >> putStr (tail . dropWhile (/='x') $ su)
513
514instance KnownNat n => Disp (C n)
515 where
516 disp n v = do
517 let su = LA.dispcf n (asRow $ extract v)
518 putStr "C " >> putStr (tail . dropWhile (/='x') $ su)
519
520--------------------------------------------------------------------------------
521
522#else
523
524module Numeric.LinearAlgebra.Static.Internal where
525
526#endif
527