Skip to content

Commit

Permalink
bring in the simple tokenizer released by openai, but also plan on le…
Browse files Browse the repository at this point in the history
…aving room for custom tokenizer with yttm
  • Loading branch information
lucidrains committed Apr 12, 2022
1 parent 4ff6d02 commit 7cf1637
Show file tree
Hide file tree
Showing 5 changed files with 262,394 additions and 11 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
recursive-include dalle2_pytorch *.txt
66 changes: 56 additions & 10 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ def exists(val):
def default(val, d):
return val if exists(val) else d

def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner

# for controlling freezing of CLIP

def set_module_requires_grad_(module, requires_grad):
Expand All @@ -30,24 +39,61 @@ def unfreeze_all_layers_(module):
# diffusion prior

class DiffusionPrior(nn.Module):
def __init__(self):
def __init__(
self,
*,
clip
):
super().__init__()
def forward(self, x):
return x
assert isinstance(clip, CLIP)

def forward(
self,
*,
text,
image
):
return text

# decoder

class Decoder(nn.Module):
def __init__(self):
def __init__(
self,
*,
clip,
prior
):
super().__init__()
def forward(self, x):
return x
assert isinstance(clip, CLIP)
assert isinstance(prior, DiffusionPrior)

def forward(
self,
*,
image
):
return image

# main class

class DALLE2(nn.Module):
def __init__(self):
def __init__(
self,
*,
clip,
prior,
decoder
):
super().__init__()

def forward(self, x):
return x
assert isinstance(clip), CLIP
assert isinstance(prior), DiffusionPrior
assert isinstance(decoder), Decoder

@torch.no_grad()
def forward(
self,
*,
text
):
return text
Loading

0 comments on commit 7cf1637

Please sign in to comment.