Skip to content

Commit

Permalink
Add more precise ipv4 validation
Browse files Browse the repository at this point in the history
  • Loading branch information
NeonDaniel committed Nov 18, 2024
1 parent d713345 commit 47cf170
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
13 changes: 12 additions & 1 deletion neon_hana/app/routers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import re

from fastapi import APIRouter, Request
from starlette.responses import PlainTextResponse

Expand All @@ -32,10 +34,19 @@
util_route = APIRouter(prefix="/util", tags=["utilities"])


def _is_ipv4(address: str) -> bool:
ipv4_regex = re.compile(
r'^(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01'
r']?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|'
r'2[0-4][0-9]|[01]?[0-9][0-9]?)$')
return ipv4_regex.match(address)


@util_route.get("/client_ip", response_class=PlainTextResponse)
async def api_client_ip(request: Request) -> str:
ip_addr = request.client.host if request.client else "127.0.0.1"
if len(ip_addr.split('.')) != 4:

if not _is_ipv4(ip_addr):
# Reported host is a hostname, not an IP address. Return a generic
# loopback value
ip_addr = "127.0.0.1"
Expand Down
10 changes: 10 additions & 0 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,16 @@ def test_llm(self, send_request):
headers={"Authorization": f"Bearer {token}"})
self.assertEqual(response.status_code, 422, response.text)

def test_util_is_ipv4(self):
from neon_hana.app.routers.util import _is_ipv4
self.assertTrue(_is_ipv4("127.0.0.1"))
self.assertTrue(_is_ipv4("10.0.0.10"))
self.assertTrue(_is_ipv4("1.1.1.1"))
self.assertFalse(_is_ipv4("ai.neon.api.1"))
self.assertFalse(_is_ipv4("host.local"))
self.assertFalse(_is_ipv4("localhost"))
self.assertFalse(_is_ipv4("1.0.0.300"))

def test_util_client_ip(self):
response = self.test_app.get("/util/client_ip")
self.assertEqual(response.text, "127.0.0.1")
Expand Down

0 comments on commit 47cf170

Please sign in to comment.