Skip to content

support Bedrock #81

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/monkey_patch/exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class MonkeyPatchException(Exception):
pass
123 changes: 60 additions & 63 deletions src/monkey_patch/function_modeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,24 @@
from monkey_patch.utils import approximate_token_count, prepare_object_for_saving, encode_int, decode_int
import copy


EXAMPLE_ELEMENT_LIMIT = 1000


class FunctionModeler(object):
def __init__(self, data_worker, workspace_id = 0, check_for_finetunes = True) -> None:
def __init__(self, data_worker, workspace_id=0, check_for_finetunes=True) -> None:
self.function_configs = {}
self.data_worker = data_worker
self.distillation_token_limit = 3000 # the token limit for finetuning
self.distillation_token_limit = 3000 # the token limit for finetuning
self.align_buffer = {}
self._get_datasets()
self.workspace_id = workspace_id
self.check_for_finetunes = check_for_finetunes


def _get_dataset_info(self, dataset_type, func_hash, type = "length"):
def _get_dataset_info(self, dataset_type, func_hash, type="length"):
"""
Get the dataset size for a function hash
"""
return self.data_worker._load_dataset(dataset_type, func_hash, return_type = type)
return self.data_worker._load_dataset(dataset_type, func_hash, return_type=type)

def _get_datasets(self):
"""
Expand Down Expand Up @@ -55,17 +53,16 @@ def save_align_statements(self, function_hash, args, kwargs, output):
successfully_saved, new_datapoint = self.data_worker.log_align(function_hash, example)
if successfully_saved:
if function_hash in self.dataset_sizes["alignments"]:
self.dataset_sizes["alignments"][function_hash] += 1
self.dataset_sizes["alignments"][function_hash] += 1
else:
self.dataset_sizes["alignments"][function_hash] = 1

if new_datapoint:
# update align buffer
if function_hash not in self.align_buffer:
self.align_buffer[function_hash] = bytearray()
self.align_buffer[function_hash].extend(str(example.__dict__).encode('utf-8') + b'\r\n')


def save_datapoint(self, func_hash, example):
"""
Save datapoint to the training data
Expand All @@ -75,13 +72,14 @@ def save_datapoint(self, func_hash, example):
if func_hash in self.dataset_sizes["patches"]:
# if the dataset size is -1, it means we havent read in the dataset size yet
if self.dataset_sizes["patches"][func_hash] == -1:
self.dataset_sizes["patches"][func_hash] = self._get_dataset_info("patches", func_hash, type = "length")
self.dataset_sizes["patches"][func_hash] = self._get_dataset_info("patches", func_hash,
type="length")
else:
self.dataset_sizes["patches"][func_hash] += datapoints
else:
self.dataset_sizes["patches"][func_hash] = datapoints
return len(written_datapoints) > 0

def get_alignments(self, func_hash, max=20):
"""
Get all aligns for a function hash
Expand All @@ -104,7 +102,7 @@ def get_alignments(self, func_hash, max=20):
# easy and straightforward way to get nr of words (not perfect but doesnt need to be)
# Can do the proper way of tokenizing later, it might be slower and we dont need 100% accuracy
example_element_limit = EXAMPLE_ELEMENT_LIMIT

examples = []
for example_bytes in split_buffer:
if example_bytes in example_set:
Expand All @@ -128,18 +126,17 @@ def load_align_statements(self, function_hash):
Load all align statements
"""
if function_hash not in self.align_buffer:
dataset_size, align_dataset = self._get_dataset_info("alignments", function_hash, type = "both")
dataset_size, align_dataset = self._get_dataset_info("alignments", function_hash, type="both")
if align_dataset:
self.align_buffer[function_hash] = bytearray(align_dataset)
self.dataset_sizes["alignments"][function_hash] = dataset_size


def postprocess_datapoint(self, func_hash, function_description, example, repaired=True):
"""
Postprocess the datapoint
"""
try:

added = self.save_datapoint(func_hash, example)
if added:
self._update_datapoint_config(repaired, func_hash)
Expand All @@ -154,21 +151,20 @@ def _load_function_config(self, func_hash, function_description):
"""
Load the config file for a function hash
"""

config, default = self.data_worker._load_function_config(func_hash)
if default and self.check_for_finetunes:
if default and self.check_for_finetunes and default.get('finetune_support', True):
finetuned, finetune_config = self._check_for_finetunes(function_description)
if finetuned:
config = finetune_config
self.function_configs[func_hash] = config
return config



def _check_for_finetunes(self, function_description):
# This here should be discussed, what's the bestd way to do it

