diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/Provider/base_provider.py | 36 |
1 files changed, 23 insertions, 13 deletions
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index d5f23931..def2cd6d 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -4,8 +4,7 @@ from ..typing import Any, CreateResult, AsyncGenerator, Union import browser_cookie3 import asyncio -from time import time -import math + class BaseProvider(ABC): url: str @@ -48,6 +47,17 @@ def get_cookies(cookie_domain: str) -> dict: return _cookies[cookie_domain] +def format_prompt(messages: list[dict[str, str]], add_special_tokens=False): + if add_special_tokens or len(messages) > 1: + formatted = "\n".join( + ["%s: %s" % ((message["role"]).capitalize(), message["content"]) for message in messages] + ) + return f"{formatted}\nAssistant:" + else: + return messages.pop()["content"] + + + class AsyncProvider(BaseProvider): @classmethod def create_completion( @@ -72,20 +82,19 @@ class AsyncGeneratorProvider(AsyncProvider): cls, model: str, messages: list[dict[str, str]], - stream: bool = True, **kwargs: Any) -> CreateResult: - - if stream: - yield from run_generator(cls.create_async_generator(model, messages, **kwargs)) - else: - yield from AsyncProvider.create_completion(cls=cls, model=model, messages=messages, **kwargs) + stream: bool = True, + **kwargs + ) -> CreateResult: + yield from run_generator(cls.create_async_generator(model, messages, stream=stream, **kwargs)) @classmethod async def create_async( cls, model: str, - messages: list[dict[str, str]], **kwargs: Any) -> str: - - chunks = [chunk async for chunk in cls.create_async_generator(model, messages, **kwargs)] + messages: list[dict[str, str]], + **kwargs + ) -> str: + chunks = [chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)] if chunks: return "".join(chunks) @@ -93,8 +102,9 @@ class AsyncGeneratorProvider(AsyncProvider): @abstractmethod def create_async_generator( model: str, - messages: list[dict[str, str]]) -> AsyncGenerator: - + messages: list[dict[str, str]], + **kwargs + ) -> AsyncGenerator: raise NotImplementedError() |