summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric/LinearAlgebra/Static
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Static')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs422
1 files changed, 422 insertions, 0 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs b/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs
new file mode 100644
index 0000000..c9641d5
--- /dev/null
+++ b/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs
@@ -0,0 +1,422 @@
1{-# LANGUAGE DataKinds #-}
2{-# LANGUAGE KindSignatures #-}
3{-# LANGUAGE GeneralizedNewtypeDeriving #-}
4{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE FunctionalDependencies #-}
6{-# LANGUAGE FlexibleContexts #-}
7{-# LANGUAGE ScopedTypeVariables #-}
8{-# LANGUAGE EmptyDataDecls #-}
9{-# LANGUAGE Rank2Types #-}
10{-# LANGUAGE FlexibleInstances #-}
11{-# LANGUAGE TypeOperators #-}
12{-# LANGUAGE ViewPatterns #-}
13{-# LANGUAGE GADTs #-}
14
15
16{- |
17Module : Numeric.LinearAlgebra.Static.Internal
18Copyright : (c) Alberto Ruiz 2006-14
19License : BSD3
20Stability : provisional
21
22-}
23
24module Numeric.LinearAlgebra.Static.Internal where
25
26
27import GHC.TypeLits
28import Numeric.LinearAlgebra.HMatrix as LA
29import Data.Packed as D
30import Data.Packed.ST
31import Data.Proxy(Proxy)
32import Foreign.Storable(Storable)
33import Text.Printf
34
35--------------------------------------------------------------------------------
36
37newtype Dim (n :: Nat) t = Dim t
38 deriving Show
39
40lift1F
41 :: (c t -> c t)
42 -> Dim n (c t) -> Dim n (c t)
43lift1F f (Dim v) = Dim (f v)
44
45lift2F
46 :: (c t -> c t -> c t)
47 -> Dim n (c t) -> Dim n (c t) -> Dim n (c t)
48lift2F f (Dim u) (Dim v) = Dim (f u v)
49
50--------------------------------------------------------------------------------
51
52newtype R n = R (Dim n (Vector ℝ))
53 deriving (Num,Fractional)
54
55newtype C n = C (Dim n (Vector ℂ))
56 deriving (Num,Fractional)
57
58newtype L m n = L (Dim m (Dim n (Matrix ℝ)))
59
60newtype M m n = M (Dim m (Dim n (Matrix ℂ)))
61
62
63mkR :: Vector ℝ -> R n
64mkR = R . Dim
65
66mkC :: Vector ℂ -> C n
67mkC = C . Dim
68
69mkL :: Matrix ℝ -> L m n
70mkL x = L (Dim (Dim x))
71
72mkM :: Matrix ℂ -> M m n
73mkM x = M (Dim (Dim x))
74
75--------------------------------------------------------------------------------
76
77type V n t = Dim n (Vector t)
78
79ud :: Dim n (Vector t) -> Vector t
80ud (Dim v) = v
81
82mkV :: forall (n :: Nat) t . t -> Dim n t
83mkV = Dim
84
85
86vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t)
87 => V n t -> V m t -> V (n+m) t
88(ud -> u) `vconcat` (ud -> v) = mkV (vjoin [u', v'])
89 where
90 du = fromIntegral . natVal $ (undefined :: Proxy n)
91 dv = fromIntegral . natVal $ (undefined :: Proxy m)
92 u' | du > 1 && size u == 1 = LA.konst (u D.@> 0) du
93 | otherwise = u
94 v' | dv > 1 && size v == 1 = LA.konst (v D.@> 0) dv
95 | otherwise = v
96
97
98gvec2 :: Storable t => t -> t -> V 2 t
99gvec2 a b = mkV $ runSTVector $ do
100 v <- newUndefinedVector 2
101 writeVector v 0 a
102 writeVector v 1 b
103 return v
104
105gvec3 :: Storable t => t -> t -> t -> V 3 t
106gvec3 a b c = mkV $ runSTVector $ do
107 v <- newUndefinedVector 3
108 writeVector v 0 a
109 writeVector v 1 b
110 writeVector v 2 c
111 return v
112
113
114gvec4 :: Storable t => t -> t -> t -> t -> V 4 t
115gvec4 a b c d = mkV $ runSTVector $ do
116 v <- newUndefinedVector 4
117 writeVector v 0 a
118 writeVector v 1 b
119 writeVector v 2 c
120 writeVector v 3 d
121 return v
122
123
124gvect :: forall n t . (Show t, KnownNat n, Numeric t) => String -> [t] -> V n t
125gvect st xs'
126 | ok = mkV v
127 | not (null rest) && null (tail rest) = abort (show xs')
128 | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]")
129 | otherwise = abort (show xs)
130 where
131 (xs,rest) = splitAt d xs'
132 ok = size v == d && null rest
133 v = LA.fromList xs
134 d = fromIntegral . natVal $ (undefined :: Proxy n)
135 abort info = error $ st++" "++show d++" can't be created from elements "++info
136
137
138--------------------------------------------------------------------------------
139
140type GM m n t = Dim m (Dim n (Matrix t))
141
142
143gmat :: forall m n t . (Show t, KnownNat m, KnownNat n, Numeric t) => String -> [t] -> GM m n t
144gmat st xs'
145 | ok = Dim (Dim x)
146 | not (null rest) && null (tail rest) = abort (show xs')
147 | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]")
148 | otherwise = abort (show xs)
149 where
150 (xs,rest) = splitAt (m'*n') xs'
151 v = LA.fromList xs
152 x = reshape n' v
153 ok = rem (size v) n' == 0 && size x == (m',n') && null rest
154 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
155 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
156 abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info
157
158--------------------------------------------------------------------------------
159
160class Num t => Sized t s d | s -> t, s -> d
161 where
162 konst :: t -> s
163 unwrap :: s -> d
164 fromList :: [t] -> s
165 extract :: s -> d
166
167singleV v = size v == 1
168singleM m = rows m == 1 && cols m == 1
169
170
171instance forall n. KnownNat n => Sized ℂ (C n) (Vector ℂ)
172 where
173 konst x = mkC (LA.scalar x)
174 unwrap (C (Dim v)) = v
175 fromList xs = C (gvect "C" xs)
176 extract (unwrap -> v)
177 | singleV v = LA.konst (v!0) d
178 | otherwise = v
179 where
180 d = fromIntegral . natVal $ (undefined :: Proxy n)
181
182
183instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ)
184 where
185 konst x = mkR (LA.scalar x)
186 unwrap (R (Dim v)) = v
187 fromList xs = R (gvect "R" xs)
188 extract (unwrap -> v)
189 | singleV v = LA.konst (v!0) d
190 | otherwise = v
191 where
192 d = fromIntegral . natVal $ (undefined :: Proxy n)
193
194
195
196instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ)
197 where
198 konst x = mkL (LA.scalar x)
199 fromList xs = L (gmat "L" xs)
200 unwrap (L (Dim (Dim m))) = m
201 extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n'
202 extract (unwrap -> a)
203 | singleM a = LA.konst (a `atIndex` (0,0)) (m',n')
204 | otherwise = a
205 where
206 m' = fromIntegral . natVal $ (undefined :: Proxy m)
207 n' = fromIntegral . natVal $ (undefined :: Proxy n)
208
209
210instance forall m n . (KnownNat m, KnownNat n) => Sized ℂ (M m n) (Matrix ℂ)
211 where
212 konst x = mkM (LA.scalar x)
213 fromList xs = M (gmat "M" xs)
214 unwrap (M (Dim (Dim m))) = m
215 extract (isDiagC -> Just (z,y,(m',n'))) = diagRect z y m' n'
216 extract (unwrap -> a)
217 | singleM a = LA.konst (a `atIndex` (0,0)) (m',n')
218 | otherwise = a
219 where
220 m' = fromIntegral . natVal $ (undefined :: Proxy m)
221 n' = fromIntegral . natVal $ (undefined :: Proxy n)
222
223--------------------------------------------------------------------------------
224
225instance (KnownNat n, KnownNat m) => Transposable (L m n) (L n m)
226 where
227 tr a@(isDiag -> Just _) = mkL (extract a)
228 tr (extract -> a) = mkL (tr a)
229
230instance (KnownNat n, KnownNat m) => Transposable (M m n) (M n m)
231 where
232 tr a@(isDiagC -> Just _) = mkM (extract a)
233 tr (extract -> a) = mkM (tr a)
234
235--------------------------------------------------------------------------------
236
237isDiag :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ, Vector ℝ, (Int,Int))
238isDiag (L x) = isDiagg x
239
240isDiagC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (ℂ, Vector ℂ, (Int,Int))
241isDiagC (M x) = isDiagg x
242
243
244isDiagg :: forall m n t . (Numeric t, KnownNat m, KnownNat n) => GM m n t -> Maybe (t, Vector t, (Int,Int))
245isDiagg (Dim (Dim x))
246 | singleM x = Nothing
247 | rows x == 1 && m' > 1 || cols x == 1 && n' > 1 = Just (z,yz,(m',n'))
248 | otherwise = Nothing
249 where
250 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
251 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
252 v = flatten x
253 z = v `atIndex` 0
254 y = subVector 1 (size v-1) v
255 ny = size y
256 zeros = LA.konst 0 (max 0 (min m' n' - ny))
257 yz = vjoin [y,zeros]
258
259--------------------------------------------------------------------------------
260
261instance forall n . KnownNat n => Show (R n)
262 where
263 show (R (Dim v))
264 | singleV v = "("++show (v!0)++" :: R "++show d++")"
265 | otherwise = "(vector"++ drop 8 (show v)++" :: R "++show d++")"
266 where
267 d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
268
269instance forall n . KnownNat n => Show (C n)
270 where
271 show (C (Dim v))
272 | singleV v = "("++show (v!0)++" :: C "++show d++")"
273 | otherwise = "(vector"++ drop 8 (show v)++" :: C "++show d++")"
274 where
275 d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
276
277instance forall m n . (KnownNat m, KnownNat n) => Show (L m n)
278 where
279 show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n'
280 show (L (Dim (Dim x)))
281 | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n'
282 | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")"
283 where
284 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
285 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
286
287instance forall m n . (KnownNat m, KnownNat n) => Show (M m n)
288 where
289 show (isDiagC -> Just (z,y,(m',n'))) = printf "(diag %s %s :: M %d %d)" (show z) (drop 9 $ show y) m' n'
290 show (M (Dim (Dim x)))
291 | singleM x = printf "(%s :: M %d %d)" (show (x `atIndex` (0,0))) m' n'
292 | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: M "++show m'++" "++show n'++")"
293 where
294 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
295 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
296
297--------------------------------------------------------------------------------
298
299instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t))
300 where
301 (+) = lift2F (+)
302 (*) = lift2F (*)
303 (-) = lift2F (-)
304 abs = lift1F abs
305 signum = lift1F signum
306 negate = lift1F negate
307 fromInteger x = Dim (fromInteger x)
308
309instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim n (Vector t))
310 where
311 fromRational x = Dim (fromRational x)
312 (/) = lift2F (/)
313
314
315instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t)))
316 where
317 (+) = (lift2F . lift2F) (+)
318 (*) = (lift2F . lift2F) (*)
319 (-) = (lift2F . lift2F) (-)
320 abs = (lift1F . lift1F) abs
321 signum = (lift1F . lift1F) signum
322 negate = (lift1F . lift1F) negate
323 fromInteger x = Dim (Dim (fromInteger x))
324
325instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim m (Dim n (Matrix t)))
326 where
327 fromRational x = Dim (Dim (fromRational x))
328 (/) = (lift2F.lift2F) (/)
329
330--------------------------------------------------------------------------------
331
332
333adaptDiag f a@(isDiag -> Just _) b | isFull b = f (mkL (extract a)) b
334adaptDiag f a b@(isDiag -> Just _) | isFull a = f a (mkL (extract b))
335adaptDiag f a b = f a b
336
337isFull m = isDiag m == Nothing && not (singleM (unwrap m))
338
339
340lift1L f (L v) = L (f v)
341lift2L f (L a) (L b) = L (f a b)
342lift2LD f = adaptDiag (lift2L f)
343
344
345instance (KnownNat n, KnownNat m) => Num (L n m)
346 where
347 (+) = lift2LD (+)
348 (*) = lift2LD (*)
349 (-) = lift2LD (-)
350 abs = lift1L abs
351 signum = lift1L signum
352 negate = lift1L negate
353 fromInteger = L . Dim . Dim . fromInteger
354
355instance (KnownNat n, KnownNat m) => Fractional (L n m)
356 where
357 fromRational = L . Dim . Dim . fromRational
358 (/) = lift2LD (/)
359
360--------------------------------------------------------------------------------
361
362adaptDiagC f a@(isDiagC -> Just _) b | isFullC b = f (mkM (extract a)) b
363adaptDiagC f a b@(isDiagC -> Just _) | isFullC a = f a (mkM (extract b))
364adaptDiagC f a b = f a b
365
366isFullC m = isDiagC m == Nothing && not (singleM (unwrap m))
367
368lift1M f (M v) = M (f v)
369lift2M f (M a) (M b) = M (f a b)
370lift2MD f = adaptDiagC (lift2M f)
371
372instance (KnownNat n, KnownNat m) => Num (M n m)
373 where
374 (+) = lift2MD (+)
375 (*) = lift2MD (*)
376 (-) = lift2MD (-)
377 abs = lift1M abs
378 signum = lift1M signum
379 negate = lift1M negate
380 fromInteger = M . Dim . Dim . fromInteger
381
382instance (KnownNat n, KnownNat m) => Fractional (M n m)
383 where
384 fromRational = M . Dim . Dim . fromRational
385 (/) = lift2MD (/)
386
387--------------------------------------------------------------------------------
388
389
390class Disp t
391 where
392 disp :: Int -> t -> IO ()
393
394
395instance (KnownNat m, KnownNat n) => Disp (L m n)
396 where
397 disp n x = do
398 let a = extract x
399 let su = LA.dispf n a
400 printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su)
401
402instance (KnownNat m, KnownNat n) => Disp (M m n)
403 where
404 disp n x = do
405 let a = extract x
406 let su = LA.dispcf n a
407 printf "M %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su)
408
409
410instance KnownNat n => Disp (R n)
411 where
412 disp n v = do
413 let su = LA.dispf n (asRow $ extract v)
414 putStr "R " >> putStr (tail . dropWhile (/='x') $ su)
415
416instance KnownNat n => Disp (C n)
417 where
418 disp n v = do
419 let su = LA.dispcf n (asRow $ extract v)
420 putStr "C " >> putStr (tail . dropWhile (/='x') $ su)
421
422--------------------------------------------------------------------------------