This rule raises an issue when pytorch.load
is used to load a model.
Why is this an issue?
In PyTorch, it is common to load serialized models using the torch.load
function. Under the hood, torch.load
uses the
pickle
library to load the model and the weights. If the model comes from an untrusted source, an attacker could inject a malicious
payload which would be executed during the deserialization.
How to fix it
Use a safer alternative to load the model, such as safetensors.torch.load_model
. Alternatively, PyTorch can be instructed to only load
the weights by setting the parameter weights_only=True
. This avoids the use of the pickle
library and is therefore safe.
Note that the use of weights_only
requires saving only the state_dict
of a model instead of the whole model.
Code examples
Noncompliant code example
import torch
model = torch.load('model.pth') # Noncompliant: torch.load is used to load the model
Compliant solution
import torch
import safetensors
model = MyModel()
safetensors.torch.load_model(model, 'model.pth')
Resources
Documentation