Home > Misc > Neural networks

Non-neural MNIST benchmarks

Introduction

I've mentioned a couple of times that I'd like to try some non-neural approaches to classifying the MNIST digits, so that I can have a better idea of how good the neural networks (and other machine learning algorithms, if I ever work with them) really are. This is just me – I'm not pretending that my results here are world-class classification benchmarks, they're just what I might have come up with if I'd been given the problem to classify the MNIST digits and didn't know where to start.

Having read through to chapter 3 of Michael Nielsen's book, my best neural network has reached 98.1% accuracy on MNIST. Earlier, in chapter 1, Nielsen had written:

It's not difficult to find other ideas which achieve accuracies in the 20 to 50 percent range. If you work a bit harder you can get up over 50 percent. But to get much higher accuracies it helps to use established machine learning algorithms.

I think that either this is underselling how easy the MNIST dataset is to classify, or "machine learning algorithms" include a broader array of pretty simple techniques than I'd thought (the Wikipedia machine learning template includes "linear regression", so maybe it's the latter). Dotting the test image with the means of each digit's training images gets me to 82%. With some more careful thought about PCA, I get 94%. With k-nearest-neighbours (which does look like an "official" machine learning algorithm, albeit a very simple one) I get over 97%. So, at least on MNIST, any advantage of neural networks for classification purposes is really just trying to squeeze as much out of the data as possible, rather than being an enormous amount better than simple alternatives.

Comparing test images to digit means

This is a very simple algorithm:

Vectors should be normalised, so as not to favour digits with more light pixels (forgetting to do this led to me getting about 60% accuracy and therefore thinking that my then-weird PCA algorithm was good for getting 70%). On the test set it gets 8236/10000 correct. Jittered scatterplot of test answers against true answers:

R code:

library(ggplot2)

load("../digit_data/r_images_nielsen.rda")

digit_image_matrices = list()
mean_digits = list()
length(digit_image_matrices) = 10
length(mean_digits) = 10

# Training
for (i in 0:9) {
  this_i = which(digits$training == i)
  num_digits = length(this_i)
  
  digit_image_matrices[[i + 1]] = matrix(0, nrow=num_digits, ncol=784)
  for (j in 1:num_digits) {
    digit_image_matrices[[i + 1]][j, ] = images$training[[this_i[j]]]
  }
  
  mean_digits[[i + 1]] = colMeans(digit_image_matrices[[i + 1]])
  
  this_norm = sqrt(sum(mean_digits[[i + 1]] * mean_digits[[i + 1]]))
  mean_digits[[i + 1]] = mean_digits[[i + 1]] / this_norm
}

# Testing
test_answers = list()
num_test_images = length(digits$test)
length(test_answers) = num_test_images
test_answers_vec = numeric(num_test_images)

for (i in 1:num_test_images) {
  test_answers[[i]] = numeric(10)
  
  for (j in 1:10) {
    this_img = images$test[[i]]
    this_norm = sqrt(sum(this_img*this_img))
    this_img = this_img / this_norm
    test_answers[[i]][j] = sum(mean_digits[[j]] * this_img)
  }
  
  test_answers[[i]] = test_answers[[i]] / sum(test_answers[[i]])
  test_answers_vec[i] = which(test_answers[[i]] == max(test_answers[[i]]))[1] - 1
}

true_answers = digits$test

qplot(true_answers, test_answers_vec, alpha=.1, position=position_jitter()) + guides(alpha=FALSE)
print(length(which(test_answers_vec == true_answers)))

PCA

My first attempt at a PCA-based classification method felt wrong but got me to what I then thought was a respectable 70%. After working through the theory of PCA carefully (see my notes on the subject), I came up with a better way of using principal components to classify digits, with the accuracy over 94%.

The logic behind the method goes like this. Consider all the '0' digit images, treat them as vectors of length 784, and put them in a big matrix with 784 columns. Calculate the first few principal components of this data matrix – (too?) loosely speaking, these should correspond to the main ways in which the digit is drawn. Repeat for the other digits, so that there is now a collection of each digit's first few principal components.

