mirror of
https://gitee.com/g1879/DrissionPage.git
synced 2024-12-10 04:00:23 +08:00
273 lines
11 KiB
Python
273 lines
11 KiB
Python
# -*- coding:utf-8 -*-
|
||
"""
|
||
@Author : g1879
|
||
@Contact : g1879@qq.com
|
||
@File : session_page.py
|
||
"""
|
||
import os
|
||
import re
|
||
from pathlib import Path
|
||
from random import randint
|
||
from time import time
|
||
from typing import Union, List
|
||
from urllib.parse import urlparse, quote
|
||
|
||
from requests_html import HTMLSession, HTMLResponse
|
||
|
||
from .common import get_loc_from_str, translate_loc_to_xpath, avoid_duplicate_name
|
||
from .config import OptionsManager
|
||
from .session_element import SessionElement, execute_session_find
|
||
|
||
|
||
class SessionPage(object):
|
||
"""SessionPage封装了页面操作的常用功能,使用requests_html来获取、解析网页。"""
|
||
|
||
def __init__(self, session: HTMLSession, timeout: float = 10):
|
||
"""初始化函数"""
|
||
self._session = session
|
||
self.timeout = timeout
|
||
self._url = None
|
||
self._url_available = None
|
||
self._response = None
|
||
|
||
@property
|
||
def session(self) -> HTMLSession:
|
||
return self._session
|
||
|
||
@property
|
||
def response(self) -> HTMLResponse:
|
||
return self._response
|
||
|
||
@property
|
||
def url(self) -> str:
|
||
"""当前访问url"""
|
||
return self._url
|
||
|
||
@property
|
||
def url_available(self) -> bool:
|
||
"""url有效性"""
|
||
return self._url_available
|
||
|
||
@property
|
||
def cookies(self) -> dict:
|
||
"""当前session的cookies"""
|
||
return self.session.cookies.get_dict()
|
||
|
||
@property
|
||
def title(self) -> str:
|
||
"""获取网页title"""
|
||
return self.ele(('css selector', 'title')).text
|
||
|
||
@property
|
||
def html(self) -> str:
|
||
"""获取元素innerHTML,如未指定元素则获取所有源代码"""
|
||
return self.response.html.html
|
||
|
||
def ele(self,
|
||
loc_or_ele: Union[tuple, str, SessionElement],
|
||
mode: str = None,
|
||
show_errmsg: bool = False) -> Union[SessionElement, List[SessionElement], None]:
|
||
"""查找一个元素
|
||
:param loc_or_ele: 页面元素地址
|
||
:param mode: 以某种方式查找元素,可选'single','all'
|
||
:param show_errmsg: 是否显示错误信息
|
||
:return: 页面元素对象或列表
|
||
"""
|
||
if isinstance(loc_or_ele, SessionElement):
|
||
return loc_or_ele
|
||
elif isinstance(loc_or_ele, str):
|
||
loc = get_loc_from_str(loc_or_ele)
|
||
else:
|
||
loc = translate_loc_to_xpath(loc_or_ele)
|
||
|
||
return execute_session_find(self.response.html, loc, mode, show_errmsg)
|
||
|
||
def eles(self, loc: Union[tuple, str], show_errmsg: bool = False) -> List[SessionElement]:
|
||
"""查找符合条件的所有元素"""
|
||
return self.ele(loc, mode='all', show_errmsg=True)
|
||
|
||
def get(self, url: str, go_anyway: bool = False, **kwargs) -> Union[bool, None]:
|
||
"""用get方式跳转到url,调用_make_response()函数生成response对象"""
|
||
to_url = quote(url, safe='/:&?=%;#@')
|
||
if not url or (not go_anyway and self.url == to_url):
|
||
return
|
||
self._url = to_url
|
||
self._response = self._make_response(to_url, **kwargs)
|
||
if self._response:
|
||
self._response.html.encoding = self._response.encoding # 修复requests_html丢失编码方式的bug
|
||
self._url_available = True if self._response and self._response.ok else False
|
||
return self._url_available
|
||
|
||
def post(self, url: str, data: dict = None, go_anyway: bool = False, **kwargs) -> Union[bool, None]:
|
||
"""用post方式跳转到url,调用_make_response()函数生成response对象"""
|
||
to_url = quote(url, safe='/:&?=%;#@')
|
||
if not url or (not go_anyway and self._url == to_url):
|
||
return
|
||
self._url = to_url
|
||
self._response = self._make_response(to_url, mode='post', data=data, **kwargs)
|
||
if self._response:
|
||
try:
|
||
self._response.html.encoding = self._response.encoding # 修复requests_html丢失编码方式的bug
|
||
except:
|
||
pass
|
||
self._url_available = True if self._response and self._response.status_code == 200 else False
|
||
return self._url_available
|
||
|
||
def download(self,
|
||
file_url: str,
|
||
goal_path: str = None,
|
||
rename: str = None,
|
||
file_exists: str = 'rename',
|
||
show_msg: bool = False,
|
||
**kwargs) -> tuple:
|
||
"""下载一个文件
|
||
生成的response不写入self._response,是临时的
|
||
:param file_url: 文件url
|
||
:param goal_path: 存放路径url
|
||
:param rename: 重命名文件,不改变扩展名
|
||
:param kwargs: 连接参数
|
||
:param file_exists: 若存在同名文件,可选择'rename', 'overwrite', 'skip'方式处理
|
||
:param show_msg: 是否显示下载信息
|
||
:return: 元组,bool和状态信息(成功时信息为文件路径)
|
||
"""
|
||
goal_path = goal_path or OptionsManager().get_value('paths', 'global_tmp_path')
|
||
if not goal_path:
|
||
raise IOError('No path specified.')
|
||
|
||
kwargs['stream'] = True
|
||
if 'timeout' not in kwargs:
|
||
kwargs['timeout'] = 20
|
||
|
||
r = self._make_response(file_url, mode='get', **kwargs)
|
||
if not r:
|
||
if show_msg:
|
||
print('Invalid link')
|
||
return False, 'Invalid link'
|
||
# -------------------获取文件名-------------------
|
||
# header里有文件名,则使用它,否则在url里截取,但不能保证url包含文件名
|
||
if 'Content-disposition' in r.headers:
|
||
file_name = r.headers['Content-disposition'].split('"')[1].encode('ISO-8859-1').decode('utf-8')
|
||
elif os.path.basename(file_url):
|
||
file_name = os.path.basename(file_url).split("?")[0]
|
||
else:
|
||
file_name = f'untitled_{time()}_{randint(0, 100)}'
|
||
|
||
if rename: # 重命名文件,不改变扩展名
|
||
ext_name = file_name.split('.')[-1]
|
||
if rename.lower().endswith(f'.{ext_name}'.lower()) or ext_name == file_name:
|
||
full_name = rename
|
||
else:
|
||
full_name = f'{rename}.{ext_name}'
|
||
else:
|
||
full_name = file_name
|
||
|
||
full_name = re.sub(r'[\\/*:|<>?"]', '', full_name).strip()
|
||
goal_Path = Path(goal_path)
|
||
goal_path = ''
|
||
for key, i in enumerate(goal_Path.parts): # 去除路径中的非法字符
|
||
goal_path += goal_Path.drive if key == 0 and goal_Path.drive else re.sub(r'[*:|<>?"]', '', i).strip()
|
||
goal_path += '\\' if i != '\\' and key < len(goal_Path.parts) - 1 else ''
|
||
full_path = Path(f'{goal_path}\\{full_name}')
|
||
|
||
if full_path.exists():
|
||
if file_exists == 'skip':
|
||
return False, 'A file with the same name already exists.'
|
||
elif file_exists == 'overwrite':
|
||
pass
|
||
elif file_exists == 'rename':
|
||
full_name = avoid_duplicate_name(goal_path, full_name)
|
||
full_path = Path(f'{goal_path}\\{full_name}')
|
||
else:
|
||
raise ValueError("file_exists can only be selected in 'skip', 'overwrite', 'rename'")
|
||
Path(goal_path).mkdir(parents=True, exist_ok=True)
|
||
|
||
# 打印要下载的文件
|
||
if show_msg:
|
||
print_txt = full_name if file_name == full_name else f'{file_name} -> {full_name}'
|
||
print(print_txt)
|
||
print(f'Downloading to: {goal_path}')
|
||
# -------------------开始下载-------------------
|
||
# 获取远程文件大小
|
||
file_size = int(r.headers['Content-Length']) if 'Content-Length' in r.headers else None
|
||
# 已下载文件大小和下载状态
|
||
downloaded_size, download_status = 0, False
|
||
# 完整的存放路径
|
||
try:
|
||
with open(str(full_path), 'wb') as tmpFile:
|
||
for chunk in r.iter_content(chunk_size=1024):
|
||
if chunk:
|
||
tmpFile.write(chunk)
|
||
# 如表头有返回文件大小,显示进度
|
||
if show_msg and file_size:
|
||
downloaded_size += 1024
|
||
rate = downloaded_size / file_size if downloaded_size < file_size else 1
|
||
print('\r {:.0%} '.format(rate), end="")
|
||
except Exception as e:
|
||
download_status, info = False, f'Download failed.\n{e}'
|
||
raise
|
||
else:
|
||
download_status, info = (False, 'File size is 0.') if full_path.stat().st_size == 0 else (True, 'Success.')
|
||
finally:
|
||
# 删除下载出错文件
|
||
if not download_status and full_path.exists():
|
||
full_path.unlink()
|
||
r.close()
|
||
# -------------------显示并返回值-------------------
|
||
if show_msg:
|
||
print(info)
|
||
info = f'{goal_path}\\{full_name}' if download_status else info
|
||
return download_status, info
|
||
|
||
def _make_response(self, url: str, mode: str = 'get', data: dict = None, **kwargs) -> Union[HTMLResponse, bool]:
|
||
"""生成response对象。接收mode参数,以决定用什么方式。
|
||
:param url: 要访问的网址
|
||
:param mode: 'get', 'post'中选择
|
||
:param data: 提交的数据
|
||
:param kwargs: 其它参数
|
||
:return: Response对象
|
||
"""
|
||
if mode not in ['get', 'post']:
|
||
raise ValueError("mode must be 'get' or 'post'.")
|
||
url = quote(url, safe='/:&?=%;#@')
|
||
|
||
# 设置referer和host值
|
||
kwargs_set = set(x.lower() for x in kwargs)
|
||
if 'headers' in kwargs_set:
|
||
header_set = set(x.lower() for x in kwargs['headers'])
|
||
if self._url and 'referer' not in header_set:
|
||
kwargs['headers']['Referer'] = self._url
|
||
if 'host' not in header_set:
|
||
kwargs['headers']['Host'] = urlparse(url).hostname
|
||
else:
|
||
kwargs['headers'] = self.session.headers
|
||
kwargs['headers']['Host'] = urlparse(url).hostname
|
||
if self._url:
|
||
kwargs['headers']['Referer'] = self._url
|
||
|
||
if 'timeout' not in kwargs_set:
|
||
kwargs['timeout'] = self.timeout
|
||
|
||
try:
|
||
r = None
|
||
if mode == 'get':
|
||
r = self.session.get(url, **kwargs)
|
||
elif mode == 'post':
|
||
r = self.session.post(url, data=data, **kwargs)
|
||
except:
|
||
return_value = False
|
||
else:
|
||
headers = dict(r.headers)
|
||
if 'Content-Type' not in headers or 'charset' not in headers['Content-Type']:
|
||
re_result = re.search(r'<meta.*?charset=[ \'"]*([^"\' />]+).*?>', r.text)
|
||
try:
|
||
charset = re_result.group(1)
|
||
except:
|
||
charset = 'utf-8'
|
||
else:
|
||
charset = headers['Content-Type'].split('=')[1]
|
||
# 避免存在退格符导致乱码或解析出错
|
||
r._content = r.content if 'stream' in kwargs and kwargs['stream'] else r.content.replace(b'\x08', b'\\b')
|
||
r.encoding = charset
|
||
return_value = r
|
||
return return_value
|