100 lines
3.4 KiB
Rust
100 lines
3.4 KiB
Rust
use axum::{
|
|
extract::{State, WebSocketUpgrade},
|
|
response::IntoResponse,
|
|
};
|
|
use axum::extract::ws::{Message, WebSocket};
|
|
use futures_util::{SinkExt, StreamExt};
|
|
use tokio::sync::mpsc;
|
|
use uuid::Uuid;
|
|
use tracing::{debug, error, info, warn};
|
|
|
|
use helios_common::protocol::{ClientMessage, ServerMessage};
|
|
use crate::{AppState, session::Session};
|
|
|
|
pub async fn ws_upgrade(
|
|
ws: WebSocketUpgrade,
|
|
State(state): State<AppState>,
|
|
) -> impl IntoResponse {
|
|
ws.on_upgrade(move |socket| handle_socket(socket, state))
|
|
}
|
|
|
|
async fn handle_socket(socket: WebSocket, state: AppState) {
|
|
let session_id = Uuid::new_v4();
|
|
let (cmd_tx, mut cmd_rx) = mpsc::channel::<ServerMessage>(64);
|
|
|
|
// Register session (label filled in on Hello)
|
|
let session = Session {
|
|
id: session_id,
|
|
label: None,
|
|
cmd_tx,
|
|
};
|
|
state.sessions.insert(session);
|
|
info!("Client connected: session={session_id}");
|
|
|
|
let (mut ws_tx, mut ws_rx) = socket.split();
|
|
|
|
// Spawn task: forward server commands → WS
|
|
let send_task = tokio::spawn(async move {
|
|
while let Some(msg) = cmd_rx.recv().await {
|
|
match serde_json::to_string(&msg) {
|
|
Ok(json) => {
|
|
if let Err(e) = ws_tx.send(Message::Text(json.into())).await {
|
|
error!("WS send error for session={session_id}: {e}");
|
|
break;
|
|
}
|
|
}
|
|
Err(e) => {
|
|
error!("Serialization error for session={session_id}: {e}");
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
// Main loop: receive client messages
|
|
while let Some(result) = ws_rx.next().await {
|
|
match result {
|
|
Ok(Message::Text(text)) => {
|
|
match serde_json::from_str::<ClientMessage>(&text) {
|
|
Ok(msg) => handle_client_message(session_id, msg, &state).await,
|
|
Err(e) => {
|
|
warn!("Invalid JSON from session={session_id}: {e} | raw={text}");
|
|
}
|
|
}
|
|
}
|
|
Ok(Message::Close(_)) => {
|
|
info!("Client disconnected gracefully: session={session_id}");
|
|
break;
|
|
}
|
|
Ok(Message::Ping(_)) | Ok(Message::Pong(_)) | Ok(Message::Binary(_)) => {}
|
|
Err(e) => {
|
|
error!("WS receive error for session={session_id}: {e}");
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
send_task.abort();
|
|
state.sessions.remove(&session_id);
|
|
info!("Session cleaned up: session={session_id}");
|
|
}
|
|
|
|
async fn handle_client_message(session_id: Uuid, msg: ClientMessage, state: &AppState) {
|
|
match &msg {
|
|
ClientMessage::Hello { label } => {
|
|
if let Some(lbl) = label {
|
|
state.sessions.set_label(&session_id, lbl.clone());
|
|
}
|
|
debug!("Hello from session={session_id}, label={label:?}");
|
|
}
|
|
ClientMessage::ScreenshotResponse { request_id, .. }
|
|
| ClientMessage::ExecResponse { request_id, .. }
|
|
| ClientMessage::ListWindowsResponse { request_id, .. }
|
|
| ClientMessage::Ack { request_id }
|
|
| ClientMessage::Error { request_id, .. } => {
|
|
let rid = *request_id;
|
|
if !state.sessions.resolve_pending(rid, msg) {
|
|
warn!("No pending request for request_id={rid} (session={session_id})");
|
|
}
|
|
}
|
|
}
|
|
}
|