First version. Most definitely a work in progress.
authorDavid E. Eisenstat <deisenst@cs.princeton.edu>
Fri, 22 Sep 2006 22:25:54 +0000 (22:25 +0000)
committerDavid E. Eisenstat <deisenst@cs.princeton.edu>
Fri, 22 Sep 2006 22:25:54 +0000 (22:25 +0000)
15 files changed:
Makefile [new file with mode: 0644]
README.txt [new file with mode: 0644]
accounts.py [new file with mode: 0644]
api.py [new file with mode: 0644]
bwcap.py [new file with mode: 0644]
config.py [new file with mode: 0644]
database.py [new file with mode: 0644]
delegate.py [new file with mode: 0644]
forward_api_calls.c [new file with mode: 0644]
logger.py [new file with mode: 0644]
nm.py [new file with mode: 0644]
plc.py [new file with mode: 0644]
sliver.py [new file with mode: 0644]
ticket.py [new file with mode: 0644]
tools.py [new file with mode: 0644]

diff --git a/Makefile b/Makefile
new file mode 100644 (file)
index 0000000..5ee8c18
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,6 @@
+forward_api_calls: forward_api_calls.c
+       $(CC) -Wall -Os -o $@ $?
+       strip $@
+clean:
+       rm -f forward_api_calls
+.PHONY: clean
diff --git a/README.txt b/README.txt
new file mode 100644 (file)
index 0000000..f532a85
--- /dev/null
@@ -0,0 +1,109 @@
+THE NEW NODE MANAGER
+====================
+
+This is a very preliminary version of the new node manager.  Currently
+it is set up to download slices.xml; however, not all of the
+implemented functionality is accessible via slices.xml.
+
+FILES
+=====
+
+accounts.py - Account management functionality generic between
+delegate accounts and VServers.
+
+api.py - XMLRPC interface to Node Manager functionality.  Runs on port
+812, supports a Help() call with more information.
+
+bwcap.py - Sets the bandwidth cap via the bwlimit module.  The bwlimit
+calls are commented out because they've been giving me a bunch of
+errors.
+
+config.py - Configuration parameters.  You'll probably want to change
+SA_HOSTNAME to the PLC address.
+
+database.py - The dreaded NM database.  The main class defined is a
+dict subclass, which both indexes and stores various records.  These
+records include the sliver/delegate records, as well as the timestamp,
+node bw cap, and any other crap PLC wants to put there.
+
+delegate.py - Create and delete delegate accounts.  These accounts
+have low space overhead (unlike a VServer) and serve to authenticate
+remote NM users.
+
+forward_api_calls.c - The forward_api_calls program proxies stdin to
+the Unix domain socket /tmp/node_mgr_api, letting Node Manager take
+advantage of ssh authentication.  It is intended for use as a shell on
+a special delegate account.
+
+logger.py - This is a very basic logger.
+
+Makefile - For compiling forward_api_calls.
+
+nm.py - The main program.
+
+plc.py - Downloads and parses slices.xml, reads the node id file.
+
+README.txt - Duh.
+
+sliver.py - Handles all VServer functionality.
+
+ticket.py - Not used at the moment; contains a demonstration of
+xmlsec1.
+
+tools.py - Various convenience functions for functionality provided by
+Linux.
+
+RUNNING
+=======
+
+Change SA_HOSTNAME in config.py and run nm.py.  No bootstrapping
+required.
+
+INTERNALS
+=========
+
+At the moment, the main thread loops forever, fetching slices.xml and
+updating the database.  Other threads handle incoming API connections
+(each connection is handled by a separate thread) and the database
+dumper.  There is also one thread per account, which supervises
+creation/deletion/resource initialization for that account.  The other
+threads request operations by means of a queue.
+
+Other than the queues, the threads synchronize by acquiring a global
+database lock before reading/writing the database.  The database
+itself is a collection of records, which are just Python dicts with
+certain required fields.  The most important of these fields are
+'timestamp', 'expiry', and 'record_key'.  'record_key' serves to
+uniquely identify a particular record; the only naming conventions
+followed are that account records have record_key <account
+type>_<account name>; thus sliver princeton_sirius has record_key
+'sliver_princeton_sirius'.
+
+The two main features that will not be familiar from the old node
+manager are delegates and loans.  Delegates, as described above, are
+lightweight accounts whose sole purpose is to proxy NM API calls from
+outside.  The current code makes a delegate account 'del_snoop' that's
+allowed to spy on everyone's RSpec; you'll need to change the key in
+plc.py order to use it.  Loans are resource transfers from one sliver
+to another; the format for loans is a list of triples: recipient
+sliver, resource type, amount.  Thus for princeton_sirius to give 20%
+guaranteed CPU to princeton_eisentest, it would call
+
+api.SetLoans(['princeton_eisentest', 'nm_cpu_guaranteed_share', 200])
+
+provided, of course, that it has 200 guaranteed shares :)
+
+POSTSCRIPT
+==========
+
+The log file will come in a great deal of use when attempting to
+use/debug node manager; it lives at /var/log/pl_node_mgr.log.  If you
+break the DB, you should kill the pickled copy, which lives at
+<config.py:DB_FILE>.
+
+I have been refactoring the code constantly in an attempt to keep the
+amount of glue to a minimum; unfortunately comments quickly grow stale
+in such an environment, and I have not yet made any attempt to comment
+reasonably.  Until such time as I do, I'm on the hook for limited
+support of this thing.  Please feel free to contact me at
+deisenst@cs.princeton.edu.
diff --git a/accounts.py b/accounts.py
new file mode 100644 (file)
index 0000000..b22a4ba
--- /dev/null
@@ -0,0 +1,109 @@
+import Queue
+import os
+import pwd
+import threading
+
+import logger
+import tools
+
+
+_name_worker_lock = threading.Lock()
+_name_worker = {}
+
+def all():
+    pw_ents = pwd.getpwall()
+    for pw_ent in pw_ents:
+        if pw_ent[6] in acct_class_by_shell:
+            yield acct_class_by_shell[pw_ent[6]].TYPE, pw_ent[0]
+
+def get(name):
+    _name_worker_lock.acquire()
+    try:
+        if name not in _name_worker: _name_worker[name] = Worker(name)
+        return _name_worker[name]
+    finally: _name_worker_lock.release()
+
+
+def install_ssh_keys(rec):
+    """Write <rec['ssh_keys']> to <rec['name']>'s authorized_keys file."""
+    dot_ssh = '/home/%s/.ssh' % rec['name']
+    def do_installation():
+        if not os.access(dot_ssh, os.F_OK): os.mkdir(dot_ssh)
+        tools.write_file(dot_ssh + '/authorized_keys',
+                         lambda thefile: thefile.write(rec['ssh_keys']))
+    logger.log('%s: installing ssh keys' % rec['name'])
+    tools.fork_as(rec['name'], do_installation)
+
+
+TYPES = []
+acct_class_by_shell = {}
+acct_class_by_type = {}
+
+def register_account_type(acct_class):
+    TYPES.append(acct_class.TYPE)
+    acct_class_by_shell[acct_class.SHELL] = acct_class
+    acct_class_by_type[acct_class.TYPE] = acct_class
+
+
+class Worker:
+    # these semaphores are acquired before creating/destroying an account
+    _create_sem = threading.Semaphore(1)
+    _destroy_sem = threading.Semaphore(1)
+
+    def __init__(self, name):
+        self.name = name
+        self._acct = None
+        self._q = Queue.Queue()
+        tools.as_daemon_thread(self._run)
+
+    def ensure_created(self, rec):
+        self._q.put((self._ensure_created, tools.deepcopy(rec)))
+
+    def _ensure_created(self, rec):
+        curr_class = self._get_class()
+        next_class = acct_class_by_type[rec['account_type']]
+        if next_class != curr_class:
+            self._destroy(curr_class)
+            self._create_sem.acquire()
+            try: next_class.create(self.name)
+            finally: self._create_sem.release()
+        self._make_acct_obj()
+        self._acct.configure(rec)
+        if next_class != curr_class: self._acct.start()
+
+    def ensure_destroyed(self): self._q.put((self._ensure_destroyed,))
+    def _ensure_destroyed(self): self._destroy(self._get_class())
+
+    def start(self): self._q.put((self._start,))
+    def _start(self):
+        self._make_acct_obj()
+        self._acct.start()
+
+    def stop(self): self._q.put((self._stop,))
+    def _stop(self):
+        self._make_acct_obj()
+        self._acct.stop()
+
+    def _destroy(self, curr_class):
+        self._acct = None
+        if curr_class:
+            self._destroy_sem.acquire()
+            try: curr_class.destroy(self.name)
+            finally: self._destroy_sem.release()
+
+    def _get_class(self):
+        try: shell = pwd.getpwnam(self.name)[6]
+        except KeyError: return None
+        return acct_class_by_shell[shell]
+
+    def _make_acct_obj(self):
+        curr_class = self._get_class()
+        if not isinstance(self._acct, curr_class):
+            self._acct = curr_class(self.name)
+
+    def _run(self):
+        while True:
+            try:
+                cmd = self._q.get()
+                cmd[0](*cmd[1:])
+            except: logger.log_exc()
diff --git a/api.py b/api.py
new file mode 100644 (file)
index 0000000..73f38e7
--- /dev/null
+++ b/api.py
@@ -0,0 +1,156 @@
+import SimpleXMLRPCServer
+import SocketServer
+import cPickle
+import errno
+import os
+import pwd
+import socket
+import struct
+import threading
+import xmlrpclib
+
+from config import *
+import accounts
+import database
+import logger
+import tools
+
+
+api_method_dict = {}
+nargs_dict = {}
+
+def export_to_api(nargs):
+    def export(method):
+        nargs_dict[method.__name__] = nargs
+        api_method_dict[method.__name__] = method
+        return method
+    return export
+
+@export_to_api(0)
+def DumpDatabase():
+    """DumpDatabase(): return the entire node manager DB, pickled"""
+    return cPickle.dumps(dict(database._db), 0)
+
+@export_to_api(0)
+def Help():
+    """Help(): get help"""
+    return ''.join([method.__doc__ + '\n' for method in api_method_dict.itervalues()])
+
+@export_to_api(1)
+def CreateSliver(rec):
+    """CreateSliver(sliver_name): set up a non-PLC-instantiated sliver"""
+    if not rec['plc_instantiated']:
+        accounts.get(rec['name']).ensure_created(rec)
+
+@export_to_api(1)
+def DeleteSliver(rec):
+    """DeleteSliver(sliver_name): tear down a non-PLC-instantiated sliver"""
+    if not rec['plc_instantiated']:
+        accounts.get(rec['name']).ensure_destroyed()
+
+@export_to_api(1)
+def Start(rec):
+    """Start(sliver_name): run start scripts belonging to the specified sliver"""
+    accounts.get(rec['name']).start()
+
+@export_to_api(1)
+def Stop(rec):
+    """Stop(sliver_name): kill all processes belonging to the specified sliver"""
+    accounts.get(rec['name']).stop()
+
+@export_to_api(1)
+def GetEffectiveRSpec(rec):
+    """GetEffectiveRSpec(sliver_name): return the RSpec allocated to the specified sliver, including loans"""
+    return tools.deepcopy(rec.get('eff_rspec', {}))
+
+@export_to_api(1)
+def GetRSpec(rec):
+    """GetRSpec(sliver_name): return the RSpec allocated to the specified sliver, excluding loans"""
+    return tools.deepcopy(rec.get('rspec', {}))
+
+@export_to_api(1)
+def GetLoans(rec):
+    """GetLoans(sliver_name): return the list of loans made by the specified sliver"""
+    return tools.deepcopy(rec.get('loans', []))
+
+def validate_loans(obj):
+    """Check that <obj> is a valid loan specification."""
+    def validate_loan(obj):
+        return (type(obj)==list or type(obj)==tuple) and len(obj)==3 and \
+               type(obj[0])==str and \
+               type(obj[1])==str and obj[1] in LOANABLE_RESOURCES and \
+               type(obj[2])==int and obj[2]>0
+    return type(obj)==list and False not in map(validate_loan, obj)
+
+@export_to_api(2)
+def SetLoans(rec, loans):
+    """SetLoans(sliver_name, loans): overwrite the list of loans made by the specified sliver"""
+    if not validate_loans(loans):
+        raise xmlrpclib.Fault(102, 'Invalid argument: the second argument must be a well-formed loan specification')
+    rec['loans'] = loans
+    database.deliver_records([rec])
+
+api_method_list = api_method_dict.keys()
+api_method_list.sort()
+
+
+class APIRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler):
+    # overriding _dispatch to achieve this effect is officially deprecated,
+    # but I can't figure out how to get access to .request
+    # without duplicating SimpleXMLRPCServer code here,
+    # which is more likely to change than the deprecated behavior
+    # is to be broken
+
+    @database.synchronized
+    def _dispatch(self, method_name, args):
+        method_name = str(method_name)
+        try: method = api_method_dict[method_name]
+        except KeyError:
+            raise xmlrpclib.Fault(100, 'Invalid API method %s.  Valid choices are %s' % (method_name, ', '.join(api_method_list)))
+
+        expected_nargs = nargs_dict[method_name]
+        if len(args) != nargs_dict[method_name]:
+            raise xmlrpclib.Fault(101, 'Invalid argument count: got %d, expecting %d.' % (len(args), expected_nargs))
+        else:
+            # XXX - these ought to be imported directly from some .h file
+            SO_PEERCRED = 17
+            sizeof_struct_ucred = 12
+            ucred = self.request.getsockopt(socket.SOL_SOCKET, SO_PEERCRED,
+                                            sizeof_struct_ucred)
+            xid = struct.unpack('3i', ucred)[2]
+            caller_name = pwd.getpwuid(xid)[0]
+
+            if expected_nargs >= 1:
+                target_name = args[0]
+                target_rec = database.get_sliver(target_name)
+                if not target_rec: raise xmlrpclib.Fault(102, 'Invalid argument: the first argument must be a sliver name.')
+
+                if caller_name not in (args[0], 'root') and \
+                       (caller_name, method_name) not in target_rec['delegations']:
+                    raise xmlrpclib.Fault(108, 'Permission denied.')
+                result = method(target_rec, *args[1:])
+            else:
+                if method_name == 'DumpDatabase' and caller_name != 'root':
+                    raise xmlrpclib.Fault(108, 'Permission denied.')
+                result = method()
+            if result == None: result = 1
+            return result
+
+class APIServer_INET(SocketServer.ThreadingMixIn,
+                     SimpleXMLRPCServer.SimpleXMLRPCServer):
+    allow_reuse_address = True
+
+class APIServer_UNIX(APIServer_INET): address_family = socket.AF_UNIX
+
+def start():
+    serv1 = APIServer_INET(('127.0.0.1', API_SERVER_PORT),
+                           requestHandler=APIRequestHandler, logRequests=0)
+    tools.as_daemon_thread(serv1.serve_forever)
+    unix_addr = '/tmp/node_mgr_api'
+    try: os.unlink(unix_addr)
+    except OSError, e:
+        if e.errno != errno.ENOENT: raise
+    serv2 = APIServer_UNIX(unix_addr,
+                           requestHandler=APIRequestHandler, logRequests=0)
+    tools.as_daemon_thread(serv2.serve_forever)
+    os.chmod(unix_addr, 0666)
diff --git a/bwcap.py b/bwcap.py
new file mode 100644 (file)
index 0000000..3d95db0
--- /dev/null
+++ b/bwcap.py
@@ -0,0 +1,18 @@
+import bwlimit
+
+import logger
+import tools
+
+
+_old_rec = {}
+
+def update(rec):
+    global _old_rec
+    if rec != _old_rec:
+        if rec['cap'] != _old_rec.get('cap'):
+            logger.log('setting node bw cap to %d' % rec['cap'])
+#             bwlimit.init('eth0', rec['cap'])
+        if rec['exempt_ips'] != _old_rec.get('exempt_ips'):
+            logger.log('initializing exempt ips to %s' % rec['exempt_ips'])
+#             bwlimit.exempt_init('Internet2', rec['exempt_ips'])
+        _old_rec = tools.deepcopy(rec)
diff --git a/config.py b/config.py
new file mode 100644 (file)
index 0000000..3b41326
--- /dev/null
+++ b/config.py
@@ -0,0 +1,33 @@
+"""Global parameters and configuration."""
+
+try:
+    from bwlimit import bwmin, bwmax
+
+    DEFAULT_RSPEC = {'nm_cpu_share': 32, 'nm_cpu_guaranteed_share': 0,
+                     'nm_disk_quota': 5000000,
+                     'nm_enabled': 1,
+                     'nm_net_min_rate': bwmin, 'nm_net_max_rate': bwmax,
+                     'nm_net_exempt_min_rate': bwmin,
+                     'nm_net_exempt_max_rate': bwmax,
+                     'nm_net_share': 1}
+except ImportError: pass
+
+API_SERVER_PORT = 812
+
+DB_FILE = '/root/pl_node_mgr_db.pickle'
+
+KEY_FILE = '/home/deisenst/nm/key.pem'
+
+LOANABLE_RESOURCES = set(['nm_cpu_share', 'nm_cpu_guaranteed_share',
+                          'nm_net_max_rate', 'nm_net_exempt_max_rate',
+                          'nm_net_share'])
+
+LOG_FILE = '/var/log/pl_node_mgr.log'
+
+PID_FILE = '/var/run/pl_node_mgr.pid'
+
+SA_HOSTNAME = 'plc-a.demo.vmware'
+
+START_DELAY_SECS = 10
+
+TICKET_SERVER_PORT = 1813
diff --git a/database.py b/database.py
new file mode 100644 (file)
index 0000000..bc1155e
--- /dev/null
@@ -0,0 +1,133 @@
+import cPickle
+import sys
+import threading
+import time
+
+from config import DB_FILE
+import accounts
+import bwcap
+import logger
+import tools
+
+
+_db_lock = threading.RLock()
+_db_cond = threading.Condition(_db_lock)
+_dump_requested = False
+
+
+def synchronized(function):
+    def sync_fun(*args, **kw_args):
+        _db_lock.acquire()
+        try: return function(*args, **kw_args)
+        finally: _db_lock.release()
+    sync_fun.__doc__ = function.__doc__
+    sync_fun.__name__ = function.__name__
+    return sync_fun
+
+
+class Database(dict):
+    def deliver_records(self, recs):
+        ts = self.get_timestamp()
+        for rec in recs:
+            old_rec = self.setdefault(rec['record_key'], {})
+            if rec['timestamp'] >= max(ts, old_rec.get('timestamp', 0)):
+                old_rec.update(rec, dirty=True)
+        self.compute_effective_rspecs()
+        if self.get_timestamp() > ts:
+            self.delete_old_records()
+            self.delete_old_accounts()
+            for rec in self.itervalues(): rec['dirty'] = True
+        self.create_new_accounts()
+        self.update_bwcap()
+
+    def get_timestamp(self):
+        return self.get('timestamp', {'timestamp': 0})['timestamp']
+
+
+    def compute_effective_rspecs(self):
+        """Apply loans to field 'rspec' to get field 'eff_rspec'."""
+        slivers = dict([(rec['name'], rec) for rec in self.itervalues() \
+                        if rec.get('account_type') == 'sliver'])
+
+        # Pass 1: copy 'rspec' to 'eff_rspec', saving the old value
+        for sliver in slivers.itervalues():
+            sliver['old_eff_rspec'] = sliver.get('eff_rspec')
+            sliver['eff_rspec'] = sliver['rspec'].copy()
+
+        # Pass 2: apply loans
+        for sliver in slivers.itervalues():
+            remaining_loanable_amount = sliver['rspec'].copy()
+            for other_name, resource, amount in sliver.get('loans', []):
+                if other_name in slivers and \
+                       0 < amount <= remaining_loanable_amount[resource]:
+                    sliver['eff_rspec'][resource] -= amount
+                    remaining_loanable_amount[resource] -= amount
+                    slivers[other_name]['eff_rspec'][resource] += amount
+
+        # Pass 3: mark changed rspecs dirty
+        for sliver in slivers.itervalues():
+            if sliver['eff_rspec'] != sliver['old_eff_rspec']:
+                sliver['needs_update'] = True
+            del sliver['old_eff_rspec']
+
+
+    def delete_old_records(self):
+        ts = self.get_timestamp()
+        now = time.time()
+        for key in self.keys():
+            rec = self[key]
+            if rec['timestamp'] < ts or rec.get('expiry', sys.maxint) < now:
+                del self[key]
+
+    def delete_old_accounts(self):
+        for acct_type, name in accounts.all():
+            if ('%s_%s' % (acct_type, name)) not in self:
+                accounts.get(name).ensure_destroyed()
+
+    def create_new_accounts(self):
+        """Invoke the appropriate create() function for every dirty account."""
+        for rec in self.itervalues():
+            if 'account_type' not in rec: continue
+            if rec['dirty'] and rec['plc_instantiated']:
+                accounts.get(rec['name']).ensure_created(rec)
+            rec['dirty'] = False
+
+    def update_bwcap(self):
+        bwcap_rec = self.get('bwcap')
+        if bwcap_rec and bwcap_rec['dirty']:
+            bwcap.update(bwcap_rec)
+            bwcap_rec['dirty'] = False
+
+
+_db = Database()
+
+@synchronized
+def deliver_records(recs):
+    global _dump_requested
+    _db.deliver_records(recs)
+    _dump_requested = True
+    _db_cond.notify()
+
+@synchronized
+def get_sliver(name): return _db.get('sliver_'+name)
+
+def start():
+    def run():
+        global _dump_requested
+        _db_lock.acquire()
+        try:  # load the db
+            f = open(DB_FILE)
+            _db.update(cPickle.load(f))
+            f.close()
+        except: logger.log_exc()
+        while True:  # handle dump requests forever
+            while not _dump_requested:
+                _db_cond.wait()
+            db_copy = tools.deepcopy(_db)
+            _dump_requested = False
+            _db_lock.release()
+            try: tools.write_file(DB_FILE,
+                                  lambda f: cPickle.dump(db_copy, f, -1))
+            except: logger.log_exc()
+            _db_lock.acquire()
+    tools.as_daemon_thread(run)
diff --git a/delegate.py b/delegate.py
new file mode 100644 (file)
index 0000000..6dd85e8
--- /dev/null
@@ -0,0 +1,34 @@
+import accounts
+import logger
+import tools
+
+
+class Delegate:
+    SHELL = '/bin/forward_api_calls'
+    TYPE = 'delegate'
+
+    def __init__(self, name): self.name = name
+
+    @staticmethod
+    def create(name):
+        add_shell(Delegate.SHELL)
+        logger.log_call('/usr/sbin/useradd',
+                        '-p', '*', '-s', Delegate.SHELL, name)
+
+    @staticmethod
+    def destroy(name): logger.log_call('/usr/sbin/userdel', '-r', name)
+
+    def configure(self, rec): accounts.install_ssh_keys(rec)
+    def start(self): pass
+    def stop(self): pass
+
+
+def add_shell(shell):
+    """Add <shell> to /etc/shells if it's not already there."""
+    etc_shells = open('/etc/shells')
+    valid_shells = etc_shells.read().split()
+    etc_shells.close()
+    if shell not in valid_shells:
+        etc_shells = open('/etc/shells', 'a')
+        print >>etc_shells, shell
+        etc_shells.close()
diff --git a/forward_api_calls.c b/forward_api_calls.c
new file mode 100644 (file)
index 0000000..0d4536b
--- /dev/null
@@ -0,0 +1,124 @@
+/* forward_api_calls.c: forward XMLRPC calls to the Node Manager
+ * Used as a shell, this code works in tandem with sshd
+ * to allow authenticated remote access to a localhost-only service.
+ *
+ * Bugs:
+ * Doesn't handle Unicode properly.  UTF-8 is probably OK.
+ *
+ * Change History:
+ * 2006/09/14: [deisenst] Switched to PF_UNIX sockets so that SO_PEERCRED works
+ * 2006/09/08: [deisenst] First version.
+ */
+
+static const int TIMEOUT_SECS = 30;
+const char *API_addr = "/tmp/node_mgr_api";
+
+static const char *Header =
+  "POST / HTTP/1.0\r\n"
+  "Content-Type: text/xml\r\n"
+  "Content-Length: %d\r\n"
+  "\r\n%n";
+
+static const char *Error_template =
+  "<?xml version='1.0'?>\n"
+  "<methodResponse>\n"
+  "<fault>\n"
+  "<value><struct>\n"
+  "<member>\n"
+  "<name>faultCode</name>\n"
+  "<value><int>1</int></value>\n"
+  "</member>\n"
+  "<member>\n"
+  "<name>faultString</name>\n"
+  "<value><string>%s: %s</string></value>\n"
+  "</member>\n"
+  "</struct></value>\n"
+  "</fault>\n"
+  "</methodResponse>\n";
+
+#include <ctype.h>
+#include <errno.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+static void ERROR(const char *s) {
+  printf(Error_template, s, strerror(errno));
+  exit(1);
+}
+
+int main(int argc, char **argv, char **envp) {
+  ssize_t len;
+  char header_buf[4096];
+  char content_buf[4096];
+  size_t content_len;
+  int sockfd;
+  struct sockaddr_un addr;
+  int consecutive_newlines;
+
+  alarm(TIMEOUT_SECS);
+
+  /* read xmlrpc request from stdin
+   * 4 KiB ought to be enough for anyone
+   */
+  content_len = 0;
+  while(content_len < sizeof content_buf) {
+    len = read(0,
+              content_buf + content_len,
+              sizeof content_buf - content_len);
+    if(len < 0) ERROR("read()");
+    else if(0 == len) break;
+    content_len += len;
+  }
+
+  /* connect to the API server */
+  sockfd = socket(PF_UNIX, SOCK_STREAM, 0);
+  if(sockfd < 0)
+    ERROR("socket()");
+  memset(&addr, 0, sizeof addr);
+  addr.sun_family = AF_UNIX;
+  strncpy(addr.sun_path, API_addr, sizeof addr.sun_path);
+  if(connect(sockfd, (struct sockaddr *)&addr, sizeof addr) < 0)
+    ERROR("connect()");
+
+  /* send the request */
+  snprintf(header_buf, sizeof header_buf, Header, content_len, &len);
+  write(sockfd, header_buf, len);
+  write(sockfd, content_buf, content_len);
+  shutdown(sockfd, SHUT_WR);
+
+  /* forward the response */
+  consecutive_newlines = 0;
+  while((len = read(sockfd, content_buf, sizeof content_buf)) != 0) {
+    size_t processed_len = 0;
+    if(len < 0) {
+      /* "Connection reset by peer" is not worth bothering the user. */
+      if(ECONNRESET == errno) break;
+      else ERROR("read()");
+    }
+    content_len = len;
+
+    while(consecutive_newlines < 2 && processed_len < content_len) {
+      char ch = content_buf[processed_len++];
+      if(ch == '\n') consecutive_newlines++;
+      else if(!isspace(ch)) consecutive_newlines = 0;
+    }
+
+    if(processed_len < content_len) {
+      len = fwrite(content_buf + processed_len, sizeof (char),
+                  content_len - processed_len, stdout);
+      /* make sure faults don't mess up previously sent xml */
+      if(len < content_len - processed_len) ERROR("fwrite()");
+    }
+  }
+
+  /* goodbye */
+  shutdown(sockfd, SHUT_RD);
+  close(sockfd);
+
+  return 0;
+}
diff --git a/logger.py b/logger.py
new file mode 100644 (file)
index 0000000..3411df5
--- /dev/null
+++ b/logger.py
@@ -0,0 +1,27 @@
+import fcntl
+import os
+import subprocess
+import time
+import traceback
+
+from config import LOG_FILE
+
+
+def log(msg):
+    """Write <msg> to the log file."""
+    # the next three lines ought to be an atomic operation but aren't
+    fd = os.open(LOG_FILE, os.O_WRONLY | os.O_CREAT | os.O_APPEND, 0600)
+    flags = fcntl.fcntl(fd, fcntl.F_GETFD)
+    fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC)
+    if not msg.endswith('\n'): msg += '\n'
+    os.write(fd, '%s: %s' % (time.asctime(time.gmtime()), msg))
+    os.close(fd)
+
+def log_call(*args):
+    log('running command %s' % ' '.join(args))
+    try: subprocess.call(args)
+    except: log_exc()
+
+def log_exc():
+    """Log the traceback resulting from an exception."""
+    log(traceback.format_exc())
diff --git a/nm.py b/nm.py
new file mode 100644 (file)
index 0000000..fa527da
--- /dev/null
+++ b/nm.py
@@ -0,0 +1,50 @@
+"""Node Manager"""
+
+import optparse
+import time
+
+from config import *
+import accounts
+import api
+import database
+import delegate
+import logger
+import plc
+import sliver
+import tools
+
+
+parser = optparse.OptionParser()
+parser.add_option('-d', '--daemon',
+                  action='store_true', dest='daemon', default=False,
+                  help='run daemonized')
+parser.add_option('-s', '--startup',
+                  action='store_true', dest='startup', default=False,
+                  help='run all sliver startup scripts')
+(options, args) = parser.parse_args()
+
+def run():
+    try:
+        if options.daemon: tools.daemon()
+
+        accounts.register_account_type(sliver.Sliver)
+        accounts.register_account_type(delegate.Delegate)
+
+        other_pid = tools.pid_file()
+        if other_pid != None:
+            print """There might be another instance of the node manager running as pid %d.  If this is not the case, please remove the pid file %s""" % (other_pid, PID_FILE)
+            return
+
+        database.start()
+        api.start()
+        while True:
+            try: plc.fetch_and_update()
+            except: logger.log_exc()
+            time.sleep(10)
+    except: logger.log_exc()
+
+
+if __name__ == '__main__': run()
+else:
+    # This is for debugging purposes.  Open a copy of Python and import nm
+    tools.as_daemon_thread(run)
diff --git a/plc.py b/plc.py
new file mode 100644 (file)
index 0000000..0d705c0
--- /dev/null
+++ b/plc.py
@@ -0,0 +1,123 @@
+"""Parse slices.xml.  This file will become obsolete when the new API comes online."""
+
+import base64
+import sys
+sys.path.append('/usr/local/planetlab/bin')
+import SslFetch
+import time
+import xml.parsers.expat
+
+from config import *
+import database
+import logger
+
+
+_worker = SslFetch.Worker(SA_HOSTNAME, cacert_file='/usr/boot/cacert.pem')
+
+def fetch(filename):
+    logger.log('fetching %s' % filename)
+    (rc, data) = _worker.fetch(filename)
+    if rc == 0:
+        logger.log('fetch succeeded')
+        return data
+    else:
+        # XXX - should get a better error message from SslFetch/libcurl
+        curl_doc = 'http://curl.haxx.se/libcurl/c/libcurl-errors.html'
+        raise 'fetch failed, rc=%d (see %s)' % (rc, curl_doc)
+
+
+delegate_key = 'ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEAzNQIrVC9ZV9iDgu5/WXxcH/SyGdLG45CWXoWWh37UNA4dCVVlxtQ96xF7poolnxnM1irKUiXx85FsjA37z6m7IWl1h9uMYEJEvYkkxApsCmwm8C02m/BsOWK4Zjh4sv7QTeDgDnqhwnBw/U4jnkt8yKfVTBTNUY01dESzOgBfBc= root@yankee.cs.princeton.edu'
+
+def fetch_and_update():
+    sx = slices_xml(fetch('/xml/slices-0.5.xml'))
+    # sx = slices_xml(open('/root/slices-0.5.xml').read())
+    recs = [{'record_key': 'timestamp', 'type': 'timestamp', 'timestamp': time.time()}]
+    recs.append({'record_key': 'delegate_del_snoop', 'timestamp': time.time(), 'account_type': 'delegate', 'name': 'del_snoop', 'ssh_keys': delegate_key, 'plc_instantiated': True})
+    recs.append({'record_key': 'bwcap', 'timestamp': time.time(), 'cap': 5000000000, 'exempt_ips': ['127.0.0.1']})
+    for id, name in sx.id_name.iteritems():
+        rec = {}
+        rec['record_key'] = 'sliver_' + name
+        rec['account_type'] = 'sliver'
+        rec['name'] = name
+        rec['expiry'] = sx.id_expiry[id]
+        rec['timestamp'] = sx.id_ts.get(id) or time.time()
+        rec['delegations'] = [('del_snoop', 'GetRSpec')]
+        rec['id'] = id
+        rec['rspec'] = sx.get_rspec(id)
+        ssh_keys = []
+        for uid in sx.id_uids[id]: ssh_keys.extend(sx.uid_keys[uid])
+        rec['ssh_keys'] = '\n'.join(ssh_keys)
+        rec['plc_instantiated'] = True
+        rec['initscript'] = base64.b64encode('#!/bin/sh\n/bin/echo hello >/world.txt')
+        recs.append(rec)
+    database.deliver_records(recs)
+
+
+node_id = None
+
+def get_node_id():
+    global node_id
+    if node_id == None:
+        filename = '/etc/planetlab/node_id'
+        logger.log('reading node id from %s' % filename)
+        id_file = open(filename)
+        node_id = int(id_file.readline())
+        id_file.close()
+    return node_id
+
+
+class slices_xml:
+    def __init__(self, data):
+        self.node_id = get_node_id()
+        self.id_name = {}
+        self.id_expiry = {}
+        self.id_uids = {}
+        self.uid_keys = {}
+        self.id_rspec = {}
+        self.id_ts = {}
+        parser = xml.parsers.expat.ParserCreate()
+        parser.StartElementHandler = self._start_element
+        parser.CharacterDataHandler = self._char_data
+        isfinal = True
+        parser.Parse(data, isfinal)
+
+    def get_rspec(self, id):
+        rspec = DEFAULT_RSPEC.copy()
+        rspec.update(self.id_rspec[id])
+        return rspec
+
+    def _start_element(self, name, attrs):
+        self.last_tag = name
+        if   name == u'slice':
+            self.id = int(attrs[u'id'])
+            self.name = str(attrs[u'name'])
+            self.expiry = int(attrs[u'expiry'])
+        elif name == u'timestamp':
+            self.id_ts[self.id] = int(attrs[u'value'])
+        elif name == u'node':
+            # remember slices with slivers on us
+            nid = int(attrs[u'id'])
+            if nid == self.node_id:
+                self.id_name[self.id] = self.name
+                self.id_expiry[self.id] = self.expiry
+                self.id_uids[self.id] = []
+                self.id_rspec[self.id] = {}
+        elif name == u'user':
+            # remember users with slices with slivers on us
+            if self.id in self.id_name:
+                uid = int(attrs[u'person_id'])
+                self.id_uids[self.id].append(uid)
+                self.uid_keys[uid] = []
+        elif name == u'resource':
+            self.rname = str(attrs[u'name'])
+        elif name == u'key':
+            # remember keys of users with slices with slivers on us
+            uid = int(attrs[u'person_id'])
+            if uid in self.uid_keys:
+                self.uid_keys[uid].append(str(attrs[u'value']))
+
+    def _char_data(self, data):
+        if self.last_tag == u'value' and self.id in self.id_name:
+            try: self.id_rspec[self.id][self.rname] = int(data)
+            except ValueError: pass
+        self.last_tag = u''
diff --git a/sliver.py b/sliver.py
new file mode 100644 (file)
index 0000000..0a299da
--- /dev/null
+++ b/sliver.py
@@ -0,0 +1,109 @@
+import base64
+import errno
+import os
+import vserver
+
+from config import DEFAULT_RSPEC
+import accounts
+import logger
+import tools
+
+
+class Sliver(vserver.VServer):
+    """This class wraps vserver.VServer to make its interface closer to what we need for the Node Manager."""
+
+    SHELL = '/bin/vsh'
+    TYPE = 'sliver'
+
+    def __init__(self, name):
+        vserver.VServer.__init__(self, name, vm_running=True)
+        self.disk_limit_has_been_set = False
+        self.rspec = DEFAULT_RSPEC.copy()
+        self.ssh_keys = None
+        self.initscript = ''
+
+    @staticmethod
+    def create(name): logger.log_call('/usr/sbin/vuseradd', name)
+
+    @staticmethod
+    def destroy(name): logger.log_call('/usr/sbin/vuserdel', name)
+
+    def configure(self, rec):
+        self.rspec.update(rec['eff_rspec'])
+        self.set_resources()
+        if rec['ssh_keys'] != self.ssh_keys:
+            accounts.install_ssh_keys(rec)
+            self.ssh_keys = rec['ssh_keys']
+        if rec['initscript'] != self.initscript:
+            logger.log('%s: installing initscript' % self.name)
+            def install_initscript():
+                flags = os.O_WRONLY|os.O_CREAT|os.O_TRUNC
+                fd = os.open('/etc/rc.vinit', flags, 0755)
+                os.write(fd, base64.b64decode(rec['initscript']))
+                os.close(fd)
+            try: self.chroot_call(install_initscript)
+            except OSError, e:
+                if e.errno != errno.EEXIST: logger.log_exc()
+            self.initscript = rec['initscript']
+
+    def start(self):
+        if self.rspec['nm_enabled']:
+            logger.log('%s: starting' % self.name)
+            child_pid = os.fork()
+            if child_pid == 0:
+                vserver.VServer.start(self, True)
+                os._exit(0)
+            else: os.waitpid(child_pid, 0)
+        else: logger.log('%s: not starting, is not enabled' % self.name)
+
+    def stop(self):
+        logger.log('%s: stopping' % self.name)
+        vserver.VServer.stop(self)
+        # make sure we always make the syscalls when setting resource limits
+        self.vm_running = True
+
+    def set_resources(self):
+        """Set the resource limits of sliver <self.name>."""
+        # disk limits
+        disk_max_KiB = self.rspec['nm_disk_quota']
+        logger.log('%s: setting max disk usage to %d KiB' %
+                   (self.name, disk_max_KiB))
+        try:  # don't let slivers over quota escape other limits
+            if not self.disk_limit_has_been_set:
+                self.vm_running = False
+                logger.log('%s: computing disk usage' % self.name)
+                self.init_disk_info()
+                # even if set_disklimit() triggers an exception,
+                # the kernel probably knows the disk usage
+                self.disk_limit_has_been_set = True
+            vserver.VServer.set_disklimit(self, disk_max_KiB)
+            self.vm_running = True
+        except OSError: logger.log_exc()
+
+        # bw limits
+        bw_fields = ['nm_net_min_rate', 'nm_net_max_rate',
+                     'nm_net_exempt_min_rate', 'nm_net_exempt_max_rate',
+                     'nm_net_share']
+        args = tuple(map(self.rspec.__getitem__, bw_fields))
+        logger.log('%s: setting bw share to %d' % (self.name, args[-1]))
+        logger.log('%s: setting bw limits to %s bps' % (self.name, args[:-1]))
+        self.set_bwlimit(*args)
+
+        # cpu limits / remote login
+        cpu_guaranteed_shares = self.rspec['nm_cpu_guaranteed_share']
+        cpu_shares = self.rspec['nm_cpu_share']
+        if self.rspec['nm_enabled']:
+            if cpu_guaranteed_shares > 0:
+                logger.log('%s: setting cpu share to %d%% guaranteed' %
+                           (self.name, cpu_guaranteed_shares/10.0))
+                self.set_sched_config(cpu_guaranteed_shares,
+                                      vserver.SCHED_CPU_GUARANTEED)
+            else:
+                logger.log('%s: setting cpu share to %d' %
+                           (self.name, cpu_shares))
+                self.set_sched_config(cpu_shares, 0)
+        else:
+            # tell vsh to disable remote login by setting CPULIMIT to 0
+            logger.log('%s: disabling remote login' % self.name)
+            self.set_sched_config(0, 0)
+            self.stop()
diff --git a/ticket.py b/ticket.py
new file mode 100644 (file)
index 0000000..3389027
--- /dev/null
+++ b/ticket.py
@@ -0,0 +1,55 @@
+import SocketServer
+import os
+import subprocess
+
+from config import KEY_FILE, TICKET_SERVER_PORT
+import tools
+
+
+class TicketServer(SocketServer.ThreadingMixIn, SocketServer.TCPServer):
+    allow_reuse_address = True
+
+
+class TicketRequestHandler(SocketServer.StreamRequestHandler):
+    def handle(self):
+        data = self.rfile.read()
+        filename = tools.write_temp_file(lambda thefile:
+                                         thefile.write(TEMPLATE % data))
+        result = subprocess.Popen([XMLSEC1, '--sign',
+                                   '--privkey-pem', KEY_FILE, filename],
+                                  stdout=subprocess.PIPE).stdout
+        self.wfile.write(result.read())
+        result.close()
+#         os.unlink(filename)
+
+
+def start():
+    tools.as_daemon_thread(TicketServer(('', TICKET_SERVER_PORT),
+                                        TicketRequestHandler).serve_forever)
+
+
+XMLSEC1 = '/usr/bin/xmlsec1'
+
+TEMPLATE = '''<?xml version="1.0" encoding="UTF-8"?>
+<Envelope xmlns="urn:envelope">
+  <Data>%s</Data>
+  <Signature xmlns="http://www.w3.org/2000/09/xmldsig#">
+    <SignedInfo>
+      <CanonicalizationMethod Algorithm="http://www.w3.org/TR/2001/REC-xml-c14n-20010315" />
+      <SignatureMethod Algorithm="http://www.w3.org/2000/09/xmldsig#rsa-sha1" />
+      <Reference URI="">
+        <Transforms>
+          <Transform Algorithm="http://www.w3.org/2000/09/xmldsig#enveloped-signature" />
+        </Transforms>
+        <DigestMethod Algorithm="http://www.w3.org/2000/09/xmldsig#sha1" />
+        <DigestValue></DigestValue>
+      </Reference>
+    </SignedInfo>
+    <SignatureValue/>
+    <KeyInfo>
+       <KeyName/>
+    </KeyInfo>
+  </Signature>
+</Envelope>
+'''
+
diff --git a/tools.py b/tools.py
new file mode 100644 (file)
index 0000000..bc391a9
--- /dev/null
+++ b/tools.py
@@ -0,0 +1,89 @@
+import cPickle
+import errno
+import os
+import pwd
+import tempfile
+import threading
+
+from config import PID_FILE
+import logger
+
+
+def as_daemon_thread(run):
+    thr = threading.Thread(target=run)
+    thr.setDaemon(True)
+    thr.start()
+
+
+# after http://www.erlenstar.demon.co.uk/unix/faq_2.html
+def daemon():
+    """Daemonize the current process."""
+    if os.fork() != 0: os._exit(0)
+    os.setsid()
+    if os.fork() != 0: os._exit(0)
+    os.chdir('/')
+    os.umask(0)
+    devnull = os.open(os.devnull, os.O_RDWR)
+    for fd in range(3): os.dup2(devnull, fd)
+
+
+def deepcopy(obj):
+    """Return a deep copy of obj."""
+    return cPickle.loads(cPickle.dumps(obj, -1))
+
+
+def fork_as(su, function, *args):
+    """fork(), cd / to avoid keeping unused directories open, close all nonstandard file descriptors (to avoid capturing open sockets), fork() again (to avoid zombies) and call <function> with arguments <args> in the grandchild process.  If <su> is not None, set our group and user ids appropriately in the child process."""
+    child_pid = os.fork()
+    if child_pid == 0:
+        try:
+            os.chdir('/')
+            # close all nonstandard file descriptors
+            _SC_OPEN_MAX = 4
+            for fd in range(3, os.sysconf(_SC_OPEN_MAX)):
+                try: os.close(fd)
+                except OSError: pass  # most likely an fd that isn't open
+            pw_ent = pwd.getpwnam(su)
+            os.setegid(pw_ent[3])
+            os.seteuid(pw_ent[2])
+            child_pid = os.fork()
+            if child_pid == 0: function(*args)
+        except:
+            os.seteuid(os.getuid())  # undo su so we can write the log file
+            os.setegid(os.getgid())
+            logger.log_exc()
+        os._exit(0)
+    else: os.waitpid(child_pid, 0)
+
+
+def pid_file():
+    """We use a pid file to ensure that only one copy of NM is running at a given time.  If successful, this function will write a pid file containing the pid of the current process.  The return value is the pid of the other running process, or None otherwise."""
+    other_pid = None
+    # check for a pid file
+    if os.access(PID_FILE, os.F_OK):
+        # pid file exists, read it
+        handle = open(PID_FILE)
+        other_pid = int(handle.read())
+        handle.close()
+        # check for a process with that pid by sending signal 0
+        try: os.kill(other_pid, 0)
+        except OSError, e:
+            if e.errno == errno.ESRCH: other_pid = None  # doesn't exist
+            else: raise  # who knows
+    if other_pid == None:
+        # write a new pid file
+        write_file(PID_FILE, lambda thefile: thefile.write(str(os.getpid())))
+    return other_pid
+
+
+def write_file(filename, do_write):
+    """Write file <filename> atomically by opening a temporary file, using <do_write> to write that file, and then renaming the temporary file."""
+    os.rename(write_temp_file(do_write), filename)
+
+
+def write_temp_file(do_write):
+    fd, temporary_filename = tempfile.mkstemp()
+    thefile = os.fdopen(fd, 'w')
+    do_write(thefile)
+    thefile.close()
+    return temporary_filename