The first thing that springs to mind is that there might be NaNs somewhere in the data, however that seems to have been ruled out in a comment to the OP. Putting that aside, what you are describing sounds a lot like exploding gradients which can occur when the gradients become excessively large during training, leading to unstable training and loss divergence which can eventually result in NaNs. Another possibility is vanishing gradients, occur when gradients become exceedingly small during backpropagation, preventing effective weight updates in earlier layers of the network. However, in my experience exploding gradients are much more common, and this answer concentrates on exploding gradients - though many of the ideas in this answer can address both issues (and other issues encountered with training neural networks).
FYIW, my money is on the NaNs causing the classification issue. To determine whether you have either issue, monitor the gradient magnitudes during training by tracking their norms (eg., the L2 norm) and visualising trends using tools like TensorBoard or bespoke visualisations. Observe the loss function for erratic spikes and check for sudden divergence in model outputs, which may signal instability caused by exploding gradients. Additionally, inspect weight updates to identify drastic changes between iterations, as such behaviour can indicate unstable updates contributing to the problem.
To fix the problem of exploding gradients, there are several things you can try (in no particular order):
1.Use of custom layers:
- Custom layers might involve operations like division, exponentials, or logarithms that can become unstable when processing edge cases or extreme inputs. Adding safeguards such as clamping or small constants (eg., adding $\epsilon$ to denominators) can help ensure stability.
- Custom layers might inadvertently block or distort gradient flow, especially if non-differentiable or poorly scaled operations are introduced. Visualise gradients during backpropagation (eg., using hooks in
PyTorch or GradientTape in TensorFlow) to help ensure they propagate correctly through these layers. - Test each custom layer with synthetic data to confirm it behaves as expected across a wide range of inputs. Specifically, verify that outputs remain stable and gradients flow appropriately for inputs with different magnitudes or edge cases.
2. Use of Custom Loss Function:
- A custom loss function that outputs very large or very small values can destabilise training, especially if gradients are amplified or suppressed as a result. Check loss outputs across training batches to ensure they remain within a reasonable range.
- The custom loss function might fail to account for edge cases, such as zero inputs, extreme values, or invalid probabilities, leading to
NaN or extreme gradients. Add checks and fallback mechanisms to handle such cases safely. - Temporarily replace the custom loss function with a standard one (eg., binary cross-entropy) to see if the issue persists. If using the custom loss is important, compare its outputs to the standard loss to verify similar behaviour under normal conditions.
3. Adjust the Learning Rate
- If set too high, updates may overshoot the optimal values, causing exploding gradients or oscillate, making training unstable.
- A high learning rate can cause gradients to explode as updates become excessively large. Reducing the learning rate slows down updates, helping to stabilise training and control gradient magnitudes.
- A gradual warmup starts with a small learning rate and incrementally increases it over the initial training epochs. This helps to allow the model to adapt to the optimisation process without sudden large updates that could destabilise the gradients.
4. Gradient Clipping
- Gradients are rescaled if their norm exceeds a predefined threshold, maintaining their direction but reducing their magnitude.
- Two popular methods are norm-based clipping where the gradients are re-scaled, and value-based clipping where gradients are truncated when their values exceed a specific target
- Deep learning frameworks including Keras, PyTorch and TensorFlow provide built-in clipping functionality.
5. Ensure Proper Weight Initialisation
- If the initial weights are too small or too large, they can cause gradients to vanish or explode during backpropagation which destabilises training and may prevent convergence. Proper weight initialisation ensures that gradients remain within a manageable range throughout training.
- Xavier/Glorot Initialisation (Glorot and Bengio, 2010) sets the initial weights by drawing values from a distribution with a variance proportional to the number of input and output neurons. This helps to ensure that activations and gradients flow properly through the network for activation functions like
tanh. - He initialisation (He et al, 2015) was developed for activation functions like
ReLU-based activations and its variants ([Leaky ReLU][4] for example).
6. Employ Normalisation Techniques
- Techniques like Batch Normalisation reduce internal covariate shift (the change in the distribution of layer inputs during training as the parameters of previous layers are updated, which can slow down training and destabilise optimisation) by standardising layer inputs during training. This helps to stabilise activations and gradients, which is important when training deep networks like transformers, helping to prevent exploding or vanishing gradients that could lead to
NaN losses. - Batch Normalisation also allows the use of higher learning rates, which can improve convergence. By normalising inputs across a batch, it makes the network more robust to varying input distributions.
- Unstable gradients might cause the model to predict a single class (mode collapse). Normalisation smooths the optimisation landscape, reducing extreme gradient updates and minimising the risk of the model locking onto one class prematurely.
7. Use a Subset of the Dataset
- Training on a subset reduces computation time, allowing faster iterations to identify whether the problem stems from the model architecture, training process, or dataset properties. If the issue persists with subsets, it’s likely related to the model or training setup rather than the data.
- While there are apparently no
NaNs in the data, other issues—such as extreme outliers, inconsistencies in the input scales (eg., intensity or mass values), or unexpected correlations—could cause instability in training. Testing a small, clean subset can help isolate these data-related problems. - Running the model on smaller subsets makes it easier to debug whether it is prematurely collapsing to a single class or exhibiting other unexpected behaviours, helping to pinpoint whether the issue lies in the optimisation process or the input-output relationship.
8. Modify Activation Functions
- Activation functions like
tanh and sigmoid are prone to vanishing gradients, especially in deep networks, as they squash values into a small range, limiting gradient flow (Hochreiter, 1998). Replacing them with functions like [ReLU][4] or its variants can mitigate this by maintaining stronger gradient signals during backpropagation. [ReLU][4] is computationally efficient and helps avoid the gradient problem by keeping gradients non-zero for positive inputs, enabling deeper models to train more effectively. [Leaky ReLU][4] addresses the issue of ReLU dying (outputs being permanently zero for negative inputs) by allowing small negative gradients, while [GELU][6] smooths activation transitions, providing a balance between ReLU's efficiency and gradient flow stability. - Replace activations like
tanh or sigmoid with [ReLU][4] or its variants (eg., [Leaky ReLU][4], [GELU][6]).
9. Optimise Network Architecture
- Overly complex networks can amplify gradient-related issues, such as vanishing or exploding gradients, and are more prone to instability during training. Simplifying the architecture reduces the depth or width of the network, making optimisation more stable and computationally efficient.
- If the task does not require a deep or wide network, consider reducing the number of layers or neurons per layer. This decreases the risk of numerical instability while maintaining adequate model capacity for simpler problems.
- Tailoring the architecture to the specific requirements of the task, such as the structure of the input data (eg., numeric vectors in this case), helps eliminate redundant parameters that may contribute to instability.
- An ablation study systematically removes or modifies components of the network to evaluate their impact on performance. This can help identify architectural elements that may be contributing to instability, such as unnecessary layers, overly large parameter counts, or specific design choices that exacerbate gradient-related issues.
10. Regularisation
- Regularisation discourages the model from overfitting or relying too heavily on specific parameters by penalising large weights with the addition of a penalty term. This can mitigate instability and reduce the likelihood of exploding gradients.
- L1 regularisation adds a penalty proportional to the absolute value of weights, promoting sparsity in the network by driving some weights to zero, which simplifies the model and enhances stability.
- L2 regularisation penalises the square of the weights, constraining their magnitudes and preventing them from growing excessively, which helps stabilise gradient updates during training.
Here are a few relevant links from Cross Validated which also might help:
Deep Learning sentiment analysis model always predicts same class
One class never gets predicted, regardless of the model
How do you know that your classifier is suffering from class imbalance?
References
Glorot, X., & Bengio, Y. (2010). Understanding the difficulty of training deep feedforward neural networks. Proceedings of the 13th International Conference on Artificial Intelligence and Statistics (AISTATS), 249–256.
https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf
He, K., Zhang, X., Ren, S., & Sun, J. (2015). Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification. Proceedings of the IEEE International Conference on Computer Vision (ICCV), 1026–1034.
https://openaccess.thecvf.com/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf
The vanishing gradient problem during learning recurrent neural nets and problem solutions. International Journal of Uncertainty, Fuzziness and Knowledge-Based Systems, 6(2), 107–116.
https://doi.org/10.1142/S0218488598000094
https://www.bioinf.jku.at/publications/older/2304.pdf