TorchScript has limited support for Python’s super()
mechanism, which can lead to compilation errors when converting PyTorch models
for deployment.
TorchScript is PyTorch’s way to create serializable and optimizable models from PyTorch code. It allows you to run models independently from
Python, which is essential for production deployment, mobile applications, and performance optimization.
However, TorchScript operates with a subset of Python’s features. The super()
function relies on Python’s method resolution order
(MRO) and dynamic attribute lookup, which are not fully supported in TorchScript’s static compilation environment.
When TorchScript encounters super()
calls, it may fail to properly resolve the method calls during compilation, resulting in runtime
errors or unexpected behavior. This is particularly problematic in forward()
methods of neural network modules, where inheritance is
commonly used.
What is the potential impact?
Using super()
calls in TorchScript methods can cause:
- Compilation failures when converting models to TorchScript format
- Runtime errors in deployed models
- Inconsistent behavior between eager mode and TorchScript execution
- Deployment issues in production environments and mobile applications
How to fix in PyTorch?
Replace super() calls with direct method calls or refactor the inheritance structure to avoid super() usage in TorchScript methods.
Non-compliant code example
import torch
import torch.nn as nn
class MyModel(nn.Module):
@torch.jit.script_method
def forward(self, x):
return super().forward(x) # Noncompliant
Compliant code example
import torch
import torch.nn as nn
class MyModel(nn.Module):
def forward(self, x):
# Avoid super() in TorchScript methods
return self.process(x)
def process(self, x):
return x
For complex inheritance scenarios, explicitly call parent class methods by name instead of using super().
Non-compliant code example
import torch
import torch.nn as nn
class BaseModel(nn.Module):
def forward(self, x):
return x * 2
class DerivedModel(BaseModel):
@torch.jit.script_method
def forward(self, x):
result = super().forward(x) # Noncompliant
return result + 1
Compliant code example
import torch
import torch.nn as nn
class BaseModel(nn.Module):
def forward(self, x):
return x * 2
def base_forward(self, x):
return x * 2
class DerivedModel(BaseModel):
@torch.jit.script_method
def forward(self, x):
result = self.base_forward(x)
return result + 1
Documentation