Skip to content

Commit

Permalink
Merge pull request #4 from google-research/bugfix
Browse files Browse the repository at this point in the history
make evaluation step consistent between 1p and multihop
  • Loading branch information
Hanjun-Dai authored Feb 15, 2022
2 parents 02e5a89 + be4d1a6 commit e4ba95a
Show file tree
Hide file tree
Showing 14 changed files with 123 additions and 59 deletions.
17 changes: 3 additions & 14 deletions smore/common/embedding/embed_rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,11 @@ def sync(self):


class EmbeddingReadOnly(object):
def __init__(self, embed, gpu_id=-1, target_dtype=None):
def __init__(self, embed, gpu_id=-1):
super(EmbeddingReadOnly, self).__init__()
self.embed = embed.data
self.gpu_id = gpu_id
self.embed_dim = embed.shape[1]
if target_dtype is None:
self.target_dtype = embed.dtype
else:
self.target_dtype = target_dtype
if target_dtype != embed.dtype:
self.type_cast = lambda x: x.type(self.target_dtype)
else:
self.type_cast = lambda x: x
if gpu_id == -1:
self.device = 'cpu'
else:
Expand Down Expand Up @@ -74,15 +66,13 @@ def read(self, indices, name=None):
t = self.embed.to(self.device)
t.job_handle = self.dummy_job
return t
if indices.numel() == self.embed.shape[0] and self.embed.is_cuda: # TODO: make it more explicit
return self.embed
if not self.embed.is_cuda:
if name is not None and indices.numel() != self.embed.shape[0]: # TODO: make it more explicit
return self.async_read(indices, name)
for key in self.last_write_jobs:
self.last_write_jobs[key].sync()
indices = indices.view(-1)
submat = self.type_cast(self.embed[indices].to(self.device))
submat = self.embed[indices].to(self.device)
submat.job_handle = self.dummy_job
return submat

Expand All @@ -97,7 +87,7 @@ def async_read(self, indices, name):
self.embed,
buf,
out)
submat = self.type_cast(out[:indices.shape[0]])
submat = out[:indices.shape[0]]
submat.job_handle = job_handle
return submat

Expand All @@ -106,7 +96,6 @@ class EmbeddingRW(EmbeddingReadOnly):
def __init__(self, embed, gpu_id=-1):
super(EmbeddingRW, self).__init__(embed, gpu_id)
self.write_thread_pool = self.read_thread_pool
assert self.target_dtype == self.embed.dtype
self.write_buf = {}
self.write_src_cache = {}

Expand Down
76 changes: 74 additions & 2 deletions smore/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,15 @@
from _thread import start_new_thread
import traceback
import logging
import pdb

import os
from tqdm import tqdm
import shutil
import zipfile
import urllib.request as ur
from smore.common.config import name_query_dict, query_name_dict

GBFACTOR = float(1 << 30)

def cal_ent_loc(query_structure, idx):
if query_structure[0] == '<':
return cal_ent_loc(query_structure[1], idx)
Expand Down Expand Up @@ -236,6 +241,73 @@ def sample_negative_bidirectional(query, ent_in, ent_out, nent):
pass


def download_url(url, folder, log=True):
r"""Downloads the content of an URL to a specific folder.
Args:
url (string): The url.
folder (string): The folder.
log (bool, optional): If :obj:`False`, will not print anything to the
console. (default: :obj:`True`)
"""

filename = url.rpartition('/')[2]
path = osp.join(folder, filename)

if osp.exists(path) and osp.getsize(path) > 0: # pragma: no cover
if log:
print('Using exist file', filename)
return path

if log:
print('Downloading', url)

if not osp.exists(folder):
os.makedirs(folder)
data = ur.urlopen(url)

size = int(data.info()["Content-Length"])

chunk_size = 1024*1024
num_iter = int(size/chunk_size) + 2

downloaded_size = 0

try:
with open(path, 'wb') as f:
pbar = tqdm(range(num_iter))
for i in pbar:
chunk = data.read(chunk_size)
downloaded_size += len(chunk)
pbar.set_description("Downloaded {:.2f} GB".format(float(downloaded_size)/GBFACTOR))
f.write(chunk)
except:
if osp.exists(path):
os.remove(path)
raise RuntimeError('Stopped downloading due to interruption.')


return path

def maybe_download_dataset(data_path):
data_name = data_path.split('/')[-1]
if data_name in ['FB15k', 'FB15k-237', 'NELL', "FB400k"]:
if not (osp.exists(data_path) and osp.exists(osp.join(data_path, "stats.txt"))):
url = "https://snap.stanford.edu/betae/%s.zip" % data_name
path = download_url(url, osp.split(osp.abspath(data_path))[0])
extract_zip(path, osp.split(osp.abspath(data_path))[0])
os.unlink(path)

