Overhauled IP matching feature.

This commit is contained in:
Andrew Bastien 2023-11-24 19:40:08 -05:00
parent 923ec28ec4
commit 44ebefe561
5 changed files with 107 additions and 14 deletions

View file

@ -11,6 +11,7 @@ import importlib
import importlib.machinery
import importlib.util
import inspect
import ipaddress
import math
import os
import random
@ -2942,3 +2943,78 @@ def str2int(number):
# invalid number-word, raise ValueError
raise ValueError(f"String {original_input} cannot be converted to int.")
return sum(sums)
def match_ip(address, pattern) -> bool:
"""
Check if an IP address matches a given pattern. The pattern can be a single IP address
such as 8.8.8.8 or a CIDR-formatted subnet like 10.0.0.0/8
IPv6 is supported to, with CIDR-subnets looking like 2001:db8::/48
Args:
address (str): The source address being checked.
pattern (str): The single IP address or subnet to check against.
Returns:
result (bool): Whether it was a match or not.
"""
try:
# Convert the given IP address to an IPv4Address or IPv6Address object
ip_obj = ipaddress.ip_address(address)
except ValueError:
# Invalid IP address format
return False
try:
# Check if pattern is a single IP or a subnet
if "/" in pattern:
# It's (hopefully) a subnet in CIDR notation
network = ipaddress.ip_network(pattern, strict=False)
if ip_obj in network:
return True
else:
# It's a single IP address
if ip_obj == ipaddress.ip_address(pattern):
return True
except ValueError:
return False
return False
def ip_from_request(request, exclude=None) -> str:
"""
Retrieves the IP address from a web Request, while respecting X-Forwarded-For and
settings.UPSTREAM_IPS.
Args:
request (django Request or twisted.web.http.Request): The web request.
exclude: (list, optional): A list of IP addresses to exclude from the check. If left none,
then settings.UPSTREAM_IPS will be used.
Returns:
ip (str): The IP address the request originated from.
"""
if exclude is None:
exclude = settings.UPSTREAM_IPS
if hasattr(request, "getClientIP"):
# It's a twisted request.
remote_addr = request.getClientIP()
forwarded = request.getHeader("x-forwarded-for")
else:
# it's a Django request.
remote_addr = request.META.get("REMOTE_ADDR")
forwarded = request.META.get("HTTP_X_FORWARDED_FOR")
addresses = [remote_addr]
if forwarded:
addresses.extend(x.strip() for x in forwarded.split(","))
for addr in reversed(addresses):
if all(not match_ip(addr, pattern) for pattern in exclude):
return addr
logger.log_warn("ip_from_request: No valid IP address found in request. Using remote_addr.")
return remote_addr