Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion function_schema/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
get_type_hints,
)

from .field import FieldInfo
from .types import FunctionSchema, Doc, DocMeta
from .utils import unwrap_doc

Expand All @@ -22,7 +23,8 @@
UnionType = Union # type: ignore


__all__ = ("get_function_schema", "guess_type", "Doc", "Annotated")
__all__ = ("get_function_schema", "guess_type",
"Doc", "Annotated", "FieldInfo")


def get_function_schema(
Expand Down Expand Up @@ -88,11 +90,19 @@ def get_function_schema(
if type_hint is not None:
param_args = get_args(type_hint)
is_annotated = get_origin(type_hint) is Annotated

# process Optional type for python <= 3.9
if get_origin(type_hint) is Union and type(None) in param_args:
type_hint = next(t for t in param_args if t is not type(None))
param_args = get_args(type_hint)
is_annotated = get_origin(type_hint) is Annotated

else:
param_args = []
is_annotated = False

enum_ = None
field_info = {}
default_value = inspect._empty

if is_annotated:
Expand Down Expand Up @@ -122,6 +132,14 @@ def get_function_schema(
# use typing.Literal as enum if no enum found
get_origin(T) is Literal and get_args(T) or None,
)

# find fieldinfo in param_args tuple
for arg in param_args:
if isinstance(arg, FieldInfo):
# XXX: latest field_info will override previous ones
field_info.update(arg.to_dict())
if isinstance(arg, dict): # XXX: allow dict to be passed as field_info ?
field_info.update(arg)
else:
T = param.annotation
description = f"The {name} parameter"
Expand All @@ -135,8 +153,13 @@ def get_function_schema(
schema["properties"][name] = {
"type": guess_type(T),
"description": description, # type: ignore
**field_info,
}

if "required" in field_info and field_info["required"]:
schema["required"].append(name)
del schema["properties"][name]["required"]

if enum_ is not None:
schema["properties"][name]["enum"] = [
t for t in enum_ if t is not None]
Expand Down
89 changes: 89 additions & 0 deletions function_schema/field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from dataclasses import dataclass, field
from typing import Optional, Literal, Union, List

Numeric = Union[int, float]
AlphaNumeric = Union[str, Numeric, None]


@dataclass
class FieldInfo:
"""
A class to represent the schema of a field with various attributes.
Not all attributes that exist in the JSON schema are included here (yet).
Attributes:
----------
type : Optional[Literal["object", "array", "string", "number", "integer", "boolean"]]
The type of the field.
description : Optional[str]
A brief description of the field.
enum : Optional[List[AlphaNumeric]]
A list of allowed values for the field.
required : Optional[bool]
Indicates if the field is required.
minimum : Optional[Numeric]
The minimum value for the field.
maximum : Optional[Numeric]
The maximum value for the field.
exclusive_minimum : Optional[Numeric]
The exclusive minimum value for the field.
exclusive_maximum : Optional[Numeric]
The exclusive maximum value for the field.
max_length : Optional[int]
The maximum length of the field.
min_length : Optional[int]
The minimum length of the field.
pattern : Optional[str]
A regex pattern that the field value must match.
"""

type: Optional[Literal["object", "array", "string",
"number", "integer", "boolean"]] = None
description: Optional[str] = None
enum: Optional[List[AlphaNumeric]] = None
required: Optional[bool] = None
minimum: Optional[Numeric] = None
maximum: Optional[Numeric] = None
exclusive_minimum: Optional[Numeric] = None
exclusive_maximum: Optional[Numeric] = None
max_length: Optional[int] = None
min_length: Optional[int] = None
pattern: Optional[str] = None

def __ge__(self, value: Numeric):
self.minimum = value
return self

def __gt__(self, value: Numeric):
self.exclusive_minimum = value
return self

def __le__(self, value: Numeric):
self.maximum = value
return self

def __lt__(self, value: Numeric):
self.exclusive_maximum = value
return self

def __rlshift__(self, value: Numeric):
self.minimum = value
return self

def to_dict(self):
result = {
"type": self.type,
"description": self.description,
"enum": self.enum,
"required": self.required,
"minimum": self.minimum,
"maximum": self.maximum,
"exclusiveMinimum": self.exclusive_minimum,
"exclusiveMaximum": self.exclusive_maximum,
"maxLength": self.max_length,
"minLength": self.min_length,
"pattern": self.pattern,
}
return {k: v for k, v in result.items() if v is not None}


F = FieldInfo()
122 changes: 122 additions & 0 deletions test/test_additional_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from typing import Annotated, Optional
from function_schema.core import get_function_schema, FieldInfo


def test_fieldinfo_type():
"""Test if FieldInfo type overrides the guessed type"""
def func(a: Annotated[int, FieldInfo(type="integer")]):
...

schema = get_function_schema(func)
assert schema["parameters"]["properties"]["a"]["type"] == "integer"


def test_fieldinfo_description():
"""Test if FieldInfo description is added to the schema"""
def func(a: Annotated[int, FieldInfo(description="An integer parameter")]):
...

schema = get_function_schema(func)
assert schema["parameters"]["properties"]["a"]["description"] == "An integer parameter"


def test_fieldinfo_enum():
"""Test if FieldInfo enum is added to the schema"""
def func(a: Annotated[str, FieldInfo(enum=["red", "green", "blue"])]):
...

schema = get_function_schema(func)
assert schema["parameters"]["properties"]["a"]["enum"] == [
"red", "green", "blue"]


def test_fieldinfo_required():
"""Test if FieldInfo required is added to the schema"""
def func(a: Annotated[Optional[int], FieldInfo(required=True)]):
...

schema = get_function_schema(func)
assert "required" in schema["parameters"]
assert "a" in schema["parameters"]["required"]


def test_fieldinfo_min_max():
def func(a: Annotated[int, FieldInfo(minimum=1, maximum=100)]):
...

schema = get_function_schema(func)
assert schema["parameters"]["properties"]["a"]["minimum"] == 1
assert schema["parameters"]["properties"]["a"]["maximum"] == 100


def test_fieldinfo_exclusive_min_max():
def func(a: Annotated[int, FieldInfo(exclusive_minimum=0, exclusive_maximum=10)]):
...

schema = get_function_schema(func)
assert schema["parameters"]["properties"]["a"]["exclusiveMinimum"] == 0
assert schema["parameters"]["properties"]["a"]["exclusiveMaximum"] == 10


def test_fieldinfo_length():
def func(a: Annotated[str, FieldInfo(min_length=5, max_length=10)]):
...

schema = get_function_schema(func)
assert schema["parameters"]["properties"]["a"]["minLength"] == 5
assert schema["parameters"]["properties"]["a"]["maxLength"] == 10


def test_fieldinfo_pattern():
def func(a: Annotated[str, FieldInfo(pattern="^[a-zA-Z0-9]+$")]):
...

schema = get_function_schema(func)
assert schema["parameters"]["properties"]["a"]["pattern"] == "^[a-zA-Z0-9]+$"


def test_more_than_one_fieldinfo():
def func(a: Annotated[int, FieldInfo(minimum=1, maximum=100), FieldInfo(pattern="^[0-9]+$")]):
...

schema = get_function_schema(func)
assert schema["parameters"]["properties"]["a"]["minimum"] == 1
assert schema["parameters"]["properties"]["a"]["maximum"] == 100
assert schema["parameters"]["properties"]["a"]["pattern"] == "^[0-9]+$"


def test_fieldinfo_expression():
def func(a: Annotated[int, FieldInfo() >= 1]):
...

schema = get_function_schema(func)
assert schema["parameters"]["properties"]["a"]["minimum"] == 1


def test_fieldinfo_expression_chain():
F = FieldInfo()

def func(a: Annotated[int, 1 <= F < 2]):
...

schema = get_function_schema(func)
assert schema["parameters"]["properties"]["a"]["minimum"] == 1
assert schema["parameters"]["properties"]["a"]["exclusiveMaximum"] == 2


def test_fieldinfo_expression_combo():
def func(a: Annotated[int, 1 <= FieldInfo(description="The number") <= 42]):
...

schema = get_function_schema(func)
assert schema["parameters"]["properties"]["a"]["minimum"] == 1
assert schema["parameters"]["properties"]["a"]["maximum"] == 42
assert schema["parameters"]["properties"]["a"]["description"] == "The number"


def test_dict_as_additional_field():
def func(a: Annotated[int, {"minimum": 10}]):
...

schema = get_function_schema(func)
assert schema["parameters"]["properties"]["a"]["minimum"] == 10
Loading