From f028acfefebbda617d866ab09bfcbe574fd2be80 Mon Sep 17 00:00:00 2001 From: kqlio67 <> Date: Fri, 3 Jan 2025 12:31:49 +0200 Subject: Update g4f/Provider/DDG.py --- g4f/Provider/DDG.py | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/g4f/Provider/DDG.py b/g4f/Provider/DDG.py index f04b647d..fb29203d 100644 --- a/g4f/Provider/DDG.py +++ b/g4f/Provider/DDG.py @@ -3,9 +3,16 @@ from __future__ import annotations from aiohttp import ClientSession, ClientTimeout, ClientError import json from ..typing import AsyncResult, Messages -from .base_provider import AsyncGeneratorProvider, ProviderModelMixin +from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, BaseConversation from .helper import format_prompt +class Conversation(BaseConversation): + vqd: str = None + message_history: Messages = [] + + def __init__(self, model: str): + self.model = model + class DDG(AsyncGeneratorProvider, ProviderModelMixin): label = "DuckDuckGo AI Chat" url = "https://duckduckgo.com/aichat" @@ -55,6 +62,8 @@ class DDG(AsyncGeneratorProvider, ProviderModelMixin): cls, model: str, messages: Messages, + conversation: Conversation = None, + return_conversation: bool = False, proxy: str = None, **kwargs ) -> AsyncResult: @@ -63,16 +72,30 @@ class DDG(AsyncGeneratorProvider, ProviderModelMixin): } async with ClientSession(headers=headers, timeout=ClientTimeout(total=30)) as session: # Fetch VQD token - vqd = await cls.fetch_vqd(session) - headers["x-vqd-4"] = vqd + if conversation is None: + conversation = Conversation(model) + + if conversation.vqd is None: + conversation.vqd = await cls.fetch_vqd(session) + + headers["x-vqd-4"] = conversation.vqd + + if return_conversation: + yield conversation + + if len(messages) >= 2: + conversation.message_history.extend([messages[-2], messages[-1]]) + elif len(messages) == 1: + conversation.message_history.append(messages[-1]) payload = { - "model": model, - "messages": [{"role": "user", "content": format_prompt(messages)}], + "model": conversation.model, + "messages": conversation.message_history, } try: async with session.post(cls.api_endpoint, headers=headers, json=payload, proxy=proxy) as response: + conversation.vqd = response.headers.get("x-vqd-4") response.raise_for_status() async for line in response.content: line = line.decode("utf-8").strip() -- cgit v1.2.3