Skip to content

Commit a972af0

Browse files
committed
Fix check_copies: sync retrieve_timesteps and drop unsupported Copied from tags
1 parent a3d9b04 commit a972af0

4 files changed

Lines changed: 34 additions & 13 deletions

File tree

src/diffusers/modular_pipelines/ltx/before_denoise.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,47 @@ def retrieve_timesteps(
5050
sigmas: list[float] | None = None,
5151
**kwargs,
5252
):
53+
r"""
54+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
55+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
56+
57+
Args:
58+
scheduler (`SchedulerMixin`):
59+
The scheduler to get timesteps from.
60+
num_inference_steps (`int`):
61+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
62+
must be `None`.
63+
device (`str` or `torch.device`, *optional*):
64+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
65+
timesteps (`list[int]`, *optional*):
66+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
67+
`num_inference_steps` and `sigmas` must be `None`.
68+
sigmas (`list[float]`, *optional*):
69+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
70+
`num_inference_steps` and `timesteps` must be `None`.
71+
72+
Returns:
73+
`tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
74+
second element is the number of inference steps.
75+
"""
5376
if timesteps is not None and sigmas is not None:
54-
raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
77+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
5578
if timesteps is not None:
79+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
80+
if not accepts_timesteps:
81+
raise ValueError(
82+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
83+
f" timestep schedules. Please check whether you are using the correct scheduler."
84+
)
5685
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
5786
timesteps = scheduler.timesteps
5887
num_inference_steps = len(timesteps)
5988
elif sigmas is not None:
6089
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
6190
if not accept_sigmas:
6291
raise ValueError(
63-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom sigmas."
92+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
93+
f" sigmas schedules. Please check whether you are using the correct scheduler."
6494
)
6595
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
6696
timesteps = scheduler.timesteps
@@ -71,14 +101,11 @@ def retrieve_timesteps(
71101
return timesteps, num_inference_steps
72102

73103

74-
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
75104
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
76-
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape
77-
# [B, C, F // p_t, p_t, H // p, p, W // p, p].
105+
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
78106
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
79107
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
80-
# dim=0 is the batch size, dim=1 is the effective video sequence length,
81-
# dim=2 is the effective number of input features
108+
# dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
82109
batch_size, num_channels, num_frames, height, width = latents.shape
83110
post_patch_num_frames = num_frames // patch_size_t
84111
post_patch_height = height // patch_size
@@ -97,7 +124,6 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int
97124
return latents
98125

99126

100-
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents
101127
def _normalize_latents(
102128
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
103129
) -> torch.Tensor:

src/diffusers/modular_pipelines/ltx/decoders.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
logger = logging.get_logger(__name__)
2929

3030

31-
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents
3231
def _unpack_latents(
3332
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
3433
) -> torch.Tensor:
@@ -42,7 +41,6 @@ def _unpack_latents(
4241
return latents
4342

4443

45-
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents
4644
def _denormalize_latents(
4745
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
4846
) -> torch.Tensor:

src/diffusers/modular_pipelines/ltx/denoise.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
logger = logging.get_logger(__name__)
3535

3636

37-
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
3837
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
3938
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape
4039
# [B, C, F // p_t, p_t, H // p, p, W // p, p].
@@ -60,7 +59,6 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int
6059
return latents
6160

6261

63-
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents
6462
def _unpack_latents(
6563
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
6664
) -> torch.Tensor:

src/diffusers/modular_pipelines/ltx/encoders.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ def retrieve_latents(
186186
raise AttributeError("Could not access latents of provided encoder_output")
187187

188188

189-
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents
190189
def _normalize_latents(
191190
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
192191
) -> torch.Tensor:

0 commit comments

Comments
 (0)