Adds remove() method and renames record/unrecord methods to reflect expected input.

This commit is contained in:
Johnny 2020-10-20 21:00:03 +00:00
parent ee95fca6cf
commit 3444cce6e9

View file

@ -125,22 +125,58 @@ class Throttle(object):
if not previously_throttled and currently_throttled: if not previously_throttled and currently_throttled:
logger.log_sec(f"Throttle Activated: {failmsg} (IP: {ip}, {self.limit} hits in {self.timeout} seconds.)") logger.log_sec(f"Throttle Activated: {failmsg} (IP: {ip}, {self.limit} hits in {self.timeout} seconds.)")
self.record_key(ip) self.record_ip(ip)
def record_key(self, key, *args, **kwargs): def remove(self, ip, *args, **kwargs):
"""
Clears data stored for an IP from the throttle.
Args:
ip(str): IP to clear.
"""
exists = self.get(ip)
if not exists: return False
cache_key = self.get_cache_key(ip)
self.storage.delete(cache_key)
self.unrecord_ip(ip)
# Return True if NOT exists
return ~bool(self.get(ip))
def record_ip(self, ip, *args, **kwargs):
""" """
Tracks keys as they are added to the cache (since there is no way to Tracks keys as they are added to the cache (since there is no way to
get a list of keys after-the-fact). get a list of keys after-the-fact).
Args: Args:
key(str): Key being added to cache. This should be the original ip(str): IP being added to cache. This should be the original
key, not the cache-prefixed version. IP, not the cache-prefixed key.
""" """
keys_key = self.get_cache_key('keys') keys_key = self.get_cache_key('keys')
keys = self.storage.get(keys_key, set()) keys = self.storage.get(keys_key, set())
keys.add(key) keys.add(ip)
self.storage.set(keys_key, keys, self.timeout) self.storage.set(keys_key, keys, self.timeout)
return True
def unrecord_ip(self, ip, *args, **kwargs):
"""
Forces removal of a key from the key registry.
Args:
ip(str): IP to remove from list of keys.
"""
keys_key = self.get_cache_key('keys')
keys = self.storage.get(keys_key, set())
try:
keys.remove(ip)
self.storage.set(keys_key, keys, self.timeout)
return True
except KeyError:
return False
def check(self, ip): def check(self, ip):
""" """
@ -171,7 +207,7 @@ class Throttle(object):
return True return True
else: else:
# timeout has passed. clear faillist # timeout has passed. clear faillist
self.storage.delete(cache_key) self.remove(ip)
return False return False
else: else:
return False return False