summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs524
1 files changed, 0 insertions, 524 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs b/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs
deleted file mode 100644
index 7b770e0..0000000
--- a/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs
+++ /dev/null
@@ -1,524 +0,0 @@
1#if __GLASGOW_HASKELL__ >= 708
2
3{-# LANGUAGE DataKinds #-}
4{-# LANGUAGE KindSignatures #-}
5{-# LANGUAGE GeneralizedNewtypeDeriving #-}
6{-# LANGUAGE MultiParamTypeClasses #-}
7{-# LANGUAGE FunctionalDependencies #-}
8{-# LANGUAGE FlexibleContexts #-}
9{-# LANGUAGE ScopedTypeVariables #-}
10{-# LANGUAGE Rank2Types #-}
11{-# LANGUAGE FlexibleInstances #-}
12{-# LANGUAGE TypeOperators #-}
13{-# LANGUAGE ViewPatterns #-}
14
15{- |
16Module : Numeric.LinearAlgebra.Static.Internal
17Copyright : (c) Alberto Ruiz 2006-14
18License : BSD3
19Stability : provisional
20
21-}
22
23module Numeric.LinearAlgebra.Static.Internal where
24
25
26import GHC.TypeLits
27import qualified Numeric.LinearAlgebra as LA
28import Numeric.LinearAlgebra hiding (konst,size)
29import Data.Packed as D
30import Data.Packed.ST
31import Data.Proxy(Proxy)
32import Foreign.Storable(Storable)
33import Text.Printf
34
35--------------------------------------------------------------------------------
36
37newtype Dim (n :: Nat) t = Dim t
38 deriving Show
39
40lift1F
41 :: (c t -> c t)
42 -> Dim n (c t) -> Dim n (c t)
43lift1F f (Dim v) = Dim (f v)
44
45lift2F
46 :: (c t -> c t -> c t)
47 -> Dim n (c t) -> Dim n (c t) -> Dim n (c t)
48lift2F f (Dim u) (Dim v) = Dim (f u v)
49
50--------------------------------------------------------------------------------
51
52newtype R n = R (Dim n (Vector ℝ))
53 deriving (Num,Fractional,Floating)
54
55newtype C n = C (Dim n (Vector ℂ))
56 deriving (Num,Fractional,Floating)
57
58newtype L m n = L (Dim m (Dim n (Matrix ℝ)))
59
60newtype M m n = M (Dim m (Dim n (Matrix ℂ)))
61
62
63mkR :: Vector ℝ -> R n
64mkR = R . Dim
65
66mkC :: Vector ℂ -> C n
67mkC = C . Dim
68
69mkL :: Matrix ℝ -> L m n
70mkL x = L (Dim (Dim x))
71
72mkM :: Matrix ℂ -> M m n
73mkM x = M (Dim (Dim x))
74
75--------------------------------------------------------------------------------
76
77type V n t = Dim n (Vector t)
78
79ud :: Dim n (Vector t) -> Vector t
80ud (Dim v) = v
81
82mkV :: forall (n :: Nat) t . t -> Dim n t
83mkV = Dim
84
85
86vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t)
87 => V n t -> V m t -> V (n+m) t
88(ud -> u) `vconcat` (ud -> v) = mkV (vjoin [u', v'])
89 where
90 du = fromIntegral . natVal $ (undefined :: Proxy n)
91 dv = fromIntegral . natVal $ (undefined :: Proxy m)
92 u' | du > 1 && LA.size u == 1 = LA.konst (u D.@> 0) du
93 | otherwise = u
94 v' | dv > 1 && LA.size v == 1 = LA.konst (v D.@> 0) dv
95 | otherwise = v
96
97
98gvec2 :: Storable t => t -> t -> V 2 t
99gvec2 a b = mkV $ runSTVector $ do
100 v <- newUndefinedVector 2
101 writeVector v 0 a
102 writeVector v 1 b
103 return v
104
105gvec3 :: Storable t => t -> t -> t -> V 3 t
106gvec3 a b c = mkV $ runSTVector $ do
107 v <- newUndefinedVector 3
108 writeVector v 0 a
109 writeVector v 1 b
110 writeVector v 2 c
111 return v
112
113
114gvec4 :: Storable t => t -> t -> t -> t -> V 4 t
115gvec4 a b c d = mkV $ runSTVector $ do
116 v <- newUndefinedVector 4
117 writeVector v 0 a
118 writeVector v 1 b
119 writeVector v 2 c
120 writeVector v 3 d
121 return v
122
123
124gvect :: forall n t . (Show t, KnownNat n, Numeric t) => String -> [t] -> V n t
125gvect st xs'
126 | ok = mkV v
127 | not (null rest) && null (tail rest) = abort (show xs')
128 | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]")
129 | otherwise = abort (show xs)
130 where
131 (xs,rest) = splitAt d xs'
132 ok = LA.size v == d && null rest
133 v = LA.fromList xs
134 d = fromIntegral . natVal $ (undefined :: Proxy n)
135 abort info = error $ st++" "++show d++" can't be created from elements "++info
136
137
138--------------------------------------------------------------------------------
139
140type GM m n t = Dim m (Dim n (Matrix t))
141
142
143gmat :: forall m n t . (Show t, KnownNat m, KnownNat n, Numeric t) => String -> [t] -> GM m n t
144gmat st xs'
145 | ok = Dim (Dim x)
146 | not (null rest) && null (tail rest) = abort (show xs')
147 | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]")
148 | otherwise = abort (show xs)
149 where
150 (xs,rest) = splitAt (m'*n') xs'
151 v = LA.fromList xs
152 x = reshape n' v
153 ok = null rest && ((n' == 0 && dim v == 0) || n'> 0 && (rem (LA.size v) n' == 0) && LA.size x == (m',n'))
154 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
155 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
156 abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info
157
158--------------------------------------------------------------------------------
159
160class Num t => Sized t s d | s -> t, s -> d
161 where
162 konst :: t -> s
163 unwrap :: s -> d t
164 fromList :: [t] -> s
165 extract :: s -> d t
166 create :: d t -> Maybe s
167 size :: s -> IndexOf d
168
169singleV v = LA.size v == 1
170singleM m = rows m == 1 && cols m == 1
171
172
173instance forall n. KnownNat n => Sized ℂ (C n) Vector
174 where
175 size _ = fromIntegral . natVal $ (undefined :: Proxy n)
176 konst x = mkC (LA.scalar x)
177 unwrap (C (Dim v)) = v
178 fromList xs = C (gvect "C" xs)
179 extract s@(unwrap -> v)
180 | singleV v = LA.konst (v!0) (size s)
181 | otherwise = v
182 create v
183 | LA.size v == size r = Just r
184 | otherwise = Nothing
185 where
186 r = mkC v :: C n
187
188
189instance forall n. KnownNat n => Sized ℝ (R n) Vector
190 where
191 size _ = fromIntegral . natVal $ (undefined :: Proxy n)
192 konst x = mkR (LA.scalar x)
193 unwrap (R (Dim v)) = v
194 fromList xs = R (gvect "R" xs)
195 extract s@(unwrap -> v)
196 | singleV v = LA.konst (v!0) (size s)
197 | otherwise = v
198 create v
199 | LA.size v == size r = Just r
200 | otherwise = Nothing
201 where
202 r = mkR v :: R n
203
204
205
206instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) Matrix
207 where
208 size _ = ((fromIntegral . natVal) (undefined :: Proxy m)
209 ,(fromIntegral . natVal) (undefined :: Proxy n))
210 konst x = mkL (LA.scalar x)
211 fromList xs = L (gmat "L" xs)
212 unwrap (L (Dim (Dim m))) = m
213 extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n'
214 extract s@(unwrap -> a)
215 | singleM a = LA.konst (a `atIndex` (0,0)) (size s)
216 | otherwise = a
217 create x
218 | LA.size x == size r = Just r
219 | otherwise = Nothing
220 where
221 r = mkL x :: L m n
222
223
224instance forall m n . (KnownNat m, KnownNat n) => Sized ℂ (M m n) Matrix
225 where
226 size _ = ((fromIntegral . natVal) (undefined :: Proxy m)
227 ,(fromIntegral . natVal) (undefined :: Proxy n))
228 konst x = mkM (LA.scalar x)
229 fromList xs = M (gmat "M" xs)
230 unwrap (M (Dim (Dim m))) = m
231 extract (isDiagC -> Just (z,y,(m',n'))) = diagRect z y m' n'
232 extract s@(unwrap -> a)
233 | singleM a = LA.konst (a `atIndex` (0,0)) (size s)
234 | otherwise = a
235 create x
236 | LA.size x == size r = Just r
237 | otherwise = Nothing
238 where
239 r = mkM x :: M m n
240
241--------------------------------------------------------------------------------
242
243instance (KnownNat n, KnownNat m) => Transposable (L m n) (L n m)
244 where
245 tr a@(isDiag -> Just _) = mkL (extract a)
246 tr (extract -> a) = mkL (tr a)
247 tr' = tr
248
249instance (KnownNat n, KnownNat m) => Transposable (M m n) (M n m)
250 where
251 tr a@(isDiagC -> Just _) = mkM (extract a)
252 tr (extract -> a) = mkM (tr a)
253 tr' a@(isDiagC -> Just _) = mkM (extract a)
254 tr' (extract -> a) = mkM (tr' a)
255
256--------------------------------------------------------------------------------
257
258isDiag :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ, Vector ℝ, (Int,Int))
259isDiag (L x) = isDiagg x
260
261isDiagC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (ℂ, Vector ℂ, (Int,Int))
262isDiagC (M x) = isDiagg x
263
264
265isDiagg :: forall m n t . (Numeric t, KnownNat m, KnownNat n) => GM m n t -> Maybe (t, Vector t, (Int,Int))
266isDiagg (Dim (Dim x))
267 | singleM x = Nothing
268 | rows x == 1 && m' > 1 || cols x == 1 && n' > 1 = Just (z,yz,(m',n'))
269 | otherwise = Nothing
270 where
271 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
272 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
273 v = flatten x
274 z = v `atIndex` 0
275 y = subVector 1 (LA.size v-1) v
276 ny = LA.size y
277 zeros = LA.konst 0 (max 0 (min m' n' - ny))
278 yz = vjoin [y,zeros]
279
280--------------------------------------------------------------------------------
281
282instance forall n . KnownNat n => Show (R n)
283 where
284 show s@(R (Dim v))
285 | singleV v = "("++show (v!0)++" :: R "++show d++")"
286 | otherwise = "(vector"++ drop 8 (show v)++" :: R "++show d++")"
287 where
288 d = size s
289
290instance forall n . KnownNat n => Show (C n)
291 where
292 show s@(C (Dim v))
293 | singleV v = "("++show (v!0)++" :: C "++show d++")"
294 | otherwise = "(vector"++ drop 8 (show v)++" :: C "++show d++")"
295 where
296 d = size s
297
298instance forall m n . (KnownNat m, KnownNat n) => Show (L m n)
299 where
300 show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n'
301 show s@(L (Dim (Dim x)))
302 | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n'
303 | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")"
304 where
305 (m',n') = size s
306
307instance forall m n . (KnownNat m, KnownNat n) => Show (M m n)
308 where
309 show (isDiagC -> Just (z,y,(m',n'))) = printf "(diag %s %s :: M %d %d)" (show z) (drop 9 $ show y) m' n'
310 show s@(M (Dim (Dim x)))
311 | singleM x = printf "(%s :: M %d %d)" (show (x `atIndex` (0,0))) m' n'
312 | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: M "++show m'++" "++show n'++")"
313 where
314 (m',n') = size s
315
316--------------------------------------------------------------------------------
317
318instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t))
319 where
320 (+) = lift2F (+)
321 (*) = lift2F (*)
322 (-) = lift2F (-)
323 abs = lift1F abs
324 signum = lift1F signum
325 negate = lift1F negate
326 fromInteger x = Dim (fromInteger x)
327
328instance (Num (Vector t), Num (Matrix t), Fractional t, Numeric t) => Fractional (Dim n (Vector t))
329 where
330 fromRational x = Dim (fromRational x)
331 (/) = lift2F (/)
332
333instance (Fractional t, Floating (Vector t), Numeric t) => Floating (Dim n (Vector t)) where
334 sin = lift1F sin
335 cos = lift1F cos
336 tan = lift1F tan
337 asin = lift1F asin
338 acos = lift1F acos
339 atan = lift1F atan
340 sinh = lift1F sinh
341 cosh = lift1F cosh
342 tanh = lift1F tanh
343 asinh = lift1F asinh
344 acosh = lift1F acosh
345 atanh = lift1F atanh
346 exp = lift1F exp
347 log = lift1F log
348 sqrt = lift1F sqrt
349 (**) = lift2F (**)
350 pi = Dim pi
351
352
353instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t)))
354 where
355 (+) = (lift2F . lift2F) (+)
356 (*) = (lift2F . lift2F) (*)
357 (-) = (lift2F . lift2F) (-)
358 abs = (lift1F . lift1F) abs
359 signum = (lift1F . lift1F) signum
360 negate = (lift1F . lift1F) negate
361 fromInteger x = Dim (Dim (fromInteger x))
362
363instance (Num (Vector t), Num (Matrix t), Fractional t, Numeric t) => Fractional (Dim m (Dim n (Matrix t)))
364 where
365 fromRational x = Dim (Dim (fromRational x))
366 (/) = (lift2F.lift2F) (/)
367
368instance (Num (Vector t), Floating (Matrix t), Fractional t, Numeric t) => Floating (Dim m (Dim n (Matrix t))) where
369 sin = (lift1F . lift1F) sin
370 cos = (lift1F . lift1F) cos
371 tan = (lift1F . lift1F) tan
372 asin = (lift1F . lift1F) asin
373 acos = (lift1F . lift1F) acos
374 atan = (lift1F . lift1F) atan
375 sinh = (lift1F . lift1F) sinh
376 cosh = (lift1F . lift1F) cosh
377 tanh = (lift1F . lift1F) tanh
378 asinh = (lift1F . lift1F) asinh
379 acosh = (lift1F . lift1F) acosh
380 atanh = (lift1F . lift1F) atanh
381 exp = (lift1F . lift1F) exp
382 log = (lift1F . lift1F) log
383 sqrt = (lift1F . lift1F) sqrt
384 (**) = (lift2F . lift2F) (**)
385 pi = Dim (Dim pi)
386
387--------------------------------------------------------------------------------
388
389
390adaptDiag f a@(isDiag -> Just _) b | isFull b = f (mkL (extract a)) b
391adaptDiag f a b@(isDiag -> Just _) | isFull a = f a (mkL (extract b))
392adaptDiag f a b = f a b
393
394isFull m = isDiag m == Nothing && not (singleM (unwrap m))
395
396
397lift1L f (L v) = L (f v)
398lift2L f (L a) (L b) = L (f a b)
399lift2LD f = adaptDiag (lift2L f)
400
401
402instance (KnownNat n, KnownNat m) => Num (L n m)
403 where
404 (+) = lift2LD (+)
405 (*) = lift2LD (*)
406 (-) = lift2LD (-)
407 abs = lift1L abs
408 signum = lift1L signum
409 negate = lift1L negate
410 fromInteger = L . Dim . Dim . fromInteger
411
412instance (KnownNat n, KnownNat m) => Fractional (L n m)
413 where
414 fromRational = L . Dim . Dim . fromRational
415 (/) = lift2LD (/)
416
417instance (KnownNat n, KnownNat m) => Floating (L n m) where
418 sin = lift1L sin
419 cos = lift1L cos
420 tan = lift1L tan
421 asin = lift1L asin
422 acos = lift1L acos
423 atan = lift1L atan
424 sinh = lift1L sinh
425 cosh = lift1L cosh
426 tanh = lift1L tanh
427 asinh = lift1L asinh
428 acosh = lift1L acosh
429 atanh = lift1L atanh
430 exp = lift1L exp
431 log = lift1L log
432 sqrt = lift1L sqrt
433 (**) = lift2LD (**)
434 pi = konst pi
435
436--------------------------------------------------------------------------------
437
438adaptDiagC f a@(isDiagC -> Just _) b | isFullC b = f (mkM (extract a)) b
439adaptDiagC f a b@(isDiagC -> Just _) | isFullC a = f a (mkM (extract b))
440adaptDiagC f a b = f a b
441
442isFullC m = isDiagC m == Nothing && not (singleM (unwrap m))
443
444lift1M f (M v) = M (f v)
445lift2M f (M a) (M b) = M (f a b)
446lift2MD f = adaptDiagC (lift2M f)
447
448instance (KnownNat n, KnownNat m) => Num (M n m)
449 where
450 (+) = lift2MD (+)
451 (*) = lift2MD (*)
452 (-) = lift2MD (-)
453 abs = lift1M abs
454 signum = lift1M signum
455 negate = lift1M negate
456 fromInteger = M . Dim . Dim . fromInteger
457
458instance (KnownNat n, KnownNat m) => Fractional (M n m)
459 where
460 fromRational = M . Dim . Dim . fromRational
461 (/) = lift2MD (/)
462
463instance (KnownNat n, KnownNat m) => Floating (M n m) where
464 sin = lift1M sin
465 cos = lift1M cos
466 tan = lift1M tan
467 asin = lift1M asin
468 acos = lift1M acos
469 atan = lift1M atan
470 sinh = lift1M sinh
471 cosh = lift1M cosh
472 tanh = lift1M tanh
473 asinh = lift1M asinh
474 acosh = lift1M acosh
475 atanh = lift1M atanh
476 exp = lift1M exp
477 log = lift1M log
478 sqrt = lift1M sqrt
479 (**) = lift2MD (**)
480 pi = M pi
481
482--------------------------------------------------------------------------------
483
484
485class Disp t
486 where
487 disp :: Int -> t -> IO ()
488
489
490instance (KnownNat m, KnownNat n) => Disp (L m n)
491 where
492 disp n x = do
493 let a = extract x
494 let su = LA.dispf n a
495 printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su)
496
497instance (KnownNat m, KnownNat n) => Disp (M m n)
498 where
499 disp n x = do
500 let a = extract x
501 let su = LA.dispcf n a
502 printf "M %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su)
503
504
505instance KnownNat n => Disp (R n)
506 where
507 disp n v = do
508 let su = LA.dispf n (asRow $ extract v)
509 putStr "R " >> putStr (tail . dropWhile (/='x') $ su)
510
511instance KnownNat n => Disp (C n)
512 where
513 disp n v = do
514 let su = LA.dispcf n (asRow $ extract v)
515 putStr "C " >> putStr (tail . dropWhile (/='x') $ su)
516
517--------------------------------------------------------------------------------
518
519#else
520
521module Numeric.LinearAlgebra.Static.Internal where
522
523#endif
524