diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index 0aecc914..eb98e9bd 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -6,7 +6,7 @@ from gflownet.utils.common import set_device, set_float_precision -class Policy(ABC): +class Policy: def __init__(self, config, env, device, float_precision, base=None): # Device and float precision self.device = set_device(device) @@ -26,12 +26,8 @@ def parse_config(self, config): # If config is null, default to uniform if config is None: config = OmegaConf.create() - config.type = "uniform" + self.type = config.get("type", "uniform") self.checkpoint = config.get("checkpoint", None) - if "type" in config: - self.type = config.type - else: - self.type = "uniform" def instantiate(self): if self.type == "fixed":