Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gowthamkpr committed Oct 22, 2024
1 parent 6797231 commit 4845b6a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import keras
from keras import layers

from keras_hub.src.api_export import keras_hub_export
Expand Down Expand Up @@ -53,9 +54,16 @@ def __init__(
def get_config(self):
config = super().get_config()
config["fpn_channels"] = self.fpn_channels
config["image_encoder"] = self.image_encoder
config["image_encoder"] = keras.layers.serialize(self.image_encoder)
return config

@classmethod
def from_config(cls, config):
config["image_encoder"] = keras.layers.deserialize(
config["image_encoder"]
)
return cls(**config)


def diffbin_fpn_model(inputs, out_channels):
in2 = layers.Conv2D(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,40 @@
from keras_hub.src.models.differential_binarization.differential_binarization import (
DifferentialBinarization,
)
from keras_hub.src.models.differential_binarization.differential_binarization_backbone import (
DifferentialBinarizationBackbone,
)
from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import (
ResNetImageClassifierPreprocessor,
)
from keras_hub.src.tests.test_case import TestCase


class DifferentialBinarizationTest(TestCase):
def setUp(self):
self.images = ops.ones((2, 224, 224, 3))
self.labels = ops.zeros((2, 224, 224, 4))
self.backbone = ResNetBackbone(
image_encoder = ResNetBackbone(
input_conv_filters=[64],
input_conv_kernel_sizes=[7],
stackwise_num_filters=[64, 128, 256, 512],
stackwise_num_blocks=[3, 4, 6, 3],
stackwise_num_strides=[1, 2, 2, 2],
block_type="bottleneck_block",
image_shape=(224, 224, 3),
include_rescaling=False,
)
self.backbone = DifferentialBinarizationBackbone(
image_encoder=image_encoder
)
self.preprocessor = ResNetImageClassifierPreprocessor()
self.init_kwargs = {
"backbone": self.backbone,
"preprocessor": self.preprocessor,
}
self.train_data = (self.images, self.labels)

def test_basics(self):
pytest.skip(
reason="TODO: enable after preprocessor flow is figured out"
)
self.run_task_test(
cls=DifferentialBinarization,
init_kwargs=self.init_kwargs,
Expand Down

0 comments on commit 4845b6a

Please sign in to comment.