diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 7808cf6..fe3cb06 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,3 +1,4 @@ +graphql-core>=3.2.6 loguru Pillow>=11.0.0 pydantic>=2.5,<3 diff --git a/tests/conftest.py b/tests/conftest.py index a9fc93e..82cb744 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,7 @@ FeedbackSubmitResponse, CreditUsage, ) -from vlmrun.client.predictions import SchemaCastMixin +from vlmrun.client.predictions import SchemaHandlerMixin @pytest.fixture @@ -30,7 +30,7 @@ def mock_client(monkeypatch): """Mock the VLMRun class.""" class MockVLMRun: - class AudioPredictions(SchemaCastMixin): + class AudioPredictions(SchemaHandlerMixin): def __init__(self, client): self._client = client @@ -221,7 +221,7 @@ class MockInvoiceSchema(BaseModel): schemas = {"document.invoice": MockInvoiceSchema, "general": None} return schemas.get(domain) - class ImagePredictions(SchemaCastMixin): + class ImagePredictions(SchemaHandlerMixin): def __init__(self, client): self._client = client @@ -243,7 +243,7 @@ def generate(self, domain: str, images=None, urls=None, **kwargs): self._cast_response_to_schema(prediction, domain, kwargs.get("config")) return prediction - class VideoPredictions(SchemaCastMixin): + class VideoPredictions(SchemaHandlerMixin): def __init__(self, client): self._client = client @@ -259,7 +259,7 @@ def generate(self, domain: str = None, **kwargs): self._cast_response_to_schema(prediction, domain, kwargs.get("config")) return prediction - class DocumentPredictions(SchemaCastMixin): + class DocumentPredictions(SchemaHandlerMixin): def __init__(self, client): self._client = client diff --git a/tests/test_predictions.py b/tests/test_predictions.py index 596b72b..9ecf912 100644 --- a/tests/test_predictions.py +++ b/tests/test_predictions.py @@ -4,6 +4,7 @@ from pydantic import BaseModel from PIL import Image from vlmrun.client.types import PredictionResponse, GenerationConfig +from vlmrun.common.gql import create_pydantic_model_from_gql class MockInvoiceSchema(BaseModel): @@ -186,3 +187,144 @@ def mock_get_schema(domain): ) assert isinstance(response.response, MockInvoiceSchema) + + +class AddressSchema(BaseModel): + """Mock address schema for testing.""" + + street: str + city: str + state: str + zip: str + + +class NestedInvoiceSchema(BaseModel): + """Mock invoice schema for testing.""" + + invoice_id: str + total_amount: float + period_start: str + period_end: str + address: AddressSchema + + +def test_create_pydantic_model_from_gql_basic(): + """Test basic field filtering with GQL.""" + gql_query = """ + { + invoice_id + total_amount + } + """ + + FilteredModel = create_pydantic_model_from_gql(NestedInvoiceSchema, gql_query) + assert set(FilteredModel.model_fields.keys()) == {"invoice_id", "total_amount"} + + +def test_create_pydantic_model_from_gql_nested(): + """Test nested field filtering with GQL.""" + gql_query = """ + { + invoice_id + address { + state + zip + } + } + """ + + FilteredModel = create_pydantic_model_from_gql(NestedInvoiceSchema, gql_query) + + assert set(FilteredModel.model_fields.keys()) == {"invoice_id", "address"} + + AddressModel = FilteredModel.model_fields["address"].annotation + assert set(AddressModel.model_fields.keys()) == {"state", "zip"} + + +def test_create_pydantic_model_from_gql_invalid_field(): + """Test handling of invalid field in GQL query.""" + gql_query = """ + { + invoice_id + nonexistent_field + } + """ + + with pytest.raises( + ValueError, match="Field 'nonexistent_field' not found in model" + ): + create_pydantic_model_from_gql(NestedInvoiceSchema, gql_query) + + +def test_create_pydantic_model_from_gql_invalid_nested_field(): + """Test handling of invalid nested field in GQL query.""" + gql_query = """ + { + invoice_id + address { + nonexistent_field + } + } + """ + + with pytest.raises( + ValueError, match="Field 'nonexistent_field' not found in nested model" + ): + create_pydantic_model_from_gql(NestedInvoiceSchema, gql_query) + + +def test_create_pydantic_model_from_gql_malformed(): + """Test handling of malformed GQL query.""" + malformed_query = """ + { + invoice_id + address { + """ + + with pytest.raises(Exception, match="Syntax Error"): + create_pydantic_model_from_gql(NestedInvoiceSchema, malformed_query) + + +def test_create_pydantic_model_from_gql_nested_scalar(): + """Test handling of attempting to query nested fields of a scalar.""" + gql_query = """ + { + invoice_id { + nested_field + } + } + """ + + with pytest.raises(ValueError, match="Cannot query nested fields of scalar type"): + create_pydantic_model_from_gql(NestedInvoiceSchema, gql_query) + + +def test_create_pydantic_model_from_gql_all_fields(): + """Test requesting all fields.""" + gql_query = """ + { + invoice_id + total_amount + period_start + period_end + address { + street + city + state + zip + } + } + """ + + FilteredModel = create_pydantic_model_from_gql(NestedInvoiceSchema, gql_query) + + assert set(FilteredModel.model_fields.keys()) == { + "invoice_id", + "total_amount", + "period_start", + "period_end", + "address", + } + + AddressModel = FilteredModel.model_fields["address"].annotation + assert set(AddressModel.model_fields.keys()) == {"street", "city", "state", "zip"} diff --git a/vlmrun/client/predictions.py b/vlmrun/client/predictions.py index 1c5366a..5c68660 100644 --- a/vlmrun/client/predictions.py +++ b/vlmrun/client/predictions.py @@ -2,13 +2,14 @@ from __future__ import annotations from pathlib import Path -from typing import List, Optional, Union +from typing import List, Optional, Union, Dict, Any, Type from PIL import Image from loguru import logger import time from tqdm import tqdm from vlmrun.common.image import encode_image +from vlmrun.common.gql import create_pydantic_model_from_gql from vlmrun.client.base_requestor import APIRequestor from vlmrun.types.abstract import VLMRunProtocol from vlmrun.client.types import ( @@ -18,10 +19,60 @@ RequestMetadata, ) from vlmrun.hub.utils import jsonschema_to_model +from pydantic import BaseModel -class SchemaCastMixin: - """Mixin class to handle schema casting for predictions.""" +class SchemaHandlerMixin: + """Mixin class to handle schema operations for predictions. + + Handles: + - Schema casting for responses + - GQL query processing + - Request kwargs preparation with schema support + """ + + def _prepare_gql_schema( + self, + domain: str, + config: GenerationConfig, + ) -> Dict[str, Any]: + """Prepare schema for GQL query.""" + try: + base_model: Type[BaseModel] + + if config.response_model: + base_model = config.response_model + elif config.json_schema: + base_model = jsonschema_to_model(config.json_schema) + else: + base_model = self._client.hub.get_pydantic_model(domain) + + return create_pydantic_model_from_gql( + base_model=base_model, gql_query=config.gql + ).model_json_schema() + except Exception as e: + logger.error(f"Failed to process schema with GQL: {e}") + raise + + def _prepare_request_kwargs( + self, + domain: str, + metadata: Optional[RequestMetadata] = None, + config: Optional[GenerationConfig] = None, + ) -> Dict[str, Any]: + additional_kwargs = {} + + if config: + if config.gql: + filtered_schema = self._prepare_gql_schema(domain, config) + config.json_schema = filtered_schema + + additional_kwargs["config"] = config.model_dump() + + if metadata: + additional_kwargs["metadata"] = metadata.model_dump() + + return additional_kwargs def _cast_response_to_schema( self, @@ -111,7 +162,7 @@ def wait(self, id: str, timeout: int = 60, sleep: int = 1) -> PredictionResponse raise TimeoutError(f"Prediction {id} did not complete within {timeout} seconds") -class ImagePredictions(SchemaCastMixin, Predictions): +class ImagePredictions(SchemaHandlerMixin, Predictions): """Image prediction resource for VLM Run API.""" def generate( @@ -169,11 +220,7 @@ def generate( raise ValueError("All URLs must be strings") images_data = urls - additional_kwargs = {} - if config: - additional_kwargs["config"] = config.model_dump() - if metadata: - additional_kwargs["metadata"] = metadata.model_dump() + additional_kwargs = self._prepare_request_kwargs(domain, metadata, config) response, status_code, headers = self._requestor.request( method="POST", url="image/generate", @@ -196,7 +243,7 @@ def generate( def FilePredictions(route: str): """File prediction resource for VLM Run API.""" - class _FilePredictions(SchemaCastMixin, Predictions): + class _FilePredictions(SchemaHandlerMixin, Predictions): """File prediction resource for VLM Run API.""" def generate( @@ -256,11 +303,7 @@ def generate( "File or URL must be a pathlib.Path, str, or AnyHttpUrl" ) - additional_kwargs = {} - if config: - additional_kwargs["config"] = config.model_dump() - if metadata: - additional_kwargs["metadata"] = metadata.model_dump() + additional_kwargs = self._prepare_request_kwargs(domain, metadata, config) response, status_code, headers = self._requestor.request( method="POST", url=f"{route}/generate", diff --git a/vlmrun/client/types.py b/vlmrun/client/types.py index dae5e52..e7f1607 100644 --- a/vlmrun/client/types.py +++ b/vlmrun/client/types.py @@ -127,6 +127,10 @@ class GenerationConfig(BaseModel): confidence: bool = Field(default=False) grounding: bool = Field(default=False) + gql: Optional[str] = Field( + default=None, description="GraphQL query to filter response fields" + ) + class RequestMetadata(BaseModel): environment: Literal["dev", "staging", "prod"] = Field(default="dev") diff --git a/vlmrun/common/gql.py b/vlmrun/common/gql.py new file mode 100644 index 0000000..a69c766 --- /dev/null +++ b/vlmrun/common/gql.py @@ -0,0 +1,89 @@ +from typing import Type, Dict, Any +from pydantic import BaseModel, create_model +from graphql import parse +from graphql.language.ast import ( + SelectionSetNode, + FieldNode, + OperationDefinitionNode, + DocumentNode, +) +from loguru import logger + + +def create_pydantic_model_from_gql( + base_model: Type[BaseModel], gql_query: str +) -> Type[BaseModel]: + """Creates a subset Pydantic model based on a GraphQL query. + + Args: + base_model: The original Pydantic model containing all fields + gql_query: GraphQL query string specifying desired fields + + Returns: + A new Pydantic model containing only the fields specified in the query + + Raises: + graphql.GraphQLSyntaxError: If the GQL query is invalid + ValueError: If the query references fields not in the base model + """ + try: + ast: DocumentNode = parse(gql_query) + except Exception as e: + logger.error(f"Invalid GraphQL query: {e}") + raise + + fields: Dict[str, Any] = {} + + def extract_fields( + node: SelectionSetNode, + current_model: Type[BaseModel], + current_fields: Dict[str, Any], + ) -> None: + if isinstance(node, SelectionSetNode): + for selection in node.selections: + if isinstance(selection, FieldNode): + field_name = selection.name.value + original_field = current_model.model_fields.get(field_name) + + if original_field is None: + if current_model == base_model: + raise ValueError(f"Field '{field_name}' not found in model") + else: + raise ValueError( + f"Field '{field_name}' not found in nested model" + ) + + if selection.selection_set: + if not hasattr(original_field.annotation, "model_fields"): + raise ValueError( + f"Cannot query nested fields of scalar type: {field_name}" + ) + + nested_fields = {} + extract_fields( + selection.selection_set, + original_field.annotation, + nested_fields, + ) + if nested_fields: + current_fields[field_name] = ( + create_model( + f"{current_model.__name__}{field_name.capitalize()}", + **nested_fields, + ), + original_field.default, + ) + else: + current_fields[field_name] = ( + original_field.annotation, + original_field.default, + ) + + for definition in ast.definitions: + if isinstance(definition, OperationDefinitionNode): + extract_fields(definition.selection_set, base_model, fields) + + if not fields: + raise ValueError("No valid fields found in GraphQL query") + + return create_model(f"{base_model.__name__}", __base__=BaseModel, **fields)