Incremental tensor concatenation using torch.cat()
in loops creates significant performance and correctness problems.
Each call to torch.cat()
creates a new tensor and copies all existing data. When done repeatedly in a loop, this results in quadratic
time complexity O(n²) instead of linear O(n). For example, concatenating 1000 tensors this way performs roughly 500,000 copy operations instead of
just 1000.
Starting with empty tensors (torch.Tensor()
) introduces additional issues. Empty tensors have undefined dimensions and may use
different data types or devices than the tensors being concatenated. This can lead to runtime errors or unexpected tensor shapes.
The repeated memory allocations and deallocations also create unnecessary garbage collection pressure, further degrading performance in
memory-constrained environments.
What is the potential impact?
Performance degradation becomes severe with larger datasets, potentially making training or inference impractically slow. Memory usage increases
significantly due to intermediate tensor copies. Runtime errors may occur when concatenating with empty tensors due to dimension or type
mismatches.
How to fix?
Collect tensors in a Python list during the loop, then use torch.stack()
for a new dimension or torch.cat()
once at the
end. This reduces complexity from O(n²) to O(n).
Non-compliant code example
import torch
# Inefficient: incremental concatenation
batched_tensors = torch.Tensor() # Noncompliant
for i in range(100):
tensor = torch.rand(10, 20)
batched_tensors = torch.cat((batched_tensors, tensor)) # Noncompliant
Compliant code example
import torch
# Efficient: collect then stack/concatenate once
tensor_list = []
for i in range(100):
tensor = torch.rand(10, 20)
tensor_list.append(tensor)
batched_tensors = torch.stack(tensor_list) # Shape: [100, 10, 20]
Documentation