diff --git a/function_schema/core.py b/function_schema/core.py index fdd0775..ed65881 100644 --- a/function_schema/core.py +++ b/function_schema/core.py @@ -12,6 +12,7 @@ get_type_hints, ) +from .field import FieldInfo from .types import FunctionSchema, Doc, DocMeta from .utils import unwrap_doc @@ -22,7 +23,8 @@ UnionType = Union # type: ignore -__all__ = ("get_function_schema", "guess_type", "Doc", "Annotated") +__all__ = ("get_function_schema", "guess_type", + "Doc", "Annotated", "FieldInfo") def get_function_schema( @@ -88,11 +90,19 @@ def get_function_schema( if type_hint is not None: param_args = get_args(type_hint) is_annotated = get_origin(type_hint) is Annotated + + # process Optional type for python <= 3.9 + if get_origin(type_hint) is Union and type(None) in param_args: + type_hint = next(t for t in param_args if t is not type(None)) + param_args = get_args(type_hint) + is_annotated = get_origin(type_hint) is Annotated + else: param_args = [] is_annotated = False enum_ = None + field_info = {} default_value = inspect._empty if is_annotated: @@ -122,6 +132,14 @@ def get_function_schema( # use typing.Literal as enum if no enum found get_origin(T) is Literal and get_args(T) or None, ) + + # find fieldinfo in param_args tuple + for arg in param_args: + if isinstance(arg, FieldInfo): + # XXX: latest field_info will override previous ones + field_info.update(arg.to_dict()) + if isinstance(arg, dict): # XXX: allow dict to be passed as field_info ? + field_info.update(arg) else: T = param.annotation description = f"The {name} parameter" @@ -135,8 +153,13 @@ def get_function_schema( schema["properties"][name] = { "type": guess_type(T), "description": description, # type: ignore + **field_info, } + if "required" in field_info and field_info["required"]: + schema["required"].append(name) + del schema["properties"][name]["required"] + if enum_ is not None: schema["properties"][name]["enum"] = [ t for t in enum_ if t is not None] diff --git a/function_schema/field.py b/function_schema/field.py new file mode 100644 index 0000000..2d48ae7 --- /dev/null +++ b/function_schema/field.py @@ -0,0 +1,89 @@ +from dataclasses import dataclass, field +from typing import Optional, Literal, Union, List + +Numeric = Union[int, float] +AlphaNumeric = Union[str, Numeric, None] + + +@dataclass +class FieldInfo: + """ + A class to represent the schema of a field with various attributes. + Not all attributes that exist in the JSON schema are included here (yet). + Attributes: + ---------- + type : Optional[Literal["object", "array", "string", "number", "integer", "boolean"]] + The type of the field. + description : Optional[str] + A brief description of the field. + enum : Optional[List[AlphaNumeric]] + A list of allowed values for the field. + required : Optional[bool] + Indicates if the field is required. + minimum : Optional[Numeric] + The minimum value for the field. + maximum : Optional[Numeric] + The maximum value for the field. + exclusive_minimum : Optional[Numeric] + The exclusive minimum value for the field. + exclusive_maximum : Optional[Numeric] + The exclusive maximum value for the field. + max_length : Optional[int] + The maximum length of the field. + min_length : Optional[int] + The minimum length of the field. + pattern : Optional[str] + A regex pattern that the field value must match. + """ + + type: Optional[Literal["object", "array", "string", + "number", "integer", "boolean"]] = None + description: Optional[str] = None + enum: Optional[List[AlphaNumeric]] = None + required: Optional[bool] = None + minimum: Optional[Numeric] = None + maximum: Optional[Numeric] = None + exclusive_minimum: Optional[Numeric] = None + exclusive_maximum: Optional[Numeric] = None + max_length: Optional[int] = None + min_length: Optional[int] = None + pattern: Optional[str] = None + + def __ge__(self, value: Numeric): + self.minimum = value + return self + + def __gt__(self, value: Numeric): + self.exclusive_minimum = value + return self + + def __le__(self, value: Numeric): + self.maximum = value + return self + + def __lt__(self, value: Numeric): + self.exclusive_maximum = value + return self + + def __rlshift__(self, value: Numeric): + self.minimum = value + return self + + def to_dict(self): + result = { + "type": self.type, + "description": self.description, + "enum": self.enum, + "required": self.required, + "minimum": self.minimum, + "maximum": self.maximum, + "exclusiveMinimum": self.exclusive_minimum, + "exclusiveMaximum": self.exclusive_maximum, + "maxLength": self.max_length, + "minLength": self.min_length, + "pattern": self.pattern, + } + return {k: v for k, v in result.items() if v is not None} + + +F = FieldInfo() diff --git a/test/test_additional_field.py b/test/test_additional_field.py new file mode 100644 index 0000000..8e2ca56 --- /dev/null +++ b/test/test_additional_field.py @@ -0,0 +1,122 @@ +from typing import Annotated, Optional +from function_schema.core import get_function_schema, FieldInfo + + +def test_fieldinfo_type(): + """Test if FieldInfo type overrides the guessed type""" + def func(a: Annotated[int, FieldInfo(type="integer")]): + ... + + schema = get_function_schema(func) + assert schema["parameters"]["properties"]["a"]["type"] == "integer" + + +def test_fieldinfo_description(): + """Test if FieldInfo description is added to the schema""" + def func(a: Annotated[int, FieldInfo(description="An integer parameter")]): + ... + + schema = get_function_schema(func) + assert schema["parameters"]["properties"]["a"]["description"] == "An integer parameter" + + +def test_fieldinfo_enum(): + """Test if FieldInfo enum is added to the schema""" + def func(a: Annotated[str, FieldInfo(enum=["red", "green", "blue"])]): + ... + + schema = get_function_schema(func) + assert schema["parameters"]["properties"]["a"]["enum"] == [ + "red", "green", "blue"] + + +def test_fieldinfo_required(): + """Test if FieldInfo required is added to the schema""" + def func(a: Annotated[Optional[int], FieldInfo(required=True)]): + ... + + schema = get_function_schema(func) + assert "required" in schema["parameters"] + assert "a" in schema["parameters"]["required"] + + +def test_fieldinfo_min_max(): + def func(a: Annotated[int, FieldInfo(minimum=1, maximum=100)]): + ... + + schema = get_function_schema(func) + assert schema["parameters"]["properties"]["a"]["minimum"] == 1 + assert schema["parameters"]["properties"]["a"]["maximum"] == 100 + + +def test_fieldinfo_exclusive_min_max(): + def func(a: Annotated[int, FieldInfo(exclusive_minimum=0, exclusive_maximum=10)]): + ... + + schema = get_function_schema(func) + assert schema["parameters"]["properties"]["a"]["exclusiveMinimum"] == 0 + assert schema["parameters"]["properties"]["a"]["exclusiveMaximum"] == 10 + + +def test_fieldinfo_length(): + def func(a: Annotated[str, FieldInfo(min_length=5, max_length=10)]): + ... + + schema = get_function_schema(func) + assert schema["parameters"]["properties"]["a"]["minLength"] == 5 + assert schema["parameters"]["properties"]["a"]["maxLength"] == 10 + + +def test_fieldinfo_pattern(): + def func(a: Annotated[str, FieldInfo(pattern="^[a-zA-Z0-9]+$")]): + ... + + schema = get_function_schema(func) + assert schema["parameters"]["properties"]["a"]["pattern"] == "^[a-zA-Z0-9]+$" + + +def test_more_than_one_fieldinfo(): + def func(a: Annotated[int, FieldInfo(minimum=1, maximum=100), FieldInfo(pattern="^[0-9]+$")]): + ... + + schema = get_function_schema(func) + assert schema["parameters"]["properties"]["a"]["minimum"] == 1 + assert schema["parameters"]["properties"]["a"]["maximum"] == 100 + assert schema["parameters"]["properties"]["a"]["pattern"] == "^[0-9]+$" + + +def test_fieldinfo_expression(): + def func(a: Annotated[int, FieldInfo() >= 1]): + ... + + schema = get_function_schema(func) + assert schema["parameters"]["properties"]["a"]["minimum"] == 1 + + +def test_fieldinfo_expression_chain(): + F = FieldInfo() + + def func(a: Annotated[int, 1 <= F < 2]): + ... + + schema = get_function_schema(func) + assert schema["parameters"]["properties"]["a"]["minimum"] == 1 + assert schema["parameters"]["properties"]["a"]["exclusiveMaximum"] == 2 + + +def test_fieldinfo_expression_combo(): + def func(a: Annotated[int, 1 <= FieldInfo(description="The number") <= 42]): + ... + + schema = get_function_schema(func) + assert schema["parameters"]["properties"]["a"]["minimum"] == 1 + assert schema["parameters"]["properties"]["a"]["maximum"] == 42 + assert schema["parameters"]["properties"]["a"]["description"] == "The number" + + +def test_dict_as_additional_field(): + def func(a: Annotated[int, {"minimum": 10}]): + ... + + schema = get_function_schema(func) + assert schema["parameters"]["properties"]["a"]["minimum"] == 10