forked from fastnlp/fastNLP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_dpcnn.py
122 lines (94 loc) · 4.18 KB
/
train_dpcnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径
import torch.cuda
from fastNLP.core.utils import cache_results
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from fastNLP.core.trainer import Trainer
from fastNLP import CrossEntropyLoss, AccuracyMetric
from fastNLP.embeddings import StaticEmbedding
from reproduction.text_classification.model.dpcnn import DPCNN
from fastNLP.io.data_loader import YelpLoader
from fastNLP.core.sampler import BucketSampler
from fastNLP.core import LRScheduler
from fastNLP.core.const import Const as C
from fastNLP.core.vocabulary import VocabularyOption
from utils.util_init import set_rng_seeds
import os
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# hyper
class Config():
seed = 12345
model_dir_or_name = "dpcnn-yelp-f"
embedding_grad = True
train_epoch = 30
batch_size = 100
task = "yelp_f"
#datadir = 'workdir/datasets/SST'
# datadir = 'workdir/datasets/yelp_polarity'
datadir = 'workdir/datasets/yelp_full'
#datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"}
datafile = {"train": "train.csv", "test": "test.csv"}
lr = 1e-3
src_vocab_op = VocabularyOption(max_size=100000)
embed_dropout = 0.3
cls_dropout = 0.1
weight_decay = 1e-5
def __init__(self):
self.datadir = os.path.join(os.environ['HOME'], self.datadir)
self.datapath = {k: os.path.join(self.datadir, v)
for k, v in self.datafile.items()}
ops = Config()
set_rng_seeds(ops.seed)
print('RNG SEED: {}'.format(ops.seed))
# 1.task相关信息:利用dataloader载入dataInfo
#datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train'])
@cache_results(ops.model_dir_or_name+'-data-cache')
def load_data():
datainfo = YelpLoader(fine_grained=True, lower=True).process(
paths=ops.datapath, train_ds=['train'], src_vocab_op=ops.src_vocab_op)
for ds in datainfo.datasets.values():
ds.apply_field(len, C.INPUT, C.INPUT_LEN)
ds.set_input(C.INPUT, C.INPUT_LEN)
ds.set_target(C.TARGET)
embedding = StaticEmbedding(
datainfo.vocabs['words'], model_dir_or_name='en-glove-840b-300', requires_grad=ops.embedding_grad,
normalize=False
)
return datainfo, embedding
datainfo, embedding = load_data()
embedding.embedding.weight.data /= embedding.embedding.weight.data.std()
print(embedding.embedding.weight.mean(), embedding.embedding.weight.std())
# 2.或直接复用fastNLP的模型
# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)])
print(datainfo)
print(datainfo.datasets['train'][0])
model = DPCNN(init_embed=embedding, num_cls=len(datainfo.vocabs[C.TARGET]),
embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout)
print(model)
# 3. 声明loss,metric,optimizer
loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET)
metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET)
optimizer = SGD([param for param in model.parameters() if param.requires_grad == True],
lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay)
callbacks = []
callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5)))
# callbacks.append(
# LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch <
# ops.train_epoch * 0.8 else ops.lr * 0.1))
# )
# callbacks.append(
# FitlogCallback(data=datainfo.datasets, verbose=1)
# )
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)
# 4.定义train方法
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size),
metrics=[metric],
dev_data=datainfo.datasets['test'], device=device,
check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks,
n_epochs=ops.train_epoch, num_workers=4)
if __name__ == "__main__":
print(trainer.train())