Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
4fc12e2
created test for pinning first and last block on device
bconstantine Nov 28, 2025
93e6d31
fix comments in tests for cleaner code
bconstantine Nov 28, 2025
3455019
Support explicit block modules in group offloading
Aki-07 Nov 29, 2025
9c3c14f
Add pinning support to group offloading hooks
Aki-07 Nov 29, 2025
3b3813d
Expose group offload pinning options in API
Aki-07 Nov 29, 2025
b9e0994
created test for pinning first and last block on device
bconstantine Nov 28, 2025
a99755a
Support explicit block modules in group offloading
Aki-07 Nov 29, 2025
ffad316
Expose group offload pinning options in API
Aki-07 Nov 29, 2025
33d8b52
removed deprecated flag pin_first_last
bconstantine Nov 30, 2025
ed8a97a
created test for pinning first and last block on device
bconstantine Nov 28, 2025
de38128
Support explicit block modules in group offloading
Aki-07 Nov 29, 2025
c72ddbc
Expose group offload pinning options in API
Aki-07 Nov 29, 2025
1cd3355
removed deprecated flag pin_first_last
bconstantine Nov 30, 2025
1194a83
Address review feedback for group offload pinning
Aki-07 Dec 10, 2025
3ef894d
Apply style fixes
github-actions[bot] Dec 11, 2025
7a2f3f0
Merge branch 'feature/group-offload-pinning' of https://github.com/bc…
bconstantine Dec 11, 2025
005e51b
Merge branch 'feature/group-offload-pinning' of https://github.com/bc…
bconstantine Dec 11, 2025
1bd4539
Fix disk offload block_modules recursion to avoid extra files
Aki-07 Dec 11, 2025
c82820e
Merge branch 'main' into feature/group-offload-pinning
Aki-07 Dec 12, 2025
93c253f
Prefix block offload group ids with module prefix
Aki-07 Dec 12, 2025
8d059e6
Attach group offload hook to root when fully grouped
Aki-07 Dec 12, 2025
b950c74
Fix leaf-level group offload root hook
Aki-07 Dec 12, 2025
8da39a3
Apply style fixes after lint
Aki-07 Dec 13, 2025
6c5e41a
Avoid eager offload before adapters load
Aki-07 Dec 13, 2025
d08d988
removed apply_block_offloading_to_submodule
bconstantine Dec 15, 2025
61b3662
normalize pin groups changed to validate pin groups
bconstantine Dec 15, 2025
0cbd079
added default_stream
bconstantine Dec 15, 2025
53659d8
Eagerly write disk offload tensors for safetensor checks
Aki-07 Dec 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 173 additions & 23 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass, replace
from enum import Enum
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Callable, Dict, List, Optional, Set, Tuple, Union

import safetensors.torch
import torch
Expand Down Expand Up @@ -62,6 +62,7 @@ class GroupOffloadingConfig:
block_modules: Optional[List[str]] = None
exclude_kwargs: Optional[List[str]] = None
module_prefix: Optional[str] = ""
pin_groups: Optional[Union[str, Callable]] = None


class ModuleGroup:
Expand Down Expand Up @@ -94,6 +95,7 @@ def __init__(
self.record_stream = record_stream
self.onload_self = onload_self
self.low_cpu_mem_usage = low_cpu_mem_usage
self.pinned = False

self.offload_to_disk_path = offload_to_disk_path
self._is_offloaded_to_disk = False
Expand Down Expand Up @@ -156,7 +158,7 @@ def _pinned_memory_tensors(self):
finally:
pinned_dict = None

def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream=None):
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor.data.record_stream(default_stream)
Expand Down Expand Up @@ -212,7 +214,6 @@ def _onload_from_memory(self):

context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
default_stream = self._torch_accelerator_module.current_stream() if self.stream is not None else None

with context:
if self.stream is not None:
with self._pinned_memory_tensors() as pinned_memory:
Expand Down Expand Up @@ -291,7 +292,8 @@ def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None
self.config = config

