little changes
[nepi.git] / src / neco / resources / linux / ssh_api.py
index 93e895e..7f5009e 100644 (file)
@@ -1,9 +1,15 @@
+
+from neco.util.sshfuncs import eintr_retry, rexec, rcopy, rspawn, \
+        rcheckpid, rstatus, rkill, RUNNING, FINISHED 
+
 import hashlib
+import logging
 import os
 import re
+import tempfile
 
-class SSHAPI(object):
-    def __init__(self, host, user, identity, port, agent, forward_x11):
+class SSHApi(object):
+    def __init__(self, host, user, port, identity, agent, forward_x11):
         self.host = host
         self.user = user
         # ssh identity file
@@ -14,6 +20,10 @@ class SSHAPI(object):
         # forward X11 
         self.forward_x11 = forward_x11
 
+        self._pm = None
+        
+        self._logger = logging.getLogger("neco.linux.SSHApi")
+
     # TODO: Investigate using http://nixos.org/nix/
     @property
     def pm(self):
@@ -25,7 +35,7 @@ class SSHAPI(object):
             self._logger.error(msg)
             raise RuntimeError(msg)
 
-        out = self.execute("cat /etc/issue")
+        out, err = self.execute("cat /etc/issue")
 
         if out.find("Fedora") == 0:
             self._pm = "yum"
@@ -40,7 +50,7 @@ class SSHAPI(object):
 
     @property
     def is_localhost(self):
-        return ( self.host or self.ip ) in ['localhost', '127.0.0.7', '::1']
+        return self.host in ['localhost', '127.0.0.7', '::1']
 
     # TODO: Investigate using http://nixos.org/nix/
     def install(self, packages):
@@ -67,20 +77,27 @@ class SSHAPI(object):
 
            dst  destination path on the remote host (remote is always self.host)
         """
-        # If src is a string input 
-        if not os.path.isfile(src) and not isdir:
-            # src is text input that should be uploaded as file           
-            src = cStringIO.StringIO(src)
+        # If source is a string input 
+        if not os.path.isfile(src):
+            # src is text input that should be uploaded as file
+            # create a temporal file with the content to upload
+            f = tempfile.NamedTemporaryFile(delete=False)
+            f.write(src)
+            f.close()
+            src = f.name
 
         if not self.is_localhost:
             # Build destination as <user>@<server>:<path>
             dst = "%s@%s:%s" % (self.user, self.host, dst)
-        return self.copy(src, dst)
+
+        ret = self.copy(src, dst)
+
+        return ret
 
     def download(self, src, dst):
         if not self.is_localhost:
             # Build destination as <user>@<server>:<path>
-            src = "%s@%s:%s" % (self.user, self.host or self.ip, src)
+            src = "%s@%s:%s" % (self.user, self.host, src)
         return self.copy(src, dst)
         
     def is_alive(self, verbose = False):
@@ -88,7 +105,7 @@ class SSHAPI(object):
             return True
 
         try:
-            out = self.execute("echo 'ALIVE'",
+            (out, err) = self.execute("echo 'ALIVE'",
                 timeout = 60,
                 err_on_timeout = False,
                 persistent = False)
@@ -132,11 +149,11 @@ class SSHAPI(object):
                 src, dst, 
                 port = self.port,
                 agent = self.agent,
-                identity_file = self.identity_file)
+                identity = self.identity)
 
             if proc.wait():
                 msg = "Error uploading to %s got:\n%s%s" %\
-                        (self.host or self.ip, out, err)
+                        (self.host, out, err)
                 self._logger.error(msg)
                 raise RuntimeError(msg)
 
@@ -172,15 +189,15 @@ class SSHAPI(object):
         else:
             (out, err), proc = eintr_retry(rexec)(
                     command, 
-                    self.host or self.ip
+                    self.host, 
                     self.user,
                     port = self.port, 
-                    agent = self.forward_agent,
+                    agent = self.agent,
                     sudo = sudo,
                     stdin = stdin, 
-                    identity_file = self.identity_file,
+                    identity = self.identity,
                     tty = tty,
-                    x11 = self.enable_x11,
+                    x11 = self.forward_x11,
                     env = env,
                     timeout = timeout,
                     retry = retry,
@@ -190,10 +207,9 @@ class SSHAPI(object):
 
             if proc.wait():
                 msg = "Failed to execute command %s at node %s: %s %s" % \
-                        (command, self.host or self.ip, out, err,)
+                        (command, self.host, out, err,)
                 self._logger.warn(msg)
                 raise RuntimeError(msg)
-
         return (out, err)
 
     def run(self, command, home, 
@@ -203,7 +219,7 @@ class SSHAPI(object):
             sudo = False):
         self._logger.info("Running %s", command)
         
-        pidfile = './pid',
+        pidfile = './pid'
 
         if self.is_localhost:
             if stderr == stdout:
@@ -247,7 +263,7 @@ class SSHAPI(object):
                 user = self.user,
                 port = self.port,
                 agent = self.agent,
-                identity_file = self.file
+                identity = self.identity
                 )
             
             if proc.wait():
@@ -258,13 +274,13 @@ class SSHAPI(object):
     def checkpid(self, path):            
         # Get PID/PPID
         # NOTE: wait a bit for the pidfile to be created
-        pidtuple = rcheck_pid(
+        pidtuple = rcheckpid(
             os.path.join(path, 'pid'),
             host = self.host,
             user = self.user,
             port = self.port,
             agent = self.agent,
-            identity_file = self.identity
+            identity = self.identity
             )
         
         return pidtuple
@@ -276,7 +292,7 @@ class SSHAPI(object):
                 user = self.user,
                 port = self.port,
                 agent = self.agent,
-                identity_file = self.identity
+                identity = self.identity
                 )
            
         return status
@@ -292,27 +308,22 @@ class SSHAPI(object):
                 port = self.port,
                 agent = self.agent,
                 sudo = sudo,
-                identity_file = self.identity
+                identity = self.identity
                 )
 
-class SSHAPIFactory(object):
+class SSHApiFactory(object):
     _apis = dict()
 
     @classmethod 
-    def get_api(cls, attributes):
-        host = attributes.get("hostname")
-        user = attributes.get("username")
-        identity = attributes.get("identity", "%s/.ssh/id_rsa" % os.environ['HOME'])
-        port = attributes.get("port", 22)
-        agent = attributes.get("agent", True)
-        forward_X11 = attributes.get("forwardX11", False)
-
-        key = cls.make_key(host, user, identity, port, agent, forward_X11)
-        api = self._apis.get(key)
-
-        if no api:
-            api = SSHAPI(host, user, identity, port, agent, forward_X11)
-            self._apis[key] = api
+    def get_api(cls, host, user, port = 22, identity = None, 
+            agent = True, forward_X11 = False):
+        key = cls.make_key(host, user, port, agent, forward_X11)
+        api = cls._apis.get(key)
+
+        if not api:
+            api = SSHApi(host, user, port, identity, agent, forward_X11)
+            cls._apis[key] = api
 
         return api