diff options
author | H Lohaus <hlohaus@users.noreply.github.com> | 2024-11-25 15:47:10 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-11-25 15:47:10 +0100 |
commit | f4ff7f8da56746d1b1fd87dc9c11b72b1664f399 (patch) | |
tree | d0e48eed6722761fe2018cfae777af306e45560d | |
parent | Merge pull request #2423 from hlohaus/model (diff) | |
parent | Fix provider selection in images generate (diff) | |
download | gpt4free-f4ff7f8da56746d1b1fd87dc9c11b72b1664f399.tar gpt4free-f4ff7f8da56746d1b1fd87dc9c11b72b1664f399.tar.gz gpt4free-f4ff7f8da56746d1b1fd87dc9c11b72b1664f399.tar.bz2 gpt4free-f4ff7f8da56746d1b1fd87dc9c11b72b1664f399.tar.lz gpt4free-f4ff7f8da56746d1b1fd87dc9c11b72b1664f399.tar.xz gpt4free-f4ff7f8da56746d1b1fd87dc9c11b72b1664f399.tar.zst gpt4free-f4ff7f8da56746d1b1fd87dc9c11b72b1664f399.zip |
-rw-r--r-- | g4f/Provider/Airforce.py | 18 | ||||
-rw-r--r-- | g4f/client/__init__.py | 55 |
2 files changed, 39 insertions, 34 deletions
diff --git a/g4f/Provider/Airforce.py b/g4f/Provider/Airforce.py index f5bcfefa..283d561e 100644 --- a/g4f/Provider/Airforce.py +++ b/g4f/Provider/Airforce.py @@ -7,6 +7,7 @@ import re import requests from requests.packages.urllib3.exceptions import InsecureRequestWarning requests.packages.urllib3.disable_warnings(InsecureRequestWarning) +from urllib.parse import quote from ..typing import AsyncResult, Messages from .base_provider import AsyncGeneratorProvider, ProviderModelMixin @@ -95,14 +96,18 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin): model: str, messages: Messages, proxy: str = None, + prompt: str = None, seed: int = None, size: str = "1:1", # "1:1", "16:9", "9:16", "21:9", "9:21", "1:2", "2:1" stream: bool = False, **kwargs ) -> AsyncResult: model = cls.get_model(model) + if model in cls.image_models: - return cls._generate_image(model, messages, proxy, seed, size) + if prompt is None: + prompt = messages[-1]['content'] + return cls._generate_image(model, prompt, proxy, seed, size) else: return cls._generate_text(model, messages, proxy, stream, **kwargs) @@ -110,7 +115,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin): async def _generate_image( cls, model: str, - messages: Messages, + prompt: str, proxy: str = None, seed: int = None, size: str = "1:1", @@ -125,7 +130,6 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin): } if seed is None: seed = random.randint(0, 100000) - prompt = messages[-1]['content'] async with StreamSession(headers=headers, proxy=proxy) as session: params = { @@ -140,12 +144,8 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin): if 'application/json' in content_type: raise RuntimeError(await response.json().get("error", {}).get("message")) - elif 'image' in content_type: - image_data = b"" - async for chunk in response.iter_content(): - if chunk: - image_data += chunk - image_url = f"{cls.api_endpoint_imagine}?model={model}&prompt={prompt}&size={size}&seed={seed}" + elif content_type.startswith("image/"): + image_url = f"{cls.api_endpoint_imagine}?model={model}&prompt={quote(prompt)}&size={size}&seed={seed}" yield ImageResponse(images=image_url, alt=prompt) @classmethod diff --git a/g4f/client/__init__.py b/g4f/client/__init__.py index dcd408ab..86a81049 100644 --- a/g4f/client/__init__.py +++ b/g4f/client/__init__.py @@ -16,7 +16,7 @@ from ..providers.response import ResponseType, FinishReason, BaseConversation, S from ..errors import NoImageResponseError, ModelNotFoundError from ..providers.retry_provider import IterListProvider from ..providers.asyncio import get_running_loop, to_sync_generator, async_generator_to_list -from ..Provider.needs_auth.BingCreateImages import BingCreateImages +from ..Provider.needs_auth import BingCreateImages, OpenaiAccount from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse from .image_models import ImageModels from .types import IterResponse, ImageProvider, Client as BaseClient @@ -264,28 +264,34 @@ class Images: """ return asyncio.run(self.async_generate(prompt, model, provider, response_format, proxy, **kwargs)) - async def async_generate( - self, - prompt: str, - model: Optional[str] = None, - provider: Optional[ProviderType] = None, - response_format: Optional[str] = "url", - proxy: Optional[str] = None, - **kwargs - ) -> ImagesResponse: + async def get_provider_handler(self, model: Optional[str], provider: Optional[ImageProvider], default: ImageProvider) -> ImageProvider: if provider is None: - provider_handler = self.models.get(model, provider or self.provider or BingCreateImages) + provider_handler = self.provider + if provider_handler is None: + provider_handler = self.models.get(model, default) elif isinstance(provider, str): provider_handler = convert_to_provider(provider) else: provider_handler = provider if provider_handler is None: - raise ModelNotFoundError(f"Unknown model: {model}") + return default if isinstance(provider_handler, IterListProvider): if provider_handler.providers: provider_handler = provider_handler.providers[0] else: raise ModelNotFoundError(f"IterListProvider for model {model} has no providers") + return provider_handler + + async def async_generate( + self, + prompt: str, + model: Optional[str] = None, + provider: Optional[ProviderType] = None, + response_format: Optional[str] = "url", + proxy: Optional[str] = None, + **kwargs + ) -> ImagesResponse: + provider_handler = await self.get_provider_handler(model, provider, BingCreateImages) if proxy is None: proxy = self.client.proxy @@ -311,7 +317,7 @@ class Images: response = item break else: - raise ValueError(f"Provider {provider} does not support image generation") + raise ValueError(f"Provider {getattr(provider_handler, '__name__')} does not support image generation") if isinstance(response, ImageResponse): return await self._process_image_response( response, @@ -320,6 +326,8 @@ class Images: model, getattr(provider_handler, "__name__", None) ) + if response is None: + raise NoImageResponseError(f"No image response from {getattr(provider_handler, '__name__')}") raise NoImageResponseError(f"Unexpected response type: {type(response)}") def create_variation( @@ -343,31 +351,26 @@ class Images: proxy: Optional[str] = None, **kwargs ) -> ImagesResponse: - if provider is None: - provider = self.models.get(model, provider or self.provider or BingCreateImages) - if provider is None: - raise ModelNotFoundError(f"Unknown model: {model}") - if isinstance(provider, str): - provider = convert_to_provider(provider) + provider_handler = await self.get_provider_handler(model, provider, OpenaiAccount) if proxy is None: proxy = self.client.proxy - if hasattr(provider, "create_async_generator"): + if hasattr(provider_handler, "create_async_generator"): messages = [{"role": "user", "content": "create a variation of this image"}] generator = None try: - generator = provider.create_async_generator(model, messages, image=image, response_format=response_format, proxy=proxy, **kwargs) + generator = provider_handler.create_async_generator(model, messages, image=image, response_format=response_format, proxy=proxy, **kwargs) async for chunk in generator: if isinstance(chunk, ImageResponse): response = chunk break finally: await safe_aclose(generator) - elif hasattr(provider, 'create_variation'): - if asyncio.iscoroutinefunction(provider.create_variation): - response = await provider.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs) + elif hasattr(provider_handler, 'create_variation'): + if asyncio.iscoroutinefunction(provider.provider_handler): + response = await provider_handler.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs) else: - response = provider.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs) + response = provider_handler.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs) else: raise NoImageResponseError(f"Provider {provider} does not support image variation") @@ -375,6 +378,8 @@ class Images: response = ImageResponse([response]) if isinstance(response, ImageResponse): return self._process_image_response(response, response_format, proxy, model, getattr(provider, "__name__", None)) + if response is None: + raise NoImageResponseError(f"No image response from {getattr(provider, '__name__')}") raise NoImageResponseError(f"Unexpected response type: {type(response)}") async def _process_image_response( |