diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5d7febb..0e7044d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,8 +1,6 @@ name: Tests -on: push -# TODO: Use the line below after we make this repo public -# on: [push, pull_request] +on: [push, pull_request] # When this workflow is queued, automatically cancel any previous running # or pending jobs from the same branch diff --git a/.gitignore b/.gitignore index f91c775..57d2317 100644 --- a/.gitignore +++ b/.gitignore @@ -119,6 +119,9 @@ venv.bak/ # Rope project settings .ropeproject +# PyCharm project settings +.idea + # mkdocs documentation /site diff --git a/dask_mongo/core.py b/dask_mongo/core.py index d9265e9..010be9b 100644 --- a/dask_mongo/core.py +++ b/dask_mongo/core.py @@ -1,131 +1,121 @@ -from copy import deepcopy +from __future__ import annotations + from math import ceil -from typing import Dict +from typing import Any -import dask -import dask.bag as db import pymongo -from dask import delayed -from distributed import get_client +from bson import ObjectId +from dask.bag import Bag +from dask.base import tokenize +from dask.graph_manipulation import checkpoint -@delayed def write_mongo( - values, - connection_args, - database, - collection, -): - with pymongo.MongoClient(**connection_args) as mongo_client: - database_ = mongo_client.get_database(database) - # NOTE: `insert_many` will mutate its input by inserting a "_id" entry. - # This can lead to confusing results, so we make a deepcopy to avoid this. - database_[collection].insert_many(deepcopy(values)) + values: list[dict], + connection_kwargs: dict[str, Any], + database: str, + collection: str, +) -> None: + with pymongo.MongoClient(**connection_kwargs) as mongo_client: + coll = mongo_client[database][collection] + # `insert_many` will mutate its input by inserting a "_id" entry. + # This can lead to confusing results; pass copies to it to preserve the input. + values = [v.copy() for v in values] + coll.insert_many(values) def to_mongo( - bag: db.Bag, - *, - connection_args: Dict, + bag: Bag, database: str, collection: str, - compute_options: Dict = None, -): + *, + connection_kwargs: dict[str, Any] = None, + compute: bool = True, + compute_kwargs: dict[str, Any] = None, +) -> Any: """Write a Dask Bag to a Mongo database. Parameters ---------- bag: Dask Bag to write into the database. - connection_args: - Connection arguments to pass to ``MongoClient``. - database: + database : str Name of the database to write to. If it does not exists it will be created. - collection: + collection : str Name of the collection within the database to write to. If it does not exists it will be created. - compute_options: - Keyword arguments to be forwarded to ``dask.compute()``. + connection_kwargs : dict + Arguments to pass to ``MongoClient``. + compute : bool, optional + If true, immediately executes. If False, returns a delayed + object, which can be computed at a later time. + compute_kwargs : dict, optional + Options to be passed in to the compute method + Returns + ------- + If compute=True, block until computation is done, then return None. + If compute=False, immediately return a dask.delayed object. """ - if compute_options is None: - compute_options = {} - - partitions = [ - write_mongo(partition, connection_args, database, collection) - for partition in bag.to_delayed() - ] - - try: - client = get_client() - except ValueError: - # Using single-machine scheduler - dask.compute(partitions, **compute_options) + partials = bag.map_partitions( + write_mongo, connection_kwargs or {}, database, collection + ) + collect = checkpoint(partials) + if compute: + return collect.compute(**compute_kwargs or {}) else: - return client.compute(partitions, **compute_options) + return collect -@delayed def fetch_mongo( - connection_args, - database, - collection, - id_min, - id_max, - match, - include_last=False, -): - with pymongo.MongoClient(**connection_args) as mongo_client: - database_ = mongo_client.get_database(database) - - results = list( - database_[collection].aggregate( - [ - {"$match": match}, - { - "$match": { - "_id": { - "$gte": id_min, - "$lte" if include_last else "$lt": id_max, - } - } - }, - ] - ) - ) - - return results + connection_kwargs: dict[str, Any], + database: str, + collection: str, + match: dict[str, Any], + id_min: ObjectId, + id_max: ObjectId, + include_last: bool, +) -> list[dict[str, Any]]: + match2 = {"_id": {"$gte": id_min, "$lte" if include_last else "$lt": id_max}} + with pymongo.MongoClient(**connection_kwargs) as mongo_client: + coll = mongo_client[database][collection] + return list(coll.aggregate([{"$match": match}, {"$match": match2}])) def read_mongo( - connection_args: Dict, database: str, collection: str, chunksize: int, - match: Dict = {}, + *, + connection_kwargs: dict[str, Any] = None, + match: dict[str, Any] = None, ): """Read data from a Mongo database into a Dask Bag. Parameters ---------- - connection_args: - Connection arguments to pass to ``MongoClient``. database: - Name of the database to write to. If it does not exists it will be created. + Name of the database to read from collection: - Name of the collection within the database to write to. - If it does not exists it will be created. + Name of the collection within the database to read from chunksize: Number of elements desired per partition. + connection_kwargs: + Connection arguments to pass to ``MongoClient`` match: - Dictionary with match expression. By default it will bring all the documents in the collection. + MongoDB match query, used to filter the documents in the collection. If omitted, + this function will load all the documents in the collection. """ + if not connection_kwargs: + connection_kwargs = {} + if not match: + match = {} - with pymongo.MongoClient(**connection_args) as mongo_client: - database_ = mongo_client.get_database(database) + with pymongo.MongoClient(**connection_kwargs) as mongo_client: + coll = mongo_client[database][collection] nrows = next( ( - database_[collection].aggregate( + coll.aggregate( [ {"$match": match}, {"$count": "count"}, @@ -137,7 +127,7 @@ def read_mongo( npartitions = int(ceil(nrows / chunksize)) partitions_ids = list( - database_[collection].aggregate( + coll.aggregate( [ {"$match": match}, {"$bucketAuto": {"groupBy": "$_id", "buckets": npartitions}}, @@ -146,17 +136,16 @@ def read_mongo( ) ) - partitions = [ - fetch_mongo( - connection_args, - database, - collection, + common_args = (connection_kwargs, database, collection, match) + name = "read_mongo-" + tokenize(common_args, chunksize) + dsk = { + (name, i): ( + fetch_mongo, + *common_args, partition["_id"]["min"], partition["_id"]["max"], - match, - include_last=idx == len(partitions_ids) - 1, + i == len(partitions_ids) - 1, ) - for idx, partition in enumerate(partitions_ids) - ] - - return db.from_delayed(partitions) + for i, partition in enumerate(partitions_ids) + } + return Bag(dsk, name, len(partitions_ids)) diff --git a/dask_mongo/tests/test_core.py b/dask_mongo/tests/test_core.py index 9630142..c1d0db7 100644 --- a/dask_mongo/tests/test_core.py +++ b/dask_mongo/tests/test_core.py @@ -6,7 +6,6 @@ import pymongo import pytest from dask.bag.utils import assert_eq -from distributed import wait from distributed.utils_test import cluster_fixture # noqa: F401 from distributed.utils_test import client, gen_cluster, loop # noqa: F401 @@ -14,16 +13,16 @@ @pytest.fixture -def connection_args(tmp_path): +def connection_kwargs(tmp_path): port = 27016 with subprocess.Popen( ["mongod", f"--dbpath={str(tmp_path)}", f"--port={port}"] ) as proc: - connection_args = { + connection_kwargs = { "host": "localhost", "port": port, } - yield connection_args + yield connection_kwargs proc.terminate() @@ -40,51 +39,65 @@ def gen_data(size=10): @gen_cluster(client=True, clean_kwargs={"threads": False}) -async def test_to_mongo(c, s, a, b, connection_args): +async def test_to_mongo_distributed_async(c, s, a, b, connection_kwargs): records = gen_data(size=10) npartitions = 3 b = db.from_sequence(records, npartitions=npartitions) - with pymongo.MongoClient(**connection_args) as mongo_client: - database = "test-db" - assert database not in mongo_client.list_database_names() - collection = "test-collection" - - partitions = to_mongo( - b, - connection_args=connection_args, - database=database, - collection=collection, - ) - assert len(partitions) == npartitions - await wait(partitions) + database = "test-db" + collection = "test-collection" + delayed_ = to_mongo( + b, + database, + collection, + connection_kwargs=connection_kwargs, + # DaskMethodsMixin.compute() does not work with async distributed.Client + compute=False, + ) + with pymongo.MongoClient(**connection_kwargs) as mongo_client: + assert database not in mongo_client.list_database_names() + await c.compute(delayed_) assert database in mongo_client.list_database_names() assert [collection] == mongo_client[database].list_collection_names() - results = list(mongo_client[database][collection].find()) + # Drop "_id" and sort by "idx" for comparison results = [{k: v for k, v in result.items() if k != "_id"} for result in results] results = sorted(results, key=lambda x: x["idx"]) assert_eq(b, results) -def test_to_mongo_single_machine_scheduler(connection_args): +@pytest.mark.parametrize( + "compute,compute_kwargs", + [ + (False, None), + (True, None), + (True, dict(scheduler="sync")), + ], +) +def test_to_mongo(connection_kwargs, compute, compute_kwargs): records = gen_data(size=10) npartitions = 3 b = db.from_sequence(records, npartitions=npartitions) + database = "test-db" + collection = "test-collection" - with pymongo.MongoClient(**connection_args) as mongo_client: - database = "test-db" + with pymongo.MongoClient(**connection_kwargs) as mongo_client: assert database not in mongo_client.list_database_names() - collection = "test-collection" - to_mongo( + out = to_mongo( b, - connection_args=connection_args, - database=database, - collection=collection, + database, + collection, + connection_kwargs=connection_kwargs, + compute=compute, + compute_kwargs=compute_kwargs, ) + if compute: + assert out is None + else: + assert out.compute() is None assert database in mongo_client.list_database_names() assert [collection] == mongo_client[database].list_collection_names() @@ -96,40 +109,40 @@ def test_to_mongo_single_machine_scheduler(connection_args): assert_eq(b, results) -def test_read_mongo(connection_args, client): +def test_read_mongo(connection_kwargs, client): records = gen_data(size=10) database = "test-db" collection = "test-collection" - with pymongo.MongoClient(**connection_args) as mongo_client: + with pymongo.MongoClient(**connection_kwargs) as mongo_client: database_ = mongo_client.get_database(database) database_[collection].insert_many(deepcopy(records)) b = read_mongo( - connection_args=connection_args, - database=database, - collection=collection, + database, + collection, chunksize=5, + connection_kwargs=connection_kwargs, ) b = b.map(lambda x: {k: v for k, v in x.items() if k != "_id"}) assert_eq(b, records) -def test_read_mongo_match(connection_args): +def test_read_mongo_match(connection_kwargs): records = gen_data(size=10) database = "test-db" collection = "test-collection" - with pymongo.MongoClient(**connection_args) as mongo_client: + with pymongo.MongoClient(**connection_kwargs) as mongo_client: database_ = mongo_client.get_database(database) database_[collection].insert_many(deepcopy(records)) b = read_mongo( - connection_args=connection_args, - database=database, - collection=collection, + database, + collection, chunksize=5, + connection_kwargs=connection_kwargs, match={"idx": {"$gte": 2, "$lte": 7}}, ) @@ -138,21 +151,21 @@ def test_read_mongo_match(connection_args): assert_eq(b, expected) -def test_read_mongo_chunksize(connection_args): +def test_read_mongo_chunksize(connection_kwargs): records = gen_data(size=10) database = "test-db" collection = "test-collection" - with pymongo.MongoClient(**connection_args) as mongo_client: + with pymongo.MongoClient(**connection_kwargs) as mongo_client: database_ = mongo_client.get_database(database) database_[collection].insert_many(deepcopy(records)) # divides evenly total nrows, 10/5 = 2 b = read_mongo( - connection_args=connection_args, - database=database, - collection=collection, + database, + collection, chunksize=5, + connection_kwargs=connection_kwargs, ) assert b.npartitions == 2 @@ -160,10 +173,10 @@ def test_read_mongo_chunksize(connection_args): # does not divides evenly total nrows, 10/3 -> 4 b = read_mongo( - connection_args=connection_args, - database=database, - collection=collection, + database, + collection, chunksize=3, + connection_kwargs=connection_kwargs, ) assert b.npartitions == 4