diff options
author | H Lohaus <hlohaus@users.noreply.github.com> | 2024-05-15 02:28:45 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-15 02:28:45 +0200 |
commit | 008ed60d980efcca18977da93caac14c6526e2bd (patch) | |
tree | 3dcae2aaf84bb7b7e0c44b76e563313e6ad364b1 /g4f | |
parent | gpt-4o (beta) (diff) | |
parent | Update chatgpt url, uvloop support (diff) | |
download | gpt4free-008ed60d980efcca18977da93caac14c6526e2bd.tar gpt4free-008ed60d980efcca18977da93caac14c6526e2bd.tar.gz gpt4free-008ed60d980efcca18977da93caac14c6526e2bd.tar.bz2 gpt4free-008ed60d980efcca18977da93caac14c6526e2bd.tar.lz gpt4free-008ed60d980efcca18977da93caac14c6526e2bd.tar.xz gpt4free-008ed60d980efcca18977da93caac14c6526e2bd.tar.zst gpt4free-008ed60d980efcca18977da93caac14c6526e2bd.zip |
Diffstat (limited to 'g4f')
-rw-r--r-- | g4f/Provider/needs_auth/Gemini.py | 25 | ||||
-rw-r--r-- | g4f/Provider/needs_auth/OpenaiChat.py | 15 | ||||
-rw-r--r-- | g4f/Provider/openai/har_file.py | 2 | ||||
-rw-r--r-- | g4f/Provider/openai/proofofwork.py | 7 | ||||
-rw-r--r-- | g4f/client/async_client.py | 11 | ||||
-rw-r--r-- | g4f/gui/client/index.html | 6 | ||||
-rw-r--r-- | g4f/gui/client/static/css/dracula.min.css | 7 | ||||
-rw-r--r-- | g4f/gui/client/static/css/style.css | 3 | ||||
-rw-r--r-- | g4f/gui/client/static/js/chat.v1.js | 46 | ||||
-rw-r--r-- | g4f/providers/retry_provider.py | 186 |
10 files changed, 222 insertions, 86 deletions
diff --git a/g4f/Provider/needs_auth/Gemini.py b/g4f/Provider/needs_auth/Gemini.py index 75cdd199..25ad1c6e 100644 --- a/g4f/Provider/needs_auth/Gemini.py +++ b/g4f/Provider/needs_auth/Gemini.py @@ -17,12 +17,12 @@ except ImportError: pass from ... import debug -from ...typing import Messages, Cookies, ImageType, AsyncResult +from ...typing import Messages, Cookies, ImageType, AsyncResult, AsyncIterator from ..base_provider import AsyncGeneratorProvider from ..helper import format_prompt, get_cookies from ...requests.raise_for_status import raise_for_status from ...errors import MissingAuthError, MissingRequirementsError -from ...image import to_bytes, ImageResponse +from ...image import to_bytes, to_data_uri, ImageResponse from ...webdriver import get_browser, get_driver_cookies REQUEST_HEADERS = { @@ -59,7 +59,7 @@ class Gemini(AsyncGeneratorProvider): _cookies: Cookies = None @classmethod - async def nodriver_login(cls) -> Cookies: + async def nodriver_login(cls) -> AsyncIterator[str]: try: import nodriver as uc except ImportError: @@ -72,6 +72,9 @@ class Gemini(AsyncGeneratorProvider): if debug.logging: print(f"Open nodriver with user_dir: {user_data_dir}") browser = await uc.start(user_data_dir=user_data_dir) + login_url = os.environ.get("G4F_LOGIN_URL") + if login_url: + yield f"Please login: [Google Gemini]({login_url})\n\n" page = await browser.get(f"{cls.url}/app") await page.select("div.ql-editor.textarea", 240) cookies = {} @@ -79,10 +82,10 @@ class Gemini(AsyncGeneratorProvider): if c.domain.endswith(".google.com"): cookies[c.name] = c.value await page.close() - return cookies + cls._cookies = cookies @classmethod - async def webdriver_login(cls, proxy: str): + async def webdriver_login(cls, proxy: str) -> AsyncIterator[str]: driver = None try: driver = get_browser(proxy=proxy) @@ -131,13 +134,14 @@ class Gemini(AsyncGeneratorProvider): ) as session: snlm0e = await cls.fetch_snlm0e(session, cls._cookies) if cls._cookies else None if not snlm0e: - cls._cookies = await cls.nodriver_login(); + async for chunk in cls.nodriver_login(): + yield chunk if cls._cookies is None: async for chunk in cls.webdriver_login(proxy): yield chunk if not snlm0e: - if "__Secure-1PSID" not in cls._cookies: + if cls._cookies is None or "__Secure-1PSID" not in cls._cookies: raise MissingAuthError('Missing "__Secure-1PSID" cookie') snlm0e = await cls.fetch_snlm0e(session, cls._cookies) if not snlm0e: @@ -193,6 +197,13 @@ class Gemini(AsyncGeneratorProvider): image = fetch.headers["location"] resolved_images.append(image) preview.append(image.replace('=s512', '=s200')) + # preview_url = image.replace('=s512', '=s200') + # async with client.get(preview_url) as fetch: + # preview_data = to_data_uri(await fetch.content.read()) + # async with client.get(image) as fetch: + # data = to_data_uri(await fetch.content.read()) + # preview.append(preview_data) + # resolved_images.append(data) yield ImageResponse(resolved_images, image_prompt, {"orginal_links": images, "preview": preview}) def build_request( diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index 056a3702..03ea4539 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -38,7 +38,7 @@ DEFAULT_HEADERS = { "accept": "*/*", "accept-encoding": "gzip, deflate, br, zstd", "accept-language": "en-US,en;q=0.5", - "referer": "https://chat.openai.com/", + "referer": "https://chatgpt.com/", "sec-ch-ua": "\"Brave\";v=\"123\", \"Not:A-Brand\";v=\"8\", \"Chromium\";v=\"123\"", "sec-ch-ua-mobile": "?0", "sec-ch-ua-platform": "\"Windows\"", @@ -53,15 +53,15 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): """A class for creating and managing conversations with OpenAI chat service""" label = "OpenAI ChatGPT" - url = "https://chat.openai.com" + url = "https://chatgpt.com" working = True supports_gpt_35_turbo = True supports_gpt_4 = True supports_message_history = True supports_system_message = True default_model = None - default_vision_model = "gpt-4-vision" - models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo"] + default_vision_model = "gpt-4o" + models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo", "gpt-4o"] model_aliases = { "text-davinci-002-render-sha": "gpt-3.5-turbo", "": "gpt-3.5-turbo", @@ -442,6 +442,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): try: image_request = await cls.upload_image(session, cls._headers, image, image_name) if image else None except Exception as e: + image_request = None if debug.logging: print("OpenaiChat: Upload image failed") print(f"{e.__class__.__name__}: {e}") @@ -601,7 +602,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): this._fetch = this.fetch; this.fetch = async (url, options) => { const response = await this._fetch(url, options); - if (url == "https://chat.openai.com/backend-api/conversation") { + if (url == "https://chatgpt.com/backend-api/conversation") { this._headers = options.headers; return response; } @@ -637,7 +638,7 @@ this.fetch = async (url, options) => { if debug.logging: print(f"Open nodriver with user_dir: {user_data_dir}") browser = await uc.start(user_data_dir=user_data_dir) - page = await browser.get("https://chat.openai.com/") + page = await browser.get("https://chatgpt.com/") await page.select("[id^=headlessui-menu-button-]", 240) api_key = await page.evaluate( "(async () => {" @@ -652,7 +653,7 @@ this.fetch = async (url, options) => { ) cookies = {} for c in await page.browser.cookies.get_all(): - if c.domain.endswith("chat.openai.com"): + if c.domain.endswith("chatgpt.com"): cookies[c.name] = c.value user_agent = await page.evaluate("window.navigator.userAgent") await page.close() diff --git a/g4f/Provider/openai/har_file.py b/g4f/Provider/openai/har_file.py index 6a34c97a..220c20bf 100644 --- a/g4f/Provider/openai/har_file.py +++ b/g4f/Provider/openai/har_file.py @@ -26,7 +26,7 @@ class arkReq: self.userAgent = userAgent arkPreURL = "https://tcr9i.chat.openai.com/fc/gt2/public_key/35536E1E-65B4-4D96-9D97-6ADB7EFF8147" -sessionUrl = "https://chat.openai.com/api/auth/session" +sessionUrl = "https://chatgpt.com/api/auth/session" chatArk: arkReq = None accessToken: str = None cookies: dict = None diff --git a/g4f/Provider/openai/proofofwork.py b/g4f/Provider/openai/proofofwork.py index e44ef6f7..51d96bc4 100644 --- a/g4f/Provider/openai/proofofwork.py +++ b/g4f/Provider/openai/proofofwork.py @@ -16,12 +16,9 @@ def generate_proof_token(required: bool, seed: str, difficulty: str, user_agent: # Get current UTC time now_utc = datetime.now(timezone.utc) - # Convert UTC time to Eastern Time - now_et = now_utc.astimezone(timezone(timedelta(hours=-5))) + parse_time = now_utc.strftime('%a, %d %b %Y %H:%M:%S GMT') - parse_time = now_et.strftime('%a, %d %b %Y %H:%M:%S GMT') - - config = [core + screen, parse_time, 4294705152, 0, user_agent] + config = [core + screen, parse_time, None, 0, user_agent, "https://tcr9i.chat.openai.com/v2/35536E1E-65B4-4D96-9D97-6ADB7EFF8147/api.js","dpl=53d243de46ff04dadd88d293f088c2dd728f126f","en","en-US",442,"pluginsā[object PluginArray]","","alert"] diff_len = len(difficulty) // 2 diff --git a/g4f/client/async_client.py b/g4f/client/async_client.py index 8e1ee33c..07ad3357 100644 --- a/g4f/client/async_client.py +++ b/g4f/client/async_client.py @@ -11,10 +11,9 @@ from .types import AsyncIterResponse, ImageProvider from .image_models import ImageModels from .helper import filter_json, find_stop, filter_none, cast_iter_async from .service import get_last_provider, get_model_and_provider -from ..typing import Union, Iterator, Messages, AsyncIterator, ImageType +from ..typing import Union, Messages, AsyncIterator, ImageType from ..errors import NoImageResponseError from ..image import ImageResponse as ImageProviderResponse -from ..providers.base_provider import AsyncGeneratorProvider try: anext @@ -88,7 +87,7 @@ def create_response( api_key: str = None, **kwargs ): - has_asnyc = isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider) + has_asnyc = hasattr(provider, "create_async_generator") if has_asnyc: create = provider.create_async_generator else: @@ -157,7 +156,7 @@ class Chat(): def __init__(self, client: AsyncClient, provider: ProviderType = None): self.completions = Completions(client, provider) -async def iter_image_response(response: Iterator) -> Union[ImagesResponse, None]: +async def iter_image_response(response: AsyncIterator) -> Union[ImagesResponse, None]: async for chunk in response: if isinstance(chunk, ImageProviderResponse): return ImagesResponse([Image(image) for image in chunk.get_list()]) @@ -182,7 +181,7 @@ class Images(): async def generate(self, prompt, model: str = "", **kwargs) -> ImagesResponse: provider = self.models.get(model, self.provider) - if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider): + if hasattr(provider, "create_async_generator"): response = create_image(self.client, provider, prompt, **kwargs) else: response = await provider.create_async(prompt) @@ -195,7 +194,7 @@ class Images(): async def create_variation(self, image: ImageType, model: str = None, **kwargs): provider = self.models.get(model, self.provider) result = None - if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider): + if hasattr(provider, "create_async_generator"): response = provider.create_async_generator( "", [{"role": "user", "content": "create a image like this"}], diff --git a/g4f/gui/client/index.html b/g4f/gui/client/index.html index 66bcaaab..064e4594 100644 --- a/g4f/gui/client/index.html +++ b/g4f/gui/client/index.html @@ -19,8 +19,7 @@ <script src="/static/js/highlightjs-copy.min.js"></script> <script src="/static/js/chat.v1.js" defer></script> <script src="https://cdn.jsdelivr.net/npm/markdown-it@13.0.1/dist/markdown-it.min.js"></script> - <link rel="stylesheet" - href="//cdn.jsdelivr.net/gh/highlightjs/cdn-release@11.7.0/build/styles/base16/dracula.min.css"> + <link rel="stylesheet" href="/static/css/dracula.min.css"> <script> MathJax = { chtml: { @@ -244,8 +243,5 @@ <div class="mobile-sidebar"> <i class="fa-solid fa-bars"></i> </div> - <script> - </script> </body> - </html> diff --git a/g4f/gui/client/static/css/dracula.min.css b/g4f/gui/client/static/css/dracula.min.css new file mode 100644 index 00000000..729bbbfb --- /dev/null +++ b/g4f/gui/client/static/css/dracula.min.css @@ -0,0 +1,7 @@ +/*! + Theme: Dracula + Author: Mike Barkmin (http://github.com/mikebarkmin) based on Dracula Theme (http://github.com/dracula) + License: ~ MIT (or more permissive) [via base16-schemes-source] + Maintainer: @highlightjs/core-team + Version: 2021.09.0 +*/pre code.hljs{display:block;overflow-x:auto;padding:1em}code.hljs{padding:3px 5px}.hljs{color:#e9e9f4;background:#282936}.hljs ::selection,.hljs::selection{background-color:#4d4f68;color:#e9e9f4}.hljs-comment{color:#626483}.hljs-tag{color:#62d6e8}.hljs-operator,.hljs-punctuation,.hljs-subst{color:#e9e9f4}.hljs-operator{opacity:.7}.hljs-bullet,.hljs-deletion,.hljs-name,.hljs-selector-tag,.hljs-template-variable,.hljs-variable{color:#ea51b2}.hljs-attr,.hljs-link,.hljs-literal,.hljs-number,.hljs-symbol,.hljs-variable.constant_{color:#b45bcf}.hljs-class .hljs-title,.hljs-title,.hljs-title.class_{color:#00f769}.hljs-strong{font-weight:700;color:#00f769}.hljs-addition,.hljs-code,.hljs-string,.hljs-title.class_.inherited__{color:#ebff87}.hljs-built_in,.hljs-doctag,.hljs-keyword.hljs-atrule,.hljs-quote,.hljs-regexp{color:#a1efe4}.hljs-attribute,.hljs-function .hljs-title,.hljs-section,.hljs-title.function_,.ruby .hljs-property{color:#62d6e8}.diff .hljs-meta,.hljs-keyword,.hljs-template-tag,.hljs-type{color:#b45bcf}.hljs-emphasis{color:#b45bcf;font-style:italic}.hljs-meta,.hljs-meta .hljs-keyword,.hljs-meta .hljs-string{color:#00f769}.hljs-meta .hljs-keyword,.hljs-meta-keyword{font-weight:700}
\ No newline at end of file diff --git a/g4f/gui/client/static/css/style.css b/g4f/gui/client/static/css/style.css index 979f9f96..01bc17fa 100644 --- a/g4f/gui/client/static/css/style.css +++ b/g4f/gui/client/static/css/style.css @@ -381,7 +381,8 @@ body { } .message .count .fa-clipboard, -.message .count .fa-volume-high { +.message .count .fa-volume-high, +.message .count .fa-rotate { z-index: 1000; cursor: pointer; } diff --git a/g4f/gui/client/static/js/chat.v1.js b/g4f/gui/client/static/js/chat.v1.js index 23605ed4..a0178e63 100644 --- a/g4f/gui/client/static/js/chat.v1.js +++ b/g4f/gui/client/static/js/chat.v1.js @@ -109,8 +109,9 @@ const register_message_buttons = async () => { let playlist = []; function play_next() { const next = playlist.shift(); - if (next) + if (next && el.dataset.do_play) { next.play(); + } } if (el.dataset.stopped) { el.classList.remove("blink") @@ -179,6 +180,20 @@ const register_message_buttons = async () => { }); } }); + document.querySelectorAll(".message .fa-rotate").forEach(async (el) => { + if (!("click" in el.dataset)) { + el.dataset.click = "true"; + el.addEventListener("click", async () => { + const message_el = el.parentElement.parentElement.parentElement; + el.classList.add("clicked"); + setTimeout(() => el.classList.remove("clicked"), 1000); + prompt_lock = true; + await hide_message(window.conversation_id, message_el.dataset.index); + window.token = message_id(); + await ask_gpt(message_el.dataset.index); + }) + } + }); } const delete_conversations = async () => { @@ -257,9 +272,9 @@ const remove_cancel_button = async () => { }, 300); }; -const prepare_messages = (messages, filter_last_message=true) => { +const prepare_messages = (messages, message_index = -1) => { // Removes none user messages at end - if (filter_last_message) { + if (message_index == -1) { let last_message; while (last_message = messages.pop()) { if (last_message["role"] == "user") { @@ -267,14 +282,16 @@ const prepare_messages = (messages, filter_last_message=true) => { break; } } + } else if (message_index >= 0) { + messages = messages.filter((_, index) => message_index >= index); } // Remove history, if it's selected if (document.getElementById('history')?.checked) { - if (filter_last_message) { - messages = [messages.pop()]; - } else { + if (message_index == null) { messages = [messages.pop(), messages.pop()]; + } else { + messages = [messages.pop()]; } } @@ -361,11 +378,11 @@ imageInput?.addEventListener("click", (e) => { } }); -const ask_gpt = async () => { +const ask_gpt = async (message_index = -1) => { regenerate.classList.add(`regenerate-hidden`); messages = await get_messages(window.conversation_id); total_messages = messages.length; - messages = prepare_messages(messages); + messages = prepare_messages(messages, message_index); stop_generating.classList.remove(`stop_generating-hidden`); @@ -528,6 +545,7 @@ const hide_option = async (conversation_id) => { const span_el = document.createElement("span"); span_el.innerText = input_el.value; span_el.classList.add("convo-title"); + span_el.onclick = () => set_conversation(conversation_id); left_el.removeChild(input_el); left_el.appendChild(span_el); } @@ -616,7 +634,7 @@ const load_conversation = async (conversation_id, scroll=true) => { } if (window.GPTTokenizer_cl100k_base) { - const filtered = prepare_messages(messages, false); + const filtered = prepare_messages(messages, null); if (filtered.length > 0) { last_model = last_model?.startsWith("gpt-4") ? "gpt-4" : "gpt-3.5-turbo" let count_total = GPTTokenizer_cl100k_base?.encodeChat(filtered, last_model).length @@ -683,15 +701,15 @@ async function save_system_message() { await save_conversation(window.conversation_id, conversation); } } - -const hide_last_message = async (conversation_id) => { +const hide_message = async (conversation_id, message_index =- 1) => { const conversation = await get_conversation(conversation_id) - const last_message = conversation.items.pop(); + message_index = message_index == -1 ? conversation.items.length - 1 : message_index + const last_message = message_index in conversation.items ? conversation.items[message_index] : null; if (last_message !== null) { if (last_message["role"] == "assistant") { last_message["regenerate"] = true; } - conversation.items.push(last_message); + conversation.items[message_index] = last_message; } await save_conversation(conversation_id, conversation); }; @@ -790,7 +808,7 @@ document.getElementById("cancelButton").addEventListener("click", async () => { document.getElementById("regenerateButton").addEventListener("click", async () => { prompt_lock = true; - await hide_last_message(window.conversation_id); + await hide_message(window.conversation_id); window.token = message_id(); await ask_gpt(); }); diff --git a/g4f/providers/retry_provider.py b/g4f/providers/retry_provider.py index d64e8471..e2520437 100644 --- a/g4f/providers/retry_provider.py +++ b/g4f/providers/retry_provider.py @@ -3,18 +3,16 @@ from __future__ import annotations import asyncio import random -from ..typing import Type, List, CreateResult, Messages, Iterator +from ..typing import Type, List, CreateResult, Messages, Iterator, AsyncResult from .types import BaseProvider, BaseRetryProvider from .. import debug from ..errors import RetryProviderError, RetryNoProviderError -class RetryProvider(BaseRetryProvider): +class NewBaseRetryProvider(BaseRetryProvider): def __init__( self, providers: List[Type[BaseProvider]], - shuffle: bool = True, - single_provider_retry: bool = False, - max_retries: int = 3, + shuffle: bool = True ) -> None: """ Initialize the BaseRetryProvider. @@ -26,8 +24,6 @@ class RetryProvider(BaseRetryProvider): """ self.providers = providers self.shuffle = shuffle - self.single_provider_retry = single_provider_retry - self.max_retries = max_retries self.working = True self.last_provider: Type[BaseProvider] = None @@ -56,7 +52,146 @@ class RetryProvider(BaseRetryProvider): exceptions = {} started: bool = False + for provider in providers: + self.last_provider = provider + try: + if debug.logging: + print(f"Using {provider.__name__} provider") + for token in provider.create_completion(model, messages, stream, **kwargs): + yield token + started = True + if started: + return + except Exception as e: + exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + if started: + raise e + + raise_exceptions(exceptions) + + async def create_async( + self, + model: str, + messages: Messages, + **kwargs, + ) -> str: + """ + Asynchronously create a completion using available providers. + Args: + model (str): The model to be used for completion. + messages (Messages): The messages to be used for generating completion. + Returns: + str: The result of the asynchronous completion. + Raises: + Exception: Any exception encountered during the asynchronous completion process. + """ + providers = self.providers + if self.shuffle: + random.shuffle(providers) + + exceptions = {} + + for provider in providers: + self.last_provider = provider + try: + if debug.logging: + print(f"Using {provider.__name__} provider") + return await asyncio.wait_for( + provider.create_async(model, messages, **kwargs), + timeout=kwargs.get("timeout", 60), + ) + except Exception as e: + exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + + raise_exceptions(exceptions) + + def get_providers(self, stream: bool): + providers = [p for p in self.providers if stream and p.supports_stream] if stream else self.providers + if self.shuffle: + random.shuffle(providers) + return providers + + async def create_async_generator( + self, + model: str, + messages: Messages, + stream: bool = True, + **kwargs + ) -> AsyncResult: + exceptions = {} + started: bool = False + + for provider in self.get_providers(stream): + self.last_provider = provider + try: + if debug.logging: + print(f"Using {provider.__name__} provider") + if not stream: + yield await provider.create_async(model, messages, **kwargs) + elif hasattr(provider, "create_async_generator"): + async for token in provider.create_async_generator(model, messages, stream, **kwargs): + yield token + else: + for token in provider.create_completion(model, messages, stream, **kwargs): + yield token + started = True + if started: + return + except Exception as e: + exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + if started: + raise e + + raise_exceptions(exceptions) + +class RetryProvider(NewBaseRetryProvider): + def __init__( + self, + providers: List[Type[BaseProvider]], + shuffle: bool = True, + single_provider_retry: bool = False, + max_retries: int = 3, + ) -> None: + """ + Initialize the BaseRetryProvider. + Args: + providers (List[Type[BaseProvider]]): List of providers to use. + shuffle (bool): Whether to shuffle the providers list. + single_provider_retry (bool): Whether to retry a single provider if it fails. + max_retries (int): Maximum number of retries for a single provider. + """ + super().__init__(providers, shuffle) + self.single_provider_retry = single_provider_retry + self.max_retries = max_retries + + def create_completion( + self, + model: str, + messages: Messages, + stream: bool = False, + **kwargs, + ) -> CreateResult: + """ + Create a completion using available providers, with an option to stream the response. + Args: + model (str): The model to be used for completion. + messages (Messages): The messages to be used for generating completion. + stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False. + Yields: + CreateResult: Tokens or results from the completion. + Raises: + Exception: Any exception encountered during the completion process. + """ + providers = self.get_providers(stream) if self.single_provider_retry and len(providers) == 1: + exceptions = {} + started: bool = False provider = providers[0] self.last_provider = provider for attempt in range(self.max_retries): @@ -74,25 +209,9 @@ class RetryProvider(BaseRetryProvider): print(f"{provider.__name__}: {e.__class__.__name__}: {e}") if started: raise e + raise_exceptions(exceptions) else: - for provider in providers: - self.last_provider = provider - try: - if debug.logging: - print(f"Using {provider.__name__} provider") - for token in provider.create_completion(model, messages, stream, **kwargs): - yield token - started = True - if started: - return - except Exception as e: - exceptions[provider.__name__] = e - if debug.logging: - print(f"{provider.__name__}: {e.__class__.__name__}: {e}") - if started: - raise e - - raise_exceptions(exceptions) + yield from super().create_completion(model, messages, stream, **kwargs) async def create_async( self, @@ -131,22 +250,9 @@ class RetryProvider(BaseRetryProvider): exceptions[provider.__name__] = e if debug.logging: print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + raise_exceptions(exceptions) else: - for provider in providers: - self.last_provider = provider - try: - if debug.logging: - print(f"Using {provider.__name__} provider") - return await asyncio.wait_for( - provider.create_async(model, messages, **kwargs), - timeout=kwargs.get("timeout", 60), - ) - except Exception as e: - exceptions[provider.__name__] = e - if debug.logging: - print(f"{provider.__name__}: {e.__class__.__name__}: {e}") - - raise_exceptions(exceptions) + return await super().create_async(model, messages, **kwargs) class IterProvider(BaseRetryProvider): __name__ = "IterProvider" |