Using Heuristics to Prune Experiments

Building a cutting-edge machine learning system requires time, effort, and experimentation.  At Modulate, we've run thousands of experiments consuming hundreds of thousands of hours of GPU time, in our ongoing pursuit of building perfect voice skins.  Often we're running 8-12 experiments per ML engineer concurrently.  This makes individual,  detailed monitoring of experiments difficult, necessitating the development of tools to monitor experiment health at-a-glance.  This post discusses one of the techniques we use to maximize the effectiveness of each experiment: pruning bad experiments based on heuristics, developed by watching training statistics over time and correlating them to voice skin performance.

We'll discuss two early signals, Adversary Catching and Mean Runaway, which we watch to indicate failure in the first 12 hours of training, long before final model quality could be judged.  Then we'll discuss a longer-term statistic, Dead RELUs, that gives at-a-glance information about the health of the experiment.  

ADVERSARY CATCHING  

Adversary catching is a qualitative switch in the behavior of the voice skin network due to combining adversarial training with a variety of other losses, including "content" losses that preserve the emotional and phonetic content of the input speech.  This shift must occur during training in order for the adversary to force the voice skin network to  convert to the target voice, and to reduce distortion in the output audio.  We typically observe this catching happening as the content loss evens out, typically around ~10k steps in - failure to do so indicates a failing experiment, which can be terminated early.

At the beginning of training the voice skin network is mostly learning how to produce any plausible speech at all, guided by the content losses.  At this stage, the adversary's job is easy and it quickly satisfies its training objective.  In the plots below, the  adversary loss is quickly minimized early on; and the gradient penalty  regularizer on the adversary stays relatively low, demonstrating that the adversary isn't working very hard for its low loss values.  The  content loss is also diving rapidly, as it's the driving force for optimizing the voice skin network to produce speech at these early stages.

adversary_catch_adversary_start.png

The adversary network's loss at the beginning of training. Early on, the adversary is able to distinguish between real and synthetic audio with high accuracy.

image-asset.png

The adversary network's gradient penalty regularizer loss at the beginning of training. The relatively low penalty indicates that the network is able to easily do its job.

imag-3blog.png

The voice skin network's content loss at the beginning of training. The sharp slope indicates that minimizing this loss is a driving force in the network early on.

After the initial stage of training, the content loss has forced the voice skin network to produce output speech that starts resembling human speech.  Suddenly, the adversary is no longer able to rely on easy features (noise level, volume level, etc.) to distinguish real from synthetic speech, and it is forced to begin training in earnest.  What  follows is another ~10k steps of alternating high and low accuracy from adversary as it rapidly traverses the loss landscape, before falling into a relatively well behaved long term back-and-forth with the voice skin network.

imag-4-blog.png

The adversary network's loss catching after ~10k steps. At this point, the adversary must begin to learn more about the structure of audio, and begins oscillating through the loss landscape.

imag-blog-5.png

The adversary network's gradient penalty regularizer loss at the catch point. The loss begins to increase, indicating increasing difficulty in determining real from synthetic audio.

imag-blog-6.png

The voice skin network's content loss at the catch point. The content loss begins decreasing more slowly as the focus shifts to the adversary network, eventually leveling off entirely.

By watching these three plots over time, we can determine whether adversary catching is likely to occur in the first ~10-20k steps of training.  If it is absent over that period of time, when the content loss is beginning to level out, it is unlikely to occur at all.  At that point the experiment will not produce plausible-sounding audio, no matter how long it trains for, so it is killed.

MEAN RUNAWAY  

In contrast with Adversary Catching, Mean Runaway is relatively  straightforward: we want to ensure that the distribution of sample values in the voice skin network's outputs matches the distribution of sample values in real audio.  In particular, our content loss explicitly ignores DC bias in the output audio, relying on the adversary to  enforce a plausible offset of sample values.  This works for the majority of training time; but early in the process, before adversary catching, the voice skin network's output can "run away" to implausible values.

imag-7-blog.png

The means of the audio sample distributions output by the voice skin network. This example shows a healthy training run. In the beginning of training, the audio means can fluctuate dramatically - ensuring that they converge towards the true distribution helps promote adversary catching.

The true distribution of audio samples from a sample audio clip. The distribution of samples output by the voice skin network should converge to something similar to this.
The true distribution of audio samples from a sample audio clip. The distribution of samples output by the voice skin network should converge to something similar to this.

While the means distribution can fluctuate significantly early in training, it must settle close to the true distribution before adversary catching can occur (otherwise the adversary can simply compute the DC bias in the audio as a real/synthetic feature).  If the means stay poorly distributed while the content loss begins to level off, it indicates that the adversary is likely to maintain near-zero loss forever, and the experiment will fail.  It is therefore usually killed  in that instance.  

DEAD RELUS

After progressing through the first ~12 hours of training, clearing adversary catching and much of the risk of mean runaway, many experiments will successfully test their desired hypotheses.   Nevertheless, things can still go haywire when gradual trends reach a qualitative tipping point.

One area where this can commonly occur is in a RELU nonlinearity layer, where positive inputs to the layer are passed through the network untouched, while negative inputs to the layer are zeroed out.  The  zeroed out inputs carry no information further through the network, rendering whatever features they represent unimportant to the network's calculation.  By selecting some inputs to carry through, and some to die, the network can choose what features to look at for any given input.

However, input features, through either chance or forceful selection from the network, can be zeroed by the RELU unit many times in a row.   This leads to the network evolving away from using that output from the RELU layer - causing times where that feature is active to produce larger loss, leading more zeroing of that feature.  This cycle can  eventually produce a "dead" RELU, where that input feature is always negative, and therefore always zeroed by the RELU.  This represents a loss of capacity in the network - the weights corresponding to the input feature, and that use the output feature, are wasted, as the feature is  always zeroed.

Features from two different inputs passing through the same RELU layer. RELUs in green are active in all cases, and linearly preserve information passing through. Features in blue are sometimes active and sometimes zeroed, representing selection of …
Features from two different inputs passing through the same RELU layer. RELUs in green are active in all cases, and linearly preserve information passing through. Features in blue are sometimes active and sometimes zeroed, representing selection of some features for use in the output. Features in red are "dead", always zeroed by the RELU, and show a loss of capacity for the network.

When a RELU output stays dead over a long period of time, its input features become out-of-sync with the rest of the network, due to not receiving any gradients, along with the wights processing that RELU's  outputs.  At times, other changes to the earlier layers in the network can send an input feature from strongly negative to positive,  reactivating a once-dead RELU.  These "undead" RELUs now pass  information to vastly out-of-date output weights, which can overwhelm the rest of the network.  The resulting output from the network typically causes a jump in the loss value, which in turn causes large  gradients that can further throw the network off course.

imag-10blog.png

The number of Dead RELUs over time. Highlighted spikes indicate a number of "undead" RELUs which can cause the network training to veer off track.

Observing a large number of dead RELUs can indicate an issue with the neural network underperforming, as the lost capacity reduces the net's possible expressiveness.  Spikes of "undead" RELUs can be a signal that a network is spiraling out of control and should be shut down.  If undead RELUs are a common occurrence during training, their impact can be lessened by implementing gradient clipping or other techniques to reduce the impact of single extreme outputs.

CONCLUSION

We've discussed a few statistics that we measure to evaluate network  performance.  In the case of Adversary Catching and Mean Runaway, abnormal behavior can indicate that an experiment should be killed in the first ~12 hours of training, to avoid unnecessarily wasted time and GPU resources.  By monitoring dead RELUs, we can see at a glance whether an entire batch of experiments are behaving, and which experiments are unlikely to perform well in QA.