From f2b2b81ca7227b3f9cbee2b7e9e5f6ff4e2867e8 Mon Sep 17 00:00:00 2001 From: Jesse Luehrs Date: Thu, 16 Oct 2014 13:07:32 -0400 Subject: add ssl support --- termcast_server/termcast.py | 97 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 75 insertions(+), 22 deletions(-) (limited to 'termcast_server/termcast.py') diff --git a/termcast_server/termcast.py b/termcast_server/termcast.py index f99971d..a2318d6 100644 --- a/termcast_server/termcast.py +++ b/termcast_server/termcast.py @@ -1,6 +1,7 @@ import time import json import re +import ssl import vt100 @@ -132,27 +133,26 @@ class Handler(object): return ret class Connection(object): - def __init__(self, client, connection_id, publisher): + def __init__(self, client, connection_id, publisher, pemfile): self.client = client self.connection_id = connection_id self.publisher = publisher + self.pemfile = pemfile self.viewers = 0 + self.context = None def run(self): - buf = b'' - while len(buf) < 1024 and b"\n" not in buf: - buf += self.client.recv(1024) - - pos = buf.find(b"\n") - if pos == -1: + auth = self._readline() + if auth is None: print("no authentication found") return + print(auth) - auth = buf[:pos] - if auth[-1:] == b"\r": - auth = auth[:-1] - - buf = buf[pos+1:] + if auth == b"starttls": + if not self._starttls(): + print("TLS connection failed") + return + auth = self._readline() m = auth_re.match(auth) if m is None: @@ -163,16 +163,7 @@ class Connection(object): self.name = m.group(1) self.client.send(b"hello, " + self.name + b"\n") - extra_data = {} - m = extra_data_re.match(buf) - if m is not None: - try: - extra_data_json = m.group(1).decode('utf-8') - extra_data = json.loads(extra_data_json) - except Exception as e: - print("failed to parse metadata: %s" % e, file=sys.stderr) - pass - buf = buf[len(m.group(0)):] + extra_data, buf = self._try_read_metadata() if "geometry" in extra_data: self.handler = Handler( @@ -234,3 +225,65 @@ class Connection(object): "total_time": self.handler.total_time(), "viewers": self.viewers, } + + def _readline(self): + buf = b'' + while len(buf) < 1024 and b"\n" not in buf: + buf += self.client.recv(1) + + pos = buf.find(b"\n") + if pos == -1: + return + + line = buf[:pos] + if line[-1:] == b"\r": + line = line[:-1] + + return line + + def _starttls(self): + if self.context is None: + self.context = ssl.create_default_context( + purpose=ssl.Purpose.CLIENT_AUTH + ) + self.context.load_cert_chain(certfile=self.pemfile) + try: + self.client = self.context.wrap_socket( + self.client, server_side=True + ) + except Exception as e: + print('*** TLS connection failed: ' + str(e)) + return False + + return True + + def _try_read_metadata(self): + buf = b'' + while len(buf) < 6: + more = self.client.recv(6 - len(buf)) + if len(more) > 0: + buf += more + else: + return {}, buf + + if buf != b'\033]499;': + return {}, buf + + while len(buf) < 4096 and b"\007" not in buf: + buf += self.client.recv(1) + + if b"\007" not in buf: + return {}, buf + + extra_data = {} + m = extra_data_re.match(buf) + if m is not None: + try: + extra_data_json = m.group(1).decode('utf-8') + extra_data = json.loads(extra_data_json) + except Exception as e: + print("failed to parse metadata: %s" % e, file=sys.stderr) + pass + buf = buf[len(m.group(0)):] + + return extra_data, buf -- cgit v1.2.3-54-g00ecf