try ast-grep (#24149)

This commit is contained in:
Asuka Minato 2025-08-19 14:41:52 +09:00 committed by GitHub
parent 75199442c1
commit 70da81d0e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 35 additions and 32 deletions

View File

@ -23,6 +23,9 @@ jobs:
uv run ruff check --fix-only . uv run ruff check --fix-only .
# Format code # Format code
uv run ruff format . uv run ruff format .
- name: ast-grep
run: |
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27 - uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27

View File

@ -137,7 +137,7 @@ class InstructionGenerateApi(Resource):
from models import App, db from models import App, db
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
app = db.session.query(App).filter(App.id == args["flow_id"]).first() app = db.session.query(App).where(App.id == args["flow_id"]).first()
if not app: if not app:
return {"error": f"app {args['flow_id']} not found"}, 400 return {"error": f"app {args['flow_id']} not found"}, 400
workflow = WorkflowService().get_draft_workflow(app_model=app) workflow = WorkflowService().get_draft_workflow(app_model=app)

View File

@ -39,7 +39,7 @@ class UploadFileApi(Resource):
data_source_info = document.data_source_info_dict data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info: if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"] file_id = data_source_info["upload_file_id"]
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file: if not upload_file:
raise NotFound("UploadFile not found.") raise NotFound("UploadFile not found.")
else: else:

View File

@ -181,7 +181,7 @@ class MessageCycleManager:
:param message_id: message id :param message_id: message id
:return: :return:
""" """
message_file = db.session.query(MessageFile).filter(MessageFile.id == message_id).first() message_file = db.session.query(MessageFile).where(MessageFile.id == message_id).first()
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
return MessageStreamResponse( return MessageStreamResponse(

View File

@ -399,9 +399,9 @@ class LLMGenerator:
def instruction_modify_legacy( def instruction_modify_legacy(
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
) -> dict: ) -> dict:
app: App | None = db.session.query(App).filter(App.id == flow_id).first() app: App | None = db.session.query(App).where(App.id == flow_id).first()
last_run: Message | None = ( last_run: Message | None = (
db.session.query(Message).filter(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
) )
if not last_run: if not last_run:
return LLMGenerator.__instruction_modify_common( return LLMGenerator.__instruction_modify_common(
@ -442,7 +442,7 @@ class LLMGenerator:
) -> dict: ) -> dict:
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
app: App | None = db.session.query(App).filter(App.id == flow_id).first() app: App | None = db.session.query(App).where(App.id == flow_id).first()
if not app: if not app:
raise ValueError("App not found.") raise ValueError("App not found.")
workflow = WorkflowService().get_draft_workflow(app_model=app) workflow = WorkflowService().get_draft_workflow(app_model=app)

View File

@ -37,7 +37,7 @@ def clean_workflow_runlogs_precise():
cutoff_date = datetime.datetime.now() - datetime.timedelta(days=retention_days) cutoff_date = datetime.datetime.now() - datetime.timedelta(days=retention_days)
try: try:
total_workflow_runs = db.session.query(WorkflowRun).filter(WorkflowRun.created_at < cutoff_date).count() total_workflow_runs = db.session.query(WorkflowRun).where(WorkflowRun.created_at < cutoff_date).count()
if total_workflow_runs == 0: if total_workflow_runs == 0:
_logger.info("No expired workflow run logs found") _logger.info("No expired workflow run logs found")
return return
@ -49,7 +49,7 @@ def clean_workflow_runlogs_precise():
while True: while True:
workflow_runs = ( workflow_runs = (
db.session.query(WorkflowRun.id).filter(WorkflowRun.created_at < cutoff_date).limit(BATCH_SIZE).all() db.session.query(WorkflowRun.id).where(WorkflowRun.created_at < cutoff_date).limit(BATCH_SIZE).all()
) )
if not workflow_runs: if not workflow_runs:
@ -99,52 +99,52 @@ def _delete_batch_with_retry(workflow_run_ids: list[str], attempt_count: int) ->
message_id_list = [msg.id for msg in message_data] message_id_list = [msg.id for msg in message_data]
conversation_id_list = list({msg.conversation_id for msg in message_data if msg.conversation_id}) conversation_id_list = list({msg.conversation_id for msg in message_data if msg.conversation_id})
if message_id_list: if message_id_list:
db.session.query(AppAnnotationHitHistory).filter( db.session.query(AppAnnotationHitHistory).where(
AppAnnotationHitHistory.message_id.in_(message_id_list) AppAnnotationHitHistory.message_id.in_(message_id_list)
).delete(synchronize_session=False) ).delete(synchronize_session=False)
db.session.query(MessageAgentThought).filter( db.session.query(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_id_list)).delete(
MessageAgentThought.message_id.in_(message_id_list)
).delete(synchronize_session=False)
db.session.query(MessageChain).filter(MessageChain.message_id.in_(message_id_list)).delete(
synchronize_session=False synchronize_session=False
) )
db.session.query(MessageFile).filter(MessageFile.message_id.in_(message_id_list)).delete( db.session.query(MessageChain).where(MessageChain.message_id.in_(message_id_list)).delete(
synchronize_session=False synchronize_session=False
) )
db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id.in_(message_id_list)).delete( db.session.query(MessageFile).where(MessageFile.message_id.in_(message_id_list)).delete(
synchronize_session=False synchronize_session=False
) )
db.session.query(MessageFeedback).filter(MessageFeedback.message_id.in_(message_id_list)).delete( db.session.query(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_id_list)).delete(
synchronize_session=False synchronize_session=False
) )
db.session.query(Message).filter(Message.workflow_run_id.in_(workflow_run_ids)).delete( db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_id_list)).delete(
synchronize_session=False synchronize_session=False
) )
db.session.query(WorkflowAppLog).filter(WorkflowAppLog.workflow_run_id.in_(workflow_run_ids)).delete( db.session.query(Message).where(Message.workflow_run_id.in_(workflow_run_ids)).delete(
synchronize_session=False
)
db.session.query(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(workflow_run_ids)).delete(
synchronize_session=False synchronize_session=False
) )
db.session.query(WorkflowNodeExecutionModel).filter( db.session.query(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.workflow_run_id.in_(workflow_run_ids) WorkflowNodeExecutionModel.workflow_run_id.in_(workflow_run_ids)
).delete(synchronize_session=False) ).delete(synchronize_session=False)
if conversation_id_list: if conversation_id_list:
db.session.query(ConversationVariable).filter( db.session.query(ConversationVariable).where(
ConversationVariable.conversation_id.in_(conversation_id_list) ConversationVariable.conversation_id.in_(conversation_id_list)
).delete(synchronize_session=False) ).delete(synchronize_session=False)
db.session.query(Conversation).filter(Conversation.id.in_(conversation_id_list)).delete( db.session.query(Conversation).where(Conversation.id.in_(conversation_id_list)).delete(
synchronize_session=False synchronize_session=False
) )
db.session.query(WorkflowRun).filter(WorkflowRun.id.in_(workflow_run_ids)).delete(synchronize_session=False) db.session.query(WorkflowRun).where(WorkflowRun.id.in_(workflow_run_ids)).delete(synchronize_session=False)
db.session.commit() db.session.commit()
return True return True

