use axum::{ extract::{State, WebSocketUpgrade}, response::IntoResponse, }; use axum::extract::ws::{Message, WebSocket}; use futures_util::{SinkExt, StreamExt}; use tokio::sync::mpsc; use uuid::Uuid; use tracing::{debug, error, info, warn}; use helios_common::protocol::{ClientMessage, ServerMessage}; use crate::{AppState, session::Session}; pub async fn ws_upgrade( ws: WebSocketUpgrade, State(state): State, ) -> impl IntoResponse { ws.on_upgrade(move |socket| handle_socket(socket, state)) } async fn handle_socket(socket: WebSocket, state: AppState) { let (cmd_tx, mut cmd_rx) = mpsc::channel::(64); let (mut ws_tx, mut ws_rx) = socket.split(); // Wait for the Hello message to get the device label let label = loop { match ws_rx.next().await { Some(Ok(Message::Text(text))) => { match serde_json::from_str::(&text) { Ok(ClientMessage::Hello { label }) => { if label.is_empty() { warn!("Client sent empty label, disconnecting"); return; } break label; } Ok(_) => { warn!("Expected Hello as first message, got something else"); return; } Err(e) => { warn!("Invalid JSON on handshake: {e}"); return; } } } Some(Ok(Message::Close(_))) | None => return, _ => continue, } }; // Register session by label let session = Session { label: label.clone(), cmd_tx, }; state.sessions.insert(session); info!("Client connected: device={label}"); // Spawn task: forward server commands → WS let label_clone = label.clone(); let send_task = tokio::spawn(async move { while let Some(msg) = cmd_rx.recv().await { match serde_json::to_string(&msg) { Ok(json) => { if let Err(e) = ws_tx.send(Message::Text(json.into())).await { error!("WS send error for device={label_clone}: {e}"); break; } } Err(e) => { error!("Serialization error for device={label_clone}: {e}"); } } } }); // Main loop: receive client messages while let Some(result) = ws_rx.next().await { match result { Ok(Message::Text(text)) => { match serde_json::from_str::(&text) { Ok(msg) => handle_client_message(&label, msg, &state).await, Err(e) => { warn!("Invalid JSON from device={label}: {e}"); } } } Ok(Message::Close(_)) => { info!("Client disconnected gracefully: device={label}"); break; } Ok(Message::Ping(_)) | Ok(Message::Pong(_)) | Ok(Message::Binary(_)) => {} Err(e) => { error!("WS receive error for device={label}: {e}"); break; } } } send_task.abort(); state.sessions.remove(&label); info!("Session cleaned up: device={label}"); } async fn handle_client_message(label: &str, msg: ClientMessage, state: &AppState) { match &msg { ClientMessage::Hello { .. } => { debug!("Duplicate Hello from device={label}, ignoring"); } ClientMessage::ScreenshotResponse { request_id, .. } | ClientMessage::ExecResponse { request_id, .. } | ClientMessage::ListWindowsResponse { request_id, .. } | ClientMessage::VersionResponse { request_id, .. } | ClientMessage::LogsResponse { request_id, .. } | ClientMessage::DownloadResponse { request_id, .. } | ClientMessage::ClipboardGetResponse { request_id, .. } | ClientMessage::PromptResponse { request_id, .. } | ClientMessage::Ack { request_id } | ClientMessage::Error { request_id, .. } => { let rid = *request_id; if !state.sessions.resolve_pending(rid, msg) { warn!("No pending request for request_id={rid} (device={label})"); } } } }