Skip to content
This repository has been archived by the owner on Dec 21, 2023. It is now read-only.

[WIP] kwargs flow through for Image Similarity -> Nearest Neighbors #3233

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 83 additions & 16 deletions src/python/turicreate/test/test_image_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,21 +251,25 @@ def get_psnr(x, y):
)

# Get model distances for comparison
img = data[0:1][self.feature][0]
img_fixed = tc.image_analysis.resize(img, *reversed(self.input_image_shape))
tc_ret = self.model.query(img_fixed, k=data.num_rows())

if _mac_ver() >= (10, 13):
from PIL import Image as _PIL_Image

pil_img = _PIL_Image.fromarray(img_fixed.pixel_data)
coreml_ret = coreml_model.predict({"awesome_image": pil_img})

# Compare distances
coreml_distances = np.array(coreml_ret["distance"])
tc_distances = tc_ret.sort("reference_label")["distance"].to_numpy()
psnr_value = get_psnr(coreml_distances, tc_distances)
self.assertTrue(psnr_value > 50)
if self.feature == "awesome_image":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are you trying to accomplish by this change?

img = data[0:1][self.feature][0]
img_fixed = tc.image_analysis.resize(img, *reversed(self.input_image_shape))
tc_ret = self.model.query(img_fixed, k=data.num_rows())

if _mac_ver() >= (10, 13):
from PIL import Image as _PIL_Image

pil_img = _PIL_Image.fromarray(img_fixed.pixel_data)
coreml_ret = coreml_model.predict({"awesome_image": pil_img})

# Compare distances
coreml_distances = np.array(coreml_ret["distance"])
tc_distances = tc_ret.sort("reference_label")["distance"].to_numpy()
psnr_value = get_psnr(coreml_distances, tc_distances)
self.assertTrue(psnr_value > 50)
else:
# Broad else clause to ignore features not supported in coreml
pass

def test_save_and_load(self):
with test_util.TempDirectory() as filename:
Expand All @@ -287,6 +291,60 @@ def test_save_and_load(self):
print("Export coreml passed")


class ImageSimilarityTestWithKwargs(unittest.TestCase):
@classmethod
def setUpClass(self, input_image_shape=(3, 224, 224), model="resnet-50"):
"""
The setup class method for the basic test case with all default values.
"""
self.feature = "awesome_image"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need the majority of the code in this setUpClass. It looks like the only member variable that's actually getting used is self.model. I think everything else should be removed.

self.label = None
self.input_image_shape = input_image_shape
self.pre_trained_model = model

# Create the model
self.def_opts = {
"model": "resnet-50",
"verbose": True,
}

# Model
self.model = tc.image_similarity.create(
data, feature=self.feature, label=None, model=self.pre_trained_model,
method='lsh', distance='squared_euclidean'
)
self.nn_model = self.model.feature_extractor
self.lm_model = self.model.similarity_model
self.opts = self.def_opts.copy()

# Answers
self.get_ans = {
"similarity_model": lambda x: type(x)
== tc.nearest_neighbors.NearestNeighborsModel,
"feature": lambda x: x == self.feature,
"training_time": lambda x: x > 0,
"input_image_shape": lambda x: x == self.input_image_shape,
"label": lambda x: x == self.label,
"feature_extractor": lambda x: callable(x.extract_features),
"num_features": lambda x: x == self.lm_model.num_features,
"num_examples": lambda x: x == self.lm_model.num_examples,
"model": lambda x: (
x == self.pre_trained_model
or (
self.pre_trained_model == "VisionFeaturePrint_Screen"
and x == "VisionFeaturePrint_Scene"
)
),
}
self.fields_ans = self.get_ans.keys()

def assertModelWorks(self):
self.assertEqual(self.model.similarity_model.distance[0][1],
'squared_euclidean'
)



class ImageSimilaritySqueezeNetTest(ImageSimilarityTest):
@classmethod
def setUpClass(self):
Expand All @@ -306,7 +364,7 @@ def setUpClass(self):
)


# A test to gaurantee that old code using the incorrect name still works.
# A test to guarantee that old code using the incorrect name still works.
@unittest.skipIf(
_mac_ver() < (10, 14), "VisionFeaturePrint_Scene only supported on macOS 10.14+"
)
Expand All @@ -316,3 +374,12 @@ def setUpClass(self):
super(ImageSimilarityVisionFeaturePrintSceneTest_bad_name, self).setUpClass(
model="VisionFeaturePrint_Screen", input_image_shape=(3, 299, 299)
)


# A test to ensure kwargs are still accepted in create()
class ImageSimilarityCreateKwargsTest(ImageSimilarityTest):
@classmethod
def setUpClass(self):
super(ImageSimilarityCreateKwargsTest, self).setUpClass(
model="resnet-50", input_image_shape=(3, 300, 300)
)
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@


def create(
dataset, label=None, feature=None, model="resnet-50", verbose=True, batch_size=64
dataset,
label=None,
feature=None,
model="resnet-50",
verbose=True,
batch_size=64,
**kwargs
):
"""
Create a :class:`ImageSimilarityModel` model.
Expand Down Expand Up @@ -63,6 +69,9 @@ def create(
batch_size : int, optional
If you are getting memory errors, try decreasing this value. If you
have a powerful computer, increasing this value may improve performance.

**kwargs : optional
Options for downstream methods like :py:func:`tc.nearest_neighbors.create()`.

Returns
-------
Expand Down Expand Up @@ -148,6 +157,7 @@ def create(
label=label,
features=["__image_features__"],
verbose=verbose,
**kwargs
)

# set input image shape
Expand Down