Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
Signed-off-by: Kim, Vinnam <[email protected]>
  • Loading branch information
vinnamkim committed Apr 15, 2024
1 parent b52a33b commit c96258e
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 24 deletions.
2 changes: 1 addition & 1 deletion src/otx/core/model/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def __init__(
torch_compile=torch_compile,
)

from otx.algo.hooks.recording_forward_hook import get_feature_vector
from otx.algo.explain.explain_algo import get_feature_vector

self.model.feature_vector_fn = get_feature_vector
self.model.explain_fn = self.get_explain_fn()
Expand Down
2 changes: 1 addition & 1 deletion src/otx/core/model/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def __init__(
torch_compile=torch_compile,
)

from otx.algo.hooks.recording_forward_hook import get_feature_vector
from otx.algo.explain.explain_algo import get_feature_vector

self.model.feature_vector_fn = get_feature_vector
self.model.explain_fn = self.get_explain_fn()
Expand Down
11 changes: 0 additions & 11 deletions tests/unit/core/model/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,3 @@ def test_reset_restore_model_forward(self, otx_model):
otx_model._restore_model_forward()
assert otx_model.original_model_forward is None
assert str(otx_model.model.forward) == str(initial_model_forward)

def test_export_parameters(self, otx_model):
otx_model.image_size = (1, 64, 64, 3)
otx_model.explain_mode = False
parameters = otx_model._export_parameters
assert isinstance(parameters, dict)
assert "output_names" in parameters

otx_model.explain_mode = True
parameters = otx_model._export_parameters
assert parameters["output_names"] == ["feature_vector", "saliency_map"]
11 changes: 0 additions & 11 deletions tests/unit/core/model/test_inst_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,3 @@ def test_reset_restore_model_forward(self, otx_model):
otx_model._restore_model_forward()
assert otx_model.original_model_forward is None
assert str(otx_model.model.forward) == str(initial_model_forward)

def test_export_parameters(self, otx_model):
otx_model.image_size = (1, 64, 64, 3)
otx_model.explain_mode = False
parameters = otx_model._export_parameters
assert isinstance(parameters, dict)
assert "output_names" in parameters

otx_model.explain_mode = True
parameters = otx_model._export_parameters
assert parameters["output_names"] == ["feature_vector", "saliency_map"]

0 comments on commit c96258e

Please sign in to comment.