diff --git a/hloc/pairs_from_retrieval.py b/hloc/pairs_from_retrieval.py index 4329547a..b3c71cd7 100644 --- a/hloc/pairs_from_retrieval.py +++ b/hloc/pairs_from_retrieval.py @@ -67,9 +67,10 @@ def pairs_from_score_matrix(scores: torch.Tensor, return pairs -def main(descriptors, output, num_matched, +def main(descriptors: Path, output: Path, num_matched: int, query_prefix=None, query_list=None, - db_prefix=None, db_list=None, db_model=None, db_descriptors=None): + db_prefix=None, db_list=None, db_model=None, db_descriptors=None, + chunk_size: int = -1): logger.info('Extracting image pairs from a retrieval database.') # We handle multiple reference feature files. @@ -94,17 +95,27 @@ def main(descriptors, output, num_matched, device = 'cuda' if torch.cuda.is_available() else 'cpu' db_desc = get_descriptors(db_names, db_descriptors, name2db) + db_desc = db_desc.to(device) query_desc = get_descriptors(query_names, descriptors) - sim = torch.einsum('id,jd->ij', query_desc.to(device), db_desc.to(device)) - # Avoid self-matching - self = np.array(query_names)[:, None] == np.array(db_names)[None] - pairs = pairs_from_score_matrix(sim, self, num_matched, min_score=0) - pairs = [(query_names[i], db_names[j]) for i, j in pairs] + num_pairs = 0 + chunk_size = chunk_size if chunk_size > 0 else len(query_names) + with output.open('w') as f: + query_name_splits = [query_names[i:i + chunk_size] for i in range(0, len(query_names), chunk_size)] + query_splits = torch.split(query_desc, chunk_size) + for i, (names, query_desc_split) in enumerate(zip(query_name_splits, query_splits)): + if i != 0: + f.write('\n') + sim = torch.einsum('id,jd->ij', query_desc_split.to(device), db_desc) - logger.info(f'Found {len(pairs)} pairs.') - with open(output, 'w') as f: - f.write('\n'.join(' '.join([i, j]) for i, j in pairs)) + # Avoid self-matching + self = np.array(names)[:, None] == np.array(db_names)[None] + pairs = pairs_from_score_matrix(sim, self, num_matched, min_score=0) + pairs = [(names[i], db_names[j]) for i, j in pairs] + num_pairs += len(pairs) + f.write('\n'.join(' '.join([i, j]) for i, j in pairs)) + + logger.info(f'Found {num_pairs} pairs.') if __name__ == "__main__": @@ -118,5 +129,6 @@ def main(descriptors, output, num_matched, parser.add_argument('--db_list', type=Path) parser.add_argument('--db_model', type=Path) parser.add_argument('--db_descriptors', type=Path) + parser.add_argument('--chunk_size', type=int, default=-1) args = parser.parse_args() main(**args.__dict__)