diff --git a/train3.ipynb b/train3.ipynb deleted file mode 100644 index 65f84b4..0000000 --- a/train3.ipynb +++ /dev/null @@ -1,213 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n", - "# Train your faultSeg model with synthetic datasets\n", - "\n", - "it is a binary segmentation problem\n", - "\n", - "the synthetic 3D seismic and corresponding fault images are in the folder ./data/train/ " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Train with data generator" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [ - "from faultSeg_classes import DataGenerator\n", - "\n", - "# training image dimensions\n", - "n1, n2, n3 = 128, 128, 128\n", - "params = {'batch_size': 1,\n", - " 'dim':(n1,n2,n3),\n", - " 'n_channels': 1,\n", - " 'shuffle': True}\n", - "\n", - "tdpath = 'data/train/seis/'\n", - "tfpath = 'data/train/fault/'\n", - "\n", - "vdpath = 'data/validation/seis/'\n", - "vfpath = 'data/validation/fault/'\n", - "tdata_IDs = range(200)\n", - "vdata_IDs = range(20)\n", - "training_generator = DataGenerator(dpath=tdpath,fpath=tfpath,data_IDs=tdata_IDs,**params)\n", - "validation_generator = DataGenerator(dpath=vdpath,fpath=vfpath,data_IDs=vdata_IDs,**params)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [ - "from unet3 import *\n", - "from keras import callbacks\n", - "\n", - "K.set_image_data_format('channels_last')\n", - "model_name = 'fault'\n", - "model_dir = os.path.join('check', model_name)\n", - "csv_fn = os.path.join(model_dir, 'train_log.csv')\n", - "checkpoint_fn = os.path.join(model_dir, 'checkpoint.{epoch:02d}.hdf5')\n", - "\n", - "model = unet()\n", - "checkpointer = callbacks.ModelCheckpoint(filepath=checkpoint_fn, verbose=1, save_best_only=False)\n", - "csv_logger = callbacks.CSVLogger(csv_fn, append=True, separator=';')\n", - "tensorboard = callbacks.TensorBoard(log_dir=model_dir, histogram_freq=0, batch_size=2,\n", - " write_graph=True, write_grads=True, write_images=True)\n", - "history = model.fit_generator(\n", - " generator=training_generator,\n", - " validation_data=validation_generator,\n", - " epochs=25,verbose=1,callbacks=[checkpointer, csv_logger, tensorboard])\n", - "\n", - "print(history)\n", - "# serialize model to JSON\n", - "model_json = model.to_json()\n", - "with open(\"model3.json\", \"w\") as json_file:\n", - " json_file.write(model_json)\n", - "# serialize weights to HDF5\n", - "model.save_weights(\"model3.h5\")\n", - "print(\"Saved model to disk\") " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [ - "# list all data in history\n", - "print(history.history.keys())\n", - "fig = plt.figure(figsize=(10,6))\n", - "\n", - "# summarize history for accuracy\n", - "plt.plot(history.history['acc'])\n", - "plt.plot(history.history['val_acc'])\n", - "plt.title('Model accuracy',fontsize=20)\n", - "plt.ylabel('Accuracy',fontsize=20)\n", - "plt.xlabel('Epoch',fontsize=20)\n", - "plt.legend(['train', 'test'], loc='center right',fontsize=20)\n", - "plt.tick_params(axis='both', which='major', labelsize=18)\n", - "plt.tick_params(axis='both', which='minor', labelsize=18)\n", - "plt.show()\n", - "\n", - "# summarize history for loss\n", - "fig = plt.figure(figsize=(10,6))\n", - "plt.plot(history.history['loss'])\n", - "plt.plot(history.history['val_loss'])\n", - "plt.title('Model loss',fontsize=20)\n", - "plt.ylabel('Loss',fontsize=20)\n", - "plt.xlabel('Epoch',fontsize=20)\n", - "plt.legend(['train', 'test'], loc='center right',fontsize=20)\n", - "plt.tick_params(axis='both', which='major', labelsize=18)\n", - "plt.tick_params(axis='both', which='minor', labelsize=18)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Load the trained model and test on your own datasets" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [ - "from keras.models import load_model\n", - "from unet3 import *\n", - "\n", - "# load json and create model \n", - "json_file = open('model3.json', 'r')\n", - "loaded_model_json = json_file.read()\n", - "json_file.close()\n", - "loaded_model = model_from_json(loaded_model_json)\n", - "# load weights into new model\n", - "loaded_model.load_weights(\"check4/fault/checkpoint.25.hdf5\")\n", - "print(\"Loaded model from disk\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "gx,m1,m2,m3 = np.fromfile(\"data/validation/seis/6.dat\",dtype=np.single),128,128,128\n", - "gx = gx-np.min(gx)\n", - "gx = gx/np.max(gx)\n", - "gx = gx*255\n", - "k = 50\n", - "x = np.reshape(gx,(1,128,128,128,1))\n", - "Y = loaded_model.predict(x,verbose=1)\n", - "print(Y.shape)\n", - "\n", - "# Y1 = Y[0]\n", - "# Y2 = Y[1]\n", - "# Y3 = Y[2]\n", - "# Y4 = Y[3]\n", - "# Y5 = Y[4]\n", - "#Y6 = Y[5]\n", - "fig = plt.figure(figsize=(10,10))\n", - "plt.subplot(1, 2, 1)\n", - "imgplot1 = plt.imshow(np.transpose(x[0,k,:,:,0]),cmap=plt.cm.bone,interpolation='nearest',aspect=1)\n", - "plt.subplot(1, 2, 2)\n", - "imgplot2 = plt.imshow(np.transpose(Y[0,k,:,:,0]),cmap=plt.cm.bone,interpolation='nearest',aspect=1)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.5" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -}