This is the second article of our blog post series about TensorFlow Mobile. The first post tackled some of the theoretical backgrounds of on-device machine learning, including quantization and state-of-the-art model architectures. This article deals with quantization-aware model training with the TensorFlow Object Detection API. The third part of this series will describe how you can convert a model with the TensorFlow Lite model converter and how you can deploy the model using Android Studio.
Welcome back! I hope you enjoyed the last blog post about state-of-the-art deep learning on mobile devices. This part will be less theoretical. We will work along an example application to get hands-on experience with TensorFlow. The source code is available on GitHub. As an example use case, we will train an object detector for cars on the cars196 dataset, which we will fetch from TensorFlow Datasets – another handy component of the TensorFlow ecosystem.
I share an apartment with two other students that are enrolled in mechanical engineering and mechatronics. They are geeks when talking about cars and engines, so almost always when I am out with the two of them, someone excitedly exclaims: “Wow, take a look at that car over there! It’s the new (insert arbitrary manufacturer) with a huge (some technical metrics that I do not understand) hybrid drive“, following an extensive discussion about the advantages and disadvantages of hybrid drives and how they work in detail. Meanwhile, I am still looking for the car that is being talked about. Maybe I can speed up my forlorn search for that car with a car classifier in my pocket which will quickly tell me the manufacturer, model and year of each car in my field of view and highlight it with a bounding box. Fortunately, the cars196 dataset provides images with exactly this information.
Let’s see how we can utilize TensorFlow Datasets and TensorFlow Lite to build and train our car detector quantization-aware.
Note: I tested the code on Linux and macOS. A hint for all the Mac users out there: make sure that you have the XCode developer tools installed since they are required to install the TensorFlow Object Detection API.
Prerequisites
- Python>=3.7
- TensorFlow >= 2.3.0
- TensorFlow Datasets >= 2.1.0
- ProtoBuf Compiler (protoc) >= 3.0
A word about Hardware
We are going to train an object detector. Those models are quite heavy and require a lot of computing power and time to become good. I trained everything on a Nvidia GeForce Titan X Pascal with 12 GB GDDR5X VRAM and a CUDA computing capability of 6.1. The crucial part here is the memory. You should also be good to go with a GTX 1080Ti with 11GB, but it will be hard to train the model on smaller cards. Indeed, the configs used for training define only a batch size of six images, which was the maximum for my setup. If you want to learn more about which GPUs are suitable for deep learning, I highly recommend reading this blog post by Tim Dettmers. If you do not have access to a large GPU, I also prepared a Google Colab Notebook for quantization-aware training of an image classifier. You can head over to that notebook, train the classifier and proceed with the next and final part of this blog post series, where we will deploy the models to an Android application with Android Studio.
Repository
To set you up and running in a convenient way, I created a GitHub repository containing all code snippets, results, and trained models discussed in this post. You can head over to the repository at any time and use the provided material to run your own experiments.
A short word about the structure of the repository: at the topmost level, you can find three directories, data, model, and od-api. data contains a script to download and prepare the cars196 dataset that we will use for our model training. Here you can also find the labelmap which is used to translate between numerical labels and class names. The model directory contains a configuration that you can use to train an object detector. The directory od-api contains the TensorFlow Object Detection API as a git submodule. This enables us to define all paths required in the model configuration file to be relative to the git repository. This allows you to start directly with the model training, once you have cloned the repository and installed the object detection API, independent of the machine you decide to use for model training.
Installing the TensorFlow Object Detection API
If you have not already installed the Object Detection API, now is the time to do so. Move to the directory od-api/models/research in the repository. From here, you can install the API by issuing the following commands. Make sure that you have installed the protobuf compiler (protoc) on your machine.
1 2 3 4 5 6 7 8 9 |
# Compile protos. protoc object_detection/protos/*.proto --python_out=. # Install TensorFlow Object Detection API. cp object_detection/packages/tf2/setup.py . python -m pip install --use-feature=2020-resolver . |
You can test the installation with the following command:
1 |
python object_detection/builders/model_builder_tf2_test.py |
Dataset
As mentioned before, we will fetch the cars196 dataset from TensorFlow Datasets (TFDS). TFDS is a collection of datasets ready to use, with TensorFlow or other Python ML frameworks such as Jax. All datasets are exposed as tf.data.Datasets , enabling easy-to-use and high-performance input pipelines. You can find more information about TFDS on their website.
Let’s have a quick look at some of the samples of our dataset to get a better feeling for it. For an interactive version of the following lines of code, please have a look at this colab notebook for quantization-aware training of an image classifier.
The following lines will download the dataset and split it into train, validation and test subsets.
1 2 3 4 5 6 7 8 9 |
(train_ds, val_ds, test_ds), metadata = tfds.load( 'cars196', split=['train[:80%]', 'train[80%:]', 'test'], with_info=True ) |
We use the TFDS slicing API to reserve 80% of the original train split as actual training data and 20% as validation data to verify the training progress after each epoch. We want to see which features are provided by the generator, so we add the with_info=True
argument to the function call.
Let’s see what the tfds.core.DatasetInfo
object stored in metadata looks like.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
tfds.core.DatasetInfo( name='cars196', version=2.0.0, description='The Cars dataset contains 16,185 images of 196 classes of cars. The data is split into 8,144 training images and 8,041 testing images, where each class has been split roughly in a 50-50 split. Classes are typically at the level of Make, Model, Year, e.g. 2012 Tesla Model S or 2012 BMW M3 coupe.', homepage='https://ai.stanford.edu/~jkrause/cars/car_dataset.html', features=FeaturesDict({ 'bbox': BBoxFeature(shape=(4,), dtype=tf.float32), 'image': Image(shape=(None, None, 3), dtype=tf.uint8), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=196), }), total_num_examples=16185, splits={ 'test': 8041, 'train': 8144, }, supervised_keys=('image', 'label'), citation="""@inproceedings{KrauseStarkDengFei-Fei_3DRR2013, title = {3D Object Representations for Fine-Grained Categorization}, booktitle = {4th International IEEE Workshop on 3D Representation and Recognition (3dRR-13)}, year = {2013}, address = {Sydney, Australia}, author = {Jonathan Krause and Michael Stark and Jia Deng and Li Fei-Fei} }""", redistribution_info=, ) |
Alright, so we have a total of 16185 images split roughly into a 50-50 train/test split. Each sample consists of a BBoxFeature
which contains normalized bounding box coordinates in the order [ymin, xmin, ymax, xmax]
, a numeric label and the image of the car. Let’s plot some of the cars together with their bounding boxes.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import matplotlib.pyplot as plt import matplotlib.patches as patches samples = train_ds.take(9) fig = plt.figure(figsize=(10, 10)) for i, sample in enumerate(samples): ax = fig.add_subplot(f'33{i}') image = sample['image'] label = sample['label'] height, width, _ = image.shape bbox = sample['bbox'] * tf.constant([height, width, height, width], tf.float32) bbox_patch = patches.Rectangle((bbox[1], bbox[0]), width=bbox[3] - bbox[1], height=bbox[2] - bbox[0], linewidth=5, edgecolor='cyan', facecolor='none') ax.imshow(image) ax.add_patch(bbox_patch) plt.axis('off') |
Looks good! Now let’s see how we can train the model using the TensorFlow Object Detection API.
Training the model with the TensorFlow Object Detection API
As I have said earlier, we are going to train an object detector in this blog post. There are several reasons why I decided to write a tutorial about object detection rather than yet another image classification guide. The most important one is that there already exists a large amount of image classification tutorials that show how to convert an image classifier to TensorFlow Lite, but I have not found many tutorials about object detection. I assume that this is due to the fact that image classification is a bit easier to understand and set up. However, an object detector supports multiple other potential use cases, like object counting or multi-class classification. We will use the TensorFlow object detection API to train our model. Therefore, we need to transform the data into a specific format expected by the object detection API. The following lines transform an image with bounding boxes into a TFRecord example with the expected format.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
def create_tf_example(example): height, width = example['image'].shape[:-1] filename = b'' # Filename of the image. Empty, because image is not from file encoded_image_data = tf.image.encode_jpeg(example['image']).numpy() image_format = b'jpeg' ymins = [example['bbox'][0]] xmins = [example['bbox'][1]] ymaxs = [example['bbox'][2]] xmaxs = [example['bbox'][3]] # add 1 to label, since it was stored as [0, 195] instead of [1, 196] label = example['label'].numpy() + 1 classes_text = [CATEGORY_INDEX[label]['name'].encode('utf-8')] # List of string class name of bounding box classes = [label] # List of integer class id of bounding box (1 per box) tf_example = tf.train.Example(features=tf.train.Features(feature={ 'image/height': dataset_util.int64_feature(height), 'image/width': dataset_util.int64_feature(width), 'image/filename': dataset_util.bytes_feature(filename), 'image/source_id': dataset_util.bytes_feature(filename), 'image/encoded': dataset_util.bytes_feature(encoded_image_data), 'image/format': dataset_util.bytes_feature(image_format), 'image/object/bbox/xmin': dataset_util.float_list_feature(xmins), 'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs), 'image/object/bbox/ymin': dataset_util.float_list_feature(ymins), 'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs), 'image/object/class/text': dataset_util.bytes_list_feature(classes_text), 'image/object/class/label': dataset_util.int64_list_feature(classes), })) return tf_example |
We do not need to go too much into detail, but I will outline some of the things happening here. Since we load our dataset from TFDS, we already have tensors on which we operate. However, we want to store the image data together with all labels in an efficient binary format called TFRecord. Therefore, we need to specify the encoding of our image and pass the encoded string rather than the raw image data. Ironically, the TFDS copy of cars196 is already stored as TFRecords on your drive, but the TensorFlow object detection API requires us to restructure those records in order to work properly. In the end, a TFRecord stores a collection of TF Examples which are something like dictionaries storing features as key-value pairs. The TensorFlow object detection API requires the structure of those TF Examples to be equivalent to the structure required by the PASCAL VOC (Pattern Analysis, Statistical Modelling, and Computational Learning Visual Object Challenge). The idea behind this format is that we have images as first-order features which can comprise multiple bounding boxes and labels. Thus, we store the samples in a hierarchical manner, starting with images (using the key ‘image’) and adding box coordinates and labels underneath (using keys like ‘image/object/bbox/label’) as lists.
As a second step, we need to write the TF Examples to TFRecord files. It is recommended to shard the dataset across multiple files, if the dataset comprises more than a few thousand examples. Since the cars196 dataset contains over 16k examples, we will store them to multiple TFRecord files. This will allow the API to read input examples in parallel improving throughput during training. We can write the files to disk with the following few lines of code:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
def write_sharded_tf_records(examples, output_filebase, num_shards): with contextlib2.ExitStack() as tf_record_close_stack: output_tfrecords = tf_record_creation_util.open_sharded_output_tfrecords( tf_record_close_stack, output_filebase, num_shards) for index, example in examples.enumerate(): tf_example = create_tf_example(example) output_shard_index = index % num_shards output_tfrecords[output_shard_index].write(tf_example.SerializeToString()) |
To generate the TFRecords for the cars196 dataset, just launch the following command from the repository root:
1 |
python data/cars196_to_pascal_voc_format.py --output_path=./data --num_shards=10 |
This will download the dataset using TensorFlow Datasets and convert it to sharded TFRecords. Once the execution of the script is finished, you can find the generated TFRecord files in the data directory. You should see 20 files in total following the naming convention ( train | test )_cars196.record-?????-of-00010.
Model configuration
In the previous post of this series I talked about state-of-the-art neural network architectures for deep learning on mobile devices. Those were all image classification networks but can be used as part of an object detector. We just have to add a region proposal network that proposes image regions likely containing an interesting object. The classifier will then use those proposed regions and tell us whether there is something relevant (read: a class that is part of our dataset) or not. Since we want to deploy the model to a mobile device using TensorFlow Lite, we currently have to use a Single shot MultiBox Detector (SSD) architecture for this purpose, because this is the only one supported by TensorFlow Lite by the time of this writing.
Fortunately, we do not have to implement the whole neural network architecture on our own since the TensorFlow Object Detection API contains already pre-trained networks. We will use an object detection model that is built on MobileNet, which we learned about in the previous blog post of this series. If you have not read it or want to refresh your knowledge about MobileNet and quantization-aware training, you can find the article here.
We will also load some weights together with the model architecture and use them to finetune our model. Most models in TensorFlow Object Detection API are pre-trained on COCO (common objects in context), a large-scale object detection, segmentation, and captioning dataset. It is widely used as a baseline-detection dataset and therefore used as a starting point for transfer learning. We have to manually download the model checkpoints from the TensorFlow 2 Model Detection Zoo. I created a small bash script in model/ssd_mobilenet_v2/pretrained/ that downloads the checkpoint for the model into the same directory. Just execute it from the repository root. We are now set up to configure the training pipeline.
The TensorFlow Object Detection API was designed using a configuration-driven approach and can be used from the command line. The model architecture, training configuration, data sources and checkpoint directories can be defined using a configuration language that resembles JSON. You can find the configuration for the SSD MobileNet in model/ssd_mobilenet_v2/ssd_mobilenet_v2.config. Let’s have a look at the structure of the config file.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
model { (... Add model config here...) } train_config : { (... Add train_config here...) } train_input_reader: { (... Add train_input configuration here...) } eval_config: { (... Add eval_input configuration here...) } eval_input_reader: { (... Add eval_input configuration here...) } |
For the matter of simplicity, we will only adjust the parts of this file that define training and evaluation configurations and input readers, since the model hyper parameters are already carefully chosen by the TF Object Detection Team. However, feel free to adjust them if you want to. The following list contains all the parameters for each section of the config file that need to be manually adjusted in order to start model training.
train_config
- fine_tune_checkpoint: This specifies the location of the fine tuning checkpoint that we have downloaded for the previous part of this article. It can be an absolute path or relative to the training script of the TensorFlow Object Detection API. Since we have the API as a submodule in our repository, we can specify the path relative to that directory. Therefore, you do not have to adjust the path if you followed all of the prior instructions of this article.
- batch_size: You probably need to adjust this parameter, since it depends on the amount of VRAM that your GPU provides.
The remaining parameters in this section are default. Again, if you want to adjust them feel free to do so. You could for example try to change the hyper parameters of the optimizer or use a different optimizer at all.
train_input_reader / eval_input_reader
- label_map_path: This parameter specifies the path to the label map. Again, it can be an absolute path or relative to the train script. I created the label map for the cars196 dataset and placed it in the data directory of the repository. As before, if you follow this article and want to reproduce my results, just use the path as it is specified.
- tf_record_input_reader: As before, you can leave this parameter untouched. However, if you want to use your own dataset, make sure to adjust the path properly. Especially the suffix of the TFRecord files (?????-of-00010) is only required if your dataset is shared across multiple TFRecord files.
- sample_1_of_n_examples: We add this optional parameter to the eval_input_reader. It tells the API to use only a fraction of the eval data for one evaluation round. You can tweak this parameter however you want to. If you set it to 1, the whole evaluation dataset will be used. Read the section about training and evaluation below to get a better understanding of how the evaluation happens in the TF Object Detection API.
Since we want to train the model quantization-aware, I added a config section for this at the end of the file:
1 2 3 4 5 6 7 8 9 10 11 12 13 |
graph_rewriter { quantization { delay: 60000 weight_bits: 8 activation_bits: 8 } } |
Essentially, this section applies weight and activation quantization after 60.000 training steps and uses 8 bit to store weights and activations each. You can remove this section if you do not want to train the model quantization-aware. Remember that you can always apply post-training quantization with TFLite.
Training and Evaluation
We are almost there! We can now train the model with the TensorFlow Object Detection API. To do so, move to the directory od-api/models/research.
The main entry point to the API is the script object_detection/model_main_tf2.py. It is used to issue training and evaluation jobs alike. We have to provide paths to the pipeline configuration and a model directory, which is used to continually log checkpoints and metrics. Whether you launch a training or an evaluation job depends on the arguments passed to the script. If you add the –checkpoint_dir flag, the program launches in evaluation mode, using the provided checkpoint and the data specified in the eval_input_reader section of the pipeline configuration. I recommend launching both jobs in parallel with a small subset of the evaluation data since the evaluation job will run continually every time a new checkpoint is written to the model directory by the training job. You can then monitor the training progress with TensorBoard and compare training and evaluation performance.
First, create the directory on which you want to create a log during training. I suggest putting everything inside of model/ssd_mobilenet_v2/runs. I will name the experiment “qat“ for quantization-aware training. If you want to follow this, just launch the following command from the repository root:
1 |
mkdir model/ssd_mobilenet_v2/runs && mkdir model/ssd_mobilenet_v2/runs/qat |
After that, for training, run the following from od-api/models/research:
1 2 3 4 5 6 7 8 9 10 11 |
PIPELINE_CONFIG_PATH=../../../model/ssd_mobilenet_v2/ssd_mobilenet_v2.config MODEL_DIR=../../../model/ssd_mobilenet_v2/runs/qat/ python object_detection/model_main_tf2.py \ --pipeline_config_path=${PIPELINE_CONFIG_PATH} \ --model_dir=${MODEL_DIR} \ --alsologtostderr |
To launch the evaluation script, put the training process into the background and launch the following command
1 2 3 4 5 6 7 8 9 10 11 |
CHECKPOINT_DIR=${MODEL_DIR} python object_detection/model_main_tf2.py \ --pipeline_config_path=${PIPELINE_CONFIG_PATH} \ --model_dir=${MODEL_DIR} \ --checkpoint_dir=${CHECKPOINT_DIR} \ --alsologtostderr |
To watch the training progress with TensorBoard, launch the following command:
1 |
tensorboard --logdir=${MODEL_DIR} |
Now it is time to wait for the training to finish. This can take a couple of hours. In the meantime, you can have a look at the results of my training run that I have uploaded to TensorBoard.dev. Your training metrics should behave similarly if you did not adjust any of the model or optimizer hyper-parameters in the pipeline config.
Let’s have a look at some of the results. When you head to the images tab in TensorBoard, you can see some example images and the predictions of your model. You can also see the improvement of your model over time if you move the slider above the images. Below are some examples of my model. The image on the left shows the prediction of the model and the one on the right shows the corresponding ground truth.
On those images, the results look good. The model detects everyday-cars, like the Honda in the second image, as well as more exotic ones like the Ferrari in the fourth image with high confidence, and the location of the bounding box is also very close to the ground truth.
The last thing for today is to convert the model to a format that we can later use to deploy it using TensorFlow Lite. As you may already have guessed, the TensorFlow Object Detection API can help us here. What we want to achieve is to load the checkpoints that were logged during model training and convert them into a format that is readable for TFLite. We can do so by launching the following command from the od-api/models/research folder inside the repository:
1 2 3 4 5 6 7 |
python object_detection/export_tflite_graph_tf2.py \ --pipeline_config_path $PIPELINE_CONFIG_PATH \ --trained_checkpoint_dir $MODEL_DIR \ --output_directory ../../../model/ssd_mobilenet_v2/tflite |
This will take the latest checkpoint from the training run and convert it to a TensorFlow-saved model format which we will use in the next post of this series and convert it to TensorFlow Lite.
Conclusion
In this article we have trained a neural network using the TensorFlow Object Detection API to detect different cars in images. The training happened quantization-aware, so that the model can run efficiently on mobile devices. Finally, we exported our model to the saved model format. If you were not able to train the model due to missing hardware or had some other problems, you can find the saved model in the GitHub Repository in the folder model/ssd_mobilenet_v2/tflite.
I really hope you enjoyed this article! In the next part of this series we are going to integrate our model in a mobile app and see how it works on images taken with a smartphone.
Read on
There’s more to AI-enabled apps than TensorFlow. Check out our dedicated Website to learn more!