From 75fe95cbedd06d86fb64245f154480b8b731aef8 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Mon, 30 Dec 2024 02:51:36 +0100 Subject: Add continue messages support, Remove old text_to_speech service from gui Update gui and client readmes, Add HuggingSpaces group provider; Add providers parameters config forms to gui --- g4f/Provider/ClaudeSon.py | 11 +-- g4f/Provider/Copilot.py | 69 +++++++++------- g4f/Provider/DDG.py | 1 - g4f/Provider/DeepInfraChat.py | 19 ++--- g4f/Provider/Liaobots.py | 1 - g4f/Provider/ReplicateHome.py | 2 - g4f/Provider/__init__.py | 2 +- g4f/Provider/hf_space/BlackForestLabsFlux1Dev.py | 19 ++++- .../hf_space/BlackForestLabsFlux1Schnell.py | 48 +++++------ g4f/Provider/hf_space/VoodoohopFlux1Schnell.py | 50 +++++------- g4f/Provider/hf_space/__init__.py | 57 +++++++++++++ g4f/Provider/needs_auth/DeepInfra.py | 2 +- g4f/Provider/needs_auth/HuggingChat.py | 6 +- g4f/Provider/needs_auth/HuggingFace.py | 57 +++++++++---- g4f/Provider/needs_auth/OpenaiChat.py | 94 +++++++++++++++------- 15 files changed, 281 insertions(+), 157 deletions(-) (limited to 'g4f/Provider') diff --git a/g4f/Provider/ClaudeSon.py b/g4f/Provider/ClaudeSon.py index 5adc4f38..2dffd24b 100644 --- a/g4f/Provider/ClaudeSon.py +++ b/g4f/Provider/ClaudeSon.py @@ -3,18 +3,15 @@ from __future__ import annotations from aiohttp import ClientSession from ..typing import AsyncResult, Messages +from ..requests.raise_for_status import raise_for_status from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from .helper import format_prompt - class ClaudeSon(AsyncGeneratorProvider, ProviderModelMixin): url = "https://claudeson.net" api_endpoint = "https://claudeson.net/api/coze/chat" working = True - - supports_system_message = True - supports_message_history = True - + default_model = 'claude-3.5-sonnet' models = [default_model] @@ -40,7 +37,7 @@ class ClaudeSon(AsyncGeneratorProvider, ProviderModelMixin): "type": "company" } async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response: - response.raise_for_status() + await raise_for_status(response) async for chunk in response.content: if chunk: - yield chunk.decode() + yield chunk.decode(errors="ignore") \ No newline at end of file diff --git a/g4f/Provider/Copilot.py b/g4f/Provider/Copilot.py index 7999c0f4..bc28fd60 100644 --- a/g4f/Provider/Copilot.py +++ b/g4f/Provider/Copilot.py @@ -24,7 +24,7 @@ from .openai.har_file import get_headers, get_har_files from ..typing import CreateResult, Messages, ImagesType from ..errors import MissingRequirementsError, NoValidHarFileError from ..requests.raise_for_status import raise_for_status -from ..providers.response import JsonConversation, RequestLogin +from ..providers.response import BaseConversation, JsonConversation, RequestLogin, Parameters from ..providers.asyncio import get_running_loop from ..requests import get_nodriver from ..image import ImageResponse, to_bytes, is_accepted_format @@ -61,10 +61,12 @@ class Copilot(AbstractProvider, ProviderModelMixin): stream: bool = False, proxy: str = None, timeout: int = 900, + prompt: str = None, images: ImagesType = None, - conversation: Conversation = None, + conversation: BaseConversation = None, return_conversation: bool = False, - web_search: bool = True, + api_key: str = None, + web_search: bool = False, **kwargs ) -> CreateResult: if not has_curl_cffi: @@ -73,6 +75,8 @@ class Copilot(AbstractProvider, ProviderModelMixin): websocket_url = cls.websocket_url headers = None if cls.needs_auth or images is not None: + if api_key is not None: + cls._access_token = api_key if cls._access_token is None: try: cls._access_token, cls._cookies = readHAR(cls.url) @@ -86,7 +90,7 @@ class Copilot(AbstractProvider, ProviderModelMixin): cls._access_token, cls._cookies = asyncio.run(get_access_token_and_cookies(cls.url, proxy)) else: raise h - debug.log(f"Copilot: Access token: {cls._access_token[:7]}...{cls._access_token[-5:]}") + yield Parameters(**{"api_key": cls._access_token, "cookies": cls._cookies}) websocket_url = f"{websocket_url}&accessToken={quote(cls._access_token)}" headers = {"authorization": f"Bearer {cls._access_token}"} @@ -124,13 +128,17 @@ class Copilot(AbstractProvider, ProviderModelMixin): raise_for_status(response) conversation_id = response.json().get("id") if return_conversation: - yield Conversation(conversation_id) - prompt = format_prompt_max_length(messages, 10000) + conversation = Conversation(conversation_id) + yield conversation + if prompt is None: + prompt = format_prompt_max_length(messages, 10000) debug.log(f"Copilot: Created conversation: {conversation_id}") else: conversation_id = conversation.conversation_id - prompt = messages[-1]["content"] + if prompt is None: + prompt = messages[-1]["content"] debug.log(f"Copilot: Use conversation: {conversation_id}") + yield Parameters(**{"conversation": conversation.get_dict(), "user": user, "prompt": prompt}) uploaded_images = [] if images is not None: @@ -166,28 +174,31 @@ class Copilot(AbstractProvider, ProviderModelMixin): msg = None image_prompt: str = None last_msg = None - while True: - try: - msg = wss.recv()[0] - msg = json.loads(msg) - except: - break - last_msg = msg - if msg.get("event") == "appendText": - is_started = True - yield msg.get("text") - elif msg.get("event") == "generatingImage": - image_prompt = msg.get("prompt") - elif msg.get("event") == "imageGenerated": - yield ImageResponse(msg.get("url"), image_prompt, {"preview": msg.get("thumbnailUrl")}) - elif msg.get("event") == "done": - break - elif msg.get("event") == "error": - raise RuntimeError(f"Error: {msg}") - elif msg.get("event") not in ["received", "startMessage", "citation", "partCompleted"]: - debug.log(f"Copilot Message: {msg}") - if not is_started: - raise RuntimeError(f"Invalid response: {last_msg}") + try: + while True: + try: + msg = wss.recv()[0] + msg = json.loads(msg) + except: + break + last_msg = msg + if msg.get("event") == "appendText": + is_started = True + yield msg.get("text") + elif msg.get("event") == "generatingImage": + image_prompt = msg.get("prompt") + elif msg.get("event") == "imageGenerated": + yield ImageResponse(msg.get("url"), image_prompt, {"preview": msg.get("thumbnailUrl")}) + elif msg.get("event") == "done": + break + elif msg.get("event") == "error": + raise RuntimeError(f"Error: {msg}") + elif msg.get("event") not in ["received", "startMessage", "citation", "partCompleted"]: + debug.log(f"Copilot Message: {msg}") + if not is_started: + raise RuntimeError(f"Invalid response: {last_msg}") + finally: + yield Parameters(**{"cookies": {c.name: c.value for c in session.cookies.jar}}) async def get_access_token_and_cookies(url: str, proxy: str = None, target: str = "ChatAI",): browser = await get_nodriver(proxy=proxy, user_data_dir="copilot") diff --git a/g4f/Provider/DDG.py b/g4f/Provider/DDG.py index 71c39bcd..ae418c16 100644 --- a/g4f/Provider/DDG.py +++ b/g4f/Provider/DDG.py @@ -9,7 +9,6 @@ from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, BaseConve from .helper import format_prompt from ..requests.aiohttp import get_connector from ..requests.raise_for_status import raise_for_status -from .. import debug MODELS = [ {"model":"gpt-4o","modelName":"GPT-4o","modelVariant":None,"modelStyleId":"gpt-4o-mini","createdBy":"OpenAI","moderationLevel":"HIGH","isAvailable":1,"inputCharLimit":16e3,"settingId":"4"}, diff --git a/g4f/Provider/DeepInfraChat.py b/g4f/Provider/DeepInfraChat.py index 48b87b9b..e947383d 100644 --- a/g4f/Provider/DeepInfraChat.py +++ b/g4f/Provider/DeepInfraChat.py @@ -1,9 +1,10 @@ from __future__ import annotations -import json +import json from aiohttp import ClientSession from ..typing import AsyncResult, Messages +from ..requests.raise_for_status import raise_for_status from .base_provider import AsyncGeneratorProvider, ProviderModelMixin class DeepInfraChat(AsyncGeneratorProvider, ProviderModelMixin): @@ -14,7 +15,7 @@ class DeepInfraChat(AsyncGeneratorProvider, ProviderModelMixin): supports_stream = True supports_system_message = True supports_message_history = True - + default_model = 'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo' models = [ 'meta-llama/Llama-3.3-70B-Instruct', @@ -48,7 +49,7 @@ class DeepInfraChat(AsyncGeneratorProvider, ProviderModelMixin): **kwargs ) -> AsyncResult: model = cls.get_model(model) - + headers = { 'Accept-Language': 'en-US,en;q=0.9', 'Content-Type': 'application/json', @@ -57,31 +58,31 @@ class DeepInfraChat(AsyncGeneratorProvider, ProviderModelMixin): 'X-Deepinfra-Source': 'web-page', 'accept': 'text/event-stream', } - async with ClientSession(headers=headers) as session: data = { "model": model, "messages": messages, "stream": True } - async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response: - response.raise_for_status() + await raise_for_status(response) async for chunk in response.content: if chunk: - chunk_text = chunk.decode() + chunk_text = chunk.decode(errors="ignore") try: # Handle streaming response if chunk_text.startswith("data: "): if chunk_text.strip() == "data: [DONE]": continue chunk_data = json.loads(chunk_text[6:]) - if content := chunk_data["choices"][0]["delta"].get("content"): + content = chunk_data["choices"][0]["delta"].get("content") + if content: yield content # Handle non-streaming response else: chunk_data = json.loads(chunk_text) - if content := chunk_data["choices"][0]["message"].get("content"): + content = chunk_data["choices"][0]["message"].get("content") + if content: yield content except (json.JSONDecodeError, KeyError): continue diff --git a/g4f/Provider/Liaobots.py b/g4f/Provider/Liaobots.py index 1e8131f8..2f032434 100644 --- a/g4f/Provider/Liaobots.py +++ b/g4f/Provider/Liaobots.py @@ -192,7 +192,6 @@ class Liaobots(AsyncGeneratorProvider, ProviderModelMixin): cls, model: str, messages: Messages, - auth: str = None, proxy: str = None, connector: BaseConnector = None, **kwargs diff --git a/g4f/Provider/ReplicateHome.py b/g4f/Provider/ReplicateHome.py index 351c2cf6..0f5452cd 100644 --- a/g4f/Provider/ReplicateHome.py +++ b/g4f/Provider/ReplicateHome.py @@ -16,8 +16,6 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin): api_endpoint = "https://homepage.replicate.com/api/prediction" working = True supports_stream = True - supports_system_message = True - supports_message_history = True default_model = 'google-deepmind/gemma-2b-it' default_image_model = 'stability-ai/stable-diffusion-3' diff --git a/g4f/Provider/__init__.py b/g4f/Provider/__init__.py index 82e8da06..09a53e35 100644 --- a/g4f/Provider/__init__.py +++ b/g4f/Provider/__init__.py @@ -10,7 +10,7 @@ from .selenium import * from .needs_auth import * from .not_working import * from .local import * -from .hf_space import * +from .hf_space import HuggingSpace from .Airforce import Airforce from .AmigoChat import AmigoChat diff --git a/g4f/Provider/hf_space/BlackForestLabsFlux1Dev.py b/g4f/Provider/hf_space/BlackForestLabsFlux1Dev.py index 7987cc1b..74c9502d 100644 --- a/g4f/Provider/hf_space/BlackForestLabsFlux1Dev.py +++ b/g4f/Provider/hf_space/BlackForestLabsFlux1Dev.py @@ -5,6 +5,7 @@ from aiohttp import ClientSession from ...typing import AsyncResult, Messages from ...image import ImageResponse, ImagePreview +from ...errors import ResponseError from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin class BlackForestLabsFlux1Dev(AsyncGeneratorProvider, ProviderModelMixin): @@ -12,14 +13,24 @@ class BlackForestLabsFlux1Dev(AsyncGeneratorProvider, ProviderModelMixin): api_endpoint = "/gradio_api/call/infer" working = True - + default_model = 'flux-dev' models = [default_model] image_models = [default_model] @classmethod async def create_async_generator( - cls, model: str, messages: Messages, prompt: str = None, api_key: str = None, proxy: str = None, **kwargs + cls, model: str, messages: Messages, + prompt: str = None, + api_key: str = None, + proxy: str = None, + width: int = 1024, + height: int = 1024, + guidance_scale: float = 3.5, + num_inference_steps: int = 28, + seed: int = 0, + randomize_seed: bool = True, + **kwargs ) -> AsyncResult: headers = { "Content-Type": "application/json", @@ -30,7 +41,7 @@ class BlackForestLabsFlux1Dev(AsyncGeneratorProvider, ProviderModelMixin): async with ClientSession(headers=headers) as session: prompt = messages[-1]["content"] if prompt is None else prompt data = { - "data": [prompt, 0, True, 1024, 1024, 3.5, 28] + "data": [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps] } async with session.post(f"{cls.url}{cls.api_endpoint}", json=data, proxy=proxy) as response: response.raise_for_status() @@ -43,7 +54,7 @@ class BlackForestLabsFlux1Dev(AsyncGeneratorProvider, ProviderModelMixin): event = chunk[7:].decode(errors="replace").strip() if chunk.startswith(b"data: "): if event == "error": - raise RuntimeError(f"GPU token limit exceeded: {chunk.decode(errors='replace')}") + raise ResponseError(f"GPU token limit exceeded: {chunk.decode(errors='replace')}") if event in ("complete", "generating"): try: data = json.loads(chunk[6:]) diff --git a/g4f/Provider/hf_space/BlackForestLabsFlux1Schnell.py b/g4f/Provider/hf_space/BlackForestLabsFlux1Schnell.py index 7b29b7af..2dd129d2 100644 --- a/g4f/Provider/hf_space/BlackForestLabsFlux1Schnell.py +++ b/g4f/Provider/hf_space/BlackForestLabsFlux1Schnell.py @@ -2,11 +2,11 @@ from __future__ import annotations from aiohttp import ClientSession import json -import random -from typing import Optional from ...typing import AsyncResult, Messages from ...image import ImageResponse +from ...errors import ResponseError +from ...requests.raise_for_status import raise_for_status from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin class BlackForestLabsFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin): @@ -19,6 +19,7 @@ class BlackForestLabsFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin): default_image_model = default_model image_models = [default_image_model] models = [*image_models] + model_aliases = {"flux-schnell-black-forest-labs": default_model} @classmethod async def create_async_generator( @@ -26,21 +27,21 @@ class BlackForestLabsFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin): model: str, messages: Messages, proxy: str = None, + prompt: str = None, width: int = 768, height: int = 768, num_inference_steps: int = 2, - seed: Optional[int] = None, - randomize_seed: bool = False, + seed: int = 0, + randomize_seed: bool = True, **kwargs ) -> AsyncResult: - if seed is None: - seed = random.randint(0, 10000) - + width = max(32, width - (width % 8)) height = max(32, height - (height % 8)) - - prompt = messages[-1]["content"] - + + if prompt is None: + prompt = messages[-1]["content"] + payload = { "data": [ prompt, @@ -51,31 +52,26 @@ class BlackForestLabsFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin): num_inference_steps ] } - async with ClientSession() as session: async with session.post(cls.api_endpoint, json=payload, proxy=proxy) as response: - response.raise_for_status() + await raise_for_status(response) response_data = await response.json() event_id = response_data['event_id'] - while True: async with session.get(f"{cls.api_endpoint}/{event_id}", proxy=proxy) as status_response: - status_response.raise_for_status() - events = (await status_response.text()).split('\n\n') - - for event in events: - if event.startswith('event:'): - event_parts = event.split('\ndata: ') + await raise_for_status(status_response) + while not status_response.content.at_eof(): + event = await status_response.content.readuntil(b'\n\n') + if event.startswith(b'event:'): + event_parts = event.split(b'\ndata: ') if len(event_parts) < 2: continue - - event_type = event_parts[0].split(': ')[1] + event_type = event_parts[0].split(b': ')[1] data = event_parts[1] - - if event_type == 'error': - raise Exception(f"Error generating image: {data}") - elif event_type == 'complete': + if event_type == b'error': + raise ResponseError(f"Error generating image: {data}") + elif event_type == b'complete': json_data = json.loads(data) image_url = json_data[0]['url'] yield ImageResponse(images=[image_url], alt=prompt) - return + return \ No newline at end of file diff --git a/g4f/Provider/hf_space/VoodoohopFlux1Schnell.py b/g4f/Provider/hf_space/VoodoohopFlux1Schnell.py index bd55b20b..c1778d94 100644 --- a/g4f/Provider/hf_space/VoodoohopFlux1Schnell.py +++ b/g4f/Provider/hf_space/VoodoohopFlux1Schnell.py @@ -2,23 +2,23 @@ from __future__ import annotations from aiohttp import ClientSession import json -import random -from typing import Optional from ...typing import AsyncResult, Messages from ...image import ImageResponse +from ...errors import ResponseError +from ...requests.raise_for_status import raise_for_status from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin class VoodoohopFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin): url = "https://voodoohop-flux-1-schnell.hf.space" api_endpoint = "https://voodoohop-flux-1-schnell.hf.space/call/infer" - working = True - + default_model = "flux-schnell" default_image_model = default_model image_models = [default_image_model] models = [*image_models] + model_aliases = {"flux-schnell-voodoohop": default_model} @classmethod async def create_async_generator( @@ -26,21 +26,20 @@ class VoodoohopFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin): model: str, messages: Messages, proxy: str = None, + prompt: str = None, width: int = 768, height: int = 768, num_inference_steps: int = 2, - seed: Optional[int] = None, - randomize_seed: bool = False, + seed: int = 0, + randomize_seed: bool = True, **kwargs ) -> AsyncResult: - if seed is None: - seed = random.randint(0, 10000) - width = max(32, width - (width % 8)) height = max(32, height - (height % 8)) - - prompt = messages[-1]["content"] - + + if prompt is None: + prompt = messages[-1]["content"] + payload = { "data": [ prompt, @@ -51,31 +50,26 @@ class VoodoohopFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin): num_inference_steps ] } - async with ClientSession() as session: async with session.post(cls.api_endpoint, json=payload, proxy=proxy) as response: - response.raise_for_status() + await raise_for_status(response) response_data = await response.json() event_id = response_data['event_id'] - while True: async with session.get(f"{cls.api_endpoint}/{event_id}", proxy=proxy) as status_response: - status_response.raise_for_status() - events = (await status_response.text()).split('\n\n') - - for event in events: - if event.startswith('event:'): - event_parts = event.split('\ndata: ') + await raise_for_status(status_response) + while not status_response.content.at_eof(): + event = await status_response.content.readuntil(b'\n\n') + if event.startswith(b'event:'): + event_parts = event.split(b'\ndata: ') if len(event_parts) < 2: continue - - event_type = event_parts[0].split(': ')[1] + event_type = event_parts[0].split(b': ')[1] data = event_parts[1] - - if event_type == 'error': - raise Exception(f"Error generating image: {data}") - elif event_type == 'complete': + if event_type == b'error': + raise ResponseError(f"Error generating image: {data}") + elif event_type == b'complete': json_data = json.loads(data) image_url = json_data[0]['url'] yield ImageResponse(images=[image_url], alt=prompt) - return + return \ No newline at end of file diff --git a/g4f/Provider/hf_space/__init__.py b/g4f/Provider/hf_space/__init__.py index 94524e35..87dfb32b 100644 --- a/g4f/Provider/hf_space/__init__.py +++ b/g4f/Provider/hf_space/__init__.py @@ -1,3 +1,60 @@ +from __future__ import annotations + +from ...typing import AsyncResult, Messages +from ...errors import ResponseError +from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin + from .BlackForestLabsFlux1Dev import BlackForestLabsFlux1Dev from .BlackForestLabsFlux1Schnell import BlackForestLabsFlux1Schnell from .VoodoohopFlux1Schnell import VoodoohopFlux1Schnell + +class HuggingSpace(AsyncGeneratorProvider, ProviderModelMixin): + url = "https://huggingface.co/spaces" + working = True + default_model = BlackForestLabsFlux1Dev.default_model + providers = [BlackForestLabsFlux1Dev, BlackForestLabsFlux1Schnell, VoodoohopFlux1Schnell] + + @classmethod + def get_parameters(cls, **kwargs) -> dict: + parameters = {} + for provider in cls.providers: + parameters = {**parameters, **provider.get_parameters(**kwargs)} + return parameters + + @classmethod + def get_models(cls, **kwargs) -> list[str]: + if not cls.models: + for provider in cls.providers: + cls.models.extend(provider.get_models(**kwargs)) + cls.models.extend(provider.model_aliases.keys()) + cls.models = list(set(cls.models)) + cls.models.sort() + return cls.models + + @classmethod + async def create_async_generator( + cls, model: str, messages: Messages, **kwargs + ) -> AsyncResult: + is_started = False + for provider in cls.providers: + if model in provider.model_aliases: + async for chunk in provider.create_async_generator(provider.model_aliases[model], messages, **kwargs): + is_started = True + yield chunk + if is_started: + return + error = None + for provider in cls.providers: + if model in provider.get_models(): + try: + async for chunk in provider.create_async_generator(model, messages, **kwargs): + is_started = True + yield chunk + if is_started: + break + except ResponseError as e: + if is_started: + raise e + error = e + if not is_started and error is not None: + raise error \ No newline at end of file diff --git a/g4f/Provider/needs_auth/DeepInfra.py b/g4f/Provider/needs_auth/DeepInfra.py index 3b5b6227..c5ebac1e 100644 --- a/g4f/Provider/needs_auth/DeepInfra.py +++ b/g4f/Provider/needs_auth/DeepInfra.py @@ -54,4 +54,4 @@ class DeepInfra(OpenaiAPI): max_tokens=max_tokens, headers=headers, **kwargs - ) + ) \ No newline at end of file diff --git a/g4f/Provider/needs_auth/HuggingChat.py b/g4f/Provider/needs_auth/HuggingChat.py index c261595b..9557a6f9 100644 --- a/g4f/Provider/needs_auth/HuggingChat.py +++ b/g4f/Provider/needs_auth/HuggingChat.py @@ -32,7 +32,8 @@ class HuggingChat(AbstractProvider, ProviderModelMixin): default_model = "Qwen/Qwen2.5-72B-Instruct" default_image_model = "black-forest-labs/FLUX.1-dev" image_models = [ - "black-forest-labs/FLUX.1-dev" + "black-forest-labs/FLUX.1-dev", + "black-forest-labs/FLUX.1-schnell", ] models = [ default_model, @@ -59,9 +60,10 @@ class HuggingChat(AbstractProvider, ProviderModelMixin): "hermes-3": "NousResearch/Hermes-3-Llama-3.1-8B", "mistral-nemo": "mistralai/Mistral-Nemo-Instruct-2407", "phi-3.5-mini": "microsoft/Phi-3.5-mini-instruct", - + ### Image ### "flux-dev": "black-forest-labs/FLUX.1-dev", + "flux-schnell": "black-forest-labs/FLUX.1-schnell", } @classmethod diff --git a/g4f/Provider/needs_auth/HuggingFace.py b/g4f/Provider/needs_auth/HuggingFace.py index 19f33fd0..f77b68cb 100644 --- a/g4f/Provider/needs_auth/HuggingFace.py +++ b/g4f/Provider/needs_auth/HuggingFace.py @@ -9,7 +9,9 @@ from ...typing import AsyncResult, Messages from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, format_prompt from ...errors import ModelNotFoundError, ModelNotSupportedError, ResponseError from ...requests import StreamSession, raise_for_status +from ...providers.response import FinishReason from ...image import ImageResponse +from ... import debug from .HuggingChat import HuggingChat @@ -48,6 +50,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): max_new_tokens: int = 1024, temperature: float = 0.7, prompt: str = None, + action: str = None, extra_data: dict = {}, **kwargs ) -> AsyncResult: @@ -85,6 +88,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): "temperature": temperature, **extra_data } + do_continue = action == "continue" async with StreamSession( headers=headers, proxy=proxy, @@ -97,20 +101,23 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): model_type = None if "config" in model_data and "model_type" in model_data["config"]: model_type = model_data["config"]["model_type"] + debug.log(f"Model type: {model_type}") if model_type in ("gpt2", "gpt_neo", "gemma", "gemma2"): - inputs = format_prompt(messages) + inputs = format_prompt(messages, do_continue=do_continue) + elif model_type in ("mistral"): + inputs = format_prompt_mistral(messages, do_continue) elif "config" in model_data and "tokenizer_config" in model_data["config"] and "eos_token" in model_data["config"]["tokenizer_config"]: eos_token = model_data["config"]["tokenizer_config"]["eos_token"] if eos_token in ("<|endoftext|>", "", ""): - inputs = format_prompt_custom(messages, eos_token) + inputs = format_prompt_custom(messages, eos_token, do_continue) elif eos_token == "<|im_end|>": - inputs = format_prompt_qwen(messages) + inputs = format_prompt_qwen(messages, do_continue) elif eos_token == "<|eot_id|>": - inputs = format_prompt_llama(messages) + inputs = format_prompt_llama(messages, do_continue) else: - inputs = format_prompt_default(messages) + inputs = format_prompt(messages, do_continue=do_continue) else: - inputs = format_prompt_default(messages) + inputs = format_prompt(messages, do_continue=do_continue) if model_type == "gpt2" and max_new_tokens >= 1024: params["max_new_tokens"] = 512 payload = {"inputs": inputs, "parameters": params, "stream": stream} @@ -121,6 +128,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): await raise_for_status(response) if stream: first = True + is_special = False async for line in response.iter_lines(): if line.startswith(b"data:"): data = json.loads(line[5:]) @@ -128,11 +136,15 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): raise ResponseError(data["error"]) if not data["token"]["special"]: chunk = data["token"]["text"] - if first: + if first and not do_continue: first = False chunk = chunk.lstrip() if chunk: yield chunk + else: + is_special = True + debug.log(f"Special token: {is_special}") + yield FinishReason("stop" if is_special else "max_tokens", actions=["variant"] if is_special else ["continue", "variant"]) else: if response.headers["content-type"].startswith("image/"): base64_data = base64.b64encode(b"".join([chunk async for chunk in response.iter_content()])) @@ -141,7 +153,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): else: yield (await response.json())[0]["generated_text"].strip() -def format_prompt_default(messages: Messages) -> str: +def format_prompt_mistral(messages: Messages, do_continue: bool = False) -> str: system_messages = [message["content"] for message in messages if message["role"] == "system"] question = " ".join([messages[-1]["content"], *system_messages]) history = "".join([ @@ -149,19 +161,30 @@ def format_prompt_default(messages: Messages) -> str: for idx, message in enumerate(messages) if message["role"] == "assistant" ]) + if do_continue: + return history[:-len('')] return f"{history}[INST] {question} [/INST]" -def format_prompt_qwen(messages: Messages) -> str: - return "".join([ +def format_prompt_qwen(messages: Messages, do_continue: bool = False) -> str: + prompt = "".join([ f"<|im_start|>{message['role']}\n{message['content']}\n<|im_end|>\n" for message in messages - ]) + "<|im_start|>assistant\n" + ]) + ("" if do_continue else "<|im_start|>assistant\n") + if do_continue: + return prompt[:-len("\n<|im_end|>\n")] + return prompt -def format_prompt_llama(messages: Messages) -> str: - return "<|begin_of_text|>" + "".join([ +def format_prompt_llama(messages: Messages, do_continue: bool = False) -> str: + prompt = "<|begin_of_text|>" + "".join([ f"<|start_header_id|>{message['role']}<|end_header_id|>\n\n{message['content']}\n<|eot_id|>\n" for message in messages - ]) + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ]) + ("" if do_continue else "<|start_header_id|>assistant<|end_header_id|>\n\n") + if do_continue: + return prompt[:-len("\n<|eot_id|>\n")] + return prompt -def format_prompt_custom(messages: Messages, end_token: str = "") -> str: - return "".join([ +def format_prompt_custom(messages: Messages, end_token: str = "", do_continue: bool = False) -> str: + prompt = "".join([ f"<|{message['role']}|>\n{message['content']}{end_token}\n" for message in messages - ]) + "<|assistant|>\n" \ No newline at end of file + ]) + ("" if do_continue else "<|assistant|>\n") + if do_continue: + return prompt[:-len(end_token + "\n")] + return prompt \ No newline at end of file diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index 9c869fef..652cbfc2 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -7,7 +7,6 @@ import uuid import json import base64 import time -import requests import random from typing import AsyncIterator, Iterator, Optional, Generator, Dict, List from copy import copy @@ -21,11 +20,12 @@ except ImportError: from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ...typing import AsyncResult, Messages, Cookies, ImagesType from ...requests.raise_for_status import raise_for_status -from ...requests import StreamSession, Session +from ...requests import StreamSession from ...requests import get_nodriver from ...image import ImageResponse, ImageRequest, to_image, to_bytes, is_accepted_format from ...errors import MissingAuthError, NoValidHarFileError -from ...providers.response import JsonConversation, FinishReason, SynthesizeData, Sources, TitleGeneration, RequestLogin, quote_url +from ...providers.response import JsonConversation, FinishReason, SynthesizeData +from ...providers.response import Sources, TitleGeneration, RequestLogin, Parameters from ..helper import format_cookies from ..openai.har_file import get_request_config from ..openai.har_file import RequestConfig, arkReq, arkose_url, start_url, conversation_url, backend_url, backend_anon_url @@ -272,13 +272,11 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): messages: Messages, proxy: str = None, timeout: int = 180, - cookies: Cookies = None, auto_continue: bool = False, history_disabled: bool = False, action: str = "next", conversation_id: str = None, conversation: Conversation = None, - parent_id: str = None, images: ImagesType = None, return_conversation: bool = False, max_retries: int = 3, @@ -294,12 +292,10 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): proxy (str): Proxy to use for requests. timeout (int): Timeout for requests. api_key (str): Access token for authentication. - cookies (dict): Cookies to use for authentication. auto_continue (bool): Flag to automatically continue the conversation. history_disabled (bool): Flag to disable history and training. action (str): Type of action ('next', 'continue', 'variant'). conversation_id (str): ID of the conversation. - parent_id (str): ID of the parent message. images (ImagesType): Images to include in the conversation. return_conversation (bool): Flag to include response fields in the output. **kwargs: Additional keyword arguments. @@ -311,7 +307,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): RuntimeError: If an error occurs during processing. """ if cls.needs_auth: - async for message in cls.login(proxy): + async for message in cls.login(proxy, **kwargs): yield message async with StreamSession( proxy=proxy, @@ -321,11 +317,12 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): image_requests = None if not cls.needs_auth: if cls._headers is None: - cls._create_request_args(cookies) + cls._create_request_args(cls._cookies) async with session.get(cls.url, headers=INIT_HEADERS) as response: cls._update_request_args(session) await raise_for_status(response) else: + print(cls._headers) async with session.get(cls.url, headers=cls._headers) as response: cls._update_request_args(session) await raise_for_status(response) @@ -336,7 +333,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): debug.log(f"{e.__class__.__name__}: {e}") model = cls.get_model(model) if conversation is None: - conversation = Conversation(conversation_id, str(uuid.uuid4()) if parent_id is None else parent_id) + conversation = Conversation(conversation_id, str(uuid.uuid4())) else: conversation = copy(conversation) if cls._api_key is None: @@ -363,7 +360,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): if need_arkose and RequestConfig.arkose_token is None: await get_request_config(proxy) - cls._create_request_args(RequestConfig,cookies, RequestConfig.headers) + cls._create_request_args(RequestConfig.cookies, RequestConfig.headers) cls._set_api_key(RequestConfig.access_token) if RequestConfig.arkose_token is None: raise MissingAuthError("No arkose token found in .har file") @@ -407,6 +404,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): data["conversation_id"] = conversation.conversation_id debug.log(f"OpenaiChat: Use conversation: {conversation.conversation_id}") if action != "continue": + data["parent_message_id"] = conversation.parent_message_id + conversation.parent_message_id = None messages = messages if conversation_id is None else [messages[-1]] data["messages"] = cls.create_messages(messages, image_requests, ["search"] if web_search else None) headers = { @@ -475,7 +474,16 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): await asyncio.sleep(5) else: break - yield FinishReason(conversation.finish_reason) + yield Parameters(**{ + "action": "continue" if conversation.finish_reason == "max_tokens" else "variant", + "conversation": conversation.get_dict(), + "proof_token": RequestConfig.proof_token, + "cookies": cls._cookies, + "headers": cls._headers, + "web_search": web_search, + }) + actions = ["variant", "continue"] if conversation.finish_reason == "max_tokens" else ["variant"] + yield FinishReason(conversation.finish_reason, actions=actions) @classmethod async def iter_messages_line(cls, session: StreamSession, line: bytes, fields: Conversation, sources: Sources) -> AsyncIterator: @@ -530,6 +538,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): if image_response is not None: yield image_response if m.get("author", {}).get("role") == "assistant": + if fields.parent_message_id is None: + fields.parent_message_id = v.get("message", {}).get("id") fields.message_id = v.get("message", {}).get("id") return if "error" in line and line.get("error"): @@ -553,24 +563,49 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): yield chunk @classmethod - async def login(cls, proxy: str = None) -> AsyncIterator[str]: - if cls._expires is not None and cls._expires < time.time(): + async def login( + cls, + proxy: str = None, + api_key: str = None, + proof_token: str = None, + cookies: Cookies = None, + headers: dict = None, + **kwargs + ) -> AsyncIterator[str]: + if cls._expires is not None and (cls._expires - 60*10) < time.time(): cls._headers = cls._api_key = None - try: - await get_request_config(proxy) + if cls._headers is None or headers is not None: + cls._headers = {} if headers is None else headers + if proof_token is not None: + RequestConfig.proof_token = proof_token + if cookies is not None: + RequestConfig.cookies = cookies + if api_key is not None: cls._create_request_args(RequestConfig.cookies, RequestConfig.headers) - if RequestConfig.access_token is not None or cls.needs_auth: - if not cls._set_api_key(RequestConfig.access_token): - raise NoValidHarFileError(f"Access token is not valid: {RequestConfig.access_token}") - except NoValidHarFileError: - if has_nodriver: - if cls._api_key is None: - login_url = os.environ.get("G4F_LOGIN_URL") - if login_url: - yield RequestLogin(cls.label, login_url) - await cls.nodriver_auth(proxy) - else: - raise + cls._set_api_key(api_key) + else: + try: + await get_request_config(proxy) + cls._create_request_args(RequestConfig.cookies, RequestConfig.headers) + print(RequestConfig.access_token) + if RequestConfig.access_token is not None or cls.needs_auth: + if not cls._set_api_key(RequestConfig.access_token): + raise NoValidHarFileError(f"Access token is not valid: {RequestConfig.access_token}") + except NoValidHarFileError: + if has_nodriver: + if cls._api_key is None: + login_url = os.environ.get("G4F_LOGIN_URL") + if login_url: + yield RequestLogin(cls.label, login_url) + await cls.nodriver_auth(proxy) + else: + raise + yield Parameters(**{ + "api_key": cls._api_key, + "proof_token": RequestConfig.proof_token, + "cookies": RequestConfig.cookies, + "headers": RequestConfig.headers + }) @classmethod async def nodriver_auth(cls, proxy: str = None): @@ -666,11 +701,12 @@ class Conversation(JsonConversation): """ Class to encapsulate response fields. """ - def __init__(self, conversation_id: str = None, message_id: str = None, finish_reason: str = None): + def __init__(self, conversation_id: str = None, message_id: str = None, finish_reason: str = None, parent_message_id: str = None): self.conversation_id = conversation_id self.message_id = message_id self.finish_reason = finish_reason self.is_recipient = False + self.parent_message_id = message_id if parent_message_id is None else parent_message_id def get_cookies( urls: Optional[Iterator[str]] = None -- cgit v1.2.3