PyTorch provides two ways to apply operations: module classes and functional operations. Module classes like nn.Softmax
are designed
to be instantiated once (typically in __init__
) and reused, while functional operations like F.softmax
are meant to be
called directly on tensors.
Instantiating modules inline within forward methods creates several problems:
- Performance overhead: Creating new module instances on every forward pass is inefficient and wastes computational resources
- Memory waste: Each instantiation allocates unnecessary memory that could be avoided
- Potential bugs: Simply calling
nn.Softmax(x)
creates a module instance but doesn’t actually apply the operation,
leading to incorrect results
- Code clarity: The intent becomes unclear - it’s not obvious whether you’re creating a module or applying an operation
The forward method should focus on the computational flow, not object creation. Operations that don’t maintain state should use functional
equivalents, while stateful operations should use pre-instantiated modules.
What is the potential impact?
This anti-pattern can lead to performance degradation due to repeated object instantiation, increased memory usage, and potential runtime errors
when the module is not properly applied to the input tensor.
How to fix in PyTorch?
Replace inline module instantiation with the corresponding functional operation from torch.nn.functional. This is the preferred approach for
stateless operations.
Non-compliant code example
import torch.nn as nn
class MyModel(nn.Module):
def forward(self, x):
output = nn.Softmax(dim=1)(x) # Noncompliant
return output
Compliant code example
import torch.nn as nn
import torch.nn.functional as F
class MyModel(nn.Module):
def forward(self, x):
output = F.softmax(x, dim=1)
return output
If you need to reuse the same module configuration, instantiate it once in init and use it in forward. This approach is better for modules
with learnable parameters or specific configurations.
Non-compliant code example
import torch.nn as nn
class MyModel(nn.Module):
def forward(self, x):
dropout = nn.Dropout(p=0.5)(x) # Noncompliant
return dropout
Compliant code example
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.dropout = nn.Dropout(p=0.5)
def forward(self, x):
dropout = self.dropout(x)
return dropout
Documentation