diff --git a/src/python/turicreate/test/test_image_similarity.py b/src/python/turicreate/test/test_image_similarity.py index 21bfae2b48..964c731e30 100644 --- a/src/python/turicreate/test/test_image_similarity.py +++ b/src/python/turicreate/test/test_image_similarity.py @@ -251,21 +251,25 @@ def get_psnr(x, y): ) # Get model distances for comparison - img = data[0:1][self.feature][0] - img_fixed = tc.image_analysis.resize(img, *reversed(self.input_image_shape)) - tc_ret = self.model.query(img_fixed, k=data.num_rows()) - - if _mac_ver() >= (10, 13): - from PIL import Image as _PIL_Image - - pil_img = _PIL_Image.fromarray(img_fixed.pixel_data) - coreml_ret = coreml_model.predict({"awesome_image": pil_img}) - - # Compare distances - coreml_distances = np.array(coreml_ret["distance"]) - tc_distances = tc_ret.sort("reference_label")["distance"].to_numpy() - psnr_value = get_psnr(coreml_distances, tc_distances) - self.assertTrue(psnr_value > 50) + if self.feature == "awesome_image": + img = data[0:1][self.feature][0] + img_fixed = tc.image_analysis.resize(img, *reversed(self.input_image_shape)) + tc_ret = self.model.query(img_fixed, k=data.num_rows()) + + if _mac_ver() >= (10, 13): + from PIL import Image as _PIL_Image + + pil_img = _PIL_Image.fromarray(img_fixed.pixel_data) + coreml_ret = coreml_model.predict({"awesome_image": pil_img}) + + # Compare distances + coreml_distances = np.array(coreml_ret["distance"]) + tc_distances = tc_ret.sort("reference_label")["distance"].to_numpy() + psnr_value = get_psnr(coreml_distances, tc_distances) + self.assertTrue(psnr_value > 50) + else: + # Broad else clause to ignore features not supported in coreml + pass def test_save_and_load(self): with test_util.TempDirectory() as filename: @@ -287,6 +291,60 @@ def test_save_and_load(self): print("Export coreml passed") +class ImageSimilarityTestWithKwargs(unittest.TestCase): + @classmethod + def setUpClass(self, input_image_shape=(3, 224, 224), model="resnet-50"): + """ + The setup class method for the basic test case with all default values. + """ + self.feature = "awesome_image" + self.label = None + self.input_image_shape = input_image_shape + self.pre_trained_model = model + + # Create the model + self.def_opts = { + "model": "resnet-50", + "verbose": True, + } + + # Model + self.model = tc.image_similarity.create( + data, feature=self.feature, label=None, model=self.pre_trained_model, + method='lsh', distance='squared_euclidean' + ) + self.nn_model = self.model.feature_extractor + self.lm_model = self.model.similarity_model + self.opts = self.def_opts.copy() + + # Answers + self.get_ans = { + "similarity_model": lambda x: type(x) + == tc.nearest_neighbors.NearestNeighborsModel, + "feature": lambda x: x == self.feature, + "training_time": lambda x: x > 0, + "input_image_shape": lambda x: x == self.input_image_shape, + "label": lambda x: x == self.label, + "feature_extractor": lambda x: callable(x.extract_features), + "num_features": lambda x: x == self.lm_model.num_features, + "num_examples": lambda x: x == self.lm_model.num_examples, + "model": lambda x: ( + x == self.pre_trained_model + or ( + self.pre_trained_model == "VisionFeaturePrint_Screen" + and x == "VisionFeaturePrint_Scene" + ) + ), + } + self.fields_ans = self.get_ans.keys() + + def assertModelWorks(self): + self.assertEqual(self.model.similarity_model.distance[0][1], + 'squared_euclidean' + ) + + + class ImageSimilaritySqueezeNetTest(ImageSimilarityTest): @classmethod def setUpClass(self): @@ -306,7 +364,7 @@ def setUpClass(self): ) -# A test to gaurantee that old code using the incorrect name still works. +# A test to guarantee that old code using the incorrect name still works. @unittest.skipIf( _mac_ver() < (10, 14), "VisionFeaturePrint_Scene only supported on macOS 10.14+" ) @@ -316,3 +374,12 @@ def setUpClass(self): super(ImageSimilarityVisionFeaturePrintSceneTest_bad_name, self).setUpClass( model="VisionFeaturePrint_Screen", input_image_shape=(3, 299, 299) ) + + +# A test to ensure kwargs are still accepted in create() +class ImageSimilarityCreateKwargsTest(ImageSimilarityTest): + @classmethod + def setUpClass(self): + super(ImageSimilarityCreateKwargsTest, self).setUpClass( + model="resnet-50", input_image_shape=(3, 300, 300) + ) diff --git a/src/python/turicreate/toolkits/image_similarity/image_similarity.py b/src/python/turicreate/toolkits/image_similarity/image_similarity.py index bab3017153..02011a93c3 100644 --- a/src/python/turicreate/toolkits/image_similarity/image_similarity.py +++ b/src/python/turicreate/toolkits/image_similarity/image_similarity.py @@ -22,7 +22,13 @@ def create( - dataset, label=None, feature=None, model="resnet-50", verbose=True, batch_size=64 + dataset, + label=None, + feature=None, + model="resnet-50", + verbose=True, + batch_size=64, + **kwargs ): """ Create a :class:`ImageSimilarityModel` model. @@ -63,6 +69,9 @@ def create( batch_size : int, optional If you are getting memory errors, try decreasing this value. If you have a powerful computer, increasing this value may improve performance. + + **kwargs : optional + Options for downstream methods like :py:func:`tc.nearest_neighbors.create()`. Returns ------- @@ -148,6 +157,7 @@ def create( label=label, features=["__image_features__"], verbose=verbose, + **kwargs ) # set input image shape