From 7893a0835e2c3f4e06c1ccfaec9baef1bdacea7d Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Wed, 1 Jan 2025 04:20:02 +0100 Subject: Add filessupport, scrape and refine your data Remove Webdriver usages Add continue messages for other providers --- etc/unittest/backend.py | 2 +- g4f/Provider/Blackbox.py | 2 +- g4f/Provider/Cloudflare.py | 19 +- g4f/Provider/bing/create_images.py | 34 --- g4f/Provider/needs_auth/HuggingFace.py | 44 +-- g4f/Provider/needs_auth/OpenaiChat.py | 5 +- g4f/Provider/needs_auth/Poe.py | 3 +- g4f/Provider/needs_auth/Theb.py | 3 +- g4f/Provider/not_working/Aura.py | 4 +- g4f/Provider/not_working/MyShell.py | 3 +- g4f/Provider/selenium/PerplexityAi.py | 3 +- g4f/Provider/selenium/Phind.py | 3 +- g4f/Provider/selenium/TalkAi.py | 3 +- g4f/api/__init__.py | 37 ++- g4f/api/stubs.py | 4 + g4f/client/__init__.py | 49 +--- g4f/cookies.py | 2 +- g4f/gui/__init__.py | 7 +- g4f/gui/client/index.html | 7 +- g4f/gui/client/static/css/style.css | 35 ++- g4f/gui/client/static/js/chat.v1.js | 249 ++++++++++++---- g4f/gui/run.py | 2 +- g4f/gui/server/api.py | 54 ++-- g4f/gui/server/app.py | 12 +- g4f/gui/server/backend.py | 180 ------------ g4f/gui/server/backend_api.py | 264 +++++++++++++++++ g4f/gui/server/internet.py | 2 +- g4f/requests/__init__.py | 65 ----- g4f/requests/raise_for_status.py | 2 +- g4f/tools/files.py | 511 +++++++++++++++++++++++++++++++++ g4f/tools/run_tools.py | 87 ++++++ g4f/tools/web_search.py | 230 +++++++++++++++ g4f/web_search.py | 172 ----------- g4f/webdriver.py | 257 ----------------- setup.py | 10 + 35 files changed, 1481 insertions(+), 885 deletions(-) delete mode 100644 g4f/gui/server/backend.py create mode 100644 g4f/gui/server/backend_api.py create mode 100644 g4f/tools/files.py create mode 100644 g4f/tools/run_tools.py create mode 100644 g4f/tools/web_search.py delete mode 100644 g4f/web_search.py delete mode 100644 g4f/webdriver.py diff --git a/etc/unittest/backend.py b/etc/unittest/backend.py index 75ab6b47..d3883156 100644 --- a/etc/unittest/backend.py +++ b/etc/unittest/backend.py @@ -5,7 +5,7 @@ import asyncio from unittest.mock import MagicMock from g4f.errors import MissingRequirementsError try: - from g4f.gui.server.backend import Backend_Api + from g4f.gui.server.backend_api import Backend_Api has_requirements = True except: has_requirements = False diff --git a/g4f/Provider/Blackbox.py b/g4f/Provider/Blackbox.py index c4b14e20..d3cd1ea1 100644 --- a/g4f/Provider/Blackbox.py +++ b/g4f/Provider/Blackbox.py @@ -14,7 +14,7 @@ from ..typing import AsyncResult, Messages, ImagesType from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..image import ImageResponse, to_data_uri from ..cookies import get_cookies_dir -from ..web_search import get_search_message +from ..tools.web_search import get_search_message from .helper import format_prompt from .. import debug diff --git a/g4f/Provider/Cloudflare.py b/g4f/Provider/Cloudflare.py index 132d4780..f69b8128 100644 --- a/g4f/Provider/Cloudflare.py +++ b/g4f/Provider/Cloudflare.py @@ -5,8 +5,10 @@ import json from ..typing import AsyncResult, Messages, Cookies from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, get_running_loop -from ..requests import Session, StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies, DEFAULT_HEADERS, has_nodriver, has_curl_cffi -from ..errors import ResponseStatusError +from ..requests import Session, StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies +from ..requests import DEFAULT_HEADERS, has_nodriver, has_curl_cffi +from ..providers.response import FinishReason +from ..errors import ResponseStatusError, ModelNotFoundError class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin): label = "Cloudflare AI" @@ -70,7 +72,10 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin): cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies) else: cls._args = {"headers": DEFAULT_HEADERS, "cookies": {}} - model = cls.get_model(model) + try: + model = cls.get_model(model) + except ModelNotFoundError: + pass data = { "messages": messages, "lora": None, @@ -89,6 +94,7 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin): except ResponseStatusError: cls._args = None raise + reason = None async for line in response.iter_lines(): if line.startswith(b'data: '): if line == b'data: [DONE]': @@ -97,5 +103,10 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin): content = json.loads(line[6:].decode()) if content.get("response") and content.get("response") != '': yield content['response'] + reason = "max_tokens" + elif content.get("response") == '': + reason = "stop" except Exception: - continue \ No newline at end of file + continue + if reason is not None: + yield FinishReason(reason) \ No newline at end of file diff --git a/g4f/Provider/bing/create_images.py b/g4f/Provider/bing/create_images.py index 45ba30b6..bcc88d1f 100644 --- a/g4f/Provider/bing/create_images.py +++ b/g4f/Provider/bing/create_images.py @@ -15,7 +15,6 @@ except ImportError: from ..helper import get_connector from ...errors import MissingRequirementsError, RateLimitError -from ...webdriver import WebDriver, get_driver_cookies, get_browser BING_URL = "https://www.bing.com" TIMEOUT_LOGIN = 1200 @@ -31,39 +30,6 @@ BAD_IMAGES = [ "https://r.bing.com/rp/TX9QuO3WzcCJz1uaaSwQAz39Kb0.jpg", ] -def wait_for_login(driver: WebDriver, timeout: int = TIMEOUT_LOGIN) -> None: - """ - Waits for the user to log in within a given timeout period. - - Args: - driver (WebDriver): Webdriver for browser automation. - timeout (int): Maximum waiting time in seconds. - - Raises: - RuntimeError: If the login process exceeds the timeout. - """ - driver.get(f"{BING_URL}/") - start_time = time.time() - while not driver.get_cookie("_U"): - if time.time() - start_time > timeout: - raise RuntimeError("Timeout error") - time.sleep(0.5) - -def get_cookies_from_browser(proxy: str = None) -> dict[str, str]: - """ - Retrieves cookies from the browser using webdriver. - - Args: - proxy (str, optional): Proxy configuration. - - Returns: - dict[str, str]: Retrieved cookies. - """ - with get_browser(proxy=proxy) as driver: - wait_for_login(driver) - time.sleep(1) - return get_driver_cookies(driver) - def create_session(cookies: Dict[str, str], proxy: str = None, connector: BaseConnector = None) -> ClientSession: """ Creates a new client session with specified cookies and headers. diff --git a/g4f/Provider/needs_auth/HuggingFace.py b/g4f/Provider/needs_auth/HuggingFace.py index f77b68cb..05e69072 100644 --- a/g4f/Provider/needs_auth/HuggingFace.py +++ b/g4f/Provider/needs_auth/HuggingFace.py @@ -102,22 +102,15 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): if "config" in model_data and "model_type" in model_data["config"]: model_type = model_data["config"]["model_type"] debug.log(f"Model type: {model_type}") - if model_type in ("gpt2", "gpt_neo", "gemma", "gemma2"): - inputs = format_prompt(messages, do_continue=do_continue) - elif model_type in ("mistral"): - inputs = format_prompt_mistral(messages, do_continue) - elif "config" in model_data and "tokenizer_config" in model_data["config"] and "eos_token" in model_data["config"]["tokenizer_config"]: - eos_token = model_data["config"]["tokenizer_config"]["eos_token"] - if eos_token in ("<|endoftext|>", "", ""): - inputs = format_prompt_custom(messages, eos_token, do_continue) - elif eos_token == "<|im_end|>": - inputs = format_prompt_qwen(messages, do_continue) - elif eos_token == "<|eot_id|>": - inputs = format_prompt_llama(messages, do_continue) + inputs = get_inputs(messages, model_data, model_type, do_continue) + debug.log(f"Inputs len: {len(inputs)}") + if len(inputs) > 4096: + if len(messages) > 6: + messages = messages[:3] + messages[-3:] else: - inputs = format_prompt(messages, do_continue=do_continue) - else: - inputs = format_prompt(messages, do_continue=do_continue) + messages = [m for m in messages if m["role"] == "system"] + [messages[-1]] + inputs = get_inputs(messages, model_data, model_type, do_continue) + debug.log(f"New len: {len(inputs)}") if model_type == "gpt2" and max_new_tokens >= 1024: params["max_new_tokens"] = 512 payload = {"inputs": inputs, "parameters": params, "stream": stream} @@ -187,4 +180,23 @@ def format_prompt_custom(messages: Messages, end_token: str = "", do_continu ]) + ("" if do_continue else "<|assistant|>\n") if do_continue: return prompt[:-len(end_token + "\n")] - return prompt \ No newline at end of file + return prompt + +def get_inputs(messages: Messages, model_data: dict, model_type: str, do_continue: bool = False) -> str: + if model_type in ("gpt2", "gpt_neo", "gemma", "gemma2"): + inputs = format_prompt(messages, do_continue=do_continue) + elif model_type in ("mistral"): + inputs = format_prompt_mistral(messages, do_continue) + elif "config" in model_data and "tokenizer_config" in model_data["config"] and "eos_token" in model_data["config"]["tokenizer_config"]: + eos_token = model_data["config"]["tokenizer_config"]["eos_token"] + if eos_token in ("<|endoftext|>", "", ""): + inputs = format_prompt_custom(messages, eos_token, do_continue) + elif eos_token == "<|im_end|>": + inputs = format_prompt_qwen(messages, do_continue) + elif eos_token == "<|eot_id|>": + inputs = format_prompt_llama(messages, do_continue) + else: + inputs = format_prompt(messages, do_continue=do_continue) + else: + inputs = format_prompt(messages, do_continue=do_continue) + return inputs \ No newline at end of file diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index 652cbfc2..f15431e4 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -404,7 +404,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): data["conversation_id"] = conversation.conversation_id debug.log(f"OpenaiChat: Use conversation: {conversation.conversation_id}") if action != "continue": - data["parent_message_id"] = conversation.parent_message_id + data["parent_message_id"] = getattr(conversation, "parent_message_id", conversation.message_id) conversation.parent_message_id = None messages = messages if conversation_id is None else [messages[-1]] data["messages"] = cls.create_messages(messages, image_requests, ["search"] if web_search else None) @@ -604,7 +604,6 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): "api_key": cls._api_key, "proof_token": RequestConfig.proof_token, "cookies": RequestConfig.cookies, - "headers": RequestConfig.headers }) @classmethod @@ -636,6 +635,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): page = await browser.get(cls.url) user_agent = await page.evaluate("window.navigator.userAgent") await page.select("#prompt-textarea", 240) + await page.evaluate("document.getElementById('prompt-textarea').innerText = 'Hello'") + await page.evaluate("document.querySelector('[data-testid=\"send-button\"]').click()") while True: if cls._api_key is not None: break diff --git a/g4f/Provider/needs_auth/Poe.py b/g4f/Provider/needs_auth/Poe.py index 46b998e8..a0ef7453 100644 --- a/g4f/Provider/needs_auth/Poe.py +++ b/g4f/Provider/needs_auth/Poe.py @@ -5,7 +5,6 @@ import time from ...typing import CreateResult, Messages from ..base_provider import AbstractProvider from ..helper import format_prompt -from ...webdriver import WebDriver, WebDriverSession, element_send_text models = { "meta-llama/Llama-2-7b-chat-hf": {"name": "Llama-2-7b"}, @@ -22,7 +21,7 @@ models = { class Poe(AbstractProvider): url = "https://poe.com" - working = True + working = False needs_auth = True supports_stream = True diff --git a/g4f/Provider/needs_auth/Theb.py b/g4f/Provider/needs_auth/Theb.py index 7d3de027..f0600e4b 100644 --- a/g4f/Provider/needs_auth/Theb.py +++ b/g4f/Provider/needs_auth/Theb.py @@ -5,7 +5,6 @@ import time from ...typing import CreateResult, Messages from ..base_provider import AbstractProvider from ..helper import format_prompt -from ...webdriver import WebDriver, WebDriverSession, element_send_text models = { "theb-ai": "TheB.AI", @@ -34,7 +33,7 @@ models = { class Theb(AbstractProvider): label = "TheB.AI" url = "https://beta.theb.ai" - working = True + working = False supports_stream = True models = models.keys() diff --git a/g4f/Provider/not_working/Aura.py b/g4f/Provider/not_working/Aura.py index e841d909..2881dd14 100644 --- a/g4f/Provider/not_working/Aura.py +++ b/g4f/Provider/not_working/Aura.py @@ -4,8 +4,6 @@ from aiohttp import ClientSession from ...typing import AsyncResult, Messages from ..base_provider import AsyncGeneratorProvider -from ...requests import get_args_from_browser -from ...webdriver import WebDriver class Aura(AsyncGeneratorProvider): url = "https://openchat.team" @@ -19,7 +17,7 @@ class Aura(AsyncGeneratorProvider): proxy: str = None, temperature: float = 0.5, max_tokens: int = 8192, - webdriver: WebDriver = None, + webdriver = None, **kwargs ) -> AsyncResult: args = get_args_from_browser(cls.url, webdriver, proxy) diff --git a/g4f/Provider/not_working/MyShell.py b/g4f/Provider/not_working/MyShell.py index 02e182d4..24dfca9d 100644 --- a/g4f/Provider/not_working/MyShell.py +++ b/g4f/Provider/not_working/MyShell.py @@ -5,7 +5,6 @@ import time, json from ...typing import CreateResult, Messages from ..base_provider import AbstractProvider from ..helper import format_prompt -from ...webdriver import WebDriver, WebDriverSession, bypass_cloudflare class MyShell(AbstractProvider): url = "https://app.myshell.ai/chat" @@ -21,7 +20,7 @@ class MyShell(AbstractProvider): stream: bool, proxy: str = None, timeout: int = 120, - webdriver: WebDriver = None, + webdriver = None, **kwargs ) -> CreateResult: with WebDriverSession(webdriver, "", proxy=proxy) as driver: diff --git a/g4f/Provider/selenium/PerplexityAi.py b/g4f/Provider/selenium/PerplexityAi.py index d965dbf7..0f6c3f68 100644 --- a/g4f/Provider/selenium/PerplexityAi.py +++ b/g4f/Provider/selenium/PerplexityAi.py @@ -12,7 +12,6 @@ except ImportError: from ...typing import CreateResult, Messages from ..base_provider import AbstractProvider from ..helper import format_prompt -from ...webdriver import WebDriver, WebDriverSession, element_send_text class PerplexityAi(AbstractProvider): url = "https://www.perplexity.ai" @@ -28,7 +27,7 @@ class PerplexityAi(AbstractProvider): stream: bool, proxy: str = None, timeout: int = 120, - webdriver: WebDriver = None, + webdriver = None, virtual_display: bool = True, copilot: bool = False, **kwargs diff --git a/g4f/Provider/selenium/Phind.py b/g4f/Provider/selenium/Phind.py index b6f7cc07..d17eb27e 100644 --- a/g4f/Provider/selenium/Phind.py +++ b/g4f/Provider/selenium/Phind.py @@ -6,7 +6,6 @@ from urllib.parse import quote from ...typing import CreateResult, Messages from ..base_provider import AbstractProvider from ..helper import format_prompt -from ...webdriver import WebDriver, WebDriverSession class Phind(AbstractProvider): url = "https://www.phind.com" @@ -22,7 +21,7 @@ class Phind(AbstractProvider): stream: bool, proxy: str = None, timeout: int = 120, - webdriver: WebDriver = None, + webdriver = None, creative_mode: bool = None, **kwargs ) -> CreateResult: diff --git a/g4f/Provider/selenium/TalkAi.py b/g4f/Provider/selenium/TalkAi.py index a7b63375..d722022d 100644 --- a/g4f/Provider/selenium/TalkAi.py +++ b/g4f/Provider/selenium/TalkAi.py @@ -4,7 +4,6 @@ import time, json, time from ...typing import CreateResult, Messages from ..base_provider import AbstractProvider -from ...webdriver import WebDriver, WebDriverSession class TalkAi(AbstractProvider): url = "https://talkai.info" @@ -19,7 +18,7 @@ class TalkAi(AbstractProvider): messages: Messages, stream: bool, proxy: str = None, - webdriver: WebDriver = None, + webdriver = None, **kwargs ) -> CreateResult: with WebDriverSession(webdriver, "", virtual_display=True, proxy=proxy) as driver: diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py index 4b3f0580..ed28a427 100644 --- a/g4f/api/__init__.py +++ b/g4f/api/__init__.py @@ -41,11 +41,12 @@ from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthErr from g4f.cookies import read_cookie_files, get_cookies_dir from g4f.Provider import ProviderType, ProviderUtils, __providers__ from g4f.gui import get_gui_app +from g4f.tools.files import supports_filename, get_streaming from .stubs import ( ChatCompletionsConfig, ImageGenerationConfig, ProviderResponseModel, ModelResponseModel, ErrorResponseModel, ProviderResponseDetailModel, - FileResponseModel, Annotated + FileResponseModel, UploadResponseModel, Annotated ) logger = logging.getLogger(__name__) @@ -424,6 +425,40 @@ class Api: read_cookie_files() return response_data + @self.app.get("/v1/files/{bucket_id}", responses={ + HTTP_200_OK: {"content": { + "text/event-stream": {"schema": {"type": "string"}}, + "text/plain": {"schema": {"type": "string"}}, + }}, + HTTP_404_NOT_FOUND: {"model": ErrorResponseModel}, + }) + def read_files(request: Request, bucket_id: str, delete_files: bool = True, refine_chunks_with_spacy: bool = False): + bucket_dir = os.path.join(get_cookies_dir(), bucket_id) + event_stream = "text/event-stream" in request.headers.get("accept", "") + if not os.path.isdir(bucket_dir): + return ErrorResponse.from_message("Bucket dir not found", 404) + return StreamingResponse(get_streaming(bucket_dir, delete_files, refine_chunks_with_spacy, event_stream), media_type="text/plain") + + @self.app.post("/v1/files/{bucket_id}", responses={ + HTTP_200_OK: {"model": UploadResponseModel} + }) + def upload_files(bucket_id: str, files: List[UploadFile]): + bucket_dir = os.path.join(get_cookies_dir(), bucket_id) + os.makedirs(bucket_dir, exist_ok=True) + filenames = [] + for file in files: + try: + filename = os.path.basename(file.filename) + if file and supports_filename(filename): + with open(os.path.join(bucket_dir, filename), 'wb') as f: + shutil.copyfileobj(file.file, f) + filenames.append(filename) + finally: + file.file.close() + with open(os.path.join(bucket_dir, "files.txt"), 'w') as f: + [f.write(f"{filename}\n") for filename in filenames] + return {"bucket_id": bucket_id, "url": f"/v1/files/{bucket_id}", "files": filenames} + @self.app.get("/v1/synthesize/{provider}", responses={ HTTP_200_OK: {"content": {"audio/*": {}}}, HTTP_404_NOT_FOUND: {"model": ErrorResponseModel}, diff --git a/g4f/api/stubs.py b/g4f/api/stubs.py index 6f11b49c..f021079e 100644 --- a/g4f/api/stubs.py +++ b/g4f/api/stubs.py @@ -66,6 +66,10 @@ class ModelResponseModel(BaseModel): created: int owned_by: Optional[str] +class UploadResponseModel(BaseModel): + bucket_id: str + url: str + class ErrorResponseModel(BaseModel): error: ErrorResponseMessageModel model: Optional[str] = None diff --git a/g4f/client/__init__.py b/g4f/client/__init__.py index 19e27619..95499455 100644 --- a/g4f/client/__init__.py +++ b/g4f/client/__init__.py @@ -7,7 +7,6 @@ import string import asyncio import aiohttp import base64 -import json from typing import Union, AsyncIterator, Iterator, Coroutine, Optional from ..image import ImageResponse, copy_images, images_dir @@ -17,13 +16,13 @@ from ..providers.response import ResponseType, FinishReason, BaseConversation, S from ..errors import NoImageResponseError from ..providers.retry_provider import IterListProvider from ..providers.asyncio import to_sync_generator, async_generator_to_list -from ..web_search import get_search_message, do_search from ..Provider.needs_auth import BingCreateImages, OpenaiAccount +from ..tools.run_tools import async_iter_run_tools, iter_run_tools from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse from .image_models import ImageModels from .types import IterResponse, ImageProvider, Client as BaseClient from .service import get_model_and_provider, convert_to_provider -from .helper import find_stop, filter_json, filter_none, safe_aclose, to_async_iterator +from .helper import find_stop, filter_json, filter_none, safe_aclose from .. import debug ChatCompletionResponseType = Iterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]] @@ -38,47 +37,6 @@ except NameError: except StopAsyncIteration: raise StopIteration -def validate_arguments(data: dict) -> dict: - if "arguments" in data: - if isinstance(data["arguments"], str): - data["arguments"] = json.loads(data["arguments"]) - if not isinstance(data["arguments"], dict): - raise ValueError("Tool function arguments must be a dictionary or a json string") - else: - return filter_none(**data["arguments"]) - else: - return {} - -async def async_iter_run_tools(async_iter_callback, model, messages, tool_calls: Optional[list] = None, **kwargs): - if tool_calls is not None: - for tool in tool_calls: - if tool.get("type") == "function": - if tool.get("function", {}).get("name") == "search_tool": - tool["function"]["arguments"] = validate_arguments(tool["function"]) - messages = messages.copy() - messages[-1]["content"] = await do_search( - messages[-1]["content"], - **tool["function"]["arguments"] - ) - response = async_iter_callback(model=model, messages=messages, **kwargs) - if not hasattr(response, "__aiter__"): - response = to_async_iterator(response) - async for chunk in response: - yield chunk - -def iter_run_tools(iter_callback, model, messages, tool_calls: Optional[list] = None, **kwargs): - if tool_calls is not None: - for tool in tool_calls: - if tool.get("type") == "function": - if tool.get("function", {}).get("name") == "search_tool": - tool["function"]["arguments"] = validate_arguments(tool["function"]) - messages[-1]["content"] = get_search_message( - messages[-1]["content"], - raise_search_exceptions=True, - **tool["function"]["arguments"] - ) - return iter_callback(model=model, messages=messages, **kwargs) - # Synchronous iter_response function def iter_response( response: Union[Iterator[Union[str, ResponseType]]], @@ -131,7 +89,8 @@ def iter_response( break idx += 1 - + if usage is None: + usage = Usage(prompt_tokens=0, completion_tokens=idx, total_tokens=idx).get_dict() finish_reason = "stop" if finish_reason is None else finish_reason if stream: diff --git a/g4f/cookies.py b/g4f/cookies.py index afc9245b..2328f0f2 100644 --- a/g4f/cookies.py +++ b/g4f/cookies.py @@ -180,7 +180,7 @@ def read_cookie_files(dirPath: str = None): except json.JSONDecodeError: # Error: not a json file! continue - if not isinstance(cookieFile, list): + if not isinstance(cookieFile, list) or not isinstance(cookieFile[0], dict) or "domain" not in cookieFile[0]: continue debug.log(f"Read cookie file: {path}") new_cookies = {} diff --git a/g4f/gui/__init__.py b/g4f/gui/__init__.py index b5550631..4dc286ad 100644 --- a/g4f/gui/__init__.py +++ b/g4f/gui/__init__.py @@ -1,9 +1,9 @@ from ..errors import MissingRequirementsError try: - from .server.app import app from .server.website import Website - from .server.backend import Backend_Api + from .server.backend_api import Backend_Api + from .server.app import create_app import_error = None except ImportError as e: import_error = e @@ -11,6 +11,7 @@ except ImportError as e: def get_gui_app(): if import_error is not None: raise MissingRequirementsError(f'Install "gui" requirements | pip install -U g4f[gui]\n{import_error}') + app = create_app() site = Website(app) for route in site.routes: @@ -36,7 +37,7 @@ def run_gui(host: str = '0.0.0.0', port: int = 8080, debug: bool = False) -> Non 'debug': debug } - get_gui_app() + app = get_gui_app() print(f"Running on port {config['port']}") app.run(**config) diff --git a/g4f/gui/client/index.html b/g4f/gui/client/index.html index acd6b7e5..31ffc0b0 100644 --- a/g4f/gui/client/index.html +++ b/g4f/gui/client/index.html @@ -136,6 +136,11 @@ +
+ Refine files with spaCy + + +
@@ -258,7 +263,7 @@
${count_words_and_tokens(message, get_selected_model()?.value)} - - - - -
@@ -484,11 +498,15 @@ const prepare_messages = (messages, message_index = -1, do_continue = false) => } } - messages.forEach((new_message) => { + messages.forEach((new_message, i) => { + // Copy message first + new_message = { ...new_message }; + // Include last message, if do_continue + if (i + 1 == messages.length && do_continue) { + delete new_message.regenerate; + } // Include only not regenerated messages if (new_message && !new_message.regenerate) { - // Copy message first - new_message = { ...new_message }; // Remove generated images from history new_message.content = filter_message(new_message.content); // Remove internal fields @@ -707,8 +725,8 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi message_storage[message_id] = ""; stop_generating.classList.remove("stop_generating-hidden"); - if (message_index == -1 && !regenerate) { - await scroll_to_bottom(); + if (message_index == -1) { + await lazy_scroll_to_bottom(); } let count_total = message_box.querySelector('.count_total'); @@ -750,9 +768,10 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi inner: content_el.querySelector('.content_inner'), count: content_el.querySelector('.count'), update_timeouts: [], + message_index: message_index, } - if (message_index == -1 && !regenerate) { - await scroll_to_bottom(); + if (message_index == -1) { + await lazy_scroll_to_bottom(); } try { const input = imageInput && imageInput.files.length > 0 ? imageInput : cameraInput; @@ -813,7 +832,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi let cursorDiv = message_el.querySelector(".cursor"); if (cursorDiv) cursorDiv.parentNode.removeChild(cursorDiv); if (message_index == -1) { - await scroll_to_bottom(); + await lazy_scroll_to_bottom(); } await safe_remove_cancel_button(); await register_message_buttons(); @@ -826,6 +845,12 @@ async function scroll_to_bottom() { message_box.scrollTop = message_box.scrollHeight; } +async function lazy_scroll_to_bottom() { + if (message_box.scrollHeight - message_box.scrollTop < 2 * message_box.clientHeight) { + await scroll_to_bottom(); + } +} + const clear_conversations = async () => { const elements = box_conversations.childNodes; let index = elements.length; @@ -971,7 +996,23 @@ const load_conversation = async (conversation_id, scroll=true) => { } else { buffer = ""; } - buffer += item.content; + buffer = buffer.replace(/ \[aborted\]$/g, "").replace(/ \[error\]$/g, ""); + let lines = buffer.trim().split("\n"); + let lastLine = lines[lines.length - 1]; + let newContent = item.content; + if (newContent.startsWith("```\n")) { + newContent = item.content.substring(4); + } + if (newContent.startsWith(lastLine)) { + newContent = newContent.substring(lastLine.length); + } else { + let words = buffer.trim().split(" "); + let lastWord = words[words.length - 1]; + if (newContent.startsWith(lastWord)) { + newContent = newContent.substring(lastWord.length); + } + } + buffer += newContent; last_model = item.provider?.model; providers.push(item.provider?.name); let next_i = parseInt(i) + 1; @@ -993,28 +1034,74 @@ const load_conversation = async (conversation_id, scroll=true) => { synthesize_params = (new URLSearchParams(synthesize_params)).toString(); let synthesize_url = `/backend-api/v2/synthesize/${synthesize_provider}?${synthesize_params}`; + const file = new File([buffer], 'message.md', {type: 'text/plain'}); + const objectUrl = URL.createObjectURL(file); + let add_buttons = []; - // Always add regenerate button - add_buttons.push(``); // Add continue button if possible + actions = ["variant"] if (item.finish && item.finish.actions) { - item.finish.actions.forEach((action) => { - if (action == "continue") { - if (messages.length >= i - 1) { - add_buttons.push(``); - } + actions = item.finish.actions + } + if (!("continue" in actions)) { + let reason = "stop"; + // Read finish reason from conversation + if (item.finish && item.finish.reason) { + reason = item.finish.reason; + } + let lines = buffer.trim().split("\n"); + let lastLine = lines[lines.length - 1]; + // Has a stop or error token at the end + if (lastLine.endsWith("[aborted]") || lastLine.endsWith("[error]")) { + reason = "error"; + // Has an even number of start or end code tags + } else if (buffer.split("```").length - 1 % 2 === 1) { + reason = "error"; + // Has a end token at the end + } else if (lastLine.endsWith("```") || lastLine.endsWith(".") || lastLine.endsWith("?") || lastLine.endsWith("!") + || lastLine.endsWith('"') || lastLine.endsWith("'") || lastLine.endsWith(")") + || lastLine.endsWith(">") || lastLine.endsWith("]") || lastLine.endsWith("}") ) { + reason = "stop" + } else { + // Has an emoji at the end + const regex = /\p{Emoji}$/u; + if (regex.test(lastLine)) { + reason = "stop" } - }); + } + if (reason == "max_tokens" || reason == "error") { + actions.push("continue") + } + } + + add_buttons.push(``); + + if (actions.includes("variant")) { + add_buttons.push(``); + } + if (actions.includes("continue")) { + if (messages.length >= i - 1) { + add_buttons.push(``); + } } elements.push(` -
+
${item.role == "assistant" ? gpt_image : user_image} @@ -1028,12 +1115,6 @@ const load_conversation = async (conversation_id, scroll=true) => {
${markdown_render(buffer)}
${count_words_and_tokens(buffer, next_provider?.model)} - - - - - - ${add_buttons.join("")}
@@ -1444,11 +1525,8 @@ function update_message(content_map, message_id, content = null) { content_map.inner.innerHTML = html; content_map.count.innerText = count_words_and_tokens(message_storage[message_id], provider_storage[message_id]?.model); highlight(content_map.inner); - if (!content_map.container.classList.contains("regenerate")) { - if (message_box.scrollTop >= message_box.scrollHeight - message_box.clientHeight - 200) { - window.scrollTo(0, 0); - message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" }); - } + if (content_map.message_index == -1) { + lazy_scroll_to_bottom(); } content_map.update_timeouts.forEach((timeoutId)=>clearTimeout(timeoutId)); content_map.update_timeouts = []; @@ -1711,19 +1789,76 @@ async function upload_cookies() { fileInput.value = ""; } +function formatFileSize(bytes) { + const units = ['B', 'KB', 'MB', 'GB']; + let unitIndex = 0; + while (bytes >= 1024 && unitIndex < units.length - 1) { + bytes /= 1024; + unitIndex++; + } + return `${bytes.toFixed(2)} ${units[unitIndex]}`; +} + +async function upload_files(fileInput) { + const paperclip = document.querySelector(".user-input .fa-paperclip"); + const bucket_id = uuid(); + + const formData = new FormData(); + Array.from(fileInput.files).forEach(file => { + formData.append('files[]', file); + }); + paperclip.classList.add("blink"); + await fetch("/backend-api/v2/files/" + bucket_id, { + method: 'POST', + body: formData + }); + let do_refine = document.getElementById("refine").checked; + function connectToSSE(url) { + const eventSource = new EventSource(url); + eventSource.onmessage = (event) => { + const data = JSON.parse(event.data); + if (data.error) { + inputCount.innerText = `Error: ${data.error.message}`; + } else if (data.action == "load") { + inputCount.innerText = `Read data: ${formatFileSize(data.size)}`; + } else if (data.action == "refine") { + inputCount.innerText = `Refine data: ${formatFileSize(data.size)}`; + } else if (data.action == "done") { + if (do_refine) { + do_refine = false; + connectToSSE(`/backend-api/v2/files/${bucket_id}?refine_chunks_with_spacy=true`); + return; + } + inputCount.innerText = "Files are loaded successfully"; + messageInput.value += (messageInput.value ? "\n" : "") + JSON.stringify({bucket_id: bucket_id}) + "\n"; + paperclip.classList.remove("blink"); + fileInput.value = ""; + delete fileInput.dataset.text; + } + }; + eventSource.onerror = (event) => { + eventSource.close(); + paperclip.classList.remove("blink"); + } + } + connectToSSE(`/backend-api/v2/files/${bucket_id}`); +} + fileInput.addEventListener('change', async (event) => { if (fileInput.files.length) { type = fileInput.files[0].name.split('.').pop() if (type == "har") { return await upload_cookies(); + } else if (type != "json") { + await upload_files(fileInput); } fileInput.dataset.type = type - const reader = new FileReader(); - reader.addEventListener('load', async (event) => { - fileInput.dataset.text = event.target.result; - if (type == "json") { + if (type == "json") { + const reader = new FileReader(); + reader.addEventListener('load', async (event) => { + fileInput.dataset.text = event.target.result; const data = JSON.parse(fileInput.dataset.text); - if ("g4f" in data.options) { + if (data.options && "g4f" in data.options) { let count = 0; Object.keys(data).forEach(key => { if (key != "options" && !localStorage.getItem(key)) { @@ -1736,11 +1871,23 @@ fileInput.addEventListener('change', async (event) => { fileInput.value = ""; inputCount.innerText = `${count} Conversations were imported successfully`; } else { - await upload_cookies(); + is_cookie_file = false; + if (Array.isArray(data)) { + data.forEach((item) => { + if (item.domain && item.name && item.value) { + is_cookie_file = true; + } + }); + } + if (is_cookie_file) { + await upload_cookies(); + } else { + await upload_files(fileInput); + } } - } - }); - reader.readAsText(fileInput.files[0]); + }); + reader.readAsText(fileInput.files[0]); + } } else { delete fileInput.dataset.text; } diff --git a/g4f/gui/run.py b/g4f/gui/run.py index 7acc5d9a..40cca6d5 100644 --- a/g4f/gui/run.py +++ b/g4f/gui/run.py @@ -1,5 +1,6 @@ from .gui_parser import gui_parser from ..cookies import read_cookie_files +from g4f.gui import run_gui import g4f.cookies import g4f.debug @@ -8,7 +9,6 @@ def run_gui_args(args): g4f.debug.logging = True if not args.ignore_cookie_files: read_cookie_files() - from g4f.gui import run_gui host = args.host port = args.port debug = args.debug diff --git a/g4f/gui/server/api.py b/g4f/gui/server/api.py index ca80f7f3..66c7390e 100644 --- a/g4f/gui/server/api.py +++ b/g4f/gui/server/api.py @@ -7,17 +7,17 @@ from typing import Iterator from flask import send_from_directory from inspect import signature -from g4f import version, models -from g4f import ChatCompletion, get_model_and_provider -from g4f.errors import VersionNotFoundError -from g4f.image import ImagePreview, ImageResponse, copy_images, ensure_images_dir, images_dir -from g4f.Provider import ProviderUtils, __providers__ -from g4f.providers.base_provider import ProviderModelMixin -from g4f.providers.retry_provider import IterListProvider -from g4f.providers.response import BaseConversation, JsonConversation, FinishReason -from g4f.providers.response import SynthesizeData, TitleGeneration, RequestLogin, Parameters -from g4f.client.service import convert_to_provider -from g4f import debug +from ...errors import VersionNotFoundError +from ...image import ImagePreview, ImageResponse, copy_images, ensure_images_dir, images_dir +from ...tools.run_tools import iter_run_tools +from ...Provider import ProviderUtils, __providers__ +from ...providers.base_provider import ProviderModelMixin +from ...providers.retry_provider import IterListProvider +from ...providers.response import BaseConversation, JsonConversation, FinishReason +from ...providers.response import SynthesizeData, TitleGeneration, RequestLogin, Parameters +from ... import version, models +from ... import ChatCompletion, get_model_and_provider +from ... import debug logger = logging.getLogger(__name__) conversations: dict[dict[str, BaseConversation]] = {} @@ -90,16 +90,28 @@ class Api: api_key = json_data.get("api_key") if api_key is not None: kwargs["api_key"] = api_key + kwargs["tool_calls"] = [{ + "function": { + "name": "bucket_tool" + }, + "type": "function" + }] do_web_search = json_data.get('web_search') if do_web_search and provider: - provider_handler = convert_to_provider(provider) - if hasattr(provider_handler, "get_parameters"): - if "web_search" in provider_handler.get_parameters(): - kwargs['web_search'] = True - do_web_search = False - if do_web_search: - from ...web_search import get_search_message - messages[-1]["content"] = get_search_message(messages[-1]["content"]) + kwargs["tool_calls"].append({ + "function": { + "name": "safe_search_tool" + }, + "type": "function" + }) + action = json_data.get('action') + if action == "continue": + kwargs["tool_calls"].append({ + "function": { + "name": "continue_tool" + }, + "type": "function" + }) conversation = json_data.get("conversation") if conversation is not None: kwargs["conversation"] = JsonConversation(**conversation) @@ -139,7 +151,7 @@ class Api: logging=False ) params = { - **provider_handler.get_parameters(as_json=True), + **(provider_handler.get_parameters(as_json=True) if hasattr(provider_handler, "get_parameters") else {}), "model": model, "messages": kwargs.get("messages"), "web_search": kwargs.get("web_search") @@ -153,7 +165,7 @@ class Api: yield self._format_json("parameters", params) first = True try: - result = ChatCompletion.create(**{**kwargs, "model": model, "provider": provider_handler}) + result = iter_run_tools(ChatCompletion.create, **{**kwargs, "model": model, "provider": provider_handler}) for chunk in result: if first: first = False diff --git a/g4f/gui/server/app.py b/g4f/gui/server/app.py index 869d3880..86ea40a3 100644 --- a/g4f/gui/server/app.py +++ b/g4f/gui/server/app.py @@ -1,9 +1,9 @@ import sys, os from flask import Flask -if getattr(sys, 'frozen', False): - template_folder = os.path.join(sys._MEIPASS, "client") -else: - template_folder = "../client" - -app = Flask(__name__, template_folder=template_folder, static_folder=f"{template_folder}/static") \ No newline at end of file +def create_app() -> Flask: + if getattr(sys, 'frozen', False): + template_folder = os.path.join(sys._MEIPASS, "client") + else: + template_folder = "../client" + return Flask(__name__, template_folder=template_folder, static_folder=f"{template_folder}/static") \ No newline at end of file diff --git a/g4f/gui/server/backend.py b/g4f/gui/server/backend.py deleted file mode 100644 index 23e6ee21..00000000 --- a/g4f/gui/server/backend.py +++ /dev/null @@ -1,180 +0,0 @@ -import json -import flask -import os -import logging -import asyncio -from flask import Flask, request, jsonify -from typing import Generator -from werkzeug.utils import secure_filename - -from g4f.image import is_allowed_extension, to_image -from g4f.client.service import convert_to_provider -from g4f.providers.asyncio import to_sync_generator -from g4f.errors import ProviderNotFoundError -from g4f.cookies import get_cookies_dir -from .api import Api - -logger = logging.getLogger(__name__) - -def safe_iter_generator(generator: Generator) -> Generator: - start = next(generator) - def iter_generator(): - yield start - yield from generator - return iter_generator() - -class Backend_Api(Api): - """ - Handles various endpoints in a Flask application for backend operations. - - This class provides methods to interact with models, providers, and to handle - various functionalities like conversations, error handling, and version management. - - Attributes: - app (Flask): A Flask application instance. - routes (dict): A dictionary mapping API endpoints to their respective handlers. - """ - def __init__(self, app: Flask) -> None: - """ - Initialize the backend API with the given Flask application. - - Args: - app (Flask): Flask application instance to attach routes to. - """ - self.app: Flask = app - - def jsonify_models(**kwargs): - response = self.get_models(**kwargs) - if isinstance(response, list): - return jsonify(response) - return response - - def jsonify_provider_models(**kwargs): - response = self.get_provider_models(**kwargs) - if isinstance(response, list): - return jsonify(response) - return response - - def jsonify_providers(**kwargs): - response = self.get_providers(**kwargs) - if isinstance(response, list): - return jsonify(response) - return response - - self.routes = { - '/backend-api/v2/models': { - 'function': jsonify_models, - 'methods': ['GET'] - }, - '/backend-api/v2/models/': { - 'function': jsonify_provider_models, - 'methods': ['GET'] - }, - '/backend-api/v2/providers': { - 'function': jsonify_providers, - 'methods': ['GET'] - }, - '/backend-api/v2/version': { - 'function': self.get_version, - 'methods': ['GET'] - }, - '/backend-api/v2/conversation': { - 'function': self.handle_conversation, - 'methods': ['POST'] - }, - '/backend-api/v2/synthesize/': { - 'function': self.handle_synthesize, - 'methods': ['GET'] - }, - '/backend-api/v2/upload_cookies': { - 'function': self.upload_cookies, - 'methods': ['POST'] - }, - '/images/': { - 'function': self.serve_images, - 'methods': ['GET'] - } - } - - def upload_cookies(self): - file = None - if "file" in request.files: - file = request.files['file'] - if file.filename == '': - return 'No selected file', 400 - if file and file.filename.endswith(".json") or file.filename.endswith(".har"): - filename = secure_filename(file.filename) - file.save(os.path.join(get_cookies_dir(), filename)) - return "File saved", 200 - return 'Not supported file', 400 - - def handle_conversation(self): - """ - Handles conversation requests and streams responses back. - - Returns: - Response: A Flask response object for streaming. - """ - - kwargs = {} - if "files[]" in request.files: - images = [] - for file in request.files.getlist('files[]'): - if file.filename != '' and is_allowed_extension(file.filename): - images.append((to_image(file.stream, file.filename.endswith('.svg')), file.filename)) - kwargs['images'] = images - if "json" in request.form: - json_data = json.loads(request.form['json']) - else: - json_data = request.json - - kwargs = self._prepare_conversation_kwargs(json_data, kwargs) - - return self.app.response_class( - self._create_response_stream( - kwargs, - json_data.get("conversation_id"), - json_data.get("provider"), - json_data.get("download_images", True), - ), - mimetype='text/event-stream' - ) - - def handle_synthesize(self, provider: str): - try: - provider_handler = convert_to_provider(provider) - except ProviderNotFoundError: - return "Provider not found", 404 - if not hasattr(provider_handler, "synthesize"): - return "Provider doesn't support synthesize", 500 - response_data = provider_handler.synthesize({**request.args}) - if asyncio.iscoroutinefunction(provider_handler.synthesize): - response_data = asyncio.run(response_data) - else: - if hasattr(response_data, "__aiter__"): - response_data = to_sync_generator(response_data) - response_data = safe_iter_generator(response_data) - content_type = getattr(provider_handler, "synthesize_content_type", "application/octet-stream") - response = flask.Response(response_data, content_type=content_type) - response.headers['Cache-Control'] = "max-age=604800" - return response - - def get_provider_models(self, provider: str): - api_key = request.headers.get("x_api_key") - models = super().get_provider_models(provider, api_key) - if models is None: - return "Provider not found", 404 - return models - - def _format_json(self, response_type: str, content) -> str: - """ - Formats and returns a JSON response. - - Args: - response_type (str): The type of the response. - content: The content to be included in the response. - - Returns: - str: A JSON formatted string. - """ - return json.dumps(super()._format_json(response_type, content)) + "\n" \ No newline at end of file diff --git a/g4f/gui/server/backend_api.py b/g4f/gui/server/backend_api.py new file mode 100644 index 00000000..07505424 --- /dev/null +++ b/g4f/gui/server/backend_api.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +import json +import flask +import os +import logging +import asyncio +import shutil +from flask import Flask, Response, request, jsonify +from typing import Generator +from pathlib import Path +from werkzeug.utils import secure_filename + +from ...image import is_allowed_extension, to_image +from ...client.service import convert_to_provider +from ...providers.asyncio import to_sync_generator +from ...tools.files import supports_filename, get_streaming, get_bucket_dir, get_buckets +from ...errors import ProviderNotFoundError +from ...cookies import get_cookies_dir +from .api import Api + +logger = logging.getLogger(__name__) + +def safe_iter_generator(generator: Generator) -> Generator: + start = next(generator) + def iter_generator(): + yield start + yield from generator + return iter_generator() + +class Backend_Api(Api): + """ + Handles various endpoints in a Flask application for backend operations. + + This class provides methods to interact with models, providers, and to handle + various functionalities like conversations, error handling, and version management. + + Attributes: + app (Flask): A Flask application instance. + routes (dict): A dictionary mapping API endpoints to their respective handlers. + """ + def __init__(self, app: Flask) -> None: + """ + Initialize the backend API with the given Flask application. + + Args: + app (Flask): Flask application instance to attach routes to. + """ + self.app: Flask = app + + def jsonify_models(**kwargs): + response = self.get_models(**kwargs) + if isinstance(response, list): + return jsonify(response) + return response + + def jsonify_provider_models(**kwargs): + response = self.get_provider_models(**kwargs) + if isinstance(response, list): + return jsonify(response) + return response + + def jsonify_providers(**kwargs): + response = self.get_providers(**kwargs) + if isinstance(response, list): + return jsonify(response) + return response + + self.routes = { + '/backend-api/v2/models': { + 'function': jsonify_models, + 'methods': ['GET'] + }, + '/backend-api/v2/models/': { + 'function': jsonify_provider_models, + 'methods': ['GET'] + }, + '/backend-api/v2/providers': { + 'function': jsonify_providers, + 'methods': ['GET'] + }, + '/backend-api/v2/version': { + 'function': self.get_version, + 'methods': ['GET'] + }, + '/backend-api/v2/conversation': { + 'function': self.handle_conversation, + 'methods': ['POST'] + }, + '/backend-api/v2/synthesize/': { + 'function': self.handle_synthesize, + 'methods': ['GET'] + }, + '/backend-api/v2/upload_cookies': { + 'function': self.upload_cookies, + 'methods': ['POST'] + }, + '/images/': { + 'function': self.serve_images, + 'methods': ['GET'] + } + } + + @app.route('/backend-api/v2/buckets', methods=['GET']) + def list_buckets(): + try: + buckets = get_buckets() + if buckets is None: + return jsonify({"error": {"message": "Error accessing bucket directory"}}), 500 + sanitized_buckets = [secure_filename(b) for b in buckets] + return jsonify(sanitized_buckets), 200 + except Exception as e: + return jsonify({"error": {"message": str(e)}}), 500 + + @app.route('/backend-api/v2/files/', methods=['GET', 'DELETE']) + def manage_files(bucket_id: str): + bucket_id = secure_filename(bucket_id) + bucket_dir = get_bucket_dir(secure_filename(bucket_id)) + + if not os.path.isdir(bucket_dir): + return jsonify({"error": {"message": "Bucket directory not found"}}), 404 + + if request.method == 'DELETE': + try: + shutil.rmtree(bucket_dir) + return jsonify({"message": "Bucket deleted successfully"}), 200 + except OSError as e: + return jsonify({"error": {"message": f"Error deleting bucket: {str(e)}"}}), 500 + except Exception as e: + return jsonify({"error": {"message": str(e)}}), 500 + + delete_files = request.args.get('delete_files', True) + refine_chunks_with_spacy = request.args.get('refine_chunks_with_spacy', False) + event_stream = 'text/event-stream' in request.headers.get('Accept', '') + mimetype = "text/event-stream" if event_stream else "text/plain"; + return Response(get_streaming(bucket_dir, delete_files, refine_chunks_with_spacy, event_stream), mimetype=mimetype) + + @self.app.route('/backend-api/v2/files/', methods=['POST']) + def upload_files(bucket_id: str): + bucket_id = secure_filename(bucket_id) + bucket_dir = get_bucket_dir(bucket_id) + os.makedirs(bucket_dir, exist_ok=True) + filenames = [] + for file in request.files.getlist('files[]'): + try: + filename = secure_filename(file.filename) + if supports_filename(filename): + with open(os.path.join(bucket_dir, filename), 'wb') as f: + shutil.copyfileobj(file.stream, f) + filenames.append(filename) + finally: + file.stream.close() + with open(os.path.join(bucket_dir, "files.txt"), 'w') as f: + [f.write(f"{filename}\n") for filename in filenames] + return {"bucket_id": bucket_id, "files": filenames} + + @app.route('/backend-api/v2/files//', methods=['PUT']) + def upload_file(bucket_id, filename): + bucket_id = secure_filename(bucket_id) + bucket_dir = get_bucket_dir(bucket_id) + filename = secure_filename(filename) + bucket_path = Path(bucket_dir) + + if not supports_filename(filename): + return jsonify({"error": {"message": f"File type not allowed"}}), 400 + + if not bucket_path.exists(): + bucket_path.mkdir(parents=True, exist_ok=True) + + try: + file_path = bucket_path / filename + file_data = request.get_data() + if not file_data: + return jsonify({"error": {"message": "No file data received"}}), 400 + + with open(str(file_path), 'wb') as f: + f.write(file_data) + + return jsonify({"message": f"File '{filename}' uploaded successfully to bucket '{bucket_id}'"}), 201 + except Exception as e: + return jsonify({"error": {"message": f"Error uploading file: {str(e)}"}}), 500 + + def upload_cookies(self): + file = None + if "file" in request.files: + file = request.files['file'] + if file.filename == '': + return 'No selected file', 400 + if file and file.filename.endswith(".json") or file.filename.endswith(".har"): + filename = secure_filename(file.filename) + file.save(os.path.join(get_cookies_dir(), filename)) + return "File saved", 200 + return 'Not supported file', 400 + + def handle_conversation(self): + """ + Handles conversation requests and streams responses back. + + Returns: + Response: A Flask response object for streaming. + """ + + kwargs = {} + if "files[]" in request.files: + images = [] + for file in request.files.getlist('files[]'): + if file.filename != '' and is_allowed_extension(file.filename): + images.append((to_image(file.stream, file.filename.endswith('.svg')), file.filename)) + kwargs['images'] = images + if "json" in request.form: + json_data = json.loads(request.form['json']) + else: + json_data = request.json + + kwargs = self._prepare_conversation_kwargs(json_data, kwargs) + + return self.app.response_class( + self._create_response_stream( + kwargs, + json_data.get("conversation_id"), + json_data.get("provider"), + json_data.get("download_images", True), + ), + mimetype='text/event-stream' + ) + + def handle_synthesize(self, provider: str): + try: + provider_handler = convert_to_provider(provider) + except ProviderNotFoundError: + return "Provider not found", 404 + if not hasattr(provider_handler, "synthesize"): + return "Provider doesn't support synthesize", 500 + response_data = provider_handler.synthesize({**request.args}) + if asyncio.iscoroutinefunction(provider_handler.synthesize): + response_data = asyncio.run(response_data) + else: + if hasattr(response_data, "__aiter__"): + response_data = to_sync_generator(response_data) + response_data = safe_iter_generator(response_data) + content_type = getattr(provider_handler, "synthesize_content_type", "application/octet-stream") + response = flask.Response(response_data, content_type=content_type) + response.headers['Cache-Control'] = "max-age=604800" + return response + + def get_provider_models(self, provider: str): + api_key = request.headers.get("x_api_key") + models = super().get_provider_models(provider, api_key) + if models is None: + return "Provider not found", 404 + return models + + def _format_json(self, response_type: str, content) -> str: + """ + Formats and returns a JSON response. + + Args: + response_type (str): The type of the response. + content: The content to be included in the response. + + Returns: + str: A JSON formatted string. + """ + return json.dumps(super()._format_json(response_type, content)) + "\n" \ No newline at end of file diff --git a/g4f/gui/server/internet.py b/g4f/gui/server/internet.py index 47a8556b..f4b76885 100644 --- a/g4f/gui/server/internet.py +++ b/g4f/gui/server/internet.py @@ -1,3 +1,3 @@ from __future__ import annotations -from ...web_search import SearchResults, search, get_search_message \ No newline at end of file +from ...tools.web_search import SearchResults, search, get_search_message \ No newline at end of file diff --git a/g4f/requests/__init__.py b/g4f/requests/__init__.py index 324675c0..75f64268 100644 --- a/g4f/requests/__init__.py +++ b/g4f/requests/__init__.py @@ -32,8 +32,6 @@ except ImportError: from .. import debug from .raise_for_status import raise_for_status -from ..webdriver import WebDriver, WebDriverSession -from ..webdriver import bypass_cloudflare, get_driver_cookies from ..errors import MissingRequirementsError from ..typing import Cookies from .defaults import DEFAULT_HEADERS, WEBVIEW_HAEDERS @@ -66,69 +64,6 @@ async def get_args_from_webview(url: str) -> dict: window.destroy() return {"headers": headers, "cookies": cookies} -def get_args_from_browser( - url: str, - webdriver: WebDriver = None, - proxy: str = None, - timeout: int = 120, - do_bypass_cloudflare: bool = True, - virtual_display: bool = False -) -> dict: - """ - Create a Session object using a WebDriver to handle cookies and headers. - - Args: - url (str): The URL to navigate to using the WebDriver. - webdriver (WebDriver, optional): The WebDriver instance to use. - proxy (str, optional): Proxy server to use for the Session. - timeout (int, optional): Timeout in seconds for the WebDriver. - - Returns: - Session: A Session object configured with cookies and headers from the WebDriver. - """ - with WebDriverSession(webdriver, "", proxy=proxy, virtual_display=virtual_display) as driver: - if do_bypass_cloudflare: - bypass_cloudflare(driver, url, timeout) - headers = { - **DEFAULT_HEADERS, - 'referer': url, - } - if not hasattr(driver, "requests"): - headers["user-agent"] = driver.execute_script("return navigator.userAgent") - else: - for request in driver.requests: - if request.url.startswith(url): - for key, value in request.headers.items(): - if key in ( - "accept-encoding", - "accept-language", - "user-agent", - "sec-ch-ua", - "sec-ch-ua-platform", - "sec-ch-ua-arch", - "sec-ch-ua-full-version", - "sec-ch-ua-platform-version", - "sec-ch-ua-bitness" - ): - headers[key] = value - break - cookies = get_driver_cookies(driver) - return { - 'cookies': cookies, - 'headers': headers, - } - -def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str = None, timeout: int = 120) -> Session: - if not has_curl_cffi: - raise MissingRequirementsError('Install "curl_cffi" package | pip install -U curl_cffi') - args = get_args_from_browser(url, webdriver, proxy, timeout) - return Session( - **args, - proxies={"https": proxy, "http": proxy}, - timeout=timeout, - impersonate="chrome" - ) - def get_cookie_params_from_dict(cookies: Cookies, url: str = None, domain: str = None) -> list[CookieParam]: [CookieParam.from_json({ "name": key, diff --git a/g4f/requests/raise_for_status.py b/g4f/requests/raise_for_status.py index 3934220e..01a7d789 100644 --- a/g4f/requests/raise_for_status.py +++ b/g4f/requests/raise_for_status.py @@ -25,7 +25,7 @@ async def raise_for_status_async(response: Union[StreamResponse, ClientResponse] return text = await response.text() if message is None: - message = "HTML content" if response.headers.get("content-type").startswith("text/html") else text + message = "HTML content" if response.headers.get("content-type").startswith("text/html") else text if message == "HTML content": if response.status == 520: message = "Unknown error (Cloudflare)" diff --git a/g4f/tools/files.py b/g4f/tools/files.py new file mode 100644 index 00000000..d0d1c23b --- /dev/null +++ b/g4f/tools/files.py @@ -0,0 +1,511 @@ +from __future__ import annotations + +import os +import json +from pathlib import Path +from typing import Iterator, Optional +from aiohttp import ClientSession, ClientError, ClientResponse, ClientTimeout +import urllib.parse +import time +import zipfile +import asyncio +import hashlib +import base64 + +try: + from werkzeug.utils import secure_filename +except ImportError: + secure_filename = os.path.basename + +try: + import PyPDF2 + from PyPDF2.errors import PdfReadError + has_pypdf2 = True +except ImportError: + has_pypdf2 = False +try: + import pdfplumber + has_pdfplumber = True +except ImportError: + has_pdfplumber = False +try: + from pdfminer.high_level import extract_text + has_pdfminer = True +except ImportError: + has_pdfminer = False +try: + from docx import Document + has_docx = True +except ImportError: + has_docx = False +try: + import docx2txt + has_docx2txt = True +except ImportError: + has_docx2txt = False +try: + from odf.opendocument import load + from odf.text import P + has_odfpy = True +except ImportError: + has_odfpy = False +try: + import ebooklib + from ebooklib import epub + has_ebooklib = True +except ImportError: + has_ebooklib = False +try: + import pandas as pd + has_openpyxl = True +except ImportError: + has_openpyxl = False +try: + import spacy + has_spacy = True +except: + has_spacy = False +try: + from bs4 import BeautifulSoup + has_beautifulsoup4 = True +except ImportError: + has_beautifulsoup4 = False + +from .web_search import scrape_text +from ..cookies import get_cookies_dir +from ..requests.aiohttp import get_connector +from ..errors import MissingRequirementsError +from .. import debug + +PLAIN_FILE_EXTENSIONS = ["txt", "xml", "json", "js", "har", "sh", "py", "php", "css", "yaml", "sql", "log", "csv", "twig", "md"] +PLAIN_CACHE = "plain.cache" +DOWNLOADS_FILE = "downloads.json" +FILE_LIST = "files.txt" + +def supports_filename(filename: str): + if filename.endswith(".pdf"): + if has_pypdf2: + return True + elif has_pdfplumber: + return True + elif has_pdfminer: + return True + raise MissingRequirementsError(f'Install "pypdf2" requirements | pip install -U g4f[files]') + elif filename.endswith(".docx"): + if has_docx: + return True + elif has_docx2txt: + return True + raise MissingRequirementsError(f'Install "docx" requirements | pip install -U g4f[files]') + elif has_odfpy and filename.endswith(".odt"): + return True + elif has_ebooklib and filename.endswith(".epub"): + return True + elif has_openpyxl and filename.endswith(".xlsx"): + return True + elif filename.endswith(".html"): + if not has_beautifulsoup4: + raise MissingRequirementsError(f'Install "beautifulsoup4" requirements | pip install -U g4f[files]') + return True + elif filename.endswith(".zip"): + return True + elif filename.endswith("package-lock.json") and filename != FILE_LIST: + return False + else: + extension = os.path.splitext(filename)[1][1:] + if extension in PLAIN_FILE_EXTENSIONS: + return True + return False + +def get_bucket_dir(bucket_id: str): + bucket_dir = os.path.join(get_cookies_dir(), "buckets", bucket_id) + return bucket_dir + +def get_buckets(): + buckets_dir = os.path.join(get_cookies_dir(), "buckets") + try: + return [d for d in os.listdir(buckets_dir) if os.path.isdir(os.path.join(buckets_dir, d))] + except OSError as e: + return None + +def spacy_refine_chunks(source_iterator): + if not has_spacy: + raise MissingRequirementsError(f'Install "spacy" requirements | pip install -U g4f[files]') + + nlp = spacy.load("en_core_web_sm") + for page in source_iterator: + doc = nlp(page) + #for chunk in doc.noun_chunks: + # yield " ".join([token.lemma_ for token in chunk if not token.is_stop]) + # for token in doc: + # if not token.is_space: + # yield token.lemma_.lower() + # yield " " + sentences = list(doc.sents) + summary = sorted(sentences, key=lambda x: len(x.text), reverse=True)[:2] + for sent in summary: + yield sent.text + +def get_filenames(bucket_dir: Path): + files = bucket_dir / FILE_LIST + with files.open('r') as f: + return [filename.strip() for filename in f.readlines()] + +def stream_read_files(bucket_dir: Path, filenames: list) -> Iterator[str]: + for filename in filenames: + file_path: Path = bucket_dir / filename + if not file_path.exists() and 0 > file_path.lstat().st_size: + continue + extension = os.path.splitext(filename)[1][1:] + if filename.endswith(".zip"): + with zipfile.ZipFile(file_path, 'r') as zip_ref: + zip_ref.extractall(bucket_dir) + try: + yield from stream_read_files(bucket_dir, [f for f in zip_ref.namelist() if supports_filename(f)]) + except zipfile.BadZipFile: + pass + finally: + for unlink in zip_ref.namelist()[::-1]: + filepath = os.path.join(bucket_dir, unlink) + if os.path.exists(filepath): + if os.path.isdir(filepath): + os.rmdir(filepath) + else: + os.unlink(filepath) + continue + yield f"```{filename}\n" + if has_pypdf2 and filename.endswith(".pdf"): + try: + reader = PyPDF2.PdfReader(file_path) + for page_num in range(len(reader.pages)): + page = reader.pages[page_num] + yield page.extract_text() + except PdfReadError: + continue + if has_pdfplumber and filename.endswith(".pdf"): + with pdfplumber.open(file_path) as pdf: + for page in pdf.pages: + yield page.extract_text() + if has_pdfminer and filename.endswith(".pdf"): + yield extract_text(file_path) + elif has_docx and filename.endswith(".docx"): + doc = Document(file_path) + for para in doc.paragraphs: + yield para.text + elif has_docx2txt and filename.endswith(".docx"): + yield docx2txt.process(file_path) + elif has_odfpy and filename.endswith(".odt"): + textdoc = load(file_path) + allparas = textdoc.getElementsByType(P) + for p in allparas: + yield p.firstChild.data if p.firstChild else "" + elif has_ebooklib and filename.endswith(".epub"): + book = epub.read_epub(file_path) + for doc_item in book.get_items(): + if doc_item.get_type() == ebooklib.ITEM_DOCUMENT: + yield doc_item.get_content().decode(errors='ignore') + elif has_openpyxl and filename.endswith(".xlsx"): + df = pd.read_excel(file_path) + for row in df.itertuples(index=False): + yield " ".join(str(cell) for cell in row) + elif has_beautifulsoup4 and filename.endswith(".html"): + yield from scrape_text(file_path.read_text(errors="ignore")) + elif extension in PLAIN_FILE_EXTENSIONS: + yield file_path.read_text(errors="ignore") + yield f"\n```\n\n" + +def cache_stream(stream: Iterator[str], bucket_dir: Path) -> Iterator[str]: + cache_file = bucket_dir / PLAIN_CACHE + tmp_file = bucket_dir / f"{PLAIN_CACHE}.{time.time()}.tmp" + if cache_file.exists(): + for chunk in read_path_chunked(cache_file): + yield chunk + return + with open(tmp_file, "w") as f: + for chunk in stream: + f.write(chunk) + yield chunk + tmp_file.rename(cache_file) + +def is_complete(data: str): + return data.endswith("\n```\n\n") and data.count("```") % 2 == 0 + +def read_path_chunked(path: Path): + with path.open("r", encoding='utf-8') as f: + current_chunk_size = 0 + buffer = "" + for line in f: + current_chunk_size += len(line.encode('utf-8')) + buffer += line + if current_chunk_size >= 4096: + if is_complete(buffer) or current_chunk_size >= 8192: + yield buffer + buffer = "" + current_chunk_size = 0 + if current_chunk_size > 0: + yield buffer + +def read_bucket(bucket_dir: Path): + bucket_dir = Path(bucket_dir) + cache_file = bucket_dir / PLAIN_CACHE + spacy_file = bucket_dir / f"spacy_0001.cache" + if not spacy_file.exists(): + yield cache_file.read_text() + for idx in range(1, 1000): + spacy_file = bucket_dir / f"spacy_{idx:04d}.cache" + plain_file = bucket_dir / f"plain_{idx:04d}.cache" + if spacy_file.exists(): + yield spacy_file.read_text() + elif plain_file.exists(): + yield plain_file.read_text() + else: + break + +def stream_read_parts_and_refine(bucket_dir: Path, delete_files: bool = False) -> Iterator[str]: + cache_file = bucket_dir / PLAIN_CACHE + space_file = Path(bucket_dir) / f"spacy_0001.cache" + part_one = bucket_dir / f"plain_0001.cache" + if not space_file.exists() and not part_one.exists() and cache_file.exists(): + split_file_by_size_and_newline(cache_file, bucket_dir) + for idx in range(1, 1000): + part = bucket_dir / f"plain_{idx:04d}.cache" + tmp_file = Path(bucket_dir) / f"spacy_{idx:04d}.{time.time()}.tmp" + cache_file = Path(bucket_dir) / f"spacy_{idx:04d}.cache" + if cache_file.exists(): + with open(cache_file, "r") as f: + yield f.read() + continue + if not part.exists(): + break + with tmp_file.open("w") as f: + for chunk in spacy_refine_chunks(read_path_chunked(part)): + f.write(chunk) + yield chunk + tmp_file.rename(cache_file) + if delete_files: + part.unlink() + +def split_file_by_size_and_newline(input_filename, output_dir, chunk_size_bytes=1024*1024): # 1MB + """Splits a file into chunks of approximately chunk_size_bytes, splitting only at newline characters. + + Args: + input_filename: Path to the input file. + output_prefix: Prefix for the output files (e.g., 'output_part_'). + chunk_size_bytes: Desired size of each chunk in bytes. + """ + split_filename = os.path.splitext(os.path.basename(input_filename)) + output_prefix = os.path.join(output_dir, split_filename[0] + "_") + + with open(input_filename, 'r', encoding='utf-8') as infile: + chunk_num = 1 + current_chunk = "" + current_chunk_size = 0 + + for line in infile: + current_chunk += line + current_chunk_size += len(line.encode('utf-8')) + + if current_chunk_size >= chunk_size_bytes: + if is_complete(current_chunk) or current_chunk_size >= chunk_size_bytes * 2: + output_filename = f"{output_prefix}{chunk_num:04d}{split_filename[1]}" + with open(output_filename, 'w', encoding='utf-8') as outfile: + outfile.write(current_chunk) + current_chunk = "" + current_chunk_size = 0 + chunk_num += 1 + + # Write the last chunk + if current_chunk: + output_filename = f"{output_prefix}{chunk_num:04d}{split_filename[1]}" + with open(output_filename, 'w', encoding='utf-8') as outfile: + outfile.write(current_chunk) + +async def get_filename(response: ClientResponse): + """ + Attempts to extract a filename from an aiohttp response. Prioritizes Content-Disposition, then URL. + + Args: + response: The aiohttp ClientResponse object. + + Returns: + The filename as a string, or None if it cannot be determined. + """ + + content_disposition = response.headers.get('Content-Disposition') + if content_disposition: + try: + filename = content_disposition.split('filename=')[1].strip('"') + if filename: + return secure_filename(filename) + except IndexError: + pass + + content_type = response.headers.get('Content-Type') + url = str(response.url) + if content_type and url: + extension = await get_file_extension(response) + if extension: + parsed_url = urllib.parse.urlparse(url) + sha256_hash = hashlib.sha256(url.encode()).digest() + base64_encoded = base64.b32encode(sha256_hash).decode().lower() + return f"{parsed_url.netloc} {parsed_url.path[1:].replace('/', '_')} {base64_encoded[:6]}{extension}" + + return None + +async def get_file_extension(response: ClientResponse): + """ + Attempts to determine the file extension from an aiohttp response. Improved to handle more types. + + Args: + response: The aiohttp ClientResponse object. + + Returns: + The file extension (e.g., ".html", ".json", ".pdf", ".zip", ".md", ".txt") as a string, + or None if it cannot be determined. + """ + + content_type = response.headers.get('Content-Type') + if content_type: + if "html" in content_type.lower(): + return ".html" + elif "json" in content_type.lower(): + return ".json" + elif "pdf" in content_type.lower(): + return ".pdf" + elif "zip" in content_type.lower(): + return ".zip" + elif "text/plain" in content_type.lower(): + return ".txt" + elif "markdown" in content_type.lower(): + return ".md" + + url = str(response.url) + if url: + return Path(url).suffix.lower() + + return None + +def read_links(html: str, base: str) -> set[str]: + soup = BeautifulSoup(html, "html.parser") + for selector in [ + "main", + ".main-content-wrapper", + ".main-content", + ".emt-container-inner", + ".content-wrapper", + "#content", + "#mainContent", + ]: + select = soup.select_one(selector) + if select: + soup = select + break + urls = [] + for link in soup.select("a"): + if "rel" not in link.attrs or "nofollow" not in link.attrs["rel"]: + url = link.attrs.get("href") + if url and url.startswith("https://"): + urls.append(url.split("#")[0]) + return set([urllib.parse.urljoin(base, link) for link in urls]) + +async def download_urls( + bucket_dir: Path, + urls: list[str], + max_depth: int = 2, + loaded_urls: set[str] = set(), + lock: asyncio.Lock = None, + delay: int = 3, + group_size: int = 5, + timeout: int = 10, + proxy: Optional[str] = None +) -> list[str]: + if lock is None: + lock = asyncio.Lock() + async with ClientSession( + connector=get_connector(proxy=proxy), + timeout=ClientTimeout(timeout) + ) as session: + async def download_url(url: str) -> str: + try: + async with session.get(url) as response: + response.raise_for_status() + filename = await get_filename(response) + if not filename: + print(f"Failed to get filename for {url}") + return None + newfiles = [filename] + if filename.endswith(".html") and max_depth > 0: + new_urls = read_links(await response.text(), str(response.url)) + async with lock: + new_urls = [new_url for new_url in new_urls if new_url not in loaded_urls] + [loaded_urls.add(url) for url in new_urls] + if new_urls: + for i in range(0, len(new_urls), group_size): + newfiles += await download_urls(bucket_dir, new_urls[i:i + group_size], max_depth - 1, loaded_urls, lock, delay + 1) + await asyncio.sleep(delay) + if supports_filename(filename) and filename != DOWNLOADS_FILE: + target = bucket_dir / filename + with target.open("wb") as f: + async for chunk in response.content.iter_chunked(4096): + f.write(chunk) + return newfiles + except (ClientError, asyncio.TimeoutError) as e: + debug.log(f"Download failed: {e.__class__.__name__}: {e}") + return None + files = set() + for results in await asyncio.gather(*[download_url(url) for url in urls]): + if results: + [files.add(url) for url in results] + return files + +def get_streaming(bucket_dir: str, delete_files = False, refine_chunks_with_spacy = False, event_stream: bool = False) -> Iterator[str]: + bucket_dir = Path(bucket_dir) + bucket_dir.mkdir(parents=True, exist_ok=True) + try: + download_file = bucket_dir / DOWNLOADS_FILE + if download_file.exists(): + urls = [] + with download_file.open('r') as f: + data = json.load(f) + download_file.unlink() + if isinstance(data, list): + for item in data: + if "url" in item: + urls.append(item["url"]) + if urls: + filenames = asyncio.run(download_urls(bucket_dir, urls)) + with open(os.path.join(bucket_dir, FILE_LIST), 'w') as f: + [f.write(f"{filename}\n") for filename in filenames if filename] + + if refine_chunks_with_spacy: + size = 0 + for chunk in stream_read_parts_and_refine(bucket_dir, delete_files): + if event_stream: + size += len(chunk) + yield f'data: {json.dumps({"action": "refine", "size": size})}\n\n' + else: + yield chunk + else: + streaming = stream_read_files(bucket_dir, get_filenames(bucket_dir)) + streaming = cache_stream(streaming, bucket_dir) + size = 0 + for chunk in streaming: + if event_stream: + size += len(chunk) + yield f'data: {json.dumps({"action": "load", "size": size})}\n\n' + else: + yield chunk + files_txt = os.path.join(bucket_dir, FILE_LIST) + if delete_files and os.path.exists(files_txt): + for filename in get_filenames(bucket_dir): + if os.path.exists(os.path.join(bucket_dir, filename)): + os.remove(os.path.join(bucket_dir, filename)) + os.remove(files_txt) + if event_stream: + yield f'data: {json.dumps({"action": "delete_files"})}\n\n' + if event_stream: + yield f'data: {json.dumps({"action": "done"})}\n\n' + except Exception as e: + if event_stream: + yield f'data: {json.dumps({"error": {"message": str(e)}})}\n\n' + raise e \ No newline at end of file diff --git a/g4f/tools/run_tools.py b/g4f/tools/run_tools.py new file mode 100644 index 00000000..21e9ec09 --- /dev/null +++ b/g4f/tools/run_tools.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import re +import json +import asyncio +from typing import Optional, Callable, AsyncIterator + +from ..typing import Messages +from ..providers.helper import filter_none +from ..client.helper import to_async_iterator +from .web_search import do_search, get_search_message +from .files import read_bucket, get_bucket_dir +from .. import debug + +def validate_arguments(data: dict) -> dict: + if "arguments" in data: + if isinstance(data["arguments"], str): + data["arguments"] = json.loads(data["arguments"]) + if not isinstance(data["arguments"], dict): + raise ValueError("Tool function arguments must be a dictionary or a json string") + else: + return filter_none(**data["arguments"]) + else: + return {} + +async def async_iter_run_tools(async_iter_callback, model, messages, tool_calls: Optional[list] = None, **kwargs): + if tool_calls is not None: + for tool in tool_calls: + if tool.get("type") == "function": + if tool.get("function", {}).get("name") == "search_tool": + tool["function"]["arguments"] = validate_arguments(tool["function"]) + messages = messages.copy() + messages[-1]["content"] = await do_search( + messages[-1]["content"], + **tool["function"]["arguments"] + ) + elif tool.get("function", {}).get("name") == "continue": + last_line = messages[-1]["content"].strip().splitlines()[-1] + content = f"Continue writing the story after this line start with a plus sign if you begin a new word.\n{last_line}" + messages.append({"role": "user", "content": content}) + response = async_iter_callback(model=model, messages=messages, **kwargs) + if not hasattr(response, "__aiter__"): + response = to_async_iterator(response) + async for chunk in response: + yield chunk + +def iter_run_tools( + iter_callback: Callable, + model: str, + messages: Messages, + provider: Optional[str] = None, + tool_calls: Optional[list] = None, + **kwargs +) -> AsyncIterator: + if tool_calls is not None: + for tool in tool_calls: + if tool.get("type") == "function": + if tool.get("function", {}).get("name") == "search_tool": + tool["function"]["arguments"] = validate_arguments(tool["function"]) + messages[-1]["content"] = get_search_message( + messages[-1]["content"], + raise_search_exceptions=True, + **tool["function"]["arguments"] + ) + elif tool.get("function", {}).get("name") == "safe_search_tool": + tool["function"]["arguments"] = validate_arguments(tool["function"]) + try: + messages[-1]["content"] = asyncio.run(do_search(messages[-1]["content"], **tool["function"]["arguments"])) + except Exception as e: + debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}") + # Enable provider native web search + kwargs["web_search"] = True + elif tool.get("function", {}).get("name") == "continue_tool": + if provider not in ("OpenaiAccount", "HuggingFace"): + last_line = messages[-1]["content"].strip().splitlines()[-1] + content = f"continue after this line:\n{last_line}" + messages.append({"role": "user", "content": content}) + else: + # Enable provider native continue + if "action" not in kwargs: + kwargs["action"] = "continue" + elif tool.get("function", {}).get("name") == "bucket_tool": + def on_bucket(match): + return "".join(read_bucket(get_bucket_dir(match.group(1)))) + messages[-1]["content"] = re.sub(r'{"bucket_id":"([^"]*)"}', on_bucket, messages[-1]["content"]) + print(messages[-1]) + return iter_callback(model=model, messages=messages, provider=provider, **kwargs) \ No newline at end of file diff --git a/g4f/tools/web_search.py b/g4f/tools/web_search.py new file mode 100644 index 00000000..9033e0ad --- /dev/null +++ b/g4f/tools/web_search.py @@ -0,0 +1,230 @@ +from __future__ import annotations + +from aiohttp import ClientSession, ClientTimeout, ClientError +import json +import hashlib +from pathlib import Path +from collections import Counter +try: + from duckduckgo_search import DDGS + from duckduckgo_search.exceptions import DuckDuckGoSearchException + from bs4 import BeautifulSoup + has_requirements = True +except ImportError: + has_requirements = False +try: + import spacy + has_spacy = True +except: + has_spacy = False +from typing import Iterator +from ..cookies import get_cookies_dir +from ..errors import MissingRequirementsError +from .. import debug + +import asyncio + +DEFAULT_INSTRUCTIONS = """ +Using the provided web search results, to write a comprehensive reply to the user request. +Make sure to add the sources of cites using [[Number]](Url) notation after the reference. Example: [[0]](http://google.com) +""" + +class SearchResults(): + def __init__(self, results: list, used_words: int): + self.results = results + self.used_words = used_words + + def __iter__(self): + yield from self.results + + def __str__(self): + search = "" + for idx, result in enumerate(self.results): + if search: + search += "\n\n\n" + search += f"Title: {result.title}\n\n" + if result.text: + search += result.text + else: + search += result.snippet + search += f"\n\nSource: [[{idx}]]({result.url})" + return search + + def __len__(self) -> int: + return len(self.results) + +class SearchResultEntry(): + def __init__(self, title: str, url: str, snippet: str, text: str = None): + self.title = title + self.url = url + self.snippet = snippet + self.text = text + + def set_text(self, text: str): + self.text = text + +def scrape_text(html: str, max_words: int = None) -> Iterator[str]: + soup = BeautifulSoup(html, "html.parser") + for selector in [ + "main", + ".main-content-wrapper", + ".main-content", + ".emt-container-inner", + ".content-wrapper", + "#content", + "#mainContent", + ]: + select = soup.select_one(selector) + if select: + soup = select + break + # Zdnet + for remove in [".c-globalDisclosure"]: + select = soup.select_one(remove) + if select: + select.extract() + + for paragraph in soup.select("p, table, ul, h1, h2, h3, h4, h5, h6"): + for line in paragraph.text.splitlines(): + words = [word for word in line.replace("\t", " ").split(" ") if word] + count = len(words) + if not count: + continue + if max_words: + max_words -= count + if max_words <= 0: + break + yield " ".join(words) + "\n" + +async def fetch_and_scrape(session: ClientSession, url: str, max_words: int = None) -> str: + try: + bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / "fetch_and_scrape" + bucket_dir.mkdir(parents=True, exist_ok=True) + md5_hash = hashlib.md5(url.encode()).hexdigest() + cache_file = bucket_dir / f"{url.split('/')[3]}.{md5_hash}.txt" + if cache_file.exists(): + return cache_file.read_text() + async with session.get(url) as response: + if response.status == 200: + html = await response.text() + text = "".join(scrape_text(html, max_words)) + with open(cache_file, "w") as f: + f.write(text) + return text + except ClientError: + return + +async def search(query: str, max_results: int = 5, max_words: int = 2500, backend: str = "auto", add_text: bool = True, timeout: int = 5, region: str = "wt-wt") -> SearchResults: + if not has_requirements: + raise MissingRequirementsError('Install "duckduckgo-search" and "beautifulsoup4" package | pip install -U g4f[search]') + with DDGS() as ddgs: + results = [] + for result in ddgs.text( + query, + region=region, + safesearch="moderate", + timelimit="y", + max_results=max_results, + backend=backend, + ): + results.append(SearchResultEntry( + result["title"], + result["href"], + result["body"] + )) + + if add_text: + requests = [] + async with ClientSession(timeout=ClientTimeout(timeout)) as session: + for entry in results: + requests.append(fetch_and_scrape(session, entry.url, int(max_words / (max_results - 1)))) + texts = await asyncio.gather(*requests) + + formatted_results = [] + used_words = 0 + left_words = max_words + for i, entry in enumerate(results): + if add_text: + entry.text = texts[i] + if left_words: + left_words -= entry.title.count(" ") + 5 + if entry.text: + left_words -= entry.text.count(" ") + else: + left_words -= entry.snippet.count(" ") + if 0 > left_words: + break + used_words = max_words - left_words + formatted_results.append(entry) + + return SearchResults(formatted_results, used_words) + +async def do_search(prompt: str, query: str = None, instructions: str = DEFAULT_INSTRUCTIONS, **kwargs) -> str: + if query is None: + query = spacy_get_keywords(prompt) + json_bytes = json.dumps({"query": query, **kwargs}, sort_keys=True).encode() + md5_hash = hashlib.md5(json_bytes).hexdigest() + bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / "web_search" + bucket_dir.mkdir(parents=True, exist_ok=True) + cache_file = bucket_dir / f"{query[:20]}.{md5_hash}.txt" + if cache_file.exists(): + with open(cache_file, "r") as f: + search_results = f.read() + else: + search_results = await search(query, **kwargs) + with open(cache_file, "w") as f: + f.write(str(search_results)) + + new_prompt = f""" +{search_results} + +Instruction: {instructions} + +User request: +{prompt} +""" + debug.log(f"Web search: '{query.strip()[:50]}...' {len(search_results.results)} Results {search_results.used_words} Words") + return new_prompt + +def get_search_message(prompt: str, raise_search_exceptions=False, **kwargs) -> str: + try: + return asyncio.run(do_search(prompt, **kwargs)) + except (DuckDuckGoSearchException, MissingRequirementsError) as e: + if raise_search_exceptions: + raise e + debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}") + return prompt + +def spacy_get_keywords(text: str): + if not has_spacy: + return text + + # Load the spaCy language model + nlp = spacy.load("en_core_web_sm") + + # Process the query + doc = nlp(text) + + # Extract keywords based on POS and named entities + keywords = [] + for token in doc: + # Filter for nouns, proper nouns, and adjectives + if token.pos_ in {"NOUN", "PROPN", "ADJ"} and not token.is_stop: + keywords.append(token.lemma_) + + # Add named entities as keywords + for ent in doc.ents: + keywords.append(ent.text) + + # Remove duplicates and print keywords + keywords = list(set(keywords)) + #print("Keyword:", keywords) + + #keyword_freq = Counter(keywords) + #keywords = keyword_freq.most_common() + #print("Keyword Frequencies:", keywords) + + keywords = [chunk.text for chunk in doc.noun_chunks if not chunk.root.is_stop] + #print("Phrases:", keywords) + + return keywords \ No newline at end of file diff --git a/g4f/web_search.py b/g4f/web_search.py deleted file mode 100644 index 5d3c5659..00000000 --- a/g4f/web_search.py +++ /dev/null @@ -1,172 +0,0 @@ -from __future__ import annotations - -from aiohttp import ClientSession, ClientTimeout, ClientError -try: - from duckduckgo_search import DDGS - from duckduckgo_search.exceptions import DuckDuckGoSearchException - from bs4 import BeautifulSoup - has_requirements = True -except ImportError: - has_requirements = False -from .errors import MissingRequirementsError -from . import debug - -import asyncio - -DEFAULT_INSTRUCTIONS = """ -Using the provided web search results, to write a comprehensive reply to the user request. -Make sure to add the sources of cites using [[Number]](Url) notation after the reference. Example: [[0]](http://google.com) -""" - -class SearchResults(): - def __init__(self, results: list, used_words: int): - self.results = results - self.used_words = used_words - - def __iter__(self): - yield from self.results - - def __str__(self): - search = "" - for idx, result in enumerate(self.results): - if search: - search += "\n\n\n" - search += f"Title: {result.title}\n\n" - if result.text: - search += result.text - else: - search += result.snippet - search += f"\n\nSource: [[{idx}]]({result.url})" - return search - - def __len__(self) -> int: - return len(self.results) - -class SearchResultEntry(): - def __init__(self, title: str, url: str, snippet: str, text: str = None): - self.title = title - self.url = url - self.snippet = snippet - self.text = text - - def set_text(self, text: str): - self.text = text - -def scrape_text(html: str, max_words: int = None) -> str: - soup = BeautifulSoup(html, "html.parser") - for selector in [ - "main", - ".main-content-wrapper", - ".main-content", - ".emt-container-inner", - ".content-wrapper", - "#content", - "#mainContent", - ]: - select = soup.select_one(selector) - if select: - soup = select - break - # Zdnet - for remove in [".c-globalDisclosure"]: - select = soup.select_one(remove) - if select: - select.extract() - clean_text = "" - for paragraph in soup.select("p, h1, h2, h3, h4, h5, h6"): - text = paragraph.get_text() - for line in text.splitlines(): - words = [] - for word in line.replace("\t", " ").split(" "): - if word: - words.append(word) - count = len(words) - if not count: - continue - if max_words: - max_words -= count - if max_words <= 0: - break - if clean_text: - clean_text += "\n" - clean_text += " ".join(words) - - return clean_text - -async def fetch_and_scrape(session: ClientSession, url: str, max_words: int = None) -> str: - try: - async with session.get(url) as response: - if response.status == 200: - html = await response.text() - return scrape_text(html, max_words) - except ClientError: - return - -async def search(query: str, max_results: int = 5, max_words: int = 2500, backend: str = "auto", add_text: bool = True, timeout: int = 5, region: str = "wt-wt") -> SearchResults: - if not has_requirements: - raise MissingRequirementsError('Install "duckduckgo-search" and "beautifulsoup4" package | pip install -U g4f[search]') - with DDGS() as ddgs: - results = [] - for result in ddgs.text( - query, - region=region, - safesearch="moderate", - timelimit="y", - max_results=max_results, - backend=backend, # Changed from 'api' to 'auto' - ): - results.append(SearchResultEntry( - result["title"], - result["href"], - result["body"] - )) - - if add_text: - requests = [] - async with ClientSession(timeout=ClientTimeout(timeout)) as session: - for entry in results: - requests.append(fetch_and_scrape(session, entry.url, int(max_words / (max_results - 1)))) - texts = await asyncio.gather(*requests) - - formatted_results = [] - used_words = 0 - left_words = max_words - for i, entry in enumerate(results): - if add_text: - entry.text = texts[i] - if left_words: - left_words -= entry.title.count(" ") + 5 - if entry.text: - left_words -= entry.text.count(" ") - else: - left_words -= entry.snippet.count(" ") - if 0 > left_words: - break - used_words = max_words - left_words - formatted_results.append(entry) - - return SearchResults(formatted_results, used_words) - -async def do_search(prompt: str, query: str = None, instructions: str = DEFAULT_INSTRUCTIONS, **kwargs) -> str: - if query is None: - query = prompt - search_results = await search(query, **kwargs) - new_prompt = f""" -{search_results} - -Instruction: {instructions} - -User request: -{prompt} -""" - debug.log(f"Web search: '{query.strip()[:50]}...' {len(search_results.results)} Results {search_results.used_words} Words") - return new_prompt - -def get_search_message(prompt: str, raise_search_exceptions=False, **kwargs) -> str: - try: - return asyncio.run(do_search(prompt, **kwargs)) - except (DuckDuckGoSearchException, MissingRequirementsError) as e: - if raise_search_exceptions: - raise e - debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}") - return prompt diff --git a/g4f/webdriver.py b/g4f/webdriver.py deleted file mode 100644 index 022e7a9f..00000000 --- a/g4f/webdriver.py +++ /dev/null @@ -1,257 +0,0 @@ -from __future__ import annotations - -try: - from platformdirs import user_config_dir - from undetected_chromedriver import Chrome, ChromeOptions, find_chrome_executable - from selenium.webdriver.remote.webdriver import WebDriver - from selenium.webdriver.remote.webelement import WebElement - from selenium.webdriver.common.by import By - from selenium.webdriver.support.ui import WebDriverWait - from selenium.webdriver.support import expected_conditions as EC - from selenium.webdriver.common.keys import Keys - from selenium.common.exceptions import NoSuchElementException - has_requirements = True -except ImportError: - from typing import Type as WebDriver - has_requirements = False - -import time -from shutil import which -from os import path -from os import access, R_OK -from .typing import Cookies -from .errors import MissingRequirementsError -from . import debug - -try: - from pyvirtualdisplay import Display - has_pyvirtualdisplay = True -except ImportError: - has_pyvirtualdisplay = False - -try: - from undetected_chromedriver import Chrome as _Chrome, ChromeOptions - from seleniumwire.webdriver import InspectRequestsMixin, DriverCommonMixin - - class Chrome(InspectRequestsMixin, DriverCommonMixin, _Chrome): - def __init__(self, *args, options=None, seleniumwire_options={}, **kwargs): - if options is None: - options = ChromeOptions() - config = self._setup_backend(seleniumwire_options) - options.add_argument(f"--proxy-server={config['proxy']['httpProxy']}") - options.add_argument("--proxy-bypass-list=<-loopback>") - options.add_argument("--ignore-certificate-errors") - super().__init__(*args, options=options, **kwargs) - has_seleniumwire = True -except: - has_seleniumwire = False - -def get_browser( - user_data_dir: str = None, - headless: bool = False, - proxy: str = None, - options: ChromeOptions = None -) -> WebDriver: - """ - Creates and returns a Chrome WebDriver with specified options. - - Args: - user_data_dir (str, optional): Directory for user data. If None, uses default directory. - headless (bool, optional): Whether to run the browser in headless mode. Defaults to False. - proxy (str, optional): Proxy settings for the browser. Defaults to None. - options (ChromeOptions, optional): ChromeOptions object with specific browser options. Defaults to None. - - Returns: - WebDriver: An instance of WebDriver configured with the specified options. - """ - if not has_requirements: - raise MissingRequirementsError('Install Webdriver packages | pip install -U g4f[webdriver]') - browser = find_chrome_executable() - if browser is None: - raise MissingRequirementsError('Install "Google Chrome" browser') - if user_data_dir is None: - user_data_dir = user_config_dir("g4f") - if user_data_dir and debug.logging: - print("Open browser with config dir:", user_data_dir) - if not options: - options = ChromeOptions() - if proxy: - options.add_argument(f'--proxy-server={proxy}') - # Check for system driver in docker - driver = which('chromedriver') or '/usr/bin/chromedriver' - if not path.isfile(driver) or not access(driver, R_OK): - driver = None - return Chrome( - options=options, - user_data_dir=user_data_dir, - driver_executable_path=driver, - browser_executable_path=browser, - headless=headless, - patcher_force_close=True - ) - -def get_driver_cookies(driver: WebDriver) -> Cookies: - """ - Retrieves cookies from the specified WebDriver. - - Args: - driver (WebDriver): The WebDriver instance from which to retrieve cookies. - - Returns: - dict: A dictionary containing cookies with their names as keys and values as cookie values. - """ - return {cookie["name"]: cookie["value"] for cookie in driver.get_cookies()} - -def bypass_cloudflare(driver: WebDriver, url: str, timeout: int) -> None: - """ - Attempts to bypass Cloudflare protection when accessing a URL using the provided WebDriver. - - Args: - driver (WebDriver): The WebDriver to use for accessing the URL. - url (str): The URL to access. - timeout (int): Time in seconds to wait for the page to load. - - Raises: - Exception: If there is an error while bypassing Cloudflare or loading the page. - """ - driver.get(url) - if driver.find_element(By.TAG_NAME, "body").get_attribute("class") == "no-js": - if debug.logging: - print("Cloudflare protection detected:", url) - - # Open website in a new tab - element = driver.find_element(By.ID, "challenge-body-text") - driver.execute_script(f""" - arguments[0].addEventListener('click', () => {{ - window.open(arguments[1]); - }}); - """, element, url) - element.click() - time.sleep(5) - - # Switch to the new tab and close the old tab - original_window = driver.current_window_handle - for window_handle in driver.window_handles: - if window_handle != original_window: - driver.close() - driver.switch_to.window(window_handle) - break - - # Click on the challenge button in the iframe - try: - driver.switch_to.frame(driver.find_element(By.CSS_SELECTOR, "#turnstile-wrapper iframe")) - WebDriverWait(driver, 5).until( - EC.presence_of_element_located((By.CSS_SELECTOR, "#challenge-stage input")) - ).click() - except NoSuchElementException: - ... - except Exception as e: - if debug.logging: - print(f"Error bypassing Cloudflare: {str(e).splitlines()[0]}") - #driver.switch_to.default_content() - driver.switch_to.window(window_handle) - driver.execute_script("document.href = document.href;") - WebDriverWait(driver, timeout).until( - EC.presence_of_element_located((By.CSS_SELECTOR, "body:not(.no-js)")) - ) - -class WebDriverSession: - """ - Manages a Selenium WebDriver session, including handling of virtual displays and proxies. - """ - - def __init__( - self, - webdriver: WebDriver = None, - user_data_dir: str = None, - headless: bool = False, - virtual_display: bool = False, - proxy: str = None, - options: ChromeOptions = None - ): - """ - Initializes a new instance of the WebDriverSession. - - Args: - webdriver (WebDriver, optional): A WebDriver instance for the session. Defaults to None. - user_data_dir (str, optional): Directory for user data. Defaults to None. - headless (bool, optional): Whether to run the browser in headless mode. Defaults to False. - virtual_display (bool, optional): Whether to use a virtual display. Defaults to False. - proxy (str, optional): Proxy settings for the browser. Defaults to None. - options (ChromeOptions, optional): ChromeOptions for the browser. Defaults to None. - """ - self.webdriver = webdriver - self.user_data_dir = user_data_dir - self.headless = headless - self.virtual_display = Display(size=(1920, 1080)) if has_pyvirtualdisplay and virtual_display else None - self.proxy = proxy - self.options = options - self.default_driver = None - - def reopen( - self, - user_data_dir: str = None, - headless: bool = False, - virtual_display: bool = False - ) -> WebDriver: - """ - Reopens the WebDriver session with new settings. - - Args: - user_data_dir (str, optional): Directory for user data. Defaults to current value. - headless (bool, optional): Whether to run the browser in headless mode. Defaults to current value. - virtual_display (bool, optional): Whether to use a virtual display. Defaults to current value. - - Returns: - WebDriver: The reopened WebDriver instance. - """ - user_data_dir = user_data_dir or self.user_data_dir - if self.default_driver: - self.default_driver.quit() - if not virtual_display and self.virtual_display: - self.virtual_display.stop() - self.virtual_display = None - self.default_driver = get_browser(user_data_dir, headless, self.proxy) - return self.default_driver - - def __enter__(self) -> WebDriver: - """ - Context management method for entering a session. Initializes and returns a WebDriver instance. - - Returns: - WebDriver: An instance of WebDriver for this session. - """ - if self.webdriver: - return self.webdriver - if self.virtual_display: - self.virtual_display.start() - self.default_driver = get_browser(self.user_data_dir, self.headless, self.proxy, self.options) - return self.default_driver - - def __exit__(self, exc_type, exc_val, exc_tb): - """ - Context management method for exiting a session. Closes and quits the WebDriver. - - Args: - exc_type: Exception type. - exc_val: Exception value. - exc_tb: Exception traceback. - - Note: - Closes the WebDriver and stops the virtual display if used. - """ - if self.default_driver: - try: - self.default_driver.close() - except Exception as e: - if debug.logging: - print(f"Error closing WebDriver: {str(e).splitlines()[0]}") - finally: - self.default_driver.quit() - if self.virtual_display: - self.virtual_display.stop() - -def element_send_text(element: WebElement, text: str) -> None: - script = "arguments[0].innerText = arguments[1];" - element.parent.execute_script(script, element, text) - element.send_keys(Keys.ENTER) \ No newline at end of file diff --git a/setup.py b/setup.py index d41e04bd..d16c19fb 100644 --- a/setup.py +++ b/setup.py @@ -84,6 +84,16 @@ EXTRA_REQUIRE = { ], "local": [ "gpt4all" + ], + "files": [ + "spacy", + "filesplit", + "beautifulsoup4", + "pypdf2", + "docx", + "odfpy", + "ebooklib", + "openpyxl", ] } -- cgit v1.2.3 From 0d59789eedf3784cf4c3aaf764785a4ad91723c4 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Wed, 1 Jan 2025 14:01:33 +0100 Subject: Add File API Documentation for Python and JS Format Bucket Placeholder in GUI --- README.md | 3 +- docs/file.md | 182 ++++++++++++++++++++++++++++++++++++ g4f/gui/client/static/js/chat.v1.js | 14 ++- g4f/tools/files.py | 103 +++++++++++--------- g4f/tools/run_tools.py | 18 +++- g4f/tools/web_search.py | 19 +++- 6 files changed, 285 insertions(+), 54 deletions(-) create mode 100644 docs/file.md diff --git a/README.md b/README.md index 52ccc2b8..b30e290f 100644 --- a/README.md +++ b/README.md @@ -243,7 +243,8 @@ print(f"Generated image URL: {image_url}") - **Requests API from G4F:** [/docs/requests](docs/requests.md) - **Client API from G4F:** [/docs/client](docs/client.md) - **AsyncClient API from G4F:** [/docs/async_client](docs/async_client.md) - + - **File API from G4F:** [/docs/file](docs/file.md) + - **Legacy:** - **Legacy API with python modules:** [/docs/legacy](docs/legacy.md) diff --git a/docs/file.md b/docs/file.md new file mode 100644 index 00000000..be20d8f0 --- /dev/null +++ b/docs/file.md @@ -0,0 +1,182 @@ +## G4F - File API Documentation with Web Download and Enhanced File Support + +This document details the enhanced G4F File API, allowing users to upload files, download files from web URLs, and process a wider range of file types for integration with language models. + +**Key Improvements:** + +* **Web URL Downloads:** Upload a `downloads.json` file to your bucket containing a list of URLs. The API will download and process these files. Example: `[{"url": "https://example.com/document.pdf"}]` + +* **Expanded File Support:** Added support for additional plain text file extensions: `.txt`, `.xml`, `.json`, `.js`, `.har`, `.sh`, `.py`, `.php`, `.css`, `.yaml`, `.sql`, `.log`, `.csv`, `.twig`, `.md`. Binary file support remains for `.pdf`, `.html`, `.docx`, `.odt`, `.epub`, `.xlsx`, and `.zip`. + +* **Server-Sent Events (SSE):** SSE are now used to provide asynchronous updates on file download and processing progress. This improves the user experience, particularly for large files and multiple downloads. + + +**API Endpoints:** + +* **Upload:** `/v1/files/{bucket_id}` (POST) + + * **Method:** POST + * **Path Parameters:** `bucket_id` (Generated by your own. For example a UUID) + * **Body:** Multipart/form-data with files OR a `downloads.json` file containing URLs. + * **Response:** JSON object with `bucket_id`, `url`, and a list of uploaded/downloaded filenames. + + +* **Retrieve:** `/v1/files/{bucket_id}` (GET) + + * **Method:** GET + * **Path Parameters:** `bucket_id` + * **Query Parameters:** + * `delete_files`: (Optional, boolean, default `true`) Delete files after retrieval. + * `refine_chunks_with_spacy`: (Optional, boolean, default `false`) Apply spaCy-based refinement. + * **Response:** Streaming response with extracted text, separated by ``` markers. SSE updates are sent if the `Accept` header includes `text/event-stream`. + + +**Example Usage (Python):** + +```python +import requests +import uuid +import json + +def upload_and_process(files_or_urls, bucket_id=None): + if bucket_id is None: + bucket_id = str(uuid.uuid4()) + + if isinstance(files_or_urls, list): #URLs + files = {'files': ('downloads.json', json.dumps(files_or_urls), 'application/json')} + elif isinstance(files_or_urls, dict): #Files + files = files_or_urls + else: + raise ValueError("files_or_urls must be a list of URLs or a dictionary of files") + + upload_response = requests.post(f'http://localhost:1337/v1/files/{bucket_id}', files=files) + + if upload_response.status_code == 200: + upload_data = upload_response.json() + print(f"Upload successful. Bucket ID: {upload_data['bucket_id']}") + else: + print(f"Upload failed: {upload_response.status_code} - {upload_response.text}") + + response = requests.get(f'http://localhost:1337/v1/files/{bucket_id}', stream=True, headers={'Accept': 'text/event-stream'}) + for line in response.iter_lines(): + if line: + line = line.decode('utf-8') + if line.startswith('data:'): + try: + data = json.loads(line[5:]) #remove data: prefix + if "action" in data: + print(f"SSE Event: {data}") + elif "error" in data: + print(f"Error: {data['error']['message']}") + else: + print(f"File data received: {data}") #Assuming it's file content + except json.JSONDecodeError as e: + print(f"Error decoding JSON: {e}") + else: + print(f"Unhandled SSE event: {line}") + response.close() + +# Example with URLs +urls = [{"url": "https://github.com/xtekky/gpt4free/issues"}] +bucket_id = upload_and_process(urls) + +#Example with files +files = {'files': open('document.pdf', 'rb'), 'files': open('data.json', 'rb')} +bucket_id = upload_and_process(files) +``` + + +**Example Usage (JavaScript):** + +```javascript +function uuid() { + return ([1e7]+-1e3+-4e3+-8e3+-1e11).replace(/[018]/g, c => + (c ^ crypto.getRandomValues(new Uint8Array(1))[0] & 15 >> c / 4).toString(16) + ); +} + +async function upload_files_or_urls(data) { + let bucket_id = uuid(); // Use a random generated key for your bucket + + let formData = new FormData(); + if (typeof data === "object" && data.constructor === Array) { //URLs + const blob = new Blob([JSON.stringify(data)], { type: 'application/json' }); + const file = new File([blob], 'downloads.json', { type: 'application/json' }); // Create File object + formData.append('files', file); // Append as a file + } else { //Files + Array.from(data).forEach(file => { + formData.append('files', file); + }); + } + + await fetch("/v1/files/" + bucket_id, { + method: 'POST', + body: formData + }); + + function connectToSSE(url) { + const eventSource = new EventSource(url); + eventSource.onmessage = (event) => { + const data = JSON.parse(event.data); + if (data.error) { + console.error("Error:", data.error.message); + } else if (data.action === "done") { + console.log("Files loaded successfully. Bucket ID:", bucket_id); + // Use bucket_id in your LLM prompt. + const prompt = `Use files from bucket. ${JSON.stringify({"bucket_id": bucket_id})} to answer this: ...your question...`; + // ... Send prompt to your language model ... + } else { + console.log("SSE Event:", data); // Update UI with progress as needed + } + }; + eventSource.onerror = (event) => { + console.error("SSE Error:", event); + eventSource.close(); + }; + } + + connectToSSE(`/v1/files/${bucket_id}`); //Retrieve and refine +} + +// Example with URLs +const urls = [{"url": "https://github.com/xtekky/gpt4free/issues"}]; +upload_files_or_urls(urls) + +// Example with files (using a file input element) +const fileInput = document.getElementById('fileInput'); +fileInput.addEventListener('change', () => { + upload_files_or_urls(fileInput.files); +}); +``` + +**Integrating with `ChatCompletion`:** + +To incorporate file uploads into your client applications, include the `tool_calls` parameter in your chat completion requests, using the `bucket_tool` function. The `bucket_id` is passed as a JSON object within your prompt. + + +```json +{ + "messages": [ + { + "role": "user", + "content": "Answer this question using the files in the specified bucket: ...your question...\n{\"bucket_id\": \"your_actual_bucket_id\"}" + } + ], + "tool_calls": [ + { + "function": { + "name": "bucket_tool" + }, + "type": "function" + } + ] +} +``` + +**Important Considerations:** + +* **Error Handling:** Implement robust error handling in both Python and JavaScript to gracefully manage potential issues during file uploads, downloads, and API interactions. +* **Dependencies:** Ensure all required packages are installed (`pip install -U g4f[files]` for Python). + +--- +[Return to Home](/) \ No newline at end of file diff --git a/g4f/gui/client/static/js/chat.v1.js b/g4f/gui/client/static/js/chat.v1.js index ea1afe29..4afea101 100644 --- a/g4f/gui/client/static/js/chat.v1.js +++ b/g4f/gui/client/static/js/chat.v1.js @@ -59,6 +59,10 @@ if (window.markdownit) { return markdown.render(content .replaceAll(/|/gm, "") .replaceAll(//gm, "") + .replaceAll(/{"bucket_id":"([^"]+)"}/gm, (match, p1) => { + size = appStorage.getItem(`bucket:${p1}`); + return `**Bucket:** [[${p1}]](/backend-api/v2/files/${p1})${size ? ` (${formatFileSize(size)})` : ""}`; + }) ) .replaceAll("', '') @@ -1802,16 +1806,18 @@ function formatFileSize(bytes) { async function upload_files(fileInput) { const paperclip = document.querySelector(".user-input .fa-paperclip"); const bucket_id = uuid(); + delete fileInput.dataset.text; + paperclip.classList.add("blink"); const formData = new FormData(); Array.from(fileInput.files).forEach(file => { formData.append('files[]', file); }); - paperclip.classList.add("blink"); await fetch("/backend-api/v2/files/" + bucket_id, { method: 'POST', body: formData }); + let do_refine = document.getElementById("refine").checked; function connectToSSE(url) { const eventSource = new EventSource(url); @@ -1819,21 +1825,25 @@ async function upload_files(fileInput) { const data = JSON.parse(event.data); if (data.error) { inputCount.innerText = `Error: ${data.error.message}`; + paperclip.classList.remove("blink"); + fileInput.value = ""; } else if (data.action == "load") { inputCount.innerText = `Read data: ${formatFileSize(data.size)}`; } else if (data.action == "refine") { inputCount.innerText = `Refine data: ${formatFileSize(data.size)}`; + } else if (data.action == "download") { + inputCount.innerText = `Download: ${data.count} files`; } else if (data.action == "done") { if (do_refine) { do_refine = false; connectToSSE(`/backend-api/v2/files/${bucket_id}?refine_chunks_with_spacy=true`); return; } + appStorage.setItem(`bucket:${bucket_id}`, data.size); inputCount.innerText = "Files are loaded successfully"; messageInput.value += (messageInput.value ? "\n" : "") + JSON.stringify({bucket_id: bucket_id}) + "\n"; paperclip.classList.remove("blink"); fileInput.value = ""; - delete fileInput.dataset.text; } }; eventSource.onerror = (event) => { diff --git a/g4f/tools/files.py b/g4f/tools/files.py index d0d1c23b..b06fe5ea 100644 --- a/g4f/tools/files.py +++ b/g4f/tools/files.py @@ -3,7 +3,7 @@ from __future__ import annotations import os import json from pathlib import Path -from typing import Iterator, Optional +from typing import Iterator, Optional, AsyncIterator from aiohttp import ClientSession, ClientError, ClientResponse, ClientTimeout import urllib.parse import time @@ -74,6 +74,7 @@ except ImportError: from .web_search import scrape_text from ..cookies import get_cookies_dir from ..requests.aiohttp import get_connector +from ..providers.asyncio import to_sync_generator from ..errors import MissingRequirementsError from .. import debug @@ -148,10 +149,12 @@ def spacy_refine_chunks(source_iterator): def get_filenames(bucket_dir: Path): files = bucket_dir / FILE_LIST - with files.open('r') as f: - return [filename.strip() for filename in f.readlines()] + if files.exists(): + with files.open('r') as f: + return [filename.strip() for filename in f.readlines()] + return [] -def stream_read_files(bucket_dir: Path, filenames: list) -> Iterator[str]: +def stream_read_files(bucket_dir: Path, filenames: list, delete_files: bool = False) -> Iterator[str]: for filename in filenames: file_path: Path = bucket_dir / filename if not file_path.exists() and 0 > file_path.lstat().st_size: @@ -161,17 +164,18 @@ def stream_read_files(bucket_dir: Path, filenames: list) -> Iterator[str]: with zipfile.ZipFile(file_path, 'r') as zip_ref: zip_ref.extractall(bucket_dir) try: - yield from stream_read_files(bucket_dir, [f for f in zip_ref.namelist() if supports_filename(f)]) + yield from stream_read_files(bucket_dir, [f for f in zip_ref.namelist() if supports_filename(f)], delete_files) except zipfile.BadZipFile: pass finally: - for unlink in zip_ref.namelist()[::-1]: - filepath = os.path.join(bucket_dir, unlink) - if os.path.exists(filepath): - if os.path.isdir(filepath): - os.rmdir(filepath) - else: - os.unlink(filepath) + if delete_files: + for unlink in zip_ref.namelist()[::-1]: + filepath = os.path.join(bucket_dir, unlink) + if os.path.exists(filepath): + if os.path.isdir(filepath): + os.rmdir(filepath) + else: + os.unlink(filepath) continue yield f"```{filename}\n" if has_pypdf2 and filename.endswith(".pdf"): @@ -320,7 +324,7 @@ def split_file_by_size_and_newline(input_filename, output_dir, chunk_size_bytes= with open(output_filename, 'w', encoding='utf-8') as outfile: outfile.write(current_chunk) -async def get_filename(response: ClientResponse): +async def get_filename(response: ClientResponse) -> str: """ Attempts to extract a filename from an aiohttp response. Prioritizes Content-Disposition, then URL. @@ -347,8 +351,9 @@ async def get_filename(response: ClientResponse): if extension: parsed_url = urllib.parse.urlparse(url) sha256_hash = hashlib.sha256(url.encode()).digest() - base64_encoded = base64.b32encode(sha256_hash).decode().lower() - return f"{parsed_url.netloc} {parsed_url.path[1:].replace('/', '_')} {base64_encoded[:6]}{extension}" + base32_encoded = base64.b32encode(sha256_hash).decode() + url_hash = base32_encoded[:24].lower() + return f"{parsed_url.netloc} {parsed_url.path[1:].replace('/', '_')} {url_hash}{extension}" return None @@ -404,21 +409,22 @@ def read_links(html: str, base: str) -> set[str]: for link in soup.select("a"): if "rel" not in link.attrs or "nofollow" not in link.attrs["rel"]: url = link.attrs.get("href") - if url and url.startswith("https://"): + if url and url.startswith("https://") or url.startswith("/"): urls.append(url.split("#")[0]) return set([urllib.parse.urljoin(base, link) for link in urls]) async def download_urls( bucket_dir: Path, urls: list[str], - max_depth: int = 2, - loaded_urls: set[str] = set(), + max_depth: int = 1, + loading_urls: set[str] = set(), lock: asyncio.Lock = None, delay: int = 3, + new_urls: list[str] = list(), group_size: int = 5, timeout: int = 10, proxy: Optional[str] = None -) -> list[str]: +) -> AsyncIterator[str]: if lock is None: lock = asyncio.Lock() async with ClientSession( @@ -433,30 +439,37 @@ async def download_urls( if not filename: print(f"Failed to get filename for {url}") return None - newfiles = [filename] + if not supports_filename(filename) or filename == DOWNLOADS_FILE: + return None if filename.endswith(".html") and max_depth > 0: - new_urls = read_links(await response.text(), str(response.url)) - async with lock: - new_urls = [new_url for new_url in new_urls if new_url not in loaded_urls] - [loaded_urls.add(url) for url in new_urls] - if new_urls: - for i in range(0, len(new_urls), group_size): - newfiles += await download_urls(bucket_dir, new_urls[i:i + group_size], max_depth - 1, loaded_urls, lock, delay + 1) - await asyncio.sleep(delay) - if supports_filename(filename) and filename != DOWNLOADS_FILE: - target = bucket_dir / filename - with target.open("wb") as f: - async for chunk in response.content.iter_chunked(4096): - f.write(chunk) - return newfiles + add_urls = read_links(await response.text(), str(response.url)) + if add_urls: + async with lock: + add_urls = [add_url for add_url in add_urls if add_url not in loading_urls] + [loading_urls.add(add_url) for add_url in add_urls] + [new_urls.append(add_url) for add_url in add_urls if add_url not in new_urls] + target = bucket_dir / filename + with target.open("wb") as f: + async for chunk in response.content.iter_chunked(4096): + if b'', f'\n'.encode())) + return filename except (ClientError, asyncio.TimeoutError) as e: debug.log(f"Download failed: {e.__class__.__name__}: {e}") return None - files = set() - for results in await asyncio.gather(*[download_url(url) for url in urls]): - if results: - [files.add(url) for url in results] - return files + for filename in await asyncio.gather(*[download_url(url) for url in urls]): + if filename: + yield filename + else: + await asyncio.sleep(delay) + while new_urls: + next_urls = list() + for i in range(0, len(new_urls), group_size): + chunked_urls = new_urls[i:i + group_size] + async for filename in download_urls(bucket_dir, chunked_urls, max_depth - 1, loading_urls, lock, delay + 1, next_urls): + yield filename + await asyncio.sleep(delay) + new_urls = next_urls def get_streaming(bucket_dir: str, delete_files = False, refine_chunks_with_spacy = False, event_stream: bool = False) -> Iterator[str]: bucket_dir = Path(bucket_dir) @@ -473,9 +486,13 @@ def get_streaming(bucket_dir: str, delete_files = False, refine_chunks_with_spac if "url" in item: urls.append(item["url"]) if urls: - filenames = asyncio.run(download_urls(bucket_dir, urls)) + count = 0 with open(os.path.join(bucket_dir, FILE_LIST), 'w') as f: - [f.write(f"{filename}\n") for filename in filenames if filename] + for filename in to_sync_generator(download_urls(bucket_dir, urls)): + f.write(f"{filename}\n") + if event_stream: + count += 1 + yield f'data: {json.dumps({"action": "download", "count": count})}\n\n' if refine_chunks_with_spacy: size = 0 @@ -486,7 +503,7 @@ def get_streaming(bucket_dir: str, delete_files = False, refine_chunks_with_spac else: yield chunk else: - streaming = stream_read_files(bucket_dir, get_filenames(bucket_dir)) + streaming = stream_read_files(bucket_dir, get_filenames(bucket_dir), delete_files) streaming = cache_stream(streaming, bucket_dir) size = 0 for chunk in streaming: @@ -504,7 +521,7 @@ def get_streaming(bucket_dir: str, delete_files = False, refine_chunks_with_spac if event_stream: yield f'data: {json.dumps({"action": "delete_files"})}\n\n' if event_stream: - yield f'data: {json.dumps({"action": "done"})}\n\n' + yield f'data: {json.dumps({"action": "done", "size": size})}\n\n' except Exception as e: if event_stream: yield f'data: {json.dumps({"error": {"message": str(e)}})}\n\n' diff --git a/g4f/tools/run_tools.py b/g4f/tools/run_tools.py index 21e9ec09..b3febfcd 100644 --- a/g4f/tools/run_tools.py +++ b/g4f/tools/run_tools.py @@ -12,6 +12,10 @@ from .web_search import do_search, get_search_message from .files import read_bucket, get_bucket_dir from .. import debug +BUCKET_INSTRUCTIONS = """ +Instruction: Make sure to add the sources of cites using [[domain]](Url) notation after the reference. Example: [[a-z0-9.]](http://example.com) +""" + def validate_arguments(data: dict) -> dict: if "arguments" in data: if isinstance(data["arguments"], str): @@ -36,7 +40,7 @@ async def async_iter_run_tools(async_iter_callback, model, messages, tool_calls: ) elif tool.get("function", {}).get("name") == "continue": last_line = messages[-1]["content"].strip().splitlines()[-1] - content = f"Continue writing the story after this line start with a plus sign if you begin a new word.\n{last_line}" + content = f"Continue after this line.\n{last_line}" messages.append({"role": "user", "content": content}) response = async_iter_callback(model=model, messages=messages, **kwargs) if not hasattr(response, "__aiter__"): @@ -73,7 +77,7 @@ def iter_run_tools( elif tool.get("function", {}).get("name") == "continue_tool": if provider not in ("OpenaiAccount", "HuggingFace"): last_line = messages[-1]["content"].strip().splitlines()[-1] - content = f"continue after this line:\n{last_line}" + content = f"Continue after this line:\n{last_line}" messages.append({"role": "user", "content": content}) else: # Enable provider native continue @@ -82,6 +86,14 @@ def iter_run_tools( elif tool.get("function", {}).get("name") == "bucket_tool": def on_bucket(match): return "".join(read_bucket(get_bucket_dir(match.group(1)))) - messages[-1]["content"] = re.sub(r'{"bucket_id":"([^"]*)"}', on_bucket, messages[-1]["content"]) + has_bucket = False + for message in messages: + if "content" in message and isinstance(message["content"], str): + new_message_content = re.sub(r'{"bucket_id":"([^"]*)"}', on_bucket, message["content"]) + if new_message_content != message["content"]: + has_bucket = True + message["content"] = new_message_content + if has_bucket and isinstance(messages[-1]["content"], str): + messages[-1]["content"] += BUCKET_INSTRUCTIONS print(messages[-1]) return iter_callback(model=model, messages=messages, provider=provider, **kwargs) \ No newline at end of file diff --git a/g4f/tools/web_search.py b/g4f/tools/web_search.py index 9033e0ad..780e45df 100644 --- a/g4f/tools/web_search.py +++ b/g4f/tools/web_search.py @@ -4,7 +4,10 @@ from aiohttp import ClientSession, ClientTimeout, ClientError import json import hashlib from pathlib import Path -from collections import Counter +from urllib.parse import urlparse +import datetime +import asyncio + try: from duckduckgo_search import DDGS from duckduckgo_search.exceptions import DuckDuckGoSearchException @@ -17,13 +20,12 @@ try: has_spacy = True except: has_spacy = False + from typing import Iterator from ..cookies import get_cookies_dir from ..errors import MissingRequirementsError from .. import debug -import asyncio - DEFAULT_INSTRUCTIONS = """ Using the provided web search results, to write a comprehensive reply to the user request. Make sure to add the sources of cites using [[Number]](Url) notation after the reference. Example: [[0]](http://google.com) @@ -64,7 +66,8 @@ class SearchResultEntry(): self.text = text def scrape_text(html: str, max_words: int = None) -> Iterator[str]: - soup = BeautifulSoup(html, "html.parser") + source = BeautifulSoup(html, "html.parser") + soup = source for selector in [ "main", ".main-content-wrapper", @@ -96,12 +99,18 @@ def scrape_text(html: str, max_words: int = None) -> Iterator[str]: break yield " ".join(words) + "\n" + canonical_link = source.find("link", rel="canonical") + if canonical_link and "href" in canonical_link.attrs: + link = canonical_link["href"] + domain = urlparse(link).netloc + yield f"\nSource: [{domain}]({link})" + async def fetch_and_scrape(session: ClientSession, url: str, max_words: int = None) -> str: try: bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / "fetch_and_scrape" bucket_dir.mkdir(parents=True, exist_ok=True) md5_hash = hashlib.md5(url.encode()).hexdigest() - cache_file = bucket_dir / f"{url.split('/')[3]}.{md5_hash}.txt" + cache_file = bucket_dir / f"{url.split('/')[3]}.{datetime.date.today()}.{md5_hash}.txt" if cache_file.exists(): return cache_file.read_text() async with session.get(url) as response: -- cgit v1.2.3 From fe88b57dfa06ff392f2045400e1a4fe71c5a2237 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Wed, 1 Jan 2025 14:10:48 +0100 Subject: Add File API Documentation for Python and JS --- g4f/tools/run_tools.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/g4f/tools/run_tools.py b/g4f/tools/run_tools.py index b3febfcd..c283786e 100644 --- a/g4f/tools/run_tools.py +++ b/g4f/tools/run_tools.py @@ -42,6 +42,19 @@ async def async_iter_run_tools(async_iter_callback, model, messages, tool_calls: last_line = messages[-1]["content"].strip().splitlines()[-1] content = f"Continue after this line.\n{last_line}" messages.append({"role": "user", "content": content}) + elif tool.get("function", {}).get("name") == "bucket_tool": + def on_bucket(match): + return "".join(read_bucket(get_bucket_dir(match.group(1)))) + has_bucket = False + for message in messages: + if "content" in message and isinstance(message["content"], str): + new_message_content = re.sub(r'{"bucket_id":"([^"]*)"}', on_bucket, message["content"]) + if new_message_content != message["content"]: + has_bucket = True + message["content"] = new_message_content + if has_bucket and isinstance(messages[-1]["content"], str): + messages[-1]["content"] += BUCKET_INSTRUCTIONS + response = async_iter_callback(model=model, messages=messages, **kwargs) if not hasattr(response, "__aiter__"): response = to_async_iterator(response) @@ -95,5 +108,5 @@ def iter_run_tools( message["content"] = new_message_content if has_bucket and isinstance(messages[-1]["content"], str): messages[-1]["content"] += BUCKET_INSTRUCTIONS - print(messages[-1]) + return iter_callback(model=model, messages=messages, provider=provider, **kwargs) \ No newline at end of file -- cgit v1.2.3