diff options
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra')
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Complex.hs | 80 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Real.hs | 395 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Static.hs | 193 |
3 files changed, 514 insertions, 154 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Complex.hs b/packages/base/src/Numeric/LinearAlgebra/Complex.hs new file mode 100644 index 0000000..17bc397 --- /dev/null +++ b/packages/base/src/Numeric/LinearAlgebra/Complex.hs | |||
@@ -0,0 +1,80 @@ | |||
1 | {-# LANGUAGE DataKinds #-} | ||
2 | {-# LANGUAGE KindSignatures #-} | ||
3 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} | ||
4 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
5 | {-# LANGUAGE FunctionalDependencies #-} | ||
6 | {-# LANGUAGE FlexibleContexts #-} | ||
7 | {-# LANGUAGE ScopedTypeVariables #-} | ||
8 | {-# LANGUAGE EmptyDataDecls #-} | ||
9 | {-# LANGUAGE Rank2Types #-} | ||
10 | {-# LANGUAGE FlexibleInstances #-} | ||
11 | {-# LANGUAGE TypeOperators #-} | ||
12 | {-# LANGUAGE ViewPatterns #-} | ||
13 | {-# LANGUAGE GADTs #-} | ||
14 | |||
15 | |||
16 | {- | | ||
17 | Module : Numeric.LinearAlgebra.Complex | ||
18 | Copyright : (c) Alberto Ruiz 2006-14 | ||
19 | License : BSD3 | ||
20 | Stability : experimental | ||
21 | |||
22 | -} | ||
23 | |||
24 | module Numeric.LinearAlgebra.Complex( | ||
25 | C, | ||
26 | vec2, vec3, vec4, (&), (#), | ||
27 | vect, | ||
28 | R | ||
29 | ) where | ||
30 | |||
31 | import GHC.TypeLits | ||
32 | import Numeric.HMatrix hiding ( | ||
33 | (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vect,mat,linspace) | ||
34 | import qualified Numeric.HMatrix as LA | ||
35 | import Data.Proxy(Proxy) | ||
36 | import Numeric.LinearAlgebra.Static | ||
37 | |||
38 | |||
39 | |||
40 | instance forall n . KnownNat n => Show (C n) | ||
41 | where | ||
42 | show (ud1 -> v) | ||
43 | | size v == 1 = "("++show (v!0)++" :: C "++show d++")" | ||
44 | | otherwise = "(vect"++ drop 8 (show v)++" :: C "++show d++")" | ||
45 | where | ||
46 | d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
47 | |||
48 | |||
49 | ud1 :: C n -> Vector ℂ | ||
50 | ud1 (C (Dim v)) = v | ||
51 | |||
52 | mkC :: Vector ℂ -> C n | ||
53 | mkC = C . Dim | ||
54 | |||
55 | |||
56 | infixl 4 & | ||
57 | (&) :: forall n . KnownNat n | ||
58 | => C n -> ℂ -> C (n+1) | ||
59 | u & x = u # (mkC (LA.scalar x) :: C 1) | ||
60 | |||
61 | infixl 4 # | ||
62 | (#) :: forall n m . (KnownNat n, KnownNat m) | ||
63 | => C n -> C m -> C (n+m) | ||
64 | (C u) # (C v) = C (vconcat u v) | ||
65 | |||
66 | |||
67 | |||
68 | vec2 :: ℂ -> ℂ -> C 2 | ||
69 | vec2 a b = C (gvec2 a b) | ||
70 | |||
71 | vec3 :: ℂ -> ℂ -> ℂ -> C 3 | ||
72 | vec3 a b c = C (gvec3 a b c) | ||
73 | |||
74 | |||
75 | vec4 :: ℂ -> ℂ -> ℂ -> ℂ -> C 4 | ||
76 | vec4 a b c d = C (gvec4 a b c d) | ||
77 | |||
78 | vect :: forall n . KnownNat n => [ℂ] -> C n | ||
79 | vect xs = C (gvect "C" xs) | ||
80 | |||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs index 5634031..424e766 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Real.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs | |||
@@ -11,13 +11,15 @@ | |||
11 | {-# LANGUAGE TypeOperators #-} | 11 | {-# LANGUAGE TypeOperators #-} |
12 | {-# LANGUAGE ViewPatterns #-} | 12 | {-# LANGUAGE ViewPatterns #-} |
13 | {-# LANGUAGE GADTs #-} | 13 | {-# LANGUAGE GADTs #-} |
14 | {-# LANGUAGE OverlappingInstances #-} | ||
15 | {-# LANGUAGE TypeFamilies #-} | ||
14 | 16 | ||
15 | 17 | ||
16 | {- | | 18 | {- | |
17 | Module : Numeric.LinearAlgebra.Real | 19 | Module : Numeric.LinearAlgebra.Real |
18 | Copyright : (c) Alberto Ruiz 2006-14 | 20 | Copyright : (c) Alberto Ruiz 2006-14 |
19 | License : BSD3 | 21 | License : BSD3 |
20 | Stability : provisional | 22 | Stability : experimental |
21 | 23 | ||
22 | Experimental interface for real arrays with statically checked dimensions. | 24 | Experimental interface for real arrays with statically checked dimensions. |
23 | 25 | ||
@@ -26,165 +28,173 @@ Experimental interface for real arrays with statically checked dimensions. | |||
26 | module Numeric.LinearAlgebra.Real( | 28 | module Numeric.LinearAlgebra.Real( |
27 | -- * Vector | 29 | -- * Vector |
28 | R, | 30 | R, |
29 | vec2, vec3, vec4, 𝕧, (&), | 31 | vec2, vec3, vec4, (&), (#), |
32 | vect, | ||
33 | linspace, range, dim, | ||
30 | -- * Matrix | 34 | -- * Matrix |
31 | L, Sq, | 35 | L, Sq, |
32 | row, col, (¦),(——), | 36 | row, col, (¦),(——), |
33 | Konst(..), | 37 | unrow, uncol, |
38 | Sized(..), | ||
34 | eye, | 39 | eye, |
35 | diagR, diag, | 40 | diagR, diag, Diag(..), |
36 | blockAt, | 41 | blockAt, |
42 | mat, | ||
37 | -- * Products | 43 | -- * Products |
38 | (<>),(#>),(<·>), | 44 | (<>),(#>),(<·>), |
45 | -- * Linear Systems | ||
46 | linSolve, -- (<\>), | ||
39 | -- * Pretty printing | 47 | -- * Pretty printing |
40 | Disp(..), | 48 | Disp(..), |
41 | -- * Misc | 49 | -- * Misc |
42 | Dim, unDim, | 50 | C, |
51 | withVector, withMatrix, | ||
43 | module Numeric.HMatrix | 52 | module Numeric.HMatrix |
44 | ) where | 53 | ) where |
45 | 54 | ||
46 | 55 | ||
47 | import GHC.TypeLits | 56 | import GHC.TypeLits |
48 | import Numeric.HMatrix hiding ((<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col) | 57 | import Numeric.HMatrix hiding ( |
58 | (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vect,mat,linspace,(<\>),fromList,takeDiag) | ||
49 | import qualified Numeric.HMatrix as LA | 59 | import qualified Numeric.HMatrix as LA |
50 | import Data.Packed.ST | ||
51 | import Data.Proxy(Proxy) | 60 | import Data.Proxy(Proxy) |
61 | import Numeric.LinearAlgebra.Static | ||
62 | import Text.Printf | ||
63 | |||
64 | instance forall n . KnownNat n => Show (R n) | ||
65 | where | ||
66 | show (ud1 -> v) | ||
67 | | singleV v = "("++show (v!0)++" :: R "++show d++")" | ||
68 | | otherwise = "(vect"++ drop 8 (show v)++" :: R "++show d++")" | ||
69 | where | ||
70 | d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
52 | 71 | ||
53 | newtype Dim (n :: Nat) t = Dim t | ||
54 | deriving Show | ||
55 | 72 | ||
56 | unDim :: Dim n t -> t | 73 | ud1 :: R n -> Vector ℝ |
57 | unDim (Dim x) = x | 74 | ud1 (R (Dim v)) = v |
58 | 75 | ||
59 | -- data Proxy :: Nat -> * | ||
60 | 76 | ||
77 | mkR :: Vector ℝ -> R n | ||
78 | mkR = R . Dim | ||
61 | 79 | ||
62 | lift1F | ||
63 | :: (c t -> c t) | ||
64 | -> Dim n (c t) -> Dim n (c t) | ||
65 | lift1F f (Dim v) = Dim (f v) | ||
66 | 80 | ||
67 | lift2F | 81 | infixl 4 & |
68 | :: (c t -> c t -> c t) | 82 | (&) :: forall n . KnownNat n |
69 | -> Dim n (c t) -> Dim n (c t) -> Dim n (c t) | 83 | => R n -> ℝ -> R (n+1) |
70 | lift2F f (Dim u) (Dim v) = Dim (f u v) | 84 | u & x = u # (konst x :: R 1) |
85 | |||
86 | infixl 4 # | ||
87 | (#) :: forall n m . (KnownNat n, KnownNat m) | ||
88 | => R n -> R m -> R (n+m) | ||
89 | (R u) # (R v) = R (vconcat u v) | ||
71 | 90 | ||
72 | 91 | ||
73 | 92 | ||
74 | type R n = Dim n (Vector ℝ) | 93 | vec2 :: ℝ -> ℝ -> R 2 |
94 | vec2 a b = R (gvec2 a b) | ||
75 | 95 | ||
76 | type L m n = Dim m (Dim n (Matrix ℝ)) | 96 | vec3 :: ℝ -> ℝ -> ℝ -> R 3 |
97 | vec3 a b c = R (gvec3 a b c) | ||
77 | 98 | ||
78 | 99 | ||
79 | infixl 4 & | 100 | vec4 :: ℝ -> ℝ -> ℝ -> ℝ -> R 4 |
80 | (&) :: forall n . KnownNat n | 101 | vec4 a b c d = R (gvec4 a b c d) |
81 | => R n -> ℝ -> R (n+1) | 102 | |
82 | Dim v & x = Dim (vjoin [v', scalar x]) | 103 | vect :: forall n . KnownNat n => [ℝ] -> R n |
104 | vect xs = R (gvect "R" xs) | ||
105 | |||
106 | linspace :: forall n . KnownNat n => (ℝ,ℝ) -> R n | ||
107 | linspace (a,b) = mkR (LA.linspace d (a,b)) | ||
83 | where | 108 | where |
84 | d = fromIntegral . natVal $ (undefined :: Proxy n) | 109 | d = fromIntegral . natVal $ (undefined :: Proxy n) |
85 | v' | d > 1 && size v == 1 = LA.konst (v!0) d | ||
86 | | otherwise = v | ||
87 | 110 | ||
111 | range :: forall n . KnownNat n => R n | ||
112 | range = mkR (LA.linspace d (1,fromIntegral d)) | ||
113 | where | ||
114 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
88 | 115 | ||
89 | -- vect0 :: R 0 | 116 | dim :: forall n . KnownNat n => R n |
90 | -- vect0 = Dim (fromList[]) | 117 | dim = mkR (scalar d) |
118 | where | ||
119 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
91 | 120 | ||
92 | 𝕧 :: ℝ -> R 1 | ||
93 | 𝕧 = Dim . scalar | ||
94 | 121 | ||
122 | -------------------------------------------------------------------------------- | ||
95 | 123 | ||
96 | vec2 :: ℝ -> ℝ -> R 2 | 124 | newtype L m n = L (Dim m (Dim n (Matrix ℝ))) |
97 | vec2 a b = Dim $ runSTVector $ do | 125 | deriving (Num,Fractional) |
98 | v <- newUndefinedVector 2 | ||
99 | writeVector v 0 a | ||
100 | writeVector v 1 b | ||
101 | return v | ||
102 | 126 | ||
103 | vec3 :: ℝ -> ℝ -> ℝ -> R 3 | ||
104 | vec3 a b c = Dim $ runSTVector $ do | ||
105 | v <- newUndefinedVector 3 | ||
106 | writeVector v 0 a | ||
107 | writeVector v 1 b | ||
108 | writeVector v 2 c | ||
109 | return v | ||
110 | 127 | ||
128 | ud2 :: L m n -> Matrix ℝ | ||
129 | ud2 (L (Dim (Dim x))) = x | ||
111 | 130 | ||
112 | vec4 :: ℝ -> ℝ -> ℝ -> ℝ -> R 4 | ||
113 | vec4 a b c d = Dim $ runSTVector $ do | ||
114 | v <- newUndefinedVector 4 | ||
115 | writeVector v 0 a | ||
116 | writeVector v 1 b | ||
117 | writeVector v 2 c | ||
118 | writeVector v 3 d | ||
119 | return v | ||
120 | 131 | ||
121 | 132 | ||
122 | 133 | ||
134 | mkL :: Matrix ℝ -> L m n | ||
135 | mkL x = L (Dim (Dim x)) | ||
123 | 136 | ||
124 | instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t)) | ||
125 | where | ||
126 | (+) = lift2F (+) | ||
127 | (*) = lift2F (*) | ||
128 | (-) = lift2F (-) | ||
129 | abs = lift1F abs | ||
130 | signum = lift1F signum | ||
131 | negate = lift1F negate | ||
132 | fromInteger x = Dim (fromInteger x) | ||
133 | |||
134 | instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t))) | ||
135 | where | ||
136 | (+) = (lift2F . lift2F) (+) | ||
137 | (*) = (lift2F . lift2F) (*) | ||
138 | (-) = (lift2F . lift2F) (-) | ||
139 | abs = (lift1F . lift1F) abs | ||
140 | signum = (lift1F . lift1F) signum | ||
141 | negate = (lift1F . lift1F) negate | ||
142 | fromInteger x = Dim (Dim (fromInteger x)) | ||
143 | |||
144 | instance Fractional (Dim n (Vector Double)) | ||
145 | where | ||
146 | fromRational x = Dim (fromRational x) | ||
147 | (/) = lift2F (/) | ||
148 | 137 | ||
149 | instance Fractional (Dim m (Dim n (Matrix Double))) | 138 | instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) |
150 | where | 139 | where |
151 | fromRational x = Dim (Dim (fromRational x)) | 140 | show (ud2 -> x) |
152 | (/) = (lift2F.lift2F) (/) | 141 | | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n' |
142 | | isDiag = printf "(diag %s %s :: L %d %d)" (show z) shy m' n' | ||
143 | | otherwise = "(mat"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")" | ||
144 | where | ||
145 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
146 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
147 | isDiag = rows x == 1 && m' > 1 | ||
148 | v = flatten x | ||
149 | z = v!0 | ||
150 | y = subVector 1 (size v-1) v | ||
151 | shy = drop 9 (show y) | ||
153 | 152 | ||
154 | -------------------------------------------------------------------------------- | 153 | -------------------------------------------------------------------------------- |
155 | 154 | ||
156 | class Konst t | 155 | instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ) |
157 | where | 156 | where |
158 | konst :: ℝ -> t | 157 | konst x = mkR (LA.scalar x) |
158 | extract = ud1 | ||
159 | fromList = vect | ||
160 | expand (extract -> v) | ||
161 | | singleV v = LA.konst (v!0) d | ||
162 | | otherwise = v | ||
163 | where | ||
164 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
159 | 165 | ||
160 | instance forall n. KnownNat n => Konst (R n) | ||
161 | where | ||
162 | konst x = Dim (LA.konst x d) | ||
163 | where | ||
164 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
165 | 166 | ||
166 | instance forall m n . (KnownNat m, KnownNat n) => Konst (L m n) | 167 | instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ) |
167 | where | 168 | where |
168 | konst x = Dim (Dim (LA.konst x (m',n'))) | 169 | konst x = mkL (LA.scalar x) |
170 | extract = ud2 | ||
171 | fromList = mat | ||
172 | expand (extract -> a) | ||
173 | | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') | ||
174 | | rows a == 1 && m'>1 = diagRect x y m' n' | ||
175 | | otherwise = a | ||
169 | where | 176 | where |
170 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 177 | m' = fromIntegral . natVal $ (undefined :: Proxy m) |
171 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | 178 | n' = fromIntegral . natVal $ (undefined :: Proxy n) |
179 | v = flatten a | ||
180 | x = v!0 | ||
181 | y = subVector 1 (size v -1) v | ||
172 | 182 | ||
173 | -------------------------------------------------------------------------------- | 183 | -------------------------------------------------------------------------------- |
174 | 184 | ||
175 | diagR :: forall m n k . (KnownNat m, KnownNat n) => ℝ -> R k -> L m n | 185 | diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n |
176 | diagR x v = Dim (Dim (diagRect x (unDim v) m' n')) | 186 | diagR x v = mkL (asRow (vjoin [scalar x, expand v])) |
177 | where | ||
178 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
179 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
180 | 187 | ||
181 | diag :: KnownNat n => R n -> Sq n | 188 | diag :: KnownNat n => R n -> Sq n |
182 | diag = diagR 0 | 189 | diag = diagR 0 |
183 | 190 | ||
191 | eye :: KnownNat n => Sq n | ||
192 | eye = diag 1 | ||
193 | |||
184 | -------------------------------------------------------------------------------- | 194 | -------------------------------------------------------------------------------- |
185 | 195 | ||
186 | blockAt :: forall m n . (KnownNat m, KnownNat n) => ℝ -> Int -> Int -> Matrix Double -> L m n | 196 | blockAt :: forall m n . (KnownNat m, KnownNat n) => ℝ -> Int -> Int -> Matrix Double -> L m n |
187 | blockAt x r c a = Dim (Dim res) | 197 | blockAt x r c a = mkL res |
188 | where | 198 | where |
189 | z = scalar x | 199 | z = scalar x |
190 | z1 = LA.konst x (r,c) | 200 | z1 = LA.konst x (r,c) |
@@ -196,117 +206,189 @@ blockAt x r c a = Dim (Dim res) | |||
196 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | 206 | n' = fromIntegral . natVal $ (undefined :: Proxy n) |
197 | res = fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]] | 207 | res = fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]] |
198 | 208 | ||
199 | {- | ||
200 | matrix :: (KnownNat m, KnownNat n) => Matrix Double -> L n m | ||
201 | matrix = blockAt 0 0 0 | ||
202 | -} | ||
203 | 209 | ||
210 | |||
211 | mat :: forall m n . (KnownNat m, KnownNat n) => [ℝ] -> L m n | ||
212 | mat xs = L (gmat "L" xs) | ||
213 | |||
204 | -------------------------------------------------------------------------------- | 214 | -------------------------------------------------------------------------------- |
205 | 215 | ||
206 | class Disp t | 216 | class Disp t |
207 | where | 217 | where |
208 | disp :: Int -> t -> IO () | 218 | disp :: Int -> t -> IO () |
209 | 219 | ||
210 | instance Disp (L n m) | 220 | |
221 | instance (KnownNat m, KnownNat n) => Disp (L m n) | ||
211 | where | 222 | where |
212 | disp n (d2 -> a) = do | 223 | disp n x = do |
224 | let a = expand x | ||
225 | let su = LA.dispf n a | ||
226 | printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su) | ||
227 | |||
228 | {- | ||
229 | disp n (ud2 -> a) = do | ||
213 | if rows a == 1 && cols a == 1 | 230 | if rows a == 1 && cols a == 1 |
214 | then putStrLn $ "Const " ++ (last . words . LA.dispf n $ a) | 231 | then putStrLn $ "Const " ++ (last . words . LA.dispf n $ a) |
215 | else putStr "Dim " >> LA.disp n a | 232 | else putStr "Dim " >> LA.disp n a |
233 | -} | ||
216 | 234 | ||
217 | instance Disp (R n) | 235 | instance KnownNat n => Disp (R n) |
218 | where | 236 | where |
219 | disp n (unDim -> v) = do | 237 | disp n v = do |
220 | let su = LA.dispf n (asRow v) | 238 | let su = LA.dispf n (asRow $ expand v) |
221 | if LA.size v == 1 | 239 | putStr "R " >> putStr (tail . dropWhile (/='x') $ su) |
222 | then putStrLn $ "Const " ++ (last . words $ su ) | ||
223 | else putStr "Dim " >> putStr (tail . dropWhile (/='x') $ su) | ||
224 | 240 | ||
225 | -------------------------------------------------------------------------------- | 241 | -------------------------------------------------------------------------------- |
226 | {- | ||
227 | infixl 3 # | ||
228 | (#) :: L r c -> R c -> L (r+1) c | ||
229 | Dim (Dim m) # Dim v = Dim (Dim (m LA.—— asRow v)) | ||
230 | 242 | ||
231 | 243 | ||
232 | 𝕞 :: forall n . KnownNat n => L 0 n | ||
233 | 𝕞 = Dim (Dim (LA.konst 0 (0,d))) | ||
234 | where | ||
235 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
236 | -} | ||
237 | |||
238 | row :: R n -> L 1 n | 244 | row :: R n -> L 1 n |
239 | row (Dim v) = Dim (Dim (asRow v)) | 245 | row = mkL . asRow . ud1 |
240 | 246 | ||
241 | col :: R n -> L n 1 | 247 | col :: R n -> L n 1 |
242 | col = tr . row | 248 | col = tr . row |
243 | 249 | ||
244 | infixl 3 ¦ | 250 | unrow :: L 1 n -> R n |
245 | (¦) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2) | 251 | unrow = mkR . head . toRows . ud2 |
246 | a ¦ b = rjoin (expk a) (expk b) | 252 | |
247 | where | 253 | uncol :: L n 1 -> R n |
248 | Dim (Dim a') `rjoin` Dim (Dim b') = Dim (Dim (a' LA.¦ b')) | 254 | uncol = unrow . tr |
255 | |||
249 | 256 | ||
250 | infixl 2 —— | 257 | infixl 2 —— |
251 | (——) :: (KnownNat r1, KnownNat r2, KnownNat c) => L r1 c -> L r2 c -> L (r1+r2) c | 258 | (——) :: (KnownNat r1, KnownNat r2, KnownNat c) => L r1 c -> L r2 c -> L (r1+r2) c |
252 | a —— b = cjoin (expk a) (expk b) | 259 | a —— b = mkL (expand a LA.—— expand b) |
253 | where | ||
254 | Dim (Dim a') `cjoin` Dim (Dim b') = Dim (Dim (a' LA.—— b')) | ||
255 | |||
256 | expk :: (KnownNat n, KnownNat m) => L m n -> L m n | ||
257 | expk x | singleton x = konst (d2 x `atIndex` (0,0)) | ||
258 | | otherwise = x | ||
259 | where | ||
260 | singleton (d2 -> m) = rows m == 1 && cols m == 1 | ||
261 | 260 | ||
262 | 261 | ||
263 | {- | 262 | infixl 3 ¦ |
263 | (¦) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2) | ||
264 | a ¦ b = tr (tr a —— tr b) | ||
264 | 265 | ||
265 | -} | ||
266 | 266 | ||
267 | type Sq n = L n n | 267 | type Sq n = L n n |
268 | 268 | ||
269 | type GL = (KnownNat n, KnownNat m) => L m n | 269 | type GL = (KnownNat n, KnownNat m) => L m n |
270 | type GSq = KnownNat n => Sq n | 270 | type GSq = KnownNat n => Sq n |
271 | 271 | ||
272 | isDiag0 :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (Vector ℝ) | ||
273 | isDiag0 (extract -> x) | ||
274 | | rows x == 1 && m' > 1 && z == 0 = Just y | ||
275 | | otherwise = Nothing | ||
276 | where | ||
277 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
278 | v = flatten x | ||
279 | z = v!0 | ||
280 | y = subVector 1 (size v-1) v | ||
281 | |||
282 | |||
272 | infixr 8 <> | 283 | infixr 8 <> |
273 | (<>) :: L m k -> L k n -> L m n | 284 | (<>) :: (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n |
274 | (d2 -> a) <> (d2 -> b) = Dim (Dim (a LA.<> b)) | 285 | a <> b = mkL (expand a LA.<> expand b) |
275 | 286 | ||
276 | infixr 8 #> | 287 | infixr 8 #> |
277 | (#>) :: L m n -> R n -> R m | 288 | (#>) :: (KnownNat m, KnownNat n) => L m n -> R n -> R m |
278 | (d2 -> m) #> (unDim -> v) = Dim (m LA.#> v) | 289 | (isDiag0 -> Just w) #> v = mkR (w' * v') |
290 | where | ||
291 | v' = expand v | ||
292 | w' = subVector 0 (max 0 (size w - size v')) (vjoin [w , z]) | ||
293 | z = LA.konst 0 (max 0 (size v' - size w)) | ||
294 | |||
295 | m #> v = mkR (expand m LA.#> expand v) | ||
279 | 296 | ||
280 | infixr 8 <·> | 297 | infixr 8 <·> |
281 | (<·>) :: R n -> R n -> ℝ | 298 | (<·>) :: R n -> R n -> ℝ |
282 | (unDim -> u) <·> (unDim -> v) = udot u v | 299 | (ud1 -> u) <·> (ud1 -> v) |
300 | | singleV u || singleV v = sumElements (u * v) | ||
301 | | otherwise = udot u v | ||
283 | 302 | ||
284 | 303 | ||
285 | d2 :: forall c (n :: Nat) (n1 :: Nat). Dim n1 (Dim n c) -> c | 304 | instance Transposable (L m n) (L n m) |
286 | d2 = unDim . unDim | 305 | where |
306 | tr (ud2 -> a) = mkL (tr a) | ||
287 | 307 | ||
308 | -------------------------------------------------------------------------------- | ||
309 | {- | ||
310 | class Minim (n :: Nat) (m :: Nat) | ||
311 | where | ||
312 | type Mini n m :: Nat | ||
288 | 313 | ||
289 | instance Transposable (L m n) (L n m) | 314 | instance forall (n :: Nat) . Minim n n |
290 | where | 315 | where |
291 | tr (Dim (Dim a)) = Dim (Dim (tr a)) | 316 | type Mini n n = n |
292 | 317 | ||
293 | 318 | ||
294 | eye :: forall n . KnownNat n => Sq n | 319 | instance forall (n :: Nat) (m :: Nat) . (n <= m+1) => Minim n m |
295 | eye = Dim (Dim (ident d)) | ||
296 | where | 320 | where |
297 | d = fromIntegral . natVal $ (undefined :: Proxy n) | 321 | type Mini n m = n |
322 | |||
323 | instance forall (n :: Nat) (m :: Nat) . (m <= n+1) => Minim n m | ||
324 | where | ||
325 | type Mini n m = m | ||
326 | -} | ||
327 | |||
328 | class Diag m d | m -> d | ||
329 | where | ||
330 | takeDiag :: m -> d | ||
331 | |||
332 | |||
333 | |||
334 | instance forall n . (KnownNat n) => Diag (L n n) (R n) | ||
335 | where | ||
336 | takeDiag m = mkR (LA.takeDiag (expand m)) | ||
337 | |||
338 | |||
339 | instance forall m n . (KnownNat m, KnownNat n, m <= n+1) => Diag (L m n) (R m) | ||
340 | where | ||
341 | takeDiag m = mkR (LA.takeDiag (expand m)) | ||
342 | |||
343 | |||
344 | instance forall m n . (KnownNat m, KnownNat n, n <= m+1) => Diag (L m n) (R n) | ||
345 | where | ||
346 | takeDiag m = mkR (LA.takeDiag (expand m)) | ||
298 | 347 | ||
299 | 348 | ||
300 | -------------------------------------------------------------------------------- | 349 | -------------------------------------------------------------------------------- |
301 | 350 | ||
351 | linSolve :: L m m -> L m n -> L m n | ||
352 | linSolve (ud2 -> a) (ud2 -> b) = mkL (LA.linearSolve a b) | ||
353 | |||
354 | -------------------------------------------------------------------------------- | ||
355 | |||
356 | withVector | ||
357 | :: forall z | ||
358 | . Vector ℝ | ||
359 | -> (forall n . (KnownNat n) => R n -> z) | ||
360 | -> z | ||
361 | withVector v f = | ||
362 | case someNatVal $ fromIntegral $ size v of | ||
363 | Nothing -> error "static/dynamic mismatch" | ||
364 | Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m) | ||
365 | |||
366 | |||
367 | withMatrix | ||
368 | :: forall z | ||
369 | . Matrix ℝ | ||
370 | -> (forall m n . (KnownNat m, KnownNat n) => L m n -> z) | ||
371 | -> z | ||
372 | withMatrix a f = | ||
373 | case someNatVal $ fromIntegral $ rows a of | ||
374 | Nothing -> error "static/dynamic mismatch" | ||
375 | Just (SomeNat (_ :: Proxy m)) -> | ||
376 | case someNatVal $ fromIntegral $ cols a of | ||
377 | Nothing -> error "static/dynamic mismatch" | ||
378 | Just (SomeNat (_ :: Proxy n)) -> | ||
379 | f (mkL a :: L n m) | ||
380 | |||
381 | -------------------------------------------------------------------------------- | ||
382 | |||
302 | test :: (Bool, IO ()) | 383 | test :: (Bool, IO ()) |
303 | test = (ok,info) | 384 | test = (ok,info) |
304 | where | 385 | where |
305 | ok = d2 (eye :: Sq 5) == ident 5 | 386 | ok = expand (eye :: Sq 5) == ident 5 |
306 | && d2 (mTm sm :: Sq 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..] | 387 | && ud2 (mTm sm :: Sq 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..] |
307 | && d2 (tm :: L 3 5) == mat 5 [1..15] | 388 | && ud2 (tm :: L 3 5) == LA.mat 5 [1..15] |
308 | && thingS == thingD | 389 | && thingS == thingD |
309 | && precS == precD | 390 | && precS == precD |
391 | && withVector (LA.vect [1..15]) sumV == sumElements (LA.fromList [1..15]) | ||
310 | 392 | ||
311 | info = do | 393 | info = do |
312 | print $ u | 394 | print $ u |
@@ -319,19 +401,24 @@ test = (ok,info) | |||
319 | print thingD | 401 | print thingD |
320 | print precS | 402 | print precS |
321 | print precD | 403 | print precD |
404 | print $ withVector (LA.vect [1..15]) sumV | ||
405 | |||
406 | sumV w = w <·> konst 1 | ||
322 | 407 | ||
323 | u = vec2 3 5 | 408 | u = vec2 3 5 |
324 | 409 | ||
410 | 𝕧 x = vect [x] :: R 1 | ||
411 | |||
325 | v = 𝕧 2 & 4 & 7 | 412 | v = 𝕧 2 & 4 & 7 |
326 | 413 | ||
327 | mTm :: L n m -> Sq m | 414 | -- mTm :: L n m -> Sq m |
328 | mTm a = tr a <> a | 415 | mTm a = tr a <> a |
329 | 416 | ||
330 | tm :: GL | 417 | tm :: GL |
331 | tm = lmat 0 [1..] | 418 | tm = lmat 0 [1..] |
332 | 419 | ||
333 | lmat :: forall m n . (KnownNat m, KnownNat n) => ℝ -> [ℝ] -> L m n | 420 | lmat :: forall m n . (KnownNat m, KnownNat n) => ℝ -> [ℝ] -> L m n |
334 | lmat z xs = Dim . Dim . reshape n' . fromList . take (m'*n') $ xs ++ repeat z | 421 | lmat z xs = mkL . reshape n' . LA.fromList . take (m'*n') $ xs ++ repeat z |
335 | where | 422 | where |
336 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 423 | m' = fromIntegral . natVal $ (undefined :: Proxy m) |
337 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | 424 | n' = fromIntegral . natVal $ (undefined :: Proxy n) |
@@ -343,12 +430,12 @@ test = (ok,info) | |||
343 | where | 430 | where |
344 | q = tm :: L 10 3 | 431 | q = tm :: L 10 3 |
345 | 432 | ||
346 | thingD = vjoin [unDim u, 1] LA.<·> tr m LA.#> m LA.#> unDim v | 433 | thingD = vjoin [ud1 u, 1] LA.<·> tr m LA.#> m LA.#> ud1 v |
347 | where | 434 | where |
348 | m = mat 3 [1..30] | 435 | m = LA.mat 3 [1..30] |
349 | 436 | ||
350 | precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <·> konst 2 #> v | 437 | precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <·> konst 2 #> v |
351 | precD = 1 + 2 * vjoin[unDim u, 6] LA.<·> LA.konst 2 (size (unDim u) +1, size (unDim v)) LA.#> unDim v | 438 | precD = 1 + 2 * vjoin[ud1 u, 6] LA.<·> LA.konst 2 (size (ud1 u) +1, size (ud1 v)) LA.#> ud1 v |
352 | 439 | ||
353 | 440 | ||
354 | instance (KnownNat n', KnownNat m') => Testable (L n' m') | 441 | instance (KnownNat n', KnownNat m') => Testable (L n' m') |
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs new file mode 100644 index 0000000..f9e935d --- /dev/null +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs | |||
@@ -0,0 +1,193 @@ | |||
1 | {-# LANGUAGE DataKinds #-} | ||
2 | {-# LANGUAGE KindSignatures #-} | ||
3 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} | ||
4 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
5 | {-# LANGUAGE FunctionalDependencies #-} | ||
6 | {-# LANGUAGE FlexibleContexts #-} | ||
7 | {-# LANGUAGE ScopedTypeVariables #-} | ||
8 | {-# LANGUAGE EmptyDataDecls #-} | ||
9 | {-# LANGUAGE Rank2Types #-} | ||
10 | {-# LANGUAGE FlexibleInstances #-} | ||
11 | {-# LANGUAGE TypeOperators #-} | ||
12 | {-# LANGUAGE ViewPatterns #-} | ||
13 | {-# LANGUAGE GADTs #-} | ||
14 | |||
15 | |||
16 | {- | | ||
17 | Module : Numeric.LinearAlgebra.Static | ||
18 | Copyright : (c) Alberto Ruiz 2006-14 | ||
19 | License : BSD3 | ||
20 | Stability : provisional | ||
21 | |||
22 | -} | ||
23 | |||
24 | module Numeric.LinearAlgebra.Static( | ||
25 | Dim(..), | ||
26 | R(..), C(..), | ||
27 | lift1F, lift2F, | ||
28 | vconcat, gvec2, gvec3, gvec4, gvect, gmat, | ||
29 | Sized(..), | ||
30 | singleV, singleM | ||
31 | ) where | ||
32 | |||
33 | |||
34 | import GHC.TypeLits | ||
35 | import Numeric.HMatrix as LA | ||
36 | import Data.Packed as D | ||
37 | import Data.Packed.ST | ||
38 | import Data.Proxy(Proxy) | ||
39 | import Foreign.Storable(Storable) | ||
40 | |||
41 | |||
42 | |||
43 | newtype R n = R (Dim n (Vector ℝ)) | ||
44 | deriving (Num,Fractional) | ||
45 | |||
46 | |||
47 | newtype C n = C (Dim n (Vector ℂ)) | ||
48 | deriving (Num,Fractional) | ||
49 | |||
50 | |||
51 | |||
52 | newtype Dim (n :: Nat) t = Dim t | ||
53 | deriving Show | ||
54 | |||
55 | lift1F | ||
56 | :: (c t -> c t) | ||
57 | -> Dim n (c t) -> Dim n (c t) | ||
58 | lift1F f (Dim v) = Dim (f v) | ||
59 | |||
60 | lift2F | ||
61 | :: (c t -> c t -> c t) | ||
62 | -> Dim n (c t) -> Dim n (c t) -> Dim n (c t) | ||
63 | lift2F f (Dim u) (Dim v) = Dim (f u v) | ||
64 | |||
65 | -------------------------------------------------------------------------------- | ||
66 | |||
67 | instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t)) | ||
68 | where | ||
69 | (+) = lift2F (+) | ||
70 | (*) = lift2F (*) | ||
71 | (-) = lift2F (-) | ||
72 | abs = lift1F abs | ||
73 | signum = lift1F signum | ||
74 | negate = lift1F negate | ||
75 | fromInteger x = Dim (fromInteger x) | ||
76 | |||
77 | instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t))) | ||
78 | where | ||
79 | (+) = (lift2F . lift2F) (+) | ||
80 | (*) = (lift2F . lift2F) (*) | ||
81 | (-) = (lift2F . lift2F) (-) | ||
82 | abs = (lift1F . lift1F) abs | ||
83 | signum = (lift1F . lift1F) signum | ||
84 | negate = (lift1F . lift1F) negate | ||
85 | fromInteger x = Dim (Dim (fromInteger x)) | ||
86 | |||
87 | instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim n (Vector t)) | ||
88 | where | ||
89 | fromRational x = Dim (fromRational x) | ||
90 | (/) = lift2F (/) | ||
91 | |||
92 | instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim m (Dim n (Matrix t))) | ||
93 | where | ||
94 | fromRational x = Dim (Dim (fromRational x)) | ||
95 | (/) = (lift2F.lift2F) (/) | ||
96 | |||
97 | -------------------------------------------------------------------------------- | ||
98 | |||
99 | type V n t = Dim n (Vector t) | ||
100 | |||
101 | ud :: Dim n (Vector t) -> Vector t | ||
102 | ud (Dim v) = v | ||
103 | |||
104 | mkV :: forall (n :: Nat) t . t -> Dim n t | ||
105 | mkV = Dim | ||
106 | |||
107 | type M m n t = Dim m (Dim n (Matrix t)) | ||
108 | |||
109 | ud2 :: Dim m (Dim n (Matrix t)) -> Matrix t | ||
110 | ud2 (Dim (Dim m)) = m | ||
111 | |||
112 | mkM :: forall (m :: Nat) (n :: Nat) t . t -> Dim m (Dim n t) | ||
113 | mkM = Dim . Dim | ||
114 | |||
115 | |||
116 | vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t) | ||
117 | => V n t -> V m t -> V (n+m) t | ||
118 | (ud -> u) `vconcat` (ud -> v) = mkV (vjoin [u', v']) | ||
119 | where | ||
120 | du = fromIntegral . natVal $ (undefined :: Proxy n) | ||
121 | dv = fromIntegral . natVal $ (undefined :: Proxy m) | ||
122 | u' | du > 1 && size u == 1 = LA.konst (u D.@> 0) du | ||
123 | | otherwise = u | ||
124 | v' | dv > 1 && size v == 1 = LA.konst (v D.@> 0) dv | ||
125 | | otherwise = v | ||
126 | |||
127 | |||
128 | gvec2 :: Storable t => t -> t -> V 2 t | ||
129 | gvec2 a b = mkV $ runSTVector $ do | ||
130 | v <- newUndefinedVector 2 | ||
131 | writeVector v 0 a | ||
132 | writeVector v 1 b | ||
133 | return v | ||
134 | |||
135 | gvec3 :: Storable t => t -> t -> t -> V 3 t | ||
136 | gvec3 a b c = mkV $ runSTVector $ do | ||
137 | v <- newUndefinedVector 3 | ||
138 | writeVector v 0 a | ||
139 | writeVector v 1 b | ||
140 | writeVector v 2 c | ||
141 | return v | ||
142 | |||
143 | |||
144 | gvec4 :: Storable t => t -> t -> t -> t -> V 4 t | ||
145 | gvec4 a b c d = mkV $ runSTVector $ do | ||
146 | v <- newUndefinedVector 4 | ||
147 | writeVector v 0 a | ||
148 | writeVector v 1 b | ||
149 | writeVector v 2 c | ||
150 | writeVector v 3 d | ||
151 | return v | ||
152 | |||
153 | |||
154 | gvect :: forall n t . (Show t, KnownNat n, Numeric t) => String -> [t] -> V n t | ||
155 | gvect st xs' | ||
156 | | ok = mkV v | ||
157 | | not (null rest) && null (tail rest) = abort (show xs') | ||
158 | | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]") | ||
159 | | otherwise = abort (show xs) | ||
160 | where | ||
161 | (xs,rest) = splitAt d xs' | ||
162 | ok = size v == d && null rest | ||
163 | v = LA.fromList xs | ||
164 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
165 | abort info = error $ st++" "++show d++" can't be created from elements "++info | ||
166 | |||
167 | |||
168 | gmat :: forall m n t . (Show t, KnownNat m, KnownNat n, Numeric t) => String -> [t] -> M m n t | ||
169 | gmat st xs' | ||
170 | | ok = mkM x | ||
171 | | not (null rest) && null (tail rest) = abort (show xs') | ||
172 | | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]") | ||
173 | | otherwise = abort (show xs) | ||
174 | where | ||
175 | (xs,rest) = splitAt (m'*n') xs' | ||
176 | v = LA.fromList xs | ||
177 | x = reshape n' v | ||
178 | ok = rem (size v) n' == 0 && size x == (m',n') && null rest | ||
179 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
180 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
181 | abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info | ||
182 | |||
183 | |||
184 | class Num t => Sized t s d | s -> t, s -> d | ||
185 | where | ||
186 | konst :: t -> s | ||
187 | extract :: s -> d | ||
188 | fromList :: [t] -> s | ||
189 | expand :: s -> d | ||
190 | |||
191 | singleV v = size v == 1 | ||
192 | singleM m = rows m == 1 && cols m == 1 | ||
193 | |||