From f2b2b81ca7227b3f9cbee2b7e9e5f6ff4e2867e8 Mon Sep 17 00:00:00 2001 From: Jesse Luehrs Date: Thu, 16 Oct 2014 13:07:32 -0400 Subject: add ssl support --- termcast_server/__init__.py | 9 +++-- termcast_server/termcast.py | 97 +++++++++++++++++++++++++++++++++++---------- termcast_server/web.py | 3 +- 3 files changed, 82 insertions(+), 27 deletions(-) diff --git a/termcast_server/__init__.py b/termcast_server/__init__.py index 35ff278..e570d7e 100644 --- a/termcast_server/__init__.py +++ b/termcast_server/__init__.py @@ -10,9 +10,10 @@ from . import termcast from . import web class Server(object): - def __init__(self, keyfile): + def __init__(self, keyfile, pemfile): self.publisher = pubsub.Publisher() self.keyfile = keyfile + self.pemfile = pemfile def listen(self): ssh_sock = self._open_socket(2200) @@ -44,7 +45,7 @@ class Server(object): def wait_for_web_connection(self, sock): sock.setblocking(0) sock.listen(100) - web.start_server(sock, self.publisher) + web.start_server(sock, self.publisher, self.pemfile) def handle_ssh_connection(self, client): self._handle_connection( @@ -58,7 +59,7 @@ class Server(object): self._handle_connection( client, lambda client, connection_id: termcast.Connection( - client, connection_id, self.publisher + client, connection_id, self.publisher, self.pemfile ) ) @@ -87,5 +88,5 @@ class Server(object): return sock def main(): - server = Server(sys.argv[1]) + server = Server(sys.argv[1], sys.argv[2]) server.listen() diff --git a/termcast_server/termcast.py b/termcast_server/termcast.py index f99971d..a2318d6 100644 --- a/termcast_server/termcast.py +++ b/termcast_server/termcast.py @@ -1,6 +1,7 @@ import time import json import re +import ssl import vt100 @@ -132,27 +133,26 @@ class Handler(object): return ret class Connection(object): - def __init__(self, client, connection_id, publisher): + def __init__(self, client, connection_id, publisher, pemfile): self.client = client self.connection_id = connection_id self.publisher = publisher + self.pemfile = pemfile self.viewers = 0 + self.context = None def run(self): - buf = b'' - while len(buf) < 1024 and b"\n" not in buf: - buf += self.client.recv(1024) - - pos = buf.find(b"\n") - if pos == -1: + auth = self._readline() + if auth is None: print("no authentication found") return + print(auth) - auth = buf[:pos] - if auth[-1:] == b"\r": - auth = auth[:-1] - - buf = buf[pos+1:] + if auth == b"starttls": + if not self._starttls(): + print("TLS connection failed") + return + auth = self._readline() m = auth_re.match(auth) if m is None: @@ -163,16 +163,7 @@ class Connection(object): self.name = m.group(1) self.client.send(b"hello, " + self.name + b"\n") - extra_data = {} - m = extra_data_re.match(buf) - if m is not None: - try: - extra_data_json = m.group(1).decode('utf-8') - extra_data = json.loads(extra_data_json) - except Exception as e: - print("failed to parse metadata: %s" % e, file=sys.stderr) - pass - buf = buf[len(m.group(0)):] + extra_data, buf = self._try_read_metadata() if "geometry" in extra_data: self.handler = Handler( @@ -234,3 +225,65 @@ class Connection(object): "total_time": self.handler.total_time(), "viewers": self.viewers, } + + def _readline(self): + buf = b'' + while len(buf) < 1024 and b"\n" not in buf: + buf += self.client.recv(1) + + pos = buf.find(b"\n") + if pos == -1: + return + + line = buf[:pos] + if line[-1:] == b"\r": + line = line[:-1] + + return line + + def _starttls(self): + if self.context is None: + self.context = ssl.create_default_context( + purpose=ssl.Purpose.CLIENT_AUTH + ) + self.context.load_cert_chain(certfile=self.pemfile) + try: + self.client = self.context.wrap_socket( + self.client, server_side=True + ) + except Exception as e: + print('*** TLS connection failed: ' + str(e)) + return False + + return True + + def _try_read_metadata(self): + buf = b'' + while len(buf) < 6: + more = self.client.recv(6 - len(buf)) + if len(more) > 0: + buf += more + else: + return {}, buf + + if buf != b'\033]499;': + return {}, buf + + while len(buf) < 4096 and b"\007" not in buf: + buf += self.client.recv(1) + + if b"\007" not in buf: + return {}, buf + + extra_data = {} + m = extra_data_re.match(buf) + if m is not None: + try: + extra_data_json = m.group(1).decode('utf-8') + extra_data = json.loads(extra_data_json) + except Exception as e: + print("failed to parse metadata: %s" % e, file=sys.stderr) + pass + buf = buf[len(m.group(0)):] + + return extra_data, buf diff --git a/termcast_server/web.py b/termcast_server/web.py index 1c64fd4..57ddf14 100644 --- a/termcast_server/web.py +++ b/termcast_server/web.py @@ -60,7 +60,8 @@ def make_app(publisher): ('/-/', WebSocketHandler, dict(publisher=publisher)), ]) -def start_server(sock, publisher): +def start_server(sock, publisher, pemfile): + # XXX set up ssl with pemfile server = tornado.httpserver.HTTPServer(make_app(publisher)) server.add_socket(sock) tornado.ioloop.IOLoop.instance().start() -- cgit v1.2.3-54-g00ecf