diff --git a/pyner/named_entity/inference.py b/pyner/named_entity/inference.py index 422a744..3d8ff76 100644 --- a/pyner/named_entity/inference.py +++ b/pyner/named_entity/inference.py @@ -1,6 +1,7 @@ import json import logging import pathlib +from typing import Optional import chainer import click @@ -15,10 +16,10 @@ @click.command() @click.argument("model") -@click.option("--epoch", type=int, required=True) +@click.option("--epoch", type=int) @click.option("--device", type=int, default=-1) @click.option("--metric", type=str, default="validation/main/fscore") -def run_inference(model: str, epoch: int, device: str, metric: str): +def run_inference(model: str, epoch: Optional[int], device: str, metric: str): chainer.config.train = False if device >= 0: @@ -29,7 +30,6 @@ def run_inference(model: str, epoch: int, device: str, metric: str): model_dir = pathlib.Path(model) configs = json.load(open(model_dir / "args")) - metric = metric.replace("/", ".") snapshot_file, prediction_path = select_snapshot( epoch, metric, model, model_dir) logger.debug(f"creat prediction into {prediction_path}")