@@ -50,17 +50,47 @@ def retrieve_timesteps(
5050 sigmas : list [float ] | None = None ,
5151 ** kwargs ,
5252):
53+ r"""
54+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
55+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
56+
57+ Args:
58+ scheduler (`SchedulerMixin`):
59+ The scheduler to get timesteps from.
60+ num_inference_steps (`int`):
61+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
62+ must be `None`.
63+ device (`str` or `torch.device`, *optional*):
64+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
65+ timesteps (`list[int]`, *optional*):
66+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
67+ `num_inference_steps` and `sigmas` must be `None`.
68+ sigmas (`list[float]`, *optional*):
69+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
70+ `num_inference_steps` and `timesteps` must be `None`.
71+
72+ Returns:
73+ `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
74+ second element is the number of inference steps.
75+ """
5376 if timesteps is not None and sigmas is not None :
54- raise ValueError ("Only one of `timesteps` or `sigmas` can be passed." )
77+ raise ValueError ("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values " )
5578 if timesteps is not None :
79+ accepts_timesteps = "timesteps" in set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
80+ if not accepts_timesteps :
81+ raise ValueError (
82+ f"The current scheduler class { scheduler .__class__ } 's `set_timesteps` does not support custom"
83+ f" timestep schedules. Please check whether you are using the correct scheduler."
84+ )
5685 scheduler .set_timesteps (timesteps = timesteps , device = device , ** kwargs )
5786 timesteps = scheduler .timesteps
5887 num_inference_steps = len (timesteps )
5988 elif sigmas is not None :
6089 accept_sigmas = "sigmas" in set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
6190 if not accept_sigmas :
6291 raise ValueError (
63- f"The current scheduler class { scheduler .__class__ } 's `set_timesteps` does not support custom sigmas."
92+ f"The current scheduler class { scheduler .__class__ } 's `set_timesteps` does not support custom"
93+ f" sigmas schedules. Please check whether you are using the correct scheduler."
6494 )
6595 scheduler .set_timesteps (sigmas = sigmas , device = device , ** kwargs )
6696 timesteps = scheduler .timesteps
@@ -71,14 +101,11 @@ def retrieve_timesteps(
71101 return timesteps , num_inference_steps
72102
73103
74- # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
75104def _pack_latents (latents : torch .Tensor , patch_size : int = 1 , patch_size_t : int = 1 ) -> torch .Tensor :
76- # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape
77- # [B, C, F // p_t, p_t, H // p, p, W // p, p].
105+ # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
78106 # The patch dimensions are then permuted and collapsed into the channel dimension of shape:
79107 # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
80- # dim=0 is the batch size, dim=1 is the effective video sequence length,
81- # dim=2 is the effective number of input features
108+ # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
82109 batch_size , num_channels , num_frames , height , width = latents .shape
83110 post_patch_num_frames = num_frames // patch_size_t
84111 post_patch_height = height // patch_size
@@ -97,7 +124,6 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int
97124 return latents
98125
99126
100- # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents
101127def _normalize_latents (
102128 latents : torch .Tensor , latents_mean : torch .Tensor , latents_std : torch .Tensor , scaling_factor : float = 1.0
103129) -> torch .Tensor :
0 commit comments