Autogen_core 测试代码:test_cancellation.py
目录
- 第一段代码:定义 `LongRunningAgent` 类
- 主要逻辑:
- 完成的功能:
- 第二段代码:定义 `NestingLongRunningAgent` 类
- 主要逻辑:
- 完成的功能:
- 第三段代码:测试取消功能
- 主要逻辑:
- 完成的功能:
- 第四段代码:测试嵌套取消功能(仅外层调用)
- 主要逻辑:
- 完成的功能:
第一段代码:定义 LongRunningAgent
类
import asyncio
from dataclasses import dataclass
import pytest
from autogen_core import (
AgentId,
AgentInstantiationContext,
CancellationToken,
MessageContext,
RoutedAgent,
SingleThreadedAgentRuntime,
message_handler,
)
@dataclass
class MessageType: ...
# Note for future reader:
# To do cancellation, only the token should be interacted with as a user
# If you cancel a future, it may not work as you expect.
class LongRunningAgent(RoutedAgent):
def __init__(self) -> None:
super().__init__("A long running agent")
self.called = False
self.cancelled = False
@message_handler
async def on_new_message(self, message: MessageType, ctx: MessageContext) -> MessageType:
self.called = True
sleep = asyncio.ensure_future(asyncio.sleep(100))
ctx.cancellation_token.link_future(sleep)
try:
await sleep
return MessageType()
except asyncio.CancelledError:
self.cancelled = True
raise
class NestingLongRunningAgent(RoutedAgent):
def __init__(self, nested_agent: AgentId) -> None:
super().__init__("A nesting long running agent")
self.called = False
self.cancelled = False
self._nested_agent = nested_agent
@message_handler
async def on_new_message(self, message: MessageType, ctx: MessageContext) -> MessageType:
self.called = True
response = self.send_message(message, self._nested_agent, cancellation_token=ctx.cancellation_token)
try:
val = await response
assert isinstance(val, MessageType)
return val
except asyncio.CancelledError:
self.cancelled = True
raise
这段代码定义了一个名为 LongRunningAgent
的类,它继承自 RoutedAgent
。这个类的主要功能是处理长时间运行的任务,并支持取消操作。
主要逻辑:
-
初始化:在
__init__
方法中,设置了代理的名称,并初始化了两个标志called
和cancelled
,用于跟踪任务是否被调用和是否被取消。 -
消息处理:通过装饰器
@message_handler
定义了一个异步方法on_new_message
,用于处理接收到的消息。-
设置长时间运行任务:使用
asyncio.sleep(100)
创建一个长时间运行的任务,并将其封装为asyncio.Future
对象。 -
链接取消令牌:通过
ctx.cancellation_token.link_future(sleep)
将取消令牌与长时间运行的任务链接起来,这样当令牌被取消时,任务也会被取消。 -
等待任务完成:使用
await sleep
等待任务完成。 -
处理取消:如果任务在等待过程中被取消(即接收到
asyncio.CancelledError
),则设置cancelled
标志为True
并重新抛出异常。
-
完成的功能:
-
定义了一个能够处理长时间运行任务的代理。
-
支持通过取消令牌取消长时间运行的任务。
第二段代码:定义 NestingLongRunningAgent
类
async def test_cancellation_with_token() -> None:
runtime = SingleThreadedAgentRuntime()
await LongRunningAgent.register(runtime, "long_running", LongRunningAgent)
agent_id = AgentId("long_running", key="default")
token = CancellationToken()
response = asyncio.create_task(runtime.send_message(MessageType(), recipient=agent_id, cancellation_token=token))
assert not response.done()
while runtime.unprocessed_messages_count == 0:
await asyncio.sleep(0.01)
await runtime._process_next() # type: ignore
token.cancel()
with pytest.raises(asyncio.CancelledError):
await response
assert response.done()
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LongRunningAgent)
assert long_running_agent.called
assert long_running_agent.cancelled
await test_cancellation_with_token()
这段代码定义了一个名为 NestingLongRunningAgent
的类,它也继承自 RoutedAgent
。这个类的主要功能是嵌套调用另一个代理,并处理长时间运行的任务。
主要逻辑:
-
初始化:在
__init__
方法中,设置了代理的名称,并初始化了两个标志called
和cancelled
,同时接收一个嵌套代理的AgentId
。 -
消息处理:通过装饰器
@message_handler
定义了一个异步方法on_new_message
,用于处理接收到的消息。-
调用嵌套代理:使用
self.send_message
向嵌套代理发送消息,并将取消令牌传递给嵌套代理。 -
等待响应:使用
await response
等待嵌套代理的响应。 -
处理取消:如果嵌套代理在等待过程中被取消(即接收到
asyncio.CancelledError
),则设置cancelled
标志为True
并重新抛出异常。
-
完成的功能:
-
定义了一个能够嵌套调用另一个代理的代理。
-
支持通过取消令牌取消嵌套代理的任务。
第三段代码:测试取消功能
async def test_nested_cancellation_only_outer_called() -> None:
runtime = SingleThreadedAgentRuntime()
await LongRunningAgent.register(runtime, "long_running", LongRunningAgent)
await NestingLongRunningAgent.register(
runtime,
"nested",
lambda: NestingLongRunningAgent(AgentId("long_running", key=AgentInstantiationContext.current_agent_id().key)),
)
long_running_id = AgentId("long_running", key="default")
nested_id = AgentId("nested", key="default")
token = CancellationToken()
response = asyncio.create_task(runtime.send_message(MessageType(), nested_id, cancellation_token=token))
assert not response.done()
while runtime.unprocessed_messages_count == 0:
await asyncio.sleep(0.01)
await runtime._process_next() # type: ignore
token.cancel()
with pytest.raises(asyncio.CancelledError):
await response
assert response.done()
nested_agent = await runtime.try_get_underlying_agent_instance(nested_id, type=NestingLongRunningAgent)
assert nested_agent.called
assert nested_agent.cancelled
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running_id, type=LongRunningAgent)
assert long_running_agent.called is False
assert long_running_agent.cancelled is False
await test_nested_cancellation_only_outer_called()
这段代码定义了一个异步函数 test_cancellation_with_token
,用于测试 LongRunningAgent
的取消功能。
主要逻辑:
-
初始化运行时环境:创建一个
SingleThreadedAgentRuntime
实例。 -
注册代理:注册
LongRunningAgent
。 -
创建取消令牌:创建一个
CancellationToken
实例。 -
发送消息:使用
runtime.send_message
向代理发送消息,并将取消令牌传递给代理。 -
等待代理开始处理消息:使用
while
循环等待代理开始处理消息。 -
取消任务:调用
token.cancel()
取消任务。 -
检查结果:检查任务是否被取消,并验证代理的
called
和cancelled
标志。
完成的功能:
-
测试
LongRunningAgent
的取消功能。 -
验证代理在取消操作后,能够正确地抛出
asyncio.CancelledError
异常。 -
验证代理的
called
和cancelled
标志是否正确设置。
第四段代码:测试嵌套取消功能(仅外层调用)
async def test_nested_cancellation_inner_called() -> None:
runtime = SingleThreadedAgentRuntime()
await LongRunningAgent.register(runtime, "long_running", LongRunningAgent)
await NestingLongRunningAgent.register(
runtime,
"nested",
lambda: NestingLongRunningAgent(AgentId("long_running", key=AgentInstantiationContext.current_agent_id().key)),
)
long_running_id = AgentId("long_running", key="default")
nested_id = AgentId("nested", key="default")
token = CancellationToken()
response = asyncio.create_task(runtime.send_message(MessageType(), nested_id, cancellation_token=token))
assert not response.done()
while runtime.unprocessed_messages_count == 0:
await asyncio.sleep(0.01)
await runtime._process_next() # type: ignore
# allow the inner agent to process
await runtime._process_next() # type: ignore
token.cancel()
with pytest.raises(asyncio.CancelledError):
await response
assert response.done()
nested_agent = await runtime.try_get_underlying_agent_instance(nested_id, type=NestingLongRunningAgent)
assert nested_agent.called
assert nested_agent.cancelled
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running_id, type=LongRunningAgent)
assert long_running_agent.called
assert long_running_agent.cancelled
await test_nested_cancellation_inner_called()
这段代码定义了一个异步函数 test_nested_cancellation_only_outer_called
,用于测试 NestingLongRunningAgent
的取消功能,仅取消外层代理。
主要逻辑:
-
初始化运行时环境:创建一个
SingleThreadedAgentRuntime
实例。 -
注册代理:注册
LongRunningAgent
和NestingLongRunningAgent
。 -
创建取消令牌:创建一个
CancellationToken
实例。 -
发送消息:使用
runtime.send_message
向嵌套代理发送消息,并将取消令牌传递给代理。 -
等待代理开始处理消息:使用
while
循环等待代理开始处理消息。 -
取消任务:调用
token.cancel()
取消任务。 -
检查结果:检查任务是否被取消,并验证代理的
called
和cancelled
标志。
完成的功能:
-
嵌套代理的取消传播
测试当一个父任务(NestingLongRunningAgent)被取消时,其内部的子任务(LongRunningAgent)是否也会被正确取消。-
父任务(nested)启动后触发子任务(long_running)。
-
通过 CancellationToken 取消父任务,观察子任务是否同步取消。
-
-
异步任务的状态验证
确保任务取消后,所有相关代理(Agent)的状态正确更新:-
called:标记代理是否被调用。
-
cancelled:标记代理是否被取消。
-