Image Classification with Tarantella

This section delves into a more advanced usage of Tarantella by looking at distributed training for state-of-the-art image classification models.

The image classification model architectures are imported through the tf.keras.applications module, available in recent TensorFlow releases.


Deep Residual Networks (ResNets) represented a breakthrough in the field of computer vision, enabling deeper and more complex deep convolutional networks. Introduced in [He], ResNet-50 has become a standard model for image classification tasks, and has been shown to scale to very large number of nodes in data parallel training [Goyal].

Run Resnet-50 with Tarantella

Let’s assume we have access to two nodes (saved in hostfile) equipped with 4 GPUs each. We can now simply run the ResNet-50 as follows:

cd examples/image_classification
tarantella --hostfile ./hostfile --devices-per-node 4 \
-- ./ --model_arch=resnet50 --batch_size=256 --train_epochs=3 \
                            --val_freq=3 --train_num_samples=2560 --val_num_samples=256 \

The above command will train a ResNet-50 models on the 8 devices available in parallel for 3 epochs. The --val_freq parameter specifies the frequency of evaluations of the validation dataset performed in between training epochs.

Note the --batch_size parameter, which specifies the global batch size used in training.

The --synthetic_data instructs the code to generate a synthetic ImageNet-like dataset, that can be used to showcase the training procedure. However, it will not provide any meaningful results. Remove the --synthetic_data parameter a specify a --data_dir path to an actual ImageNet directory to properly train the model.


On the STYX cluster, a pre-downloaded version of the ImageNet dataset can be found in /home/DATA/PUBLIC_DATA/ImageNet.


On STYX, don’t forget to add the following -x flags to the tarantella command to correctly detect the GPUs (and redo the steps from here if in a new terminal)

tarantella -x XLA_FLAGS="--xla_gpu_cuda_data_dir=${CONDA_ENV_PATH}/lib" ...

Implementation overview

We will now look closer into the implementation of the ResNet-50 training scheme. The main training steps reside in the examples/image_classification/ file.

The most important step in enabling data parallelism with Tarantella is to wrap the Keras model into a Tarantella model that uses data parallelism for speeding up training.

This is summarized below for the ResNet50 model:

model = tf.keras.applications.resnet50.ResNet50(include_top=True, weights=None, classes=1000,
                                                input_shape=(224, 224, 3), input_tensor=None,
                                                pooling=None, classifier_activation='softmax')
model = tnt.Model(model,
                  parallel_strategy = tnt.ParallelStrategy.DATA)

Next, the training procedure can simply be written down as it would be for a standard, TensorFlow-only model. No further changes are required to train the model in a distributed manner.

In particular, the ImageNet dataset is loaded and preprocessed as follows:

train_input_dataset = load_dataset(dataset_type='train',
                                   data_dir=args.data_dir, num_samples = args.train_num_samples,
                                   batch_size=args.batch_size, dtype=tf.float32,

The load_dataset function reads the input files in data_dir, loads the training samples, and processes them into TensorFlow datasets.

The user only needs to pass the global batch_size value, and the Tarantella framework will ensure that the dataset is properly distributed among devices, such that:

  • each device will process an independent set of samples

  • each device will group the samples into micro batches, where the micro-batch size will be roughly equal to batch_size / num_devices. If the batch size is not a multiple of the number of ranks, the remainder samples will be equally distributed among the participating ranks, such that some ranks will use a micro-batch of (batch_size / num_devices) + 1.

  • each device will apply the same set of transformations to its input samples as specified in the load_dataset function.

The advantage of using the automatic dataset distribution mechanism of Tarantella is that users can reason about their I/O pipeline without taking care of the details about how to distribute it.

Before starting the training, the model is compiled using a standard Keras optimizer and loss.

model.compile('optimizer' : tf.keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9),
              'loss' : tf.keras.losses.SparseCategoricalCrossentropy(),
              'metrics' : [tf.keras.metrics.SparseCategoricalAccuracy()])

We provide flags to enable the most commonly used Keras callbacks, such as the TensorBoard profiler, which can simply be passed to the fit function of the Tarantella model.

callbacks.append(tf.keras.callbacks.TensorBoard(log_dir = flags_obj.model_dir,
                                                profile_batch = 2))

There is no need for any further changes to proceed with distributed training:

history =,
                    validation_data = val_dataset,



Goyal, Priya, et al. “Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour.” arXiv preprint arXiv:1706.02677 (2017).


He, Kaiming, et al. “Deep residual learning for image recognition.” Proceedings of the IEEE conference on computer vision and pattern recognition. arXiv preprint arXiv:1512.03385 (2016).