diff --git a/dj_database_url.py b/dj_database_url.py index a7f5da7..ba78d8b 100644 --- a/dj_database_url.py +++ b/dj_database_url.py @@ -1,17 +1,17 @@ import os import urllib.parse as urlparse - -try: - from django import VERSION as DJANGO_VERSION -except ImportError: - DJANGO_VERSION = None +from importlib import import_module def get_postgres_backend(): # Django deprecated the `django.db.backends.postgresql_psycopg2` in 2.0. # https://docs.djangoproject.com/en/stable/releases/2.0/#id1 - if DJANGO_VERSION and DJANGO_VERSION < (2, 0): - return "django.db.backends.postgresql_psycopg2" + try: + django = import_module("django") + if django.VERSION < (2, 0): + return "django.db.backends.postgresql_psycopg2" + except ModuleNotFoundError: + pass return "django.db.backends.postgresql" diff --git a/test_dj_database_url.py b/test_dj_database_url.py index ba58136..97dfcc7 100644 --- a/test_dj_database_url.py +++ b/test_dj_database_url.py @@ -1,6 +1,7 @@ import os import re import unittest +from types import SimpleNamespace from unittest.mock import patch from urllib.parse import uses_netloc @@ -120,14 +121,24 @@ def test_provide_conn_max_age__use_it_in_final_config(self): config = dj_database_url.parse(URL, CONN_MAX_AGE=600) assert config["CONN_MAX_AGE"] == 600 - @patch("dj_database_url.DJANGO_VERSION", (1, 11, 0, "final", 1)) - def test_django_version_pre_2__use_postgresql_psycopg2_backend(self): + @patch("dj_database_url.import_module") + def test_django_version_pre_2__use_postgresql_psycopg2_backend(self, mock): + version = (1, 11, 0, "final", 1) expected = "django.db.backends.postgresql_psycopg2" + mock.return_value = SimpleNamespace(VERSION=version) assert dj_database_url.get_postgres_backend() == expected - @patch("dj_database_url.DJANGO_VERSION", (3, 2, 0, "final", 0)) - def test_django_version_post_2__use_postgresql_backend(self): + @patch("dj_database_url.import_module") + def test_django_version_post_2__use_postgresql_backend(self, mock): + version = (3, 2, 0, "final", 0) expected = "django.db.backends.postgresql" + mock.return_value = SimpleNamespace(VERSION=version) + assert dj_database_url.get_postgres_backend() == expected + + @patch("dj_database_url.import_module") + def test_no_django_installed__use_postgresql_backend(self, mock): + expected = "django.db.backends.postgresql" + mock.side_effect = ModuleNotFoundError() assert dj_database_url.get_postgres_backend() == expected def test_register_multiple_times__no_duplicates_in_uses_netloc(self): @@ -137,7 +148,6 @@ def test_register_multiple_times__no_duplicates_in_uses_netloc(self): # that list is short and performs linear search on it. dj_database_url.register("django.contrib.db.backends.bag_end", "bag-end") dj_database_url.register("django.contrib.db.backends.bag_end", "bag-end") - assert len(uses_netloc) == len(set(uses_netloc)) def test_postgres_parsing(self):