Skip to content

Commit

Permalink
refactor: Used DEEP_STUBS for chained mocking (#238)
Browse files Browse the repository at this point in the history
Refs: #237
  • Loading branch information
nirikash authored Jan 7, 2025
1 parent 21c37d0 commit b520587
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
import ch.sbb.polarion.extension.generic.auth.ValidatorType;
import com.polarion.core.config.Configuration;
import com.polarion.core.config.IConfiguration;
import com.polarion.core.config.IRestConfiguration;
import com.polarion.platform.security.AuthenticationFailedException;
import com.polarion.platform.security.ISecurityService;
import com.polarion.platform.security.login.AccessToken;
import com.polarion.platform.security.login.ILogin;
import com.polarion.platform.security.login.IToken;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Answers;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.junit.jupiter.MockitoExtension;
Expand All @@ -34,22 +33,14 @@ class AuthenticationFilterTest {

@Mock
private ContainerRequestContext requestContext;
@Mock
@Mock(answer = Answers.RETURNS_DEEP_STUBS)
private ISecurityService securityService;
@Mock
private ILogin login;
@Mock
private ILogin.IBase base;
@Mock
private ILogin.IUsingAuthenticator authenticator;
@Mock
private ILogin.IFinal loginFinal;
@Mock
@Mock(answer = Answers.RETURNS_DEEP_STUBS)
private HttpServletRequest httpServletRequest;
@Mock
@Mock(answer = Answers.RETURNS_DEEP_STUBS)
private IConfiguration configuration;
@Mock
private IRestConfiguration restConfiguration;

@Test
void filterRequestWithoutAuthorizationHeaderAndXsrfHeader() {
Expand All @@ -76,14 +67,15 @@ void filterRequestWithoutBearerInAuthorizationHeader() {
}

@Test
@SuppressWarnings("unchecked")
void filterRequestWithValidBearerToken() throws IOException, AuthenticationFailedException {
when(requestContext.getHeaderString(HttpHeaders.AUTHORIZATION)).thenReturn("Bearer token");
when(requestContext.getHeaderString(AuthenticationFilter.X_POLARION_REST_TOKEN_HEADER)).thenReturn(null);

when(securityService.login()).thenReturn(login);
when(login.from("REST")).thenReturn(base);
when(base.authenticator(any())).thenReturn(authenticator);
when(authenticator.with((IToken<AccessToken>) any())).thenReturn(loginFinal);
when(securityService.login()
.from("REST")
.authenticator(any())
.with(any(IToken.class))).thenReturn(loginFinal);

Subject subject = new Subject();
when(loginFinal.perform()).thenReturn(subject);
Expand All @@ -95,14 +87,15 @@ void filterRequestWithValidBearerToken() throws IOException, AuthenticationFaile


@Test
@SuppressWarnings("unchecked")
void filterRequestWithFailedAuthentication() throws AuthenticationFailedException {
when(requestContext.getHeaderString(HttpHeaders.AUTHORIZATION)).thenReturn("Bearer failed_token");
when(requestContext.getHeaderString(AuthenticationFilter.X_POLARION_REST_TOKEN_HEADER)).thenReturn(null);

when(securityService.login()).thenReturn(login);
when(login.from("REST")).thenReturn(base);
when(base.authenticator(any())).thenReturn(authenticator);
when(authenticator.with((IToken<AccessToken>) any())).thenReturn(loginFinal);
when(securityService.login()
.from("REST")
.authenticator(any())
.with(any(IToken.class))).thenReturn(loginFinal);

when(loginFinal.perform()).thenThrow(new AuthenticationFailedException("Something went wrong"));

Expand All @@ -114,6 +107,7 @@ void filterRequestWithFailedAuthentication() throws AuthenticationFailedExceptio
}

@Test
@SuppressWarnings("unused")
void filterRequestWithValidXsrfToken() throws IOException, AuthenticationFailedException {
when(requestContext.getHeaderString(HttpHeaders.AUTHORIZATION)).thenReturn(null);
when(requestContext.getHeaderString(AuthenticationFilter.X_POLARION_REST_TOKEN_HEADER)).thenReturn("validXsrfToken");
Expand Down Expand Up @@ -144,12 +138,13 @@ void filterRequestWithInvalidXsrfToken() {

try (MockedStatic<Configuration> configurationMockedStatic = mockStatic(Configuration.class)) {
configurationMockedStatic.when(Configuration::getInstance).thenReturn(configuration);
when(configuration.rest()).thenReturn(restConfiguration);
when(restConfiguration.restApiTokenEnabled()).thenReturn(true);
when(configuration
.rest()
.restApiTokenEnabled()).thenReturn(true);

Principal userPrincipal = mock(Principal.class);
when(httpServletRequest.getUserPrincipal()).thenReturn(userPrincipal);
when(userPrincipal.getName()).thenReturn("user");
when(httpServletRequest
.getUserPrincipal()
.getName()).thenReturn("user");

AuthenticationFilter filter = new AuthenticationFilter(securityService, httpServletRequest);

Expand All @@ -165,12 +160,13 @@ void filterRequestWithXsrfTokenButConfigurationIsDisabled() {
when(requestContext.getHeaderString(AuthenticationFilter.X_POLARION_REST_TOKEN_HEADER)).thenReturn("xsrf_token");
try (MockedStatic<Configuration> configurationMockedStatic = mockStatic(Configuration.class)) {
configurationMockedStatic.when(Configuration::getInstance).thenReturn(configuration);
when(configuration.rest()).thenReturn(restConfiguration);
when(restConfiguration.restApiTokenEnabled()).thenReturn(false);
when(configuration
.rest()
.restApiTokenEnabled()).thenReturn(false);

Principal userPrincipal = mock(Principal.class);
when(httpServletRequest.getUserPrincipal()).thenReturn(userPrincipal);
when(userPrincipal.getName()).thenReturn("user");
when(httpServletRequest
.getUserPrincipal()
.getName()).thenReturn("user");

AuthenticationFilter filter = new AuthenticationFilter(securityService, httpServletRequest);

Expand All @@ -186,12 +182,13 @@ void filterRequestWithXsrfTokenForDifferentUser() {
when(requestContext.getHeaderString(AuthenticationFilter.X_POLARION_REST_TOKEN_HEADER)).thenReturn("xsrf_token_for_different_user");
try (MockedStatic<Configuration> configurationMockedStatic = mockStatic(Configuration.class)) {
configurationMockedStatic.when(Configuration::getInstance).thenReturn(configuration);
when(configuration.rest()).thenReturn(restConfiguration);
when(restConfiguration.restApiTokenEnabled()).thenReturn(true);
when(configuration
.rest()
.restApiTokenEnabled()).thenReturn(true);

Principal userPrincipal = mock(Principal.class);
when(httpServletRequest.getUserPrincipal()).thenReturn(userPrincipal);
when(userPrincipal.getName()).thenReturn("different_user");
when(httpServletRequest
.getUserPrincipal()
.getName()).thenReturn("different_user");

AuthenticationFilter filter = new AuthenticationFilter(securityService, httpServletRequest);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Answers;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockedStatic;
Expand Down Expand Up @@ -45,16 +44,14 @@ class CorsFilterTest {
private ContainerRequestContext requestContext;
@Mock
private ContainerResponseContext responseContext;

@Mock
private IConfiguration configuration;
@Mock
private IRestConfiguration restConfiguration;
@Mock(answer = Answers.RETURNS_DEEP_STUBS)
@Mock
MockedStatic<Configuration> configurationMockedStatic;

@BeforeEach
void setUp() throws MalformedURLException {
IConfiguration configuration = mock(IConfiguration.class);
configurationMockedStatic.when(Configuration::getInstance).thenReturn(configuration);
lenient().when(configuration.rest()).thenReturn(restConfiguration);
lenient().when(configuration.getBaseURL()).thenReturn(URI.create(LOCALHOST_8080).toURL());
Expand Down Expand Up @@ -99,8 +96,7 @@ void requestOriginNotAllowed(String input) throws URISyntaxException {
when(requestContext.getMethod()).thenReturn(HttpMethod.GET);

// no CORS enabled
HashSet<String> corsAllowedOrigins = new HashSet<>();
Arrays.stream(input.split( "," )).forEach(o -> corsAllowedOrigins.add(o));
HashSet<String> corsAllowedOrigins = new HashSet<>(Arrays.asList(input.split(",")));
when(restConfiguration.corsAllowedOrigins()).thenReturn(corsAllowedOrigins);

CorsFilter corsFilter = new CorsFilter();
Expand All @@ -122,8 +118,7 @@ void requestOriginAllowed(String input) throws URISyntaxException {
when(requestContext.getHeaderString(CorsFilter.ORIGIN)).thenReturn(LOCALHOST_1111);
when(requestContext.getMethod()).thenReturn(HttpMethod.GET);

HashSet<String> corsAllowedOrigins = new HashSet<>();
Arrays.stream(input.split( "," )).forEach(o -> corsAllowedOrigins.add(o));
HashSet<String> corsAllowedOrigins = new HashSet<>(Arrays.asList(input.split(",")));
when(restConfiguration.corsAllowedOrigins()).thenReturn(corsAllowedOrigins);

CorsFilter corsFilter = new CorsFilter();
Expand All @@ -139,8 +134,7 @@ void testContainerResponseFilter(String input) throws URISyntaxException {
when(uriInfo.getRequestUri()).thenReturn(new URI(LOCALHOST_8080 + "/some-extension"));
when(requestContext.getHeaderString(CorsFilter.ORIGIN)).thenReturn(LOCALHOST_1111);

HashSet<String> corsAllowedOrigins = new HashSet<>();
Arrays.stream(input.split( "," )).forEach(o -> corsAllowedOrigins.add(o));
HashSet<String> corsAllowedOrigins = new HashSet<>(Arrays.asList(input.split(",")));
when(restConfiguration.corsAllowedOrigins()).thenReturn(corsAllowedOrigins);

MultivaluedMap<String, Object> responseHeaders = new MultivaluedHashMap<>();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package ch.sbb.polarion.extension.generic.util;

import org.junit.jupiter.api.Test;
import org.mockito.Answers;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.security.auth.Subject;
import javax.servlet.http.HttpServletRequest;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
Expand All @@ -17,11 +17,11 @@ class RequestContextUtilTest {
@Test
void shouldReturnUserSubject() {
// Arrange
ServletRequestAttributes requestAttributes = mock(ServletRequestAttributes.class);
HttpServletRequest request = mock(HttpServletRequest.class);
ServletRequestAttributes requestAttributes = mock(ServletRequestAttributes.class, Answers.RETURNS_DEEP_STUBS);
Subject subject = mock(Subject.class);
when(requestAttributes.getRequest()).thenReturn(request);
when(request.getAttribute("user_subject")).thenReturn(subject);
when(requestAttributes
.getRequest()
.getAttribute("user_subject")).thenReturn(subject);
RequestContextHolder.setRequestAttributes(requestAttributes);

// Act
Expand Down

0 comments on commit b520587

Please sign in to comment.