315 lines
11 KiB
Rust
315 lines
11 KiB
Rust
//! JWT provider — sign and verify JSON Web Tokens.
|
|
//!
|
|
//! Uses HMAC-SHA256 (HS256) for signing. Does NOT use the `jsonwebtoken` crate
|
|
//! to keep dependencies minimal; implements the JWT spec directly.
|
|
//!
|
|
//! Format: base64url(header).base64url(payload).base64url(signature)
|
|
|
|
use crate::{AuthContext, AuthError, AuthProvider, AuthResult, AuthUser, RoleRegistry};
|
|
use hmac::{Hmac, Mac};
|
|
use sha2::Sha256;
|
|
|
|
type HmacSha256 = Hmac<Sha256>;
|
|
|
|
/// JWT claims payload.
|
|
///
|
|
/// The `session_id` field is included so the Engram session node can be
|
|
/// validated on every request, enabling server-side session invalidation
|
|
/// even for stateless JWTs.
|
|
#[derive(Debug, Clone)]
|
|
pub struct JwtClaims {
|
|
pub sub: String, // user ID
|
|
pub email: String,
|
|
pub name: String,
|
|
pub roles: Vec<String>,
|
|
/// The Engram Session node ID. Used by `EngramSessionStore` to validate
|
|
/// the session graph node on every request, enabling server-side logout.
|
|
pub session_id: Option<String>,
|
|
pub iat: u64, // issued-at (unix seconds)
|
|
pub exp: u64, // expiry (unix seconds)
|
|
}
|
|
|
|
impl JwtClaims {
|
|
pub fn new(user: &AuthUser, roles: Vec<String>, ttl_seconds: u64) -> Self {
|
|
let now = unix_now();
|
|
Self {
|
|
sub: user.id.clone(),
|
|
email: user.email.clone(),
|
|
name: user.name.clone(),
|
|
roles,
|
|
session_id: None,
|
|
iat: now,
|
|
exp: now + ttl_seconds,
|
|
}
|
|
}
|
|
|
|
/// Create claims with an Engram session ID embedded.
|
|
pub fn new_with_session(
|
|
user: &AuthUser,
|
|
roles: Vec<String>,
|
|
session_id: impl Into<String>,
|
|
ttl_seconds: u64,
|
|
) -> Self {
|
|
let mut claims = Self::new(user, roles, ttl_seconds);
|
|
claims.session_id = Some(session_id.into());
|
|
claims
|
|
}
|
|
|
|
pub fn is_expired(&self) -> bool {
|
|
unix_now() > self.exp
|
|
}
|
|
|
|
/// Serialize claims to JSON (manual, no serde dependency).
|
|
pub fn to_json(&self) -> String {
|
|
let roles_json = self
|
|
.roles
|
|
.iter()
|
|
.map(|r| format!("\"{}\"", r))
|
|
.collect::<Vec<_>>()
|
|
.join(",");
|
|
// Build the JSON manually, inserting session_id only when present.
|
|
let mut json = format!(
|
|
"{{\"sub\":\"{}\",\"email\":\"{}\",\"name\":\"{}\",\"roles\":[{}]",
|
|
self.sub, self.email, self.name, roles_json
|
|
);
|
|
if let Some(sid) = &self.session_id {
|
|
json.push_str(&format!(",\"session_id\":\"{}\"", sid));
|
|
}
|
|
json.push_str(&format!(",\"iat\":{},\"exp\":{}}}", self.iat, self.exp));
|
|
json
|
|
}
|
|
|
|
/// Deserialize claims from JSON (manual parser).
|
|
pub fn from_json(json: &str) -> Option<Self> {
|
|
let sub = extract_str(json, "sub")?;
|
|
let email = extract_str(json, "email").unwrap_or_default();
|
|
let name = extract_str(json, "name").unwrap_or_default();
|
|
let iat = extract_u64(json, "iat").unwrap_or(0);
|
|
let exp = extract_u64(json, "exp").unwrap_or(0);
|
|
let roles = extract_str_array(json, "roles");
|
|
let session_id = extract_str(json, "session_id");
|
|
Some(Self { sub, email, name, roles, session_id, iat, exp })
|
|
}
|
|
}
|
|
|
|
/// JWT provider — issues and verifies HS256 JWTs.
|
|
pub struct JwtProvider {
|
|
secret: Vec<u8>,
|
|
/// Token TTL in seconds (default: 3600 = 1 hour).
|
|
pub ttl_seconds: u64,
|
|
}
|
|
|
|
impl JwtProvider {
|
|
pub fn new(secret: impl Into<Vec<u8>>) -> Self {
|
|
Self { secret: secret.into(), ttl_seconds: 3600 }
|
|
}
|
|
|
|
pub fn from_env(env_var: &str) -> AuthResult<Self> {
|
|
let secret = std::env::var(env_var).map_err(|_| {
|
|
AuthError::Config(format!("env var {} not set", env_var))
|
|
})?;
|
|
Ok(Self::new(secret.into_bytes()))
|
|
}
|
|
|
|
pub fn with_ttl(mut self, seconds: u64) -> Self {
|
|
self.ttl_seconds = seconds;
|
|
self
|
|
}
|
|
|
|
/// Sign a token with HMAC-SHA256.
|
|
fn sign(&self, header_payload: &str) -> String {
|
|
let mut mac = HmacSha256::new_from_slice(&self.secret)
|
|
.expect("HMAC can take key of any size");
|
|
mac.update(header_payload.as_bytes());
|
|
let result = mac.finalize();
|
|
base64url_encode(&result.into_bytes())
|
|
}
|
|
|
|
/// Encode a JWT token from claims.
|
|
pub fn encode(&self, claims: &JwtClaims) -> String {
|
|
let header = base64url_encode(b"{\"alg\":\"HS256\",\"typ\":\"JWT\"}");
|
|
let payload = base64url_encode(claims.to_json().as_bytes());
|
|
let header_payload = format!("{}.{}", header, payload);
|
|
let signature = self.sign(&header_payload);
|
|
format!("{}.{}", header_payload, signature)
|
|
}
|
|
|
|
/// Decode and verify a JWT token.
|
|
pub fn decode(&self, token: &str) -> AuthResult<JwtClaims> {
|
|
let parts: Vec<&str> = token.split('.').collect();
|
|
if parts.len() != 3 {
|
|
return Err(AuthError::TokenInvalid("not a valid JWT".into()));
|
|
}
|
|
|
|
let header_payload = format!("{}.{}", parts[0], parts[1]);
|
|
let expected_sig = self.sign(&header_payload);
|
|
if !constant_time_eq(parts[2], &expected_sig) {
|
|
return Err(AuthError::TokenInvalid("signature mismatch".into()));
|
|
}
|
|
|
|
let payload_bytes = base64url_decode(parts[1])
|
|
.ok_or_else(|| AuthError::TokenInvalid("payload decode failed".into()))?;
|
|
let payload_str = String::from_utf8(payload_bytes)
|
|
.map_err(|_| AuthError::TokenInvalid("payload not utf8".into()))?;
|
|
|
|
let claims = JwtClaims::from_json(&payload_str)
|
|
.ok_or_else(|| AuthError::TokenInvalid("claims parse failed".into()))?;
|
|
|
|
if claims.is_expired() {
|
|
return Err(AuthError::TokenExpired);
|
|
}
|
|
|
|
Ok(claims)
|
|
}
|
|
}
|
|
|
|
impl AuthProvider for JwtProvider {
|
|
fn name(&self) -> &'static str {
|
|
"jwt"
|
|
}
|
|
|
|
fn verify(&self, token: &str) -> AuthResult<AuthContext> {
|
|
let claims = self.decode(token)?;
|
|
let user = AuthUser::new(&claims.sub, &claims.email, &claims.name);
|
|
Ok(AuthContext::authenticated(user, claims.roles, token))
|
|
}
|
|
|
|
fn issue(&self, user: AuthUser, _role_registry: &RoleRegistry) -> AuthResult<String> {
|
|
let claims = JwtClaims::new(&user, Vec::new(), self.ttl_seconds);
|
|
Ok(self.encode(&claims))
|
|
}
|
|
|
|
fn revoke(&self, _token: &str) -> AuthResult<()> {
|
|
// JWTs are stateless — revocation requires a blocklist.
|
|
// TODO: maintain a revocation list (in-memory or Redis).
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
// ── Crypto helpers ─────────────────────────────────────────────────────────────
|
|
|
|
fn base64url_encode(input: &[u8]) -> String {
|
|
const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
|
|
let mut out = String::new();
|
|
for chunk in input.chunks(3) {
|
|
let b0 = chunk[0] as u32;
|
|
let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
|
|
let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
|
|
let n = (b0 << 16) | (b1 << 8) | b2;
|
|
out.push(CHARS[((n >> 18) & 63) as usize] as char);
|
|
out.push(CHARS[((n >> 12) & 63) as usize] as char);
|
|
if chunk.len() > 1 {
|
|
out.push(CHARS[((n >> 6) & 63) as usize] as char);
|
|
}
|
|
if chunk.len() > 2 {
|
|
out.push(CHARS[(n & 63) as usize] as char);
|
|
}
|
|
}
|
|
out
|
|
}
|
|
|
|
fn base64url_decode(input: &str) -> Option<Vec<u8>> {
|
|
// Pad if needed
|
|
let mut s = input.replace('-', "+").replace('_', "/");
|
|
while s.len() % 4 != 0 {
|
|
s.push('=');
|
|
}
|
|
base64_decode_standard(&s)
|
|
}
|
|
|
|
fn base64_decode_standard(input: &str) -> Option<Vec<u8>> {
|
|
const TABLE: [u8; 128] = {
|
|
let mut t = [255u8; 128];
|
|
let chars = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
|
let mut i = 0usize;
|
|
while i < chars.len() {
|
|
t[chars[i] as usize] = i as u8;
|
|
i += 1;
|
|
}
|
|
t
|
|
};
|
|
|
|
let input = input.trim_end_matches('=');
|
|
let mut out = Vec::new();
|
|
let bytes = input.as_bytes();
|
|
let mut i = 0;
|
|
while i + 3 < bytes.len() {
|
|
let a = TABLE.get(bytes[i] as usize).copied().filter(|&v| v != 255)?;
|
|
let b = TABLE.get(bytes[i+1] as usize).copied().filter(|&v| v != 255)?;
|
|
let c = TABLE.get(bytes[i+2] as usize).copied().filter(|&v| v != 255)?;
|
|
let d = TABLE.get(bytes[i+3] as usize).copied().filter(|&v| v != 255)?;
|
|
let n = ((a as u32) << 18) | ((b as u32) << 12) | ((c as u32) << 6) | (d as u32);
|
|
out.push((n >> 16) as u8);
|
|
out.push((n >> 8) as u8);
|
|
out.push(n as u8);
|
|
i += 4;
|
|
}
|
|
// Handle remaining bytes
|
|
if i + 2 == bytes.len() {
|
|
let a = TABLE.get(bytes[i] as usize).copied().filter(|&v| v != 255)?;
|
|
let b = TABLE.get(bytes[i+1] as usize).copied().filter(|&v| v != 255)?;
|
|
out.push(((a as u32) << 2 | (b as u32) >> 4) as u8);
|
|
} else if i + 3 == bytes.len() {
|
|
let a = TABLE.get(bytes[i] as usize).copied().filter(|&v| v != 255)?;
|
|
let b = TABLE.get(bytes[i+1] as usize).copied().filter(|&v| v != 255)?;
|
|
let c = TABLE.get(bytes[i+2] as usize).copied().filter(|&v| v != 255)?;
|
|
let n = ((a as u32) << 10) | ((b as u32) << 4) | ((c as u32) >> 2);
|
|
out.push((n >> 8) as u8);
|
|
out.push(n as u8);
|
|
}
|
|
Some(out)
|
|
}
|
|
|
|
fn constant_time_eq(a: &str, b: &str) -> bool {
|
|
if a.len() != b.len() {
|
|
return false;
|
|
}
|
|
a.bytes()
|
|
.zip(b.bytes())
|
|
.fold(0u8, |acc, (x, y)| acc | (x ^ y))
|
|
== 0
|
|
}
|
|
|
|
fn unix_now() -> u64 {
|
|
std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.map(|d| d.as_secs())
|
|
.unwrap_or(0)
|
|
}
|
|
|
|
// ── Minimal JSON field extractors ─────────────────────────────────────────────
|
|
|
|
fn extract_str(json: &str, key: &str) -> Option<String> {
|
|
let pattern = format!("\"{}\":\"", key);
|
|
let start = json.find(&pattern)? + pattern.len();
|
|
let rest = &json[start..];
|
|
let end = rest.find('"')?;
|
|
Some(rest[..end].to_string())
|
|
}
|
|
|
|
fn extract_u64(json: &str, key: &str) -> Option<u64> {
|
|
let pattern = format!("\"{}\":", key);
|
|
let start = json.find(&pattern)? + pattern.len();
|
|
let rest = &json[start..];
|
|
let end = rest.find(|c: char| !c.is_ascii_digit()).unwrap_or(rest.len());
|
|
rest[..end].parse().ok()
|
|
}
|
|
|
|
fn extract_str_array(json: &str, key: &str) -> Vec<String> {
|
|
let pattern = format!("\"{}\":[", key);
|
|
let start = match json.find(&pattern) {
|
|
None => return Vec::new(),
|
|
Some(s) => s + pattern.len(),
|
|
};
|
|
let rest = &json[start..];
|
|
let end = rest.find(']').unwrap_or(rest.len());
|
|
let content = &rest[..end];
|
|
content
|
|
.split(',')
|
|
.filter_map(|s| {
|
|
let s = s.trim().trim_matches('"');
|
|
if s.is_empty() { None } else { Some(s.to_string()) }
|
|
})
|
|
.collect()
|
|
}
|