refactor: enforce device labels, unify screenshot, remove deprecated commands, session-id-less design

- Device labels: lowercase, no whitespace, only a-z 0-9 - _ (enforced at config time)
- Session IDs removed: device label is the sole identifier
- Routes changed: /sessions/:id → /devices/:label
- Removed commands: click, type, find-window, wait-for-window, label, old version, server-version
- Renamed: status → version (compares relay/remote.py/client commits)
- Unified screenshot: takes 'screen' or a window label as argument
- Windows listed with human-readable labels (same format as device labels)
- Single instance enforcement via PID lock file
- Removed input.rs (click/type functionality)
- All docs and code in English
- Protocol: Hello.label is now required (String, not Option<String>)
- Client auto-migrates invalid labels on startup
This commit is contained in:
Helios 2026-03-06 01:55:28 +01:00
parent 5fd01a423d
commit 0b4a6de8ae
No known key found for this signature in database
GPG key ID: C8259547CD8309B5
14 changed files with 736 additions and 1180 deletions

View file

@ -24,6 +24,7 @@ base64 = "0.22"
png = "0.17"
futures-util = "0.3"
colored = "2"
scopeguard = "1"
terminal_size = "0.3"
unicode-width = "0.1"

View file

@ -1,36 +1,40 @@
# helios-client (Phase 2 — not yet implemented)
# helios-client
This crate will contain the Windows remote-control client for `helios-remote`.
Windows client for helios-remote. Connects to the relay server via WebSocket and executes commands.
## Planned Features
## Features
- Connects to the relay server via WebSocket (`wss://`)
- Sends a `Hello` message on connect with an optional display label
- Handles incoming `ServerMessage` commands:
- `ScreenshotRequest` → captures the primary display (Windows GDI or `windows-capture`) and responds with base64 PNG
- `ExecRequest` → runs a shell command in a persistent `cmd.exe` / PowerShell session and returns stdout/stderr/exit-code
- `ClickRequest` → simulates a mouse click via `SendInput` Win32 API
- `TypeRequest` → types text via `SendInput` (virtual key events)
- Persistent shell session so `cd C:\Users` persists across `exec` calls
- Auto-reconnect with exponential backoff
- Configurable via environment variables or a `client.toml` config file
- Full-screen and per-window screenshots
- Shell command execution (persistent PowerShell session)
- Window management (list, focus, maximize, minimize)
- File upload/download
- Clipboard get/set
- Program launch (fire-and-forget)
- User prompts (MessageBox)
- Single instance enforcement (PID lock file)
## Planned Tech Stack
## Configuration
| Crate | Purpose |
|---|---|
| `tokio` | Async runtime |
| `tokio-tungstenite` | WebSocket client |
| `serde_json` | Protocol serialization |
| `windows` / `winapi` | Screen capture, mouse/keyboard input |
| `base64` | PNG encoding for screenshots |
On first run, the client prompts for:
- **Relay URL** (default: `wss://remote.agent-helios.me/ws`)
- **API Key**
- **Device label** — must be lowercase, no whitespace, only `a-z 0-9 - _`
## Build Target
Config is saved to `%APPDATA%/helios-remote/config.toml`.
## Device Labels
The device label is the sole identifier for this machine. It must follow these rules:
- Lowercase only
- No whitespace
- Only characters: `a-z`, `0-9`, `-`, `_`
Examples: `moritz_pc`, `work-desktop`, `gaming-rig`
If an existing config has an invalid label, it will be automatically migrated on next startup.
## Build
```bash
cargo build -p helios-client --release
```
cargo build --target x86_64-pc-windows-gnu
```
## App Icon
The file `assets/logo.ico` in the repository root is the application icon intended for the Windows `.exe`. It can be embedded at compile time using a build script (e.g. via the `winres` crate).

View file

@ -1,154 +0,0 @@
/// Mouse click and keyboard input via Windows SendInput (or stub on non-Windows).
use helios_common::MouseButton;
#[cfg(windows)]
pub fn click(x: i32, y: i32, button: &MouseButton) -> Result<(), String> {
use windows::Win32::UI::Input::KeyboardAndMouse::{
SendInput, INPUT, INPUT_MOUSE, MOUSEEVENTF_ABSOLUTE, MOUSEEVENTF_LEFTDOWN,
MOUSEEVENTF_LEFTUP, MOUSEEVENTF_MIDDLEDOWN, MOUSEEVENTF_MIDDLEUP, MOUSEEVENTF_MOVE,
MOUSEEVENTF_RIGHTDOWN, MOUSEEVENTF_RIGHTUP, MOUSEINPUT,
};
use windows::Win32::UI::WindowsAndMessaging::{GetSystemMetrics, SM_CXSCREEN, SM_CYSCREEN};
unsafe {
let screen_w = GetSystemMetrics(SM_CXSCREEN) as i32;
let screen_h = GetSystemMetrics(SM_CYSCREEN) as i32;
if screen_w == 0 || screen_h == 0 {
return Err(format!(
"Could not get screen dimensions: {screen_w}x{screen_h}"
));
}
// Convert pixel coords to absolute 0-65535 range
let abs_x = ((x * 65535) / screen_w) as i32;
let abs_y = ((y * 65535) / screen_h) as i32;
let (down_flag, up_flag) = match button {
MouseButton::Left => (MOUSEEVENTF_LEFTDOWN, MOUSEEVENTF_LEFTUP),
MouseButton::Right => (MOUSEEVENTF_RIGHTDOWN, MOUSEEVENTF_RIGHTUP),
MouseButton::Middle => (MOUSEEVENTF_MIDDLEDOWN, MOUSEEVENTF_MIDDLEUP),
};
// Move to position
let move_input = INPUT {
r#type: INPUT_MOUSE,
Anonymous: windows::Win32::UI::Input::KeyboardAndMouse::INPUT_0 {
mi: MOUSEINPUT {
dx: abs_x,
dy: abs_y,
mouseData: 0,
dwFlags: MOUSEEVENTF_MOVE | MOUSEEVENTF_ABSOLUTE,
time: 0,
dwExtraInfo: 0,
},
},
};
let down_input = INPUT {
r#type: INPUT_MOUSE,
Anonymous: windows::Win32::UI::Input::KeyboardAndMouse::INPUT_0 {
mi: MOUSEINPUT {
dx: abs_x,
dy: abs_y,
mouseData: 0,
dwFlags: down_flag | MOUSEEVENTF_ABSOLUTE,
time: 0,
dwExtraInfo: 0,
},
},
};
let up_input = INPUT {
r#type: INPUT_MOUSE,
Anonymous: windows::Win32::UI::Input::KeyboardAndMouse::INPUT_0 {
mi: MOUSEINPUT {
dx: abs_x,
dy: abs_y,
mouseData: 0,
dwFlags: up_flag | MOUSEEVENTF_ABSOLUTE,
time: 0,
dwExtraInfo: 0,
},
},
};
let inputs = [move_input, down_input, up_input];
let result = SendInput(&inputs, std::mem::size_of::<INPUT>() as i32);
if result != inputs.len() as u32 {
return Err(format!(
"SendInput for click at ({x},{y}) sent {result}/{} events — some may have been blocked by UIPI",
inputs.len()
));
}
Ok(())
}
}
#[cfg(windows)]
pub fn type_text(text: &str) -> Result<(), String> {
use windows::Win32::UI::Input::KeyboardAndMouse::{
SendInput, INPUT, INPUT_KEYBOARD, KEYBDINPUT, KEYEVENTF_UNICODE,
};
if text.is_empty() {
return Ok(());
}
unsafe {
let mut inputs: Vec<INPUT> = Vec::with_capacity(text.len() * 2);
for ch in text.encode_utf16() {
// Key down
inputs.push(INPUT {
r#type: INPUT_KEYBOARD,
Anonymous: windows::Win32::UI::Input::KeyboardAndMouse::INPUT_0 {
ki: KEYBDINPUT {
wVk: windows::Win32::UI::Input::KeyboardAndMouse::VIRTUAL_KEY(0),
wScan: ch,
dwFlags: KEYEVENTF_UNICODE,
time: 0,
dwExtraInfo: 0,
},
},
});
// Key up
inputs.push(INPUT {
r#type: INPUT_KEYBOARD,
Anonymous: windows::Win32::UI::Input::KeyboardAndMouse::INPUT_0 {
ki: KEYBDINPUT {
wVk: windows::Win32::UI::Input::KeyboardAndMouse::VIRTUAL_KEY(0),
wScan: ch,
dwFlags: KEYEVENTF_UNICODE
| windows::Win32::UI::Input::KeyboardAndMouse::KEYEVENTF_KEYUP,
time: 0,
dwExtraInfo: 0,
},
},
});
}
let result = SendInput(&inputs, std::mem::size_of::<INPUT>() as i32);
if result != inputs.len() as u32 {
return Err(format!(
"SendInput for type_text sent {result}/{} events — some may have been blocked (UIPI or secure desktop)",
inputs.len()
));
}
Ok(())
}
}
#[cfg(not(windows))]
pub fn click(_x: i32, _y: i32, _button: &MouseButton) -> Result<(), String> {
Err("click() is only supported on Windows".to_string())
}
#[cfg(not(windows))]
pub fn type_text(_text: &str) -> Result<(), String> {
Err("type_text() is only supported on Windows".to_string())
}

View file

@ -11,27 +11,23 @@ use tokio_tungstenite::{connect_async_tls_with_config, tungstenite::Message, Con
use base64::Engine;
use helios_common::{ClientMessage, ServerMessage};
use uuid::Uuid;
use helios_common::protocol::{is_valid_label, sanitize_label};
mod display;
mod logger;
mod shell;
mod screenshot;
mod input;
mod windows_mgmt;
// Re-export trunc for use in this file
use display::trunc;
fn banner() {
println!();
// Use same column layout as info_line: 2sp + emoji_cell(2w) + 2sp + name(14) + 2sp + value
// ☀ is 1-wide → emoji_cell pads to 2 → need 1 extra space to match
println!(" {} {}", "".yellow().bold(), "HELIOS REMOTE".bold());
display::info_line("🔗", "commit:", &env!("GIT_COMMIT").dimmed().to_string());
}
fn print_session_info(label: &str, sid: &uuid::Uuid) {
fn print_device_info(label: &str) {
#[cfg(windows)]
{
let admin = is_admin();
@ -46,7 +42,6 @@ fn print_session_info(label: &str, sid: &uuid::Uuid) {
display::info_line("👤", "privileges:", &"no admin".dimmed().to_string());
display::info_line("🖥", "device:", &label.dimmed().to_string());
display::info_line("🪪", "session:", &sid.to_string().dimmed().to_string());
println!();
}
@ -72,14 +67,69 @@ fn enable_ansi() {
}
}
// ────────────────────────────────────────────────────────────────────────────
// ── Single instance enforcement ─────────────────────────────────────────────
fn lock_file_path() -> PathBuf {
let base = dirs::config_dir()
.or_else(|| dirs::home_dir())
.unwrap_or_else(|| PathBuf::from("."));
base.join("helios-remote").join("instance.lock")
}
/// Try to acquire a single-instance lock. Returns true if we got it.
fn acquire_instance_lock() -> bool {
let path = lock_file_path();
if let Some(parent) = path.parent() {
let _ = std::fs::create_dir_all(parent);
}
// Check if another instance is running
if path.exists() {
if let Ok(content) = std::fs::read_to_string(&path) {
if let Ok(pid) = content.trim().parse::<u32>() {
// Check if process is still alive
#[cfg(windows)]
{
use windows::Win32::System::Threading::{OpenProcess, PROCESS_QUERY_LIMITED_INFORMATION};
let alive = unsafe {
OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, false, pid).is_ok()
};
if alive {
return false;
}
}
#[cfg(not(windows))]
{
use std::process::Command;
let alive = Command::new("kill")
.args(["-0", &pid.to_string()])
.status()
.map(|s| s.success())
.unwrap_or(false);
if alive {
return false;
}
}
}
}
}
// Write our PID
let pid = std::process::id();
std::fs::write(&path, pid.to_string()).is_ok()
}
fn release_instance_lock() {
let _ = std::fs::remove_file(lock_file_path());
}
// ── Config ──────────────────────────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Config {
relay_url: String,
api_key: String,
label: Option<String>,
session_id: Option<String>, // persistent UUID
label: String,
}
impl Config {
@ -131,39 +181,73 @@ fn prompt_config() -> Config {
};
let label = {
let default_label = hostname();
print!(" {} Label for this PC [{}]: ", "".cyan().bold(), default_label);
std::io::stdout().flush().unwrap();
let mut input = String::new();
std::io::stdin().read_line(&mut input).unwrap();
let trimmed = input.trim().to_string();
if trimmed.is_empty() {
Some(default_label)
} else {
Some(trimmed)
let default_label = sanitize_label(&hostname());
loop {
print!(" {} Device label [{}]: ", "".cyan().bold(), default_label);
std::io::stdout().flush().unwrap();
let mut input = String::new();
std::io::stdin().read_line(&mut input).unwrap();
let trimmed = input.trim();
let candidate = if trimmed.is_empty() {
default_label.clone()
} else {
trimmed.to_string()
};
if is_valid_label(&candidate) {
break candidate;
}
println!(" {} Label must be lowercase, no spaces. Only a-z, 0-9, '-', '_'.",
"".red().bold());
println!(" Suggestion: {}", sanitize_label(&candidate).cyan());
}
};
Config { relay_url, api_key, label, session_id: None }
Config { relay_url, api_key, label }
}
#[tokio::main]
async fn main() {
// Enable ANSI color codes on Windows (required when running as admin)
#[cfg(windows)]
enable_ansi();
logger::init();
// Suppress tracing output by default
if std::env::var("RUST_LOG").is_err() {
unsafe { std::env::set_var("RUST_LOG", "off"); }
}
banner();
// Single instance check
if !acquire_instance_lock() {
display::err("", "Another instance of helios-remote is already running.");
display::err("", "Only one instance per device is allowed.");
std::process::exit(1);
}
// Clean up lock on exit
let _guard = scopeguard::guard((), |_| release_instance_lock());
// Load or prompt for config
let config = match Config::load() {
Some(c) => c,
Some(c) => {
// Validate existing label
if !is_valid_label(&c.label) {
let new_label = sanitize_label(&c.label);
display::info_line("", "migrate:", &format!(
"Label '{}' is invalid, migrating to '{}'", c.label, new_label
));
let mut cfg = c;
cfg.label = new_label;
if let Err(e) = cfg.save() {
display::err("", &format!("Failed to save config: {e}"));
}
cfg
} else {
c
}
}
None => {
display::info_line("", "setup:", "No config found — first-time setup");
println!();
@ -178,22 +262,8 @@ async fn main() {
}
};
// Resolve or generate persistent session UUID
let sid: Uuid = match &config.session_id {
Some(id) => Uuid::parse_str(id).unwrap_or_else(|_| Uuid::new_v4()),
None => {
let id = Uuid::new_v4();
let mut cfg = config.clone();
cfg.session_id = Some(id.to_string());
if let Err(e) = cfg.save() {
display::err("", &format!("Failed to save session_id: {e}"));
}
id
}
};
let label = config.label.clone().unwrap_or_else(|| hostname());
print_session_info(&label, &sid);
let label = config.label.clone();
print_device_info(&label);
let config = Arc::new(config);
let shell = Arc::new(Mutex::new(shell::PersistentShell::new()));
@ -225,9 +295,9 @@ async fn main() {
let (mut write, mut read) = ws_stream.split();
// Send Hello
// Send Hello with device label
let hello = ClientMessage::Hello {
label: config.label.clone(),
label: label.clone(),
};
let hello_json = serde_json::to_string(&hello).unwrap();
if let Err(e) = write.send(Message::Text(hello_json)).await {
@ -254,9 +324,6 @@ async fn main() {
let shell_clone = Arc::clone(&shell);
tokio::spawn(async move {
// tokio isolates panics per task — a panic here won't kill
// the main loop. handle_message uses map_err everywhere so
// it should never panic in practice.
let response = handle_message(server_msg, shell_clone).await;
let json = match serde_json::to_string(&response) {
Ok(j) => j,
@ -343,14 +410,14 @@ async fn handle_message(
}
ServerMessage::ScreenshotRequest { request_id } => {
display::cmd_start("📷", "screenshot", "");
display::cmd_start("📷", "screenshot", "screen");
match screenshot::take_screenshot() {
Ok((image_base64, width, height)) => {
display::cmd_done("📷", "screenshot", "", true, &format!("{width}×{height}"));
display::cmd_done("📷", "screenshot", "screen", true, &format!("{width}×{height}"));
ClientMessage::ScreenshotResponse { request_id, image_base64, width, height }
}
Err(e) => {
display::cmd_done("📷", "screenshot", "", false, &format!("{e}"));
display::cmd_done("📷", "screenshot", "screen", false, &format!("{e}"));
ClientMessage::Error { request_id, message: format!("Screenshot failed: {e}") }
}
}
@ -358,7 +425,6 @@ async fn handle_message(
ServerMessage::PromptRequest { request_id, message, title: _ } => {
display::prompt_waiting(&message);
// Read user input from stdin (blocking)
let answer = tokio::task::spawn_blocking(|| {
let mut input = String::new();
std::io::stdin().read_line(&mut input).ok();
@ -375,8 +441,6 @@ async fn handle_message(
match sh.run(&command, timeout_ms).await {
Ok((stdout, stderr, exit_code)) => {
let result = if exit_code != 0 {
// For errors: use first non-empty stderr line (actual error message),
// ignoring PowerShell boilerplate like "+ CategoryInfo", "In Zeile:", etc.
let err_line = stderr.lines()
.map(|l| l.trim())
.find(|l| !l.is_empty()
@ -388,7 +452,6 @@ async fn handle_message(
.to_string();
err_line
} else {
// Success: first stdout line, no exit code
stdout.trim().lines().next().unwrap_or("").to_string()
};
display::cmd_done("", "execute", &payload, exit_code == 0, &result);
@ -401,36 +464,6 @@ async fn handle_message(
}
}
ServerMessage::ClickRequest { request_id, x, y, button } => {
let payload = format!("({x}, {y}) {button:?}");
display::cmd_start("🖱", "click", &payload);
match input::click(x, y, &button) {
Ok(()) => {
display::cmd_done("🖱", "click", &payload, true, "done");
ClientMessage::Ack { request_id }
}
Err(e) => {
display::cmd_done("🖱", "click", &payload, false, &format!("{e}"));
ClientMessage::Error { request_id, message: format!("Click at ({x},{y}) failed: {e}") }
}
}
}
ServerMessage::TypeRequest { request_id, text } => {
let payload = format!("{} chars", text.len());
display::cmd_start("", "type", &payload);
match input::type_text(&text) {
Ok(()) => {
display::cmd_done("", "type", &payload, true, "done");
ClientMessage::Ack { request_id }
}
Err(e) => {
display::cmd_done("", "type", &payload, false, &format!("{e}"));
ClientMessage::Error { request_id, message: format!("Type failed: {e}") }
}
}
}
ServerMessage::ListWindowsRequest { request_id } => {
display::cmd_start("🪟", "list windows", "");
match windows_mgmt::list_windows() {
@ -610,7 +643,7 @@ async fn handle_message(
if let Some(rid) = request_id {
ClientMessage::Ack { request_id: rid }
} else {
ClientMessage::Hello { label: None }
ClientMessage::Hello { label: String::new() }
}
}
}

View file

@ -1,4 +1,4 @@
use helios_common::protocol::WindowInfo;
use helios_common::protocol::{sanitize_label, WindowInfo};
// ── Windows implementation ──────────────────────────────────────────────────
@ -14,7 +14,6 @@ mod win_impl {
keybd_event, KEYEVENTF_KEYUP, VK_MENU,
};
// Collect HWNDs via EnumWindows callback
unsafe extern "system" fn enum_callback(hwnd: HWND, lparam: LPARAM) -> BOOL {
let list = &mut *(lparam.0 as *mut Vec<HWND>);
list.push(hwnd);
@ -38,19 +37,29 @@ mod win_impl {
String::from_utf16_lossy(&buf[..len as usize])
}
/// Generate a human-readable label from a window title.
/// E.g. "Google Chrome" -> "google_chrome", "Discord" -> "discord"
fn window_label(title: &str) -> String {
sanitize_label(title)
}
pub fn list_windows() -> Result<Vec<WindowInfo>, String> {
let hwnds = get_all_hwnds();
let mut windows = Vec::new();
for hwnd in hwnds {
let visible = unsafe { IsWindowVisible(hwnd).as_bool() };
let title = hwnd_title(hwnd);
// Only return visible windows with a non-empty title
if !visible || title.is_empty() {
continue;
}
let label = window_label(&title);
if label.is_empty() {
continue;
}
windows.push(WindowInfo {
id: hwnd.0 as u64,
title,
label,
visible: true,
});
}
@ -71,9 +80,6 @@ mod win_impl {
Ok(())
}
/// Bypass Windows Focus Stealing Prevention by sending a fake Alt keypress
/// before calling SetForegroundWindow. Without this, SetForegroundWindow
/// silently fails when the calling thread is not in the foreground.
unsafe fn force_foreground(hwnd: HWND) {
keybd_event(VK_MENU.0 as u8, 0, Default::default(), 0);
keybd_event(VK_MENU.0 as u8, 0, KEYEVENTF_KEYUP, 0);

View file

@ -1,27 +1,58 @@
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// Information about a single window on the client machine
/// Information about a single window on the client machine.
/// `label` is a human-readable, lowercase identifier (e.g. "google_chrome", "discord").
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WindowInfo {
pub id: u64,
pub title: String,
pub label: String,
pub visible: bool,
}
/// Validate a device/window label: lowercase, no whitespace, only a-z 0-9 - _
pub fn is_valid_label(s: &str) -> bool {
!s.is_empty()
&& s.chars()
.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-' || c == '_')
}
/// Convert an arbitrary string into a valid label.
/// Lowercase, replace whitespace and invalid chars with '_', collapse runs.
pub fn sanitize_label(s: &str) -> String {
let mut result = String::with_capacity(s.len());
let mut prev_underscore = false;
for c in s.chars() {
if c.is_ascii_alphanumeric() {
result.push(c.to_ascii_lowercase());
prev_underscore = false;
} else if c == '-' {
result.push('-');
prev_underscore = false;
} else {
// Replace whitespace and other chars with _
if !prev_underscore && !result.is_empty() {
result.push('_');
prev_underscore = true;
}
}
}
// Trim trailing _
result.trim_end_matches('_').to_string()
}
/// Messages sent from the relay server to a connected client
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ServerMessage {
/// Request a screenshot from the client
/// Request a full-screen screenshot
ScreenshotRequest { request_id: Uuid },
/// Capture a specific window by its HWND (works even if behind other windows)
/// Capture a specific window by its HWND
WindowScreenshotRequest { request_id: Uuid, window_id: u64 },
/// Fetch the last N lines of the client log file
LogsRequest { request_id: Uuid, lines: u32 },
/// Show a MessageBox on the client asking the user to do something.
/// Blocks until the user clicks OK — use this when you need the user
/// to perform a manual action before continuing.
/// Show a MessageBox on the client asking the user to do something
PromptRequest {
request_id: Uuid,
message: String,
@ -31,21 +62,8 @@ pub enum ServerMessage {
ExecRequest {
request_id: Uuid,
command: String,
/// Timeout in milliseconds. None = use client default (30s)
timeout_ms: Option<u64>,
},
/// Simulate a mouse click
ClickRequest {
request_id: Uuid,
x: i32,
y: i32,
button: MouseButton,
},
/// Type text on the client
TypeRequest {
request_id: Uuid,
text: String,
},
/// Acknowledge a client message
Ack { request_id: Uuid },
/// Server-side error response
@ -90,8 +108,8 @@ pub enum ServerMessage {
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ClientMessage {
/// Client registers itself with optional display name
Hello { label: Option<String> },
/// Client registers itself with its device label
Hello { label: String },
/// Response to a screenshot request — base64-encoded PNG
ScreenshotResponse {
request_id: Uuid,
@ -106,7 +124,7 @@ pub enum ClientMessage {
stderr: String,
exit_code: i32,
},
/// Generic acknowledgement for click/type/minimize-all/focus/maximize
/// Generic acknowledgement
Ack { request_id: Uuid },
/// Client error response
Error {
@ -137,29 +155,33 @@ pub enum ClientMessage {
},
/// Response to a clipboard-get request
ClipboardGetResponse { request_id: Uuid, text: String },
/// Response to a prompt request — contains the user's typed answer
/// Response to a prompt request
PromptResponse { request_id: Uuid, answer: String },
}
/// Mouse button variants
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MouseButton {
Left,
Right,
Middle,
}
impl Default for MouseButton {
fn default() -> Self {
MouseButton::Left
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_valid_labels() {
assert!(is_valid_label("moritz_pc"));
assert!(is_valid_label("my-desktop"));
assert!(is_valid_label("pc01"));
assert!(!is_valid_label("Moritz PC"));
assert!(!is_valid_label(""));
assert!(!is_valid_label("has spaces"));
assert!(!is_valid_label("UPPER"));
}
#[test]
fn test_sanitize_label() {
assert_eq!(sanitize_label("Moritz PC"), "moritz_pc");
assert_eq!(sanitize_label("My Desktop!!"), "my_desktop");
assert_eq!(sanitize_label("hello-world"), "hello-world");
assert_eq!(sanitize_label("DESKTOP-ABC123"), "desktop-abc123");
}
#[test]
fn test_server_message_serialization() {
let msg = ServerMessage::ExecRequest {
@ -174,25 +196,9 @@ mod tests {
#[test]
fn test_client_message_serialization() {
let msg = ClientMessage::Hello { label: Some("test-pc".into()) };
let msg = ClientMessage::Hello { label: "test-pc".into() };
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("hello"));
assert!(json.contains("test-pc"));
}
#[test]
fn test_roundtrip() {
let msg = ClientMessage::ExecResponse {
request_id: Uuid::nil(),
stdout: "hello\n".into(),
stderr: String::new(),
exit_code: 0,
};
let json = serde_json::to_string(&msg).unwrap();
let decoded: ClientMessage = serde_json::from_str(&json).unwrap();
match decoded {
ClientMessage::ExecResponse { exit_code, .. } => assert_eq!(exit_code, 0),
_ => panic!("wrong variant"),
}
}
}

View file

@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize};
use uuid::Uuid;
use tracing::error;
use helios_common::protocol::{ClientMessage, MouseButton, ServerMessage};
use helios_common::protocol::{ClientMessage, ServerMessage};
use crate::AppState;
const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
@ -21,33 +21,29 @@ pub struct ErrorBody {
pub error: String,
}
fn not_found(session_id: &str) -> (StatusCode, Json<ErrorBody>) {
fn not_found(label: &str) -> (StatusCode, Json<ErrorBody>) {
(
StatusCode::NOT_FOUND,
Json(ErrorBody {
error: format!("Session '{session_id}' not found or not connected"),
error: format!("Device '{label}' not found or not connected"),
}),
)
}
fn timeout_error(session_id: &str, op: &str) -> (StatusCode, Json<ErrorBody>) {
fn timeout_error(label: &str, op: &str) -> (StatusCode, Json<ErrorBody>) {
(
StatusCode::GATEWAY_TIMEOUT,
Json(ErrorBody {
error: format!(
"Timed out waiting for client response (session='{session_id}', op='{op}')"
),
error: format!("Timed out waiting for client response (device='{label}', op='{op}')"),
}),
)
}
fn send_error(session_id: &str, op: &str) -> (StatusCode, Json<ErrorBody>) {
fn send_error(label: &str, op: &str) -> (StatusCode, Json<ErrorBody>) {
(
StatusCode::BAD_GATEWAY,
Json(ErrorBody {
error: format!(
"Failed to send command to client — client may have disconnected (session='{session_id}', op='{op}')"
),
error: format!("Failed to send command to client — may have disconnected (device='{label}', op='{op}')"),
}),
)
}
@ -56,19 +52,19 @@ fn send_error(session_id: &str, op: &str) -> (StatusCode, Json<ErrorBody>) {
async fn dispatch<F>(
state: &AppState,
session_id: &str,
label: &str,
op: &str,
make_msg: F,
) -> Result<ClientMessage, (StatusCode, Json<ErrorBody>)>
where
F: FnOnce(Uuid) -> ServerMessage,
{
dispatch_with_timeout(state, session_id, op, make_msg, REQUEST_TIMEOUT).await
dispatch_with_timeout(state, label, op, make_msg, REQUEST_TIMEOUT).await
}
async fn dispatch_with_timeout<F>(
state: &AppState,
session_id: &str,
label: &str,
op: &str,
make_msg: F,
timeout: Duration,
@ -76,50 +72,62 @@ async fn dispatch_with_timeout<F>(
where
F: FnOnce(Uuid) -> ServerMessage,
{
let id = session_id.parse::<Uuid>().map_err(|_| {
(
StatusCode::BAD_REQUEST,
Json(ErrorBody {
error: format!("Invalid session id: '{session_id}'"),
}),
)
})?;
let tx = state
.sessions
.get_cmd_tx(&id)
.ok_or_else(|| not_found(session_id))?;
.get_cmd_tx(label)
.ok_or_else(|| not_found(label))?;
let request_id = Uuid::new_v4();
let rx = state.sessions.register_pending(request_id);
let msg = make_msg(request_id);
tx.send(msg).await.map_err(|e| {
error!("Channel send failed for session={session_id}, op={op}: {e}");
send_error(session_id, op)
error!("Channel send failed for device={label}, op={op}: {e}");
send_error(label, op)
})?;
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(response)) => Ok(response),
Ok(Err(_)) => Err(send_error(session_id, op)),
Err(_) => Err(timeout_error(session_id, op)),
Ok(Err(_)) => Err(send_error(label, op)),
Err(_) => Err(timeout_error(label, op)),
}
}
// ── Handlers ─────────────────────────────────────────────────────────────────
/// GET /sessions — list all connected clients
pub async fn list_sessions(State(state): State<AppState>) -> Json<serde_json::Value> {
let sessions = state.sessions.list();
Json(serde_json::json!({ "sessions": sessions }))
/// GET /devices — list all connected clients
pub async fn list_devices(State(state): State<AppState>) -> Json<serde_json::Value> {
let devices = state.sessions.list();
Json(serde_json::json!({ "devices": devices }))
}
/// POST /sessions/:id/windows/:window_id/screenshot
pub async fn window_screenshot(
Path((session_id, window_id)): Path<(String, u64)>,
/// POST /devices/:label/screenshot — full screen screenshot
pub async fn request_screenshot(
Path(label): Path<String>,
State(state): State<AppState>,
) -> impl IntoResponse {
match dispatch(&state, &session_id, "window_screenshot", |rid| {
match dispatch(&state, &label, "screenshot", |rid| {
ServerMessage::ScreenshotRequest { request_id: rid }
}).await {
Ok(ClientMessage::ScreenshotResponse { image_base64, width, height, .. }) => (
StatusCode::OK,
Json(serde_json::json!({ "image_base64": image_base64, "width": width, "height": height })),
).into_response(),
Ok(ClientMessage::Error { message, .. }) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": message })),
).into_response(),
Ok(_) => (StatusCode::BAD_GATEWAY, Json(serde_json::json!({ "error": "Unexpected response" }))).into_response(),
Err(e) => e.into_response(),
}
}
/// POST /devices/:label/windows/:window_id/screenshot
pub async fn window_screenshot(
Path((label, window_id)): Path<(String, u64)>,
State(state): State<AppState>,
) -> impl IntoResponse {
match dispatch(&state, &label, "window_screenshot", |rid| {
ServerMessage::WindowScreenshotRequest { request_id: rid, window_id }
}).await {
Ok(ClientMessage::ScreenshotResponse { image_base64, width, height, .. }) => (
@ -135,14 +143,14 @@ pub async fn window_screenshot(
}
}
/// GET /sessions/:id/logs?lines=100
/// GET /devices/:label/logs?lines=100
pub async fn logs(
Path(session_id): Path<String>,
Path(label): Path<String>,
Query(params): Query<std::collections::HashMap<String, String>>,
State(state): State<AppState>,
) -> impl IntoResponse {
let lines: u32 = params.get("lines").and_then(|v| v.parse().ok()).unwrap_or(100);
match dispatch(&state, &session_id, "logs", |rid| {
match dispatch(&state, &label, "logs", |rid| {
ServerMessage::LogsRequest { request_id: rid, lines }
}).await {
Ok(ClientMessage::LogsResponse { content, log_path, .. }) => (
@ -158,216 +166,95 @@ pub async fn logs(
}
}
/// POST /sessions/:id/screenshot
pub async fn request_screenshot(
Path(session_id): Path<String>,
State(state): State<AppState>,
) -> impl IntoResponse {
match dispatch(&state, &session_id, "screenshot", |rid| {
ServerMessage::ScreenshotRequest { request_id: rid }
})
.await
{
Ok(ClientMessage::ScreenshotResponse {
image_base64,
width,
height,
..
}) => (
StatusCode::OK,
Json(serde_json::json!({
"image_base64": image_base64,
"width": width,
"height": height,
})),
)
.into_response(),
Ok(ClientMessage::Error { message, .. }) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": message })),
)
.into_response(),
Ok(_) => (
StatusCode::BAD_GATEWAY,
Json(serde_json::json!({ "error": "Unexpected response from client" })),
)
.into_response(),
Err(e) => e.into_response(),
}
}
/// POST /sessions/:id/exec
/// POST /devices/:label/exec
#[derive(Deserialize)]
pub struct ExecBody {
pub command: String,
/// Optional timeout in milliseconds (default: 30000). Use higher values for
/// long-running commands like downloads.
pub timeout_ms: Option<u64>,
}
pub async fn request_exec(
Path(session_id): Path<String>,
Path(label): Path<String>,
State(state): State<AppState>,
Json(body): Json<ExecBody>,
) -> impl IntoResponse {
// Server-side wait must be at least as long as the client timeout + buffer
let server_timeout = body.timeout_ms
.map(|ms| std::time::Duration::from_millis(ms + 5_000))
.map(|ms| Duration::from_millis(ms + 5_000))
.unwrap_or(REQUEST_TIMEOUT);
match dispatch_with_timeout(&state, &session_id, "exec", |rid| ServerMessage::ExecRequest {
match dispatch_with_timeout(&state, &label, "exec", |rid| ServerMessage::ExecRequest {
request_id: rid,
command: body.command.clone(),
timeout_ms: body.timeout_ms,
}, server_timeout)
.await
{
Ok(ClientMessage::ExecResponse {
stdout,
stderr,
exit_code,
..
}) => (
}, server_timeout).await {
Ok(ClientMessage::ExecResponse { stdout, stderr, exit_code, .. }) => (
StatusCode::OK,
Json(serde_json::json!({
"stdout": stdout,
"stderr": stderr,
"exit_code": exit_code,
})),
)
.into_response(),
Json(serde_json::json!({ "stdout": stdout, "stderr": stderr, "exit_code": exit_code })),
).into_response(),
Ok(ClientMessage::Error { message, .. }) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": message })),
)
.into_response(),
Ok(_) => (
StatusCode::BAD_GATEWAY,
Json(serde_json::json!({ "error": "Unexpected response from client" })),
)
.into_response(),
).into_response(),
Ok(_) => (StatusCode::BAD_GATEWAY, Json(serde_json::json!({ "error": "Unexpected response" }))).into_response(),
Err(e) => e.into_response(),
}
}
/// POST /sessions/:id/click
#[derive(Deserialize)]
pub struct ClickBody {
pub x: i32,
pub y: i32,
#[serde(default)]
pub button: MouseButton,
}
pub async fn request_click(
Path(session_id): Path<String>,
State(state): State<AppState>,
Json(body): Json<ClickBody>,
) -> impl IntoResponse {
match dispatch(&state, &session_id, "click", |rid| ServerMessage::ClickRequest {
request_id: rid,
x: body.x,
y: body.y,
button: body.button.clone(),
})
.await
{
Ok(_) => (StatusCode::OK, Json(serde_json::json!({ "ok": true }))).into_response(),
Err(e) => e.into_response(),
}
}
/// POST /sessions/:id/type
#[derive(Deserialize)]
pub struct TypeBody {
pub text: String,
}
pub async fn request_type(
Path(session_id): Path<String>,
State(state): State<AppState>,
Json(body): Json<TypeBody>,
) -> impl IntoResponse {
match dispatch(&state, &session_id, "type", |rid| ServerMessage::TypeRequest {
request_id: rid,
text: body.text.clone(),
})
.await
{
Ok(_) => (StatusCode::OK, Json(serde_json::json!({ "ok": true }))).into_response(),
Err(e) => e.into_response(),
}
}
/// GET /sessions/:id/windows
/// GET /devices/:label/windows
pub async fn list_windows(
Path(session_id): Path<String>,
Path(label): Path<String>,
State(state): State<AppState>,
) -> impl IntoResponse {
match dispatch(&state, &session_id, "list_windows", |rid| {
match dispatch(&state, &label, "list_windows", |rid| {
ServerMessage::ListWindowsRequest { request_id: rid }
})
.await
{
}).await {
Ok(ClientMessage::ListWindowsResponse { windows, .. }) => (
StatusCode::OK,
Json(serde_json::json!({ "windows": windows })),
)
.into_response(),
).into_response(),
Ok(ClientMessage::Error { message, .. }) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": message })),
)
.into_response(),
Ok(_) => (
StatusCode::BAD_GATEWAY,
Json(serde_json::json!({ "error": "Unexpected response from client" })),
)
.into_response(),
).into_response(),
Ok(_) => (StatusCode::BAD_GATEWAY, Json(serde_json::json!({ "error": "Unexpected response" }))).into_response(),
Err(e) => e.into_response(),
}
}
/// POST /sessions/:id/windows/minimize-all
/// POST /devices/:label/windows/minimize-all
pub async fn minimize_all(
Path(session_id): Path<String>,
Path(label): Path<String>,
State(state): State<AppState>,
) -> impl IntoResponse {
match dispatch(&state, &session_id, "minimize_all", |rid| {
match dispatch(&state, &label, "minimize_all", |rid| {
ServerMessage::MinimizeAllRequest { request_id: rid }
})
.await
{
}).await {
Ok(_) => (StatusCode::OK, Json(serde_json::json!({ "ok": true }))).into_response(),
Err(e) => e.into_response(),
}
}
/// POST /sessions/:id/windows/:window_id/focus
/// POST /devices/:label/windows/:window_id/focus
pub async fn focus_window(
Path((session_id, window_id)): Path<(String, u64)>,
Path((label, window_id)): Path<(String, u64)>,
State(state): State<AppState>,
) -> impl IntoResponse {
match dispatch(&state, &session_id, "focus_window", |rid| {
match dispatch(&state, &label, "focus_window", |rid| {
ServerMessage::FocusWindowRequest { request_id: rid, window_id }
})
.await
{
}).await {
Ok(_) => (StatusCode::OK, Json(serde_json::json!({ "ok": true }))).into_response(),
Err(e) => e.into_response(),
}
}
/// POST /sessions/:id/windows/:window_id/maximize
/// POST /devices/:label/windows/:window_id/maximize
pub async fn maximize_and_focus(
Path((session_id, window_id)): Path<(String, u64)>,
Path((label, window_id)): Path<(String, u64)>,
State(state): State<AppState>,
) -> impl IntoResponse {
match dispatch(&state, &session_id, "maximize_and_focus", |rid| {
match dispatch(&state, &label, "maximize_and_focus", |rid| {
ServerMessage::MaximizeAndFocusRequest { request_id: rid, window_id }
})
.await
{
}).await {
Ok(_) => (StatusCode::OK, Json(serde_json::json!({ "ok": true }))).into_response(),
Err(e) => e.into_response(),
}
@ -376,41 +263,32 @@ pub async fn maximize_and_focus(
/// GET /version — server version (public, no auth)
pub async fn server_version() -> impl IntoResponse {
Json(serde_json::json!({
"version": env!("CARGO_PKG_VERSION"),
"commit": env!("GIT_COMMIT"),
}))
}
/// GET /sessions/:id/version — client version
/// GET /devices/:label/version — client version
pub async fn client_version(
Path(session_id): Path<String>,
Path(label): Path<String>,
State(state): State<AppState>,
) -> impl IntoResponse {
match dispatch(&state, &session_id, "version", |rid| {
match dispatch(&state, &label, "version", |rid| {
ServerMessage::VersionRequest { request_id: rid }
})
.await
{
}).await {
Ok(ClientMessage::VersionResponse { version, commit, .. }) => (
StatusCode::OK,
Json(serde_json::json!({ "version": version, "commit": commit })),
)
.into_response(),
).into_response(),
Ok(ClientMessage::Error { message, .. }) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": message })),
)
.into_response(),
Ok(_) => (
StatusCode::BAD_GATEWAY,
Json(serde_json::json!({ "error": "Unexpected response from client" })),
)
.into_response(),
).into_response(),
Ok(_) => (StatusCode::BAD_GATEWAY, Json(serde_json::json!({ "error": "Unexpected response" }))).into_response(),
Err(e) => e.into_response(),
}
}
/// POST /sessions/:id/upload
/// POST /devices/:label/upload
#[derive(Deserialize)]
pub struct UploadBody {
pub path: String,
@ -418,59 +296,49 @@ pub struct UploadBody {
}
pub async fn upload_file(
Path(session_id): Path<String>,
Path(label): Path<String>,
State(state): State<AppState>,
Json(body): Json<UploadBody>,
) -> impl IntoResponse {
match dispatch(&state, &session_id, "upload", |rid| ServerMessage::UploadRequest {
match dispatch(&state, &label, "upload", |rid| ServerMessage::UploadRequest {
request_id: rid,
path: body.path.clone(),
content_base64: body.content_base64.clone(),
})
.await
{
}).await {
Ok(_) => (StatusCode::OK, Json(serde_json::json!({ "ok": true }))).into_response(),
Err(e) => e.into_response(),
}
}
/// GET /sessions/:id/download?path=...
/// GET /devices/:label/download?path=...
#[derive(Deserialize)]
pub struct DownloadQuery {
pub path: String,
}
pub async fn download_file(
Path(session_id): Path<String>,
Path(label): Path<String>,
State(state): State<AppState>,
Query(query): Query<DownloadQuery>,
) -> impl IntoResponse {
match dispatch(&state, &session_id, "download", |rid| ServerMessage::DownloadRequest {
match dispatch(&state, &label, "download", |rid| ServerMessage::DownloadRequest {
request_id: rid,
path: query.path.clone(),
})
.await
{
}).await {
Ok(ClientMessage::DownloadResponse { content_base64, size, .. }) => (
StatusCode::OK,
Json(serde_json::json!({ "content_base64": content_base64, "size": size })),
)
.into_response(),
).into_response(),
Ok(ClientMessage::Error { message, .. }) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": message })),
)
.into_response(),
Ok(_) => (
StatusCode::BAD_GATEWAY,
Json(serde_json::json!({ "error": "Unexpected response from client" })),
)
.into_response(),
).into_response(),
Ok(_) => (StatusCode::BAD_GATEWAY, Json(serde_json::json!({ "error": "Unexpected response" }))).into_response(),
Err(e) => e.into_response(),
}
}
/// POST /sessions/:id/run
/// POST /devices/:label/run
#[derive(Deserialize)]
pub struct RunBody {
pub program: String,
@ -479,73 +347,61 @@ pub struct RunBody {
}
pub async fn run_program(
Path(session_id): Path<String>,
Path(label): Path<String>,
State(state): State<AppState>,
Json(body): Json<RunBody>,
) -> impl IntoResponse {
match dispatch(&state, &session_id, "run", |rid| ServerMessage::RunRequest {
match dispatch(&state, &label, "run", |rid| ServerMessage::RunRequest {
request_id: rid,
program: body.program.clone(),
args: body.args.clone(),
})
.await
{
}).await {
Ok(_) => (StatusCode::OK, Json(serde_json::json!({ "ok": true }))).into_response(),
Err(e) => e.into_response(),
}
}
/// GET /sessions/:id/clipboard
/// GET /devices/:label/clipboard
pub async fn clipboard_get(
Path(session_id): Path<String>,
Path(label): Path<String>,
State(state): State<AppState>,
) -> impl IntoResponse {
match dispatch(&state, &session_id, "clipboard_get", |rid| {
match dispatch(&state, &label, "clipboard_get", |rid| {
ServerMessage::ClipboardGetRequest { request_id: rid }
})
.await
{
}).await {
Ok(ClientMessage::ClipboardGetResponse { text, .. }) => (
StatusCode::OK,
Json(serde_json::json!({ "text": text })),
)
.into_response(),
).into_response(),
Ok(ClientMessage::Error { message, .. }) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": message })),
)
.into_response(),
Ok(_) => (
StatusCode::BAD_GATEWAY,
Json(serde_json::json!({ "error": "Unexpected response from client" })),
)
.into_response(),
).into_response(),
Ok(_) => (StatusCode::BAD_GATEWAY, Json(serde_json::json!({ "error": "Unexpected response" }))).into_response(),
Err(e) => e.into_response(),
}
}
/// POST /sessions/:id/clipboard
/// POST /devices/:label/clipboard
#[derive(Deserialize)]
pub struct ClipboardSetBody {
pub text: String,
}
pub async fn clipboard_set(
Path(session_id): Path<String>,
Path(label): Path<String>,
State(state): State<AppState>,
Json(body): Json<ClipboardSetBody>,
) -> impl IntoResponse {
match dispatch(&state, &session_id, "clipboard_set", |rid| {
match dispatch(&state, &label, "clipboard_set", |rid| {
ServerMessage::ClipboardSetRequest { request_id: rid, text: body.text.clone() }
})
.await
{
}).await {
Ok(_) => (StatusCode::OK, Json(serde_json::json!({ "ok": true }))).into_response(),
Err(e) => e.into_response(),
}
}
/// POST /sessions/:id/prompt
/// POST /devices/:label/prompt
#[derive(Deserialize)]
pub struct PromptBody {
pub message: String,
@ -553,17 +409,15 @@ pub struct PromptBody {
}
pub async fn prompt_user(
Path(session_id): Path<String>,
Path(label): Path<String>,
State(state): State<AppState>,
Json(body): Json<PromptBody>,
) -> impl IntoResponse {
match dispatch(&state, &session_id, "prompt", |rid| ServerMessage::PromptRequest {
match dispatch(&state, &label, "prompt", |rid| ServerMessage::PromptRequest {
request_id: rid,
message: body.message.clone(),
title: body.title.clone(),
})
.await
{
}).await {
Ok(ClientMessage::PromptResponse { answer, .. }) => {
(StatusCode::OK, Json(serde_json::json!({ "ok": true, "answer": answer }))).into_response()
}
@ -571,32 +425,3 @@ pub async fn prompt_user(
Err(e) => e.into_response(),
}
}
/// POST /sessions/:id/label
#[derive(Deserialize)]
pub struct LabelBody {
pub label: String,
}
pub async fn set_label(
Path(session_id): Path<String>,
State(state): State<AppState>,
Json(body): Json<LabelBody>,
) -> impl IntoResponse {
let id = match session_id.parse::<Uuid>() {
Ok(id) => id,
Err(_) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": format!("Invalid session id: '{session_id}'") })),
)
.into_response()
}
};
if state.sessions.set_label(&id, body.label.clone()) {
(StatusCode::OK, Json(serde_json::json!({ "ok": true }))).into_response()
} else {
not_found(&session_id).into_response()
}
}

View file

@ -31,7 +31,7 @@ async fn main() -> anyhow::Result<()> {
.init();
const GIT_COMMIT: &str = env!("GIT_COMMIT");
info!("helios-server v{} ({})", env!("CARGO_PKG_VERSION"), GIT_COMMIT);
info!("helios-server ({GIT_COMMIT})");
let api_key = std::env::var("HELIOS_API_KEY")
.unwrap_or_else(|_| "dev-secret".to_string());
@ -45,25 +45,22 @@ async fn main() -> anyhow::Result<()> {
};
let protected = Router::new()
.route("/sessions", get(api::list_sessions))
.route("/sessions/:id/screenshot", post(api::request_screenshot))
.route("/sessions/:id/exec", post(api::request_exec))
.route("/sessions/:id/click", post(api::request_click))
.route("/sessions/:id/type", post(api::request_type))
.route("/sessions/:id/label", post(api::set_label))
.route("/sessions/:id/prompt", post(api::prompt_user))
.route("/sessions/:id/windows", get(api::list_windows))
.route("/sessions/:id/windows/minimize-all", post(api::minimize_all))
.route("/sessions/:id/logs", get(api::logs))
.route("/sessions/:id/windows/:window_id/screenshot", post(api::window_screenshot))
.route("/sessions/:id/windows/:window_id/focus", post(api::focus_window))
.route("/sessions/:id/windows/:window_id/maximize", post(api::maximize_and_focus))
.route("/sessions/:id/version", get(api::client_version))
.route("/sessions/:id/upload", post(api::upload_file))
.route("/sessions/:id/download", get(api::download_file))
.route("/sessions/:id/run", post(api::run_program))
.route("/sessions/:id/clipboard", get(api::clipboard_get))
.route("/sessions/:id/clipboard", post(api::clipboard_set))
.route("/devices", get(api::list_devices))
.route("/devices/:label/screenshot", post(api::request_screenshot))
.route("/devices/:label/exec", post(api::request_exec))
.route("/devices/:label/prompt", post(api::prompt_user))
.route("/devices/:label/windows", get(api::list_windows))
.route("/devices/:label/windows/minimize-all", post(api::minimize_all))
.route("/devices/:label/logs", get(api::logs))
.route("/devices/:label/windows/:window_id/screenshot", post(api::window_screenshot))
.route("/devices/:label/windows/:window_id/focus", post(api::focus_window))
.route("/devices/:label/windows/:window_id/maximize", post(api::maximize_and_focus))
.route("/devices/:label/version", get(api::client_version))
.route("/devices/:label/upload", post(api::upload_file))
.route("/devices/:label/download", get(api::download_file))
.route("/devices/:label/run", post(api::run_program))
.route("/devices/:label/clipboard", get(api::clipboard_get))
.route("/devices/:label/clipboard", post(api::clipboard_set))
.layer(middleware::from_fn_with_state(state.clone(), require_api_key));
let app = Router::new()

View file

@ -4,11 +4,11 @@ use uuid::Uuid;
use serde::Serialize;
use helios_common::protocol::{ClientMessage, ServerMessage};
/// Represents one connected remote client
/// Represents one connected remote client.
/// The device label is the sole identifier — no session UUIDs exposed externally.
#[derive(Debug, Clone)]
pub struct Session {
pub id: Uuid,
pub label: Option<String>,
pub label: String,
/// Channel to send commands to the WS handler for this session
pub cmd_tx: mpsc::Sender<ServerMessage>,
}
@ -16,22 +16,20 @@ pub struct Session {
/// Serializable view of a session for the REST API
#[derive(Debug, Serialize)]
pub struct SessionInfo {
pub id: Uuid,
pub label: Option<String>,
pub label: String,
}
impl From<&Session> for SessionInfo {
fn from(s: &Session) -> Self {
SessionInfo {
id: s.id,
label: s.label.clone(),
}
}
}
pub struct SessionStore {
/// Active sessions by ID
sessions: DashMap<Uuid, Session>,
/// Active sessions keyed by device label
sessions: DashMap<String, Session>,
/// Pending request callbacks by request_id
pending: DashMap<Uuid, oneshot::Sender<ClientMessage>>,
}
@ -45,24 +43,15 @@ impl SessionStore {
}
pub fn insert(&self, session: Session) {
self.sessions.insert(session.id, session);
self.sessions.insert(session.label.clone(), session);
}
pub fn remove(&self, id: &Uuid) {
self.sessions.remove(id);
pub fn remove(&self, label: &str) {
self.sessions.remove(label);
}
pub fn get_cmd_tx(&self, id: &Uuid) -> Option<mpsc::Sender<ServerMessage>> {
self.sessions.get(id).map(|s| s.cmd_tx.clone())
}
pub fn set_label(&self, id: &Uuid, label: String) -> bool {
if let Some(mut s) = self.sessions.get_mut(id) {
s.label = Some(label);
true
} else {
false
}
pub fn get_cmd_tx(&self, label: &str) -> Option<mpsc::Sender<ServerMessage>> {
self.sessions.get(label).map(|s| s.cmd_tx.clone())
}
pub fn list(&self) -> Vec<SessionInfo> {
@ -77,7 +66,6 @@ impl SessionStore {
}
/// Deliver a client response to the waiting request handler.
/// Returns true if the request was found and resolved.
pub fn resolve_pending(&self, request_id: Uuid, msg: ClientMessage) -> bool {
if let Some((_, tx)) = self.pending.remove(&request_id) {
let _ = tx.send(msg);

View file

@ -19,32 +19,57 @@ pub async fn ws_upgrade(
}
async fn handle_socket(socket: WebSocket, state: AppState) {
let session_id = Uuid::new_v4();
let (cmd_tx, mut cmd_rx) = mpsc::channel::<ServerMessage>(64);
let (mut ws_tx, mut ws_rx) = socket.split();
// Register session (label filled in on Hello)
// Wait for the Hello message to get the device label
let label = loop {
match ws_rx.next().await {
Some(Ok(Message::Text(text))) => {
match serde_json::from_str::<ClientMessage>(&text) {
Ok(ClientMessage::Hello { label }) => {
if label.is_empty() {
warn!("Client sent empty label, disconnecting");
return;
}
break label;
}
Ok(_) => {
warn!("Expected Hello as first message, got something else");
return;
}
Err(e) => {
warn!("Invalid JSON on handshake: {e}");
return;
}
}
}
Some(Ok(Message::Close(_))) | None => return,
_ => continue,
}
};
// Register session by label
let session = Session {
id: session_id,
label: None,
label: label.clone(),
cmd_tx,
};
state.sessions.insert(session);
info!("Client connected: session={session_id}");
let (mut ws_tx, mut ws_rx) = socket.split();
info!("Client connected: device={label}");
// Spawn task: forward server commands → WS
let label_clone = label.clone();
let send_task = tokio::spawn(async move {
while let Some(msg) = cmd_rx.recv().await {
match serde_json::to_string(&msg) {
Ok(json) => {
if let Err(e) = ws_tx.send(Message::Text(json.into())).await {
error!("WS send error for session={session_id}: {e}");
error!("WS send error for device={label_clone}: {e}");
break;
}
}
Err(e) => {
error!("Serialization error for session={session_id}: {e}");
error!("Serialization error for device={label_clone}: {e}");
}
}
}
@ -55,36 +80,33 @@ async fn handle_socket(socket: WebSocket, state: AppState) {
match result {
Ok(Message::Text(text)) => {
match serde_json::from_str::<ClientMessage>(&text) {
Ok(msg) => handle_client_message(session_id, msg, &state).await,
Ok(msg) => handle_client_message(&label, msg, &state).await,
Err(e) => {
warn!("Invalid JSON from session={session_id}: {e} | raw={text}");
warn!("Invalid JSON from device={label}: {e}");
}
}
}
Ok(Message::Close(_)) => {
info!("Client disconnected gracefully: session={session_id}");
info!("Client disconnected gracefully: device={label}");
break;
}
Ok(Message::Ping(_)) | Ok(Message::Pong(_)) | Ok(Message::Binary(_)) => {}
Err(e) => {
error!("WS receive error for session={session_id}: {e}");
error!("WS receive error for device={label}: {e}");
break;
}
}
}
send_task.abort();
state.sessions.remove(&session_id);
info!("Session cleaned up: session={session_id}");
state.sessions.remove(&label);
info!("Session cleaned up: device={label}");
}
async fn handle_client_message(session_id: Uuid, msg: ClientMessage, state: &AppState) {
async fn handle_client_message(label: &str, msg: ClientMessage, state: &AppState) {
match &msg {
ClientMessage::Hello { label } => {
if let Some(lbl) = label {
state.sessions.set_label(&session_id, lbl.clone());
}
debug!("Hello from session={session_id}, label={label:?}");
ClientMessage::Hello { .. } => {
debug!("Duplicate Hello from device={label}, ignoring");
}
ClientMessage::ScreenshotResponse { request_id, .. }
| ClientMessage::ExecResponse { request_id, .. }
@ -98,7 +120,7 @@ async fn handle_client_message(session_id: Uuid, msg: ClientMessage, state: &App
| ClientMessage::Error { request_id, .. } => {
let rid = *request_id;
if !state.sessions.resolve_pending(rid, msg) {
warn!("No pending request for request_id={rid} (session={session_id})");
warn!("No pending request for request_id={rid} (device={label})");
}
}
}