aboutsummaryrefslogblamecommitdiffstats
path: root/termcast_server/ssh.py
blob: c1d5a8168eafe6ac333521c1a42d305102937ff2 (plain) (tree)
1
2
3
4
5
6
7
8
9
                      
               
             
                
           
                
 
                         
                                                                  
                                                   











                                                                               

                                          
                                
                               
 

                                                            
                  

                                                       
                                             
 

                                        
 


                                        
 



                                                 
 



                                                        
                 
                               


                                                            
                 
                                                                     
 
                           











                                                                     
                             




                                                   
                             
 


                                 
 
                            

                           
                                                               
                                          
                                  
                               


                                   
                                 
                                  

                         
                                                
 
                                                       
                       
                                           
                            
                      
                                  
                       


                                       
                                                                           

                                             

                                
                                    

                                   
                            
 





                                                     


                                     


                                                               
                                             

                                                   
 
                                                  

                                                           
                                                             
                                                                             


               
                                  
                                 
                                                              


                                         

                                                        
                                           

                          
                                                                  

                                     
                           
                                                                       
                                                              


                                                        

                 
                    
                                                               
 


























                                                 



                                                 
                       







                                                  
                                       





                                          


                                                  


                                                                          




                            


                                                             

                          




                                                   
                                        
                                       

                                          
                     
import multiprocessing
import paramiko
import select
import threading
import time
import traceback

