diff --git a/eval_llada.py b/eval_llada.py index 9e2a9a70..3d5a3ada 100644 --- a/eval_llada.py +++ b/eval_llada.py @@ -3,8 +3,6 @@ ''' import accelerate import torch -import re -from pathlib import Path import random import numpy as np import torch.nn.functional as F @@ -70,22 +68,18 @@ def __init__( self.accelerator = accelerator else: self.accelerator = None - - model_kwargs = {} - if self.accelerator is not None: - model_kwargs.update({'device_map': {'': f'{self.accelerator.device}'}}) - self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, **model_kwargs) + self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) self.model.eval() - self.device = torch.device(device) if self.accelerator is not None: - self.model = self.accelerator.prepare(self.model) - self.device = torch.device(f'{self.accelerator.device}') + self.device = self.accelerator.device self._rank = self.accelerator.local_process_index self._world_size = self.accelerator.num_processes else: - self.model = self.model.to(device) + self.device = torch.device(device) + + self.model = self.model.to(device) self.mask_id = mask_id self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) @@ -194,6 +188,11 @@ def suffix_greedy_prediction(self, prefix, target): return correct def _encode_pair(self, context, continuation): + """ + Move spaces at the end of context to the beginning of continuation, and + encode both context and continuation into token ids. This is modified from + `lm_eval.api.model.TemplateLM._encode_pair`. + """ n_spaces = len(context) - len(context.rstrip()) if n_spaces > 0: continuation = context[-n_spaces:] + continuation @@ -207,36 +206,44 @@ def _encode_pair(self, context, continuation): return context_enc, continuation_enc - def loglikelihood(self, requests): - def _tokenize(e): - prefix, target = self._encode_pair(e["prefix"], e["target"]) - return { - "prefix_text": e["prefix"], - "target_text": e["target"], - "prefix": prefix, - "target": target, - } - - ds = [] - ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests] - ds = Dataset.from_list(ds) - ds = ds.map(_tokenize) - ds = ds.with_format("torch") - prompt_len = [len(x["prefix"]) + len(x["target"]) for x in ds] - - assert max(prompt_len) <= 4096 - + def loglikelihood(self, requests: list[Instance]): + """Compute log-likelihood of generating a continuation from a context. + Downstream tasks should attempt to use loglikelihood instead of other + LM calls whenever possible. + + :param requests: list[Instance] + A list of Instance objects, with property `args` which returns a tuple (context, continuation). + `context: str` + Context string. Implementations of LM must be able to handle an + empty context string. + `continuation: str` + The continuation over which log likelihood will be calculated. If + there is a word boundary, the space should be in the continuation. + For example, context="hello" continuation=" world" is correct. + + :return: list[tuple[float, bool]] + A list of pairs (logprob, isgreedy) + `logprob: float` + The log probability of `continuation`. + `isgreedy`: + Whether `continuation` would be generated by greedy sampling from `context`. + """ out = [] with torch.no_grad(): - for elem in tqdm(ds, desc="Computing likelihood..."): - prefix = elem["prefix"] - target = elem["target"] - - ll = self.get_loglikelihood(prefix, target) + for instance in tqdm(requests, desc="Computing likelihood..."): + context, continuation = self._encode_pair(*instance.args) + assert len(context) + len(continuation) <= self.max_length, ( + f"Context + continuation length exceeds {self.max_length} tokens: " + f"{len(context)} + {len(continuation)}" + ) + + context = torch.tensor(context, device=self.device) + continuation = torch.tensor(continuation, device=self.device) + + logprob = self.get_loglikelihood(context, continuation) + isgreedy = self.suffix_greedy_prediction(context, continuation) + out.append((logprob, isgreedy)) - is_target_greedy_dec = self.suffix_greedy_prediction(prefix, target) - - out.append((ll, 1.0 if is_target_greedy_dec else 0.0)) torch.cuda.empty_cache() return out @@ -244,30 +251,32 @@ def loglikelihood_rolling(self, requests): raise NotImplementedError def generate_until(self, requests: list[Instance]): - def _tokenize(e): - return { - "question": self.tokenizer(e["question"])["input_ids"], - "question_text": e["question"], - "until": e["until"], - } - - ds = [{"question": req.args[0], "until": req.args[1]['until']} for req in requests] - ds = Dataset.from_list(ds) - ds = ds.map(_tokenize) - ds = ds.with_format("torch") - + """Generate greedily until a stopping sequence + + :param requests: list[Instance] + A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs). + context: str + Context string + gen_kwargs: dict + A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc. + :return: list[str] + A list of model generated continuations. + continuation: str + The generated continuation. + """ out = [] - for elem in tqdm(ds, desc="Generating..."): - prompt = elem["question"].unsqueeze(0).to(self.device) - stop_tokens = elem["until"] - - generated_answer = generate(self.model, prompt, steps=self.steps, gen_length=self.gen_length, block_length=self.block_length, + for instance in tqdm(requests, desc="Generating..."): + context, until = instance.args # type: ignore + context = self.tokenizer(context, return_tensors="pt").input_ids + until = until["until"] + + generated_answer = generate(self.model, context, steps=self.steps, gen_length=self.gen_length, block_length=self.block_length, temperature=0, cfg_scale=self.cfg, remasking=self.remasking, mask_id=self.mask_id) - - generated_answer = self.tokenizer.decode(generated_answer[0][prompt.shape[1]:], skip_special_tokens=False) - for stop_seq in stop_tokens: - if stop_seq in generated_answer: - generated_answer = generated_answer.split(stop_seq)[0] + + generated_answer = self.tokenizer.decode(generated_answer[0][context.shape[1]:], skip_special_tokens=False) + for stop_seq in until: + if stop_seq in generated_answer: + generated_answer = generated_answer.split(stop_seq)[0] # remove special tokens generated_answer_ids = self.tokenizer(generated_answer)["input_ids"]