Skip to content

Commit

Permalink
Implemented authentication token TTL configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
themisir committed Oct 10, 2023
1 parent 36e3796 commit b64a6d4
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 31 deletions.
23 changes: 18 additions & 5 deletions src/app.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
use crate::issuer::Issuer;
use crate::proxy::ProxyClient;
use crate::store::UserStore;
use crate::utils::Duration;

use std::{collections::HashMap, path::Path, str::FromStr, sync::Arc};
use std::collections::HashSet;
use std::{
collections::{HashMap, HashSet},
path::Path,
str::FromStr,
sync::Arc,
};

use serde::{Deserialize, Serialize};
use log::info;
use serde::Deserialize;
use sqlx::sqlite::SqlitePoolOptions;
use url::Url;

Expand Down Expand Up @@ -59,7 +65,7 @@ impl AppState {
}
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[derive(Deserialize, Debug, Clone)]
pub struct AppConfig {
pub base_url: Url,
pub users_db: Url,
Expand All @@ -78,7 +84,7 @@ impl AppConfig {
}
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[derive(Deserialize, Debug, Clone)]
pub struct UpstreamConfig {
pub name: String,
pub claims: Vec<String>,
Expand All @@ -92,6 +98,8 @@ pub struct UpstreamConfig {
pub require_claims: Option<Vec<String>>,
pub require_authentication: bool,

pub cookie_ttl: Option<Duration>,

pub headers: Option<HashMap<String, String>>,
}

Expand Down Expand Up @@ -121,6 +129,11 @@ impl Upstreams {
host
);
}

info!(
"Added upstream '{}' ({} -> {}): {:?}",
cfg.name, cfg.origin_url, cfg.upstream_url, cfg
);
}

