Fix get_tag error. Resolve #2128

This commit is contained in:
Griatch 2020-05-18 20:25:26 +02:00
parent 703b307c40
commit ad1169d900
2 changed files with 115 additions and 44 deletions

View file

@ -31,8 +31,14 @@ class TypedObjectManager(idmapper.manager.SharedMemoryManager):
# Attribute manager methods # Attribute manager methods
def get_attribute( def get_attribute(
self, key=None, category=None, value=None, strvalue=None, obj=None, attrtype=None, **kwargs self,
key=None,
category=None,
value=None,
strvalue=None,
obj=None,
attrtype=None,
**kwargs
): ):
""" """
Return Attribute objects by key, by category, by value, by Return Attribute objects by key, by category, by value, by
@ -76,9 +82,9 @@ class TypedObjectManager(idmapper.manager.SharedMemoryManager):
# no reason to make strvalue/value mutually exclusive at this level # no reason to make strvalue/value mutually exclusive at this level
query.append(("attribute__db_value", value)) query.append(("attribute__db_value", value))
return Attribute.objects.filter( return Attribute.objects.filter(
pk__in=self.model.db_attributes.through.objects.filter(**dict(query)).values_list( pk__in=self.model.db_attributes.through.objects.filter(
"attribute_id", flat=True **dict(query)
) ).values_list("attribute_id", flat=True)
) )
def get_nick(self, key=None, category=None, value=None, strvalue=None, obj=None): def get_nick(self, key=None, category=None, value=None, strvalue=None, obj=None):
@ -104,8 +110,15 @@ class TypedObjectManager(idmapper.manager.SharedMemoryManager):
key=key, category=category, value=value, strvalue=strvalue, obj=obj key=key, category=category, value=value, strvalue=strvalue, obj=obj
) )
def get_by_attribute(self, key=None, category=None, value=None, def get_by_attribute(
strvalue=None, attrtype=None, **kwargs): self,
key=None,
category=None,
value=None,
strvalue=None,
attrtype=None,
**kwargs
):
""" """
Return objects having attributes with the given key, category, Return objects having attributes with the given key, category,
value, strvalue or combination of those criteria. value, strvalue or combination of those criteria.
@ -132,7 +145,10 @@ class TypedObjectManager(idmapper.manager.SharedMemoryManager):
""" """
dbmodel = self.model.__dbclass__.__name__.lower() dbmodel = self.model.__dbclass__.__name__.lower()
query = [("db_attributes__db_attrtype", attrtype), ("db_attributes__db_model", dbmodel)] query = [
("db_attributes__db_attrtype", attrtype),
("db_attributes__db_model", dbmodel),
]
if key: if key:
query.append(("db_attributes__db_key", key)) query.append(("db_attributes__db_key", key))
if category: if category:
@ -158,11 +174,15 @@ class TypedObjectManager(idmapper.manager.SharedMemoryManager):
obj (list): Objects having the matching Nicks. obj (list): Objects having the matching Nicks.
""" """
return self.get_by_attribute(key=key, category=category, strvalue=nick, attrtype="nick") return self.get_by_attribute(
key=key, category=category, strvalue=nick, attrtype="nick"
)
# Tag manager methods # Tag manager methods
def get_tag(self, key=None, category=None, obj=None, tagtype=None, global_search=False): def get_tag(
self, key=None, category=None, obj=None, tagtype=None, global_search=False
):
""" """
Return Tag objects by key, by category, by object (it is Return Tag objects by key, by category, by object (it is
stored on) or with a combination of those criteria. stored on) or with a combination of those criteria.
@ -206,9 +226,9 @@ class TypedObjectManager(idmapper.manager.SharedMemoryManager):
if category: if category:
query.append(("tag__db_category", category)) query.append(("tag__db_category", category))
return Tag.objects.filter( return Tag.objects.filter(
pk__in=self.model.db_tags.through.objects.filter(**dict(query)).values_list( pk__in=self.model.db_tags.through.objects.filter(
"tag_id", flat=True **dict(query)
) ).values_list("tag_id", flat=True)
) )
def get_permission(self, key=None, category=None, obj=None): def get_permission(self, key=None, category=None, obj=None):
@ -279,7 +299,7 @@ class TypedObjectManager(idmapper.manager.SharedMemoryManager):
if not _Tag: if not _Tag:
from evennia.typeclasses.models import Tag as _Tag from evennia.typeclasses.models import Tag as _Tag
match = kwargs.get("match", "all").lower().strip() anymatch = "any" == kwargs.get("match", "all").lower().strip()
keys = make_iter(key) if key else [] keys = make_iter(key) if key else []
categories = make_iter(category) if category else [] categories = make_iter(category) if category else []
@ -290,7 +310,9 @@ class TypedObjectManager(idmapper.manager.SharedMemoryManager):
dbmodel = self.model.__dbclass__.__name__.lower() dbmodel = self.model.__dbclass__.__name__.lower()
query = ( query = (
self.filter(db_tags__db_tagtype__iexact=tagtype, db_tags__db_model__iexact=dbmodel) self.filter(
db_tags__db_tagtype__iexact=tagtype, db_tags__db_model__iexact=dbmodel
)
.distinct() .distinct()
.order_by("id") .order_by("id")
) )
@ -309,28 +331,30 @@ class TypedObjectManager(idmapper.manager.SharedMemoryManager):
) )
clauses = Q() clauses = Q()
for ikey, key in enumerate(keys): for ikey, key in enumerate(keys):
# Keep each key and category together, grouped by AND # ANY mode; must match any one of the given tags/categories
clauses |= Q(db_key__iexact=key, db_category__iexact=categories[ikey]) clauses |= Q(
db_key__iexact=key, db_category__iexact=categories[ikey]
)
else: else:
# only one or more categories given # only one or more categories given
# import evennia;evennia.set_trace()
clauses = Q() clauses = Q()
# ANY mode; must match any one of them
for category in unique_categories: for category in unique_categories:
clauses |= Q(db_category__iexact=category) clauses |= Q(db_category__iexact=category)
tags = _Tag.objects.filter(clauses) tags = _Tag.objects.filter(clauses)
query = query.filter(db_tags__in=tags).annotate( query = query.filter(db_tags__in=tags).annotate(
matches=Count("db_tags__pk", filter=Q(db_tags__in=tags), distinct=True) matches=Count("db_tags__pk", filter=Q(db_tags__in=tags),
distinct=True)
) )
# Default ALL: Match all of the tags and optionally more if anymatch:
if match == "all":
n_req_tags = tags.count() if n_keys > 0 else n_unique_categories
query = query.filter(matches__gte=n_req_tags)
# ANY: Match any single tag, ordered by weight # ANY: Match any single tag, ordered by weight
elif match == "any":
query = query.order_by("-matches") query = query.order_by("-matches")
else:
# Default ALL: Match all of the tags and optionally more
n_req_tags = n_keys if n_keys > 0 else n_unique_categories
query = query.filter(matches__gte=n_req_tags)
return query return query
@ -388,7 +412,9 @@ class TypedObjectManager(idmapper.manager.SharedMemoryManager):
# try to get old tag # try to get old tag
dbmodel = self.model.__dbclass__.__name__.lower() dbmodel = self.model.__dbclass__.__name__.lower()
tag = self.get_tag(key=key, category=category, tagtype=tagtype, global_search=True) tag = self.get_tag(
key=key, category=category, tagtype=tagtype, global_search=True
)
if tag and data is not None: if tag and data is not None:
# get tag from list returned by get_tag # get tag from list returned by get_tag
tag = tag[0] tag = tag[0]
@ -402,7 +428,9 @@ class TypedObjectManager(idmapper.manager.SharedMemoryManager):
from evennia.typeclasses.models import Tag as _Tag from evennia.typeclasses.models import Tag as _Tag
tag = _Tag.objects.create( tag = _Tag.objects.create(
db_key=key.strip().lower() if key is not None else None, db_key=key.strip().lower() if key is not None else None,
db_category=category.strip().lower() if category and key is not None else None, db_category=category.strip().lower()
if category and key is not None
else None,
db_data=data, db_data=data,
db_model=dbmodel, db_model=dbmodel,
db_tagtype=tagtype.strip().lower() if tagtype is not None else None, db_tagtype=tagtype.strip().lower() if tagtype is not None else None,
@ -511,7 +539,8 @@ class TypedObjectManager(idmapper.manager.SharedMemoryManager):
typeclass=F("db_typeclass_path"), typeclass=F("db_typeclass_path"),
# Calculate this class' percentage of total composition # Calculate this class' percentage of total composition
percent=ExpressionWrapper( percent=ExpressionWrapper(
((F("count") / float(self.count())) * 100.0), output_field=FloatField() ((F("count") / float(self.count())) * 100.0),
output_field=FloatField(),
), ),
) )
.values("typeclass", "count", "percent") .values("typeclass", "count", "percent")
@ -531,7 +560,9 @@ class TypedObjectManager(idmapper.manager.SharedMemoryManager):
stats = self.get_typeclass_totals().order_by("typeclass") stats = self.get_typeclass_totals().order_by("typeclass")
return {x.get("typeclass"): x.get("count") for x in stats} return {x.get("typeclass"): x.get("count") for x in stats}
def typeclass_search(self, typeclass, include_children=False, include_parents=False): def typeclass_search(
self, typeclass, include_children=False, include_parents=False
):
""" """
Searches through all objects returning those which has a Searches through all objects returning those which has a
certain typeclass. If location is set, limit search to objects certain typeclass. If location is set, limit search to objects
@ -806,7 +837,8 @@ class TypeclassManager(TypedObjectManager):
""" """
paths = [self.model.path] + [ paths = [self.model.path] + [
"%s.%s" % (cls.__module__, cls.__name__) for cls in self._get_subclasses(self.model) "%s.%s" % (cls.__module__, cls.__name__)
for cls in self._get_subclasses(self.model)
] ]
kwargs.update({"db_typeclass_path__in": paths}) kwargs.update({"db_typeclass_path__in": paths})
return super().get(**kwargs) return super().get(**kwargs)
@ -828,7 +860,8 @@ class TypeclassManager(TypedObjectManager):
""" """
# query, including all subclasses # query, including all subclasses
paths = [self.model.path] + [ paths = [self.model.path] + [
"%s.%s" % (cls.__module__, cls.__name__) for cls in self._get_subclasses(self.model) "%s.%s" % (cls.__module__, cls.__name__)
for cls in self._get_subclasses(self.model)
] ]
kwargs.update({"db_typeclass_path__in": paths}) kwargs.update({"db_typeclass_path__in": paths})
return super().filter(*args, **kwargs) return super().filter(*args, **kwargs)
@ -843,6 +876,7 @@ class TypeclassManager(TypedObjectManager):
""" """
paths = [self.model.path] + [ paths = [self.model.path] + [
"%s.%s" % (cls.__module__, cls.__name__) for cls in self._get_subclasses(self.model) "%s.%s" % (cls.__module__, cls.__name__)
for cls in self._get_subclasses(self.model)
] ]
return super().all().filter(db_typeclass_path__in=paths) return super().all().filter(db_typeclass_path__in=paths)

