Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General code review #6

Merged
merged 6 commits into from
Jul 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
name: Tests

on: push
# TODO: Use the line below after we make this repo public
# on: [push, pull_request]
on: [push, pull_request]

# When this workflow is queued, automatically cancel any previous running
# or pending jobs from the same branch
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ venv.bak/
# Rope project settings
.ropeproject

# PyCharm project settings
.idea

# mkdocs documentation
/site

Expand Down
177 changes: 83 additions & 94 deletions dask_mongo/core.py
Original file line number Diff line number Diff line change
@@ -1,131 +1,121 @@
from copy import deepcopy
from __future__ import annotations

from math import ceil
from typing import Dict
from typing import Any

import dask
import dask.bag as db
import pymongo
from dask import delayed
from distributed import get_client
from bson import ObjectId
from dask.bag import Bag
from dask.base import tokenize
from dask.graph_manipulation import checkpoint


@delayed
def write_mongo(
values,
connection_args,
database,
collection,
):
with pymongo.MongoClient(**connection_args) as mongo_client:
database_ = mongo_client.get_database(database)
# NOTE: `insert_many` will mutate its input by inserting a "_id" entry.
# This can lead to confusing results, so we make a deepcopy to avoid this.
database_[collection].insert_many(deepcopy(values))
values: list[dict],
connection_kwargs: dict[str, Any],
database: str,
collection: str,
) -> None:
with pymongo.MongoClient(**connection_kwargs) as mongo_client:
coll = mongo_client[database][collection]
# `insert_many` will mutate its input by inserting a "_id" entry.
# This can lead to confusing results; pass copies to it to preserve the input.
values = [v.copy() for v in values]
coll.insert_many(values)


def to_mongo(
bag: db.Bag,
*,
connection_args: Dict,
bag: Bag,
database: str,
collection: str,
compute_options: Dict = None,
):
*,
connection_kwargs: dict[str, Any] = None,
compute: bool = True,
compute_kwargs: dict[str, Any] = None,
) -> Any:
"""Write a Dask Bag to a Mongo database.

Parameters
----------
bag:
Dask Bag to write into the database.
connection_args:
Connection arguments to pass to ``MongoClient``.
database:
database : str
Name of the database to write to. If it does not exists it will be created.
collection:
collection : str
Name of the collection within the database to write to.
If it does not exists it will be created.
compute_options:
Keyword arguments to be forwarded to ``dask.compute()``.
connection_kwargs : dict
Arguments to pass to ``MongoClient``.
compute : bool, optional
If true, immediately executes. If False, returns a delayed
object, which can be computed at a later time.
compute_kwargs : dict, optional
Options to be passed in to the compute method
Returns
-------
If compute=True, block until computation is done, then return None.
If compute=False, immediately return a dask.delayed object.
"""
if compute_options is None:
compute_options = {}

partitions = [
write_mongo(partition, connection_args, database, collection)
for partition in bag.to_delayed()
]

try:
client = get_client()
except ValueError:
# Using single-machine scheduler
dask.compute(partitions, **compute_options)
partials = bag.map_partitions(
write_mongo, connection_kwargs or {}, database, collection
)
collect = checkpoint(partials)
if compute:
return collect.compute(**compute_kwargs or {})
else:
return client.compute(partitions, **compute_options)
return collect


@delayed
def fetch_mongo(
connection_args,
database,
collection,
id_min,
id_max,
match,
include_last=False,
):
with pymongo.MongoClient(**connection_args) as mongo_client:
database_ = mongo_client.get_database(database)

results = list(
database_[collection].aggregate(
[
{"$match": match},
{
"$match": {
"_id": {
"$gte": id_min,
"$lte" if include_last else "$lt": id_max,
}
}
},
]
)
)

return results
connection_kwargs: dict[str, Any],
database: str,
collection: str,
match: dict[str, Any],
id_min: ObjectId,
id_max: ObjectId,
include_last: bool,
) -> list[dict[str, Any]]:
match2 = {"_id": {"$gte": id_min, "$lte" if include_last else "$lt": id_max}}
with pymongo.MongoClient(**connection_kwargs) as mongo_client:
coll = mongo_client[database][collection]
return list(coll.aggregate([{"$match": match}, {"$match": match2}]))


def read_mongo(
connection_args: Dict,
database: str,
collection: str,
chunksize: int,
match: Dict = {},
*,
connection_kwargs: dict[str, Any] = None,
match: dict[str, Any] = None,
):
"""Read data from a Mongo database into a Dask Bag.

Parameters
----------
connection_args:
Connection arguments to pass to ``MongoClient``.
database:
Name of the database to write to. If it does not exists it will be created.
Name of the database to read from
collection:
Name of the collection within the database to write to.
If it does not exists it will be created.
Name of the collection within the database to read from
chunksize:
Number of elements desired per partition.
connection_kwargs:
Connection arguments to pass to ``MongoClient``
match:
Dictionary with match expression. By default it will bring all the documents in the collection.
MongoDB match query, used to filter the documents in the collection. If omitted,
this function will load all the documents in the collection.
"""
if not connection_kwargs:
connection_kwargs = {}
if not match:
match = {}

with pymongo.MongoClient(**connection_args) as mongo_client:
database_ = mongo_client.get_database(database)
with pymongo.MongoClient(**connection_kwargs) as mongo_client:
coll = mongo_client[database][collection]

nrows = next(
(
database_[collection].aggregate(
coll.aggregate(
[
{"$match": match},
{"$count": "count"},
Expand All @@ -137,7 +127,7 @@ def read_mongo(
npartitions = int(ceil(nrows / chunksize))

partitions_ids = list(
database_[collection].aggregate(
coll.aggregate(
[
{"$match": match},
{"$bucketAuto": {"groupBy": "$_id", "buckets": npartitions}},
Expand All @@ -146,17 +136,16 @@ def read_mongo(
)
)

partitions = [
fetch_mongo(
connection_args,
database,
collection,
common_args = (connection_kwargs, database, collection, match)
name = "read_mongo-" + tokenize(common_args, chunksize)
dsk = {
(name, i): (
fetch_mongo,
*common_args,
partition["_id"]["min"],
partition["_id"]["max"],
match,
include_last=idx == len(partitions_ids) - 1,
i == len(partitions_ids) - 1,
)
for idx, partition in enumerate(partitions_ids)
]

return db.from_delayed(partitions)
for i, partition in enumerate(partitions_ids)
}
return Bag(dsk, name, len(partitions_ids))
Loading