diff options
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Real.hs')
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Real.hs | 337 |
1 files changed, 337 insertions, 0 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs new file mode 100644 index 0000000..db15705 --- /dev/null +++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs | |||
@@ -0,0 +1,337 @@ | |||
1 | {-# LANGUAGE DataKinds #-} | ||
2 | {-# LANGUAGE KindSignatures #-} | ||
3 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} | ||
4 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
5 | {-# LANGUAGE FunctionalDependencies #-} | ||
6 | {-# LANGUAGE FlexibleContexts #-} | ||
7 | {-# LANGUAGE ScopedTypeVariables #-} | ||
8 | {-# LANGUAGE EmptyDataDecls #-} | ||
9 | {-# LANGUAGE Rank2Types #-} | ||
10 | {-# LANGUAGE FlexibleInstances #-} | ||
11 | {-# LANGUAGE TypeOperators #-} | ||
12 | {-# LANGUAGE ViewPatterns #-} | ||
13 | {-# LANGUAGE GADTs #-} | ||
14 | |||
15 | |||
16 | {- | | ||
17 | Module : Numeric.LinearAlgebra.Real | ||
18 | Copyright : (c) Alberto Ruiz 2006-14 | ||
19 | License : BSD3 | ||
20 | Stability : provisional | ||
21 | |||
22 | Experimental interface for real arrays with statically checked dimensions. | ||
23 | |||
24 | -} | ||
25 | |||
26 | module Numeric.LinearAlgebra.Real( | ||
27 | -- * Vector | ||
28 | R, | ||
29 | vec2, vec3, vec4, ๐ง, (&), | ||
30 | -- * Matrix | ||
31 | L, Sq, | ||
32 | ๐, | ||
33 | (#),(ยฆ),(โโ), | ||
34 | Konst(..), | ||
35 | eye, | ||
36 | diagR, diag, | ||
37 | blockAt, | ||
38 | -- * Products | ||
39 | (<>),(#>),(<ยท>), | ||
40 | -- * Pretty printing | ||
41 | Disp(..), | ||
42 | -- * Misc | ||
43 | Dim, unDim, | ||
44 | module Numeric.HMatrix | ||
45 | ) where | ||
46 | |||
47 | |||
48 | import GHC.TypeLits | ||
49 | import Numeric.HMatrix hiding ((<>),(#>),(<ยท>),Konst(..),diag, disp,(ยฆ),(โโ)) | ||
50 | import qualified Numeric.HMatrix as LA | ||
51 | import Data.Packed.ST | ||
52 | |||
53 | newtype Dim (n :: Nat) t = Dim t | ||
54 | deriving Show | ||
55 | |||
56 | unDim :: Dim n t -> t | ||
57 | unDim (Dim x) = x | ||
58 | |||
59 | data Proxy :: Nat -> * | ||
60 | |||
61 | |||
62 | lift1F | ||
63 | :: (c t -> c t) | ||
64 | -> Dim n (c t) -> Dim n (c t) | ||
65 | lift1F f (Dim v) = Dim (f v) | ||
66 | |||
67 | lift2F | ||
68 | :: (c t -> c t -> c t) | ||
69 | -> Dim n (c t) -> Dim n (c t) -> Dim n (c t) | ||
70 | lift2F f (Dim u) (Dim v) = Dim (f u v) | ||
71 | |||
72 | |||
73 | |||
74 | type R n = Dim n (Vector โ) | ||
75 | |||
76 | type L m n = Dim m (Dim n (Matrix โ)) | ||
77 | |||
78 | |||
79 | infixl 4 & | ||
80 | (&) :: forall n . KnownNat n | ||
81 | => R n -> โ -> R (n+1) | ||
82 | Dim v & x = Dim (vjoin [v', scalar x]) | ||
83 | where | ||
84 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
85 | v' | d > 1 && size v == 1 = LA.konst (v!0) d | ||
86 | | otherwise = v | ||
87 | |||
88 | |||
89 | -- vect0 :: R 0 | ||
90 | -- vect0 = Dim (fromList[]) | ||
91 | |||
92 | ๐ง :: โ -> R 1 | ||
93 | ๐ง = Dim . scalar | ||
94 | |||
95 | |||
96 | vec2 :: โ -> โ -> R 2 | ||
97 | vec2 a b = Dim $ runSTVector $ do | ||
98 | v <- newUndefinedVector 2 | ||
99 | writeVector v 0 a | ||
100 | writeVector v 1 b | ||
101 | return v | ||
102 | |||
103 | vec3 :: โ -> โ -> โ -> R 3 | ||
104 | vec3 a b c = Dim $ runSTVector $ do | ||
105 | v <- newUndefinedVector 3 | ||
106 | writeVector v 0 a | ||
107 | writeVector v 1 b | ||
108 | writeVector v 2 c | ||
109 | return v | ||
110 | |||
111 | |||
112 | vec4 :: โ -> โ -> โ -> โ -> R 4 | ||
113 | vec4 a b c d = Dim $ runSTVector $ do | ||
114 | v <- newUndefinedVector 4 | ||
115 | writeVector v 0 a | ||
116 | writeVector v 1 b | ||
117 | writeVector v 2 c | ||
118 | writeVector v 3 d | ||
119 | return v | ||
120 | |||
121 | |||
122 | |||
123 | |||
124 | instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t)) | ||
125 | where | ||
126 | (+) = lift2F (+) | ||
127 | (*) = lift2F (*) | ||
128 | (-) = lift2F (-) | ||
129 | abs = lift1F abs | ||
130 | signum = lift1F signum | ||
131 | negate = lift1F negate | ||
132 | fromInteger x = Dim (fromInteger x) | ||
133 | |||
134 | instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t))) | ||
135 | where | ||
136 | (+) = (lift2F . lift2F) (+) | ||
137 | (*) = (lift2F . lift2F) (*) | ||
138 | (-) = (lift2F . lift2F) (-) | ||
139 | abs = (lift1F . lift1F) abs | ||
140 | signum = (lift1F . lift1F) signum | ||
141 | negate = (lift1F . lift1F) negate | ||
142 | fromInteger x = Dim (Dim (fromInteger x)) | ||
143 | |||
144 | -------------------------------------------------------------------------------- | ||
145 | |||
146 | class Konst t | ||
147 | where | ||
148 | konst :: โ -> t | ||
149 | |||
150 | instance forall n. KnownNat n => Konst (R n) | ||
151 | where | ||
152 | konst x = Dim (LA.konst x d) | ||
153 | where | ||
154 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
155 | |||
156 | instance forall m n . (KnownNat m, KnownNat n) => Konst (L m n) | ||
157 | where | ||
158 | konst x = Dim (Dim (LA.konst x (m',n'))) | ||
159 | where | ||
160 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
161 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
162 | |||
163 | -------------------------------------------------------------------------------- | ||
164 | |||
165 | diagR :: forall m n k . (KnownNat m, KnownNat n) => โ -> R k -> L m n | ||
166 | diagR x v = Dim (Dim (diagRect x (unDim v) m' n')) | ||
167 | where | ||
168 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
169 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
170 | |||
171 | diag :: KnownNat n => R n -> Sq n | ||
172 | diag = diagR 0 | ||
173 | |||
174 | -------------------------------------------------------------------------------- | ||
175 | |||
176 | blockAt :: forall m n . (KnownNat m, KnownNat n) => โ -> Int -> Int -> Matrix Double -> L m n | ||
177 | blockAt x r c a = Dim (Dim res) | ||
178 | where | ||
179 | z = scalar x | ||
180 | z1 = LA.konst x (r,c) | ||
181 | z2 = LA.konst x (max 0 (m'-(ra+r)), max 0 (n'-(ca+c))) | ||
182 | ra = min (rows a) . max 0 $ m'-r | ||
183 | ca = min (cols a) . max 0 $ n'-c | ||
184 | sa = subMatrix (0,0) (ra, ca) a | ||
185 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
186 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
187 | res = fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]] | ||
188 | |||
189 | {- | ||
190 | matrix :: (KnownNat m, KnownNat n) => Matrix Double -> L n m | ||
191 | matrix = blockAt 0 0 0 | ||
192 | -} | ||
193 | |||
194 | -------------------------------------------------------------------------------- | ||
195 | |||
196 | class Disp t | ||
197 | where | ||
198 | disp :: Int -> t -> IO () | ||
199 | |||
200 | instance Disp (L n m) | ||
201 | where | ||
202 | disp n (d2 -> a) = do | ||
203 | if rows a == 1 && cols a == 1 | ||
204 | then putStrLn $ "Const " ++ (last . words . LA.dispf n $ a) | ||
205 | else putStr "Dim " >> LA.disp n a | ||
206 | |||
207 | instance Disp (R n) | ||
208 | where | ||
209 | disp n (unDim -> v) = do | ||
210 | let su = LA.dispf n (asRow v) | ||
211 | if LA.size v == 1 | ||
212 | then putStrLn $ "Const " ++ (last . words $ su ) | ||
213 | else putStr "Dim " >> putStr (tail . dropWhile (/='x') $ su) | ||
214 | |||
215 | -------------------------------------------------------------------------------- | ||
216 | |||
217 | infixl 3 # | ||
218 | (#) :: L r c -> R c -> L (r+1) c | ||
219 | Dim (Dim m) # Dim v = Dim (Dim (m LA.โโ asRow v)) | ||
220 | |||
221 | |||
222 | ๐ :: forall n . KnownNat n => L 0 n | ||
223 | ๐ = Dim (Dim (LA.konst 0 (0,d))) | ||
224 | where | ||
225 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
226 | |||
227 | infixl 3 ยฆ | ||
228 | (ยฆ) :: L r c1 -> L r c2 -> L r (c1+c2) | ||
229 | Dim (Dim a) ยฆ Dim (Dim b) = Dim (Dim (a LA.ยฆ b)) | ||
230 | |||
231 | infixl 2 โโ | ||
232 | (โโ) :: L r1 c -> L r2 c -> L (r1+r2) c | ||
233 | Dim (Dim a) โโ Dim (Dim b) = Dim (Dim (a LA.โโ b)) | ||
234 | |||
235 | |||
236 | {- | ||
237 | |||
238 | -} | ||
239 | |||
240 | type Sq n = L n n | ||
241 | |||
242 | type GL = (KnownNat n, KnownNat m) => L m n | ||
243 | type GSq = KnownNat n => Sq n | ||
244 | |||
245 | infixr 8 <> | ||
246 | (<>) :: L m k -> L k n -> L m n | ||
247 | (d2 -> a) <> (d2 -> b) = Dim (Dim (a LA.<> b)) | ||
248 | |||
249 | infixr 8 #> | ||
250 | (#>) :: L m n -> R n -> R m | ||
251 | (d2 -> m) #> (unDim -> v) = Dim (m LA.#> v) | ||
252 | |||
253 | infixr 8 <ยท> | ||
254 | (<ยท>) :: R n -> R n -> โ | ||
255 | (unDim -> u) <ยท> (unDim -> v) = udot u v | ||
256 | |||
257 | |||
258 | d2 :: forall c (n :: Nat) (n1 :: Nat). Dim n1 (Dim n c) -> c | ||
259 | d2 = unDim . unDim | ||
260 | |||
261 | |||
262 | instance Transposable (L m n) (L n m) | ||
263 | where | ||
264 | tr (Dim (Dim a)) = Dim (Dim (tr a)) | ||
265 | |||
266 | |||
267 | eye :: forall n . KnownNat n => Sq n | ||
268 | eye = Dim (Dim (ident d)) | ||
269 | where | ||
270 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
271 | |||
272 | |||
273 | -------------------------------------------------------------------------------- | ||
274 | |||
275 | test :: (Bool, IO ()) | ||
276 | test = (ok,info) | ||
277 | where | ||
278 | ok = d2 (eye :: Sq 5) == ident 5 | ||
279 | && d2 (mTm sm :: Sq 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..] | ||
280 | && d2 (tm :: L 3 5) == mat 5 [1..15] | ||
281 | && thingS == thingD | ||
282 | && precS == precD | ||
283 | |||
284 | info = do | ||
285 | print $ u | ||
286 | print $ v | ||
287 | print (eye :: Sq 3) | ||
288 | print $ ((u & 5) + 1) <ยท> v | ||
289 | print (tm :: L 2 5) | ||
290 | print (tm <> sm :: L 2 3) | ||
291 | print thingS | ||
292 | print thingD | ||
293 | print precS | ||
294 | print precD | ||
295 | |||
296 | u = vec2 3 5 | ||
297 | |||
298 | v = ๐ง 2 & 4 & 7 | ||
299 | |||
300 | mTm :: L n m -> Sq m | ||
301 | mTm a = tr a <> a | ||
302 | |||
303 | tm :: GL | ||
304 | tm = lmat 0 [1..] | ||
305 | |||
306 | lmat :: forall m n . (KnownNat m, KnownNat n) => โ -> [โ] -> L m n | ||
307 | lmat z xs = Dim . Dim . reshape n' . fromList . take (m'*n') $ xs ++ repeat z | ||
308 | where | ||
309 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
310 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
311 | |||
312 | sm :: GSq | ||
313 | sm = lmat 0 [1..] | ||
314 | |||
315 | thingS = (u & 1) <ยท> tr q #> q #> v | ||
316 | where | ||
317 | q = tm :: L 10 3 | ||
318 | |||
319 | thingD = vjoin [unDim u, 1] LA.<ยท> tr m LA.#> m LA.#> unDim v | ||
320 | where | ||
321 | m = mat 3 [1..30] | ||
322 | |||
323 | precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <ยท> konst 2 #> v | ||
324 | precD = 1 + 2 * vjoin[unDim u, 6] LA.<ยท> LA.konst 2 (size (unDim u) +1, size (unDim v)) LA.#> unDim v | ||
325 | |||
326 | |||
327 | instance (KnownNat n', KnownNat m') => Testable (L n' m') | ||
328 | where | ||
329 | checkT _ = test | ||
330 | |||
331 | {- | ||
332 | do (snd test) | ||
333 | fst test | ||
334 | -} | ||
335 | |||
336 | |||
337 | |||