Merge pull request #3495 from jaborsh/TraitReloadFix

CounterTrait now checks for last_update before defaulting to current time.
This commit is contained in:
Griatch 2024-04-07 21:22:45 +02:00 committed by GitHub
commit f84dde8870
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -571,12 +571,16 @@ class TraitHandler:
# initialize any # initialize any
# Note that .trait_data retains the connection to the database, meaning every # Note that .trait_data retains the connection to the database, meaning every
# update we do to .trait_data automatically syncs with database. # update we do to .trait_data automatically syncs with database.
self.trait_data = obj.attributes.get(db_attribute_key, category=db_attribute_category) self.trait_data = obj.attributes.get(
db_attribute_key, category=db_attribute_category
)
if self.trait_data is None: if self.trait_data is None:
# no existing storage; initialize it, we then have to fetch it again # no existing storage; initialize it, we then have to fetch it again
# to retain the db connection # to retain the db connection
obj.attributes.add(db_attribute_key, {}, category=db_attribute_category) obj.attributes.add(db_attribute_key, {}, category=db_attribute_category)
self.trait_data = obj.attributes.get(db_attribute_key, category=db_attribute_category) self.trait_data = obj.attributes.get(
db_attribute_key, category=db_attribute_category
)
self._cache = {} self._cache = {}
def __len__(self): def __len__(self):
@ -595,7 +599,9 @@ class TraitHandler:
_SA(self, trait_key, value) _SA(self, trait_key, value)
else: else:
trait_cls = self._get_trait_class(trait_key=trait_key) trait_cls = self._get_trait_class(trait_key=trait_key)
valid_keys = list_to_string(list(trait_cls.default_keys.keys()), endsep="or") valid_keys = list_to_string(
list(trait_cls.default_keys.keys()), endsep="or"
)
raise TraitException( raise TraitException(
f"Trait object not settable directly. Assign to {trait_key}.{valid_keys}." f"Trait object not settable directly. Assign to {trait_key}.{valid_keys}."
) )
@ -627,7 +633,9 @@ class TraitHandler:
try: try:
trait_type = self.trait_data[trait_key]["trait_type"] trait_type = self.trait_data[trait_key]["trait_type"]
except KeyError: except KeyError:
raise TraitException(f"Trait class for Trait {trait_key} could not be found.") raise TraitException(
f"Trait class for Trait {trait_key} could not be found."
)
try: try:
return _TRAIT_CLASSES[trait_type] return _TRAIT_CLASSES[trait_type]
except KeyError: except KeyError:
@ -657,11 +665,18 @@ class TraitHandler:
if trait is None and trait_key in self.trait_data: if trait is None and trait_key in self.trait_data:
trait_type = self.trait_data[trait_key]["trait_type"] trait_type = self.trait_data[trait_key]["trait_type"]
trait_cls = self._get_trait_class(trait_type) trait_cls = self._get_trait_class(trait_type)
trait = self._cache[trait_key] = trait_cls(_GA(self, "trait_data")[trait_key]) trait = self._cache[trait_key] = trait_cls(
_GA(self, "trait_data")[trait_key]
)
return trait return trait
def add( def add(
self, trait_key, name=None, trait_type=DEFAULT_TRAIT_TYPE, force=True, **trait_properties self,
trait_key,
name=None,
trait_type=DEFAULT_TRAIT_TYPE,
force=True,
**trait_properties,
): ):
""" """
Create a new Trait and add it to the handler. Create a new Trait and add it to the handler.
@ -748,7 +763,9 @@ class TraitProperty:
""" """
def __init__(self, name=None, trait_type=DEFAULT_TRAIT_TYPE, force=True, **trait_properties): def __init__(
self, name=None, trait_type=DEFAULT_TRAIT_TYPE, force=True, **trait_properties
):
""" """
Initialize a TraitField. Mimics TraitHandler.add input except no `trait_key`. Initialize a TraitField. Mimics TraitHandler.add input except no `trait_key`.
@ -767,7 +784,9 @@ class TraitProperty:
""" """
self._traithandler_name = trait_properties.pop("traithandler_name", "traits") self._traithandler_name = trait_properties.pop("traithandler_name", "traits")
trait_properties.update({"name": name, "trait_type": trait_type, "force": force}) trait_properties.update(
{"name": name, "trait_type": trait_type, "force": force}
)
self._trait_properties = trait_properties self._trait_properties = trait_properties
self._cache = {} self._cache = {}
@ -807,7 +826,9 @@ class TraitProperty:
if trait is None: if trait is None:
# initialize the trait # initialize the trait
traithandler.add(self._trait_key, **self._trait_properties) traithandler.add(self._trait_key, **self._trait_properties)
trait = traithandler.get(self._trait_key) # caches it in the traithandler trait = traithandler.get(
self._trait_key
) # caches it in the traithandler
self._cache[instance] = trait self._cache[instance] = trait
return self._cache[instance] return self._cache[instance]
@ -915,13 +936,21 @@ class Trait:
if MandatoryTraitKey in unset_defaults.values(): if MandatoryTraitKey in unset_defaults.values():
# we have one or more unset keys that was mandatory # we have one or more unset keys that was mandatory
_raise_err([key for key, value in unset_defaults.items() if value == MandatoryTraitKey]) _raise_err(
[
key
for key, value in unset_defaults.items()
if value == MandatoryTraitKey
]
)
# apply the default values # apply the default values
trait_data.update(unset_defaults) trait_data.update(unset_defaults)
if not cls.allow_extra_properties: if not cls.allow_extra_properties:
# don't allow any extra properties - remove the extra data # don't allow any extra properties - remove the extra data
for key in (key for key in inp.difference(req) if key not in ("name", "trait_type")): for key in (
key for key in inp.difference(req) if key not in ("name", "trait_type")
):
del trait_data[key] del trait_data[key]
return trait_data return trait_data
@ -945,7 +974,12 @@ class Trait:
def __getattr__(self, key): def __getattr__(self, key):
"""Access extra parameters as attributes.""" """Access extra parameters as attributes."""
if key in ("default_keys", "data_default", "trait_type", "allow_extra_properties"): if key in (
"default_keys",
"data_default",
"trait_type",
"allow_extra_properties",
):
return _GA(self, key) return _GA(self, key)
try: try:
return self._data[key] return self._data[key]
@ -1276,7 +1310,7 @@ class CounterTrait(Trait):
) )
# set up rate # set up rate
if trait_data["rate"] != 0: if trait_data["rate"] != 0:
trait_data["last_update"] = time() trait_data["last_update"] = trait_data.get("last_update", time())
else: else:
trait_data["last_update"] = None trait_data["last_update"] = None
return trait_data return trait_data
@ -1310,7 +1344,8 @@ class CounterTrait(Trait):
"""Check if we passed the ratetarget in either direction.""" """Check if we passed the ratetarget in either direction."""
ratetarget = self._data["ratetarget"] ratetarget = self._data["ratetarget"]
return ratetarget is not None and ( return ratetarget is not None and (
(self.rate < 0 and value <= ratetarget) or (self.rate > 0 and value >= ratetarget) (self.rate < 0 and value <= ratetarget)
or (self.rate > 0 and value >= ratetarget)
) )
def _stop_timer(self): def _stop_timer(self):
@ -1435,7 +1470,9 @@ class CounterTrait(Trait):
@current.setter @current.setter
def current(self, value): def current(self, value):
if type(value) in (int, float): if type(value) in (int, float):
self._data["current"] = self._check_and_start_timer(self._enforce_boundaries(value)) self._data["current"] = self._check_and_start_timer(
self._enforce_boundaries(value)
)
@current.deleter @current.deleter
def current(self): def current(self):
@ -1552,6 +1589,7 @@ class GaugeTrait(CounterTrait):
rate = self.rate rate = self.rate
if rate != 0 and self._data["last_update"] is not None: if rate != 0 and self._data["last_update"] is not None:
now = time() now = time()
tdiff = now - self._data["last_update"] tdiff = now - self._data["last_update"]
current += rate * tdiff current += rate * tdiff
value = current value = current
@ -1657,13 +1695,17 @@ class GaugeTrait(CounterTrait):
def current(self): def current(self):
"""The `current` value of the gauge.""" """The `current` value of the gauge."""
return self._update_current( return self._update_current(
self._enforce_boundaries(self._data.get("current", (self.base + self.mod) * self.mult)) self._enforce_boundaries(
self._data.get("current", (self.base + self.mod) * self.mult)
)
) )
@current.setter @current.setter
def current(self, value): def current(self, value):
if type(value) in (int, float): if type(value) in (int, float):
self._data["current"] = self._check_and_start_timer(self._enforce_boundaries(value)) self._data["current"] = self._check_and_start_timer(
self._enforce_boundaries(value)
)
@current.deleter @current.deleter
def current(self): def current(self):