def extract_zip(path, folder):
r"""Extracts a zip archive to a specific folder.
Args:
path (string): The path to the tar archive.
folder (string): The folder.
"""
print('Extracting', path)
with zipfile.ZipFile(path, 'r') as f:
f.extractall(folder)


def thread_wrapped_func(func):
"""Wrapped func for torch.multiprocessing.Process.
Expand Down
2 changes: 1 addition & 1 deletion smore/models/kg_reasoning.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def attach_feature(self, name, feat, gpu_id, is_sparse):
else:
device = 'cuda:{}'.format(gpu_id)
if is_sparse:
feat_read = EmbeddingReadOnly(feat, gpu_id=gpu_id, target_dtype=torch.float32)
feat_read = EmbeddingReadOnly(feat, gpu_id=gpu_id)
setattr(self, "%s_feat" % name, feat_read)
else:
setattr(self, "%s_feat" % name, feat.to(device))
Expand Down
7 changes: 4 additions & 3 deletions smore/training/beta_scripts/train_15k.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,20 @@ eval_path=$data_folder/eval-betae
export CUDA_VISIBLE_DEVICES=0,1,2,3

#beta
python ../main_train.py --do_test --gpus '0.1.2.3' \
python ../main_train.py --do_train --do_test --gpus '0.1.2.3' \
--data_path $data_folder --eval_path $eval_path \
-n 1024 -b 512 -d 400 -g 60 \
-a 0.5 -adv \
-lr 0.0001 --max_steps 450001 --geo beta --valid_steps 15000 \
-betam "(1600,2)" --tasks '1p.2p.3p.2i.3i.ip.pi.2u.up' --training_tasks '1p.2p.3p.2i.3i' \
-betam '(1600,2,fisher,0.055,layer,True)' --tasks '1p.2p.3p.2i.3i.ip.pi.2u.up' --training_tasks '1p.2p.3p.2i.3i' \
--save_checkpoint_steps 150000 \
--share_negative \
--logit_impl custom \
--lr_schedule none \
--sampler_type naive \
--filter_test \
--share_optim_stats \
--port 29501 \
--port 29511 \
--online_sample --prefix '../logs' --online_sample_mode '(500,0,w,wstruct,120)' \
--train_online_mode '(single,3000,e,True,before)' --optim_mode '(aggr,adam,cpu,False,5)' --online_weighted_structure_prob '(20,20,20,10,10)' --print_on_screen \
$@
9 changes: 5 additions & 4 deletions smore/training/beta_scripts/train_237.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

data_name=FB15k-237-betae
data_name=FB15k-237
data_folder=$HOME/data/knowledge_graphs/$data_name
eval_path=$data_folder/eval-original
eval_path=$data_folder/eval-betae

export CUDA_VISIBLE_DEVICES=0,1,2,3

Expand All @@ -26,14 +26,15 @@ python ../main_train.py --do_train --do_test --gpus '0.1.2.3' \
-n 1024 -b 512 -d 400 -g 60 \
-a 0.5 -adv \
-lr 0.0001 --max_steps 450001 --geo beta --valid_steps 15000 \
-betam "(1600,2)" --tasks '1p.2p.3p.2i.3i.ip.pi.2u.up' --training_tasks '1p.2p.3p.2i.3i' \
-betam '(1600,2,fisher,0.055,layer,True)' --tasks '1p.2p.3p.2i.3i.ip.pi.2u.up' --training_tasks '1p.2p.3p.2i.3i' \
--save_checkpoint_steps 150000 \
--share_negative \
--lr_schedule none \
--logit_impl custom \
--sampler_type naive \
--filter_test \
--share_optim_stats \
--port 29500 \
--port 29510 \
--online_sample --prefix '../logs' --online_sample_mode '(500,0,w,wstruct,120)' \
--train_online_mode '(single,3000,e,True,before)' --optim_mode '(aggr,adam,cpu,False,5)' --online_weighted_structure_prob '(20,20,20,10,10)' --print_on_screen \
$@
5 changes: 3 additions & 2 deletions smore/training/beta_scripts/train_nell.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ python ../main_train.py --do_train --do_test --gpus '0.1.2.3' \
-n 1024 -b 512 -d 400 -g 60 \
-a 0.5 -adv \
-lr 0.0001 --max_steps 450001 --geo beta --valid_steps 15000 \
-betam "(1600,2)" --tasks '1p.2p.3p.2i.3i.ip.pi.2u.up' --training_tasks '1p.2p.3p.2i.3i' \
-betam '(1600,2,fisher,0.055,layer,True)' --tasks '1p.2p.3p.2i.3i.ip.pi.2u.up' --training_tasks '1p.2p.3p.2i.3i' \
--save_checkpoint_steps 150000 \
--share_negative \
--lr_schedule none \
--logit_impl custom \
--sampler_type naive \
--filter_test \
--share_optim_stats \
--port 29500 \
--port 29512 \
--online_sample --prefix '../logs' --online_sample_mode '(500,0,w,wstruct,120)' \
--train_online_mode '(single,3000,e,True,before)' --optim_mode '(aggr,adam,cpu,False,5)' --online_weighted_structure_prob '(20,20,20,10,10)' --print_on_screen \
$@
3 changes: 2 additions & 1 deletion smore/training/box_scripts/train_15k.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
python ../main_train.py --do_train --do_test --gpus '0.1.2.3' \
--data_path $data_folder --eval_path $eval_path \
-n 1024 -b 512 -d 400 -g 24 \
-lr 0.0001 --max_steps 1000001 --geo box --valid_steps 20000 \
-lr 0.0001 --max_steps 1500001 --geo box --valid_steps 20000 \
-boxm '(none,0.02)' --tasks '1p.2p.3p.2i.3i.ip.pi.2u.up' --training_tasks '1p.2p.3p.2i.3i' \
--save_checkpoint_steps 50000 \
--sampler_type naive \
--logit_impl custom \
--lr_schedule none \
--port 29500 \
--share_negative \
Expand Down
9 changes: 5 additions & 4 deletions smore/training/box_scripts/train_237.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.

data_name=FB15k-237-betae
data_name=FB15k-237
data_folder=$HOME/data/knowledge_graphs/$data_name
eval_path=$data_folder/eval-original
eval_path=$data_folder/eval-betae

export CUDA_VISIBLE_DEVICES=0,1,2,3

#box
python ../main_train.py --do_test --gpus '0.1.2.3' \
python ../main_train.py --do_train --do_test --gpus '0.1.2.3' \
--data_path $data_folder --eval_path $eval_path \
-n 1024 -b 512 -d 400 -g 24 \
-lr 0.0001 --max_steps 450001 --geo box --valid_steps 15000 \
-boxm '(none,0.02)' --tasks '1p.2p.3p.2i.3i.ip.pi.2u.up' --training_tasks '1p.2p.3p.2i.3i' \
--save_checkpoint_steps 30000 \
--logit_impl custom \
--share_negative \
--filter_test \
--share_optim_stats \
--port 29500 \
--port 29502 \
--online_sample --prefix '../logs' --online_sample_mode '(500,0,w,wstruct,120)' \
--train_online_mode '(single,3000,e,True,before)' --optim_mode '(aggr,adam,cpu,False,5)' --online_weighted_structure_prob '(2,2,2,1,1)' --print_on_screen \
$@
3 changes: 2 additions & 1 deletion smore/training/box_scripts/train_nell.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ export CUDA_VISIBLE_DEVICES=4,5,6,7
python ../main_train.py --do_train --do_test --gpus '0.1.2.3' \
--data_path $data_folder --eval_path $eval_path \
-n 1024 -b 512 -d 400 -g 24 \
-lr 0.0001 --max_steps 450001 --geo box --valid_steps 15000 \
-lr 0.0001 --max_steps 600001 --geo box --valid_steps 15000 \
-boxm '(none,0.02)' --tasks '1p.2p.3p.2i.3i.ip.pi.2u.up' --training_tasks '1p.2p.3p.2i.3i' \
--save_checkpoint_steps 30000 \
--sampler_type naive \
--logit_impl custom \
--port 29501 \
--share_negative \
--filter_test \
Expand Down
4 changes: 3 additions & 1 deletion smore/training/main_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import math

from smore.models import build_model
from smore.common.util import flatten_query, list2tuple, parse_time, set_global_seed, eval_tuple, construct_graph, tuple2filterlist
from smore.common.util import flatten_query, list2tuple, parse_time, set_global_seed, eval_tuple, construct_graph, tuple2filterlist, maybe_download_dataset
from smore.common.config import parse_args, all_tasks, query_name_dict, name_query_dict
from smore.common.embedding.embed_optimizer import get_optim_class
from smore.cpp_sampler.sampler_clib import KGMem
Expand Down Expand Up @@ -305,6 +305,8 @@ def main(parser):
set_global_seed(args.seed)
gpus = [int(i) for i in args.gpus.split(".")]
assert args.gpus == '.'.join([str(i) for i in range(len(gpus))]), 'only support continuous gpu ids starting from 0, please set CUDA_VISIBLE_DEVICES instead'

maybe_download_dataset(args.data_path)

setup_train_mode(args)

Expand Down
24 changes: 8 additions & 16 deletions smore/training/train_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_step_mp(model, args, train_sampler, test_dataloader, result_buffer, tra
rank = dist.get_rank()
step = 0
total_steps = len(test_dataloader)
logs = collections.defaultdict(list)
logs = collections.defaultdict(collections.Counter)
all_embed = None
negative_sample_bias = None

Expand Down Expand Up @@ -153,27 +153,19 @@ def test_step_mp(model, args, train_sampler, test_dataloader, result_buffer, tra
h10 = torch.mean((cur_ranking <= 10).to(torch.float)).item()
h1m = ((cur_ranking[0] == 1).to(torch.float)).item()

logs[query_structure].append({
'MRR': mrr,
'HITS1': h1,
'HITS3': h3,
'HITS10': h10,
'HITS1max': h1m,
'num_hard_answer': num_hard,
})
logs[query_structure]['MRR'] += mrr
logs[query_structure]['HITS1'] += h1
logs[query_structure]['HITS3'] += h3
logs[query_structure]['HITS10'] += h10
logs[query_structure]['HITS1max'] += h1m
logs[query_structure]['num_hard_answer'] += 1
logs[query_structure]['num_queries'] += 1

if step % args.test_log_steps == 0:
logging.info('Evaluating the model... (%d/%d)' % (step, total_steps))

step += 1

# metrics = collections.defaultdict(lambda: collections.defaultdict(int))
# for query_structure in logs:
# for metric in logs[query_structure][0].keys():
# if metric in ['num_hard_answer']:
# continue
# metrics[query_structure][metric] = sum([log[metric] for log in logs[query_structure]])/len(logs[query_structure])
# metrics[query_structure]['num_queries'] = len(logs[query_structure])
result_buffer.put((logs, train_step))


Expand Down
5 changes: 3 additions & 2 deletions smore/training/vec_scripts/train_15k.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
python ../main_train.py --do_train --do_test --gpus '0.1.2.3' \
--data_path $data_folder --eval_path $eval_path \
-n 1024 -b 512 -d 800 -g 24 \
-lr 0.0001 --max_steps 450001 --geo vec --valid_steps 15000 \
-lr 0.0001 --max_steps 2000001 --geo vec --valid_steps 15000 \
--tasks '1p.2p.3p.2i.3i.ip.pi.2u.up' --training_tasks '1p.2p.3p.2i.3i' \
--save_checkpoint_steps 30000 \
--sampler_type naive \
--logit_impl custom \
--share_negative \
--filter_test \
--port 29500 \
--port 29503 \
--share_optim_stats \
--online_sample --prefix '../logs' --online_sample_mode '(500,0,w,wstruct,120)' \
--train_online_mode '(single,3000,e,True,before)' --optim_mode '(aggr,adam,cpu,False,5)' --online_weighted_structure_prob '(2,2,2,1,1)' --print_on_screen \
Expand Down
11 changes: 6 additions & 5 deletions smore/training/vec_scripts/train_237.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.

data_name=FB15k-237-betae
data_name=FB15k-237
data_folder=$HOME/data/knowledge_graphs/$data_name
eval_path=$data_folder/eval-original
eval_path=$data_folder/eval-betae

export CUDA_VISIBLE_DEVICES=0,1,2,3

#vec
python ../main_train.py --do_train --do_test --gpus '0.1.2.3' \
--data_path $data_folder --eval_path $eval_path \
-n 1024 -b 512 -d 800 -g 24 \
-lr 0.0001 --max_steps 750001 --geo vec --valid_steps 15000 \
-lr 0.0001 --max_steps 1500001 --geo vec --valid_steps 15000 \
--tasks '1p.2p.3p.2i.3i.ip.pi.2u.up' --training_tasks '1p.2p.3p.2i.3i' \
--save_checkpoint_steps 30000 \
--lr_schedule none \
--sampler_type sqrt \
--sampler_type naive \
--logit_impl custom \
--share_negative \
--filter_test \
--port 29500 \
--port 29505 \
--share_optim_stats \
--online_sample --prefix '../logs' --online_sample_mode '(500,0,w,wstruct,120)' \
--train_online_mode '(single,3000,e,True,before)' --optim_mode '(aggr,adam,cpu,False,5)' --online_weighted_structure_prob '(2,2,2,1,1)' --print_on_screen \
Expand Down
Loading

0 comments on commit e4ba95a

Please sign in to comment.