From 804a80bc7cb788a77af169242b84c892ec4fc638 Mon Sep 17 00:00:00 2001 From: H Lohaus Date: Sun, 24 Nov 2024 17:43:45 +0100 Subject: Arm2 (#2414) * Fix arm v7 build / improve api * Update stubs.py * Fix unit tests --- g4f/api/__init__.py | 115 ++++++++++++++++++++++++++++++++-------------------- 1 file changed, 70 insertions(+), 45 deletions(-) (limited to 'g4f/api/__init__.py') diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py index 94fc23a5..a8403e5c 100644 --- a/g4f/api/__init__.py +++ b/g4f/api/__init__.py @@ -8,21 +8,29 @@ import os import shutil import os.path -from fastapi import FastAPI, Response, Request, UploadFile +from fastapi import FastAPI, Response, Request, UploadFile, Depends +from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.responses import StreamingResponse, RedirectResponse, HTMLResponse, JSONResponse from fastapi.exceptions import RequestValidationError from fastapi.security import APIKeyHeader from starlette.exceptions import HTTPException -from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN +from starlette.status import ( + HTTP_200_OK, + HTTP_422_UNPROCESSABLE_ENTITY, + HTTP_404_NOT_FOUND, + HTTP_401_UNAUTHORIZED, + HTTP_403_FORBIDDEN +) from fastapi.encoders import jsonable_encoder +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.middleware.cors import CORSMiddleware from starlette.responses import FileResponse -from pydantic import BaseModel -from typing import Union, Optional, List +from pydantic import BaseModel, Field +from typing import Union, Optional, List, Annotated import g4f import g4f.debug -from g4f.client import AsyncClient, ChatCompletion, convert_to_provider +from g4f.client import AsyncClient, ChatCompletion, ImagesResponse, convert_to_provider from g4f.providers.response import BaseConversation from g4f.client.helper import filter_none from g4f.image import is_accepted_format, images_dir @@ -30,6 +38,7 @@ from g4f.typing import Messages from g4f.errors import ProviderNotFoundError from g4f.cookies import read_cookie_files, get_cookies_dir from g4f.Provider import ProviderType, ProviderUtils, __providers__ +from g4f.gui import get_gui_app logger = logging.getLogger(__name__) @@ -50,6 +59,10 @@ def create_app(g4f_api_key: str = None): api.register_authorization() api.register_validation_exception_handler() + if AppConfig.gui: + gui_app = WSGIMiddleware(get_gui_app()) + app.mount("/", gui_app) + # Read cookie files if not ignored if not AppConfig.ignore_cookie_files: read_cookie_files() @@ -61,17 +74,17 @@ def create_app_debug(g4f_api_key: str = None): return create_app(g4f_api_key) class ChatCompletionsConfig(BaseModel): - messages: Messages - model: str - provider: Optional[str] = None + messages: Messages = Field(examples=[[{"role": "system", "content": ""}, {"role": "user", "content": ""}]]) + model: str = Field(default="") + provider: Optional[str] = Field(examples=[None]) stream: bool = False - temperature: Optional[float] = None - max_tokens: Optional[int] = None - stop: Union[list[str], str, None] = None - api_key: Optional[str] = None - web_search: Optional[bool] = None - proxy: Optional[str] = None - conversation_id: str = None + temperature: Optional[float] = Field(examples=[None]) + max_tokens: Optional[int] = Field(examples=[None]) + stop: Union[list[str], str, None] = Field(examples=[None]) + api_key: Optional[str] = Field(examples=[None]) + web_search: Optional[bool] = Field(examples=[None]) + proxy: Optional[str] = Field(examples=[None]) + conversation_id: Optional[str] = Field(examples=[None]) class ImageGenerationConfig(BaseModel): prompt: str @@ -101,6 +114,9 @@ class ModelResponseModel(BaseModel): created: int owned_by: Optional[str] +class ErrorResponseModel(BaseModel): + error: str + class AppConfig: ignored_providers: Optional[list[str]] = None g4f_api_key: Optional[str] = None @@ -109,6 +125,7 @@ class AppConfig: provider: str = None image_provider: str = None proxy: str = None + gui: bool = False @classmethod def set_config(cls, **data): @@ -129,6 +146,8 @@ class Api: self.get_g4f_api_key = APIKeyHeader(name="g4f-api-key") self.conversations: dict[str, dict[str, BaseConversation]] = {} + security = HTTPBearer(auto_error=False) + def register_authorization(self): @self.app.middleware("http") async def authorization(request: Request, call_next): @@ -192,7 +211,7 @@ class Api: } for model_id, model in model_list.items()] @self.app.get("/v1/models/{model_name}") - async def model_info(model_name: str): + async def model_info(model_name: str) -> ModelResponseModel: if model_name in g4f.models.ModelUtils.convert: model_info = g4f.models.ModelUtils.convert[model_name] return JSONResponse({ @@ -201,20 +220,20 @@ class Api: 'created': 0, 'owned_by': model_info.base_provider }) - return JSONResponse({"error": "The model does not exist."}, 404) - - @self.app.post("/v1/chat/completions") - async def chat_completions(config: ChatCompletionsConfig, request: Request = None, provider: str = None): + return JSONResponse({"error": "The model does not exist."}, HTTP_404_NOT_FOUND) + + @self.app.post("/v1/chat/completions", response_model=ChatCompletion) + async def chat_completions( + config: ChatCompletionsConfig, + credentials: Annotated[HTTPAuthorizationCredentials, Depends(Api.security)] = None, + provider: str = None + ): try: config.provider = provider if config.provider is None else config.provider if config.provider is None: config.provider = AppConfig.provider - if config.api_key is None and request is not None: - auth_header = request.headers.get("Authorization") - if auth_header is not None: - api_key = auth_header.split(None, 1)[-1] - if api_key and api_key != "Bearer": - config.api_key = api_key + if credentials is not None: + config.api_key = credentials.credentials conversation = return_conversation = None if config.conversation_id is not None and config.provider is not None: @@ -242,8 +261,7 @@ class Api: ) if not config.stream: - response: ChatCompletion = await response - return JSONResponse(response.to_json()) + return await response async def streaming(): try: @@ -254,7 +272,7 @@ class Api: self.conversations[config.conversation_id] = {} self.conversations[config.conversation_id][config.provider] = chunk else: - yield f"data: {json.dumps(chunk.to_json())}\n\n" + yield f"data: {chunk.json()}\n\n" except GeneratorExit: pass except Exception as e: @@ -268,15 +286,15 @@ class Api: logger.exception(e) return Response(content=format_exception(e, config), status_code=500, media_type="application/json") - @self.app.post("/v1/images/generate") - @self.app.post("/v1/images/generations") - async def generate_image(config: ImageGenerationConfig, request: Request): - if config.api_key is None: - auth_header = request.headers.get("Authorization") - if auth_header is not None: - api_key = auth_header.split(None, 1)[-1] - if api_key and api_key != "Bearer": - config.api_key = api_key + @self.app.post("/v1/images/generate", response_model=ImagesResponse) + @self.app.post("/v1/images/generations", response_model=ImagesResponse) + async def generate_image( + request: Request, + config: ImageGenerationConfig, + credentials: Annotated[HTTPAuthorizationCredentials, Depends(Api.security)] = None + ): + if credentials is not None: + config.api_key = credentials.credentials try: response = await self.client.images.generate( prompt=config.prompt, @@ -291,7 +309,7 @@ class Api: for image in response.data: if hasattr(image, "url") and image.url.startswith("/"): image.url = f"{request.base_url}{image.url.lstrip('/')}" - return JSONResponse(response.to_json()) + return response except Exception as e: logger.exception(e) return Response(content=format_exception(e, config, True), status_code=500, media_type="application/json") @@ -342,22 +360,29 @@ class Api: file.file.close() return response_data - @self.app.get("/v1/synthesize/{provider}") + @self.app.get("/v1/synthesize/{provider}", responses={ + HTTP_200_OK: {"content": {"audio/*": {}}}, + HTTP_404_NOT_FOUND: {"model": ErrorResponseModel}, + HTTP_422_UNPROCESSABLE_ENTITY: {"model": ErrorResponseModel}, + }) async def synthesize(request: Request, provider: str): try: provider_handler = convert_to_provider(provider) except ProviderNotFoundError: - return Response("Provider not found", 404) + return JSONResponse({"error": "Provider not found"}, HTTP_404_NOT_FOUND) if not hasattr(provider_handler, "synthesize"): - return Response("Provider doesn't support synthesize", 500) + return JSONResponse({"error": "Provider doesn't support synthesize"}, HTTP_404_NOT_FOUND) if len(request.query_params) == 0: - return Response("Missing query params", 500) + return JSONResponse({"error": "Missing query params"}, HTTP_422_UNPROCESSABLE_ENTITY) response_data = provider_handler.synthesize({**request.query_params}) content_type = getattr(provider_handler, "synthesize_content_type", "application/octet-stream") return StreamingResponse(response_data, media_type=content_type) - @self.app.get("/images/{filename}") - async def get_image(filename) -> FileResponse: + @self.app.get("/images/{filename}", response_class=FileResponse, responses={ + HTTP_200_OK: {"content": {"image/*": {}}}, + HTTP_404_NOT_FOUND: {} + }) + async def get_image(filename): target = os.path.join(images_dir, filename) if not os.path.isfile(target): -- cgit v1.2.3