diff --git a/neon_hana/app/routers/util.py b/neon_hana/app/routers/util.py index 3808201..2d62d94 100644 --- a/neon_hana/app/routers/util.py +++ b/neon_hana/app/routers/util.py @@ -24,6 +24,8 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import re + from fastapi import APIRouter, Request from starlette.responses import PlainTextResponse @@ -32,10 +34,19 @@ util_route = APIRouter(prefix="/util", tags=["utilities"]) +def _is_ipv4(address: str) -> bool: + ipv4_regex = re.compile( + r'^(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01' + r']?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|' + r'2[0-4][0-9]|[01]?[0-9][0-9]?)$') + return ipv4_regex.match(address) + + @util_route.get("/client_ip", response_class=PlainTextResponse) async def api_client_ip(request: Request) -> str: ip_addr = request.client.host if request.client else "127.0.0.1" - if len(ip_addr.split('.')) != 4: + + if not _is_ipv4(ip_addr): # Reported host is a hostname, not an IP address. Return a generic # loopback value ip_addr = "127.0.0.1" diff --git a/tests/test_app.py b/tests/test_app.py index 491b44c..ef546b1 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -532,6 +532,16 @@ def test_llm(self, send_request): headers={"Authorization": f"Bearer {token}"}) self.assertEqual(response.status_code, 422, response.text) + def test_util_is_ipv4(self): + from neon_hana.app.routers.util import _is_ipv4 + self.assertTrue(_is_ipv4("127.0.0.1")) + self.assertTrue(_is_ipv4("10.0.0.10")) + self.assertTrue(_is_ipv4("1.1.1.1")) + self.assertFalse(_is_ipv4("ai.neon.api.1")) + self.assertFalse(_is_ipv4("host.local")) + self.assertFalse(_is_ipv4("localhost")) + self.assertFalse(_is_ipv4("1.0.0.300")) + def test_util_client_ip(self): response = self.test_app.get("/util/client_ip") self.assertEqual(response.text, "127.0.0.1")