Initial implementation: relay server + common protocol + client stub

This commit is contained in:
Helios 2026-03-02 18:03:46 +01:00
commit 7285a33cff
No known key found for this signature in database
GPG key ID: C8259547CD8309B5
17 changed files with 926 additions and 0 deletions

24
crates/server/Cargo.toml Normal file
View file

@ -0,0 +1,24 @@
[package]
name = "helios-server"
version = "0.1.0"
edition = "2021"
[[bin]]
name = "helios-server"
path = "src/main.rs"
[dependencies]
helios-common = { path = "../common" }
tokio = { version = "1", features = ["full"] }
axum = { version = "0.7", features = ["ws"] }
tower = "0.4"
tower-http = { version = "0.5", features = ["trace"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
uuid = { version = "1", features = ["v4", "serde"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tokio-tungstenite = "0.21"
futures-util = "0.3"
dashmap = "5"
anyhow = "1"

263
crates/server/src/api.rs Normal file
View file

@ -0,0 +1,263 @@
use std::time::Duration;
use axum::{
extract::{Path, State},
http::StatusCode,
response::IntoResponse,
Json,
};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use tracing::error;
use helios_common::protocol::{ClientMessage, MouseButton, ServerMessage};
use crate::AppState;
const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
// ── Response types ──────────────────────────────────────────────────────────
#[derive(Serialize)]
pub struct ErrorBody {
pub error: String,
}
fn not_found(session_id: &str) -> (StatusCode, Json<ErrorBody>) {
(
StatusCode::NOT_FOUND,
Json(ErrorBody {
error: format!("Session '{session_id}' not found or not connected"),
}),
)
}
fn timeout_error(session_id: &str, op: &str) -> (StatusCode, Json<ErrorBody>) {
(
StatusCode::GATEWAY_TIMEOUT,
Json(ErrorBody {
error: format!(
"Timed out waiting for client response (session='{session_id}', op='{op}')"
),
}),
)
}
fn send_error(session_id: &str, op: &str) -> (StatusCode, Json<ErrorBody>) {
(
StatusCode::BAD_GATEWAY,
Json(ErrorBody {
error: format!(
"Failed to send command to client — client may have disconnected (session='{session_id}', op='{op}')"
),
}),
)
}
// ── Helper to send a command and await the response ─────────────────────────
async fn dispatch<F>(
state: &AppState,
session_id: &str,
op: &str,
make_msg: F,
) -> Result<ClientMessage, (StatusCode, Json<ErrorBody>)>
where
F: FnOnce(Uuid) -> ServerMessage,
{
let id = session_id.parse::<Uuid>().map_err(|_| {
(
StatusCode::BAD_REQUEST,
Json(ErrorBody {
error: format!("Invalid session id: '{session_id}'"),
}),
)
})?;
let tx = state
.sessions
.get_cmd_tx(&id)
.ok_or_else(|| not_found(session_id))?;
let request_id = Uuid::new_v4();
let rx = state.sessions.register_pending(request_id);
let msg = make_msg(request_id);
tx.send(msg).await.map_err(|e| {
error!("Channel send failed for session={session_id}, op={op}: {e}");
send_error(session_id, op)
})?;
match tokio::time::timeout(REQUEST_TIMEOUT, rx).await {
Ok(Ok(response)) => Ok(response),
Ok(Err(_)) => Err(send_error(session_id, op)),
Err(_) => Err(timeout_error(session_id, op)),
}
}
// ── Handlers ─────────────────────────────────────────────────────────────────
/// GET /sessions — list all connected clients
pub async fn list_sessions(State(state): State<AppState>) -> Json<serde_json::Value> {
let sessions = state.sessions.list();
Json(serde_json::json!({ "sessions": sessions }))
}
/// POST /sessions/:id/screenshot
pub async fn request_screenshot(
Path(session_id): Path<String>,
State(state): State<AppState>,
) -> impl IntoResponse {
match dispatch(&state, &session_id, "screenshot", |rid| {
ServerMessage::ScreenshotRequest { request_id: rid }
})
.await
{
Ok(ClientMessage::ScreenshotResponse {
image_base64,
width,
height,
..
}) => (
StatusCode::OK,
Json(serde_json::json!({
"image_base64": image_base64,
"width": width,
"height": height,
})),
)
.into_response(),
Ok(ClientMessage::Error { message, .. }) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": message })),
)
.into_response(),
Ok(_) => (
StatusCode::BAD_GATEWAY,
Json(serde_json::json!({ "error": "Unexpected response from client" })),
)
.into_response(),
Err(e) => e.into_response(),
}
}
/// POST /sessions/:id/exec
#[derive(Deserialize)]
pub struct ExecBody {
pub command: String,
}
pub async fn request_exec(
Path(session_id): Path<String>,
State(state): State<AppState>,
Json(body): Json<ExecBody>,
) -> impl IntoResponse {
match dispatch(&state, &session_id, "exec", |rid| ServerMessage::ExecRequest {
request_id: rid,
command: body.command.clone(),
})
.await
{
Ok(ClientMessage::ExecResponse {
stdout,
stderr,
exit_code,
..
}) => (
StatusCode::OK,
Json(serde_json::json!({
"stdout": stdout,
"stderr": stderr,
"exit_code": exit_code,
})),
)
.into_response(),
Ok(ClientMessage::Error { message, .. }) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": message })),
)
.into_response(),
Ok(_) => (
StatusCode::BAD_GATEWAY,
Json(serde_json::json!({ "error": "Unexpected response from client" })),
)
.into_response(),
Err(e) => e.into_response(),
}
}
/// POST /sessions/:id/click
#[derive(Deserialize)]
pub struct ClickBody {
pub x: i32,
pub y: i32,
#[serde(default)]
pub button: MouseButton,
}
pub async fn request_click(
Path(session_id): Path<String>,
State(state): State<AppState>,
Json(body): Json<ClickBody>,
) -> impl IntoResponse {
match dispatch(&state, &session_id, "click", |rid| ServerMessage::ClickRequest {
request_id: rid,
x: body.x,
y: body.y,
button: body.button.clone(),
})
.await
{
Ok(_) => (StatusCode::OK, Json(serde_json::json!({ "ok": true }))).into_response(),
Err(e) => e.into_response(),
}
}
/// POST /sessions/:id/type
#[derive(Deserialize)]
pub struct TypeBody {
pub text: String,
}
pub async fn request_type(
Path(session_id): Path<String>,
State(state): State<AppState>,
Json(body): Json<TypeBody>,
) -> impl IntoResponse {
match dispatch(&state, &session_id, "type", |rid| ServerMessage::TypeRequest {
request_id: rid,
text: body.text.clone(),
})
.await
{
Ok(_) => (StatusCode::OK, Json(serde_json::json!({ "ok": true }))).into_response(),
Err(e) => e.into_response(),
}
}
/// POST /sessions/:id/label
#[derive(Deserialize)]
pub struct LabelBody {
pub label: String,
}
pub async fn set_label(
Path(session_id): Path<String>,
State(state): State<AppState>,
Json(body): Json<LabelBody>,
) -> impl IntoResponse {
let id = match session_id.parse::<Uuid>() {
Ok(id) => id,
Err(_) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": format!("Invalid session id: '{session_id}'") })),
)
.into_response()
}
};
if state.sessions.set_label(&id, body.label.clone()) {
(StatusCode::OK, Json(serde_json::json!({ "ok": true }))).into_response()
} else {
not_found(&session_id).into_response()
}
}

