Classification across artistic domains - Classifying illustrations
Chapters 1 and 2 of the fast.ai book+course cover just enough to train an image classification model. Here I want to practice those tools through a primitive investigation of a question that I had when the lecture was talking about neural networks learning features: how generalizable are those features that are learned across domains?
I don’t have any empirical proof of this, but someone who has only seen real cats and dogs can likely differentiate between illustrations of cats and dogs (e.g. cats and dogs in animation). In other words, there are features of illustrated cats and dogs that let us differentiate between them. It’s likely that some of those features differentiate between real cats and dogs as well.
Can a neural network trained (or more accurately, fine-tuned) on real cats and dogs use its learned features to differentiate between illustrated cats and dogs? To look into this, we’ll do the following:
- Step 1. Fine-tune the pre-trained ResNet18 CNN included with Fast.ai on a dataset of illustrated cats and dogs and look at its performance. This is our baseline.
- Step 2. Fine-tune the pre-trained ResNet18 CNN again, this time on real cats and dogs. We then make predictions on illustrated cats and dogs using this model and compare with the baseline from Step 1.
Remark. I couldn’t find a convenient way to extract the error rate of a machine learning model on a custom dataset. As such, my evaluation of how well each model does on the test dataset is going to be a how-does-it-look evaluation of the confusion matrix. I don’t like it too much, but I’m personally going to let it slide since I’m betting that we’ll pick up those statistical tools later in the course and that’s not my focus right now.
Step 1: Illustrations predicting Illustrations
Step 1.0: Download images as training set and test set
As in the lectures, we’ll grab our data from DDG. That is already a caveat - the training data is going to be a bit of a mess without curating out some illustrations.
In contrast to the example in the fasti.ai book, I’ll need both a training/verficiation set for training the model and a test set on which to compare the models. That behooves us to tweak the code a little bit. Let’s incorporate that into our downloading.
import fastbook
fastbook.setup_book()
from fastbook import *
from fastai.vision.widgets import *
from fastbook import search_images_ddg
import random
animal_types = 'cat', 'dog'
path = Path('images', 'illus_cats_dogs')
I arbitrary choose an amount of 15% of the dataset to be put aside as a test set. The search_images_ddg function returns the urls of a bunch of images as a list, which I split into two lists according to that percentage: train_result and test_results. I then download the images corresponding to those urls into images/illus_cats_dogs/test/o and images/illus_cats_dogs/train/o, respectively, where o is either cat or dog.
test_pct = .15
if not path.exists():
path.mkdir()
for o in animal_types:
dir = o
results = search_images_ddg(f'{o} illustration')
random.shuffle(results)
len_test = int(test_pct * len(results))
train_results = results[len_test:]
test_results = results[:len_test]
dest_test = (path/'test'/dir)
dest_test.mkdir(parents=True, exist_ok=True)
download_images(dest_test, urls=test_results)
dest_train = (path/'train'/dir)
dest_train.mkdir(parents=True, exist_ok=True)
download_images(dest_train, urls=train_results)
Step 1.1: Model Training
As a first pass, we use transfer learning to train ResNet18 to classify illustrated cats and dogs. Here we follow the fast.ai book word for word.
path = Path('images', 'illus_cats_dogs', 'train')
fns = get_image_files(path)
failed = verify_images(fns)
failed.map(Path.unlink)
illus = DataBlock(
blocks=(ImageBlock, CategoryBlock),
get_items = get_image_files,
splitter=RandomSplitter(valid_pct=0.2, seed=61),
get_y=parent_label,
item_tfms=Resize(128)
)
illus = illus.new(item_tfms=Resize(128), batch_tfms=aug_transforms(mult=2))
dls = illus.dataloaders(path)
dls.valid.show_batch(max_n=6, nrows=2)

Now we train our model and check it out:
learn = vision_learner(dls, resnet18, metrics=error_rate)
learn.fine_tune(3)
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

interp.plot_top_losses(10, nrows=2)

