diff --git a/IM/InfrastructureList.py b/IM/InfrastructureList.py index eff8673b..12e366d8 100644 --- a/IM/InfrastructureList.py +++ b/IM/InfrastructureList.py @@ -234,10 +234,12 @@ def _save_data_to_db(db_url, inf_list, inf_id=None): data = inf.serialize() if db.db_type == DataBase.MONGO: res = db.replace("inf_list", {"id": inf.id}, {"id": inf.id, "deleted": int(inf.deleted), - "data": data, "date": time.time()}) + "data": data, "date": time.time(), + "auth": inf.auth.serialize()}) else: - res = db.execute("replace into inf_list (id, deleted, data, date) values (%s, %s, %s, now())", - (inf.id, int(inf.deleted), data)) + res = db.execute("replace into inf_list (id, deleted, data, date, auth)" + " values (%s, %s, %s, now(), %s)", + (inf.id, int(inf.deleted), data, inf.auth.serialize())) db.close() return res diff --git a/IM/db.py b/IM/db.py index 8a27e388..1685330a 100644 --- a/IM/db.py +++ b/IM/db.py @@ -281,6 +281,17 @@ def replace(self, table_name, filt, replacement): res = self.connection[table_name].replace_one(filt, replacement, True) return res.modified_count == 1 or res.upserted_id is not None + def update(self, table_name, filt, updates): + """ insert/replace elements """ + if self.db_type != DataBase.MONGO: + raise Exception("Operation only supported in MongoDB") + + if self.connection is None: + raise Exception("DataBase object not connected") + else: + res = self.connection[table_name].update_one(filt, updates, True) + return res.modified_count == 1 or res.upserted_id is not None + def delete(self, table_name, filt): """ delete elements """ if self.db_type != DataBase.MONGO: diff --git a/scripts/db_1_14_X_to_1_15_X.py b/scripts/db_1_14_X_to_1_15_X.py index a12e082d..8fecaa44 100644 --- a/scripts/db_1_14_X_to_1_15_X.py +++ b/scripts/db_1_14_X_to_1_15_X.py @@ -77,7 +77,10 @@ def get_data_from_db(db, inf_id): print(inf_id) if inf: auth = inf.auth.serialize() - res = db.execute("UPDATE `inf_list` SET `auth` = %s WHERE id = %s", (auth, inf_id)) + if db.db_type == DataBase.MONGO: + res = db.update("inf_list", {"id": inf_id}, {"auth": auth}) + else: + res = db.execute("UPDATE `inf_list` SET `auth` = %s WHERE id = %s", (auth, inf_id)) except Exception as e: sys.stderr.write("Error updating auth field in Inf ID: %s. Ignoring.\n" % inf_id) else: diff --git a/test/unit/db.py b/test/unit/db.py index 51eedb43..c662dd03 100644 --- a/test/unit/db.py +++ b/test/unit/db.py @@ -63,7 +63,7 @@ def test_mysql_db(self, mdb_conn): db.close() - @patch('IM.db.MongoClient') + @patch('IM.db.MongoClient') def test_mongo_db(self, mongo): client = MagicMock() mongo.return_value = client @@ -88,6 +88,10 @@ def test_mongo_db(self, mongo): self.assertTrue(res) self.assertEqual(table.replace_one.call_args_list[0][0], ({}, {'data': 'test1', 'id': 1}, True)) + res = db.update('table', {'id': 1}, {'data': 'test1'}) + self.assertTrue(res) + self.assertEqual(table.update_one.call_args_list[0][0], ({'id': 1}, {'data': 'test1'}, True)) + table.find.return_value = [{'id': 2, 'data': 'test2', '_id': 2}] res = db.find('table', {'id': 2}, {'data': True}) self.assertEqual(len(res), 1)