Skip to content

Commit 17cfe20

Browse files
hjlarryjiangzhijie
authored and
jiangzhijie
committed
feat: add the audio tool (langgenius#10695)
1 parent de6065b commit 17cfe20

File tree

7 files changed

+224
-0
lines changed

7 files changed

+224
-0
lines changed
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
2+
3+
4+
class AudioToolProvider(BuiltinToolProviderController):
5+
def _validate_credentials(self, credentials: dict) -> None:
6+
pass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
identity:
2+
author: hjlarry
3+
name: audio
4+
label:
5+
en_US: Audio
6+
description:
7+
en_US: A tool for tts and asr.
8+
zh_Hans: 一个用于文本转语音和语音转文本的工具。
9+
icon: icon.svg
10+
tags:
11+
- utilities
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import io
2+
from typing import Any
3+
4+
from core.file.enums import FileType
5+
from core.file.file_manager import download
6+
from core.model_manager import ModelManager
7+
from core.model_runtime.entities.model_entities import ModelType
8+
from core.tools.entities.common_entities import I18nObject
9+
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
10+
from core.tools.tool.builtin_tool import BuiltinTool
11+
from services.model_provider_service import ModelProviderService
12+
13+
14+
class ASRTool(BuiltinTool):
15+
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
16+
file = tool_parameters.get("audio_file")
17+
if file.type != FileType.AUDIO:
18+
return [self.create_text_message("not a valid audio file")]
19+
audio_binary = io.BytesIO(download(file))
20+
audio_binary.name = "temp.mp3"
21+
provider, model = tool_parameters.get("model").split("#")
22+
model_manager = ModelManager()
23+
model_instance = model_manager.get_model_instance(
24+
tenant_id=self.runtime.tenant_id,
25+
provider=provider,
26+
model_type=ModelType.SPEECH2TEXT,
27+
model=model,
28+
)
29+
text = model_instance.invoke_speech2text(
30+
file=audio_binary,
31+
user=user_id,
32+
)
33+
return [self.create_text_message(text)]
34+
35+
def get_available_models(self) -> list[tuple[str, str]]:
36+
model_provider_service = ModelProviderService()
37+
models = model_provider_service.get_models_by_model_type(
38+
tenant_id=self.runtime.tenant_id, model_type="speech2text"
39+
)
40+
items = []
41+
for provider_model in models:
42+
provider = provider_model.provider
43+
for model in provider_model.models:
44+
items.append((provider, model.model))
45+
return items
46+
47+
def get_runtime_parameters(self) -> list[ToolParameter]:
48+
parameters = []
49+
50+
options = []
51+
for provider, model in self.get_available_models():
52+
option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
53+
options.append(option)
54+
55+
parameters.append(
56+
ToolParameter(
57+
name="model",
58+
label=I18nObject(en_US="Model", zh_Hans="Model"),
59+
human_description=I18nObject(
60+
en_US="All available ASR models",
61+
zh_Hans="所有可用的 ASR 模型",
62+
),
63+
type=ToolParameter.ToolParameterType.SELECT,
64+
form=ToolParameter.ToolParameterForm.FORM,
65+
required=True,
66+
default=options[0].value,
67+
options=options,
68+
)
69+
)
70+
return parameters
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
identity:
2+
name: asr
3+
author: hjlarry
4+
label:
5+
en_US: Speech To Text
6+
description:
7+
human:
8+
en_US: Convert audio file to text.
9+
zh_Hans: 将音频文件转换为文本。
10+
llm: Convert audio file to text.
11+
parameters:
12+
- name: audio_file
13+
type: file
14+
required: true
15+
label:
16+
en_US: Audio File
17+
zh_Hans: 音频文件
18+
human_description:
19+
en_US: The audio file to be converted.
20+
zh_Hans: 要转换的音频文件。
21+
llm_description: The audio file to be converted.
22+
form: llm
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import io
2+
from typing import Any
3+
4+
from core.model_manager import ModelManager
5+
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
6+
from core.tools.entities.common_entities import I18nObject
7+
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
8+
from core.tools.tool.builtin_tool import BuiltinTool
9+
from services.model_provider_service import ModelProviderService
10+
11+
12+
class TTSTool(BuiltinTool):
13+
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
14+
provider, model = tool_parameters.get("model").split("#")
15+
voice = tool_parameters.get(f"voice#{provider}#{model}")
16+
model_manager = ModelManager()
17+
model_instance = model_manager.get_model_instance(
18+
tenant_id=self.runtime.tenant_id,
19+
provider=provider,
20+
model_type=ModelType.TTS,
21+
model=model,
22+
)
23+
tts = model_instance.invoke_tts(
24+
content_text=tool_parameters.get("text"),
25+
user=user_id,
26+
tenant_id=self.runtime.tenant_id,
27+
voice=voice,
28+
)
29+
buffer = io.BytesIO()
30+
for chunk in tts:
31+
buffer.write(chunk)
32+
33+
wav_bytes = buffer.getvalue()
34+
return [
35+
self.create_text_message("Audio generated successfully"),
36+
self.create_blob_message(
37+
blob=wav_bytes,
38+
meta={"mime_type": "audio/x-wav"},
39+
save_as=self.VariableKey.AUDIO,
40+
),
41+
]
42+
43+
def get_available_models(self) -> list[tuple[str, str, list[Any]]]:
44+
model_provider_service = ModelProviderService()
45+
models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts")
46+
items = []
47+
for provider_model in models:
48+
provider = provider_model.provider
49+
for model in provider_model.models:
50+
voices = model.model_properties.get(ModelPropertyKey.VOICES, [])
51+
items.append((provider, model.model, voices))
52+
return items
53+
54+
def get_runtime_parameters(self) -> list[ToolParameter]:
55+
parameters = []
56+
57+
options = []
58+
for provider, model, voices in self.get_available_models():
59+
option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
60+
options.append(option)
61+
parameters.append(
62+
ToolParameter(
63+
name=f"voice#{provider}#{model}",
64+
label=I18nObject(en_US=f"Voice of {model}({provider})"),
65+
type=ToolParameter.ToolParameterType.SELECT,
66+
form=ToolParameter.ToolParameterForm.FORM,
67+
options=[
68+
ToolParameterOption(value=voice.get("mode"), label=I18nObject(en_US=voice.get("name")))
69+
for voice in voices
70+
],
71+
)
72+
)
73+
74+
parameters.insert(
75+
0,
76+
ToolParameter(
77+
name="model",
78+
label=I18nObject(en_US="Model", zh_Hans="Model"),
79+
human_description=I18nObject(
80+
en_US="All available TTS models",
81+
zh_Hans="所有可用的 TTS 模型",
82+
),
83+
type=ToolParameter.ToolParameterType.SELECT,
84+
form=ToolParameter.ToolParameterForm.FORM,
85+
required=True,
86+
default=options[0].value,
87+
options=options,
88+
),
89+
)
90+
return parameters
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
identity:
2+
name: tts
3+
author: hjlarry
4+
label:
5+
en_US: Text To Speech
6+
description:
7+
human:
8+
en_US: Convert text to audio file.
9+
zh_Hans: 将文本转换为音频文件。
10+
llm: Convert text to audio file.
11+
parameters:
12+
- name: text
13+
type: string
14+
required: true
15+
label:
16+
en_US: Text
17+
zh_Hans: 文本
18+
human_description:
19+
en_US: The text to be converted.
20+
zh_Hans: 要转换的文本。
21+
llm_description: The text to be converted.
22+
form: llm

0 commit comments

Comments
 (0)