Skip to content

feat: configurable examples #132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions gliner_api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from pydantic import AliasChoices, BaseModel, Field


class Entity(BaseModel):
start: int = Field(
ge=0,
description="Start index of the entity in the input text",
)
end: int = Field(
ge=0,
description="End index of the entity in the input text",
)
text: str = Field(
description="Text of the entity, extracted from the input text",
)
type: str = Field(
validation_alias=AliasChoices("type", "label"),
description="Entity type or label",
)
score: float = Field(
ge=0.0,
le=1.0,
description="Confidence score of the entity detection, between 0 and 1",
)
2 changes: 1 addition & 1 deletion gliner_api/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from gliner import GLiNER
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN

from gliner_api import Entity
from gliner_api.config import Config, get_config
from gliner_api.datamodel import (
BatchRequest,
BatchResponse,
Entity,
ErrorMessage,
HealthCheckResponse,
InfoResponse,
Expand Down
92 changes: 19 additions & 73 deletions gliner_api/datamodel.py
Original file line number Diff line number Diff line change
@@ -1,136 +1,88 @@
from pydantic import AliasChoices, BaseModel, Field, TypeAdapter
from pydantic import BaseModel, Field, TypeAdapter

from gliner_api import Entity
from gliner_api.config import Config, get_config
from gliner_api.examples import Examples, get_examples

config: Config = get_config()
examples: Examples = get_examples()


class ErrorMessage(BaseModel):
error: str = Field(description="Short error code")
detail: str = Field(description="Detailed error explanaiton")


class Entity(BaseModel):
start: int = Field(
ge=0,
description="Start index of the entity in the input text",
)
end: int = Field(
ge=0,
description="End index of the entity in the input text",
)
text: str = Field(
description="Text of the entity, extracted from the input text",
)
type: str = Field(
validation_alias=AliasChoices("type", "label"),
description="Entity type or label",
)
score: float = Field(
ge=0.0,
le=1.0,
description="Confidence score of the entity detection, between 0 and 1",
)


class InvokeRequest(BaseModel):
text: str = Field(
description="Input text to analyze for entities",
examples=["Steve Jobs founded Apple Inc. in Cupertino, CA on April 1, 1976."],
examples=[example.text for example in examples.invoke],
)
threshold: float = Field(
default=config.default_threshold,
description="Threshold for entity detection; if not set, uses default threshold (see gliner config from /api/info endpoint)",
examples=[0.5],
examples=[example.threshold for example in examples.invoke],
ge=0.0,
le=1.0,
)
entity_types: list[str] = Field(
default=config.default_entities,
description="List of entity types to detect; if not set, uses default entities (see gliner config from /api/info endpoint)",
examples=[["person", "organization", "location", "date"]],
examples=[example.entity_types for example in examples.invoke],
)
flat_ner: bool = Field(
default=True,
description="Whether to return flat entities (default: True). If False, returns nested entities.",
examples=[True],
examples=[example.flat_ner for example in examples.invoke],
)
multi_label: bool = Field(
default=False,
description="Whether to allow multiple labels per entity (default: False). If True, there can be multiple entities returned for the same span.",
examples=[False],
examples=[example.multi_label for example in examples.invoke],
)


class InvokeResponse(BaseModel):
entities: list[Entity] = Field(
description="List of detected entities in the input text",
examples=[
[
Entity(start=0, end=10, text="Steve Jobs", type="person", score=0.99),
Entity(start=19, end=24, text="Apple", type="organization", score=0.98),
Entity(start=28, end=37, text="Cupertino", type="location", score=0.98),
Entity(start=39, end=49, text="California", type="location", score=0.99),
Entity(start=53, end=66, text="April 1, 1976", type="date", score=0.68),
]
],
examples=[example.entities for example in examples.invoke],
)


class BatchRequest(BaseModel):
texts: list[str] = Field(
description="List of input texts to analyze for entities",
examples=[
[
"Steve Jobs founded Apple Inc. in Cupertino, CA on April 1, 1976.",
"Until her death in 2022, the head of the Windsor family, Queen Elizabeth, resided in London.",
],
],
examples=[example.texts for example in examples.batch],
min_length=1,
)
threshold: float = Field(
default=config.default_threshold,
description="Threshold for entity detection; if not set, uses default threshold (see gliner config from /api/info endpoint)",
examples=[0.3],
examples=[example.threshold for example in examples.batch],
ge=0.0,
le=1.0,
)
entity_types: list[str] = Field(
default=config.default_entities,
description="List of entity types to detect; if not set, uses default entities (see gliner config from /api/info endpoint)",
examples=[["person", "organization", "location", "date"]],
examples=[example.entity_types for example in examples.batch],
)
flat_ner: bool = Field(
default=True,
description="Whether to return flat entities (default: True). If False, returns nested entities.",
examples=[True],
examples=[example.flat_ner for example in examples.batch],
)
multi_label: bool = Field(
default=False,
description="Whether to allow multiple labels per entity (default: False). If True, there can be multiple entities returned for the same span.",
examples=[False],
examples=[example.multi_label for example in examples.batch],
)


class BatchResponse(BaseModel):
entities: list[list[Entity]] = Field(
description="List of lists of detected entities for each input text",
examples=[
[
[
Entity(start=0, end=10, text="Steve Jobs", type="person", score=0.99),
Entity(start=19, end=24, text="Apple", type="organization", score=0.98),
Entity(start=28, end=37, text="Cupertino", type="location", score=0.98),
Entity(start=39, end=49, text="California", type="location", score=0.99),
Entity(start=53, end=66, text="April 1, 1976", type="date", score=0.68),
],
[
Entity(start=19, end=23, text="2022", type="date", score=0.38),
Entity(start=41, end=55, text="Windsor family", type="organization", score=0.90),
Entity(start=57, end=72, text="Queen Elizabeth", type="person", score=0.99),
Entity(start=85, end=91, text="London", type="location", score=0.99),
],
]
],
examples=[example.entities for example in examples.batch],
)


Expand All @@ -145,37 +97,31 @@ class InfoResponse(BaseModel):
model_id: str = Field(
default=config.model_id,
description="The Huggingface model ID for a GLiNER model.",
examples=["knowledgator/gliner-x-base"],
)
default_entities: list[str] = Field(
default=config.default_entities,
description="The default entities to be detected, used if request includes no specific entities.",
examples=[["person", "organization", "location", "date"]],
)
default_threshold: float = Field(
default=config.default_threshold,
description="The default threshold for entity detection, used if request includes no specific threshold.",
examples=[0.5],
ge=0.0,
le=1.0,
)
api_key_required: bool = Field(
default=config.api_key is not None,
description="Whether an API key is required for requests",
examples=[False],
)
configured_use_case: str = Field(
default=config.use_case,
description="The configured use case for this deployment",
examples=["general"],
)
onnx_enabled: bool = Field(
default=config.onnx_enabled,
description="Whether the GLiNER model is loaded as an ONNX model",
examples=[False],
)


# Define TypeAdapter for Entity list once and reuse it
entity_list_adapter: TypeAdapter[list[Entity]] = TypeAdapter(list[Entity])
deep_entity_list_adapter: TypeAdapter[list[list[Entity]]] = TypeAdapter(list[list[Entity]])
entity_list_adapter: TypeAdapter[list[Entity]] = TypeAdapter(type=list[Entity])
deep_entity_list_adapter: TypeAdapter[list[list[Entity]]] = TypeAdapter(type=list[list[Entity]])
124 changes: 124 additions & 0 deletions gliner_api/examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from functools import lru_cache

from pydantic import Field
from pydantic_settings import (
BaseSettings,
JsonConfigSettingsSource,
PydanticBaseSettingsSource,
SettingsConfigDict,
YamlConfigSettingsSource,
)

from gliner_api import Entity
from gliner_api.config import get_config


class Examples(BaseSettings):
invoke: list["InvokeExample"] = Field(
default=[
{
"text": "Steve Jobs founded Apple Inc. in Cupertino, CA on April 1, 1976.",
"entities": [
Entity(start=0, end=10, text="Steve Jobs", type="person", score=0.88),
Entity(start=19, end=29, text="Apple Inc.", type="organization", score=0.84),
Entity(start=33, end=46, text="Cupertino, CA", type="location", score=0.63),
Entity(start=50, end=63, text="April 1, 1976", type="date", score=0.69),
],
},
{
"text": "Until her death in 2022, the head of the Windsor family, Queen Elizabeth, resided in London.",
"entity_types": ["person", "organization", "location", "date"],
"entities": [],
},
{
"text": "The Eiffel Tower was completed in 1889 and is located in Paris, France.",
"entity_types": ["building", "location", "date"],
"entities": [],
},
{
"text": "Barack Obama served as the 44th President of the United States from 2009 to 2017.",
"threshold": 0.4,
"entity_types": ["person", "organization", "location", "date", "job title"],
"flat_ner": False,
"entities": [],
},
{
"text": "Albert Einstein developed the theory of relativity, which revolutionized modern physics.",
"threshold": 0.2,
"entity_types": ["person", "research field", "topic", "physical law"],
"multi_label": True,
"entities": [],
},
]
)
batch: list["BatchExample"] = Field(
default=[
{
"texts": [
"Steve Jobs founded Apple Inc. in Cupertino, CA on April 1, 1976.",
"Until her death in 2022, the head of the Windsor family, Queen Elizabeth, resided in London.",
],
"entities": [
[
Entity(start=0, end=10, text="Steve Jobs", type="person", score=0.99),
Entity(start=19, end=24, text="Apple", type="organization", score=0.98),
Entity(start=28, end=37, text="Cupertino", type="location", score=0.98),
Entity(start=39, end=49, text="California", type="location", score=0.99),
Entity(start=53, end=66, text="April 1, 1976", type="date", score=0.68),
],
[
Entity(start=19, end=23, text="2022", type="date", score=0.38),
Entity(start=41, end=55, text="Windsor family", type="organization", score=0.90),
Entity(start=57, end=72, text="Queen Elizabeth", type="person", score=0.99),
Entity(start=85, end=91, text="London", type="location", score=0.99),
],
],
}
]
)

model_config: SettingsConfigDict = SettingsConfigDict(
yaml_file="examples.yaml",
yaml_file_encoding="utf-8",
json_file="examples.json",
json_file_encoding="utf-8",
)

@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return (
init_settings,
YamlConfigSettingsSource(settings_cls=settings_cls),
JsonConfigSettingsSource(settings_cls=settings_cls),
)


class InvokeExample(BaseSettings):
text: str
threshold: float = Field(ge=0.0, le=1.0, default_factory=get_config().default_threshold)
entity_types: list[str] = Field(default_factory=get_config().default_entities)
flat_ner: bool = True
multi_label: bool = False
entities: list[Entity]


class BatchExample(BaseSettings):
texts: list[str]
threshold: float = Field(ge=0.0, le=1.0, default_factory=get_config().default_threshold)
entity_types: list[str] = Field(default_factory=get_config().default_entities)
flat_ner: bool = True
multi_label: bool = False
entities: list[list[Entity]]


@lru_cache
def get_examples() -> Examples:
"""Get the examples for the API docs and the frontend."""
return Examples()
Loading