forked from KittenCN/stock_prediction
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbert_data_preprocess.py
41 lines (36 loc) · 1.81 KB
/
bert_data_preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from common import *
train_first_line = True
x_list = []
y_list = []
## ----type I
texts = pd.read_csv(open(bert_data_path+'/data'+'/Train_DataSet.csv',encoding='UTF-8'))
labels = pd.read_csv(open(bert_data_path+'/data'+'/Train_DataSet_Label.csv',encoding='UTF-8'))
pbar = tqdm(total=len(texts), leave=False)
for id in texts['id']:
pbar.update(1)
if labels[labels['id'] == id].empty:
continue
label = labels[labels['id'] == id]['label'].values[0]
if not pd.isna(texts[texts['id'] == id]['title'].values[0]) and pd.isna(texts[texts['id'] == id]['content'].values[0]):
text = texts[texts['id'] == id]['title'].values[0]
elif pd.isna(texts[texts['id'] == id]['title'].values[0]) and not pd.isna(texts[texts['id'] == id]['content'].values[0]):
text = texts[texts['id'] == id]['content'].values[0]
elif not pd.isna(texts[texts['id'] == id]['content'].values[0]) and not pd.isna(texts[texts['id'] == id]['content'].values[0]):
text = texts[texts['id'] == id]['title'].values[0]+','+texts[texts['id'] == id]['content'].values[0]
else:
continue
text.replace("'",'').replace('"','')
x_list.append(text)
y_list.append(label*0.5)
pbar.close()
data = pd.DataFrame({'label':y_list,'text':x_list})
data.to_csv(bert_data_path+'/data'+'/Train2.csv',index=False,sep=',',encoding='utf-8')
##----type II
negative = open(bert_data_path+'/data'+'/negative.txt',encoding='UTF-8').readlines()
positive = open(bert_data_path+'/data'+'/positive.txt',encoding='UTF-8').readlines()
negative = map(lambda x: x.strip(), negative)
positive = map(lambda x: x.strip(), positive)
df_neg = pd.DataFrame({'label':0,'text':negative})
df_pos = pd.DataFrame({'label':1,'text':positive})
df = pd.concat([df_pos,df_neg],axis=0)
df.to_csv(bert_data_path+'/data'+'/Train3.csv',index=False,sep=',',encoding='utf-8')