def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
if self.group.offload_leader == module:
# For disk offload we materialize the safetensor files upfront so callers can inspect them immediately.
if self.group.offload_to_disk_path is not None and self.group.offload_leader == module:
self.group.offload_()
return module

Expand All @@ -301,6 +303,27 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
if self.group.onload_leader is None:
self.group.onload_leader = module

if self.group.pinned:
if self.group.onload_leader == module and not self._is_group_on_device():
self.group.onload_()

should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
if should_onload_next_group:
self.next_group.onload_()

should_synchronize = (
not self.group.onload_self
and self.group.stream is not None
and not should_onload_next_group
and not self.group.record_stream
)
if should_synchronize:
self.group.stream.synchronize()

args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
kwargs = self._send_kwargs_to_device(kwargs)
return args, kwargs

# If the current module is the onload_leader of the group, we onload the group if it is supposed
# to onload itself. In the case of using prefetching with streams, we onload the next group if
# it is not supposed to onload itself.
Expand All @@ -313,7 +336,10 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
self.next_group.onload_()

should_synchronize = (
not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
not self.group.onload_self
and self.group.stream is not None
and not should_onload_next_group
and not self.group.record_stream
)
if should_synchronize:
# If this group didn't onload itself, it means it was asynchronously onloaded by the
Expand All @@ -325,10 +351,18 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
self.group.stream.synchronize()

args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
kwargs = self._send_kwargs_to_device(kwargs)
return args, kwargs

def post_forward(self, module: torch.nn.Module, output):
if self.group.pinned:
return output

# Some Autoencoder models use a feature cache that is passed through submodules
# and modified in place. The `send_to_device` call returns a copy of this feature cache object
# which breaks the inplace updates. Use `exclude_kwargs` to mark these cache features
if self.group.offload_leader == module:
self.group.offload_()
return output

def _send_kwargs_to_device(self, kwargs):
exclude_kwargs = self.config.exclude_kwargs or []
if exclude_kwargs:
moved_kwargs = send_to_device(
Expand All @@ -337,15 +371,21 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
non_blocking=self.group.non_blocking,
)
kwargs.update(moved_kwargs)
else:
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
return kwargs
return send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)

return args, kwargs
def _is_group_on_device(self) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have erased the duplicate method names for _is_group_on_device

tensors = []
for group_module in self.group.modules:
tensors.extend(list(group_module.parameters()))
tensors.extend(list(group_module.buffers()))
tensors.extend(self.group.parameters)
tensors.extend(self.group.buffers)

def post_forward(self, module: torch.nn.Module, output):
if self.group.offload_leader == module:
self.group.offload_()
return output
if len(tensors) == 0:
return True

return all(t.device == self.group.onload_device for t in tensors)


class LazyPrefetchGroupOffloadingHook(ModelHook):
Expand All @@ -358,9 +398,10 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):

_is_stateful = False

def __init__(self):
def __init__(self, pin_groups: Optional[Union[str, Callable]] = None):
self.execution_order: List[Tuple[str, torch.nn.Module]] = []
self._layer_execution_tracker_module_names = set()
self.pin_groups = pin_groups

def initialize_hook(self, module):
def make_execution_order_update_callback(current_name, current_submodule):
Expand Down Expand Up @@ -442,6 +483,50 @@ def post_forward(self, module, output):
group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group
group_offloading_hooks[i].next_group.onload_self = False

if self.pin_groups is not None and num_executed > 0:
param_exec_info = []
for idx, ((name, submodule), hook) in enumerate(zip(self.execution_order, group_offloading_hooks)):
if hook is None:
continue
if next(submodule.parameters(), None) is None and next(submodule.buffers(), None) is None:
continue
param_exec_info.append((name, submodule, hook))