# hash the function_hash into 16 characters
finetune_hash = function_description.__hash__(purpose = "finetune") + encode_int(self.workspace_id)
finetune_hash = function_description.__hash__(purpose="finetune") + encode_int(self.workspace_id)
# List 10 fine-tuning jobs
finetunes = openai.FineTuningJob.list(limit=1000)
# Check if the function_hash is in the fine-tuning jobs
Expand All @@ -184,9 +180,9 @@ def _check_for_finetunes(self, function_description):
return True, config
except:
return False, {}

return False, {}

def _construct_config_from_finetune(self, finetune_hash, finetune):
model = finetune["fine_tuned_model"]
# get the ending location of finetune hash in the model name
Expand All @@ -197,19 +193,16 @@ def _construct_config_from_finetune(self, finetune_hash, finetune):
nr_of_training_runs = decode_int(next_char) + 1
nr_of_training_points = (2 ** nr_of_training_runs) * 200
config = {
"distilled_model": model,
"current_model_stats": {
"trained_on_datapoints": nr_of_training_points,
"running_faults": []},
"last_training_run": {"trained_on_datapoints": nr_of_training_points},
"current_training_run": {},
"teacher_models": ["gpt-4","gpt-4-32k"], # currently supported teacher models
"nr_of_training_runs": nr_of_training_runs}

return config


"distilled_model": model,
"current_model_stats": {
"trained_on_datapoints": nr_of_training_points,
"running_faults": []},
"last_training_run": {"trained_on_datapoints": nr_of_training_points},
"current_training_run": {},
"teacher_models": ["gpt-4", "gpt-4-32k"], # currently supported teacher models
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add the biggest available Claude model here. @MartBakler is this correct?

"nr_of_training_runs": nr_of_training_runs}

return config

def get_models(self, function_description):
"""
Expand All @@ -220,7 +213,7 @@ def get_models(self, function_description):
func_config = self.function_configs[func_hash]
else:
func_config = self._load_function_config(func_hash, function_description)

# for backwards compatibility
if "distilled_model" not in func_config:
if func_config["current_model"] in func_config["teacher_models"]:
Expand All @@ -231,7 +224,7 @@ def get_models(self, function_description):
distilled_model = func_config["distilled_model"]

return distilled_model, func_config["teacher_models"]

def _update_datapoint_config(self, repaired, func_hash):
"""
Update the config to reflect the new datapoint in the training data
Expand All @@ -249,7 +242,7 @@ def _update_datapoint_config(self, repaired, func_hash):
self.function_configs[func_hash]["current_model_stats"]["running_faults"].append(0)
# take the last 100 datapoints
self.function_configs[func_hash]["current_model_stats"]["running_faults"] = \
self.function_configs[func_hash]["current_model_stats"]["running_faults"][-100:]
self.function_configs[func_hash]["current_model_stats"]["running_faults"][-100:]

# check if the last 10 datapoints are 50% faulty, this is the switch condition
if sum(self.function_configs[func_hash]["current_model_stats"]["running_faults"][-10:]) / 10 > 0.5:
Expand All @@ -262,11 +255,10 @@ def _update_datapoint_config(self, repaired, func_hash):
print(e)
print("Could not update config file")
pass

def _update_config_file(self, func_hash):
self.data_worker._update_function_config(func_hash, self.function_configs[func_hash])


def check_for_finetuning(self, function_description, func_hash):
"""
Check for finetuning status
Expand All @@ -285,7 +277,7 @@ def check_for_finetuning(self, function_description, func_hash):
except Exception as e:
print(e)
print("Error checking for finetuning")

def _check_finetuning_condition(self, func_hash):
"""
Check if the finetuning condition is met
Expand All @@ -294,18 +286,19 @@ def _check_finetuning_condition(self, func_hash):
if func_hash not in self.function_configs:
return False


training_threshold = (2 ** self.function_configs[func_hash]["nr_of_training_runs"]) * 200

align_dataset_size = self.dataset_sizes["alignments"][func_hash] if func_hash in self.dataset_sizes["alignments"] else 0
patch_dataset_size = self.dataset_sizes["patches"][func_hash] if func_hash in self.dataset_sizes["patches"] else 0
align_dataset_size = self.dataset_sizes["alignments"][func_hash] if func_hash in self.dataset_sizes[
"alignments"] else 0
patch_dataset_size = self.dataset_sizes["patches"][func_hash] if func_hash in self.dataset_sizes[
"patches"] else 0

if patch_dataset_size == -1:
# if havent read in the patch dataset size, read it in
patch_dataset_size = self._get_dataset_info("patches", func_hash, type = "length")
patch_dataset_size = self._get_dataset_info("patches", func_hash, type="length")
self.dataset_sizes["patches"][func_hash] = patch_dataset_size
return (patch_dataset_size + align_dataset_size) > training_threshold

