diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 86e49fa6cf72..9b51b12d6296 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -57,7 +57,7 @@ public static class AuthConstants public static readonly RangeConstant ARGON2_ITERATIONS = new(2, 10, 3); public static readonly RangeConstant ARGON2_MEMORY = new(15, 1024, 64); public static readonly RangeConstant ARGON2_PARALLELISM = new(1, 16, 4); - + public static readonly string NewDeviceVerificationExceptionCacheKeyFormat = "NewDeviceVerificationException_{0}"; } public class RangeConstant diff --git a/src/Identity/IdentityServer/RequestValidators/DeviceValidator.cs b/src/Identity/IdentityServer/RequestValidators/DeviceValidator.cs index 2a048bcb2aab..d59417bfa72d 100644 --- a/src/Identity/IdentityServer/RequestValidators/DeviceValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/DeviceValidator.cs @@ -10,6 +10,7 @@ using Bit.Core.Settings; using Bit.Identity.IdentityServer.Enums; using Duende.IdentityServer.Validation; +using Microsoft.Extensions.Caching.Distributed; namespace Bit.Identity.IdentityServer.RequestValidators; @@ -20,6 +21,8 @@ public class DeviceValidator( IMailService mailService, ICurrentContext currentContext, IUserService userService, + IDistributedCache distributedCache, + ILogger logger, IFeatureService featureService) : IDeviceValidator { private readonly IDeviceService _deviceService = deviceService; @@ -28,6 +31,8 @@ public class DeviceValidator( private readonly IMailService _mailService = mailService; private readonly ICurrentContext _currentContext = currentContext; private readonly IUserService _userService = userService; + private readonly IDistributedCache distributedCache = distributedCache; + private readonly ILogger _logger = logger; private readonly IFeatureService _featureService = featureService; public async Task ValidateRequestDeviceAsync(ValidatedTokenRequest request, CustomValidatorRequestContext context) @@ -67,7 +72,6 @@ public async Task ValidateRequestDeviceAsync(ValidatedTokenRequest request !context.SsoRequired && _globalSettings.EnableNewDeviceVerification) { - // We only want to return early if the device is invalid or there is an error var validationResult = await HandleNewDeviceVerificationAsync(context.User, request); if (validationResult != DeviceValidationResultType.Success) { @@ -121,6 +125,18 @@ private async Task HandleNewDeviceVerificationAsync( return DeviceValidationResultType.InvalidUser; } + // CS exception flow + // Check cache for user information + var cacheKey = string.Format(AuthConstants.NewDeviceVerificationExceptionCacheKeyFormat, user.Id.ToString()); + var cacheValue = await distributedCache.GetAsync(cacheKey); + if (cacheValue != null) + { + // if found in cache return success result and remove from cache + await distributedCache.RemoveAsync(cacheKey); + _logger.LogInformation("New device verification exception for user {UserId} found in cache", user.Id); + return DeviceValidationResultType.Success; + } + // parse request for NewDeviceOtp to validate var newDeviceOtp = request.Raw["NewDeviceOtp"]?.ToString(); // we only check null here since an empty OTP will be considered an incorrect OTP diff --git a/test/Identity.Test/IdentityServer/DeviceValidatorTests.cs b/test/Identity.Test/IdentityServer/DeviceValidatorTests.cs index 304715b68cb1..105267ea305d 100644 --- a/test/Identity.Test/IdentityServer/DeviceValidatorTests.cs +++ b/test/Identity.Test/IdentityServer/DeviceValidatorTests.cs @@ -10,6 +10,8 @@ using Bit.Identity.IdentityServer.RequestValidators; using Bit.Test.Common.AutoFixture.Attributes; using Duende.IdentityServer.Validation; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; using AuthFixtures = Bit.Identity.Test.AutoFixture; @@ -24,6 +26,8 @@ public class DeviceValidatorTests private readonly IMailService _mailService; private readonly ICurrentContext _currentContext; private readonly IUserService _userService; + private readonly IDistributedCache _distributedCache; + private readonly Logger _logger; private readonly IFeatureService _featureService; private readonly DeviceValidator _sut; @@ -35,6 +39,8 @@ public DeviceValidatorTests() _mailService = Substitute.For(); _currentContext = Substitute.For(); _userService = Substitute.For(); + _distributedCache = Substitute.For(); + _logger = new Logger(Substitute.For()); _featureService = Substitute.For(); _sut = new DeviceValidator( _deviceService, @@ -43,6 +49,8 @@ public DeviceValidatorTests() _mailService, _currentContext, _userService, + _distributedCache, + _logger, _featureService); } @@ -51,7 +59,7 @@ public async void GetKnownDeviceAsync_UserNull_ReturnsFalse( Device device) { // Arrange - // AutoData arrages + // AutoData arranges // Act var result = await _sut.GetKnownDeviceAsync(null, device); @@ -421,6 +429,30 @@ public async void HandleNewDeviceVerificationAsync_UserNull_ContextModified_Retu Assert.Equal(expectedErrorMessage, actualResponse.Message); } + [Theory, BitAutoData] + public async void HandleNewDeviceVerificationAsync_UserHasCacheValue_ReturnsSuccess( + CustomValidatorRequestContext context, + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request) + { + // Arrange + ArrangeForHandleNewDeviceVerificationTest(context, request); + _featureService.IsEnabled(FeatureFlagKeys.NewDeviceVerification).Returns(true); + _globalSettings.EnableNewDeviceVerification = true; + _distributedCache.GetAsync(Arg.Any()).Returns([1]); + + // Act + var result = await _sut.ValidateRequestDeviceAsync(request, context); + + // Assert + await _userService.Received(0).SendOTPAsync(context.User); + await _deviceService.Received(1).SaveAsync(Arg.Any()); + + Assert.True(result); + Assert.False(context.CustomResponse.ContainsKey("ErrorModel")); + Assert.Equal(context.User.Id, context.Device.UserId); + Assert.NotNull(context.Device); + } + [Theory, BitAutoData] public async void HandleNewDeviceVerificationAsync_NewDeviceOtpValid_ReturnsSuccess( CustomValidatorRequestContext context, @@ -430,6 +462,7 @@ public async void HandleNewDeviceVerificationAsync_NewDeviceOtpValid_ReturnsSucc ArrangeForHandleNewDeviceVerificationTest(context, request); _featureService.IsEnabled(FeatureFlagKeys.NewDeviceVerification).Returns(true); _globalSettings.EnableNewDeviceVerification = true; + _distributedCache.GetAsync(Arg.Any()).Returns(null as byte[]); var newDeviceOtp = "123456"; request.Raw.Add("NewDeviceOtp", newDeviceOtp); @@ -461,6 +494,7 @@ public async void HandleNewDeviceVerificationAsync_NewDeviceOtpInvalid_ReturnsIn ArrangeForHandleNewDeviceVerificationTest(context, request); _featureService.IsEnabled(FeatureFlagKeys.NewDeviceVerification).Returns(true); _globalSettings.EnableNewDeviceVerification = true; + _distributedCache.GetAsync(Arg.Any()).Returns(null as byte[]); request.Raw.Add("NewDeviceOtp", newDeviceOtp); @@ -489,6 +523,7 @@ public async void HandleNewDeviceVerificationAsync_UserHasNoDevices_ReturnsSucce ArrangeForHandleNewDeviceVerificationTest(context, request); _featureService.IsEnabled(FeatureFlagKeys.NewDeviceVerification).Returns(true); _globalSettings.EnableNewDeviceVerification = true; + _distributedCache.GetAsync(Arg.Any()).Returns([1]); _deviceRepository.GetManyByUserIdAsync(context.User.Id).Returns([]); // Act @@ -515,6 +550,7 @@ public async void HandleNewDeviceVerificationAsync_NewDeviceOtpEmpty_UserHasDevi _featureService.IsEnabled(FeatureFlagKeys.NewDeviceVerification).Returns(true); _globalSettings.EnableNewDeviceVerification = true; _deviceRepository.GetManyByUserIdAsync(context.User.Id).Returns([new Device()]); + _distributedCache.GetAsync(Arg.Any()).Returns(null as byte[]); // Act var result = await _sut.ValidateRequestDeviceAsync(request, context);