Given a test image, the algorithm goes:

The question is how many principal components to keep. Using all 784 is obviously wrong, since every digit's full set of factors will perfectly transform and back-transform any test image to itself. Using only the first factor will miss lots of variations on the digit (for example, if it's shifted a little from where the first factor says it should be). After a bit of playing around, I found that about 30 factors gave near-optimal results, with 9462 out of the 10000 test images classified correctly. Jittered scatterplot:

R code (sorry for the inconsistency between lists and matrices):

library(ggplot2)

load("../digit_data/r_images_nielsen.rda")

digit_image_matrices = list()
digit_pca = list()
digit_means = list()
length(digit_image_matrices) = 10
length(digit_pca) = 10
length(digit_means) = 10

# Training
for (i in 0:9) {
  print(i)
  this_i = which(digits$training == i)
  num_digits = length(this_i)
  
  digit_image_matrices[[i+1]] = matrix(0, nrow=num_digits, ncol=784)
  for (j in 1:num_digits) {
    digit_image_matrices[[i+1]][j, ] = images$training[[this_i[j]]]
  }
  
  digit_pca[[i+1]]$pca = prcomp(digit_image_matrices[[i+1]], center=TRUE)
}


for (i in 0:9) {
  # Calculate means by digit for use in centring the test images.
  digit_means[[i + 1]] = colMeans(digit_image_matrices[[i + 1]])
}

# Testing

# Number of principal components to use:
num_evecs = 30

num_test = length(digits$test)
test_image_matrix = matrix(0, nrow=num_test, ncol=784)

for (j in 1:num_test) {
  test_image_matrix[j, ] = images$test[[j]]
}

test_answers = matrix(0, nrow=num_test, ncol=10)
test_answers_vec = numeric(num_test)

for (i in 0:9) {
  print(i)
  this_centred_matrix = test_image_matrix
  
  for (j in 1:num_test) {
    this_centred_matrix[j, ] = this_centred_matrix[j, ] - digit_means[[i + 1]]
  }
  
  # Convert the images to the first num_evecs factors, then convert back and
  # see how good the reproduction of the original image is.
  reduced_images = t(digit_pca[[i + 1]]$pca$rotation[ , 1:num_evecs]) %*% t(this_centred_matrix)
  back_transf_images = t(digit_pca[[i + 1]]$pca$rotation[ , 1:num_evecs] %*% reduced_images)
  
  comparison = abs(this_centred_matrix - back_transf_images)
  test_answers[ , i + 1] = rowSums(comparison)
}

test_answers_vec = apply(test_answers, 1, function(x) which(x == min(x)) - 1)
true_answers = digits$test[1:num_test]

qplot(true_answers, test_answers_vec, alpha=.1, position=position_jitter()) + guides(alpha=FALSE)
print(length(which(test_answers_vec == true_answers)))

k-nearest neighbours

