From ff66df14867807c2b888efe0d7bed4eccf49786b Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Sun, 15 Dec 2024 23:22:36 +0100 Subject: Improved ignored providers support, Add get_models to OpenaiAPI, HuggingFace and Groq Add xAI provider --- g4f/Provider/needs_auth/Cerebras.py | 23 ++++---------------- g4f/Provider/needs_auth/Groq.py | 7 +++--- g4f/Provider/needs_auth/HuggingFace.py | 36 ++++++++++++++++++++----------- g4f/Provider/needs_auth/HuggingFaceAPI.py | 3 ++- g4f/Provider/needs_auth/OpenaiAPI.py | 32 +++++++++++++++++++++++---- g4f/Provider/needs_auth/__init__.py | 1 + g4f/Provider/needs_auth/xAI.py | 22 +++++++++++++++++++ 7 files changed, 85 insertions(+), 39 deletions(-) create mode 100644 g4f/Provider/needs_auth/xAI.py (limited to 'g4f/Provider/needs_auth') diff --git a/g4f/Provider/needs_auth/Cerebras.py b/g4f/Provider/needs_auth/Cerebras.py index 0f94c476..df34db0e 100644 --- a/g4f/Provider/needs_auth/Cerebras.py +++ b/g4f/Provider/needs_auth/Cerebras.py @@ -1,6 +1,5 @@ from __future__ import annotations -import requests from aiohttp import ClientSession from .OpenaiAPI import OpenaiAPI @@ -11,35 +10,21 @@ from ...cookies import get_cookies class Cerebras(OpenaiAPI): label = "Cerebras Inference" url = "https://inference.cerebras.ai/" + api_base = "https://api.cerebras.ai/v1" working = True default_model = "llama3.1-70b" - fallback_models = [ + models = [ "llama3.1-70b", "llama3.1-8b", ] model_aliases = {"llama-3.1-70b": "llama3.1-70b", "llama-3.1-8b": "llama3.1-8b"} - @classmethod - def get_models(cls, api_key: str = None): - if not cls.models: - try: - headers = {} - if api_key: - headers["authorization"] = f"Bearer ${api_key}" - response = requests.get(f"https://api.cerebras.ai/v1/models", headers=headers) - raise_for_status(response) - data = response.json() - cls.models = [model.get("model") for model in data.get("models")] - except Exception: - cls.models = cls.fallback_models - return cls.models - @classmethod async def create_async_generator( cls, model: str, messages: Messages, - api_base: str = "https://api.cerebras.ai/v1", + api_base: str = api_base, api_key: str = None, cookies: Cookies = None, **kwargs @@ -62,4 +47,4 @@ class Cerebras(OpenaiAPI): }, **kwargs ): - yield chunk + yield chunk \ No newline at end of file diff --git a/g4f/Provider/needs_auth/Groq.py b/g4f/Provider/needs_auth/Groq.py index 943fc81a..e9f3fad9 100644 --- a/g4f/Provider/needs_auth/Groq.py +++ b/g4f/Provider/needs_auth/Groq.py @@ -6,9 +6,10 @@ from ...typing import AsyncResult, Messages class Groq(OpenaiAPI): label = "Groq" url = "https://console.groq.com/playground" + api_base = "https://api.groq.com/openai/v1" working = True default_model = "mixtral-8x7b-32768" - models = [ + fallback_models = [ "distil-whisper-large-v3-en", "gemma2-9b-it", "gemma-7b-it", @@ -35,9 +36,9 @@ class Groq(OpenaiAPI): cls, model: str, messages: Messages, - api_base: str = "https://api.groq.com/openai/v1", + api_base: str = api_base, **kwargs ) -> AsyncResult: return super().create_async_generator( model, messages, api_base=api_base, **kwargs - ) + ) \ No newline at end of file diff --git a/g4f/Provider/needs_auth/HuggingFace.py b/g4f/Provider/needs_auth/HuggingFace.py index 3fa5a624..6887ac4d 100644 --- a/g4f/Provider/needs_auth/HuggingFace.py +++ b/g4f/Provider/needs_auth/HuggingFace.py @@ -6,8 +6,8 @@ import random import requests from ...typing import AsyncResult, Messages -from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin -from ...errors import ModelNotFoundError, ModelNotSupportedError +from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, format_prompt +from ...errors import ModelNotFoundError, ModelNotSupportedError, ResponseError from ...requests import StreamSession, raise_for_status from ...image import ImageResponse @@ -28,9 +28,11 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): cls.models = [model["id"] for model in requests.get(url).json()] cls.models.append("meta-llama/Llama-3.2-11B-Vision-Instruct") cls.models.append("nvidia/Llama-3.1-Nemotron-70B-Instruct-HF") + cls.models.sort() if not cls.image_models: url = "https://huggingface.co/api/models?pipeline_tag=text-to-image" cls.image_models = [model["id"] for model in requests.get(url).json() if model["trendingScore"] >= 20] + cls.image_models.sort() cls.models.extend(cls.image_models) return cls.models @@ -89,19 +91,27 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): ) as session: if payload is None: async with session.get(f"https://huggingface.co/api/models/{model}") as response: + await raise_for_status(response) model_data = await response.json() - if "config" in model_data and "tokenizer_config" in model_data["config"] and "eos_token" in model_data["config"]["tokenizer_config"]: + model_type = None + if "config" in model_data and "model_type" in model_data["config"]: + model_type = model_data["config"]["model_type"] + if model_type in ("gpt2", "gpt_neo", "gemma", "gemma2"): + inputs = format_prompt(messages) + elif "config" in model_data and "tokenizer_config" in model_data["config"] and "eos_token" in model_data["config"]["tokenizer_config"]: eos_token = model_data["config"]["tokenizer_config"]["eos_token"] - if eos_token == "": - inputs = format_prompt_mistral(messages) + if eos_token in ("<|endoftext|>", "", ""): + inputs = format_prompt_custom(messages, eos_token) elif eos_token == "<|im_end|>": inputs = format_prompt_qwen(messages) elif eos_token == "<|eot_id|>": inputs = format_prompt_llama(messages) else: - inputs = format_prompt(messages) + inputs = format_prompt_default(messages) else: - inputs = format_prompt(messages) + inputs = format_prompt_default(messages) + if model_type == "gpt2" and max_new_tokens >= 1024: + params["max_new_tokens"] = 512 payload = {"inputs": inputs, "parameters": params, "stream": stream} async with session.post(f"{api_base.rstrip('/')}/models/{model}", json=payload) as response: @@ -113,6 +123,8 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): async for line in response.iter_lines(): if line.startswith(b"data:"): data = json.loads(line[5:]) + if "error" in data: + raise ResponseError(data["error"]) if not data["token"]["special"]: chunk = data["token"]["text"] if first: @@ -128,7 +140,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): else: yield (await response.json())[0]["generated_text"].strip() -def format_prompt(messages: Messages) -> str: +def format_prompt_default(messages: Messages) -> str: system_messages = [message["content"] for message in messages if message["role"] == "system"] question = " ".join([messages[-1]["content"], *system_messages]) history = "".join([ @@ -146,9 +158,9 @@ def format_prompt_qwen(messages: Messages) -> str: def format_prompt_llama(messages: Messages) -> str: return "<|begin_of_text|>" + "".join([ f"<|start_header_id|>{message['role']}<|end_header_id|>\n\n{message['content']}\n<|eot_id|>\n" for message in messages - ]) + "<|start_header_id|>assistant<|end_header_id|>\\n\\n" - -def format_prompt_mistral(messages: Messages) -> str: + ]) + "<|start_header_id|>assistant<|end_header_id|>\n\n" + +def format_prompt_custom(messages: Messages, end_token: str = "") -> str: return "".join([ - f"<|{message['role']}|>\n{message['content']}'\n" for message in messages + f"<|{message['role']}|>\n{message['content']}{end_token}\n" for message in messages ]) + "<|assistant|>\n" \ No newline at end of file diff --git a/g4f/Provider/needs_auth/HuggingFaceAPI.py b/g4f/Provider/needs_auth/HuggingFaceAPI.py index a93ab3a6..661491b2 100644 --- a/g4f/Provider/needs_auth/HuggingFaceAPI.py +++ b/g4f/Provider/needs_auth/HuggingFaceAPI.py @@ -7,6 +7,7 @@ from ...typing import AsyncResult, Messages class HuggingFaceAPI(OpenaiAPI): label = "HuggingFace (Inference API)" url = "https://api-inference.huggingface.co" + api_base = "https://api-inference.huggingface.co/v1" working = True default_model = "meta-llama/Llama-3.2-11B-Vision-Instruct" default_vision_model = default_model @@ -19,7 +20,7 @@ class HuggingFaceAPI(OpenaiAPI): cls, model: str, messages: Messages, - api_base: str = "https://api-inference.huggingface.co/v1", + api_base: str = api_base, max_tokens: int = 500, **kwargs ) -> AsyncResult: diff --git a/g4f/Provider/needs_auth/OpenaiAPI.py b/g4f/Provider/needs_auth/OpenaiAPI.py index e4731ae2..ebc4d519 100644 --- a/g4f/Provider/needs_auth/OpenaiAPI.py +++ b/g4f/Provider/needs_auth/OpenaiAPI.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import requests from ..helper import filter_none from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, FinishReason @@ -8,15 +9,35 @@ from ...typing import Union, Optional, AsyncResult, Messages, ImagesType from ...requests import StreamSession, raise_for_status from ...errors import MissingAuthError, ResponseError from ...image import to_data_uri +from ... import debug class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin): label = "OpenAI API" url = "https://platform.openai.com" + api_base = "https://api.openai.com/v1" working = True needs_auth = True supports_message_history = True supports_system_message = True default_model = "" + fallback_models = [] + + @classmethod + def get_models(cls, api_key: str = None): + if not cls.models: + try: + headers = {} + if api_key is not None: + headers["authorization"] = f"Bearer {api_key}" + response = requests.get(f"{cls.api_base}/models", headers=headers) + raise_for_status(response) + data = response.json() + cls.models = [model.get("id") for model in data.get("data")] + cls.models.sort() + except Exception as e: + debug.log(e) + cls.models = cls.fallback_models + return cls.models @classmethod async def create_async_generator( @@ -27,7 +48,7 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin): timeout: int = 120, images: ImagesType = None, api_key: str = None, - api_base: str = "https://api.openai.com/v1", + api_base: str = api_base, temperature: float = None, max_tokens: int = None, top_p: float = None, @@ -47,14 +68,14 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin): *[{ "type": "image_url", "image_url": {"url": to_data_uri(image)} - } for image, image_name in images], + } for image, _ in images], { "type": "text", "text": messages[-1]["content"] } ] async with StreamSession( - proxies={"all": proxy}, + proxy=proxy, headers=cls.get_headers(stream, api_key, headers), timeout=timeout, impersonate=impersonate, @@ -111,7 +132,10 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin): if "error_message" in data: raise ResponseError(data["error_message"]) elif "error" in data: - raise ResponseError(f'Error {data["error"]["code"]}: {data["error"]["message"]}') + if "code" in data["error"]: + raise ResponseError(f'Error {data["error"]["code"]}: {data["error"]["message"]}') + else: + raise ResponseError(data["error"]["message"]) @classmethod def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict: diff --git a/g4f/Provider/needs_auth/__init__.py b/g4f/Provider/needs_auth/__init__.py index d79e7e3d..94285361 100644 --- a/g4f/Provider/needs_auth/__init__.py +++ b/g4f/Provider/needs_auth/__init__.py @@ -26,3 +26,4 @@ from .Replicate import Replicate from .Theb import Theb from .ThebApi import ThebApi from .WhiteRabbitNeo import WhiteRabbitNeo +from .xAI import xAI \ No newline at end of file diff --git a/g4f/Provider/needs_auth/xAI.py b/g4f/Provider/needs_auth/xAI.py new file mode 100644 index 00000000..0ffeff3b --- /dev/null +++ b/g4f/Provider/needs_auth/xAI.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from .OpenaiAPI import OpenaiAPI +from ...typing import AsyncResult, Messages + +class xAI(OpenaiAPI): + label = "xAI" + url = "https://console.x.ai" + api_base = "https://api.x.ai/v1" + working = True + + @classmethod + def create_async_generator( + cls, + model: str, + messages: Messages, + api_base: str = api_base, + **kwargs + ) -> AsyncResult: + return super().create_async_generator( + model, messages, api_base=api_base, **kwargs + ) \ No newline at end of file -- cgit v1.2.3