Mathematical operations involving floating-point arithmetic can suffer from precision loss and catastrophic cancellation, particularly when dealing
with values close to critical points. Two common patterns that exhibit these issues are:
-
torch.log(1 + x)
- When x
is close to zero, adding it to 1 can result in precision loss
-
torch.exp(x) - 1
- When x
is close to zero, the subtraction can cause catastrophic cancellation
In floating-point arithmetic, these operations can lead to significant numerical errors:
- For
torch.log(1 + x)
when x
is very small (like 1e-15
), the computation 1 + x
might not
accurately represent the true mathematical result due to the limited number of significant digits that can be represented.
- For
torch.exp(x) - 1
when x
is close to zero, torch.exp(x)
returns a value very close to 1.0, and
subtracting 1 from this value can result in significant loss of precision.
PyTorch provides specialized functions that are designed to compute these expressions accurately:
-
torch.log1p(x)
computes log(1 + x)
accurately for small values of x
-
torch.expm1(x)
computes exp(x) - 1
accurately for small values of x
These functions use specialized algorithms that avoid the problematic intermediate computations, maintaining precision even when dealing with
values close to critical points.
This numerical instability can propagate through calculations and significantly affect the accuracy of machine learning models, particularly in
scenarios involving:
- Small gradients during training
- Probability calculations with values close to critical points
- Mathematical operations in activation functions
- Loss function computations
- Iterative optimization algorithms
What is the potential impact?
Using numerically unstable mathematical operations instead of their stable counterparts can lead to:
- Reduced accuracy in mathematical computations
- Potential instability in machine learning model training
- Incorrect results in probability calculations
- Accumulation of numerical errors in iterative algorithms
- Inaccurate gradient calculations during training
- Convergence issues in optimization algorithms
- Reduced model performance due to accumulated numerical errors
- Inconsistent results across different hardware or precision settings
How to fix?
Replace the numerically unstable patterns with their stable PyTorch equivalents:
- Replace
torch.log(1 + x)
with torch.log1p(x)
- Replace
torch.exp(x) - 1
with torch.expm1(x)
These specialized functions compute the same mathematical results but with better numerical stability for values near critical points.
Non-compliant code examples
import torch
# Numerically unstable logarithm computation
result1 = torch.log(1 + x) # Noncompliant
# Numerically unstable exponential computation
result2 = torch.exp(x) - 1 # Noncompliant
Compliant code examples
import torch
# Numerically stable logarithm computation
result1 = torch.log1p(x)
# Numerically stable exponential computation
result2 = torch.expm1(x)
Documentation
Standards