From 058c5a1af231bb5eba52ae1c16358839dd64acd7 Mon Sep 17 00:00:00 2001 From: Pehat Date: Wed, 24 Jun 2020 04:43:12 +0200 Subject: [PATCH] Allow to use local models via commandline parameters --- PULSE.py | 18 +++++++++++++----- requirements.txt | 2 ++ run.py | 4 +++- 3 files changed, 18 insertions(+), 6 deletions(-) create mode 100644 requirements.txt diff --git a/PULSE.py b/PULSE.py index d499aaa..4ff6554 100644 --- a/PULSE.py +++ b/PULSE.py @@ -11,7 +11,7 @@ class PULSE(torch.nn.Module): - def __init__(self, cache_dir, verbose=True): + def __init__(self, cache_dir, synthesis_path=None, mapping_path=None, verbose=True): super(PULSE, self).__init__() self.synthesis = G_synthesis().cuda() @@ -20,7 +20,11 @@ def __init__(self, cache_dir, verbose=True): cache_dir = Path(cache_dir) cache_dir.mkdir(parents=True, exist_ok = True) if self.verbose: print("Loading Synthesis Network") - with open_url("https://drive.google.com/uc?id=1TCViX1YpQyRsklTVYEJwdbmK91vklCo8", cache_dir=cache_dir, verbose=verbose) as f: + if synthesis_path is None: + f = open_url("https://drive.google.com/uc?id=1TCViX1YpQyRsklTVYEJwdbmK91vklCo8", cache_dir=cache_dir, verbose=verbose) + else: + f = open(synthesis_path, "rb") + with f: self.synthesis.load_state_dict(torch.load(f)) for param in self.synthesis.parameters(): @@ -34,8 +38,12 @@ def __init__(self, cache_dir, verbose=True): if self.verbose: print("\tLoading Mapping Network") mapping = G_mapping().cuda() - with open_url("https://drive.google.com/uc?id=14R6iHGf5iuVx3DMNsACAl7eBr7Vdpd0k", cache_dir=cache_dir, verbose=verbose) as f: - mapping.load_state_dict(torch.load(f)) + if mapping_path is None: + f = open_url("https://drive.google.com/uc?id=14R6iHGf5iuVx3DMNsACAl7eBr7Vdpd0k", cache_dir=cache_dir, verbose=verbose) + else: + f = open(mapping_path, "rb") + with f: + mapping.load_state_dict(torch.load(f)) if self.verbose: print("\tRunning Mapping Network") with torch.no_grad(): @@ -120,7 +128,7 @@ def forward(self, ref_im, } schedule_func = schedule_dict[lr_schedule] scheduler = torch.optim.lr_scheduler.LambdaLR(opt.opt, schedule_func) - + loss_builder = LossBuilder(ref_im, loss_str, eps).cuda() min_loss = np.inf diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ac988bd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +torch +torchvision diff --git a/run.py b/run.py index 8278e37..01aed31 100644 --- a/run.py +++ b/run.py @@ -46,6 +46,8 @@ def __getitem__(self, idx): parser.add_argument('-steps', type=int, default=100, help='Number of optimization steps') parser.add_argument('-lr_schedule', type=str, default='linear1cycledrop', help='fixed, linear1cycledrop, linear1cycle') parser.add_argument('-save_intermediate', action='store_true', help='Whether to store and save intermediate HR and LR images during optimization') +parser.add_argument('-synthesis_path', type=str, default=None, help='Path to synthesis.pt file. If not specified, fetched from Google Drive') +parser.add_argument('-mapping_path', type=str, default=None, help='Path to mapping.pt file. If not specified, fetched from Google Drive') kwargs = vars(parser.parse_args()) @@ -55,7 +57,7 @@ def __getitem__(self, idx): dataloader = DataLoader(dataset, batch_size=kwargs["batch_size"]) -model = PULSE(cache_dir=kwargs["cache_dir"]) +model = PULSE(cache_dir=kwargs["cache_dir"], synthesis_path=kwargs["synthesis_path"], mapping_path=kwargs["mapping_path"]) model = DataParallel(model) toPIL = torchvision.transforms.ToPILImage()