diff --git a/app/main/forms.py b/app/main/forms.py index 10ba612d0c..044669c3dc 100644 --- a/app/main/forms.py +++ b/app/main/forms.py @@ -75,6 +75,7 @@ sentence_case, ) from app.main.validators import ( + CanEncode, CannotContainURLsOrLinks, CharactersNotAllowed, CommonlyUsedPassword, @@ -2532,11 +2533,16 @@ class CallbackForm(StripWhitespaceForm): r"(?:#[\w\-._~%!$&'()*+,;=:@/?]*)?$", message="Must be a valid https URL", ), + CanEncode(field_type="a web address"), ], ) bearer_token = GovukPasswordField( "Bearer token", - validators=[DataRequired(message="Cannot be empty"), Length(min=10, thing="the bearer token")], + validators=[ + DataRequired(message="Cannot be empty"), + Length(min=10, thing="the bearer token"), + CanEncode(field_type="a bearer token"), + ], ) def validate(self, *args, **kwargs): diff --git a/app/main/validators.py b/app/main/validators.py index 0124c8a2b9..0cb511d3fb 100644 --- a/app/main/validators.py +++ b/app/main/validators.py @@ -24,6 +24,46 @@ from app.utils.user import is_gov_user +class CanEncode: + """ + Validates that the field data can be encoded into a specific character set. + """ + + def __init__(self, encoding="latin-1", field_type=None, message=None): + self.encoding = encoding + self.field_type = field_type + self.message = message + + def __call__(self, form, field): + if field.data: + unsupported = set() + for char in field.data: + try: + char.encode(self.encoding) + except UnicodeEncodeError: + unsupported.add(char) + unsupported_char_list = list(unsupported) + if unsupported_char_list: + unsupported_char_list.sort() + + field_type = "this field" + if self.field_type is not None: + field_type = self.field_type + + if unsupported_char_list != []: + message = self.message + if message is None: + message = ( + "You cannot use {} in {}. You must use percent encoding if you want to include {}.".format( + formatted_list(unsupported_char_list, conjunction="or", before_each="", after_each=""), + field_type, + "these characters" if len(unsupported_char_list) > 1 else "this character", + ) + ) + + raise ValidationError(message) + + class CommonlyUsedPassword: def __init__(self, message=None): if not message: diff --git a/tests/app/main/test_forms.py b/tests/app/main/test_forms.py index cc5853aff3..9fb3542bf0 100644 --- a/tests/app/main/test_forms.py +++ b/tests/app/main/test_forms.py @@ -1,6 +1,7 @@ import pytest -from app.main.forms import OrderableFieldsForm, StripWhitespaceStringField +from app.main.forms import CallbackForm, OrderableFieldsForm, StripWhitespaceStringField +from app.main.validators import CanEncode from tests.conftest import set_config_values @@ -41,3 +42,9 @@ class TestForm(OrderableFieldsForm): with set_config_values(notify_admin, {"WTF_CSRF_ENABLED": True}): form = TestForm() assert [field.name for field in form] == ["csrf_token", "field2", "field1"] + + +def test_callbackform_has_can_encode_validators(notify_admin, client_request): + cbf = CallbackForm() + assert any(isinstance(x, CanEncode) for x in cbf.url.validators) + assert any(isinstance(x, CanEncode) for x in cbf.bearer_token.validators) diff --git a/tests/app/main/test_validators.py b/tests/app/main/test_validators.py index 1f63e6a11c..ffd3bbc143 100644 --- a/tests/app/main/test_validators.py +++ b/tests/app/main/test_validators.py @@ -5,6 +5,7 @@ from wtforms import ValidationError from app.main.validators import ( + CanEncode, CharactersNotAllowed, MustContainAlphanumericCharacters, NoCommasInPlaceHolders, @@ -217,3 +218,39 @@ def test_string_cannot_contain_string_with_custom_error_message(): assert str(error.value) == "No sequences please" assert mock_field.error_summary_messages == ["No sequences in %s please"] + + +@pytest.mark.parametrize( + "data, err_msg", + [ + ( + "📵 ghi", + "You cannot use 📵 in this field. You must use percent encoding if you want to include this character.", + ), + ( + "∆ abc 📲", + "You cannot use ∆ or 📲 in this field. You must use percent encoding if you want to include these characters.", # noqa + ), + ], +) +def test_can_encode_validation(data, err_msg, client_request): + with pytest.raises(ValidationError) as error: + CanEncode()(None, _gen_mock_field(data)) + + assert str(error.value) == err_msg + + +def test_string_can_encode_with_custom_field_type(): + mock_field = _gen_mock_field("∆ abc 📲", error_summary_messages=[]) + with pytest.raises(ValidationError) as error: + CanEncode(field_type="a web address")(None, mock_field) + + assert ( + str(error.value) + == "You cannot use ∆ or 📲 in a web address. You must use percent encoding if you want to include these characters." # noqa + ) + + +@pytest.mark.parametrize("string", ["", "Résumé", "München"]) +def test_string_can_encode_does_not_raise(string): + CanEncode()(None, _gen_mock_field(string))