diff options
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Static')
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs | 107 |
1 files changed, 60 insertions, 47 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs b/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs index 7968d77..339ef7d 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs | |||
@@ -7,13 +7,10 @@ | |||
7 | {-# LANGUAGE FunctionalDependencies #-} | 7 | {-# LANGUAGE FunctionalDependencies #-} |
8 | {-# LANGUAGE FlexibleContexts #-} | 8 | {-# LANGUAGE FlexibleContexts #-} |
9 | {-# LANGUAGE ScopedTypeVariables #-} | 9 | {-# LANGUAGE ScopedTypeVariables #-} |
10 | {-# LANGUAGE EmptyDataDecls #-} | ||
11 | {-# LANGUAGE Rank2Types #-} | 10 | {-# LANGUAGE Rank2Types #-} |
12 | {-# LANGUAGE FlexibleInstances #-} | 11 | {-# LANGUAGE FlexibleInstances #-} |
13 | {-# LANGUAGE TypeOperators #-} | 12 | {-# LANGUAGE TypeOperators #-} |
14 | {-# LANGUAGE ViewPatterns #-} | 13 | {-# LANGUAGE ViewPatterns #-} |
15 | {-# LANGUAGE GADTs #-} | ||
16 | |||
17 | 14 | ||
18 | {- | | 15 | {- | |
19 | Module : Numeric.LinearAlgebra.Static.Internal | 16 | Module : Numeric.LinearAlgebra.Static.Internal |
@@ -28,7 +25,7 @@ module Numeric.LinearAlgebra.Static.Internal where | |||
28 | 25 | ||
29 | import GHC.TypeLits | 26 | import GHC.TypeLits |
30 | import qualified Numeric.LinearAlgebra.HMatrix as LA | 27 | import qualified Numeric.LinearAlgebra.HMatrix as LA |
31 | import Numeric.LinearAlgebra.HMatrix hiding (konst) | 28 | import Numeric.LinearAlgebra.HMatrix hiding (konst,size) |
32 | import Data.Packed as D | 29 | import Data.Packed as D |
33 | import Data.Packed.ST | 30 | import Data.Packed.ST |
34 | import Data.Proxy(Proxy) | 31 | import Data.Proxy(Proxy) |
@@ -83,7 +80,7 @@ ud :: Dim n (Vector t) -> Vector t | |||
83 | ud (Dim v) = v | 80 | ud (Dim v) = v |
84 | 81 | ||
85 | mkV :: forall (n :: Nat) t . t -> Dim n t | 82 | mkV :: forall (n :: Nat) t . t -> Dim n t |
86 | mkV = Dim | 83 | mkV = Dim |
87 | 84 | ||
88 | 85 | ||
89 | vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t) | 86 | vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t) |
@@ -92,9 +89,9 @@ vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t) | |||
92 | where | 89 | where |
93 | du = fromIntegral . natVal $ (undefined :: Proxy n) | 90 | du = fromIntegral . natVal $ (undefined :: Proxy n) |
94 | dv = fromIntegral . natVal $ (undefined :: Proxy m) | 91 | dv = fromIntegral . natVal $ (undefined :: Proxy m) |
95 | u' | du > 1 && size u == 1 = LA.konst (u D.@> 0) du | 92 | u' | du > 1 && LA.size u == 1 = LA.konst (u D.@> 0) du |
96 | | otherwise = u | 93 | | otherwise = u |
97 | v' | dv > 1 && size v == 1 = LA.konst (v D.@> 0) dv | 94 | v' | dv > 1 && LA.size v == 1 = LA.konst (v D.@> 0) dv |
98 | | otherwise = v | 95 | | otherwise = v |
99 | 96 | ||
100 | 97 | ||
@@ -132,7 +129,7 @@ gvect st xs' | |||
132 | | otherwise = abort (show xs) | 129 | | otherwise = abort (show xs) |
133 | where | 130 | where |
134 | (xs,rest) = splitAt d xs' | 131 | (xs,rest) = splitAt d xs' |
135 | ok = size v == d && null rest | 132 | ok = LA.size v == d && null rest |
136 | v = LA.fromList xs | 133 | v = LA.fromList xs |
137 | d = fromIntegral . natVal $ (undefined :: Proxy n) | 134 | d = fromIntegral . natVal $ (undefined :: Proxy n) |
138 | abort info = error $ st++" "++show d++" can't be created from elements "++info | 135 | abort info = error $ st++" "++show d++" can't be created from elements "++info |
@@ -153,7 +150,7 @@ gmat st xs' | |||
153 | (xs,rest) = splitAt (m'*n') xs' | 150 | (xs,rest) = splitAt (m'*n') xs' |
154 | v = LA.fromList xs | 151 | v = LA.fromList xs |
155 | x = reshape n' v | 152 | x = reshape n' v |
156 | ok = rem (size v) n' == 0 && size x == (m',n') && null rest | 153 | ok = rem (LA.size v) n' == 0 && LA.size x == (m',n') && null rest |
157 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | 154 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int |
158 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | 155 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int |
159 | abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info | 156 | abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info |
@@ -162,66 +159,84 @@ gmat st xs' | |||
162 | 159 | ||
163 | class Num t => Sized t s d | s -> t, s -> d | 160 | class Num t => Sized t s d | s -> t, s -> d |
164 | where | 161 | where |
165 | konst :: t -> s | 162 | konst :: t -> s |
166 | unwrap :: s -> d | 163 | unwrap :: s -> d t |
167 | fromList :: [t] -> s | 164 | fromList :: [t] -> s |
168 | extract :: s -> d | 165 | extract :: s -> d t |
169 | 166 | create :: d t -> Maybe s | |
170 | singleV v = size v == 1 | 167 | size :: s -> IndexOf d |
168 | |||
169 | singleV v = LA.size v == 1 | ||
171 | singleM m = rows m == 1 && cols m == 1 | 170 | singleM m = rows m == 1 && cols m == 1 |
172 | 171 | ||
173 | 172 | ||
174 | instance forall n. KnownNat n => Sized ℂ (C n) (Vector ℂ) | 173 | instance forall n. KnownNat n => Sized ℂ (C n) Vector |
175 | where | 174 | where |
175 | size _ = fromIntegral . natVal $ (undefined :: Proxy n) | ||
176 | konst x = mkC (LA.scalar x) | 176 | konst x = mkC (LA.scalar x) |
177 | unwrap (C (Dim v)) = v | 177 | unwrap (C (Dim v)) = v |
178 | fromList xs = C (gvect "C" xs) | 178 | fromList xs = C (gvect "C" xs) |
179 | extract (unwrap -> v) | 179 | extract s@(unwrap -> v) |
180 | | singleV v = LA.konst (v!0) d | 180 | | singleV v = LA.konst (v!0) (size s) |
181 | | otherwise = v | 181 | | otherwise = v |
182 | where | 182 | create v |
183 | d = fromIntegral . natVal $ (undefined :: Proxy n) | 183 | | LA.size v == size r = Just r |
184 | | otherwise = Nothing | ||
185 | where | ||
186 | r = mkC v :: C n | ||
184 | 187 | ||
185 | 188 | ||
186 | instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ) | 189 | instance forall n. KnownNat n => Sized ℝ (R n) Vector |
187 | where | 190 | where |
191 | size _ = fromIntegral . natVal $ (undefined :: Proxy n) | ||
188 | konst x = mkR (LA.scalar x) | 192 | konst x = mkR (LA.scalar x) |
189 | unwrap (R (Dim v)) = v | 193 | unwrap (R (Dim v)) = v |
190 | fromList xs = R (gvect "R" xs) | 194 | fromList xs = R (gvect "R" xs) |
191 | extract (unwrap -> v) | 195 | extract s@(unwrap -> v) |
192 | | singleV v = LA.konst (v!0) d | 196 | | singleV v = LA.konst (v!0) (size s) |
193 | | otherwise = v | 197 | | otherwise = v |
194 | where | 198 | create v |
195 | d = fromIntegral . natVal $ (undefined :: Proxy n) | 199 | | LA.size v == size r = Just r |
200 | | otherwise = Nothing | ||
201 | where | ||
202 | r = mkR v :: R n | ||
196 | 203 | ||
197 | 204 | ||
198 | 205 | ||
199 | instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ) | 206 | instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) Matrix |
200 | where | 207 | where |
208 | size _ = ((fromIntegral . natVal) (undefined :: Proxy m) | ||
209 | ,(fromIntegral . natVal) (undefined :: Proxy n)) | ||
201 | konst x = mkL (LA.scalar x) | 210 | konst x = mkL (LA.scalar x) |
202 | fromList xs = L (gmat "L" xs) | 211 | fromList xs = L (gmat "L" xs) |
203 | unwrap (L (Dim (Dim m))) = m | 212 | unwrap (L (Dim (Dim m))) = m |
204 | extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n' | 213 | extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n' |
205 | extract (unwrap -> a) | 214 | extract s@(unwrap -> a) |
206 | | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') | 215 | | singleM a = LA.konst (a `atIndex` (0,0)) (size s) |
207 | | otherwise = a | 216 | | otherwise = a |
217 | create x | ||
218 | | LA.size x == size r = Just r | ||
219 | | otherwise = Nothing | ||
208 | where | 220 | where |
209 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 221 | r = mkL x :: L m n |
210 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
211 | 222 | ||
212 | 223 | ||
213 | instance forall m n . (KnownNat m, KnownNat n) => Sized ℂ (M m n) (Matrix ℂ) | 224 | instance forall m n . (KnownNat m, KnownNat n) => Sized ℂ (M m n) Matrix |
214 | where | 225 | where |
226 | size _ = ((fromIntegral . natVal) (undefined :: Proxy m) | ||
227 | ,(fromIntegral . natVal) (undefined :: Proxy n)) | ||
215 | konst x = mkM (LA.scalar x) | 228 | konst x = mkM (LA.scalar x) |
216 | fromList xs = M (gmat "M" xs) | 229 | fromList xs = M (gmat "M" xs) |
217 | unwrap (M (Dim (Dim m))) = m | 230 | unwrap (M (Dim (Dim m))) = m |
218 | extract (isDiagC -> Just (z,y,(m',n'))) = diagRect z y m' n' | 231 | extract (isDiagC -> Just (z,y,(m',n'))) = diagRect z y m' n' |
219 | extract (unwrap -> a) | 232 | extract s@(unwrap -> a) |
220 | | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') | 233 | | singleM a = LA.konst (a `atIndex` (0,0)) (size s) |
221 | | otherwise = a | 234 | | otherwise = a |
235 | create x | ||
236 | | LA.size x == size r = Just r | ||
237 | | otherwise = Nothing | ||
222 | where | 238 | where |
223 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 239 | r = mkM x :: M m n |
224 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
225 | 240 | ||
226 | -------------------------------------------------------------------------------- | 241 | -------------------------------------------------------------------------------- |
227 | 242 | ||
@@ -254,8 +269,8 @@ isDiagg (Dim (Dim x)) | |||
254 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | 269 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int |
255 | v = flatten x | 270 | v = flatten x |
256 | z = v `atIndex` 0 | 271 | z = v `atIndex` 0 |
257 | y = subVector 1 (size v-1) v | 272 | y = subVector 1 (LA.size v-1) v |
258 | ny = size y | 273 | ny = LA.size y |
259 | zeros = LA.konst 0 (max 0 (min m' n' - ny)) | 274 | zeros = LA.konst 0 (max 0 (min m' n' - ny)) |
260 | yz = vjoin [y,zeros] | 275 | yz = vjoin [y,zeros] |
261 | 276 | ||
@@ -263,39 +278,37 @@ isDiagg (Dim (Dim x)) | |||
263 | 278 | ||
264 | instance forall n . KnownNat n => Show (R n) | 279 | instance forall n . KnownNat n => Show (R n) |
265 | where | 280 | where |
266 | show (R (Dim v)) | 281 | show s@(R (Dim v)) |
267 | | singleV v = "("++show (v!0)++" :: R "++show d++")" | 282 | | singleV v = "("++show (v!0)++" :: R "++show d++")" |
268 | | otherwise = "(vector"++ drop 8 (show v)++" :: R "++show d++")" | 283 | | otherwise = "(vector"++ drop 8 (show v)++" :: R "++show d++")" |
269 | where | 284 | where |
270 | d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | 285 | d = size s |
271 | 286 | ||
272 | instance forall n . KnownNat n => Show (C n) | 287 | instance forall n . KnownNat n => Show (C n) |
273 | where | 288 | where |
274 | show (C (Dim v)) | 289 | show s@(C (Dim v)) |
275 | | singleV v = "("++show (v!0)++" :: C "++show d++")" | 290 | | singleV v = "("++show (v!0)++" :: C "++show d++")" |
276 | | otherwise = "(vector"++ drop 8 (show v)++" :: C "++show d++")" | 291 | | otherwise = "(vector"++ drop 8 (show v)++" :: C "++show d++")" |
277 | where | 292 | where |
278 | d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | 293 | d = size s |
279 | 294 | ||
280 | instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) | 295 | instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) |
281 | where | 296 | where |
282 | show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n' | 297 | show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n' |
283 | show (L (Dim (Dim x))) | 298 | show s@(L (Dim (Dim x))) |
284 | | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n' | 299 | | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n' |
285 | | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")" | 300 | | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")" |
286 | where | 301 | where |
287 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | 302 | (m',n') = size s |
288 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
289 | 303 | ||
290 | instance forall m n . (KnownNat m, KnownNat n) => Show (M m n) | 304 | instance forall m n . (KnownNat m, KnownNat n) => Show (M m n) |
291 | where | 305 | where |
292 | show (isDiagC -> Just (z,y,(m',n'))) = printf "(diag %s %s :: M %d %d)" (show z) (drop 9 $ show y) m' n' | 306 | show (isDiagC -> Just (z,y,(m',n'))) = printf "(diag %s %s :: M %d %d)" (show z) (drop 9 $ show y) m' n' |
293 | show (M (Dim (Dim x))) | 307 | show s@(M (Dim (Dim x))) |
294 | | singleM x = printf "(%s :: M %d %d)" (show (x `atIndex` (0,0))) m' n' | 308 | | singleM x = printf "(%s :: M %d %d)" (show (x `atIndex` (0,0))) m' n' |
295 | | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: M "++show m'++" "++show n'++")" | 309 | | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: M "++show m'++" "++show n'++")" |
296 | where | 310 | where |
297 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | 311 | (m',n') = size s |
298 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
299 | 312 | ||
300 | -------------------------------------------------------------------------------- | 313 | -------------------------------------------------------------------------------- |
301 | 314 | ||