Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General code review #6

Merged
merged 6 commits into from
Jul 16, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
name: Tests

on: push
# TODO: Use the line below after we make this repo public
# on: [push, pull_request]
on: [push, pull_request]

# When this workflow is queued, automatically cancel any previous running
# or pending jobs from the same branch
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ venv.bak/
# Rope project settings
.ropeproject

# PyCharm project settings
.idea

# mkdocs documentation
/site

Expand Down
130 changes: 57 additions & 73 deletions dask_mongo/core.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,35 @@
from copy import deepcopy
from __future__ import annotations

from math import ceil
from typing import Dict

import dask
import dask.bag as db
import pymongo
from dask import delayed
from distributed import get_client
from bson import ObjectId
from dask.bag import Bag
from dask.base import tokenize
from dask.delayed import Delayed
from dask.graph_manipulation import checkpoint


@delayed
def write_mongo(
values,
connection_args,
database,
collection,
):
values: list[dict],
connection_args: dict,
database: str,
collection: str,
) -> None:
with pymongo.MongoClient(**connection_args) as mongo_client:
database_ = mongo_client.get_database(database)
# NOTE: `insert_many` will mutate its input by inserting a "_id" entry.
# This can lead to confusing results, so we make a deepcopy to avoid this.
database_[collection].insert_many(deepcopy(values))
coll = mongo_client[database][collection]
# insert_many` will mutate its input by inserting a "_id" entry.
# This can lead to confusing results; pass copies to it to preserve the input.
values = [v.copy() for v in values]
coll.insert_many(values)


def to_mongo(
bag: db.Bag,
*,
connection_args: Dict,
bag: Bag,
connection_args: dict,
database: str,
collection: str,
compute_options: Dict = None,
):
) -> Delayed:
"""Write a Dask Bag to a Mongo database.

Parameters
Expand All @@ -44,41 +43,27 @@ def to_mongo(
collection:
Name of the collection within the database to write to.
If it does not exists it will be created.
compute_options:
Keyword arguments to be forwarded to ``dask.compute()``.
Returns
-------
dask.delayed
"""
if compute_options is None:
compute_options = {}

partitions = [
write_mongo(partition, connection_args, database, collection)
for partition in bag.to_delayed()
]
partials = bag.map_partitions(write_mongo, connection_args, database, collection)
return checkpoint(partials)

try:
client = get_client()
except ValueError:
# Using single-machine scheduler
dask.compute(partitions, **compute_options)
else:
return client.compute(partitions, **compute_options)


@delayed
def fetch_mongo(
connection_args,
database,
collection,
id_min,
id_max,
match,
include_last=False,
connection_args: dict,
database: str,
collection: str,
match: dict,
id_min: ObjectId,
id_max: ObjectId,
include_last: bool,
):
with pymongo.MongoClient(**connection_args) as mongo_client:
database_ = mongo_client.get_database(database)

results = list(
database_[collection].aggregate(
coll = mongo_client[database][collection]
return list(
coll.aggregate(
[
{"$match": match},
{
Expand All @@ -93,39 +78,39 @@ def fetch_mongo(
)
)

return results


def read_mongo(
connection_args: Dict,
connection_args: dict,
database: str,
collection: str,
chunksize: int,
match: Dict = {},
match: dict = None,
):
"""Read data from a Mongo database into a Dask Bag.

Parameters
----------
connection_args:
Connection arguments to pass to ``MongoClient``.
Connection arguments to pass to ``MongoClient``
database:
Name of the database to write to. If it does not exists it will be created.
Name of the database to read from
collection:
Name of the collection within the database to write to.
If it does not exists it will be created.
Name of the collection within the database to read from
chunksize:
Number of elements desired per partition.
match:
Dictionary with match expression. By default it will bring all the documents in the collection.
MongoDB match query, used to filter the documents in the collection. If omitted,
this function will load all the documents in the collection.
"""
if not match:
match = {}

with pymongo.MongoClient(**connection_args) as mongo_client:
database_ = mongo_client.get_database(database)
coll = mongo_client[database][collection]

nrows = next(
(
database_[collection].aggregate(
coll.aggregate(
[
{"$match": match},
{"$count": "count"},
Expand All @@ -137,7 +122,7 @@ def read_mongo(
npartitions = int(ceil(nrows / chunksize))

partitions_ids = list(
database_[collection].aggregate(
coll.aggregate(
[
{"$match": match},
{"$bucketAuto": {"groupBy": "$_id", "buckets": npartitions}},
Expand All @@ -146,17 +131,16 @@ def read_mongo(
)
)

partitions = [
fetch_mongo(
connection_args,
database,
collection,
common_args = (connection_args, database, collection, match)
name = "read_mongo-" + tokenize(common_args)
dsk = {
(name, i): (
fetch_mongo,
*common_args,
partition["_id"]["min"],
partition["_id"]["max"],
match,
include_last=idx == len(partitions_ids) - 1,
i == len(partitions_ids) - 1,
)
for idx, partition in enumerate(partitions_ids)
]

return db.from_delayed(partitions)
for i, partition in enumerate(partitions_ids)
}
return Bag(dsk, name, len(partitions_ids))
48 changes: 22 additions & 26 deletions dask_mongo/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pymongo
import pytest
from dask.bag.utils import assert_eq
from distributed import wait
from distributed.utils_test import cluster_fixture # noqa: F401
from distributed.utils_test import client, gen_cluster, loop # noqa: F401

Expand Down Expand Up @@ -40,52 +39,49 @@ def gen_data(size=10):


@gen_cluster(client=True, clean_kwargs={"threads": False})
async def test_to_mongo(c, s, a, b, connection_args):
async def test_to_mongo_distributed(c, s, a, b, connection_args):
records = gen_data(size=10)
npartitions = 3
b = db.from_sequence(records, npartitions=npartitions)

database = "test-db"
collection = "test-collection"
delayed_ = to_mongo(
b,
connection_args=connection_args,
database=database,
collection=collection,
)

with pymongo.MongoClient(**connection_args) as mongo_client:
database = "test-db"
assert database not in mongo_client.list_database_names()
collection = "test-collection"

partitions = to_mongo(
b,
connection_args=connection_args,
database=database,
collection=collection,
)
assert len(partitions) == npartitions
await wait(partitions)

await c.compute(delayed_)
assert database in mongo_client.list_database_names()
assert [collection] == mongo_client[database].list_collection_names()

results = list(mongo_client[database][collection].find())

# Drop "_id" and sort by "idx" for comparison
results = [{k: v for k, v in result.items() if k != "_id"} for result in results]
results = sorted(results, key=lambda x: x["idx"])
assert_eq(b, results)


def test_to_mongo_single_machine_scheduler(connection_args):
def test_to_mongo(connection_args):
records = gen_data(size=10)
npartitions = 3
b = db.from_sequence(records, npartitions=npartitions)
database = "test-db"
collection = "test-collection"
delayed_ = to_mongo(
b,
connection_args=connection_args,
database=database,
collection=collection,
)

with pymongo.MongoClient(**connection_args) as mongo_client:
database = "test-db"
assert database not in mongo_client.list_database_names()
collection = "test-collection"

to_mongo(
b,
connection_args=connection_args,
database=database,
collection=collection,
)

delayed_.compute()
assert database in mongo_client.list_database_names()
assert [collection] == mongo_client[database].list_collection_names()

Expand Down