Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 69 additions & 60 deletions eval_llada.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
'''
import accelerate
import torch
import re
from pathlib import Path
import random
import numpy as np
import torch.nn.functional as F
Expand Down Expand Up @@ -70,22 +68,18 @@ def __init__(
self.accelerator = accelerator
else:
self.accelerator = None

model_kwargs = {}
if self.accelerator is not None:
model_kwargs.update({'device_map': {'': f'{self.accelerator.device}'}})

self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, **model_kwargs)
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
self.model.eval()

self.device = torch.device(device)
if self.accelerator is not None:
self.model = self.accelerator.prepare(self.model)
self.device = torch.device(f'{self.accelerator.device}')
self.device = self.accelerator.device
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
else:
self.model = self.model.to(device)
self.device = torch.device(device)

self.model = self.model.to(device)

self.mask_id = mask_id
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
Expand Down Expand Up @@ -194,6 +188,11 @@ def suffix_greedy_prediction(self, prefix, target):
return correct

def _encode_pair(self, context, continuation):
"""
Move spaces at the end of context to the beginning of continuation, and
encode both context and continuation into token ids. This is modified from
`lm_eval.api.model.TemplateLM._encode_pair`.
"""
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
Expand All @@ -207,67 +206,77 @@ def _encode_pair(self, context, continuation):

return context_enc, continuation_enc

def loglikelihood(self, requests):
def _tokenize(e):
prefix, target = self._encode_pair(e["prefix"], e["target"])
return {
"prefix_text": e["prefix"],
"target_text": e["target"],
"prefix": prefix,
"target": target,
}

ds = []
ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
ds = Dataset.from_list(ds)
ds = ds.map(_tokenize)
ds = ds.with_format("torch")
prompt_len = [len(x["prefix"]) + len(x["target"]) for x in ds]

assert max(prompt_len) <= 4096

def loglikelihood(self, requests: list[Instance]):
"""Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible.

:param requests: list[Instance]
A list of Instance objects, with property `args` which returns a tuple (context, continuation).
`context: str`
Context string. Implementations of LM must be able to handle an
empty context string.
`continuation: str`
The continuation over which log likelihood will be calculated. If
there is a word boundary, the space should be in the continuation.
For example, context="hello" continuation=" world" is correct.

:return: list[tuple[float, bool]]
A list of pairs (logprob, isgreedy)
`logprob: float`
The log probability of `continuation`.
`isgreedy`:
Whether `continuation` would be generated by greedy sampling from `context`.
"""
out = []
with torch.no_grad():
for elem in tqdm(ds, desc="Computing likelihood..."):
prefix = elem["prefix"]
target = elem["target"]

ll = self.get_loglikelihood(prefix, target)
for instance in tqdm(requests, desc="Computing likelihood..."):
context, continuation = self._encode_pair(*instance.args)
assert len(context) + len(continuation) <= self.max_length, (
f"Context + continuation length exceeds {self.max_length} tokens: "
f"{len(context)} + {len(continuation)}"
)

context = torch.tensor(context, device=self.device)
continuation = torch.tensor(continuation, device=self.device)

logprob = self.get_loglikelihood(context, continuation)
isgreedy = self.suffix_greedy_prediction(context, continuation)
out.append((logprob, isgreedy))

is_target_greedy_dec = self.suffix_greedy_prediction(prefix, target)

out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
torch.cuda.empty_cache()
return out

def loglikelihood_rolling(self, requests):
raise NotImplementedError

def generate_until(self, requests: list[Instance]):
def _tokenize(e):
return {
"question": self.tokenizer(e["question"])["input_ids"],
"question_text": e["question"],
"until": e["until"],
}

ds = [{"question": req.args[0], "until": req.args[1]['until']} for req in requests]
ds = Dataset.from_list(ds)
ds = ds.map(_tokenize)
ds = ds.with_format("torch")

"""Generate greedily until a stopping sequence

:param requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs).
context: str
Context string
gen_kwargs: dict
A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc.
:return: list[str]
A list of model generated continuations.
continuation: str
The generated continuation.
"""
out = []
for elem in tqdm(ds, desc="Generating..."):
prompt = elem["question"].unsqueeze(0).to(self.device)
stop_tokens = elem["until"]

generated_answer = generate(self.model, prompt, steps=self.steps, gen_length=self.gen_length, block_length=self.block_length,
for instance in tqdm(requests, desc="Generating..."):
context, until = instance.args # type: ignore
context = self.tokenizer(context, return_tensors="pt").input_ids
until = until["until"]

generated_answer = generate(self.model, context, steps=self.steps, gen_length=self.gen_length, block_length=self.block_length,
temperature=0, cfg_scale=self.cfg, remasking=self.remasking, mask_id=self.mask_id)
generated_answer = self.tokenizer.decode(generated_answer[0][prompt.shape[1]:], skip_special_tokens=False)
for stop_seq in stop_tokens:
if stop_seq in generated_answer:
generated_answer = generated_answer.split(stop_seq)[0]

generated_answer = self.tokenizer.decode(generated_answer[0][context.shape[1]:], skip_special_tokens=False)
for stop_seq in until:
if stop_seq in generated_answer:
generated_answer = generated_answer.split(stop_seq)[0]

# remove special tokens
generated_answer_ids = self.tokenizer(generated_answer)["input_ids"]
Expand Down