26
crates/server/src/auth.rs Normal file
View file

@ -0,0 +1,26 @@
use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::Response,
};
use crate::AppState;
/// Axum middleware that checks the `X-Api-Key` header.
pub async fn require_api_key(
State(state): State<AppState>,
req: Request,
next: Next,
) -> Result<Response, StatusCode> {
let key = req
.headers()
.get("X-Api-Key")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if key != state.api_key {
return Err(StatusCode::UNAUTHORIZED);
}
Ok(next.run(req).await)
}

62
crates/server/src/main.rs Normal file
View file

@ -0,0 +1,62 @@
mod session;
mod ws_handler;
mod api;
mod auth;
use std::sync::Arc;
use axum::{
Router,
routing::{get, post},
middleware,
};
use tokio::net::TcpListener;
use tracing::info;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use session::SessionStore;
use auth::require_api_key;
#[derive(Clone)]
pub struct AppState {
pub sessions: Arc<SessionStore>,
pub api_key: String,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "helios_server=debug,tower_http=info".into()))
.with(tracing_subscriber::fmt::layer())
.init();
let api_key = std::env::var("HELIOS_API_KEY")
.unwrap_or_else(|_| "dev-secret".to_string());
let bind_addr = std::env::var("HELIOS_BIND")
.unwrap_or_else(|_| "0.0.0.0:3000".to_string());
let state = AppState {
sessions: Arc::new(SessionStore::new()),
api_key,
};
let protected = Router::new()
.route("/sessions", get(api::list_sessions))
.route("/sessions/:id/screenshot", post(api::request_screenshot))
.route("/sessions/:id/exec", post(api::request_exec))
.route("/sessions/:id/click", post(api::request_click))
.route("/sessions/:id/type", post(api::request_type))
.route("/sessions/:id/label", post(api::set_label))
.layer(middleware::from_fn_with_state(state.clone(), require_api_key));
let app = Router::new()
.route("/ws", get(ws_handler::ws_upgrade))
.merge(protected)
.with_state(state);
info!("helios-server listening on {bind_addr}");
let listener = TcpListener::bind(&bind_addr).await?;
axum::serve(listener, app).await?;
Ok(())
}

View file

