diff --git a/generative/networks/schedulers/pndm.py b/generative/networks/schedulers/pndm.py index b729315f..3fb1b844 100644 --- a/generative/networks/schedulers/pndm.py +++ b/generative/networks/schedulers/pndm.py @@ -51,6 +51,20 @@ class PNDMPredictionType(StrEnum): EPSILON = "epsilon" V_PREDICTION = "v_prediction" +class PNDMTimestepSpacing(StrEnum): + """ + Set of valid inference timestep spacing names for the PNDM scheduler's `timestep_spacing` argument. + + See Table 2. of "Common Diffusion Noise Schedules and Sample Steps are Flawed" https://arxiv.org/abs/2305.08891 + + leading: first step is always included. + linspace: first and last step are always included. + trailing: last step is always included. + """ + + LEADING = "leading" + LINSPACE = "linspace" + TRAILING = "trailing" class PNDMScheduler(Scheduler): """ @@ -73,6 +87,7 @@ class PNDMScheduler(Scheduler): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. + timestep_spacing: member of PNDMTimestepSpacing. Controls which timesteps are included during inference. schedule_args: arguments to pass to the schedule function """ @@ -84,15 +99,19 @@ def __init__( set_alpha_to_one: bool = False, prediction_type: str = PNDMPredictionType.EPSILON, steps_offset: int = 0, + timestep_spacing: str = "leading", **schedule_args, ) -> None: super().__init__(num_train_timesteps, schedule, **schedule_args) if prediction_type not in PNDMPredictionType.__members__.values(): raise ValueError("Argument `prediction_type` must be a member of PNDMPredictionType") - self.prediction_type = prediction_type + if timestep_spacing not in PNDMTimestepSpacing.__members__.values(): + raise ValueError("Argument `timestep_spacing` must be a member of PNDMTimestepSpacing") + self.timestep_spacing = timestep_spacing + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # standard deviation of the initial noise distribution @@ -132,10 +151,16 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N self.num_inference_steps = num_inference_steps step_ratio = self.num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().astype(np.int64) - self._timesteps += self.steps_offset + if self.timestep_spacing == PNDMTimestepSpacing.LEADING: + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().astype(np.int64) + self._timesteps += self.steps_offset + elif self.timestep_spacing == PNDMTimestepSpacing.LINSPACE: + self._timesteps = np.linspace(0, self.num_train_timesteps-1, self.num_inference_steps, dtype=np.int64) + elif self.timestep_spacing == PNDMTimestepSpacing.TRAILING: + self._timesteps = np.round(np.flip(np.arange(self.num_train_timesteps, 0, -step_ratio)).astype(np.int64)) + self._timesteps -= 1 if self.skip_prk_steps: # for some models like stable diffusion the prk steps can/should be skipped to