diff --git a/DrissionPage/chromium_element.py b/DrissionPage/chromium_element.py index 395622b..634dde2 100644 --- a/DrissionPage/chromium_element.py +++ b/DrissionPage/chromium_element.py @@ -9,19 +9,19 @@ from os import sep from os.path import basename from pathlib import Path from re import search -from typing import Union, Tuple, List, Any from time import perf_counter, sleep +from typing import Union, Tuple, List, Any from urllib.parse import urlparse -from pychrome import Tab from requests import Session from requests.cookies import RequestsCookieJar +from .base import DrissionElement, BaseElement, BasePage +from .common import make_absolute_link, get_loc, get_ele_txt, format_html, is_js_func, _location_in_viewport from .config import DriverOptions, _cookies_to_tuple from .keys import _keys_to_typing, _keyDescriptionForString, _keyDefinitions from .session_element import make_session_ele, SessionElement -from .base import DrissionElement, BaseElement, BasePage -from .common import make_absolute_link, get_loc, get_ele_txt, format_html, is_js_func, _location_in_viewport +from .tab import Tab class ChromiumElement(DrissionElement): diff --git a/DrissionPage/chromium_page.py b/DrissionPage/chromium_page.py index 70c2c81..2ad9b1f 100644 --- a/DrissionPage/chromium_page.py +++ b/DrissionPage/chromium_page.py @@ -6,13 +6,13 @@ from re import search from time import perf_counter, sleep from typing import Union, Tuple, List -from pychrome import Tab from requests import Session from .chromium_element import Timeout, ChromiumBase from .chromium_tab import ChromiumTab from .config import DriverOptions from .drission import connect_chrome +from .tab import Tab class ChromiumPage(ChromiumBase): @@ -129,8 +129,12 @@ class ChromiumPage(ChromiumBase): self._window_setter = WindowSizeSetter(self) return self._window_setter - def get_tab(self, tab_id: str) -> ChromiumTab: - """获取一个标签页对象""" + def get_tab(self, tab_id: str = None) -> ChromiumTab: + """获取一个标签页对象 \n + :param tab_id: 要获取的标签页id,为None时获取当前tab + :return: 标签页对象 + """ + tab_id = tab_id or self.tab_id return ChromiumTab(self, tab_id) def get_screenshot(self, path: [str, Path] = None, diff --git a/DrissionPage/tab.py b/DrissionPage/tab.py new file mode 100644 index 0000000..1c0e8ce --- /dev/null +++ b/DrissionPage/tab.py @@ -0,0 +1,258 @@ +# -*- coding: utf-8 -*- + +from __future__ import unicode_literals + +from functools import partial +from json import dumps, loads +from logging import getLogger +from os import getenv +from threading import Thread, Event +from warnings import warn + +from websocket import WebSocketTimeoutException, WebSocketException, WebSocketConnectionClosedException, \ + create_connection + +try: + import Queue as queue +except ImportError: + import queue + +logger = getLogger(__name__) + + +class GenericAttr(object): + def __init__(self, name, tab): + self.__dict__['name'] = name + self.__dict__['tab'] = tab + + def __getattr__(self, item): + method_name = "%s.%s" % (self.name, item) + event_listener = self.tab.get_listener(method_name) + + if event_listener: + return event_listener + + return partial(self.tab.call_method, method_name) + + def __setattr__(self, key, value): + self.tab.set_listener("%s.%s" % (self.name, key), value) + + +class Tab(object): + status_initial = 'initial' + status_started = 'started' + status_stopped = 'stopped' + + def __init__(self, **kwargs): + self.id = kwargs.get("id") + self.type = kwargs.get("type") + self.debug = getenv("DEBUG", False) + + self._websocket_url = kwargs.get("webSocketDebuggerUrl") + self._kwargs = kwargs + + self._cur_id = 1000 + + self._ws = None + + self._recv_th = Thread(target=self._recv_loop) + self._recv_th.daemon = True + self._handle_event_th = Thread(target=self._handle_event_loop) + self._handle_event_th.daemon = True + + self._stopped = Event() + self._started = False + self.status = self.status_initial + + self.event_handlers = {} + self.method_results = {} + self.event_queue = queue.Queue() + + def _send(self, message, timeout=None): + if 'id' not in message: + self._cur_id += 1 + message['id'] = self._cur_id + + message_json = dumps(message) + + if self.debug: # pragma: no cover + print("SEND > %s" % message_json) + + if not isinstance(timeout, (int, float)) or timeout > 1: + q_timeout = 1 + else: + q_timeout = timeout / 2.0 + + try: + self.method_results[message['id']] = queue.Queue() + + # just raise the exception to user + self._ws.send(message_json) + + while not self._stopped.is_set(): + try: + if isinstance(timeout, (int, float)): + if timeout < q_timeout: + q_timeout = timeout + + timeout -= q_timeout + + return self.method_results[message['id']].get(timeout=q_timeout) + except queue.Empty: + if isinstance(timeout, (int, float)) and timeout <= 0: + raise TimeoutException("Calling %s timeout" % message['method']) + + continue + + raise UserAbortException("User abort, call stop() when calling %s" % message['method']) + finally: + self.method_results.pop(message['id'], None) + + def _recv_loop(self): + while not self._stopped.is_set(): + try: + self._ws.settimeout(1) + message_json = self._ws.recv() + message = loads(message_json) + except WebSocketTimeoutException: + continue + except (WebSocketException, OSError, WebSocketConnectionClosedException): + if not self._stopped.is_set(): + # logger.error("websocket exception", exc_info=True) + self._stopped.set() + return + + if self.debug: # pragma: no cover + print('< RECV %s' % message_json) + + if "method" in message: + self.event_queue.put(message) + + elif "id" in message: + if message["id"] in self.method_results: + self.method_results[message['id']].put(message) + else: # pragma: no cover + warn("unknown message: %s" % message) + + def _handle_event_loop(self): + while not self._stopped.is_set(): + try: + event = self.event_queue.get(timeout=1) + except queue.Empty: + continue + + if event['method'] in self.event_handlers: + try: + self.event_handlers[event['method']](**event['params']) + except Exception as e: + logger.error("callback %s exception" % event['method'], exc_info=True) + + self.event_queue.task_done() + + def __getattr__(self, item): + attr = GenericAttr(item, self) + setattr(self, item, attr) + return attr + + def call_method(self, _method, *args, **kwargs): + if not self._started: + raise RuntimeException("Cannot call method before it is started") + + if args: + raise CallMethodException("the params should be key=value format") + + if self._stopped.is_set(): + raise RuntimeException("Tab has been stopped") + + timeout = kwargs.pop("_timeout", None) + result = self._send({"method": _method, "params": kwargs}, timeout=timeout) + if 'result' not in result and 'error' in result: + warn("%s error: %s" % (_method, result['error']['message'])) + raise CallMethodException("calling method: %s error: %s" % (_method, result['error']['message'])) + + return result['result'] + + def set_listener(self, event, callback): + if not callback: + return self.event_handlers.pop(event, None) + + if not callable(callback): + raise RuntimeException("callback should be callable") + + self.event_handlers[event] = callback + return True + + def get_listener(self, event): + return self.event_handlers.get(event, None) + + def del_all_listeners(self): + self.event_handlers = {} + return True + + def start(self): + if self._started: + return False + + if not self._websocket_url: + raise RuntimeException("Already has another client connect to this tab") + + self._started = True + self.status = self.status_started + self._stopped.clear() + self._ws = create_connection(self._websocket_url, enable_multithread=True) + self._recv_th.start() + self._handle_event_th.start() + return True + + def stop(self): + if self._stopped.is_set(): + return False + + if not self._started: + raise RuntimeException("Tab is not running") + + self.status = self.status_stopped + self._stopped.set() + if self._ws: + self._ws.close() + return True + + def wait(self, timeout=None): + if not self._started: + raise RuntimeException("Tab is not running") + + if timeout: + return self._stopped.wait(timeout) + + self._recv_th.join() + self._handle_event_th.join() + return True + + def __str__(self): + return "" % self.id + + __repr__ = __str__ + + +class PyChromeException(Exception): + pass + + +class UserAbortException(PyChromeException): + pass + + +class TabConnectionException(PyChromeException): + pass + + +class CallMethodException(PyChromeException): + pass + + +class TimeoutException(PyChromeException): + pass + + +class RuntimeException(PyChromeException): + pass diff --git a/DrissionPage/web_page.py b/DrissionPage/web_page.py index 09a0ee5..d22e6cb 100644 --- a/DrissionPage/web_page.py +++ b/DrissionPage/web_page.py @@ -3,16 +3,16 @@ from time import sleep from typing import Union, Tuple, List from DownloadKit import DownloadKit -from pychrome import Tab from requests import Session, Response from tldextract import extract -from .chromium_element import ChromiumElement, ChromiumFrame, ChromiumBase -from .session_element import SessionElement from .base import BasePage -from .config import DriverOptions, SessionOptions, _cookies_to_tuple +from .chromium_element import ChromiumElement, ChromiumFrame, ChromiumBase from .chromium_page import ChromiumPage +from .config import DriverOptions, SessionOptions, _cookies_to_tuple +from .session_element import SessionElement from .session_page import SessionPage +from .tab import Tab class WebPage(SessionPage, ChromiumPage, BasePage): diff --git a/requirements.txt b/requirements.txt index cafa535..a2d4027 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,4 @@ lxml cssselect DownloadKit FlowViewer -pychrome \ No newline at end of file +websocket-client \ No newline at end of file diff --git a/setup.py b/setup.py index 1bea2b3..5c846b2 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ with open("README.md", "r", encoding='utf-8') as fh: setup( name="DrissionPage", - version="3.0.8", + version="3.0.9", author="g1879", author_email="g1879@qq.com", description="A module that integrates selenium and requests session, encapsulates common page operations.", @@ -24,7 +24,7 @@ setup( "requests", "DownloadKit", "FlowViewer", - "pychrome" + "websocket-client" ], classifiers=[ "Programming Language :: Python :: 3.6",