-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
6 changed files
with
92 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import asyncio | ||
import json | ||
|
||
from google.cloud import storage | ||
from smart_open import open | ||
|
||
|
||
def parse_bucket(path: str) -> str: | ||
""" | ||
Parse bucket name from a GCS path. For example given the path | ||
gs://bucket-name/path/to/file, return bucket-name | ||
""" | ||
return path.split("/")[2] | ||
|
||
|
||
def convert_path_to_list(path: str) -> list[str]: | ||
if path.startswith("gs://"): | ||
bucket_name = parse_bucket(path) | ||
paths = [] | ||
client = storage.Client() | ||
for blob in client.list_blobs(bucket_name, prefix=path): | ||
paths.append(f"gs://{bucket_name}/{blob.name}") | ||
return paths | ||
return [path] | ||
|
||
|
||
async def read_file_and_enqueue(path, queue: asyncio.Queue): | ||
paths = convert_path_to_list(path) | ||
for path in paths: | ||
with open(path, mode="r") as file: | ||
print(f"Sending request to Queue from file {path}") | ||
for line in file.readlines(): | ||
request = json.loads(line) | ||
await queue.put(request) | ||
await queue.put(None) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
pytest | ||
pytest-asyncio | ||
pytest-httpserver | ||
pytest-mock |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import asyncio | ||
from dataclasses import dataclass | ||
from unittest.mock import Mock | ||
|
||
from google.cloud import storage | ||
import pytest | ||
|
||
from batchelor.reader import parse_bucket, convert_path_to_list | ||
|
||
|
||
@pytest.fixture | ||
def mock_client(mocker): | ||
return mocker.patch("google.cloud.storage.Client", autospec=True) | ||
|
||
|
||
def test_parse_bucket(mock_client): | ||
input = "gs://bucket-name/path/to/file" | ||
expected = "bucket-name" | ||
assert parse_bucket(input) == expected | ||
|
||
input = "gcss://bucket-name/path/to/file" | ||
expected = "bucket-name" | ||
assert parse_bucket(input) == expected | ||
|
||
|
||
@dataclass | ||
class Blob: | ||
name: str | ||
|
||
|
||
def test_convert_path_to_list_single(mocker, mock_client): | ||
path = "gs://bucket-name/path/to/file.json" | ||
mock_client.return_value.list_blobs.return_value = [Blob(name="path/to/file.json")] | ||
|
||
output = convert_path_to_list(path) | ||
assert mock_client.return_value.list_blobs.call_count == 1 | ||
assert len(output) == 1 | ||
assert output[0] == path | ||
|
||
|
||
def test_convert_path_to_list_multiple(mocker, mock_client): | ||
path = "gs://bucket-name/path" | ||
mock_client.return_value.list_blobs.return_value = [ | ||
Blob(name="path/file1.jsonl"), | ||
Blob(name="path/file2.jsonl"), | ||
] | ||
|
||
output = convert_path_to_list(path) | ||
assert mock_client.return_value.list_blobs.call_count == 1 | ||
assert len(output) == 2 | ||
assert output[0] == path + "/file1.jsonl" | ||
assert output[1] == path + "/file2.jsonl" |