Fix filtering
This commit is contained in:
parent
c83567c95e
commit
ec570a17cd
2 changed files with 7 additions and 1 deletions
|
|
@ -8,7 +8,7 @@ https://django-filter.readthedocs.io/en/latest/guide/rest_framework.html
|
||||||
"""
|
"""
|
||||||
from django.db.models import Q
|
from django.db.models import Q
|
||||||
from django_filters.rest_framework.filterset import FilterSet
|
from django_filters.rest_framework.filterset import FilterSet
|
||||||
from django_filters.filters import CharFilter
|
from django_filters.filters import CharFilter, EMPTY_VALUES
|
||||||
|
|
||||||
from evennia.objects.models import ObjectDB
|
from evennia.objects.models import ObjectDB
|
||||||
from evennia.accounts.models import AccountDB
|
from evennia.accounts.models import AccountDB
|
||||||
|
|
@ -22,6 +22,10 @@ class TagTypeFilter(CharFilter):
|
||||||
tag_type = None
|
tag_type = None
|
||||||
|
|
||||||
def filter(self, qs, value):
|
def filter(self, qs, value):
|
||||||
|
# if no value is specified, we don't use the filter
|
||||||
|
if value in EMPTY_VALUES:
|
||||||
|
return qs
|
||||||
|
# if they enter a value, we filter objects by having a tag of this type with the given name
|
||||||
return qs.filter(Q(db_tags__db_tagtype=self.tag_type) & Q(db_tags__db_key=value))
|
return qs.filter(Q(db_tags__db_tagtype=self.tag_type) & Q(db_tags__db_key=value))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ urlpatterns = [
|
||||||
)
|
)
|
||||||
class TestEvenniaRESTApi(EvenniaTest):
|
class TestEvenniaRESTApi(EvenniaTest):
|
||||||
client_class = APIClient
|
client_class = APIClient
|
||||||
|
maxDiff = None
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|
@ -96,6 +97,7 @@ class TestEvenniaRESTApi(EvenniaTest):
|
||||||
with self.subTest(msg=f"Testing {view.view_name} "):
|
with self.subTest(msg=f"Testing {view.view_name} "):
|
||||||
view_url = reverse(f"api:{view.view_name}")
|
view_url = reverse(f"api:{view.view_name}")
|
||||||
response = self.client.get(view_url)
|
response = self.client.get(view_url)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
self.assertCountEqual(response.data['results'], [view.serializer(obj).data for obj in view.list])
|
self.assertCountEqual(response.data['results'], [view.serializer(obj).data for obj in view.list])
|
||||||
|
|
||||||
def test_create(self):
|
def test_create(self):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue