-
Notifications
You must be signed in to change notification settings - Fork 617
/
Copy pathdemo_agent.py
118 lines (101 loc) · 4.29 KB
/
demo_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# os.environ['SWIFT_DEBUG'] = '1'
def infer(engine: 'InferEngine', infer_request: 'InferRequest'):
stop = [engine.default_template.agent_template.keyword.observation] # compat react_en
request_config = RequestConfig(max_tokens=512, temperature=0, stop=stop)
resp_list = engine.infer([infer_request], request_config)
query = infer_request.messages[0]['content']
response = resp_list[0].choices[0].message.content
print(f'query: {query}')
print(f'response: {response}')
print(f'tool_calls: {resp_list[0].choices[0].message.tool_calls}')
tool = '{"temperature": 32, "condition": "Sunny", "humidity": 50}'
print(f'tool_response: {tool}')
infer_request.messages += [{'role': 'assistant', 'content': response}, {'role': 'tool', 'content': tool}]
resp_list = engine.infer([infer_request], request_config)
response2 = resp_list[0].choices[0].message.content
print(f'response2: {response2}')
def infer_stream(engine: 'InferEngine', infer_request: 'InferRequest'):
stop = [engine.default_template.agent_template.keyword.observation]
request_config = RequestConfig(max_tokens=512, temperature=0, stream=True, stop=stop)
gen_list = engine.infer([infer_request], request_config)
query = infer_request.messages[0]['content']
response = ''
print(f'query: {query}\nresponse: ', end='')
for resp in gen_list[0]:
if resp is None:
continue
delta = resp.choices[0].delta.content
response += delta
print(delta, end='', flush=True)
print()
print(f'tool_calls: {resp.choices[0].delta.tool_calls}')
tool = '{"temperature": 32, "condition": "Sunny", "humidity": 50}'
print(f'tool_response: {tool}\nresponse2: ', end='')
infer_request.messages += [{'role': 'assistant', 'content': response}, {'role': 'tool', 'content': tool}]
gen_list = engine.infer([infer_request], request_config)
for resp in gen_list[0]:
if resp is None:
continue
print(resp.choices[0].delta.content, end='', flush=True)
print()
def get_infer_request():
return InferRequest(
messages=[{
'role': 'user',
'content': "How's the weather in Beijing today?"
}],
tools=[{
'name': 'get_current_weather',
'description': 'Get the current weather in a given location',
'parameters': {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'The city and state, e.g. San Francisco, CA'
},
'unit': {
'type': 'string',
'enum': ['celsius', 'fahrenheit']
}
},
'required': ['location']
}
}])
def infer_continue_generate(engine):
# Continue generating after the assistant message.
infer_request = InferRequest(messages=[{
'role': 'user',
'content': 'How is the weather today?'
}, {
'role': 'assistant',
'content': 'It is sunny today, '
}, {
'role': 'assistant',
'content': None
}])
request_config = RequestConfig(max_tokens=512, temperature=0)
resp_list = engine.infer([infer_request], request_config)
response = resp_list[0].choices[0].message.content
print(f'response: {response}')
if __name__ == '__main__':
from swift.llm import InferEngine, InferRequest, PtEngine, RequestConfig
from swift.plugin import agent_templates
model = 'Qwen/Qwen2.5-1.5B-Instruct'
infer_backend = 'pt'
if infer_backend == 'pt':
engine = PtEngine(model, max_batch_size=64)
elif infer_backend == 'vllm':
from swift.llm import VllmEngine
engine = VllmEngine(model, max_model_len=8192)
elif infer_backend == 'lmdeploy':
from swift.llm import LmdeployEngine
engine = LmdeployEngine(model)
# agent_template = agent_templates['hermes']() # react_en/qwen_en/qwen_en_parallel
# engine.default_template.agent_template = agent_template
infer(engine, get_infer_request())
infer_stream(engine, get_infer_request())
# infer_continue_generate(engine)