@@ -95,8 +95,13 @@ def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
9595 raise ValueError ("diff_output_preservation requires a network to be set" )
9696 if self .train_config .train_text_encoder :
9797 raise ValueError ("diff_output_preservation is not supported with train_text_encoder" )
98-
99- # always do a prior prediction when doing diff output preservation
98+
99+ if self .train_config .blank_prompt_preservation :
100+ if self .network_config is None :
101+ raise ValueError ("blank_prompt_preservation requires a network to be set" )
102+
103+ if self .train_config .blank_prompt_preservation or self .train_config .diff_output_preservation :
104+ # always do a prior prediction when doing output preservation
100105 self .do_prior_prediction = True
101106
102107 # store the loss target for a batch so we can use it in a loss
@@ -343,6 +348,13 @@ def hook_before_train_loop(self):
343348 self .sd .text_encoder_to ("cpu" )
344349 flush ()
345350
351+ if self .train_config .blank_prompt_preservation and self .cached_blank_embeds is None :
352+ # make sure we have this if not unloading
353+ self .cached_blank_embeds = self .sd .encode_prompt ("" , ** encode_kwargs ).to (
354+ self .device_torch ,
355+ dtype = self .sd .torch_dtype
356+ ).detach ()
357+
346358 if self .train_config .diffusion_feature_extractor_path is not None :
347359 vae = self .sd .vae
348360 # if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer":
@@ -1769,6 +1781,14 @@ def get_adapter_multiplier():
17691781 if self .train_config .diff_output_preservation :
17701782 prior_embeds_to_use = self .diff_output_preservation_embeds .expand_to_batch (noisy_latents .shape [0 ])
17711783
1784+ if self .train_config .blank_prompt_preservation :
1785+ blank_embeds = self .cached_blank_embeds .clone ().detach ().to (
1786+ self .device_torch , dtype = dtype
1787+ )
1788+ prior_embeds_to_use = concat_prompt_embeds (
1789+ [blank_embeds ] * noisy_latents .shape [0 ]
1790+ )
1791+
17721792 prior_pred = self .get_prior_prediction (
17731793 noisy_latents = noisy_latents ,
17741794 conditional_embeds = prior_embeds_to_use ,
@@ -1944,7 +1964,8 @@ def get_adapter_multiplier():
19441964 prior_to_calculate_loss = prior_pred
19451965 # if we are doing diff_output_preservation and not noing inverted masked prior
19461966 # then we need to send none here so it will not target the prior
1947- if self .train_config .diff_output_preservation and not do_inverted_masked_prior :
1967+ doing_preservation = self .train_config .diff_output_preservation or self .train_config .blank_prompt_preservation
1968+ if doing_preservation and not do_inverted_masked_prior :
19481969 prior_to_calculate_loss = None
19491970
19501971 loss = self .calculate_loss (
@@ -1957,24 +1978,34 @@ def get_adapter_multiplier():
19571978 prior_pred = prior_to_calculate_loss ,
19581979 )
19591980
1960- if self .train_config .diff_output_preservation :
1981+ if self .train_config .diff_output_preservation or self . train_config . blank_prompt_preservation :
19611982 # send the loss backwards otherwise checkpointing will fail
19621983 self .accelerator .backward (loss )
19631984 normal_loss = loss .detach () # dont send backward again
19641985
1965- dop_embeds = self .diff_output_preservation_embeds .expand_to_batch (noisy_latents .shape [0 ])
1966- dop_pred = self .predict_noise (
1986+ with torch .no_grad ():
1987+ if self .train_config .diff_output_preservation :
1988+ preservation_embeds = self .diff_output_preservation_embeds .expand_to_batch (noisy_latents .shape [0 ])
1989+ elif self .train_config .blank_prompt_preservation :
1990+ blank_embeds = self .cached_blank_embeds .clone ().detach ().to (
1991+ self .device_torch , dtype = dtype
1992+ )
1993+ preservation_embeds = concat_prompt_embeds (
1994+ [blank_embeds ] * noisy_latents .shape [0 ]
1995+ )
1996+ preservation_pred = self .predict_noise (
19671997 noisy_latents = noisy_latents .to (self .device_torch , dtype = dtype ),
19681998 timesteps = timesteps ,
1969- conditional_embeds = dop_embeds .to (self .device_torch , dtype = dtype ),
1999+ conditional_embeds = preservation_embeds .to (self .device_torch , dtype = dtype ),
19702000 unconditional_embeds = unconditional_embeds ,
19712001 batch = batch ,
19722002 ** pred_kwargs
19732003 )
1974- dop_loss = torch .nn .functional .mse_loss (dop_pred , prior_pred ) * self .train_config .diff_output_preservation_multiplier
1975- self .accelerator .backward (dop_loss )
1976-
1977- loss = normal_loss + dop_loss
2004+ multiplier = self .train_config .diff_output_preservation_multiplier if self .train_config .diff_output_preservation else self .train_config .blank_prompt_preservation_multiplier
2005+ preservation_loss = torch .nn .functional .mse_loss (preservation_pred , prior_pred ) * multiplier
2006+ self .accelerator .backward (preservation_loss )
2007+
2008+ loss = normal_loss + preservation_loss
19782009 loss = loss .clone ().detach ()
19792010 # require grad again so the backward wont fail
19802011 loss .requires_grad_ (True )
0 commit comments