num_param_modules = len(param_exec_info)
if num_param_modules > 0:
pinned_indices = set()
if isinstance(self.pin_groups, str):
if self.pin_groups == "all":
pinned_indices = set(range(num_param_modules))
elif self.pin_groups == "first_last":
pinned_indices.add(0)
pinned_indices.add(num_param_modules - 1)
elif callable(self.pin_groups):
for idx, (name, submodule, _) in enumerate(param_exec_info):
should_pin = False
try:
should_pin = bool(self.pin_groups(submodule))
except TypeError:
try:
should_pin = bool(self.pin_groups(name, submodule))
except TypeError:
should_pin = bool(self.pin_groups(name, submodule, idx))
if should_pin:
pinned_indices.add(idx)

pinned_groups = set()
for idx in pinned_indices:
if idx >= num_param_modules:
continue
group = param_exec_info[idx][2].group
if group not in pinned_groups:
group.pinned = True
pinned_groups.add(group)

for group in pinned_groups:
if group.offload_device != group.onload_device:
group.onload_()

return output


Expand All @@ -461,6 +546,19 @@ def pre_forward(self, module, *args, **kwargs):
return args, kwargs


VALID_PIN_GROUPS = {"all", "first_last"}


def _validate_pin_groups(pin_groups: Optional[Union[str, Callable]]) -> Optional[Union[str, Callable]]:
if pin_groups is None or callable(pin_groups):
return pin_groups
if isinstance(pin_groups, str) and pin_groups in VALID_PIN_GROUPS:
return pin_groups
raise ValueError(
f"`pin_groups` must be None, {', '.join(repr(v) for v in sorted(VALID_PIN_GROUPS))}, or a callable."
)


