-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_sample_cond.py
59 lines (43 loc) · 2.24 KB
/
run_sample_cond.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import argparse
from load_model import load_model
from transformers import GPT2TokenizerFast
from sampling import OrderedSampler,DiffusionSampler
def main():
parser = argparse.ArgumentParser(description="Generate some samples")
parser.add_argument("--model_path", default="JingyangOu/radd-lambda-dce", type=str)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--length", type=int, default=1024)
parser.add_argument("--steps", type=int, default=1024)
parser.add_argument("--method", type=str, default="tweedie") # euler, tweedie
parser.add_argument("--strategy", type=str, default="direct") # direct, top_p, top_k
parser.add_argument("--strategy_para", type=int, default=0.8) # p for top_p, k for top_k, no use when direct
parser.add_argument("--prefix", type=str, default="I have a ")
parser.add_argument("--suffix", type=str, default=" in a joint framework.")
args = parser.parse_args()
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2-large')
prefix_ids = tokenizer(args.prefix).input_ids
suffix_ids = tokenizer(args.suffix).input_ids
input_ids = prefix_ids + suffix_ids
input_locs = list(range(len(prefix_ids))) + list(range(1024-len(suffix_ids), 1024))
# more generaly commands can be defined with something like below:
# input_ids = [0, 1, 512, 8080, 50256, 20000]
# input_locs = [5, 6, 19, 20, 1000, 10001]
input_ids = torch.tensor(input_ids, device="cuda")[None].repeat(args.batch_size, 1)
def proj_fun(x):
x[:, input_locs] = input_ids
return x
device = torch.device('cuda')
model, noise = load_model(args.model_path, device)
token_dim = model.config.tokens + 1
if args.method == 'euler' or args.method == 'tweedie':
sampler = DiffusionSampler(args.method, model, noise, (args.batch_size, args.length), token_dim, args.strategy, args.strategy_para, device=device)
else:
raise ValueError(f"Method {args.method} is not valid.")
samples = sampler.sample(args.steps, proj_fun=proj_fun)
text_samples = tokenizer.batch_decode(samples)
for i in text_samples:
print(i)
print("=================================================")
if __name__=="__main__":
main()