def _execute_finetuning(self, function_description, func_hash):
"""
Execute the finetuning
Expand All @@ -315,24 +308,24 @@ def _execute_finetuning(self, function_description, func_hash):
"""
# get function description
function_string = str(function_description.__dict__.__repr__() + "\n")

# get the align dataset
align_dataset = self._get_dataset_info("alignments", func_hash, type = "dataset")
align_dataset = self._get_dataset_info("alignments", func_hash, type="dataset")
if not align_dataset:
align_dataset = ""
else:
align_dataset = align_dataset.decode('utf-8')

# get the patch dataset
patch_dataset = self._get_dataset_info("patches", func_hash, type = "dataset")
patch_dataset = self._get_dataset_info("patches", func_hash, type="dataset")
if not patch_dataset:
patch_dataset = ""
else:
patch_dataset = patch_dataset.decode('utf-8')

if align_dataset == "" and patch_dataset == "":
return

dataset = align_dataset + patch_dataset

dataset.replace("\\n", "[SEP_TOKEN]")
Expand All @@ -352,7 +345,7 @@ def _execute_finetuning(self, function_description, func_hash):
"content": f"{instruction}\nFunction: {function_string}---\nInputs:\nArgs: {x['args']}\nKwargs: {x['kwargs']}\nOutput:"},
{"role": "assistant", "content": str(x['output']) if x['output'] is not None else "None"}]}
for x in dataset]

# Create an in-memory text stream
temp_file = io.StringIO()
# Write data to the stream
Expand All @@ -365,7 +358,7 @@ def _execute_finetuning(self, function_description, func_hash):
temp_file.seek(0)

# create the finetune hash
finetune_hash = function_description.__hash__(purpose = "finetune")
finetune_hash = function_description.__hash__(purpose="finetune")
nr_of_training_runs = self.function_configs[func_hash]["nr_of_training_runs"]
finetune_hash += encode_int(self.workspace_id)
finetune_hash += encode_int(nr_of_training_runs)
Expand All @@ -377,21 +370,23 @@ def _execute_finetuning(self, function_description, func_hash):
return

# here can be sure that datasets were read in as that is checked in the finetune_check
align_dataset_size = self.dataset_sizes["alignments"][func_hash] if func_hash in self.dataset_sizes["alignments"] else 0
patch_dataset_size = self.dataset_sizes["patches"][func_hash] if func_hash in self.dataset_sizes["patches"] else 0
align_dataset_size = self.dataset_sizes["alignments"][func_hash] if func_hash in self.dataset_sizes[
"alignments"] else 0
patch_dataset_size = self.dataset_sizes["patches"][func_hash] if func_hash in self.dataset_sizes[
"patches"] else 0
total_dataset_size = align_dataset_size + patch_dataset_size
training_file_id = response["id"]
# submit the finetuning job
try:
finetuning_response = openai.FineTuningJob.create(training_file=training_file_id, model="gpt-3.5-turbo",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to figure out a way to call the logic to trigger a finetune in the relevant LLM provider class (i.e OpenAI, Bedrock, etc) - this is not the right place for this logic to live now we are adding support for more providers.

suffix=finetune_hash)
suffix=finetune_hash)
except Exception as e:
return

self.function_configs[func_hash]["current_training_run"] = {"job_id": finetuning_response["id"],
"trained_on_datapoints": total_dataset_size,
"last_checked": datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S")}
"trained_on_datapoints": total_dataset_size,
"last_checked": datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S")}
# update the config json file
try:
self._update_config_file(func_hash)
Expand Down Expand Up @@ -424,11 +419,13 @@ def _update_finetune_config(self, response, func_hash, status):
"""
if status == "failed":
self.function_configs[func_hash]["current_training_run"] = {}
else:
else:
self.function_configs[func_hash]["distilled_model"] = response["fine_tuned_model"]
self.function_configs[func_hash]["last_training_run"] = self.function_configs[func_hash]["current_training_run"]
self.function_configs[func_hash]["last_training_run"] = self.function_configs[func_hash][
"current_training_run"]
self.function_configs[func_hash]["current_model_stats"] = {
"trained_on_datapoints": self.function_configs[func_hash]["current_training_run"]["trained_on_datapoints"],
"trained_on_datapoints": self.function_configs[func_hash]["current_training_run"][
"trained_on_datapoints"],
"running_faults": []}
self.function_configs[func_hash]["nr_of_training_runs"] += 1
self.function_configs[func_hash]["current_training_run"] = {}
Expand All @@ -437,4 +434,4 @@ def _update_finetune_config(self, response, func_hash, status):
except Exception as e:
print(e)
print("Could not update config file after a successful finetuning run")
pass
pass
Loading