The Power of Image Augmentation
In this post, it will show the effect of image augmentation while training Convolutional Neural Network. And it will also show how to use `ImageDataGenerator` in Tensorflow.
- Packages
- Load the dataset
- Check the sample images
- Tensorflow ImageDataGenerator
- Model Build
- Image Augmentation
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
plt.rcParams['figure.figsize'] = (10, 6)
Here, we will use Cats and Dogs datasets from kaggle, which is binary classification problem. For the simplicity, its datasets are filtered with some images.
base_dir = './dataset/cats_and_dogs_filtered'
os.listdir(base_dir)
train_dir = os.path.join(base_dir, 'train')
val_dir = os.path.join(base_dir, 'validation')
os.listdir(train_dir)
os.listdir(val_dir)
To use ImageDataGenerator
in Tensorflow, the folder structure should be organized hierarchically. For example,
- train
- label_1
- label_2
- ...
- val
- label_1
- label_2
- ...
Anyway, we prepare the directory path for the convenience.
train_cat_dir = os.path.join(train_dir, 'cats')
train_dog_dir = os.path.join(train_dir, 'dogs')
val_cat_dir = os.path.join(val_dir, 'cats')
val_dog_dir = os.path.join(val_dir, 'dogs')
train_dog_fname = os.listdir(train_dog_dir)
train_cat_fname = os.listdir(train_cat_dir)
nrows = 4
ncols = 4
pic_idx = 0
fig = plt.gcf()
fig.set_size_inches(ncols * 4, nrows * 4)
pic_idx += 8
next_cat_pic = [os.path.join(train_cat_dir, fname) for fname in train_cat_fname[pic_idx - 8:pic_idx]]
next_dog_pic = [os.path.join(train_dog_dir, fname) for fname in train_dog_fname[pic_idx - 8:pic_idx]]
for i, img_path in enumerate(next_cat_pic + next_dog_pic):
sp = plt.subplot(nrows, ncols, i + 1)
sp.axis('off')
img = mpimg.imread(img_path)
plt.imshow(img)
plt.show()
As you can see, each image has different shape, and some of image contains unlabeled object like human hand, cage, etc. This kind of feature makes hard to train generalization of dataset for classifying cats and dogs. See the CNN performance with naive dataset shortly.
Tensorflow ImageDataGenerator
Anyway, of course we can make image itself for training and validation data, But tensorflow offers nice API called ImageDataGenerator
. If we can make sure the correct structure of dataset folder and appropriate options, we can define pythonic generator for dataset. To compare the performance whether the image augmentation is applied or not, let's build naive data generator. In this case, we will just apply rescale, which can make the color range between 0 and 1.
train_datagen = ImageDataGenerator(rescale=1/255.)
val_datagen = ImageDataGenerator(rescale=1/255.)
train_generator = train_datagen.flow_from_directory(
directory=train_dir,
batch_size=20,
target_size=(150, 150),
class_mode='binary'
)
val_generator = val_datagen.flow_from_directory(
directory=val_dir,
batch_size=20,
target_size=(150, 150),
class_mode='binary'
)
We need to think about the target size
, batch size
, and class_mode
. Since we try to solve binary classification, we need to define binary
for class_mode
. If we want extend the problem with multiclass classification, we just set it to categorical
. And target_size
makes each image to default shape, and its value will affect to the input_shape
in CNN.
After that, we can make pythonic generator that contains dataset. Thankfully, label encoding is applied automatically. If we want to find out which label is mapped, we can use class_indices
of ImageDataGenerator.
train_generator.class_indices
model = Sequential([
Conv2D(16, (3, 3), activation='relu', input_shape=(150, 150, 3)),
MaxPooling2D(2, 2),
Conv2D(32, (3, 3), activation='relu'),
MaxPooling2D(2, 2),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D(2, 2),
Flatten(),
Dense(512, activation='relu'),
Dense(1, activation='sigmoid')
])
model.summary()
model.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001), loss='binary_crossentropy', metrics=['accuracy'])
history_naive = model.fit(train_generator, steps_per_epoch=100, epochs=100, validation_data=val_generator, validation_steps=50)
accuracy_naive = history_naive.history['accuracy']
val_accuracy_naive = history_naive.history['val_accuracy']
loss_naive = history_naive.history['loss']
val_loss_naive = history_naive.history['val_loss']
epochs = range(len(accuracy_naive))
plt.subplot(1, 2, 1)
plt.plot(epochs, accuracy_naive, 'bo', label='Training accuracy')
plt.plot(epochs, val_accuracy_naive, 'b', label='Validation accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(epochs, loss_naive, 'bo', label='Training loss')
plt.plot(epochs, val_loss_naive, 'b', label='Validation loss')
plt.legend()
plt.show()
As you can see, while training accuracy is almost 100%, validation accuracy stucks in 72~3%. From the result, we can make sure that overfitting is occurred, thus model is not trained well in terms of generalization.
Image Augmentation
To overcome this problem, a few approaches exist. One definite way is gather lots of dataset, but it is hard to obtain lots of dataset in most of cases. Instead, we can generate more synthetic images from original data, such as rotation, zoom, shearing. This technique is called Image Augmentation. Actually, we can do this easily with previously mentioned ImageDataGenerator
.
train_aug_datagen = ImageDataGenerator(
rescale=1/255.,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
val_datagen = ImageDataGenerator(
rescale=1/255.
)
train_generator = train_aug_datagen.flow_from_directory(
directory=train_dir,
target_size=(150, 150),
batch_size=20,
class_mode='binary'
)
val_generator = val_datagen.flow_from_directory(
directory=val_dir,
target_size=(150, 150),
batch_size=20,
class_mode='binary'
)
Note that, validation dataset should be remained in original feature. So that's why we only define rescaling in validation data generator.
Let's do the same process with same model.
model = Sequential([
Conv2D(16, (3, 3), activation='relu', input_shape=(150, 150, 3)),
MaxPooling2D(2, 2),
Conv2D(32, (3, 3), activation='relu'),
MaxPooling2D(2, 2),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D(2, 2),
Flatten(),
Dense(512, activation='relu'),
Dense(1, activation='sigmoid')
])
model.summary()
model.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001), loss='binary_crossentropy', metrics=['accuracy'])
history_augmentation = model.fit(train_generator, steps_per_epoch=100, epochs=100, validation_data=val_generator, validation_steps=50)
accuracy_aug = history_augmentation.history['accuracy']
val_accuracy_aug = history_augmentation.history['val_accuracy']
loss_aug = history_augmentation.history['loss']
val_loss_aug = history_augmentation.history['val_loss']
epochs = range(len(accuracy_aug))
plt.subplot(1, 2, 1)
plt.plot(epochs, accuracy_aug, 'bo', label='Training accuracy')
plt.plot(epochs, val_accuracy_aug, 'b', label='Validation accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(epochs, loss_aug, 'bo', label='Training loss')
plt.plot(epochs, val_loss_aug, 'b', label='Validation loss')
plt.legend()
plt.show()
As you can see, the validation accuracy is slightly increased from 72% to 80%.