View File

@ -293,7 +293,7 @@ class AppAnnotationService:
annotation_ids_to_delete = [annotation.id for annotation, _ in annotations_to_delete] annotation_ids_to_delete = [annotation.id for annotation, _ in annotations_to_delete]
# Step 2: Bulk delete hit histories in a single query # Step 2: Bulk delete hit histories in a single query
db.session.query(AppAnnotationHitHistory).filter( db.session.query(AppAnnotationHitHistory).where(
AppAnnotationHitHistory.annotation_id.in_(annotation_ids_to_delete) AppAnnotationHitHistory.annotation_id.in_(annotation_ids_to_delete)
).delete(synchronize_session=False) ).delete(synchronize_session=False)
@ -307,7 +307,7 @@ class AppAnnotationService:
# Step 4: Bulk delete annotations in a single query # Step 4: Bulk delete annotations in a single query
deleted_count = ( deleted_count = (
db.session.query(MessageAnnotation) db.session.query(MessageAnnotation)
.filter(MessageAnnotation.id.in_(annotation_ids_to_delete)) .where(MessageAnnotation.id.in_(annotation_ids_to_delete))
.delete(synchronize_session=False) .delete(synchronize_session=False)
) )
@ -505,9 +505,9 @@ class AppAnnotationService:
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
) )
annotations_query = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id) annotations_query = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id)
for annotation in annotations_query.yield_per(100): for annotation in annotations_query.yield_per(100):
annotation_hit_histories_query = db.session.query(AppAnnotationHitHistory).filter( annotation_hit_histories_query = db.session.query(AppAnnotationHitHistory).where(
AppAnnotationHitHistory.annotation_id == annotation.id AppAnnotationHitHistory.annotation_id == annotation.id
) )
for annotation_hit_history in annotation_hit_histories_query.yield_per(100): for annotation_hit_history in annotation_hit_histories_query.yield_per(100):

View File

@ -471,7 +471,7 @@ class TestAnnotationService:
# Verify annotation was deleted # Verify annotation was deleted
from extensions.ext_database import db from extensions.ext_database import db
deleted_annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
assert deleted_annotation is None assert deleted_annotation is None
# Verify delete_annotation_index_task was called (when annotation setting exists) # Verify delete_annotation_index_task was called (when annotation setting exists)
@ -1175,7 +1175,7 @@ class TestAnnotationService:
AppAnnotationService.delete_app_annotation(app.id, annotation_id) AppAnnotationService.delete_app_annotation(app.id, annotation_id)
# Verify annotation was deleted # Verify annotation was deleted
deleted_annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
assert deleted_annotation is None assert deleted_annotation is None
# Verify delete_annotation_index_task was called # Verify delete_annotation_index_task was called

View File

@ -234,7 +234,7 @@ class TestAPIBasedExtensionService:
# Verify extension was deleted # Verify extension was deleted
from extensions.ext_database import db from extensions.ext_database import db
deleted_extension = db.session.query(APIBasedExtension).filter(APIBasedExtension.id == extension_id).first() deleted_extension = db.session.query(APIBasedExtension).where(APIBasedExtension.id == extension_id).first()
assert deleted_extension is None assert deleted_extension is None
def test_save_extension_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): def test_save_extension_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):

View File

@ -484,7 +484,7 @@ class TestMessageService:
# Verify feedback was deleted # Verify feedback was deleted
from extensions.ext_database import db from extensions.ext_database import db
deleted_feedback = db.session.query(MessageFeedback).filter(MessageFeedback.id == feedback.id).first() deleted_feedback = db.session.query(MessageFeedback).where(MessageFeedback.id == feedback.id).first()
assert deleted_feedback is None assert deleted_feedback is None
def test_create_feedback_no_rating_when_not_exists( def test_create_feedback_no_rating_when_not_exists(

View File

@ -469,6 +469,6 @@ class TestModelLoadBalancingService:
# Verify inherit config was created in database # Verify inherit config was created in database
inherit_configs = ( inherit_configs = (
db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.name == "__inherit__").all() db.session.query(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__").all()
) )
assert len(inherit_configs) == 1 assert len(inherit_configs) == 1