Skip to content

Commit

Permalink
Merge pull request #134 from jazzband/fix-testing-with-redshift
Browse files Browse the repository at this point in the history
fix bugs for testing with Redshift (include a new feature; add columun with UNIQUE)
  • Loading branch information
shimizukawa authored Jul 17, 2024
2 parents 5a98834 + 2a108f5 commit 829b1dc
Show file tree
Hide file tree
Showing 10 changed files with 201 additions and 35 deletions.
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@ Features:
* #83 Drop Python-3.6 support.
* #127 Drop Python-3.7 support.
* #83 Drop Django-2.2 support.
* #134 Support adding COLUMN with UNIQUE; adding column without UNIQUE then add UNIQUE CONSTRAINT.

Bug Fixes:

* #134 inspectdb should suppress output 'id = AutoField(primary_key=True)'
* #134 fix for decreasing size of column with default by create-copy-drop-rename strategy.

3.0.0 (2022/02/27)
------------------

Expand Down
36 changes: 34 additions & 2 deletions django_redshift_backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,13 @@ def add_field(self, model, field):
"column": self.quote_name(field.column),
"definition": definition,
}

# ## Redshift
if field.unique:
# temporarily remove UNIQUE constraint from sql
# because Redshift can't add column with UNIQUE constraint.
sql = sql.rstrip(" UNIQUE")

# ## Redshift
if not field.null and self.effective_default(field) is None:
# Redshift Can't add NOT NULL column without DEFAULT.
Expand All @@ -321,6 +328,10 @@ def add_field(self, model, field):
# ## original BasePGDatabaseSchemaEditor.add_field has CREATE INDEX.
# ## Redshift doesn't support INDEX.

# Add UNIQUE constraints later
if field.unique:
self.deferred_sql.append(self._create_unique_sql(model, [field]))

