diff --git a/DrissionPage/session_page.py b/DrissionPage/session_page.py index 7049b4f..030ec4f 100644 --- a/DrissionPage/session_page.py +++ b/DrissionPage/session_page.py @@ -16,7 +16,7 @@ from requests import Session, Response from tldextract import extract from .base import BasePage -from .common import get_available_file_name, format_html +from .common import get_usable_path, format_html, make_valid_name from .config import _cookie_to_dict from .session_element import SessionElement, make_session_ele @@ -256,7 +256,7 @@ class SessionPage(BasePage): goal_path: str, rename: str = None, file_exists: str = 'rename', - post_data: dict = None, + post_data: Union[str, dict] = None, show_msg: bool = False, show_errmsg: bool = False, retry: int = None, @@ -273,126 +273,81 @@ class SessionPage(BasePage): :param retry: 重试次数 :param interval: 重试间隔时间 :param kwargs: 连接参数 - :return: 下载是否成功(bool)和状态信息(成功时信息为文件路径)的元组 + :return: 下载是否成功(bool)和状态信息(成功时信息为文件路径)的元组,跳过时第一位为None """ if file_exists == 'skip' and Path(f'{goal_path}{sep}{rename}').exists(): if show_msg: print(f'{file_url}\n{goal_path}{sep}{rename}\n已跳过。\n') + return None, 'Skipped because a file with the same name already exists.' - return False, 'Skipped because a file with the same name already exists.' + def do() -> tuple: + kwargs['stream'] = True + if 'timeout' not in kwargs: + kwargs['timeout'] = 20 - def do(url: str, - goal: str, - new_name: str = None, - exists: str = 'rename', - data: dict = None, - msg: bool = False, - errmsg: bool = False, - **args) -> tuple: - args['stream'] = True - - if 'timeout' not in args: - args['timeout'] = 20 - - mode = 'post' if data else 'get' + mode = 'post' if post_data else 'get' # 生成的response不写入self._response,是临时的 - r, info = self._make_response(url, mode=mode, data=data, show_errmsg=errmsg, **args) + r, info = self._make_response(file_url, mode=mode, data=post_data, show_errmsg=show_errmsg, **kwargs) if r is None: - if msg: + if show_msg: print(info) - return False, info if not r.ok: - if errmsg: + if show_errmsg: raise ConnectionError(f'连接状态码:{r.status_code}.') - return False, f'Status code: {r.status_code}.' # -------------------获取文件名------------------- - file_name = '' - content_disposition = r.headers.get('content-disposition', '').replace(' ', '') - - # 使用header里的文件名 - if content_disposition: - # TODO: 待测试 - txt = search(r'filename\*="?([^";]+)', content_disposition) - if txt: - charset, file_name = txt.group(1).split("''", 1) - file_name = unquote(content_disposition, charset) - else: - txt = search(r'filename="?([^";]+)', content_disposition) - if txt: - file_name = unquote(txt.group(1)) - - if file_name and file_name[0] == file_name[-1] == "'": - file_name = file_name[1:-1] - - # 在url里获取文件名 - if not file_name and os_PATH.basename(url): - file_name = os_PATH.basename(url).split("?")[0] - - # 找不到则用时间和随机数生成文件名 - if not file_name: - file_name = f'untitled_{time()}_{randint(0, 100)}' - - # 去除非法字符 - file_name = sub(r'[\\/*:|<>?"]', '', file_name).strip() - file_name = unquote(file_name) + file_name = _get_download_file_name(file_url, r.headers) # -------------------重命名,不改变扩展名------------------- - if new_name: - new_name = sub(r'[\\/*:|<>?"]', '', new_name).strip() + if rename: ext_name = file_name.split('.')[-1] - - if '.' in new_name or ext_name == file_name: - full_name = new_name + if '.' in rename or ext_name == file_name: # 新文件名带后缀或原文件名没有后缀 + full_name = rename else: - full_name = f'{new_name}.{ext_name}' - + full_name = f'{rename}.{ext_name}' else: full_name = file_name + full_name = make_valid_name(full_name) + # -------------------生成路径------------------- - goal_Path = Path(goal) - goal = '' + goal_Path = Path(goal_path) skip = False - for key, p in enumerate(goal_Path.parts): # 去除路径中的非法字符 - goal += goal_Path.drive if key == 0 and goal_Path.drive else sub(r'[*:|<>?"]', '', p).strip() - goal += '\\' if p != '\\' and key < len(goal_Path.parts) - 1 else '' - - goal_Path = Path(goal).absolute() - goal_Path.mkdir(parents=True, exist_ok=True) + # 按windows规则去除路径中的非法字符 + goal = goal_Path.anchor + sub(r'[*:|<>?"]', '', goal_path.lstrip(goal_Path.anchor)).strip() + Path(goal).absolute().mkdir(parents=True, exist_ok=True) full_path = Path(f'{goal}{sep}{full_name}') if full_path.exists(): if file_exists == 'rename': - full_name = get_available_file_name(goal, full_name) - full_path = Path(f'{goal}{sep}{full_name}') + full_path = get_usable_path(f'{goal}{sep}{full_name}') + full_name = full_path.name - elif exists == 'skip': + elif file_exists == 'skip': skip = True - elif exists == 'overwrite': + elif file_exists == 'overwrite': pass else: raise ValueError("file_exists参数只能是'skip'、'overwrite' 或 'rename'。") # -------------------打印要下载的文件------------------- - if msg: + if show_msg: print(file_url) print(full_name if file_name == full_name else f'{file_name} -> {full_name}') print(f'正在下载到:{goal}') - if skip: print('已跳过。\n') # -------------------开始下载------------------- if skip: - return False, 'Skipped because a file with the same name already exists.' + return None, 'Skipped because a file with the same name already exists.' # 获取远程文件大小 content_length = r.headers.get('content-length') @@ -408,20 +363,20 @@ class SessionPage(BasePage): tmpFile.write(chunk) # 如表头有返回文件大小,显示进度 - if msg and file_size: + if show_msg and file_size: downloaded_size += 1024 rate = downloaded_size / file_size if downloaded_size < file_size else 1 print('\r {:.0%} '.format(rate), end="") except Exception as e: - if errmsg: + if show_errmsg: raise ConnectionError(e) download_status, info = False, f'Download failed.\n{e}' else: if full_path.stat().st_size == 0: - if errmsg: + if show_errmsg: raise ValueError('文件大小为0。') download_status, info = False, 'File size is 0.' @@ -430,14 +385,13 @@ class SessionPage(BasePage): download_status, info = True, str(full_path) finally: - # 删除下载出错文件 if not download_status and full_path.exists(): - full_path.unlink() + full_path.unlink() # 删除下载出错文件 r.close() # -------------------显示并返回值------------------- - if msg: + if show_msg: print(info, '\n') info = f'{goal}{sep}{full_name}' if download_status else info @@ -445,15 +399,15 @@ class SessionPage(BasePage): retry_times = retry or self.retry_times retry_interval = interval or self.retry_interval - result = do(file_url, goal_path, rename, file_exists, post_data, show_msg, show_errmsg, **kwargs) + result = do() - if not result[0] and not str(result[1]).startswith('Skipped'): + if result[0] is False: # 第一位为None表示跳过的情况 for i in range(retry_times): sleep(retry_interval) - print(f'重试 {file_url}') - result = do(file_url, goal_path, rename, file_exists, post_data, show_msg, show_errmsg, **kwargs) - if result[0]: + + result = do() + if result[0] is not False: break return result @@ -540,3 +494,49 @@ class SessionPage(BasePage): # ----------------获取并设置编码结束----------------- return r, 'Success' + + +def _get_download_file_name(url, headers) -> str: + """从headers或url中获取文件名,如果获取不到,生成一个随机文件名 + :param url: 文件url + :param headers: 返回的headers + :return: 下载文件的文件名 + """ + file_name = '' + charset = '' + content_disposition = headers.get('content-disposition', '').replace(' ', '') + + # 使用header里的文件名 + if content_disposition: + txt = search(r'filename\*="?([^";]+)', content_disposition) + if txt: # 文件名自带编码方式 + txt = txt.group(1).split("''", 1) + if len(txt) == 2: + charset, file_name = txt + else: + file_name = txt[0] + + else: # 文件名没有编码方式 + txt = search(r'filename="?([^";]+)', content_disposition) + if txt: + file_name = txt.group(1) + + # 获取编码(如有) + content_type = headers.get('content-type', '').lower() + charset = search(r'charset[=: ]*(.*)?[;]', content_type) + if charset: + charset = charset.group(1) + + file_name = file_name.strip("'") + + # 在url里获取文件名 + if not file_name and os_PATH.basename(url): + file_name = os_PATH.basename(url).split("?")[0] + + # 找不到则用时间和随机数生成文件名 + if not file_name: + file_name = f'untitled_{time()}_{randint(0, 100)}' + + # 去除非法字符 + charset = charset or 'utf-8' + return unquote(file_name, charset)