Skip to content

Commit

Permalink
Add flatten parameter to RetrieveTask, closes #569
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Oct 3, 2023
1 parent 46a8ef8 commit f8a6c63
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
26 changes: 18 additions & 8 deletions src/python/txtai/workflow/task/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tempfile

from urllib.request import urlretrieve
from urllib.parse import urlparse

from .url import UrlTask

Expand All @@ -15,12 +16,13 @@ class RetrieveTask(UrlTask):
Task that retrieves urls (local or remote) to a local directory.
"""

def register(self, directory=None):
def register(self, directory=None, flatten=True):
"""
Adds retrieve parameters to task.
Args:
directory: local directory used to store retrieved files
flatten: flatten input directory structure, defaults to True
"""

# pylint: disable=W0201
Expand All @@ -32,20 +34,28 @@ def register(self, directory=None):
directory = self.tempdir.name

# Create output directory if necessary
if not os.path.exists(directory):
os.makedirs(directory)
os.makedirs(directory, exist_ok=True)

self.directory = directory
self.flatten = flatten

def prepare(self, element):
# Extract file name
_, name = os.path.split(element)
# Extract file path from URL
path = urlparse(element).path

# Derive output path
path = os.path.join(self.directory, name)
if self.flatten:
# Flatten directory structure (default)
path = os.path.join(self.directory, os.path.basename(path))
else:
# Derive output path
path = os.path.join(self.directory, os.path.normpath(path.lstrip("/")))
directory = os.path.dirname(path)

# Create local directory, if necessary
os.makedirs(directory, exist_ok=True)

# Retrieve URL
urlretrieve(element, os.path.join(self.directory, name))
urlretrieve(element, path)

# Return new file path
return path
5 changes: 5 additions & 0 deletions test/python/testworkflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,11 @@ def testRetrieveWorkflow(self):
results = list(workflow(["file://" + Utils.PATH + "/books.jpg"]))
self.assertTrue(results[0].endswith("books.jpg"))

# Test with directory structures
workflow = Workflow([RetrieveTask(flatten=False)])
results = list(workflow(["file://" + Utils.PATH + "/books.jpg"]))
self.assertTrue(results[0].endswith("books.jpg") and "txtai" in results[0])

def testScheduleWorkflow(self):
"""
Test workflow schedules
Expand Down

0 comments on commit f8a6c63

Please sign in to comment.