diff --git a/training/export.py b/training/export.py index 65aaf2e..d6d0361 100644 --- a/training/export.py +++ b/training/export.py @@ -88,18 +88,19 @@ def export(config, output_path): first = tf.reshape(first, (-1, first.shape[-1])) stages[0][0]["weights"] = first.numpy().tolist() + network = { + "geometry": { + "intersections": config["projection"]["config"]["geometry"]["intersections"], + "radius": config["projection"]["config"]["geometry"]["radius"], + "shape": config["projection"]["config"]["geometry"]["shape"], + }, + "mesh": config["projection"]["config"]["mesh"]["model"], + "network": stages, + } + + # Add classification meta data + if config["label"]["type"] == "Classification": + network["class_map"] = {c["name"]: i for i, c in enumerate(config["label"]["config"]["classes"])} + with open(os.path.join(output_path, "model.yaml"), "w") as out: - yaml.dump( - { - "mesh": config["projection"]["config"]["mesh"]["model"], - "geometry": { - "shape": config["projection"]["config"]["geometry"]["shape"], - "radius": config["projection"]["config"]["geometry"]["radius"], - "intersections": config["projection"]["config"]["geometry"]["intersections"], - }, - "network": stages, - }, - out, - default_flow_style=None, - width=float("inf"), - ) + yaml.dump(network, out, default_flow_style=None, width=float("inf"))