Skip to content

Commit

Permalink
Add basic typing support
Browse files Browse the repository at this point in the history
Only `Factory.build()` and `Factory.create()` are properly typed,
provided the class is declared as `class UserFactory(Factory[User]):`.

Relies on mypy for tests.

Reviewed-By: Raphaël Barrois <[email protected]>
  • Loading branch information
last-partizan authored and rbarrois committed Jan 18, 2024
1 parent 69809cf commit 95dfa90
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 13 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ testall:
# DOC: Run tests for the currently installed version
# Remove cgi warning when dropping support for Django<=4.1.
test:
mypy --ignore-missing-imports tests/test_typing.py
python \
-b \
-X dev \
Expand Down
6 changes: 4 additions & 2 deletions factory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright: See the LICENSE file.

import sys

from .base import (
BaseDictFactory,
BaseListFactory,
Expand Down Expand Up @@ -70,10 +72,10 @@
pass

__author__ = 'Raphaël Barrois <[email protected]>'
try:
if sys.version_info >= (3, 8):
# Python 3.8+
import importlib.metadata as importlib_metadata
except ImportError:
else:
import importlib_metadata

__version__ = importlib_metadata.version("factory_boy")
21 changes: 14 additions & 7 deletions factory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
import collections
import logging
import warnings
from typing import Generic, List, TypeVar

from . import builder, declarations, enums, errors, utils

logger = logging.getLogger('factory.generate')

T = TypeVar('T')

# Factory metaclasses


Expand Down Expand Up @@ -405,7 +408,7 @@ def reset(self, next_value=0):
self.seq = next_value


class BaseFactory:
class BaseFactory(Generic[T]):
"""Factory base support for sequences, attributes and stubs."""

# Backwards compatibility
Expand Down Expand Up @@ -506,12 +509,12 @@ def _create(cls, model_class, *args, **kwargs):
return model_class(*args, **kwargs)

@classmethod
def build(cls, **kwargs):
def build(cls, **kwargs) -> T:
"""Build an instance of the associated class, with overridden attrs."""
return cls._generate(enums.BUILD_STRATEGY, kwargs)

@classmethod
def build_batch(cls, size, **kwargs):
def build_batch(cls, size: int, **kwargs) -> List[T]:
"""Build a batch of instances of the given class, with overridden attrs.
Args:
Expand All @@ -523,12 +526,12 @@ def build_batch(cls, size, **kwargs):
return [cls.build(**kwargs) for _ in range(size)]

@classmethod
def create(cls, **kwargs):
def create(cls, **kwargs) -> T:
"""Create an instance of the associated class, with overridden attrs."""
return cls._generate(enums.CREATE_STRATEGY, kwargs)

@classmethod
def create_batch(cls, size, **kwargs):
def create_batch(cls, size: int, **kwargs) -> List[T]:
"""Create a batch of instances of the given class, with overridden attrs.
Args:
Expand Down Expand Up @@ -627,18 +630,22 @@ def simple_generate_batch(cls, create, size, **kwargs):
return cls.generate_batch(strategy, size, **kwargs)


class Factory(BaseFactory, metaclass=FactoryMetaClass):
class Factory(BaseFactory[T], metaclass=FactoryMetaClass):
"""Factory base with build and create support.
This class has the ability to support multiple ORMs by using custom creation
functions.
"""

# Backwards compatibility
AssociatedClassError: type[Exception]

class Meta(BaseMeta):
pass


# Backwards compatibility
# Add the association after metaclass execution.
# Otherwise, AssociatedClassError would be detected as a declaration.
Factory.AssociatedClassError = errors.AssociatedClassError


Expand Down
7 changes: 4 additions & 3 deletions factory/django.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import os
import warnings
from typing import TypeVar

from django.contrib.auth.hashers import make_password
from django.core import files as django_files
Expand All @@ -20,9 +21,9 @@


DEFAULT_DB_ALIAS = 'default' # Same as django.db.DEFAULT_DB_ALIAS
T = TypeVar("T")


_LAZY_LOADS = {}
_LAZY_LOADS: dict[str, object] = {}


def get_model(app, model):
Expand Down Expand Up @@ -72,7 +73,7 @@ def get_model_class(self):
return self.model


class DjangoModelFactory(base.Factory):
class DjangoModelFactory(base.Factory[T]):
"""Factory for Django models.
This makes sure that the 'sequence' field of created objects is a new id.
Expand Down
2 changes: 1 addition & 1 deletion factory/faker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def evaluate(self, instance, step, extra):
subfaker = self._get_faker(locale)
return subfaker.format(self.provider, **extra)

_FAKER_REGISTRY = {}
_FAKER_REGISTRY: dict[str, faker.Faker] = {}
_DEFAULT_LOCALE = faker.config.DEFAULT_LOCALE

@classmethod
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ dev =
Django
flake8
isort
mypy
Pillow
SQLAlchemy
sqlalchemy_utils
Expand Down
28 changes: 28 additions & 0 deletions tests/test_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright: See the LICENSE file.

import dataclasses
import unittest
import typing

import factory


@dataclasses.dataclass
class User:
name: str
email: str
id: int


class TypingTests(unittest.TestCase):
def test_simple_factory(self) -> None:
class UserFactory(factory.Factory[User]):
name = "John Doe"
email = "[email protected]"
id = 42
class Meta:
model = User

result: User
result = UserFactory.build()
result = UserFactory.create()

0 comments on commit 95dfa90

Please sign in to comment.