View file

@ -58,12 +58,16 @@ class TestTypedObjectManager(EvenniaTest):
self.obj2.tags.add("tag4") self.obj2.tags.add("tag4")
self.obj2.tags.add("tag2c") self.obj2.tags.add("tag2c")
self.assertEqual(self._manager("get_by_tag", "tag1"), [self.obj1]) self.assertEqual(self._manager("get_by_tag", "tag1"), [self.obj1])
self.assertEqual(set(self._manager("get_by_tag", "tag2")), set([self.obj1, self.obj2])) self.assertEqual(
set(self._manager("get_by_tag", "tag2")), set([self.obj1, self.obj2])
)
self.assertEqual(self._manager("get_by_tag", "tag2a"), [self.obj2]) self.assertEqual(self._manager("get_by_tag", "tag2a"), [self.obj2])
self.assertEqual(self._manager("get_by_tag", "tag3 with spaces"), [self.obj2]) self.assertEqual(self._manager("get_by_tag", "tag3 with spaces"), [self.obj2])
self.assertEqual(self._manager("get_by_tag", ["tag2a", "tag2b"]), [self.obj2]) self.assertEqual(self._manager("get_by_tag", ["tag2a", "tag2b"]), [self.obj2])
self.assertEqual(self._manager("get_by_tag", ["tag2a", "tag1"]), []) self.assertEqual(self._manager("get_by_tag", ["tag2a", "tag1"]), [])
self.assertEqual(self._manager("get_by_tag", ["tag2a", "tag4", "tag2c"]), [self.obj2]) self.assertEqual(
self._manager("get_by_tag", ["tag2a", "tag4", "tag2c"]), [self.obj2]
)
def test_get_by_tag_and_category(self): def test_get_by_tag_and_category(self):
self.obj1.tags.add("tag5", "category1") self.obj1.tags.add("tag5", "category1")
@ -79,24 +83,57 @@ class TestTypedObjectManager(EvenniaTest):
self.obj1.tags.add("tag8", "category6") self.obj1.tags.add("tag8", "category6")
self.obj2.tags.add("tag9", "category6") self.obj2.tags.add("tag9", "category6")
self.assertEqual(self._manager("get_by_tag", "tag5", "category1"), [self.obj1, self.obj2]) self.assertEqual(
self._manager("get_by_tag", "tag5", "category1"), [self.obj1, self.obj2]
)
self.assertEqual(self._manager("get_by_tag", "tag6", "category1"), []) self.assertEqual(self._manager("get_by_tag", "tag6", "category1"), [])
self.assertEqual(self._manager("get_by_tag", "tag6", "category3"), [self.obj1, self.obj2]) self.assertEqual(
self._manager("get_by_tag", "tag6", "category3"), [self.obj1, self.obj2]
)
self.assertEqual( self.assertEqual(
self._manager("get_by_tag", ["tag5", "tag6"], ["category1", "category3"]), self._manager("get_by_tag", ["tag5", "tag6"], ["category1", "category3"]),
[self.obj1, self.obj2], [self.obj1, self.obj2],
) )
self.assertEqual( self.assertEqual(
self._manager("get_by_tag", ["tag5", "tag7"], "category1"), [self.obj1, self.obj2] self._manager("get_by_tag", ["tag5", "tag7"], "category1"),
[self.obj1, self.obj2],
)
self.assertEqual(
self._manager("get_by_tag", category="category1"), [self.obj1, self.obj2]
) )
self.assertEqual(self._manager("get_by_tag", category="category1"), [self.obj1, self.obj2])
self.assertEqual(self._manager("get_by_tag", category="category2"), [self.obj2]) self.assertEqual(self._manager("get_by_tag", category="category2"), [self.obj2])
self.assertEqual( self.assertEqual(
self._manager("get_by_tag", category=["category1", "category3"]), [self.obj1, self.obj2] self._manager("get_by_tag", category=["category1", "category3"]),
[self.obj1, self.obj2],
) )
self.assertEqual( self.assertEqual(
self._manager("get_by_tag", category=["category1", "category2"]), [self.obj1, self.obj2] self._manager("get_by_tag", category=["category1", "category2"]),
[self.obj1, self.obj2],
)
self.assertEqual(
self._manager("get_by_tag", category=["category5", "category4"]), []
)
self.assertEqual(
self._manager("get_by_tag", category="category1"), [self.obj1, self.obj2]
)
self.assertEqual(
self._manager("get_by_tag", category="category6"), [self.obj1, self.obj2]
)
def test_get_tag_with_all(self):
self.obj1.tags.add("tagA", "categoryA")
self.assertEqual(
self._manager(
"get_by_tag", ["tagA", "tagB"], ["categoryA", "categoryB"], match="all"
),
[],
)
def test_get_tag_with_any(self):
self.obj1.tags.add("tagA", "categoryA")
self.assertEqual(
self._manager(
"get_by_tag", ["tagA", "tagB"], ["categoryA", "categoryB"], match="any"
),
[self.obj1],
) )
self.assertEqual(self._manager("get_by_tag", category=["category5", "category4"]), [])
self.assertEqual(self._manager("get_by_tag", category="category1"), [self.obj1, self.obj2])
self.assertEqual(self._manager("get_by_tag", category="category6"), [self.obj1, self.obj2])