Skip to content

Commit

Permalink
Token based auth integration with core extension
Browse files Browse the repository at this point in the history
Provide a way for lettuce clients to use token-based authentication.
TOKENs come with a TTL. After a Redis client authenticates with a TOKEN, if they didn't renew their authentication we need to evict (close) them. The suggested approach is to leverage the existing CredentialsProvider and add support for streaming credentials to handle token refresh scenarios. Each time a new token is received connection is reauthenticated.
  • Loading branch information
ggivo committed Dec 3, 2024
1 parent 91871b6 commit 78a0aea
Show file tree
Hide file tree
Showing 6 changed files with 472 additions and 2 deletions.
35 changes: 33 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,19 @@
<tag>HEAD</tag>
</scm>

<repositories>
<repository>
<id>sonatype-snapshots</id>
<url>https://oss.sonatype.org/content/repositories/snapshots/</url>
<releases>
<enabled>false</enabled>
</releases>
<snapshots>
<enabled>true</enabled>
</snapshots>
</repository>
</repositories>

<distributionManagement>
<snapshotRepository>
<id>ossrh</id>
Expand Down Expand Up @@ -173,12 +186,30 @@
<type>pom</type>
<scope>import</scope>
</dependency>

<dependency>
<groupId>redis.clients.authentication</groupId>
<artifactId>redis-authx-core</artifactId>
<version>0.1.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>redis.clients.authentication</groupId>
<artifactId>redis-authx-entraid</artifactId>
<version>0.1.0-SNAPSHOT</version>
<scope>test</scope>
</dependency>
</dependencies>
</dependencyManagement>

<dependencies>

<dependency>
<groupId>redis.clients.authentication</groupId>
<artifactId>redis-authx-core</artifactId>
</dependency>
<dependency>
<groupId>redis.clients.authentication</groupId>
<artifactId>redis-authx-entraid</artifactId>
<scope>test</scope>
</dependency>
<!-- Start of core dependencies -->

<dependency>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package io.lettuce.authx;

import io.lettuce.core.RedisCredentials;
import io.lettuce.core.StreamingCredentialsProvider;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;
import redis.clients.authentication.core.Token;
import redis.clients.authentication.core.TokenAuthConfig;
import redis.clients.authentication.core.TokenListener;
import redis.clients.authentication.core.TokenManager;

