-
Notifications
You must be signed in to change notification settings - Fork 2
/
tot.py
190 lines (137 loc) · 4.26 KB
/
tot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import json
import logging
from pathlib import Path
from typing import NamedTuple, Dict, List
import ir_datasets
from ir_datasets.formats import TrecQrels, BaseDocs, BaseQueries
from ir_datasets.indices import PickleLz4FullStore
NAME = "trec-tot"
log = logging.getLogger(__name__)
N_DOCS = 3185450
class TrecToTDoc(NamedTuple):
title: str
doc_id: int
wikidata_id: str
text: str
sections: List[Dict[str, str]]
class TrecToTQuery(NamedTuple):
query_id: str
query: str
class TrecToTDocs(BaseDocs):
def __init__(self, dlc):
super().__init__()
self._dlc = dlc
def docs_iter(self):
return iter(self.docs_store())
def parse_sections(self, doc):
sections = {}
for s in doc["sections"]:
sections[s["section"]] = doc["text"][s["start"]:s["end"]]
doc["sections"] = sections
return doc
def _docs_iter(self):
with self._dlc.stream() as stream:
for line in stream:
yield TrecToTDoc(**self.parse_sections(json.loads(line)))
def docs_cls(self):
return TrecToTDoc
def docs_store(self, field='doc_id'):
return PickleLz4FullStore(
path=f'{ir_datasets.util.home_path()}/trec-tot24/docs.pklz4',
init_iter_fn=self._docs_iter,
data_cls=self.docs_cls(),
lookup_field=field,
index_fields=[field],
count_hint=N_DOCS
)
def docs_count(self):
return self.docs_store().count()
def docs_namespace(self):
return f'{NAME}/{self._name}'
def docs_lang(self):
return 'en'
class LocalFileStream:
def __init__(self, path):
self._path = path
def stream(self):
return open(self._path, "rb")
class TrecToTQueries(BaseQueries):
def __init__(self, name, dlc):
super().__init__()
self._name = name
self._dlc = dlc
def queries_iter(self):
with self._dlc.stream() as stream:
for line in stream:
data = json.loads(line)
yield TrecToTQuery(**data)
def queries_cls(self):
return TrecToTQuery
def queries_namespace(self):
return f'{NAME}/{self._name}'
def queries_lang(self):
return 'en'
def register(path):
qrel_defs = {
1: 'answer',
0: 'not answer',
}
path = Path(path)
# corpus
corpus = path / "corpus.jsonl"
for split in {"train-2024", "dev1-2024", "dev2-2024", "test-2024"}:
name = split
# queries
queries = path / split / "queries.jsonl"
if not queries.exists():
log.warning(f"not loading '{split}' split: {queries} not found")
continue
components = [
TrecToTDocs(LocalFileStream(corpus)),
TrecToTQueries(name, LocalFileStream(queries)),
]
has_qrel = False
# no qrel for test set
if split != "test-2024":
qrel = path / split / "qrel.txt"
has_qrel = True
components.append(TrecQrels(LocalFileStream(qrel), qrel_defs))
ds = ir_datasets.Dataset(
*components
)
ir_datasets.registry.register(NAME + ":" + name, ds)
log.info(f"registered: {NAME}:{name} [qrel={has_qrel}]")
if __name__ == '__main__':
path = input("Enter data path:")
register(path.strip())
sets = []
for split in {"train-2024", "dev1-2024", "dev2-2024", "test-2024"}:
name = split
sets.append(NAME + ":" + name)
print(f"available sets: {sets}")
q = None
for name in sets:
try:
dataset = ir_datasets.load(name)
except KeyError:
print(f"error loading {name}, skipping!")
continue
n_q = 0
for q in dataset.queries_iter():
n_q += 1
if "test" not in name:
n_qrel = 0
for qrel in dataset.qrels_iter():
n_qrel += 1
assert n_qrel == n_q
print(name)
print(f"n queries: {n_q}")
print()
print(f"example query: {q}")
n_docs = 0
dataset = ir_datasets.load("trec-tot:train-2024")
doc = None
for doc in dataset.docs_iter():
n_docs += 1
print(f"example doc: {doc}")
print("corpus size: ", n_docs)