From c31f5435c43ede7847dae0f3ed007357e7ff198c Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Thu, 28 Nov 2024 17:46:46 +0100 Subject: Fix api with default providers, add unittests for RetryProvider --- etc/unittest/__main__.py | 3 ++- etc/unittest/mocks.py | 32 ++++++++++++++++++++-- etc/unittest/retry_provider.py | 60 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 92 insertions(+), 3 deletions(-) create mode 100644 etc/unittest/retry_provider.py (limited to 'etc') diff --git a/etc/unittest/__main__.py b/etc/unittest/__main__.py index 0acc5865..e49dec30 100644 --- a/etc/unittest/__main__.py +++ b/etc/unittest/__main__.py @@ -6,5 +6,6 @@ from .main import * from .model import * from .client import * from .include import * +from .retry_provider import * -unittest.main() +unittest.main() \ No newline at end of file diff --git a/etc/unittest/mocks.py b/etc/unittest/mocks.py index 102730fa..c2058e34 100644 --- a/etc/unittest/mocks.py +++ b/etc/unittest/mocks.py @@ -34,9 +34,37 @@ class ModelProviderMock(AbstractProvider): class YieldProviderMock(AsyncGeneratorProvider): working = True - + async def create_async_generator( model, messages, stream, **kwargs ): for message in messages: - yield message["content"] \ No newline at end of file + yield message["content"] + +class RaiseExceptionProviderMock(AbstractProvider): + working = True + + @classmethod + def create_completion( + cls, model, messages, stream, **kwargs + ): + raise RuntimeError(cls.__name__) + yield cls.__name__ + +class AsyncRaiseExceptionProviderMock(AsyncGeneratorProvider): + working = True + + @classmethod + async def create_async_generator( + cls, model, messages, stream, **kwargs + ): + raise RuntimeError(cls.__name__) + yield cls.__name__ + +class YieldNoneProviderMock(AsyncGeneratorProvider): + working = True + + async def create_async_generator( + model, messages, stream, **kwargs + ): + yield None \ No newline at end of file diff --git a/etc/unittest/retry_provider.py b/etc/unittest/retry_provider.py new file mode 100644 index 00000000..6d41ef94 --- /dev/null +++ b/etc/unittest/retry_provider.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import unittest + +from g4f.client import AsyncClient, ChatCompletion, ChatCompletionChunk +from g4f.providers.retry_provider import IterListProvider +from .mocks import YieldProviderMock, RaiseExceptionProviderMock, AsyncRaiseExceptionProviderMock, YieldNoneProviderMock + +DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}] + +class TestIterListProvider(unittest.IsolatedAsyncioTestCase): + + async def test_skip_provider(self): + client = AsyncClient(provider=IterListProvider([RaiseExceptionProviderMock, YieldProviderMock], False)) + response = await client.chat.completions.create(DEFAULT_MESSAGES, "") + self.assertIsInstance(response, ChatCompletion) + self.assertEqual("Hello", response.choices[0].message.content) + + async def test_only_one_result(self): + client = AsyncClient(provider=IterListProvider([YieldProviderMock, YieldProviderMock])) + response = await client.chat.completions.create(DEFAULT_MESSAGES, "") + self.assertIsInstance(response, ChatCompletion) + self.assertEqual("Hello", response.choices[0].message.content) + + async def test_stream_skip_provider(self): + client = AsyncClient(provider=IterListProvider([AsyncRaiseExceptionProviderMock, YieldProviderMock], False)) + messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]] + response = client.chat.completions.create(messages, "Hello", stream=True) + async for chunk in response: + chunk: ChatCompletionChunk = chunk + self.assertIsInstance(chunk, ChatCompletionChunk) + if chunk.choices[0].delta.content is not None: + self.assertIsInstance(chunk.choices[0].delta.content, str) + + async def test_stream_only_one_result(self): + client = AsyncClient(provider=IterListProvider([YieldProviderMock, YieldProviderMock], False)) + messages = [{'role': 'user', 'content': chunk} for chunk in ["You ", "You "]] + response = client.chat.completions.create(messages, "Hello", stream=True, max_tokens=2) + response_list = [] + async for chunk in response: + response_list.append(chunk) + self.assertEqual(len(response_list), 3) + for chunk in response_list: + if chunk.choices[0].delta.content is not None: + self.assertEqual(chunk.choices[0].delta.content, "You ") + + async def test_skip_none(self): + client = AsyncClient(provider=IterListProvider([YieldNoneProviderMock, YieldProviderMock], False)) + response = await client.chat.completions.create(DEFAULT_MESSAGES, "") + self.assertIsInstance(response, ChatCompletion) + self.assertEqual("Hello", response.choices[0].message.content) + + async def test_stream_skip_none(self): + client = AsyncClient(provider=IterListProvider([YieldNoneProviderMock, YieldProviderMock], False)) + response = client.chat.completions.create(DEFAULT_MESSAGES, "", stream=True) + response_list = [chunk async for chunk in response] + self.assertEqual(len(response_list), 2) + for chunk in response_list: + if chunk.choices[0].delta.content is not None: + self.assertEqual(chunk.choices[0].delta.content, "Hello") \ No newline at end of file -- cgit v1.2.3