From 19c318e3910ae0ddccff06f5c58660003ed483ce Mon Sep 17 00:00:00 2001 From: hlohaus <983577+hlohaus@users.noreply.github.com> Date: Mon, 10 Feb 2025 20:58:37 +0100 Subject: Cache pipline tag in HuggingFaceAPI --- g4f/Provider/DDG.py | 4 +--- g4f/Provider/hf/HuggingFaceAPI.py | 43 ++++++++++++++++++++++----------------- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/g4f/Provider/DDG.py b/g4f/Provider/DDG.py index 2ac786be..0a08c936 100644 --- a/g4f/Provider/DDG.py +++ b/g4f/Provider/DDG.py @@ -173,12 +173,10 @@ class DDG(AsyncGeneratorProvider, ProviderModelMixin): n: c.value for n, c in session.cookie_jar.filter_cookies(cls.url).items() } + yield conversation if reason is not None: yield FinishReason(reason) - - if return_conversation: - yield conversation except asyncio.TimeoutError as e: raise TimeoutError(f"Request timed out: {str(e)}") \ No newline at end of file diff --git a/g4f/Provider/hf/HuggingFaceAPI.py b/g4f/Provider/hf/HuggingFaceAPI.py index ee737fed..cc0f794d 100644 --- a/g4f/Provider/hf/HuggingFaceAPI.py +++ b/g4f/Provider/hf/HuggingFaceAPI.py @@ -22,6 +22,8 @@ class HuggingFaceAPI(OpenaiTemplate): vision_models = vision_models model_aliases = model_aliases + pipeline_tag: dict[str, str] = {} + @classmethod def get_models(cls, **kwargs): if not cls.models: @@ -32,6 +34,20 @@ class HuggingFaceAPI(OpenaiTemplate): cls.models.append(model) return cls.models + @classmethod + async def get_pipline_tag(cls, model: str, api_key: str = None): + if model in cls.pipeline_tag: + return cls.pipeline_tag[model] + async with StreamSession( + timeout=30, + headers=cls.get_headers(False, api_key), + ) as session: + async with session.get(f"https://huggingface.co/api/models/{model}") as response: + await raise_for_status(response) + model_data = await response.json() + cls.pipeline_tag[model] = model_data.get("pipeline_tag") + return cls.pipeline_tag[model] + @classmethod async def create_async_generator( cls, @@ -44,25 +60,14 @@ class HuggingFaceAPI(OpenaiTemplate): images: ImagesType = None, **kwargs ): - if api_base is None: - model_name = model - if model in cls.model_aliases: - model_name = cls.model_aliases[model] - api_base = f"https://api-inference.huggingface.co/models/{model_name}/v1" - async with StreamSession( - timeout=30, - headers=cls.get_headers(False, api_key), - ) as session: - async with session.get(f"https://huggingface.co/api/models/{model}") as response: - if response.status == 404: - raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}") - await raise_for_status(response) - model_data = await response.json() - pipeline_tag = model_data.get("pipeline_tag") - if images is None and pipeline_tag not in ("text-generation", "image-text-to-text"): - raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}") - elif pipeline_tag != "image-text-to-text": - raise ModelNotSupportedError(f"Model does not support images: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}") + if model in cls.model_aliases: + model = cls.model_aliases[model] + api_base = f"https://api-inference.huggingface.co/models/{model}/v1" + pipeline_tag = await cls.get_pipline_tag(model, api_key) + if images is None and pipeline_tag not in ("text-generation", "image-text-to-text"): + raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}") + elif pipeline_tag != "image-text-to-text": + raise ModelNotSupportedError(f"Model does not support images: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}") start = calculate_lenght(messages) if start > max_inputs_lenght: if len(messages) > 6: -- cgit v1.2.3