summaryrefslogtreecommitdiff
path: root/examples/pca2.hs
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2010-02-03 17:53:51 +0000
committerAlberto Ruiz <aruiz@um.es>2010-02-03 17:53:51 +0000
commit3f5bf5985d3da0e4d01cd9c126cb781cb6fc28ef (patch)
treef4ff2376e550668e19912b96e3ad88979a52cfb3 /examples/pca2.hs
parentab5a2c1c8b4ecd7f8d9d9823d9976410aa94bb18 (diff)
updated examples, removed Util module
Diffstat (limited to 'examples/pca2.hs')
-rw-r--r--examples/pca2.hs28
1 files changed, 13 insertions, 15 deletions
diff --git a/examples/pca2.hs b/examples/pca2.hs
index 8c20370..c38857c 100644
--- a/examples/pca2.hs
+++ b/examples/pca2.hs
@@ -9,33 +9,31 @@ import Control.Monad(when)
9type Vec = Vector Double 9type Vec = Vector Double
10type Mat = Matrix Double 10type Mat = Matrix Double
11 11
12sumColumns m = constant 1 (rows m) <> m 12-- Vector with the mean value of the columns of a matrix
13mean a = constant (recip . fromIntegral . rows $ a) (rows a) <> a
13 14
14-- Vector with the mean value of the columns of a Mat 15-- covariance matrix of a list of observations stored as rows
15mean x = sumColumns x / fromIntegral (rows x) 16cov x = (trans xc <> xc) / fromIntegral (rows x - 1)
17 where xc = x - asRow (mean x)
16 18
17-- covariance matrix of a list of observations as rows of a matrix
18cov x = (trans xc <> xc) / fromIntegral (rows x -1)
19 where xc = center x
20 center m = m - constant 1 (rows m) `outer` mean m
21 19
22type Stat = (Vec, [Double], Mat) 20type Stat = (Vec, [Double], Mat)
23-- 1st and 2nd order statistics of a dataset (mean, eigenvalues and eigenvectors of cov) 21-- 1st and 2nd order statistics of a dataset (mean, eigenvalues and eigenvectors of cov)
24stat :: Mat -> Stat 22stat :: Mat -> Stat
25stat x = (m, toList s, trans v) where 23stat x = (m, toList s, trans v) where
26 m = mean x 24 m = mean x
27 (s,v) = eigSH' (cov x) 25 (s,v) = eigSH' (cov x)
28 26
29-- creates the compression and decompression functions from the desired reconstruction 27-- creates the compression and decompression functions from the desired reconstruction
30-- quality and the statistics of a data set 28-- quality and the statistics of a data set
31pca :: Double -> Stat -> (Vec -> Vec , Vec -> Vec) 29pca :: Double -> Stat -> (Vec -> Vec , Vec -> Vec)
32pca prec (m,s,v) = (encode,decode) 30pca prec (m,s,v) = (encode,decode)
33 where 31 where
34 encode x = vp <> (x - m) 32 encode x = vp <> (x - m)
35 decode x = x <> vp + m 33 decode x = x <> vp + m
36 vp = takeRows n v 34 vp = takeRows n v
37 n = 1 + (length $ fst $ span (< (prec'*sum s)) $ cumSum s) 35 n = 1 + (length $ fst $ span (< (prec'*sum s)) $ cumSum s)
38 cumSum = tail . scanl (+) 0.0 36 cumSum = tail . scanl (+) 0.0
39 prec' = if prec <=0.0 || prec >= 1.0 37 prec' = if prec <=0.0 || prec >= 1.0
40 then error "the precision in pca must be 0<prec<1" 38 then error "the precision in pca must be 0<prec<1"
41 else prec 39 else prec
@@ -49,7 +47,7 @@ test st prec x = do
49 let (pe,pd) = pca prec st 47 let (pe,pd) = pca prec st
50 let y = pe x 48 let y = pe x
51 print $ dim y 49 print $ dim y
52 shdigit (pd y) 50 shdigit (pd y)
53 51
54main = do 52main = do
55 ok <- doesFileExist ("mnist.txt") 53 ok <- doesFileExist ("mnist.txt")
@@ -58,7 +56,7 @@ main = do
58 system("wget -nv http://dis.um.es/~alberto/material/sp/mnist.txt.gz") 56 system("wget -nv http://dis.um.es/~alberto/material/sp/mnist.txt.gz")
59 system("gunzip mnist.txt.gz") 57 system("gunzip mnist.txt.gz")
60 return () 58 return ()
61 m <- fromFile "mnist.txt" (5000,785) 59 m <- loadMatrix "mnist.txt"
62 let xs = takeColumns (cols m -1) m 60 let xs = takeColumns (cols m -1) m
63 let x = toRows xs !! 4 -- an arbitrary test vector 61 let x = toRows xs !! 4 -- an arbitrary test vector
64 shdigit x 62 shdigit x