# Add any FK constraints later
if (
field.remote_field
Expand Down Expand Up @@ -838,13 +849,17 @@ def _get_max_length(field):

old_max_length = _get_max_length(old_field)
new_max_length = _get_max_length(new_field)
decrease_size_with_default = old_default is not None and (
old_max_length > new_max_length
)

# Size is changed
if (
type(old_field) == type(new_field)
type(old_field) is type(new_field)
and old_max_length is not None
and new_max_length is not None
and old_max_length != new_max_length
and not decrease_size_with_default
):
# if shrink size as `old_field.max_length > new_field.max_length` and
# larger data in database, this change will raise exception.
Expand Down Expand Up @@ -910,7 +925,11 @@ def _get_max_length(field):
fragment = actions.pop(0)

# Type or default is changed?
elif (old_type != new_type) or needs_database_default:
elif (
(old_type != new_type)
or needs_database_default
or decrease_size_with_default
):
fragment, actions = self._alter_column_with_recreate(
model, old_field, new_field
)
Expand Down Expand Up @@ -1069,6 +1088,19 @@ class DatabaseCreation(BasePGDatabaseCreation):


class DatabaseIntrospection(BasePGDatabaseIntrospection):
# to avoid output 'id = meta.AutoField(primary_key=True)',
# return 'AutoField' for 'identity'.
def get_field_type(self, data_type, description):
field_type = super().get_field_type(data_type, description)
if description.default and "identity" in description.default:
if field_type == "IntegerField":
return "AutoField"
elif field_type == "BigIntegerField":
return "BigAutoField"
elif field_type == "SmallIntegerField":
return "SmallAutoField"
return field_type

def get_table_description(self, cursor, table_name):
"""
Return a description of the table with the DB-API cursor.description
Expand Down
6 changes: 6 additions & 0 deletions doc/dev.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ To test the database migration as well, start postgres and test it as follows::
$ docker-compose up -d
$ TEST_WITH_POSTGRES=1 tox

To test migrations with Redshift, do it as follows:

1. Create your redshift cruster on AWS
2. Get a redshift endpoint URI
3. run tox as: `TEST_WITH_REDSHIFT=redshift://user:password@<cluster>.<slug>.<region>.redshift.amazonaws.com:5439/<database>?DISABLE_SERVER_SIDE_CURSORS=True tox`

CI (Continuous Integration)
----------------------------

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ classifiers = [
"Topic :: Software Development :: Libraries :: Python Modules",
]
dependencies = [
"django",
"django<4.2",
]

[project.optional-dependencies]
Expand Down
52 changes: 52 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,55 @@

from django.apps import apps # noqa E402
apps.populate(['testapp'])


import contextlib
from unittest import mock

import pytest

from django_redshift_backend.base import BasePGDatabaseWrapper

TEST_WITH_POSTGRES = os.environ.get('TEST_WITH_POSTGRES')
TEST_WITH_REDSHIFT = os.environ.get('TEST_WITH_REDSHIFT')

skipif_no_database = pytest.mark.skipif(
not TEST_WITH_POSTGRES and not TEST_WITH_REDSHIFT,
reason="no TEST_WITH_POSTGRES/TEST_WITH_REDSHIFT are found",
)
run_only_postgres = pytest.mark.skipif(
not TEST_WITH_POSTGRES,
reason="Test only for postgres",
)
run_only_redshift = pytest.mark.skipif(
not TEST_WITH_REDSHIFT,
reason="Test only for redshift",
)

@contextlib.contextmanager
def postgres_fixture():
"""A context manager that patches the database backend to use PostgreSQL
for local testing.
The purpose of the postgres_fixture context manager is to conditionally
patch the database backend to use PostgreSQL for testing, but only if the
TEST_WITH_POSTGRES variable is set to True.
The reason for not using pytest.fixture in the current setup is due to the
use of classes that inherit from TestCase. pytest fixtures do not directly
integrate with Django's TestCase based tests.
"""
if TEST_WITH_POSTGRES:
with \
mock.patch(
'django_redshift_backend.base.DatabaseWrapper.data_types',
BasePGDatabaseWrapper.data_types,
), \
mock.patch(
'django_redshift_backend.base.DatabaseSchemaEditor._get_create_options',
lambda self, model: '',
):
yield

else:
yield
17 changes: 9 additions & 8 deletions tests/settings.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# -*- coding: utf-8 -*-
import os
import environ

if uri := os.environ.get("TEST_WITH_REDSHIFT"):
# use URI if it has least one charactor.
os.environ["DATABASE_URL"] = uri
else:
os.environ["DATABASE_URL"] = "redshift://user:password@localhost:5439/testing"
env = environ.Env()

DATABASES = {
'default': {
'ENGINE': 'django_redshift_backend',
'NAME': 'testing',
'USER': 'user',
'PASSWORD': 'password',
'HOST': 'localhost',
'PORT': '5439',
}
'default': env.db()
}

SECRET_KEY = '<key>'
11 changes: 4 additions & 7 deletions tests/test_inspectdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@

from django.db import connections
from django.core.management import call_command
import pytest

from django_redshift_backend.base import BasePGDatabaseWrapper
from test_base import OperationTestBase

from conftest import skipif_no_database, postgres_fixture

def norm_sql(sql):
return ' '.join(sql.split()).replace('( ', '(').replace(' )', ')').replace(' ;', ';')
Expand Down Expand Up @@ -188,8 +187,7 @@ def test_get_get_constraints_does_not_use_unsupported_functions(self):
self.assertEqual(self.expected_indexes_query, executed_sql)


@pytest.mark.skipif(not os.environ.get('TEST_WITH_POSTGRES'),
reason='to run, TEST_WITH_POSTGRES=1 tox')
@skipif_no_database
class InspectDbTests(OperationTestBase):
available_apps = []
databases = {'default'}
Expand All @@ -210,11 +208,10 @@ class Meta:
def tearDown(self):
self.cleanup_test_tables()

@mock.patch('django_redshift_backend.base.DatabaseWrapper.data_types', BasePGDatabaseWrapper.data_types)
@mock.patch('django_redshift_backend.base.DatabaseSchemaEditor._get_create_options', lambda self, model: '')
@postgres_fixture()
def test_inspectdb(self):
self.set_up_test_model('test')
out = StringIO()
call_command('inspectdb', stdout=out)
call_command('inspectdb', 'test_pony', stdout=out)
print(out.getvalue())
self.assertIn(self.expected_pony_model, out.getvalue())
Loading

0 comments on commit 829b1dc

Please sign in to comment.