feat: implement CSRF protection on RPC layer

Double-submit cookie pattern: backend generates csrf_token cookie on login
(non-HttpOnly so JS can read it), validates X-CSRF-Token header matches
cookie on all authenticated RPC calls. Returns 403 if missing/mismatched.
Frontend reads cookie and sends header automatically.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Dorian
2026-03-11 00:46:52 +00:00
parent b9efd1b3d0
commit a7653d4c8b
3 changed files with 100 additions and 14 deletions

View File

@@ -32,6 +32,7 @@ use serde::{Deserialize, Serialize};
use std::net::IpAddr;
use std::sync::{Arc, Mutex};
use tracing::{debug, error};
use rand::Rng;
#[derive(Debug, Deserialize)]
struct RpcRequest {
@@ -147,6 +148,37 @@ impl RpcHandler {
}
}
// CSRF protection: validate X-CSRF-Token header for authenticated methods
if !is_unauthenticated {
let csrf_cookie = extract_csrf_cookie(&parts.headers);
let csrf_header = parts
.headers
.get("x-csrf-token")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
match (&csrf_cookie, &csrf_header) {
(Some(cookie), Some(header)) if cookie == header => { /* valid */ }
_ => {
let rpc_resp = RpcResponse {
result: None,
error: Some(RpcError {
code: 403,
message: "CSRF token missing or invalid".to_string(),
data: None,
}),
};
let resp_body = serde_json::to_vec(&rpc_resp)
.context("Failed to serialize response")?;
return Ok(Response::builder()
.status(StatusCode::FORBIDDEN)
.header("Content-Type", "application/json")
.body(hyper::Body::from(resp_body))
.unwrap());
}
}
}
// Rate limit login attempts
if rpc_req.method == "auth.login" {
let client_ip = extract_client_ip(&parts.headers);
@@ -388,12 +420,19 @@ impl RpcHandler {
if let Ok(Some(totp_data)) = self.auth_manager.get_totp_data().await {
if let Ok(secret) = crate::totp::decrypt_secret(&totp_data, password) {
let token = self.session_store.create_pending(secret).await;
response.headers_mut().insert(
let csrf_token = generate_csrf_token();
response.headers_mut().append(
"Set-Cookie",
format!("session={}; HttpOnly; SameSite=Strict; Path=/{}", token, self.cookie_suffix())
.parse()
.unwrap(),
);
response.headers_mut().append(
"Set-Cookie",
format!("csrf_token={}; SameSite=Strict; Path=/{}", csrf_token, self.cookie_suffix())
.parse()
.unwrap(),
);
// Override the response body to indicate TOTP is required
let totp_body = serde_json::json!({
"result": { "requires_totp": true },
@@ -407,31 +446,42 @@ impl RpcHandler {
} else {
// No 2FA: create a full session immediately
let token = self.session_store.create().await;
response.headers_mut().insert(
let csrf_token = generate_csrf_token();
response.headers_mut().append(
"Set-Cookie",
format!("session={}; HttpOnly; SameSite=Strict; Path=/{}", token, self.cookie_suffix())
.parse()
.unwrap(),
);
response.headers_mut().append(
"Set-Cookie",
format!("csrf_token={}; SameSite=Strict; Path=/{}", csrf_token, self.cookie_suffix())
.parse()
.unwrap(),
);
}
}
// On successful TOTP verification, the session is already upgraded to full
// (handled inside handle_login_totp/handle_login_backup)
// On logout, invalidate session and expire the cookie
// On logout, invalidate session and expire cookies
if rpc_req.method == "auth.logout" {
if let Some(token) = &session_token {
self.session_store.remove(token).await;
}
let logout_cookie = if self.config.dev_mode {
"session=; HttpOnly; SameSite=Strict; Path=/; Max-Age=0".to_string()
} else {
"session=; HttpOnly; SameSite=Strict; Path=/; Max-Age=0; Secure".to_string()
};
response.headers_mut().insert(
let secure_suffix = if self.config.dev_mode { "" } else { "; Secure" };
response.headers_mut().append(
"Set-Cookie",
logout_cookie.parse().unwrap(),
format!("session=; HttpOnly; SameSite=Strict; Path=/; Max-Age=0{}", secure_suffix)
.parse()
.unwrap(),
);
response.headers_mut().append(
"Set-Cookie",
format!("csrf_token=; SameSite=Strict; Path=/; Max-Age=0{}", secure_suffix)
.parse()
.unwrap(),
);
}
@@ -448,6 +498,31 @@ impl RpcHandler {
}
}
/// Generate a random CSRF token (32-byte hex string).
fn generate_csrf_token() -> String {
let mut bytes = [0u8; 32];
rand::thread_rng().fill(&mut bytes);
hex::encode(bytes)
}
/// Extract the csrf_token cookie value from headers.
fn extract_csrf_cookie(headers: &hyper::HeaderMap) -> Option<String> {
for value in headers.get_all("cookie") {
if let Ok(s) = value.to_str() {
for part in s.split(';') {
let part = part.trim();
if let Some(val) = part.strip_prefix("csrf_token=") {
let val = val.trim();
if !val.is_empty() {
return Some(val.to_string());
}
}
}
}
}
None
}
/// Extract the client IP from request headers (X-Real-IP or X-Forwarded-For).
fn extract_client_ip(headers: &hyper::HeaderMap) -> IpAddr {
headers