import datetime
import functools
import json
import warnings
from decimal import Decimal
from enum import Enum, IntEnum
from typing import TYPE_CHECKING, Any, Callable, Optional, Type, TypeVar, Union
from uuid import UUID, uuid4
from pypika import functions
from pypika.enums import SqlTypes
from pypika.terms import Term
from tortoise import timezone
from tortoise.exceptions import ConfigurationError, FieldError
from tortoise.fields.base import Field
from tortoise.timezone import get_default_timezone, get_timezone, get_use_tz, localtime
from tortoise.validators import MaxLengthValidator
try:
from ciso8601 import parse_datetime
except ImportError: # pragma: nocoverage
from iso8601 import parse_date
parse_datetime = functools.partial(parse_date, default_timezone=None)
if TYPE_CHECKING: # pragma: nocoverage
from tortoise.models import Model
__all__ = (
"BigIntField",
"BinaryField",
"BooleanField",
"CharEnumField",
"CharField",
"DateField",
"DatetimeField",
"DecimalField",
"FloatField",
"IntEnumField",
"IntField",
"JSONField",
"SmallIntField",
"TextField",
"TimeDeltaField",
"UUIDField",
)
# Doing this we can replace json dumps/loads with different implementations
JsonDumpsFunc = Callable[[Any], str]
JsonLoadsFunc = Callable[[Union[str, bytes]], Any]
JSON_DUMPS: JsonDumpsFunc = functools.partial(json.dumps, separators=(",", ":"))
JSON_LOADS: JsonLoadsFunc = json.loads
try:
# Use orjson as an optional accelerator
import orjson
JSON_DUMPS = lambda x: orjson.dumps(x).decode() # noqa: E731
JSON_LOADS = orjson.loads
except ImportError: # pragma: nocoverage
pass
[docs]class IntField(Field[int], int):
"""
Integer field. (32-bit signed)
``primary_key`` (bool):
True if field is Primary Key.
"""
SQL_TYPE = "INT"
allows_generated = True
def __init__(self, primary_key: Optional[bool] = None, **kwargs: Any) -> None:
if primary_key or kwargs.get("pk"):
kwargs["generated"] = bool(kwargs.get("generated", True))
super().__init__(primary_key=primary_key, **kwargs)
@property
def constraints(self) -> dict:
return {
"ge": -2147483648,
"le": 2147483647,
}
class _db_postgres:
GENERATED_SQL = "SERIAL NOT NULL PRIMARY KEY"
class _db_sqlite:
GENERATED_SQL = "INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL"
class _db_mysql:
GENERATED_SQL = "INT NOT NULL PRIMARY KEY AUTO_INCREMENT"
class _db_mssql:
GENERATED_SQL = "INT IDENTITY(1,1) NOT NULL PRIMARY KEY"
class _db_oracle:
GENERATED_SQL = "INT GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY NOT NULL"
[docs]class BigIntField(IntField):
"""
Big integer field. (64-bit signed)
``primary_key`` (bool):
True if field is Primary Key.
"""
SQL_TYPE = "BIGINT"
@property
def constraints(self) -> dict:
return {
"ge": -9223372036854775808,
"le": 9223372036854775807,
}
class _db_postgres:
GENERATED_SQL = "BIGSERIAL NOT NULL PRIMARY KEY"
class _db_mysql:
GENERATED_SQL = "BIGINT NOT NULL PRIMARY KEY AUTO_INCREMENT"
class _db_mssql:
GENERATED_SQL = "BIGINT IDENTITY(1,1) NOT NULL PRIMARY KEY"
class _db_oracle:
SQL_TYPE = "INT"
GENERATED_SQL = "INT GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY NOT NULL"
[docs]class SmallIntField(IntField):
"""
Small integer field. (16-bit signed)
``primary_key`` (bool):
True if field is Primary Key.
"""
SQL_TYPE = "SMALLINT"
@property
def constraints(self) -> dict:
return {
"ge": -32768,
"le": 32767,
}
class _db_postgres:
GENERATED_SQL = "SMALLSERIAL NOT NULL PRIMARY KEY"
class _db_mysql:
GENERATED_SQL = "SMALLINT NOT NULL PRIMARY KEY AUTO_INCREMENT"
class _db_mssql:
GENERATED_SQL = "SMALLINT IDENTITY(1,1) NOT NULL PRIMARY KEY"
class _db_oracle:
GENERATED_SQL = "SMALLINT GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY NOT NULL"
[docs]class CharField(Field[str]):
"""
Character field.
You must provide the following:
``max_length`` (int):
Maximum length of the field in characters.
"""
field_type = str
def __init__(self, max_length: int, **kwargs: Any) -> None:
if int(max_length) < 1:
raise ConfigurationError("'max_length' must be >= 1")
self.max_length = int(max_length)
super().__init__(**kwargs)
self.validators.append(MaxLengthValidator(self.max_length))
@property
def constraints(self) -> dict:
return {
"max_length": self.max_length,
}
@property
def SQL_TYPE(self) -> str: # type: ignore
return f"VARCHAR({self.max_length})"
class _db_oracle:
def __init__(self, field: "CharField") -> None:
self.field = field
@property
def SQL_TYPE(self) -> str:
return f"NVARCHAR2({self.field.max_length})"
[docs]class TextField(Field[str], str): # type: ignore
"""
Large Text field.
"""
indexable = False
SQL_TYPE = "TEXT"
def __init__(
self,
primary_key: Optional[bool] = None,
unique: bool = False,
db_index: bool = False,
**kwargs: Any,
) -> None:
if primary_key or kwargs.get("pk"):
warnings.warn(
"TextField as a PrimaryKey is Deprecated, use CharField instead",
DeprecationWarning,
stacklevel=2,
)
if unique:
raise ConfigurationError(
"TextField doesn't support unique indexes, consider CharField or another strategy"
)
if db_index or kwargs.get("index"):
raise ConfigurationError("TextField can't be indexed, consider CharField")
super().__init__(primary_key=primary_key, **kwargs)
class _db_mysql:
SQL_TYPE = "LONGTEXT"
class _db_mssql:
SQL_TYPE = "NVARCHAR(MAX)"
class _db_oracle:
SQL_TYPE = "NCLOB"
[docs]class BooleanField(Field[bool]):
"""
Boolean field.
"""
# Bool is not subclassable, so we specify type here
field_type = bool
SQL_TYPE = "BOOL"
class _db_sqlite:
SQL_TYPE = "INT"
class _db_mssql:
SQL_TYPE = "BIT"
class _db_oracle:
SQL_TYPE = "NUMBER(1)"
[docs]class DecimalField(Field[Decimal], Decimal):
"""
Accurate decimal field.
You must provide the following:
``max_digits`` (int):
Max digits of significance of the decimal field.
``decimal_places`` (int):
How many of those significant digits is after the decimal point.
"""
skip_to_python_if_native = True
def __init__(self, max_digits: int, decimal_places: int, **kwargs: Any) -> None:
if int(max_digits) < 1:
raise ConfigurationError("'max_digits' must be >= 1")
if int(decimal_places) < 0:
raise ConfigurationError("'decimal_places' must be >= 0")
super().__init__(**kwargs)
self.max_digits = max_digits
self.decimal_places = decimal_places
self.quant = Decimal("1" if decimal_places == 0 else f"1.{('0' * decimal_places)}")
def to_python_value(self, value: Any) -> Optional[Decimal]:
if value is not None:
value = Decimal(value).quantize(self.quant).normalize()
self.validate(value)
return value
@property
def SQL_TYPE(self) -> str: # type: ignore
return f"DECIMAL({self.max_digits},{self.decimal_places})"
class _db_sqlite:
SQL_TYPE = "VARCHAR(40)"
def function_cast(self, term: Term) -> Term:
return functions.Cast(term, SqlTypes.NUMERIC)
# In case of queryset with filter `__year`/`__month`/`__day` ..., value can be int, float or str. Example:
# `await MyModel.filter(created_at__year=2024)`
# `await MyModel.filter(created_at__year=2024.0)`
# `await MyModel.filter(created_at__year='2024')`
DatetimeFieldQueryValueType = TypeVar(
"DatetimeFieldQueryValueType", datetime.datetime, int, float, str
)
[docs]class DatetimeField(Field[datetime.datetime], datetime.datetime):
"""
Datetime field.
``auto_now`` and ``auto_now_add`` is exclusive.
You can opt to set neither or only ONE of them.
``auto_now`` (bool):
Always set to ``datetime.utcnow()`` on save.
``auto_now_add`` (bool):
Set to ``datetime.utcnow()`` on first save only.
"""
SQL_TYPE = "TIMESTAMP"
class _db_mysql:
SQL_TYPE = "DATETIME(6)"
class _db_postgres:
SQL_TYPE = "TIMESTAMPTZ"
class _db_mssql:
SQL_TYPE = "DATETIME2"
class _db_oracle:
SQL_TYPE = "TIMESTAMP WITH TIME ZONE"
def __init__(self, auto_now: bool = False, auto_now_add: bool = False, **kwargs: Any) -> None:
if auto_now_add and auto_now:
raise ConfigurationError("You can choose only 'auto_now' or 'auto_now_add'")
super().__init__(**kwargs)
self.auto_now = auto_now
self.auto_now_add = auto_now | auto_now_add
def to_python_value(self, value: Any) -> Optional[datetime.datetime]:
if value is not None:
if isinstance(value, datetime.datetime):
value = value
elif isinstance(value, int):
value = datetime.datetime.fromtimestamp(value)
else:
value = parse_datetime(value)
if timezone.is_naive(value):
value = timezone.make_aware(value, get_timezone())
else:
value = localtime(value)
self.validate(value)
return value
def to_db_value(
self, value: Optional[DatetimeFieldQueryValueType], instance: "Union[Type[Model], Model]"
) -> Optional[DatetimeFieldQueryValueType]:
# Only do this if it is a Model instance, not class. Test for guaranteed instance var
if hasattr(instance, "_saved_in_db") and (
self.auto_now
or (self.auto_now_add and getattr(instance, self.model_field_name) is None)
):
now = timezone.now()
setattr(instance, self.model_field_name, now)
return now # type:ignore[return-value]
if value is not None:
if isinstance(value, datetime.datetime) and get_use_tz():
if timezone.is_naive(value):
warnings.warn(
"DateTimeField %s received a naive datetime (%s)"
" while time zone support is active." % (self.model_field_name, value),
RuntimeWarning,
)
value = timezone.make_aware(value, "UTC")
self.validate(value)
return value
@property
def constraints(self) -> dict:
data = {}
if self.auto_now_add:
data["readOnly"] = True
return data
[docs] def describe(self, serializable: bool) -> dict:
desc = super().describe(serializable)
desc["auto_now_add"] = self.auto_now_add
desc["auto_now"] = self.auto_now
return desc
[docs]class DateField(Field[datetime.date], datetime.date):
"""
Date field.
"""
skip_to_python_if_native = True
SQL_TYPE = "DATE"
def to_python_value(self, value: Any) -> Optional[datetime.date]:
if value is not None and not isinstance(value, datetime.date):
value = parse_datetime(value).date()
self.validate(value)
return value
def to_db_value(
self, value: Optional[Union[datetime.date, str]], instance: "Union[Type[Model], Model]"
) -> Optional[datetime.date]:
if value is not None and not isinstance(value, datetime.date):
value = parse_datetime(value).date()
self.validate(value)
return value
class TimeField(Field[datetime.time], datetime.time):
"""
Time field.
"""
skip_to_python_if_native = True
SQL_TYPE = "TIME"
class _db_oracle:
SQL_TYPE = "NVARCHAR2(8)"
def __init__(self, auto_now: bool = False, auto_now_add: bool = False, **kwargs: Any) -> None:
if auto_now_add and auto_now:
raise ConfigurationError("You can choose only 'auto_now' or 'auto_now_add'")
super().__init__(**kwargs)
self.auto_now = auto_now
self.auto_now_add = auto_now | auto_now_add
def to_python_value(self, value: Any) -> Optional[Union[datetime.time, datetime.timedelta]]:
if value is not None:
if isinstance(value, str):
value = datetime.time.fromisoformat(value)
if isinstance(value, datetime.timedelta):
return value
if timezone.is_naive(value):
value = value.replace(tzinfo=get_default_timezone())
self.validate(value)
return value
def to_db_value(
self,
value: Optional[Union[datetime.time, datetime.timedelta]],
instance: "Union[Type[Model], Model]",
) -> Optional[Union[datetime.time, datetime.timedelta]]:
# Only do this if it is a Model instance, not class. Test for guaranteed instance var
if hasattr(instance, "_saved_in_db") and (
self.auto_now
or (self.auto_now_add and getattr(instance, self.model_field_name) is None)
):
now = timezone.now().time()
setattr(instance, self.model_field_name, now)
return now
if value is not None:
if isinstance(value, datetime.timedelta):
return value
if get_use_tz():
if timezone.is_naive(value):
warnings.warn(
"TimeField %s received a naive time (%s)"
" while time zone support is active." % (self.model_field_name, value),
RuntimeWarning,
)
value = value.replace(tzinfo=get_default_timezone())
self.validate(value)
return value
class _db_mysql:
SQL_TYPE = "TIME(6)"
class _db_postgres:
SQL_TYPE = "TIMETZ"
[docs]class TimeDeltaField(Field[datetime.timedelta]):
"""
A field for storing time differences.
"""
field_type = datetime.timedelta
SQL_TYPE = "BIGINT"
class _db_oracle:
SQL_TYPE = "NUMBER(19)"
def to_python_value(self, value: Any) -> Optional[datetime.timedelta]:
self.validate(value)
if value is None or isinstance(value, datetime.timedelta):
return value
return datetime.timedelta(microseconds=value)
def to_db_value(
self, value: Optional[datetime.timedelta], instance: "Union[Type[Model], Model]"
) -> Optional[int]:
self.validate(value)
if value is None:
return None
return (value.days * 86400000000) + (value.seconds * 1000000) + value.microseconds
[docs]class FloatField(Field[float], float):
"""
Float (double) field.
"""
SQL_TYPE = "DOUBLE PRECISION"
class _db_sqlite:
SQL_TYPE = "REAL"
class _db_mysql:
SQL_TYPE = "DOUBLE"
[docs]class JSONField(Field[Union[dict, list]], dict, list): # type: ignore
"""
JSON field.
This field can store dictionaries or lists of any JSON-compliant structure.
You can specify your own custom JSON encoder/decoder, leaving at the default should work well.
If you have ``orjson`` installed, we default to using that,
else the default ``json`` module will be used.
``encoder``:
The custom JSON encoder.
``decoder``:
The custom JSON decoder.
"""
SQL_TYPE = "JSON"
indexable = False
class _db_postgres:
SQL_TYPE = "JSONB"
class _db_mssql:
SQL_TYPE = "NVARCHAR(MAX)"
class _db_oracle:
SQL_TYPE = "NCLOB"
def __init__(
self,
encoder: JsonDumpsFunc = JSON_DUMPS,
decoder: JsonLoadsFunc = JSON_LOADS,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
def to_db_value(
self, value: Optional[Union[dict, list, str, bytes]], instance: "Union[Type[Model], Model]"
) -> Optional[str]:
self.validate(value)
if value is not None:
if isinstance(value, (str, bytes)):
try:
self.decoder(value)
except Exception:
raise FieldError(f"Value {value!r} is invalid json value.")
if isinstance(value, bytes):
value = value.decode()
else:
value = self.encoder(value)
return value
def to_python_value(
self, value: Optional[Union[str, bytes, dict, list]]
) -> Optional[Union[dict, list]]:
if isinstance(value, (str, bytes)):
try:
return self.decoder(value)
except Exception:
raise FieldError(
f"Value {value if isinstance(value, str) else value.decode()} is invalid json value."
)
self.validate(value)
return value
[docs]class UUIDField(Field[UUID], UUID):
"""
UUID Field
This field can store uuid value.
If used as a primary key, it will auto-generate a UUID4 by default.
"""
SQL_TYPE = "CHAR(36)"
class _db_postgres:
SQL_TYPE = "UUID"
def __init__(self, **kwargs: Any) -> None:
if (kwargs.get("primary_key") or kwargs.get("pk", False)) and "default" not in kwargs:
kwargs["default"] = uuid4
super().__init__(**kwargs)
def to_db_value(self, value: Any, instance: "Union[Type[Model], Model]") -> Optional[str]:
return value and str(value)
def to_python_value(self, value: Any) -> Optional[UUID]:
if value is None or isinstance(value, UUID):
return value
return UUID(value)
[docs]class BinaryField(Field[bytes], bytes): # type: ignore
"""
Binary field.
This is for storing ``bytes`` objects.
Note that filter or queryset-update operations are not supported.
"""
indexable = False
SQL_TYPE = "BLOB"
class _db_postgres:
SQL_TYPE = "BYTEA"
class _db_mysql:
SQL_TYPE = "LONGBLOB"
class _db_mssql:
SQL_TYPE = "VARBINARY(MAX)"
class IntEnumFieldInstance(SmallIntField):
def __init__(
self,
enum_type: Type[IntEnum],
description: Optional[str] = None,
generated: bool = False,
**kwargs: Any,
) -> None:
# Validate values
minimum = 1 if generated else -32768
for item in enum_type:
try:
value = int(item.value)
except ValueError:
raise ConfigurationError("IntEnumField only supports integer enums!")
if not minimum <= value < 32768:
raise ConfigurationError(
"The valid range of IntEnumField's values is {}..32767!".format(minimum)
)
# Automatic description for the field if not specified by the user
if description is None:
description = "\n".join([f"{e.name}: {int(e.value)}" for e in enum_type])[:2048]
super().__init__(description=description, **kwargs)
self.enum_type = enum_type
def to_python_value(self, value: Union[int, None]) -> Union[IntEnum, None]:
value = self.enum_type(value) if value is not None else None
self.validate(value)
return value
def to_db_value(
self, value: Union[IntEnum, None, int], instance: "Union[Type[Model], Model]"
) -> Union[int, None]:
if isinstance(value, IntEnum):
value = int(value.value)
if isinstance(value, int):
value = int(self.enum_type(value))
self.validate(value)
return value
IntEnumType = TypeVar("IntEnumType", bound=IntEnum)
[docs]def IntEnumField(
enum_type: Type[IntEnumType],
description: Optional[str] = None,
**kwargs: Any,
) -> IntEnumType:
"""
Enum Field
A field representing an integer enumeration.
The description of the field is set automatically if not specified to a multiline list of
"name: value" pairs.
**Note**: Valid int value of ``enum_type`` is acceptable.
``enum_type``:
The enum class
``description``:
The description of the field. It is set automatically if not specified to a multiline list
of "name: value" pairs.
"""
return IntEnumFieldInstance(enum_type, description, **kwargs) # type: ignore
class CharEnumFieldInstance(CharField):
def __init__(
self,
enum_type: Type[Enum],
description: Optional[str] = None,
max_length: int = 0,
**kwargs: Any,
) -> None:
# Automatic description for the field if not specified by the user
if description is None:
description = "\n".join([f"{e.name}: {str(e.value)}" for e in enum_type])[:2048]
# Automatic CharField max_length
if max_length == 0:
for item in enum_type:
item_len = len(str(item.value))
if item_len > max_length:
max_length = item_len
super().__init__(description=description, max_length=max_length, **kwargs)
self.enum_type = enum_type
def to_python_value(self, value: Union[str, None]) -> Union[Enum, None]:
self.validate(value)
return self.enum_type(value) if value is not None else None
def to_db_value(
self, value: Union[Enum, None, str], instance: "Union[Type[Model], Model]"
) -> Union[str, None]:
self.validate(value)
if isinstance(value, Enum):
return str(value.value)
if isinstance(value, str):
return str(self.enum_type(value).value)
return value
CharEnumType = TypeVar("CharEnumType", bound=Enum)
[docs]def CharEnumField(
enum_type: Type[CharEnumType],
description: Optional[str] = None,
max_length: int = 0,
**kwargs: Any,
) -> CharEnumType:
"""
Char Enum Field
A field representing a character enumeration.
**Warning**: If ``max_length`` is not specified or equals to zero, the size of represented
char fields is automatically detected. So if later you update the enum, you need to update your
table schema as well.
**Note**: Valid str value of ``enum_type`` is acceptable.
``enum_type``:
The enum class
``description``:
The description of the field. It is set automatically if not specified to a multiline list
of "name: value" pairs.
``max_length``:
The length of the created CharField. If it is zero it is automatically detected from
enum_type.
"""
return CharEnumFieldInstance(enum_type, description, max_length, **kwargs) # type: ignore