Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reformat Python Project with Black and Isort #256

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
12 changes: 12 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[flake8]
max-line-length = 88
select = "E303, W293, W291, W292, E305, E231, E302"
exclude =
.tox,
__pycache__,
*.pyc,
.env
venv*/*,
.venv/*,
reports/*,
dist/*,
42 changes: 42 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
name: Python CI

on:
push:
branches: [ master ]
pull_request:
branches: [ master, dev ]

concurrency:
group: ${{ format('ci-{0}', github.head_ref && format('pr-{0}', github.event.pull_request.number) || github.sha) }}
cancel-in-progress: ${{ github.event_name == 'pull_request' }}

jobs:
lint:
runs-on: ubuntu-latest
env:
min-python-version: "3.10"

steps:
- name: Checkout repository
uses: actions/checkout@v3

- name: Set up Python ${{ env.min-python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ env.min-python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt

- name: Lint with flake8
run: flake8

- name: Check black formatting
run: black . --check
if: success() || failure()

- name: Check isort formatting
run: isort . --check
if: success() || failure()
10 changes: 10 additions & 0 deletions .isort.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[settings]
profile = black
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
line_length = 88
sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
skip = .tox,__pycache__,*.pyc,venv*/*,reports,venv,env,node_modules,.env,.venv,dist
10 changes: 5 additions & 5 deletions ldm/data/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ def __init__(self, model_type):
self.transform = load_midas_transform(model_type)

def pt2np(self, x):
x = ((x + 1.0) * .5).detach().cpu().numpy()
x = ((x + 1.0) * 0.5).detach().cpu().numpy()
return x

def np2pt(self, x):
x = torch.from_numpy(x) * 2 - 1.
x = torch.from_numpy(x) * 2 - 1.0
return x

def __call__(self, sample):
# sample['jpg'] is tensor hwc in [-1, 1] at this point
x = self.pt2np(sample['jpg'])
x = self.pt2np(sample["jpg"])
x = self.transform({"image": x})["image"]
sample['midas_in'] = x
return sample
sample["midas_in"] = x
return sample
138 changes: 95 additions & 43 deletions ldm/models/autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,49 @@
import torch
from contextlib import contextmanager

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from contextlib import contextmanager

from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.diffusionmodules.model import Decoder, Encoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution

from ldm.util import instantiate_from_config
from ldm.modules.ema import LitEma
from ldm.util import instantiate_from_config


class AutoencoderKL(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False
):
def __init__(
self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False,
):
super().__init__()
self.learn_logvar = learn_logvar
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
assert type(colorize_nlabels) == int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor

self.use_ema = ema_decay is not None
if self.use_ema:
self.ema_decay = ema_decay
assert 0. < ema_decay < 1.
assert 0.0 < ema_decay < 1.0
self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

Expand Down Expand Up @@ -112,19 +113,51 @@ def training_step(self, batch, batch_idx, optimizer_idx):

if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log(
"aeloss",
aeloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
)
return aeloss

if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")

self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)

self.log(
"discloss",
discloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
)
return discloss

def validation_step(self, batch, batch_idx):
Expand All @@ -136,11 +169,25 @@ def validation_step(self, batch, batch_idx):
def _validation_step(self, batch, batch_idx, postfix=""):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)

discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
0,
self.global_step,
last_layer=self.get_last_layer(),
split="val" + postfix,
)

discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
1,
self.global_step,
last_layer=self.get_last_layer(),
split="val" + postfix,
)

self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
self.log_dict(log_dict_ae)
Expand All @@ -149,15 +196,19 @@ def _validation_step(self, batch, batch_idx, postfix=""):

def configure_optimizers(self):
lr = self.learning_rate
ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
ae_params_list = (
list(self.encoder.parameters())
+ list(self.decoder.parameters())
+ list(self.quant_conv.parameters())
+ list(self.post_quant_conv.parameters())
)
if self.learn_logvar:
print(f"{self.__class__.__name__}: Learning logvar")
ae_params_list.append(self.loss.logvar)
opt_ae = torch.optim.Adam(ae_params_list,
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
)
return [opt_ae, opt_disc], []

def get_last_layer(self):
Expand All @@ -184,7 +235,9 @@ def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
# colorize with random projection
assert xrec_ema.shape[1] > 3
xrec_ema = self.to_rgb(xrec_ema)
log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
log["samples_ema"] = self.decode(
torch.randn_like(posterior_ema.sample())
)
log["reconstructions_ema"] = xrec_ema
log["inputs"] = x
return log
Expand All @@ -194,7 +247,7 @@ def to_rgb(self, x):
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x


Expand All @@ -216,4 +269,3 @@ def quantize(self, x, *args, **kwargs):

def forward(self, x, *args, **kwargs):
return x

Loading