Structure handlers to allow .get() to return lists

See #1154. In the end I didn't modify the Attributehandler and
TagHandler like this, instead I added the `return_list` argument
for cases when one wants a guaranteed return.
This commit is contained in:
Griatch 2017-08-27 14:56:05 +02:00
parent 05a3d0435d
commit 92df3ce5ae
13 changed files with 86 additions and 1608 deletions

View file

@ -544,9 +544,9 @@ class CmdSetHandler(object):
# legacy alias # legacy alias
delete_default = remove_default delete_default = remove_default
def all(self): def get(self):
""" """
Show all cmdsets. Get all cmdsets.
Returns: Returns:
cmdsets (list): All the command sets currently in the handler. cmdsets (list): All the command sets currently in the handler.
@ -554,6 +554,9 @@ class CmdSetHandler(object):
""" """
return self.cmdset_stack return self.cmdset_stack
# backwards-compatible alias
all = get
def clear(self): def clear(self):
""" """
Removes all Command Sets from the handler except the default one Removes all Command Sets from the handler except the default one

View file

@ -1841,7 +1841,7 @@ class CmdLock(ObjManipCommand):
obj = caller.search(self.lhs) obj = caller.search(self.lhs)
if not obj: if not obj:
return return
caller.msg(obj.locks.all()) caller.msg("\n".join(obj.locks.all()))
class CmdExamine(ObjManipCommand): class CmdExamine(ObjManipCommand):

View file

@ -172,23 +172,25 @@ class ChannelHandler(object):
Initializes the channel handler's internal state. Initializes the channel handler's internal state.
""" """
self.cached_channel_cmds = {} self._cached_channel_cmds = {}
self.cached_cmdsets = {} self._cached_cmdsets = {}
self._cached_channels = {}
def __str__(self): def __str__(self):
""" """
Returns the string representation of the handler Returns the string representation of the handler
""" """
return ", ".join(str(cmd) for cmd in self.cached_channel_cmds) return ", ".join(str(cmd) for cmd in self._cached_channel_cmds)
def clear(self): def clear(self):
""" """
Reset the cache storage. Reset the cache storage.
""" """
self.cached_channel_cmds = {} self._cached_channel_cmds = {}
self.cached_cmdsets = {} self._cached_cmdsets = {}
self._cached_channels = {}
def add(self, channel): def add(self, channel):
""" """
@ -221,9 +223,11 @@ class ChannelHandler(object):
key = channel.key key = channel.key
cmd.__doc__ = cmd.__doc__.format(channelkey=key, cmd.__doc__ = cmd.__doc__.format(channelkey=key,
lower_channelkey=key.strip().lower(), lower_channelkey=key.strip().lower(),
channeldesc=channel.attributes.get("desc", default="").strip()) channeldesc=channel.attributes.get(
self.cached_channel_cmds[channel] = cmd "desc", default="").strip())
self.cached_cmdsets = {} self._cached_channel_cmds[channel] = cmd
self._cached_channels[key] = channel
self._cached_cmdsets = {}
add_channel = add # legacy alias add_channel = add # legacy alias
def remove(self, channel): def remove(self, channel):
@ -247,11 +251,28 @@ class ChannelHandler(object):
global _CHANNELDB global _CHANNELDB
if not _CHANNELDB: if not _CHANNELDB:
from evennia.comms.models import ChannelDB as _CHANNELDB from evennia.comms.models import ChannelDB as _CHANNELDB
self.cached_channel_cmds = {} self._cached_channel_cmds = {}
self.cached_cmdsets = {} self._cached_cmdsets = {}
self._cached_channels = {}
for channel in _CHANNELDB.objects.get_all_channels(): for channel in _CHANNELDB.objects.get_all_channels():
self.add(channel) self.add(channel)
def get(self, channelname=None):
"""
Get a channel from the handler, or all channels
Args:
channelame (str, optional): Channel key, case insensitive.
Returns
channels (list): The matching channels in a list, or all
channels in the handler.
"""
if channelname:
channel = self._cached_channels.get(channelname.lower(), None)
return [channel] if channel else []
return self._cached_channels.values()
def get_cmdset(self, source_object): def get_cmdset(self, source_object):
""" """
Retrieve cmdset for channels this source_object has Retrieve cmdset for channels this source_object has
@ -266,12 +287,12 @@ class ChannelHandler(object):
access to. access to.
""" """
if source_object in self.cached_cmdsets: if source_object in self._cached_cmdsets:
return self.cached_cmdsets[source_object] return self._cached_cmdsets[source_object]
else: else:
# create a new cmdset holding all viable channels # create a new cmdset holding all viable channels
chan_cmdset = None chan_cmdset = None
chan_cmds = [channelcmd for channel, channelcmd in self.cached_channel_cmds.iteritems() chan_cmds = [channelcmd for channel, channelcmd in self._cached_channel_cmds.iteritems()
if channel.subscriptions.has(source_object) and if channel.subscriptions.has(source_object) and
channelcmd.access(source_object, 'send')] channelcmd.access(source_object, 'send')]
if chan_cmds: if chan_cmds:
@ -281,7 +302,7 @@ class ChannelHandler(object):
chan_cmdset.duplicates = True chan_cmdset.duplicates = True
for cmd in chan_cmds: for cmd in chan_cmds:
chan_cmdset.add(cmd) chan_cmdset.add(cmd)
self.cached_cmdsets[source_object] = chan_cmdset self._cached_cmdsets[source_object] = chan_cmdset
return chan_cmdset return chan_cmdset

