Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
graphql-core>=3.2.6
loguru
Pillow>=11.0.0
pydantic>=2.5,<3
Expand Down
10 changes: 5 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
FeedbackSubmitResponse,
CreditUsage,
)
from vlmrun.client.predictions import SchemaCastMixin
from vlmrun.client.predictions import SchemaHandlerMixin


@pytest.fixture
Expand All @@ -30,7 +30,7 @@ def mock_client(monkeypatch):
"""Mock the VLMRun class."""

class MockVLMRun:
class AudioPredictions(SchemaCastMixin):
class AudioPredictions(SchemaHandlerMixin):
def __init__(self, client):
self._client = client

Expand Down Expand Up @@ -221,7 +221,7 @@ class MockInvoiceSchema(BaseModel):
schemas = {"document.invoice": MockInvoiceSchema, "general": None}
return schemas.get(domain)

class ImagePredictions(SchemaCastMixin):
class ImagePredictions(SchemaHandlerMixin):
def __init__(self, client):
self._client = client

Expand All @@ -243,7 +243,7 @@ def generate(self, domain: str, images=None, urls=None, **kwargs):
self._cast_response_to_schema(prediction, domain, kwargs.get("config"))
return prediction

class VideoPredictions(SchemaCastMixin):
class VideoPredictions(SchemaHandlerMixin):
def __init__(self, client):
self._client = client

Expand All @@ -259,7 +259,7 @@ def generate(self, domain: str = None, **kwargs):
self._cast_response_to_schema(prediction, domain, kwargs.get("config"))
return prediction

class DocumentPredictions(SchemaCastMixin):
class DocumentPredictions(SchemaHandlerMixin):
def __init__(self, client):
self._client = client

Expand Down
142 changes: 142 additions & 0 deletions tests/test_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pydantic import BaseModel
from PIL import Image
from vlmrun.client.types import PredictionResponse, GenerationConfig
from vlmrun.common.gql import create_pydantic_model_from_gql


class MockInvoiceSchema(BaseModel):
Expand Down Expand Up @@ -186,3 +187,144 @@ def mock_get_schema(domain):
)

assert isinstance(response.response, MockInvoiceSchema)


class AddressSchema(BaseModel):
"""Mock address schema for testing."""

street: str
city: str
state: str
zip: str


class NestedInvoiceSchema(BaseModel):
"""Mock invoice schema for testing."""

invoice_id: str
total_amount: float
period_start: str
period_end: str
address: AddressSchema


def test_create_pydantic_model_from_gql_basic():
"""Test basic field filtering with GQL."""
gql_query = """
{
invoice_id
total_amount
}
"""

FilteredModel = create_pydantic_model_from_gql(NestedInvoiceSchema, gql_query)
assert set(FilteredModel.model_fields.keys()) == {"invoice_id", "total_amount"}


def test_create_pydantic_model_from_gql_nested():
"""Test nested field filtering with GQL."""
gql_query = """
{
invoice_id
address {
state
zip
}
}
"""

FilteredModel = create_pydantic_model_from_gql(NestedInvoiceSchema, gql_query)

assert set(FilteredModel.model_fields.keys()) == {"invoice_id", "address"}

AddressModel = FilteredModel.model_fields["address"].annotation
assert set(AddressModel.model_fields.keys()) == {"state", "zip"}


def test_create_pydantic_model_from_gql_invalid_field():
"""Test handling of invalid field in GQL query."""
gql_query = """
{
invoice_id
nonexistent_field
}
"""

with pytest.raises(
ValueError, match="Field 'nonexistent_field' not found in model"
):
create_pydantic_model_from_gql(NestedInvoiceSchema, gql_query)


def test_create_pydantic_model_from_gql_invalid_nested_field():
"""Test handling of invalid nested field in GQL query."""
gql_query = """
{
invoice_id
address {
nonexistent_field
}
}
"""

with pytest.raises(
ValueError, match="Field 'nonexistent_field' not found in nested model"
):
create_pydantic_model_from_gql(NestedInvoiceSchema, gql_query)


def test_create_pydantic_model_from_gql_malformed():
"""Test handling of malformed GQL query."""
malformed_query = """
{
invoice_id
address {
"""

with pytest.raises(Exception, match="Syntax Error"):
create_pydantic_model_from_gql(NestedInvoiceSchema, malformed_query)


def test_create_pydantic_model_from_gql_nested_scalar():
"""Test handling of attempting to query nested fields of a scalar."""
gql_query = """
{
invoice_id {
nested_field
}
}
"""

with pytest.raises(ValueError, match="Cannot query nested fields of scalar type"):
create_pydantic_model_from_gql(NestedInvoiceSchema, gql_query)


