diff --git a/DrissionPage/chromium_base.py b/DrissionPage/chromium_base.py index 9f48a4d..80e1c7a 100644 --- a/DrissionPage/chromium_base.py +++ b/DrissionPage/chromium_base.py @@ -19,7 +19,7 @@ from .chromium_element import ChromiumScroll, ChromiumElement, run_js, make_chro from .commons.constants import HANDLE_ALERT_METHOD, ERROR, NoneElement from .commons.locator import get_loc from .commons.tools import get_usable_path, clean_folder -from .commons.web import set_browser_cookies, ResponseData +from .commons.web import set_browser_cookies, DataPacket from .errors import ContextLossError, ElementLossError, AlertExistsError, CallMethodError, TabClosedError, \ NoRectError, BrowserConnectError from .session_element import make_session_ele @@ -41,6 +41,7 @@ class ChromiumBase(BasePage): self._tab_obj = None self._set = None self._screencast = None + self._listener = None if isinstance(address, int) or (isinstance(address, str) and address.isdigit()): address = f'127.0.0.1:{address}' @@ -360,6 +361,13 @@ class ChromiumBase(BasePage): self._screencast = Screencast(self) return self._screencast + @property + def listener(self): + """返回用于聆听数据包的对象""" + if self._listener is None: + self._listener = NetworkListener(self) + return self._listener + def run_cdp(self, cmd, **cmd_args): """执行Chrome DevTools Protocol语句 :param cmd: 协议项目 @@ -1024,32 +1032,6 @@ class ChromiumBaseWaiter(object): sleep(gap) return False - def set_targets(self, targets, is_regex=False): - """指定要等待的数据包 - :param targets: 要匹配的数据包url特征,可用list等传入多个 - :param is_regex: 设置的target是否正则表达式 - :return: None - """ - if not self._listener: - self._listener = NetworkListener(self._driver) - self._listener.set_targets(targets, is_regex) - - def data_packets(self, timeout=None, any_one=False): - """等待指定数据包加载完成 - :param timeout: 超时时间,为None则使用页面对象timeout - :param any_one: 多个target时,是否全部监听到才结束,为True时监听到一个目标就结束 - :return: ResponseData对象或监听结果字典 - """ - if not self._listener: - self._listener = NetworkListener(self._driver) - return self._listener.listen(timeout, any_one) - - def stop_listening(self): - """停止监听数据包""" - if not self._listener: - self._listener = NetworkListener(self._driver) - self._listener.stop() - class NetworkListener(object): def __init__(self, page): @@ -1058,39 +1040,73 @@ class NetworkListener(object): self._is_regex = False self._results = {} self._single = False + self._method = None self._requests = {} - def set_targets(self, targets, is_regex=False): + self.is_listening = False + self._count = None + self._caught = 0 # 已获取到的数量 + self._driver = self._page.driver + + def set_targets(self, targets=None, is_regex=False, count=None, method=None): """指定要等待的数据包 - :param targets: 要匹配的数据包url特征,可用list等传入多个 + :param targets: 要匹配的数据包url特征,可用list等传入多个,为None时获取所有 :param is_regex: 设置的target是否正则表达式 + :param count: 设置总共等待多少个数据包,为None时每个目标等待1个 + :param method: 设置监听的请求类型,可用list等指定多个,为None时监听全部 :return: None """ - if not isinstance(targets, (str, list, tuple, set)): - raise TypeError('targets只能是str、list、tuple、set。') + if not isinstance(targets, (str, list, tuple, set)) and targets is not None: + raise TypeError('targets只能是str、list、tuple、set、None。') + if targets is None: + targets = '' + self._is_regex = is_regex if isinstance(targets, str): self._targets = {targets} - self._single = True else: self._targets = set(targets) - self._single = False - self._page.run_cdp('Network.enable') - if targets is not None: - self._page.driver.Network.requestWillBeSent = self._requestWillBeSent - self._page.driver.Network.responseReceived = self._response_received - self._page.driver.Network.loadingFinished = self._loading_finished - else: - self.stop() + + self._count = len(self._targets) if not count else count + self._single = self._count == 1 + if method is not None: + if isinstance(method, str): + self._method = {method.upper()} + elif isinstance(method, (list, tuple, set)): + self._method = set(i.upper() for i in method) + else: + raise TypeError('method参数只能是str、list、tuple、set类型。') + self.start() + + def start(self): + self._driver.set_listener('Network.requestWillBeSent', self._requestWillBeSent) + self._driver.set_listener('Network.responseReceived', self._response_received) + self._driver.set_listener('Network.loadingFinished', self._loading_finished) + self._driver.set_listener('Network.loadingFailed', self._loading_failed) + self._driver.call_method('Network.enable') + self._requests = {} + # self._driver.set_listener('Fetch.requestPaused', self._request_paused) + # self._driver.call_method('Fetch.enable', patterns=[{'requestStage': 'Request'}, {'requestStage': 'Response'}]) def stop(self): """停止监听数据包""" - self._page.run_cdp('Network.disable') - self._page.driver.Network.requestWillBeSent = None - self._page.driver.Network.responseReceived = None - self._page.driver.Network.loadingFinished = None + self._driver.call_method('Network.disable') + self._driver.set_listener('Network.requestWillBeSent', None) + self._driver.set_listener('Network.responseReceived', None) + self._driver.set_listener('Network.loadingFinished', None) + self._driver.set_listener('Network.loadingFailed', None) + # self._driver.call_method('Fetch.disable') + # self._driver.set_listener('Fetch.requestPaused', None) - def listen(self, timeout=None, any_one=False): + def listen(self, timeout=None, any_one=False, asyn=False): + if asyn: + pass + else: + r = self._listen(timeout, any_one) + self._results = {} + return r + + def _listen(self, timeout=None, any_one=False): """等待指定数据包加载完成 :param timeout: 超时时间,为None则使用页面对象timeout :param any_one: 多个target时,是否全部监听到才结束,为True时监听到一个目标就结束 @@ -1099,24 +1115,45 @@ class NetworkListener(object): if self._targets is None: raise RuntimeError('必须先用set_targets()设置等待目标。') + self.is_listening = True timeout = timeout if timeout is not None else self._page.timeout end_time = perf_counter() + timeout - while perf_counter() < end_time: - if self._results and (any_one or set(self._results) == self._targets): - break + while perf_counter() < end_time and not ((any_one and self._caught) or self._caught >= self._count): sleep(.1) self._requests = {} - if not self._results: - return False - r = list(self._results.values())[0] if self._single else self._results + self.is_listening = False + return self.results() + + @property + def results(self): + """返沪监听到的数据""" + return list(self._results.values())[0][0] if self._results and self._single else self._results + + def clear(self): + """清空已监听到的数据""" self._results = {} - return r + + def _requestWillBeSent(self, **kwargs): + """接收到请求时的回调函数""" + for target in self._targets: + if ((self._is_regex and search(target, kwargs['request']['url'])) or + (not self._is_regex and target in kwargs['request']['url'])) and ( + not self._method or kwargs['request']['method'] in self._method): + self._requests[kwargs['requestId']] = DataPacket(self._page.tab_id, target, kwargs) + + if kwargs['request'].get('hasPostData', None) and not kwargs['request'].get('postData', None): + self._requests[kwargs['requestId']]._raw_post_data = \ + self._page.run_cdp('Network.getRequestPostData', requestId=kwargs['requestId'])['postData'] + + break def _response_received(self, **kwargs): """接收到返回信息时处理方法""" - if kwargs['requestId'] in self._requests: - self._requests[kwargs['requestId']]['response'] = kwargs['response'] + request_id = kwargs['requestId'] + if request_id in self._requests: + self._requests[request_id]._raw_response = kwargs['response'] + self._requests[request_id]._resource_type = kwargs['type'] def _loading_finished(self, **kwargs): """请求完成时处理方法""" @@ -1130,25 +1167,61 @@ class NetworkListener(object): body = '' is_base64 = False - request = self._requests[request_id] - target = request['target'] - rd = ResponseData(request_id, request['response'], body, self._page.tab_id, target) - rd.method = request['method'] - rd.postData = request['post_data'] - rd._base64_body = is_base64 - rd.requestHeaders = request['request_headers'] - self._results[target] = rd + dp = self._requests[request_id] + target = dp.target + dp._raw_body = body + dp._base64_body = is_base64 - def _requestWillBeSent(self, **kwargs): - """接收到请求时的回调函数""" - for target in self._targets: - if (self._is_regex and search(target, kwargs['request']['url'])) or ( - not self._is_regex and target in kwargs['request']['url']): - self._requests[kwargs['requestId']] = {'target': target, - 'method': kwargs['request']['method'], - 'post_data': kwargs['request'].get('postData', None), - 'request_headers': kwargs['request']['headers']} - break + if target in self._results: + self._results[target].append(dp) + else: + self._results[target] = [dp] + + self._caught += 1 + + def _loading_failed(self, **kwargs): + """请求失败时的回调方法""" + request_id = kwargs['requestId'] + if request_id in self._requests: + dp = self._requests[request_id] + target = dp.target + dp.errorText = kwargs['errorText'] + dp._resource_type = kwargs['type'] + + if target in self._results: + self._results[target].append(dp) + else: + self._results[target] = [dp] + + self._caught += 1 + + def _request_paused(self, **kwargs): + i = kwargs['requestId'] + if 'networkId' not in kwargs: + pass + # for target in self._targets: + # if (self._is_regex and search(target, kwargs['request']['url'])) or ( + # not self._is_regex and target in kwargs['request']['url']): + # dp = DataPacket(self._page.tab_id, target, kwargs) + # body = self._driver.call_method('Fetch.getResponseBody', requestId=i) + # dp._raw_body = body['body'] + # dp._base64_body = body['base64Encoded'] + # if 'networkId' in kwargs and kwargs['request'].get('hasPostData', None) \ + # and not kwargs['request'].get('postData', None): + # pd = self._driver.call_method('Network.getRequestPostData', requestId=kwargs['networkId']) + # if 'postData' in pd: + # dp._raw_post_data = pd['postData'] + # + # if target in self._results: + # self._results[target].append(dp) + # else: + # self._results[target] = [dp] + # + # self._caught += 1 + # break + + method = 'Request' if 'responseStatusCode' not in kwargs else 'Response' + self._driver.call_method(f'Fetch.continue{method}', requestId=i) class ChromiumPageScroll(ChromiumScroll): diff --git a/DrissionPage/chromium_base.pyi b/DrissionPage/chromium_base.pyi index 9638dc8..3511318 100644 --- a/DrissionPage/chromium_base.pyi +++ b/DrissionPage/chromium_base.pyi @@ -15,7 +15,7 @@ from .chromium_driver import ChromiumDriver from .chromium_element import ChromiumElement, ChromiumScroll from .chromium_frame import ChromiumFrame from .commons.constants import NoneElement -from .commons.web import ResponseData +from .commons.web import DataPacket from .session_element import SessionElement @@ -42,6 +42,7 @@ class ChromiumBase(BasePage): self._wait: ChromiumBaseWaiter = ... self._set: ChromiumBaseSetter = ... self._screencast: Screencast = ... + self._listener: NetworkListener = ... def _connect_browser(self, tab_id: str = None) -> None: ... @@ -129,37 +130,33 @@ class ChromiumBase(BasePage): @property def screencast(self) -> Screencast: ... + @property + def listener(self) -> NetworkListener: ... + def run_js(self, script: str, *args: Any, as_expr: bool = False) -> Any: ... def run_js_loaded(self, script: str, *args: Any, as_expr: bool = False) -> Any: ... def run_async_js(self, script: str, *args: Any, as_expr: bool = False) -> None: ... - def get(self, - url: str, - show_errmsg: bool = False, - retry: int = None, - interval: float = None, - timeout: float = None) -> Union[None, bool]: ... + def get(self, url: str, show_errmsg: bool = False, retry: int = None, + interval: float = None, timeout: float = None) -> Union[None, bool]: ... - def get_cookies(self, as_dict: bool = False, all_domains: bool = False, all_info: bool = False) -> Union[ - list, dict]: ... + def get_cookies(self, as_dict: bool = False, all_domains: bool = False, + all_info: bool = False) -> Union[list, dict]: ... - def ele(self, - loc_or_ele: Union[Tuple[str, str], str, ChromiumElement, ChromiumFrame], - timeout: float = None) -> ChromiumElement: ... + def ele(self, loc_or_ele: Union[Tuple[str, str], str, ChromiumElement, ChromiumFrame], + timeout: float = None) -> Union[ChromiumElement, str]: ... - def eles(self, - loc_or_str: Union[Tuple[str, str], str], - timeout: float = None) -> List[ChromiumElement]: ... + def eles(self, loc_or_str: Union[Tuple[str, str], str], + timeout: float = None) -> List[Union[ChromiumElement, str]]: ... def s_ele(self, loc_or_ele: Union[Tuple[str, str], str] = None) \ -> Union[SessionElement, str, NoneElement]: ... def s_eles(self, loc_or_str: Union[Tuple[str, str], str]) -> List[Union[SessionElement, str]]: ... - def _find_elements(self, - loc_or_ele: Union[Tuple[str, str], str, ChromiumElement, ChromiumFrame], + def _find_elements(self, loc_or_ele: Union[Tuple[str, str], str, ChromiumElement, ChromiumFrame], timeout: float = None, single: bool = True, relative: bool = False, raise_err: bool = None) \ -> Union[ChromiumElement, ChromiumFrame, NoneElement, List[Union[ChromiumElement, ChromiumFrame]]]: ... @@ -231,37 +228,50 @@ class ChromiumBaseWaiter(object): def load_complete(self, timeout: float = None) -> bool: ... - def set_targets(self, targets: Union[str, list, tuple, set], is_regex: bool = False) -> None: ... - - def stop_listening(self) -> None: ... - - def data_packets(self, timeout: float = None, - any_one: bool = False) -> Union[ResponseData, Dict[str, ResponseData], False]: ... - def upload_paths_inputted(self) -> None: ... class NetworkListener(object): - def __init__(self, page): + def __init__(self, page: ChromiumBase): self._page: ChromiumBase = ... + self._count: int = ... + self._caught: int = ... self._targets: Union[str, dict] = ... self._single: bool = ... - self._results: Union[ResponseData, Dict[str, ResponseData], False] = ... + self._method: set = ... + self._results: Union[DataPacket, Dict[str, List[DataPacket]], False] = ... self._is_regex: bool = ... + self._driver: ChromiumDriver = ... self._requests: dict = ... + self.is_listening: bool = ... - def set_targets(self, targets: Union[str, list, tuple, set], is_regex: bool = False) -> None: ... + def set_targets(self, targets: Union[str, list, tuple, set, None] = None, is_regex: bool = False, + count: int = None, method: Union[str, list, tuple, set] = None) -> None: ... + + def start(self) -> None: ... def stop(self) -> None: ... - def listen(self, timeout: float = None, - any_one: bool = False) -> Union[ResponseData, Dict[str, ResponseData], False]: ... + @property + def results(self) -> Union[DataPacket, Dict[str, List[DataPacket]], False]: ... + + def clear(self) -> None: ... + + def listen(self, timeout: float = None, any_one: bool = False, + asyn: bool = False) -> Union[DataPacket, Dict[str, List[DataPacket]], False]: ... + + def _listen(self, timeout: float = None, + any_one: bool = False) -> Union[DataPacket, Dict[str, List[DataPacket]], False]: ... + + def _requestWillBeSent(self, **kwargs) -> None: ... def _response_received(self, **kwargs) -> None: ... def _loading_finished(self, **kwargs) -> None: ... - def _requestWillBeSent(self, **kwargs) -> None: ... + def _loading_failed(self, **kwargs) -> None: ... + + def _request_paused(self, **kwargs) -> None: ... class ChromiumPageScroll(ChromiumScroll): @@ -366,4 +376,4 @@ class ScreencastMode(object): def frugal_imgs_mode(self) -> None: ... - def imgs_mode(self) -> None: ... \ No newline at end of file + def imgs_mode(self) -> None: ... diff --git a/DrissionPage/commons/web.py b/DrissionPage/commons/web.py index 0a7cd14..cd5b16e 100644 --- a/DrissionPage/commons/web.py +++ b/DrissionPage/commons/web.py @@ -6,7 +6,7 @@ from base64 import b64decode from html import unescape from http.cookiejar import Cookie -from json import loads, JSONDecodeError +from json import JSONDecodeError, loads from re import sub from urllib.parse import urlparse, urljoin, urlunparse @@ -15,87 +15,123 @@ from requests.structures import CaseInsensitiveDict from tldextract import extract -class ResponseData(object): +class DataPacket(object): """返回的数据包管理类""" - __slots__ = ('requestId', 'response', 'rawBody', 'tab', 'target', 'url', 'status', 'statusText', 'securityDetails', - 'headersText', 'mimeType', 'requestHeadersText', 'connectionReused', 'connectionId', 'remoteIPAddress', - 'remotePort', 'fromDiskCache', 'fromServiceWorker', 'fromPrefetchCache', 'encodedDataLength', 'timing', - 'serviceWorkerResponseSource', 'responseTime', 'cacheStorageCacheName', 'protocol', 'securityState', - '_requestHeaders', '_body', '_base64_body', '_rawPostData', '_postData', 'method') - def __init__(self, request_id, response, body, tab, target): + def __init__(self, tab, target, raw_request): """ - :param response: response的数据 - :param body: response包含的内容 :param tab: 产生这个数据包的tab的id :param target: 监听目标 + :param raw_request: 原始request数据,从cdp获得 """ - self.requestId = request_id - self.response = CaseInsensitiveDict(response) - self.rawBody = body self.tab = tab self.target = target - self._requestHeaders = None - self._postData = None - self._body = None + + self._raw_request = raw_request + self._raw_post_data = None + + self._raw_response = None + self._raw_body = None self._base64_body = False - self._rawPostData = None + + self._request = None + self._response = None + self.errorText = None + self._resource_type = None + + @property + def url(self): + return self.request.url + + @property + def method(self): + return self.request.method + + @property + def frameId(self): + return self._raw_request.get('frameId') + + @property + def resourceType(self): + return self._resource_type + + @property + def request(self): + if self._request is None: + self._request = Request(self._raw_request['request'], self._raw_post_data) + return self._request + + @property + def response(self): + if self._response is None: + self._response = Response(self._raw_response, self._raw_body, self._base64_body) + return self._response + + +class Request(object): + def __init__(self, raw_request, post_data): + self._request = raw_request + self._raw_post_data = post_data + self._postData = None + self._headers = None def __getattr__(self, item): - return self.response.get(item, None) - - def __getitem__(self, item): - return self.response.get(item, None) - - def __repr__(self): - return f'' + return self._request.get(item, None) @property def headers(self): """以大小写不敏感字典返回headers数据""" - headers = self.response.get('headers', None) - return CaseInsensitiveDict(headers) if headers else None - - @property - def requestHeaders(self): - """以大小写不敏感字典返回requestHeaders数据""" - if self._requestHeaders: - return self._requestHeaders - headers = self.response.get('requestHeaders', None) - return CaseInsensitiveDict(headers) if headers else None - - @requestHeaders.setter - def requestHeaders(self, val): - """设置requestHeaders""" - self._requestHeaders = val + if self._headers is None: + self._headers = CaseInsensitiveDict(self._request['headers']) + return self._headers @property def postData(self): """返回postData数据""" - if self._postData is None and self._rawPostData: + if self._postData is None: + if self._raw_post_data: + postData = self._raw_post_data + elif self._request.get('postData', None): + postData = self._request['postData'] + else: + postData = False try: - self._postData = loads(self._rawPostData) + self._postData = loads(postData) except (JSONDecodeError, TypeError): - self._postData = self._rawPostData + self._postData = postData return self._postData - @postData.setter - def postData(self, val): - """设置postData""" - self._rawPostData = val + +class Response(object): + def __init__(self, raw_response, raw_body, base64_body): + self._response = raw_response + self._raw_body = raw_body + self._is_base64_body = base64_body + self._body = None + self._headers = None + + def __getattr__(self, item): + return self._response.get(item, None) + + @property + def headers(self): + """以大小写不敏感字典返回headers数据""" + if self._headers is None: + self._headers = CaseInsensitiveDict(self._response['headers']) + return self._headers @property def body(self): """返回body内容,如果是json格式,自动进行转换,如果时图片格式,进行base64转换,其它格式直接返回文本""" if self._body is None: - if self._base64_body: - self._body = b64decode(self.rawBody) + if self._is_base64_body: + self._body = b64decode(self._raw_body) else: try: - self._body = loads(self.rawBody) + self._body = loads(self._raw_body) except (JSONDecodeError, TypeError): - self._body = self.rawBody + self._body = self._raw_body return self._body diff --git a/DrissionPage/commons/web.pyi b/DrissionPage/commons/web.pyi index 115fd08..9c87f1a 100644 --- a/DrissionPage/commons/web.pyi +++ b/DrissionPage/commons/web.pyi @@ -15,64 +15,84 @@ from DrissionPage.chromium_element import ChromiumElement from DrissionPage.chromium_base import ChromiumBase -class ResponseData(object): +class DataPacket(object): + """返回的数据包管理类""" - def __init__(self, request_id: str, response: dict, body: str, tab: str, target: str): - self.requestId: str = ... - self.response: CaseInsensitiveDict = ... - self.rawBody: str = ... - self._body: Union[str, dict, bytes] = ... - self._base64_body: bool = ... + def __init__(self, tab: str, target: str, raw_info: dict): self.tab: str = ... self.target: str = ... - self.method: str = ... - self._postData: dict = ... - self._rawPostData: str = ... - self.url: str = ... - self.status: str = ... - self.statusText: str = ... - self.headersText: str = ... - self.mimeType: str = ... - self.requestHeadersText: str = ... - self.connectionReused: str = ... - self.connectionId: str = ... - self.remoteIPAddress: str = ... - self.remotePort: str = ... - self.fromDiskCache: str = ... - self.fromServiceWorker: str = ... - self.fromPrefetchCache: str = ... - self.encodedDataLength: str = ... - self.timing: str = ... - self.serviceWorkerResponseSource: str = ... - self.responseTime: str = ... - self.cacheStorageCacheName: str = ... - self.protocol: str = ... - self.securityState: str = ... - self.securityDetails: str = ... - - def __getattr__(self, item: str) -> Union[str, None]: ... - - def __getitem__(self, item: str) -> Union[str, None]: ... - - def __repr__(self) -> str: ... + self._raw_request: dict = ... + self._raw_response: dict = ... + self._raw_post_data: str = ... + self._raw_body: str = ... + self._base64_body: bool = ... + self._request: Request = ... + self._response: Response = ... + self.errorText: str = ... + self._resource_type: str = ... @property - def headers(self) -> Union[CaseInsensitiveDict, None]: ... + def url(self) -> str: ... @property - def requestHeaders(self) -> Union[CaseInsensitiveDict, None]: ... - - @requestHeaders.setter - def requestHeaders(self, val: dict) -> None: ... + def method(self) -> str: ... @property - def postData(self) -> Union[dict, str, None]: ... - - @postData.setter - def postData(self, val: Union[str, dict]) -> None: ... + def frameId(self) -> str: ... @property - def body(self) -> Union[str, dict, bytes]: ... + def resourceType(self) -> str: ... + + @property + def request(self) -> Request: ... + + @property + def response(self) -> Response: ... + + +class Request(object): + url: str = ... + _headers: Union[CaseInsensitiveDict, None] = ... + method: str = ... + + # urlFragment: str = ... + # postDataEntries: list = ... + # mixedContentType: str = ... + # initialPriority: str = ... + # referrerPolicy: str = ... + # isLinkPreload: bool = ... + # trustTokenParams: dict = ... + # isSameSite: bool = ... + + def __init__(self, raw_request: dict, post_data: str): + self._request: dict = ... + self._raw_post_data: str = ... + self._postData: str = ... + + @property + def headers(self) -> dict: ... + + @property + def postData(self) -> Union[str, dict]: ... + + +class Response(object): + status: str = ... + statusText: int = ... + mimeType: str = ... + + def __init__(self, raw_response: dict, raw_body: str, base64_body: bool): + self._response: dict = ... + self._raw_body: str = ... + self._is_base64_body: bool = ... + self._body: Union[str, dict] = ... + self._headers: dict = ... + + @property + def headers(self) -> CaseInsensitiveDict: ... + + @property + def body(self) -> Union[str, dict, bool]: ... def get_ele_txt(e: DrissionElement) -> str: ...