The importance of forgetting in artificial and animal intelligence
The surprising dynamics related to learning that are common to artificial and biological systems.
Deep neural networks (DNNs) have taken the AI research community by storm, approaching human-like performance in niche learning tasks from recognizing speech to finding objects in images. The industry has taken notice, with adoption growing by 37% in the past four years, according to Gartner, a leading research and advisory firm.
But how does a DNN learn? What “information” does it contain? How is such information represented, and where is it stored? How does information content in the DNN change during learning?
In 2016, my collaborators and I (then at UCLA) set out to answer some of these questions. To frame the questions mathematically, we had to form a viable definition of “information” in deep networks.
Traditional information theory is built around Claude Shannon’s idea to quantify how many bits are needed to send a message. But as Shannon himself noted, this is a measure of information for communication. When applied to measure how much information a DNN has in its weights about the task it is trying to solve, it has the unwelcome tendency to give degenerate nonsensical values.
This paradox led to the introduction a more general notion of the information Lagrangian — which defines information as the trade-off between how much noise could be added to the weights between layers and the resulting accuracy of its input-output behavior. Intuitively, even if a network is very large, this suggests that if we can replace most computations with random noise and still get the same output, then the DNN does not actually contain that much information. Pleasingly, for some particular noise models, we can conduct specializations to recover Shannon’s original definition.
The next step is related to the computing of information for DNNs with millions of parameters.
As learning progresses, one would expect the amount of information stored in the weights of the network to increase monotonically: the more you train, the more you learn. However, the information in the weights (the blue line in the figure at right) follows a completely different path: First, the information contained in the weights increases sharply, as if the network was trying to acquire information about the data set. Following this, the information in the weights drops — almost as though the network was “forgetting”, or shedding information about the training data. Amazingly, such forgetting is occurring while performance in the learning task, shown in the green dashed curve, continues to increase!
When we shared these findings with biologists, they were not surprised. In biological systems, forgetting is an important aspect of learning. Animal brains have a bounded capacity. There is an ongoing need to forget useless information and consolidate useful information. However, DNNs are not biological in nature. There is no apparent reason why memorizing first, and then forgetting, should be beneficial.
Our research uncovered another connected discovery — one that was surprising to our biologist collaborator as well.
Biological networks have another fundamental property: they lose their plasticity over time. If people do not learn a skill (say, seeing or speaking) during a critical period of development, their ability to learn that skill is permanently impaired. This is common when it comes to humans, where, for example, failure to correct visual defects early enough during childhood can result in lifelong amblyopia-impaired vision in one eye, even if the defect is later corrected. The importance of the critical learning period is especially pronounced in the animal kingdom — for example, it is vital for birds developing the ability to sing.
The inability to learn a new skill later in life is considered a side effect of the loss of neuronal plasticity due to several biochemical factors. Artificial neural networks, on the other hand, have no plasticity. They do not age. Why then would they have a critical learning period?
We set out to repeat a classical experiment of neuroscience pioneers Hubel and Wiesel, who in the '50s and '60s studied the effect of temporary visual deficit in cats after birth and correlated the phenomenon to permanent visual impairment later in life.
We “blindfolded” the DNNs by blurring the training images at the beginning of the training. Then we let the network train on clear images. We found that the deficiency introduced in the initial period resulted in permanent deficit (classification accuracy loss), no matter how much additional training the network performed.
In other words, DNNs exhibit critical learning periods just like biological systems. If we messed with the data during the “information acquisition” phase, the network would get into a state from which it cannot recover. Altering the data after this critical period has no effect.
We then performed a process akin to “artificial neural recording” and measured the information flow among different neurons. We found that during the critical period, the way information flows between layers is fluid. However, after the critical period, these ways become fixed. Unlike neural plasticity, a DNN exhibits some form of “information plasticity”, where the ability to process information is lost during learning. But rather than being a consequence of aging or some complex biochemical phenomenon, this “forgetting” appears to be an essential part of learning. This is true for both artificial and biological systems.
Over the subsequent years, we tried to understand and analyze these dynamics related to learning that are common to artificial and biological systems.
We found a rich universe of findings. Some of our learnings are already making their way into our products. For instance, it is common in AI to train a DNN model to solve a task — say, finding cats and dogs in images — and then fine-tune it for a different task — say, recognizing objects for autonomous-driving applications. But how do we know what model to start from to solve a customer problem? When are two learning tasks “close”? How do we represent learning tasks mathematically, and how do we compute their distance?
To give just one practical application of our research, Task2Vec is a method for representing a learning task with a simple vector. This vector is a function of the information in the weights discussed earlier. The amount of information needed to fine-tune one model from another is an (asymmetric) distance between the tasks the two models represent. We can now measure how difficult it would be to fine-tune a given model for a given task. This is part of our Amazon Rekognition Custom Labels service, where customers can provide a few sample images of objects, and the system learns a model to detect them and classify them in never-before-seen images.
AI is truly in its infancy. The depth of the intellectual questions raised by the field is invigorating. For now, there’s consolation for those of us aging and beginning to forget things. We can take comfort in the knowledge that we are still learning.