Add extra unit test to test #3271

This commit is contained in:
Griatch 2023-09-23 23:25:37 +02:00
parent 0f9d2beb09
commit 1add10bcb0

View file

@ -3,21 +3,20 @@ Unit tests for the scripts package
""" """
from unittest import TestCase, mock
from collections import defaultdict from collections import defaultdict
from unittest import TestCase, mock
from parameterized import parameterized
from evennia import DefaultScript from evennia import DefaultScript
from evennia.objects.objects import DefaultObject from evennia.objects.objects import DefaultObject
from evennia.scripts.models import ObjectDoesNotExist, ScriptDB
from evennia.scripts.scripts import DoNothing, ExtendedLoopingCall
from evennia.utils.create import create_script
from evennia.utils.test_resources import BaseEvenniaTest
from evennia.scripts.tickerhandler import TickerHandler
from evennia.scripts.monitorhandler import MonitorHandler
from evennia.scripts.manager import ScriptDBManager from evennia.scripts.manager import ScriptDBManager
from evennia.scripts.models import ObjectDoesNotExist, ScriptDB
from evennia.scripts.monitorhandler import MonitorHandler
from evennia.scripts.scripts import DoNothing, ExtendedLoopingCall
from evennia.scripts.tickerhandler import TickerHandler
from evennia.utils.create import create_script
from evennia.utils.dbserialize import dbserialize from evennia.utils.dbserialize import dbserialize
from evennia.utils.test_resources import BaseEvenniaTest
from parameterized import parameterized
class TestScript(BaseEvenniaTest): class TestScript(BaseEvenniaTest):
@ -29,34 +28,38 @@ class TestScript(BaseEvenniaTest):
self.assertFalse(errors, errors) self.assertFalse(errors, errors)
mockinit.assert_called() mockinit.assert_called()
class TestTickerHandler(TestCase): class TestTickerHandler(TestCase):
""" Test the TickerHandler class """ """Test the TickerHandler class"""
def test_store_key_raises_RunTimeError(self): def test_store_key_raises_RunTimeError(self):
""" Test _store_key method raises RuntimeError for interval < 1 """ """Test _store_key method raises RuntimeError for interval < 1"""
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
th=TickerHandler() th = TickerHandler()
th._store_key(None, None, 0, None) th._store_key(None, None, 0, None)
def test_remove_raises_RunTimeError(self): def test_remove_raises_RunTimeError(self):
""" Test remove method raises RuntimeError for catching old ordering of arguments """ """Test remove method raises RuntimeError for catching old ordering of arguments"""
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
th=TickerHandler() th = TickerHandler()
th.remove(callback=1) th.remove(callback=1)
class TestScriptDBManager(TestCase): class TestScriptDBManager(TestCase):
""" Test the ScriptDBManger class """ """Test the ScriptDBManger class"""
def test_not_obj_return_empty_list(self): def test_not_obj_return_empty_list(self):
""" Test get_all_scripts_on_obj returns empty list for falsy object """ """Test get_all_scripts_on_obj returns empty list for falsy object"""
manager_obj = ScriptDBManager() manager_obj = ScriptDBManager()
returned_list = manager_obj.get_all_scripts_on_obj(False) returned_list = manager_obj.get_all_scripts_on_obj(False)
self.assertEqual(returned_list, []) self.assertEqual(returned_list, [])
class TestingListIntervalScript(DefaultScript): class TestingListIntervalScript(DefaultScript):
""" """
A script that does nothing. Used to test listing of script with nonzero intervals. A script that does nothing. Used to test listing of script with nonzero intervals.
""" """
def at_script_creation(self): def at_script_creation(self):
""" """
Setup the script Setup the script
@ -66,11 +69,13 @@ class TestingListIntervalScript(DefaultScript):
self.interval = 1 self.interval = 1
self.repeats = 1 self.repeats = 1
class TestScriptHandler(BaseEvenniaTest): class TestScriptHandler(BaseEvenniaTest):
""" """
Test the ScriptHandler class. Test the ScriptHandler class.
""" """
def setUp(self): def setUp(self):
self.obj, self.errors = DefaultObject.create("test_object") self.obj, self.errors = DefaultObject.create("test_object")
@ -82,7 +87,7 @@ class TestScriptHandler(BaseEvenniaTest):
self.obj.scripts.add(TestingListIntervalScript) self.obj.scripts.add(TestingListIntervalScript)
self.num = self.obj.scripts.start(self.obj.scripts.all()[0].key) self.num = self.obj.scripts.start(self.obj.scripts.all()[0].key)
self.assertTrue(self.num == 1) self.assertTrue(self.num == 1)
def test_list_script_intervals(self): def test_list_script_intervals(self):
"Checks that Scripthandler __str__ function lists script intervals correctly" "Checks that Scripthandler __str__ function lists script intervals correctly"
self.obj.scripts.add(TestingListIntervalScript) self.obj.scripts.add(TestingListIntervalScript)
@ -90,6 +95,13 @@ class TestScriptHandler(BaseEvenniaTest):
self.assertTrue("None/1" in self.str) self.assertTrue("None/1" in self.str)
self.assertTrue("1 repeats" in self.str) self.assertTrue("1 repeats" in self.str)
def test_get_script(self):
"Checks that Scripthandler get function returns correct script"
self.obj.scripts.add(TestingListIntervalScript)
script = self.obj.scripts.get("interval_test")
self.assertTrue(bool(script))
class TestScriptDB(TestCase): class TestScriptDB(TestCase):
"Check the singleton/static ScriptDB object works correctly" "Check the singleton/static ScriptDB object works correctly"
@ -161,14 +173,14 @@ class TestExtendedLoopingCall(TestCase):
loopcall._scheduleFrom.assert_called_with(121) loopcall._scheduleFrom.assert_called_with(121)
def test_start_invalid_interval(self): def test_start_invalid_interval(self):
""" Test the .start method with interval less than zero """ """Test the .start method with interval less than zero"""
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
callback = mock.MagicMock() callback = mock.MagicMock()
loopcall = ExtendedLoopingCall(callback) loopcall = ExtendedLoopingCall(callback)
loopcall.start(-1, now=True, start_delay=None, count_start=1) loopcall.start(-1, now=True, start_delay=None, count_start=1)
def test__call__when_delay(self): def test__call__when_delay(self):
""" Test __call__ modifies start_delay and starttime if start_delay was previously set """ """Test __call__ modifies start_delay and starttime if start_delay was previously set"""
callback = mock.MagicMock() callback = mock.MagicMock()
loopcall = ExtendedLoopingCall(callback) loopcall = ExtendedLoopingCall(callback)
loopcall.clock.seconds = mock.MagicMock(return_value=1) loopcall.clock.seconds = mock.MagicMock(return_value=1)
@ -176,12 +188,12 @@ class TestExtendedLoopingCall(TestCase):
loopcall.starttime = 0 loopcall.starttime = 0
loopcall() loopcall()
self.assertEqual(loopcall.start_delay, None) self.assertEqual(loopcall.start_delay, None)
self.assertEqual(loopcall.starttime, 1) self.assertEqual(loopcall.starttime, 1)
def test_force_repeat(self): def test_force_repeat(self):
""" Test forcing script to run that is scheduled to run in the future """ """Test forcing script to run that is scheduled to run in the future"""
callback = mock.MagicMock() callback = mock.MagicMock()
loopcall = ExtendedLoopingCall(callback) loopcall = ExtendedLoopingCall(callback)
loopcall.clock.seconds = mock.MagicMock(return_value=0) loopcall.clock.seconds = mock.MagicMock(return_value=0)
@ -192,10 +204,12 @@ class TestExtendedLoopingCall(TestCase):
callback.assert_called_once() callback.assert_called_once()
def dummy_func(): def dummy_func():
""" Dummy function used as callback parameter """ """Dummy function used as callback parameter"""
return 0 return 0
class TestMonitorHandler(TestCase): class TestMonitorHandler(TestCase):
""" """
Test the MonitorHandler class. Test the MonitorHandler class.
@ -220,13 +234,13 @@ class TestMonitorHandler(TestCase):
def test_remove(self): def test_remove(self):
"""Tests that removing an object from the monitor handler works correctly""" """Tests that removing an object from the monitor handler works correctly"""
obj = mock.Mock() obj = mock.Mock()
fieldname = 'db_remove' fieldname = "db_remove"
callback = dummy_func callback = dummy_func
idstring = 'test_remove' idstring = "test_remove"
"""Add an object to the monitor handler and then remove it""" """Add an object to the monitor handler and then remove it"""
self.handler.add(obj,fieldname,callback,idstring=idstring) self.handler.add(obj, fieldname, callback, idstring=idstring)
self.handler.remove(obj,fieldname,idstring=idstring) self.handler.remove(obj, fieldname, idstring=idstring)
self.assertEquals(self.handler.monitors[obj][fieldname], {}) self.assertEquals(self.handler.monitors[obj][fieldname], {})
def test_add_with_invalid_function(self): def test_add_with_invalid_function(self):
@ -234,25 +248,29 @@ class TestMonitorHandler(TestCase):
"""Tests that add method rejects objects where callback is not a function""" """Tests that add method rejects objects where callback is not a function"""
fieldname = "db_key" fieldname = "db_key"
callback = "not_a_function" callback = "not_a_function"
self.handler.add(obj, fieldname, callback) self.handler.add(obj, fieldname, callback)
self.assertNotIn(fieldname, self.handler.monitors[obj]) self.assertNotIn(fieldname, self.handler.monitors[obj])
def test_all(self): def test_all(self):
"""Tests that all method correctly returns information about added objects""" """Tests that all method correctly returns information about added objects"""
obj = [mock.Mock(),mock.Mock()] obj = [mock.Mock(), mock.Mock()]
fieldname = ["db_all1","db_all2"] fieldname = ["db_all1", "db_all2"]
callback = dummy_func callback = dummy_func
idstring = ["test_all1","test_all2"] idstring = ["test_all1", "test_all2"]
self.handler.add(obj[0], fieldname[0], callback, idstring=idstring[0]) self.handler.add(obj[0], fieldname[0], callback, idstring=idstring[0])
self.handler.add(obj[1], fieldname[1], callback, idstring=idstring[1],persistent=True) self.handler.add(obj[1], fieldname[1], callback, idstring=idstring[1], persistent=True)
output = self.handler.all() output = self.handler.all()
self.assertEquals(output, self.assertEquals(
[(obj[0], fieldname[0], idstring[0], False, {}), output,
(obj[1], fieldname[1], idstring[1], True, {})]) [
(obj[0], fieldname[0], idstring[0], False, {}),
(obj[1], fieldname[1], idstring[1], True, {}),
],
)
def test_clear(self): def test_clear(self):
"""Tests that the clear function correctly clears the monitor handler""" """Tests that the clear function correctly clears the monitor handler"""
obj = mock.Mock() obj = mock.Mock()
@ -277,7 +295,7 @@ class TestMonitorHandler(TestCase):
category = "testattribute" category = "testattribute"
"""Add attribute to handler and assert that it has been added""" """Add attribute to handler and assert that it has been added"""
self.handler.add(obj, fieldname, callback, idstring=idstring,category=category) self.handler.add(obj, fieldname, callback, idstring=idstring, category=category)
index = obj.attributes.get(fieldname, return_obj=True) index = obj.attributes.get(fieldname, return_obj=True)
name = "db_value[testattribute]" name = "db_value[testattribute]"
@ -287,5 +305,5 @@ class TestMonitorHandler(TestCase):
self.assertEqual(self.handler.monitors[index][name][idstring], (callback, False, {})) self.assertEqual(self.handler.monitors[index][name][idstring], (callback, False, {}))
"""Remove attribute from the handler and assert that it is gone""" """Remove attribute from the handler and assert that it is gone"""
self.handler.remove(obj,fieldname,idstring=idstring,category=category) self.handler.remove(obj, fieldname, idstring=idstring, category=category)
self.assertEquals(self.handler.monitors[index][name], {}) self.assertEquals(self.handler.monitors[index][name], {})