diff --git a/eval_llada.py b/eval_llada.py index 9e2a9a7..4588521 100644 --- a/eval_llada.py +++ b/eval_llada.py @@ -274,7 +274,8 @@ def _tokenize(e): generated_answer = self.tokenizer.decode(generated_answer_ids, skip_special_tokens=True) out.append(generated_answer) - self.accelerator.wait_for_everyone() + if self.accelerator is not None: + self.accelerator.wait_for_everyone() return out diff --git a/generate.py b/generate.py index c2cef3b..f97799b 100644 --- a/generate.py +++ b/generate.py @@ -55,7 +55,7 @@ def generate(model, prompt, steps=128, gen_length=128, block_length=128, tempera remasking: Remasking strategy. 'low_confidence' or 'random'. mask_id: The toke id of [MASK] is 126336. ''' - x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device) + x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device) x[:, :prompt.shape[1]] = prompt.clone() prompt_index = (x != mask_id)