From 79c407b9397c9e10807dcda4d9df166609284b64 Mon Sep 17 00:00:00 2001 From: H Lohaus Date: Fri, 29 Nov 2024 13:56:11 +0100 Subject: IterListProvider support for generating images (#2441) * IterListProvider support for generating images * Add missing get_har_files import in Copilot * Fix typo in dall-e-3 model name * Add image client unittests * Add MicrosoftDesigner provider * Import MicrosoftDesigner and add it to the model list --- etc/unittest/__main__.py | 1 + etc/unittest/image_client.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ etc/unittest/mocks.py | 21 +++++++++++++++++++++ 3 files changed, 66 insertions(+) create mode 100644 etc/unittest/image_client.py (limited to 'etc') diff --git a/etc/unittest/__main__.py b/etc/unittest/__main__.py index e49dec30..3719c374 100644 --- a/etc/unittest/__main__.py +++ b/etc/unittest/__main__.py @@ -5,6 +5,7 @@ from .backend import * from .main import * from .model import * from .client import * +from .image_client import * from .include import * from .retry_provider import * diff --git a/etc/unittest/image_client.py b/etc/unittest/image_client.py new file mode 100644 index 00000000..b52ba8b0 --- /dev/null +++ b/etc/unittest/image_client.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import asyncio +import unittest + +from g4f.client import AsyncClient, ImagesResponse +from g4f.providers.retry_provider import IterListProvider +from .mocks import ( + YieldImageResponseProviderMock, + MissingAuthProviderMock, + AsyncRaiseExceptionProviderMock, + YieldNoneProviderMock +) + +DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}] + +class TestIterListProvider(unittest.IsolatedAsyncioTestCase): + + async def test_skip_provider(self): + client = AsyncClient(image_provider=IterListProvider([MissingAuthProviderMock, YieldImageResponseProviderMock], False)) + response = await client.images.generate("Hello", "", response_format="orginal") + self.assertIsInstance(response, ImagesResponse) + self.assertEqual("Hello", response.data[0].url) + + async def test_only_one_result(self): + client = AsyncClient(image_provider=IterListProvider([YieldImageResponseProviderMock, YieldImageResponseProviderMock], False)) + response = await client.images.generate("Hello", "", response_format="orginal") + self.assertIsInstance(response, ImagesResponse) + self.assertEqual("Hello", response.data[0].url) + + async def test_skip_none(self): + client = AsyncClient(image_provider=IterListProvider([YieldNoneProviderMock, YieldImageResponseProviderMock], False)) + response = await client.images.generate("Hello", "", response_format="orginal") + self.assertIsInstance(response, ImagesResponse) + self.assertEqual("Hello", response.data[0].url) + + def test_raise_exception(self): + async def run_exception(): + client = AsyncClient(image_provider=IterListProvider([YieldNoneProviderMock, AsyncRaiseExceptionProviderMock], False)) + await client.images.generate("Hello", "") + self.assertRaises(RuntimeError, asyncio.run, run_exception()) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/etc/unittest/mocks.py b/etc/unittest/mocks.py index c2058e34..c43d98cc 100644 --- a/etc/unittest/mocks.py +++ b/etc/unittest/mocks.py @@ -1,4 +1,6 @@ from g4f.providers.base_provider import AbstractProvider, AsyncProvider, AsyncGeneratorProvider +from g4f.image import ImageResponse +from g4f.errors import MissingAuthError class ProviderMock(AbstractProvider): working = True @@ -41,6 +43,25 @@ class YieldProviderMock(AsyncGeneratorProvider): for message in messages: yield message["content"] +class YieldImageResponseProviderMock(AsyncGeneratorProvider): + working = True + + @classmethod + async def create_async_generator( + cls, model, messages, stream, prompt: str, **kwargs + ): + yield ImageResponse(prompt, "") + +class MissingAuthProviderMock(AbstractProvider): + working = True + + @classmethod + def create_completion( + cls, model, messages, stream, **kwargs + ): + raise MissingAuthError(cls.__name__) + yield cls.__name__ + class RaiseExceptionProviderMock(AbstractProvider): working = True -- cgit v1.2.3