diff options
Diffstat (limited to 'examples/pca2.hs')
-rw-r--r-- | examples/pca2.hs | 28 |
1 files changed, 13 insertions, 15 deletions
diff --git a/examples/pca2.hs b/examples/pca2.hs index 8c20370..c38857c 100644 --- a/examples/pca2.hs +++ b/examples/pca2.hs | |||
@@ -9,33 +9,31 @@ import Control.Monad(when) | |||
9 | type Vec = Vector Double | 9 | type Vec = Vector Double |
10 | type Mat = Matrix Double | 10 | type Mat = Matrix Double |
11 | 11 | ||
12 | sumColumns m = constant 1 (rows m) <> m | 12 | -- Vector with the mean value of the columns of a matrix |
13 | mean a = constant (recip . fromIntegral . rows $ a) (rows a) <> a | ||
13 | 14 | ||
14 | -- Vector with the mean value of the columns of a Mat | 15 | -- covariance matrix of a list of observations stored as rows |
15 | mean x = sumColumns x / fromIntegral (rows x) | 16 | cov x = (trans xc <> xc) / fromIntegral (rows x - 1) |
17 | where xc = x - asRow (mean x) | ||
16 | 18 | ||
17 | -- covariance matrix of a list of observations as rows of a matrix | ||
18 | cov x = (trans xc <> xc) / fromIntegral (rows x -1) | ||
19 | where xc = center x | ||
20 | center m = m - constant 1 (rows m) `outer` mean m | ||
21 | 19 | ||
22 | type Stat = (Vec, [Double], Mat) | 20 | type Stat = (Vec, [Double], Mat) |
23 | -- 1st and 2nd order statistics of a dataset (mean, eigenvalues and eigenvectors of cov) | 21 | -- 1st and 2nd order statistics of a dataset (mean, eigenvalues and eigenvectors of cov) |
24 | stat :: Mat -> Stat | 22 | stat :: Mat -> Stat |
25 | stat x = (m, toList s, trans v) where | 23 | stat x = (m, toList s, trans v) where |
26 | m = mean x | 24 | m = mean x |
27 | (s,v) = eigSH' (cov x) | 25 | (s,v) = eigSH' (cov x) |
28 | 26 | ||
29 | -- creates the compression and decompression functions from the desired reconstruction | 27 | -- creates the compression and decompression functions from the desired reconstruction |
30 | -- quality and the statistics of a data set | 28 | -- quality and the statistics of a data set |
31 | pca :: Double -> Stat -> (Vec -> Vec , Vec -> Vec) | 29 | pca :: Double -> Stat -> (Vec -> Vec , Vec -> Vec) |
32 | pca prec (m,s,v) = (encode,decode) | 30 | pca prec (m,s,v) = (encode,decode) |
33 | where | 31 | where |
34 | encode x = vp <> (x - m) | 32 | encode x = vp <> (x - m) |
35 | decode x = x <> vp + m | 33 | decode x = x <> vp + m |
36 | vp = takeRows n v | 34 | vp = takeRows n v |
37 | n = 1 + (length $ fst $ span (< (prec'*sum s)) $ cumSum s) | 35 | n = 1 + (length $ fst $ span (< (prec'*sum s)) $ cumSum s) |
38 | cumSum = tail . scanl (+) 0.0 | 36 | cumSum = tail . scanl (+) 0.0 |
39 | prec' = if prec <=0.0 || prec >= 1.0 | 37 | prec' = if prec <=0.0 || prec >= 1.0 |
40 | then error "the precision in pca must be 0<prec<1" | 38 | then error "the precision in pca must be 0<prec<1" |
41 | else prec | 39 | else prec |
@@ -49,7 +47,7 @@ test st prec x = do | |||
49 | let (pe,pd) = pca prec st | 47 | let (pe,pd) = pca prec st |
50 | let y = pe x | 48 | let y = pe x |
51 | print $ dim y | 49 | print $ dim y |
52 | shdigit (pd y) | 50 | shdigit (pd y) |
53 | 51 | ||
54 | main = do | 52 | main = do |
55 | ok <- doesFileExist ("mnist.txt") | 53 | ok <- doesFileExist ("mnist.txt") |
@@ -58,7 +56,7 @@ main = do | |||
58 | system("wget -nv http://dis.um.es/~alberto/material/sp/mnist.txt.gz") | 56 | system("wget -nv http://dis.um.es/~alberto/material/sp/mnist.txt.gz") |
59 | system("gunzip mnist.txt.gz") | 57 | system("gunzip mnist.txt.gz") |
60 | return () | 58 | return () |
61 | m <- fromFile "mnist.txt" (5000,785) | 59 | m <- loadMatrix "mnist.txt" |
62 | let xs = takeColumns (cols m -1) m | 60 | let xs = takeColumns (cols m -1) m |
63 | let x = toRows xs !! 4 -- an arbitrary test vector | 61 | let x = toRows xs !! 4 -- an arbitrary test vector |
64 | shdigit x | 62 | shdigit x |