Skip to content

[inductor] TorchInductor does not correctly recognize the grad status of model code #125474

@xuzhao9

Description

@xuzhao9

🐛 Describe the bug

At

if config.freezing and not torch.is_grad_enabled():
, torchinductor finds torch.is_grad_enabled() is True even if the compiled code sets with torch.no_grad().

Reproduction script:

import torch
import torch.nn.functional as F

class TestModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 20, 5)

    def forward(self, x):
        return F.relu(self.conv1(x))

model = TestModule()
example_inputs = torch.randn([1, 1, 10, 10]).cpu()

@torch.compile
def run(model, example_inputs):
    with torch.no_grad():
        # Check inductor grad status at
        # https://github.com/pytorch/pytorch/blob/30610251ec7b8f7e0507df06c3aadbcf90658e0e/torch/_inductor/compile_fx.py#L1370
        # If we print `torch._inductor.config.cpp.weight_prepack` , torchinductor believes torch.is_grad_enabled() is False
        # Otherwise, torchinductor believes torch.is_grad_enabled() is True 
        # print("torch._inductor.config.cpp.weight_prepack", torch._inductor.config.cpp.weight_prepack)
        model(example_inputs)

run(model, example_inputs)

Versions

The latest nightly release (20240503).

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions