diff --git a/DrissionPage/session_page.py b/DrissionPage/session_page.py
index 2647173..dc5d939 100644
--- a/DrissionPage/session_page.py
+++ b/DrissionPage/session_page.py
@@ -4,14 +4,13 @@
@Contact : g1879@qq.com
@File : session_page.py
"""
-from os import path as os_PATH
-from random import randint
from re import search
-from time import time, sleep
+from time import sleep
from typing import Union, List, Tuple
-from urllib.parse import urlparse, quote, unquote
+from urllib.parse import urlparse, quote
from requests import Session, Response
+from requests.structures import CaseInsensitiveDict
from tldextract import extract
from DownloadKit import DownloadKit
@@ -79,7 +78,8 @@ class SessionPage(BasePage):
raise ValueError('没有传入url。')
self._url = to_url
- self._response = self._try_to_connect(to_url, times=retry, interval=interval, show_errmsg=show_errmsg, **kwargs)
+ self._response, info = self._make_response(to_url, 'get', retry=retry, interval=interval,
+ show_errmsg=show_errmsg, **kwargs)
if self._response is None:
self._url_available = False
@@ -246,7 +246,8 @@ class SessionPage(BasePage):
raise ValueError('没有传入url。')
self._url = to_url
- self._response = self._try_to_connect(to_url, retry, interval, 'post', data, show_errmsg, **kwargs)
+ # self._response = self._try_to_connect(to_url, retry, interval, 'post', data, show_errmsg, **kwargs)
+ self._response, info = self._make_response(to_url, 'post', data, retry, interval, show_errmsg, **kwargs)
if self._response is None:
self._url_available = False
@@ -266,6 +267,8 @@ class SessionPage(BasePage):
url: str,
mode: str = 'get',
data: Union[dict, str] = None,
+ retry: int = None,
+ interval: float = None,
show_errmsg: bool = False,
**kwargs) -> tuple:
"""生成response对象 \n
@@ -276,114 +279,116 @@ class SessionPage(BasePage):
:param kwargs: 其它参数
:return: tuple,第一位为Response或None,第二位为出错信息或'Success'
"""
- if not url:
- if show_errmsg:
- raise ValueError('URL为空。')
- return None, 'URL为空。'
-
- if mode not in ('get', 'post'):
- raise ValueError("mode参数只能是'get'或'post'。")
-
- url = quote(url, safe='/:&?=%;#@+!')
+ kwargs = CaseInsensitiveDict(kwargs)
+ if 'headers' not in kwargs:
+ kwargs['headers'] = {}
+ else:
+ kwargs['headers'] = CaseInsensitiveDict(kwargs['headers'])
# 设置referer和host值
- kwargs_set = set(x.lower() for x in kwargs)
+ hostname = urlparse(url).hostname
+ scheme = urlparse(url).scheme
+ if not _check_headers(kwargs, self.session.headers, 'Referer'):
+ kwargs['headers']['Referer'] = self.url if self.url else f'{scheme}://{hostname}'
+ if 'Host' not in kwargs['headers']:
+ kwargs['headers']['Host'] = hostname
- if 'headers' in kwargs_set:
- header_set = set(x.lower() for x in kwargs['headers'])
-
- if self.url and 'referer' not in header_set:
- kwargs['headers']['Referer'] = self.url
-
- if 'host' not in header_set:
- kwargs['headers']['Host'] = urlparse(url).hostname
-
- else:
- kwargs['headers'] = self.session.headers
- kwargs['headers']['Host'] = urlparse(url).hostname
-
- if self.url:
- kwargs['headers']['Referer'] = self.url
-
- if 'timeout' not in kwargs_set:
+ if not _check_headers(kwargs, self.session.headers, 'timeout'):
kwargs['timeout'] = self.timeout
- try:
- r = None
+ r = None
+ retry = retry if retry is not None else self.retry_times
+ interval = interval if interval is not None else self.retry_interval
+ for i in range(retry + 1):
+ try:
+ if mode == 'get':
+ r = self.session.get(url, **kwargs)
+ elif mode == 'post':
+ r = self.session.post(url, data=data, **kwargs)
- if mode == 'get':
- r = self.session.get(url, **kwargs)
- elif mode == 'post':
- r = self.session.post(url, data=data, **kwargs)
+ print(r.url)
+ if r:
+ print(r.request.headers)
+ e = 'Success'
+ r = _set_charset(r)
+ return r, e
- except Exception as e:
- if show_errmsg:
- raise e
+ except Exception as e:
+ if show_errmsg:
+ raise e
- return None, e
+ if i < retry:
+ sleep(interval)
+ if r is None:
+ return None, '连接失败'
+
+ if not r.ok:
+ return r, f'状态码:{r.status_code}'
+
+ # try:
+ # r = None
+ # if mode == 'get':
+ # r = self.session.get(url, **kwargs)
+ # elif mode == 'post':
+ # r = self.session.post(url, data=data, **kwargs)
+ #
+ # if r is None:
+ # return None, '连接失败'
+ #
+ # except Exception as e:
+ # if show_errmsg:
+ # raise e
+ #
+ # return None, e
+ #
+ # else:
+ # # ----------------获取并设置编码开始-----------------
+ # # 在headers中获取编码
+ # content_type = r.headers.get('content-type', '').lower()
+ # charset = search(r'charset[=: ]*(.*)?[;]', content_type)
+ #
+ # if charset:
+ # r.encoding = charset.group(1)
+ #
+ # # 在headers中获取不到编码,且如果是网页
+ # elif content_type.replace(' ', '').startswith('text/html'):
+ # re_result = search(b']+).*?>', r.content)
+ #
+ # if re_result:
+ # charset = re_result.group(1).decode()
+ # else:
+ # charset = r.apparent_encoding
+ #
+ # r.encoding = charset
+ # # ----------------获取并设置编码结束-----------------
+ #
+ # return r, 'Success'
+
+
+def _check_headers(kwargs, headers: Union[dict, CaseInsensitiveDict], arg: str) -> bool:
+ """检查kwargs或headers中是否有arg所示属性"""
+ return arg in kwargs['headers'] or arg in headers
+
+
+def _set_charset(response) -> Response:
+ """设置Response对象的编码"""
+ # 在headers中获取编码
+ content_type = response.headers.get('content-type', '').lower()
+ charset = search(r'charset[=: ]*(.*)?[;]', content_type)
+
+ if charset:
+ response.encoding = charset.group(1)
+
+ # 在headers中获取不到编码,且如果是网页
+ elif content_type.replace(' ', '').startswith('text/html'):
+ re_result = search(b']+).*?>', response.content)
+
+ if re_result:
+ charset = re_result.group(1).decode()
else:
- # ----------------获取并设置编码开始-----------------
- # 在headers中获取编码
- content_type = r.headers.get('content-type', '').lower()
- charset = search(r'charset[=: ]*(.*)?[;]', content_type)
+ charset = response.apparent_encoding
- if charset:
- r.encoding = charset.group(1)
+ response.encoding = charset
- # 在headers中获取不到编码,且如果是网页
- elif content_type.replace(' ', '').startswith('text/html'):
- re_result = search(b']+).*?>', r.content)
-
- if re_result:
- charset = re_result.group(1).decode()
- else:
- charset = r.apparent_encoding
-
- r.encoding = charset
- # ----------------获取并设置编码结束-----------------
-
- return r, 'Success'
-
-
-def _get_download_file_name(url, response) -> str:
- """从headers或url中获取文件名,如果获取不到,生成一个随机文件名
- :param url: 文件url
- :param response: 返回的response
- :return: 下载文件的文件名
- """
- file_name = ''
- charset = ''
- content_disposition = response.headers.get('content-disposition', '').replace(' ', '')
-
- # 使用header里的文件名
- if content_disposition:
- txt = search(r'filename\*="?([^";]+)', content_disposition)
- if txt: # 文件名自带编码方式
- txt = txt.group(1).split("''", 1)
- if len(txt) == 2:
- charset, file_name = txt
- else:
- file_name = txt[0]
-
- else: # 文件名没带编码方式
- txt = search(r'filename="?([^";]+)', content_disposition)
- if txt:
- file_name = txt.group(1)
-
- # 获取编码(如有)
- charset = response.encoding
-
- file_name = file_name.strip("'")
-
- # 在url里获取文件名
- if not file_name and os_PATH.basename(url):
- file_name = os_PATH.basename(url).split("?")[0]
-
- # 找不到则用时间和随机数生成文件名
- if not file_name:
- file_name = f'untitled_{time()}_{randint(0, 100)}'
-
- # 去除非法字符
- charset = charset or 'utf-8'
- return unquote(file_name, charset)
+ return response