From 9239cadd8b3e539a7d5da0eb22b2b047417fb426 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Sun, 19 Nov 2023 05:36:04 +0100 Subject: Add Response Handler to OpenaiChat Update Providers with WebDriver Add WebDriverSession helper Use native streaming in curl_cffi --- g4f/Provider/AItianhuSpace.py | 58 ++++--- g4f/Provider/MyShell.py | 34 ++--- g4f/Provider/PerplexityAi.py | 62 ++++---- g4f/Provider/Phind.py | 15 +- g4f/Provider/TalkAi.py | 21 +-- g4f/Provider/helper.py | 52 +++++++ g4f/Provider/needs_auth/Bard.py | 45 +++--- g4f/Provider/needs_auth/OpenaiChat.py | 274 +++++++++++++++++++++++++--------- g4f/Provider/needs_auth/Poe.py | 93 ++++++------ g4f/Provider/needs_auth/Theb.py | 92 ++++++------ g4f/requests.py | 163 +++----------------- 11 files changed, 461 insertions(+), 448 deletions(-) (limited to 'g4f') diff --git a/g4f/Provider/AItianhuSpace.py b/g4f/Provider/AItianhuSpace.py index 312cb3b3..fabe6b47 100644 --- a/g4f/Provider/AItianhuSpace.py +++ b/g4f/Provider/AItianhuSpace.py @@ -5,7 +5,7 @@ import random from ..typing import CreateResult, Messages from .base_provider import BaseProvider -from .helper import WebDriver, format_prompt, get_browser, get_random_string +from .helper import WebDriver, WebDriverSession, format_prompt, get_random_string from .. import debug class AItianhuSpace(BaseProvider): @@ -24,7 +24,7 @@ class AItianhuSpace(BaseProvider): domain: str = None, proxy: str = None, timeout: int = 120, - browser: WebDriver = None, + web_driver: WebDriver = None, headless: bool = True, **kwargs ) -> CreateResult: @@ -38,36 +38,35 @@ class AItianhuSpace(BaseProvider): print(f"AItianhuSpace | using domain: {domain}") url = f"https://{domain}" prompt = format_prompt(messages) - driver = browser if browser else get_browser("", headless, proxy) - from selenium.webdriver.common.by import By - from selenium.webdriver.support.ui import WebDriverWait - from selenium.webdriver.support import expected_conditions as EC + with WebDriverSession(web_driver, "", headless=headless, proxy=proxy) as driver: + from selenium.webdriver.common.by import By + from selenium.webdriver.support.ui import WebDriverWait + from selenium.webdriver.support import expected_conditions as EC - wait = WebDriverWait(driver, timeout) + wait = WebDriverWait(driver, timeout) - # Bypass devtools detection - driver.get("https://blank.page/") - wait.until(EC.visibility_of_element_located((By.ID, "sheet"))) - driver.execute_script(f""" -document.getElementById('sheet').addEventListener('click', () => {{ - window.open('{url}', '_blank'); -}}); -""") - driver.find_element(By.ID, "sheet").click() - time.sleep(10) + # Bypass devtools detection + driver.get("https://blank.page/") + wait.until(EC.visibility_of_element_located((By.ID, "sheet"))) + driver.execute_script(f""" + document.getElementById('sheet').addEventListener('click', () => {{ + window.open('{url}', '_blank'); + }}); + """) + driver.find_element(By.ID, "sheet").click() + time.sleep(10) - original_window = driver.current_window_handle - for window_handle in driver.window_handles: - if window_handle != original_window: - driver.close() - driver.switch_to.window(window_handle) - break + original_window = driver.current_window_handle + for window_handle in driver.window_handles: + if window_handle != original_window: + driver.close() + driver.switch_to.window(window_handle) + break - # Wait for page load - wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea.n-input__textarea-el"))) + # Wait for page load + wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea.n-input__textarea-el"))) - try: # Register hook in XMLHttpRequest script = """ const _http_request_open = XMLHttpRequest.prototype.open; @@ -114,9 +113,4 @@ return ""; elif chunk != "": break else: - time.sleep(0.1) - finally: - if not browser: - driver.close() - time.sleep(0.1) - driver.quit() \ No newline at end of file + time.sleep(0.1) \ No newline at end of file diff --git a/g4f/Provider/MyShell.py b/g4f/Provider/MyShell.py index 548c4be1..a1c8d335 100644 --- a/g4f/Provider/MyShell.py +++ b/g4f/Provider/MyShell.py @@ -4,7 +4,7 @@ import time, json from ..typing import CreateResult, Messages from .base_provider import BaseProvider -from .helper import WebDriver, format_prompt, get_browser +from .helper import WebDriver, WebDriverSession, format_prompt class MyShell(BaseProvider): url = "https://app.myshell.ai/chat" @@ -20,22 +20,27 @@ class MyShell(BaseProvider): stream: bool, proxy: str = None, timeout: int = 120, - browser: WebDriver = None, + web_driver: WebDriver = None, **kwargs ) -> CreateResult: - driver = browser if browser else get_browser("", False, proxy) + with WebDriverSession(web_driver, "", proxy=proxy) as driver: + from selenium.webdriver.common.by import By + from selenium.webdriver.support.ui import WebDriverWait + from selenium.webdriver.support import expected_conditions as EC - from selenium.webdriver.common.by import By - from selenium.webdriver.support.ui import WebDriverWait - from selenium.webdriver.support import expected_conditions as EC + driver.get(cls.url) - driver.get(cls.url) - try: # Wait for page load and cloudflare validation WebDriverWait(driver, timeout).until( EC.presence_of_element_located((By.CSS_SELECTOR, "body:not(.no-js)")) ) # Send request with message + data = { + "botId": "4738", + "conversation_scenario": 3, + "message": format_prompt(messages), + "messageType": 1 + } script = """ response = await fetch("https://api.myshell.ai/v1/bot/chat/send_message", { "headers": { @@ -49,12 +54,6 @@ response = await fetch("https://api.myshell.ai/v1/bot/chat/send_message", { }) window.reader = response.body.getReader(); """ - data = { - "botId": "4738", - "conversation_scenario": 3, - "message": format_prompt(messages), - "messageType": 1 - } driver.execute_script(script.replace("{body}", json.dumps(data))) script = """ chunk = await window.reader.read(); @@ -80,9 +79,4 @@ return content; elif chunk != "": break else: - time.sleep(0.1) - finally: - if not browser: - driver.close() - time.sleep(0.1) - driver.quit() \ No newline at end of file + time.sleep(0.1) \ No newline at end of file diff --git a/g4f/Provider/PerplexityAi.py b/g4f/Provider/PerplexityAi.py index bce77715..c0b2412e 100644 --- a/g4f/Provider/PerplexityAi.py +++ b/g4f/Provider/PerplexityAi.py @@ -4,7 +4,7 @@ import time from ..typing import CreateResult, Messages from .base_provider import BaseProvider -from .helper import WebDriver, format_prompt, get_browser +from .helper import WebDriver, WebDriverSession, format_prompt class PerplexityAi(BaseProvider): url = "https://www.perplexity.ai" @@ -20,27 +20,27 @@ class PerplexityAi(BaseProvider): stream: bool, proxy: str = None, timeout: int = 120, - browser: WebDriver = None, + web_driver: WebDriver = None, + virtual_display: bool = True, copilot: bool = False, **kwargs ) -> CreateResult: - driver = browser if browser else get_browser("", False, proxy) + with WebDriverSession(web_driver, "", virtual_display=virtual_display, proxy=proxy) as driver: + from selenium.webdriver.common.by import By + from selenium.webdriver.support.ui import WebDriverWait + from selenium.webdriver.support import expected_conditions as EC + from selenium.webdriver.common.keys import Keys - from selenium.webdriver.common.by import By - from selenium.webdriver.support.ui import WebDriverWait - from selenium.webdriver.support import expected_conditions as EC - from selenium.webdriver.common.keys import Keys + prompt = format_prompt(messages) - prompt = format_prompt(messages) + driver.get(f"{cls.url}/") + wait = WebDriverWait(driver, timeout) - driver.get(f"{cls.url}/") - wait = WebDriverWait(driver, timeout) + # Is page loaded? + wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea[placeholder='Ask anything...']"))) - # Is page loaded? - wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea[placeholder='Ask anything...']"))) - - # Register WebSocket hook - script = """ + # Register WebSocket hook + script = """ window._message = window._last_message = ""; window._message_finished = false; const _socket_send = WebSocket.prototype.send; @@ -67,22 +67,21 @@ WebSocket.prototype.send = function(...args) { return _socket_send.call(this, ...args); }; """ - driver.execute_script(script) + driver.execute_script(script) - if copilot: - try: - # Check for account - driver.find_element(By.CSS_SELECTOR, "img[alt='User avatar']") - # Enable copilot - driver.find_element(By.CSS_SELECTOR, "button[data-testid='copilot-toggle']").click() - except: - raise RuntimeError("You need a account for copilot") + if copilot: + try: + # Check for account + driver.find_element(By.CSS_SELECTOR, "img[alt='User avatar']") + # Enable copilot + driver.find_element(By.CSS_SELECTOR, "button[data-testid='copilot-toggle']").click() + except: + raise RuntimeError("You need a account for copilot") - # Submit prompt - driver.find_element(By.CSS_SELECTOR, "textarea[placeholder='Ask anything...']").send_keys(prompt) - driver.find_element(By.CSS_SELECTOR, "textarea[placeholder='Ask anything...']").send_keys(Keys.ENTER) + # Submit prompt + driver.find_element(By.CSS_SELECTOR, "textarea[placeholder='Ask anything...']").send_keys(prompt) + driver.find_element(By.CSS_SELECTOR, "textarea[placeholder='Ask anything...']").send_keys(Keys.ENTER) - try: # Stream response script = """ if(window._message && window._message != window._last_message) { @@ -104,9 +103,4 @@ if(window._message && window._message != window._last_message) { elif chunk != "": break else: - time.sleep(0.1) - finally: - if not browser: - driver.close() - time.sleep(0.1) - driver.quit() \ No newline at end of file + time.sleep(0.1) \ No newline at end of file diff --git a/g4f/Provider/Phind.py b/g4f/Provider/Phind.py index 34abbe35..32f63665 100644 --- a/g4f/Provider/Phind.py +++ b/g4f/Provider/Phind.py @@ -5,7 +5,7 @@ from urllib.parse import quote from ..typing import CreateResult, Messages from .base_provider import BaseProvider -from .helper import WebDriver, format_prompt, get_browser +from .helper import WebDriver, WebDriverSession, format_prompt class Phind(BaseProvider): url = "https://www.phind.com" @@ -21,13 +21,11 @@ class Phind(BaseProvider): stream: bool, proxy: str = None, timeout: int = 120, - browser: WebDriver = None, + web_driver: WebDriver = None, creative_mode: bool = None, **kwargs ) -> CreateResult: - try: - driver = browser if browser else get_browser("", False, proxy) - + with WebDriverSession(web_driver, "", proxy=proxy) as driver: from selenium.webdriver.common.by import By from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support import expected_conditions as EC @@ -102,9 +100,4 @@ if(window._reader) { elif chunk != "": break else: - time.sleep(0.1) - finally: - if not browser: - driver.close() - time.sleep(0.1) - driver.quit() \ No newline at end of file + time.sleep(0.1) \ No newline at end of file diff --git a/g4f/Provider/TalkAi.py b/g4f/Provider/TalkAi.py index 5b03b91e..20ba65b5 100644 --- a/g4f/Provider/TalkAi.py +++ b/g4f/Provider/TalkAi.py @@ -4,7 +4,7 @@ import time, json, time from ..typing import CreateResult, Messages from .base_provider import BaseProvider -from .helper import WebDriver, get_browser +from .helper import WebDriver, WebDriverSession class TalkAi(BaseProvider): url = "https://talkai.info" @@ -19,16 +19,14 @@ class TalkAi(BaseProvider): messages: Messages, stream: bool, proxy: str = None, - browser: WebDriver = None, + web_driver: WebDriver = None, **kwargs ) -> CreateResult: - driver = browser if browser else get_browser("", False, proxy) + with WebDriverSession(web_driver, "", virtual_display=True, proxy=proxy) as driver: + from selenium.webdriver.common.by import By + from selenium.webdriver.support.ui import WebDriverWait + from selenium.webdriver.support import expected_conditions as EC - from selenium.webdriver.common.by import By - from selenium.webdriver.support.ui import WebDriverWait - from selenium.webdriver.support import expected_conditions as EC - - try: driver.get(f"{cls.url}/chat/") # Wait for page load @@ -86,9 +84,4 @@ return content; elif chunk != "": break else: - time.sleep(0.1) - finally: - if not browser: - driver.close() - time.sleep(0.1) - driver.quit() \ No newline at end of file + time.sleep(0.1) \ No newline at end of file diff --git a/g4f/Provider/helper.py b/g4f/Provider/helper.py index c420dee3..03e9ba94 100644 --- a/g4f/Provider/helper.py +++ b/g4f/Provider/helper.py @@ -6,6 +6,7 @@ import webbrowser import random import string import secrets +import time from os import path from asyncio import AbstractEventLoop from platformdirs import user_config_dir @@ -34,6 +35,10 @@ except ImportError: class ChromeOptions(): def add_argument(): pass +try: + from pyvirtualdisplay import Display +except ImportError: + pass from ..typing import Dict, Messages, Union, Tuple from .. import debug @@ -144,6 +149,53 @@ def get_browser( options.add_argument(f'--proxy-server={proxy}') return Chrome(options=options, user_data_dir=user_data_dir, headless=headless) +class WebDriverSession(): + def __init__( + self, + web_driver: WebDriver = None, + user_data_dir: str = None, + headless: bool = False, + virtual_display: bool = False, + proxy: str = None, + options: ChromeOptions = None + ): + self.web_driver = web_driver + self.user_data_dir = user_data_dir + self.headless = headless + self.virtual_display = virtual_display + self.proxy = proxy + self.options = options + + def reopen( + self, + user_data_dir: str = None, + headless: bool = False, + virtual_display: bool = False + ) -> WebDriver: + if user_data_dir == None: + user_data_dir = self.user_data_dir + self.default_driver.quit() + if not virtual_display and self.virtual_display: + self.virtual_display.stop() + self.default_driver = get_browser(user_data_dir, headless, self.proxy) + return self.default_driver + + def __enter__(self) -> WebDriver: + if self.web_driver: + return self.web_driver + if self.virtual_display == True: + self.virtual_display = Display(size=(1920,1080)) + self.virtual_display.start() + self.default_driver = get_browser(self.user_data_dir, self.headless, self.proxy, self.options) + return self.default_driver + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.default_driver: + self.default_driver.close() + time.sleep(0.1) + self.default_driver.quit() + if self.virtual_display: + self.virtual_display.stop() def get_random_string(length: int = 10) -> str: return ''.join( diff --git a/g4f/Provider/needs_auth/Bard.py b/g4f/Provider/needs_auth/Bard.py index b1df6909..77c029b8 100644 --- a/g4f/Provider/needs_auth/Bard.py +++ b/g4f/Provider/needs_auth/Bard.py @@ -4,7 +4,7 @@ import time from ...typing import CreateResult, Messages from ..base_provider import BaseProvider -from ..helper import WebDriver, format_prompt, get_browser +from ..helper import WebDriver, WebDriverSession, format_prompt class Bard(BaseProvider): url = "https://bard.google.com" @@ -18,34 +18,32 @@ class Bard(BaseProvider): messages: Messages, stream: bool, proxy: str = None, - browser: WebDriver = None, + web_driver: WebDriver = None, user_data_dir: str = None, headless: bool = True, **kwargs ) -> CreateResult: prompt = format_prompt(messages) - driver = browser if browser else get_browser(user_data_dir, headless, proxy) + session = WebDriverSession(web_driver, user_data_dir, headless, proxy=proxy) + with session as driver: + from selenium.webdriver.common.by import By + from selenium.webdriver.support.ui import WebDriverWait + from selenium.webdriver.support import expected_conditions as EC - from selenium.webdriver.common.by import By - from selenium.webdriver.support.ui import WebDriverWait - from selenium.webdriver.support import expected_conditions as EC - - try: - driver.get(f"{cls.url}/chat") - wait = WebDriverWait(driver, 10 if headless else 240) - wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "div.ql-editor.textarea"))) - except: - # Reopen browser for login - if not browser: - driver.quit() - driver = get_browser(None, False, proxy) + try: driver.get(f"{cls.url}/chat") - wait = WebDriverWait(driver, 240) + wait = WebDriverWait(driver, 10 if headless else 240) wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "div.ql-editor.textarea"))) - else: - raise RuntimeError("Prompt textarea not found. You may not be logged in.") + except: + # Reopen browser for login + if not web_driver: + driver = session.reopen(headless=False) + driver.get(f"{cls.url}/chat") + wait = WebDriverWait(driver, 240) + wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "div.ql-editor.textarea"))) + else: + raise RuntimeError("Prompt textarea not found. You may not be logged in.") - try: # Add hook in XMLHttpRequest script = """ const _http_request_open = XMLHttpRequest.prototype.open; @@ -72,9 +70,4 @@ XMLHttpRequest.prototype.open = function(method, url) { yield chunk return else: - time.sleep(0.1) - finally: - if not browser: - driver.close() - time.sleep(0.1) - driver.quit() \ No newline at end of file + time.sleep(0.1) \ No newline at end of file diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index eccf2bd7..9fd90812 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -1,20 +1,64 @@ from __future__ import annotations -import uuid, json, time, asyncio +import uuid, json, asyncio from py_arkose_generator.arkose import get_values_for_request +from asyncstdlib.itertools import tee +from async_property import async_cached_property from ..base_provider import AsyncGeneratorProvider -from ..helper import get_browser, get_cookies, format_prompt, get_event_loop +from ..helper import get_browser, get_event_loop from ...typing import AsyncResult, Messages from ...requests import StreamSession -from ... import debug + +models = { + "gpt-3.5": "text-davinci-002-render-sha", + "gpt-3.5-turbo": "text-davinci-002-render-sha", + "gpt-4": "gpt-4", + "gpt-4-gizmo": "gpt-4-gizmo" +} class OpenaiChat(AsyncGeneratorProvider): url = "https://chat.openai.com" - needs_auth = True working = True + needs_auth = True supports_gpt_35_turbo = True - _access_token = None + supports_gpt_4 = True + _access_token: str = None + + @classmethod + async def create( + cls, + prompt: str = None, + model: str = "", + messages: Messages = [], + history_disabled: bool = False, + action: str = "next", + conversation_id: str = None, + parent_id: str = None, + **kwargs + ) -> Response: + if prompt: + messages.append({"role": "user", "content": prompt}) + generator = cls.create_async_generator( + model, + messages, + history_disabled=history_disabled, + action=action, + conversation_id=conversation_id, + parent_id=parent_id, + response_fields=True, + **kwargs + ) + fields: ResponseFields = await anext(generator) + if "access_token" not in kwargs: + kwargs["access_token"] = cls._access_token + return Response( + generator, + fields, + action, + messages, + kwargs + ) @classmethod async def create_async_generator( @@ -25,50 +69,56 @@ class OpenaiChat(AsyncGeneratorProvider): timeout: int = 120, access_token: str = None, auto_continue: bool = False, - cookies: dict = None, + history_disabled: bool = True, + action: str = "next", + conversation_id: str = None, + parent_id: str = None, + response_fields: bool = False, **kwargs ) -> AsyncResult: - proxies = {"https": proxy} + if not model: + model = "gpt-3.5" + elif model not in models: + raise ValueError(f"Model are not supported: {model}") + if not parent_id: + parent_id = str(uuid.uuid4()) if not access_token: - access_token = await cls.get_access_token(cookies, proxies) + access_token = await cls.get_access_token(proxy) headers = { "Accept": "text/event-stream", "Authorization": f"Bearer {access_token}", + "Cookie": 'intercom-device-id-dgkjq2bp=0f047573-a750-46c8-be62-6d54b56e7bf0; ajs_user_id=user-iv3vxisaoNodwWpxmNpMfekH; ajs_anonymous_id=fd91be0b-0251-4222-ac1e-84b1071e9ec1; __Host-next-auth.csrf-token=d2b5f67d56f7dd6a0a42ae4becf2d1a6577b820a5edc88ab2018a59b9b506886%7Ce5c33eecc460988a137cbc72d90ee18f1b4e2f672104f368046df58e364376ac; _cfuvid=gt_mA.q6rue1.7d2.AR0KHpbVBS98i_ppfi.amj2._o-1700353424353-0-604800000; cf_clearance=GkHCfPSFU.NXGcHROoe4FantnqmnNcluhTNHz13Tk.M-1700353425-0-1-dfe77f81.816e9bc2.714615da-0.2.1700353425; __Secure-next-auth.callback-url=https%3A%2F%2Fchat.openai.com; intercom-session-dgkjq2bp=UWdrS1hHazk5VXN1c0V5Q1F0VXdCQmsyTU9pVjJMUkNpWnFnU3dKWmtIdGwxTC9wbjZuMk5hcEc0NWZDOGdndS0tSDNiaDNmMEdIL1RHU1dFWDBwOHFJUT09--f754361b91fddcd23a13b288dcb2bf8c7f509e91; _uasid="Z0FBQUFBQmxXVnV0a3dmVno4czRhcDc2ZVcwaUpSNUdZejlDR25YSk5NYTJQQkpyNmRvOGxjTHMyTlAxWmJhaURrMVhjLXZxQXdZeVpBbU1aczA5WUpHT2dwaS1MOWc4MnhyNWFnbGRzeGdJcGFKT0ZRdnBTMVJHcGV2MGNTSnVQY193c0hqUWIycHhQRVF4dENlZ3phcDdZeHgxdVhoalhrZmtZME9NbWhMQjdVR3Vzc3FRRk0ybjJjNWMwTWtIRjdPb19lUkFtRmV2MDVqd1kwWU11QTYtQkdZenEzVHhLMGplY1hZM3FlYUt1cVZaNWFTRldleEJETzJKQjk1VTJScy1GUnMxUVZWMnVxYklxMjdockVZbkZyd1R4U1RtMnA1ZzlSeXphdmVOVk9xeEdrRkVOSjhwTVd1QzFtQjhBcWdDaE92Q1VlM2pwcjFQTXRuLVJNRVlZSGpIdlZ0aGV3PT0="; _dd_s=rum=0&expire=1700356244884; __Secure-next-auth.session-token=eyJhbGciOiJkaXIiLCJlbmMiOiJBMjU2R0NNIn0..3aK6Fbdy2_8f07bf.8eT2xgonrCnz7ySY6qXFsg3kzL6UQfXKAYaw3tyn-6_X9657zy47k9qGvmi9mF0QKozj5jau3_Ca62AQQ7FmeC6Y2F1urtzqrXqwTTsQ2LuzFPIQkx6KKb2DXc8zW2-oyEzJ_EY5yxfLB2RlRkSh3M7bYNZh4_ltEcfkj38s_kIPGMxv34udtPWGWET99MCjkdwQWXylJag4s0fETA0orsBAKnGCyqAUNJbb_D7BYtGSV-MQ925kZMG6Di_QmfO0HQWURDYjmdRNcuy1PT_xJ1DJko8sjL42i4j3RhkNDkhqCIqyYImz2eHFWHW7rYKxTkrBhlCPMS5hRdcCswD7JYPcSBiwnVRYgyOocFGXoFvQgIZ2FX9NiZ3SMEVM1VwIGSE-qH0H2nMa8_iBvsOgOWJgKjVAvzzyzZvRVDUUHzJrikSFPNONVDU3h-04c1kVL4qIu9DfeTPN7n8AvNmYwMbro0L9-IUAeXNo4-pwF0Kt-AtTsamqWvMqnK4O_YOyLnDDlvkmnOvDC2d5uinwlQIxr6APO6qFfGLlHiLZemKoekxEE1Fx70dl-Ouhk1VIzbF3OC6XNNxeBm9BUYUiHdL0wj2H9rHgX4cz6ZmS_3VTgpD6UJh-evu5KJ2gIvjYmVbyzEN0aPNDxfvBaOm-Ezpy4bUJ2bUrOwNn-0knWkDiTvjYmNhCyefPCtCF6rpKNay8PCw_yh79C4SdEP6Q4V7LI0Tvdi5uz7kLCiBC4AT9L0ao1WDX03mkUOpjvzHDvPLmj8chW3lTVm_kA0eYGQY4wT0jzleWlfV0Q8rB2oYECNLWksA3F1zlGfcl4lQjprvTXRePkvAbMpoJEsZD3Ylq7-foLDLk4-M2LYAFZDs282AY04sFjAjQBxTELFCCuDgTIgTXSIskY_XCxpVXDbdLlbCJY7XVK45ybwtfqwlKRp8Mo0B131uQAFc-migHaUaoGujxJJk21bP8F0OmhNYHBo4FQqE1rQm2JH5bNM7txKeh5KXdJgVUVbRSr7OIp_OF5-Bx_v9eRBGAIDkue26E2-O8Rnrp5zQ5TnvecQLDaUzWavCLPwsZ0_gsOLBxNOmauNYZtF8IElCsQSFDdhoiMxXsYUm4ZYKEAy3GWq8HGTAvBhNkh1hvnI7y-d8-DOaZf_D_D98-olZfm-LUkeosLNpPB9rxYMqViCiW3KrXE9Yx0wlFm5ePKaVvR7Ym_EPhSOhJBKFPCvdTdMZSNPUcW0ZJBVByq0A9sxD51lYq3gaFyqh94S4s_ox182AQ3szGzHkdgLcnQmJG9OYvKxAVcd43eg6_gODAYhx02GjbMw-7JTAhyXSeCrlMteHyOXl8hai-3LilC3PmMzi7Vbu49dhF1s4LcVlUowen5ira44rQQaB26mdaOUoQfodgt66M3RTWGPXyK1Nb72AzSXsCKyaQPbzeb6cN0fdGSdG4ktwvR04eFNEkquo_3aKu2GmUKTD0XcRx9dYrfXjgY-X1DDTVs1YND2gRhdx7FFEeBVjtbj2UqmG3Rvd4IcHGe7OnYWw2MHDcol68SsR1KckXWwWREz7YTGUnDB2M1kx_H4W2mjclytnlHOnYU3RflegRPeSTbdzUZJvGKXCCz45luHkQWN_4DExE76D-9YqbFIz-RY5yL4h-Zs-i2xjm2K-4xCMM9nQIOqhLMqixIZQ2ldDAidKoYtbs5ppzbcBLyrZM96bq9DwRBY3aacqWdlRd-TfX0wv5KO4fo0sSh5FsuhuN0zcEV_NNXgqIEM_p14EcPqgbrAvCBQ8os70TRBQLXiF0EniSofGjxwF8kQvUk3C6Wfc8cTTeN-E6GxCVTn91HBwA1iSEZlRLMVb8_BcRJNqwbgnb_07jR6-eo42u88CR3KQdAWwbQRdMxsURFwZ0ujHXVGG0Ll6qCFBcHXWyDO1x1yHdHnw8_8yF26pnA2iPzrFR-8glMgIA-639sLuGAxjO1_ZuvJ9CAB41Az9S_jaZwaWy215Hk4-BRYD-MKmHtonwo3rrxhE67WJgbbu14efsw5nT6ow961pffgwXov5VA1Rg7nv1E8RvQOx7umWW6o8R4W6L8f2COsmPTXfgwIjoJKkjhUqAQ8ceG7cM0ET-38yaC0ObU8EkXfdGGgxI28qTEZWczG66_iM4hw7QEGCY5Cz2kbO6LETAiw9OsSigtBvDS7f0Ou0bZ41pdK7G3FmvdZAnjWPjObnDF4k4uWfn7mzt0fgj3FyqK20JezRDyGuAbUUhOvtZpc9sJpzxR34eXEZTouuALrHcGuNij4z6rx51FrQsaMtiup8QVrhtZbXtKLMYnWYSbkhuTeN2wY-xV1ZUsQlakIZszzGF7kuIG87KKWMpuPMvbXjz6Pp_gWJiIC6aQuk8xl5g0iBPycf_6Q-MtpuYxzNE2TpI1RyR9mHeXmteoRzrFiWp7yEC-QGNFyAJgxTqxM3CjHh1Jt6IddOsmn89rUo1dZM2Smijv_fbIv3avXLkIPX1KZjILeJCtpU0wAdsihDaRiRgDdx8fG__F8zuP0n7ziHas73cwrfg-Ujr6DhC0gTNxyd9dDA_oho9N7CQcy6EFmfNF2te7zpLony0859jtRv2t1TnpzAa1VvMK4u6mXuJ2XDo04_6GzLO3aPHinMdl1BcIAWnqAqWAu3euGFLTHOhXlfijut9N1OCifd_zWjhVtzlR39uFeCQBU5DyQArzQurdoMx8U1ETsnWgElxGSStRW-YQoPsAJ87eg9trqKspFpTVlAVN3t1GtoEAEhcwhe81SDssLmKGLc.7PqS6jRGTIfgTPlO7Ognvg; __cf_bm=VMWoAKEB45hQSwxXtnYXcurPaGZDJS4dMi6dIMFLwdw-1700355394-0-ATVsbq97iCaTaJbtYr8vtg1Zlbs3nLrJLKVBHYa2Jn7hhkGclqAy8Gbyn5ePEhDRqj93MsQmtayfYLqY5n4WiLY=; __cflb=0H28vVfF4aAyg2hkHFH9CkdHRXPsfCUf6VpYf2kz3RX' } - messages = [ - { - "id": str(uuid.uuid4()), - "author": {"role": "user"}, - "content": {"content_type": "text", "parts": [format_prompt(messages)]}, - }, - ] - message_id = str(uuid.uuid4()) - data = { - "action": "next", - "arkose_token": await get_arkose_token(proxy), - "messages": messages, - "conversation_id": None, - "parent_message_id": message_id, - "model": "text-davinci-002-render-sha", - "history_and_training_disabled": not auto_continue, - } - conversation_id = None - end_turn = False - while not end_turn: - if not auto_continue: - end_turn = True - async with StreamSession( - proxies=proxies, - headers=headers, - impersonate="chrome107", - timeout=timeout - ) as session: + async with StreamSession( + proxies={"https": proxy}, + impersonate="chrome110", + headers=headers, + timeout=timeout + ) as session: + data = { + "action": action, + "arkose_token": await get_arkose_token(proxy, timeout), + "conversation_id": conversation_id, + "parent_message_id": parent_id, + "model": models[model], + "history_and_training_disabled": history_disabled and not auto_continue, + } + if action != "continue": + data["messages"] = [{ + "id": str(uuid.uuid4()), + "author": {"role": "user"}, + "content": {"content_type": "text", "parts": [messages[-1]["content"]]}, + }] + first = True + end_turn = EndTurn() + while first or auto_continue and not end_turn.is_end: + first = False async with session.post(f"{cls.url}/backend-api/conversation", json=data) as response: try: response.raise_for_status() except: - raise RuntimeError(f"Response: {await response.text()}") - last_message = "" + raise RuntimeError(f"Error {response.status_code}: {await response.text()}") + last_message = 0 async for line in response.iter_lines(): if line.startswith(b"data: "): line = line[6:] @@ -82,50 +132,52 @@ class OpenaiChat(AsyncGeneratorProvider): continue if "error" in line and line["error"]: raise RuntimeError(line["error"]) - end_turn = line["message"]["end_turn"] - message_id = line["message"]["id"] - if line["conversation_id"]: - conversation_id = line["conversation_id"] if "message_type" not in line["message"]["metadata"]: continue - if line["message"]["metadata"]["message_type"] in ("next", "continue"): + if line["message"]["author"]["role"] != "assistant": + continue + if line["message"]["metadata"]["message_type"] in ("next", "continue", "variant"): + conversation_id = line["conversation_id"] + parent_id = line["message"]["id"] + if response_fields: + response_fields = False + yield ResponseFields(conversation_id, parent_id, end_turn) new_message = line["message"]["content"]["parts"][0] - yield new_message[len(last_message):] - last_message = new_message - if end_turn: - return + yield new_message[last_message:] + last_message = len(new_message) + if "finish_details" in line["message"]["metadata"]: + if line["message"]["metadata"]["finish_details"]["type"] == "max_tokens": + end_turn.end() + data = { "action": "continue", - "arkose_token": await get_arkose_token(proxy), + "arkose_token": await get_arkose_token(proxy, timeout), "conversation_id": conversation_id, - "parent_message_id": message_id, - "model": "text-davinci-002-render-sha", + "parent_message_id": parent_id, + "model": models[model], "history_and_training_disabled": False, } await asyncio.sleep(5) @classmethod - async def browse_access_token(cls) -> str: + async def browse_access_token(cls, proxy: str = None) -> str: def browse() -> str: try: from selenium.webdriver.common.by import By from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support import expected_conditions as EC - driver = get_browser() + driver = get_browser("~/openai", proxy=proxy) except ImportError: return - - driver.get(f"{cls.url}/") try: + driver.get(f"{cls.url}/") WebDriverWait(driver, 1200).until( EC.presence_of_element_located((By.ID, "prompt-textarea")) ) javascript = "return (await (await fetch('/api/auth/session')).json())['accessToken']" return driver.execute_script(javascript) finally: - driver.close() - time.sleep(0.1) driver.quit() loop = get_event_loop() return await loop.run_in_executor( @@ -134,22 +186,9 @@ class OpenaiChat(AsyncGeneratorProvider): ) @classmethod - async def fetch_access_token(cls, cookies: dict, proxies: dict = None) -> str: - async with StreamSession(proxies=proxies, cookies=cookies, impersonate="chrome107") as session: - async with session.get(f"{cls.url}/api/auth/session") as response: - response.raise_for_status() - auth = await response.json() - if "accessToken" in auth: - return auth["accessToken"] - - @classmethod - async def get_access_token(cls, cookies: dict = None, proxies: dict = None) -> str: - if not cls._access_token: - cookies = cookies if cookies else get_cookies("chat.openai.com") - if cookies: - cls._access_token = await cls.fetch_access_token(cookies, proxies) + async def get_access_token(cls, proxy: str = None) -> str: if not cls._access_token: - cls._access_token = await cls.browse_access_token() + cls._access_token = await cls.browse_access_token(proxy) if not cls._access_token: raise RuntimeError("Read access token failed") return cls._access_token @@ -163,12 +202,11 @@ class OpenaiChat(AsyncGeneratorProvider): ("stream", "bool"), ("proxy", "str"), ("access_token", "str"), - ("cookies", "dict[str, str]") ] param = ", ".join([": ".join(p) for p in params]) return f"g4f.provider.{cls.__name__} supports: ({param})" -async def get_arkose_token(proxy: str = None) -> str: +async def get_arkose_token(proxy: str = None, timeout: int = None) -> str: config = { "pkey": "3D86FBBA-9D22-402A-B512-3420086BA6CC", "surl": "https://tcr9i.chat.openai.com", @@ -181,10 +219,98 @@ async def get_arkose_token(proxy: str = None) -> str: async with StreamSession( proxies={"https": proxy}, impersonate="chrome107", + timeout=timeout ) as session: async with session.post(**args_for_request) as response: response.raise_for_status() decoded_json = await response.json() if "token" in decoded_json: return decoded_json["token"] - raise RuntimeError(f"Response: {decoded_json}") \ No newline at end of file + raise RuntimeError(f"Response: {decoded_json}") + +class EndTurn(): + def __init__(self): + self.is_end = False + + def end(self): + self.is_end = True + +class ResponseFields(): + def __init__( + self, + conversation_id: str, + message_id: str, + end_turn: EndTurn + ): + self.conversation_id = conversation_id + self.message_id = message_id + self._end_turn = end_turn + +class Response(): + def __init__( + self, + generator: AsyncResult, + fields: ResponseFields, + action: str, + messages: Messages, + options: dict + ): + self.aiter, self.copy = tee(generator) + self.fields = fields + self.action = action + self._messages = messages + self._options = options + + def __aiter__(self): + return self.aiter + + @async_cached_property + async def message(self) -> str: + return "".join([chunk async for chunk in self.copy]) + + async def next(self, prompt: str, **kwargs) -> Response: + return await OpenaiChat.create( + **self._options, + prompt=prompt, + messages=await self.messages, + action="next", + conversation_id=self.fields.conversation_id, + parent_id=self.fields.message_id, + **kwargs + ) + + async def do_continue(self, **kwargs) -> Response: + if self.end_turn: + raise RuntimeError("Can't continue message. Message already finished.") + return await OpenaiChat.create( + **self._options, + messages=await self.messages, + action="continue", + conversation_id=self.fields.conversation_id, + parent_id=self.fields.message_id, + **kwargs + ) + + async def variant(self, **kwargs) -> Response: + if self.action != "next": + raise RuntimeError("Can't create variant with continue or variant request.") + return await OpenaiChat.create( + **self._options, + messages=self._messages, + action="variant", + conversation_id=self.fields.conversation_id, + parent_id=self.fields.message_id, + **kwargs + ) + + @async_cached_property + async def messages(self): + messages = self._messages + messages.append({ + "role": "assistant", "content": await self.message + }) + return messages + + @property + def end_turn(self): + return self.fields._end_turn.is_end \ No newline at end of file diff --git a/g4f/Provider/needs_auth/Poe.py b/g4f/Provider/needs_auth/Poe.py index a894bcb1..1c8c97d7 100644 --- a/g4f/Provider/needs_auth/Poe.py +++ b/g4f/Provider/needs_auth/Poe.py @@ -4,7 +4,7 @@ import time from ...typing import CreateResult, Messages from ..base_provider import BaseProvider -from ..helper import WebDriver, format_prompt, get_browser +from ..helper import WebDriver, WebDriverSession, format_prompt models = { "meta-llama/Llama-2-7b-chat-hf": {"name": "Llama-2-7b"}, @@ -33,7 +33,7 @@ class Poe(BaseProvider): messages: Messages, stream: bool, proxy: str = None, - browser: WebDriver = None, + web_driver: WebDriver = None, user_data_dir: str = None, headless: bool = True, **kwargs @@ -43,56 +43,54 @@ class Poe(BaseProvider): elif model not in models: raise ValueError(f"Model are not supported: {model}") prompt = format_prompt(messages) - driver = browser if browser else get_browser(user_data_dir, headless, proxy) - script = """ -window._message = window._last_message = ""; -window._message_finished = false; -class ProxiedWebSocket extends WebSocket { - constructor(url, options) { - super(url, options); - this.addEventListener("message", (e) => { - const data = JSON.parse(JSON.parse(e.data)["messages"][0])["payload"]["data"]; - if ("messageAdded" in data) { - if (data["messageAdded"]["author"] != "human") { - window._message = data["messageAdded"]["text"]; - if (data["messageAdded"]["state"] == "complete") { - window._message_finished = true; + session = WebDriverSession(web_driver, user_data_dir, headless, proxy=proxy) + with session as driver: + from selenium.webdriver.common.by import By + from selenium.webdriver.support.ui import WebDriverWait + from selenium.webdriver.support import expected_conditions as EC + + driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", { + "source": """ + window._message = window._last_message = ""; + window._message_finished = false; + class ProxiedWebSocket extends WebSocket { + constructor(url, options) { + super(url, options); + this.addEventListener("message", (e) => { + const data = JSON.parse(JSON.parse(e.data)["messages"][0])["payload"]["data"]; + if ("messageAdded" in data) { + if (data["messageAdded"]["author"] != "human") { + window._message = data["messageAdded"]["text"]; + if (data["messageAdded"]["state"] == "complete") { + window._message_finished = true; + } } } - } - }); - } -} -window.WebSocket = ProxiedWebSocket; -""" - driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", { - "source": script - }) - - from selenium.webdriver.common.by import By - from selenium.webdriver.support.ui import WebDriverWait - from selenium.webdriver.support import expected_conditions as EC + }); + } + } + window.WebSocket = ProxiedWebSocket; + """ + }) - try: - driver.get(f"{cls.url}/{models[model]['name']}") - wait = WebDriverWait(driver, 10 if headless else 240) - wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea[class^='GrowingTextArea']"))) - except: - # Reopen browser for login - if not browser: - driver.quit() - driver = get_browser(None, False, proxy) + try: driver.get(f"{cls.url}/{models[model]['name']}") - wait = WebDriverWait(driver, 240) + wait = WebDriverWait(driver, 10 if headless else 240) wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea[class^='GrowingTextArea']"))) - else: - raise RuntimeError("Prompt textarea not found. You may not be logged in.") + except: + # Reopen browser for login + if not web_driver: + driver = session.reopen(headless=False) + driver.get(f"{cls.url}/{models[model]['name']}") + wait = WebDriverWait(driver, 240) + wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea[class^='GrowingTextArea']"))) + else: + raise RuntimeError("Prompt textarea not found. You may not be logged in.") - driver.find_element(By.CSS_SELECTOR, "footer textarea[class^='GrowingTextArea']").send_keys(prompt) - driver.find_element(By.CSS_SELECTOR, "footer button[class*='ChatMessageSendButton']").click() + driver.find_element(By.CSS_SELECTOR, "footer textarea[class^='GrowingTextArea']").send_keys(prompt) + driver.find_element(By.CSS_SELECTOR, "footer button[class*='ChatMessageSendButton']").click() - try: script = """ if(window._message && window._message != window._last_message) { try { @@ -113,9 +111,4 @@ if(window._message && window._message != window._last_message) { elif chunk != "": break else: - time.sleep(0.1) - finally: - if not browser: - driver.close() - time.sleep(0.1) - driver.quit() \ No newline at end of file + time.sleep(0.1) \ No newline at end of file diff --git a/g4f/Provider/needs_auth/Theb.py b/g4f/Provider/needs_auth/Theb.py index 89c69727..cf33f0c6 100644 --- a/g4f/Provider/needs_auth/Theb.py +++ b/g4f/Provider/needs_auth/Theb.py @@ -4,7 +4,7 @@ import time from ...typing import CreateResult, Messages from ..base_provider import BaseProvider -from ..helper import WebDriver, format_prompt, get_browser +from ..helper import WebDriver, WebDriverSession, format_prompt models = { "theb-ai": "TheB.AI", @@ -44,26 +44,60 @@ class Theb(BaseProvider): messages: Messages, stream: bool, proxy: str = None, - browser: WebDriver = None, - headless: bool = True, + web_driver: WebDriver = None, + virtual_display: bool = True, **kwargs ) -> CreateResult: if model in models: model = models[model] prompt = format_prompt(messages) - driver = browser if browser else get_browser(None, headless, proxy) + web_session = WebDriverSession(web_driver, virtual_display=virtual_display, proxy=proxy) + with web_session as driver: + from selenium.webdriver.common.by import By + from selenium.webdriver.support.ui import WebDriverWait + from selenium.webdriver.support import expected_conditions as EC + from selenium.webdriver.common.keys import Keys - from selenium.webdriver.common.by import By - from selenium.webdriver.support.ui import WebDriverWait - from selenium.webdriver.support import expected_conditions as EC - from selenium.webdriver.common.keys import Keys + # Register fetch hook + script = """ +window._fetch = window.fetch; +window.fetch = (url, options) => { + // Call parent fetch method + const result = window._fetch(url, options); + if (!url.startsWith("/api/conversation")) { + return result; + } + // Load response reader + result.then((response) => { + if (!response.body.locked) { + window._reader = response.body.getReader(); + } + }); + // Return dummy response + return new Promise((resolve, reject) => { + resolve(new Response(new ReadableStream())) + }); +} +window._last_message = ""; +""" + driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", { + "source": script + }) + + try: + driver.get(f"{cls.url}/home") + wait = WebDriverWait(driver, 5) + wait.until(EC.visibility_of_element_located((By.ID, "textareaAutosize"))) + except: + driver = web_session.reopen() + driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", { + "source": script + }) + driver.get(f"{cls.url}/home") + wait = WebDriverWait(driver, 240) + wait.until(EC.visibility_of_element_located((By.ID, "textareaAutosize"))) - - try: - driver.get(f"{cls.url}/home") - wait = WebDriverWait(driver, 10 if headless else 240) - wait.until(EC.visibility_of_element_located((By.TAG_NAME, "body"))) - time.sleep(0.1) + time.sleep(200) try: driver.find_element(By.CSS_SELECTOR, ".driver-overlay").click() driver.find_element(By.CSS_SELECTOR, ".driver-overlay").click() @@ -87,29 +121,6 @@ class Theb(BaseProvider): button = container.find_element(By.CSS_SELECTOR, "button.btn-blue.btn-small.border") button.click() - # Register fetch hook - script = """ -window._fetch = window.fetch; -window.fetch = (url, options) => { - // Call parent fetch method - const result = window._fetch(url, options); - if (!url.startsWith("/api/conversation")) { - return result; - } - // Load response reader - result.then((response) => { - if (!response.body.locked) { - window._reader = response.body.getReader(); - } - }); - // Return dummy response - return new Promise((resolve, reject) => { - resolve(new Response(new ReadableStream())) - }); -} -window._last_message = ""; -""" - driver.execute_script(script) # Submit prompt wait.until(EC.visibility_of_element_located((By.ID, "textareaAutosize"))) @@ -150,9 +161,4 @@ return ''; elif chunk != "": break else: - time.sleep(0.1) - finally: - if not browser: - driver.close() - time.sleep(0.1) - driver.quit() \ No newline at end of file + time.sleep(0.1) \ No newline at end of file diff --git a/g4f/requests.py b/g4f/requests.py index 92165c64..b70789d4 100644 --- a/g4f/requests.py +++ b/g4f/requests.py @@ -1,24 +1,15 @@ from __future__ import annotations -import warnings import json -import asyncio +from contextlib import asynccontextmanager from functools import partialmethod -from asyncio import Future, Queue -from typing import AsyncGenerator, Union, Optional +from typing import AsyncGenerator from curl_cffi.requests import AsyncSession, Response -import curl_cffi - -is_newer_0_5_8: bool = hasattr(AsyncSession, "_set_cookies") or hasattr(curl_cffi.requests.Cookies, "get_cookies_for_curl") -is_newer_0_5_9: bool = hasattr(curl_cffi.AsyncCurl, "remove_handle") -is_newer_0_5_10: bool = hasattr(AsyncSession, "release_curl") - class StreamResponse: - def __init__(self, inner: Response, queue: Queue[bytes]) -> None: + def __init__(self, inner: Response) -> None: self.inner: Response = inner - self.queue: Queue[bytes] = queue self.request = inner.request self.status_code: int = inner.status_code self.reason: str = inner.reason @@ -27,148 +18,32 @@ class StreamResponse: self.cookies = inner.cookies async def text(self) -> str: - content: bytes = await self.read() - return content.decode() + return await self.inner.atext() def raise_for_status(self) -> None: - if not self.ok: - raise RuntimeError(f"HTTP Error {self.status_code}: {self.reason}") + self.inner.raise_for_status() async def json(self, **kwargs) -> dict: - return json.loads(await self.read(), **kwargs) - - async def iter_lines( - self, chunk_size: Optional[int] = None, decode_unicode: bool = False, delimiter: Optional[str] = None - ) -> AsyncGenerator[bytes, None]: - """ - Copied from: https://requests.readthedocs.io/en/latest/_modules/requests/models/ - which is under the License: Apache 2.0 - """ - - pending: bytes = None - - async for chunk in self.iter_content( - chunk_size=chunk_size, decode_unicode=decode_unicode - ): - if pending is not None: - chunk = pending + chunk - lines = chunk.split(delimiter) if delimiter else chunk.splitlines() - if lines and lines[-1] and chunk and lines[-1][-1] == chunk[-1]: - pending = lines.pop() - else: - pending = None + return json.loads(await self.inner.acontent(), **kwargs) - for line in lines: - yield line + async def iter_lines(self) -> AsyncGenerator[bytes, None]: + async for line in self.inner.aiter_lines(): + yield line - if pending is not None: - yield pending - - async def iter_content( - self, chunk_size: Optional[int] = None, decode_unicode: bool = False - ) -> AsyncGenerator[bytes, None]: - if chunk_size: - warnings.warn("chunk_size is ignored, there is no way to tell curl that.") - if decode_unicode: - raise NotImplementedError() - while True: - chunk = await self.queue.get() - if chunk is None: - return + async def iter_content(self) -> AsyncGenerator[bytes, None]: + async for chunk in self.inner.aiter_content(): yield chunk - async def read(self) -> bytes: - return b"".join([chunk async for chunk in self.iter_content()]) - - -class StreamRequest: - def __init__(self, session: AsyncSession, method: str, url: str, **kwargs: Union[bool, int, str]) -> None: - self.session: AsyncSession = session - self.loop: asyncio.AbstractEventLoop = session.loop if session.loop else asyncio.get_running_loop() - self.queue: Queue[bytes] = Queue() - self.method: str = method - self.url: str = url - self.options: dict = kwargs - self.handle: Optional[curl_cffi.AsyncCurl] = None - - def _on_content(self, data: bytes) -> None: - if not self.enter.done(): - self.enter.set_result(None) - self.queue.put_nowait(data) - - def _on_done(self, task: Future) -> None: - if not self.enter.done(): - self.enter.set_result(None) - self.queue.put_nowait(None) - - self.loop.call_soon(self.release_curl) - - async def fetch(self) -> StreamResponse: - if self.handle: - raise RuntimeError("Request already started") - self.curl: curl_cffi.AsyncCurl = await self.session.pop_curl() - self.enter: asyncio.Future = self.loop.create_future() - if is_newer_0_5_10: - request, _, header_buffer, _, _ = self.session._set_curl_options( - self.curl, - self.method, - self.url, - content_callback=self._on_content, - **self.options - ) - else: - request, _, header_buffer = self.session._set_curl_options( - self.curl, - self.method, - self.url, - content_callback=self._on_content, - **self.options - ) - if is_newer_0_5_9: - self.handle = self.session.acurl.add_handle(self.curl) - else: - await self.session.acurl.add_handle(self.curl, False) - self.handle = self.session.acurl._curl2future[self.curl] - self.handle.add_done_callback(self._on_done) - # Wait for headers - await self.enter - # Raise exceptions - if self.handle.done(): - self.handle.result() - if is_newer_0_5_8: - response = self.session._parse_response(self.curl, _, header_buffer) - response.request = request - else: - response = self.session._parse_response(self.curl, request, _, header_buffer) - return StreamResponse(response, self.queue) - - async def __aenter__(self) -> StreamResponse: - return await self.fetch() - - async def __aexit__(self, *args) -> None: - self.release_curl() - - def release_curl(self) -> None: - if is_newer_0_5_10: - self.session.release_curl(self.curl) - return - if not self.curl: - return - self.curl.clean_after_perform() - if is_newer_0_5_9: - self.session.acurl.remove_handle(self.curl) - elif not self.handle.done() and not self.handle.cancelled(): - self.session.acurl.set_result(self.curl) - self.curl.reset() - self.session.push_curl(self.curl) - self.curl = None - - class StreamSession(AsyncSession): - def request( + @asynccontextmanager + async def request( self, method: str, url: str, **kwargs - ) -> StreamRequest: - return StreamRequest(self, method, url, **kwargs) + ) -> AsyncGenerator[StreamResponse]: + response = await super().request(method, url, stream=True, **kwargs) + try: + yield StreamResponse(response) + finally: + await response.aclose() head = partialmethod(request, "HEAD") get = partialmethod(request, "GET") -- cgit v1.2.3