diff --git a/flask_admin/contrib/sqla/form.py b/flask_admin/contrib/sqla/form.py index 6bc42f7e7..543342ae3 100644 --- a/flask_admin/contrib/sqla/form.py +++ b/flask_admin/contrib/sqla/form.py @@ -265,18 +265,21 @@ def convert(self, model, mapper, name, prop, field_args, hidden_pk): column=column, field_args=kwargs) return None + @classmethod + def _nullable_common(cls, column, field_args): + if column.nullable: + filters = field_args.get('filters', []) + filters.append(lambda x: x or None) + field_args['filters'] = filters + @classmethod def _string_common(cls, column, field_args, **extra): if hasattr(column.type, 'length') and isinstance(column.type.length, int) and column.type.length: field_args['validators'].append(validators.Length(max=column.type.length)) + cls._nullable_common(column, field_args) @converts('String') # includes VARCHAR, CHAR, and Unicode def conv_String(self, column, field_args, **extra): - if column.nullable: - filters = field_args.get('filters', []) - filters.append(lambda x: x or None) - field_args['filters'] = filters - self._string_common(column=column, field_args=field_args, **extra) return fields.StringField(**field_args) @@ -288,9 +291,8 @@ def convert_enum(self, column, field_args, **extra): if column.nullable: field_args['allow_blank'] = column.nullable accepted_values.append(None) - filters = field_args.get('filters', []) - filters.append(lambda x: x or None) - field_args['filters'] = filters + + self._nullable_common(column, field_args) field_args['choices'] = available_choices field_args['validators'].append(validators.AnyOf(accepted_values)) @@ -310,9 +312,8 @@ def convert_choice_type(self, column, field_args, **extra): if column.nullable: field_args['allow_blank'] = column.nullable accepted_values.append(None) - filters = field_args.get('filters', []) - filters.append(lambda x: x or None) - field_args['filters'] = filters + + self._nullable_common(column, field_args) field_args['choices'] = available_choices field_args['validators'].append(validators.AnyOf(accepted_values)) @@ -347,10 +348,7 @@ def convert_arrow_time(self, field_args, **extra): @converts('sqlalchemy_utils.types.email.EmailType') def convert_email(self, field_args, column=None, **extra): - if column.nullable: - filters = field_args.get('filters', []) - filters.append(lambda x: x or None) - field_args['filters'] = filters + self._nullable_common(column, field_args) field_args['validators'].append(validators.Email()) return fields.StringField(**field_args) diff --git a/flask_admin/tests/sqla/test_basic.py b/flask_admin/tests/sqla/test_basic.py index 8fe1785f2..b0fb2f348 100644 --- a/flask_admin/tests/sqla/test_basic.py +++ b/flask_admin/tests/sqla/test_basic.py @@ -235,8 +235,8 @@ def test_model(app, db, admin): model = db.session.query(Model1).first() assert model.test1 == u'test1large' assert model.test2 == u'test2' - assert model.test3 == u'' - assert model.test4 == u'' + assert model.test3 == None + assert model.test4 == None assert model.email_field == u'test@test.com' assert model.choice_field == u'choice-1' assert model.enum_field == u'model1_v1' @@ -285,8 +285,8 @@ def test_model(app, db, admin): model = db.session.query(Model1).first() assert model.test1 == 'test1small' assert model.test2 == 'test2large' - assert model.test3 == '' - assert model.test4 == '' + assert model.test3 == None + assert model.test4 == None assert model.email_field == u'test2@test.com' assert model.choice_field is None assert model.enum_field is None @@ -2639,3 +2639,82 @@ def test_export_csv(app, db, admin): data = rv.data.decode('utf-8') assert rv.status_code == 200 assert len(data.splitlines()) > 21 + + +STRING_CONSTANT = "Anyway, here's Wonderwall" + + +def test_string_null_behavior(app, db, admin): + with app.app_context(): + class StringTestModel(db.Model): + id = db.Column(db.Integer, primary_key=True) + test_no = db.Column(db.Integer, nullable=False) + string_field = db.Column(db.String) + string_field_nonull = db.Column(db.String, nullable=False) + string_field_nonull_default = db.Column(db.String, nullable=False, default='') + text_field = db.Column(db.Text) + text_field_nonull = db.Column(db.Text, nullable=False) + text_field_nonull_default = db.Column(db.Text, nullable=False, default='') + + db.create_all() + + view = CustomModelView(StringTestModel, db.session) + admin.add_view(view) + + client = app.test_client() + + valid_params = { + "test_no": 1, + "string_field_nonull": STRING_CONSTANT, + "text_field_nonull": STRING_CONSTANT, + } + rv = client.post('/admin/stringtestmodel/new/', + data=valid_params) + assert rv.status_code == 302 + + # Assert on defaults + valid_inst = db.session.query(StringTestModel).filter(StringTestModel.test_no == 1).one() + assert valid_inst.string_field is None + assert valid_inst.string_field_nonull == STRING_CONSTANT + assert valid_inst.string_field_nonull_default == '' + assert valid_inst.text_field is None + assert valid_inst.text_field_nonull == STRING_CONSTANT + assert valid_inst.text_field_nonull_default == '' + + # Assert that nulls are caught on the non-null fields + invalid_string_field = { + "test_no": 2, + "string_field_nonull": None, + "text_field_nonull": STRING_CONSTANT, + } + rv = client.post('/admin/stringtestmodel/new/', + data=invalid_string_field) + assert rv.status_code == 200 + assert b'This field is required.' in rv.data + assert db.session.query(StringTestModel).filter(StringTestModel.test_no == 2).all() == [] + + invalid_text_field = { + "test_no": 3, + "string_field_nonull": STRING_CONSTANT, + "text_field_nonull": None, + } + rv = client.post('/admin/stringtestmodel/new/', + data=invalid_text_field) + assert rv.status_code == 200 + assert b'This field is required.' in rv.data + assert db.session.query(StringTestModel).filter(StringTestModel.test_no == 3).all() == [] + + # Assert that empty strings are converted to None on nullable fields. + empty_strings = { + "test_no": 4, + "string_field": "", + "text_field": "", + "string_field_nonull": STRING_CONSTANT, + "text_field_nonull": STRING_CONSTANT, + } + rv = client.post('/admin/stringtestmodel/new/', + data=empty_strings) + assert rv.status_code == 302 + empty_string_inst = db.session.query(StringTestModel).filter(StringTestModel.test_no == 4).one() + assert empty_string_inst.string_field is None + assert empty_string_inst.text_field is None