From 7b45a5d4525a3745f2dac1a11b26c88629413aea Mon Sep 17 00:00:00 2001 From: Weishan-0 <29089406+Weishan-0@users.noreply.github.com> Date: Thu, 18 Jul 2024 19:37:04 +0800 Subject: [PATCH] fix: Unable to display images generated by Dall-E 3 (#6155) --- .../provider/builtin/dalle/tools/dalle3.py | 47 +++++++++++++++++-- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.py b/api/core/tools/provider/builtin/dalle/tools/dalle3.py index 61609947f..f985deade 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.py @@ -1,5 +1,5 @@ +import base64 import random -from base64 import b64decode from typing import Any, Union from openai import OpenAI @@ -69,11 +69,50 @@ class DallE3Tool(BuiltinTool): result = [] for image in response.data: - result.append(self.create_blob_message(blob=b64decode(image.b64_json), - meta={'mime_type': 'image/png'}, - save_as=self.VARIABLE_KEY.IMAGE.value)) + mime_type, blob_image = DallE3Tool._decode_image(image.b64_json) + blob_message = self.create_blob_message(blob=blob_image, + meta={'mime_type': mime_type}, + save_as=self.VARIABLE_KEY.IMAGE.value) + result.append(blob_message) return result + @staticmethod + def _decode_image(base64_image: str) -> tuple[str, bytes]: + """ + Decode a base64 encoded image. If the image is not prefixed with a MIME type, + it assumes 'image/png' as the default. + + :param base64_image: Base64 encoded image string + :return: A tuple containing the MIME type and the decoded image bytes + """ + if DallE3Tool._is_plain_base64(base64_image): + return 'image/png', base64.b64decode(base64_image) + else: + return DallE3Tool._extract_mime_and_data(base64_image) + + @staticmethod + def _is_plain_base64(encoded_str: str) -> bool: + """ + Check if the given encoded string is plain base64 without a MIME type prefix. + + :param encoded_str: Base64 encoded image string + :return: True if the string is plain base64, False otherwise + """ + return not encoded_str.startswith('data:image') + + @staticmethod + def _extract_mime_and_data(encoded_str: str) -> tuple[str, bytes]: + """ + Extract MIME type and image data from a base64 encoded string with a MIME type prefix. + + :param encoded_str: Base64 encoded image string with MIME type prefix + :return: A tuple containing the MIME type and the decoded image bytes + """ + mime_type = encoded_str.split(';')[0].split(':')[1] + image_data_base64 = encoded_str.split(',')[1] + decoded_data = base64.b64decode(image_data_base64) + return mime_type, decoded_data + @staticmethod def _generate_random_id(length=8): characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'