Skip to content

Commit

Permalink
Add support for JSONField
Browse files Browse the repository at this point in the history
Fixes #23
  • Loading branch information
timgraham committed Apr 16, 2023
1 parent 3cf90d1 commit 19cbd8f
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## 3.2 beta 1 - TBD

- Added support for `JSONField`.

- The `regex` lookup pattern is no longer implicitly anchored at both ends.

## 3.2 alpha 2 - 2022-03-03
Expand Down
16 changes: 14 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ DATABASES = {
- Snowflake doesn't support indexes. Thus, Django ignores any indexes defined
on models or fields.

- `JSONField` is [not supported](https://github.com/cedar-team/django-snowflake/issues/23).

- Snowflake doesn't support check constraints, so the various
`PositiveIntegerField` model fields allow negative values (though validation
at the form level still works).
Expand Down Expand Up @@ -95,6 +93,20 @@ if you encounter an issue worth documenting.
transactions to speed it up. A future version of Django (5.0 at the earliest)
may leverage Snowflake's single layer transactions to give some speed up.

* Due to snowflake-connector-python's [lack of VARIANT support](https://github.com/snowflakedb/snowflake-connector-python/issues/244),
some `JSONField` queries with complex JSON parameters [don't work](https://github.com/cedar-team/django-snowflake/issues/58).

For example, if `value` is a `JSONField`, this won't work:
```python
>>> JSONModel.objects.filter(value__k={"l": "m"})
```
A workaround is:
```python
>>> from django.db.models.expressions import RawSQL
>>> JSONModel.objects.filter(value__k=RawSQL("PARSE_JSON(%s)", ('{"l": "m"}',)))
```
In addition, ``QuerySet.bulk_update()`` isn't supported for `JSONField`.

* Interval math where the interval is a column
[is not supported](https://github.com/cedar-team/django-snowflake/issues/27).

Expand Down
2 changes: 2 additions & 0 deletions django_snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@
check_django_compatability()

from .functions import register_functions # noqa
from .lookups import register_lookups # noqa

register_functions()
register_lookups()
14 changes: 14 additions & 0 deletions django_snowflake/compiler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from itertools import chain

from django.db.models import JSONField
from django.db.models.sql import compiler


Expand All @@ -16,6 +17,7 @@ def as_sql(self):
fields = self.query.fields or [opts.pk]
result.append("(%s)" % ", ".join(qn(f.column) for f in fields))

select_columns = []
if self.query.fields:
value_rows = [
[
Expand All @@ -24,6 +26,15 @@ def as_sql(self):
]
for obj in self.query.objs
]
has_json_field = False
for i, field in enumerate(fields, 1):
if isinstance(field, JSONField):
has_json_field = True
select_columns.append(f'parse_json(${i})')
else:
select_columns.append(f'${i}')
if not has_json_field:
select_columns = []
else:
# An empty object.
value_rows = [
Expand Down Expand Up @@ -68,6 +79,9 @@ def as_sql(self):
params += [self.returning_params]
return [(" ".join(result), tuple(chain.from_iterable(params)))]

if select_columns:
result.append('SELECT ' + (", ".join(c for c in select_columns)) + ' FROM')

if can_bulk:
result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
if ignore_conflicts_suffix_sql:
Expand Down
49 changes: 46 additions & 3 deletions django_snowflake/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

class DatabaseFeatures(BaseDatabaseFeatures):
can_clone_databases = True
can_introspect_json_field = False
closed_cursor_error_class = InterfaceError
create_test_procedure_without_params_sql = """
CREATE PROCEDURE test_procedure() RETURNS varchar LANGUAGE JAVASCRIPT AS $$
Expand Down Expand Up @@ -35,8 +34,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
# This feature is specific to the Django fork used for testing.
supports_indexes = False
supports_index_column_ordering = False
# Not yet implemented in this backend.
supports_json_field = False
supports_json_field_contains = False
supports_over_clause = True
supports_partial_indexes = False
# https://docs.snowflake.com/en/sql-reference/functions-regexp.html#backreferences
Expand Down Expand Up @@ -89,6 +87,49 @@ class DatabaseFeatures(BaseDatabaseFeatures):
'expressions.tests.FTimeDeltaTests.test_date_subtraction',
'expressions.tests.FTimeDeltaTests.test_datetime_subtraction',
'expressions.tests.FTimeDeltaTests.test_time_subtraction',
# JSONField queries with complex JSON parameters don't work:
# https://github.com/cedar-team/django-snowflake/issues/58
# Query:
# WHERE "MODEL_FIELDS_NULLABLEJSONMODEL"."VALUE" = 'null'
# needs to operate as:
# WHERE "MODEL_FIELDS_NULLABLEJSONMODEL"."VALUE" = PARSE_JSON('null')
'model_fields.test_jsonfield.TestSaveLoad.test_json_null_different_from_sql_null',
# Query:
# WHERE TO_JSON("MODEL_FIELDS_NULLABLEJSONMODEL"."VALUE":k) = '{"l": "m"}'
# needs to operate as:
# WHERE TO_JSON("MODEL_FIELDS_NULLABLEJSONMODEL"."VALUE":k) = PARSE_JSON('{"l": "m"}')
'model_fields.test_jsonfield.TestQuerying.test_shallow_lookup_obj_target',
# Query:
# WHERE "MODEL_FIELDS_NULLABLEJSONMODEL"."VALUE" = '{"a": "b", "c": 14}'
# needs to operate as:
# WHERE "MODEL_FIELDS_NULLABLEJSONMODEL"."VALUE" = PARSE_JSON('{"a": "b", "c": 14}')
'model_fields.test_jsonfield.TestQuerying.test_exact_complex',
# Three cases:
# lookup='value__bar__in', value=[['foo', 'bar']]
# lookup='value__bar__in', value=[['foo', 'bar'], ['a']]
# lookup='value__bax__in', value=[{'foo': 'bar'}, {'a': 'b'}]
# Query:
# WHERE TO_JSON("MODEL_FIELDS_NULLABLEJSONMODEL"."VALUE":bar) IN ('["foo", "bar"]')
# needs to operate as:
# WHERE TO_JSON("MODEL_FIELDS_NULLABLEJSONMODEL"."VALUE":bar) IN (PARSE_JSON('["foo", "bar"]'))
'model_fields.test_jsonfield.TestQuerying.test_key_in',
# QuerySet.bulk_update() not supported for JSONField:
# Expression type does not match column data type, expecting VARIANT
# but got VARCHAR(16777216) for column JSON_FIELD
'queries.test_bulk_update.BulkUpdateTests.test_json_field',
# Server-side bug?
# CAST(TO_JSON("MODEL_FIELDS_NULLABLEJSONMODEL"."VALUE":d) AS VARIANT)
# gives '"[\\"e\\",{\\"f\\":\\"g\\"}]"' and appending [0] gives None.
# The expected result ('"e"') is given by:
# PARSE_JSON(TO_JSON("MODEL_FIELDS_NULLABLEJSONMODEL"."VALUE":d))[0]
# Possibly this backend could rewrite CAST(... AS VARIANT) to PARSE_JSON(...)?
'model_fields.test_jsonfield.TestQuerying.test_key_transform_annotation_expression',
'model_fields.test_jsonfield.TestQuerying.test_key_transform_expression',
'model_fields.test_jsonfield.TestQuerying.test_nested_key_transform_annotation_expression',
'model_fields.test_jsonfield.TestQuerying.test_nested_key_transform_expression',
# Fixed if TO_JSON is removed from the ORDER BY clause (or may be fine
# as is as some databases give the ordering that Snowflake does.)
'model_fields.test_jsonfield.TestQuerying.test_ordering_by_transform',
}

django_test_skips = {
Expand Down Expand Up @@ -155,6 +196,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
'expressions.tests.FTimeDeltaTests.test_datetime_subquery_subtraction',
'expressions_window.tests.WindowFunctionTests.test_subquery_row_range_rank',
'lookup.tests.LookupTests.test_nested_outerref_lhs',
'model_fields.test_jsonfield.TestQuerying.test_nested_key_transform_on_subquery',
'model_fields.test_jsonfield.TestQuerying.test_obj_subquery_lookup',
'queries.test_qs_combinators.QuerySetSetOperationTests.test_union_with_values_list_on_annotated_and_unannotated', # noqa
'queries.tests.ExcludeTest17600.test_exclude_plain',
'queries.tests.ExcludeTest17600.test_exclude_plain_distinct',
Expand Down
1 change: 1 addition & 0 deletions django_snowflake/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
'TIME': 'TimeField',
'TIMESTAMP_LTZ': 'DateTimeField',
'VARCHAR': 'CharField',
'VARIANT': 'JSONField',
}

def get_constraints(self, cursor, table_name):
Expand Down
71 changes: 71 additions & 0 deletions django_snowflake/lookups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from django.db.models.fields.json import (
HasKeyLookup, KeyTextTransform, KeyTransform,
)


def compile_json_path(key_transforms):
json_path = ''
for transform in key_transforms:
try:
idx = int(transform)
except ValueError: # non-integer
# The first separator must be a colon, otherwise a period.
separator = ':' if json_path == '' else '.'
# Escape quotes to protect against SQL injection.
transform = transform.replace('"', '\\"')
json_path += f'{separator}"{transform}"'
else:
# An integer lookup is an array index.
json_path += f'[{idx}]'
# Escape percent literals since snowflake-connector-python uses
# interpolation to bind parameters.
return json_path.replace('%', '%%')


def key_text_transform(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
return f'{lhs}{json_path}::VARCHAR', tuple(params)


def has_key_lookup(self, compiler, connection):
# Process JSON path from the left-hand side.
if isinstance(self.lhs, KeyTransform):
lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
compiler, connection
)
lhs_json_path = compile_json_path(lhs_key_transforms)
else:
lhs, lhs_params = self.process_lhs(compiler, connection)
lhs_json_path = ''
# Process JSON path from the right-hand side.
rhs = self.rhs
rhs_params = []
if not isinstance(rhs, (list, tuple)):
rhs = [rhs]
rhs_json_paths = []
for key in rhs:
if isinstance(key, KeyTransform):
*_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
else:
rhs_key_transforms = [key]
rhs_json_paths.append(compile_json_path(rhs_key_transforms))
# Add condition for each key.
if self.logical_operator:
sql = f'IS_NULL_VALUE({lhs}{lhs_json_path}%s) IS NOT NULL'
sql = "(%s)" % self.logical_operator.join(sql % path for path in rhs_json_paths)
else:
sql = f'IS_NULL_VALUE({lhs}{lhs_json_path}{rhs_json_paths[0]}) IS NOT NULL'
return sql, tuple(lhs_params) + tuple(rhs_params)


def key_transform(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
return f'TO_JSON({lhs}{json_path})', tuple(params)


def register_lookups():
HasKeyLookup.as_snowflake = has_key_lookup
KeyTextTransform.as_snowflake = key_text_transform
KeyTransform.as_snowflake = key_transform

0 comments on commit 19cbd8f

Please sign in to comment.