Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
xinwucwp committed May 6, 2020
1 parent 56f0a45 commit ea756a4
Show file tree
Hide file tree
Showing 14 changed files with 120 additions and 8 deletions.
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@

# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
/model/*.hdf5
/data/train/
/data/validation/
/data/
/check/
/log/
/png/
#/src
#/build
#/libs
Expand Down
Binary file removed .train.py.swp
Binary file not shown.
Binary file removed .unet3.py.swp
Binary file not shown.
Binary file removed .utils.py.swp
Binary file not shown.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ to be sufficient to train a good fault segmentation network.**

### Training

Run train3 to start training a new faultSeg model by using the 200 synthetic datasets
Run train to start training a new faultSeg model by using the 200 synthetic datasets

## Publications

Expand Down
Binary file added __pycache__/fnet.cpython-37.pyc
Binary file not shown.
Binary file modified __pycache__/unet3.cpython-37.pyc
Binary file not shown.
Binary file modified __pycache__/utils.cpython-37.pyc
Binary file not shown.
105 changes: 105 additions & 0 deletions apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import math
import skimage
import numpy as np
import os
import matplotlib.pyplot as plt

import tensorflow as tf
#from keras import backend
from keras.layers import *
from keras.models import load_model
from skimage.measure import compare_psnr
from unet3 import cross_entropy_balanced
import os
pngDir = './png/'
model = load_model('check/fseg-'+'25.hdf5',
#'impd.hdf5',
custom_objects={
'cross_entropy_balanced': cross_entropy_balanced
#}
)
def main():
#goTrainTest()
goValidTest()
#goF3Test()

def goTrainTest():
seismPath = "./data/train/seis/"
faultPath = "./data/train/fault/"
n1,n2,n3=128,128,128
dk = 100
gx = np.fromfile(seismPath+str(dk)+'.dat',dtype=np.single)
fx = np.fromfile(faultPath+str(dk)+'.dat',dtype=np.single)
gx = np.reshape(gx,(n1,n2,n3))
fx = np.reshape(fx,(n1,n2,n3))
gm = np.mean(gx)
gs = np.std(gx)
gx = gx-gm
gx = gx/gs
gx = np.transpose(gx)
fx = np.transpose(fx)
fp = model.predict(np.reshape(gx,(1,n1,n2,n3,1)),verbose=1)
fp = fp[0,:,:,:,0]
gx1 = gx[50,:,:]
fx1 = fx[50,:,:]
fp1 = fp[50,:,:]
plot2d(gx1,fx1,fp1,png='fp')

def goValidTest():
seismPath = "./data/validation/seis/"
faultPath = "./data/validation/fault/"
n1,n2,n3=128,128,128
dk = 2
gx = np.fromfile(seismPath+str(dk)+'.dat',dtype=np.single)
fx = np.fromfile(faultPath+str(dk)+'.dat',dtype=np.single)
gx = np.reshape(gx,(n1,n2,n3))
fx = np.reshape(fx,(n1,n2,n3))
gm = np.mean(gx)
gs = np.std(gx)
gx = gx-gm
gx = gx/gs
gx = np.transpose(gx)
fx = np.transpose(fx)
fp = model.predict(np.reshape(gx,(1,n1,n2,n3,1)),verbose=1)
fp = fp[0,:,:,:,0]
gx1 = gx[50,:,:]
fx1 = fx[50,:,:]
fp1 = fp[50,:,:]
plot2d(gx1,fx1,fp1,png='fp')

def goF3Test():
seismPath = "./data/prediction/f3d/"
n3,n2,n1=512,384,128
gx = np.fromfile(seismPath+'gxl.dat',dtype=np.single)
gx = np.reshape(gx,(n3,n2,n1))
gm = np.mean(gx)
gs = np.std(gx)
gx = gx-gm
gx = gx/gs
gx = np.transpose(gx)
fp = model.predict(np.reshape(gx,(1,n1,n2,n3,1)),verbose=1)
fp = fp[0,:,:,:,0]
gx1 = gx[99,:,:]
fp1 = fp[99,:,:]
plot2d(gx1,fp1,fp1,png='f3d/fp')

def plot2d(gx,fx,fp,png=None):
fig = plt.figure(figsize=(15,5))
#fig = plt.figure()
ax = fig.add_subplot(131)
ax.imshow(gx,vmin=-2,vmax=2,cmap=plt.cm.bone,interpolation='bicubic',aspect=1)
ax = fig.add_subplot(132)
ax.imshow(fx,vmin=0,vmax=1,cmap=plt.cm.bone,interpolation='bicubic',aspect=1)
ax = fig.add_subplot(133)
ax.imshow(fp,vmin=0,vmax=1.0,cmap=plt.cm.bone,interpolation='bicubic',aspect=1)
if png:
plt.savefig(pngDir+png+'.png')
#cbar = plt.colorbar()
#cbar.set_label('Fault probability')
plt.tight_layout()
plt.show()

if __name__ == '__main__':
main()


5 changes: 5 additions & 0 deletions cpurun
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/sh

#org.python.util.jython $*
CUDA_VISIBLE_DEVICES='' python $*

Binary file modified log/training/events.out.tfevents.1588727343.ustc
Binary file not shown.
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def goTrain():
valid_generator = DataGenerator(dpath=seismPathV,fpath=faultPathV,
data_IDs=valid_ID,**params)
model = unet(input_size=(None, None, None,1))
model.compile(optimizer=Adam(lr=1e-3), loss='binary_crossentropy',
model.compile(optimizer=Adam(lr=1e-3), loss=cross_entropy_balanced,
metrics=['accuracy'])
model.summary()

Expand Down
4 changes: 2 additions & 2 deletions unet3.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def unet(pretrained_weights = None,input_size = (None,None,None,1)):
conv3 = Conv3D(64, (3,3,3), activation='relu', padding='same')(conv3)
pool3 = MaxPooling3D(pool_size=(2,2,2))(conv3)

conv4 = Conv3D(512, (3,3,3), activation='relu', padding='same')(pool3)
conv4 = Conv3D(512, (3,3,3), activation='relu', padding='same')(conv4)
conv4 = Conv3D(128, (3,3,3), activation='relu', padding='same')(pool3)
conv4 = Conv3D(128, (3,3,3), activation='relu', padding='same')(conv4)

up5 = concatenate([UpSampling3D(size=(2,2,2))(conv4), conv3], axis=-1)
conv5 = Conv3D(64, (3,3,3), activation='relu', padding='same')(up5)
Expand Down
5 changes: 3 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import keras
import random
from keras.utils import to_categorical

class DataGenerator(keras.utils.Sequence):
Expand Down Expand Up @@ -43,8 +44,6 @@ def on_epoch_end(self):
def __data_generation(self, data_IDs_temp):
'Generates data containing batch_size samples'
# Initialization
X = np.zeros((4, *self.dim, self.n_channels),dtype=np.single)
Y = np.zeros((4, *self.dim, self.n_channels),dtype=np.single)
gx = np.fromfile(self.dpath+str(data_IDs_temp[0])+'.dat',dtype=np.single)
fx = np.fromfile(self.fpath+str(data_IDs_temp[0])+'.dat',dtype=np.single)
gx = np.reshape(gx,self.dim)
Expand All @@ -59,6 +58,8 @@ def __data_generation(self, data_IDs_temp):
gx = np.transpose(gx)
fx = np.transpose(fx)
# Generate data
X = np.zeros((4, *self.dim, self.n_channels),dtype=np.single)
Y = np.zeros((4, *self.dim, self.n_channels),dtype=np.single)
for i in range(4):
X[i,] = np.reshape(np.rot90(gx,i,(2,1)), (*self.dim,self.n_channels))
Y[i,] = np.reshape(np.rot90(fx,i,(2,1)), (*self.dim,self.n_channels))
Expand Down

0 comments on commit ea756a4

Please sign in to comment.