From: Alina Quereilhac Date: Mon, 1 Apr 2013 11:40:30 +0000 (+0200) Subject: Creating ssh api (unfinished) X-Git-Tag: nepi-3.0.0~122^2~20 X-Git-Url: http://git.onelab.eu/?a=commitdiff_plain;h=c7524b8d2bb83b5a74fcc5f58dfd025194d4836b;p=nepi.git Creating ssh api (unfinished) --- diff --git a/src/neco/execution/ec.py b/src/neco/execution/ec.py index 0b15c3a0..076e4dd5 100644 --- a/src/neco/execution/ec.py +++ b/src/neco/execution/ec.py @@ -178,7 +178,8 @@ class ExperimentController(object): self._thread.join() def schedule(self, date, callback, track = False): - """ + """ Schedule a callback to be executed at time date. + date string containing execution time for the task. It can be expressed as an absolute time, using timestamp format, or as a relative time matching diff --git a/src/neco/execution/resource.py b/src/neco/execution/resource.py index ebc537ce..62bf821c 100644 --- a/src/neco/execution/resource.py +++ b/src/neco/execution/resource.py @@ -110,7 +110,7 @@ class ResourceManager(object): def set_after(self, name, value, time, after_status, guid): pass - def next_step(self): + def start(self): pass def stop(self): @@ -138,7 +138,7 @@ class ResourceFactory(object): cls._resource_types[rclass.rtype()] = rclass @classmethod - def create(cls, rtype, ec, guid, creds): + def create(cls, rtype, ec, guid): rclass = cls._resource_types[rtype] - return rclass(ec, guid, creds) + return rclass(ec, guid) diff --git a/src/neco/resources/linux/node.py b/src/neco/resources/linux/node.py index feaad462..80bb103f 100644 --- a/src/neco/resources/linux/node.py +++ b/src/neco/resources/linux/node.py @@ -1,313 +1,51 @@ -from neco.execution.resource import ResourceManager -from neco.util.sshfuncs import eintr_retry, rexec, rcopy, \ - rspawn, rcheck_pid, rstatus, rkill, make_control_path, RUNNING - -import cStringIO -import logging -import os.path -import subprocess +from neco.execution.resource import ResourceManager, clsinit +from neco.execution.attribute import Attribute, Flags +@clsinit class LinuxNode(ResourceManager): - def __init__(self, ec, guid): - super(LinuxNode, self).__init__(ec, guid) - self.ip = None - self.host = None - self.user = None - self.port = None - self.identity_file = None - self.enable_x11 = False - self.forward_agent = True - - # packet management system - either yum or apt for now... - self._pm = None - - # Logging - loglevel = "debug" - self._logger = logging.getLogger("neco.resources.base.LinuxNode.%s" %\ - self.guid) - self._logger.setLevel(getattr(logging, loglevel.upper())) - - # For ssh connections we use the ControlMaster option which - # allows us to decrease the number of open ssh network connections. - # Subsequent ssh connections will reuse a same master connection. - # This might pose a problem when using X11 and ssh-agent, since - # display and agent forwarded will be those of the first connection, - # which created the master. - # To avoid reusing a master created by a previous LinuxNode instance, - # we explicitly erase the ControlPath socket. - control_path = make_control_path(self.user, self.host, self.port) - try: - os.remove(control_path) - except: - pass - - @property - def pm(self): - if self._pm: - return self._pm - - if (not (self.host or self.ip) or not self.user): - msg = "Can't resolve package management system. Insufficient data." - self._logger.error(msg) - raise RuntimeError(msg) - - out = self.execute("cat /etc/issue") - - if out.find("Fedora") == 0: - self._pm = "yum" - elif out.find("Debian") == 0 or out.find("Ubuntu") ==0: - self._pm = "apt-get" - else: - msg = "Can't resolve package management system. Unknown OS." - self._logger.error(msg) - raise RuntimeError(msg) - - return self._pm - - @property - def is_localhost(self): - return ( self.host or self.ip ) in ['localhost', '127.0.0.7', '::1'] - - def install(self, packages): - if not isinstance(packages, list): - packages = [packages] + _rtype = "LinuxNode" - for p in packages: - self.execute("%s -y install %s" % (self.pm, p), sudo = True, - tty = True) + @classmethod + def _register_attributes(cls): + hostname = Attribute("hostname", "Hostname of the machine") + username = Attribute("username", "Local account username", + flags = Flags.Credential) + password = Attribute("pasword", "Local account password", + flags = Flags.Credential) - def uninstall(self, packages): - if not isinstance(packages, list): - packages = [packages] + cls._register_attribute(hostname) + cls._register_attribute(username) + cls._register_attribute(password) - for p in packages: - self.execute("%s -y remove %s" % (self.pm, p), sudo = True, - tty = True) - - def upload(self, src, dst): - if not os.path.isfile(src): - src = cStringIO.StringIO(src) - - if not self.is_localhost: - # Build destination as @: - dst = "%s@%s:%s" % (self.user, self.host or self.ip, dst) - return self.copy(src, dst) - - def download(self, src, dst): - if not self.is_localhost: - # Build destination as @: - src = "%s@%s:%s" % (self.user, self.host or self.ip, src) - return self.copy(src, dst) - - def is_alive(self, verbose = False): - if self.is_localhost: - return True - - try: - out = self.execute("echo 'ALIVE'", - timeout = 60, - err_on_timeout = False, - persistent = False) - except: - if verbose: - self._logger.warn("Unresponsive node %s got:\n%s%s", self.host, out, err) - return False - - if out.strip().startswith('ALIVE'): - return True - else: - if verbose: - self._logger.warn("Unresponsive node %s got:\n%s%s", self.host, out, err) - return False - - def mkdir(self, path, clean = True): - if clean: - self.rmdir(path) - - return self.execute( - "mkdir -p %s" % path, - timeout = 120, - retry = 3 - ) - - def rmdir(self, path): - return self.execute( - "rm -rf %s" % path, - timeout = 120, - retry = 3 - ) - - def copy(self, src, dst): - if self.is_localhost: - command = ["cp", "-R", src, dst] - p = subprocess.Popen(command, stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - out, err = p.communicate() - else: - (out, err), proc = eintr_retry(rcopy)( - src, dst, - port = self.port, - agent = self.agent, - identity_file = self.identity_file) - - if proc.wait(): - msg = "Error uploading to %s got:\n%s%s" %\ - (self.host or self.ip, out, err) - self._logger.error(msg) - raise RuntimeError(msg) + def __init__(self, ec, guid): + super(LinuxNode, self).__init__(ec, guid) - return (out, err) + self._logger = logging.getLogger("neco.linux.Node.%d" % guid) + #elf._logger.setLevel(neco.LOGLEVEL) - def execute(self, command, - sudo = False, - stdin = None, - tty = False, - env = None, - timeout = None, - retry = 0, - err_on_timeout = True, - connect_timeout = 30, - persistent = True): - """ Notice that this invocation will block until the - execution finishes. If this is not the desired behavior, - use 'run' instead.""" + def deploy(self): + pass - if self.is_localhost: - if env: - export = '' - for envkey, envval in env.iteritems(): - export += '%s=%s ' % (envkey, envval) - command = export + command + def discover(self, filters): + pass - if sudo: - command = "sudo " + command + def provision(self, filters): + pass - p = subprocess.Popen(command, stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - out, err = p.communicate() - else: - (out, err), proc = eintr_retry(rexec)( - command, - self.host or self.ip, - self.user, - port = self.port, - agent = self.forward_agent, - sudo = sudo, - stdin = stdin, - identity_file = self.identity_file, - tty = tty, - x11 = self.enable_x11, - env = env, - timeout = timeout, - retry = retry, - err_on_timeout = err_on_timeout, - connect_timeout = connect_timeout, - persistent = persistent) + def start(self): + pass - if proc.wait(): - msg = "Failed to execute command %s at node %s: %s %s" % \ - (command, self.host or self.ip, out, err,) - self._logger.warn(msg) - raise RuntimeError(msg) + def stop(self): + pass - return (out, err) + def deploy(self, group = None): + pass - def run(self, command, home, - stdin = None, - stdout = 'stdout', - stderr = 'stderr', - sudo = False): - self._logger.info("Running %s", command) - - pidfile = './pid', + def release(self): + pass - if self.is_localhost: - if stderr == stdout: - stderr = '&1' - else: - stderr = ' ' + stderr - - daemon_command = '{ { %(command)s > %(stdout)s 2>%(stderr)s < %(stdin)s & } ; echo $! 1 > %(pidfile)s ; }' % { - 'command' : command, - 'pidfile' : pidfile, - - 'stdout' : stdout, - 'stderr' : stderr, - 'stdin' : stdin, - } - - cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c '%(command)s' " % { - 'command' : daemon_command, - - 'sudo' : 'sudo -S' if sudo else '', - - 'pidfile' : pidfile, - 'gohome' : 'cd %s ; ' % home if home else '', - 'create' : 'mkdir -p %s ; ' % home if create_home else '', - } - p = subprocess.Popen(command, stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - out, err = p.communicate() - else: - # Start process in a "daemonized" way, using nohup and heavy - # stdin/out redirection to avoid connection issues - (out,err), proc = rspawn( - command, - pidfile = pidfile, - home = home, - stdin = stdin if stdin is not None else '/dev/null', - stdout = stdout if stdout else '/dev/null', - stderr = stderr if stderr else '/dev/null', - sudo = sudo, - host = self.host, - user = self.user, - port = self.port, - agent = self.forward_agent, - identity_file = self.identity_file - ) - - if proc.wait(): - raise RuntimeError, "Failed to set up application: %s %s" % (out,err,) + def _validate_connection(self, guid): + # TODO: Validate! + return True - return (out, err) - - def checkpid(self, path): - # Get PID/PPID - # NOTE: wait a bit for the pidfile to be created - pidtuple = rcheck_pid( - os.path.join(path, 'pid'), - host = self.host, - user = self.user, - port = self.port, - agent = self.forward_agent, - identity_file = self.identity_file - ) - - return pidtuple - - def status(self, pid, ppid): - status = rstatus( - pid, ppid, - host = self.host, - user = self.user, - port = self.port, - agent = self.forward_agent, - identity_file = self.identity_file - ) - - return status - - def kill(self, pid, ppid, sudo = False): - status = self.status(pid, ppid) - if status == RUNNING: - # kill by ppid+pid - SIGTERM first, then try SIGKILL - rkill( - pid, ppid, - host = self.host, - user = self.user, - port = self.port, - agent = self.forward_agent, - sudo = sudo, - identity_file = self.identity_file - ) diff --git a/src/neco/resources/linux/ssh_api.py b/src/neco/resources/linux/ssh_api.py new file mode 100644 index 00000000..93e895ec --- /dev/null +++ b/src/neco/resources/linux/ssh_api.py @@ -0,0 +1,323 @@ +import hashlib +import os +import re + +class SSHAPI(object): + def __init__(self, host, user, identity, port, agent, forward_x11): + self.host = host + self.user = user + # ssh identity file + self.identity = identity + self.port = port + # use ssh agent + self.agent = agent + # forward X11 + self.forward_x11 = forward_x11 + + # TODO: Investigate using http://nixos.org/nix/ + @property + def pm(self): + if self._pm: + return self._pm + + if (not self.host or not self.user): + msg = "Can't resolve package management system. Insufficient data." + self._logger.error(msg) + raise RuntimeError(msg) + + out = self.execute("cat /etc/issue") + + if out.find("Fedora") == 0: + self._pm = "yum" + elif out.find("Debian") == 0 or out.find("Ubuntu") ==0: + self._pm = "apt-get" + else: + msg = "Can't resolve package management system. Unknown OS." + self._logger.error(msg) + raise RuntimeError(msg) + + return self._pm + + @property + def is_localhost(self): + return ( self.host or self.ip ) in ['localhost', '127.0.0.7', '::1'] + + # TODO: Investigate using http://nixos.org/nix/ + def install(self, packages): + if not isinstance(packages, list): + packages = [packages] + + for p in packages: + self.execute("%s -y install %s" % (self.pm, p), sudo = True, + tty = True) + + # TODO: Investigate using http://nixos.org/nix/ + def uninstall(self, packages): + if not isinstance(packages, list): + packages = [packages] + + for p in packages: + self.execute("%s -y remove %s" % (self.pm, p), sudo = True, + tty = True) + + def upload(self, src, dst): + """ Copy content to destination + + src content to copy. Can be a local file, directory or text input + + dst destination path on the remote host (remote is always self.host) + """ + # If src is a string input + if not os.path.isfile(src) and not isdir: + # src is text input that should be uploaded as file + src = cStringIO.StringIO(src) + + if not self.is_localhost: + # Build destination as @: + dst = "%s@%s:%s" % (self.user, self.host, dst) + return self.copy(src, dst) + + def download(self, src, dst): + if not self.is_localhost: + # Build destination as @: + src = "%s@%s:%s" % (self.user, self.host or self.ip, src) + return self.copy(src, dst) + + def is_alive(self, verbose = False): + if self.is_localhost: + return True + + try: + out = self.execute("echo 'ALIVE'", + timeout = 60, + err_on_timeout = False, + persistent = False) + except: + if verbose: + self._logger.warn("Unresponsive node %s got:\n%s%s", self.host, out, err) + return False + + if out.strip().startswith('ALIVE'): + return True + else: + if verbose: + self._logger.warn("Unresponsive node %s got:\n%s%s", self.host, out, err) + return False + + def mkdir(self, path, clean = True): + if clean: + self.rmdir(path) + + return self.execute( + "mkdir -p %s" % path, + timeout = 120, + retry = 3 + ) + + def rmdir(self, path): + return self.execute( + "rm -rf %s" % path, + timeout = 120, + retry = 3 + ) + + def copy(self, src, dst): + if self.is_localhost: + command = ["cp", "-R", src, dst] + p = subprocess.Popen(command, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + out, err = p.communicate() + else: + (out, err), proc = eintr_retry(rcopy)( + src, dst, + port = self.port, + agent = self.agent, + identity_file = self.identity_file) + + if proc.wait(): + msg = "Error uploading to %s got:\n%s%s" %\ + (self.host or self.ip, out, err) + self._logger.error(msg) + raise RuntimeError(msg) + + return (out, err) + + def execute(self, command, + sudo = False, + stdin = None, + tty = False, + env = None, + timeout = None, + retry = 0, + err_on_timeout = True, + connect_timeout = 30, + persistent = True): + """ Notice that this invocation will block until the + execution finishes. If this is not the desired behavior, + use 'run' instead.""" + + if self.is_localhost: + if env: + export = '' + for envkey, envval in env.iteritems(): + export += '%s=%s ' % (envkey, envval) + command = export + command + + if sudo: + command = "sudo " + command + + p = subprocess.Popen(command, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + out, err = p.communicate() + else: + (out, err), proc = eintr_retry(rexec)( + command, + self.host or self.ip, + self.user, + port = self.port, + agent = self.forward_agent, + sudo = sudo, + stdin = stdin, + identity_file = self.identity_file, + tty = tty, + x11 = self.enable_x11, + env = env, + timeout = timeout, + retry = retry, + err_on_timeout = err_on_timeout, + connect_timeout = connect_timeout, + persistent = persistent) + + if proc.wait(): + msg = "Failed to execute command %s at node %s: %s %s" % \ + (command, self.host or self.ip, out, err,) + self._logger.warn(msg) + raise RuntimeError(msg) + + return (out, err) + + def run(self, command, home, + stdin = None, + stdout = 'stdout', + stderr = 'stderr', + sudo = False): + self._logger.info("Running %s", command) + + pidfile = './pid', + + if self.is_localhost: + if stderr == stdout: + stderr = '&1' + else: + stderr = ' ' + stderr + + daemon_command = '{ { %(command)s > %(stdout)s 2>%(stderr)s < %(stdin)s & } ; echo $! 1 > %(pidfile)s ; }' % { + 'command' : command, + 'pidfile' : pidfile, + + 'stdout' : stdout, + 'stderr' : stderr, + 'stdin' : stdin, + } + + cmd = "%(create)s%(gohome)s rm -f %(pidfile)s ; %(sudo)s nohup bash -c '%(command)s' " % { + 'command' : daemon_command, + + 'sudo' : 'sudo -S' if sudo else '', + + 'pidfile' : pidfile, + 'gohome' : 'cd %s ; ' % home if home else '', + 'create' : 'mkdir -p %s ; ' % home if create_home else '', + } + p = subprocess.Popen(command, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + out, err = p.communicate() + else: + # Start process in a "daemonized" way, using nohup and heavy + # stdin/out redirection to avoid connection issues + (out,err), proc = rspawn( + command, + pidfile = pidfile, + home = home, + stdin = stdin if stdin is not None else '/dev/null', + stdout = stdout if stdout else '/dev/null', + stderr = stderr if stderr else '/dev/null', + sudo = sudo, + host = self.host, + user = self.user, + port = self.port, + agent = self.agent, + identity_file = self.file + ) + + if proc.wait(): + raise RuntimeError, "Failed to set up application: %s %s" % (out,err,) + + return (out, err) + + def checkpid(self, path): + # Get PID/PPID + # NOTE: wait a bit for the pidfile to be created + pidtuple = rcheck_pid( + os.path.join(path, 'pid'), + host = self.host, + user = self.user, + port = self.port, + agent = self.agent, + identity_file = self.identity + ) + + return pidtuple + + def status(self, pid, ppid): + status = rstatus( + pid, ppid, + host = self.host, + user = self.user, + port = self.port, + agent = self.agent, + identity_file = self.identity + ) + + return status + + def kill(self, pid, ppid, sudo = False): + status = self.status(pid, ppid) + if status == RUNNING: + # kill by ppid+pid - SIGTERM first, then try SIGKILL + rkill( + pid, ppid, + host = self.host, + user = self.user, + port = self.port, + agent = self.agent, + sudo = sudo, + identity_file = self.identity + ) + +class SSHAPIFactory(object): + _apis = dict() + + @classmethod + def get_api(cls, attributes): + host = attributes.get("hostname") + user = attributes.get("username") + identity = attributes.get("identity", "%s/.ssh/id_rsa" % os.environ['HOME']) + port = attributes.get("port", 22) + agent = attributes.get("agent", True) + forward_X11 = attributes.get("forwardX11", False) + + key = cls.make_key(host, user, identity, port, agent, forward_X11) + api = self._apis.get(key) + + if no api: + api = SSHAPI(host, user, identity, port, agent, forward_X11) + self._apis[key] = api + + return api + + @classmethod + def make_key(cls, *args): + skey = "".join(map(str, args)) + return hashlib.md5(skey).hexdigest() + diff --git a/src/neco/resources/omf/omf_api.py b/src/neco/resources/omf/omf_api.py index aca53e51..412e735d 100644 --- a/src/neco/resources/omf/omf_api.py +++ b/src/neco/resources/omf/omf_api.py @@ -150,6 +150,7 @@ class OMFAPI(object): class OMFAPIFactory(object): + # XXX: put '_apis' instead of '_Api' _Api = dict() @classmethod @@ -169,6 +170,16 @@ class OMFAPIFactory(object): cls._Api[key] = OmfApi return OmfApi + # XXX: this is not a hash :) + # From wikipedia: "A hash function is any algorithm or subroutine that maps large data + # sets of variable length to smaller data sets of a fixed length." + # The idea is to apply a function to get a smaller string. Use hashlib instead. + # e.g: + # import hashlib + # res = slice + "_" + host + "_" + port + # hashlib.md5(res).hexdigest() + # + # XXX: change method name for 'make_key' @classmethod def _hash_api(cls, slice, host, port): res = slice + "_" + host + "_" + port diff --git a/src/neco/resources/omf/omf_node.py b/src/neco/resources/omf/omf_node.py index 9bcffe99..064da473 100644 --- a/src/neco/resources/omf/omf_node.py +++ b/src/neco/resources/omf/omf_node.py @@ -17,6 +17,10 @@ class OMFNode(ResourceManager): hostname = Attribute("hostname", "Hostname of the machine") cpu = Attribute("cpu", "CPU of the node") ram = Attribute("ram", "RAM of the node") + # XXX: flags = "0x02" is not human readable. + # instead: + # from neco.execution.attribute import Attribute, Flags + # xmppSlice = Attribute("xmppSlice","Name of the slice", flags = Flags.Credential) xmppSlice = Attribute("xmppSlice","Name of the slice", flags = "0x02") xmppHost = Attribute("xmppHost", "Xmpp Server",flags = "0x02") xmppPort = Attribute("xmppPort", "Xmpp Port",flags = "0x02") @@ -27,7 +31,7 @@ class OMFNode(ResourceManager): cls._register_attribute(xmppSlice) cls._register_attribute(xmppHost) cls._register_attribute(xmppPort) - cls._register_attribute(xmppPassword) + ls._register_attribute(xmppPassword) @classmethod def _register_filters(cls): @@ -40,6 +44,10 @@ class OMFNode(ResourceManager): cls._register_filter(granularity) cls._register_filter(hardware_type) + # XXX: We don't necessary need to have the credentials at the + # moment we create the RM + # THE OMF API SHOULD BE CREATED ON THE DEPLOY METHOD, NOT NOW + # THIS FORCES MORE CONSTRAINES ON THE WAY WE WILL AUTHOMATE DEPLOYMENT! def __init__(self, ec, guid, creds): super(OMFNode, self).__init__(ec, guid) self.set('xmppSlice', creds['xmppSlice']) @@ -47,9 +55,12 @@ class OMFNode(ResourceManager): self.set('xmppPort', creds['xmppPort']) self.set('xmppPassword', creds['xmppPassword']) + # XXX: Lines should not be more than 80 characters! self._omf_api = OMFAPIFactory.get_api(self.get('xmppSlice'), self.get('xmppHost'), self.get('xmppPort'), self.get('xmppPassword')) self._logger = logging.getLogger("neco.omf.omfNode ") + + # XXX: TO DISCUSS self._logger.setLevel(neco.LOGLEVEL) def _validate_connection(self, guid): diff --git a/src/neco/util/sshfuncs.py b/src/neco/util/sshfuncs.py index aad6039c..f8d1cfc9 100644 --- a/src/neco/util/sshfuncs.py +++ b/src/neco/util/sshfuncs.py @@ -201,7 +201,6 @@ def rexec(command, host, user, def rcopy(source, dest, port = None, agent = True, - recursive = False, identity_file = None): """ Copies file from/to remote sites. @@ -209,214 +208,57 @@ def rcopy(source, dest, Source and destination should have the user and host encoded as per scp specs. - If source is a file object, a special mode will be used to - create the remote file with the same contents. - - If dest is a file object, the remote file (source) will be - read and written into dest. - - In these modes, recursive cannot be True. - Source can be a list of files to copy to a single destination, in which case it is advised that the destination be a folder. """ - if isinstance(source, file) and source.tell() == 0: - source = source.name - elif hasattr(source, 'read'): - tmp = tempfile.NamedTemporaryFile() - while True: - buf = source.read(65536) - if buf: - tmp.write(buf) - else: - break - tmp.seek(0) - source = tmp.name + # Parse destination as @: + if isinstance(dest, basestring) and ':' in dest: + remspec, path = dest.split(':',1) + elif isinstance(source, basestring) and ':' in source: + remspec, path = source.split(':',1) + else: + raise ValueError, "Both endpoints cannot be local" + user, host = remspec.rsplit('@',1) + + raw_string = r'''rsync -rlpcSz --timeout=900 ''' + raw_string += r''' -e 'ssh -o BatchMode=yes ''' + raw_string += r''' -o NoHostAuthenticationForLocalhost=yes ''' + raw_string += r''' -o StrictHostKeyChecking=no ''' + raw_string += r''' -o ConnectionAttempts=3 ''' + + if openssh_has_persist(): + control_path = make_control_path(user, host, port) + raw_string += r''' -o ControlMaster=auto ''' + raw_string += r''' -o ControlPath=%s ''' % control_path + + if port: + raw_string += r''' -p %d ''' % port - if isinstance(source, file) or isinstance(dest, file) \ - or hasattr(source, 'read') or hasattr(dest, 'write'): - assert not recursive + if identity_file: + raw_string += r''' -i "%s" ''' % identity_file - # Parse source/destination as @: - if isinstance(dest, basestring) and ':' in dest: - remspec, path = dest.split(':',1) - elif isinstance(source, basestring) and ':' in source: - remspec, path = source.split(':',1) - else: - raise ValueError, "Both endpoints cannot be local" - user,host = remspec.rsplit('@',1) - tmp_known_hosts = None - - args = ['ssh', '-l', user, '-C', - # Don't bother with localhost. Makes test easier - '-o', 'NoHostAuthenticationForLocalhost=yes', - # XXX: Possible security issue - # Avoid interactive requests to accept new host keys - '-o', 'StrictHostKeyChecking=no', - '-o', 'ConnectTimeout=30', - '-o', 'ConnectionAttempts=3', - '-o', 'ServerAliveInterval=30', - '-o', 'TCPKeepAlive=yes', - host ] - - if openssh_has_persist(): - control_path = make_control_path(user, host, port) - args.extend([ - '-o', 'ControlMaster=auto', - '-o', 'ControlPath=%s' % control_path, - '-o', 'ControlPersist=60' ]) - if port: - args.append('-P%d' % port) - if identity_file: - args.extend(('-i', identity_file)) - - if isinstance(source, file) or hasattr(source, 'read'): - args.append('cat > %s' % (shell_escape(path),)) - elif isinstance(dest, file) or hasattr(dest, 'write'): - args.append('cat %s' % (shell_escape(path),)) - else: - raise AssertionError, "Unreachable code reached! :-Q" + # closing -e 'ssh...' + raw_string += r''' ' ''' - # connects to the remote host and starts a remote connection - if isinstance(source, file): - proc = subprocess.Popen(args, - stdout = open('/dev/null','w'), - stderr = subprocess.PIPE, - stdin = source) - err = proc.stderr.read() - eintr_retry(proc.wait)() - return ((None,err), proc) - elif isinstance(dest, file): - proc = subprocess.Popen(args, - stdout = open('/dev/null','w'), - stderr = subprocess.PIPE, - stdin = source) - err = proc.stderr.read() - eintr_retry(proc.wait)() - return ((None,err), proc) - elif hasattr(source, 'read'): - # file-like (but not file) source - proc = subprocess.Popen(args, - stdout = open('/dev/null','w'), - stderr = subprocess.PIPE, - stdin = subprocess.PIPE) - - buf = None - err = [] - while True: - if not buf: - buf = source.read(4096) - if not buf: - #EOF - break - - rdrdy, wrdy, broken = select.select( - [proc.stderr], - [proc.stdin], - [proc.stderr,proc.stdin]) - - if proc.stderr in rdrdy: - # use os.read for fully unbuffered behavior - err.append(os.read(proc.stderr.fileno(), 4096)) - - if proc.stdin in wrdy: - proc.stdin.write(buf) - buf = None - - if broken: - break - proc.stdin.close() - err.append(proc.stderr.read()) - - eintr_retry(proc.wait)() - return ((None,''.join(err)), proc) - elif hasattr(dest, 'write'): - # file-like (but not file) dest - proc = subprocess.Popen(args, - stdout = subprocess.PIPE, - stderr = subprocess.PIPE, - stdin = open('/dev/null','w')) - - buf = None - err = [] - while True: - rdrdy, wrdy, broken = select.select( - [proc.stderr, proc.stdout], - [], - [proc.stderr, proc.stdout]) - - if proc.stderr in rdrdy: - # use os.read for fully unbuffered behavior - err.append(os.read(proc.stderr.fileno(), 4096)) - - if proc.stdout in rdrdy: - # use os.read for fully unbuffered behavior - buf = os.read(proc.stdout.fileno(), 4096) - dest.write(buf) - - if not buf: - #EOF - break - - if broken: - break - err.append(proc.stderr.read()) - - eintr_retry(proc.wait)() - return ((None,''.join(err)), proc) - else: - raise AssertionError, "Unreachable code reached! :-Q" + if isinstance(source,list): + source = ' '.join(source) else: - # Parse destination as @: - if isinstance(dest, basestring) and ':' in dest: - remspec, path = dest.split(':',1) - elif isinstance(source, basestring) and ':' in source: - remspec, path = source.split(':',1) - else: - raise ValueError, "Both endpoints cannot be local" - user,host = remspec.rsplit('@',1) - - # plain scp - args = ['scp', '-q', '-p', '-C', - # Don't bother with localhost. Makes test easier - '-o', 'NoHostAuthenticationForLocalhost=yes', - # XXX: Possible security issue - # Avoid interactive requests to accept new host keys - '-o', 'StrictHostKeyChecking=no', - '-o', 'ConnectTimeout=30', - '-o', 'ConnectionAttempts=3', - '-o', 'ServerAliveInterval=30', - '-o', 'TCPKeepAlive=yes' ] - - if port: - args.append('-P%d' % port) - if recursive: - args.append('-r') - if identity_file: - args.extend(('-i', identity_file)) - - if isinstance(source,list): - args.extend(source) - else: - if openssh_has_persist(): - control_path = make_control_path(user, host, port) - args.extend([ - '-o', 'ControlMaster=no', - '-o', 'ControlPath=%s' % control_path ]) - args.append(source) + source = '"%s"' % source - args.append(dest) + raw_string += r''' %s ''' % source + raw_string += r''' %s ''' % dest - # connects to the remote host and starts a remote connection - proc = subprocess.Popen(args, - stdout = subprocess.PIPE, - stdin = subprocess.PIPE, - stderr = subprocess.PIPE) - - comm = proc.communicate() - eintr_retry(proc.wait)() - return (comm, proc) + # connects to the remote host and starts a remote connection + proc = subprocess.Popen(raw_string, + shell=True, + stdout = subprocess.PIPE, + stdin = subprocess.PIPE, + stderr = subprocess.PIPE) + + comm = proc.communicate() + eintr_retry(proc.wait)() + return (comm, proc) def rspawn(command, pidfile, stdout = '/dev/null', @@ -465,7 +307,8 @@ def rspawn(command, pidfile, stderr = '&1' else: stderr = ' ' + stderr - + + #XXX: ppid is always 1!!! daemon_command = '{ { %(command)s > %(stdout)s 2>%(stderr)s < %(stdin)s & } ; echo $! 1 > %(pidfile)s ; }' % { 'command' : command, 'pidfile' : pidfile, @@ -562,6 +405,7 @@ def rstatus(pid, ppid, One of NOT_STARTED, RUNNING, FINISHED """ + # XXX: ppid unused (out,err),proc = rexec( "ps --pid %(pid)d -o pid | grep -c %(pid)d ; true" % { 'ppid' : ppid, @@ -590,7 +434,7 @@ def rstatus(pid, ppid, @eintr_retry -def rkill(pid, ppid, +def rkill(pid, ppid, host = None, port = None, user = None, diff --git a/test/util/sshfuncs.py b/test/util/sshfuncs.py new file mode 100644 index 00000000..756c79a1 --- /dev/null +++ b/test/util/sshfuncs.py @@ -0,0 +1,245 @@ +#!/usr/bin/env python + +from neco.util.sshfuncs import * + +import getpass +import unittest +import os +import subprocess +import re +import signal +import shutil +import socket +import subprocess +import tempfile +import time + +def find_bin(name, extra_path = None): + search = [] + if "PATH" in os.environ: + search += os.environ["PATH"].split(":") + for pref in ("/", "/usr/", "/usr/local/"): + for d in ("bin", "sbin"): + search.append(pref + d) + if extra_path: + search += extra_path + + for d in search: + try: + os.stat(d + "/" + name) + return d + "/" + name + except OSError, e: + if e.errno != os.errno.ENOENT: + raise + return None + +def find_bin_or_die(name, extra_path = None): + r = find_bin(name) + if not r: + raise RuntimeError(("Cannot find `%s' command, impossible to " + + "continue.") % name) + return r + +def gen_ssh_keypair(filename): + ssh_keygen = find_bin_or_die("ssh-keygen") + args = [ssh_keygen, '-q', '-N', '', '-f', filename] + assert subprocess.Popen(args).wait() == 0 + return filename, "%s.pub" % filename + +def add_key_to_agent(filename): + ssh_add = find_bin_or_die("ssh-add") + args = [ssh_add, filename] + null = file("/dev/null", "w") + assert subprocess.Popen(args, stderr = null).wait() == 0 + null.close() + +def get_free_port(): + s = socket.socket() + s.bind(("127.0.0.1", 0)) + port = s.getsockname()[1] + return port + +_SSH_CONF = """ListenAddress 127.0.0.1:%d +Protocol 2 +HostKey %s +UsePrivilegeSeparation no +PubkeyAuthentication yes +PasswordAuthentication no +AuthorizedKeysFile %s +UsePAM no +AllowAgentForwarding yes +PermitRootLogin yes +StrictModes no +PermitUserEnvironment yes +""" + +def gen_sshd_config(filename, port, server_key, auth_keys): + conf = open(filename, "w") + text = _SSH_CONF % (port, server_key, auth_keys) + conf.write(text) + conf.close() + return filename + +def gen_auth_keys(pubkey, output, environ): + #opts = ['from="127.0.0.1/32"'] # fails in stupid yans setup + opts = [] + for k, v in environ.items(): + opts.append('environment="%s=%s"' % (k, v)) + + lines = file(pubkey).readlines() + pubkey = lines[0].split()[0:2] + out = file(output, "w") + out.write("%s %s %s\n" % (",".join(opts), pubkey[0], pubkey[1])) + out.close() + return output + +def start_ssh_agent(): + ssh_agent = find_bin_or_die("ssh-agent") + proc = subprocess.Popen([ssh_agent], stdout = subprocess.PIPE) + (out, foo) = proc.communicate() + assert proc.returncode == 0 + d = {} + for l in out.split("\n"): + match = re.search("^(\w+)=([^ ;]+);.*", l) + if not match: + continue + k, v = match.groups() + os.environ[k] = v + d[k] = v + return d + +def stop_ssh_agent(data): + # No need to gather the pid, ssh-agent knows how to kill itself; after we + # had set up the environment + ssh_agent = find_bin_or_die("ssh-agent") + null = file("/dev/null", "w") + proc = subprocess.Popen([ssh_agent, "-k"], stdout = null) + null.close() + assert proc.wait() == 0 + for k in data: + del os.environ[k] + +class test_environment(object): + def __init__(self): + sshd = find_bin_or_die("sshd") + environ = {} + self.dir = tempfile.mkdtemp() + self.server_keypair = gen_ssh_keypair( + os.path.join(self.dir, "server_key")) + self.client_keypair = gen_ssh_keypair( + os.path.join(self.dir, "client_key")) + self.authorized_keys = gen_auth_keys(self.client_keypair[1], + os.path.join(self.dir, "authorized_keys"), environ) + self.port = get_free_port() + self.sshd_conf = gen_sshd_config( + os.path.join(self.dir, "sshd_config"), + self.port, self.server_keypair[0], self.authorized_keys) + + self.sshd = subprocess.Popen([sshd, '-q', '-D', '-f', self.sshd_conf]) + self.ssh_agent_vars = start_ssh_agent() + add_key_to_agent(self.client_keypair[0]) + + def __del__(self): + if self.sshd: + os.kill(self.sshd.pid, signal.SIGTERM) + self.sshd.wait() + if self.ssh_agent_vars: + stop_ssh_agent(self.ssh_agent_vars) + shutil.rmtree(self.dir) + +class SSHfuncsTestCase(unittest.TestCase): + def test_rexec(self): + env = test_environment() + user = getpass.getuser() + host = "localhost" + + command = "hostname" + + plocal = subprocess.Popen(command, stdout=subprocess.PIPE, + stdin=subprocess.PIPE) + outlocal, errlocal = plocal.communicate() + + (outremote, errrmote), premote = rexec(command, host, user, + port = env.port, agent = True) + + self.assertEquals(outlocal, outremote) + + def test_rcopy(self): + env = test_environment() + user = getpass.getuser() + host = "localhost" + + # create some temp files and directories to copy + dirpath = tempfile.mkdtemp() + f = tempfile.NamedTemporaryFile(dir=dirpath, delete=False) + f.close() + + f1 = tempfile.NamedTemporaryFile(delete=False) + f1.close() + f1.name + + source = [dirpath, f1.name] + destdir = tempfile.mkdtemp() + dest = "%s@%s:%s" % (user, host, destdir) + rcopy(source, dest, port = env.port, agent = True) + + files = [] + def recls(files, dirname, names): + files.extend(names) + os.path.walk(destdir, recls, files) + + origfiles = map(lambda s: os.path.basename(s), [dirpath, f.name, f1.name]) + + self.assertEquals(sorted(origfiles), sorted(files)) + + def test_rproc_manage(self): + env = test_environment() + user = getpass.getuser() + host = "localhost" + command = "ping localhost" + + f = tempfile.NamedTemporaryFile(delete=False) + pidfile = f.name + + (out,err), proc = rspawn( + command, + pidfile, + host = host, + user = user, + port = env.port, + agent = True) + + time.sleep(2) + + (pid, ppid) = rcheck_pid(pidfile, + host = host, + user = user, + port = env.port, + agent = True) + + status = rstatus(pid, ppid, + host = host, + user = user, + port = env.port, + agent = True) + + self.assertEquals(status, RUNNING) + + rkill(pid, ppid, + host = host, + user = user, + port = env.port, + agent = True) + + status = rstatus(pid, ppid, + host = host, + user = user, + port = env.port, + agent = True) + + self.assertEquals(status, FINISHED) + + +if __name__ == '__main__': + unittest.main() +