-
Notifications
You must be signed in to change notification settings - Fork 25
support Bedrock #81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kevintruong
wants to merge
2
commits into
Tanuki:master
Choose a base branch
from
kevintruong:fea/bedrock
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
support Bedrock #81
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
class MonkeyPatchException(Exception): | ||
pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,26 +8,24 @@ | |
from monkey_patch.utils import approximate_token_count, prepare_object_for_saving, encode_int, decode_int | ||
import copy | ||
|
||
|
||
EXAMPLE_ELEMENT_LIMIT = 1000 | ||
|
||
|
||
class FunctionModeler(object): | ||
def __init__(self, data_worker, workspace_id = 0, check_for_finetunes = True) -> None: | ||
def __init__(self, data_worker, workspace_id=0, check_for_finetunes=True) -> None: | ||
self.function_configs = {} | ||
self.data_worker = data_worker | ||
self.distillation_token_limit = 3000 # the token limit for finetuning | ||
self.distillation_token_limit = 3000 # the token limit for finetuning | ||
self.align_buffer = {} | ||
self._get_datasets() | ||
self.workspace_id = workspace_id | ||
self.check_for_finetunes = check_for_finetunes | ||
|
||
|
||
def _get_dataset_info(self, dataset_type, func_hash, type = "length"): | ||
def _get_dataset_info(self, dataset_type, func_hash, type="length"): | ||
""" | ||
Get the dataset size for a function hash | ||
""" | ||
return self.data_worker._load_dataset(dataset_type, func_hash, return_type = type) | ||
return self.data_worker._load_dataset(dataset_type, func_hash, return_type=type) | ||
|
||
def _get_datasets(self): | ||
""" | ||
|
@@ -55,17 +53,16 @@ def save_align_statements(self, function_hash, args, kwargs, output): | |
successfully_saved, new_datapoint = self.data_worker.log_align(function_hash, example) | ||
if successfully_saved: | ||
if function_hash in self.dataset_sizes["alignments"]: | ||
self.dataset_sizes["alignments"][function_hash] += 1 | ||
self.dataset_sizes["alignments"][function_hash] += 1 | ||
else: | ||
self.dataset_sizes["alignments"][function_hash] = 1 | ||
|
||
if new_datapoint: | ||
# update align buffer | ||
if function_hash not in self.align_buffer: | ||
self.align_buffer[function_hash] = bytearray() | ||
self.align_buffer[function_hash].extend(str(example.__dict__).encode('utf-8') + b'\r\n') | ||
|
||
|
||
def save_datapoint(self, func_hash, example): | ||
""" | ||
Save datapoint to the training data | ||
|
@@ -75,13 +72,14 @@ def save_datapoint(self, func_hash, example): | |
if func_hash in self.dataset_sizes["patches"]: | ||
# if the dataset size is -1, it means we havent read in the dataset size yet | ||
if self.dataset_sizes["patches"][func_hash] == -1: | ||
self.dataset_sizes["patches"][func_hash] = self._get_dataset_info("patches", func_hash, type = "length") | ||
self.dataset_sizes["patches"][func_hash] = self._get_dataset_info("patches", func_hash, | ||
type="length") | ||
else: | ||
self.dataset_sizes["patches"][func_hash] += datapoints | ||
else: | ||
self.dataset_sizes["patches"][func_hash] = datapoints | ||
return len(written_datapoints) > 0 | ||
|
||
def get_alignments(self, func_hash, max=20): | ||
""" | ||
Get all aligns for a function hash | ||
|
@@ -104,7 +102,7 @@ def get_alignments(self, func_hash, max=20): | |
# easy and straightforward way to get nr of words (not perfect but doesnt need to be) | ||
# Can do the proper way of tokenizing later, it might be slower and we dont need 100% accuracy | ||
example_element_limit = EXAMPLE_ELEMENT_LIMIT | ||
|
||
examples = [] | ||
for example_bytes in split_buffer: | ||
if example_bytes in example_set: | ||
|
@@ -128,18 +126,17 @@ def load_align_statements(self, function_hash): | |
Load all align statements | ||
""" | ||
if function_hash not in self.align_buffer: | ||
dataset_size, align_dataset = self._get_dataset_info("alignments", function_hash, type = "both") | ||
dataset_size, align_dataset = self._get_dataset_info("alignments", function_hash, type="both") | ||
if align_dataset: | ||
self.align_buffer[function_hash] = bytearray(align_dataset) | ||
self.dataset_sizes["alignments"][function_hash] = dataset_size | ||
|
||
|
||
def postprocess_datapoint(self, func_hash, function_description, example, repaired=True): | ||
""" | ||
Postprocess the datapoint | ||
""" | ||
try: | ||
|
||
added = self.save_datapoint(func_hash, example) | ||
if added: | ||
self._update_datapoint_config(repaired, func_hash) | ||
|
@@ -154,21 +151,20 @@ def _load_function_config(self, func_hash, function_description): | |
""" | ||
Load the config file for a function hash | ||
""" | ||
|
||
config, default = self.data_worker._load_function_config(func_hash) | ||
if default and self.check_for_finetunes: | ||
if default and self.check_for_finetunes and default.get('finetune_support', True): | ||
finetuned, finetune_config = self._check_for_finetunes(function_description) | ||
if finetuned: | ||
config = finetune_config | ||
self.function_configs[func_hash] = config | ||
return config | ||
|
||
|
||
|
||
def _check_for_finetunes(self, function_description): | ||
# This here should be discussed, what's the bestd way to do it | ||
|
||
# hash the function_hash into 16 characters | ||
finetune_hash = function_description.__hash__(purpose = "finetune") + encode_int(self.workspace_id) | ||
finetune_hash = function_description.__hash__(purpose="finetune") + encode_int(self.workspace_id) | ||
# List 10 fine-tuning jobs | ||
finetunes = openai.FineTuningJob.list(limit=1000) | ||
# Check if the function_hash is in the fine-tuning jobs | ||
|
@@ -184,9 +180,9 @@ def _check_for_finetunes(self, function_description): | |
return True, config | ||
except: | ||
return False, {} | ||
|
||
return False, {} | ||
|
||
def _construct_config_from_finetune(self, finetune_hash, finetune): | ||
model = finetune["fine_tuned_model"] | ||
# get the ending location of finetune hash in the model name | ||
|
@@ -197,19 +193,16 @@ def _construct_config_from_finetune(self, finetune_hash, finetune): | |
nr_of_training_runs = decode_int(next_char) + 1 | ||
nr_of_training_points = (2 ** nr_of_training_runs) * 200 | ||
config = { | ||
"distilled_model": model, | ||
"current_model_stats": { | ||
"trained_on_datapoints": nr_of_training_points, | ||
"running_faults": []}, | ||
"last_training_run": {"trained_on_datapoints": nr_of_training_points}, | ||
"current_training_run": {}, | ||
"teacher_models": ["gpt-4","gpt-4-32k"], # currently supported teacher models | ||
"nr_of_training_runs": nr_of_training_runs} | ||
|
||
return config | ||
|
||
|
||
"distilled_model": model, | ||
"current_model_stats": { | ||
"trained_on_datapoints": nr_of_training_points, | ||
"running_faults": []}, | ||
"last_training_run": {"trained_on_datapoints": nr_of_training_points}, | ||
"current_training_run": {}, | ||
"teacher_models": ["gpt-4", "gpt-4-32k"], # currently supported teacher models | ||
"nr_of_training_runs": nr_of_training_runs} | ||
|
||
return config | ||
|
||
def get_models(self, function_description): | ||
""" | ||
|
@@ -220,7 +213,7 @@ def get_models(self, function_description): | |
func_config = self.function_configs[func_hash] | ||
else: | ||
func_config = self._load_function_config(func_hash, function_description) | ||
|
||
# for backwards compatibility | ||
if "distilled_model" not in func_config: | ||
if func_config["current_model"] in func_config["teacher_models"]: | ||
|
@@ -231,7 +224,7 @@ def get_models(self, function_description): | |
distilled_model = func_config["distilled_model"] | ||
|
||
return distilled_model, func_config["teacher_models"] | ||
|
||
def _update_datapoint_config(self, repaired, func_hash): | ||
""" | ||
Update the config to reflect the new datapoint in the training data | ||
|
@@ -249,7 +242,7 @@ def _update_datapoint_config(self, repaired, func_hash): | |
self.function_configs[func_hash]["current_model_stats"]["running_faults"].append(0) | ||
# take the last 100 datapoints | ||
self.function_configs[func_hash]["current_model_stats"]["running_faults"] = \ | ||
self.function_configs[func_hash]["current_model_stats"]["running_faults"][-100:] | ||
self.function_configs[func_hash]["current_model_stats"]["running_faults"][-100:] | ||
|
||
# check if the last 10 datapoints are 50% faulty, this is the switch condition | ||
if sum(self.function_configs[func_hash]["current_model_stats"]["running_faults"][-10:]) / 10 > 0.5: | ||
|
@@ -262,11 +255,10 @@ def _update_datapoint_config(self, repaired, func_hash): | |
print(e) | ||
print("Could not update config file") | ||
pass | ||
|
||
def _update_config_file(self, func_hash): | ||
self.data_worker._update_function_config(func_hash, self.function_configs[func_hash]) | ||
|
||
|
||
def check_for_finetuning(self, function_description, func_hash): | ||
""" | ||
Check for finetuning status | ||
|
@@ -285,7 +277,7 @@ def check_for_finetuning(self, function_description, func_hash): | |
except Exception as e: | ||
print(e) | ||
print("Error checking for finetuning") | ||
|
||
def _check_finetuning_condition(self, func_hash): | ||
""" | ||
Check if the finetuning condition is met | ||
|
@@ -294,18 +286,19 @@ def _check_finetuning_condition(self, func_hash): | |
if func_hash not in self.function_configs: | ||
return False | ||
|
||
|
||
training_threshold = (2 ** self.function_configs[func_hash]["nr_of_training_runs"]) * 200 | ||
|
||
align_dataset_size = self.dataset_sizes["alignments"][func_hash] if func_hash in self.dataset_sizes["alignments"] else 0 | ||
patch_dataset_size = self.dataset_sizes["patches"][func_hash] if func_hash in self.dataset_sizes["patches"] else 0 | ||
align_dataset_size = self.dataset_sizes["alignments"][func_hash] if func_hash in self.dataset_sizes[ | ||
"alignments"] else 0 | ||
patch_dataset_size = self.dataset_sizes["patches"][func_hash] if func_hash in self.dataset_sizes[ | ||
"patches"] else 0 | ||
|
||
if patch_dataset_size == -1: | ||
# if havent read in the patch dataset size, read it in | ||
patch_dataset_size = self._get_dataset_info("patches", func_hash, type = "length") | ||
patch_dataset_size = self._get_dataset_info("patches", func_hash, type="length") | ||
self.dataset_sizes["patches"][func_hash] = patch_dataset_size | ||
return (patch_dataset_size + align_dataset_size) > training_threshold | ||
|
||
def _execute_finetuning(self, function_description, func_hash): | ||
""" | ||
Execute the finetuning | ||
|
@@ -315,24 +308,24 @@ def _execute_finetuning(self, function_description, func_hash): | |
""" | ||
# get function description | ||
function_string = str(function_description.__dict__.__repr__() + "\n") | ||
|
||
# get the align dataset | ||
align_dataset = self._get_dataset_info("alignments", func_hash, type = "dataset") | ||
align_dataset = self._get_dataset_info("alignments", func_hash, type="dataset") | ||
if not align_dataset: | ||
align_dataset = "" | ||
else: | ||
align_dataset = align_dataset.decode('utf-8') | ||
|
||
# get the patch dataset | ||
patch_dataset = self._get_dataset_info("patches", func_hash, type = "dataset") | ||
patch_dataset = self._get_dataset_info("patches", func_hash, type="dataset") | ||
if not patch_dataset: | ||
patch_dataset = "" | ||
else: | ||
patch_dataset = patch_dataset.decode('utf-8') | ||
|
||
if align_dataset == "" and patch_dataset == "": | ||
return | ||
|
||
dataset = align_dataset + patch_dataset | ||
|
||
dataset.replace("\\n", "[SEP_TOKEN]") | ||
|
@@ -352,7 +345,7 @@ def _execute_finetuning(self, function_description, func_hash): | |
"content": f"{instruction}\nFunction: {function_string}---\nInputs:\nArgs: {x['args']}\nKwargs: {x['kwargs']}\nOutput:"}, | ||
{"role": "assistant", "content": str(x['output']) if x['output'] is not None else "None"}]} | ||
for x in dataset] | ||
|
||
# Create an in-memory text stream | ||
temp_file = io.StringIO() | ||
# Write data to the stream | ||
|
@@ -365,7 +358,7 @@ def _execute_finetuning(self, function_description, func_hash): | |
temp_file.seek(0) | ||
|
||
# create the finetune hash | ||
finetune_hash = function_description.__hash__(purpose = "finetune") | ||
finetune_hash = function_description.__hash__(purpose="finetune") | ||
nr_of_training_runs = self.function_configs[func_hash]["nr_of_training_runs"] | ||
finetune_hash += encode_int(self.workspace_id) | ||
finetune_hash += encode_int(nr_of_training_runs) | ||
|
@@ -377,21 +370,23 @@ def _execute_finetuning(self, function_description, func_hash): | |
return | ||
|
||
# here can be sure that datasets were read in as that is checked in the finetune_check | ||
align_dataset_size = self.dataset_sizes["alignments"][func_hash] if func_hash in self.dataset_sizes["alignments"] else 0 | ||
patch_dataset_size = self.dataset_sizes["patches"][func_hash] if func_hash in self.dataset_sizes["patches"] else 0 | ||
align_dataset_size = self.dataset_sizes["alignments"][func_hash] if func_hash in self.dataset_sizes[ | ||
"alignments"] else 0 | ||
patch_dataset_size = self.dataset_sizes["patches"][func_hash] if func_hash in self.dataset_sizes[ | ||
"patches"] else 0 | ||
total_dataset_size = align_dataset_size + patch_dataset_size | ||
training_file_id = response["id"] | ||
# submit the finetuning job | ||
try: | ||
finetuning_response = openai.FineTuningJob.create(training_file=training_file_id, model="gpt-3.5-turbo", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to figure out a way to call the logic to trigger a finetune in the relevant LLM provider class (i.e OpenAI, Bedrock, etc) - this is not the right place for this logic to live now we are adding support for more providers. |
||
suffix=finetune_hash) | ||
suffix=finetune_hash) | ||
except Exception as e: | ||
return | ||
|
||
self.function_configs[func_hash]["current_training_run"] = {"job_id": finetuning_response["id"], | ||
"trained_on_datapoints": total_dataset_size, | ||
"last_checked": datetime.datetime.now().strftime( | ||
"%Y-%m-%d %H:%M:%S")} | ||
"trained_on_datapoints": total_dataset_size, | ||
"last_checked": datetime.datetime.now().strftime( | ||
"%Y-%m-%d %H:%M:%S")} | ||
# update the config json file | ||
try: | ||
self._update_config_file(func_hash) | ||
|
@@ -424,11 +419,13 @@ def _update_finetune_config(self, response, func_hash, status): | |
""" | ||
if status == "failed": | ||
self.function_configs[func_hash]["current_training_run"] = {} | ||
else: | ||
else: | ||
self.function_configs[func_hash]["distilled_model"] = response["fine_tuned_model"] | ||
self.function_configs[func_hash]["last_training_run"] = self.function_configs[func_hash]["current_training_run"] | ||
self.function_configs[func_hash]["last_training_run"] = self.function_configs[func_hash][ | ||
"current_training_run"] | ||
self.function_configs[func_hash]["current_model_stats"] = { | ||
"trained_on_datapoints": self.function_configs[func_hash]["current_training_run"]["trained_on_datapoints"], | ||
"trained_on_datapoints": self.function_configs[func_hash]["current_training_run"][ | ||
"trained_on_datapoints"], | ||
"running_faults": []} | ||
self.function_configs[func_hash]["nr_of_training_runs"] += 1 | ||
self.function_configs[func_hash]["current_training_run"] = {} | ||
|
@@ -437,4 +434,4 @@ def _update_finetune_config(self, response, func_hash, status): | |
except Exception as e: | ||
print(e) | ||
print("Could not update config file after a successful finetuning run") | ||
pass | ||
pass |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add the biggest available Claude model here. @MartBakler is this correct?