From 7fb52e82a3929596d655ee3007dc20c02bcf4a98 Mon Sep 17 00:00:00 2001 From: Zoe Roux Date: Sun, 15 Nov 2020 18:24:58 +0100 Subject: [PATCH] Fixing the downloader, adding a workdir argument --- Pipfile | 1 + autopipe/__init__.py | 13 +++++- autopipe/coordinators/download_example.py | 2 +- autopipe/pipe/downloader.py | 49 ++++++++--------------- 4 files changed, 30 insertions(+), 35 deletions(-) diff --git a/Pipfile b/Pipfile index 4756e3b..964f268 100644 --- a/Pipfile +++ b/Pipfile @@ -5,5 +5,6 @@ name = "pypi" [packages] feedparser = "*" +requests = "*" [dev-packages] diff --git a/autopipe/__init__.py b/autopipe/__init__.py index 8898c99..5cf9f78 100644 --- a/autopipe/__init__.py +++ b/autopipe/__init__.py @@ -15,6 +15,7 @@ autopipe: Autopipe def _parse_args(argv=None): from sys import argv as sysargv from argparse import ArgumentParser, HelpFormatter + import os class CustomHelpFormatter(HelpFormatter): # noinspection PyProtectedMember @@ -25,6 +26,11 @@ def _parse_args(argv=None): args_string = self._format_args(action, default) return ', '.join(action.option_strings) + ' ' + args_string + def dir_path(path): + if os.path.isdir(path): + return path + raise NotADirectoryError + # noinspection PyTypeChecker parser = ArgumentParser(description="Easily run advanced pipelines in a daemon or in one run sessions.", formatter_class=CustomHelpFormatter) @@ -35,7 +41,12 @@ def _parse_args(argv=None): help="Set the logging level. (default: warn ; available: %(choices)s)", type=LogLevel.parse) parser.add_argument("-d", "--daemon", help="Enable the daemon mode (rerun input generators after a sleep cooldown)", action="store_true") - return parser.parse_args(argv if argv is not None else sysargv[1:]) + parser.add_argument("-w", "--workdir", help="Change the workdir, default is the pwd.", type=dir_path, metavar="dir") + + args = parser.parse_args(argv if argv is not None else sysargv[1:]) + if args.workdir is not None: + os.chdir(args.workdir) + return args def main(argv=None): diff --git a/autopipe/coordinators/download_example.py b/autopipe/coordinators/download_example.py index 22f35b8..c5bf98e 100644 --- a/autopipe/coordinators/download_example.py +++ b/autopipe/coordinators/download_example.py @@ -16,7 +16,7 @@ class DownloadExample(Coordinator): @property def input(self): return RssInput(f"http://www.obsrv.com/General/ImageFeed.aspx?{self.query}", - lambda x: FileData(x.title, x["media_content"][0]["url"], False)) + lambda x: FileData(None, x["media_content"][0]["url"], False)) @property def pipeline(self) -> List[Union[Pipe, Callable[[APData], Union[APData, Pipe]]]]: diff --git a/autopipe/pipe/downloader.py b/autopipe/pipe/downloader.py index 530f7a1..f9348f8 100644 --- a/autopipe/pipe/downloader.py +++ b/autopipe/pipe/downloader.py @@ -1,4 +1,7 @@ -import logging +import os +from pathlib import Path + +import requests from autopipe import Pipe, APData @@ -15,6 +18,10 @@ class FileData(APData): class DownloaderPipe(Pipe): + def __init__(self, cwd=None): + self.cwd = cwd if cwd is not None else os.getcwd() + Path(self.cwd).mkdir(parents=True, exist_ok=True) + @property def name(self): return "Downloader" @@ -23,37 +30,13 @@ class DownloaderPipe(Pipe): super().pipe(data) if data.is_local: return data - # if not force_refresh and os.path.isfile(path): - # if not read: - # return - # with open(path, "r") as f: - # return StringIO(f.read()) - # - # if message: - # print(message) - # r = requests.get(url, stream=progress) - # try: - # Path(os.path.dirname(path)).mkdir(parents=True, exist_ok=True) - # with open(path, "wb") as f: - # length = r.headers.get("content-length") - # if progress and length: - # local = 0 - # length = int(length) - # for chunk in r.iter_content(chunk_size=4096): - # f.write(chunk) - # local += len(chunk) - # per = 50 * local // length - # print(f"\r [{'#' * per}{'-' * (50 - per)}] ({sizeof_fmt(local)}/{sizeof_fmt(length)}) \r", - # end='', flush=True) - # else: - # f.write(r.content) - # if read: - # return StringIO(r.content.decode(encoding)) - # except KeyboardInterrupt: - # os.remove(path) - # if progress: - # print() - # print("Download cancelled") - # raise + path = os.path.join(self.cwd, data.name if data.name is not None else data.link.split('/')[-1]) + try: + r = requests.get(data.link) + with open(path, "wb") as f: + f.write(r.content) + except KeyboardInterrupt: + os.remove(path) + raise data.is_local = True return data