PyTorch’s torch.flatten()
and tensor.flatten()
methods flatten all dimensions by default, starting from dimension 0. This
includes the batch dimension, which is usually not intended.
When working with batched data in neural networks, the first dimension typically represents the batch size. Flattening this dimension results in a
1D tensor that loses the batch structure, making it impossible to process multiple samples independently.
This behavior differs from frameworks like Keras, where flatten operations automatically preserve the batch dimension. Developers transitioning
from Keras or those new to PyTorch often encounter this issue, leading to unexpected tensor shapes and model failures.
For example, if you have a tensor with shape (32, 100, 100)
representing 32 images of size 100x100:
-
torch.flatten(x)
produces shape (320000,)
- all dimensions flattened
-
torch.flatten(x, start_dim=1)
produces shape (32, 10000)
- batch dimension preserved
The second approach maintains the batch structure, allowing proper batch processing in neural networks.
What is the potential impact?
Flattening the batch dimension can cause:
- Model training failures due to incompatible tensor shapes
- Silent bugs where the model processes data incorrectly
- Performance issues when batch processing is broken
- Difficult-to-debug shape mismatches in neural network layers
How to fix in PyTorch?
Specify the start_dim
parameter to preserve the batch dimension when flattening tensors.
If the first dimension should be also part of the flattening, it is a good practice to be explicit and specify start_dim=0
.
Non-compliant code example
import torch
x = torch.randn(32, 100, 100) # Shape: (batch_size, height, width)
flattened = torch.flatten(x) # Noncompliant: flattens all dimensions
# Result shape: (320000,) - batch dimension lost
Compliant code example
import torch
x = torch.randn(32, 100, 100) # Shape: (batch_size, height, width)
flattened = torch.flatten(x, start_dim=1) # Preserves batch dimension
# Result shape: (32, 10000) - batch dimension preserved
Use nn.Flatten()
layer which preserves the batch dimension by default.
Non-compliant code example
import torch
x = torch.randn(32, 100, 100)
flattened = x.flatten() # Noncompliant: flattens all dimensions
Compliant code example
import torch
import torch.nn as nn
x = torch.randn(32, 100, 100)
flatten_layer = nn.Flatten() # Defaults to start_dim=1
flattened = flatten_layer(x) # Preserves batch dimension
Documentation