From 6e0bc147b52cb1e52d7fb3f8dd01d1f33dae201e Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Fri, 3 Jan 2025 20:35:46 +0100 Subject: Support continue messages in Airforce Add auth caching for OpenAI ChatGPT Some provider improvments --- g4f/providers/base_provider.py | 203 ++++++++++++++++++++--------------------- 1 file changed, 100 insertions(+), 103 deletions(-) (limited to 'g4f/providers/base_provider.py') diff --git a/g4f/providers/base_provider.py b/g4f/providers/base_provider.py index ec0f354b..82d1ea61 100644 --- a/g4f/providers/base_provider.py +++ b/g4f/providers/base_provider.py @@ -7,7 +7,7 @@ from concurrent.futures import ThreadPoolExecutor from abc import abstractmethod import json from inspect import signature, Parameter -from typing import Optional, Awaitable, _GenericAlias +from typing import Optional, _GenericAlias from pathlib import Path try: from types import NoneType @@ -16,11 +16,11 @@ except ImportError: from ..typing import CreateResult, AsyncResult, Messages from .types import BaseProvider -from .asyncio import get_running_loop, to_sync_generator +from .asyncio import get_running_loop, to_sync_generator, to_async_iterator from .response import BaseConversation, AuthResult from .helper import concat_chunks, async_concat_chunks from ..cookies import get_cookies_dir -from ..errors import ModelNotSupportedError, ResponseError, MissingAuthError +from ..errors import ModelNotSupportedError, ResponseError, MissingAuthError, NoValidHarFileError from .. import debug SAFE_PARAMETERS = [ @@ -31,7 +31,7 @@ SAFE_PARAMETERS = [ "temperature", "top_k", "top_p", "frequency_penalty", "presence_penalty", "max_tokens", "max_new_tokens", "stop", - "api_key", "seed", "width", "height", + "api_key", "api_base", "seed", "width", "height", "proof_token", "max_retries" ] @@ -63,9 +63,29 @@ PARAMETER_EXAMPLES = { } class AbstractProvider(BaseProvider): - """ - Abstract class for providing asynchronous functionality to derived classes. - """ + + @classmethod + @abstractmethod + def create_completion( + cls, + model: str, + messages: Messages, + stream: bool, + **kwargs + ) -> CreateResult: + """ + Create a completion with the given parameters. + + Args: + model (str): The model to use. + messages (Messages): The messages to process. + stream (bool): Whether to use streaming. + **kwargs: Additional keyword arguments. + + Returns: + CreateResult: The result of the creation process. + """ + raise NotImplementedError() @classmethod async def create_async( @@ -92,16 +112,24 @@ class AbstractProvider(BaseProvider): Returns: str: The created result as a string. """ - loop = loop or asyncio.get_running_loop() + loop = asyncio.get_running_loop() if loop is None else loop def create_func() -> str: - return concat_chunks(cls.create_completion(model, messages, False, **kwargs)) + return concat_chunks(cls.create_completion(model, messages, **kwargs)) return await asyncio.wait_for( loop.run_in_executor(executor, create_func), timeout=timeout ) - + + @classmethod + def get_create_function(cls) -> callable: + return cls.create_completion + + @classmethod + def get_async_create_function(cls) -> callable: + return cls.create_async + @classmethod def get_parameters(cls, as_json: bool = False) -> dict[str, Parameter]: params = {name: parameter for name, parameter in signature( @@ -149,7 +177,7 @@ class AbstractProvider(BaseProvider): ) for name, param in { **BASIC_PARAMETERS, **params, - **{"provider": cls.__name__, "stream": cls.supports_stream, "model": getattr(cls, "default_model", "")}, + **{"provider": cls.__name__, "model": getattr(cls, "default_model", ""), "stream": cls.supports_stream}, }.items()} return params @@ -233,6 +261,14 @@ class AsyncProvider(AbstractProvider): """ raise NotImplementedError() + @classmethod + def get_create_function(cls) -> callable: + return cls.create_completion + + @classmethod + def get_async_create_function(cls) -> callable: + return cls.create_async + class AsyncGeneratorProvider(AsyncProvider): """ Provides asynchronous generator functionality for streaming results. @@ -262,30 +298,10 @@ class AsyncGeneratorProvider(AsyncProvider): CreateResult: The result of the streaming completion creation. """ return to_sync_generator( - cls.create_async_generator(model, messages, stream=stream, **kwargs) + cls.create_async_generator(model, messages, stream=stream, **kwargs), + stream=stream ) - @classmethod - async def create_async( - cls, - model: str, - messages: Messages, - **kwargs - ) -> str: - """ - Asynchronously creates a result from a generator. - - Args: - cls (type): The class on which this method is called. - model (str): The model to use for creation. - messages (Messages): The messages to process. - **kwargs: Additional keyword arguments. - - Returns: - str: The created result as a string. - """ - return await async_concat_chunks(cls.create_async_generator(model, messages, stream=False, **kwargs)) - @staticmethod @abstractmethod async def create_async_generator( @@ -311,11 +327,13 @@ class AsyncGeneratorProvider(AsyncProvider): """ raise NotImplementedError() - create_authed = create_completion - - create_authed_async = create_async + @classmethod + def get_create_function(cls) -> callable: + return cls.create_completion - create_async_authed = create_async_generator + @classmethod + def get_async_create_function(cls) -> callable: + return cls.create_async_generator class ProviderModelMixin: default_model: str = None @@ -357,97 +375,76 @@ class RaiseErrorMixin(): else: raise ResponseError(data["error"]) -class AuthedMixin(): +class AsyncAuthedProvider(AsyncGeneratorProvider): @classmethod - def on_auth(cls, **kwargs) -> Optional[AuthResult]: + async def on_auth_async(cls, **kwargs) -> AuthResult: if "api_key" not in kwargs: raise MissingAuthError(f"API key is required for {cls.__name__}") - return None + return AuthResult() @classmethod - def create_authed( - cls, - model: str, - messages: Messages, - **kwargs - ) -> CreateResult: - auth_result = {} - cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json" - if cache_file.exists(): - with cache_file.open("r") as f: - auth_result = json.load(f) - return cls.create_completion(model, messages, **kwargs, **auth_result) - auth_result = cls.on_auth(**kwargs) - try: - return cls.create_completion(model, messages, **kwargs) - finally: - cache_file.parent.mkdir(parents=True, exist_ok=True) - cache_file.write_text(json.dumps(auth_result.get_dict())) + def on_auth(cls, **kwargs) -> AuthResult: + return asyncio.run(cls.on_auth_async(**kwargs)) -class AsyncAuthedMixin(AuthedMixin): @classmethod - async def create_async_authed( - cls, - model: str, - messages: Messages, - **kwargs - ) -> str: - cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json" - if cache_file.exists(): - auth_result = {} - with cache_file.open("r") as f: - auth_result = json.load(f) - return cls.create_completion(model, messages, **kwargs, **auth_result) - auth_result = cls.on_auth(**kwargs) - try: - return await cls.create_async(model, messages, **kwargs) - finally: - if auth_result is not None: - cache_file.parent.mkdir(parents=True, exist_ok=True) - cache_file.write_text(json.dumps(auth_result.get_dict())) + def get_create_function(cls) -> callable: + return cls.create_completion -class AsyncAuthedGeneratorMixin(AsyncAuthedMixin): + @classmethod + def get_async_create_function(cls) -> callable: + return cls.create_async_generator @classmethod - async def create_async_authed( + def create_completion( cls, model: str, messages: Messages, **kwargs - ) -> str: - cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json" - if cache_file.exists(): - auth_result = {} - with cache_file.open("r") as f: - auth_result = json.load(f) - return cls.create_completion(model, messages, **kwargs, **auth_result) - auth_result = cls.on_auth(**kwargs) + ) -> CreateResult: try: - return await async_concat_chunks(cls.create_async_generator(model, messages, stream=False, **kwargs)) + auth_result = AuthResult() + cache_file = Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json" + if cache_file.exists(): + with cache_file.open("r") as f: + auth_result = AuthResult(**json.load(f)) + else: + auth_result = cls.on_auth(**kwargs) + return to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs)) + except (MissingAuthError, NoValidHarFileError): + if cache_file.exists(): + cache_file.unlink() + auth_result = cls.on_auth(**kwargs) + return to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs)) finally: - if auth_result is not None: cache_file.parent.mkdir(parents=True, exist_ok=True) cache_file.write_text(json.dumps(auth_result.get_dict())) - @classmethod - def create_async_authed_generator( + async def create_async_generator( cls, model: str, messages: Messages, - stream: bool = True, **kwargs - ) -> Awaitable[AsyncResult]: - cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json" - if cache_file.exists(): - auth_result = {} - with cache_file.open("r") as f: - auth_result = json.load(f) - return cls.create_completion(model, messages, **kwargs, **auth_result) - auth_result = cls.on_auth(**kwargs) + ) -> AsyncResult: try: - return cls.create_async_generator(model, messages, stream=stream, **kwargs) + auth_result = AuthResult() + cache_file = Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json" + if cache_file.exists(): + with cache_file.open("r") as f: + auth_result = AuthResult(**json.load(f)) + else: + auth_result = await cls.on_auth_async(**kwargs) + response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result)) + async for chunk in response: + yield chunk + except (MissingAuthError, NoValidHarFileError): + if cache_file.exists(): + cache_file.unlink() + auth_result = await cls.on_auth_async(**kwargs) + response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result)) + async for chunk in response: + yield chunk finally: if auth_result is not None: cache_file.parent.mkdir(parents=True, exist_ok=True) - cache_file.write_text(json.dumps(auth_result.get_dict())) + cache_file.write_text(json.dumps(auth_result.get_dict())) \ No newline at end of file -- cgit v1.2.3