diff --git a/models/engine/db_storage.py b/models/engine/db_storage.py index b8e7d291e6f..e01ad3e71dc 100755 --- a/models/engine/db_storage.py +++ b/models/engine/db_storage.py @@ -74,3 +74,28 @@ def reload(self): def close(self): """call remove() method on the private session attribute""" self.__session.remove() + + + def get(self, cls, id): + """" a method to retrieve one object """ + if cls not in classes.values(): + return None + cls_value = models.storage.all(cls) + for value in cls_value.values(): + if (value.id == id): + return value + + return None + + def count(self, cls=None): + """ a method t count the number of object in storage """ + object_no = classes.values() + + if not cls: + i = 0; + for clas in object_no: + i += len(models.storage.all(clas).values()) + else: + i = len(models.storage.all(cls).values()) + + return i diff --git a/models/engine/file_storage.py b/models/engine/file_storage.py index c8cb8c1764d..19b65b7149e 100755 --- a/models/engine/file_storage.py +++ b/models/engine/file_storage.py @@ -65,6 +65,28 @@ def delete(self, obj=None): if key in self.__objects: del self.__objects[key] - def close(self): - """call reload() method for deserializing the JSON file to objects""" - self.reload() + + def get(self, cls, id): + """" a method to retrieve one object """ + if cls not in classes.values(): + return None + cls_value = models.storage.all(cls) + for value in cls_value.values(): + if (value.id == id): + return value + + return None + + def count(self, cls=None): + """ a method t count the number of object in storage """ + object_no = classes.values() + + if not cls: + i = 0; + for clas in object_no: + i += len(models.storage.all(clas).values()) + else: + i = len(models.storage.all(cls).values()) + + return i + diff --git a/tests/test_models/test_engine/test_db_storage.py b/tests/test_models/test_engine/test_db_storage.py index 766e625b5af..df010c64478 100755 --- a/tests/test_models/test_engine/test_db_storage.py +++ b/tests/test_models/test_engine/test_db_storage.py @@ -86,3 +86,24 @@ def test_new(self): @unittest.skipIf(models.storage_t != 'db', "not testing db storage") def test_save(self): """Test that save properly saves objects to file.json""" + + def test_get_db(self): + """ Tests method for obtaining an instance db storage""" + dic = {"name": "Calabar"} + instance = State(**dic) + storage.new(instance) + storage.save() + get_instance = storage.get(State, instance.id) + self.assertEqual(get_instance, instance) + + def test_count(self): + """ Tests count method db storage """ + dic = {"name": "Lagos"} + state = State(**dic) + storage.new(state) + dic = {"name": "Abuja", "state_id": state.id} + city = City(**dic) + storage.new(city) + storage.save() + c = storage.count() + self.assertEqual(len(storage.all()), c) diff --git a/tests/test_models/test_engine/test_file_storage.py b/tests/test_models/test_engine/test_file_storage.py index 1474a34fec0..d9a24294c4c 100755 --- a/tests/test_models/test_engine/test_file_storage.py +++ b/tests/test_models/test_engine/test_file_storage.py @@ -113,3 +113,12 @@ def test_save(self): with open("file.json", "r") as f: js = f.read() self.assertEqual(json.loads(string), json.loads(js)) + + + def test_filestorage_count_cls(self): + ''' + Tests the count method with class name + ''' + all_obj = models.storage.all('State') + count_all_obj = models.storage.count('State') + self.assertEqual(len(all_obj), count_all_obj)