diff --git a/app/api/attendees.py b/app/api/attendees.py index 1e100b7ff5..475eda611a 100644 --- a/app/api/attendees.py +++ b/app/api/attendees.py @@ -275,7 +275,7 @@ def before_update_object(self, obj, data, kwargs): obj.attendee_notes, data['attendee_notes'] ) - validate_custom_form_constraints_request( + data['complex_field_values'] = validate_custom_form_constraints_request( 'attendee', self.resource.schema, obj, data ) diff --git a/app/api/helpers/custom_forms.py b/app/api/helpers/custom_forms.py index e02a8582f8..febecc9ee0 100644 --- a/app/api/helpers/custom_forms.py +++ b/app/api/helpers/custom_forms.py @@ -1,3 +1,4 @@ +import marshmallow from flask_rest_jsonapi.schema import get_relationships from sqlalchemy import inspect @@ -9,21 +10,36 @@ def object_as_dict(obj): return {c.key: getattr(obj, c.key) for c in inspect(obj).mapper.column_attrs} +def get_schema(form_fields): + attrs = {} + + for field in form_fields: + if field.type in ['text', 'checkbox', 'select']: + field_type = marshmallow.fields.Str + elif field.type == 'email': + field_type = marshmallow.fields.Email + elif field.type == 'number': + field_type = marshmallow.fields.Float + else: + raise UnprocessableEntityError( + {'pointer': '/data/complex-field-values/' + field.identifier}, + 'Invalid Field Type: ' + field.type, + ) + attrs[field.identifier] = field_type(required=field.is_required) + return type('DynamicSchema', (marshmallow.Schema,), attrs) + + def validate_custom_form_constraints(form, obj): - required_form_fields = CustomForms.query.filter_by( - form=form, - event_id=obj.event_id, - is_included=True, - is_required=True, - deleted_at=None, - ) + form_fields = CustomForms.query.filter_by( + form=form, event_id=obj.event_id, is_included=True, deleted_at=None, + ).all() + required_form_fields = filter(lambda field: field.is_required, form_fields) missing_required_fields = [] - for field in required_form_fields.all(): + for field in required_form_fields: if not field.is_complex: if not getattr(obj, field.identifier): missing_required_fields.append(field.identifier) else: - if not (obj.complex_field_values or {}).get(field.identifier): missing_required_fields.append(field.identifier) @@ -33,6 +49,17 @@ def validate_custom_form_constraints(form, obj): f'Missing required fields {missing_required_fields}', ) + if obj.complex_field_values: + complex_form_fields = filter(lambda field: field.is_complex, form_fields) + schema = get_schema(complex_form_fields)() + + data, errors = schema.load(obj.complex_field_values) + + if errors: + raise UnprocessableEntityError({'errors': errors}, 'Schema Validation Error') + + return data + def validate_custom_form_constraints_request(form, schema, obj, data): new_obj = type(obj)(**object_as_dict(obj)) @@ -41,4 +68,4 @@ def validate_custom_form_constraints_request(form, schema, obj, data): if hasattr(new_obj, key) and key not in relationship_fields: setattr(new_obj, key, value) - validate_custom_form_constraints(form, new_obj) + return validate_custom_form_constraints(form, new_obj) diff --git a/tests/all/integration/api/attendee/test_attendee_api.py b/tests/all/integration/api/attendee/test_attendee_api.py index ed52e822a5..8d32ec9d18 100644 --- a/tests/all/integration/api/attendee/test_attendee_api.py +++ b/tests/all/integration/api/attendee/test_attendee_api.py @@ -163,6 +163,16 @@ def get_complex_custom_form_attendee(db): is_required=True, is_complex=True, ) + CustomForms( + event=attendee.event, + form='attendee', + field_identifier='transFatContent', + name='Trans Fat Content', + type='number', + is_included=True, + is_required=False, + is_complex=True, + ) db.session.commit() return attendee @@ -288,6 +298,48 @@ def test_custom_form_complex_fields_complete(db, client, jwt): assert attendee.complex_field_values['best_friend'] == 'Tester' +def test_ignore_complex_custom_form_fields(db, client, jwt): + """Test to see that extra data from complex JSON is dropped""" + attendee = get_complex_custom_form_attendee(db) + + data = json.dumps( + { + 'data': { + 'type': 'attendee', + 'id': str(attendee.id), + "attributes": { + "firstname": "Areeb", + "lastname": "Jamal", + "job-title": "Software Engineer", + "complex-field-values": { + "bestFriend": "Bester", + "transFat-content": 20.08, + "shalimar": "sophie", + }, + }, + } + } + ) + + response = client.patch( + f'/v1/attendees/{attendee.id}', + content_type='application/vnd.api+json', + headers=jwt, + data=data, + ) + + db.session.refresh(attendee) + + assert response.status_code == 200 + + assert attendee.firstname == 'Areeb' + assert attendee.lastname == 'Jamal' + assert attendee.job_title == 'Software Engineer' + assert attendee.complex_field_values['best_friend'] == 'Bester' + assert attendee.complex_field_values['trans_fat_content'] == 20.08 + assert attendee.complex_field_values.get('shalimar') is None + + def test_edit_attendee_ticket(db, client, jwt): attendee = AttendeeOrderTicketSubFactory() ticket = TicketSubFactory(event=attendee.event)