View file

@ -533,7 +533,7 @@ class SubscriptionHandler(object):
self.obj.db_object_subscriptions.add(subscriber) self.obj.db_object_subscriptions.add(subscriber)
elif clsname == "AccountDB": elif clsname == "AccountDB":
self.obj.db_account_subscriptions.add(subscriber) self.obj.db_account_subscriptions.add(subscriber)
_CHANNELHANDLER.cached_cmdsets.pop(subscriber, None) _CHANNELHANDLER._cached_cmdsets.pop(subscriber, None)
self._recache() self._recache()
def remove(self, entity): def remove(self, entity):
@ -556,7 +556,7 @@ class SubscriptionHandler(object):
self.obj.db_account_subscriptions.remove(entity) self.obj.db_account_subscriptions.remove(entity)
elif clsname == "ObjectDB": elif clsname == "ObjectDB":
self.obj.db_object_subscriptions.remove(entity) self.obj.db_object_subscriptions.remove(entity)
_CHANNELHANDLER.cached_cmdsets.pop(subscriber, None) _CHANNELHANDLER._cached_cmdsets.pop(subscriber, None)
self._recache() self._recache()
def all(self): def all(self):
@ -571,6 +571,7 @@ class SubscriptionHandler(object):
if self._cache is None: if self._cache is None:
self._recache() self._recache()
return self._cache return self._cache
get = all # alias
def online(self): def online(self):
""" """

View file

@ -107,7 +107,6 @@ from __future__ import print_function
from builtins import object from builtins import object
import re import re
import inspect
from django.conf import settings from django.conf import settings
from evennia.utils import logger, utils from evennia.utils import logger, utils
from django.utils.translation import ugettext as _ from django.utils.translation import ugettext as _
@ -370,13 +369,13 @@ class LockHandler(object):
def all(self): def all(self):
""" """
Return all lockstrings. Return all lockstrings
Returns: Returns:
lockstring (str): The full lockstring lockstrings (list): All separate lockstrings
""" """
return self.get() return str(self).split(';')
def remove(self, access_type): def remove(self, access_type):
""" """

View file

@ -7,7 +7,7 @@ entities.
""" """
import time import time
from builtins import object from builtins import object
from future.utils import listvalues, with_metaclass from future.utils import with_metaclass
from django.conf import settings from django.conf import settings

View file

@ -93,6 +93,7 @@ class MonitorHandler(object):
def at_update(self, obj, fieldname): def at_update(self, obj, fieldname):
""" """
Called by the field as it saves. Called by the field as it saves.
""" """
to_delete = [] to_delete = []
if obj in self.monitors and fieldname in self.monitors[obj]: if obj in self.monitors and fieldname in self.monitors[obj]:
@ -175,6 +176,9 @@ class MonitorHandler(object):
""" """
List all monitors. List all monitors.
Returns:
monitors (list): The handled monitors.
""" """
output = [] output = []
for obj in self.monitors: for obj in self.monitors:

