From d93853afd8299d39b27ddb8eb5160064f7a962fa Mon Sep 17 00:00:00 2001 From: kqlio67 <> Date: Mon, 10 Feb 2025 21:17:10 +0200 Subject: feat(g4f/image/copy_images.py): improve image handling with Unicode support and safer encoding --- g4f/image/copy_images.py | 126 +++++++++++++++++++++++++++++++---------------- 1 file changed, 83 insertions(+), 43 deletions(-) diff --git a/g4f/image/copy_images.py b/g4f/image/copy_images.py index 4794b584..b9a2c369 100644 --- a/g4f/image/copy_images.py +++ b/g4f/image/copy_images.py @@ -6,7 +6,7 @@ import uuid import asyncio import hashlib import re -from urllib.parse import quote_plus +from urllib.parse import quote, unquote from aiohttp import ClientSession, ClientError from ..typing import Optional, Cookies @@ -15,26 +15,24 @@ from ..Provider.template import BackendApi from . import is_accepted_format, extract_data_uri from .. import debug -# Define the directory for generated images +# Directory for storing generated images images_dir = "./generated_images" def get_image_extension(image: str) -> str: - match = re.search(r"\.(?:jpe?g|png|webp)", image) - if match: - return match.group(0) - return ".jpg" + """Extract image extension from URL or filename, default to .jpg""" + match = re.search(r"\.(jpe?g|png|webp)$", image, re.IGNORECASE) + return f".{match.group(1).lower()}" if match else ".jpg" -# Function to ensure the images directory exists def ensure_images_dir(): + """Create images directory if it doesn't exist""" os.makedirs(images_dir, exist_ok=True) def get_source_url(image: str, default: str = None) -> str: - source_url = image.split("url=", 1) - if len(source_url) > 1: - source_url = source_url[1] - source_url = source_url.replace("%2F", "/").replace("%3A", ":").replace("%3F", "?").replace("%3D", "=") - if source_url.startswith("https://"): - return source_url + """Extract original URL from image parameter if present""" + if "url=" in image: + decoded_url = unquote(image.split("url=", 1)[1]) + if decoded_url.startswith(("http://", "https://")): + return decoded_url return default async def copy_images( @@ -47,45 +45,87 @@ async def copy_images( target: str = None, ssl: bool = None ) -> list[str]: + """ + Download and store images locally with Unicode-safe filenames + Returns list of relative image URLs + """ if add_url: add_url = not cookies ensure_images_dir() + async with ClientSession( connector=get_connector(proxy=proxy), cookies=cookies, headers=headers, ) as session: - async def copy_image(image: str, target: str = None, headers: dict = headers, ssl: bool = ssl) -> str: - if target is None or len(images) > 1: - hash = hashlib.sha256(image.encode()).hexdigest() - target = f"{quote_plus('+'.join(alt.split()[:10]), '')[:100]}_{hash[:16]}" if alt else str(uuid.uuid4()) - target = f"{int(time.time())}_{target}{get_image_extension(image)}" - target = os.path.join(images_dir, target) - if image.startswith("data:"): - with open(target, "wb") as f: - f.write(extract_data_uri(image)) - else: - try: - if BackendApi.working and image.startswith(BackendApi.url) and headers is None: - headers = BackendApi.headers - ssl = BackendApi.ssl - async with session.get(image, ssl=ssl, headers=headers) as response: + async def copy_image(image: str, target: str = None) -> str: + """Process individual image and return its local URL""" + target_path = None + try: + # Generate filename components + file_hash = hashlib.sha256(image.encode()).hexdigest()[:16] + timestamp = int(time.time()) + + # Sanitize alt text for filename (Unicode-safe) + if alt: + # Keep letters, numbers, basic punctuation and all Unicode chars + clean_alt = re.sub( + r'[^\w\s.-]', # Allow all Unicode word chars + '_', + unquote(alt).strip(), + flags=re.UNICODE + ) + clean_alt = re.sub(r'[\s_]+', '_', clean_alt)[:100] + else: + clean_alt = "image" + + # Build safe filename with full Unicode support + extension = get_image_extension(image) + filename = ( + f"{timestamp}_" + f"{clean_alt}_" + f"{file_hash}" + f"{extension}" + ) + target_path = os.path.join(images_dir, filename) + + # Handle different image types + if image.startswith("data:"): + with open(target_path, "wb") as f: + f.write(extract_data_uri(image)) + else: + # Apply BackendApi settings if needed + if BackendApi.working and image.startswith(BackendApi.url): + request_headers = BackendApi.headers if headers is None else headers + request_ssl = BackendApi.ssl + else: + request_headers = headers + request_ssl = ssl + + async with session.get(image, ssl=request_ssl, headers=request_headers) as response: response.raise_for_status() - with open(target, "wb") as f: + with open(target_path, "wb") as f: async for chunk in response.content.iter_chunked(4096): f.write(chunk) - except ClientError as e: - debug.log(f"copy_images failed: {e.__class__.__name__}: {e}") - if os.path.exists(target): - os.unlink(target) - return get_source_url(image, image) - if "." not in target: - with open(target, "rb") as f: - extension = is_accepted_format(f.read(12)).split("/")[-1] - extension = "jpg" if extension == "jpeg" else extension - new_target = f"{target}.{extension}" - os.rename(target, new_target) - target = new_target - return f"/images/{os.path.basename(target)}{'?url=' + image if add_url and not image.startswith('data:') else ''}" - return await asyncio.gather(*[copy_image(image, target) for image in images]) + # Verify file format + if not os.path.splitext(target_path)[1]: + with open(target_path, "rb") as f: + file_header = f.read(12) + detected_type = is_accepted_format(file_header) + if detected_type: + new_ext = f".{detected_type.split('/')[-1]}" + os.rename(target_path, f"{target_path}{new_ext}") + target_path = f"{target_path}{new_ext}" + + # Build URL with safe encoding + url_filename = quote(os.path.basename(target_path)) + return f"/images/{url_filename}{'?url=' + quote(image) if add_url and not image.startswith('data:') else ''}" + + except (ClientError, IOError, OSError) as e: + debug.log(f"Image processing failed: {e.__class__.__name__}: {e}") + if target_path and os.path.exists(target_path): + os.unlink(target_path) + return get_source_url(image, image) + + return await asyncio.gather(*[copy_image(img, target) for img in images]) -- cgit v1.2.3