Demo 3: HKR classifier on MNIST dataset

This notebook will demonstrate learning a binary task on the MNIST0-8 dataset.

# pip install deel-lip -qqq
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.python.keras.layers import Input, Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import binary_accuracy
from tensorflow.keras.models import Sequential

from deel.lip.layers import (
from deel.lip.activations import MaxMin, GroupSort, GroupSort2, FullSort
from deel.lip.losses import HKR, KR, HingeMargin
data preparation

For this task we will select two classes: 0 and 8. Labels are changed to {-1,1}, wich is compatible with the Hinge term used in the loss.

from tensorflow.keras.datasets import mnist

# first we select the two classes
selected_classes = [0, 8]  # must be two classes as we perform binary classification

def prepare_data(x, y, class_a=0, class_b=8):
    This function convert the MNIST data to make it suitable for our binary classification
    # select items from the two selected classes
    mask = (y == class_a) + (
        y == class_b
    )  # mask to select only items from class_a or class_b
    x = x[mask]
    y = y[mask]
    x = x.astype("float32")
    y = y.astype("float32")
    # convert from range int[0,255] to float32[-1,1]
    x /= 255
    x = x.reshape((-1, 28, 28, 1))
    # change label to binary classification {-1,1}
    y[y == class_a] = 1.0
    y[y == class_b] = -1.0
    return x, y

# now we load the dataset
(x_train, y_train_ord), (x_test, y_test_ord) = mnist.load_data()

# prepare the data
x_train, y_train = prepare_data(
    x_train, y_train_ord, selected_classes[0], selected_classes[1]
x_test, y_test = prepare_data(
    x_test, y_test_ord, selected_classes[0], selected_classes[1]

# display infos about dataset
    "train set size: %i samples, classes proportions: %.3f percent"
    % (y_train.shape[0], 100 * y_train[y_train == 1].sum() / y_train.shape[0])
    "test set size: %i samples, classes proportions: %.3f percent"
    % (y_test.shape[0], 100 * y_test[y_test == 1].sum() / y_test.shape[0])
train set size: 11774 samples, classes proportions: 50.306 percent
test set size: 1954 samples, classes proportions: 50.154 percent

Build lipschitz Model

Let’s first explicit the paremeters of this experiment

# training parameters
epochs = 10
batch_size = 128

# network parameters
activation = GroupSort  # ReLU, MaxMin, GroupSort2

# loss parameters
min_margin = 1.0
alpha = 10.0

Now we can build the network. Here the experiment is done with a MLP. But Deel-lip also provide state of the art 1-Lipschitz convolutions.

# helper function to build the 1-lipschitz MLP
wass = Sequential(
        Input((28, 28, 1)),
        SpectralDense(32, GroupSort2(), use_bias=True),
        SpectralDense(16, GroupSort2(), use_bias=True),
        FrobeniusDense(1, activation=None, use_bias=False),
Model: "lipModel"
Layer (type)                 Output Shape              Param #
flatten (Flatten)            (None, 784)               0
spectral_dense (SpectralDens (None, 32)                50241
spectral_dense_1 (SpectralDe (None, 16)                1057
frobenius_dense (FrobeniusDe (None, 1)                 32
Total params: 51,330
Trainable params: 25,664
Non-trainable params: 25,666
optimizer = Adam(lr=0.001)
# as the output of our classifier is in the real range [-1, 1], binary accuracy must be redefined
def HKR_binary_accuracy(y_true, y_pred):
    S_true = tf.dtypes.cast(tf.greater_equal(y_true[:, 0], 0), dtype=tf.float32)
    S_pred = tf.dtypes.cast(tf.greater_equal(y_pred[:, 0], 0), dtype=tf.float32)
    return binary_accuracy(S_true, S_pred)
        alpha=alpha, min_margin=min_margin
    ),  # HKR stands for the hinge regularized KR loss
        KR,  # shows the KR term of the loss
        HingeMargin(min_margin=min_margin),  # shows the hinge term of the loss
        HKR_binary_accuracy,  # shows the classification accuracy

Learn classification on MNIST

Now the model is build, we can learn the task.
    validation_data=(x_test, y_test),
Epoch 1/10
92/92 [==============================] - 3s 10ms/step - loss: -0.5542 - KR: 3.2748 - HingeMargin: 0.2721 - HKR_binary_accuracy: 0.8725 - val_loss: -5.0345 - val_KR: 5.5790 - val_HingeMargin: 0.0553 - val_HKR_binary_accuracy: 0.9777
Epoch 2/10
92/92 [==============================] - 1s 6ms/step - loss: -4.8969 - KR: 5.4644 - HingeMargin: 0.0567 - HKR_binary_accuracy: 0.9785 - val_loss: -5.3840 - val_KR: 5.7409 - val_HingeMargin: 0.0383 - val_HKR_binary_accuracy: 0.9845
Epoch 3/10
92/92 [==============================] - 1s 6ms/step - loss: -5.3341 - KR: 5.7611 - HingeMargin: 0.0427 - HKR_binary_accuracy: 0.9840 - val_loss: -5.5146 - val_KR: 5.8514 - val_HingeMargin: 0.0360 - val_HKR_binary_accuracy: 0.9845
Epoch 4/10
92/92 [==============================] - 1s 6ms/step - loss: -5.4725 - KR: 5.8629 - HingeMargin: 0.0390 - HKR_binary_accuracy: 0.9858 - val_loss: -5.5682 - val_KR: 5.9083 - val_HingeMargin: 0.0362 - val_HKR_binary_accuracy: 0.9855
Epoch 5/10
92/92 [==============================] - 1s 6ms/step - loss: -5.4682 - KR: 5.8617 - HingeMargin: 0.0393 - HKR_binary_accuracy: 0.9862 - val_loss: -5.5683 - val_KR: 5.9196 - val_HingeMargin: 0.0366 - val_HKR_binary_accuracy: 0.9845
Epoch 6/10
92/92 [==============================] - 1s 6ms/step - loss: -5.5441 - KR: 5.9086 - HingeMargin: 0.0364 - HKR_binary_accuracy: 0.9878 - val_loss: -5.6268 - val_KR: 5.9399 - val_HingeMargin: 0.0336 - val_HKR_binary_accuracy: 0.9874
Epoch 7/10
92/92 [==============================] - 1s 6ms/step - loss: -5.6141 - KR: 5.9665 - HingeMargin: 0.0352 - HKR_binary_accuracy: 0.9877 - val_loss: -5.7121 - val_KR: 5.9817 - val_HingeMargin: 0.0300 - val_HKR_binary_accuracy: 0.9894
Epoch 8/10
92/92 [==============================] - 1s 6ms/step - loss: -5.6687 - KR: 6.0017 - HingeMargin: 0.0333 - HKR_binary_accuracy: 0.9875 - val_loss: -5.7358 - val_KR: 6.0305 - val_HingeMargin: 0.0322 - val_HKR_binary_accuracy: 0.9869
Epoch 9/10
92/92 [==============================] - 1s 6ms/step - loss: -5.6956 - KR: 6.0167 - HingeMargin: 0.0321 - HKR_binary_accuracy: 0.9883 - val_loss: -5.7684 - val_KR: 6.0966 - val_HingeMargin: 0.0350 - val_HKR_binary_accuracy: 0.9840
Epoch 10/10
92/92 [==============================] - 1s 6ms/step - loss: -5.7525 - KR: 6.0836 - HingeMargin: 0.0331 - HKR_binary_accuracy: 0.9881 - val_loss: -5.8637 - val_KR: 6.0924 - val_HingeMargin: 0.0260 - val_HKR_binary_accuracy: 0.9899
As we can see the model reach a very decent accuracy on this task.