-
Notifications
You must be signed in to change notification settings - Fork 26.2k
Closed
Closed
Copy link
Labels
module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
At
pytorch/torch/_inductor/compile_fx.py
Line 1371 in 3061025
| if config.freezing and not torch.is_grad_enabled(): |
torch.is_grad_enabled() is True even if the compiled code sets with torch.no_grad().
Reproduction script:
import torch
import torch.nn.functional as F
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5)
def forward(self, x):
return F.relu(self.conv1(x))
model = TestModule()
example_inputs = torch.randn([1, 1, 10, 10]).cpu()
@torch.compile
def run(model, example_inputs):
with torch.no_grad():
# Check inductor grad status at
# https://github.com/pytorch/pytorch/blob/30610251ec7b8f7e0507df06c3aadbcf90658e0e/torch/_inductor/compile_fx.py#L1370
# If we print `torch._inductor.config.cpp.weight_prepack` , torchinductor believes torch.is_grad_enabled() is False
# Otherwise, torchinductor believes torch.is_grad_enabled() is True
# print("torch._inductor.config.cpp.weight_prepack", torch._inductor.config.cpp.weight_prepack)
model(example_inputs)
run(model, example_inputs)
Versions
The latest nightly release (20240503).
cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire
FireballDWF
Metadata
Metadata
Assignees
Labels
module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module