Continual learning in the federated-learning context
Using gradient diversity to optimize selection of past samples for retention improves performance while combatting catastrophic forgetting.
Federated learning is a process in which distributed devices, each with its own store of locally collected data, can contribute to a global machine learning model without transmitting the data itself. By keeping data local, federated learning both reduces network traffic and protects data privacy.
Continual learning is the process of continually updating a model as new data becomes available. The key is to avoid “catastrophic forgetting”, in which model updates based on new data overwrite existing settings, degrading performance on the old data.
In a paper we presented at this year’s Conference on Empirical Methods in Natural-Language Processing (EMNLP), we combine these two techniques, with a new method for doing continual federated learning that improves upon its predecessors.
One way to protect against catastrophic forgetting is for each device to retain samples of the data it’s already seen. When new data comes in, it’s merged with the old data, and the model is retrained on the joint dataset.
The crux of our method is a procedure for selecting data samples for retention. We present the procedure in two varieties: uncoordinated, in which each device selects its own samples locally; and coordinated, in which sample selection is coordinated across devices by a central server.
In experiments, we compared our sample selection approach to three predecessors. The methods’ relative performance depended on how many prior samples a device could store. At 50 and 100 samples, both versions of our method significantly outperformed their predecessors, but the uncoordinated method offered slightly better performance than the coordinated.
At 20 samples, our methods again enjoyed a significant advantage over the benchmarks, but the coordinated version was the top performer. At 10 and fewer samples, other methods began to overtake ours.
Gradient-based sample selection
For any given data sample, the graph of a machine learning model’s loss function against the settings of its parameters can be envisioned as a landscape, with peaks representing high-error outputs and troughs representing low-error outputs. Given the model’s current parameter settings — a particular point on the landscape — the goal of the machine learning algorithm is to pick a direction that leads downhill, toward lower-error outputs. The negative of downhill direction is known as a gradient.
A common way to select samples for retention is to maximize diversity of gradients, which ensures a concomitant diversity in the types of information contained in the samples. Since a gradient is simply a direction in a multidimensional space, selecting samples whose gradients sum to zero maximizes diversity: all the gradients point in different directions.
The problem of optimizing gradient diversity can be formulated as assigning each gradient a coefficient of 1 or 0 such that the sum over all gradients is as close to zero as possible. The sum of the coefficients, in turn, should be equal to the memory budget available for storing samples. If we have space on our device for N samples, we want N coefficients to be 1 and the rest to be 0.
This is, however, an NP-complete problem, as it requires systematically trying out different combinations of N gradients. We propose relaxing this requirement, so that, while the sum of the coefficients is still N, the coefficients themselves may be fractional. This is a computationally tractable problem, since it requires only successive refinements of an initial guess. Finally, we select the N samples with the highest coefficients.
In our experiments, this uncoordinated approach was the best-performing method for doing continual federated learning with an N of 50 or higher: each device simply optimized gradient diversity locally. Presumably, with enough bites at the apple, local sampling provides good enough coverage of important gradients for the model as a whole.
An N of 20, however, requires more careful sample selection, and that’s where our coordinated method performed best.
The coordinated method alternates between summing gradients locally and globally. First, each device finds a local optimization whose sum is as close to zero as possible. Then it sends the aggregated gradients for all of its local samples and their computed coefficients to a central server. Aggregating gradients, rather than sending them individually, protects against potential attacks that try to reverse-engineer locally stored data from their gradients.
Usually, the local choice of coefficients will not yield a sum of exactly zero. The central server considers the existing, nonzero sums from all devices and computes the minimal modification of all of them that will yield a global sum of zero. Then it sends the modified sums back to the devices, as new nonzero targets for optimization.
This process can be repeated as many times as necessary, but in our experiments, we found that just one iteration was generally enough to achieve a global sum that was very close to zero. After the final iteration, each device selects the data samples corresponding to the N largest coefficients in its sum.
As baselines for our experiments, we used three prior sampling strategies. One was a naïve uniform sampling approach, which simply samples from all the data currently on the device, and the other two used weighted sampling to try to ensure a better balance between previously seen and newly acquired data.
At N = 10, the random-sampling approaches were competitive with our approach, and at N = 5, they outperformed it. But in practice, distributed devices are frequently able to store more than five or ten samples. And our paper provides a guide to optimizing the sample selection strategy to the device capacity.