Synchronization fixes:
[nepi.git] / src / nepi / testbeds / planetlab / node.py
index f5d159a..1e52a51 100644 (file)
@@ -7,6 +7,10 @@ import operator
 import rspawn
 import time
 import os
+import collections
+import cStringIO
+
+from nepi.util import server
 
 class Node(object):
     BASEFILTERS = {
@@ -27,6 +31,9 @@ class Node(object):
         'max_bandwidth' : ('bw%(timeframe)s', '[value'),
     }    
     
+    DEPENDS_PIDFILE = '/tmp/nepi-depends.pid'
+    DEPENDS_LOGFILE = '/tmp/nepi-depends.log'
+    
     def __init__(self, api=None):
         if not api:
             api = plcapi.PLCAPI()
@@ -47,17 +54,36 @@ class Node(object):
         self.max_num_external_ifaces = None
         self.timeframe = 'm'
         
-        # Applications add requirements to connected nodes
+        # Applications and routes add requirements to connected nodes
         self.required_packages = set()
+        self.required_vsys = set()
+        self.pythonpath = []
+        self.env = collections.defaultdict(list)
         
         # Testbed-derived attributes
         self.slicename = None
         self.ident_path = None
+        self.server_key = None
         self.home_path = None
         
         # Those are filled when an actual node is allocated
         self._node_id = None
     
+    @property
+    def _nepi_testbed_environment_setup(self):
+        command = cStringIO.StringIO()
+        command.write('export PYTHONPATH=$PYTHONPATH:%s' % (
+            ':'.join(["${HOME}/"+server.shell_escape(s) for s in self.pythonpath])
+        ))
+        command.write(' ; export PATH=$PATH:%s' % (
+            ':'.join(["${HOME}/"+server.shell_escape(s) for s in self.pythonpath])
+        ))
+        if self.env:
+            for envkey, envvals in self.env.iteritems():
+                for envval in envvals:
+                    command.write(' ; export %s=%s' % (envkey, envval))
+        return command.getvalue()
+    
     def build_filters(self, target_filters, filter_map):
         for attr, tag in filter_map.iteritems():
             value = getattr(self, attr, None)
@@ -105,6 +131,31 @@ class Node(object):
                 candidates &= set(map(operator.itemgetter('node_id'),
                     self._api.GetNodeTags(filters=tagfilter, fields=fields)))
         
+        # filter by vsys tags - special case since it doesn't follow
+        # the usual semantics
+        if self.required_vsys:
+            newcandidates = collections.defaultdict(set)
+            
+            vsys_tags = self._api.GetNodeTags(
+                tagname='vsys', 
+                node_id = list(candidates), 
+                fields = ['node_id','value'])
+            
+            vsys_tags = map(
+                operator.itemgetter(['node_id','value']),
+                vsys_tags)
+            
+            required_vsys = self.required_vsys
+            for node_id, value in vsys_tags:
+                if value in required_vsys:
+                    newcandidates[value].add(node_id)
+            
+            # take only those that have all the required vsys tags
+            newcandidates = reduce(
+                lambda accum, new : accum & new,
+                newcandidates.itervalues(),
+                candidates)
+        
         # filter by iface count
         if self.min_num_external_ifaces is not None or self.max_num_external_ifaces is not None:
             # fetch interfaces for all, in one go
@@ -133,7 +184,7 @@ class Node(object):
         self.fetch_node_info()
     
     def fetch_node_info(self):
-        info = self._api.GetNodes(self._node_id)
+        info = self._api.GetNodes(self._node_id)[0]
         tags = dict( (t['tagname'],t['value'])
                      for t in self._api.GetNodeTags(node_id=self._node_id, fields=('tagname','value')) )
 
@@ -158,6 +209,9 @@ class Node(object):
         if 'interface_ids' in info:
             self.min_num_external_ifaces = \
             self.max_num_external_ifaces = len(info['interface_ids'])
