Use new Msgpack library for unpacking

This commit is contained in:
shortcutme 2019-03-16 02:30:54 +01:00
parent 20806a8c97
commit edd3f35790
No known key found for this signature in database
GPG key ID: 5B63BAE6CB9613AE

View file

@ -2,8 +2,6 @@ import socket
import time import time
import gevent import gevent
import msgpack
import msgpack.fallback
try: try:
from gevent.coros import RLock from gevent.coros import RLock
except: except:
@ -11,17 +9,17 @@ except:
from Config import config from Config import config
from Debug import Debug from Debug import Debug
from util import StreamingMsgpack from util import Msgpack
from Crypt import CryptConnection from Crypt import CryptConnection
from util import helper from util import helper
class Connection(object): class Connection(object):
__slots__ = ( __slots__ = (
"sock", "sock_wrapped", "ip", "port", "cert_pin", "target_onion", "id", "protocol", "type", "server", "unpacker", "req_id", "ip_type", "sock", "sock_wrapped", "ip", "port", "cert_pin", "target_onion", "id", "protocol", "type", "server", "unpacker", "unpacker_bytes", "req_id", "ip_type",
"handshake", "crypt", "connected", "event_connected", "closed", "start_time", "handshake_time", "last_recv_time", "is_private_ip", "is_tracker_connection", "handshake", "crypt", "connected", "event_connected", "closed", "start_time", "handshake_time", "last_recv_time", "is_private_ip", "is_tracker_connection",
"last_message_time", "last_send_time", "last_sent_time", "incomplete_buff_recv", "bytes_recv", "bytes_sent", "cpu_time", "send_lock", "last_message_time", "last_send_time", "last_sent_time", "incomplete_buff_recv", "bytes_recv", "bytes_sent", "cpu_time", "send_lock",
"last_ping_delay", "last_req_time", "last_cmd_sent", "last_cmd_recv", "bad_actions", "sites", "name", "updateName", "waiting_requests", "waiting_streams" "last_ping_delay", "last_req_time", "last_cmd_sent", "last_cmd_recv", "bad_actions", "sites", "name", "waiting_requests", "waiting_streams"
) )
def __init__(self, server, ip, port, sock=None, target_onion=None, is_tracker_connection=False): def __init__(self, server, ip, port, sock=None, target_onion=None, is_tracker_connection=False):
@ -46,6 +44,7 @@ class Connection(object):
self.server = server self.server = server
self.unpacker = None # Stream incoming socket messages here self.unpacker = None # Stream incoming socket messages here
self.unpacker_bytes = 0 # How many bytes the unpacker received
self.req_id = 0 # Last request id self.req_id = 0 # Last request id
self.handshake = {} # Handshake info got from peer self.handshake = {} # Handshake info got from peer
self.crypt = None # Connection encryption method self.crypt = None # Connection encryption method
@ -102,7 +101,7 @@ class Connection(object):
return "<%s>" % self.__str__() return "<%s>" % self.__str__()
def log(self, text): def log(self, text):
self.server.log.debug("%s > %s" % (self.name, text.decode("utf8", "ignore"))) self.server.log.debug("%s > %s" % (self.name, text))
def getValidSites(self): def getValidSites(self):
return [key for key, val in self.server.tor_manager.site_onions.items() if val == self.target_onion] return [key for key, val in self.server.tor_manager.site_onions.items() if val == self.target_onion]
@ -162,7 +161,7 @@ class Connection(object):
self.sock.do_handshake() self.sock.do_handshake()
self.crypt = "tls-rsa" self.crypt = "tls-rsa"
self.sock_wrapped = True self.sock_wrapped = True
except Exception, err: except Exception as err:
if not config.force_encryption: if not config.force_encryption:
self.log("Crypt connection error: %s, adding ip %s as broken ssl." % (err, self.ip)) self.log("Crypt connection error: %s, adding ip %s as broken ssl." % (err, self.ip))
self.server.broken_ssl_ips[self.ip] = True self.server.broken_ssl_ips[self.ip] = True
@ -194,10 +193,16 @@ class Connection(object):
self.sock = CryptConnection.manager.wrapSocket(self.sock, "tls-rsa", True) self.sock = CryptConnection.manager.wrapSocket(self.sock, "tls-rsa", True)
self.sock_wrapped = True self.sock_wrapped = True
self.crypt = "tls-rsa" self.crypt = "tls-rsa"
except Exception, err: except Exception as err:
self.log("Socket peek error: %s" % Debug.formatException(err)) self.log("Socket peek error: %s" % Debug.formatException(err))
self.messageLoop() self.messageLoop()
def getMsgpackUnpacker(self):
if self.handshake and self.handshake.get("use_bin_type"):
return Msgpack.getUnpacker(fallback=True, decode=False)
else: # Backward compatibility for <0.7.0
return Msgpack.getUnpacker(fallback=True, decode=True)
# Message loop for connection # Message loop for connection
def messageLoop(self): def messageLoop(self):
if not self.sock: if not self.sock:
@ -208,7 +213,7 @@ class Connection(object):
self.connected = True self.connected = True
buff_len = 0 buff_len = 0
req_len = 0 req_len = 0
unpacker_bytes = 0 self.unpacker_bytes = 0
try: try:
while not self.closed: while not self.closed:
@ -225,15 +230,15 @@ class Connection(object):
req_len += buff_len req_len += buff_len
if not self.unpacker: if not self.unpacker:
self.unpacker = msgpack.fallback.Unpacker() self.unpacker = self.getMsgpackUnpacker()
unpacker_bytes = 0 self.unpacker_bytes = 0
self.unpacker.feed(buff) self.unpacker.feed(buff)
unpacker_bytes += buff_len self.unpacker_bytes += buff_len
while True: while True:
try: try:
message = self.unpacker.next() message = next(self.unpacker)
except StopIteration: except StopIteration:
break break
if not type(message) is dict: if not type(message) is dict:
@ -257,10 +262,10 @@ class Connection(object):
# Handle message # Handle message
if "stream_bytes" in message: if "stream_bytes" in message:
buff_left = self.handleStream(message, self.unpacker, buff, unpacker_bytes) buff_left = self.handleStream(message, buff)
self.unpacker = msgpack.fallback.Unpacker() self.unpacker = self.getMsgpackUnpacker()
self.unpacker.feed(buff_left) self.unpacker.feed(buff_left)
unpacker_bytes = len(buff_left) self.unpacker_bytes = len(buff_left)
if config.debug_socket: if config.debug_socket:
self.log("Start new unpacker with buff_left: %r" % buff_left) self.log("Start new unpacker with buff_left: %r" % buff_left)
else: else:
@ -274,19 +279,23 @@ class Connection(object):
self.server.stat_recv["error: %s" % err]["num"] += 1 self.server.stat_recv["error: %s" % err]["num"] += 1
self.close("MessageLoop ended (closed: %s)" % self.closed) # MessageLoop ended, close connection self.close("MessageLoop ended (closed: %s)" % self.closed) # MessageLoop ended, close connection
def getUnpackerUnprocessedBytesNum(self):
if "tell" in dir(self.unpacker):
bytes_num = self.unpacker_bytes - self.unpacker.tell()
else:
bytes_num = self.unpacker._fb_buf_n - self.unpacker._fb_buf_o
return bytes_num
# Stream socket directly to a file # Stream socket directly to a file
def handleStream(self, message, unpacker, buff, unpacker_bytes): def handleStream(self, message, buff):
stream_bytes_left = message["stream_bytes"] stream_bytes_left = message["stream_bytes"]
file = self.waiting_streams[message["to"]] file = self.waiting_streams[message["to"]]
if "tell" in dir(unpacker): unprocessed_bytes_num = self.getUnpackerUnprocessedBytesNum()
unpacker_unprocessed_bytes = unpacker_bytes - unpacker.tell()
else:
unpacker_unprocessed_bytes = unpacker._fb_buf_n - unpacker._fb_buf_o
if unpacker_unprocessed_bytes: # Found stream bytes in unpacker if unprocessed_bytes_num: # Found stream bytes in unpacker
unpacker_stream_bytes = min(unpacker_unprocessed_bytes, stream_bytes_left) unpacker_stream_bytes = min(unprocessed_bytes_num, stream_bytes_left)
buff_stream_start = len(buff) - unpacker_unprocessed_bytes buff_stream_start = len(buff) - unprocessed_bytes_num
file.write(buff[buff_stream_start:buff_stream_start + unpacker_stream_bytes]) file.write(buff[buff_stream_start:buff_stream_start + unpacker_stream_bytes])
stream_bytes_left -= unpacker_stream_bytes stream_bytes_left -= unpacker_stream_bytes
else: else:
@ -295,7 +304,7 @@ class Connection(object):
if config.debug_socket: if config.debug_socket:
self.log( self.log(
"Starting stream %s: %s bytes (%s from unpacker, buff size: %s, unprocessed: %s)" % "Starting stream %s: %s bytes (%s from unpacker, buff size: %s, unprocessed: %s)" %
(message["to"], message["stream_bytes"], unpacker_stream_bytes, len(buff), unpacker_unprocessed_bytes) (message["to"], message["stream_bytes"], unpacker_stream_bytes, len(buff), unprocessed_bytes_num)
) )
try: try:
@ -351,6 +360,7 @@ class Connection(object):
handshake = { handshake = {
"version": config.version, "version": config.version,
"protocol": "v2", "protocol": "v2",
"use_bin_type": True,
"peer_id": peer_id, "peer_id": peer_id,
"fileserver_port": self.server.port, "fileserver_port": self.server.port,
"port_opened": self.server.port_opened.get(self.ip_type, None), "port_opened": self.server.port_opened.get(self.ip_type, None),
@ -390,6 +400,9 @@ class Connection(object):
# Check if we can encrypt the connection # Check if we can encrypt the connection
if handshake.get("crypt_supported") and self.ip not in self.server.broken_ssl_ips: if handshake.get("crypt_supported") and self.ip not in self.server.broken_ssl_ips:
if type(handshake["crypt_supported"][0]) is bytes:
handshake["crypt_supported"] = [item.decode() for item in handshake["crypt_supported"]] # Backward compatibility
if self.ip_type == "onion" or self.ip in config.ip_local: if self.ip_type == "onion" or self.ip in config.ip_local:
crypt = None crypt = None
elif handshake.get("crypt"): # Recommended crypt by server elif handshake.get("crypt"): # Recommended crypt by server
@ -513,13 +526,13 @@ class Connection(object):
self.server.stat_sent[stat_key]["num"] += 1 self.server.stat_sent[stat_key]["num"] += 1
if streaming: if streaming:
with self.send_lock: with self.send_lock:
bytes_sent = StreamingMsgpack.stream(message, self.sock.sendall) bytes_sent = Msgpack.stream(message, self.sock.sendall)
self.bytes_sent += bytes_sent self.bytes_sent += bytes_sent
self.server.bytes_sent += bytes_sent self.server.bytes_sent += bytes_sent
self.server.stat_sent[stat_key]["bytes"] += bytes_sent self.server.stat_sent[stat_key]["bytes"] += bytes_sent
message = None message = None
else: else:
data = msgpack.packb(message) data = Msgpack.pack(message)
self.bytes_sent += len(data) self.bytes_sent += len(data)
self.server.bytes_sent += len(data) self.server.bytes_sent += len(data)
self.server.stat_sent[stat_key]["bytes"] += len(data) self.server.stat_sent[stat_key]["bytes"] += len(data)