Home > Misc > Neural networks
Non-neural MNIST benchmarks
Introduction
I've mentioned a couple of times that I'd like to try some non-neural approaches to classifying the MNIST digits, so that I can have a better
idea of how good the neural networks (and other machine learning algorithms, if I ever work with them) really are. This is just me – I'm not
pretending that my results here are world-class classification benchmarks, they're just what I might have come up with if I'd been given the problem
to classify the MNIST digits and didn't know where to start.
Having read through to chapter 3 of Michael Nielsen's book, my best neural network
has reached 98.1% accuracy on MNIST. Earlier, in chapter 1, Nielsen had written:
It's not difficult to find other ideas which achieve accuracies in the 20 to 50 percent range. If you work a bit harder you can get up
over 50 percent. But to get much higher accuracies it helps to use established machine learning algorithms.
I think that either this is underselling how easy the MNIST dataset is to classify, or "machine learning algorithms" include a broader array of
pretty simple techniques than I'd thought (the Wikipedia machine learning template includes "linear regression", so maybe it's the latter). Dotting
the test image with the means of each digit's training images gets me to 82%. With some more careful thought about PCA, I get 94%. With
k-nearest-neighbours (which does look like an "official" machine learning algorithm, albeit a very simple one) I get over 97%. So, at least on MNIST,
any advantage of neural networks for classification purposes is really just trying to squeeze as much out of the data as possible, rather than being
an enormous amount better than simple alternatives.
Comparing test images to digit means
This is a very simple algorithm:
- Treat each 28 × 28 image as a 784-length vector.
- For each digit 0-9, calculate the mean vector from the training set.
- For each test image, find the dot product with each of the training means.
- Return the digit with the highest dot product.
Vectors should be normalised, so as not to favour digits with more light pixels (forgetting to do this led to me getting about 60% accuracy and
therefore thinking that my then-weird PCA algorithm was good for getting 70%). On the test set it gets 8236/10000 correct. Jittered scatterplot of test
answers against true answers:
![]()
R code:
library(ggplot2)
load("../digit_data/r_images_nielsen.rda")
digit_image_matrices = list()
mean_digits = list()
length(digit_image_matrices) = 10
length(mean_digits) = 10
# Training
for (i in 0:9) {
this_i = which(digits$training == i)
num_digits = length(this_i)
digit_image_matrices[[i + 1]] = matrix(0, nrow=num_digits, ncol=784)
for (j in 1:num_digits) {
digit_image_matrices[[i + 1]][j, ] = images$training[[this_i[j]]]
}
mean_digits[[i + 1]] = colMeans(digit_image_matrices[[i + 1]])
this_norm = sqrt(sum(mean_digits[[i + 1]] * mean_digits[[i + 1]]))
mean_digits[[i + 1]] = mean_digits[[i + 1]] / this_norm
}
# Testing
test_answers = list()
num_test_images = length(digits$test)
length(test_answers) = num_test_images
test_answers_vec = numeric(num_test_images)
for (i in 1:num_test_images) {
test_answers[[i]] = numeric(10)
for (j in 1:10) {
this_img = images$test[[i]]
this_norm = sqrt(sum(this_img*this_img))
this_img = this_img / this_norm
test_answers[[i]][j] = sum(mean_digits[[j]] * this_img)
}
test_answers[[i]] = test_answers[[i]] / sum(test_answers[[i]])
test_answers_vec[i] = which(test_answers[[i]] == max(test_answers[[i]]))[1] - 1
}
true_answers = digits$test
qplot(true_answers, test_answers_vec, alpha=.1, position=position_jitter()) + guides(alpha=FALSE)
print(length(which(test_answers_vec == true_answers)))
PCA
My first attempt at a PCA-based classification method felt wrong but got me to what I then thought was a respectable 70%. After working through
the theory of PCA carefully (see my notes on the subject), I came up with a better way of using principal components to
classify digits, with the accuracy over 94%.
The logic behind the method goes like this. Consider all the '0' digit images, treat them as vectors of length 784, and put them in a big
matrix with 784 columns. Calculate the first few principal components of this data matrix – (too?) loosely speaking, these should correspond to
the main ways in which the digit is drawn. Repeat for the other digits, so that there is now a collection of each digit's first few
principal components.
Given a test image, the algorithm goes:
- For each digit 0-9:
- Convert the test image to the digit's first few factors.
- Back-transform from these factors to a 784-length vector representing an approximation to the original image.
- Calculate a score based on how far the back-transformed image differs from the original.
- The digit that scored the best is the returned classification.
The question is how many principal components to keep. Using all 784 is obviously wrong, since every digit's full set of factors will perfectly
transform and back-transform any test image to itself. Using only the first factor will miss lots of variations on the digit (for example, if it's
shifted a little from where the first factor says it should be). After a bit of playing around, I found that about 30 factors gave near-optimal
results, with 9462 out of the 10000 test images classified correctly. Jittered scatterplot:
![]()
R code (sorry for the inconsistency between lists and matrices):
library(ggplot2)
load("../digit_data/r_images_nielsen.rda")
digit_image_matrices = list()
digit_pca = list()
digit_means = list()
length(digit_image_matrices) = 10
length(digit_pca) = 10
length(digit_means) = 10
# Training
for (i in 0:9) {
print(i)
this_i = which(digits$training == i)
num_digits = length(this_i)
digit_image_matrices[[i+1]] = matrix(0, nrow=num_digits, ncol=784)
for (j in 1:num_digits) {
digit_image_matrices[[i+1]][j, ] = images$training[[this_i[j]]]
}
digit_pca[[i+1]]$pca = prcomp(digit_image_matrices[[i+1]], center=TRUE)
}
for (i in 0:9) {
# Calculate means by digit for use in centring the test images.
digit_means[[i + 1]] = colMeans(digit_image_matrices[[i + 1]])
}
# Testing
# Number of principal components to use:
num_evecs = 30
num_test = length(digits$test)
test_image_matrix = matrix(0, nrow=num_test, ncol=784)
for (j in 1:num_test) {
test_image_matrix[j, ] = images$test[[j]]
}
test_answers = matrix(0, nrow=num_test, ncol=10)
test_answers_vec = numeric(num_test)
for (i in 0:9) {
print(i)
this_centred_matrix = test_image_matrix
for (j in 1:num_test) {
this_centred_matrix[j, ] = this_centred_matrix[j, ] - digit_means[[i + 1]]
}
# Convert the images to the first num_evecs factors, then convert back and
# see how good the reproduction of the original image is.
reduced_images = t(digit_pca[[i + 1]]$pca$rotation[ , 1:num_evecs]) %*% t(this_centred_matrix)
back_transf_images = t(digit_pca[[i + 1]]$pca$rotation[ , 1:num_evecs] %*% reduced_images)
comparison = abs(this_centred_matrix - back_transf_images)
test_answers[ , i + 1] = rowSums(comparison)
}
test_answers_vec = apply(test_answers, 1, function(x) which(x == min(x)) - 1)
true_answers = digits$test[1:num_test]
qplot(true_answers, test_answers_vec, alpha=.1, position=position_jitter()) + guides(alpha=FALSE)
print(length(which(test_answers_vec == true_answers)))
k-nearest neighbours
(I didn't know about this algorithm ahead of time, but I saw it in the Wikipedia article
on MNIST, and it looked pretty simple.)
Consider each training image \( T^i \) as a vector of length 784, with j-th entry \( T^i_j \). Define the distance from \( T^i \) to a test image
\( T \) as
\begin{equation*}
d^i = \sqrt{\sum_{j=1}^{784} (T^i_j - T_j)^2}.
\end{equation*}
The nearest neighbour to \( T \) across the training images is simply the image \( T^i \) with the smallest distance \( d^i \), and the simplest
implementation of k-nearest neighbours is just "1-nearest neighbour", and (for MNIST) returns the digit of this \( T^i \). More generally, for
\( k > 1 \), the classification will be the most common of the \( k \) nearest neighbours in the training set.
The class R package implements k-NN and works pretty much straight out of the box, and it
classified the digits so well that I thought I might have made a mistake, accidentally feeding the true answers into the system somewhere. I wrote my
own implementation of k-NN just to make sure it was all as simple and effective as it seemed, and my Rcpp code follows shortly.
Wikipedia says that k-NN can struggle with high-dimensional data sets,
but the 784 dimensions of MNIST posed no problems and the 1-nearest-neighbour got 96.7% of the test images correct. Nevertheless, the algorithm is
both much faster and slightly more accurate working in reduced dimensions, and using the first 30 principal components and \( k = 3 \) brings the score
up to 9738/10000. With my code, it took about 25 minutes to work through the classifications with 784-length vectors, and about half a minute with
30-length vectors. Jittered scatterplot:
![]()
(Unlike in the previous section, the principal components are calculated over the full training dataset, and not separately by digit.)
As in a previous post, I use Romain François's code to
return sorted indices for me. The main function to be called from R is k_nearest_classif()
, and it is less than 40 lines long. Rcpp:
#include <Rcpp.h>
#include <queue>
using namespace Rcpp;
// http://gallery.rcpp.org/articles/top-elements-from-vectors-using-priority-queue/
template <int RTYPE>
class IndexComparator {
public:
typedef typename Rcpp::traits::storage_type<RTYPE>::type STORAGE ;
IndexComparator( const Vector<RTYPE>& data_ ) : data(data_.begin()){}
inline bool operator()(int i, int j) const {
return data[i] > data[j] || (data[i] == data[j] && j > i ) ;
}
private:
const STORAGE* data ;
} ;
template <>
class IndexComparator<STRSXP> {
public:
IndexComparator( const CharacterVector& data_ ) : data(data_.begin()){}
inline bool operator()(int i, int j) const {
return (String)data[i] > (String)data[j] || (data[i] == data[j] && j > i );
}
private:
const SEXP* data ;
} ;
template <int RTYPE>
class IndexQueue {
public:
typedef std::priority_queue<int, std::vector<int>, IndexComparator<RTYPE> > Queue ;
IndexQueue( const Vector<RTYPE>& data_ ) : comparator(data_), q(comparator), data(data_) {}
inline operator IntegerVector(){
int n = q.size() ;
IntegerVector res(n) ;
for( int i=0; i<n; i++){
// +1 for 1-based R indexing [deleted -- DB.]
res[i] = q.top();
q.pop() ;
}
return res ;
}
inline void input( int i){
// if( data[ q.top() ] < data[i] ){
if( comparator(i, q.top() ) ){
q.pop();
q.push(i) ;
}
}
inline void pop(){ q.pop() ; }
inline void push( int i){ q.push(i) ; }
private:
IndexComparator<RTYPE> comparator ;
Queue q ;
const Vector<RTYPE>& data ;
} ;
template <int RTYPE>
IntegerVector top_index(Vector<RTYPE> v, int n){
int size = v.size() ;
// not interesting case. Less data than n
if( size < n){
return seq( 0, n-1 ) ;
}
IndexQueue<RTYPE> q( v ) ;
for( int i=0; i<n; i++) q.push(i) ;
for( int i=n; i<size; i++) q.input(i) ;
return q ;
}
// [[Rcpp::export]]
IntegerVector top_index( SEXP x, int n){
switch( TYPEOF(x) ){
case INTSXP: return top_index<INTSXP>( x, n ) ;
case REALSXP: return top_index<REALSXP>( x, n ) ;
case STRSXP: return top_index<STRSXP>( x, n ) ;
default: stop("type not handled") ;
}
return IntegerVector() ; // not used
}
// [[Rcpp::export]]
IntegerVector k_nearest_classif(NumericMatrix training, IntegerVector training_outputs, NumericMatrix test,
int k = 1) {
long i, j, ct;
long num_training = training_outputs.size();
long num_test = test.nrow();
long num_pixels = test.ncol();
NumericVector distances(num_training);
long min_class = min(training_outputs);
long max_class = max(training_outputs);
long num_classes = max_class - min_class + 1;
IntegerVector nearest(k);
IntegerVector classifs(num_test);
for (i = 0; i < num_test; i++) {
// Calculate distances to training images:
for (j = 0; j < num_training; j++) {
distances[j] = 0.0;
for (ct = 0; ct < num_pixels; ct++) {
distances[j] -= (training(j, ct) - test(i, ct)) * (training(j, ct) - test(i, ct));
}
}
// Find k nearest:
nearest = top_index(distances, k);
IntegerVector classes(num_classes);
for (j = 0; j < k; j++) {
classes[training_outputs[nearest[j]] - min_class] += 1;
}
IntegerVector this_classif_i = top_index(classes, 1);
classifs[i] = this_classif_i[0] + min_class;
}
return classifs;
}
R (the PCA is done manually because prcomp()
ate up too much memory when I used it, possibly incorrectly, when first trying to
get this to work):
library(Rcpp)
library(ggplot2)
library(EBImage)
load("../digit_data/r_images_nielsen.rda")
sourceCpp("knn_by_me.cpp")
num_training = length(digits$training)
num_test = length(digits$test)
training_image_matrix = matrix(0, nrow=num_training, ncol=784)
test_image_matrix = matrix(0, nrow=num_test, ncol=784)
# Training
for (j in 1:num_training) {
training_image_matrix[j, ] = images$training[[j]]
}
# Doing the PCA manually because prcomp soaked up too much memory (?? -- maybe I did something wrong with it).
digit_means = colMeans(training_image_matrix)
for (j in 1:num_training) {
training_image_matrix[j, ] = training_image_matrix[j, ] - digit_means
}
for (j in 1:num_test) {
test_image_matrix[j, ] = images$test[[j]] - digit_means
}
var_covar = cov(training_image_matrix)
# Remove zero-variance pixels (not needed unless the variances are normalised)
nonzero_i = which(diag(var_covar) > 0)
var_covar = var_covar[nonzero_i, nonzero_i]
training_image_matrix = training_image_matrix[ , nonzero_i]
test_image_matrix = test_image_matrix[ , nonzero_i]
# Usually you'd want to normalise the variances, but I get very marginally
# worse results when I do this:
#
# pixel_stdevs = sqrt(diag(var_covar))
# pixel_stdevs = pixel_stdevs[nonzero_i]
#
# for (i in 1:num_training) {
# training_image_matrix[i, ] = training_image_matrix[i, ] / pixel_stdevs
# }
#
# for (i in 1:num_test) {
# test_image_matrix[i, ] = test_image_matrix[i, ] / pixel_stdevs
# }
eigenthings = eigen(var_covar)
new_dim = 30
reduced_training_matrix = training_image_matrix %*% eigenthings$vectors[ , 1:new_dim]
reduced_test_matrix = test_image_matrix %*% eigenthings$vectors[ , 1:new_dim]
test_answers_vec = k_nearest_classif(reduced_training_matrix, digits$training, reduced_test_matrix, k=3)
true_answers = digits$test
qplot(true_answers, test_answers_vec, alpha=.1, position=position_jitter()) + guides(alpha=FALSE)
print(sprintf("score = %d", length(which(test_answers_vec == digits$test))))
Conclusions
MNIST digits are pretty easy to classify, and neural networks aren't a lot better than much less sophisticated algorithms on it. As a dataset
to learn on, MNIST has the very nice property that it's small and doesn't take too long to process, but things will certainly be more interesting when I
branch out into new territory instead of trying to marginally improve the second or third significant figure in the number of MNIST digits correctly
classified.
Posted 2015-11-17.
Home > Misc > Neural networks