class Connection(object):
    def __init__(self, client, connection_id, publisher, keyfile):
        self.transport = paramiko.Transport(client)

        key = None
        with open(keyfile) as f:
            header = f.readline()
            if header == "-----BEGIN DSA PRIVATE KEY-----\n":
                key = paramiko.DSSKey(filename=keyfile)
            elif header == "-----BEGIN RSA PRIVATE KEY-----\n":
                key = paramiko.RSAKey(filename=keyfile)
        if key is None:
            raise Exception("%s doesn't appear to be an SSH keyfile" % keyfile)
        self.transport.add_server_key(key)

        self.connection_id = connection_id
        self.publisher = publisher
        self.initialized = False
        self.watching_id = None

        self.rpipe, self.wpipe = multiprocessing.Pipe(False)

    def run(self):
        self.server = Server()
        self.transport.start_server(server=self.server)
        self.chan = self.transport.accept(10)

        if self.chan is not None:
            self.server.pty_event.wait()

            while True:
                self.initialized = False
                self.watching_id = None

                streamer = self.select_stream()
                if streamer is None:
                    break
                self.watching_id = streamer["id"]

                print(
                    "new viewer watching %s (%s)" % (
                        streamer["name"], streamer["id"]
                    )
                )
                self._send_all(
                    "\033[1;%d;1;%dr\033[m\033[H\033[2J" % (
                        streamer["rows"], streamer["cols"]
                    )
                )
                self.publisher.notify("new_viewer", self.watching_id)

                while True:
                    rout, wout, eout = select.select(
                        [self.chan, self.rpipe],
                        [],
                        []
                    )
                    if self.chan in rout:
                        c = self.chan.recv(1)
                        if c == b'q':
                            print(
                                "viewer stopped watching %s (%s)" % (
                                    streamer["name"], streamer["id"]
                                )
                            )
                            self._cleanup_watcher()
                            break

                    if self.rpipe in rout:
                        self._cleanup_watcher()
                        break

        if self.chan is not None:
            self.chan.close()
        self.transport.close()

    def select_stream(self):
        key_code = ord('a')
        keymap = {}
        streamers = self.publisher.request_all("get_streamers")
        # XXX this will require pagination
        for streamer in streamers:
            key = chr(key_code)
            if key == "q":
                key_code += 1
                key = chr(key_code)
            streamer["key"] = key
            keymap[key] = streamer
            key_code += 1

        self._display_streamer_screen(streamers)

        c = self.chan.recv(1).decode('utf-8', 'ignore')
        if c in keymap:
            self._send_all("\033[2J\033[H")
            return keymap[c]
        elif c == 'q':
            self._send_all("\r\n")
            return None
        else:
            return self.select_stream()

    def msg_new_data(self, connection_id, prev_buf, data, screen, updates):
        if self.watching_id != connection_id:
            return

        if not self.initialized:
            self._send_all(prev_buf)
            self.initialized = True

        self._send_all(data)

    def msg_streamer_disconnect(self, connection_id):
        if self.watching_id != connection_id:
            return

        self.wpipe.send("q")

    def _send_all(self, data):
        total_sent = 0
        while total_sent < len(data):
            try:
                total_sent += self.chan.send(data[total_sent:])
            except Exception as e:
                print(traceback.format_exc())
                print("*** send failed: " + str(e))
                break

    def _display_streamer_screen(self, streamers):
        self._send_all("\033[H\033[2JWelcome to Termcast!")
        self._send_all(
            "\033[3H   %-20s  %-15s  %-10s  %-12s  %-15s" % (
                "User", "Terminal size", "Viewers", "Idle time", "Total time"
            )
        )
        row = 4
        for streamer in streamers:
            key = streamer["key"]
            name = streamer["name"].decode('utf-8', 'replace')
            rows = streamer["rows"]
            cols = streamer["cols"]
            viewers = streamer["viewers"]
            idle = time.time() - streamer["idle_since"]
            total = time.time() - streamer["created_at"]
            size = "(%dx%d)" % (cols, rows)
            size_pre = ""
            size_post = ""
            if cols > self.server.cols or rows > self.server.rows:
                size_pre = "\033[31m"
                size_post = "\033[m"
            self._send_all(
                "\033[%dH%s) %-20s  %s%-15s%s  %-10s  %-12s  %-15s" % (
                    row, key, name, size_pre, size, size_post,
                    viewers,
                    self._human_readable_duration(idle),
                    self._human_readable_duration(total)
                )
            )
            row += 1
        self._send_all("\033[%dHChoose a stream: " % (row + 1))

    def _human_readable_duration(self, duration):
        days = 0
        hours = 0
        minutes = 0
        seconds = 0

        if duration > 60*60*24:
            days = duration // (60*60*24)
            duration -= days * 60*60*24
        if duration > 60*60:
            hours = duration // (60*60)
            duration -= hours * 60*60
        if duration > 60:
            minutes = duration // 60
            duration -= minutes * 60
        seconds = duration

        ret = "%02ds" % seconds
        if minutes > 0 or hours > 0 or days > 0:
            ret = ("%02dm" % minutes) + ret
        if hours > 0 or days > 0:
            ret = ("%02dh" % hours) + ret
        if days > 0:
            ret = ("%dd" % days) + ret

        return ret

    def _cleanup_watcher(self):
        self.publisher.notify(
            "viewer_disconnect", self.watching_id
        )
        self._send_all(
            ("\033[1;%d;1;%dr"
            + "\033[m"
            + "\033[?9l\033[?1000l"
            + "\033[H\033[2J") % (
                self.server.rows, self.server.cols
            )
        )

class Server(paramiko.ServerInterface):
    def __init__(self):
        super()
        self.cols = 80
        self.rows = 24
        self.pty_event = threading.Event()

    def check_channel_request(self, kind, chanid):
        return paramiko.OPEN_SUCCEEDED

    def check_channel_pty_request(
        self, channel, term, width, height, pixelwidth, pixelheight, modes
    ):
        self.cols = width
        self.rows = height
        self.pty_event.set()
        return True

    def check_channel_window_change_request(
        self, channel, width, height, pixelwidth, pixelheight
    ):
        self.cols = width
        self.rows = height
        return True

    def check_channel_shell_request(self, channel):
        return True

    def check_auth_none(self, username):
        return paramiko.AUTH_SUCCESSFUL

    def get_allowed_auths(self, username):
        return "none"