Skip to content

Commit

Permalink
Add initial read_mongo functionality (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
ncclementi authored Jul 16, 2021
1 parent d711e97 commit e6d2488
Show file tree
Hide file tree
Showing 8 changed files with 258 additions and 65 deletions.
8 changes: 1 addition & 7 deletions ci/environment-3.7.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,6 @@ channels:
dependencies:
- python=3.7
- dask
- pandas
- pymongo
- mongodb
- pytest
# Temporarily need dev version of distributed to use
# @gen_cluster with pytest fixtures
- pip
- pip:
- git+https://github.com/dask/distributed
- pytest
8 changes: 1 addition & 7 deletions ci/environment-3.8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,6 @@ channels:
dependencies:
- python=3.8
- dask
- pandas
- pymongo
- mongodb
- pytest
# Temporarily need dev version of distributed to use
# @gen_cluster with pytest fixtures
- pip
- pip:
- git+https://github.com/dask/distributed
- pytest
8 changes: 1 addition & 7 deletions ci/environment-3.9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,6 @@ channels:
dependencies:
- python=3.9
- dask
- pandas
- pymongo
- mongodb
- pytest
# Temporarily need dev version of distributed to use
# @gen_cluster with pytest fixtures
- pip
- pip:
- git+https://github.com/dask/distributed
- pytest
2 changes: 1 addition & 1 deletion dask_mongo/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .core import to_mongo
from .core import read_mongo, to_mongo
129 changes: 123 additions & 6 deletions dask_mongo/core.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,58 @@
from copy import deepcopy
from math import ceil
from typing import Dict

import dask
import pandas as pd
import dask.bag as db
import pymongo
from dask import delayed
from distributed import get_client


@delayed
def write_mongo(
df: pd.DataFrame,
values,
connection_args,
database,
collection,
):
with pymongo.MongoClient(**connection_args) as mongo_client:
db = mongo_client.get_database(database)
db[collection].insert_many(df.to_dict("records"))
database_ = mongo_client.get_database(database)
# NOTE: `insert_many` will mutate its input by inserting a "_id" entry.
# This can lead to confusing results, so we make a deepcopy to avoid this.
database_[collection].insert_many(deepcopy(values))


def to_mongo(
df,
bag: db.Bag,
*,
connection_args: Dict,
database: str,
collection: str,
compute_options: Dict = None,
):
"""Write a Dask Bag to a Mongo database.
Parameters
----------
bag:
Dask Bag to write into the database.
connection_args:
Connection arguments to pass to ``MongoClient``.
database:
Name of the database to write to. If it does not exists it will be created.
collection:
Name of the collection within the database to write to.
If it does not exists it will be created.
compute_options:
Keyword arguments to be forwarded to ``dask.compute()``.
"""
if compute_options is None:
compute_options = {}

partitions = [
write_mongo(partition, connection_args, database, collection)
for partition in df.to_delayed()
for partition in bag.to_delayed()
]

try:
Expand All @@ -43,3 +62,101 @@ def to_mongo(
dask.compute(partitions, **compute_options)
else:
return client.compute(partitions, **compute_options)


@delayed
def fetch_mongo(
connection_args,
database,
collection,
id_min,
id_max,
match,
include_last=False,
):
with pymongo.MongoClient(**connection_args) as mongo_client:
database_ = mongo_client.get_database(database)

results = list(
database_[collection].aggregate(
[
{"$match": match},
{
"$match": {
"_id": {
"$gte": id_min,
"$lte" if include_last else "$lt": id_max,
}
}
},
]
)
)

return results


def read_mongo(
connection_args: Dict,
database: str,
collection: str,
chunksize: int,
match: Dict = {},
):
"""Read data from a Mongo database into a Dask Bag.
Parameters
----------
connection_args:
Connection arguments to pass to ``MongoClient``.
database:
Name of the database to write to. If it does not exists it will be created.
collection:
Name of the collection within the database to write to.
If it does not exists it will be created.
chunksize:
Number of elements desired per partition.
match:
Dictionary with match expression. By default it will bring all the documents in the collection.
"""

with pymongo.MongoClient(**connection_args) as mongo_client:
database_ = mongo_client.get_database(database)

nrows = next(
(
database_[collection].aggregate(
[
{"$match": match},
{"$count": "count"},
]
)
)
)["count"]

npartitions = int(ceil(nrows / chunksize))

partitions_ids = list(
database_[collection].aggregate(
[
{"$match": match},
{"$bucketAuto": {"groupBy": "$_id", "buckets": npartitions}},
],
allowDiskUse=True,
)
)

partitions = [
fetch_mongo(
connection_args,
database,
collection,
partition["_id"]["min"],
partition["_id"]["max"],
match,
include_last=idx == len(partitions_ids) - 1,
)
for idx, partition in enumerate(partitions_ids)
]

return db.from_delayed(partitions)
Loading

0 comments on commit e6d2488

Please sign in to comment.