diff --git a/pyzoo/zoo/examples/orca/learn/tf/image_segmentation/README.md b/pyzoo/zoo/examples/orca/learn/tf/image_segmentation/README.md index febd444cfd..d9dd1b63df 100644 --- a/pyzoo/zoo/examples/orca/learn/tf/image_segmentation/README.md +++ b/pyzoo/zoo/examples/orca/learn/tf/image_segmentation/README.md @@ -24,12 +24,18 @@ You should manually download the dataset from kaggle [carvana-image-masking-chal ## Run example on local ```bash -python image_segmentation.py --cluster_mode local +# linux +python image_segmentation.py --cluster_mode local +# macos +python image_segmentation.py --cluster_mode local --platform mac ``` ## Run example on yarn cluster ```bash -python image_segmentation.py --cluster_mode yarn +# linux +python image_segmentation.py --cluster_mode yarn +# macos +python image_segmentation.py --cluster_mode yarn --platform mac ``` Options @@ -37,3 +43,5 @@ Options * `--file_path` The path to carvana train.zip, train_mask.zip and train_mask.csv.zip. Default to be `/tmp/carvana/`. * `--epochs` The number of epochs to train the model. Default to be 8. * `--batch_size` Batch size for training and prediction. Default to be 8. +* `--platform` The platform you used to run the example. Default to be `linux`. You should pass `mac` if you use macos. +* `--non_interactive` Flag to not visualize the result. diff --git a/pyzoo/zoo/examples/orca/learn/tf/image_segmentation/image_segmentation.py b/pyzoo/zoo/examples/orca/learn/tf/image_segmentation/image_segmentation.py index 24e2508410..af8120a0b1 100644 --- a/pyzoo/zoo/examples/orca/learn/tf/image_segmentation/image_segmentation.py +++ b/pyzoo/zoo/examples/orca/learn/tf/image_segmentation/image_segmentation.py @@ -46,7 +46,11 @@ def load_data(file_path): load_data_from_zip(file_path, 'train_masks.csv.zip') -def main(cluster_mode, max_epoch, file_path, batch_size): +def main(cluster_mode, max_epoch, file_path, batch_size, platform, non_interactive): + import matplotlib + if not non_interactive and platform == "mac": + matplotlib.use('qt5agg') + if cluster_mode == "local": init_orca_context(cluster_mode="local", cores=4, memory="3g") elif cluster_mode == "yarn": @@ -174,25 +178,27 @@ def bce_dice_loss(y_true, y_pred): val_image_label = val_shards.collect()[0] val_image = val_image_label["x"] val_label = val_image_label["y"] - # visualize 5 predicted results - plt.figure(figsize=(10, 20)) - for i in range(5): - img = val_image[i] - label = val_label[i] - predicted_label = pred[i] - - plt.subplot(5, 3, 3 * i + 1) - plt.imshow(img) - plt.title("Input image") - - plt.subplot(5, 3, 3 * i + 2) - plt.imshow(label[:, :, 0], cmap='gray') - plt.title("Actual Mask") - plt.subplot(5, 3, 3 * i + 3) - plt.imshow(predicted_label, cmap='gray') - plt.title("Predicted Mask") - plt.suptitle("Examples of Input Image, Label, and Prediction") - plt.show() + if not non_interactive: + # visualize 5 predicted results + plt.figure(figsize=(10, 20)) + for i in range(5): + img = val_image[i] + label = val_label[i] + predicted_label = pred[i] + + plt.subplot(5, 3, 3 * i + 1) + plt.imshow(img) + plt.title("Input image") + + plt.subplot(5, 3, 3 * i + 2) + plt.imshow(label[:, :, 0], cmap='gray') + plt.title("Actual Mask") + plt.subplot(5, 3, 3 * i + 3) + plt.imshow(predicted_label, cmap='gray') + plt.title("Predicted Mask") + plt.suptitle("Examples of Input Image, Label, and Prediction") + + plt.show() stop_orca_context() @@ -207,6 +213,11 @@ def bce_dice_loss(y_true, y_pred): help="The number of epochs to train the model") parser.add_argument('--batch_size', type=int, default=8, help="Batch size for training and prediction") + parser.add_argument('--platform', type=str, default="linux", + help="The platform you used. Only linux and mac are supported.") + parser.add_argument('--non_interactive', default=False, action="store_true", + help="Flag to not visualize the result.") args = parser.parse_args() - main(args.cluster_mode, args.epochs, args.file_path, args.batch_size) + main(args.cluster_mode, args.epochs, args.file_path, args.batch_size, args.platform, + args.non_interactive) diff --git a/pyzoo/zoo/examples/run-example-tests-pip.sh b/pyzoo/zoo/examples/run-example-tests-pip.sh index d00def7f9f..b514930c40 100644 --- a/pyzoo/zoo/examples/run-example-tests-pip.sh +++ b/pyzoo/zoo/examples/run-example-tests-pip.sh @@ -602,7 +602,7 @@ fi # Run the example export SPARK_DRIVER_MEMORY=3g python ${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/orca/learn/tf/image_segmentation/image_segmentation.py \ - --file_path analytics-zoo-data/data/carvana + --file_path analytics-zoo-data/data/carvana --epochs 1 --non_interactive exit_status=$? if [ $exit_status -ne 0 ]; then diff --git a/pyzoo/zoo/examples/run-example-tests.sh b/pyzoo/zoo/examples/run-example-tests.sh index 540e45f84e..db31bad7b5 100644 --- a/pyzoo/zoo/examples/run-example-tests.sh +++ b/pyzoo/zoo/examples/run-example-tests.sh @@ -622,7 +622,7 @@ ${ANALYTICS_ZOO_HOME}/bin/spark-submit-python-with-zoo.sh \ --driver-memory 3g \ --executor-memory 3g \ ${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/orca/learn/tf/image_segmentation/image_segmentation.py \ - --file_path analytics-zoo-data/data/carvana --epochs 1 + --file_path analytics-zoo-data/data/carvana --epochs 1 --non_interactive exit_status=$? if [ $exit_status -ne 0 ]; then