Skip to content

Commit ee206cf

Browse files
committed
Added blank prompt preservation
1 parent ca57ffc commit ee206cf

File tree

7 files changed

+143
-28
lines changed

7 files changed

+143
-28
lines changed

extensions_built_in/sd_trainer/SDTrainer.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,13 @@ def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
9595
raise ValueError("diff_output_preservation requires a network to be set")
9696
if self.train_config.train_text_encoder:
9797
raise ValueError("diff_output_preservation is not supported with train_text_encoder")
98-
99-
# always do a prior prediction when doing diff output preservation
98+
99+
if self.train_config.blank_prompt_preservation:
100+
if self.network_config is None:
101+
raise ValueError("blank_prompt_preservation requires a network to be set")
102+
103+
if self.train_config.blank_prompt_preservation or self.train_config.diff_output_preservation:
104+
# always do a prior prediction when doing output preservation
100105
self.do_prior_prediction = True
101106

102107
# store the loss target for a batch so we can use it in a loss
@@ -343,6 +348,13 @@ def hook_before_train_loop(self):
343348
self.sd.text_encoder_to("cpu")
344349
flush()
345350

351+
if self.train_config.blank_prompt_preservation and self.cached_blank_embeds is None:
352+
# make sure we have this if not unloading
353+
self.cached_blank_embeds = self.sd.encode_prompt("", **encode_kwargs).to(
354+
self.device_torch,
355+
dtype=self.sd.torch_dtype
356+
).detach()
357+
346358
if self.train_config.diffusion_feature_extractor_path is not None:
347359
vae = self.sd.vae
348360
# if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer":
@@ -1769,6 +1781,14 @@ def get_adapter_multiplier():
17691781
if self.train_config.diff_output_preservation:
17701782
prior_embeds_to_use = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
17711783

