Skip to content

Conversation

@sayakpaul
Copy link
Member

What does this PR do?

@AstraliteHeart could you also review?

@sayakpaul sayakpaul requested a review from stevhliu April 15, 2025 14:08
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

AuraFlow can be compiled with `torch.compile()` to speed up inference latency even for different resolutions. First, install PyTorch nightly following the instructions from [here](https://pytorch.org/). The snippet below shows the changes needed to enable this:

```diff
+ torch.fx.experimental._config.use_duck_shape = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be clearer if we also briefly explain what use_duck_shape is in the context of compiling with AuraFlow.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cc: @StrongerXi. Could you provide us with any hints?

Copy link

@StrongerXi StrongerXi Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/pytorch/pytorch/blob/4c4a5df73bd40f4ff2f6f69acb636617f38f5320/torch/fx/experimental/_config.py#L86-L87

This flag changes whether we should use the same symbolic variable to represent input sizes that are the same.

So duck_shape = True means -- under dynamic shapes, upon compilation of a torch.compile-labelled component, if your inputs have dimensions that are the same, the compiler assumes they will be the same in the future as well.

For instance, this could happen if you generate 512x512 images first, which eventually becomes something like 64x64 hidden_states input to the transformer, and compilation adds this assumption s0 == s1 for future hidden_states input with dynamic shape s0 x s1. Then if you generate 1024x512 image next, the transformer block sees an input of, say, 128x64 hidden_states, and the aforementioned s0 == s1 assumption breaks, which triggers recompilation.

So duck_shape = False avoids adding that s0 == s1 assumption in the first place, so that no recompilation will be triggered from its breaking. Although this means compiler might have less opportunity to do optimizations (since it lost the s0 == s1 assumption), but its practical effects are hard to predict, and I only saw a small difference from my limited local experiment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@AstraliteHeart
Copy link
Contributor

LGTM, perhaps mention that this enables from 100% (on low resolutions) to a 30% (on 1536x1536) speed improvements and link to any necessary deps/compilation docs (i.e. Triton)? I expect AuraFlow docs to get more attention after Pony V7 release so providing some extra context for users would be beneficial.

@sayakpaul sayakpaul merged commit ce1063a into main Apr 16, 2025
4 checks passed
@sayakpaul sayakpaul deleted the auraflow-compilation-docs branch April 16, 2025 05:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants