diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 38dd501c4a..af90cfa823 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -21,7 +21,7 @@ from pydantic import VERSION as P_VERSION from pydantic import BaseModel from pydantic.fields import FieldInfo -from typing_extensions import Annotated, get_args, get_origin +from typing_extensions import Annotated, Literal, get_args, get_origin # Reassign variable to make it reexported for mypy PYDANTIC_VERSION = P_VERSION @@ -459,6 +459,8 @@ def is_field_noneable(field: "FieldInfo") -> bool: return field.allow_none # type: ignore[no-any-return, attr-defined] def get_sa_type_from_field(field: Any) -> Any: + if get_origin(field.type_) is Literal: + return Literal if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: return field.type_ raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 38c85915aa..404d1efd0d 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -655,6 +655,9 @@ def get_sqlalchemy_type(field: Any) -> Any: type_ = get_sa_type_from_field(field) metadata = get_field_metadata(field) + # Checks for `Literal` type annotation + if type_ is Literal: + return AutoString # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI if issubclass(type_, Enum): return sa_Enum(type_) diff --git a/tests/test_main.py b/tests/test_main.py index 60d5c40ebb..98b9abcd67 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -4,6 +4,7 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import RelationshipProperty from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select +from typing_extensions import Literal def test_should_allow_duplicate_row_if_unique_constraint_is_not_passed(clear_sqlmodel): @@ -125,3 +126,25 @@ class Hero(SQLModel, table=True): # The next statement should not raise an AttributeError assert hero_rusty_man.team assert hero_rusty_man.team.name == "Preventers" + + +def test_literal_typehints_are_treated_as_strings(clear_sqlmodel): + """Test https://github.com/fastapi/sqlmodel/issues/57""" + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(unique=True) + weakness: Literal["Kryptonite", "Dehydration", "Munchies"] + + superguy = Hero(name="Superguy", weakness="Kryptonite") + + engine = create_engine("sqlite://", echo=True) + + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(superguy) + session.commit() + session.refresh(superguy) + assert superguy.weakness == "Kryptonite" + assert isinstance(superguy.weakness, str)