feat(workflow_cycle_manager): Removes redundant repository methods and adds caching (#22597)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
3826b57424
commit
b88dd17fc1
@ -6,7 +6,6 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
@ -206,44 +205,3 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
|||||||
# Update the in-memory cache for faster subsequent lookups
|
# Update the in-memory cache for faster subsequent lookups
|
||||||
logger.debug(f"Updating cache for execution_id: {db_model.id}")
|
logger.debug(f"Updating cache for execution_id: {db_model.id}")
|
||||||
self._execution_cache[db_model.id] = db_model
|
self._execution_cache[db_model.id] = db_model
|
||||||
|
|
||||||
def get(self, execution_id: str) -> Optional[WorkflowExecution]:
|
|
||||||
"""
|
|
||||||
Retrieve a WorkflowExecution by its ID.
|
|
||||||
|
|
||||||
First checks the in-memory cache, and if not found, queries the database.
|
|
||||||
If found in the database, adds it to the cache for future lookups.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
execution_id: The workflow execution ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The WorkflowExecution instance if found, None otherwise
|
|
||||||
"""
|
|
||||||
# First check the cache
|
|
||||||
if execution_id in self._execution_cache:
|
|
||||||
logger.debug(f"Cache hit for execution_id: {execution_id}")
|
|
||||||
# Convert cached DB model to domain model
|
|
||||||
cached_db_model = self._execution_cache[execution_id]
|
|
||||||
return self._to_domain_model(cached_db_model)
|
|
||||||
|
|
||||||
# If not in cache, query the database
|
|
||||||
logger.debug(f"Cache miss for execution_id: {execution_id}, querying database")
|
|
||||||
with self._session_factory() as session:
|
|
||||||
stmt = select(WorkflowRun).where(
|
|
||||||
WorkflowRun.id == execution_id,
|
|
||||||
WorkflowRun.tenant_id == self._tenant_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._app_id:
|
|
||||||
stmt = stmt.where(WorkflowRun.app_id == self._app_id)
|
|
||||||
|
|
||||||
db_model = session.scalar(stmt)
|
|
||||||
if db_model:
|
|
||||||
# Add DB model to cache
|
|
||||||
self._execution_cache[execution_id] = db_model
|
|
||||||
|
|
||||||
# Convert to domain model and return
|
|
||||||
return self._to_domain_model(db_model)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import logging
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from sqlalchemy import UnaryExpression, asc, delete, desc, select
|
from sqlalchemy import UnaryExpression, asc, desc, select
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
@ -218,47 +218,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||||||
logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}")
|
logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}")
|
||||||
self._node_execution_cache[db_model.node_execution_id] = db_model
|
self._node_execution_cache[db_model.node_execution_id] = db_model
|
||||||
|
|
||||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
|
|
||||||
"""
|
|
||||||
Retrieve a NodeExecution by its node_execution_id.
|
|
||||||
|
|
||||||
First checks the in-memory cache, and if not found, queries the database.
|
|
||||||
If found in the database, adds it to the cache for future lookups.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_execution_id: The node execution ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The NodeExecution instance if found, None otherwise
|
|
||||||
"""
|
|
||||||
# First check the cache
|
|
||||||
if node_execution_id in self._node_execution_cache:
|
|
||||||
logger.debug(f"Cache hit for node_execution_id: {node_execution_id}")
|
|
||||||
# Convert cached DB model to domain model
|
|
||||||
cached_db_model = self._node_execution_cache[node_execution_id]
|
|
||||||
return self._to_domain_model(cached_db_model)
|
|
||||||
|
|
||||||
# If not in cache, query the database
|
|
||||||
logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database")
|
|
||||||
with self._session_factory() as session:
|
|
||||||
stmt = select(WorkflowNodeExecutionModel).where(
|
|
||||||
WorkflowNodeExecutionModel.node_execution_id == node_execution_id,
|
|
||||||
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._app_id:
|
|
||||||
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
|
|
||||||
|
|
||||||
db_model = session.scalar(stmt)
|
|
||||||
if db_model:
|
|
||||||
# Add DB model to cache
|
|
||||||
self._node_execution_cache[node_execution_id] = db_model
|
|
||||||
|
|
||||||
# Convert to domain model and return
|
|
||||||
return self._to_domain_model(db_model)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_db_models_by_workflow_run(
|
def get_db_models_by_workflow_run(
|
||||||
self,
|
self,
|
||||||
workflow_run_id: str,
|
workflow_run_id: str,
|
||||||
@ -344,68 +303,3 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||||||
domain_models.append(domain_model)
|
domain_models.append(domain_model)
|
||||||
|
|
||||||
return domain_models
|
return domain_models
|
||||||
|
|
||||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
|
|
||||||
"""
|
|
||||||
Retrieve all running NodeExecution instances for a specific workflow run.
|
|
||||||
|
|
||||||
This method queries the database directly and updates the cache with any
|
|
||||||
retrieved executions that have a node_execution_id.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
workflow_run_id: The workflow run ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of running NodeExecution instances
|
|
||||||
"""
|
|
||||||
with self._session_factory() as session:
|
|
||||||
stmt = select(WorkflowNodeExecutionModel).where(
|
|
||||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
|
||||||
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
|
|
||||||
WorkflowNodeExecutionModel.status == WorkflowNodeExecutionStatus.RUNNING,
|
|
||||||
WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._app_id:
|
|
||||||
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
|
|
||||||
|
|
||||||
db_models = session.scalars(stmt).all()
|
|
||||||
domain_models = []
|
|
||||||
|
|
||||||
for model in db_models:
|
|
||||||
# Update cache if node_execution_id is present
|
|
||||||
if model.node_execution_id:
|
|
||||||
self._node_execution_cache[model.node_execution_id] = model
|
|
||||||
|
|
||||||
# Convert to domain model
|
|
||||||
domain_model = self._to_domain_model(model)
|
|
||||||
domain_models.append(domain_model)
|
|
||||||
|
|
||||||
return domain_models
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
"""
|
|
||||||
Clear all WorkflowNodeExecution records for the current tenant_id and app_id.
|
|
||||||
|
|
||||||
This method deletes all WorkflowNodeExecution records that match the tenant_id
|
|
||||||
and app_id (if provided) associated with this repository instance.
|
|
||||||
It also clears the in-memory cache.
|
|
||||||
"""
|
|
||||||
with self._session_factory() as session:
|
|
||||||
stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.tenant_id == self._tenant_id)
|
|
||||||
|
|
||||||
if self._app_id:
|
|
||||||
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
|
|
||||||
|
|
||||||
result = session.execute(stmt)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
deleted_count = result.rowcount
|
|
||||||
logger.info(
|
|
||||||
f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}"
|
|
||||||
+ (f" and app {self._app_id}" if self._app_id else "")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clear the in-memory cache
|
|
||||||
self._node_execution_cache.clear()
|
|
||||||
logger.info("Cleared in-memory node execution cache")
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||||
|
|
||||||
@ -28,15 +28,3 @@ class WorkflowExecutionRepository(Protocol):
|
|||||||
execution: The WorkflowExecution instance to save or update
|
execution: The WorkflowExecution instance to save or update
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def get(self, execution_id: str) -> Optional[WorkflowExecution]:
|
|
||||||
"""
|
|
||||||
Retrieve a WorkflowExecution by its ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
execution_id: The workflow execution ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The WorkflowExecution instance if found, None otherwise
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|||||||
@ -39,18 +39,6 @@ class WorkflowNodeExecutionRepository(Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
|
|
||||||
"""
|
|
||||||
Retrieve a NodeExecution by its node_execution_id.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_execution_id: The node execution ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The NodeExecution instance if found, None otherwise
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def get_by_workflow_run(
|
def get_by_workflow_run(
|
||||||
self,
|
self,
|
||||||
workflow_run_id: str,
|
workflow_run_id: str,
|
||||||
@ -69,24 +57,3 @@ class WorkflowNodeExecutionRepository(Protocol):
|
|||||||
A list of NodeExecution instances
|
A list of NodeExecution instances
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
|
|
||||||
"""
|
|
||||||
Retrieve all running NodeExecution instances for a specific workflow run.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
workflow_run_id: The workflow run ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of running NodeExecution instances
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
"""
|
|
||||||
Clear all NodeExecution records based on implementation-specific criteria.
|
|
||||||
|
|
||||||
This method is intended to be used for bulk deletion operations, such as removing
|
|
||||||
all records associated with a specific app_id and tenant_id in multi-tenant implementations.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|||||||
@ -55,24 +55,15 @@ class WorkflowCycleManager:
|
|||||||
self._workflow_execution_repository = workflow_execution_repository
|
self._workflow_execution_repository = workflow_execution_repository
|
||||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||||
|
|
||||||
|
# Initialize caches for workflow execution cycle
|
||||||
|
# These caches avoid redundant repository calls during a single workflow execution
|
||||||
|
self._workflow_execution_cache: dict[str, WorkflowExecution] = {}
|
||||||
|
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
|
||||||
|
|
||||||
def handle_workflow_run_start(self) -> WorkflowExecution:
|
def handle_workflow_run_start(self) -> WorkflowExecution:
|
||||||
inputs = {**self._application_generate_entity.inputs}
|
inputs = self._prepare_workflow_inputs()
|
||||||
|
execution_id = self._get_or_generate_execution_id()
|
||||||
|
|
||||||
# Iterate over SystemVariable fields using Pydantic's model_fields
|
|
||||||
if self._workflow_system_variables:
|
|
||||||
for field_name, value in self._workflow_system_variables.to_dict().items():
|
|
||||||
if field_name == SystemVariableKey.CONVERSATION_ID:
|
|
||||||
continue
|
|
||||||
inputs[f"sys.{field_name}"] = value
|
|
||||||
|
|
||||||
# handle special values
|
|
||||||
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
|
|
||||||
|
|
||||||
# init workflow run
|
|
||||||
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
|
|
||||||
execution_id = str(
|
|
||||||
self._workflow_system_variables.workflow_execution_id if self._workflow_system_variables else None
|
|
||||||
) or str(uuid4())
|
|
||||||
execution = WorkflowExecution.new(
|
execution = WorkflowExecution.new(
|
||||||
id_=execution_id,
|
id_=execution_id,
|
||||||
workflow_id=self._workflow_info.workflow_id,
|
workflow_id=self._workflow_info.workflow_id,
|
||||||
@ -83,9 +74,7 @@ class WorkflowCycleManager:
|
|||||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||||
)
|
)
|
||||||
|
|
||||||
self._workflow_execution_repository.save(execution)
|
return self._save_and_cache_workflow_execution(execution)
|
||||||
|
|
||||||
return execution
|
|
||||||
|
|
||||||
def handle_workflow_run_success(
|
def handle_workflow_run_success(
|
||||||
self,
|
self,
|
||||||
@ -99,23 +88,15 @@ class WorkflowCycleManager:
|
|||||||
) -> WorkflowExecution:
|
) -> WorkflowExecution:
|
||||||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
|
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
|
||||||
|
|
||||||
# outputs = WorkflowEntry.handle_special_values(outputs)
|
self._update_workflow_execution_completion(
|
||||||
|
workflow_execution,
|
||||||
|
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||||
|
outputs=outputs,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
total_steps=total_steps,
|
||||||
|
)
|
||||||
|
|
||||||
workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED
|
self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id)
|
||||||
workflow_execution.outputs = outputs or {}
|
|
||||||
workflow_execution.total_tokens = total_tokens
|
|
||||||
workflow_execution.total_steps = total_steps
|
|
||||||
workflow_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
|
||||||
|
|
||||||
if trace_manager:
|
|
||||||
trace_manager.add_trace_task(
|
|
||||||
TraceTask(
|
|
||||||
TraceTaskName.WORKFLOW_TRACE,
|
|
||||||
workflow_execution=workflow_execution,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
user_id=trace_manager.user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self._workflow_execution_repository.save(workflow_execution)
|
self._workflow_execution_repository.save(workflow_execution)
|
||||||
return workflow_execution
|
return workflow_execution
|
||||||
@ -132,24 +113,17 @@ class WorkflowCycleManager:
|
|||||||
trace_manager: Optional[TraceQueueManager] = None,
|
trace_manager: Optional[TraceQueueManager] = None,
|
||||||
) -> WorkflowExecution:
|
) -> WorkflowExecution:
|
||||||
execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
|
execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
|
||||||
# outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
|
|
||||||
|
|
||||||
execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED
|
self._update_workflow_execution_completion(
|
||||||
execution.outputs = outputs or {}
|
execution,
|
||||||
execution.total_tokens = total_tokens
|
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||||
execution.total_steps = total_steps
|
outputs=outputs,
|
||||||
execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
total_tokens=total_tokens,
|
||||||
execution.exceptions_count = exceptions_count
|
total_steps=total_steps,
|
||||||
|
exceptions_count=exceptions_count,
|
||||||
|
)
|
||||||
|
|
||||||
if trace_manager:
|
self._add_trace_task_if_needed(trace_manager, execution, conversation_id)
|
||||||
trace_manager.add_trace_task(
|
|
||||||
TraceTask(
|
|
||||||
TraceTaskName.WORKFLOW_TRACE,
|
|
||||||
workflow_execution=execution,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
user_id=trace_manager.user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self._workflow_execution_repository.save(execution)
|
self._workflow_execution_repository.save(execution)
|
||||||
return execution
|
return execution
|
||||||
@ -169,39 +143,18 @@ class WorkflowCycleManager:
|
|||||||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
|
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
|
||||||
now = naive_utc_now()
|
now = naive_utc_now()
|
||||||
|
|
||||||
workflow_execution.status = WorkflowExecutionStatus(status.value)
|
self._update_workflow_execution_completion(
|
||||||
workflow_execution.error_message = error_message
|
workflow_execution,
|
||||||
workflow_execution.total_tokens = total_tokens
|
status=status,
|
||||||
workflow_execution.total_steps = total_steps
|
total_tokens=total_tokens,
|
||||||
workflow_execution.finished_at = now
|
total_steps=total_steps,
|
||||||
workflow_execution.exceptions_count = exceptions_count
|
error_message=error_message,
|
||||||
|
exceptions_count=exceptions_count,
|
||||||
# Use the instance repository to find running executions for a workflow run
|
finished_at=now,
|
||||||
running_node_executions = self._workflow_node_execution_repository.get_running_executions(
|
|
||||||
workflow_run_id=workflow_execution.id_
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update the domain models
|
self._fail_running_node_executions(workflow_execution.id_, error_message, now)
|
||||||
for node_execution in running_node_executions:
|
self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id)
|
||||||
if node_execution.node_execution_id:
|
|
||||||
# Update the domain model
|
|
||||||
node_execution.status = WorkflowNodeExecutionStatus.FAILED
|
|
||||||
node_execution.error = error_message
|
|
||||||
node_execution.finished_at = now
|
|
||||||
node_execution.elapsed_time = (now - node_execution.created_at).total_seconds()
|
|
||||||
|
|
||||||
# Update the repository with the domain model
|
|
||||||
self._workflow_node_execution_repository.save(node_execution)
|
|
||||||
|
|
||||||
if trace_manager:
|
|
||||||
trace_manager.add_trace_task(
|
|
||||||
TraceTask(
|
|
||||||
TraceTaskName.WORKFLOW_TRACE,
|
|
||||||
workflow_execution=workflow_execution,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
user_id=trace_manager.user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self._workflow_execution_repository.save(workflow_execution)
|
self._workflow_execution_repository.save(workflow_execution)
|
||||||
return workflow_execution
|
return workflow_execution
|
||||||
@ -214,8 +167,198 @@ class WorkflowCycleManager:
|
|||||||
) -> WorkflowNodeExecution:
|
) -> WorkflowNodeExecution:
|
||||||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
|
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
|
||||||
|
|
||||||
# Create a domain model
|
domain_execution = self._create_node_execution_from_event(
|
||||||
created_at = datetime.now(UTC).replace(tzinfo=None)
|
workflow_execution=workflow_execution,
|
||||||
|
event=event,
|
||||||
|
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._save_and_cache_node_execution(domain_execution)
|
||||||
|
|
||||||
|
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
|
||||||
|
domain_execution = self._get_node_execution_from_cache(event.node_execution_id)
|
||||||
|
|
||||||
|
self._update_node_execution_completion(
|
||||||
|
domain_execution,
|
||||||
|
event=event,
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._workflow_node_execution_repository.save(domain_execution)
|
||||||
|
return domain_execution
|
||||||
|
|
||||||
|
def handle_workflow_node_execution_failed(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
event: QueueNodeFailedEvent
|
||||||
|
| QueueNodeInIterationFailedEvent
|
||||||
|
| QueueNodeInLoopFailedEvent
|
||||||
|
| QueueNodeExceptionEvent,
|
||||||
|
) -> WorkflowNodeExecution:
|
||||||
|
"""
|
||||||
|
Workflow node execution failed
|
||||||
|
:param event: queue node failed event
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
domain_execution = self._get_node_execution_from_cache(event.node_execution_id)
|
||||||
|
|
||||||
|
status = (
|
||||||
|
WorkflowNodeExecutionStatus.EXCEPTION
|
||||||
|
if isinstance(event, QueueNodeExceptionEvent)
|
||||||
|
else WorkflowNodeExecutionStatus.FAILED
|
||||||
|
)
|
||||||
|
|
||||||
|
self._update_node_execution_completion(
|
||||||
|
domain_execution,
|
||||||
|
event=event,
|
||||||
|
status=status,
|
||||||
|
error=event.error,
|
||||||
|
handle_special_values=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._workflow_node_execution_repository.save(domain_execution)
|
||||||
|
return domain_execution
|
||||||
|
|
||||||
|
def handle_workflow_node_execution_retried(
|
||||||
|
self, *, workflow_execution_id: str, event: QueueNodeRetryEvent
|
||||||
|
) -> WorkflowNodeExecution:
|
||||||
|
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
|
||||||
|
|
||||||
|
domain_execution = self._create_node_execution_from_event(
|
||||||
|
workflow_execution=workflow_execution,
|
||||||
|
event=event,
|
||||||
|
status=WorkflowNodeExecutionStatus.RETRY,
|
||||||
|
error=event.error,
|
||||||
|
created_at=event.start_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle inputs and outputs
|
||||||
|
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||||
|
outputs = event.outputs
|
||||||
|
metadata = self._merge_event_metadata(event)
|
||||||
|
|
||||||
|
domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=metadata)
|
||||||
|
|
||||||
|
return self._save_and_cache_node_execution(domain_execution)
|
||||||
|
|
||||||
|
def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution:
|
||||||
|
# Check cache first
|
||||||
|
if id in self._workflow_execution_cache:
|
||||||
|
return self._workflow_execution_cache[id]
|
||||||
|
|
||||||
|
raise WorkflowRunNotFoundError(id)
|
||||||
|
|
||||||
|
def _prepare_workflow_inputs(self) -> dict[str, Any]:
|
||||||
|
"""Prepare workflow inputs by merging application inputs with system variables."""
|
||||||
|
inputs = {**self._application_generate_entity.inputs}
|
||||||
|
|
||||||
|
if self._workflow_system_variables:
|
||||||
|
for field_name, value in self._workflow_system_variables.to_dict().items():
|
||||||
|
if field_name != SystemVariableKey.CONVERSATION_ID:
|
||||||
|
inputs[f"sys.{field_name}"] = value
|
||||||
|
|
||||||
|
return dict(WorkflowEntry.handle_special_values(inputs) or {})
|
||||||
|
|
||||||
|
def _get_or_generate_execution_id(self) -> str:
|
||||||
|
"""Get execution ID from system variables or generate a new one."""
|
||||||
|
if self._workflow_system_variables and self._workflow_system_variables.workflow_execution_id:
|
||||||
|
return str(self._workflow_system_variables.workflow_execution_id)
|
||||||
|
return str(uuid4())
|
||||||
|
|
||||||
|
def _save_and_cache_workflow_execution(self, execution: WorkflowExecution) -> WorkflowExecution:
|
||||||
|
"""Save workflow execution to repository and cache it."""
|
||||||
|
self._workflow_execution_repository.save(execution)
|
||||||
|
self._workflow_execution_cache[execution.id_] = execution
|
||||||
|
return execution
|
||||||
|
|
||||||
|
def _save_and_cache_node_execution(self, execution: WorkflowNodeExecution) -> WorkflowNodeExecution:
|
||||||
|
"""Save node execution to repository and cache it if it has an ID."""
|
||||||
|
self._workflow_node_execution_repository.save(execution)
|
||||||
|
if execution.node_execution_id:
|
||||||
|
self._node_execution_cache[execution.node_execution_id] = execution
|
||||||
|
return execution
|
||||||
|
|
||||||
|
def _get_node_execution_from_cache(self, node_execution_id: str) -> WorkflowNodeExecution:
|
||||||
|
"""Get node execution from cache or raise error if not found."""
|
||||||
|
domain_execution = self._node_execution_cache.get(node_execution_id)
|
||||||
|
if not domain_execution:
|
||||||
|
raise ValueError(f"Domain node execution not found: {node_execution_id}")
|
||||||
|
return domain_execution
|
||||||
|
|
||||||
|
def _update_workflow_execution_completion(
|
||||||
|
self,
|
||||||
|
execution: WorkflowExecution,
|
||||||
|
*,
|
||||||
|
status: WorkflowExecutionStatus,
|
||||||
|
total_tokens: int,
|
||||||
|
total_steps: int,
|
||||||
|
outputs: Mapping[str, Any] | None = None,
|
||||||
|
error_message: Optional[str] = None,
|
||||||
|
exceptions_count: int = 0,
|
||||||
|
finished_at: Optional[datetime] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Update workflow execution with completion data."""
|
||||||
|
execution.status = status
|
||||||
|
execution.outputs = outputs or {}
|
||||||
|
execution.total_tokens = total_tokens
|
||||||
|
execution.total_steps = total_steps
|
||||||
|
execution.finished_at = finished_at or naive_utc_now()
|
||||||
|
execution.exceptions_count = exceptions_count
|
||||||
|
if error_message:
|
||||||
|
execution.error_message = error_message
|
||||||
|
|
||||||
|
def _add_trace_task_if_needed(
|
||||||
|
self,
|
||||||
|
trace_manager: Optional[TraceQueueManager],
|
||||||
|
workflow_execution: WorkflowExecution,
|
||||||
|
conversation_id: Optional[str],
|
||||||
|
) -> None:
|
||||||
|
"""Add trace task if trace manager is provided."""
|
||||||
|
if trace_manager:
|
||||||
|
trace_manager.add_trace_task(
|
||||||
|
TraceTask(
|
||||||
|
TraceTaskName.WORKFLOW_TRACE,
|
||||||
|
workflow_execution=workflow_execution,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
user_id=trace_manager.user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _fail_running_node_executions(
|
||||||
|
self,
|
||||||
|
workflow_execution_id: str,
|
||||||
|
error_message: str,
|
||||||
|
now: datetime,
|
||||||
|
) -> None:
|
||||||
|
"""Fail all running node executions for a workflow."""
|
||||||
|
running_node_executions = [
|
||||||
|
node_exec
|
||||||
|
for node_exec in self._node_execution_cache.values()
|
||||||
|
if node_exec.workflow_execution_id == workflow_execution_id
|
||||||
|
and node_exec.status == WorkflowNodeExecutionStatus.RUNNING
|
||||||
|
]
|
||||||
|
|
||||||
|
for node_execution in running_node_executions:
|
||||||
|
if node_execution.node_execution_id:
|
||||||
|
node_execution.status = WorkflowNodeExecutionStatus.FAILED
|
||||||
|
node_execution.error = error_message
|
||||||
|
node_execution.finished_at = now
|
||||||
|
node_execution.elapsed_time = (now - node_execution.created_at).total_seconds()
|
||||||
|
self._workflow_node_execution_repository.save(node_execution)
|
||||||
|
|
||||||
|
def _create_node_execution_from_event(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
workflow_execution: WorkflowExecution,
|
||||||
|
event: Union[QueueNodeStartedEvent, QueueNodeRetryEvent],
|
||||||
|
status: WorkflowNodeExecutionStatus,
|
||||||
|
error: Optional[str] = None,
|
||||||
|
created_at: Optional[datetime] = None,
|
||||||
|
) -> WorkflowNodeExecution:
|
||||||
|
"""Create a node execution from an event."""
|
||||||
|
now = datetime.now(UTC).replace(tzinfo=None)
|
||||||
|
created_at = created_at or now
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||||
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
|
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||||
@ -232,152 +375,76 @@ class WorkflowCycleManager:
|
|||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.node_type,
|
node_type=event.node_type,
|
||||||
title=event.node_data.title,
|
title=event.node_data.title,
|
||||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
status=status,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
|
error=error,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use the instance repository to save the domain model
|
if status == WorkflowNodeExecutionStatus.RETRY:
|
||||||
self._workflow_node_execution_repository.save(domain_execution)
|
domain_execution.finished_at = now
|
||||||
|
domain_execution.elapsed_time = (now - created_at).total_seconds()
|
||||||
|
|
||||||
return domain_execution
|
return domain_execution
|
||||||
|
|
||||||
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
|
def _update_node_execution_completion(
|
||||||
# Get the domain model from repository
|
|
||||||
domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
|
|
||||||
if not domain_execution:
|
|
||||||
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
|
|
||||||
|
|
||||||
# Process data
|
|
||||||
inputs = event.inputs
|
|
||||||
process_data = event.process_data
|
|
||||||
outputs = event.outputs
|
|
||||||
|
|
||||||
# Convert metadata keys to strings
|
|
||||||
execution_metadata_dict = {}
|
|
||||||
if event.execution_metadata:
|
|
||||||
for key, value in event.execution_metadata.items():
|
|
||||||
execution_metadata_dict[key] = value
|
|
||||||
|
|
||||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
|
||||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
|
||||||
|
|
||||||
# Update domain model
|
|
||||||
domain_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
|
||||||
domain_execution.update_from_mapping(
|
|
||||||
inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
|
|
||||||
)
|
|
||||||
domain_execution.finished_at = finished_at
|
|
||||||
domain_execution.elapsed_time = elapsed_time
|
|
||||||
|
|
||||||
# Update the repository with the domain model
|
|
||||||
self._workflow_node_execution_repository.save(domain_execution)
|
|
||||||
|
|
||||||
return domain_execution
|
|
||||||
|
|
||||||
def handle_workflow_node_execution_failed(
|
|
||||||
self,
|
self,
|
||||||
|
domain_execution: WorkflowNodeExecution,
|
||||||
*,
|
*,
|
||||||
event: QueueNodeFailedEvent
|
event: Union[
|
||||||
| QueueNodeInIterationFailedEvent
|
QueueNodeSucceededEvent,
|
||||||
| QueueNodeInLoopFailedEvent
|
QueueNodeFailedEvent,
|
||||||
| QueueNodeExceptionEvent,
|
QueueNodeInIterationFailedEvent,
|
||||||
) -> WorkflowNodeExecution:
|
QueueNodeInLoopFailedEvent,
|
||||||
"""
|
QueueNodeExceptionEvent,
|
||||||
Workflow node execution failed
|
],
|
||||||
:param event: queue node failed event
|
status: WorkflowNodeExecutionStatus,
|
||||||
:return:
|
error: Optional[str] = None,
|
||||||
"""
|
handle_special_values: bool = False,
|
||||||
# Get the domain model from repository
|
) -> None:
|
||||||
domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
|
"""Update node execution with completion data."""
|
||||||
if not domain_execution:
|
|
||||||
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
|
|
||||||
|
|
||||||
# Process data
|
|
||||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
|
||||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
|
||||||
outputs = event.outputs
|
|
||||||
|
|
||||||
# Convert metadata keys to strings
|
|
||||||
execution_metadata_dict = {}
|
|
||||||
if event.execution_metadata:
|
|
||||||
for key, value in event.execution_metadata.items():
|
|
||||||
execution_metadata_dict[key] = value
|
|
||||||
|
|
||||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||||
|
|
||||||
|
# Process data
|
||||||
|
if handle_special_values:
|
||||||
|
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||||
|
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||||
|
else:
|
||||||
|
inputs = event.inputs
|
||||||
|
process_data = event.process_data
|
||||||
|
|
||||||
|
outputs = event.outputs
|
||||||
|
|
||||||
|
# Convert metadata
|
||||||
|
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||||
|
if event.execution_metadata:
|
||||||
|
execution_metadata_dict.update(event.execution_metadata)
|
||||||
|
|
||||||
# Update domain model
|
# Update domain model
|
||||||
domain_execution.status = (
|
domain_execution.status = status
|
||||||
WorkflowNodeExecutionStatus.FAILED
|
|
||||||
if not isinstance(event, QueueNodeExceptionEvent)
|
|
||||||
else WorkflowNodeExecutionStatus.EXCEPTION
|
|
||||||
)
|
|
||||||
domain_execution.error = event.error
|
|
||||||
domain_execution.update_from_mapping(
|
domain_execution.update_from_mapping(
|
||||||
inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
|
inputs=inputs,
|
||||||
|
process_data=process_data,
|
||||||
|
outputs=outputs,
|
||||||
|
metadata=execution_metadata_dict,
|
||||||
)
|
)
|
||||||
domain_execution.finished_at = finished_at
|
domain_execution.finished_at = finished_at
|
||||||
domain_execution.elapsed_time = elapsed_time
|
domain_execution.elapsed_time = elapsed_time
|
||||||
|
|
||||||
# Update the repository with the domain model
|
if error:
|
||||||
self._workflow_node_execution_repository.save(domain_execution)
|
domain_execution.error = error
|
||||||
|
|
||||||
return domain_execution
|
def _merge_event_metadata(self, event: QueueNodeRetryEvent) -> dict[WorkflowNodeExecutionMetadataKey, str | None]:
|
||||||
|
"""Merge event metadata with origin metadata."""
|
||||||
def handle_workflow_node_execution_retried(
|
|
||||||
self, *, workflow_execution_id: str, event: QueueNodeRetryEvent
|
|
||||||
) -> WorkflowNodeExecution:
|
|
||||||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
|
|
||||||
created_at = event.start_at
|
|
||||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
|
||||||
elapsed_time = (finished_at - created_at).total_seconds()
|
|
||||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
|
||||||
outputs = event.outputs
|
|
||||||
|
|
||||||
# Convert metadata keys to strings
|
|
||||||
origin_metadata = {
|
origin_metadata = {
|
||||||
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
|
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||||
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||||
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
|
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Convert execution metadata keys to strings
|
|
||||||
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {}
|
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {}
|
||||||
if event.execution_metadata:
|
if event.execution_metadata:
|
||||||
for key, value in event.execution_metadata.items():
|
execution_metadata_dict.update(event.execution_metadata)
|
||||||
execution_metadata_dict[key] = value
|
|
||||||
|
|
||||||
merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata
|
return {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata
|
||||||
|
|
||||||
# Create a domain model
|
|
||||||
domain_execution = WorkflowNodeExecution(
|
|
||||||
id=str(uuid4()),
|
|
||||||
workflow_id=workflow_execution.workflow_id,
|
|
||||||
workflow_execution_id=workflow_execution.id_,
|
|
||||||
predecessor_node_id=event.predecessor_node_id,
|
|
||||||
node_execution_id=event.node_execution_id,
|
|
||||||
node_id=event.node_id,
|
|
||||||
node_type=event.node_type,
|
|
||||||
title=event.node_data.title,
|
|
||||||
status=WorkflowNodeExecutionStatus.RETRY,
|
|
||||||
created_at=created_at,
|
|
||||||
finished_at=finished_at,
|
|
||||||
elapsed_time=elapsed_time,
|
|
||||||
error=event.error,
|
|
||||||
index=event.node_run_index,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update with mappings
|
|
||||||
domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=merged_metadata)
|
|
||||||
|
|
||||||
# Use the instance repository to save the domain model
|
|
||||||
self._workflow_node_execution_repository.save(domain_execution)
|
|
||||||
|
|
||||||
return domain_execution
|
|
||||||
|
|
||||||
def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution:
|
|
||||||
execution = self._workflow_execution_repository.get(id)
|
|
||||||
if not execution:
|
|
||||||
raise WorkflowRunNotFoundError(id)
|
|
||||||
return execution
|
|
||||||
|
|||||||
@ -80,15 +80,12 @@ def real_workflow_system_variables():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_node_execution_repository():
|
def mock_node_execution_repository():
|
||||||
repo = MagicMock(spec=WorkflowNodeExecutionRepository)
|
repo = MagicMock(spec=WorkflowNodeExecutionRepository)
|
||||||
repo.get_by_node_execution_id.return_value = None
|
|
||||||
repo.get_running_executions.return_value = []
|
|
||||||
return repo
|
return repo
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_workflow_execution_repository():
|
def mock_workflow_execution_repository():
|
||||||
repo = MagicMock(spec=WorkflowExecutionRepository)
|
repo = MagicMock(spec=WorkflowExecutionRepository)
|
||||||
repo.get.return_value = None
|
|
||||||
return repo
|
return repo
|
||||||
|
|
||||||
|
|
||||||
@ -217,8 +214,8 @@ def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execu
|
|||||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
|
# Pre-populate the cache with the workflow execution
|
||||||
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
|
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
|
||||||
|
|
||||||
# Call the method
|
# Call the method
|
||||||
result = workflow_cycle_manager.handle_workflow_run_success(
|
result = workflow_cycle_manager.handle_workflow_run_success(
|
||||||
@ -251,11 +248,10 @@ def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execut
|
|||||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
|
# Pre-populate the cache with the workflow execution
|
||||||
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
|
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
|
||||||
|
|
||||||
# Mock get_running_executions to return an empty list
|
# No running node executions in cache (empty cache)
|
||||||
workflow_cycle_manager._workflow_node_execution_repository.get_running_executions.return_value = []
|
|
||||||
|
|
||||||
# Call the method
|
# Call the method
|
||||||
result = workflow_cycle_manager.handle_workflow_run_failed(
|
result = workflow_cycle_manager.handle_workflow_run_failed(
|
||||||
@ -289,8 +285,8 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execu
|
|||||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
|
# Pre-populate the cache with the workflow execution
|
||||||
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
|
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
|
||||||
|
|
||||||
# Create a mock event
|
# Create a mock event
|
||||||
event = MagicMock(spec=QueueNodeStartedEvent)
|
event = MagicMock(spec=QueueNodeStartedEvent)
|
||||||
@ -342,8 +338,8 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work
|
|||||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock the repository get method to return the real execution
|
# Pre-populate the cache with the workflow execution
|
||||||
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
|
workflow_cycle_manager._workflow_execution_cache["test-workflow-run-id"] = workflow_execution
|
||||||
|
|
||||||
# Call the method
|
# Call the method
|
||||||
result = workflow_cycle_manager._get_workflow_execution_or_raise_error("test-workflow-run-id")
|
result = workflow_cycle_manager._get_workflow_execution_or_raise_error("test-workflow-run-id")
|
||||||
@ -351,11 +347,13 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work
|
|||||||
# Verify the result
|
# Verify the result
|
||||||
assert result == workflow_execution
|
assert result == workflow_execution
|
||||||
|
|
||||||
# Test error case
|
# Test error case - clear cache
|
||||||
workflow_cycle_manager._workflow_execution_repository.get.return_value = None
|
workflow_cycle_manager._workflow_execution_cache.clear()
|
||||||
|
|
||||||
# Expect an error when execution is not found
|
# Expect an error when execution is not found
|
||||||
with pytest.raises(ValueError):
|
from core.app.task_pipeline.exc import WorkflowRunNotFoundError
|
||||||
|
|
||||||
|
with pytest.raises(WorkflowRunNotFoundError):
|
||||||
workflow_cycle_manager._get_workflow_execution_or_raise_error("non-existent-id")
|
workflow_cycle_manager._get_workflow_execution_or_raise_error("non-existent-id")
|
||||||
|
|
||||||
|
|
||||||
@ -384,8 +382,8 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager):
|
|||||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock the repository to return the node execution
|
# Pre-populate the cache with the node execution
|
||||||
workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution
|
workflow_cycle_manager._node_execution_cache["test-node-execution-id"] = node_execution
|
||||||
|
|
||||||
# Call the method
|
# Call the method
|
||||||
result = workflow_cycle_manager.handle_workflow_node_execution_success(
|
result = workflow_cycle_manager.handle_workflow_node_execution_success(
|
||||||
@ -414,8 +412,8 @@ def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workfl
|
|||||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
|
# Pre-populate the cache with the workflow execution
|
||||||
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
|
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
|
||||||
|
|
||||||
# Call the method
|
# Call the method
|
||||||
result = workflow_cycle_manager.handle_workflow_run_partial_success(
|
result = workflow_cycle_manager.handle_workflow_run_partial_success(
|
||||||
@ -462,8 +460,8 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
|
|||||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock the repository to return the node execution
|
# Pre-populate the cache with the node execution
|
||||||
workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution
|
workflow_cycle_manager._node_execution_cache["test-node-execution-id"] = node_execution
|
||||||
|
|
||||||
# Call the method
|
# Call the method
|
||||||
result = workflow_cycle_manager.handle_workflow_node_execution_failed(
|
result = workflow_cycle_manager.handle_workflow_node_execution_failed(
|
||||||
|
|||||||
@ -137,37 +137,6 @@ def test_save_with_existing_tenant_id(repository, session):
|
|||||||
session_obj.merge.assert_called_once_with(modified_execution)
|
session_obj.merge.assert_called_once_with(modified_execution)
|
||||||
|
|
||||||
|
|
||||||
def test_get_by_node_execution_id(repository, session, mocker: MockerFixture):
|
|
||||||
"""Test get_by_node_execution_id method."""
|
|
||||||
session_obj, _ = session
|
|
||||||
# Set up mock
|
|
||||||
mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
|
|
||||||
mock_stmt = mocker.MagicMock()
|
|
||||||
mock_select.return_value = mock_stmt
|
|
||||||
mock_stmt.where.return_value = mock_stmt
|
|
||||||
|
|
||||||
# Create a properly configured mock execution
|
|
||||||
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel)
|
|
||||||
configure_mock_execution(mock_execution)
|
|
||||||
session_obj.scalar.return_value = mock_execution
|
|
||||||
|
|
||||||
# Create a mock domain model to be returned by _to_domain_model
|
|
||||||
mock_domain_model = mocker.MagicMock()
|
|
||||||
# Mock the _to_domain_model method to return our mock domain model
|
|
||||||
repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model)
|
|
||||||
|
|
||||||
# Call method
|
|
||||||
result = repository.get_by_node_execution_id("test-node-execution-id")
|
|
||||||
|
|
||||||
# Assert select was called with correct parameters
|
|
||||||
mock_select.assert_called_once()
|
|
||||||
session_obj.scalar.assert_called_once_with(mock_stmt)
|
|
||||||
# Assert _to_domain_model was called with the mock execution
|
|
||||||
repository._to_domain_model.assert_called_once_with(mock_execution)
|
|
||||||
# Assert the result is our mock domain model
|
|
||||||
assert result is mock_domain_model
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
||||||
"""Test get_by_workflow_run method."""
|
"""Test get_by_workflow_run method."""
|
||||||
session_obj, _ = session
|
session_obj, _ = session
|
||||||
@ -202,88 +171,6 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
|||||||
assert result[0] is mock_domain_model
|
assert result[0] is mock_domain_model
|
||||||
|
|
||||||
|
|
||||||
def test_get_running_executions(repository, session, mocker: MockerFixture):
|
|
||||||
"""Test get_running_executions method."""
|
|
||||||
session_obj, _ = session
|
|
||||||
# Set up mock
|
|
||||||
mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
|
|
||||||
mock_stmt = mocker.MagicMock()
|
|
||||||
mock_select.return_value = mock_stmt
|
|
||||||
mock_stmt.where.return_value = mock_stmt
|
|
||||||
|
|
||||||
# Create a properly configured mock execution
|
|
||||||
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel)
|
|
||||||
configure_mock_execution(mock_execution)
|
|
||||||
session_obj.scalars.return_value.all.return_value = [mock_execution]
|
|
||||||
|
|
||||||
# Create a mock domain model to be returned by _to_domain_model
|
|
||||||
mock_domain_model = mocker.MagicMock()
|
|
||||||
# Mock the _to_domain_model method to return our mock domain model
|
|
||||||
repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model)
|
|
||||||
|
|
||||||
# Call method
|
|
||||||
result = repository.get_running_executions("test-workflow-run-id")
|
|
||||||
|
|
||||||
# Assert select was called with correct parameters
|
|
||||||
mock_select.assert_called_once()
|
|
||||||
session_obj.scalars.assert_called_once_with(mock_stmt)
|
|
||||||
# Assert _to_domain_model was called with the mock execution
|
|
||||||
repository._to_domain_model.assert_called_once_with(mock_execution)
|
|
||||||
# Assert the result contains our mock domain model
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0] is mock_domain_model
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_via_save(repository, session):
|
|
||||||
"""Test updating an existing record via save method."""
|
|
||||||
session_obj, _ = session
|
|
||||||
# Create a mock execution
|
|
||||||
execution = MagicMock(spec=WorkflowNodeExecutionModel)
|
|
||||||
execution.tenant_id = None
|
|
||||||
execution.app_id = None
|
|
||||||
execution.inputs = None
|
|
||||||
execution.process_data = None
|
|
||||||
execution.outputs = None
|
|
||||||
execution.metadata = None
|
|
||||||
|
|
||||||
# Mock the to_db_model method to return the execution itself
|
|
||||||
# This simulates the behavior of setting tenant_id and app_id
|
|
||||||
repository.to_db_model = MagicMock(return_value=execution)
|
|
||||||
|
|
||||||
# Call save method to update an existing record
|
|
||||||
repository.save(execution)
|
|
||||||
|
|
||||||
# Assert to_db_model was called with the execution
|
|
||||||
repository.to_db_model.assert_called_once_with(execution)
|
|
||||||
|
|
||||||
# Assert session.merge was called (for updates)
|
|
||||||
session_obj.merge.assert_called_once_with(execution)
|
|
||||||
|
|
||||||
|
|
||||||
def test_clear(repository, session, mocker: MockerFixture):
|
|
||||||
"""Test clear method."""
|
|
||||||
session_obj, _ = session
|
|
||||||
# Set up mock
|
|
||||||
mock_delete = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.delete")
|
|
||||||
mock_stmt = mocker.MagicMock()
|
|
||||||
mock_delete.return_value = mock_stmt
|
|
||||||
mock_stmt.where.return_value = mock_stmt
|
|
||||||
|
|
||||||
# Mock the execute result with rowcount
|
|
||||||
mock_result = mocker.MagicMock()
|
|
||||||
mock_result.rowcount = 5 # Simulate 5 records deleted
|
|
||||||
session_obj.execute.return_value = mock_result
|
|
||||||
|
|
||||||
# Call method
|
|
||||||
repository.clear()
|
|
||||||
|
|
||||||
# Assert delete was called with correct parameters
|
|
||||||
mock_delete.assert_called_once_with(WorkflowNodeExecutionModel)
|
|
||||||
mock_stmt.where.assert_called()
|
|
||||||
session_obj.execute.assert_called_once_with(mock_stmt)
|
|
||||||
session_obj.commit.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
def test_to_db_model(repository):
|
def test_to_db_model(repository):
|
||||||
"""Test to_db_model method."""
|
"""Test to_db_model method."""
|
||||||
# Create a domain model
|
# Create a domain model
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user