Packages

import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense

plt.rcParams['figure.figsize'] = (10, 6)

Load the dataset

Here, we will use Cats and Dogs datasets from kaggle, which is binary classification problem. For the simplicity, its datasets are filtered with some images.

base_dir = './dataset/cats_and_dogs_filtered'
os.listdir(base_dir)
['vectorize.py', 'validation', 'train']
train_dir = os.path.join(base_dir, 'train')
val_dir = os.path.join(base_dir, 'validation')
os.listdir(train_dir)
['dogs', 'cats']
os.listdir(val_dir)
['dogs', 'cats']

To use ImageDataGenerator in Tensorflow, the folder structure should be organized hierarchically. For example,

  • train
    • label_1
    • label_2
    • ...
  • val
    • label_1
    • label_2
    • ...

Anyway, we prepare the directory path for the convenience.

train_cat_dir = os.path.join(train_dir, 'cats')
train_dog_dir = os.path.join(train_dir, 'dogs')
val_cat_dir = os.path.join(val_dir, 'cats')
val_dog_dir = os.path.join(val_dir, 'dogs')

train_dog_fname = os.listdir(train_dog_dir)
train_cat_fname = os.listdir(train_cat_dir)

Check the sample images

nrows = 4
ncols = 4

pic_idx = 0

fig = plt.gcf()
fig.set_size_inches(ncols * 4, nrows * 4)

pic_idx += 8
next_cat_pic = [os.path.join(train_cat_dir, fname) for fname in train_cat_fname[pic_idx - 8:pic_idx]]
next_dog_pic = [os.path.join(train_dog_dir, fname) for fname in train_dog_fname[pic_idx - 8:pic_idx]]

for i, img_path in enumerate(next_cat_pic + next_dog_pic):
    sp = plt.subplot(nrows, ncols, i + 1)
    sp.axis('off')
    
    img = mpimg.imread(img_path)
    plt.imshow(img)
    
plt.show()

As you can see, each image has different shape, and some of image contains unlabeled object like human hand, cage, etc. This kind of feature makes hard to train generalization of dataset for classifying cats and dogs. See the CNN performance with naive dataset shortly.

Tensorflow ImageDataGenerator

Anyway, of course we can make image itself for training and validation data, But tensorflow offers nice API called ImageDataGenerator. If we can make sure the correct structure of dataset folder and appropriate options, we can define pythonic generator for dataset. To compare the performance whether the image augmentation is applied or not, let's build naive data generator. In this case, we will just apply rescale, which can make the color range between 0 and 1.

train_datagen = ImageDataGenerator(rescale=1/255.)
val_datagen = ImageDataGenerator(rescale=1/255.)
train_generator = train_datagen.flow_from_directory(
    directory=train_dir,
    batch_size=20,
    target_size=(150, 150),
    class_mode='binary'
)

val_generator = val_datagen.flow_from_directory(
    directory=val_dir,
    batch_size=20,
    target_size=(150, 150),
    class_mode='binary'
)
Found 2000 images belonging to 2 classes.
Found 1000 images belonging to 2 classes.

We need to think about the target size, batch size, and class_mode. Since we try to solve binary classification, we need to define binary for class_mode. If we want extend the problem with multiclass classification, we just set it to categorical. And target_size makes each image to default shape, and its value will affect to the input_shape in CNN.

After that, we can make pythonic generator that contains dataset. Thankfully, label encoding is applied automatically. If we want to find out which label is mapped, we can use class_indices of ImageDataGenerator.

train_generator.class_indices
{'cats': 0, 'dogs': 1}

Model Build

Let's make simple CNN model. The purpose of this notebook is to show the power of image augmentation, we do not build complex model.

model = Sequential([
    Conv2D(16, (3, 3), activation='relu', input_shape=(150, 150, 3)),
    MaxPooling2D(2, 2),
    Conv2D(32, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),
    Flatten(),
    Dense(512, activation='relu'),
    Dense(1, activation='sigmoid')
])

