Skip to content

Neural network classifier for handwritten digits (the MNIST dataset)

Notifications You must be signed in to change notification settings

kromych/digit-recognizer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

What this is

This is a neural network classifier for handwritten digits (trained with the MNIST dataset[0]). It's implemented as a feedforward neural network and aims to be a lean-and-mean demo of the concepts, hopefully hackable and easy to play with.

Provides ~97% accuracy.

The meat is in model.rs, ~150 LoC.

What this does

First, a neural network is trained, then it is given a handwritten digit from another dataset to recognize it. Finally, the digit is printed out (and some numbers to provide insight into how the trainiong went).

How to run this stuff

Do once

# Clone the repo
git clone https://github.com/kromych/digit-recognizer
# Get into the project dir
cd digit-recognizer

# Decompress the data set (please figure something out if you're on Windows)
cd data
gunzip *.gz
cd ..

Enjoy many times

# Without --release, that'll be veeeery slow.
#
# With --release, ~2 min for training on my machine (M1 Ultra)
# with the default --epochs=32

cargo run --release

and then observe something like and like that:

2026-01-05T03:38:59.198808Z  INFO digit_recognizer: Loading MNIST data...
2026-01-05T03:38:59.346420Z  INFO digit_recognizer::dataloader: MNIST data validation passed! Loaded 60000 images.
2026-01-05T03:38:59.368984Z  INFO digit_recognizer::dataloader: MNIST data validation passed! Loaded 10000 images.
2026-01-05T03:38:59.368996Z  INFO digit_recognizer: Training samples: 60000
2026-01-05T03:38:59.368998Z  INFO digit_recognizer: Test samples: 10000
2026-01-05T03:38:59.369679Z  INFO digit_recognizer: Training model for 32 epochs with learning rate 0.01...
2026-01-05T03:38:59.369682Z  INFO digit_recognizer::model: Starting epoch 1/32
....
2026-01-05T03:41:29.131753Z  INFO digit_recognizer::model: Starting epoch 32/32
2026-01-05T03:41:33.990653Z  INFO digit_recognizer::model: Completed epoch 32/32, average entropy loss: 0.0008
2026-01-05T03:41:33.990674Z  INFO digit_recognizer: Testing model...
2026-01-05T03:41:34.185698Z  INFO digit_recognizer: Accuracy: 97.52%
2026-01-05T03:41:34.185706Z  INFO digit_recognizer: Sixel graphics: Disabled (use --graphics to enable)
2026-01-05T03:41:34.185710Z  INFO digit_recognizer: Displaying 3 example predictions...
╭─────────────────────────────────╮
│              info               │
├─────────────────────────────────┤
│ Sample 1: True: 7, Predicted: 7 │
│ ✓ Correct prediction!           │
╰─────────────────────────────────╯

    @@@@@@@@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@
    @@@@@@*:-=#%@@@@@@@@@@@@@@@@
    @@@@@@.     ::::::::-#@@@@@@
    @@@@@@#+*+-. .    .  =@@@@@@
    @@@@@@@@@@@%#@####%. +@@@@@@
    @@@@@@@@@@@@@@@@@@* :%@@@@@@
    @@@@@@@@@@@@@@@@@%. *@@@@@@@
    @@@@@@@@@@@@@@@@@= .#@@@@@@@
    @@@@@@@@@@@@@@@@#  #@@@@@@@@
    @@@@@@@@@@@@@@@@= :@@@@@@@@@
    @@@@@@@@@@@@@@@@: #@@@@@@@@@
    @@@@@@@@@@@@@@@+ -@@@@@@@@@@
    @@@@@@@@@@@@@@* .#@@@@@@@@@@
    @@@@@@@@@@@@@%. -@@@@@@@@@@@
    @@@@@@@@@@@@@: .%@@@@@@@@@@@
    @@@@@@@@@@@@%  *@@@@@@@@@@@@
    @@@@@@@@@@@%. +@@@@@@@@@@@@@
    @@@@@@@@@@@=  #@@@@@@@@@@@@@
    @@@@@@@@@@#   #@@@@@@@@@@@@@
    @@@@@@@@@@+  .%@@@@@@@@@@@@@
    @@@@@@@@@@+ :%@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@

