diff --git a/hda/api.py b/hda/api.py index e613031..1470a53 100644 --- a/hda/api.py +++ b/hda/api.py @@ -189,12 +189,12 @@ def __repr__(self): self.jobId, ) - def download(self): + def download(self, download_dir: str = "."): for result in self.results: query = {"jobId": self.jobId, "uri": result["url"]} self.debug(result) url = DataOrderRequest(self.client).run(query) - self.stream(result.get("filename"), result.get("size"), *url) + self.stream(result.get("filename"), result.get("size"), download_dir, *url) class Client(object): @@ -466,7 +466,7 @@ def post(self, message, *args): self.debug("<=== %s", shorten(result)) return result - def stream(self, target, size, *args): + def stream(self, target, size, download_dir, *args): full = self.full_url(*args) filename = target @@ -480,6 +480,9 @@ def stream(self, target, size, *args): # always safe - namely not for Cryosat or other ESA datasets. filename = None + if download_dir is None or not os.path.exists(download_dir): + download_dir = "." + self.info( "Downloading %s to %s (%s)", full, @@ -523,7 +526,7 @@ def stream(self, target, size, *args): leave=False, ) as pbar: pbar.update(total) - with open(filename, mode) as f: + with open(os.path.join(download_dir, filename), mode) as f: for chunk in r.iter_content(chunk_size=1024): if chunk: f.write(chunk)