aboutsummaryrefslogtreecommitdiffstats
path: root/termcast_server
diff options
context:
space:
mode:
authorJesse Luehrs <doy@tozt.net>2014-09-23 14:01:46 -0400
committerJesse Luehrs <doy@tozt.net>2014-09-23 14:01:46 -0400
commit70cd632270083acc2bb835069d6e00f07a728cb3 (patch)
tree37298ef15ceed88a147ec668ce6977a9f1a5b942 /termcast_server
parentf44c44324c24aa667bfd56bb7b9ee2e96ac93bd4 (diff)
downloadpython-termcast-server-70cd632270083acc2bb835069d6e00f07a728cb3.tar.gz
python-termcast-server-70cd632270083acc2bb835069d6e00f07a728cb3.zip
reorganize
Diffstat (limited to 'termcast_server')
-rw-r--r--termcast_server/__init__.py78
-rw-r--r--termcast_server/__main__.py5
-rw-r--r--termcast_server/pubsub.py31
-rw-r--r--termcast_server/ssh.py206
-rw-r--r--termcast_server/termcast.py198
5 files changed, 518 insertions, 0 deletions
diff --git a/termcast_server/__init__.py b/termcast_server/__init__.py
new file mode 100644
index 0000000..3a8dfee
--- /dev/null
+++ b/termcast_server/__init__.py
@@ -0,0 +1,78 @@
+import signal
+import socket
+import sys
+import threading
+import uuid
+
+from . import pubsub
+from . import ssh
+from . import termcast
+
+class Server(object):
+ def __init__(self, keyfile):
+ self.publisher = pubsub.Publisher()
+ self.keyfile = keyfile
+
+ def listen(self):
+ ssh_sock = self._open_socket(2200)
+ termcast_sock = self._open_socket(2201)
+
+ threading.Thread(
+ target=lambda: self.wait_for_ssh_connection(ssh_sock)
+ ).start()
+ threading.Thread(
+ target=lambda: self.wait_for_termcast_connection(termcast_sock)
+ ).start()
+
+ def wait_for_ssh_connection(self, sock):
+ self._wait_for_connection(
+ sock,
+ lambda client: self.handle_ssh_connection(client)
+ )
+
+ def wait_for_termcast_connection(self, sock):
+ self._wait_for_connection(
+ sock,
+ lambda client: self.handle_termcast_connection(client)
+ )
+
+ def handle_ssh_connection(self, client):
+ self._handle_connection(
+ client,
+ lambda client, connection_id: ssh.Connection(
+ client, connection_id, self.publisher, self.keyfile
+ )
+ )
+
+ def handle_termcast_connection(self, client):
+ self._handle_connection(
+ client,
+ lambda client, connection_id: termcast.Connection(
+ client, connection_id, self.publisher
+ )
+ )
+
+ def _wait_for_connection(self, sock, cb):
+ while True:
+ try:
+ sock.listen(100)
+ client, addr = sock.accept()
+ except Exception as e:
+ print('*** Listen/accept failed: ' + str(e))
+ traceback.print_exc()
+ continue
+
+ threading.Thread(target=cb, args=(client,)).start()
+
+ def _handle_connection(self, client, cb):
+ connection_id = uuid.uuid4().hex
+ connection = cb(client, connection_id)
+ self.publisher.subscribe(connection)
+ connection.run()
+ self.publisher.unsubscribe(connection)
+
+ def _open_socket(self, port):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock.bind(('', port))
+ return sock
diff --git a/termcast_server/__main__.py b/termcast_server/__main__.py
new file mode 100644
index 0000000..d842d0a
--- /dev/null
+++ b/termcast_server/__main__.py
@@ -0,0 +1,5 @@
+import sys
+import termcast_server
+
+server = termcast_server.Server(sys.argv[1])
+server.listen()
diff --git a/termcast_server/pubsub.py b/termcast_server/pubsub.py
new file mode 100644
index 0000000..b5faf22
--- /dev/null
+++ b/termcast_server/pubsub.py
@@ -0,0 +1,31 @@
+class Publisher(object):
+ def __init__(self):
+ self.subscribers = []
+
+ def subscribe(self, who):
+ if who not in self.subscribers:
+ self.subscribers.append(who)
+
+ def unsubscribe(self, who):
+ if who in self.subscribers:
+ self.subscribers.remove(who)
+
+ def request_all(self, message, *args):
+ ret = []
+ for subscriber in self.subscribers:
+ method = "request_" + message
+ if hasattr(subscriber, method):
+ ret.append(getattr(subscriber, method)(*args))
+ return ret
+
+ def request_one(self, message, *args):
+ for subscriber in self.subscribers:
+ method = "request_" + message
+ if hasattr(subscriber, method):
+ return getattr(subscriber, method)(*args)
+
+ def notify(self, message, *args):
+ for subscriber in self.subscribers:
+ method = "msg_" + message
+ if hasattr(subscriber, method):
+ getattr(subscriber, method)(*args)
diff --git a/termcast_server/ssh.py b/termcast_server/ssh.py
new file mode 100644
index 0000000..6256f22
--- /dev/null
+++ b/termcast_server/ssh.py
@@ -0,0 +1,206 @@
+import multiprocessing
+import paramiko
+import select
+import threading
+import time
+
+class Connection(object):
+ def __init__(self, client, connection_id, publisher, keyfile):
+ self.transport = paramiko.Transport(client)
+
+ key = None
+ with open(keyfile) as f:
+ header = f.readline()
+ if header == "-----BEGIN DSA PRIVATE KEY-----\n":
+ key = paramiko.DSSKey(filename=keyfile)
+ elif header == "-----BEGIN RSA PRIVATE KEY-----\n":
+ key = paramiko.RSAKey(filename=keyfile)
+ if key is None:
+ raise Exception("%s doesn't appear to be an SSH keyfile" % keyfile)
+ self.transport.add_server_key(key)
+
+ self.connection_id = connection_id
+ self.publisher = publisher
+ self.initialized = False
+ self.watching_id = None
+
+ self.rpipe, self.wpipe = multiprocessing.Pipe(False)
+
+ def run(self):
+ self.server = Server()
+ self.transport.start_server(server=self.server)
+ self.chan = self.transport.accept(10)
+
+ if self.chan is not None:
+ self.server.pty_event.wait()
+
+ while True:
+ self.initialized = False
+ self.watching_id = None
+
+ streamer = self.select_stream()
+ if streamer is None:
+ break
+ self.watching_id = streamer["id"]
+
+ print(
+ "new viewer watching %s (%s)" % (
+ streamer["name"], streamer["id"]
+ )
+ )
+ self._send_all(
+ "\033[1;%d;1;%dr\033[m\033[H\033[2J" % (
+ streamer["rows"], streamer["cols"]
+ )
+ )
+ self.publisher.notify("new_viewer", self.watching_id)
+
+ while True:
+ rout, wout, eout = select.select(
+ [self.chan, self.rpipe],
+ [],
+ []
+ )
+ if self.chan in rout:
+ c = self.chan.recv(1)
+ if c == b'q':
+ print(
+ "viewer stopped watching %s (%s)" % (
+ streamer["name"], streamer["id"]
+ )
+ )
+ self._cleanup_watcher()
+ break
+
+ if self.rpipe in rout:
+ self._cleanup_watcher()
+ break
+
+ if self.chan is not None:
+ self.chan.close()
+ self.transport.close()
+
+ def select_stream(self):
+ key_code = ord('a')
+ keymap = {}
+ streamers = self.publisher.request_all("get_streamers")
+ # XXX this will require pagination
+ for streamer in streamers:
+ key = chr(key_code)
+ if key == "q":
+ key_code += 1
+ key = chr(key_code)
+ streamer["key"] = key
+ keymap[key] = streamer
+ key_code += 1
+
+ self._display_streamer_screen(streamers)
+
+ c = self.chan.recv(1).decode('utf-8', 'ignore')
+ if c in keymap:
+ self._send_all("\033[2J\033[H")
+ return keymap[c]
+ elif c == 'q':
+ self._send_all("\r\n")
+ return None
+ else:
+ return self.select_stream()
+
+ def msg_new_data(self, connection_id, prev_buf, data):
+ if self.watching_id != connection_id:
+ return
+
+ if not self.initialized:
+ self._send_all(prev_buf)
+ self.initialized = True
+
+ self._send_all(data)
+
+ def msg_streamer_disconnect(self, connection_id):
+ if self.watching_id != connection_id:
+ return
+
+ self.wpipe.send("q")
+
+ def _send_all(self, data):
+ total_sent = 0
+ while total_sent < len(data):
+ total_sent += self.chan.send(data[total_sent:])
+
+ def _display_streamer_screen(self, streamers):
+ self._send_all("\033[H\033[2JWelcome to Termcast!")
+ self._send_all(
+ "\033[3H %-20s %-15s %-10s %-12s %-15s" % (
+ "User", "Terminal size", "Viewers", "Idle time", "Total time"
+ )
+ )
+ row = 4
+ for streamer in streamers:
+ key = streamer["key"]
+ name = streamer["name"].decode('utf-8', 'replace')
+ rows = streamer["rows"]
+ cols = streamer["cols"]
+ viewers = streamer["viewers"]
+ idle = streamer["idle_time"]
+ total = streamer["total_time"]
+ size = "(%dx%d)" % (cols, rows)
+ size_pre = ""
+ size_post = ""
+ if cols > self.server.cols or rows > self.server.rows:
+ size_pre = "\033[31m"
+ size_post = "\033[m"
+ self._send_all(
+ "\033[%dH%s) %-20s %s%-15s%s %-10s %-12s %-15s" % (
+ row, key, name, size_pre, size, size_post,
+ viewers, idle, total
+ )
+ )
+ row += 1
+ self._send_all("\033[%dHChoose a stream: " % (row + 1))
+
+ def _cleanup_watcher(self):
+ self.publisher.notify(
+ "viewer_disconnect", self.watching_id
+ )
+ self._send_all(
+ ("\033[1;%d;1;%dr"
+ + "\033[m"
+ + "\033[?9l\033[?1000l"
+ + "\033[H\033[2J") % (
+ self.server.rows, self.server.cols
+ )
+ )
+
+class Server(paramiko.ServerInterface):
+ def __init__(self):
+ super()
+ self.cols = 80
+ self.rows = 24
+ self.pty_event = threading.Event()
+
+ def check_channel_request(self, kind, chanid):
+ return paramiko.OPEN_SUCCEEDED
+
+ def check_channel_pty_request(
+ self, channel, term, width, height, pixelwidth, pixelheight, modes
+ ):
+ self.cols = width
+ self.rows = height
+ self.pty_event.set()
+ return True
+
+ def check_channel_window_change_request(
+ self, channel, width, height, pixelwidth, pixelheight
+ ):
+ self.cols = width
+ self.rows = height
+ return True
+
+ def check_channel_shell_request(self, channel):
+ return True
+
+ def check_auth_password(self, username, password):
+ return paramiko.AUTH_SUCCESSFUL
+
+ def get_allowed_auths(self, username):
+ return "password"
diff --git a/termcast_server/termcast.py b/termcast_server/termcast.py
new file mode 100644
index 0000000..e4eed6d
--- /dev/null
+++ b/termcast_server/termcast.py
@@ -0,0 +1,198 @@
+import time
+import json
+import re
+
+import vt100
+
+auth_re = re.compile(b'^hello ([^ ]+) ([^ ]+)$')
+extra_data_re = re.compile(b'\033\]499;([^\007]*)\007')
+
+clear_patterns = [
+ b"\033[H\033[J",
+ b"\033[H\033[2J",
+ b"\033[2J\033[H",
+ # this one is from tmux - can't possibly imagine why it would choose to do
+ # things this way, but i'm sure there's some kind of reason
+ # it's not perfect (it's not always followed by a \e[H, sometimes it just
+ # moves the cursor to wherever else directly), but it helps a bit
+ lambda handler: b"\033[H\033[K\r\n\033[K" + b"".join([b"\033[1B\033[K" for i in range(handler.rows - 2)]) + b"\033[H",
+]
+
+class Handler(object):
+ def __init__(self, rows, cols):
+ self.created_at = time.time()
+ self.idle_since = time.time()
+ self.rows = rows
+ self.cols = cols
+ self.buf = b''
+ self.prev_read = b''
+ self.vt = vt100.vt100(rows, cols)
+
+ def process(self, data):
+ to_process = self.prev_read + data
+ processed = self.vt.process(to_process)
+ self.prev_read = to_process[processed:]
+
+ self.buf += data
+
+ extra_data = {}
+ while True:
+ m = extra_data_re.search(self.buf)
+ if m is None:
+ break
+ try:
+ extra_data_json = m.group(1).decode('utf-8')
+ extra_data = json.loads(extra_data_json)
+ except Exception as e:
+ print("failed to parse metadata: %s" % e, file=sys.stderr)
+ pass
+ self.buf = self.buf[:m.start(0)] + self.buf[m.end(0):]
+ if "geometry" in extra_data:
+ self.rows = extra_data["geometry"][1]
+ self.cols = extra_data["geometry"][0]
+ self.vt.set_window_size(self.rows, self.cols)
+
+ for pattern in clear_patterns:
+ if type(pattern) == type(lambda x: x):
+ pattern = pattern(self)
+ clear = self.buf.rfind(pattern)
+ if clear != -1:
+ print("found a clear")
+ self.buf = self.buf[clear + len(pattern):]
+
+ self.idle_since = time.time()
+
+ def get_term(self):
+ term = ''
+ for i in range(0, self.rows):
+ for j in range(0, self.cols):
+ term += self.vt.cell(i, j).contents()
+ term += "\n"
+
+ return term[:-1]
+
+ def total_time(self):
+ return self._human_readable_duration(time.time() - self.created_at)
+
+ def idle_time(self):
+ return self._human_readable_duration(time.time() - self.idle_since)
+
+ def _human_readable_duration(self, duration):
+ days = 0
+ hours = 0
+ minutes = 0
+ seconds = 0
+
+ if duration > 60*60*24:
+ days = duration // (60*60*24)
+ duration -= days * 60*60*24
+ if duration > 60*60:
+ hours = duration // (60*60)
+ duration -= hours * 60*60
+ if duration > 60:
+ minutes = duration // 60
+ duration -= minutes * 60
+ seconds = duration
+
+ ret = "%02ds" % seconds
+ if minutes > 0 or hours > 0 or days > 0:
+ ret = ("%02dm" % minutes) + ret
+ if hours > 0 or days > 0:
+ ret = ("%02dh" % hours) + ret
+ if days > 0:
+ ret = ("%dd" % days) + ret
+
+ return ret
+
+class Connection(object):
+ def __init__(self, client, connection_id, publisher):
+ self.client = client
+ self.connection_id = connection_id
+ self.publisher = publisher
+ self.viewers = 0
+
+ def run(self):
+ buf = b''
+ while len(buf) < 1024 and b"\n" not in buf:
+ buf += self.client.recv(1024)
+
+ pos = buf.find(b"\n")
+ if pos == -1:
+ print("no authentication found")
+ return
+
+ auth = buf[:pos]
+ if auth[-1:] == b"\r":
+ auth = auth[:-1]
+
+ buf = buf[pos+1:]
+
+ m = auth_re.match(auth)
+ if m is None:
+ print("no authentication found (%s)" % auth)
+ return
+
+ print(b"got auth: " + auth)
+ self.name = m.group(1)
+ self.client.send(b"hello, " + self.name + b"\n")
+
+ extra_data = {}
+ m = extra_data_re.match(buf)
+ if m is not None:
+ try:
+ extra_data_json = m.group(1).decode('utf-8')
+ extra_data = json.loads(extra_data_json)
+ except Exception as e:
+ print("failed to parse metadata: %s" % e, file=sys.stderr)
+ pass
+ buf = buf[len(m.group(0)):]
+
+ if "geometry" in extra_data:
+ self.handler = Handler(
+ extra_data["geometry"][1], extra_data["geometry"][0]
+ )
+ else:
+ self.handler = Handler(24, 80)
+
+ self.handler.process(buf)
+ while True:
+ buf = b''
+ try:
+ buf = self.client.recv(1024)
+ except Exception as e:
+ print('*** recv failed: ' + str(e))
+
+ if len(buf) > 0:
+ self.publisher.notify(
+ "new_data", self.connection_id, self.handler.buf, buf
+ )
+ self.handler.process(buf)
+ else:
+ self.publisher.notify("streamer_disconnect", self.connection_id)
+ return
+
+ def msg_new_viewer(self, connection_id):
+ if connection_id != self.connection_id:
+ return
+ self.viewers += 1
+ self.publisher.notify(
+ "new_data", self.connection_id, self.handler.buf, b''
+ )
+ self.client.send(b"msg watcher connected\n")
+
+ def msg_viewer_disconnect(self, connection_id):
+ if connection_id != self.connection_id:
+ return
+ self.client.send(b"msg watcher disconnected\n")
+ self.viewers -= 1
+
+ def request_get_streamers(self):
+ return {
+ "name": self.name,
+ "id": self.connection_id,
+ "rows": self.handler.rows,
+ "cols": self.handler.cols,
+ "idle_time": self.handler.idle_time(),
+ "total_time": self.handler.total_time(),
+ "viewers": self.viewers,
+ }