def test_create_pydantic_model_from_gql_all_fields():
"""Test requesting all fields."""
gql_query = """
{
invoice_id
total_amount
period_start
period_end
address {
street
city
state
zip
}
}
"""

FilteredModel = create_pydantic_model_from_gql(NestedInvoiceSchema, gql_query)

assert set(FilteredModel.model_fields.keys()) == {
"invoice_id",
"total_amount",
"period_start",
"period_end",
"address",
}

AddressModel = FilteredModel.model_fields["address"].annotation
assert set(AddressModel.model_fields.keys()) == {"street", "city", "state", "zip"}
73 changes: 58 additions & 15 deletions vlmrun/client/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from __future__ import annotations
from pathlib import Path
from typing import List, Optional, Union
from typing import List, Optional, Union, Dict, Any, Type
from PIL import Image
from loguru import logger

import time
from tqdm import tqdm
from vlmrun.common.image import encode_image
from vlmrun.common.gql import create_pydantic_model_from_gql
from vlmrun.client.base_requestor import APIRequestor
from vlmrun.types.abstract import VLMRunProtocol
from vlmrun.client.types import (
Expand All @@ -18,10 +19,60 @@
RequestMetadata,
)
from vlmrun.hub.utils import jsonschema_to_model
from pydantic import BaseModel


class SchemaCastMixin:
"""Mixin class to handle schema casting for predictions."""
class SchemaHandlerMixin:
"""Mixin class to handle schema operations for predictions.

Handles:
- Schema casting for responses
- GQL query processing
- Request kwargs preparation with schema support
"""

def _prepare_gql_schema(
self,
domain: str,
config: GenerationConfig,
) -> Dict[str, Any]:
"""Prepare schema for GQL query."""
try:
base_model: Type[BaseModel]

if config.response_model:
base_model = config.response_model
elif config.json_schema:
base_model = jsonschema_to_model(config.json_schema)
else:
base_model = self._client.hub.get_pydantic_model(domain)

return create_pydantic_model_from_gql(
base_model=base_model, gql_query=config.gql
).model_json_schema()
except Exception as e:
logger.error(f"Failed to process schema with GQL: {e}")
raise

def _prepare_request_kwargs(
self,
domain: str,
metadata: Optional[RequestMetadata] = None,
config: Optional[GenerationConfig] = None,
) -> Dict[str, Any]:
additional_kwargs = {}

if config:
if config.gql:
filtered_schema = self._prepare_gql_schema(domain, config)
config.json_schema = filtered_schema

additional_kwargs["config"] = config.model_dump()

if metadata:
additional_kwargs["metadata"] = metadata.model_dump()

return additional_kwargs

def _cast_response_to_schema(
self,
Expand Down Expand Up @@ -111,7 +162,7 @@ def wait(self, id: str, timeout: int = 60, sleep: int = 1) -> PredictionResponse
raise TimeoutError(f"Prediction {id} did not complete within {timeout} seconds")


class ImagePredictions(SchemaCastMixin, Predictions):
class ImagePredictions(SchemaHandlerMixin, Predictions):
"""Image prediction resource for VLM Run API."""

def generate(
Expand Down Expand Up @@ -169,11 +220,7 @@ def generate(
raise ValueError("All URLs must be strings")
images_data = urls

additional_kwargs = {}
if config:
additional_kwargs["config"] = config.model_dump()
if metadata:
additional_kwargs["metadata"] = metadata.model_dump()
additional_kwargs = self._prepare_request_kwargs(domain, metadata, config)
response, status_code, headers = self._requestor.request(
method="POST",
url="image/generate",
Expand All @@ -196,7 +243,7 @@ def generate(
def FilePredictions(route: str):
"""File prediction resource for VLM Run API."""

class _FilePredictions(SchemaCastMixin, Predictions):
class _FilePredictions(SchemaHandlerMixin, Predictions):
"""File prediction resource for VLM Run API."""

def generate(
Expand Down Expand Up @@ -256,11 +303,7 @@ def generate(
"File or URL must be a pathlib.Path, str, or AnyHttpUrl"
)

additional_kwargs = {}
if config:
additional_kwargs["config"] = config.model_dump()
if metadata:
additional_kwargs["metadata"] = metadata.model_dump()
additional_kwargs = self._prepare_request_kwargs(domain, metadata, config)
response, status_code, headers = self._requestor.request(
method="POST",
url=f"{route}/generate",
Expand Down
4 changes: 4 additions & 0 deletions vlmrun/client/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ class GenerationConfig(BaseModel):
confidence: bool = Field(default=False)
grounding: bool = Field(default=False)

gql: Optional[str] = Field(
default=None, description="GraphQL query to filter response fields"
)


class RequestMetadata(BaseModel):
environment: Literal["dev", "staging", "prod"] = Field(default="dev")
Expand Down
Loading