diff --git a/tenant_schemas_celery/app.py b/tenant_schemas_celery/app.py index c96e557..f6bc93c 100644 --- a/tenant_schemas_celery/app.py +++ b/tenant_schemas_celery/app.py @@ -76,10 +76,11 @@ def restore_schema(task, **kwargs): class CeleryApp(Celery): registry_cls = 'tenant_schemas_celery.registry:TenantTaskRegistry' - + task_cls = 'tenant_schemas_celery.task:TenantTask' + def create_task_cls(self): return self.subclass_with_self( - "tenant_schemas_celery.task:TenantTask", + self.task_cls, abstract=True, name="TenantTask", attribute="_app", diff --git a/tenant_schemas_celery/app_test.py b/tenant_schemas_celery/app_test.py new file mode 100644 index 0000000..58526f8 --- /dev/null +++ b/tenant_schemas_celery/app_test.py @@ -0,0 +1,32 @@ +from tenant_schemas_celery.app import CeleryApp +from tenant_schemas_celery.task import TenantTask + + +class DummyTask(TenantTask): + ... + + +def test_celery_app_should_allow_overriding_task_cls_as_object() -> None: + class App(CeleryApp): + task_cls = DummyTask + + app = App(set_as_current=False) + + @app.task() + def some_task() -> None: + ... + + assert isinstance(some_task, DummyTask) + + +def test_celery_app_should_allow_overriding_task_cls_as_string() -> None: + class App(CeleryApp): + task_cls = f"{DummyTask.__module__}:{DummyTask.__name__}" + + app = App(set_as_current=False) + + @app.task() + def some_task() -> None: + ... + + assert isinstance(some_task, DummyTask)