summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric/LinearAlgebra/Static
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Static')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs107
1 files changed, 60 insertions, 47 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs b/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs
index 7968d77..339ef7d 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs
@@ -7,13 +7,10 @@
7{-# LANGUAGE FunctionalDependencies #-} 7{-# LANGUAGE FunctionalDependencies #-}
8{-# LANGUAGE FlexibleContexts #-} 8{-# LANGUAGE FlexibleContexts #-}
9{-# LANGUAGE ScopedTypeVariables #-} 9{-# LANGUAGE ScopedTypeVariables #-}
10{-# LANGUAGE EmptyDataDecls #-}
11{-# LANGUAGE Rank2Types #-} 10{-# LANGUAGE Rank2Types #-}
12{-# LANGUAGE FlexibleInstances #-} 11{-# LANGUAGE FlexibleInstances #-}
13{-# LANGUAGE TypeOperators #-} 12{-# LANGUAGE TypeOperators #-}
14{-# LANGUAGE ViewPatterns #-} 13{-# LANGUAGE ViewPatterns #-}
15{-# LANGUAGE GADTs #-}
16
17 14
18{- | 15{- |
19Module : Numeric.LinearAlgebra.Static.Internal 16Module : Numeric.LinearAlgebra.Static.Internal
@@ -28,7 +25,7 @@ module Numeric.LinearAlgebra.Static.Internal where
28 25
29import GHC.TypeLits 26import GHC.TypeLits
30import qualified Numeric.LinearAlgebra.HMatrix as LA 27import qualified Numeric.LinearAlgebra.HMatrix as LA
31import Numeric.LinearAlgebra.HMatrix hiding (konst) 28import Numeric.LinearAlgebra.HMatrix hiding (konst,size)
32import Data.Packed as D 29import Data.Packed as D
33import Data.Packed.ST 30import Data.Packed.ST
34import Data.Proxy(Proxy) 31import Data.Proxy(Proxy)
@@ -83,7 +80,7 @@ ud :: Dim n (Vector t) -> Vector t
83ud (Dim v) = v 80ud (Dim v) = v
84 81
85mkV :: forall (n :: Nat) t . t -> Dim n t 82mkV :: forall (n :: Nat) t . t -> Dim n t
86mkV = Dim 83mkV = Dim
87 84
88 85
89vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t) 86vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t)
@@ -92,9 +89,9 @@ vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t)
92 where 89 where
93 du = fromIntegral . natVal $ (undefined :: Proxy n) 90 du = fromIntegral . natVal $ (undefined :: Proxy n)
94 dv = fromIntegral . natVal $ (undefined :: Proxy m) 91 dv = fromIntegral . natVal $ (undefined :: Proxy m)
95 u' | du > 1 && size u == 1 = LA.konst (u D.@> 0) du 92 u' | du > 1 && LA.size u == 1 = LA.konst (u D.@> 0) du
96 | otherwise = u 93 | otherwise = u
97 v' | dv > 1 && size v == 1 = LA.konst (v D.@> 0) dv 94 v' | dv > 1 && LA.size v == 1 = LA.konst (v D.@> 0) dv
98 | otherwise = v 95 | otherwise = v
99 96
100 97
@@ -132,7 +129,7 @@ gvect st xs'
132 | otherwise = abort (show xs) 129 | otherwise = abort (show xs)
133 where 130 where
134 (xs,rest) = splitAt d xs' 131 (xs,rest) = splitAt d xs'
135 ok = size v == d && null rest 132 ok = LA.size v == d && null rest
136 v = LA.fromList xs 133 v = LA.fromList xs
137 d = fromIntegral . natVal $ (undefined :: Proxy n) 134 d = fromIntegral . natVal $ (undefined :: Proxy n)
138 abort info = error $ st++" "++show d++" can't be created from elements "++info 135 abort info = error $ st++" "++show d++" can't be created from elements "++info
@@ -153,7 +150,7 @@ gmat st xs'
153 (xs,rest) = splitAt (m'*n') xs' 150 (xs,rest) = splitAt (m'*n') xs'
154 v = LA.fromList xs 151 v = LA.fromList xs
155 x = reshape n' v 152 x = reshape n' v
156 ok = rem (size v) n' == 0 && size x == (m',n') && null rest 153 ok = rem (LA.size v) n' == 0 && LA.size x == (m',n') && null rest
157 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int 154 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
158 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int 155 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
159 abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info 156 abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info
@@ -162,66 +159,84 @@ gmat st xs'
162 159
163class Num t => Sized t s d | s -> t, s -> d 160class Num t => Sized t s d | s -> t, s -> d
164 where 161 where
165 konst :: t -> s 162 konst :: t -> s
166 unwrap :: s -> d 163 unwrap :: s -> d t
167 fromList :: [t] -> s 164 fromList :: [t] -> s
168 extract :: s -> d 165 extract :: s -> d t
169 166 create :: d t -> Maybe s
170singleV v = size v == 1 167 size :: s -> IndexOf d
168
169singleV v = LA.size v == 1
171singleM m = rows m == 1 && cols m == 1 170singleM m = rows m == 1 && cols m == 1
172 171
173 172
174instance forall n. KnownNat n => Sized ℂ (C n) (Vector ℂ) 173instance forall n. KnownNat n => Sized ℂ (C n) Vector
175 where 174 where
175 size _ = fromIntegral . natVal $ (undefined :: Proxy n)
176 konst x = mkC (LA.scalar x) 176 konst x = mkC (LA.scalar x)
177 unwrap (C (Dim v)) = v 177 unwrap (C (Dim v)) = v
178 fromList xs = C (gvect "C" xs) 178 fromList xs = C (gvect "C" xs)
179 extract (unwrap -> v) 179 extract s@(unwrap -> v)
180 | singleV v = LA.konst (v!0) d 180 | singleV v = LA.konst (v!0) (size s)
181 | otherwise = v 181 | otherwise = v
182 where 182 create v
183 d = fromIntegral . natVal $ (undefined :: Proxy n) 183 | LA.size v == size r = Just r
184 | otherwise = Nothing
185 where
186 r = mkC v :: C n
184 187
185 188
186instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ) 189instance forall n. KnownNat n => Sized ℝ (R n) Vector
187 where 190 where
191 size _ = fromIntegral . natVal $ (undefined :: Proxy n)
188 konst x = mkR (LA.scalar x) 192 konst x = mkR (LA.scalar x)
189 unwrap (R (Dim v)) = v 193 unwrap (R (Dim v)) = v
190 fromList xs = R (gvect "R" xs) 194 fromList xs = R (gvect "R" xs)
191 extract (unwrap -> v) 195 extract s@(unwrap -> v)
192 | singleV v = LA.konst (v!0) d 196 | singleV v = LA.konst (v!0) (size s)
193 | otherwise = v 197 | otherwise = v
194 where 198 create v
195 d = fromIntegral . natVal $ (undefined :: Proxy n) 199 | LA.size v == size r = Just r
200 | otherwise = Nothing
201 where
202 r = mkR v :: R n
196 203
197 204
198 205
199instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ) 206instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) Matrix
200 where 207 where
208 size _ = ((fromIntegral . natVal) (undefined :: Proxy m)
209 ,(fromIntegral . natVal) (undefined :: Proxy n))
201 konst x = mkL (LA.scalar x) 210 konst x = mkL (LA.scalar x)
202 fromList xs = L (gmat "L" xs) 211 fromList xs = L (gmat "L" xs)
203 unwrap (L (Dim (Dim m))) = m 212 unwrap (L (Dim (Dim m))) = m
204 extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n' 213 extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n'
205 extract (unwrap -> a) 214 extract s@(unwrap -> a)
206 | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') 215 | singleM a = LA.konst (a `atIndex` (0,0)) (size s)
207 | otherwise = a 216 | otherwise = a
217 create x
218 | LA.size x == size r = Just r
219 | otherwise = Nothing
208 where 220 where
209 m' = fromIntegral . natVal $ (undefined :: Proxy m) 221 r = mkL x :: L m n
210 n' = fromIntegral . natVal $ (undefined :: Proxy n)
211 222
212 223
213instance forall m n . (KnownNat m, KnownNat n) => Sized ℂ (M m n) (Matrix ℂ) 224instance forall m n . (KnownNat m, KnownNat n) => Sized ℂ (M m n) Matrix
214 where 225 where
226 size _ = ((fromIntegral . natVal) (undefined :: Proxy m)
227 ,(fromIntegral . natVal) (undefined :: Proxy n))
215 konst x = mkM (LA.scalar x) 228 konst x = mkM (LA.scalar x)
216 fromList xs = M (gmat "M" xs) 229 fromList xs = M (gmat "M" xs)
217 unwrap (M (Dim (Dim m))) = m 230 unwrap (M (Dim (Dim m))) = m
218 extract (isDiagC -> Just (z,y,(m',n'))) = diagRect z y m' n' 231 extract (isDiagC -> Just (z,y,(m',n'))) = diagRect z y m' n'
219 extract (unwrap -> a) 232 extract s@(unwrap -> a)
220 | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') 233 | singleM a = LA.konst (a `atIndex` (0,0)) (size s)
221 | otherwise = a 234 | otherwise = a
235 create x
236 | LA.size x == size r = Just r
237 | otherwise = Nothing
222 where 238 where
223 m' = fromIntegral . natVal $ (undefined :: Proxy m) 239 r = mkM x :: M m n
224 n' = fromIntegral . natVal $ (undefined :: Proxy n)
225 240
226-------------------------------------------------------------------------------- 241--------------------------------------------------------------------------------
227 242
@@ -254,8 +269,8 @@ isDiagg (Dim (Dim x))
254 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int 269 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
255 v = flatten x 270 v = flatten x
256 z = v `atIndex` 0 271 z = v `atIndex` 0
257 y = subVector 1 (size v-1) v 272 y = subVector 1 (LA.size v-1) v
258 ny = size y 273 ny = LA.size y
259 zeros = LA.konst 0 (max 0 (min m' n' - ny)) 274 zeros = LA.konst 0 (max 0 (min m' n' - ny))
260 yz = vjoin [y,zeros] 275 yz = vjoin [y,zeros]
261 276
@@ -263,39 +278,37 @@ isDiagg (Dim (Dim x))
263 278
264instance forall n . KnownNat n => Show (R n) 279instance forall n . KnownNat n => Show (R n)
265 where 280 where
266 show (R (Dim v)) 281 show s@(R (Dim v))
267 | singleV v = "("++show (v!0)++" :: R "++show d++")" 282 | singleV v = "("++show (v!0)++" :: R "++show d++")"
268 | otherwise = "(vector"++ drop 8 (show v)++" :: R "++show d++")" 283 | otherwise = "(vector"++ drop 8 (show v)++" :: R "++show d++")"
269 where 284 where
270 d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int 285 d = size s
271 286
272instance forall n . KnownNat n => Show (C n) 287instance forall n . KnownNat n => Show (C n)
273 where 288 where
274 show (C (Dim v)) 289 show s@(C (Dim v))
275 | singleV v = "("++show (v!0)++" :: C "++show d++")" 290 | singleV v = "("++show (v!0)++" :: C "++show d++")"
276 | otherwise = "(vector"++ drop 8 (show v)++" :: C "++show d++")" 291 | otherwise = "(vector"++ drop 8 (show v)++" :: C "++show d++")"
277 where 292 where
278 d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int 293 d = size s
279 294
280instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) 295instance forall m n . (KnownNat m, KnownNat n) => Show (L m n)
281 where 296 where
282 show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n' 297 show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n'
283 show (L (Dim (Dim x))) 298 show s@(L (Dim (Dim x)))
284 | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n' 299 | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n'
285 | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")" 300 | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")"
286 where 301 where
287 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int 302 (m',n') = size s
288 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
289 303
290instance forall m n . (KnownNat m, KnownNat n) => Show (M m n) 304instance forall m n . (KnownNat m, KnownNat n) => Show (M m n)
291 where 305 where
292 show (isDiagC -> Just (z,y,(m',n'))) = printf "(diag %s %s :: M %d %d)" (show z) (drop 9 $ show y) m' n' 306 show (isDiagC -> Just (z,y,(m',n'))) = printf "(diag %s %s :: M %d %d)" (show z) (drop 9 $ show y) m' n'
293 show (M (Dim (Dim x))) 307 show s@(M (Dim (Dim x)))
294 | singleM x = printf "(%s :: M %d %d)" (show (x `atIndex` (0,0))) m' n' 308 | singleM x = printf "(%s :: M %d %d)" (show (x `atIndex` (0,0))) m' n'
295 | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: M "++show m'++" "++show n'++")" 309 | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: M "++show m'++" "++show n'++")"
296 where 310 where
297 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int 311 (m',n') = size s
298 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
299 312
300-------------------------------------------------------------------------------- 313--------------------------------------------------------------------------------
301 314