Converting the AMP, not working yet

This commit is contained in:
Griatch 2018-10-04 21:46:16 +02:00
parent b4cc3d0ac2
commit f9369f2784
2 changed files with 26 additions and 14 deletions

View file

@ -9,7 +9,7 @@ from functools import wraps
import time import time
from twisted.protocols import amp from twisted.protocols import amp
from collections import defaultdict, namedtuple from collections import defaultdict, namedtuple
from io import StringIO from io import StringIO, BytesIO
from itertools import count from itertools import count
import zlib # Used in Compressed class import zlib # Used in Compressed class
import pickle import pickle
@ -41,8 +41,8 @@ SSHUTD = chr(17) # server shutdown
PSTATUS = chr(18) # ping server or portal status PSTATUS = chr(18) # ping server or portal status
SRESET = chr(19) # server shutdown in reset mode SRESET = chr(19) # server shutdown in reset mode
NUL = b'\0' NUL = b'\x00'
NULNUL = '\0\0' NULNUL = b'\x00\x00'
AMP_MAXLEN = amp.MAX_VALUE_LENGTH # max allowed data length in AMP protocol (cannot be changed) AMP_MAXLEN = amp.MAX_VALUE_LENGTH # max allowed data length in AMP protocol (cannot be changed)
@ -55,7 +55,7 @@ _MSGBUFFER = defaultdict(list)
DUMMYSESSION = namedtuple('DummySession', ['sessid'])(0) DUMMYSESSION = namedtuple('DummySession', ['sessid'])(0)
_HTTP_WARNING = """ _HTTP_WARNING = bytes("""
HTTP/1.1 200 OK HTTP/1.1 200 OK
Content-Type: text/html Content-Type: text/html
@ -67,7 +67,7 @@ Content-Type: text/html
<h3>This port should NOT be publicly visible.</h3> <h3>This port should NOT be publicly visible.</h3>
</p> </p>
</body> </body>
</html>""".strip() </html>""".strip(), 'utf-8')
# Helper functions for pickling. # Helper functions for pickling.
@ -113,7 +113,8 @@ class Compressed(amp.String):
put it back together here. put it back together here.
""" """
value = StringIO()
value = BytesIO()
value.write(self.fromStringProto(strings.get(name), proto)) value.write(self.fromStringProto(strings.get(name), proto))
for counter in count(2): for counter in count(2):
# count from 2 upwards # count from 2 upwards
@ -121,7 +122,7 @@ class Compressed(amp.String):
if chunk is None: if chunk is None:
break break
value.write(self.fromStringProto(chunk, proto)) value.write(self.fromStringProto(chunk, proto))
objects[name] = value.getvalue() objects[str(name, 'utf-8')] = value.getvalue()
def toBox(self, name, strings, objects, proto): def toBox(self, name, strings, objects, proto):
""" """
@ -129,8 +130,14 @@ class Compressed(amp.String):
we break up too-long data snippets into multiple batches here. we break up too-long data snippets into multiple batches here.
""" """
value = StringIO(objects[name])
# print("toBox: name={}, strings={}, objects={}, proto{}".format(name, strings, objects, proto))
value = BytesIO(objects[str(name, 'utf-8')])
strings[name] = self.toStringProto(value.read(AMP_MAXLEN), proto) strings[name] = self.toStringProto(value.read(AMP_MAXLEN), proto)
# print("toBox strings[name] = {}".format(strings[name]))
for counter in count(2): for counter in count(2):
chunk = value.read(AMP_MAXLEN) chunk = value.read(AMP_MAXLEN)
if not chunk: if not chunk:
@ -140,12 +147,16 @@ class Compressed(amp.String):
def toString(self, inObject): def toString(self, inObject):
""" """
Convert to send as a string on the wire, with compression. Convert to send as a string on the wire, with compression.
Note: In Py3 this is really a byte stream.
""" """
return zlib.compress(super(Compressed, self).toString(inObject), 9) return zlib.compress(super(Compressed, self).toString(inObject), 9)
def fromString(self, inString): def fromString(self, inString):
""" """
Convert (decompress) from the string-representation on the wire to Python. Convert (decompress) from the string-representation on the wire to Python.
""" """
return super(Compressed, self).fromString(zlib.decompress(inString)) return super(Compressed, self).fromString(zlib.decompress(inString))
@ -167,7 +178,7 @@ class MsgPortal2Server(amp.Command):
Message Portal -> Server Message Portal -> Server
""" """
key = "MsgPortal2Server" key = b"MsgPortal2Server"
arguments = [(b'packed_data', Compressed())] arguments = [(b'packed_data', Compressed())]
errors = {Exception: b'EXCEPTION'} errors = {Exception: b'EXCEPTION'}
response = [] response = []
@ -271,7 +282,7 @@ class AMPMultiConnectionProtocol(amp.AMP):
""" """
Handle non-AMP messages, such as HTTP communication. Handle non-AMP messages, such as HTTP communication.
""" """
if data[0] == NUL: if data[:1] == NUL:
# an AMP communication # an AMP communication
if data[-2:] != NULNUL: if data[-2:] != NULNUL:
# an incomplete AMP box means more batches are forthcoming. # an incomplete AMP box means more batches are forthcoming.
@ -287,7 +298,7 @@ class AMPMultiConnectionProtocol(amp.AMP):
# not an AMP communication, return warning # not an AMP communication, return warning
self.transport.write(_HTTP_WARNING) self.transport.write(_HTTP_WARNING)
self.transport.loseConnection() self.transport.loseConnection()
print("HTML received: %s" % data) print("HTTP received: %s" % data)
def makeConnection(self, transport): def makeConnection(self, transport):
""" """
@ -348,9 +359,9 @@ class AMPMultiConnectionProtocol(amp.AMP):
Process incoming packed data. Process incoming packed data.
Args: Args:
packed_data (bytes): Zip-packed data. packed_data (bytes): Pickled data.
Returns: Returns:
unpaced_data (any): Unpacked package unpaced_data (any): Unpickled package
""" """
return loads(packed_data) return loads(packed_data)

View file

@ -194,7 +194,8 @@ class SharedMemoryModelBase(ModelBase):
# exclude some models that should not auto-create wrapper fields # exclude some models that should not auto-create wrapper fields
if cls.__name__ in ("ServerConfig", "TypeNick"): if cls.__name__ in ("ServerConfig", "TypeNick"):
return return
# dynamically create the wrapper properties for all fields not already handled (manytomanyfields are always handlers) # dynamically create the wrapper properties for all fields not already handled
# (manytomanyfields are always handlers)
for fieldname, field in ((fname, field) for fname, field in listitems(attrs) for fieldname, field in ((fname, field) for fname, field in listitems(attrs)
if fname.startswith("db_") and type(field).__name__ != "ManyToManyField"): if fname.startswith("db_") and type(field).__name__ != "ManyToManyField"):
foreignkey = type(field).__name__ == "ForeignKey" foreignkey = type(field).__name__ == "ForeignKey"