diff --git a/extras_mongoengine/fields.py b/extras_mongoengine/fields.py index 5871ddf..05b7eb4 100644 --- a/extras_mongoengine/fields.py +++ b/extras_mongoengine/fields.py @@ -18,6 +18,8 @@ def to_mongo(self, value): return self.prepare_query_value(None, value) def to_python(self, value): + if isinstance(value, timedelta): + return value return timedelta(seconds=value) def prepare_query_value(self, op, value): diff --git a/tests/test_fields.py b/tests/test_fields.py index 8011b0f..a6d3b15 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -15,9 +15,18 @@ def total_seconds(self): raise AttributeError class TimedeltaFieldTestCase(unittest.TestCase): + def setUp(self): + connect(db='extrasmongoenginetest') + self.db = get_db() self.field = TimedeltaField() + def tearDown(self): + for collection in self.db.collection_names(): + if 'system.' in collection: + continue + self.db.drop_collection(collection) + def test_construct(self): self.assertIsInstance(self.field, TimedeltaField) @@ -29,6 +38,21 @@ def test_total_seconds_26(self): value = OldStyleTimedelta(minutes=1, seconds=10) self.assertEqual(self.field.total_seconds(value), 70) + def test_number_initialization(self): + class Doc(Document): + time = TimedeltaField() + + doc = Doc(time=3600).save() + self.assertEqual(doc.time, timedelta(hours=1)) + + def test_timedelta_initialization(self): + class Doc(Document): + time = TimedeltaField() + + test_time = timedelta(days=2) + doc = Doc(time=test_time).save() + self.assertEqual(doc.time, test_time) + class LowerStringFieldTestCase(unittest.TestCase): def setUp(self):