commit 7285a33cff04347be3820769197bbbce38824a73 Author: Helios Date: Mon Mar 2 18:03:46 2026 +0100 Initial implementation: relay server + common protocol + client stub diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..13da6f8 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,25 @@ +name: CI + +on: + push: + branches: ["main", "master"] + pull_request: + +jobs: + build-and-test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Install Rust (stable) + uses: dtolnay/rust-toolchain@stable + + - name: Cache dependencies + uses: Swatinem/rust-cache@v2 + + - name: Build + run: cargo build --workspace --verbose + + - name: Test + run: cargo test --workspace --verbose diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6e65eaf --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +/target +Cargo.lock +**/*.rs.bk +.env +*.pdb diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..40cb1dd --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,7 @@ +[workspace] +members = [ + "crates/common", + "crates/server", + "crates/client", +] +resolver = "2" diff --git a/README.md b/README.md new file mode 100644 index 0000000..00f512a --- /dev/null +++ b/README.md @@ -0,0 +1,110 @@ +# helios-remote + +**AI-first remote control tool** — a relay server + Windows client written in Rust. Lets an AI agent (or any HTTP client) take full control of a remote Windows machine via a lightweight WebSocket relay. + +## Architecture + +``` +helios-remote/ +├── crates/ +│ ├── common/ # Shared protocol types, WebSocket message definitions +│ ├── server/ # Relay server (REST API + WebSocket hub) +│ └── client/ # Windows client — Phase 2 (stub only) +├── Cargo.toml # Workspace root +└── README.md +``` + +### How It Works + +``` +AI Agent + │ REST API (X-Api-Key) + ▼ +helios-server ──WebSocket── helios-client (Windows) + │ │ +POST /sessions/:id/screenshot │ Captures screen → base64 PNG +POST /sessions/:id/exec │ Runs command in persistent shell +POST /sessions/:id/click │ Simulates mouse click +POST /sessions/:id/type │ Types text +``` + +1. The **Windows client** connects to the relay server via WebSocket and sends a `Hello` message. +2. The **AI agent** calls the REST API to issue commands. +3. The relay server forwards commands to the correct client session and streams back responses. + +## Server + +### REST API + +All endpoints require the `X-Api-Key` header. + +| Method | Path | Description | +|---|---|---| +| `GET` | `/sessions` | List all connected clients | +| `POST` | `/sessions/:id/screenshot` | Request a screenshot (returns base64 PNG) | +| `POST` | `/sessions/:id/exec` | Execute a shell command | +| `POST` | `/sessions/:id/click` | Simulate a mouse click | +| `POST` | `/sessions/:id/type` | Type text | +| `POST` | `/sessions/:id/label` | Rename a session | + +### WebSocket + +Clients connect to `ws://host:3000/ws`. No auth required at the transport layer — the server trusts all WS connections as client agents. + +### Running the Server + +```bash +HELIOS_API_KEY=your-secret-key HELIOS_BIND=0.0.0.0:3000 cargo run -p helios-server +``` + +Environment variables: + +| Variable | Default | Description | +|---|---|---| +| `HELIOS_API_KEY` | `dev-secret` | API key for REST endpoints | +| `HELIOS_BIND` | `0.0.0.0:3000` | Listen address | +| `RUST_LOG` | `helios_server=debug` | Log level | + +### Example API Usage + +```bash +# List sessions +curl -H "X-Api-Key: your-secret-key" http://localhost:3000/sessions + +# Take a screenshot +curl -s -X POST -H "X-Api-Key: your-secret-key" \ + http://localhost:3000/sessions//screenshot + +# Run a command +curl -s -X POST -H "X-Api-Key: your-secret-key" \ + -H "Content-Type: application/json" \ + -d '{"command": "whoami"}' \ + http://localhost:3000/sessions//exec + +# Click at coordinates +curl -s -X POST -H "X-Api-Key: your-secret-key" \ + -H "Content-Type: application/json" \ + -d '{"x": 100, "y": 200, "button": "left"}' \ + http://localhost:3000/sessions//click +``` + +## Client (Phase 2) + +See [`crates/client/README.md`](crates/client/README.md) for the planned Windows client implementation. + +## Development + +```bash +# Build everything +cargo build + +# Run tests +cargo test + +# Run server in dev mode +RUST_LOG=debug cargo run -p helios-server +``` + +## License + +MIT diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml new file mode 100644 index 0000000..c1b8ecb --- /dev/null +++ b/crates/client/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "helios-client" +version = "0.1.0" +edition = "2021" + +# Phase 2 — Windows client (not yet implemented) +# See README.md in this crate for the planned implementation. + +[dependencies] +helios-common = { path = "../common" } diff --git a/crates/client/README.md b/crates/client/README.md new file mode 100644 index 0000000..3089c3e --- /dev/null +++ b/crates/client/README.md @@ -0,0 +1,32 @@ +# helios-client (Phase 2 — not yet implemented) + +This crate will contain the Windows remote-control client for `helios-remote`. + +## Planned Features + +- Connects to the relay server via WebSocket (`wss://`) +- Sends a `Hello` message on connect with an optional display label +- Handles incoming `ServerMessage` commands: + - `ScreenshotRequest` → captures the primary display (Windows GDI or `windows-capture`) and responds with base64 PNG + - `ExecRequest` → runs a shell command in a persistent `cmd.exe` / PowerShell session and returns stdout/stderr/exit-code + - `ClickRequest` → simulates a mouse click via `SendInput` Win32 API + - `TypeRequest` → types text via `SendInput` (virtual key events) +- Persistent shell session so `cd C:\Users` persists across `exec` calls +- Auto-reconnect with exponential backoff +- Configurable via environment variables or a `client.toml` config file + +## Planned Tech Stack + +| Crate | Purpose | +|---|---| +| `tokio` | Async runtime | +| `tokio-tungstenite` | WebSocket client | +| `serde_json` | Protocol serialization | +| `windows` / `winapi` | Screen capture, mouse/keyboard input | +| `base64` | PNG encoding for screenshots | + +## Build Target + +``` +cargo build --target x86_64-pc-windows-gnu +``` diff --git a/crates/client/src/main.rs b/crates/client/src/main.rs new file mode 100644 index 0000000..8740dca --- /dev/null +++ b/crates/client/src/main.rs @@ -0,0 +1,7 @@ +// helios-client — Phase 2 (not yet implemented) +// See crates/client/README.md for the planned implementation. + +fn main() { + eprintln!("helios-client is not yet implemented. See crates/client/README.md."); + std::process::exit(1); +} diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml new file mode 100644 index 0000000..23ee522 --- /dev/null +++ b/crates/common/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "helios-common" +version = "0.1.0" +edition = "2021" + +[dependencies] +serde = { version = "1", features = ["derive"] } +serde_json = "1" +uuid = { version = "1", features = ["v4", "serde"] } diff --git a/crates/common/src/error.rs b/crates/common/src/error.rs new file mode 100644 index 0000000..7146e95 --- /dev/null +++ b/crates/common/src/error.rs @@ -0,0 +1,35 @@ +use std::fmt; + +#[derive(Debug)] +pub enum HeliosError { + /// WebSocket protocol error + Protocol(String), + /// JSON serialization/deserialization error + Serialization(String), + /// Session not found + SessionNotFound(String), + /// Request timed out waiting for client response + Timeout(String), + /// Generic internal error + Internal(String), +} + +impl fmt::Display for HeliosError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + HeliosError::Protocol(msg) => write!(f, "Protocol error: {msg}"), + HeliosError::Serialization(msg) => write!(f, "Serialization error: {msg}"), + HeliosError::SessionNotFound(id) => write!(f, "Session not found: {id}"), + HeliosError::Timeout(msg) => write!(f, "Request timed out: {msg}"), + HeliosError::Internal(msg) => write!(f, "Internal error: {msg}"), + } + } +} + +impl std::error::Error for HeliosError {} + +impl From for HeliosError { + fn from(e: serde_json::Error) -> Self { + HeliosError::Serialization(e.to_string()) + } +} diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs new file mode 100644 index 0000000..3402b30 --- /dev/null +++ b/crates/common/src/lib.rs @@ -0,0 +1,5 @@ +pub mod protocol; +pub mod error; + +pub use protocol::*; +pub use error::*; diff --git a/crates/common/src/protocol.rs b/crates/common/src/protocol.rs new file mode 100644 index 0000000..98db984 --- /dev/null +++ b/crates/common/src/protocol.rs @@ -0,0 +1,118 @@ +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// Messages sent from the relay server to a connected client +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ServerMessage { + /// Request a screenshot from the client + ScreenshotRequest { request_id: Uuid }, + /// Execute a shell command on the client + ExecRequest { + request_id: Uuid, + command: String, + }, + /// Simulate a mouse click + ClickRequest { + request_id: Uuid, + x: i32, + y: i32, + button: MouseButton, + }, + /// Type text on the client + TypeRequest { + request_id: Uuid, + text: String, + }, + /// Acknowledge a client message + Ack { request_id: Uuid }, + /// Server-side error response + Error { + request_id: Option, + message: String, + }, +} + +/// Messages sent from the client to the relay server +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ClientMessage { + /// Client registers itself with optional display name + Hello { label: Option }, + /// Response to a screenshot request — base64-encoded PNG + ScreenshotResponse { + request_id: Uuid, + image_base64: String, + width: u32, + height: u32, + }, + /// Response to an exec request + ExecResponse { + request_id: Uuid, + stdout: String, + stderr: String, + exit_code: i32, + }, + /// Generic acknowledgement for click/type + Ack { request_id: Uuid }, + /// Client error response + Error { + request_id: Uuid, + message: String, + }, +} + +/// Mouse button variants +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum MouseButton { + Left, + Right, + Middle, +} + +impl Default for MouseButton { + fn default() -> Self { + MouseButton::Left + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_server_message_serialization() { + let msg = ServerMessage::ExecRequest { + request_id: Uuid::nil(), + command: "echo hello".into(), + }; + let json = serde_json::to_string(&msg).unwrap(); + assert!(json.contains("exec_request")); + assert!(json.contains("echo hello")); + } + + #[test] + fn test_client_message_serialization() { + let msg = ClientMessage::Hello { label: Some("test-pc".into()) }; + let json = serde_json::to_string(&msg).unwrap(); + assert!(json.contains("hello")); + assert!(json.contains("test-pc")); + } + + #[test] + fn test_roundtrip() { + let msg = ClientMessage::ExecResponse { + request_id: Uuid::nil(), + stdout: "hello\n".into(), + stderr: String::new(), + exit_code: 0, + }; + let json = serde_json::to_string(&msg).unwrap(); + let decoded: ClientMessage = serde_json::from_str(&json).unwrap(); + match decoded { + ClientMessage::ExecResponse { exit_code, .. } => assert_eq!(exit_code, 0), + _ => panic!("wrong variant"), + } + } +} diff --git a/crates/server/Cargo.toml b/crates/server/Cargo.toml new file mode 100644 index 0000000..14a4f8b --- /dev/null +++ b/crates/server/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "helios-server" +version = "0.1.0" +edition = "2021" + +[[bin]] +name = "helios-server" +path = "src/main.rs" + +[dependencies] +helios-common = { path = "../common" } +tokio = { version = "1", features = ["full"] } +axum = { version = "0.7", features = ["ws"] } +tower = "0.4" +tower-http = { version = "0.5", features = ["trace"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +uuid = { version = "1", features = ["v4", "serde"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +tokio-tungstenite = "0.21" +futures-util = "0.3" +dashmap = "5" +anyhow = "1" diff --git a/crates/server/src/api.rs b/crates/server/src/api.rs new file mode 100644 index 0000000..e590500 --- /dev/null +++ b/crates/server/src/api.rs @@ -0,0 +1,263 @@ +use std::time::Duration; +use axum::{ + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, + Json, +}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; +use tracing::error; + +use helios_common::protocol::{ClientMessage, MouseButton, ServerMessage}; +use crate::AppState; + +const REQUEST_TIMEOUT: Duration = Duration::from_secs(30); + +// ── Response types ────────────────────────────────────────────────────────── + +#[derive(Serialize)] +pub struct ErrorBody { + pub error: String, +} + +fn not_found(session_id: &str) -> (StatusCode, Json) { + ( + StatusCode::NOT_FOUND, + Json(ErrorBody { + error: format!("Session '{session_id}' not found or not connected"), + }), + ) +} + +fn timeout_error(session_id: &str, op: &str) -> (StatusCode, Json) { + ( + StatusCode::GATEWAY_TIMEOUT, + Json(ErrorBody { + error: format!( + "Timed out waiting for client response (session='{session_id}', op='{op}')" + ), + }), + ) +} + +fn send_error(session_id: &str, op: &str) -> (StatusCode, Json) { + ( + StatusCode::BAD_GATEWAY, + Json(ErrorBody { + error: format!( + "Failed to send command to client — client may have disconnected (session='{session_id}', op='{op}')" + ), + }), + ) +} + +// ── Helper to send a command and await the response ───────────────────────── + +async fn dispatch( + state: &AppState, + session_id: &str, + op: &str, + make_msg: F, +) -> Result)> +where + F: FnOnce(Uuid) -> ServerMessage, +{ + let id = session_id.parse::().map_err(|_| { + ( + StatusCode::BAD_REQUEST, + Json(ErrorBody { + error: format!("Invalid session id: '{session_id}'"), + }), + ) + })?; + + let tx = state + .sessions + .get_cmd_tx(&id) + .ok_or_else(|| not_found(session_id))?; + + let request_id = Uuid::new_v4(); + let rx = state.sessions.register_pending(request_id); + let msg = make_msg(request_id); + + tx.send(msg).await.map_err(|e| { + error!("Channel send failed for session={session_id}, op={op}: {e}"); + send_error(session_id, op) + })?; + + match tokio::time::timeout(REQUEST_TIMEOUT, rx).await { + Ok(Ok(response)) => Ok(response), + Ok(Err(_)) => Err(send_error(session_id, op)), + Err(_) => Err(timeout_error(session_id, op)), + } +} + +// ── Handlers ───────────────────────────────────────────────────────────────── + +/// GET /sessions — list all connected clients +pub async fn list_sessions(State(state): State) -> Json { + let sessions = state.sessions.list(); + Json(serde_json::json!({ "sessions": sessions })) +} + +/// POST /sessions/:id/screenshot +pub async fn request_screenshot( + Path(session_id): Path, + State(state): State, +) -> impl IntoResponse { + match dispatch(&state, &session_id, "screenshot", |rid| { + ServerMessage::ScreenshotRequest { request_id: rid } + }) + .await + { + Ok(ClientMessage::ScreenshotResponse { + image_base64, + width, + height, + .. + }) => ( + StatusCode::OK, + Json(serde_json::json!({ + "image_base64": image_base64, + "width": width, + "height": height, + })), + ) + .into_response(), + Ok(ClientMessage::Error { message, .. }) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ "error": message })), + ) + .into_response(), + Ok(_) => ( + StatusCode::BAD_GATEWAY, + Json(serde_json::json!({ "error": "Unexpected response from client" })), + ) + .into_response(), + Err(e) => e.into_response(), + } +} + +/// POST /sessions/:id/exec +#[derive(Deserialize)] +pub struct ExecBody { + pub command: String, +} + +pub async fn request_exec( + Path(session_id): Path, + State(state): State, + Json(body): Json, +) -> impl IntoResponse { + match dispatch(&state, &session_id, "exec", |rid| ServerMessage::ExecRequest { + request_id: rid, + command: body.command.clone(), + }) + .await + { + Ok(ClientMessage::ExecResponse { + stdout, + stderr, + exit_code, + .. + }) => ( + StatusCode::OK, + Json(serde_json::json!({ + "stdout": stdout, + "stderr": stderr, + "exit_code": exit_code, + })), + ) + .into_response(), + Ok(ClientMessage::Error { message, .. }) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ "error": message })), + ) + .into_response(), + Ok(_) => ( + StatusCode::BAD_GATEWAY, + Json(serde_json::json!({ "error": "Unexpected response from client" })), + ) + .into_response(), + Err(e) => e.into_response(), + } +} + +/// POST /sessions/:id/click +#[derive(Deserialize)] +pub struct ClickBody { + pub x: i32, + pub y: i32, + #[serde(default)] + pub button: MouseButton, +} + +pub async fn request_click( + Path(session_id): Path, + State(state): State, + Json(body): Json, +) -> impl IntoResponse { + match dispatch(&state, &session_id, "click", |rid| ServerMessage::ClickRequest { + request_id: rid, + x: body.x, + y: body.y, + button: body.button.clone(), + }) + .await + { + Ok(_) => (StatusCode::OK, Json(serde_json::json!({ "ok": true }))).into_response(), + Err(e) => e.into_response(), + } +} + +/// POST /sessions/:id/type +#[derive(Deserialize)] +pub struct TypeBody { + pub text: String, +} + +pub async fn request_type( + Path(session_id): Path, + State(state): State, + Json(body): Json, +) -> impl IntoResponse { + match dispatch(&state, &session_id, "type", |rid| ServerMessage::TypeRequest { + request_id: rid, + text: body.text.clone(), + }) + .await + { + Ok(_) => (StatusCode::OK, Json(serde_json::json!({ "ok": true }))).into_response(), + Err(e) => e.into_response(), + } +} + +/// POST /sessions/:id/label +#[derive(Deserialize)] +pub struct LabelBody { + pub label: String, +} + +pub async fn set_label( + Path(session_id): Path, + State(state): State, + Json(body): Json, +) -> impl IntoResponse { + let id = match session_id.parse::() { + Ok(id) => id, + Err(_) => { + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ "error": format!("Invalid session id: '{session_id}'") })), + ) + .into_response() + } + }; + + if state.sessions.set_label(&id, body.label.clone()) { + (StatusCode::OK, Json(serde_json::json!({ "ok": true }))).into_response() + } else { + not_found(&session_id).into_response() + } +} diff --git a/crates/server/src/auth.rs b/crates/server/src/auth.rs new file mode 100644 index 0000000..c36c34c --- /dev/null +++ b/crates/server/src/auth.rs @@ -0,0 +1,26 @@ +use axum::{ + extract::{Request, State}, + http::StatusCode, + middleware::Next, + response::Response, +}; +use crate::AppState; + +/// Axum middleware that checks the `X-Api-Key` header. +pub async fn require_api_key( + State(state): State, + req: Request, + next: Next, +) -> Result { + let key = req + .headers() + .get("X-Api-Key") + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); + + if key != state.api_key { + return Err(StatusCode::UNAUTHORIZED); + } + + Ok(next.run(req).await) +} diff --git a/crates/server/src/main.rs b/crates/server/src/main.rs new file mode 100644 index 0000000..d9a997d --- /dev/null +++ b/crates/server/src/main.rs @@ -0,0 +1,62 @@ +mod session; +mod ws_handler; +mod api; +mod auth; + +use std::sync::Arc; +use axum::{ + Router, + routing::{get, post}, + middleware, +}; +use tokio::net::TcpListener; +use tracing::info; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +use session::SessionStore; +use auth::require_api_key; + +#[derive(Clone)] +pub struct AppState { + pub sessions: Arc, + pub api_key: String, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::registry() + .with(tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "helios_server=debug,tower_http=info".into())) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let api_key = std::env::var("HELIOS_API_KEY") + .unwrap_or_else(|_| "dev-secret".to_string()); + + let bind_addr = std::env::var("HELIOS_BIND") + .unwrap_or_else(|_| "0.0.0.0:3000".to_string()); + + let state = AppState { + sessions: Arc::new(SessionStore::new()), + api_key, + }; + + let protected = Router::new() + .route("/sessions", get(api::list_sessions)) + .route("/sessions/:id/screenshot", post(api::request_screenshot)) + .route("/sessions/:id/exec", post(api::request_exec)) + .route("/sessions/:id/click", post(api::request_click)) + .route("/sessions/:id/type", post(api::request_type)) + .route("/sessions/:id/label", post(api::set_label)) + .layer(middleware::from_fn_with_state(state.clone(), require_api_key)); + + let app = Router::new() + .route("/ws", get(ws_handler::ws_upgrade)) + .merge(protected) + .with_state(state); + + info!("helios-server listening on {bind_addr}"); + let listener = TcpListener::bind(&bind_addr).await?; + axum::serve(listener, app).await?; + Ok(()) +} diff --git a/crates/server/src/session.rs b/crates/server/src/session.rs new file mode 100644 index 0000000..c844373 --- /dev/null +++ b/crates/server/src/session.rs @@ -0,0 +1,89 @@ +use dashmap::DashMap; +use tokio::sync::{mpsc, oneshot}; +use uuid::Uuid; +use serde::Serialize; +use helios_common::protocol::{ClientMessage, ServerMessage}; + +/// Represents one connected remote client +#[derive(Debug, Clone)] +pub struct Session { + pub id: Uuid, + pub label: Option, + /// Channel to send commands to the WS handler for this session + pub cmd_tx: mpsc::Sender, +} + +/// Serializable view of a session for the REST API +#[derive(Debug, Serialize)] +pub struct SessionInfo { + pub id: Uuid, + pub label: Option, +} + +impl From<&Session> for SessionInfo { + fn from(s: &Session) -> Self { + SessionInfo { + id: s.id, + label: s.label.clone(), + } + } +} + +pub struct SessionStore { + /// Active sessions by ID + sessions: DashMap, + /// Pending request callbacks by request_id + pending: DashMap>, +} + +impl SessionStore { + pub fn new() -> Self { + Self { + sessions: DashMap::new(), + pending: DashMap::new(), + } + } + + pub fn insert(&self, session: Session) { + self.sessions.insert(session.id, session); + } + + pub fn remove(&self, id: &Uuid) { + self.sessions.remove(id); + } + + pub fn get_cmd_tx(&self, id: &Uuid) -> Option> { + self.sessions.get(id).map(|s| s.cmd_tx.clone()) + } + + pub fn set_label(&self, id: &Uuid, label: String) -> bool { + if let Some(mut s) = self.sessions.get_mut(id) { + s.label = Some(label); + true + } else { + false + } + } + + pub fn list(&self) -> Vec { + self.sessions.iter().map(|e| SessionInfo::from(e.value())).collect() + } + + /// Register a pending request. Returns the receiver to await the client response. + pub fn register_pending(&self, request_id: Uuid) -> oneshot::Receiver { + let (tx, rx) = oneshot::channel(); + self.pending.insert(request_id, tx); + rx + } + + /// Deliver a client response to the waiting request handler. + /// Returns true if the request was found and resolved. + pub fn resolve_pending(&self, request_id: Uuid, msg: ClientMessage) -> bool { + if let Some((_, tx)) = self.pending.remove(&request_id) { + let _ = tx.send(msg); + true + } else { + false + } + } +} diff --git a/crates/server/src/ws_handler.rs b/crates/server/src/ws_handler.rs new file mode 100644 index 0000000..cdbeeba --- /dev/null +++ b/crates/server/src/ws_handler.rs @@ -0,0 +1,99 @@ +use axum::{ + extract::{State, WebSocketUpgrade}, + response::IntoResponse, +}; +use axum::extract::ws::{Message, WebSocket}; +use futures_util::{SinkExt, StreamExt}; +use tokio::sync::mpsc; +use uuid::Uuid; +use tracing::{debug, error, info, warn}; + +use helios_common::protocol::{ClientMessage, ServerMessage}; +use crate::{AppState, session::Session}; + +pub async fn ws_upgrade( + ws: WebSocketUpgrade, + State(state): State, +) -> impl IntoResponse { + ws.on_upgrade(move |socket| handle_socket(socket, state)) +} + +async fn handle_socket(socket: WebSocket, state: AppState) { + let session_id = Uuid::new_v4(); + let (cmd_tx, mut cmd_rx) = mpsc::channel::(64); + + // Register session (label filled in on Hello) + let session = Session { + id: session_id, + label: None, + cmd_tx, + }; + state.sessions.insert(session); + info!("Client connected: session={session_id}"); + + let (mut ws_tx, mut ws_rx) = socket.split(); + + // Spawn task: forward server commands → WS + let send_task = tokio::spawn(async move { + while let Some(msg) = cmd_rx.recv().await { + match serde_json::to_string(&msg) { + Ok(json) => { + if let Err(e) = ws_tx.send(Message::Text(json.into())).await { + error!("WS send error for session={session_id}: {e}"); + break; + } + } + Err(e) => { + error!("Serialization error for session={session_id}: {e}"); + } + } + } + }); + + // Main loop: receive client messages + while let Some(result) = ws_rx.next().await { + match result { + Ok(Message::Text(text)) => { + match serde_json::from_str::(&text) { + Ok(msg) => handle_client_message(session_id, msg, &state).await, + Err(e) => { + warn!("Invalid JSON from session={session_id}: {e} | raw={text}"); + } + } + } + Ok(Message::Close(_)) => { + info!("Client disconnected gracefully: session={session_id}"); + break; + } + Ok(Message::Ping(_)) | Ok(Message::Pong(_)) | Ok(Message::Binary(_)) => {} + Err(e) => { + error!("WS receive error for session={session_id}: {e}"); + break; + } + } + } + + send_task.abort(); + state.sessions.remove(&session_id); + info!("Session cleaned up: session={session_id}"); +} + +async fn handle_client_message(session_id: Uuid, msg: ClientMessage, state: &AppState) { + match &msg { + ClientMessage::Hello { label } => { + if let Some(lbl) = label { + state.sessions.set_label(&session_id, lbl.clone()); + } + debug!("Hello from session={session_id}, label={label:?}"); + } + ClientMessage::ScreenshotResponse { request_id, .. } + | ClientMessage::ExecResponse { request_id, .. } + | ClientMessage::Ack { request_id } + | ClientMessage::Error { request_id, .. } => { + let rid = *request_id; + if !state.sessions.resolve_pending(rid, msg) { + warn!("No pending request for request_id={rid} (session={session_id})"); + } + } + } +}