Hierarchical representations improve image retrieval
A new metric-learning loss function groups together superclasses and learns commonalities within them.
Image matching has many practical applications. For instance, image retrieval systems like Amazon’s StyleSnap or the Amazon Shopping app’s Camera Search let customers upload photos to search for similar images. Image matching usually works by mapping images to a representational space (an embedding space) and finding images whose mappings are nearby.
In a paper that we presented last week at WACV 2022, we explained how to improve image retrieval accuracy by explicitly modeling object hierarchies when training neural networks to compute image representations.
A shopping site, for instance, might classify a group of products as apparel, a superclass that contains the classes T-shirt and hoodie, which in turn contain instances of specific T-shirts and specific hoodies.
In our paper, we show how to leverage such hierarchies when building image retrieval systems or, if no hierarchies exist, how to construct them. In experiments that compared our approach to nine predecessors on five different datasets using multiple performance measures, we find that our approach delivers the best results a large majority of the time.
Deep metric learning
Image matching for image retrieval typically relies on deep metric learning (DML), in which a deep neural network learns not only how to map inputs to an embedding space but also the distance function used to measure proximity in that space.
There are two dominant loss functions for training DML networks: pairwise and proxy losses. Pairwise losses (e.g., contrastive, triplet) are computed between positive and negative pairs, pulling positive pairs closer, while pushing negative pairs apart. Proxy losses (e.g., proxy-NCA, proxy-anchor) learn a set of embeddings, called proxies, that represent the average locations of members of a class, or the class centroid. The loss for each training sample is computed with respect to the proxies.
Pairwise losses need to sample informative pairs/triplets from the training data; this is not needed in proxy losses, removing the complexity of pair sampling and speeding up the training. In particular, the proxy anchor loss has been shown to achieve state-of-the-art image retrieval accuracy, while converging much faster than pairwise losses.
Our work proposes a new proxy loss that explicitly uses information about class hierarchies to improve image retrieval accuracy.
With hierarchical data, there is an opportunity to impose additional constraints on the embedding space via the loss function, so that images in the same superclass are also grouped together, as shown below. This will not only help the model generalize to unseen classes — because it learns the commonalities within superclasses — but it will also lead to more reasonable retrievals when the model makes mistakes.
Hierarchical proxy loss
Our hierarchical proxy loss (HPL) is an extension of existing proxy losses. HPL consists of a hierarchy of proxies, and each training image is assigned to a single proxy at every level, as shown in the next figure. Then, the loss is computed as the weighted sum of the proxy losses of all levels.
At every level, each image is pulled to the assigned proxy and pushed away from all the other proxies. This induces the network to group images hierarchically by learning the commonalities within each group at every level.
Now, the question is how to build the proxy hierarchy when the data hierarchy is not provided — by, say, an e-commerce catalogue. In such cases, we apply online clustering on lower-level proxies during training to obtain higher-level proxies.
We begin with a DML model that has been trained to generate proxies at the most fine-grained level. Then we run the following training algorithm:
- Run a clustering algorithm, e.g., k-means, on fine proxies to obtain coarse proxies; assign each sample to one coarse proxy.
- Train the network for T iterations, updating both the network and the fine-grained proxies.
- After every T iterations, update the assignments of samples to coarse proxies and update the higher-level proxies by averaging the assigned lower-level proxies.
- Repeat steps 2–3 until convergence.
We implemented HPL on top of the latest proxy losses, proxy-NCA and proxy anchor loss, the second of which is the state-of-the-art loss in metric learning. We evaluated image retrieval accuracy on five standard metric-learning datasets and found that HPL consistently improved the retrieval accuracy over both proxy-NCA and proxy anchor loss, achieving a new state of the art. A comprehensive experimental evaluation is available in our paper.