public class TokenBasedRedisCredentialsProvider implements StreamingCredentialsProvider {

private final TokenManager tokenManager;

private final Sinks.Many<RedisCredentials> credentialsSink = Sinks.many().replay().latest();

public TokenBasedRedisCredentialsProvider(TokenAuthConfig tokenAuthConfig) {
this(new TokenManager(tokenAuthConfig.getIdentityProviderConfig().getProvider(),
tokenAuthConfig.getTokenManagerConfig()));

}

public TokenBasedRedisCredentialsProvider(TokenManager tokenManager) {
this.tokenManager = tokenManager;
initializeTokenManager();
}

/**
* Initialize the TokenManager and subscribe to token renewal events.
*/
private void initializeTokenManager() {
TokenListener listener = new TokenListener() {

@Override
public void onTokenRenewed(Token token) {
String username = token.tryGet("oid");
char[] pass = token.getValue().toCharArray();
RedisCredentials credentials = RedisCredentials.just(username, pass);
credentialsSink.tryEmitNext(credentials);
}

@Override
public void onError(Exception exception) {
credentialsSink.tryEmitError(exception);
}

};

try {
tokenManager.start(listener, false);
} catch (Exception e) {
credentialsSink.tryEmitError(e);
}
}

/**
* Resolve the latest available credentials as a Mono.
* <p>
* This method returns a Mono that emits the most recent set of Redis credentials. The Mono will complete once the
* credentials are emitted. If no credentials are available at the time of subscription, the Mono will wait until
* credentials are available.
*
* @return a Mono that emits the latest Redis credentials
*/
@Override
public Mono<RedisCredentials> resolveCredentials() {

return credentialsSink.asFlux().next();
}

/**
* Expose the Flux for all credential updates.
* <p>
* This method returns a Flux that emits all updates to the Redis credentials. Subscribers will receive the latest
* credentials whenever they are updated. The Flux will continue to emit updates until the provider is shut down.
*
* @return a Flux that emits all updates to the Redis credentials
*/
@Override
public Flux<RedisCredentials> credentials() {

return credentialsSink.asFlux().onBackpressureLatest(); // Provide a continuous stream of credentials
}

/**
* Stop the credentials provider and clean up resources.
* <p>
* This method stops the TokenManager and completes the credentials sink, ensuring that all resources are properly released.
* It should be called when the credentials provider is no longer needed.
*/
public void shutdown() {
credentialsSink.tryEmitComplete();
tokenManager.stop();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package io.lettuce.authx;

import io.lettuce.core.RedisCredentials;
import io.lettuce.core.TestTokenManager;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.Disposable;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import redis.clients.authentication.core.SimpleToken;

import java.time.Duration;
import java.util.Collections;

import static org.assertj.core.api.Assertions.assertThat;

public class TokenBasedRedisCredentialsProviderTest {

private TestTokenManager tokenManager;

private TokenBasedRedisCredentialsProvider credentialsProvider;

@BeforeEach
public void setUp() {
// Use TestToken manager to emit tokens/errors on request
tokenManager = new TestTokenManager(null, null);
credentialsProvider = new TokenBasedRedisCredentialsProvider(tokenManager);
}

@Test
public void shouldReturnPreviouslyEmittedTokenWhenResolved() {
tokenManager.emitToken(testToken("test-user", "token-1"));

Mono<RedisCredentials> credentials = credentialsProvider.resolveCredentials();

StepVerifier.create(credentials).assertNext(actual -> {
assertThat(actual.getUsername()).isEqualTo("test-user");
assertThat(new String(actual.getPassword())).isEqualTo("token-1");
}).verifyComplete();
}

@Test
public void shouldReturnLatestEmittedTokenWhenResolved() {
tokenManager.emitToken(testToken("test-user", "token-2"));
tokenManager.emitToken(testToken("test-user", "token-3")); // Latest token

Mono<RedisCredentials> credentials = credentialsProvider.resolveCredentials();

StepVerifier.create(credentials).assertNext(actual -> {
assertThat(actual.getUsername()).isEqualTo("test-user");
assertThat(new String(actual.getPassword())).isEqualTo("token-3");
}).verifyComplete();
}

@Test
public void shouldReturnTokenEmittedBeforeSubscription() {

tokenManager.emitToken(testToken("test-user", "token-1"));

// Test resolveCredentials
Mono<RedisCredentials> credentials1 = credentialsProvider.resolveCredentials();

StepVerifier.create(credentials1).assertNext(actual -> {
assertThat(actual.getUsername()).isEqualTo("test-user");
assertThat(new String(actual.getPassword())).isEqualTo("token-1");
}).verifyComplete();

// Emit second token and subscribe another
tokenManager.emitToken(testToken("test-user", "token-2"));
tokenManager.emitToken(testToken("test-user", "token-3"));
Mono<RedisCredentials> credentials2 = credentialsProvider.resolveCredentials();
StepVerifier.create(credentials2).assertNext(actual -> {
assertThat(actual.getUsername()).isEqualTo("test-user");
assertThat(new String(actual.getPassword())).isEqualTo("token-3");
}).verifyComplete();
}

@Test
public void shouldWaitForAndReturnTokenWhenEmittedLater() {
Mono<RedisCredentials> result = credentialsProvider.resolveCredentials();

tokenManager.emitTokenWithDelay(testToken("test-user", "delayed-token"), 100); // Emit token after 100ms
StepVerifier.create(result)
.assertNext(credentials -> assertThat(String.valueOf(credentials.getPassword())).isEqualTo("delayed-token"))
.verifyComplete();
}

@Test
public void shouldCompleteAllSubscribersOnStop() {
Flux<RedisCredentials> credentialsFlux1 = credentialsProvider.credentials();
Flux<RedisCredentials> credentialsFlux2 = credentialsProvider.credentials();

Disposable subscription1 = credentialsFlux1.subscribe();
Disposable subscription2 = credentialsFlux2.subscribe();

tokenManager.emitToken(testToken("test-user", "token-1"));

new Thread(() -> {
try {
Thread.sleep(100); // Delay of 100 milliseconds
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
credentialsProvider.shutdown();
}).start();

StepVerifier.create(credentialsFlux1)
.assertNext(credentials -> assertThat(String.valueOf(credentials.getPassword())).isEqualTo("token-1"))
.verifyComplete();

StepVerifier.create(credentialsFlux2)
.assertNext(credentials -> assertThat(String.valueOf(credentials.getPassword())).isEqualTo("token-1"))
.verifyComplete();
}

@Test
public void shouldPropagateMultipleTokensOnStream() {

Flux<RedisCredentials> result = credentialsProvider.credentials();
StepVerifier.create(result).then(() -> tokenManager.emitToken(testToken("test-user", "token1")))
.then(() -> tokenManager.emitToken(testToken("test-user", "token2")))
.assertNext(credentials -> assertThat(String.valueOf(credentials.getPassword())).isEqualTo("token1"))
.assertNext(credentials -> assertThat(String.valueOf(credentials.getPassword())).isEqualTo("token2"))
.thenCancel().verify(Duration.ofMillis(100));
}

@Test
public void shouldHandleTokenRequestErrorGracefully() {
Exception simulatedError = new RuntimeException("Token request failed");
tokenManager.emitError(simulatedError);

Flux<RedisCredentials> result = credentialsProvider.credentials();

StepVerifier.create(result).expectErrorMatches(
throwable -> throwable instanceof RuntimeException && "Token request failed".equals(throwable.getMessage()))
.verify();
}

private SimpleToken testToken(String username, String value) {
return new SimpleToken(value, System.currentTimeMillis() + 5000, // expires in 5 seconds
System.currentTimeMillis(), Collections.singletonMap("oid", username));

}

}
57 changes: 57 additions & 0 deletions src/test/java/io/lettuce/core/AuthenticationIntegrationTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

import javax.inject.Inject;

import io.lettuce.authx.TokenBasedRedisCredentialsProvider;
import io.lettuce.core.event.command.CommandListener;
import io.lettuce.core.event.command.CommandSucceededEvent;
import io.lettuce.core.protocol.RedisCommand;
import io.lettuce.test.Delay;
import org.awaitility.Awaitility;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag;
Expand All @@ -26,8 +28,10 @@
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;
import redis.clients.authentication.core.SimpleToken;

import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -170,6 +174,59 @@ private boolean isAuthCommandWithCredentials(RedisCommand<?, ?, ?> command, Stri
return false;
}

@Test
@Inject
void tokenBasedCredentialProvider(RedisClient client) {

ClientOptions clientOptions = ClientOptions.builder()
.disconnectedBehavior(ClientOptions.DisconnectedBehavior.REJECT_COMMANDS).build();
client.setOptions(clientOptions);
// Connection used to simulate test user credential rotation
StatefulRedisConnection<String, String> defaultConnection = client.connect();

String testUser = "streaming_cred_test_user";
char[] testPassword1 = "token_1".toCharArray();
char[] testPassword2 = "token_2".toCharArray();

TestTokenManager tokenManager = new TestTokenManager(null, null);

// streaming credentials provider that emits redis credentials which will trigger connection re-authentication
// token manager is used to emit updated credentials
TokenBasedRedisCredentialsProvider credentialsProvider = new TokenBasedRedisCredentialsProvider(tokenManager);

RedisURI uri = RedisURI.builder().withTimeout(Duration.ofSeconds(1)).withClientName("streaming_cred_test")
.withHost(TestSettings.host()).withPort(TestSettings.port()).withAuthentication(credentialsProvider).build();

// create test user with initial credentials set to 'testPassword1'
createTestUser(defaultConnection, testUser, testPassword1);
tokenManager.emitToken(testToken(testUser, testPassword1));

StatefulRedisConnection<String, String> connection = client.connect(StringCodec.UTF8, uri);
assertThat(connection.sync().aclWhoami()).isEqualTo(testUser);

// update test user credentials in Redis server (password changed to testPassword2)
// then emit updated credentials trough streaming credentials provider
// and trigger re-connect to force re-authentication
// updated credentials should be used for re-authentication
updateTestUser(defaultConnection, testUser, testPassword2);
tokenManager.emitToken(testToken(testUser, testPassword2));
connection.sync().quit();

Delay.delay(Duration.ofMillis(100));
assertThat(connection.sync().ping()).isEqualTo("PONG");

String res = connection.sync().aclWhoami();
assertThat(res).isEqualTo(testUser);

defaultConnection.close();
connection.close();
}

private SimpleToken testToken(String username, char[] password) {
return new SimpleToken(String.valueOf(password), Instant.now().plusMillis(500).toEpochMilli(),
Instant.now().toEpochMilli(), Collections.singletonMap("oid", username));
}

static class RenewableRedisCredentialsProvider implements StreamingCredentialsProvider {

private final Sinks.Many<RedisCredentials> credentialsSink = Sinks.many().replay().latest();
Expand Down
Loading

0 comments on commit 78a0aea

Please sign in to comment.