+        
+        if 'ssh_rsa_key' in info:
+            self.server_key = info['ssh_rsa_key']
 
     def validate(self):
         if self.home_path is None:
@@ -170,8 +224,8 @@ class Node(object):
     def install_dependencies(self):
         if self.required_packages:
             # TODO: make dependant on the experiment somehow...
-            pidfile = '/tmp/nepi-depends.pid'
-            logfile = '/tmp/nepi-depends.log'
+            pidfile = self.DEPENDS_PIDFILE
+            logfile = self.DEPENDS_LOGFILE
             
             # Start process in a "daemonized" way, using nohup and heavy
             # stdin/out redirection to avoid connection issues
@@ -188,14 +242,17 @@ class Node(object):
                 user = self.slicename,
                 agent = None,
                 ident_key = self.ident_path,
+                server_key = self.server_key,
                 sudo = True
                 )
             
             if proc.wait():
                 raise RuntimeError, "Failed to set up application: %s %s" % (out,err,)
     
-    def wait_dependencies(self, pidprobe=1, probe=10, pidmax=10):
+    def wait_dependencies(self, pidprobe=1, probe=0.5, pidmax=10, probemax=10):
         if self.required_packages:
+            pidfile = self.DEPENDS_PIDFILE
+            
             # get PID
             pid = ppid = None
             for probenum in xrange(pidmax):
@@ -205,7 +262,8 @@ class Node(object):
                     port = None,
                     user = self.slicename,
                     agent = None,
-                    ident_key = self.ident_path
+                    ident_key = self.ident_path,
+                    server_key = self.server_key
                     )
                 if pidtuple:
                     pid, ppid = pidtuple
@@ -222,9 +280,73 @@ class Node(object):
                     port = None,
                     user = self.slicename,
                     agent = None,
-                    ident_key = self.ident_path
+                    ident_key = self.ident_path,
+                    server_key = self.server_key
                     ):
                 time.sleep(probe)
+                probe = min(probemax, 1.5*probe)
         
+    def is_alive(self):
+        # Make sure all the paths are created where 
+        # they have to be created for deployment
+        (out,err),proc = server.popen_ssh_command(
+            "echo 'ALIVE'",
+            host = self.hostname,
+            port = None,
+            user = self.slicename,
+            agent = None,
+            ident_key = self.ident_path,
+            server_key = self.server_key
+            )
+        
+        if proc.wait():
+            return False
+        elif not err and out.strip() == 'ALIVE':
+            return True
+        else:
+            return False
+    
 
+    def configure_routes(self, routes, devs):
+        """
+        Add the specified routes to the node's routing table
+        """
+        rules = []
+        
+        for route in routes:
+            for dev in devs:
+                if dev.routes_here(route):
+                    # Schedule rule
+                    dest, prefix, nexthop = route
+                    rules.append(
+                        "add %s%s gw %s %s" % (
+                            dest,
+                            (("/%d" % (prefix,)) if prefix and prefix != 32 else ""),
+                            nexthop,
+                            dev.if_name,
+                        )
+                    )
+                    
+                    # Stop checking
+                    break
+            else:
+                raise RuntimeError, "Route %s cannot be bound to any virtual interface " \
+                    "- PL can only handle rules over virtual interfaces. Candidates are: %s" % (route,devs)
+        
+        (out,err),proc = server.popen_ssh_command(
+            "( sudo -S bash -c 'cat /vsys/vroute.out >&2' & ) ; sudo -S bash -c 'cat > /vsys/vroute.in'" % dict(
+                home = server.shell_escape(self.home_path)),
+            host = self.hostname,
+            port = None,
+            user = self.slicename,
+            agent = None,
+            ident_key = self.ident_path,
+            server_key = self.server_key,
+            stdin = '\n'.join(rules)
+            )
+        
+        if proc.wait() or err:
+            raise RuntimeError, "Could not set routes: %s%s" % (out,err)
+        
+