Skip to content

Commit

Permalink
Improve the exporting of the yaml models
Browse files Browse the repository at this point in the history
  • Loading branch information
TrentHouliston committed Sep 4, 2020
1 parent c2a5de1 commit 6cb28d0
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions training/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,24 @@ def export(config, output_path):
# We have to run a predict step so that everything is loaded properly
model.predict(training_dataset.take(1))

# Print out the model summary
model.summary()

stages = []
for m in model.stages:
op = model.ops[m]

if type(op[0]) is GraphConvolution:
stages.append([])
op = op[0]
stages.append(
[
{
"weights": op.dense.weights[0].numpy().tolist(),
"biases": op.dense.weights[1].numpy().tolist(),
"activation": op.dense.activation.__name__,
}
]
)
elif type(op[0]) is tf.keras.layers.Dense:
op = op[0]
stages[-1].append(
Expand All @@ -79,4 +91,4 @@ def export(config, output_path):
stages[0][0]["weights"] = first.numpy().tolist()

with open(os.path.join(output_path, "model.yaml"), "w") as out:
yaml.dump(stages, out)
yaml.dump(stages, out, default_flow_style=None, width=1000000)

0 comments on commit 6cb28d0

Please sign in to comment.