Skip to content

Commit f528ffe

Browse files
committed
chores
1 parent f04b8b8 commit f528ffe

File tree

3 files changed

+65
-19
lines changed

3 files changed

+65
-19
lines changed

.github/workflows/test.yml

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,12 @@ jobs:
88
runs-on: ubuntu-latest
99
steps:
1010
- uses: actions/checkout@v4
11-
12-
- name: Create test file
13-
run: echo "This is a test file for R2 verification" > test-file.txt
14-
15-
- name: Copy file to R2
16-
uses: prewk/s3-cp-action@v2
11+
- uses: prewk/s3-cp-action@v2
1712
with:
1813
aws_access_key_id: ${{ secrets.CLOUDFLARE_R2_ACCESS_KEY_ID }}
1914
aws_secret_access_key: ${{ secrets.CLOUDFLARE_R2_SECRET_ACCESS_KEY }}
20-
source: test-file.txt
21-
dest: s3://hyprnote-cache2/test-file.txt
15+
source: "s3://hyprnote-cache2/v0/binaries/stt"
16+
dest: "./stt"
2217
aws_s3_endpoint: ${{ secrets.CLOUDFLARE_R2_ENDPOINT_URL }}
2318
aws_region: auto
19+
- run: ls -l

owhisper/owhisper-server/src/commands/run/mod.rs

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ use crate::{misc::shutdown_signal, Server};
1616

1717
#[derive(clap::Parser)]
1818
pub struct RunArgs {
19+
/// Model ID from the config file
20+
#[arg(value_parser = validate_model_from_config)]
1921
pub model: String,
2022

2123
/// Audio file path, '-' for stdin, or omit for microphone
@@ -32,13 +34,6 @@ pub async fn handle_run(args: RunArgs) -> anyhow::Result<()> {
3234
log::set_max_level(log::LevelFilter::Off);
3335

3436
let config = owhisper_config::Config::new(args.config.clone())?;
35-
if !config.models.iter().any(|m| m.id() == args.model) {
36-
return Err(anyhow::anyhow!(
37-
"'{}' not found in '{:?}'",
38-
args.model,
39-
owhisper_config::global_config_path()
40-
));
41-
}
4237

4338
let api_key = config.general.as_ref().and_then(|g| g.api_key.clone());
4439
let server = Server::new(config, None);
@@ -84,6 +79,29 @@ pub async fn handle_run(args: RunArgs) -> anyhow::Result<()> {
8479
Ok(())
8580
}
8681

82+
fn validate_model_from_config(s: &str) -> Result<String, String> {
83+
let config =
84+
owhisper_config::Config::new(None).map_err(|e| format!("Failed to load config: {}", e))?;
85+
86+
let model_ids: Vec<String> = config.models.iter().map(|m| m.id().to_string()).collect();
87+
88+
if model_ids.contains(&s.to_string()) {
89+
Ok(s.to_string())
90+
} else {
91+
let available = if model_ids.is_empty() {
92+
"No models found in config".to_string()
93+
} else {
94+
format!("Available models: {}", model_ids.join(", "))
95+
};
96+
Err(format!(
97+
"'{}' not found in config at '{:?}'. {}",
98+
s,
99+
owhisper_config::global_config_path(),
100+
available
101+
))
102+
}
103+
}
104+
87105
enum InputMode {
88106
File(String),
89107
Stdin,

owhisper/owhisper-server/src/server.rs

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ impl Server {
4444
Self { config, port }
4545
}
4646

47-
pub async fn build_router(&self) -> anyhow::Result<Router> {
47+
pub async fn build_router(&self) -> anyhow::Result<Router<()>> {
4848
let api_key = self.config.general.as_ref().and_then(|g| g.api_key.clone());
4949

5050
let mut services = HashMap::new();
@@ -77,9 +77,13 @@ impl Server {
7777
let app_state = Arc::new(AppState { api_key, services });
7878

7979
let stt_router = self.build_stt_router(app_state.clone()).await;
80-
81-
let app = Router::new()
80+
let other_router = Router::new()
8281
.route("/health", axum::routing::get(health))
82+
.route("/models", axum::routing::get(list_models))
83+
.route("/v1/models", axum::routing::get(list_models))
84+
.with_state(app_state.clone());
85+
86+
let app = other_router
8387
.merge(stt_router)
8488
// .layer(middleware::from_fn_with_state(
8589
// app_state.clone(),
@@ -125,7 +129,7 @@ impl Server {
125129
Ok(addr.port())
126130
}
127131

128-
async fn build_stt_router(&self, app_state: Arc<AppState>) -> Router {
132+
async fn build_stt_router(&self, app_state: Arc<AppState>) -> Router<()> {
129133
Router::new()
130134
.route("/listen", axum::routing::any(handle_transcription))
131135
.route("/v1/listen", axum::routing::any(handle_transcription))
@@ -249,6 +253,34 @@ async fn health() -> &'static str {
249253
"OK"
250254
}
251255

256+
#[derive(serde::Serialize)]
257+
struct ModelInfo {
258+
id: String,
259+
object: String,
260+
}
261+
262+
#[derive(serde::Serialize)]
263+
struct ModelsResponse {
264+
object: String,
265+
data: Vec<ModelInfo>,
266+
}
267+
268+
async fn list_models(State(state): State<Arc<AppState>>) -> axum::Json<ModelsResponse> {
269+
let models: Vec<ModelInfo> = state
270+
.services
271+
.keys()
272+
.map(|id| ModelInfo {
273+
id: id.clone(),
274+
object: "model".to_string(),
275+
})
276+
.collect();
277+
278+
axum::Json(ModelsResponse {
279+
object: "list".to_string(),
280+
data: models,
281+
})
282+
}
283+
252284
async fn auth_middleware(
253285
State(state): State<Arc<AppState>>,
254286
token_header: Option<TypedHeader<Authorization<Token>>>,

0 commit comments

Comments
 (0)