model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 148, 148, 16)      448       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 74, 74, 16)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 72, 72, 32)        4640      
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 36, 36, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 34, 34, 64)        18496     
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 17, 17, 64)        0         
_________________________________________________________________
flatten (Flatten)            (None, 18496)             0         
_________________________________________________________________
dense (Dense)                (None, 512)               9470464   
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 513       
=================================================================
Total params: 9,494,561
Trainable params: 9,494,561
Non-trainable params: 0
_________________________________________________________________
model.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001), loss='binary_crossentropy', metrics=['accuracy'])
history_naive = model.fit(train_generator, steps_per_epoch=100, epochs=100, validation_data=val_generator, validation_steps=50)
Epoch 1/100
100/100 [==============================] - 6s 54ms/step - loss: 0.7442 - accuracy: 0.5620 - val_loss: 0.6405 - val_accuracy: 0.6700
Epoch 2/100
100/100 [==============================] - 5s 52ms/step - loss: 0.6123 - accuracy: 0.6830 - val_loss: 0.6601 - val_accuracy: 0.6170
Epoch 3/100
100/100 [==============================] - 5s 52ms/step - loss: 0.5366 - accuracy: 0.7270 - val_loss: 0.5534 - val_accuracy: 0.7190
Epoch 4/100
100/100 [==============================] - 5s 53ms/step - loss: 0.4738 - accuracy: 0.7765 - val_loss: 0.5489 - val_accuracy: 0.7340
Epoch 5/100
100/100 [==============================] - 5s 54ms/step - loss: 0.4059 - accuracy: 0.8115 - val_loss: 0.5849 - val_accuracy: 0.7100
Epoch 6/100
100/100 [==============================] - 5s 54ms/step - loss: 0.3210 - accuracy: 0.8610 - val_loss: 0.5813 - val_accuracy: 0.7350
Epoch 7/100
100/100 [==============================] - 5s 52ms/step - loss: 0.2312 - accuracy: 0.9040 - val_loss: 0.6729 - val_accuracy: 0.7390
Epoch 8/100
100/100 [==============================] - 5s 52ms/step - loss: 0.1853 - accuracy: 0.9275 - val_loss: 0.7735 - val_accuracy: 0.7370
Epoch 9/100
100/100 [==============================] - 5s 52ms/step - loss: 0.1275 - accuracy: 0.9620 - val_loss: 0.9282 - val_accuracy: 0.7140
Epoch 10/100
100/100 [==============================] - 5s 52ms/step - loss: 0.0941 - accuracy: 0.9610 - val_loss: 1.3261 - val_accuracy: 0.7010
Epoch 11/100
100/100 [==============================] - 5s 52ms/step - loss: 0.0654 - accuracy: 0.9760 - val_loss: 1.2061 - val_accuracy: 0.7380
Epoch 12/100
100/100 [==============================] - 5s 52ms/step - loss: 0.0665 - accuracy: 0.9755 - val_loss: 1.1133 - val_accuracy: 0.7170
Epoch 13/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0464 - accuracy: 0.9845 - val_loss: 1.2408 - val_accuracy: 0.7420
Epoch 14/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0413 - accuracy: 0.9890 - val_loss: 1.4219 - val_accuracy: 0.7490
Epoch 15/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0251 - accuracy: 0.9920 - val_loss: 1.8764 - val_accuracy: 0.7430
Epoch 16/100
100/100 [==============================] - 5s 52ms/step - loss: 0.0683 - accuracy: 0.9840 - val_loss: 1.5557 - val_accuracy: 0.7290
Epoch 17/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0277 - accuracy: 0.9920 - val_loss: 2.5855 - val_accuracy: 0.7240
Epoch 18/100
100/100 [==============================] - 5s 52ms/step - loss: 0.0505 - accuracy: 0.9875 - val_loss: 1.6732 - val_accuracy: 0.7340
Epoch 19/100
100/100 [==============================] - 5s 52ms/step - loss: 0.0220 - accuracy: 0.9925 - val_loss: 2.0212 - val_accuracy: 0.7390
Epoch 20/100
100/100 [==============================] - 5s 52ms/step - loss: 0.0809 - accuracy: 0.9855 - val_loss: 1.8006 - val_accuracy: 0.7170
Epoch 21/100
100/100 [==============================] - 5s 52ms/step - loss: 0.0145 - accuracy: 0.9940 - val_loss: 2.1849 - val_accuracy: 0.7360
Epoch 22/100
100/100 [==============================] - 5s 52ms/step - loss: 0.0236 - accuracy: 0.9950 - val_loss: 2.2045 - val_accuracy: 0.7340
Epoch 23/100
100/100 [==============================] - 5s 52ms/step - loss: 0.0497 - accuracy: 0.9905 - val_loss: 2.2413 - val_accuracy: 0.7000
Epoch 24/100
100/100 [==============================] - 5s 52ms/step - loss: 0.0170 - accuracy: 0.9970 - val_loss: 2.5266 - val_accuracy: 0.7300
Epoch 25/100
100/100 [==============================] - 5s 52ms/step - loss: 0.0292 - accuracy: 0.9905 - val_loss: 2.0093 - val_accuracy: 0.7280
Epoch 26/100
100/100 [==============================] - 5s 53ms/step - loss: 4.4668e-04 - accuracy: 1.0000 - val_loss: 2.7501 - val_accuracy: 0.7560
Epoch 27/100
100/100 [==============================] - 5s 52ms/step - loss: 0.1578 - accuracy: 0.9905 - val_loss: 2.6624 - val_accuracy: 0.7380
Epoch 28/100
100/100 [==============================] - 5s 52ms/step - loss: 0.1436 - accuracy: 0.9910 - val_loss: 2.6825 - val_accuracy: 0.7260
Epoch 29/100
100/100 [==============================] - 5s 52ms/step - loss: 0.0560 - accuracy: 0.9935 - val_loss: 2.5473 - val_accuracy: 0.7370
Epoch 30/100
100/100 [==============================] - 5s 52ms/step - loss: 7.7166e-05 - accuracy: 1.0000 - val_loss: 3.0506 - val_accuracy: 0.7380
Epoch 31/100
100/100 [==============================] - 5s 52ms/step - loss: 0.0772 - accuracy: 0.9895 - val_loss: 2.7408 - val_accuracy: 0.7310
Epoch 32/100
100/100 [==============================] - 5s 52ms/step - loss: 0.0111 - accuracy: 0.9955 - val_loss: 3.2875 - val_accuracy: 0.7330
Epoch 33/100
100/100 [==============================] - 5s 52ms/step - loss: 0.0381 - accuracy: 0.9950 - val_loss: 3.0033 - val_accuracy: 0.7450
Epoch 34/100
100/100 [==============================] - 5s 52ms/step - loss: 0.0172 - accuracy: 0.9965 - val_loss: 2.6957 - val_accuracy: 0.7310
Epoch 35/100
100/100 [==============================] - 5s 52ms/step - loss: 0.0096 - accuracy: 0.9975 - val_loss: 2.8700 - val_accuracy: 0.7370
Epoch 36/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0069 - accuracy: 0.9975 - val_loss: 2.8511 - val_accuracy: 0.7210
Epoch 37/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0248 - accuracy: 0.9960 - val_loss: 3.5881 - val_accuracy: 0.6690
Epoch 38/100
100/100 [==============================] - 5s 52ms/step - loss: 0.0184 - accuracy: 0.9960 - val_loss: 2.8461 - val_accuracy: 0.7390
Epoch 39/100
100/100 [==============================] - 5s 52ms/step - loss: 0.0133 - accuracy: 0.9965 - val_loss: 2.5951 - val_accuracy: 0.7290
Epoch 40/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0175 - accuracy: 0.9940 - val_loss: 3.4084 - val_accuracy: 0.7310
Epoch 41/100
100/100 [==============================] - 5s 55ms/step - loss: 0.0105 - accuracy: 0.9980 - val_loss: 3.7552 - val_accuracy: 0.7190
Epoch 42/100
100/100 [==============================] - 5s 54ms/step - loss: 0.0094 - accuracy: 0.9990 - val_loss: 2.9929 - val_accuracy: 0.7170
Epoch 43/100
100/100 [==============================] - 5s 55ms/step - loss: 0.0215 - accuracy: 0.9950 - val_loss: 3.9527 - val_accuracy: 0.7490
Epoch 44/100
100/100 [==============================] - 5s 54ms/step - loss: 0.0182 - accuracy: 0.9960 - val_loss: 3.4820 - val_accuracy: 0.7360
Epoch 45/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0049 - accuracy: 0.9995 - val_loss: 3.1149 - val_accuracy: 0.7280
Epoch 46/100
100/100 [==============================] - 5s 52ms/step - loss: 0.0292 - accuracy: 0.9945 - val_loss: 3.4735 - val_accuracy: 0.7180
Epoch 47/100
100/100 [==============================] - 5s 53ms/step - loss: 6.4612e-04 - accuracy: 0.9995 - val_loss: 5.0020 - val_accuracy: 0.6890
Epoch 48/100
100/100 [==============================] - 5s 54ms/step - loss: 0.0318 - accuracy: 0.9960 - val_loss: 4.1888 - val_accuracy: 0.7160
Epoch 49/100
100/100 [==============================] - 5s 54ms/step - loss: 0.0206 - accuracy: 0.9955 - val_loss: 3.4874 - val_accuracy: 0.7340
Epoch 50/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0040 - accuracy: 0.9990 - val_loss: 3.5001 - val_accuracy: 0.7310
Epoch 51/100
100/100 [==============================] - 5s 54ms/step - loss: 1.6889e-06 - accuracy: 1.0000 - val_loss: 3.9648 - val_accuracy: 0.7340
Epoch 52/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0195 - accuracy: 0.9950 - val_loss: 4.3161 - val_accuracy: 0.7300
Epoch 53/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0040 - accuracy: 0.9985 - val_loss: 4.3087 - val_accuracy: 0.7360
Epoch 54/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0402 - accuracy: 0.9935 - val_loss: 4.0819 - val_accuracy: 0.7340
Epoch 55/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0114 - accuracy: 0.9965 - val_loss: 6.3068 - val_accuracy: 0.6610
Epoch 56/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0175 - accuracy: 0.9970 - val_loss: 4.5979 - val_accuracy: 0.7260
Epoch 57/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0030 - accuracy: 0.9990 - val_loss: 4.6900 - val_accuracy: 0.7370
Epoch 58/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0106 - accuracy: 0.9975 - val_loss: 4.8222 - val_accuracy: 0.7360
Epoch 59/100
100/100 [==============================] - 5s 54ms/step - loss: 0.0036 - accuracy: 0.9995 - val_loss: 4.7843 - val_accuracy: 0.7340
Epoch 60/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0141 - accuracy: 0.9975 - val_loss: 5.1487 - val_accuracy: 0.7270
Epoch 61/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0041 - accuracy: 0.9990 - val_loss: 4.4973 - val_accuracy: 0.7150
Epoch 62/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0055 - accuracy: 0.9990 - val_loss: 5.6428 - val_accuracy: 0.7240
Epoch 63/100
100/100 [==============================] - 5s 53ms/step - loss: 2.6188e-06 - accuracy: 1.0000 - val_loss: 5.4401 - val_accuracy: 0.7420
Epoch 64/100
100/100 [==============================] - 5s 52ms/step - loss: 4.8284e-05 - accuracy: 1.0000 - val_loss: 11.4272 - val_accuracy: 0.6490
Epoch 65/100
100/100 [==============================] - 5s 54ms/step - loss: 0.0472 - accuracy: 0.9990 - val_loss: 5.7082 - val_accuracy: 0.7190
Epoch 66/100
100/100 [==============================] - 5s 53ms/step - loss: 0.2012 - accuracy: 0.9900 - val_loss: 5.5670 - val_accuracy: 0.7240
Epoch 67/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0140 - accuracy: 0.9970 - val_loss: 5.1957 - val_accuracy: 0.7400
Epoch 68/100
100/100 [==============================] - 5s 54ms/step - loss: 6.3547e-06 - accuracy: 1.0000 - val_loss: 5.1489 - val_accuracy: 0.7290
Epoch 69/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0849 - accuracy: 0.9915 - val_loss: 4.5071 - val_accuracy: 0.7320
Epoch 70/100
100/100 [==============================] - 5s 53ms/step - loss: 8.5806e-07 - accuracy: 1.0000 - val_loss: 4.8427 - val_accuracy: 0.7530
Epoch 71/100
100/100 [==============================] - 5s 54ms/step - loss: 0.0211 - accuracy: 0.9965 - val_loss: 4.4191 - val_accuracy: 0.7330
Epoch 72/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0409 - accuracy: 0.9955 - val_loss: 4.2959 - val_accuracy: 0.7330
Epoch 73/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0114 - accuracy: 0.9985 - val_loss: 5.0376 - val_accuracy: 0.7370
Epoch 74/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0087 - accuracy: 0.9995 - val_loss: 5.6852 - val_accuracy: 0.7290
Epoch 75/100
100/100 [==============================] - 5s 53ms/step - loss: 2.8096e-06 - accuracy: 1.0000 - val_loss: 5.4950 - val_accuracy: 0.7480
Epoch 76/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0217 - accuracy: 0.9965 - val_loss: 4.8881 - val_accuracy: 0.7340
Epoch 77/100
100/100 [==============================] - 5s 53ms/step - loss: 3.1602e-05 - accuracy: 1.0000 - val_loss: 4.8897 - val_accuracy: 0.7380
Epoch 78/100
100/100 [==============================] - 5s 53ms/step - loss: 1.2130e-07 - accuracy: 1.0000 - val_loss: 5.0578 - val_accuracy: 0.7310
Epoch 79/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0154 - accuracy: 0.9970 - val_loss: 5.8194 - val_accuracy: 0.7000
Epoch 80/100
100/100 [==============================] - 5s 53ms/step - loss: 2.3486e-05 - accuracy: 1.0000 - val_loss: 5.0369 - val_accuracy: 0.7260
Epoch 81/100
100/100 [==============================] - 5s 53ms/step - loss: 7.5696e-05 - accuracy: 1.0000 - val_loss: 4.8706 - val_accuracy: 0.7240
Epoch 82/100
100/100 [==============================] - 5s 53ms/step - loss: 1.2321e-05 - accuracy: 1.0000 - val_loss: 7.2887 - val_accuracy: 0.7100
Epoch 83/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0503 - accuracy: 0.9960 - val_loss: 6.1369 - val_accuracy: 0.7230
Epoch 84/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0143 - accuracy: 0.9980 - val_loss: 7.3023 - val_accuracy: 0.7340
Epoch 85/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0073 - accuracy: 0.9995 - val_loss: 5.7609 - val_accuracy: 0.7310
Epoch 86/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0248 - accuracy: 0.9965 - val_loss: 6.0219 - val_accuracy: 0.7220
Epoch 87/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0016 - accuracy: 0.9990 - val_loss: 6.6455 - val_accuracy: 0.7180
Epoch 88/100
100/100 [==============================] - 5s 52ms/step - loss: 1.5886e-07 - accuracy: 1.0000 - val_loss: 6.8651 - val_accuracy: 0.7230
Epoch 89/100
100/100 [==============================] - 5s 53ms/step - loss: 2.4064e-09 - accuracy: 1.0000 - val_loss: 7.2050 - val_accuracy: 0.7190
Epoch 90/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0101 - accuracy: 0.9980 - val_loss: 5.6294 - val_accuracy: 0.7260
Epoch 91/100
100/100 [==============================] - 5s 53ms/step - loss: 0.0103 - accuracy: 0.9980 - val_loss: 5.5818 - val_accuracy: 0.7410
Epoch 92/100
100/100 [==============================] - 5s 53ms/step - loss: 5.4872e-06 - accuracy: 1.0000 - val_loss: 5.6043 - val_accuracy: 0.7390
Epoch 93/100
100/100 [==============================] - 5s 53ms/step - loss: 1.4232e-08 - accuracy: 1.0000 - val_loss: 5.6822 - val_accuracy: 0.7420
Epoch 94/100
100/100 [==============================] - 5s 53ms/step - loss: 1.9938e-09 - accuracy: 1.0000 - val_loss: 5.9912 - val_accuracy: 0.7460
Epoch 95/100
100/100 [==============================] - 5s 52ms/step - loss: 1.0459e-10 - accuracy: 1.0000 - val_loss: 6.0301 - val_accuracy: 0.7440
Epoch 96/100
100/100 [==============================] - 5s 53ms/step - loss: 5.5668e-11 - accuracy: 1.0000 - val_loss: 6.0432 - val_accuracy: 0.7420
Epoch 97/100
100/100 [==============================] - 5s 53ms/step - loss: 4.6533e-11 - accuracy: 1.0000 - val_loss: 6.0673 - val_accuracy: 0.7410
Epoch 98/100
100/100 [==============================] - 5s 53ms/step - loss: 4.3628e-11 - accuracy: 1.0000 - val_loss: 6.0834 - val_accuracy: 0.7400
Epoch 99/100
100/100 [==============================] - 5s 53ms/step - loss: 4.1981e-11 - accuracy: 1.0000 - val_loss: 6.0965 - val_accuracy: 0.7400
Epoch 100/100
100/100 [==============================] - 5s 53ms/step - loss: 4.1266e-11 - accuracy: 1.0000 - val_loss: 6.1099 - val_accuracy: 0.7400
accuracy_naive = history_naive.history['accuracy']
val_accuracy_naive = history_naive.history['val_accuracy']
loss_naive = history_naive.history['loss']
val_loss_naive = history_naive.history['val_loss']