..................... more stuff .................................................

2026-01-05T04:39:51.196877Z  INFO digit_recognizer: Analyzing misclassifications...

Interesting Misclassifications:
╭─────────────────────────────────╮
│              info               │
├─────────────────────────────────┤
│ Sample 9: True: 5, Predicted: 6 │
│ ✗ Incorrect prediction!         │
╰─────────────────────────────────╯

    @@@@@@@@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@%###%=*#@@@
    @@@@@@@@@@@@@@@*=.   .   @@@
    @@@@@@@@@@@@%=           @@@
    @@@@@@@@@@@@#       .----@@@
    @@@@@@@@%=*@#...-+#%@@@@@@@@
    @@@@@@@@: *@@%@%@@@@@@@@@@@@
    @@@@@@@- =@@@@@@@@@@@@@@@@@@
    @@@@@@= .%@@@@@@@@@@@@@@@@@@
    @@@@@* .%@@@@@@@@@@@@@@@@@@@
    @@@@@= -@@@@@@@@@@@@@@@@@@@@
    @@@@@.  +#@@@@@@@@@@@@@@@@@@
    @@@@@     =*+*%%*%@@@@@@@@@@
    @@@@@=           .-@@@@@@@@@
    @@@@@@#..          +@@@@@@@@
    @@@@@@@@@#*:       -@@@@@@@@
    @@@@@@@@@@@+  =#-  -@@@@@@@@
    @@@@@@@@@@@#   :   =@@@@@@@@
    @@@@@@@@@@@@-     -@@@@@@@@@
    @@@@@@@@@@@@%-. .=@@@@@@@@@@
    @@@@@@@@@@@@@@*=%@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@

If your terminal support Sixel graphics, add -g to the command line for more fun! There is more arguments to play with:

cargo run -- --help

Neural network

Input Layer (784) -> Hidden Layer (ReLU)  -> Output Layer (Softmax)
28 x 28 pixels       64 neurons (default)    10 neurons (digits 0-9)

Mathematics behind the scenes

Forward propagation

  • Input Layer: 784 neurons (flattened 28 x 28 pixel image, normalized to [0, 1])
  • Hidden Layer: Linear transformation: $z_1 = W_1 x + b_1$
  • Activation: $a_1 = \sigma (z_1)$ where $\sigma$ is the sigmoid function
  • Weight initialization: He initialization [1]

Output layer

  • Linear transformation: $z_2 = W_2 a_1 + b_2$
  • Activation: $a_2 = softmax(z_2)$

Activation functions

  • Hidden Layer: Sigmoid function $\sigma(x) = 1 / (1 + e^{-x})$
  • Output Layer: Softmax function $softmax(z_i) = e^{z_i} / \sum_{j=1}^{K} e^{z_j}$

Loss function

  • Categorical Cross-Entropy: $L(y, y') = -\sum_{i=1}^{K} y_i log(y'_i)$
  • $y$ is the one-hot encoded true label and $y'$ is the predicted probability distribution.

Backpropagation [2]

  • Gradients computed using chain rule
  • Weight updates via stochastic gradient descent (SGD) [3]
  • Learning rate: 0.01 (default, configurable)

Some limitations

  • Sequential training - one sample at a time
  • Single hidden layer configuration, hardcoded
  • No GPU acceleration
  • Simple SGD without momentum or adaptive learning rates - might be sensitive to the choice of the learning rate

Wouldn't be possible without

  1. MNIST dataset of 60,000 training images (28 x 28 pixels, grayscale) and 10,000 test images: LeCun, Y., Cortes, C., & Burges, C. J. C. The MNIST Database of Handwritten Digits, 1998
  2. Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification
  3. Rumelhart, D. E., Hinton, G. E., & Williams, R. J. Learning representations by back-propagating errors, Nature, 1986
  4. Stochastic gradient descent
  5. Cross-entropy

THANK YOU!!!