Skip to content

Commit

Permalink
Merge pull request #30 from himkt/himkt/modify
Browse files Browse the repository at this point in the history
Some updates
  • Loading branch information
himkt authored Feb 11, 2019
2 parents e18c40b + 2189703 commit 5f2f35e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 19 deletions.
6 changes: 3 additions & 3 deletions pyner/named_entity/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
logger = logging.getLogger(__name__)


def update_instances(train_datas, params):
def update_instances(train_datas, params, attr):
train_size = params.get('train_size', 1.0)
if train_size <= 0 or 1 <= train_size:
assert Exception('train_size must be in (0, 1]')
n_train = len(train_datas[0])
instances = int(train_size * n_train)
rate = 100 * train_size
logger.debug(f'Use {instances} example for training ({rate:.2f}%)')
logger.debug(f'Use {instances} example for {attr} ({rate:.2f}%)')
return [t[:instances] for t in train_datas]


Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(self, vocab, params, attr, transform):
word_sentences = vocab.load_word_sentences(word_path)
tag_sentences = vocab.load_tag_sentences(tag_path)
datas = [word_sentences, tag_sentences]
word_sentences, tag_sentences = update_instances(datas, params)
word_sentences, tag_sentences = update_instances(datas, params, attr)
self.word_sentences = word_sentences
self.tag_sentences = tag_sentences

Expand Down
28 changes: 12 additions & 16 deletions pyner/named_entity/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import chainer.training as T
import chainer.training.extensions as E

import datetime
import chainer
import logging
import yaml
Expand All @@ -27,7 +28,6 @@ def prepare_pretrained_word_vector(
word2idx,
gensim_model,
syn0,
lowercase=False
):

# if lowercased word is in pre-trained embeddings,
Expand All @@ -40,7 +40,7 @@ def prepare_pretrained_word_vector(
syn0[idx, :] = word_vector
match1 += 1

elif lowercase and word.lower() in gensim_model:
elif word.lower() in gensim_model:
word_vector = gensim_model.wv.word_vec(word.lower())
syn0[idx, :] = word_vector
match2 += 1
Expand All @@ -54,21 +54,21 @@ def prepare_pretrained_word_vector(


def create_iterator(vocab, configs, role, transform):
if 'batch' not in configs:
if 'iteration' not in configs:
raise Exception('Batch configurations are not found')

if 'external' not in configs:
raise Exception('External data configurations are not found')

batch_configs = configs['batch']
iteration_configs = configs['iteration']
external_configs = configs['external']

is_train = role == 'train'
shuffle = True if is_train else False
repeat = True if is_train else False

dataset = SequenceLabelingDataset(vocab, external_configs, role, transform)
batch_size = batch_configs['batch_size'] if is_train else len(dataset)
batch_size = iteration_configs['batch_size'] if is_train else len(dataset)

iterator = It.SerialIterator(
dataset,
Expand All @@ -94,10 +94,7 @@ def create_iterator(vocab, configs, role, transform):
configs = ConfigParser.parse(args.config)
config_path = Path(args.config)

model_path = configs['output']
logger.debug(f'model_dir: {model_path}')
vocab = Vocabulary.prepare(configs)

num_word_vocab = max(vocab.dictionaries['word2idx'].values()) + 1
num_char_vocab = max(vocab.dictionaries['char2idx'].values()) + 1
num_tag_vocab = max(vocab.dictionaries['tag2idx'].values()) + 1
Expand Down Expand Up @@ -126,8 +123,7 @@ def create_iterator(vocab, configs, role, transform):
syn0 = prepare_pretrained_word_vector(
word2idx,
vocab.gensim_model,
syn0,
preprocessing_configs['lower']
syn0
)
model.set_pretrained_word_vectors(syn0)

Expand All @@ -151,18 +147,18 @@ def create_iterator(vocab, configs, role, transform):
params['num_char_vocab'] = num_char_vocab
params['num_tag_vocab'] = num_tag_vocab

epoch = configs['batch']['epoch']
logger.debug(f'Create {model_path} for trainer\'s output')
epoch = configs['iteration']['epoch']
trigger = (epoch, 'epoch')

output_path = Path(model_path)
output_path.mkdir(parents=True, exist_ok=True)
save_args(params, model_path)
model_path = configs['output']
timestamp = datetime.datetime.now()
timestamp_str = timestamp.isoformat()
output_path = Path(f'{model_path}.{timestamp_str}')

trainer = T.Trainer(
updater,
trigger,
out=model_path
out=output_path
)
save_args(params, output_path)
msg = f'Create \x1b[31m{output_path}\x1b[0m for saving model snapshots'
Expand Down

0 comments on commit 5f2f35e

Please sign in to comment.