1784+
if self.train_config.blank_prompt_preservation:
1785+
blank_embeds = self.cached_blank_embeds.clone().detach().to(
1786+
self.device_torch, dtype=dtype
1787+
)
1788+
prior_embeds_to_use = concat_prompt_embeds(
1789+
[blank_embeds] * noisy_latents.shape[0]
1790+
)
1791+
17721792
prior_pred = self.get_prior_prediction(
17731793
noisy_latents=noisy_latents,
17741794
conditional_embeds=prior_embeds_to_use,
@@ -1944,7 +1964,8 @@ def get_adapter_multiplier():
19441964
prior_to_calculate_loss = prior_pred
19451965
# if we are doing diff_output_preservation and not noing inverted masked prior
19461966
# then we need to send none here so it will not target the prior
1947-
if self.train_config.diff_output_preservation and not do_inverted_masked_prior:
1967+
doing_preservation = self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation
1968+
if doing_preservation and not do_inverted_masked_prior:
19481969
prior_to_calculate_loss = None
19491970

19501971
loss = self.calculate_loss(
@@ -1957,24 +1978,34 @@ def get_adapter_multiplier():
19571978
prior_pred=prior_to_calculate_loss,
19581979
)
19591980

1960-
if self.train_config.diff_output_preservation:
1981+
if self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation:
19611982
# send the loss backwards otherwise checkpointing will fail
19621983
self.accelerator.backward(loss)
19631984
normal_loss = loss.detach() # dont send backward again
19641985

1965-
dop_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
1966-
dop_pred = self.predict_noise(
1986+
with torch.no_grad():
1987+
if self.train_config.diff_output_preservation:
1988+
preservation_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
1989+
elif self.train_config.blank_prompt_preservation:
1990+
blank_embeds = self.cached_blank_embeds.clone().detach().to(
1991+
self.device_torch, dtype=dtype
1992+
)
1993+
preservation_embeds = concat_prompt_embeds(
1994+
[blank_embeds] * noisy_latents.shape[0]
1995+
)
1996+
preservation_pred = self.predict_noise(
19671997
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
19681998
timesteps=timesteps,
1969-
conditional_embeds=dop_embeds.to(self.device_torch, dtype=dtype),
1999+
conditional_embeds=preservation_embeds.to(self.device_torch, dtype=dtype),
19702000
unconditional_embeds=unconditional_embeds,
19712001
batch=batch,
19722002
**pred_kwargs
19732003
)
1974-
dop_loss = torch.nn.functional.mse_loss(dop_pred, prior_pred) * self.train_config.diff_output_preservation_multiplier
1975-
self.accelerator.backward(dop_loss)
1976-
1977-
loss = normal_loss + dop_loss
2004+
multiplier = self.train_config.diff_output_preservation_multiplier if self.train_config.diff_output_preservation else self.train_config.blank_prompt_preservation_multiplier
2005+
preservation_loss = torch.nn.functional.mse_loss(preservation_pred, prior_pred) * multiplier
2006+
self.accelerator.backward(preservation_loss)
2007+
2008+
loss = normal_loss + preservation_loss
19782009
loss = loss.clone().detach()
19792010
# require grad again so the backward wont fail
19802011
loss.requires_grad_(True)

toolkit/config_modules.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,11 @@ def __init__(self, **kwargs):
451451
self.diff_output_preservation_multiplier = kwargs.get('diff_output_preservation_multiplier', 1.0)
452452
# If the trigger word is in the prompt, we will use this class name to replace it eg. "sks woman" -> "woman"
453453
self.diff_output_preservation_class = kwargs.get('diff_output_preservation_class', '')
454-
454+
455+
# blank prompt preservation will preserve the model's knowledge of a blank prompt
456+
self.blank_prompt_preservation = kwargs.get('blank_prompt_preservation', False)
457+
self.blank_prompt_preservation_multiplier = kwargs.get('blank_prompt_preservation_multiplier', 1.0)
458+
455459
# legacy
456460
if match_adapter_assist and self.match_adapter_chance == 0.0:
457461
self.match_adapter_chance = 1.0
@@ -1318,5 +1322,8 @@ def validate_configs(
13181322
if model_config.arch == 'qwen_image_edit':
13191323
if train_config.unload_text_encoder:
13201324
raise ValueError("Cannot cache unload text encoder with qwen_image_edit model. Control images are encoded with text embeddings. You can cache the text embeddings though")
1325+
1326+
if train_config.diff_output_preservation and train_config.blank_prompt_preservation:
1327+
raise ValueError("Cannot use both differential output preservation and blank prompt preservation at the same time. Please set one of them to False.")
13211328

13221329

ui/src/app/jobs/new/SimpleJob.tsx

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,12 @@ export default function SimpleJob({
215215
</FormGroup>
216216
)}
217217
{modelArch?.additionalSections?.includes('model.qie.match_target_res') && (
218-
<Checkbox
219-
label="Match Target Res"
220-
docKey="model.qie.match_target_res"
221-
checked={jobConfig.config.process[0].model.model_kwargs.match_target_res}
222-
onChange={value => setJobConfig(value, 'config.process[0].model.model_kwargs.match_target_res')}
223-
/>
218+
<Checkbox
219+
label="Match Target Res"
220+
docKey="model.qie.match_target_res"
221+
checked={jobConfig.config.process[0].model.model_kwargs.match_target_res}
222+
onChange={value => setJobConfig(value, 'config.process[0].model.model_kwargs.match_target_res')}
223+
/>
224224
)}
225225
{modelArch?.additionalSections?.includes('model.layer_offloading') && (
226226
<>
@@ -586,16 +586,27 @@ export default function SimpleJob({
586586
</FormGroup>
587587
</div>
588588
<div>
589+
{disableSections.includes('train.diff_output_preservation') ||
590+
disableSections.includes('train.blank_prompt_preservation') ? null : (
591+
<FormGroup label="Regularization">
592+
<></>
593+
</FormGroup>
594+
)}
589595
{disableSections.includes('train.diff_output_preservation') ? null : (
590596
<>
591-
<FormGroup label="Regularization">
592-
<Checkbox
593-
label="Differential Output Preservation"
594-
className="pt-1"
595-
checked={jobConfig.config.process[0].train.diff_output_preservation || false}
596-
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')}
597-
/>
598-
</FormGroup>
597+
<Checkbox
598+
label="Differential Output Preservation"
599+
docKey={'train.diff_output_preservation'}
600+
className="pt-1"
601+
checked={jobConfig.config.process[0].train.diff_output_preservation || false}
602+
onChange={value => {
603+
setJobConfig(value, 'config.process[0].train.diff_output_preservation');
604+
if (value && jobConfig.config.process[0].train.blank_prompt_preservation) {
605+
// only one can be enabled at a time
606+
setJobConfig(false, 'config.process[0].train.blank_prompt_preservation');
607+
}
608+
}}
609+
/>
599610
{jobConfig.config.process[0].train.diff_output_preservation && (
600611
<>
601612
<NumberInput
@@ -610,7 +621,7 @@ export default function SimpleJob({
610621
/>
611622
<TextInput
612623
label="DOP Preservation Class"
613-
className="pt-2"
624+
className="pt-2 pb-4"
614625
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
615626
onChange={value =>
616627
setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')
@@ -621,6 +632,39 @@ export default function SimpleJob({
621632
)}
622633
</>
623634
)}
635+
{disableSections.includes('train.blank_prompt_preservation') ? null : (
636+
<>
637+
<Checkbox
638+
label="Blank Prompt Preservation"
639+
docKey={'train.blank_prompt_preservation'}
640+
className="pt-1"
641+
checked={jobConfig.config.process[0].train.blank_prompt_preservation || false}
642+
onChange={value => {
643+
setJobConfig(value, 'config.process[0].train.blank_prompt_preservation');
644+
if (value && jobConfig.config.process[0].train.diff_output_preservation) {
645+
// only one can be enabled at a time
646+
setJobConfig(false, 'config.process[0].train.diff_output_preservation');
647+
}
648+
}}
649+
/>
650+
{jobConfig.config.process[0].train.blank_prompt_preservation && (
651+
<>
652+
<NumberInput
653+
label="BPP Loss Multiplier"
654+
className="pt-2"
655+
value={
656+
(jobConfig.config.process[0].train.blank_prompt_preservation_multiplier as number) || 1.0
657+
}
658+
onChange={value =>
659+
setJobConfig(value, 'config.process[0].train.blank_prompt_preservation_multiplier')
660+
}
661+
placeholder="eg. 1.0"
662+
min={0}
663+
/>
664+
</>
665+
)}
666+
</>
667+
)}
624668
</div>
625669
</div>
626670
</Card>

ui/src/app/jobs/new/options.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ type DisableableSections =
99
| 'network.conv'
1010
| 'trigger_word'
1111
| 'train.diff_output_preservation'
12+
| 'train.blank_prompt_preservation'
1213
| 'train.unload_text_encoder'
1314
| 'slider';
1415

ui/src/docs.tsx

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,36 @@ const docs: { [key: string]: ConfigDoc } = {
228228
</>
229229
),
230230
},
231+
'train.diff_output_preservation': {
232+
title: 'Differential Output Preservation',
233+
description: (
234+
<>
235+
Differential Output Preservation (DOP) is a technique to help preserve class of the trained concept during
236+
training. For this, you must have a trigger word set to differentiate your concept from its class. For instance,
237+
You may be training a woman named Alice. Your trigger word may be "Alice". The class is "woman", since Alice is
238+
a woman. We want to teach the model to remember what it knows about the class "woman" while teaching it what is
239+
different about Alice. During training, the trainer will make a prediction with your LoRA bypassed and your
240+
trigger word in the prompt replaced with the class word. Making "photo of Alice" become "photo of woman". This
241+
prediction is called the prior prediction. Each step, we will do the normal training step, but also do another
242+
step with this prior prediction and the class prompt in order to teach our LoRA to preserve the knowledge of the
243+
class. This should not only improve the performance of your trained concept, but also allow you to do things
244+
like "Alice standing next to a woman" and not make both of the people look like Alice.
245+
</>
246+
),
247+
},
248+
'train.blank_prompt_preservation': {
249+
title: 'Blank Prompt Preservation',
250+
description: (
251+
<>
252+
Blank Prompt Preservation (BPP) is a technique to help preserve the current models knowledge when unprompted.
253+
This will not only help the model become more flexible, but will also help the quality of your concept during
254+
inference, especially when a model uses CFG (Classifier Free Guidance) on inference. At each step during
255+
training, a prior prediction is made with a blank prompt and with the LoRA disabled. This prediction is then
256+
used as a target on an additional training step with a blank prompt, to preserve the model's knowledge when no
257+
prompt is given. This helps the model to not overfit to the prompt and retain its generalization capabilities.
258+
</>
259+
),
260+
},
231261
};
232262

233263
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {

ui/src/types.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ export interface TrainConfig {
135135
diff_output_preservation: boolean;
136136
diff_output_preservation_multiplier: number;
137137
diff_output_preservation_class: string;
138+
blank_prompt_preservation?: boolean;
139+
blank_prompt_preservation_multiplier?: number;
138140
switch_boundary_every: number;
139141
loss_type: 'mse' | 'mae' | 'wavelet' | 'stepped';
140142
}

version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
VERSION = "0.7.1"
1+
VERSION = "0.7.2"

0 commit comments

Comments
 (0)