261 lines
8.3 KiB
Python
261 lines
8.3 KiB
Python
from gevent.pywsgi import WSGIHandler, _InvalidClientInput
|
|
from gevent.queue import Queue
|
|
import gevent
|
|
import hashlib
|
|
import base64
|
|
import struct
|
|
import socket
|
|
import time
|
|
import sys
|
|
|
|
|
|
SEND_PACKET_SIZE = 1300
|
|
OPCODE_TEXT = 1
|
|
OPCODE_BINARY = 2
|
|
OPCODE_CLOSE = 8
|
|
OPCODE_PING = 9
|
|
OPCODE_PONG = 10
|
|
STATUS_OK = 1000
|
|
STATUS_PROTOCOL_ERROR = 1002
|
|
STATUS_DATA_ERROR = 1007
|
|
STATUS_POLICY_VIOLATION = 1008
|
|
STATUS_TOO_LONG = 1009
|
|
|
|
|
|
class WebSocket:
|
|
def __init__(self, socket):
|
|
self.socket = socket
|
|
self.closed = False
|
|
self.status = None
|
|
self._receive_error = None
|
|
self._queue = Queue()
|
|
self.max_length = 10 * 1024 * 1024
|
|
gevent.spawn(self._listen)
|
|
|
|
|
|
def set_max_message_length(self, length):
|
|
self.max_length = length
|
|
|
|
|
|
def _listen(self):
|
|
try:
|
|
while True:
|
|
fin = False
|
|
message = bytearray()
|
|
is_first_message = True
|
|
start_opcode = None
|
|
while not fin:
|
|
payload, opcode, fin = self._get_frame(max_length=self.max_length - len(message))
|
|
# Make sure continuation frames have correct information
|
|
if not is_first_message and opcode != 0:
|
|
self._error(STATUS_PROTOCOL_ERROR)
|
|
if is_first_message:
|
|
if opcode not in (OPCODE_TEXT, OPCODE_BINARY):
|
|
self._error(STATUS_PROTOCOL_ERROR)
|
|
# Save opcode
|
|
start_opcode = opcode
|
|
message += payload
|
|
is_first_message = False
|
|
message = bytes(message)
|
|
if start_opcode == OPCODE_TEXT: # UTF-8 text
|
|
try:
|
|
message = message.decode()
|
|
except UnicodeDecodeError:
|
|
self._error(STATUS_DATA_ERROR)
|
|
self._queue.put(message)
|
|
except Exception as e:
|
|
self.closed = True
|
|
self._receive_error = e
|
|
self._queue.put(None) # To make sure the error is read
|
|
|
|
|
|
def receive(self):
|
|
if not self._queue.empty():
|
|
return self.receive_nowait()
|
|
if isinstance(self._receive_error, EOFError):
|
|
return None
|
|
if self._receive_error:
|
|
raise self._receive_error
|
|
self._queue.peek()
|
|
return self.receive_nowait()
|
|
|
|
|
|
def receive_nowait(self):
|
|
ret = self._queue.get_nowait()
|
|
if self._receive_error and not isinstance(self._receive_error, EOFError):
|
|
raise self._receive_error
|
|
return ret
|
|
|
|
|
|
def send(self, data):
|
|
if self.closed:
|
|
raise EOFError()
|
|
if isinstance(data, str):
|
|
self._send_frame(OPCODE_TEXT, data.encode())
|
|
elif isinstance(data, bytes):
|
|
self._send_frame(OPCODE_BINARY, data)
|
|
else:
|
|
raise TypeError("Expected str or bytes, got " + repr(type(data)))
|
|
|
|
|
|
# Reads a frame from the socket. Pings, pongs and close packets are handled
|
|
# automatically
|
|
def _get_frame(self, max_length):
|
|
while True:
|
|
payload, opcode, fin = self._read_frame(max_length=max_length)
|
|
if opcode == OPCODE_PING:
|
|
self._send_frame(OPCODE_PONG, payload)
|
|
elif opcode == OPCODE_PONG:
|
|
pass
|
|
elif opcode == OPCODE_CLOSE:
|
|
if len(payload) >= 2:
|
|
self.status = struct.unpack("!H", payload[:2])[0]
|
|
was_closed = self.closed
|
|
self.closed = True
|
|
if not was_closed:
|
|
# Send a close frame in response
|
|
self.close(STATUS_OK)
|
|
raise EOFError()
|
|
else:
|
|
return payload, opcode, fin
|
|
|
|
|
|
# Low-level function, use _get_frame instead
|
|
def _read_frame(self, max_length):
|
|
header = self._recv_exactly(2)
|
|
|
|
if not (header[1] & 0x80):
|
|
self._error(STATUS_POLICY_VIOLATION)
|
|
|
|
opcode = header[0] & 0xf
|
|
fin = bool(header[0] & 0x80)
|
|
|
|
payload_length = header[1] & 0x7f
|
|
if payload_length == 126:
|
|
payload_length = struct.unpack("!H", self._recv_exactly(2))[0]
|
|
elif payload_length == 127:
|
|
payload_length = struct.unpack("!Q", self._recv_exactly(8))[0]
|
|
|
|
# Control frames are handled in a special way
|
|
if opcode in (OPCODE_PING, OPCODE_PONG):
|
|
max_length = 125
|
|
|
|
if payload_length > max_length:
|
|
self._error(STATUS_TOO_LONG)
|
|
|
|
mask = self._recv_exactly(4)
|
|
payload = self._recv_exactly(payload_length)
|
|
payload = self._unmask(payload, mask)
|
|
|
|
return payload, opcode, fin
|
|
|
|
|
|
def _recv_exactly(self, length):
|
|
buf = bytearray()
|
|
while len(buf) < length:
|
|
block = self.socket.recv(min(4096, length - len(buf)))
|
|
if block == b"":
|
|
raise EOFError()
|
|
buf += block
|
|
return bytes(buf)
|
|
|
|
|
|
def _unmask(self, payload, mask):
|
|
def gen(c):
|
|
return bytes([x ^ c for x in range(256)])
|
|
|
|
|
|
payload = bytearray(payload)
|
|
payload[0::4] = payload[0::4].translate(gen(mask[0]))
|
|
payload[1::4] = payload[1::4].translate(gen(mask[1]))
|
|
payload[2::4] = payload[2::4].translate(gen(mask[2]))
|
|
payload[3::4] = payload[3::4].translate(gen(mask[3]))
|
|
return bytes(payload)
|
|
|
|
|
|
def _send_frame(self, opcode, data):
|
|
for i in range(0, len(data), SEND_PACKET_SIZE):
|
|
part = data[i:i + SEND_PACKET_SIZE]
|
|
fin = int(i == (len(data) - 1) // SEND_PACKET_SIZE * SEND_PACKET_SIZE)
|
|
header = bytes(
|
|
[
|
|
(opcode if i == 0 else 0) | (fin << 7),
|
|
min(len(part), 126)
|
|
]
|
|
)
|
|
if len(part) >= 126:
|
|
header += struct.pack("!H", len(part))
|
|
self.socket.sendall(header + part)
|
|
|
|
|
|
def _error(self, status):
|
|
self.close(status)
|
|
raise EOFError()
|
|
|
|
|
|
def close(self, status=STATUS_OK):
|
|
self.closed = True
|
|
self._send_frame(OPCODE_CLOSE, struct.pack("!H", status))
|
|
self.socket.close()
|
|
|
|
|
|
class WebSocketHandler(WSGIHandler):
|
|
def handle_one_response(self):
|
|
self.time_start = time.time()
|
|
self.status = None
|
|
self.headers_sent = False
|
|
|
|
self.result = None
|
|
self.response_use_chunked = False
|
|
self.response_length = 0
|
|
|
|
|
|
http_connection = [s.strip().lower() for s in self.environ.get("HTTP_CONNECTION", "").split(",")]
|
|
if "upgrade" not in http_connection or self.environ.get("HTTP_UPGRADE", "") != "websocket":
|
|
# Not my problem
|
|
return super(WebSocketHandler, self).handle_one_response()
|
|
|
|
if "HTTP_SEC_WEBSOCKET_KEY" not in self.environ:
|
|
self.start_response("400 Bad Request", [])
|
|
return
|
|
|
|
# Generate Sec-Websocket-Accept header
|
|
accept = self.environ["HTTP_SEC_WEBSOCKET_KEY"].encode()
|
|
accept += b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
|
accept = base64.b64encode(hashlib.sha1(accept).digest()).decode()
|
|
|
|
# Accept
|
|
self.start_response("101 Switching Protocols", [
|
|
("Upgrade", "websocket"),
|
|
("Connection", "Upgrade"),
|
|
("Sec-Websocket-Accept", accept)
|
|
])(b"")
|
|
|
|
self.environ["wsgi.websocket"] = WebSocket(self.socket)
|
|
|
|
# Can't call super because it sets invalid flags like "status"
|
|
try:
|
|
try:
|
|
self.run_application()
|
|
finally:
|
|
try:
|
|
self.wsgi_input._discard()
|
|
except (socket.error, IOError):
|
|
pass
|
|
except _InvalidClientInput:
|
|
self._send_error_response_if_possible(400)
|
|
except socket.error as ex:
|
|
if ex.args[0] in self.ignored_socket_errors:
|
|
self.close_connection = True
|
|
else:
|
|
self.handle_error(*sys.exc_info())
|
|
except: # pylint:disable=bare-except
|
|
self.handle_error(*sys.exc_info())
|
|
finally:
|
|
self.time_finish = time.time()
|
|
self.log_request()
|
|
|
|
|
|
def process_result(self):
|
|
if "wsgi.websocket" not in self.environ:
|
|
super(WebSocketHandler, self).process_result()
|