Skip to content

Commit

Permalink
Make the exporter export the list of classes into the output for clas…
Browse files Browse the repository at this point in the history
…sification networks
  • Loading branch information
TrentHouliston committed Jun 4, 2021
1 parent 426fa8e commit cd9d08e
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions training/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,19 @@ def export(config, output_path):
first = tf.reshape(first, (-1, first.shape[-1]))
stages[0][0]["weights"] = first.numpy().tolist()

network = {
"geometry": {
"intersections": config["projection"]["config"]["geometry"]["intersections"],
"radius": config["projection"]["config"]["geometry"]["radius"],
"shape": config["projection"]["config"]["geometry"]["shape"],
},
"mesh": config["projection"]["config"]["mesh"]["model"],
"network": stages,
}

# Add classification meta data
if config["label"]["type"] == "Classification":
network["class_map"] = {c["name"]: i for i, c in enumerate(config["label"]["config"]["classes"])}

with open(os.path.join(output_path, "model.yaml"), "w") as out:
yaml.dump(
{
"mesh": config["projection"]["config"]["mesh"]["model"],
"geometry": {
"shape": config["projection"]["config"]["geometry"]["shape"],
"radius": config["projection"]["config"]["geometry"]["radius"],
"intersections": config["projection"]["config"]["geometry"]["intersections"],
},
"network": stages,
},
out,
default_flow_style=None,
width=float("inf"),
)
yaml.dump(network, out, default_flow_style=None, width=float("inf"))

0 comments on commit cd9d08e

Please sign in to comment.