Files
el/ui/crates/el-auth/src/jwt.rs
T

315 lines
11 KiB
Rust

//! JWT provider — sign and verify JSON Web Tokens.
//!
//! Uses HMAC-SHA256 (HS256) for signing. Does NOT use the `jsonwebtoken` crate
//! to keep dependencies minimal; implements the JWT spec directly.
//!
//! Format: base64url(header).base64url(payload).base64url(signature)
use crate::{AuthContext, AuthError, AuthProvider, AuthResult, AuthUser, RoleRegistry};
use hmac::{Hmac, Mac};
use sha2::Sha256;
type HmacSha256 = Hmac<Sha256>;
/// JWT claims payload.
///
/// The `session_id` field is included so the Engram session node can be
/// validated on every request, enabling server-side session invalidation
/// even for stateless JWTs.
#[derive(Debug, Clone)]
pub struct JwtClaims {
pub sub: String, // user ID
pub email: String,
pub name: String,
pub roles: Vec<String>,
/// The Engram Session node ID. Used by `EngramSessionStore` to validate
/// the session graph node on every request, enabling server-side logout.
pub session_id: Option<String>,
pub iat: u64, // issued-at (unix seconds)
pub exp: u64, // expiry (unix seconds)
}
impl JwtClaims {
pub fn new(user: &AuthUser, roles: Vec<String>, ttl_seconds: u64) -> Self {
let now = unix_now();
Self {
sub: user.id.clone(),
email: user.email.clone(),
name: user.name.clone(),
roles,
session_id: None,
iat: now,
exp: now + ttl_seconds,
}
}
/// Create claims with an Engram session ID embedded.
pub fn new_with_session(
user: &AuthUser,
roles: Vec<String>,
session_id: impl Into<String>,
ttl_seconds: u64,
) -> Self {
let mut claims = Self::new(user, roles, ttl_seconds);
claims.session_id = Some(session_id.into());
claims
}
pub fn is_expired(&self) -> bool {
unix_now() > self.exp
}
/// Serialize claims to JSON (manual, no serde dependency).
pub fn to_json(&self) -> String {
let roles_json = self
.roles
.iter()
.map(|r| format!("\"{}\"", r))
.collect::<Vec<_>>()
.join(",");
// Build the JSON manually, inserting session_id only when present.
let mut json = format!(
"{{\"sub\":\"{}\",\"email\":\"{}\",\"name\":\"{}\",\"roles\":[{}]",
self.sub, self.email, self.name, roles_json
);
if let Some(sid) = &self.session_id {
json.push_str(&format!(",\"session_id\":\"{}\"", sid));
}
json.push_str(&format!(",\"iat\":{},\"exp\":{}}}", self.iat, self.exp));
json
}
/// Deserialize claims from JSON (manual parser).
pub fn from_json(json: &str) -> Option<Self> {
let sub = extract_str(json, "sub")?;
let email = extract_str(json, "email").unwrap_or_default();
let name = extract_str(json, "name").unwrap_or_default();
let iat = extract_u64(json, "iat").unwrap_or(0);
let exp = extract_u64(json, "exp").unwrap_or(0);
let roles = extract_str_array(json, "roles");
let session_id = extract_str(json, "session_id");
Some(Self { sub, email, name, roles, session_id, iat, exp })
}
}
/// JWT provider — issues and verifies HS256 JWTs.
pub struct JwtProvider {
secret: Vec<u8>,
/// Token TTL in seconds (default: 3600 = 1 hour).
pub ttl_seconds: u64,
}
impl JwtProvider {
pub fn new(secret: impl Into<Vec<u8>>) -> Self {
Self { secret: secret.into(), ttl_seconds: 3600 }
}
pub fn from_env(env_var: &str) -> AuthResult<Self> {
let secret = std::env::var(env_var).map_err(|_| {
AuthError::Config(format!("env var {} not set", env_var))
})?;
Ok(Self::new(secret.into_bytes()))
}
pub fn with_ttl(mut self, seconds: u64) -> Self {
self.ttl_seconds = seconds;
self
}
/// Sign a token with HMAC-SHA256.
fn sign(&self, header_payload: &str) -> String {
let mut mac = HmacSha256::new_from_slice(&self.secret)
.expect("HMAC can take key of any size");
mac.update(header_payload.as_bytes());
let result = mac.finalize();
base64url_encode(&result.into_bytes())
}
/// Encode a JWT token from claims.
pub fn encode(&self, claims: &JwtClaims) -> String {
let header = base64url_encode(b"{\"alg\":\"HS256\",\"typ\":\"JWT\"}");
let payload = base64url_encode(claims.to_json().as_bytes());
let header_payload = format!("{}.{}", header, payload);
let signature = self.sign(&header_payload);
format!("{}.{}", header_payload, signature)
}
/// Decode and verify a JWT token.
pub fn decode(&self, token: &str) -> AuthResult<JwtClaims> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(AuthError::TokenInvalid("not a valid JWT".into()));
}
let header_payload = format!("{}.{}", parts[0], parts[1]);
let expected_sig = self.sign(&header_payload);
if !constant_time_eq(parts[2], &expected_sig) {
return Err(AuthError::TokenInvalid("signature mismatch".into()));
}
let payload_bytes = base64url_decode(parts[1])
.ok_or_else(|| AuthError::TokenInvalid("payload decode failed".into()))?;
let payload_str = String::from_utf8(payload_bytes)
.map_err(|_| AuthError::TokenInvalid("payload not utf8".into()))?;
let claims = JwtClaims::from_json(&payload_str)
.ok_or_else(|| AuthError::TokenInvalid("claims parse failed".into()))?;
if claims.is_expired() {
return Err(AuthError::TokenExpired);
}
Ok(claims)
}
}
impl AuthProvider for JwtProvider {
fn name(&self) -> &'static str {
"jwt"
}
fn verify(&self, token: &str) -> AuthResult<AuthContext> {
let claims = self.decode(token)?;
let user = AuthUser::new(&claims.sub, &claims.email, &claims.name);
Ok(AuthContext::authenticated(user, claims.roles, token))
}
fn issue(&self, user: AuthUser, _role_registry: &RoleRegistry) -> AuthResult<String> {
let claims = JwtClaims::new(&user, Vec::new(), self.ttl_seconds);
Ok(self.encode(&claims))
}
fn revoke(&self, _token: &str) -> AuthResult<()> {
// JWTs are stateless — revocation requires a blocklist.
// TODO: maintain a revocation list (in-memory or Redis).
Ok(())
}
}
// ── Crypto helpers ─────────────────────────────────────────────────────────────
fn base64url_encode(input: &[u8]) -> String {
const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
let mut out = String::new();
for chunk in input.chunks(3) {
let b0 = chunk[0] as u32;
let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
let n = (b0 << 16) | (b1 << 8) | b2;
out.push(CHARS[((n >> 18) & 63) as usize] as char);
out.push(CHARS[((n >> 12) & 63) as usize] as char);
if chunk.len() > 1 {
out.push(CHARS[((n >> 6) & 63) as usize] as char);
}
if chunk.len() > 2 {
out.push(CHARS[(n & 63) as usize] as char);
}
}
out
}
fn base64url_decode(input: &str) -> Option<Vec<u8>> {
// Pad if needed
let mut s = input.replace('-', "+").replace('_', "/");
while s.len() % 4 != 0 {
s.push('=');
}
base64_decode_standard(&s)
}
fn base64_decode_standard(input: &str) -> Option<Vec<u8>> {
const TABLE: [u8; 128] = {
let mut t = [255u8; 128];
let chars = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut i = 0usize;
while i < chars.len() {
t[chars[i] as usize] = i as u8;
i += 1;
}
t
};
let input = input.trim_end_matches('=');
let mut out = Vec::new();
let bytes = input.as_bytes();
let mut i = 0;
while i + 3 < bytes.len() {
let a = TABLE.get(bytes[i] as usize).copied().filter(|&v| v != 255)?;
let b = TABLE.get(bytes[i+1] as usize).copied().filter(|&v| v != 255)?;
let c = TABLE.get(bytes[i+2] as usize).copied().filter(|&v| v != 255)?;
let d = TABLE.get(bytes[i+3] as usize).copied().filter(|&v| v != 255)?;
let n = ((a as u32) << 18) | ((b as u32) << 12) | ((c as u32) << 6) | (d as u32);
out.push((n >> 16) as u8);
out.push((n >> 8) as u8);
out.push(n as u8);
i += 4;
}
// Handle remaining bytes
if i + 2 == bytes.len() {
let a = TABLE.get(bytes[i] as usize).copied().filter(|&v| v != 255)?;
let b = TABLE.get(bytes[i+1] as usize).copied().filter(|&v| v != 255)?;
out.push(((a as u32) << 2 | (b as u32) >> 4) as u8);
} else if i + 3 == bytes.len() {
let a = TABLE.get(bytes[i] as usize).copied().filter(|&v| v != 255)?;
let b = TABLE.get(bytes[i+1] as usize).copied().filter(|&v| v != 255)?;
let c = TABLE.get(bytes[i+2] as usize).copied().filter(|&v| v != 255)?;
let n = ((a as u32) << 10) | ((b as u32) << 4) | ((c as u32) >> 2);
out.push((n >> 8) as u8);
out.push(n as u8);
}
Some(out)
}
fn constant_time_eq(a: &str, b: &str) -> bool {
if a.len() != b.len() {
return false;
}
a.bytes()
.zip(b.bytes())
.fold(0u8, |acc, (x, y)| acc | (x ^ y))
== 0
}
fn unix_now() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
// ── Minimal JSON field extractors ─────────────────────────────────────────────
fn extract_str(json: &str, key: &str) -> Option<String> {
let pattern = format!("\"{}\":\"", key);
let start = json.find(&pattern)? + pattern.len();
let rest = &json[start..];
let end = rest.find('"')?;
Some(rest[..end].to_string())
}
fn extract_u64(json: &str, key: &str) -> Option<u64> {
let pattern = format!("\"{}\":", key);
let start = json.find(&pattern)? + pattern.len();
let rest = &json[start..];
let end = rest.find(|c: char| !c.is_ascii_digit()).unwrap_or(rest.len());
rest[..end].parse().ok()
}
fn extract_str_array(json: &str, key: &str) -> Vec<String> {
let pattern = format!("\"{}\":[", key);
let start = match json.find(&pattern) {
None => return Vec::new(),
Some(s) => s + pattern.len(),
};
let rest = &json[start..];
let end = rest.find(']').unwrap_or(rest.len());
let content = &rest[..end];
content
.split(',')
.filter_map(|s| {
let s = s.trim().trim_matches('"');
if s.is_empty() { None } else { Some(s.to_string()) }
})
.collect()
}