summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorH Lohaus <hlohaus@users.noreply.github.com>2024-11-25 15:47:10 +0100
committerGitHub <noreply@github.com>2024-11-25 15:47:10 +0100
commitf4ff7f8da56746d1b1fd87dc9c11b72b1664f399 (patch)
treed0e48eed6722761fe2018cfae777af306e45560d
parentMerge pull request #2423 from hlohaus/model (diff)
parentFix provider selection in images generate (diff)
downloadgpt4free-f4ff7f8da56746d1b1fd87dc9c11b72b1664f399.tar
gpt4free-f4ff7f8da56746d1b1fd87dc9c11b72b1664f399.tar.gz
gpt4free-f4ff7f8da56746d1b1fd87dc9c11b72b1664f399.tar.bz2
gpt4free-f4ff7f8da56746d1b1fd87dc9c11b72b1664f399.tar.lz
gpt4free-f4ff7f8da56746d1b1fd87dc9c11b72b1664f399.tar.xz
gpt4free-f4ff7f8da56746d1b1fd87dc9c11b72b1664f399.tar.zst
gpt4free-f4ff7f8da56746d1b1fd87dc9c11b72b1664f399.zip
-rw-r--r--g4f/Provider/Airforce.py18
-rw-r--r--g4f/client/__init__.py55
2 files changed, 39 insertions, 34 deletions
diff --git a/g4f/Provider/Airforce.py b/g4f/Provider/Airforce.py
index f5bcfefa..283d561e 100644
--- a/g4f/Provider/Airforce.py
+++ b/g4f/Provider/Airforce.py
@@ -7,6 +7,7 @@ import re
import requests
from requests.packages.urllib3.exceptions import InsecureRequestWarning
requests.packages.urllib3.disable_warnings(InsecureRequestWarning)
+from urllib.parse import quote
from ..typing import AsyncResult, Messages
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
@@ -95,14 +96,18 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
model: str,
messages: Messages,
proxy: str = None,
+ prompt: str = None,
seed: int = None,
size: str = "1:1", # "1:1", "16:9", "9:16", "21:9", "9:21", "1:2", "2:1"
stream: bool = False,
**kwargs
) -> AsyncResult:
model = cls.get_model(model)
+
if model in cls.image_models:
- return cls._generate_image(model, messages, proxy, seed, size)
+ if prompt is None:
+ prompt = messages[-1]['content']
+ return cls._generate_image(model, prompt, proxy, seed, size)
else:
return cls._generate_text(model, messages, proxy, stream, **kwargs)
@@ -110,7 +115,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
async def _generate_image(
cls,
model: str,
- messages: Messages,
+ prompt: str,
proxy: str = None,
seed: int = None,
size: str = "1:1",
@@ -125,7 +130,6 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
}
if seed is None:
seed = random.randint(0, 100000)
- prompt = messages[-1]['content']
async with StreamSession(headers=headers, proxy=proxy) as session:
params = {
@@ -140,12 +144,8 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
if 'application/json' in content_type:
raise RuntimeError(await response.json().get("error", {}).get("message"))
- elif 'image' in content_type:
- image_data = b""
- async for chunk in response.iter_content():
- if chunk:
- image_data += chunk
- image_url = f"{cls.api_endpoint_imagine}?model={model}&prompt={prompt}&size={size}&seed={seed}"
+ elif content_type.startswith("image/"):
+ image_url = f"{cls.api_endpoint_imagine}?model={model}&prompt={quote(prompt)}&size={size}&seed={seed}"
yield ImageResponse(images=image_url, alt=prompt)
@classmethod
diff --git a/g4f/client/__init__.py b/g4f/client/__init__.py
index dcd408ab..86a81049 100644
--- a/g4f/client/__init__.py
+++ b/g4f/client/__init__.py
@@ -16,7 +16,7 @@ from ..providers.response import ResponseType, FinishReason, BaseConversation, S
from ..errors import NoImageResponseError, ModelNotFoundError
from ..providers.retry_provider import IterListProvider
from ..providers.asyncio import get_running_loop, to_sync_generator, async_generator_to_list
-from ..Provider.needs_auth.BingCreateImages import BingCreateImages
+from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
from .image_models import ImageModels
from .types import IterResponse, ImageProvider, Client as BaseClient
@@ -264,28 +264,34 @@ class Images:
"""
return asyncio.run(self.async_generate(prompt, model, provider, response_format, proxy, **kwargs))
- async def async_generate(
- self,
- prompt: str,
- model: Optional[str] = None,
- provider: Optional[ProviderType] = None,
- response_format: Optional[str] = "url",
- proxy: Optional[str] = None,
- **kwargs
- ) -> ImagesResponse:
+ async def get_provider_handler(self, model: Optional[str], provider: Optional[ImageProvider], default: ImageProvider) -> ImageProvider:
if provider is None:
- provider_handler = self.models.get(model, provider or self.provider or BingCreateImages)
+ provider_handler = self.provider
+ if provider_handler is None:
+ provider_handler = self.models.get(model, default)
elif isinstance(provider, str):
provider_handler = convert_to_provider(provider)
else:
provider_handler = provider
if provider_handler is None:
- raise ModelNotFoundError(f"Unknown model: {model}")
+ return default
if isinstance(provider_handler, IterListProvider):
if provider_handler.providers:
provider_handler = provider_handler.providers[0]
else:
raise ModelNotFoundError(f"IterListProvider for model {model} has no providers")
+ return provider_handler
+
+ async def async_generate(
+ self,
+ prompt: str,
+ model: Optional[str] = None,
+ provider: Optional[ProviderType] = None,
+ response_format: Optional[str] = "url",
+ proxy: Optional[str] = None,
+ **kwargs
+ ) -> ImagesResponse:
+ provider_handler = await self.get_provider_handler(model, provider, BingCreateImages)
if proxy is None:
proxy = self.client.proxy
@@ -311,7 +317,7 @@ class Images:
response = item
break
else:
- raise ValueError(f"Provider {provider} does not support image generation")
+ raise ValueError(f"Provider {getattr(provider_handler, '__name__')} does not support image generation")
if isinstance(response, ImageResponse):
return await self._process_image_response(
response,
@@ -320,6 +326,8 @@ class Images:
model,
getattr(provider_handler, "__name__", None)
)
+ if response is None:
+ raise NoImageResponseError(f"No image response from {getattr(provider_handler, '__name__')}")
raise NoImageResponseError(f"Unexpected response type: {type(response)}")
def create_variation(
@@ -343,31 +351,26 @@ class Images:
proxy: Optional[str] = None,
**kwargs
) -> ImagesResponse:
- if provider is None:
- provider = self.models.get(model, provider or self.provider or BingCreateImages)
- if provider is None:
- raise ModelNotFoundError(f"Unknown model: {model}")
- if isinstance(provider, str):
- provider = convert_to_provider(provider)
+ provider_handler = await self.get_provider_handler(model, provider, OpenaiAccount)
if proxy is None:
proxy = self.client.proxy
- if hasattr(provider, "create_async_generator"):
+ if hasattr(provider_handler, "create_async_generator"):
messages = [{"role": "user", "content": "create a variation of this image"}]
generator = None
try:
- generator = provider.create_async_generator(model, messages, image=image, response_format=response_format, proxy=proxy, **kwargs)
+ generator = provider_handler.create_async_generator(model, messages, image=image, response_format=response_format, proxy=proxy, **kwargs)
async for chunk in generator:
if isinstance(chunk, ImageResponse):
response = chunk
break
finally:
await safe_aclose(generator)
- elif hasattr(provider, 'create_variation'):
- if asyncio.iscoroutinefunction(provider.create_variation):
- response = await provider.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
+ elif hasattr(provider_handler, 'create_variation'):
+ if asyncio.iscoroutinefunction(provider.provider_handler):
+ response = await provider_handler.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
else:
- response = provider.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
+ response = provider_handler.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
else:
raise NoImageResponseError(f"Provider {provider} does not support image variation")
@@ -375,6 +378,8 @@ class Images:
response = ImageResponse([response])
if isinstance(response, ImageResponse):
return self._process_image_response(response, response_format, proxy, model, getattr(provider, "__name__", None))
+ if response is None:
+ raise NoImageResponseError(f"No image response from {getattr(provider, '__name__')}")
raise NoImageResponseError(f"Unexpected response type: {type(response)}")
async def _process_image_response(