Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to Tensorflow v2 #5

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 73 additions & 58 deletions janni/jmain.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
SOFTWARE.
"""

import argparse
import sys
import json
import os
Expand All @@ -46,6 +45,7 @@

ARGPARSER = None


def create_config_parser(parser):
config_required_group = parser.add_argument_group(
"Required arguments",
Expand All @@ -62,7 +62,7 @@ def create_config_parser(parser):
"test": 'user_input.endswith("json")',
"message": "File has to end with .json!",
},
"default_file": "config_janni.json"
"default_file": "config_janni.json",
},
)

Expand Down Expand Up @@ -94,11 +94,10 @@ def create_config_parser(parser):
"test": 'user_input.endswith("h5")',
"message": "File has to end with .h5!",
},
"default_file": "janni_model.h5"
"default_file": "janni_model.h5",
},
)


config_optional_group = parser.add_argument_group(
"Optional arguments",
"The arguments are optional to create a config file for JANNI",
Expand All @@ -108,7 +107,10 @@ def create_config_parser(parser):
"--loss",
default="mae",
help="Loss function that is used during training: Mean squared error (mse) or mean absolute error (mae).",
choices=["mae", "mse",],
choices=[
"mae",
"mse",
],
)
config_optional_group.add_argument(
"--patch_size",
Expand Down Expand Up @@ -148,9 +150,7 @@ def create_train_parser(parser):
"config_path",
help="Path to config.json",
widget="FileChooser",
gooey_options={
"wildcard": "*.json"
}
gooey_options={"wildcard": "*.json"},
)

optional_group = parser.add_argument_group(
Expand All @@ -161,6 +161,7 @@ def create_train_parser(parser):
"-g", "--gpu", type=int, default=-1, help="GPU ID to run on"
)


def create_predict_parser(parser):
required_group = parser.add_argument_group(
"Required arguments", "These options are mandatory to run JANNI"
Expand All @@ -181,9 +182,7 @@ def create_predict_parser(parser):
"model_path",
help="File path to trained model.",
widget="FileChooser",
gooey_options={
"wildcard": "*.h5"
}
gooey_options={"wildcard": "*.h5"},
)

optional_group = parser.add_argument_group(
Expand All @@ -205,21 +204,25 @@ def create_predict_parser(parser):
"-g", "--gpu", type=int, default=-1, help="GPU ID to run on"
)


def create_parser(parser):

subparsers = parser.add_subparsers(help="sub-command help")

parser_config= subparsers.add_parser("config", help="Create the configuration file for JANNI")
parser_config = subparsers.add_parser(
"config", help="Create the configuration file for JANNI"
)
create_config_parser(parser_config)

parser_train = subparsers.add_parser("train", help="Train JANNI for your dataset.")
create_train_parser(parser_train)

parser_predict = subparsers.add_parser("denoise", help="Denoise micrographs using a (pre)trained model.")
parser_predict = subparsers.add_parser(
"denoise", help="Denoise micrographs using a (pre)trained model."
)
create_predict_parser(parser_predict)



def get_parser():
parser = GooeyParser(description="Just another noise to noise implementation")
create_parser(parser)
Expand All @@ -231,14 +234,14 @@ def _main_():
import sys

if len(sys.argv) >= 2:
if not "--ignore-gooey" in sys.argv:
if "--ignore-gooey" not in sys.argv:
sys.argv.append("--ignore-gooey")

kwargs = {"terminal_font_family": "monospace", "richtext_controls": True}
Gooey(
main,
program_name="JANNI " + ini.__version__,
#image_dir=os.path.join(os.path.abspath(os.path.dirname(__file__)), "../icons"),
# image_dir=os.path.join(os.path.abspath(os.path.dirname(__file__)), "../icons"),
progress_regex=r"^.* \( Progress:\s+(-?\d+) % \)$",
disable_progress_bar_animation=True,
tabbed_groups=True,
Expand All @@ -253,20 +256,20 @@ def main(args=None):
parser = get_parser()
args = parser.parse_args()



if "config" in sys.argv[1]:
generate_config_file(config_out_path=args.config_out_path,
architecture="unet",
patch_size=args.patch_size,
movie_dir=args.movie_dir,
even_dir=args.even_dir,
odd_dir=args.odd_dir,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
nb_epoch=args.nb_epoch,
saved_weights_name=args.saved_weights_name,
loss=args.loss,)
generate_config_file(
config_out_path=args.config_out_path,
architecture="unet",
patch_size=args.patch_size,
movie_dir=args.movie_dir,
even_dir=args.even_dir,
odd_dir=args.odd_dir,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
nb_epoch=args.nb_epoch,
saved_weights_name=args.saved_weights_name,
loss=args.loss,
)
else:
if isinstance(args.gpu, list):
if len(args.gpu) == 1:
Expand All @@ -281,14 +284,16 @@ def main(args=None):
config = read_config(args.config_path)

from . import train

loss = "mae"
if "loss" in config["train"]:
if "mae" == config["train"]["loss"] or "mse" == config["train"]["loss"]:\
if "mae" == config["train"]["loss"] or "mse" == config["train"]["loss"]:
loss = config["train"]["loss"]
else:
print("Unsupported loss chosen:",config["train"]["loss"])
print("Unsupported loss chosen:", config["train"]["loss"])
print("Use default loss MAE")
from . import utils

fbinning = utils.fourier_binning
if "binning" in config["train"]:
if config["train"]["binning"] == "rescale":
Expand All @@ -306,7 +311,10 @@ def main(args=None):
learning_rate=config["train"]["learning_rate"],
epochs=config["train"]["nb_epoch"],
model=config["model"]["architecture"],
patch_size=(config["model"]["patch_size"], config["model"]["patch_size"]),
patch_size=(
config["model"]["patch_size"],
config["model"]["patch_size"],
),
batch_size=config["train"]["batch_size"],
loss=loss,
fbinning=fbinning,
Expand All @@ -330,7 +338,7 @@ def main(args=None):
try:
u = model.tolist()
model = u.decode()
except:
except Exception:
pass
patch_size = tuple(f["patch_size"])
except KeyError:
Expand All @@ -353,37 +361,44 @@ def main(args=None):
batch_size=batch_size,
)

def generate_config_file(config_out_path,
architecture,
patch_size,
movie_dir,
even_dir,
odd_dir,
batch_size,
learning_rate,
nb_epoch,
saved_weights_name,
loss):
model_dict = {'architecture': architecture,
'patch_size': patch_size,
}

train_dict = {'movie_dir': movie_dir,
'even_dir': even_dir,
'odd_dir': odd_dir,
'batch_size': batch_size,
'learning_rate': learning_rate,
'nb_epoch': nb_epoch,
"saved_weights_name": saved_weights_name,
"loss": loss,
}

def generate_config_file(
config_out_path,
architecture,
patch_size,
movie_dir,
even_dir,
odd_dir,
batch_size,
learning_rate,
nb_epoch,
saved_weights_name,
loss,
):
model_dict = {
"architecture": architecture,
"patch_size": patch_size,
}

train_dict = {
"movie_dir": movie_dir,
"even_dir": even_dir,
"odd_dir": odd_dir,
"batch_size": batch_size,
"learning_rate": learning_rate,
"nb_epoch": nb_epoch,
"saved_weights_name": saved_weights_name,
"loss": loss,
}

from json import dump

dict = {"model": model_dict, "train": train_dict}
with open(config_out_path, 'w') as f:
with open(config_out_path, "w") as f:
dump(dict, f, ensure_ascii=False, indent=4)
print("Wrote config to", config_out_path)


def read_config(config_path):
with open(config_path) as config_buffer:
try:
Expand Down
17 changes: 12 additions & 5 deletions janni/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,18 @@
SOFTWARE.
"""

from keras.models import Model
from keras.layers import Input, Add, Conv2DTranspose, MaxPooling2D, UpSampling2D, ReLU
from keras.layers.convolutional import Conv2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.merge import concatenate
from tensorflow.keras import Model
from tensorflow.keras.layers import (
Add,
Conv2D,
Conv2DTranspose,
Input,
LeakyReLU,
MaxPooling2D,
ReLU,
UpSampling2D,
concatenate,
)


def get_rednet(
Expand Down
3 changes: 1 addition & 2 deletions janni/patch_pair_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@
SOFTWARE.
"""

from keras.utils import Sequence
from tensorflow.keras.utils import Sequence
from random import shuffle
import numpy as np
import mrcfile
from . import utils


Expand Down
Loading