forked from Sherlock-Voice/Sherlock-Voice_Model
-
Notifications
You must be signed in to change notification settings - Fork 0
/
BERTClassifier.py
32 lines (26 loc) · 1.17 KB
/
BERTClassifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from torch import nn
import torch
class BERTClassifier(nn.Module):
def __init__(self,
bert,
hidden_size = 768,
num_classes = 2, # 감정 클래스 수로 조정 긍정/부정 클래스2개
dr_rate = None,
params = None):
super(BERTClassifier, self).__init__()
self.bert = bert
self.dr_rate = dr_rate
self.classifier = nn.Linear(hidden_size , num_classes)
if dr_rate:
self.dropout = nn.Dropout(p = dr_rate)
def gen_attention_mask(self, token_ids, valid_length):
attention_mask = torch.zeros_like(token_ids)
for i, v in enumerate(valid_length):
attention_mask[i][:v] = 1
return attention_mask.float()
def forward(self, token_ids, valid_length, segment_ids):
attention_mask = self.gen_attention_mask(token_ids, valid_length)
_, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device),return_dict = False)
if self.dr_rate:
out = self.dropout(pooler)
return self.classifier(out)