summaryrefslogblamecommitdiffstats
path: root/g4f/api/__init__.py
blob: d86364d182454635f80a37c3afffb25bfb2a9bb3 (plain) (tree)
1
2
3
4
5
6
7
8
9



                                                                   
          
           


             
















                            

 



























                                                                                                                                       
 















                                                                                                                            
    





































                                                                                                                                                   
 
                                                                                               
 
                    
            















                                                      
 


                                                                            
 
                                   
                                                  
                                                  


                                                
                     
                                   
                                    
                                                



                     

                                                                            
 

                             
 
                                                                                             
 





                                                                                                              
                                                                               
from fastapi                  import FastAPI, Response, Request
from fastapi.middleware.cors  import CORSMiddleware
from typing                   import List, Union, Any, Dict, AnyStr
from ._tokenizer              import tokenize
import g4f
import time
import json
import random
import string
import uvicorn
import nest_asyncio

app = FastAPI()
nest_asyncio.apply()

origins = [
    "http://localhost",
    "http://localhost:1337",
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

JSONObject = Dict[AnyStr, Any]
JSONArray = List[Any]
JSONStructure = Union[JSONArray, JSONObject]

@app.get("/")
async def read_root():
    return Response(content=json.dumps({"info": "G4F API"}, indent=4), media_type="application/json")

@app.get("/v1")
async def read_root_v1():
    return Response(content=json.dumps({"info": "Go to /v1/chat/completions or /v1/models."}, indent=4), media_type="application/json")

@app.get("/v1/models")
async def models():
    model_list = [{
        'id': model,
        'object': 'model',
        'created': 0,
        'owned_by': 'g4f'} for model in g4f.Model.__all__()]

    return Response(content=json.dumps({
        'object': 'list',
        'data': model_list}, indent=4), media_type="application/json")

@app.get("/v1/models/{model_name}")
async def model_info(model_name: str):
    try:
        model_info = (g4f.ModelUtils.convert[model_name])

        return Response(content=json.dumps({
            'id': model_name,
            'object': 'model',
            'created': 0,
            'owned_by': model_info.base_provider
        }, indent=4), media_type="application/json")
    except:
        return Response(content=json.dumps({"error": "The model does not exist."}, indent=4), media_type="application/json")

@app.post("/v1/chat/completions")
async def chat_completions(request: Request, item: JSONStructure = None):

    item_data = {
        'model': 'gpt-3.5-turbo',
        'stream': False,
    }
    
    item_data.update(item or {})
    model = item_data.get('model')
    stream = item_data.get('stream')
    messages = item_data.get('messages')

    try:
        response = g4f.ChatCompletion.create(model=model, stream=stream, messages=messages)
    except:
        return Response(content=json.dumps({"error": "An error occurred while generating the response."}, indent=4), media_type="application/json")

    completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
    completion_timestamp = int(time.time())

    if not stream:
        prompt_tokens, _ = tokenize(''.join([message['content'] for message in messages]))
        completion_tokens, _ = tokenize(response)

        json_data = {
            'id': f'chatcmpl-{completion_id}',
            'object': 'chat.completion',
            'created': completion_timestamp,
            'model': model,
            'choices': [
                {
                    'index': 0,
                    'message': {
                        'role': 'assistant',
                        'content': response,
                    },
                    'finish_reason': 'stop',
                }
            ],
            'usage': {
                'prompt_tokens': prompt_tokens,
                'completion_tokens': completion_tokens,
                'total_tokens': prompt_tokens + completion_tokens,
            },
        }

        return Response(content=json.dumps(json_data, indent=4), media_type="application/json")

    def streaming():
        try:
            for chunk in response:
                completion_data = {
                    'id': f'chatcmpl-{completion_id}',
                    'object': 'chat.completion.chunk',
                    'created': completion_timestamp,
                    'model': model,
                    'choices': [
                        {
                            'index': 0,
                            'delta': {
                                'content': chunk,
                            },
                            'finish_reason': None,
                        }
                    ],
                }

                content = json.dumps(completion_data, separators=(',', ':'))
                yield f'data: {content}\n\n'
                time.sleep(0.03)

            end_completion_data = {
                'id': f'chatcmpl-{completion_id}',
                'object': 'chat.completion.chunk',
                'created': completion_timestamp,
                'model': model,
                'choices': [
                    {
                        'index': 0,
                        'delta': {},
                        'finish_reason': 'stop',
                    }
                ],
            }

            content = json.dumps(end_completion_data, separators=(',', ':'))
            yield f'data: {content}\n\n'

        except GeneratorExit:
            pass

    return Response(content=json.dumps(streaming(), indent=4), media_type="application/json")

@app.post("/v1/completions")
async def completions():
    return Response(content=json.dumps({'info': 'Not working yet.'}, indent=4), media_type="application/json")

def run(ip):
    split_ip = ip.split(":")
    uvicorn.run(app, host=split_ip[0], port=int(split_ip[1]), use_colors=False)