summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric/LinearAlgebra
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Complex.hs80
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Real.hs395
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Static.hs193
3 files changed, 514 insertions, 154 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Complex.hs b/packages/base/src/Numeric/LinearAlgebra/Complex.hs
new file mode 100644
index 0000000..17bc397
--- /dev/null
+++ b/packages/base/src/Numeric/LinearAlgebra/Complex.hs
@@ -0,0 +1,80 @@
1{-# LANGUAGE DataKinds #-}
2{-# LANGUAGE KindSignatures #-}
3{-# LANGUAGE GeneralizedNewtypeDeriving #-}
4{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE FunctionalDependencies #-}
6{-# LANGUAGE FlexibleContexts #-}
7{-# LANGUAGE ScopedTypeVariables #-}
8{-# LANGUAGE EmptyDataDecls #-}
9{-# LANGUAGE Rank2Types #-}
10{-# LANGUAGE FlexibleInstances #-}
11{-# LANGUAGE TypeOperators #-}
12{-# LANGUAGE ViewPatterns #-}
13{-# LANGUAGE GADTs #-}
14
15
16{- |
17Module : Numeric.LinearAlgebra.Complex
18Copyright : (c) Alberto Ruiz 2006-14
19License : BSD3
20Stability : experimental
21
22-}
23
24module Numeric.LinearAlgebra.Complex(
25 C,
26 vec2, vec3, vec4, (&), (#),
27 vect,
28 R
29) where
30
31import GHC.TypeLits
32import Numeric.HMatrix hiding (
33 (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vect,mat,linspace)
34import qualified Numeric.HMatrix as LA
35import Data.Proxy(Proxy)
36import Numeric.LinearAlgebra.Static
37
38
39
40instance forall n . KnownNat n => Show (C n)
41 where
42 show (ud1 -> v)
43 | size v == 1 = "("++show (v!0)++" :: C "++show d++")"
44 | otherwise = "(vect"++ drop 8 (show v)++" :: C "++show d++")"
45 where
46 d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
47
48
49ud1 :: C n -> Vector ℂ
50ud1 (C (Dim v)) = v
51
52mkC :: Vector ℂ -> C n
53mkC = C . Dim
54
55
56infixl 4 &
57(&) :: forall n . KnownNat n
58 => C n -> ℂ -> C (n+1)
59u & x = u # (mkC (LA.scalar x) :: C 1)
60
61infixl 4 #
62(#) :: forall n m . (KnownNat n, KnownNat m)
63 => C n -> C m -> C (n+m)
64(C u) # (C v) = C (vconcat u v)
65
66
67
68vec2 :: ℂ -> ℂ -> C 2
69vec2 a b = C (gvec2 a b)
70
71vec3 :: ℂ -> ℂ -> ℂ -> C 3
72vec3 a b c = C (gvec3 a b c)
73
74
75vec4 :: ℂ -> ℂ -> ℂ -> ℂ -> C 4
76vec4 a b c d = C (gvec4 a b c d)
77
78vect :: forall n . KnownNat n => [ℂ] -> C n
79vect xs = C (gvect "C" xs)
80
diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs
index 5634031..424e766 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Real.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs
@@ -11,13 +11,15 @@
11{-# LANGUAGE TypeOperators #-} 11{-# LANGUAGE TypeOperators #-}
12{-# LANGUAGE ViewPatterns #-} 12{-# LANGUAGE ViewPatterns #-}
13{-# LANGUAGE GADTs #-} 13{-# LANGUAGE GADTs #-}
14{-# LANGUAGE OverlappingInstances #-}
15{-# LANGUAGE TypeFamilies #-}
14 16
15 17
16{- | 18{- |
17Module : Numeric.LinearAlgebra.Real 19Module : Numeric.LinearAlgebra.Real
18Copyright : (c) Alberto Ruiz 2006-14 20Copyright : (c) Alberto Ruiz 2006-14
19License : BSD3 21License : BSD3
20Stability : provisional 22Stability : experimental
21 23
22Experimental interface for real arrays with statically checked dimensions. 24Experimental interface for real arrays with statically checked dimensions.
23 25
@@ -26,165 +28,173 @@ Experimental interface for real arrays with statically checked dimensions.
26module Numeric.LinearAlgebra.Real( 28module Numeric.LinearAlgebra.Real(
27 -- * Vector 29 -- * Vector
28 R, 30 R,
29 vec2, vec3, vec4, 𝕧, (&), 31 vec2, vec3, vec4, (&), (#),
32 vect,
33 linspace, range, dim,
30 -- * Matrix 34 -- * Matrix
31 L, Sq, 35 L, Sq,
32 row, col, (¦),(——), 36 row, col, (¦),(——),
33 Konst(..), 37 unrow, uncol,
38 Sized(..),
34 eye, 39 eye,
35 diagR, diag, 40 diagR, diag, Diag(..),
36 blockAt, 41 blockAt,
42 mat,
37 -- * Products 43 -- * Products
38 (<>),(#>),(<·>), 44 (<>),(#>),(<·>),
45 -- * Linear Systems
46 linSolve, -- (<\>),
39 -- * Pretty printing 47 -- * Pretty printing
40 Disp(..), 48 Disp(..),
41 -- * Misc 49 -- * Misc
42 Dim, unDim, 50 C,
51 withVector, withMatrix,
43 module Numeric.HMatrix 52 module Numeric.HMatrix
44) where 53) where
45 54
46 55
47import GHC.TypeLits 56import GHC.TypeLits
48import Numeric.HMatrix hiding ((<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col) 57import Numeric.HMatrix hiding (
58 (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vect,mat,linspace,(<\>),fromList,takeDiag)
49import qualified Numeric.HMatrix as LA 59import qualified Numeric.HMatrix as LA
50import Data.Packed.ST
51import Data.Proxy(Proxy) 60import Data.Proxy(Proxy)
61import Numeric.LinearAlgebra.Static
62import Text.Printf
63
64instance forall n . KnownNat n => Show (R n)
65 where
66 show (ud1 -> v)
67 | singleV v = "("++show (v!0)++" :: R "++show d++")"
68 | otherwise = "(vect"++ drop 8 (show v)++" :: R "++show d++")"
69 where
70 d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
52 71
53newtype Dim (n :: Nat) t = Dim t
54 deriving Show
55 72
56unDim :: Dim n t -> t 73ud1 :: R n -> Vector ℝ
57unDim (Dim x) = x 74ud1 (R (Dim v)) = v
58 75
59-- data Proxy :: Nat -> *
60 76
77mkR :: Vector ℝ -> R n
78mkR = R . Dim
61 79
62lift1F
63 :: (c t -> c t)
64 -> Dim n (c t) -> Dim n (c t)
65lift1F f (Dim v) = Dim (f v)
66 80
67lift2F 81infixl 4 &
68 :: (c t -> c t -> c t) 82(&) :: forall n . KnownNat n
69 -> Dim n (c t) -> Dim n (c t) -> Dim n (c t) 83 => R n -> ℝ -> R (n+1)
70lift2F f (Dim u) (Dim v) = Dim (f u v) 84u & x = u # (konst x :: R 1)
85
86infixl 4 #
87(#) :: forall n m . (KnownNat n, KnownNat m)
88 => R n -> R m -> R (n+m)
89(R u) # (R v) = R (vconcat u v)
71 90
72 91
73 92
74type R n = Dim n (Vector ℝ) 93vec2 :: ℝ -> ℝ -> R 2
94vec2 a b = R (gvec2 a b)
75 95
76type L m n = Dim m (Dim n (Matrix ℝ)) 96vec3 :: ℝ -> ℝ -> ℝ -> R 3
97vec3 a b c = R (gvec3 a b c)
77 98
78 99
79infixl 4 & 100vec4 :: ℝ -> ℝ -> ℝ -> ℝ -> R 4
80(&) :: forall n . KnownNat n 101vec4 a b c d = R (gvec4 a b c d)
81 => R n -> ℝ -> R (n+1) 102
82Dim v & x = Dim (vjoin [v', scalar x]) 103vect :: forall n . KnownNat n => [ℝ] -> R n
104vect xs = R (gvect "R" xs)
105
106linspace :: forall n . KnownNat n => (ℝ,ℝ) -> R n
107linspace (a,b) = mkR (LA.linspace d (a,b))
83 where 108 where
84 d = fromIntegral . natVal $ (undefined :: Proxy n) 109 d = fromIntegral . natVal $ (undefined :: Proxy n)
85 v' | d > 1 && size v == 1 = LA.konst (v!0) d
86 | otherwise = v
87 110
111range :: forall n . KnownNat n => R n
112range = mkR (LA.linspace d (1,fromIntegral d))
113 where
114 d = fromIntegral . natVal $ (undefined :: Proxy n)
88 115
89-- vect0 :: R 0 116dim :: forall n . KnownNat n => R n
90-- vect0 = Dim (fromList[]) 117dim = mkR (scalar d)
118 where
119 d = fromIntegral . natVal $ (undefined :: Proxy n)
91 120
92𝕧 :: ℝ -> R 1
93𝕧 = Dim . scalar
94 121
122--------------------------------------------------------------------------------
95 123
96vec2 :: ℝ -> ℝ -> R 2 124newtype L m n = L (Dim m (Dim n (Matrix ℝ)))
97vec2 a b = Dim $ runSTVector $ do 125 deriving (Num,Fractional)
98 v <- newUndefinedVector 2
99 writeVector v 0 a
100 writeVector v 1 b
101 return v
102 126
103vec3 :: ℝ -> ℝ -> ℝ -> R 3
104vec3 a b c = Dim $ runSTVector $ do
105 v <- newUndefinedVector 3
106 writeVector v 0 a
107 writeVector v 1 b
108 writeVector v 2 c
109 return v
110 127
128ud2 :: L m n -> Matrix ℝ
129ud2 (L (Dim (Dim x))) = x
111 130
112vec4 :: ℝ -> ℝ -> ℝ -> ℝ -> R 4
113vec4 a b c d = Dim $ runSTVector $ do
114 v <- newUndefinedVector 4
115 writeVector v 0 a
116 writeVector v 1 b
117 writeVector v 2 c
118 writeVector v 3 d
119 return v
120 131
121 132
122 133
134mkL :: Matrix ℝ -> L m n
135mkL x = L (Dim (Dim x))
123 136
124instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t))
125 where
126 (+) = lift2F (+)
127 (*) = lift2F (*)
128 (-) = lift2F (-)
129 abs = lift1F abs
130 signum = lift1F signum
131 negate = lift1F negate
132 fromInteger x = Dim (fromInteger x)
133
134instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t)))
135 where
136 (+) = (lift2F . lift2F) (+)
137 (*) = (lift2F . lift2F) (*)
138 (-) = (lift2F . lift2F) (-)
139 abs = (lift1F . lift1F) abs
140 signum = (lift1F . lift1F) signum
141 negate = (lift1F . lift1F) negate
142 fromInteger x = Dim (Dim (fromInteger x))
143
144instance Fractional (Dim n (Vector Double))
145 where
146 fromRational x = Dim (fromRational x)
147 (/) = lift2F (/)
148 137
149instance Fractional (Dim m (Dim n (Matrix Double))) 138instance forall m n . (KnownNat m, KnownNat n) => Show (L m n)
150 where 139 where
151 fromRational x = Dim (Dim (fromRational x)) 140 show (ud2 -> x)
152 (/) = (lift2F.lift2F) (/) 141 | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n'
142 | isDiag = printf "(diag %s %s :: L %d %d)" (show z) shy m' n'
143 | otherwise = "(mat"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")"
144 where
145 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
146 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
147 isDiag = rows x == 1 && m' > 1
148 v = flatten x
149 z = v!0
150 y = subVector 1 (size v-1) v
151 shy = drop 9 (show y)
153 152
154-------------------------------------------------------------------------------- 153--------------------------------------------------------------------------------
155 154
156class Konst t 155instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ)
157 where 156 where
158 konst :: ℝ -> t 157 konst x = mkR (LA.scalar x)
158 extract = ud1
159 fromList = vect
160 expand (extract -> v)
161 | singleV v = LA.konst (v!0) d
162 | otherwise = v
163 where
164 d = fromIntegral . natVal $ (undefined :: Proxy n)
159 165
160instance forall n. KnownNat n => Konst (R n)
161 where
162 konst x = Dim (LA.konst x d)
163 where
164 d = fromIntegral . natVal $ (undefined :: Proxy n)
165 166
166instance forall m n . (KnownNat m, KnownNat n) => Konst (L m n) 167instance forall m n . (KnownNat m, KnownNat n) => Sized (L m n) (Matrix ℝ)
167 where 168 where
168 konst x = Dim (Dim (LA.konst x (m',n'))) 169 konst x = mkL (LA.scalar x)
170 extract = ud2
171 fromList = mat
172 expand (extract -> a)
173 | singleM a = LA.konst (a `atIndex` (0,0)) (m',n')
174 | rows a == 1 && m'>1 = diagRect x y m' n'
175 | otherwise = a
169 where 176 where
170 m' = fromIntegral . natVal $ (undefined :: Proxy m) 177 m' = fromIntegral . natVal $ (undefined :: Proxy m)
171 n' = fromIntegral . natVal $ (undefined :: Proxy n) 178 n' = fromIntegral . natVal $ (undefined :: Proxy n)
179 v = flatten a
180 x = v!0
181 y = subVector 1 (size v -1) v
172 182
173-------------------------------------------------------------------------------- 183--------------------------------------------------------------------------------
174 184
175diagR :: forall m n k . (KnownNat m, KnownNat n) => ℝ -> R k -> L m n 185diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n
176diagR x v = Dim (Dim (diagRect x (unDim v) m' n')) 186diagR x v = mkL (asRow (vjoin [scalar x, expand v]))
177 where
178 m' = fromIntegral . natVal $ (undefined :: Proxy m)
179 n' = fromIntegral . natVal $ (undefined :: Proxy n)
180 187
181diag :: KnownNat n => R n -> Sq n 188diag :: KnownNat n => R n -> Sq n
182diag = diagR 0 189diag = diagR 0
183 190
191eye :: KnownNat n => Sq n
192eye = diag 1
193
184-------------------------------------------------------------------------------- 194--------------------------------------------------------------------------------
185 195
186blockAt :: forall m n . (KnownNat m, KnownNat n) => ℝ -> Int -> Int -> Matrix Double -> L m n 196blockAt :: forall m n . (KnownNat m, KnownNat n) => ℝ -> Int -> Int -> Matrix Double -> L m n
187blockAt x r c a = Dim (Dim res) 197blockAt x r c a = mkL res
188 where 198 where
189 z = scalar x 199 z = scalar x
190 z1 = LA.konst x (r,c) 200 z1 = LA.konst x (r,c)
@@ -196,117 +206,189 @@ blockAt x r c a = Dim (Dim res)
196 n' = fromIntegral . natVal $ (undefined :: Proxy n) 206 n' = fromIntegral . natVal $ (undefined :: Proxy n)
197 res = fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]] 207 res = fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]]
198 208
199{-
200matrix :: (KnownNat m, KnownNat n) => Matrix Double -> L n m
201matrix = blockAt 0 0 0
202-}
203 209
210
211mat :: forall m n . (KnownNat m, KnownNat n) => [ℝ] -> L m n
212mat xs = L (gmat "L" xs)
213
204-------------------------------------------------------------------------------- 214--------------------------------------------------------------------------------
205 215
206class Disp t 216class Disp t
207 where 217 where
208 disp :: Int -> t -> IO () 218 disp :: Int -> t -> IO ()
209 219
210instance Disp (L n m) 220
221instance (KnownNat m, KnownNat n) => Disp (L m n)
211 where 222 where
212 disp n (d2 -> a) = do 223 disp n x = do
224 let a = expand x
225 let su = LA.dispf n a
226 printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su)
227
228{-
229 disp n (ud2 -> a) = do
213 if rows a == 1 && cols a == 1 230 if rows a == 1 && cols a == 1
214 then putStrLn $ "Const " ++ (last . words . LA.dispf n $ a) 231 then putStrLn $ "Const " ++ (last . words . LA.dispf n $ a)
215 else putStr "Dim " >> LA.disp n a 232 else putStr "Dim " >> LA.disp n a
233-}
216 234
217instance Disp (R n) 235instance KnownNat n => Disp (R n)
218 where 236 where
219 disp n (unDim -> v) = do 237 disp n v = do
220 let su = LA.dispf n (asRow v) 238 let su = LA.dispf n (asRow $ expand v)
221 if LA.size v == 1 239 putStr "R " >> putStr (tail . dropWhile (/='x') $ su)
222 then putStrLn $ "Const " ++ (last . words $ su )
223 else putStr "Dim " >> putStr (tail . dropWhile (/='x') $ su)
224 240
225-------------------------------------------------------------------------------- 241--------------------------------------------------------------------------------
226{-
227infixl 3 #
228(#) :: L r c -> R c -> L (r+1) c
229Dim (Dim m) # Dim v = Dim (Dim (m LA.—— asRow v))
230 242
231 243
232𝕞 :: forall n . KnownNat n => L 0 n
233𝕞 = Dim (Dim (LA.konst 0 (0,d)))
234 where
235 d = fromIntegral . natVal $ (undefined :: Proxy n)
236-}
237
238row :: R n -> L 1 n 244row :: R n -> L 1 n
239row (Dim v) = Dim (Dim (asRow v)) 245row = mkL . asRow . ud1
240 246
241col :: R n -> L n 1 247col :: R n -> L n 1
242col = tr . row 248col = tr . row
243 249
244infixl 3 ¦ 250unrow :: L 1 n -> R n
245(¦) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2) 251unrow = mkR . head . toRows . ud2
246a ¦ b = rjoin (expk a) (expk b) 252
247 where 253uncol :: L n 1 -> R n
248 Dim (Dim a') `rjoin` Dim (Dim b') = Dim (Dim (a' LA.¦ b')) 254uncol = unrow . tr
255
249 256
250infixl 2 —— 257infixl 2 ——
251(——) :: (KnownNat r1, KnownNat r2, KnownNat c) => L r1 c -> L r2 c -> L (r1+r2) c 258(——) :: (KnownNat r1, KnownNat r2, KnownNat c) => L r1 c -> L r2 c -> L (r1+r2) c
252a —— b = cjoin (expk a) (expk b) 259a —— b = mkL (expand a LA.—— expand b)
253 where
254 Dim (Dim a') `cjoin` Dim (Dim b') = Dim (Dim (a' LA.—— b'))
255
256expk :: (KnownNat n, KnownNat m) => L m n -> L m n
257expk x | singleton x = konst (d2 x `atIndex` (0,0))
258 | otherwise = x
259 where
260 singleton (d2 -> m) = rows m == 1 && cols m == 1
261 260
262 261
263{- 262infixl 3 ¦
263(¦) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2)
264a ¦ b = tr (tr a —— tr b)
264 265
265-}
266 266
267type Sq n = L n n 267type Sq n = L n n
268 268
269type GL = (KnownNat n, KnownNat m) => L m n 269type GL = (KnownNat n, KnownNat m) => L m n
270type GSq = KnownNat n => Sq n 270type GSq = KnownNat n => Sq n
271 271
272isDiag0 :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (Vector ℝ)
273isDiag0 (extract -> x)
274 | rows x == 1 && m' > 1 && z == 0 = Just y
275 | otherwise = Nothing
276 where
277 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
278 v = flatten x
279 z = v!0
280 y = subVector 1 (size v-1) v
281
282
272infixr 8 <> 283infixr 8 <>
273(<>) :: L m k -> L k n -> L m n 284(<>) :: (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n
274(d2 -> a) <> (d2 -> b) = Dim (Dim (a LA.<> b)) 285a <> b = mkL (expand a LA.<> expand b)
275 286
276infixr 8 #> 287infixr 8 #>
277(#>) :: L m n -> R n -> R m 288(#>) :: (KnownNat m, KnownNat n) => L m n -> R n -> R m
278(d2 -> m) #> (unDim -> v) = Dim (m LA.#> v) 289(isDiag0 -> Just w) #> v = mkR (w' * v')
290 where
291 v' = expand v
292 w' = subVector 0 (max 0 (size w - size v')) (vjoin [w , z])
293 z = LA.konst 0 (max 0 (size v' - size w))
294
295m #> v = mkR (expand m LA.#> expand v)
279 296
280infixr 8 <·> 297infixr 8 <·>
281(<·>) :: R n -> R n -> ℝ 298(<·>) :: R n -> R n -> ℝ
282(unDim -> u) <·> (unDim -> v) = udot u v 299(ud1 -> u) <·> (ud1 -> v)
300 | singleV u || singleV v = sumElements (u * v)
301 | otherwise = udot u v
283 302
284 303
285d2 :: forall c (n :: Nat) (n1 :: Nat). Dim n1 (Dim n c) -> c 304instance Transposable (L m n) (L n m)
286d2 = unDim . unDim 305 where
306 tr (ud2 -> a) = mkL (tr a)
287 307
308--------------------------------------------------------------------------------
309{-
310class Minim (n :: Nat) (m :: Nat)
311 where
312 type Mini n m :: Nat
288 313
289instance Transposable (L m n) (L n m) 314instance forall (n :: Nat) . Minim n n
290 where 315 where
291 tr (Dim (Dim a)) = Dim (Dim (tr a)) 316 type Mini n n = n
292 317
293 318
294eye :: forall n . KnownNat n => Sq n 319instance forall (n :: Nat) (m :: Nat) . (n <= m+1) => Minim n m
295eye = Dim (Dim (ident d))
296 where 320 where
297 d = fromIntegral . natVal $ (undefined :: Proxy n) 321 type Mini n m = n
322
323instance forall (n :: Nat) (m :: Nat) . (m <= n+1) => Minim n m
324 where
325 type Mini n m = m
326-}
327
328class Diag m d | m -> d
329 where
330 takeDiag :: m -> d
331
332
333
334instance forall n . (KnownNat n) => Diag (L n n) (R n)
335 where
336 takeDiag m = mkR (LA.takeDiag (expand m))
337
338
339instance forall m n . (KnownNat m, KnownNat n, m <= n+1) => Diag (L m n) (R m)
340 where
341 takeDiag m = mkR (LA.takeDiag (expand m))
342
343
344instance forall m n . (KnownNat m, KnownNat n, n <= m+1) => Diag (L m n) (R n)
345 where
346 takeDiag m = mkR (LA.takeDiag (expand m))
298 347
299 348
300-------------------------------------------------------------------------------- 349--------------------------------------------------------------------------------
301 350
351linSolve :: L m m -> L m n -> L m n
352linSolve (ud2 -> a) (ud2 -> b) = mkL (LA.linearSolve a b)
353
354--------------------------------------------------------------------------------
355
356withVector
357 :: forall z
358 . Vector ℝ
359 -> (forall n . (KnownNat n) => R n -> z)
360 -> z
361withVector v f =
362 case someNatVal $ fromIntegral $ size v of
363 Nothing -> error "static/dynamic mismatch"
364 Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m)
365
366
367withMatrix
368 :: forall z
369 . Matrix ℝ
370 -> (forall m n . (KnownNat m, KnownNat n) => L m n -> z)
371 -> z
372withMatrix a f =
373 case someNatVal $ fromIntegral $ rows a of
374 Nothing -> error "static/dynamic mismatch"
375 Just (SomeNat (_ :: Proxy m)) ->
376 case someNatVal $ fromIntegral $ cols a of
377 Nothing -> error "static/dynamic mismatch"
378 Just (SomeNat (_ :: Proxy n)) ->
379 f (mkL a :: L n m)
380
381--------------------------------------------------------------------------------
382
302test :: (Bool, IO ()) 383test :: (Bool, IO ())
303test = (ok,info) 384test = (ok,info)
304 where 385 where
305 ok = d2 (eye :: Sq 5) == ident 5 386 ok = expand (eye :: Sq 5) == ident 5
306 && d2 (mTm sm :: Sq 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..] 387 && ud2 (mTm sm :: Sq 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..]
307 && d2 (tm :: L 3 5) == mat 5 [1..15] 388 && ud2 (tm :: L 3 5) == LA.mat 5 [1..15]
308 && thingS == thingD 389 && thingS == thingD
309 && precS == precD 390 && precS == precD
391 && withVector (LA.vect [1..15]) sumV == sumElements (LA.fromList [1..15])
310 392
311 info = do 393 info = do
312 print $ u 394 print $ u
@@ -319,19 +401,24 @@ test = (ok,info)
319 print thingD 401 print thingD
320 print precS 402 print precS
321 print precD 403 print precD
404 print $ withVector (LA.vect [1..15]) sumV
405
406 sumV w = w <·> konst 1
322 407
323 u = vec2 3 5 408 u = vec2 3 5
324 409
410 𝕧 x = vect [x] :: R 1
411
325 v = 𝕧 2 & 4 & 7 412 v = 𝕧 2 & 4 & 7
326 413
327 mTm :: L n m -> Sq m 414-- mTm :: L n m -> Sq m
328 mTm a = tr a <> a 415 mTm a = tr a <> a
329 416
330 tm :: GL 417 tm :: GL
331 tm = lmat 0 [1..] 418 tm = lmat 0 [1..]
332 419
333 lmat :: forall m n . (KnownNat m, KnownNat n) => ℝ -> [ℝ] -> L m n 420 lmat :: forall m n . (KnownNat m, KnownNat n) => ℝ -> [ℝ] -> L m n
334 lmat z xs = Dim . Dim . reshape n' . fromList . take (m'*n') $ xs ++ repeat z 421 lmat z xs = mkL . reshape n' . LA.fromList . take (m'*n') $ xs ++ repeat z
335 where 422 where
336 m' = fromIntegral . natVal $ (undefined :: Proxy m) 423 m' = fromIntegral . natVal $ (undefined :: Proxy m)
337 n' = fromIntegral . natVal $ (undefined :: Proxy n) 424 n' = fromIntegral . natVal $ (undefined :: Proxy n)
@@ -343,12 +430,12 @@ test = (ok,info)
343 where 430 where
344 q = tm :: L 10 3 431 q = tm :: L 10 3
345 432
346 thingD = vjoin [unDim u, 1] LA.<·> tr m LA.#> m LA.#> unDim v 433 thingD = vjoin [ud1 u, 1] LA.<·> tr m LA.#> m LA.#> ud1 v
347 where 434 where
348 m = mat 3 [1..30] 435 m = LA.mat 3 [1..30]
349 436
350 precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <·> konst 2 #> v 437 precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <·> konst 2 #> v
351 precD = 1 + 2 * vjoin[unDim u, 6] LA.<·> LA.konst 2 (size (unDim u) +1, size (unDim v)) LA.#> unDim v 438 precD = 1 + 2 * vjoin[ud1 u, 6] LA.<·> LA.konst 2 (size (ud1 u) +1, size (ud1 v)) LA.#> ud1 v
352 439
353 440
354instance (KnownNat n', KnownNat m') => Testable (L n' m') 441instance (KnownNat n', KnownNat m') => Testable (L n' m')
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs
new file mode 100644
index 0000000..f9e935d
--- /dev/null
+++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs
@@ -0,0 +1,193 @@
1{-# LANGUAGE DataKinds #-}
2{-# LANGUAGE KindSignatures #-}
3{-# LANGUAGE GeneralizedNewtypeDeriving #-}
4{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE FunctionalDependencies #-}
6{-# LANGUAGE FlexibleContexts #-}
7{-# LANGUAGE ScopedTypeVariables #-}
8{-# LANGUAGE EmptyDataDecls #-}
9{-# LANGUAGE Rank2Types #-}
10{-# LANGUAGE FlexibleInstances #-}
11{-# LANGUAGE TypeOperators #-}
12{-# LANGUAGE ViewPatterns #-}
13{-# LANGUAGE GADTs #-}
14
15
16{- |
17Module : Numeric.LinearAlgebra.Static
18Copyright : (c) Alberto Ruiz 2006-14
19License : BSD3
20Stability : provisional
21
22-}
23
24module Numeric.LinearAlgebra.Static(
25 Dim(..),
26 R(..), C(..),
27 lift1F, lift2F,
28 vconcat, gvec2, gvec3, gvec4, gvect, gmat,
29 Sized(..),
30 singleV, singleM
31) where
32
33
34import GHC.TypeLits
35import Numeric.HMatrix as LA
36import Data.Packed as D
37import Data.Packed.ST
38import Data.Proxy(Proxy)
39import Foreign.Storable(Storable)
40
41
42
43newtype R n = R (Dim n (Vector ℝ))
44 deriving (Num,Fractional)
45
46
47newtype C n = C (Dim n (Vector ℂ))
48 deriving (Num,Fractional)
49
50
51
52newtype Dim (n :: Nat) t = Dim t
53 deriving Show
54
55lift1F
56 :: (c t -> c t)
57 -> Dim n (c t) -> Dim n (c t)
58lift1F f (Dim v) = Dim (f v)
59
60lift2F
61 :: (c t -> c t -> c t)
62 -> Dim n (c t) -> Dim n (c t) -> Dim n (c t)
63lift2F f (Dim u) (Dim v) = Dim (f u v)
64
65--------------------------------------------------------------------------------
66
67instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t))
68 where
69 (+) = lift2F (+)
70 (*) = lift2F (*)
71 (-) = lift2F (-)
72 abs = lift1F abs
73 signum = lift1F signum
74 negate = lift1F negate
75 fromInteger x = Dim (fromInteger x)
76
77instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t)))
78 where
79 (+) = (lift2F . lift2F) (+)
80 (*) = (lift2F . lift2F) (*)
81 (-) = (lift2F . lift2F) (-)
82 abs = (lift1F . lift1F) abs
83 signum = (lift1F . lift1F) signum
84 negate = (lift1F . lift1F) negate
85 fromInteger x = Dim (Dim (fromInteger x))
86
87instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim n (Vector t))
88 where
89 fromRational x = Dim (fromRational x)
90 (/) = lift2F (/)
91
92instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim m (Dim n (Matrix t)))
93 where
94 fromRational x = Dim (Dim (fromRational x))
95 (/) = (lift2F.lift2F) (/)
96
97--------------------------------------------------------------------------------
98
99type V n t = Dim n (Vector t)
100
101ud :: Dim n (Vector t) -> Vector t
102ud (Dim v) = v
103
104mkV :: forall (n :: Nat) t . t -> Dim n t
105mkV = Dim
106
107type M m n t = Dim m (Dim n (Matrix t))
108
109ud2 :: Dim m (Dim n (Matrix t)) -> Matrix t
110ud2 (Dim (Dim m)) = m
111
112mkM :: forall (m :: Nat) (n :: Nat) t . t -> Dim m (Dim n t)
113mkM = Dim . Dim
114
115
116vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t)
117 => V n t -> V m t -> V (n+m) t
118(ud -> u) `vconcat` (ud -> v) = mkV (vjoin [u', v'])
119 where
120 du = fromIntegral . natVal $ (undefined :: Proxy n)
121 dv = fromIntegral . natVal $ (undefined :: Proxy m)
122 u' | du > 1 && size u == 1 = LA.konst (u D.@> 0) du
123 | otherwise = u
124 v' | dv > 1 && size v == 1 = LA.konst (v D.@> 0) dv
125 | otherwise = v
126
127
128gvec2 :: Storable t => t -> t -> V 2 t
129gvec2 a b = mkV $ runSTVector $ do
130 v <- newUndefinedVector 2
131 writeVector v 0 a
132 writeVector v 1 b
133 return v
134
135gvec3 :: Storable t => t -> t -> t -> V 3 t
136gvec3 a b c = mkV $ runSTVector $ do
137 v <- newUndefinedVector 3
138 writeVector v 0 a
139 writeVector v 1 b
140 writeVector v 2 c
141 return v
142
143
144gvec4 :: Storable t => t -> t -> t -> t -> V 4 t
145gvec4 a b c d = mkV $ runSTVector $ do
146 v <- newUndefinedVector 4
147 writeVector v 0 a
148 writeVector v 1 b
149 writeVector v 2 c
150 writeVector v 3 d
151 return v
152
153
154gvect :: forall n t . (Show t, KnownNat n, Numeric t) => String -> [t] -> V n t
155gvect st xs'
156 | ok = mkV v
157 | not (null rest) && null (tail rest) = abort (show xs')
158 | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]")
159 | otherwise = abort (show xs)
160 where
161 (xs,rest) = splitAt d xs'
162 ok = size v == d && null rest
163 v = LA.fromList xs
164 d = fromIntegral . natVal $ (undefined :: Proxy n)
165 abort info = error $ st++" "++show d++" can't be created from elements "++info
166
167
168gmat :: forall m n t . (Show t, KnownNat m, KnownNat n, Numeric t) => String -> [t] -> M m n t
169gmat st xs'
170 | ok = mkM x
171 | not (null rest) && null (tail rest) = abort (show xs')
172 | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]")
173 | otherwise = abort (show xs)
174 where
175 (xs,rest) = splitAt (m'*n') xs'
176 v = LA.fromList xs
177 x = reshape n' v
178 ok = rem (size v) n' == 0 && size x == (m',n') && null rest
179 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
180 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
181 abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info
182
183
184class Num t => Sized t s d | s -> t, s -> d
185 where
186 konst :: t -> s
187 extract :: s -> d
188 fromList :: [t] -> s
189 expand :: s -> d
190
191singleV v = size v == 1
192singleM m = rows m == 1 && cols m == 1
193