From 6ce493d4dfc2884832ff5b5be4479a55818b2fe7 Mon Sep 17 00:00:00 2001 From: H Lohaus Date: Sat, 16 Nov 2024 13:19:51 +0100 Subject: Fix api streaming, fix AsyncClient (#2357) * Fix api streaming, fix AsyncClient, Improve Client class, Some providers fixes, Update models list, Fix some tests, Update model list in Airforce provid er, Add OpenAi image generation url to api, Fix reload and debug in api arguments, Fix websearch in gui * Fix Cloadflare and Pi and AmigoChat provider * Fix conversation support in DDG provider, Add cloudflare bypass with nodriver * Fix unittests without curl_cffi --- g4f/Provider/Cloudflare.py | 136 ++++++++++++--------------------------------- 1 file changed, 37 insertions(+), 99 deletions(-) (limited to 'g4f/Provider/Cloudflare.py') diff --git a/g4f/Provider/Cloudflare.py b/g4f/Provider/Cloudflare.py index 8fb37bef..825c5027 100644 --- a/g4f/Provider/Cloudflare.py +++ b/g4f/Provider/Cloudflare.py @@ -1,72 +1,52 @@ from __future__ import annotations -from aiohttp import ClientSession import asyncio import json import uuid -import cloudscraper -from typing import AsyncGenerator -from ..typing import AsyncResult, Messages -from .base_provider import AsyncGeneratorProvider, ProviderModelMixin -from .helper import format_prompt +from ..typing import AsyncResult, Messages, Cookies +from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, get_running_loop +from ..requests import Session, StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin): label = "Cloudflare AI" url = "https://playground.ai.cloudflare.com" api_endpoint = "https://playground.ai.cloudflare.com/api/inference" + models_url = "https://playground.ai.cloudflare.com/api/models" working = True supports_stream = True supports_system_message = True supports_message_history = True - - default_model = '@cf/meta/llama-3.1-8b-instruct-awq' - models = [ - '@cf/meta/llama-2-7b-chat-fp16', - '@cf/meta/llama-2-7b-chat-int8', - - '@cf/meta/llama-3-8b-instruct', - '@cf/meta/llama-3-8b-instruct-awq', - '@hf/meta-llama/meta-llama-3-8b-instruct', - - default_model, - '@cf/meta/llama-3.1-8b-instruct-fp8', - - '@cf/meta/llama-3.2-1b-instruct', - - '@hf/mistral/mistral-7b-instruct-v0.2', - - '@cf/qwen/qwen1.5-7b-chat-awq', - - '@cf/defog/sqlcoder-7b-2', - ] - + default_model = "@cf/meta/llama-3.1-8b-instruct" model_aliases = { "llama-2-7b": "@cf/meta/llama-2-7b-chat-fp16", "llama-2-7b": "@cf/meta/llama-2-7b-chat-int8", - "llama-3-8b": "@cf/meta/llama-3-8b-instruct", "llama-3-8b": "@cf/meta/llama-3-8b-instruct-awq", "llama-3-8b": "@hf/meta-llama/meta-llama-3-8b-instruct", - "llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-awq", "llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-fp8", - "llama-3.2-1b": "@cf/meta/llama-3.2-1b-instruct", - "qwen-1.5-7b": "@cf/qwen/qwen1.5-7b-chat-awq", - - #"sqlcoder-7b": "@cf/defog/sqlcoder-7b-2", } + _args: dict = None @classmethod - def get_model(cls, model: str) -> str: - if model in cls.models: - return model - elif model in cls.model_aliases: - return cls.model_aliases[model] - else: - return cls.default_model + def get_models(cls) -> str: + if not cls.models: + if cls._args is None: + get_running_loop(check_nested=True) + args = get_args_from_nodriver(cls.url, cookies={ + '__cf_bm': uuid.uuid4().hex, + }) + cls._args = asyncio.run(args) + with Session(**cls._args) as session: + response = session.get(cls.models_url) + raise_for_status(response) + json_data = response.json() + cls.models = [model.get("name") for model in json_data.get("models")] + cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response) + return cls.models @classmethod async def create_async_generator( @@ -75,76 +55,34 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin): messages: Messages, proxy: str = None, max_tokens: int = 2048, + cookies: Cookies = None, + timeout: int = 300, **kwargs ) -> AsyncResult: model = cls.get_model(model) - - headers = { - 'Accept': 'text/event-stream', - 'Accept-Language': 'en-US,en;q=0.9', - 'Cache-Control': 'no-cache', - 'Content-Type': 'application/json', - 'Origin': cls.url, - 'Pragma': 'no-cache', - 'Referer': f'{cls.url}/', - 'Sec-Ch-Ua': '"Chromium";v="129", "Not=A?Brand";v="8"', - 'Sec-Ch-Ua-Mobile': '?0', - 'Sec-Ch-Ua-Platform': '"Linux"', - 'Sec-Fetch-Dest': 'empty', - 'Sec-Fetch-Mode': 'cors', - 'Sec-Fetch-Site': 'same-origin', - 'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36', - } - - cookies = { - '__cf_bm': uuid.uuid4().hex, - } - - scraper = cloudscraper.create_scraper() - + if cls._args is None: + cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies) data = { - "messages": [ - {"role": "user", "content": format_prompt(messages)} - ], + "messages": messages, "lora": None, "model": model, "max_tokens": max_tokens, "stream": True } - - max_retries = 3 - full_response = "" - - for attempt in range(max_retries): - try: - response = scraper.post( - cls.api_endpoint, - headers=headers, - cookies=cookies, - json=data, - stream=True, - proxies={'http': proxy, 'https': proxy} if proxy else None - ) - - if response.status_code == 403: - await asyncio.sleep(2 ** attempt) - continue - - response.raise_for_status() - - for line in response.iter_lines(): + async with StreamSession(**cls._args) as session: + async with session.post( + cls.api_endpoint, + json=data, + ) as response: + await raise_for_status(response) + cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response) + async for line in response.iter_lines(): if line.startswith(b'data: '): if line == b'data: [DONE]': - if full_response: - yield full_response break try: - content = json.loads(line[6:].decode('utf-8')) - if 'response' in content and content['response'] != '': + content = json.loads(line[6:].decode()) + if content.get("response") and content.get("response") != '': yield content['response'] except Exception: - continue - break - except Exception as e: - if attempt == max_retries - 1: - raise + continue \ No newline at end of file -- cgit v1.2.3