View file

@ -109,7 +109,7 @@ class ScriptHandler(object):
scripts (list): The found scripts matching `key`. scripts (list): The found scripts matching `key`.
""" """
return ScriptDB.objects.get_all_scripts_on_obj(self.obj, key=key) return list(ScriptDB.objects.get_all_scripts_on_obj(self.obj, key=key))
def delete(self, key=None): def delete(self, key=None):
""" """

View file

@ -6,7 +6,7 @@ from datetime import datetime, timedelta
from twisted.internet import reactor, task from twisted.internet import reactor, task
from evennia.server.models import ServerConfig from evennia.server.models import ServerConfig
from evennia.utils.logger import log_trace, log_err from evennia.utils.logger import log_err
from evennia.utils.dbserialize import dbserialize, dbunserialize from evennia.utils.dbserialize import dbserialize, dbunserialize
TASK_HANDLER = None TASK_HANDLER = None
@ -113,9 +113,9 @@ class TaskHandler(object):
try: try:
dbserialize(arg) dbserialize(arg)
except (TypeError, AttributeError): except (TypeError, AttributeError):
logger.log_err("The positional argument {} cannot be " log_err("The positional argument {} cannot be "
"pickled and will not be present in the arguments " "pickled and will not be present in the arguments "
"fed to the callback {}".format(arg, callback)) "fed to the callback {}".format(arg, callback))
else: else:
safe_args.append(arg) safe_args.append(arg)
@ -123,9 +123,9 @@ class TaskHandler(object):
try: try:
dbserialize(value) dbserialize(value)
except (TypeError, AttributeError): except (TypeError, AttributeError):
logger.log_err("The {} keyword argument {} cannot be " log_err("The {} keyword argument {} cannot be "
"pickled and will not be present in the arguments " "pickled and will not be present in the arguments "
"fed to the callback {}".format(key, value, callback)) "fed to the callback {}".format(key, value, callback))
else: else:
safe_kwargs[key] = value safe_kwargs[key] = value

View file

@ -387,7 +387,7 @@ class AttributeHandler(object):
def get(self, key=None, default=None, category=None, return_obj=False, def get(self, key=None, default=None, category=None, return_obj=False,
strattr=False, raise_exception=False, accessing_obj=None, strattr=False, raise_exception=False, accessing_obj=None,
default_access=True): default_access=True, return_list=False):
""" """
Get the Attribute. Get the Attribute.
@ -398,7 +398,8 @@ class AttributeHandler(object):
category (str, optional): the category within which to category (str, optional): the category within which to
retrieve attribute(s). retrieve attribute(s).
default (any, optional): The value to return if an default (any, optional): The value to return if an
Attribute was not defined. Attribute was not defined. If set, it will be returned in
a one-item list.
return_obj (bool, optional): If set, the return is not the value of the return_obj (bool, optional): If set, the return is not the value of the
Attribute but the Attribute object itself. Attribute but the Attribute object itself.
strattr (bool, optional): Return the `strvalue` field of strattr (bool, optional): Return the `strvalue` field of
@ -410,13 +411,15 @@ class AttributeHandler(object):
accessing_obj (object, optional): If set, an `attrread` accessing_obj (object, optional): If set, an `attrread`
permission lock will be checked before returning each permission lock will be checked before returning each
looked-after Attribute. looked-after Attribute.
default_access (bool, optional): default_access (bool, optional): If no `attrread` lock is set on
object, this determines if the lock should then be passed or not.
return_list (bool, optional):
Returns: Returns:
result (any, Attribute or list): This will be the value of the found result (any or list): One or more matches for keys and/or categories. Each match will be
Attribute unless `return_obj` is True, at which point it will be the value of the found Attribute(s) unless `return_obj` is True, at which point it
the attribute object or None. If multiple keys are given, this will be the attribute object itself or None. If `return_list` is True, this will
will be a list of values or attribute objects/None. always be a list, regardless of the number of elements.
Raises: Raises:
AttributeError: If `raise_exception` is set and no matching Attribute AttributeError: If `raise_exception` is set and no matching Attribute
@ -453,7 +456,10 @@ class AttributeHandler(object):
ret = ret if return_obj else [attr.strvalue for attr in ret if attr] ret = ret if return_obj else [attr.strvalue for attr in ret if attr]
else: else:
ret = ret if return_obj else [attr.value for attr in ret if attr] ret = ret if return_obj else [attr.value for attr in ret if attr]
if not ret:
if return_list:
return ret if ret else [default] if default is not None else []
elif not ret:
return ret if len(key) > 1 else default return ret if len(key) > 1 else default
return ret[0] if len(ret) == 1 else ret return ret[0] if len(ret) == 1 else ret

