Skip to content

Commit ad29b23

Browse files
authored
[LLGA] Replace aten::type_as with aten::to to support a corner-case (#3030)
* Replace aten::type_as with aten::to to support a corner-case the issue here is that an intermediate output of one of the partitions (output of add op) is to be used as an input to type_as later in the graph. It's easier to work around this issue by replacing all aten::type_as nodes with aten::to nodes. Another potential fix is to add LLGA End node after such intermediate outputs that are to be reused later in the graph, but that solution would require more extensive changes, since no info regarding future use of intermediate outputs would be available to IPEX/PyTorch while creating partitions at subgraph level. * Update graph_helper.cpp * Update test_jit_llga_fuser.py
1 parent 98caa70 commit ad29b23

File tree

3 files changed

+42
-3
lines changed

3 files changed

+42
-3
lines changed

‎csrc/cpu/jit/codegen/onednn/graph_helper.cpp‎

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,20 @@ bool LlgaGraphHelper::isSingleQuantDequantTo(Node* n) {
739739
n->kind() != Symbol::aten("quantize_per_channel") &&
740740
n->kind() != Symbol::aten("dequantize") && n->kind() != aten::to)
741741
return false;
742+
// Check if aten::to is used for non-quantized case
743+
if (n->kind() == aten::to) {
744+
auto input_dtype = n->input(0)->type()->expect<TensorType>()->scalarType();
745+
auto output_dtype =
746+
n->outputs()[0]->type()->expect<TensorType>()->scalarType();
747+
if (input_dtype.has_value() && output_dtype.has_value()) {
748+
if ((input_dtype.value() == at::ScalarType::Float ||
749+
input_dtype.value() == at::ScalarType::BFloat16) &&
750+
(output_dtype.value() == at::ScalarType::Float ||
751+
output_dtype.value() == at::ScalarType::BFloat16)) {
752+
return false;
753+
}
754+
}
755+
}
742756
if (!opToOwningPartition_.has(n))
743757
return false;
744758

‎csrc/cpu/jit/codegen/onednn/prepare_binary.cpp‎

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,28 @@ void handleBinaryOpInputs(Node* node, int first_input, int second_input) {
9090
}
9191
}
9292

93+
static void ReplaceTypeAsWithTo(Block* block) {
94+
for (auto node : block->nodes()) {
95+
for (auto sub : node->blocks()) {
96+
ReplaceTypeAsWithTo(sub);
97+
}
98+
99+
if (node->kind() == aten::type_as) {
100+
auto nodeOutputTypePtr = node->output()->type()->expect<TensorType>();
101+
c10::optional<at::ScalarType> outputDtype =
102+
nodeOutputTypePtr->scalarType();
103+
if (outputDtype.has_value()) {
104+
auto g = node->prev()->owningGraph();
105+
auto replacementNodeOutput =
106+
g->insert(aten::to, {node->input(0), outputDtype.value()});
107+
replacementNodeOutput->setType(
108+
nodeOutputTypePtr->withScalarType(outputDtype.value()));
109+
node->outputs()[0]->replaceAllUsesWith(replacementNodeOutput);
110+
}
111+
}
112+
}
113+
}
114+
93115
static void ConvertScalarToTensor(Block* block) {
94116
for (auto node : block->nodes()) {
95117
for (auto sub : node->blocks()) {
@@ -246,6 +268,8 @@ void PrepareBinaryForLLGA(const std::shared_ptr<Graph>& graph) {
246268
EliminateDeadCode(graph);
247269
// ConvertScalarToTensor must be placed after EliminateIdentityMulAddDiv
248270
replaceWithSelectOp(graph->block());
271+
ReplaceTypeAsWithTo(graph->block());
272+
EliminateDeadCode(graph);
249273
ConvertScalarToTensor(graph->block());
250274
// TODO: after conv-bn folding, bias will become bias? (Optional) after this
251275
// pass and will lose it when using mustNotBeNone to check Optional Bias

‎tests/cpu/test_jit_llga_fuser.py‎

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,8 @@ def forward(self, x):
222222
m = M()
223223
x = torch.rand(8, 12, 12, 12)
224224
graph, _ = self.checkTrace(m, [x])
225-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
225+
# One partition for softmax & another for TypeCast
226+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
226227

227228
def _gen_binary_inputs(self, gen_permute=True):
228229
for xshape, yshape in [
@@ -507,8 +508,8 @@ def forward(self, x):
507508
m = M(dst_dtype)
508509

509510
graph, _ = self.checkTrace(m, [x])
510-
# we do not rewrite single to
511-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
511+
# Even a single TypeCast is mapped to oneDNN Graph
512+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
512513

513514
@llga_fp32_bf16_test_env
514515
def test_typecheck(self):

0 commit comments

Comments
 (0)