epochs = range(len(accuracy_naive))

plt.subplot(1, 2, 1)
plt.plot(epochs, accuracy_naive, 'bo', label='Training accuracy')
plt.plot(epochs, val_accuracy_naive, 'b', label='Validation accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(epochs, loss_naive, 'bo', label='Training loss')
plt.plot(epochs, val_loss_naive, 'b', label='Validation loss')
plt.legend()
plt.show()

As you can see, while training accuracy is almost 100%, validation accuracy stucks in 72~3%. From the result, we can make sure that overfitting is occurred, thus model is not trained well in terms of generalization.

Image Augmentation

To overcome this problem, a few approaches exist. One definite way is gather lots of dataset, but it is hard to obtain lots of dataset in most of cases. Instead, we can generate more synthetic images from original data, such as rotation, zoom, shearing. This technique is called Image Augmentation. Actually, we can do this easily with previously mentioned ImageDataGenerator.

train_aug_datagen = ImageDataGenerator(
    rescale=1/255.,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)
val_datagen = ImageDataGenerator(
    rescale=1/255.
)

train_generator = train_aug_datagen.flow_from_directory(
    directory=train_dir,
    target_size=(150, 150),
    batch_size=20,
    class_mode='binary'
)
val_generator = val_datagen.flow_from_directory(
    directory=val_dir,
    target_size=(150, 150),
    batch_size=20,
    class_mode='binary'
)
Found 2000 images belonging to 2 classes.
Found 1000 images belonging to 2 classes.

Note that, validation dataset should be remained in original feature. So that's why we only define rescaling in validation data generator.

Let's do the same process with same model.

model = Sequential([
    Conv2D(16, (3, 3), activation='relu', input_shape=(150, 150, 3)),
    MaxPooling2D(2, 2),
    Conv2D(32, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),
    Flatten(),
    Dense(512, activation='relu'),
    Dense(1, activation='sigmoid')
])

model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_3 (Conv2D)            (None, 148, 148, 16)      448       
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 74, 74, 16)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 72, 72, 32)        4640      
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 36, 36, 32)        0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 34, 34, 64)        18496     
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 17, 17, 64)        0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 18496)             0         
_________________________________________________________________
dense_2 (Dense)              (None, 512)               9470464   
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 513       
=================================================================
Total params: 9,494,561
Trainable params: 9,494,561
Non-trainable params: 0
_________________________________________________________________
model.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001), loss='binary_crossentropy', metrics=['accuracy'])
history_augmentation = model.fit(train_generator, steps_per_epoch=100, epochs=100, validation_data=val_generator, validation_steps=50)
Epoch 1/100
100/100 [==============================] - 12s 115ms/step - loss: 0.8108 - accuracy: 0.5315 - val_loss: 0.7738 - val_accuracy: 0.5000
Epoch 2/100
100/100 [==============================] - 11s 113ms/step - loss: 0.7056 - accuracy: 0.5580 - val_loss: 0.6624 - val_accuracy: 0.5780
Epoch 3/100
100/100 [==============================] - 11s 111ms/step - loss: 0.6771 - accuracy: 0.5910 - val_loss: 0.6278 - val_accuracy: 0.6380
Epoch 4/100
100/100 [==============================] - 11s 112ms/step - loss: 0.6687 - accuracy: 0.6395 - val_loss: 0.5900 - val_accuracy: 0.6670
Epoch 5/100
100/100 [==============================] - 11s 112ms/step - loss: 0.6328 - accuracy: 0.6490 - val_loss: 0.5757 - val_accuracy: 0.6770
Epoch 6/100
100/100 [==============================] - 11s 111ms/step - loss: 0.6265 - accuracy: 0.6495 - val_loss: 0.6053 - val_accuracy: 0.6550
Epoch 7/100
100/100 [==============================] - 11s 111ms/step - loss: 0.6149 - accuracy: 0.6660 - val_loss: 0.5520 - val_accuracy: 0.7110
Epoch 8/100
100/100 [==============================] - 11s 111ms/step - loss: 0.5982 - accuracy: 0.6875 - val_loss: 0.5817 - val_accuracy: 0.6850
Epoch 9/100
100/100 [==============================] - 11s 111ms/step - loss: 0.5955 - accuracy: 0.6950 - val_loss: 0.5597 - val_accuracy: 0.7010
Epoch 10/100
100/100 [==============================] - 11s 112ms/step - loss: 0.5924 - accuracy: 0.6795 - val_loss: 0.5469 - val_accuracy: 0.7160
Epoch 11/100
100/100 [==============================] - 11s 112ms/step - loss: 0.5843 - accuracy: 0.7015 - val_loss: 0.5384 - val_accuracy: 0.7310
Epoch 12/100
100/100 [==============================] - 11s 111ms/step - loss: 0.5732 - accuracy: 0.7035 - val_loss: 0.5452 - val_accuracy: 0.7240
Epoch 13/100
100/100 [==============================] - 11s 111ms/step - loss: 0.5818 - accuracy: 0.6970 - val_loss: 0.5238 - val_accuracy: 0.7230
Epoch 14/100
100/100 [==============================] - 11s 112ms/step - loss: 0.5735 - accuracy: 0.7120 - val_loss: 0.5491 - val_accuracy: 0.7290
Epoch 15/100
100/100 [==============================] - 12s 117ms/step - loss: 0.5817 - accuracy: 0.7175 - val_loss: 0.5108 - val_accuracy: 0.7410
Epoch 16/100
100/100 [==============================] - 11s 112ms/step - loss: 0.5577 - accuracy: 0.7145 - val_loss: 0.5311 - val_accuracy: 0.7400
Epoch 17/100
100/100 [==============================] - 11s 112ms/step - loss: 0.5562 - accuracy: 0.7170 - val_loss: 0.5227 - val_accuracy: 0.7300
Epoch 18/100
100/100 [==============================] - 11s 114ms/step - loss: 0.5611 - accuracy: 0.7110 - val_loss: 0.4971 - val_accuracy: 0.7610
Epoch 19/100
100/100 [==============================] - 11s 112ms/step - loss: 0.5434 - accuracy: 0.7385 - val_loss: 0.5103 - val_accuracy: 0.7400
Epoch 20/100
100/100 [==============================] - 11s 111ms/step - loss: 0.5424 - accuracy: 0.7365 - val_loss: 0.5104 - val_accuracy: 0.7460
Epoch 21/100
100/100 [==============================] - 11s 115ms/step - loss: 0.5606 - accuracy: 0.7360 - val_loss: 0.4856 - val_accuracy: 0.7640
Epoch 22/100
100/100 [==============================] - 12s 115ms/step - loss: 0.5342 - accuracy: 0.7365 - val_loss: 0.4925 - val_accuracy: 0.7640
Epoch 23/100
100/100 [==============================] - 12s 116ms/step - loss: 0.5430 - accuracy: 0.7320 - val_loss: 0.4980 - val_accuracy: 0.7480
Epoch 24/100
100/100 [==============================] - 11s 114ms/step - loss: 0.5438 - accuracy: 0.7435 - val_loss: 0.4864 - val_accuracy: 0.7500
Epoch 25/100
100/100 [==============================] - 11s 113ms/step - loss: 0.5332 - accuracy: 0.7370 - val_loss: 0.4968 - val_accuracy: 0.7580
Epoch 26/100
100/100 [==============================] - 11s 113ms/step - loss: 0.5376 - accuracy: 0.7285 - val_loss: 0.5124 - val_accuracy: 0.7530
Epoch 27/100
100/100 [==============================] - 11s 113ms/step - loss: 0.5203 - accuracy: 0.7355 - val_loss: 0.4824 - val_accuracy: 0.7730
Epoch 28/100
100/100 [==============================] - 12s 115ms/step - loss: 0.5105 - accuracy: 0.7500 - val_loss: 0.6269 - val_accuracy: 0.7230
Epoch 29/100
100/100 [==============================] - 11s 113ms/step - loss: 0.5342 - accuracy: 0.7380 - val_loss: 0.5559 - val_accuracy: 0.7200
Epoch 30/100
100/100 [==============================] - 11s 112ms/step - loss: 0.5027 - accuracy: 0.7595 - val_loss: 0.7527 - val_accuracy: 0.6910
Epoch 31/100
100/100 [==============================] - 11s 113ms/step - loss: 0.5102 - accuracy: 0.7575 - val_loss: 0.4887 - val_accuracy: 0.7710
Epoch 32/100
100/100 [==============================] - 11s 113ms/step - loss: 0.5169 - accuracy: 0.7590 - val_loss: 0.4729 - val_accuracy: 0.7640
Epoch 33/100
100/100 [==============================] - 11s 112ms/step - loss: 0.5048 - accuracy: 0.7565 - val_loss: 0.4876 - val_accuracy: 0.7590
Epoch 34/100
100/100 [==============================] - 11s 113ms/step - loss: 0.5117 - accuracy: 0.7595 - val_loss: 0.5356 - val_accuracy: 0.7410
Epoch 35/100
100/100 [==============================] - 11s 112ms/step - loss: 0.4885 - accuracy: 0.7575 - val_loss: 0.5715 - val_accuracy: 0.7560
Epoch 36/100
100/100 [==============================] - 11s 113ms/step - loss: 0.4971 - accuracy: 0.7565 - val_loss: 0.5159 - val_accuracy: 0.7640
Epoch 37/100
100/100 [==============================] - 11s 112ms/step - loss: 0.5092 - accuracy: 0.7550 - val_loss: 0.4873 - val_accuracy: 0.7770
Epoch 38/100
100/100 [==============================] - 11s 113ms/step - loss: 0.5014 - accuracy: 0.7560 - val_loss: 0.5949 - val_accuracy: 0.7240
Epoch 39/100
100/100 [==============================] - 11s 113ms/step - loss: 0.4994 - accuracy: 0.7690 - val_loss: 0.4809 - val_accuracy: 0.7810
Epoch 40/100
100/100 [==============================] - 11s 114ms/step - loss: 0.4952 - accuracy: 0.7670 - val_loss: 0.5135 - val_accuracy: 0.7320
Epoch 41/100
100/100 [==============================] - 11s 115ms/step - loss: 0.4925 - accuracy: 0.7635 - val_loss: 0.5809 - val_accuracy: 0.7450
Epoch 42/100
100/100 [==============================] - 11s 114ms/step - loss: 0.4752 - accuracy: 0.7730 - val_loss: 0.5508 - val_accuracy: 0.7440
Epoch 43/100
100/100 [==============================] - 11s 114ms/step - loss: 0.4766 - accuracy: 0.7755 - val_loss: 0.4906 - val_accuracy: 0.7490
Epoch 44/100
100/100 [==============================] - 11s 113ms/step - loss: 0.4882 - accuracy: 0.7670 - val_loss: 0.6110 - val_accuracy: 0.7090
Epoch 45/100
100/100 [==============================] - 12s 116ms/step - loss: 0.4770 - accuracy: 0.7655 - val_loss: 0.4894 - val_accuracy: 0.7580
Epoch 46/100
100/100 [==============================] - 11s 113ms/step - loss: 0.4807 - accuracy: 0.7725 - val_loss: 0.5099 - val_accuracy: 0.7620
Epoch 47/100
100/100 [==============================] - 12s 116ms/step - loss: 0.4946 - accuracy: 0.7785 - val_loss: 0.5033 - val_accuracy: 0.7570
Epoch 48/100
100/100 [==============================] - 11s 113ms/step - loss: 0.4751 - accuracy: 0.7780 - val_loss: 0.4813 - val_accuracy: 0.7790
Epoch 49/100
100/100 [==============================] - 11s 114ms/step - loss: 0.4822 - accuracy: 0.7690 - val_loss: 0.4589 - val_accuracy: 0.7800
Epoch 50/100
100/100 [==============================] - 11s 113ms/step - loss: 0.4871 - accuracy: 0.7820 - val_loss: 0.5263 - val_accuracy: 0.7510
Epoch 51/100
100/100 [==============================] - 12s 115ms/step - loss: 0.4750 - accuracy: 0.7785 - val_loss: 0.5108 - val_accuracy: 0.7730
Epoch 52/100
100/100 [==============================] - 11s 112ms/step - loss: 0.4765 - accuracy: 0.7855 - val_loss: 0.4637 - val_accuracy: 0.7740
Epoch 53/100
100/100 [==============================] - 11s 115ms/step - loss: 0.4676 - accuracy: 0.7800 - val_loss: 0.4660 - val_accuracy: 0.7680
Epoch 54/100
100/100 [==============================] - 11s 113ms/step - loss: 0.4673 - accuracy: 0.7880 - val_loss: 0.4775 - val_accuracy: 0.7910
Epoch 55/100
100/100 [==============================] - 11s 113ms/step - loss: 0.4684 - accuracy: 0.7860 - val_loss: 0.4838 - val_accuracy: 0.7550
Epoch 56/100
100/100 [==============================] - 11s 111ms/step - loss: 0.4679 - accuracy: 0.7845 - val_loss: 0.5295 - val_accuracy: 0.7510
Epoch 57/100
100/100 [==============================] - 11s 113ms/step - loss: 0.4873 - accuracy: 0.7690 - val_loss: 0.4552 - val_accuracy: 0.7870
Epoch 58/100
100/100 [==============================] - 11s 114ms/step - loss: 0.4819 - accuracy: 0.7925 - val_loss: 0.4601 - val_accuracy: 0.7800
Epoch 59/100
100/100 [==============================] - 11s 112ms/step - loss: 0.4585 - accuracy: 0.7905 - val_loss: 0.4601 - val_accuracy: 0.7720
Epoch 60/100
100/100 [==============================] - 11s 113ms/step - loss: 0.4723 - accuracy: 0.7840 - val_loss: 0.4987 - val_accuracy: 0.7760
Epoch 61/100
100/100 [==============================] - 11s 113ms/step - loss: 0.4661 - accuracy: 0.7815 - val_loss: 0.5187 - val_accuracy: 0.7490
Epoch 62/100
100/100 [==============================] - 12s 117ms/step - loss: 0.4745 - accuracy: 0.7820 - val_loss: 0.4579 - val_accuracy: 0.8040
Epoch 63/100
100/100 [==============================] - 12s 117ms/step - loss: 0.4606 - accuracy: 0.7925 - val_loss: 0.4651 - val_accuracy: 0.7880
Epoch 64/100
100/100 [==============================] - 12s 116ms/step - loss: 0.4600 - accuracy: 0.7920 - val_loss: 0.4515 - val_accuracy: 0.7900
Epoch 65/100
100/100 [==============================] - 11s 112ms/step - loss: 0.4531 - accuracy: 0.7945 - val_loss: 0.4695 - val_accuracy: 0.7790
Epoch 66/100
100/100 [==============================] - 11s 114ms/step - loss: 0.4688 - accuracy: 0.7915 - val_loss: 0.4800 - val_accuracy: 0.7810
Epoch 67/100
100/100 [==============================] - 11s 114ms/step - loss: 0.4598 - accuracy: 0.7950 - val_loss: 0.4883 - val_accuracy: 0.7750
Epoch 68/100
100/100 [==============================] - 11s 111ms/step - loss: 0.4523 - accuracy: 0.7795 - val_loss: 0.5102 - val_accuracy: 0.7610
Epoch 69/100
100/100 [==============================] - 11s 115ms/step - loss: 0.4636 - accuracy: 0.7900 - val_loss: 0.4693 - val_accuracy: 0.7950
Epoch 70/100
100/100 [==============================] - 12s 115ms/step - loss: 0.4623 - accuracy: 0.7860 - val_loss: 0.4707 - val_accuracy: 0.7740
Epoch 71/100
100/100 [==============================] - 12s 117ms/step - loss: 0.4573 - accuracy: 0.7985 - val_loss: 0.4793 - val_accuracy: 0.7730
Epoch 72/100
100/100 [==============================] - 11s 114ms/step - loss: 0.4564 - accuracy: 0.7975 - val_loss: 0.5321 - val_accuracy: 0.7610
Epoch 73/100
100/100 [==============================] - 11s 112ms/step - loss: 0.4682 - accuracy: 0.7850 - val_loss: 0.4448 - val_accuracy: 0.7850
Epoch 74/100
100/100 [==============================] - 11s 113ms/step - loss: 0.4440 - accuracy: 0.7980 - val_loss: 1.1144 - val_accuracy: 0.6700
Epoch 75/100
100/100 [==============================] - 11s 114ms/step - loss: 0.4432 - accuracy: 0.7935 - val_loss: 0.4834 - val_accuracy: 0.7740
Epoch 76/100
100/100 [==============================] - 12s 116ms/step - loss: 0.4506 - accuracy: 0.8030 - val_loss: 0.4502 - val_accuracy: 0.7890
Epoch 77/100
100/100 [==============================] - 11s 115ms/step - loss: 0.4514 - accuracy: 0.7945 - val_loss: 0.4740 - val_accuracy: 0.7900
Epoch 78/100
100/100 [==============================] - 11s 114ms/step - loss: 0.4540 - accuracy: 0.8005 - val_loss: 0.4838 - val_accuracy: 0.7900
Epoch 79/100
100/100 [==============================] - 12s 117ms/step - loss: 0.4757 - accuracy: 0.7815 - val_loss: 0.5458 - val_accuracy: 0.7480
Epoch 80/100
100/100 [==============================] - 12s 119ms/step - loss: 0.4659 - accuracy: 0.7910 - val_loss: 0.5645 - val_accuracy: 0.7300
Epoch 81/100
100/100 [==============================] - 12s 117ms/step - loss: 0.4481 - accuracy: 0.7990 - val_loss: 0.4628 - val_accuracy: 0.8150
Epoch 82/100
100/100 [==============================] - 11s 112ms/step - loss: 0.4589 - accuracy: 0.7920 - val_loss: 0.5816 - val_accuracy: 0.8060
Epoch 83/100
100/100 [==============================] - 11s 113ms/step - loss: 0.4590 - accuracy: 0.7925 - val_loss: 0.4792 - val_accuracy: 0.7980
Epoch 84/100
100/100 [==============================] - 11s 115ms/step - loss: 0.4438 - accuracy: 0.8005 - val_loss: 0.5137 - val_accuracy: 0.7810
Epoch 85/100
100/100 [==============================] - 11s 114ms/step - loss: 0.4614 - accuracy: 0.7955 - val_loss: 0.6726 - val_accuracy: 0.7400
Epoch 86/100
100/100 [==============================] - 11s 115ms/step - loss: 0.4543 - accuracy: 0.7990 - val_loss: 0.4684 - val_accuracy: 0.8080
Epoch 87/100
100/100 [==============================] - 12s 115ms/step - loss: 0.4626 - accuracy: 0.7980 - val_loss: 0.4441 - val_accuracy: 0.7990
Epoch 88/100
100/100 [==============================] - 11s 114ms/step - loss: 0.4784 - accuracy: 0.7805 - val_loss: 0.5799 - val_accuracy: 0.7530
Epoch 89/100
100/100 [==============================] - 11s 113ms/step - loss: 0.4484 - accuracy: 0.8025 - val_loss: 0.6416 - val_accuracy: 0.7380
Epoch 90/100
100/100 [==============================] - 11s 113ms/step - loss: 0.4505 - accuracy: 0.7900 - val_loss: 0.5750 - val_accuracy: 0.7780
Epoch 91/100
100/100 [==============================] - 11s 112ms/step - loss: 0.4425 - accuracy: 0.8095 - val_loss: 0.4978 - val_accuracy: 0.7800
Epoch 92/100
100/100 [==============================] - 11s 114ms/step - loss: 0.4483 - accuracy: 0.7880 - val_loss: 0.4720 - val_accuracy: 0.7840
Epoch 93/100
100/100 [==============================] - 12s 116ms/step - loss: 0.4597 - accuracy: 0.7970 - val_loss: 0.4987 - val_accuracy: 0.7860
Epoch 94/100
100/100 [==============================] - 12s 116ms/step - loss: 0.4643 - accuracy: 0.7935 - val_loss: 0.4682 - val_accuracy: 0.7940
Epoch 95/100
100/100 [==============================] - 12s 117ms/step - loss: 0.4600 - accuracy: 0.8000 - val_loss: 0.4809 - val_accuracy: 0.7840
Epoch 96/100
100/100 [==============================] - 12s 117ms/step - loss: 0.4568 - accuracy: 0.7940 - val_loss: 0.5734 - val_accuracy: 0.7960
Epoch 97/100
100/100 [==============================] - 12s 115ms/step - loss: 0.4330 - accuracy: 0.8025 - val_loss: 0.5742 - val_accuracy: 0.7580
Epoch 98/100
100/100 [==============================] - 11s 114ms/step - loss: 0.4572 - accuracy: 0.7920 - val_loss: 0.4439 - val_accuracy: 0.8030
Epoch 99/100
100/100 [==============================] - 11s 112ms/step - loss: 0.4435 - accuracy: 0.7995 - val_loss: 0.5557 - val_accuracy: 0.7560
Epoch 100/100
100/100 [==============================] - 11s 112ms/step - loss: 0.4423 - accuracy: 0.8095 - val_loss: 0.4796 - val_accuracy: 0.8040
accuracy_aug = history_augmentation.history['accuracy']
val_accuracy_aug = history_augmentation.history['val_accuracy']
loss_aug = history_augmentation.history['loss']
val_loss_aug = history_augmentation.history['val_loss']

epochs = range(len(accuracy_aug))

plt.subplot(1, 2, 1)
plt.plot(epochs, accuracy_aug, 'bo', label='Training accuracy')
plt.plot(epochs, val_accuracy_aug, 'b', label='Validation accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(epochs, loss_aug, 'bo', label='Training loss')
plt.plot(epochs, val_loss_aug, 'b', label='Validation loss')
plt.legend()
plt.show()

As you can see, the validation accuracy is slightly increased from 72% to 80%.