Ok(Self { by_name, by_host })
Expand Down
9 changes: 2 additions & 7 deletions src/auth.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::app::AppState;
use crate::http::{get_header, AppError, Cookies, Either, SetCookie};
use crate::proxy::{
ProxyClient, UpstreamAuthorizeParams, PROXY_AUTHORIZE_ENDPOINT, PROXY_TOKEN_TTL,
ProxyClient, UpstreamAuthorizeParams, PROXY_AUTHORIZE_ENDPOINT,
};
use crate::store::User;
use crate::uri::UriBuilder;
Expand Down Expand Up @@ -65,12 +65,7 @@ pub async fn issue_upstream_token(
if let Some(claims) = upstream.filter_claims(claims) {
let token = state
.issuer()
.create_token(
upstream.name(),
&user,
claims.as_ref(),
(*PROXY_TOKEN_TTL).into(),
)
.create_token(upstream.name(), &user, claims.as_ref(), upstream.token_ttl())
.map_err(|err| anyhow::format_err!("failed to create token: {}", err))?;

Ok(Some(token))
Expand Down
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async fn start_server(state: AppState, args: &ListenArgs) -> anyhow::Result<()>
.layer(TraceLayer::new_for_http())
.with_state(state);

println!("Binding on {}", args.bind);
info!("Binding on {}", args.bind);

axum::Server::bind(&args.bind)
.serve(router.into_make_service())
Expand Down
53 changes: 36 additions & 17 deletions src/proxy.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use crate::app::{AppState, UpstreamConfig};
use crate::auth::{AuthorizeParams, issue_upstream_token};
use crate::auth::{issue_upstream_token, AuthorizeParams};
use crate::http::{Cookies, SetCookie};
use crate::issuer::Claims;
use crate::store::UserClaim;
use crate::uri::UriBuilder;
use crate::utils;

use std::{
borrow::Cow,
Expand All @@ -11,7 +13,6 @@ use std::{
str::FromStr,
};

use crate::utils;
use axum::{
extract::{FromRequest, FromRequestParts, Host, OriginalUri, Query, State},
http::{
Expand All @@ -28,7 +29,6 @@ use hyper_tls::HttpsConnector;
use log::{error, info, warn};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use crate::issuer::Claims;

pub async fn middleware(
State(state): State<AppState>,
Expand Down Expand Up @@ -66,6 +66,7 @@ pub struct ProxyClient {
origin: String,
claims: HashSet<String>,
modified_headers: Option<Vec<(HeaderName, HeaderValue)>>,
token_ttl: Duration,
}

#[derive(Serialize, Deserialize)]
Expand Down Expand Up @@ -129,13 +130,16 @@ impl ProxyClient {
}
};

let token_ttl = cfg.cookie_ttl.map_or(*PROXY_TOKEN_TTL, |v| v.into());

Ok(Self {
config: cfg.clone(),
client,
upstream_uri,
origin,
claims,
modified_headers,
token_ttl,
})
}

Expand Down Expand Up @@ -173,11 +177,23 @@ impl ProxyClient {
}
}

async fn refresh_token_if_needed(&self, state: &AppState, claims: &Claims) -> anyhow::Result<Option<String>> {
pub fn token_ttl(&self) -> Duration {
self.token_ttl
}

async fn refresh_token_if_needed(
&self,
state: &AppState,
claims: &Claims,
) -> anyhow::Result<Option<String>> {
let refresh_needed = claims.valid_for() < *PROXY_TOKEN_REFRESH_THRESHOLD;
if refresh_needed {
let user_id = i32::from_str(claims.sub.as_str())?;
let user = state.store().find_user_by_id(user_id).await?.ok_or(anyhow::format_err!("user by id {} not found", claims.sub))?;
let user = state
.store()
.find_user_by_id(user_id)
.await?
.ok_or(anyhow::format_err!("user by id {} not found", claims.sub))?;

issue_upstream_token(state, self, &user).await
} else {
Expand Down Expand Up @@ -209,13 +225,14 @@ impl ProxyClient {
let value: HeaderValue = format!("Bearer {}", cookie.value()).try_into()?;
request.headers_mut().append(AUTHORIZATION, value);


return if let Ok(Some(token)) = self.refresh_token_if_needed(state, &claims).await {
let response= self.forward(request).await?;
Ok((Self::set_cookie(token), response).into_response())
return if let Ok(Some(token)) =
self.refresh_token_if_needed(state, &claims).await
{
let response = self.forward(request).await?;
Ok((self.set_cookie(token), response).into_response())
} else {
self.forward(request).await
}
};
}
}
}
Expand All @@ -227,15 +244,17 @@ impl ProxyClient {
}
}

fn set_cookie<'c, T>(token: T) -> SetCookie<'c>
fn set_cookie<'c, T>(&self, token: T) -> SetCookie<'c>
where
T: Into<Cow<'c, str>>,
{
SetCookie(Cookie::build(PROXY_COOKIE_NAME, token)
.path("/")
.http_only(true)
.max_age(utils::Duration::from(*PROXY_TOKEN_TTL).into())
.finish())
SetCookie(
Cookie::build(PROXY_COOKIE_NAME, token)
.path("/")
.http_only(true)
.max_age(utils::Duration::from(self.token_ttl).into())
.finish(),
)
}

async fn authorize(
Expand All @@ -246,7 +265,7 @@ impl ProxyClient {
let query = Query::<UpstreamAuthorizeParams>::from_request(request, state).await?;

Ok((
Self::set_cookie(query.token.clone()),
self.set_cookie(query.token.clone()),
Redirect::to(query.redirect_to.as_str()),
))
}
Expand Down
88 changes: 87 additions & 1 deletion src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
#[derive(Copy, Clone)]
use std::fmt::Formatter;

use serde::{de, Deserialize, Deserializer};

#[derive(Copy, Clone, Debug)]
pub struct Duration(chrono::Duration);

impl From<chrono::Duration> for Duration {
Expand Down Expand Up @@ -26,3 +30,85 @@ impl From<Duration> for cookie::time::Duration {
cookie::time::Duration::milliseconds(value.0.num_milliseconds())
}
}

impl<'de> Deserialize<'de> for Duration {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct DurationVisitor;

impl<'de> de::Visitor<'de> for DurationVisitor {
type Value = chrono::Duration;

fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
write!(formatter, "Duration")
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
if v.is_empty() {
return Err(de::Error::custom("empty duration value"));
}

let mut total_seconds = 0i64;
let mut i = 0;

while let Some(ch) = v.chars().nth(i) {
if ch.is_numeric() {
let mut unit_value = 0u32;
let mut num = ch;
loop {
unit_value = unit_value * 10 + (num as u8 - b'0') as u32;

i += 1;
if let Some(next) = v.chars().nth(i) {
num = next;
if !next.is_numeric() {
let coefficient = match next {
's' => 1,
'm' => 60,
'h' => 60 * 60,
'd' => 60 * 60 * 24,
_ => {
return Err(de::Error::custom(format!(
"invalid unit: {}",
next
)));
}
};

total_seconds += unit_value as i64 * coefficient;
i += 1;

break;
}
} else {
return Err(de::Error::custom("expected duration unit"));
}
}
} else {
return Err(de::Error::custom(format!(
"expected duration value but reached {}",
ch
)));
}
}

Ok(chrono::Duration::seconds(total_seconds))
}

fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: de::Error,
{
self.visit_str(&v)
}
}

let value = deserializer.deserialize_str(DurationVisitor)?;
Ok(Duration(value))
}
}

0 comments on commit b64a6d4

Please sign in to comment.