(I didn't know about this algorithm ahead of time, but I saw it in the Wikipedia article on MNIST, and it looked pretty simple.)

Consider each training image \( T^i \) as a vector of length 784, with j-th entry \( T^i_j \). Define the distance from \( T^i \) to a test image \( T \) as

\begin{equation*} d^i = \sqrt{\sum_{j=1}^{784} (T^i_j - T_j)^2}. \end{equation*}

The nearest neighbour to \( T \) across the training images is simply the image \( T^i \) with the smallest distance \( d^i \), and the simplest implementation of k-nearest neighbours is just "1-nearest neighbour", and (for MNIST) returns the digit of this \( T^i \). More generally, for \( k > 1 \), the classification will be the most common of the \( k \) nearest neighbours in the training set.

The class R package implements k-NN and works pretty much straight out of the box, and it classified the digits so well that I thought I might have made a mistake, accidentally feeding the true answers into the system somewhere. I wrote my own implementation of k-NN just to make sure it was all as simple and effective as it seemed, and my Rcpp code follows shortly.

Wikipedia says that k-NN can struggle with high-dimensional data sets, but the 784 dimensions of MNIST posed no problems and the 1-nearest-neighbour got 96.7% of the test images correct. Nevertheless, the algorithm is both much faster and slightly more accurate working in reduced dimensions, and using the first 30 principal components and \( k = 3 \) brings the score up to 9738/10000. With my code, it took about 25 minutes to work through the classifications with 784-length vectors, and about half a minute with 30-length vectors. Jittered scatterplot:

(Unlike in the previous section, the principal components are calculated over the full training dataset, and not separately by digit.)

As in a previous post, I use Romain Fran├žois's code to return sorted indices for me. The main function to be called from R is k_nearest_classif(), and it is less than 40 lines long. Rcpp:

#include <Rcpp.h>
#include <queue>
using namespace Rcpp;

// http://gallery.rcpp.org/articles/top-elements-from-vectors-using-priority-queue/
template <int RTYPE>
class IndexComparator {
public:
  typedef typename Rcpp::traits::storage_type<RTYPE>::type STORAGE ;
  
  IndexComparator( const Vector<RTYPE>& data_ ) : data(data_.begin()){}
  
  inline bool operator()(int i, int j) const {
    return data[i] > data[j] || (data[i] == data[j] && j > i ) ;    
  }
  
private:
  const STORAGE* data ;
} ;

template <>
class IndexComparator<STRSXP> {
public:
  IndexComparator( const CharacterVector& data_ ) : data(data_.begin()){}
  
  inline bool operator()(int i, int j) const {
    return (String)data[i] > (String)data[j] || (data[i] == data[j] && j > i );    
  }
  
private:
  const SEXP* data ;
} ;

template <int RTYPE>
class IndexQueue {
public:
  typedef std::priority_queue<int, std::vector<int>, IndexComparator<RTYPE> > Queue ;
  
  IndexQueue( const Vector<RTYPE>& data_ ) : comparator(data_), q(comparator), data(data_) {}
  
  inline operator IntegerVector(){
    int n = q.size() ;
    IntegerVector res(n) ;
    for( int i=0; i<n; i++){
      // +1 for 1-based R indexing [deleted -- DB.]
      res[i] = q.top();
      q.pop() ;
    }
    return res ;
  }
  inline void input( int i){ 
    // if( data[ q.top() ] < data[i] ){
    if( comparator(i, q.top() ) ){
      q.pop(); 
      q.push(i) ;    
    }
  }
  inline void pop(){ q.pop() ; }
  inline void push( int i){ q.push(i) ; }
  
private:
  IndexComparator<RTYPE> comparator ;
  Queue q ;  
  const Vector<RTYPE>& data ;
} ;


template <int RTYPE>
IntegerVector top_index(Vector<RTYPE> v, int n){
  int size = v.size() ;
  
  // not interesting case. Less data than n
  if( size < n){
    return seq( 0, n-1 ) ;
  }
  
  IndexQueue<RTYPE> q( v )  ;
  for( int i=0; i<n; i++) q.push(i) ;
  for( int i=n; i<size; i++) q.input(i) ;   
  return q ;
}

// [[Rcpp::export]]
IntegerVector top_index( SEXP x, int n){
  switch( TYPEOF(x) ){
  case INTSXP: return top_index<INTSXP>( x, n ) ;
  case REALSXP: return top_index<REALSXP>( x, n ) ;
  case STRSXP: return top_index<STRSXP>( x, n ) ;
  default: stop("type not handled") ; 
  }
  return IntegerVector() ; // not used
}

// [[Rcpp::export]]
IntegerVector k_nearest_classif(NumericMatrix training, IntegerVector training_outputs, NumericMatrix test,
                                int k = 1) {
  long i, j, ct;
  long num_training = training_outputs.size();
  long num_test = test.nrow();
  long num_pixels = test.ncol();
  NumericVector distances(num_training);
  
  long min_class = min(training_outputs);
  long max_class = max(training_outputs);
  long num_classes = max_class - min_class + 1;
  
  IntegerVector nearest(k);
  IntegerVector classifs(num_test);
  
  for (i = 0; i < num_test; i++) {
    // Calculate distances to training images:
    for (j = 0; j < num_training; j++) {
      distances[j] = 0.0;
      for (ct = 0; ct < num_pixels; ct++) {
        distances[j] -= (training(j, ct) - test(i, ct)) * (training(j, ct) - test(i, ct));
      }
    }
    
    // Find k nearest:
    nearest = top_index(distances, k);
    IntegerVector classes(num_classes);
    
    for (j = 0; j < k; j++) {
      classes[training_outputs[nearest[j]] - min_class] += 1;
    }

    IntegerVector this_classif_i = top_index(classes, 1);
    classifs[i] = this_classif_i[0] + min_class;
  }
  
  return classifs;
}

R (the PCA is done manually because prcomp() ate up too much memory when I used it, possibly incorrectly, when first trying to get this to work):

library(Rcpp)
library(ggplot2)
library(EBImage)

load("../digit_data/r_images_nielsen.rda")
sourceCpp("knn_by_me.cpp")

num_training = length(digits$training)
num_test = length(digits$test)

training_image_matrix = matrix(0, nrow=num_training, ncol=784)
test_image_matrix = matrix(0, nrow=num_test, ncol=784)

# Training

for (j in 1:num_training) {
  training_image_matrix[j, ] = images$training[[j]]
}

# Doing the PCA manually because prcomp soaked up too much memory (?? -- maybe I did something wrong with it).

digit_means = colMeans(training_image_matrix)

for (j in 1:num_training) {
  training_image_matrix[j, ] = training_image_matrix[j, ] - digit_means
}

for (j in 1:num_test) {
  test_image_matrix[j, ] = images$test[[j]] - digit_means
}

var_covar = cov(training_image_matrix)

# Remove zero-variance pixels (not needed unless the variances are normalised)
nonzero_i = which(diag(var_covar) > 0)
var_covar = var_covar[nonzero_i, nonzero_i]
training_image_matrix = training_image_matrix[ , nonzero_i]
test_image_matrix = test_image_matrix[ , nonzero_i]

# Usually you'd want to normalise the variances, but I get very marginally
# worse results when I do this:
# 
# pixel_stdevs = sqrt(diag(var_covar))
# pixel_stdevs = pixel_stdevs[nonzero_i]
# 
# for (i in 1:num_training) {
#   training_image_matrix[i, ] = training_image_matrix[i, ] / pixel_stdevs
# }
# 
# for (i in 1:num_test) {
#   test_image_matrix[i, ] = test_image_matrix[i, ] / pixel_stdevs
# }

eigenthings = eigen(var_covar)

new_dim = 30
reduced_training_matrix = training_image_matrix %*% eigenthings$vectors[ , 1:new_dim]
reduced_test_matrix = test_image_matrix %*% eigenthings$vectors[ , 1:new_dim]

test_answers_vec = k_nearest_classif(reduced_training_matrix, digits$training, reduced_test_matrix, k=3)
true_answers = digits$test

qplot(true_answers, test_answers_vec, alpha=.1, position=position_jitter()) + guides(alpha=FALSE)
print(sprintf("score = %d", length(which(test_answers_vec == digits$test))))

Conclusions

MNIST digits are pretty easy to classify, and neural networks aren't a lot better than much less sophisticated algorithms on it. As a dataset to learn on, MNIST has the very nice property that it's small and doesn't take too long to process, but things will certainly be more interesting when I branch out into new territory instead of trying to marginally improve the second or third significant figure in the number of MNIST digits correctly classified.

Posted 2015-11-17.


Home > Misc > Neural networks