@ -0,0 +1,89 @@
use dashmap::DashMap;
use tokio::sync::{mpsc, oneshot};
use uuid::Uuid;
use serde::Serialize;
use helios_common::protocol::{ClientMessage, ServerMessage};
/// Represents one connected remote client
#[derive(Debug, Clone)]
pub struct Session {
pub id: Uuid,
pub label: Option<String>,
/// Channel to send commands to the WS handler for this session
pub cmd_tx: mpsc::Sender<ServerMessage>,
}
/// Serializable view of a session for the REST API
#[derive(Debug, Serialize)]
pub struct SessionInfo {
pub id: Uuid,
pub label: Option<String>,
}
impl From<&Session> for SessionInfo {
fn from(s: &Session) -> Self {
SessionInfo {
id: s.id,
label: s.label.clone(),
}
}
}
pub struct SessionStore {
/// Active sessions by ID
sessions: DashMap<Uuid, Session>,
/// Pending request callbacks by request_id
pending: DashMap<Uuid, oneshot::Sender<ClientMessage>>,
}
impl SessionStore {
pub fn new() -> Self {
Self {
sessions: DashMap::new(),
pending: DashMap::new(),
}
}
pub fn insert(&self, session: Session) {
self.sessions.insert(session.id, session);
}
pub fn remove(&self, id: &Uuid) {
self.sessions.remove(id);
}
pub fn get_cmd_tx(&self, id: &Uuid) -> Option<mpsc::Sender<ServerMessage>> {
self.sessions.get(id).map(|s| s.cmd_tx.clone())
}
pub fn set_label(&self, id: &Uuid, label: String) -> bool {
if let Some(mut s) = self.sessions.get_mut(id) {
s.label = Some(label);
true
} else {
false
}
}
pub fn list(&self) -> Vec<SessionInfo> {
self.sessions.iter().map(|e| SessionInfo::from(e.value())).collect()
}
/// Register a pending request. Returns the receiver to await the client response.
pub fn register_pending(&self, request_id: Uuid) -> oneshot::Receiver<ClientMessage> {
let (tx, rx) = oneshot::channel();
self.pending.insert(request_id, tx);
rx
}
/// Deliver a client response to the waiting request handler.
/// Returns true if the request was found and resolved.
pub fn resolve_pending(&self, request_id: Uuid, msg: ClientMessage) -> bool {
if let Some((_, tx)) = self.pending.remove(&request_id) {
let _ = tx.send(msg);
true
} else {
false
}
}
}

View file

@ -0,0 +1,99 @@
use axum::{
extract::{State, WebSocketUpgrade},
response::IntoResponse,
};
use axum::extract::ws::{Message, WebSocket};
use futures_util::{SinkExt, StreamExt};
use tokio::sync::mpsc;
use uuid::Uuid;
use tracing::{debug, error, info, warn};
use helios_common::protocol::{ClientMessage, ServerMessage};
use crate::{AppState, session::Session};
pub async fn ws_upgrade(
ws: WebSocketUpgrade,
State(state): State<AppState>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_socket(socket, state))
}
async fn handle_socket(socket: WebSocket, state: AppState) {
let session_id = Uuid::new_v4();
let (cmd_tx, mut cmd_rx) = mpsc::channel::<ServerMessage>(64);
// Register session (label filled in on Hello)
let session = Session {
id: session_id,
label: None,
cmd_tx,
};
state.sessions.insert(session);
info!("Client connected: session={session_id}");
let (mut ws_tx, mut ws_rx) = socket.split();
// Spawn task: forward server commands → WS
let send_task = tokio::spawn(async move {
while let Some(msg) = cmd_rx.recv().await {
match serde_json::to_string(&msg) {
Ok(json) => {
if let Err(e) = ws_tx.send(Message::Text(json.into())).await {
error!("WS send error for session={session_id}: {e}");
break;
}
}
Err(e) => {
error!("Serialization error for session={session_id}: {e}");
}
}
}
});
// Main loop: receive client messages
while let Some(result) = ws_rx.next().await {
match result {
Ok(Message::Text(text)) => {
match serde_json::from_str::<ClientMessage>(&text) {
Ok(msg) => handle_client_message(session_id, msg, &state).await,
Err(e) => {
warn!("Invalid JSON from session={session_id}: {e} | raw={text}");
}
}
}
Ok(Message::Close(_)) => {
info!("Client disconnected gracefully: session={session_id}");
break;
}
Ok(Message::Ping(_)) | Ok(Message::Pong(_)) | Ok(Message::Binary(_)) => {}
Err(e) => {
error!("WS receive error for session={session_id}: {e}");
break;
}
}
}
send_task.abort();
state.sessions.remove(&session_id);
info!("Session cleaned up: session={session_id}");
}
async fn handle_client_message(session_id: Uuid, msg: ClientMessage, state: &AppState) {
match &msg {
ClientMessage::Hello { label } => {
if let Some(lbl) = label {
state.sessions.set_label(&session_id, lbl.clone());
}
debug!("Hello from session={session_id}, label={label:?}");
}
ClientMessage::ScreenshotResponse { request_id, .. }
| ClientMessage::ExecResponse { request_id, .. }
| ClientMessage::Ack { request_id }
| ClientMessage::Error { request_id, .. } => {
let rid = *request_id;
if !state.sessions.resolve_pending(rid, msg) {
warn!("No pending request for request_id={rid} (session={session_id})");
}
}
}
}