Skip to content

Commit

Permalink
Merge branch 'new-api' into hopper-change
Browse files Browse the repository at this point in the history
  • Loading branch information
brandontrabucco committed Jan 29, 2024
2 parents 449fb7f + ecbf4e9 commit cef252f
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 10 deletions.
4 changes: 2 additions & 2 deletions design_bench/datasets/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def __init__(self, x_shards, y_shards, internal_batch_size=32,
self._disable_transform = False
self._disable_subsample = False
self.dataset_visible_mask = np.full(
[self.dataset_size], True, dtype=np.bool)
[self.dataset_size], True, dtype=np.bool_)

# handle requests to normalize and subsample the dataset
if is_normalized_x:
Expand Down Expand Up @@ -1068,7 +1068,7 @@ def subsample(self, max_samples=None, distribution=None,
indices.size, max_samples, replace=False, p=probs / probs.sum())]

# binary mask that determines which samples are visible
visible_mask = np.full([y.shape[0]], False, dtype=np.bool)
visible_mask = np.full([y.shape[0]], False, dtype=np.bool_)
visible_mask[indices] = True
self.dataset_visible_mask = visible_mask
self.dataset_size = indices.size
Expand Down
52 changes: 44 additions & 8 deletions design_bench/oracles/exact/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,44 @@
from .hopper_controller_oracle import HopperControllerOracle
from .ant_morphology_oracle import AntMorphologyOracle
from .dkitty_morphology_oracle import DKittyMorphologyOracle
from .toy_continuous_oracle import ToyContinuousOracle
from .nas_bench_oracle import NASBenchOracle
from .tf_bind_8_oracle import TFBind8Oracle
from .tf_bind_10_oracle import TFBind10Oracle
from .toy_discrete_oracle import ToyDiscreteOracle
try:
from .ant_morphology_oracle import AntMorphologyOracle
except ImportError as e:
print("Skipping AntMorphologyOracle import:", e)

try:
from .cifar_nas_oracle import CIFARNASOracle
except ImportError as e:
print("Skipping CIFARNASOracle import:", e)

try:
from .dkitty_morphology_oracle import DKittyMorphologyOracle
except ImportError as e:
print("Skipping DKittyMorphologyOracle import:", e)

try:
from .hopper_controller_oracle import HopperControllerOracle
except ImportError as e:
print("Skipping HopperControllerOracle import:", e)

try:
from .nas_bench_oracle import NASBenchOracle
except ImportError as e:
print("Skipping NASBenchOracle import:", e)

try:
from .tf_bind_8_oracle import TFBind8Oracle
except ImportError as e:
print("Skipping TFBind8Oracle import:", e)

try:
from .tf_bind_10_oracle import TFBind10Oracle
except ImportError as e:
print("Skipping TFBind10Oracle import:", e)

try:
from .toy_continuous_oracle import ToyContinuousOracle
except ImportError as e:
print("Skipping ToyContinuousOracle import:", e)

try:
from .toy_discrete_oracle import ToyDiscreteOracle
except ImportError as e:
print("Skipping ToyDiscreteOracle import:", e)

0 comments on commit cef252f

Please sign in to comment.