aboutsummaryrefslogtreecommitdiffstats
path: root/termcast_server/termcast.py
diff options
context:
space:
mode:
Diffstat (limited to 'termcast_server/termcast.py')
-rw-r--r--termcast_server/termcast.py97
1 files changed, 75 insertions, 22 deletions
diff --git a/termcast_server/termcast.py b/termcast_server/termcast.py
index f99971d..a2318d6 100644
--- a/termcast_server/termcast.py
+++ b/termcast_server/termcast.py
@@ -1,6 +1,7 @@
import time
import json
import re
+import ssl
import vt100
@@ -132,27 +133,26 @@ class Handler(object):
return ret
class Connection(object):
- def __init__(self, client, connection_id, publisher):
+ def __init__(self, client, connection_id, publisher, pemfile):
self.client = client
self.connection_id = connection_id
self.publisher = publisher
+ self.pemfile = pemfile
self.viewers = 0
+ self.context = None
def run(self):
- buf = b''
- while len(buf) < 1024 and b"\n" not in buf:
- buf += self.client.recv(1024)
-
- pos = buf.find(b"\n")
- if pos == -1:
+ auth = self._readline()
+ if auth is None:
print("no authentication found")
return
+ print(auth)
- auth = buf[:pos]
- if auth[-1:] == b"\r":
- auth = auth[:-1]
-
- buf = buf[pos+1:]
+ if auth == b"starttls":
+ if not self._starttls():
+ print("TLS connection failed")
+ return
+ auth = self._readline()
m = auth_re.match(auth)
if m is None:
@@ -163,16 +163,7 @@ class Connection(object):
self.name = m.group(1)
self.client.send(b"hello, " + self.name + b"\n")
- extra_data = {}
- m = extra_data_re.match(buf)
- if m is not None:
- try:
- extra_data_json = m.group(1).decode('utf-8')
- extra_data = json.loads(extra_data_json)
- except Exception as e:
- print("failed to parse metadata: %s" % e, file=sys.stderr)
- pass
- buf = buf[len(m.group(0)):]
+ extra_data, buf = self._try_read_metadata()
if "geometry" in extra_data:
self.handler = Handler(
@@ -234,3 +225,65 @@ class Connection(object):
"total_time": self.handler.total_time(),
"viewers": self.viewers,
}
+
+ def _readline(self):
+ buf = b''
+ while len(buf) < 1024 and b"\n" not in buf:
+ buf += self.client.recv(1)
+
+ pos = buf.find(b"\n")
+ if pos == -1:
+ return
+
+ line = buf[:pos]
+ if line[-1:] == b"\r":
+ line = line[:-1]
+
+ return line
+
+ def _starttls(self):
+ if self.context is None:
+ self.context = ssl.create_default_context(
+ purpose=ssl.Purpose.CLIENT_AUTH
+ )
+ self.context.load_cert_chain(certfile=self.pemfile)
+ try:
+ self.client = self.context.wrap_socket(
+ self.client, server_side=True
+ )
+ except Exception as e:
+ print('*** TLS connection failed: ' + str(e))
+ return False
+
+ return True
+
+ def _try_read_metadata(self):
+ buf = b''
+ while len(buf) < 6:
+ more = self.client.recv(6 - len(buf))
+ if len(more) > 0:
+ buf += more
+ else:
+ return {}, buf
+
+ if buf != b'\033]499;':
+ return {}, buf
+
+ while len(buf) < 4096 and b"\007" not in buf:
+ buf += self.client.recv(1)
+
+ if b"\007" not in buf:
+ return {}, buf
+
+ extra_data = {}
+ m = extra_data_re.match(buf)
+ if m is not None:
+ try:
+ extra_data_json = m.group(1).decode('utf-8')
+ extra_data = json.loads(extra_data_json)
+ except Exception as e:
+ print("failed to parse metadata: %s" % e, file=sys.stderr)
+ pass
+ buf = buf[len(m.group(0)):]
+
+ return extra_data, buf