View file

@ -251,6 +251,8 @@ class TagHandler(object):
""" """
if not tag: if not tag:
return return
if not self._cache_complete:
self._fullcache()
for tagstr in make_iter(tag): for tagstr in make_iter(tag):
if not tagstr: if not tagstr:
continue continue
@ -265,7 +267,7 @@ class TagHandler(object):
getattr(self.obj, self._m2m_fieldname).add(tagobj) getattr(self.obj, self._m2m_fieldname).add(tagobj)
self._setcache(tagstr, category, tagobj) self._setcache(tagstr, category, tagobj)
def get(self, key=None, default=None, category=None, return_tagobj=False): def get(self, key=None, default=None, category=None, return_tagobj=False, return_list=False):
""" """
Get the tag for the given key or list of tags. Get the tag for the given key or list of tags.
@ -277,11 +279,14 @@ class TagHandler(object):
category. category.
return_tagobj (bool, optional): Return the Tag object itself return_tagobj (bool, optional): Return the Tag object itself
instead of a string representation of the Tag. instead of a string representation of the Tag.
return_list (bool, optional): Always return a list, regardless
of number of matches.
Returns: Returns:
tags (str, TagObject or list): The matches, either string tags (list): The matches, either string
representations of the tags or the Tag objects themselves representations of the tags or the Tag objects themselves
depending on `return_tagobj`. depending on `return_tagobj`. If 'default' is set, this
will be a list with the default value as its only element.
""" """
ret = [] ret = []
@ -289,6 +294,8 @@ class TagHandler(object):
# note - the _getcache call removes case sensitivity for us # note - the _getcache call removes case sensitivity for us
ret.extend([tag if return_tagobj else to_str(tag.db_key) ret.extend([tag if return_tagobj else to_str(tag.db_key)
for tag in self._getcache(keystr, category)]) for tag in self._getcache(keystr, category)])
if return_list:
return ret if ret else [default] if default is not None else []
return ret[0] if len(ret) == 1 else (ret if ret else default) return ret[0] if len(ret) == 1 else (ret if ret else default)
def remove(self, key, category=None): def remove(self, key, category=None):
@ -327,6 +334,8 @@ class TagHandler(object):
category. category.
""" """
if not self._cache_complete:
self._fullcache()
query = {"%s__id" % self._model: self._objid, "tag__db_model": self._model, "tag__db_tagtype": self._tagtype} query = {"%s__id" % self._model: self._objid, "tag__db_model": self._model, "tag__db_tagtype": self._tagtype}
if category: if category:
query["tag__db_category"] = category.strip().lower() query["tag__db_category"] = category.strip().lower()

View file

@ -167,7 +167,6 @@ class _SaverMutable(object):
non_saver_name = cls_name non_saver_name = cls_name
raise ValueError(_ERROR_DELETED_ATTR.format(cls_name=cls_name, obj=self, raise ValueError(_ERROR_DELETED_ATTR.format(cls_name=cls_name, obj=self,
non_saver_name=non_saver_name)) non_saver_name=non_saver_name))
print("self._db_obj.pk")
self._db_obj.value = self self._db_obj.value = self
else: else:
logger.log_err("_SaverMutable %s has no root Attribute to save to." % self) logger.log_err("_SaverMutable %s has no root Attribute to save to." % self)

File diff suppressed because it is too large Load diff