refactor: simplify variable pool key structure and improve type safety (#23732)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
223c1a8089
commit
577062b93a
@ -4,4 +4,4 @@
|
|||||||
#
|
#
|
||||||
# If the selector length is more than 2, the remaining parts are the keys / indexes paths used
|
# If the selector length is more than 2, the remaining parts are the keys / indexes paths used
|
||||||
# to extract part of the variable value.
|
# to extract part of the variable value.
|
||||||
MIN_SELECTORS_LENGTH = 2
|
SELECTORS_LENGTH = 2
|
||||||
|
|||||||
@ -7,8 +7,8 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from core.file import File, FileAttribute, file_manager
|
from core.file import File, FileAttribute, file_manager
|
||||||
from core.variables import Segment, SegmentGroup, Variable
|
from core.variables import Segment, SegmentGroup, Variable
|
||||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
from core.variables.consts import SELECTORS_LENGTH
|
||||||
from core.variables.segments import FileSegment, NoneSegment
|
from core.variables.segments import FileSegment, ObjectSegment
|
||||||
from core.variables.variables import VariableUnion
|
from core.variables.variables import VariableUnion
|
||||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||||
from core.workflow.system_variable import SystemVariable
|
from core.workflow.system_variable import SystemVariable
|
||||||
@ -24,7 +24,7 @@ class VariablePool(BaseModel):
|
|||||||
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
||||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||||
# elements of the selector except the first one.
|
# elements of the selector except the first one.
|
||||||
variable_dictionary: defaultdict[str, Annotated[dict[int, VariableUnion], Field(default_factory=dict)]] = Field(
|
variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field(
|
||||||
description="Variables mapping",
|
description="Variables mapping",
|
||||||
default=defaultdict(dict),
|
default=defaultdict(dict),
|
||||||
)
|
)
|
||||||
@ -36,6 +36,7 @@ class VariablePool(BaseModel):
|
|||||||
)
|
)
|
||||||
system_variables: SystemVariable = Field(
|
system_variables: SystemVariable = Field(
|
||||||
description="System variables",
|
description="System variables",
|
||||||
|
default_factory=SystemVariable.empty,
|
||||||
)
|
)
|
||||||
environment_variables: Sequence[VariableUnion] = Field(
|
environment_variables: Sequence[VariableUnion] = Field(
|
||||||
description="Environment variables.",
|
description="Environment variables.",
|
||||||
@ -58,23 +59,29 @@ class VariablePool(BaseModel):
|
|||||||
|
|
||||||
def add(self, selector: Sequence[str], value: Any, /) -> None:
|
def add(self, selector: Sequence[str], value: Any, /) -> None:
|
||||||
"""
|
"""
|
||||||
Adds a variable to the variable pool.
|
Add a variable to the variable pool.
|
||||||
|
|
||||||
NOTE: You should not add a non-Segment value to the variable pool
|
This method accepts a selector path and a value, converting the value
|
||||||
even if it is allowed now.
|
to a Variable object if necessary before storing it in the pool.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
selector (Sequence[str]): The selector for the variable.
|
selector: A two-element sequence containing [node_id, variable_name].
|
||||||
value (VariableValue): The value of the variable.
|
The selector must have exactly 2 elements to be valid.
|
||||||
|
value: The value to store. Can be a Variable, Segment, or any value
|
||||||
|
that can be converted to a Segment (str, int, float, dict, list, File).
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the selector is invalid.
|
ValueError: If selector length is not exactly 2 elements.
|
||||||
|
|
||||||
Returns:
|
Note:
|
||||||
None
|
While non-Segment values are currently accepted and automatically
|
||||||
|
converted, it's recommended to pass Segment or Variable objects directly.
|
||||||
"""
|
"""
|
||||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
if len(selector) != SELECTORS_LENGTH:
|
||||||
raise ValueError("Invalid selector")
|
raise ValueError(
|
||||||
|
f"Invalid selector: expected {SELECTORS_LENGTH} elements (node_id, variable_name), "
|
||||||
|
f"got {len(selector)} elements"
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(value, Variable):
|
if isinstance(value, Variable):
|
||||||
variable = value
|
variable = value
|
||||||
@ -84,57 +91,85 @@ class VariablePool(BaseModel):
|
|||||||
segment = variable_factory.build_segment(value)
|
segment = variable_factory.build_segment(value)
|
||||||
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
|
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
|
||||||
|
|
||||||
key, hash_key = self._selector_to_keys(selector)
|
node_id, name = self._selector_to_keys(selector)
|
||||||
# Based on the definition of `VariableUnion`,
|
# Based on the definition of `VariableUnion`,
|
||||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||||
self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable)
|
self.variable_dictionary[node_id][name] = cast(VariableUnion, variable)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, int]:
|
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:
|
||||||
return selector[0], hash(tuple(selector[1:]))
|
return selector[0], selector[1]
|
||||||
|
|
||||||
def _has(self, selector: Sequence[str]) -> bool:
|
def _has(self, selector: Sequence[str]) -> bool:
|
||||||
key, hash_key = self._selector_to_keys(selector)
|
node_id, name = self._selector_to_keys(selector)
|
||||||
if key not in self.variable_dictionary:
|
if node_id not in self.variable_dictionary:
|
||||||
return False
|
return False
|
||||||
if hash_key not in self.variable_dictionary[key]:
|
if name not in self.variable_dictionary[node_id]:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get(self, selector: Sequence[str], /) -> Segment | None:
|
def get(self, selector: Sequence[str], /) -> Segment | None:
|
||||||
"""
|
"""
|
||||||
Retrieves the value from the variable pool based on the given selector.
|
Retrieve a variable's value from the pool as a Segment.
|
||||||
|
|
||||||
|
This method supports both simple selectors [node_id, variable_name] and
|
||||||
|
extended selectors that include attribute access for FileSegment and
|
||||||
|
ObjectSegment types.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
selector (Sequence[str]): The selector used to identify the variable.
|
selector: A sequence with at least 2 elements:
|
||||||
|
- [node_id, variable_name]: Returns the full segment
|
||||||
|
- [node_id, variable_name, attr, ...]: Returns a nested value
|
||||||
|
from FileSegment (e.g., 'url', 'name') or ObjectSegment
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Any: The value associated with the given selector.
|
The Segment associated with the selector, or None if not found.
|
||||||
|
Returns None if selector has fewer than 2 elements.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the selector is invalid.
|
ValueError: If attempting to access an invalid FileAttribute.
|
||||||
"""
|
"""
|
||||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
if len(selector) < SELECTORS_LENGTH:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
key, hash_key = self._selector_to_keys(selector)
|
node_id, name = self._selector_to_keys(selector)
|
||||||
value: Segment | None = self.variable_dictionary[key].get(hash_key)
|
segment: Segment | None = self.variable_dictionary[node_id].get(name)
|
||||||
|
|
||||||
if value is None:
|
if segment is None:
|
||||||
selector, attr = selector[:-1], selector[-1]
|
return None
|
||||||
|
|
||||||
|
if len(selector) == 2:
|
||||||
|
return segment
|
||||||
|
|
||||||
|
if isinstance(segment, FileSegment):
|
||||||
|
attr = selector[2]
|
||||||
# Python support `attr in FileAttribute` after 3.12
|
# Python support `attr in FileAttribute` after 3.12
|
||||||
if attr not in {item.value for item in FileAttribute}:
|
if attr not in {item.value for item in FileAttribute}:
|
||||||
return None
|
return None
|
||||||
value = self.get(selector)
|
|
||||||
if not isinstance(value, FileSegment | NoneSegment):
|
|
||||||
return None
|
|
||||||
if isinstance(value, FileSegment):
|
|
||||||
attr = FileAttribute(attr)
|
attr = FileAttribute(attr)
|
||||||
attr_value = file_manager.get_attr(file=value.value, attr=attr)
|
attr_value = file_manager.get_attr(file=segment.value, attr=attr)
|
||||||
return variable_factory.build_segment(attr_value)
|
return variable_factory.build_segment(attr_value)
|
||||||
return value
|
|
||||||
|
|
||||||
return value
|
# Navigate through nested attributes
|
||||||
|
result: Any = segment
|
||||||
|
for attr in selector[2:]:
|
||||||
|
result = self._extract_value(result)
|
||||||
|
result = self._get_nested_attribute(result, attr)
|
||||||
|
if result is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Return result as Segment
|
||||||
|
return result if isinstance(result, Segment) else variable_factory.build_segment(result)
|
||||||
|
|
||||||
|
def _extract_value(self, obj: Any) -> Any:
|
||||||
|
"""Extract the actual value from an ObjectSegment."""
|
||||||
|
return obj.value if isinstance(obj, ObjectSegment) else obj
|
||||||
|
|
||||||
|
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Any:
|
||||||
|
"""Get a nested attribute from a dictionary-like object."""
|
||||||
|
if not isinstance(obj, dict):
|
||||||
|
return None
|
||||||
|
return obj.get(attr)
|
||||||
|
|
||||||
def remove(self, selector: Sequence[str], /):
|
def remove(self, selector: Sequence[str], /):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from configs import dify_config
|
|||||||
from core.app.apps.exc import GenerateTaskStoppedError
|
from core.app.apps.exc import GenerateTaskStoppedError
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult
|
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult
|
||||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_engine.entities.event import (
|
||||||
@ -51,7 +51,6 @@ from core.workflow.nodes.base import BaseNode
|
|||||||
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||||
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
|
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
|
||||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||||
from core.workflow.utils import variable_utils
|
|
||||||
from libs.flask_utils import preserve_flask_contexts
|
from libs.flask_utils import preserve_flask_contexts
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
@ -701,11 +700,9 @@ class GraphEngine:
|
|||||||
route_node_state.status = RouteNodeState.Status.EXCEPTION
|
route_node_state.status = RouteNodeState.Status.EXCEPTION
|
||||||
if run_result.outputs:
|
if run_result.outputs:
|
||||||
for variable_key, variable_value in run_result.outputs.items():
|
for variable_key, variable_value in run_result.outputs.items():
|
||||||
# append variables to variable pool recursively
|
# Add variables to variable pool
|
||||||
self._append_variables_recursively(
|
self.graph_runtime_state.variable_pool.add(
|
||||||
node_id=node.node_id,
|
[node.node_id, variable_key], variable_value
|
||||||
variable_key_list=[variable_key],
|
|
||||||
variable_value=variable_value,
|
|
||||||
)
|
)
|
||||||
yield NodeRunExceptionEvent(
|
yield NodeRunExceptionEvent(
|
||||||
error=run_result.error or "System Error",
|
error=run_result.error or "System Error",
|
||||||
@ -758,11 +755,9 @@ class GraphEngine:
|
|||||||
# append node output variables to variable pool
|
# append node output variables to variable pool
|
||||||
if run_result.outputs:
|
if run_result.outputs:
|
||||||
for variable_key, variable_value in run_result.outputs.items():
|
for variable_key, variable_value in run_result.outputs.items():
|
||||||
# append variables to variable pool recursively
|
# Add variables to variable pool
|
||||||
self._append_variables_recursively(
|
self.graph_runtime_state.variable_pool.add(
|
||||||
node_id=node.node_id,
|
[node.node_id, variable_key], variable_value
|
||||||
variable_key_list=[variable_key],
|
|
||||||
variable_value=variable_value,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# When setting metadata, convert to dict first
|
# When setting metadata, convert to dict first
|
||||||
@ -851,21 +846,6 @@ class GraphEngine:
|
|||||||
logger.exception("Node %s run failed", node.title)
|
logger.exception("Node %s run failed", node.title)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
|
||||||
"""
|
|
||||||
Append variables recursively
|
|
||||||
:param node_id: node id
|
|
||||||
:param variable_key_list: variable key list
|
|
||||||
:param variable_value: variable value
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
variable_utils.append_variables_recursively(
|
|
||||||
self.graph_runtime_state.variable_pool,
|
|
||||||
node_id,
|
|
||||||
variable_key_list,
|
|
||||||
variable_value,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
|
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
|
||||||
"""
|
"""
|
||||||
Check timeout
|
Check timeout
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from typing import Any, TypeVar
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.variables import Segment
|
from core.variables import Segment
|
||||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
from core.variables.consts import SELECTORS_LENGTH
|
||||||
from core.variables.types import SegmentType
|
from core.variables.types import SegmentType
|
||||||
|
|
||||||
# Use double underscore (`__`) prefix for internal variables
|
# Use double underscore (`__`) prefix for internal variables
|
||||||
@ -23,7 +23,7 @@ _T = TypeVar("_T", bound=MutableMapping[str, Any])
|
|||||||
|
|
||||||
|
|
||||||
def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable:
|
def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable:
|
||||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
if len(selector) < SELECTORS_LENGTH:
|
||||||
raise Exception("selector too short")
|
raise Exception("selector too short")
|
||||||
node_id, var_name = selector[:2]
|
node_id, var_name = selector[:2]
|
||||||
return UpdatedVariable(
|
return UpdatedVariable(
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from typing import Any, Optional, cast
|
|||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.variables import SegmentType, Variable
|
from core.variables import SegmentType, Variable
|
||||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
from core.variables.consts import SELECTORS_LENGTH
|
||||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
@ -46,7 +46,7 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
|
|||||||
selector = item.value
|
selector = item.value
|
||||||
if not isinstance(selector, list):
|
if not isinstance(selector, list):
|
||||||
raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}")
|
raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}")
|
||||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
if len(selector) < SELECTORS_LENGTH:
|
||||||
raise InvalidDataError(f"selector too short, {node_id=}, {item=}")
|
raise InvalidDataError(f"selector too short, {node_id=}, {item=}")
|
||||||
selector_str = ".".join(selector)
|
selector_str = ".".join(selector)
|
||||||
key = f"{node_id}.#{selector_str}#"
|
key = f"{node_id}.#{selector_str}#"
|
||||||
|
|||||||
@ -1,29 +0,0 @@
|
|||||||
from core.variables.segments import ObjectSegment, Segment
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
|
||||||
|
|
||||||
|
|
||||||
def append_variables_recursively(
|
|
||||||
pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue | Segment
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Append variables recursively
|
|
||||||
:param pool: variable pool to append variables to
|
|
||||||
:param node_id: node id
|
|
||||||
:param variable_key_list: variable key list
|
|
||||||
:param variable_value: variable value
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
pool.add([node_id] + variable_key_list, variable_value)
|
|
||||||
|
|
||||||
# if variable_value is a dict, then recursively append variables
|
|
||||||
if isinstance(variable_value, ObjectSegment):
|
|
||||||
variable_dict = variable_value.value
|
|
||||||
elif isinstance(variable_value, dict):
|
|
||||||
variable_dict = variable_value
|
|
||||||
else:
|
|
||||||
return
|
|
||||||
|
|
||||||
for key, value in variable_dict.items():
|
|
||||||
# construct new key list
|
|
||||||
new_key_list = variable_key_list + [key]
|
|
||||||
append_variables_recursively(pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value)
|
|
||||||
@ -3,9 +3,8 @@ from collections.abc import Mapping, Sequence
|
|||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
|
|
||||||
from core.variables import Variable
|
from core.variables import Variable
|
||||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
from core.variables.consts import SELECTORS_LENGTH
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.utils import variable_utils
|
|
||||||
|
|
||||||
|
|
||||||
class VariableLoader(Protocol):
|
class VariableLoader(Protocol):
|
||||||
@ -78,7 +77,7 @@ def load_into_variable_pool(
|
|||||||
variables_to_load.append(list(selector))
|
variables_to_load.append(list(selector))
|
||||||
loaded = variable_loader.load_variables(variables_to_load)
|
loaded = variable_loader.load_variables(variables_to_load)
|
||||||
for var in loaded:
|
for var in loaded:
|
||||||
assert len(var.selector) >= MIN_SELECTORS_LENGTH, f"Invalid variable {var}"
|
assert len(var.selector) >= SELECTORS_LENGTH, f"Invalid variable {var}"
|
||||||
variable_utils.append_variables_recursively(
|
# Add variable directly to the pool
|
||||||
variable_pool, node_id=var.selector[0], variable_key_list=list(var.selector[1:]), variable_value=var
|
# The variable pool expects 2-element selectors [node_id, variable_name]
|
||||||
)
|
variable_pool.add([var.selector[0], var.selector[1]], var)
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from sqlalchemy.sql.expression import and_, or_
|
|||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.file.models import File
|
from core.file.models import File
|
||||||
from core.variables import Segment, StringSegment, Variable
|
from core.variables import Segment, StringSegment, Variable
|
||||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
from core.variables.consts import SELECTORS_LENGTH
|
||||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||||
from core.variables.types import SegmentType
|
from core.variables.types import SegmentType
|
||||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||||
@ -147,7 +147,7 @@ class WorkflowDraftVariableService:
|
|||||||
) -> list[WorkflowDraftVariable]:
|
) -> list[WorkflowDraftVariable]:
|
||||||
ors = []
|
ors = []
|
||||||
for selector in selectors:
|
for selector in selectors:
|
||||||
assert len(selector) >= MIN_SELECTORS_LENGTH, f"Invalid selector to get: {selector}"
|
assert len(selector) >= SELECTORS_LENGTH, f"Invalid selector to get: {selector}"
|
||||||
node_id, name = selector[:2]
|
node_id, name = selector[:2]
|
||||||
ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name))
|
ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name))
|
||||||
|
|
||||||
@ -608,7 +608,7 @@ class DraftVariableSaver:
|
|||||||
|
|
||||||
for item in updated_variables:
|
for item in updated_variables:
|
||||||
selector = item.selector
|
selector = item.selector
|
||||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
if len(selector) < SELECTORS_LENGTH:
|
||||||
raise Exception("selector too short")
|
raise Exception("selector too short")
|
||||||
# NOTE(QuantumGhost): only the following two kinds of variable could be updated by
|
# NOTE(QuantumGhost): only the following two kinds of variable could be updated by
|
||||||
# VariableAssigner: ConversationVariable and iteration variable.
|
# VariableAssigner: ConversationVariable and iteration variable.
|
||||||
|
|||||||
@ -69,8 +69,12 @@ def test_get_file_attribute(pool, file):
|
|||||||
|
|
||||||
|
|
||||||
def test_use_long_selector(pool):
|
def test_use_long_selector(pool):
|
||||||
pool.add(("node_1", "part_1", "part_2"), StringSegment(value="test_value"))
|
# The add method now only accepts 2-element selectors (node_id, variable_name)
|
||||||
|
# Store nested data as an ObjectSegment instead
|
||||||
|
nested_data = {"part_2": "test_value"}
|
||||||
|
pool.add(("node_1", "part_1"), ObjectSegment(value=nested_data))
|
||||||
|
|
||||||
|
# The get method supports longer selectors for nested access
|
||||||
result = pool.get(("node_1", "part_1", "part_2"))
|
result = pool.get(("node_1", "part_1", "part_2"))
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.value == "test_value"
|
assert result.value == "test_value"
|
||||||
@ -280,8 +284,10 @@ class TestVariablePoolSerialization:
|
|||||||
pool.add((self._NODE2_ID, "array_file"), ArrayFileSegment(value=[test_file]))
|
pool.add((self._NODE2_ID, "array_file"), ArrayFileSegment(value=[test_file]))
|
||||||
pool.add((self._NODE2_ID, "array_any"), ArrayAnySegment(value=["mixed", 123, {"key": "value"}]))
|
pool.add((self._NODE2_ID, "array_any"), ArrayAnySegment(value=["mixed", 123, {"key": "value"}]))
|
||||||
|
|
||||||
# Add nested variables
|
# Add nested variables as ObjectSegment
|
||||||
pool.add((self._NODE3_ID, "nested", "deep", "var"), StringSegment(value="deep_value"))
|
# The add method only accepts 2-element selectors
|
||||||
|
nested_obj = {"deep": {"var": "deep_value"}}
|
||||||
|
pool.add((self._NODE3_ID, "nested"), ObjectSegment(value=nested_obj))
|
||||||
|
|
||||||
def test_system_variables(self):
|
def test_system_variables(self):
|
||||||
sys_vars = SystemVariable(
|
sys_vars = SystemVariable(
|
||||||
|
|||||||
@ -1,148 +0,0 @@
|
|||||||
from typing import Any
|
|
||||||
|
|
||||||
from core.variables.segments import ObjectSegment, StringSegment
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
|
||||||
from core.workflow.utils.variable_utils import append_variables_recursively
|
|
||||||
|
|
||||||
|
|
||||||
class TestAppendVariablesRecursively:
|
|
||||||
"""Test cases for append_variables_recursively function"""
|
|
||||||
|
|
||||||
def test_append_simple_dict_value(self):
|
|
||||||
"""Test appending a simple dictionary value"""
|
|
||||||
pool = VariablePool.empty()
|
|
||||||
node_id = "test_node"
|
|
||||||
variable_key_list = ["output"]
|
|
||||||
variable_value = {"name": "John", "age": 30}
|
|
||||||
|
|
||||||
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
|
|
||||||
|
|
||||||
# Check that the main variable is added
|
|
||||||
main_var = pool.get([node_id] + variable_key_list)
|
|
||||||
assert main_var is not None
|
|
||||||
assert main_var.value == variable_value
|
|
||||||
|
|
||||||
# Check that nested variables are added recursively
|
|
||||||
name_var = pool.get([node_id] + variable_key_list + ["name"])
|
|
||||||
assert name_var is not None
|
|
||||||
assert name_var.value == "John"
|
|
||||||
|
|
||||||
age_var = pool.get([node_id] + variable_key_list + ["age"])
|
|
||||||
assert age_var is not None
|
|
||||||
assert age_var.value == 30
|
|
||||||
|
|
||||||
def test_append_object_segment_value(self):
|
|
||||||
"""Test appending an ObjectSegment value"""
|
|
||||||
pool = VariablePool.empty()
|
|
||||||
node_id = "test_node"
|
|
||||||
variable_key_list = ["result"]
|
|
||||||
|
|
||||||
# Create an ObjectSegment
|
|
||||||
obj_data = {"status": "success", "code": 200}
|
|
||||||
variable_value = ObjectSegment(value=obj_data)
|
|
||||||
|
|
||||||
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
|
|
||||||
|
|
||||||
# Check that the main variable is added
|
|
||||||
main_var = pool.get([node_id] + variable_key_list)
|
|
||||||
assert main_var is not None
|
|
||||||
assert isinstance(main_var, ObjectSegment)
|
|
||||||
assert main_var.value == obj_data
|
|
||||||
|
|
||||||
# Check that nested variables are added recursively
|
|
||||||
status_var = pool.get([node_id] + variable_key_list + ["status"])
|
|
||||||
assert status_var is not None
|
|
||||||
assert status_var.value == "success"
|
|
||||||
|
|
||||||
code_var = pool.get([node_id] + variable_key_list + ["code"])
|
|
||||||
assert code_var is not None
|
|
||||||
assert code_var.value == 200
|
|
||||||
|
|
||||||
def test_append_nested_dict_value(self):
|
|
||||||
"""Test appending a nested dictionary value"""
|
|
||||||
pool = VariablePool.empty()
|
|
||||||
node_id = "test_node"
|
|
||||||
variable_key_list = ["data"]
|
|
||||||
|
|
||||||
variable_value = {
|
|
||||||
"user": {
|
|
||||||
"profile": {"name": "Alice", "email": "alice@example.com"},
|
|
||||||
"settings": {"theme": "dark", "notifications": True},
|
|
||||||
},
|
|
||||||
"metadata": {"version": "1.0", "timestamp": 1234567890},
|
|
||||||
}
|
|
||||||
|
|
||||||
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
|
|
||||||
|
|
||||||
# Check deeply nested variables
|
|
||||||
name_var = pool.get([node_id] + variable_key_list + ["user", "profile", "name"])
|
|
||||||
assert name_var is not None
|
|
||||||
assert name_var.value == "Alice"
|
|
||||||
|
|
||||||
email_var = pool.get([node_id] + variable_key_list + ["user", "profile", "email"])
|
|
||||||
assert email_var is not None
|
|
||||||
assert email_var.value == "alice@example.com"
|
|
||||||
|
|
||||||
theme_var = pool.get([node_id] + variable_key_list + ["user", "settings", "theme"])
|
|
||||||
assert theme_var is not None
|
|
||||||
assert theme_var.value == "dark"
|
|
||||||
|
|
||||||
notifications_var = pool.get([node_id] + variable_key_list + ["user", "settings", "notifications"])
|
|
||||||
assert notifications_var is not None
|
|
||||||
assert notifications_var.value == 1 # Boolean True is converted to integer 1
|
|
||||||
|
|
||||||
version_var = pool.get([node_id] + variable_key_list + ["metadata", "version"])
|
|
||||||
assert version_var is not None
|
|
||||||
assert version_var.value == "1.0"
|
|
||||||
|
|
||||||
def test_append_non_dict_value(self):
|
|
||||||
"""Test appending a non-dictionary value (should not recurse)"""
|
|
||||||
pool = VariablePool.empty()
|
|
||||||
node_id = "test_node"
|
|
||||||
variable_key_list = ["simple"]
|
|
||||||
variable_value = "simple_string"
|
|
||||||
|
|
||||||
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
|
|
||||||
|
|
||||||
# Check that only the main variable is added
|
|
||||||
main_var = pool.get([node_id] + variable_key_list)
|
|
||||||
assert main_var is not None
|
|
||||||
assert main_var.value == variable_value
|
|
||||||
|
|
||||||
# Ensure no additional variables are created
|
|
||||||
assert len(pool.variable_dictionary[node_id]) == 1
|
|
||||||
|
|
||||||
def test_append_segment_non_object_value(self):
|
|
||||||
"""Test appending a Segment that is not ObjectSegment (should not recurse)"""
|
|
||||||
pool = VariablePool.empty()
|
|
||||||
node_id = "test_node"
|
|
||||||
variable_key_list = ["text"]
|
|
||||||
variable_value = StringSegment(value="Hello World")
|
|
||||||
|
|
||||||
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
|
|
||||||
|
|
||||||
# Check that only the main variable is added
|
|
||||||
main_var = pool.get([node_id] + variable_key_list)
|
|
||||||
assert main_var is not None
|
|
||||||
assert isinstance(main_var, StringSegment)
|
|
||||||
assert main_var.value == "Hello World"
|
|
||||||
|
|
||||||
# Ensure no additional variables are created
|
|
||||||
assert len(pool.variable_dictionary[node_id]) == 1
|
|
||||||
|
|
||||||
def test_append_empty_dict_value(self):
|
|
||||||
"""Test appending an empty dictionary value"""
|
|
||||||
pool = VariablePool.empty()
|
|
||||||
node_id = "test_node"
|
|
||||||
variable_key_list = ["empty"]
|
|
||||||
variable_value: dict[str, Any] = {}
|
|
||||||
|
|
||||||
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
|
|
||||||
|
|
||||||
# Check that the main variable is added
|
|
||||||
main_var = pool.get([node_id] + variable_key_list)
|
|
||||||
assert main_var is not None
|
|
||||||
assert main_var.value == {}
|
|
||||||
|
|
||||||
# Ensure only the main variable is created (no recursion for empty dict)
|
|
||||||
assert len(pool.variable_dictionary[node_id]) == 1
|
|
||||||
Loading…
x
Reference in New Issue
Block a user