refactor: replace try-except blocks with contextlib.suppress for cleaner exception handling (#24284)
This commit is contained in:
parent
ad8e82ee1d
commit
1abf1240b2
@ -1,3 +1,4 @@
|
|||||||
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
@ -178,7 +179,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
|||||||
def cloud_utm_record(view):
|
def cloud_utm_record(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args, **kwargs):
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
|
|
||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
@ -187,8 +188,7 @@ def cloud_utm_record(view):
|
|||||||
if utm_info:
|
if utm_info:
|
||||||
utm_info_dict: dict = json.loads(utm_info)
|
utm_info_dict: dict = json.loads(utm_info)
|
||||||
OperationService.record_utm(current_user.current_tenant_id, utm_info_dict)
|
OperationService.record_utm(current_user.current_tenant_id, utm_info_dict)
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
return decorated
|
return decorated
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import contextlib
|
||||||
import re
|
import re
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
@ -97,10 +98,8 @@ def parse_traceparent_header(traceparent: str) -> Optional[str]:
|
|||||||
Reference:
|
Reference:
|
||||||
W3C Trace Context Specification: https://www.w3.org/TR/trace-context/
|
W3C Trace Context Specification: https://www.w3.org/TR/trace-context/
|
||||||
"""
|
"""
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
parts = traceparent.split("-")
|
parts = traceparent.split("-")
|
||||||
if len(parts) == 4 and len(parts[1]) == 32:
|
if len(parts) == 4 and len(parts[1]) == 32:
|
||||||
return parts[1]
|
return parts[1]
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import contextlib
|
||||||
import json
|
import json
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
@ -624,14 +625,12 @@ class ProviderManager:
|
|||||||
|
|
||||||
for variable in provider_credential_secret_variables:
|
for variable in provider_credential_secret_variables:
|
||||||
if variable in provider_credentials:
|
if variable in provider_credentials:
|
||||||
try:
|
with contextlib.suppress(ValueError):
|
||||||
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||||
provider_credentials.get(variable) or "", # type: ignore
|
provider_credentials.get(variable) or "", # type: ignore
|
||||||
self.decoding_rsa_key,
|
self.decoding_rsa_key,
|
||||||
self.decoding_cipher_rsa,
|
self.decoding_cipher_rsa,
|
||||||
)
|
)
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# cache provider credentials
|
# cache provider credentials
|
||||||
provider_credentials_cache.set(credentials=provider_credentials)
|
provider_credentials_cache.set(credentials=provider_credentials)
|
||||||
@ -672,14 +671,12 @@ class ProviderManager:
|
|||||||
|
|
||||||
for variable in model_credential_secret_variables:
|
for variable in model_credential_secret_variables:
|
||||||
if variable in provider_model_credentials:
|
if variable in provider_model_credentials:
|
||||||
try:
|
with contextlib.suppress(ValueError):
|
||||||
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||||
provider_model_credentials.get(variable),
|
provider_model_credentials.get(variable),
|
||||||
self.decoding_rsa_key,
|
self.decoding_rsa_key,
|
||||||
self.decoding_cipher_rsa,
|
self.decoding_cipher_rsa,
|
||||||
)
|
)
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# cache provider model credentials
|
# cache provider model credentials
|
||||||
provider_model_credentials_cache.set(credentials=provider_model_credentials)
|
provider_model_credentials_cache.set(credentials=provider_model_credentials)
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import queue
|
import queue
|
||||||
@ -214,10 +215,8 @@ class ClickzettaConnectionPool:
|
|||||||
return connection
|
return connection
|
||||||
else:
|
else:
|
||||||
# Connection expired or invalid, close it
|
# Connection expired or invalid, close it
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
connection.close()
|
connection.close()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# No valid connection found, create new one
|
# No valid connection found, create new one
|
||||||
return self._create_connection(config)
|
return self._create_connection(config)
|
||||||
@ -228,10 +227,8 @@ class ClickzettaConnectionPool:
|
|||||||
|
|
||||||
if config_key not in self._pool_locks:
|
if config_key not in self._pool_locks:
|
||||||
# Pool was cleaned up, just close the connection
|
# Pool was cleaned up, just close the connection
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
connection.close()
|
connection.close()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return
|
return
|
||||||
|
|
||||||
with self._pool_locks[config_key]:
|
with self._pool_locks[config_key]:
|
||||||
@ -243,10 +240,8 @@ class ClickzettaConnectionPool:
|
|||||||
logger.debug("Returned ClickZetta connection to pool")
|
logger.debug("Returned ClickZetta connection to pool")
|
||||||
else:
|
else:
|
||||||
# Pool full or connection invalid, close it
|
# Pool full or connection invalid, close it
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
connection.close()
|
connection.close()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _cleanup_expired_connections(self) -> None:
|
def _cleanup_expired_connections(self) -> None:
|
||||||
"""Clean up expired connections from all pools."""
|
"""Clean up expired connections from all pools."""
|
||||||
@ -265,10 +260,8 @@ class ClickzettaConnectionPool:
|
|||||||
if current_time - last_used < self._connection_timeout:
|
if current_time - last_used < self._connection_timeout:
|
||||||
valid_connections.append((connection, last_used))
|
valid_connections.append((connection, last_used))
|
||||||
else:
|
else:
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
connection.close()
|
connection.close()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
self._pools[config_key] = valid_connections
|
self._pools[config_key] = valid_connections
|
||||||
|
|
||||||
@ -299,10 +292,8 @@ class ClickzettaConnectionPool:
|
|||||||
with self._pool_locks[config_key]:
|
with self._pool_locks[config_key]:
|
||||||
pool = self._pools[config_key]
|
pool = self._pools[config_key]
|
||||||
for connection, _ in pool:
|
for connection, _ in pool:
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
connection.close()
|
connection.close()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
pool.clear()
|
pool.clear()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
"""Abstract interface for document loader implementations."""
|
"""Abstract interface for document loader implementations."""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from typing import Optional, cast
|
from typing import Optional, cast
|
||||||
|
|
||||||
@ -25,12 +26,10 @@ class PdfExtractor(BaseExtractor):
|
|||||||
def extract(self) -> list[Document]:
|
def extract(self) -> list[Document]:
|
||||||
plaintext_file_exists = False
|
plaintext_file_exists = False
|
||||||
if self._file_cache_key:
|
if self._file_cache_key:
|
||||||
try:
|
with contextlib.suppress(FileNotFoundError):
|
||||||
text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8")
|
text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8")
|
||||||
plaintext_file_exists = True
|
plaintext_file_exists = True
|
||||||
return [Document(page_content=text)]
|
return [Document(page_content=text)]
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
documents = list(self.load())
|
documents = list(self.load())
|
||||||
text_list = []
|
text_list = []
|
||||||
for document in documents:
|
for document in documents:
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import base64
|
import base64
|
||||||
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -33,7 +34,7 @@ class UnstructuredEmailExtractor(BaseExtractor):
|
|||||||
elements = partition_email(filename=self._file_path)
|
elements = partition_email(filename=self._file_path)
|
||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
for element in elements:
|
for element in elements:
|
||||||
element_text = element.text.strip()
|
element_text = element.text.strip()
|
||||||
|
|
||||||
@ -43,8 +44,6 @@ class UnstructuredEmailExtractor(BaseExtractor):
|
|||||||
element_decode = base64.b64decode(element_text)
|
element_decode = base64.b64decode(element_text)
|
||||||
soup = BeautifulSoup(element_decode.decode("utf-8"), "html.parser")
|
soup = BeautifulSoup(element_decode.decode("utf-8"), "html.parser")
|
||||||
element.text = soup.get_text()
|
element.text = soup.get_text()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
from unstructured.chunking.title import chunk_by_title
|
from unstructured.chunking.title import chunk_by_title
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import base64
|
import base64
|
||||||
|
import contextlib
|
||||||
import enum
|
import enum
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -227,10 +228,8 @@ class ToolInvokeMessage(BaseModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def decode_blob_message(cls, v):
|
def decode_blob_message(cls, v):
|
||||||
if isinstance(v, dict) and "blob" in v:
|
if isinstance(v, dict) and "blob" in v:
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
v["blob"] = base64.b64decode(v["blob"])
|
v["blob"] = base64.b64decode(v["blob"])
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_serializer("message")
|
@field_serializer("message")
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import contextlib
|
||||||
import json
|
import json
|
||||||
from collections.abc import Generator, Iterable
|
from collections.abc import Generator, Iterable
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@ -69,10 +70,8 @@ class ToolEngine:
|
|||||||
if parameters and len(parameters) == 1:
|
if parameters and len(parameters) == 1:
|
||||||
tool_parameters = {parameters[0].name: tool_parameters}
|
tool_parameters = {parameters[0].name: tool_parameters}
|
||||||
else:
|
else:
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
tool_parameters = json.loads(tool_parameters)
|
tool_parameters = json.loads(tool_parameters)
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
if not isinstance(tool_parameters, dict):
|
if not isinstance(tool_parameters, dict):
|
||||||
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
|
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
|
||||||
|
|
||||||
@ -270,14 +269,12 @@ class ToolEngine:
|
|||||||
if response.meta.get("mime_type"):
|
if response.meta.get("mime_type"):
|
||||||
mimetype = response.meta.get("mime_type")
|
mimetype = response.meta.get("mime_type")
|
||||||
else:
|
else:
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text)
|
url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text)
|
||||||
extension = url.suffix
|
extension = url.suffix
|
||||||
guess_type_result, _ = guess_type(f"a{extension}")
|
guess_type_result, _ = guess_type(f"a{extension}")
|
||||||
if guess_type_result:
|
if guess_type_result:
|
||||||
mimetype = guess_type_result
|
mimetype = guess_type_result
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not mimetype:
|
if not mimetype:
|
||||||
mimetype = "image/jpeg"
|
mimetype = "image/jpeg"
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import contextlib
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -137,11 +138,9 @@ class ToolParameterConfigurationManager:
|
|||||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||||
):
|
):
|
||||||
if parameter.name in parameters:
|
if parameter.name in parameters:
|
||||||
try:
|
has_secret_input = True
|
||||||
has_secret_input = True
|
with contextlib.suppress(Exception):
|
||||||
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
|
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if has_secret_input:
|
if has_secret_input:
|
||||||
cache.set(parameters)
|
cache.set(parameters)
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import contextlib
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Optional, Protocol
|
from typing import Any, Optional, Protocol
|
||||||
|
|
||||||
@ -111,14 +112,12 @@ class ProviderConfigEncrypter:
|
|||||||
for field_name, field in fields.items():
|
for field_name, field in fields.items():
|
||||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||||
if field_name in data:
|
if field_name in data:
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
# if the value is None or empty string, skip decrypt
|
# if the value is None or empty string, skip decrypt
|
||||||
if not data[field_name]:
|
if not data[field_name]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
self.provider_config_cache.set(data)
|
self.provider_config_cache.set(data)
|
||||||
return data
|
return data
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
@ -666,10 +667,8 @@ class ParameterExtractorNode(BaseNode):
|
|||||||
if result[idx] == "{" or result[idx] == "[":
|
if result[idx] == "{" or result[idx] == "[":
|
||||||
json_str = extract_json(result[idx:])
|
json_str = extract_json(result[idx:])
|
||||||
if json_str:
|
if json_str:
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
return cast(dict, json.loads(json_str))
|
return cast(dict, json.loads(json_str))
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
logger.info("extra error: %s", result)
|
logger.info("extra error: %s", result)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -686,10 +685,9 @@ class ParameterExtractorNode(BaseNode):
|
|||||||
if result[idx] == "{" or result[idx] == "[":
|
if result[idx] == "{" or result[idx] == "[":
|
||||||
json_str = extract_json(result[idx:])
|
json_str = extract_json(result[idx:])
|
||||||
if json_str:
|
if json_str:
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
return cast(dict, json.loads(json_str))
|
return cast(dict, json.loads(json_str))
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
logger.info("extra error: %s", result)
|
logger.info("extra error: %s", result)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@ -38,12 +39,11 @@ def handle(sender, **kwargs):
|
|||||||
db.session.add(document)
|
db.session.add(document)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
indexing_runner = IndexingRunner()
|
try:
|
||||||
indexing_runner.run(documents)
|
indexing_runner = IndexingRunner()
|
||||||
end_at = time.perf_counter()
|
indexing_runner.run(documents)
|
||||||
logging.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
end_at = time.perf_counter()
|
||||||
except DocumentIsPausedError as ex:
|
logging.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||||
logging.info(click.style(str(ex), fg="yellow"))
|
except DocumentIsPausedError as ex:
|
||||||
except Exception:
|
logging.info(click.style(str(ex), fg="yellow"))
|
||||||
pass
|
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import atexit
|
import atexit
|
||||||
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
@ -106,7 +107,7 @@ def init_app(app: DifyApp):
|
|||||||
"""Custom logging handler that creates spans for logging.exception() calls"""
|
"""Custom logging handler that creates spans for logging.exception() calls"""
|
||||||
|
|
||||||
def emit(self, record: logging.LogRecord):
|
def emit(self, record: logging.LogRecord):
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
if record.exc_info:
|
if record.exc_info:
|
||||||
tracer = get_tracer_provider().get_tracer("dify.exception.logging")
|
tracer = get_tracer_provider().get_tracer("dify.exception.logging")
|
||||||
with tracer.start_as_current_span(
|
with tracer.start_as_current_span(
|
||||||
@ -126,9 +127,6 @@ def init_app(app: DifyApp):
|
|||||||
if record.exc_info[0]:
|
if record.exc_info[0]:
|
||||||
span.set_attribute("exception.type", record.exc_info[0].__name__)
|
span.set_attribute("exception.type", record.exc_info[0].__name__)
|
||||||
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter
|
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter
|
||||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter
|
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import contextlib
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
@ -142,13 +143,11 @@ class ConversationService:
|
|||||||
raise MessageNotExistsError()
|
raise MessageNotExistsError()
|
||||||
|
|
||||||
# generate conversation name
|
# generate conversation name
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
name = LLMGenerator.generate_conversation_name(
|
name = LLMGenerator.generate_conversation_name(
|
||||||
app_model.tenant_id, message.query, conversation.id, app_model.id
|
app_model.tenant_id, message.query, conversation.id, app_model.id
|
||||||
)
|
)
|
||||||
conversation.name = name
|
conversation.name = name
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import contextlib
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -44,10 +45,8 @@ class TestClickzettaVector(AbstractVectorTest):
|
|||||||
yield vector
|
yield vector
|
||||||
|
|
||||||
# Cleanup: delete the test collection
|
# Cleanup: delete the test collection
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
vector.delete()
|
vector.delete()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_clickzetta_vector_basic_operations(self, vector_store):
|
def test_clickzetta_vector_basic_operations(self, vector_store):
|
||||||
"""Test basic CRUD operations on Clickzetta vector store."""
|
"""Test basic CRUD operations on Clickzetta vector store."""
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
@ -124,13 +125,10 @@ def test_sse_client_connection_validation():
|
|||||||
mock_event_source.iter_sse.return_value = [endpoint_event]
|
mock_event_source.iter_sse.return_value = [endpoint_event]
|
||||||
|
|
||||||
# Test connection
|
# Test connection
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
with sse_client(test_url) as (read_queue, write_queue):
|
with sse_client(test_url) as (read_queue, write_queue):
|
||||||
assert read_queue is not None
|
assert read_queue is not None
|
||||||
assert write_queue is not None
|
assert write_queue is not None
|
||||||
except Exception as e:
|
|
||||||
# Connection might fail due to mocking, but we're testing the validation logic
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def test_sse_client_error_handling():
|
def test_sse_client_error_handling():
|
||||||
@ -178,7 +176,7 @@ def test_sse_client_timeout_configuration():
|
|||||||
mock_event_source.iter_sse.return_value = []
|
mock_event_source.iter_sse.return_value = []
|
||||||
mock_sse_connect.return_value.__enter__.return_value = mock_event_source
|
mock_sse_connect.return_value.__enter__.return_value = mock_event_source
|
||||||
|
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
with sse_client(
|
with sse_client(
|
||||||
test_url, headers=custom_headers, timeout=custom_timeout, sse_read_timeout=custom_sse_timeout
|
test_url, headers=custom_headers, timeout=custom_timeout, sse_read_timeout=custom_sse_timeout
|
||||||
) as (read_queue, write_queue):
|
) as (read_queue, write_queue):
|
||||||
@ -190,9 +188,6 @@ def test_sse_client_timeout_configuration():
|
|||||||
assert call_args is not None
|
assert call_args is not None
|
||||||
timeout_arg = call_args[1]["timeout"]
|
timeout_arg = call_args[1]["timeout"]
|
||||||
assert timeout_arg.read == custom_sse_timeout
|
assert timeout_arg.read == custom_sse_timeout
|
||||||
except Exception:
|
|
||||||
# Connection might fail due to mocking, but we tested the configuration
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def test_sse_transport_endpoint_validation():
|
def test_sse_transport_endpoint_validation():
|
||||||
@ -251,12 +246,10 @@ def test_sse_client_queue_cleanup():
|
|||||||
# Mock connection that raises an exception
|
# Mock connection that raises an exception
|
||||||
mock_sse_connect.side_effect = Exception("Connection failed")
|
mock_sse_connect.side_effect = Exception("Connection failed")
|
||||||
|
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
with sse_client(test_url) as (rq, wq):
|
with sse_client(test_url) as (rq, wq):
|
||||||
read_queue = rq
|
read_queue = rq
|
||||||
write_queue = wq
|
write_queue = wq
|
||||||
except Exception:
|
|
||||||
pass # Expected to fail
|
|
||||||
|
|
||||||
# Queues should be cleaned up even on exception
|
# Queues should be cleaned up even on exception
|
||||||
# Note: In real implementation, cleanup should put None to signal shutdown
|
# Note: In real implementation, cleanup should put None to signal shutdown
|
||||||
@ -283,11 +276,9 @@ def test_sse_client_headers_propagation():
|
|||||||
mock_event_source.iter_sse.return_value = []
|
mock_event_source.iter_sse.return_value = []
|
||||||
mock_sse_connect.return_value.__enter__.return_value = mock_event_source
|
mock_sse_connect.return_value.__enter__.return_value = mock_event_source
|
||||||
|
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
with sse_client(test_url, headers=custom_headers):
|
with sse_client(test_url, headers=custom_headers):
|
||||||
pass
|
pass
|
||||||
except Exception:
|
|
||||||
pass # Expected due to mocking
|
|
||||||
|
|
||||||
# Verify headers were passed to client factory
|
# Verify headers were passed to client factory
|
||||||
mock_client_factory.assert_called_with(headers=custom_headers)
|
mock_client_factory.assert_called_with(headers=custom_headers)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user