use axum::{ extract::{State, WebSocketUpgrade}, response::IntoResponse, }; use axum::extract::ws::{Message, WebSocket}; use futures_util::{SinkExt, StreamExt}; use tokio::sync::mpsc; use uuid::Uuid; use tracing::{debug, error, info, warn}; use helios_common::protocol::{ClientMessage, ServerMessage}; use crate::{AppState, session::Session}; pub async fn ws_upgrade( ws: WebSocketUpgrade, State(state): State, ) -> impl IntoResponse { ws.on_upgrade(move |socket| handle_socket(socket, state)) } async fn handle_socket(socket: WebSocket, state: AppState) { let session_id = Uuid::new_v4(); let (cmd_tx, mut cmd_rx) = mpsc::channel::(64); // Register session (label filled in on Hello) let session = Session { id: session_id, label: None, cmd_tx, }; state.sessions.insert(session); info!("Client connected: session={session_id}"); let (mut ws_tx, mut ws_rx) = socket.split(); // Spawn task: forward server commands → WS let send_task = tokio::spawn(async move { while let Some(msg) = cmd_rx.recv().await { match serde_json::to_string(&msg) { Ok(json) => { if let Err(e) = ws_tx.send(Message::Text(json.into())).await { error!("WS send error for session={session_id}: {e}"); break; } } Err(e) => { error!("Serialization error for session={session_id}: {e}"); } } } }); // Main loop: receive client messages while let Some(result) = ws_rx.next().await { match result { Ok(Message::Text(text)) => { match serde_json::from_str::(&text) { Ok(msg) => handle_client_message(session_id, msg, &state).await, Err(e) => { warn!("Invalid JSON from session={session_id}: {e} | raw={text}"); } } } Ok(Message::Close(_)) => { info!("Client disconnected gracefully: session={session_id}"); break; } Ok(Message::Ping(_)) | Ok(Message::Pong(_)) | Ok(Message::Binary(_)) => {} Err(e) => { error!("WS receive error for session={session_id}: {e}"); break; } } } send_task.abort(); state.sessions.remove(&session_id); info!("Session cleaned up: session={session_id}"); } async fn handle_client_message(session_id: Uuid, msg: ClientMessage, state: &AppState) { match &msg { ClientMessage::Hello { label } => { if let Some(lbl) = label { state.sessions.set_label(&session_id, lbl.clone()); } debug!("Hello from session={session_id}, label={label:?}"); } ClientMessage::ScreenshotResponse { request_id, .. } | ClientMessage::ExecResponse { request_id, .. } | ClientMessage::ListWindowsResponse { request_id, .. } | ClientMessage::Ack { request_id } | ClientMessage::Error { request_id, .. } => { let rid = *request_id; if !state.sessions.resolve_pending(rid, msg) { warn!("No pending request for request_id={rid} (session={session_id})"); } } } }