diff --git a/DrissionPage/chromium_base.py b/DrissionPage/chromium_base.py index ff55aee..1ddc43d 100644 --- a/DrissionPage/chromium_base.py +++ b/DrissionPage/chromium_base.py @@ -1024,15 +1024,17 @@ class ChromiumBaseWaiter(object): sleep(gap) return False - def set_targets(self, targets, is_regex=False): + def set_targets(self, targets=None, is_regex=False, count=None): """指定要等待的数据包 - :param targets: 要匹配的数据包url特征,可用list等传入多个 + :param targets: 要匹配的数据包url特征,可用list等传入多个,为None时获取所有 :param is_regex: 设置的target是否正则表达式 + :param count: 设置总共等待多少个数据包,为None时每个目标等待1个 :return: None """ if not self._listener: self._listener = NetworkListener(self._driver) - self._listener.set_targets(targets, is_regex) + self._listener.set_targets(targets, is_regex, count=count) + self._listener.start() def data_packets(self, timeout=None, any_one=False): """等待指定数据包加载完成 @@ -1064,25 +1066,26 @@ class NetworkListener(object): self._caught = 0 # 已获取到的数量 self._driver = self._page.driver - def set_targets(self, targets, is_regex=False, count=None): + def set_targets(self, targets=None, is_regex=False, count=None): """指定要等待的数据包 - :param targets: 要匹配的数据包url特征,可用list等传入多个 + :param targets: 要匹配的数据包url特征,可用list等传入多个,为None时获取所有 :param is_regex: 设置的target是否正则表达式 :param count: 设置总共等待多少个数据包,为None时每个目标等待1个 :return: None """ - if not isinstance(targets, (str, list, tuple, set)): - raise TypeError('targets只能是str、list、tuple、set。') + if not isinstance(targets, (str, list, tuple, set)) and targets is not None: + raise TypeError('targets只能是str、list、tuple、set、None。') + if targets is None: + targets = '' self._is_regex = is_regex if isinstance(targets, str): self._targets = {targets} - self._single = True else: self._targets = set(targets) - self._single = False if count is None: self._count = len(self._targets) + self._single = self._count == 1 def start(self): self._driver.set_listener('Fetch.requestPaused', self._request_paused) @@ -1120,14 +1123,27 @@ class NetworkListener(object): self._requests = {} if not self._results: return False - r = list(self._results.values())[0] if self._single else self._results + r = list(self._results.values())[0][0] if self._single else self._results self._results = {} return r + def _requestWillBeSent(self, **kwargs): + """接收到请求时的回调函数""" + for target in self._targets: + if (self._is_regex and search(target, kwargs['request']['url'])) or ( + not self._is_regex and target in kwargs['request']['url']): + self._requests[kwargs['requestId']] = DataPacket(self._page.tab_id, target, kwargs) + + if kwargs['request'].get('hasPostData', None) and not kwargs['request'].get('postData', None): + self._requests[kwargs['requestId']]._raw_post_data = \ + self._page.run_cdp('Network.getRequestPostData', requestId=kwargs['requestId'])['postData'] + + break + def _response_received(self, **kwargs): """接收到返回信息时处理方法""" if kwargs['requestId'] in self._requests: - self._requests[kwargs['requestId']]['response'] = kwargs['response'] + self._requests[kwargs['requestId']]._raw_response = kwargs['response'] def _loading_finished(self, **kwargs): """请求完成时处理方法""" @@ -1141,23 +1157,17 @@ class NetworkListener(object): body = '' is_base64 = False - request = self._requests[request_id] - target = request['target'] - rd = ResponseData(request_id, request['response'], body, self._page.tab_id, target) - rd.postData = request['post_data'] - rd._base64_body = is_base64 - rd.requestHeaders = request['request_headers'] - self._results[target] = rd + dp = self._requests[request_id] + target = dp.target + dp._raw_body = body + dp._base64_body = is_base64 - def _requestWillBeSent(self, **kwargs): - """接收到请求时的回调函数""" - for target in self._targets: - if (self._is_regex and search(target, kwargs['request']['url'])) or ( - not self._is_regex and target in kwargs['request']['url']): - self._requests[kwargs['requestId']] = {'target': target, - 'post_data': kwargs['request'].get('postData', None), - 'request_headers': kwargs['request']['headers']} - break + if target in self._results: + self._results[target].append(dp) + else: + self._results[target] = [dp] + + self._caught += 1 def _request_paused(self, **kwargs): i = kwargs['requestId'] diff --git a/DrissionPage/chromium_base.pyi b/DrissionPage/chromium_base.pyi index 809617d..97ed3bd 100644 --- a/DrissionPage/chromium_base.pyi +++ b/DrissionPage/chromium_base.pyi @@ -226,7 +226,8 @@ class ChromiumBaseWaiter(object): def upload_paths_inputted(self) -> None: ... - def set_targets(self, targets: Union[str, list, tuple, set], is_regex: bool = False) -> None: ... + def set_targets(self, targets: Union[str, list, tuple, set, None] = None, is_regex: bool = False, + count: int = None) -> None: ... def stop_listening(self) -> None: ... @@ -246,7 +247,8 @@ class NetworkListener(object): self._driver: ChromiumDriver = ... self._requests: dict = ... - def set_targets(self, targets: Union[str, list, tuple, set], is_regex: bool = False, count: int = None) -> None: ... + def set_targets(self, targets: Union[str, list, tuple, set, None] = None, is_regex: bool = False, + count: int = None) -> None: ... def start(self) -> None: ... diff --git a/DrissionPage/chromium_page.pyi b/DrissionPage/chromium_page.pyi index f05a3bd..60315f1 100644 --- a/DrissionPage/chromium_page.pyi +++ b/DrissionPage/chromium_page.pyi @@ -118,13 +118,6 @@ class ChromiumPageWaiter(ChromiumBaseWaiter): def new_tab(self, timeout: float = None) -> bool: ... - def set_targets(self, targets: Union[str, list, tuple, set], is_regex: bool = False) -> None: ... - - def stop_listening(self) -> None: ... - - def data_packets(self, timeout: float = None, - any_one: bool = False) -> Union[DataPacket, Dict[str, List[DataPacket]], False]: ... - class ChromiumTabRect(object): def __init__(self, page: ChromiumPage): diff --git a/DrissionPage/commons/web.py b/DrissionPage/commons/web.py index c782f88..771c98f 100644 --- a/DrissionPage/commons/web.py +++ b/DrissionPage/commons/web.py @@ -18,30 +18,31 @@ from tldextract import extract class DataPacket(object): """返回的数据包管理类""" - def __init__(self, tab, target, raw_info): + def __init__(self, tab, target, raw_request): """ :param tab: 产生这个数据包的tab的id :param target: 监听目标 - :param raw_info: 原始request数据,从cdp获得 + :param raw_request: 原始request数据,从cdp获得 """ self.tab = tab self.target = target - self._raw_info = raw_info - self._raw_post_data = None + self._raw_request = raw_request + self._raw_response = None + self._raw_post_data = None self._raw_body = None self._base64_body = False self._request = None self._response = None - def __repr__(self): - return f'' - - @property - def requestId(self): - return self._raw_info['requestId'] + # def __repr__(self): + # return f'' + # + # @property + # def requestId(self): + # return self._raw_info['requestId'] @property def url(self): @@ -53,28 +54,28 @@ class DataPacket(object): @property def frameId(self): - return self._raw_info['frameId'] + return self._raw_request['frameId'] - @property - def resourceType(self): - return self._raw_info['resourceType'] + # @property + # def resourceType(self): + # return self._raw_request['resourceType'] @property def request(self): if self._request is None: - self._request = Request(self._raw_info['request'], self._raw_post_data) + self._request = Request(self._raw_request['request'], self._raw_post_data) return self._request - @property - def response(self): - if self._response is None: - self._response = Response(self._raw_info, self._raw_body, self._base64_body) - return self._response + # @property + # def response(self): + # if self._response is None: + # self._response = Response(self._raw_info, self._raw_body, self._base64_body) + # return self._response class Request(object): __slots__ = ('url', 'urlFragment', 'postDataEntries', 'mixedContentType', 'initialPriority', - 'referrerPolicy', 'isLinkPreload', 'trustTokenParams', 'isSameSite', + 'referrerPolicy', 'isLinkPreload', 'trustTokenParams', 'isSameSite', 'method', '_request', '_raw_post_data', '_postData') def __init__(self, raw_request, post_data): diff --git a/DrissionPage/commons/web.pyi b/DrissionPage/commons/web.pyi index 606ab87..2598d90 100644 --- a/DrissionPage/commons/web.pyi +++ b/DrissionPage/commons/web.pyi @@ -21,7 +21,7 @@ class DataPacket(object): def __init__(self, tab: str, target: str, raw_info: dict): self.tab: str = ... self.target: str = ... - self._raw_info: dict = ... + self._raw_request: dict = ... self._raw_post_data: str = ... self._raw_body: str = ... self._base64_body: bool = ... @@ -55,6 +55,7 @@ class DataPacket(object): class Request(object): url: str = ... urlFragment: str = ... + method:str = ... postDataEntries: list = ... mixedContentType: str = ... initialPriority: str = ...