Skip to content

Multi node torch 4 pr #1933

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
d8f0d7e
Update generate_final_report.py
arjunsuresh Oct 31, 2024
390a8cb
Merge branch 'master' into dev
arjunsuresh Nov 7, 2024
6b1a0f8
Fix sdxl (#1911)
arjunsuresh Nov 7, 2024
a4ba51f
Fixes for filtering invalid results
arjunsuresh Nov 7, 2024
7097ef5
Merge branch 'master' into dev
arjunsuresh Nov 7, 2024
190ee41
Merge 7097ef540bfa0286c65c81fbfdcb300e6d54f770 into d3c01ed3de6618a8d…
arjunsuresh Nov 7, 2024
451b310
[Automated Commit] Format Codebase
arjunsuresh Nov 7, 2024
4c109ea
Update preprocess_submission.py
arjunsuresh Nov 7, 2024
40c1fe0
Added an option to pass in sample_ids.txt for SDXL accuracy check
arjunsuresh Nov 7, 2024
2a61df9
Merge 40c1fe0c28364b243b5944b3569000611ddf2b7d into d3c01ed3de6618a8d…
arjunsuresh Nov 7, 2024
89a2ffe
[Automated Commit] Format Codebase
arjunsuresh Nov 7, 2024
69ffdc0
Update accuracy_coco.py
arjunsuresh Nov 7, 2024
76b703b
Merge 69ffdc0aa783f9127af612a7de57c6329703c1dc into d3c01ed3de6618a8d…
arjunsuresh Nov 7, 2024
d1d642e
[Automated Commit] Format Codebase
arjunsuresh Nov 7, 2024
8d3b8ab
Fix typo
arjunsuresh Nov 7, 2024
b09b1ef
Not use default for sample_ids.txt
arjunsuresh Nov 8, 2024
857494f
Merge branch 'master' into dev
arjunsuresh Nov 12, 2024
df5049d
Update requirements.txt (#1907)
arjunsuresh Nov 14, 2024
a7e8c8a
Fix preprocess_sudbmission for a bug
arjunsuresh Nov 15, 2024
213c239
Fix conflict
arjunsuresh Nov 15, 2024
8915a90
Update submission_checker.py | Removed TEST05
arjunsuresh Nov 16, 2024
36d5b74
Merge branch 'master' into dev
arjunsuresh Nov 16, 2024
941c0c4
move changes to fork 4 pr
zixianwang2022 Nov 17, 2024
dffdd59
update changes with fork 4 pr
zixianwang2022 Nov 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 51 additions & 61 deletions text_to_image/backend_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def __init__(
model_id="xl",
guidance=8,
steps=20,
batch_size=1,
batch_size=2,
device="cuda",
precision="fp32",
precision="fp16",
negative_prompt="normal quality, low quality, worst quality, low res, blurry, nsfw, nude",
):
super(BackendPytorch, self).__init__()
Expand Down Expand Up @@ -57,39 +57,41 @@ def image_format(self):
return "NCHW"

def load(self):
if self.model_path is None:
log.warning(
"Model path not provided, running with default hugging face weights\n"
"This may not be valid for official submissions"
)
self.scheduler = EulerDiscreteScheduler.from_pretrained(
self.model_id, subfolder="scheduler"
)
self.pipe = StableDiffusionXLPipeline.from_pretrained(
self.model_id,
scheduler=self.scheduler,
safety_checker=None,
add_watermarker=False,
variant="fp16" if (self.dtype == torch.float16) else None,
torch_dtype=self.dtype,
)
# if self.model_path is None:
# log.warning(
# "Model path not provided, running with default hugging face weights\n"
# "This may not be valid for official submissions"
# )
self.scheduler = EulerDiscreteScheduler.from_pretrained(
self.model_id, subfolder="scheduler"
)
self.pipe = StableDiffusionXLPipeline.from_pretrained(
self.model_id,
scheduler=self.scheduler,
safety_checker=None,
add_watermarker=False,
# variant="fp16" if (self.dtype == torch.float16) else None,
variant="fp16" ,
torch_dtype=self.dtype,
)
# self.pipe.unet = torch.compile(self.pipe.unet, mode="reduce-overhead", fullgraph=True)
else:
self.scheduler = EulerDiscreteScheduler.from_pretrained(
os.path.join(self.model_path, "checkpoint_scheduler"),
subfolder="scheduler",
)
self.pipe = StableDiffusionXLPipeline.from_pretrained(
os.path.join(self.model_path, "checkpoint_pipe"),
scheduler=self.scheduler,
safety_checker=None,
add_watermarker=False,
torch_dtype=self.dtype,
)
# else:
# self.scheduler = EulerDiscreteScheduler.from_pretrained(
# os.path.join(self.model_path, "checkpoint_scheduler"),
# subfolder="scheduler",
# )
# self.pipe = StableDiffusionXLPipeline.from_pretrained(
# os.path.join(self.model_path, "checkpoint_pipe"),
# scheduler=self.scheduler,
# safety_checker=None,
# add_watermarker=False,
# variant="fp16" if (self.dtype == torch.float16) else None,
# torch_dtype=self.dtype,
# )
# self.pipe.unet = torch.compile(self.pipe.unet, mode="reduce-overhead", fullgraph=True)

self.pipe.to(self.device)
# self.pipe.set_progress_bar_config(disable=True)
#self.pipe.set_progress_bar_config(disable=True)

self.negative_prompt_tokens = self.pipe.tokenizer(
self.convert_prompt(self.negative_prompt, self.pipe.tokenizer),
Expand Down Expand Up @@ -210,15 +212,13 @@ def encode_tokens(
text_input_ids.to(device), output_hidden_states=True
)

# We are only ALWAYS interested in the pooled output of the
# final text encoder
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
# "2" because SDXL always indexes from the penultimate layer.
prompt_embeds = prompt_embeds.hidden_states[-(
clip_skip + 2)]
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]

prompt_embeds_list.append(prompt_embeds)

Expand All @@ -234,8 +234,7 @@ def encode_tokens(
and zero_out_negative_prompt
):
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(
pooled_prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
elif do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
Expand All @@ -262,35 +261,30 @@ def encode_tokens(
uncond_input.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the
# final text encoder
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

negative_prompt_embeds_list.append(negative_prompt_embeds)

negative_prompt_embeds = torch.concat(
negative_prompt_embeds_list, dim=-1)
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)

if pipe.text_encoder_2 is not None:
prompt_embeds = prompt_embeds.to(
dtype=pipe.text_encoder_2.dtype, device=device
)
else:
prompt_embeds = prompt_embeds.to(
dtype=pipe.unet.dtype, device=device)
prompt_embeds = prompt_embeds.to(dtype=pipe.unet.dtype, device=device)

bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps
# friendly method
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(
bs_embed * num_images_per_prompt, seq_len, -1
)

if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per
# prompt, using mps friendly method
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]

if pipe.text_encoder_2 is not None:
Expand Down Expand Up @@ -322,7 +316,7 @@ def encode_tokens(
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)

def prepare_inputs(self, inputs, i):
if self.batch_size == 1:
return self.encode_tokens(
Expand All @@ -337,7 +331,7 @@ def prepare_inputs(self, inputs, i):
negative_prompt_embeds = []
pooled_prompt_embeds = []
negative_pooled_prompt_embeds = []
for prompt in inputs[i: min(i + self.batch_size, len(inputs))]:
for prompt in inputs[i:min(i+self.batch_size, len(inputs))]:
assert isinstance(prompt, dict)
text_input = prompt["input_tokens"]
text_input_2 = prompt["input_tokens_2"]
Expand All @@ -358,26 +352,19 @@ def prepare_inputs(self, inputs, i):
pooled_prompt_embeds.append(p_p_e)
negative_pooled_prompt_embeds.append(n_p_p_e)


prompt_embeds = torch.cat(prompt_embeds)
negative_prompt_embeds = torch.cat(negative_prompt_embeds)
pooled_prompt_embeds = torch.cat(pooled_prompt_embeds)
negative_pooled_prompt_embeds = torch.cat(
negative_pooled_prompt_embeds)
return (
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)
negative_pooled_prompt_embeds = torch.cat(negative_pooled_prompt_embeds)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

def predict(self, inputs):
images = []
with torch.no_grad():
for i in range(0, len(inputs), self.batch_size):
latents_input = [
inputs[idx]["latents"]
for idx in range(i, min(i + self.batch_size, len(inputs)))
]
print (f'self.steps BEFORE pipe: {self.steps}')
latents_input = [inputs[idx]["latents"] for idx in range(i, min(i+self.batch_size, len(inputs)))]
latents_input = torch.cat(latents_input).to(self.device)
(
prompt_embeds,
Expand All @@ -392,8 +379,11 @@ def predict(self, inputs):
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
guidance_scale=self.guidance,
num_inference_steps=self.steps,
# num_inference_steps=20,
output_type="pt",
latents=latents_input,
).images
print (f'self.steps AFTER pipe: {self.steps}')
images.extend(generated)
return images

58 changes: 23 additions & 35 deletions text_to_image/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,23 @@ def __init__(
**kwargs,
):
super().__init__()
self.captions_df = pd.read_csv(
f"{data_path}/captions/captions.tsv", sep="\t")
self.captions_df = pd.read_csv(f"{data_path}/captions/captions.tsv", sep="\t")
self.image_size = image_size
self.preprocessed_dir = os.path.abspath(f"{data_path}/preprocessed/")
self.img_dir = os.path.abspath(f"{data_path}/validation/data/")
self.name = name

self.pipe_tokenizer = pipe_tokenizer
self.pipe_tokenizer_2 = pipe_tokenizer_2

# Preprocess prompts
self.captions_df["input_tokens"] = self.captions_df["caption"].apply(
lambda x: self.preprocess(x, pipe_tokenizer)
# lambda x: self.preprocess(x, pipe_tokenizer)
lambda x: x
)
self.captions_df["input_tokens_2"] = self.captions_df["caption"].apply(
lambda x: self.preprocess(x, pipe_tokenizer_2)
# lambda x: self.preprocess(x, pipe_tokenizer_2)
lambda x: x
)
self.latent_dtype = latent_dtype
self.latent_device = latent_device if torch.cuda.is_available() else "cpu"
Expand Down Expand Up @@ -117,10 +121,7 @@ def get_item_count(self):
return len(self.captions_df)

def get_img(self, id):
img = Image.open(
self.img_dir +
"/" +
self.captions_df.loc[id]["file_name"])
img = Image.open(self.img_dir + "/" + self.captions_df.loc[id]["file_name"])
return self.image_to_tensor(img)

def get_imgs(self, id_list):
Expand All @@ -141,11 +142,7 @@ def get_item_loc(self, id):

class PostProcessCoco:
def __init__(
self,
device="cpu",
dtype="uint8",
statistics_path=os.path.join(
os.path.dirname(__file__), "tools", "val2014.npz"),
self, device="cpu", dtype="uint8", statistics_path=os.path.join(os.path.dirname(__file__), "tools", "val2014.npz")
):
self.results = []
self.good = 0
Expand All @@ -167,33 +164,27 @@ def add_results(self, results):
def __call__(self, results, ids, expected=None, result_dict=None):
self.content_ids.extend(ids)
return [
(t.cpu().permute(1, 2, 0).float().numpy() * 255)
.round()
.astype(self.numpy_dtype)
(t.cpu().permute(1, 2, 0).float().numpy() * 255).round().astype(self.numpy_dtype)
for t in results
]

def save_images(self, ids, ds):
info = []
idx = {}
for i, image_id in enumerate(self.content_ids):
if image_id in ids:
idx[image_id] = i
for i, id in enumerate(self.content_ids):
if id in ids:
idx[id] = i
if not os.path.exists("images/"):
os.makedirs("images/", exist_ok=True)
for image_id in ids:
if not idx.get(image_id):
print(
f"image id {image_id} is missing in the results. Hence not saved.")
continue
caption = ds.get_caption(image_id)
generated = Image.fromarray(self.results[idx[image_id]])
image_path_tmp = f"images/{self.content_ids[idx[image_id]]}.png"
for id in ids:
caption = ds.get_caption(id)
generated = Image.fromarray(self.results[idx[id]])
image_path_tmp = f"images/{self.content_ids[idx[id]]}.png"
generated.save(image_path_tmp)
info.append((self.content_ids[idx[image_id]], caption))
info.append((self.content_ids[idx[id]], caption))
with open("images/captions.txt", "w+") as f:
for image_id, caption in info:
f.write(f"{image_id} {caption}\n")
for id, caption in info:
f.write(f"{id} {caption}\n")

def start(self):
self.results = []
Expand All @@ -209,10 +200,7 @@ def finalize(self, result_dict, ds=None, output_dir=None):
100 * clip.get_clip_score(caption, generated).item()
)

fid_score = compute_fid(
self.results,
self.statistics_path,
self.device)
fid_score = compute_fid(self.results, self.statistics_path, self.device)
result_dict["FID_SCORE"] = fid_score
result_dict["CLIP_SCORE"] = np.mean(self.clip_scores)

Expand Down
Loading
Loading