def apply_group_offloading(
module: torch.nn.Module,
onload_device: Union[str, torch.device],
Expand All @@ -474,6 +572,7 @@ def apply_group_offloading(
offload_to_disk_path: Optional[str] = None,
block_modules: Optional[List[str]] = None,
exclude_kwargs: Optional[List[str]] = None,
pin_groups: Optional[Union[str, Callable]] = None,
) -> None:
r"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
Expand Down Expand Up @@ -535,9 +634,13 @@ def apply_group_offloading(
List of module names that should be treated as blocks for offloading. If provided, only these modules will
be considered for block-level offloading. If not provided, the default block detection logic will be used.
exclude_kwargs (`List[str]`, *optional*):
List of kwarg keys that should not be processed by send_to_device. This is useful for mutable state like
List of kwarg keys that should not be processed by `send_to_device`. This is useful for mutable state like
caching lists that need to maintain their object identity across forward passes. If not provided, will be
inferred from the module's `_skip_keys` attribute if it exists.
pin_groups (`"first_last"` or `"all"` or `Callable`, *optional*, defaults to `None`):
Optionally keeps selected groups on the onload device permanently. Use `"first_last"` to pin the first and
last parameter-bearing groups, `"all"` to pin every parameter-bearing group, or pass a callable that
receives a module (and optionally the module name and index) and returns `True` to pin that group.

Example:
```python
Expand Down Expand Up @@ -577,6 +680,7 @@ def apply_group_offloading(
if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None:
raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.")

pin_groups = _validate_pin_groups(pin_groups)
_raise_error_if_accelerate_model_or_sequential_hook_present(module)

if block_modules is None:
Expand All @@ -597,11 +701,16 @@ def apply_group_offloading(
offload_to_disk_path=offload_to_disk_path,
block_modules=block_modules,
exclude_kwargs=exclude_kwargs,
module_prefix="",
pin_groups=pin_groups,
)
_apply_group_offloading(module, config)


def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
registry = HookRegistry.check_if_exists_or_initialize(module)
registry._group_offload_pin_groups = config.pin_groups

if config.offload_type == GroupOffloadingType.BLOCK_LEVEL:
_apply_group_offloading_block_level(module, config)
elif config.offload_type == GroupOffloadingType.LEAF_LEVEL:
Expand All @@ -617,7 +726,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
done at the top-level blocks and modules specified in block_modules.

When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified
module, recursively apply block offloading to it.
module, we either offload the entire submodule or recursively apply block offloading to it.
"""
if config.stream is not None and config.num_blocks_per_group != 1:
logger.warning(
Expand All @@ -634,7 +743,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf

for name, submodule in module.named_children():
# Check if this is an explicitly defined block module
if name in block_modules:
if block_modules and name in block_modules:
# Track submodule using a prefix to avoid filename collisions during disk offload.
# Without this, submodules sharing the same model class would be assigned identical
# filenames (derived from the class name).
Expand All @@ -643,7 +752,6 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf

_apply_group_offloading_block_level(submodule, submodule_config)
modules_with_group_offloading.add(name)

elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
# Handle ModuleList and Sequential blocks as before
for i in range(0, len(submodule), config.num_blocks_per_group):
Expand Down Expand Up @@ -672,6 +780,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
else:
# This is an unmatched module
unmatched_modules.append((name, submodule))
modules_with_group_offloading.add(name)

# Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups):
Expand Down Expand Up @@ -709,6 +818,25 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
_apply_group_offloading_hook(module, unmatched_group, config=config)
else:
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
elif config.stream is None and config.offload_to_disk_path is None:
# Ensure the top-level module always has a hook when no unmatched modules/params/buffers,
# to satisfy hook presence checks in tests. Using an empty group avoids extra offload files.
empty_group = ModuleGroup(
modules=[],
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=None,
offload_leader=module,
onload_leader=module,
parameters=[],
buffers=[],
non_blocking=False,
stream=None,
record_stream=False,
onload_self=True,
group_id=f"{config.module_prefix}{module.__class__.__name__}_empty_group",
)
_apply_group_offloading_hook(module, empty_group, config=config)


def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
Expand All @@ -735,7 +863,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=name,
group_id=f"{config.module_prefix}{name}",
)
_apply_group_offloading_hook(submodule, group, config=config)
modules_with_group_offloading.add(name)
Expand Down Expand Up @@ -782,10 +910,32 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=name,
group_id=f"{config.module_prefix}{name}",
)
_apply_group_offloading_hook(parent_module, group, config=config)

# Ensure the top-level module also has a group_offloading hook so hook presence checks pass,
# even when it holds no parameters/buffers itself.
if config.stream is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even when all real groups sit in child modules, the root needs a group_offloading hook so the model stays marked as offloaded. That keeps the guardrails working (_is_group_offload_enabled still blocks .to()/.cuda() and conflicting offloads, and reapply/remove logic finds the hook). Without it, a wrapper with no params would look un-offloaded and could be moved or re-offloaded into a bad state

root_registry = HookRegistry.check_if_exists_or_initialize(module)
if root_registry.get_hook(_GROUP_OFFLOADING) is None:
empty_group = ModuleGroup(
modules=[],
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=None,
offload_leader=module,
onload_leader=module,
parameters=[],
buffers=[],
non_blocking=False,
stream=None,
record_stream=False,
onload_self=True,
group_id=f"{config.module_prefix}{module.__class__.__name__}_empty_group",
)
root_registry.register_hook(GroupOffloadingHook(empty_group, config=config), _GROUP_OFFLOADING)

if config.stream is not None:
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
# and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
Expand Down Expand Up @@ -838,7 +988,7 @@ def _apply_lazy_group_offloading_hook(
hook = GroupOffloadingHook(group, config=config)
registry.register_hook(hook, _GROUP_OFFLOADING)

lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook(pin_groups=config.pin_groups)
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)


Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/autoencoders/autoencoder_kl_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
# keys toignore when AlignDeviceHook moves inputs/outputs between devices
# these are shared mutable state modified in-place
_skip_keys = ["feat_cache", "feat_idx"]
_group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"]

@register_to_config
def __init__(
Expand Down
Loading