Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

String null fix #2321

Merged
merged 4 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions flask_admin/contrib/sqla/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,18 +265,21 @@ def convert(self, model, mapper, name, prop, field_args, hidden_pk):
column=column, field_args=kwargs)
return None

@classmethod
def _nullable_common(cls, column, field_args):
if column.nullable:
filters = field_args.get('filters', [])
filters.append(lambda x: x or None)
field_args['filters'] = filters

@classmethod
def _string_common(cls, column, field_args, **extra):
if hasattr(column.type, 'length') and isinstance(column.type.length, int) and column.type.length:
field_args['validators'].append(validators.Length(max=column.type.length))
cls._nullable_common(column, field_args)

@converts('String') # includes VARCHAR, CHAR, and Unicode
def conv_String(self, column, field_args, **extra):
if column.nullable:
filters = field_args.get('filters', [])
filters.append(lambda x: x or None)
field_args['filters'] = filters

self._string_common(column=column, field_args=field_args, **extra)
return fields.StringField(**field_args)

Expand All @@ -288,9 +291,8 @@ def convert_enum(self, column, field_args, **extra):
if column.nullable:
field_args['allow_blank'] = column.nullable
accepted_values.append(None)
filters = field_args.get('filters', [])
filters.append(lambda x: x or None)
field_args['filters'] = filters

self._nullable_common(column, field_args)

field_args['choices'] = available_choices
field_args['validators'].append(validators.AnyOf(accepted_values))
Expand All @@ -310,9 +312,8 @@ def convert_choice_type(self, column, field_args, **extra):
if column.nullable:
field_args['allow_blank'] = column.nullable
accepted_values.append(None)
filters = field_args.get('filters', [])
filters.append(lambda x: x or None)
field_args['filters'] = filters

self._nullable_common(column, field_args)

field_args['choices'] = available_choices
field_args['validators'].append(validators.AnyOf(accepted_values))
Expand Down Expand Up @@ -347,10 +348,7 @@ def convert_arrow_time(self, field_args, **extra):

@converts('sqlalchemy_utils.types.email.EmailType')
def convert_email(self, field_args, column=None, **extra):
if column.nullable:
filters = field_args.get('filters', [])
filters.append(lambda x: x or None)
field_args['filters'] = filters
self._nullable_common(column, field_args)
field_args['validators'].append(validators.Email())
return fields.StringField(**field_args)

Expand Down
87 changes: 83 additions & 4 deletions flask_admin/tests/sqla/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ def test_model(app, db, admin):
model = db.session.query(Model1).first()
assert model.test1 == u'test1large'
assert model.test2 == u'test2'
assert model.test3 == u''
assert model.test4 == u''
assert model.test3 == None
assert model.test4 == None
assert model.email_field == u'[email protected]'
assert model.choice_field == u'choice-1'
assert model.enum_field == u'model1_v1'
Expand Down Expand Up @@ -285,8 +285,8 @@ def test_model(app, db, admin):
model = db.session.query(Model1).first()
assert model.test1 == 'test1small'
assert model.test2 == 'test2large'
assert model.test3 == ''
assert model.test4 == ''
assert model.test3 == None
assert model.test4 == None
assert model.email_field == u'[email protected]'
assert model.choice_field is None
assert model.enum_field is None
Expand Down Expand Up @@ -2639,3 +2639,82 @@ def test_export_csv(app, db, admin):
data = rv.data.decode('utf-8')
assert rv.status_code == 200
assert len(data.splitlines()) > 21


STRING_CONSTANT = "Anyway, here's Wonderwall"


def test_string_null_behavior(app, db, admin):
with app.app_context():
class StringTestModel(db.Model):
id = db.Column(db.Integer, primary_key=True)
test_no = db.Column(db.Integer, nullable=False)
string_field = db.Column(db.String)
string_field_nonull = db.Column(db.String, nullable=False)
string_field_nonull_default = db.Column(db.String, nullable=False, default='')
text_field = db.Column(db.Text)
text_field_nonull = db.Column(db.Text, nullable=False)
text_field_nonull_default = db.Column(db.Text, nullable=False, default='')

db.create_all()

view = CustomModelView(StringTestModel, db.session)
admin.add_view(view)

client = app.test_client()

valid_params = {
"test_no": 1,
"string_field_nonull": STRING_CONSTANT,
"text_field_nonull": STRING_CONSTANT,
}
rv = client.post('/admin/stringtestmodel/new/',
data=valid_params)
assert rv.status_code == 302

# Assert on defaults
valid_inst = db.session.query(StringTestModel).filter(StringTestModel.test_no == 1).one()
assert valid_inst.string_field is None
assert valid_inst.string_field_nonull == STRING_CONSTANT
assert valid_inst.string_field_nonull_default == ''
assert valid_inst.text_field is None
assert valid_inst.text_field_nonull == STRING_CONSTANT
assert valid_inst.text_field_nonull_default == ''

# Assert that nulls are caught on the non-null fields
invalid_string_field = {
"test_no": 2,
"string_field_nonull": None,
"text_field_nonull": STRING_CONSTANT,
}
rv = client.post('/admin/stringtestmodel/new/',
data=invalid_string_field)
assert rv.status_code == 200
assert b'This field is required.' in rv.data
assert db.session.query(StringTestModel).filter(StringTestModel.test_no == 2).all() == []

invalid_text_field = {
"test_no": 3,
"string_field_nonull": STRING_CONSTANT,
"text_field_nonull": None,
}
rv = client.post('/admin/stringtestmodel/new/',
data=invalid_text_field)
assert rv.status_code == 200
assert b'This field is required.' in rv.data
assert db.session.query(StringTestModel).filter(StringTestModel.test_no == 3).all() == []

# Assert that empty strings are converted to None on nullable fields.
empty_strings = {
"test_no": 4,
"string_field": "",
"text_field": "",
"string_field_nonull": STRING_CONSTANT,
"text_field_nonull": STRING_CONSTANT,
}
rv = client.post('/admin/stringtestmodel/new/',
data=empty_strings)
assert rv.status_code == 302
empty_string_inst = db.session.query(StringTestModel).filter(StringTestModel.test_no == 4).one()
assert empty_string_inst.string_field is None
assert empty_string_inst.text_field is None
Loading