diff --git a/litellm/litellm_core_utils/prompt_templates/common_utils.py b/litellm/litellm_core_utils/prompt_templates/common_utils.py index 9ba547b36009..a99883ef7b62 100644 --- a/litellm/litellm_core_utils/prompt_templates/common_utils.py +++ b/litellm/litellm_core_utils/prompt_templates/common_utils.py @@ -18,6 +18,7 @@ cast, ) +from litellm.router_utils.batch_utils import InMemoryFile from litellm.types.llms.openai import ( AllMessageValues, ChatCompletionAssistantMessage, @@ -453,6 +454,10 @@ def extract_file_data(file_data: FileTypes) -> ExtractedFileData: filename, file_content, content_type = file_data elif len(file_data) == 4: filename, file_content, content_type, file_headers = file_data + elif isinstance(file_data, InMemoryFile): + filename = file_data.name + file_content = file_data + content_type = file_data.content_type else: file_content = file_data # Convert content to bytes diff --git a/litellm/router_utils/batch_utils.py b/litellm/router_utils/batch_utils.py index 6617ad1f68ed..6c5d80afc1ba 100644 --- a/litellm/router_utils/batch_utils.py +++ b/litellm/router_utils/batch_utils.py @@ -7,9 +7,10 @@ class InMemoryFile(io.BytesIO): - def __init__(self, content: bytes, name: str): + def __init__(self, content: bytes, name: str, content_type: str = "application/jsonl"): super().__init__(content) self.name = name + self.content_type = content_type def should_replace_model_in_jsonl( @@ -63,7 +64,7 @@ def replace_model_in_jsonl(file_content: FileTypes, new_model_name: str) -> File # Reassemble the modified lines and return as bytes modified_file_content = "\n".join(modified_lines).encode("utf-8") - return InMemoryFile(modified_file_content, name="modified_file.jsonl") # type: ignore + return InMemoryFile(modified_file_content, name="modified_file.jsonl", content_type="application/jsonl") # type: ignore except (json.JSONDecodeError, UnicodeDecodeError, TypeError): # return the original file content if there is an error replacing the model name diff --git a/tests/router_unit_tests/test_router_batch_utils.py b/tests/router_unit_tests/test_router_batch_utils.py index 94cd6e001e4c..adbab7ce881b 100644 --- a/tests/router_unit_tests/test_router_batch_utils.py +++ b/tests/router_unit_tests/test_router_batch_utils.py @@ -18,7 +18,7 @@ from typing import Dict, List from litellm.router_utils.batch_utils import ( replace_model_in_jsonl, - _get_router_metadata_variable_name, + _get_router_metadata_variable_name, InMemoryFile, ) @@ -57,6 +57,9 @@ def test_bytes_input(sample_jsonl_bytes): result = replace_model_in_jsonl(sample_jsonl_bytes, new_model) assert result is not None + assert isinstance(result, InMemoryFile) + assert result.name == "modified_file.jsonl" + assert result.content_type == "application/jsonl" def test_tuple_input(sample_jsonl_bytes): @@ -66,6 +69,9 @@ def test_tuple_input(sample_jsonl_bytes): result = replace_model_in_jsonl(test_tuple, new_model) assert result is not None + assert isinstance(result, InMemoryFile) + assert result.name == "modified_file.jsonl" + assert result.content_type == "application/jsonl" def test_file_like_object(sample_file_like): @@ -74,6 +80,9 @@ def test_file_like_object(sample_file_like): result = replace_model_in_jsonl(sample_file_like, new_model) assert result is not None + assert isinstance(result, InMemoryFile) + assert result.name == "modified_file.jsonl" + assert result.content_type == "application/jsonl" def test_router_metadata_variable_name():