Building a cutting-edge machine learning system requires time, effort, and experimentation. At Modulate, we've run thousands of experiments consuming hundreds of thousands of hours of GPU time, in our ongoing pursuit of building perfect voice skins. Often we're running 8-12 experiments per ML engineer concurrently. This makes individual, detailed monitoring of experiments difficult, necessitating the development of tools to monitor experiment health at-a-glance. This post discusses one of the techniques we use to maximize the effectiveness of each experiment: pruning bad experiments based on heuristics, developed by watching training statistics over time and correlating them to voice skin performance.
We'll discuss two early signals, Adversary Catching and Mean Runaway, which we watch to indicate failure in the first 12 hours of training, long before final model quality could be judged. Then we'll discuss a longer-term statistic, Dead RELUs, that gives at-a-glance information about the health of the experiment.
Adversary catching is a qualitative switch in the behavior of the voice skin network due to combining adversarial training with a variety of other losses, including "content" losses that preserve the emotional and phonetic content of the input speech. This shift must occur during training in order for the adversary to force the voice skin network to convert to the target voice, and to reduce distortion in the output audio. We typically observe this catching happening as the content loss evens out, typically around ~10k steps in - failure to do so indicates a failing experiment, which can be terminated early.
At the beginning of training the voice skin network is mostly learning how to produce any plausible speech at all, guided by the content losses. At this stage, the adversary's job is easy and it quickly satisfies its training objective. In the plots below, the adversary loss is quickly minimized early on; and the gradient penalty regularizer on the adversary stays relatively low, demonstrating that the adversary isn't working very hard for its low loss values. The content loss is also diving rapidly, as it's the driving force for optimizing the voice skin network to produce speech at these early stages.
The adversary network's loss at the beginning of training. Early on, the adversary is able to distinguish between real and synthetic audio with high accuracy.
The adversary network's gradient penalty regularizer loss at the beginning of training. The relatively low penalty indicates that the network is able to easily do its job.
The voice skin network's content loss at the beginning of training. The sharp slope indicates that minimizing this loss is a driving force in the network early on.
After the initial stage of training, the content loss has forced the voice skin network to produce output speech that starts resembling human speech. Suddenly, the adversary is no longer able to rely on easy features (noise level, volume level, etc.) to distinguish real from synthetic speech, and it is forced to begin training in earnest. What follows is another ~10k steps of alternating high and low accuracy from adversary as it rapidly traverses the loss landscape, before falling into a relatively well behaved long term back-and-forth with the voice skin network.
The adversary network's loss catching after ~10k steps. At this point, the adversary must begin to learn more about the structure of audio, and begins oscillating through the loss landscape.
The adversary network's gradient penalty regularizer loss at the catch point. The loss begins to increase, indicating increasing difficulty in determining real from synthetic audio.
The voice skin network's content loss at the catch point. The content loss begins decreasing more slowly as the focus shifts to the adversary network, eventually leveling off entirely.
By watching these three plots over time, we can determine whether adversary catching is likely to occur in the first ~10-20k steps of training. If it is absent over that period of time, when the content loss is beginning to level out, it is unlikely to occur at all. At that point the experiment will not produce plausible-sounding audio, no matter how long it trains for, so it is killed.
In contrast with Adversary Catching, Mean Runaway is relatively straightforward: we want to ensure that the distribution of sample values in the voice skin network's outputs matches the distribution of sample values in real audio. In particular, our content loss explicitly ignores DC bias in the output audio, relying on the adversary to enforce a plausible offset of sample values. This works for the majority of training time; but early in the process, before adversary catching, the voice skin network's output can "run away" to implausible values.
The means of the audio sample distributions output by the voice skin network. This example shows a healthy training run. In the beginning of training, the audio means can fluctuate dramatically - ensuring that they converge towards the true distribution helps promote adversary catching.
While the means distribution can fluctuate significantly early in training, it must settle close to the true distribution before adversary catching can occur (otherwise the adversary can simply compute the DC bias in the audio as a real/synthetic feature). If the means stay poorly distributed while the content loss begins to level off, it indicates that the adversary is likely to maintain near-zero loss forever, and the experiment will fail. It is therefore usually killed in that instance.
After progressing through the first ~12 hours of training, clearing adversary catching and much of the risk of mean runaway, many experiments will successfully test their desired hypotheses. Nevertheless, things can still go haywire when gradual trends reach a qualitative tipping point.
One area where this can commonly occur is in a RELU nonlinearity layer, where positive inputs to the layer are passed through the network untouched, while negative inputs to the layer are zeroed out. The zeroed out inputs carry no information further through the network, rendering whatever features they represent unimportant to the network's calculation. By selecting some inputs to carry through, and some to die, the network can choose what features to look at for any given input.
However, input features, through either chance or forceful selection from the network, can be zeroed by the RELU unit many times in a row. This leads to the network evolving away from using that output from the RELU layer - causing times where that feature is active to produce larger loss, leading more zeroing of that feature. This cycle can eventually produce a "dead" RELU, where that input feature is always negative, and therefore always zeroed by the RELU. This represents a loss of capacity in the network - the weights corresponding to the input feature, and that use the output feature, are wasted, as the feature is always zeroed.
When a RELU output stays dead over a long period of time, its input features become out-of-sync with the rest of the network, due to not receiving any gradients, along with the wights processing that RELU's outputs. At times, other changes to the earlier layers in the network can send an input feature from strongly negative to positive, reactivating a once-dead RELU. These "undead" RELUs now pass information to vastly out-of-date output weights, which can overwhelm the rest of the network. The resulting output from the network typically causes a jump in the loss value, which in turn causes large gradients that can further throw the network off course.
The number of Dead RELUs over time. Highlighted spikes indicate a number of "undead" RELUs which can cause the network training to veer off track.
Observing a large number of dead RELUs can indicate an issue with the neural network underperforming, as the lost capacity reduces the net's possible expressiveness. Spikes of "undead" RELUs can be a signal that a network is spiraling out of control and should be shut down. If undead RELUs are a common occurrence during training, their impact can be lessened by implementing gradient clipping or other techniques to reduce the impact of single extreme outputs.
We've discussed a few statistics that we measure to evaluate network performance. In the case of Adversary Catching and Mean Runaway, abnormal behavior can indicate that an experiment should be killed in the first ~12 hours of training, to avoid unnecessarily wasted time and GPU resources. By monitoring dead RELUs, we can see at a glance whether an entire batch of experiments are behaving, and which experiments are unlikely to perform well in QA.