The confusion matrix looks pretty reasonable! But the top losses are a bit worrying - some of the predicted images should be unambiguous and I have no idea why some images are here. I also tried running the model with various transforms and also with ResNet34, and they ended up performing significantly worse. It makes me think that there is either a problem with the dataset or that fine-tuning ResNet on illustrations isn’t optimal (probably the former).
Step 1.2: Making predictions on the test set
In the lectures we never interpreted or evaluated a model on a whole set other than the verification set; at most we used it to make predictions one-by-one. So I’m in the wilderness here! I am trying to be as idiomatic as possible, but I’m sure that later in the course I’ll learn the proper way of doing all of this.
I want to evaluate this model on the test data that I separated at the start. The first step is to define and clean the path which contains our test images.
test_path = Path('images', 'illus_cats_dogs', 'test')
fns = get_image_files(test_path)
failed = verify_images(fns)
failed.map(Path.unlink);
A learner makes batch predictions on a dataloader (and maybe other things, but this is all we know about so far!). So we need to convert our test path into a dataloader. We could define a datablock to do this, but I think we can also use the dls.test_dl method to use the same datablock which defined learn to create a new dataloader with new data:
test_dl = learn.dls.test_dl(get_image_files(test_path), with_labels=True)
As far as I can tell, instead of a big Path object one feeds test_dl() an iterable of test items. That can be accommodated: just use get_image_files to return the Paths for all the images. We use with_labels=True because we want to evaluate the accuracy on our custom dataloader, not just make predictions on it.
test_dl.show_batch(max_n=8, nrows=2)

interp = ClassificationInterpretation.from_learner(learn=learn, dl=test_dl)
interp.plot_confusion_matrix()

interp.plot_top_losses(5, nrows=1)

That’s frankly abyssmal. The predictions for cats are especially bad, and the top losses are disheartening: the top three loss-ers should be unambiguous. There are clearly some methodological problems here that could be helped by having a proper dataset of illustrations.
This isn’t all for naught though. The point of the exercise is to see how a model trained on real cats and dogs handles illustrations. Let’s see how that goes.
Step 2. Real animals predicting illustrations
Step 2.0: Downloading data
Now we fine tune our model on photos of real dogs and cats in order to see how it performs on illustrations. We download files as before:
animal_types = 'cat', 'dog'
path = Path('images', 'real_cats_dogs')
test_pct = .15
if not path.exists():
path.mkdir()
for o in animal_types:
dir = o
results = search_images_ddg(f'{o}')
random.shuffle(results)
len_test = int(test_pct * len(results))
train_results = results[len_test:]
test_results = results[:len_test]
dest_test = (path/'test'/dir)
dest_test.mkdir(parents=True, exist_ok=True)
download_images(dest_test, urls=test_results)
dest_train = (path/'train'/dir)
dest_train.mkdir(parents=True, exist_ok=True)
download_images(dest_train, urls=train_results)
path = Path('images', 'real_cats_dogs', 'train')
fns = get_image_files(path)
failed = verify_images(fns)
failed.map(Path.unlink);
Step 2.1: Training the model on real animals
Since dataloaders are data-agnostic (they’re only templates for the data) we can reuse the one we used previously for this model as well.
real_dls = illus.dataloaders(path)
real_dls.valid.show_batch(max_n = 8, nrows=2)

real_learn = vision_learner(real_dls, resnet18, metrics=error_rate)
real_learn.fine_tune(4)
I don’t know whether it’s because of the dataset or the model, but ResNet18 converges to a very accurate model very quickly compared to the illustration dataset (to be fair, it is also a better training set).
real_interp = ClassificationInterpretation.from_learner(real_learn)
real_interp.plot_confusion_matrix()

Predictably brilliant performance.
Step 2.2: Evaluating model on illustrations
With the model in hand, we see how it performs on the testset of illustrations (encoded as test_dl). The way I’m going to do this today is to feed it into ClassificationInterpretation with the appropriate arguments:
test_interp = ClassificationInterpretation.from_learner(learn=real_learn, dl=test_dl)
test_interp.plot_confusion_matrix()

test_interp.plot_top_losses(5, nrows=1)

Terrible confusion matrix, but it doesn’t seem to work any worse than the model trained on illustrations. More revealing for me are the top losses. This is subjective, but I think the mistakes that this model makes feel more like misclassifications that a human would make than the mistakes the model from Step 1 made.
Conclusions
This was mostly an exercise in fast.ai than any legitimate research. But it was a good exercise in understanding dataloaders and learners, and I think I learned something trying to work with this test_dl method. In the future I’d like to try the experiment again with a better dataset and better statistical tools from fast.ai.