Skip to content

Commit

Permalink
[Backport] Orca tf example matplotlib failed on macos bugfix (#3236)
Browse files Browse the repository at this point in the history
* fix matplotlib crash on macos

* fix jenkins error

* mac fix
  • Loading branch information
cyita authored Dec 15, 2020
1 parent 6558f27 commit bb58d5f
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 25 deletions.
12 changes: 10 additions & 2 deletions pyzoo/zoo/examples/orca/learn/tf/image_segmentation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,24 @@ You should manually download the dataset from kaggle [carvana-image-masking-chal

## Run example on local
```bash
python image_segmentation.py --cluster_mode local
# linux
python image_segmentation.py --cluster_mode local
# macos
python image_segmentation.py --cluster_mode local --platform mac
```

## Run example on yarn cluster
```bash
python image_segmentation.py --cluster_mode yarn
# linux
python image_segmentation.py --cluster_mode yarn
# macos
python image_segmentation.py --cluster_mode yarn --platform mac
```

Options
* `--cluster_mode` The mode for the Spark cluster. local or yarn. Default to be `local`.
* `--file_path` The path to carvana train.zip, train_mask.zip and train_mask.csv.zip. Default to be `/tmp/carvana/`.
* `--epochs` The number of epochs to train the model. Default to be 8.
* `--batch_size` Batch size for training and prediction. Default to be 8.
* `--platform` The platform you used to run the example. Default to be `linux`. You should pass `mac` if you use macos.
* `--non_interactive` Flag to not visualize the result.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ def load_data(file_path):
load_data_from_zip(file_path, 'train_masks.csv.zip')


def main(cluster_mode, max_epoch, file_path, batch_size):
def main(cluster_mode, max_epoch, file_path, batch_size, platform, non_interactive):
import matplotlib
if not non_interactive and platform == "mac":
matplotlib.use('qt5agg')

if cluster_mode == "local":
init_orca_context(cluster_mode="local", cores=4, memory="3g")
elif cluster_mode == "yarn":
Expand Down Expand Up @@ -174,25 +178,27 @@ def bce_dice_loss(y_true, y_pred):
val_image_label = val_shards.collect()[0]
val_image = val_image_label["x"]
val_label = val_image_label["y"]
# visualize 5 predicted results
plt.figure(figsize=(10, 20))
for i in range(5):
img = val_image[i]
label = val_label[i]
predicted_label = pred[i]

plt.subplot(5, 3, 3 * i + 1)
plt.imshow(img)
plt.title("Input image")

plt.subplot(5, 3, 3 * i + 2)
plt.imshow(label[:, :, 0], cmap='gray')
plt.title("Actual Mask")
plt.subplot(5, 3, 3 * i + 3)
plt.imshow(predicted_label, cmap='gray')
plt.title("Predicted Mask")
plt.suptitle("Examples of Input Image, Label, and Prediction")
plt.show()
if not non_interactive:
# visualize 5 predicted results
plt.figure(figsize=(10, 20))
for i in range(5):
img = val_image[i]
label = val_label[i]
predicted_label = pred[i]

plt.subplot(5, 3, 3 * i + 1)
plt.imshow(img)
plt.title("Input image")

plt.subplot(5, 3, 3 * i + 2)
plt.imshow(label[:, :, 0], cmap='gray')
plt.title("Actual Mask")
plt.subplot(5, 3, 3 * i + 3)
plt.imshow(predicted_label, cmap='gray')
plt.title("Predicted Mask")
plt.suptitle("Examples of Input Image, Label, and Prediction")

plt.show()

stop_orca_context()

Expand All @@ -207,6 +213,11 @@ def bce_dice_loss(y_true, y_pred):
help="The number of epochs to train the model")
parser.add_argument('--batch_size', type=int, default=8,
help="Batch size for training and prediction")
parser.add_argument('--platform', type=str, default="linux",
help="The platform you used. Only linux and mac are supported.")
parser.add_argument('--non_interactive', default=False, action="store_true",
help="Flag to not visualize the result.")

args = parser.parse_args()
main(args.cluster_mode, args.epochs, args.file_path, args.batch_size)
main(args.cluster_mode, args.epochs, args.file_path, args.batch_size, args.platform,
args.non_interactive)
2 changes: 1 addition & 1 deletion pyzoo/zoo/examples/run-example-tests-pip.sh
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ fi
# Run the example
export SPARK_DRIVER_MEMORY=3g
python ${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/orca/learn/tf/image_segmentation/image_segmentation.py \
--file_path analytics-zoo-data/data/carvana
--file_path analytics-zoo-data/data/carvana --epochs 1 --non_interactive
exit_status=$?
if [ $exit_status -ne 0 ];
then
Expand Down
2 changes: 1 addition & 1 deletion pyzoo/zoo/examples/run-example-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ ${ANALYTICS_ZOO_HOME}/bin/spark-submit-python-with-zoo.sh \
--driver-memory 3g \
--executor-memory 3g \
${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/orca/learn/tf/image_segmentation/image_segmentation.py \
--file_path analytics-zoo-data/data/carvana --epochs 1
--file_path analytics-zoo-data/data/carvana --epochs 1 --non_interactive
exit_status=$?
if [ $exit_status -ne 0 ];
then
Expand Down

0 comments on commit bb58d5f

Please sign in to comment.