diff --git a/tests/test_backends.py b/tests/test_backends.py index 540e01820..fd19183e0 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -1,5 +1,7 @@ +import builtins import uuid from datetime import datetime, timedelta +from importlib import reload from json import JSONEncoder from unittest import mock from unittest.mock import patch @@ -45,6 +47,7 @@ def default(self, obj): class TestTokenBackend(TestCase): def setUp(self): + self.realimport = builtins.__import__ self.hmac_token_backend = TokenBackend("HS256", SECRET) self.hmac_leeway_token_backend = TokenBackend("HS256", SECRET, leeway=LEEWAY) self.rsa_token_backend = TokenBackend("RS256", PRIVATE_KEY, PUBLIC_KEY) @@ -76,6 +79,28 @@ def test_init_fails_for_rs_algorithms_when_crypto_not_installed(self): ): TokenBackend(algo, "not_secret") + def test_jwk_client_not_available(self): + from rest_framework_simplejwt import backends + + def myimport(name, globals=None, locals=None, fromlist=(), level=0): + if name == "jwt" and fromlist == ("PyJWKClient", "PyJWKClientError"): + raise ImportError + return self.realimport(name, globals, locals, fromlist, level) + + builtins.__import__ = myimport + + # Reload backends, mock jwk client is not available + reload(backends) + + self.assertEqual(backends.JWK_CLIENT_AVAILABLE, False) + self.assertEqual(backends.TokenBackend("HS256").jwks_client, None) + + builtins.__import__ = self.realimport + + @patch("jwt.encode", mock.Mock(return_value=b"test")) + def test_token_encode_should_return_str_for_old_PyJWT(self): + self.assertIsInstance(TokenBackend("HS256").encode({}), str) + def test_encode_hmac(self): # Should return a JSON web token for the given payload payload = {"exp": make_utc(datetime(year=2000, month=1, day=1))} diff --git a/tests/test_init.py b/tests/test_init.py new file mode 100644 index 000000000..b68363189 --- /dev/null +++ b/tests/test_init.py @@ -0,0 +1,19 @@ +from importlib import reload +from unittest.mock import Mock, patch + +from django.test import SimpleTestCase +from pkg_resources import DistributionNotFound + + +class TestInit(SimpleTestCase): + def test_package_is_not_installed(self): + with patch( + "pkg_resources.get_distribution", Mock(side_effect=DistributionNotFound) + ): + # Import package mock package is not installed + import rest_framework_simplejwt.__init__ + + self.assertEqual(rest_framework_simplejwt.__init__.__version__, None) + + # Restore origin package without mock + reload(rest_framework_simplejwt.__init__) diff --git a/tests/test_models.py b/tests/test_models.py index 3a017fbfe..c757eea75 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,3 +1,6 @@ +from importlib import reload +from unittest.mock import patch + from django.test import TestCase from rest_framework_simplejwt.models import TokenUser @@ -15,6 +18,18 @@ def setUp(self): self.user = TokenUser(self.token) + def test_type_checking(self): + from rest_framework_simplejwt import models + + with patch("typing.TYPE_CHECKING", True): + # Reload models, mock type checking + reload(models) + + self.assertEqual(models.TYPE_CHECKING, True) + + # Restore origin module without mock + reload(models) + def test_username(self): self.assertEqual(self.user.username, "deep-thought") @@ -60,6 +75,12 @@ def test_eq(self): self.assertNotEqual(user1, user2) self.assertEqual(user1, user3) + def test_eq_not_implemented(self): + user1 = TokenUser({api_settings.USER_ID_CLAIM: 1}) + user2 = "user2" + + self.assertFalse(user1 == user2) + def test_hash(self): self.assertEqual(hash(self.user), hash(self.user.id)) @@ -105,3 +126,6 @@ def test_is_authenticated(self): def test_get_username(self): self.assertEqual(self.user.get_username(), "deep-thought") + + def test_get_custom_claims_through_backup_getattr(self): + self.assertEqual(self.user.some_other_stuff, "arstarst") diff --git a/tests/test_serializers.py b/tests/test_serializers.py index 6db0e3998..322d1cd9d 100644 --- a/tests/test_serializers.py +++ b/tests/test_serializers.py @@ -1,6 +1,8 @@ from datetime import timedelta +from importlib import reload from unittest.mock import MagicMock, patch +from django.conf import settings from django.contrib.auth import get_user_model from django.test import TestCase from rest_framework import exceptions as drf_exceptions @@ -76,6 +78,17 @@ def test_it_should_not_validate_if_user_not_found(self): with self.assertRaises(drf_exceptions.AuthenticationFailed): s.is_valid() + def test_it_should_pass_validate_if_request_not_in_context(self): + s = TokenObtainSerializer( + context={}, + data={ + "username": self.username, + "password": self.password, + }, + ) + + s.is_valid() + def test_it_should_raise_if_user_not_active(self): self.user.is_active = False self.user.save() @@ -372,6 +385,32 @@ def test_it_should_blacklist_refresh_token_if_tokens_should_be_rotated_and_black # Assert old refresh token is blacklisted self.assertEqual(BlacklistedToken.objects.first().token.jti, old_jti) + @override_api_settings( + ROTATE_REFRESH_TOKENS=True, + BLACKLIST_AFTER_ROTATION=True, + ) + def test_blacklist_app_not_installed_should_pass(self): + from rest_framework_simplejwt import serializers, tokens + + # Remove blacklist app + new_apps = list(settings.INSTALLED_APPS) + new_apps.remove("rest_framework_simplejwt.token_blacklist") + + with self.settings(INSTALLED_APPS=tuple(new_apps)): + # Reload module that blacklist app not installed + reload(tokens) + reload(serializers) + + refresh = tokens.RefreshToken() + + # Serializer validates + ser = serializers.TokenRefreshSerializer(data={"refresh": str(refresh)}) + ser.validate({"refresh": str(refresh)}) + + # Restore origin module without mock + reload(tokens) + reload(serializers) + class TestTokenVerifySerializer(TestCase): def test_it_should_raise_token_error_if_token_invalid(self): @@ -489,3 +528,25 @@ def test_it_should_blacklist_refresh_token_if_everything_ok(self): # Assert old refresh token is blacklisted self.assertEqual(BlacklistedToken.objects.first().token.jti, old_jti) + + def test_blacklist_app_not_installed_should_pass(self): + from rest_framework_simplejwt import serializers, tokens + + # Remove blacklist app + new_apps = list(settings.INSTALLED_APPS) + new_apps.remove("rest_framework_simplejwt.token_blacklist") + + with self.settings(INSTALLED_APPS=tuple(new_apps)): + # Reload module that blacklist app not installed + reload(tokens) + reload(serializers) + + refresh = tokens.RefreshToken() + + # Serializer validates + ser = serializers.TokenBlacklistSerializer(data={"refresh": str(refresh)}) + ser.validate({"refresh": str(refresh)}) + + # Restore origin module without mock + reload(tokens) + reload(serializers) diff --git a/tests/test_token_blacklist.py b/tests/test_token_blacklist.py index 824808145..b4d576b68 100644 --- a/tests/test_token_blacklist.py +++ b/tests/test_token_blacklist.py @@ -1,9 +1,11 @@ +from importlib import reload from unittest.mock import patch from django.contrib.auth.models import User from django.core.management import call_command from django.db.models import BigAutoField from django.test import TestCase +from django.utils import timezone from rest_framework_simplejwt.exceptions import TokenError from rest_framework_simplejwt.serializers import TokenVerifySerializer @@ -25,6 +27,19 @@ def setUp(self): password="test_password", ) + def test_token_blacklist_old_django(self): + with patch("django.VERSION", (3, 1)): + # Import package mock blacklist old django + import rest_framework_simplejwt.token_blacklist.__init__ as blacklist + + self.assertEqual( + blacklist.default_app_config, + ("rest_framework_simplejwt.token_blacklist.apps.TokenBlacklistConfig"), + ) + + # Restore origin module without mock + reload(blacklist) + def test_sliding_tokens_are_added_to_outstanding_list(self): token = SlidingToken.for_user(self.user) @@ -114,6 +129,23 @@ def test_tokens_can_be_manually_blacklisted(self): self.assertEqual(OutstandingToken.objects.count(), 2) + def test_outstanding_token_and_blacklisted_token_expected_str(self): + outstanding = OutstandingToken.objects.create( + user=self.user, + jti="abc", + token="xyz", + expires_at=timezone.now(), + ) + blacklisted = BlacklistedToken.objects.create(token=outstanding) + + expected_outstanding_str = "Token for {} ({})".format( + outstanding.user, outstanding.jti + ) + expected_blacklisted_str = f"Blacklisted token for {blacklisted.token.user}" + + self.assertEqual(str(outstanding), expected_outstanding_str) + self.assertEqual(str(blacklisted), expected_blacklisted_str) + class TestTokenBlacklistFlushExpiredTokens(TestCase): def setUp(self): diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 9f81a1a76..47702e33a 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta +from importlib import reload from unittest.mock import patch from django.contrib.auth import get_user_model @@ -39,6 +40,18 @@ def setUpTestData(cls): password="test_password", ) + def test_type_checking(self): + from rest_framework_simplejwt import tokens + + with patch("typing.TYPE_CHECKING", True): + # Reload tokens, mock type checking + reload(tokens) + + self.assertEqual(tokens.TYPE_CHECKING, True) + + # Restore origin module without mock + reload(tokens) + def test_init_no_token_type_or_lifetime(self): class MyTestToken(Token): pass @@ -377,6 +390,11 @@ def test_for_user_with_username(self): token = MyToken.for_user(self.user) self.assertEqual(token[api_settings.USER_ID_CLAIM], self.username) + @override_api_settings(CHECK_REVOKE_TOKEN=True) + def test_revoke_token_claim_included_in_authorization_token(self): + token = MyToken.for_user(self.user) + self.assertIn(api_settings.REVOKE_TOKEN_CLAIM, token) + def test_get_token_backend(self): token = MyToken() diff --git a/tests/test_views.py b/tests/test_views.py index b1fc80113..4927e6ffd 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -440,3 +440,16 @@ class CustomTokenView(TokenViewBase): request = factory.post("/", {}, format="json") res = view(request) self.assertEqual(res.status_code, 400) + + +class TestTokenViewBase(APIViewTestCase): + def test_serializer_class_not_set_in_settings_and_class_attribute_or_wrong_path( + self, + ): + view = TokenViewBase() + msg = "Could not import serializer '%s'" % view._serializer_class + + with self.assertRaises(ImportError) as e: + view.get_serializer_class() + + self.assertEqual(e.exception.msg, msg)