Skip to content

Commit

Permalink
0.1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 19, 2021
1 parent 741932a commit 290a00d
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
13 changes: 9 additions & 4 deletions big_sleep/big_sleep.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,16 @@ def __init__(
num_latents = 32,
num_cutouts = 128,
loss_coef = 100,
image_width = 512
image_width = 512,
bilinear = False
):
super().__init__()
self.loss_coef = loss_coef
self.image_width = image_width
self.num_cutouts = num_cutouts

self.interpolation_settings = {'mode': 'bilinear', 'align_corners': False} if bilinear else {'mode': 'nearest'}

self.model = Model(
num_latents = num_latents,
image_width = image_width
Expand All @@ -77,7 +80,7 @@ def forward(self, text, return_loss = True):
offsetx = torch.randint(0, width - size, ())
offsety = torch.randint(0, width - size, ())
apper = out[:, :, offsetx:offsetx + size, offsety:offsety + size]
apper = F.interpolate(apper, (224,224), mode='nearest')
apper = F.interpolate(apper, (224,224), **self.interpolation_settings)
pieces.append(apper)

into = torch.cat(pieces)
Expand Down Expand Up @@ -121,14 +124,16 @@ def __init__(
image_width = 512,
epochs = 20,
iterations = 1050,
save_progress = False
save_progress = False,
bilinear = False
):
super().__init__()
self.epochs = epochs
self.iterations = iterations

model = BigSleep(
num_latents = num_latents
num_latents = num_latents,
bilinear = bilinear
).cuda()

self.model = model
Expand Down
8 changes: 5 additions & 3 deletions big_sleep/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ def train(
gradient_accumulate_every = 1,
epochs = 20,
iterations = 1050,
save_every = 25,
save_every = 100,
overwrite = False,
save_progress = False
save_progress = False,
bilinear = False
):

imagine = Imagine(
Expand All @@ -22,7 +23,8 @@ def train(
epochs = epochs,
iterations = iterations,
save_every = save_every,
save_progress = save_progress
save_progress = save_progress,
bilinear = bilinear
)

if not overwrite and imagine.filename.exists():
Expand Down
2 changes: 1 addition & 1 deletion big_sleep/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.2'
__version__ = '0.1.0'

0 comments on commit 290a00d

Please sign in to comment.