...
authorDavid E. Eisenstat <deisenst@cs.princeton.edu>
Fri, 27 Oct 2006 19:24:42 +0000 (19:24 +0000)
committerDavid E. Eisenstat <deisenst@cs.princeton.edu>
Fri, 27 Oct 2006 19:24:42 +0000 (19:24 +0000)
accounts.py
api.py
database.py
delegate.py
forward_api_calls.c
logger.py
nm.py
sliver_vs.py
tools.py

index 1c495ee..5066b9e 100644 (file)
@@ -1,16 +1,13 @@
 """Functionality common to all account classes.
 
-Each account class must provide five methods: create(), destroy(),
-configure(), start(), and stop().  In addition, it must provide static
-member variables SHELL, which contains the unique shell that it uses;
-and TYPE, which contains a description of the type that it uses.  TYPE
-is divided hierarchically by periods; at the moment the only
-convention is that all sliver accounts have type that begins with
+Each subclass of Account must provide five methods: create(),
+destroy(), configure(), start(), and stop().  In addition, it must
+provide static member variables SHELL, which contains the unique shell
+that it uses; and TYPE, which contains a description of the type that
+it uses.  TYPE is divided hierarchically by periods; at the moment the
+only convention is that all sliver accounts have type that begins with
 sliver.
 
-Because Python does dynamic method lookup, we do not bother with a
-boilerplate abstract superclass.
-
 There are any number of race conditions that may result from the fact
 that account names are not unique over time.  Moreover, it's a bad
 idea to perform lengthy operations while holding the database lock.
@@ -42,8 +39,8 @@ def register_class(acct_class):
 
 
 # private account name -> worker object association and associated lock
-_name_worker_lock = threading.Lock()
-_name_worker = {}
+name_worker_lock = threading.Lock()
+name_worker = {}
 
 def all():
     """Return the names of all accounts on the system with recognized shells."""
@@ -51,22 +48,38 @@ def all():
 
 def get(name):
     """Return the worker object for a particular username.  If no such object exists, create it first."""
-    _name_worker_lock.acquire()
+    name_worker_lock.acquire()
     try:
-        if name not in _name_worker: _name_worker[name] = Worker(name)
-        return _name_worker[name]
-    finally: _name_worker_lock.release()
-
-
-def install_keys(rec):
-    """Write <rec['keys']> to <rec['name']>'s authorized_keys file."""
-    name = rec['name']
-    dot_ssh = '/home/%s/.ssh' % name
-    def do_installation():
-        if not os.access(dot_ssh, os.F_OK): os.mkdir(dot_ssh)
-        tools.write_file(dot_ssh + '/authorized_keys', lambda thefile: thefile.write(rec['keys']))
-    logger.log('%s: installing ssh keys' % name)
-    tools.fork_as(name, do_installation)
+        if name not in name_worker: name_worker[name] = Worker(name)
+        return name_worker[name]
+    finally: name_worker_lock.release()
+
+
+class Account:
+    def __init__(self, rec):
+        self.name = rec['name']
+        self.keys = ''
+        self.configure(rec)
+
+    @staticmethod
+    def create(name): abstract
+    @staticmethod
+    def destroy(name): abstract
+
+    def configure(self, rec):
+        """Write <rec['keys']> to my authorized_keys file."""
+        new_keys = rec['keys']
+        if new_keys != self.keys:
+            self.keys = new_keys
+            dot_ssh = '/home/%s/.ssh' % self.name
+            def do_installation():
+                if not os.access(dot_ssh, os.F_OK): os.mkdir(dot_ssh)
+                tools.write_file(dot_ssh + '/authorized_keys', lambda f: f.write(keys))
+            logger.log('%s: installing ssh keys' % self.name)
+            tools.fork_as(self.name, do_installation)
+
+    def start(self): pass
+    def stop(self): pass
 
 
 class Worker:
@@ -89,28 +102,24 @@ class Worker:
 
     def _ensure_created(self, rec):
         curr_class = self._get_class()
-        next_class = type_acct_class[rec['account_type']]
+        next_class = type_acct_class[rec['type']]
         if next_class != curr_class:
             self._destroy(curr_class)
             self._create_sem.acquire()
             try: next_class.create(self.name)
             finally: self._create_sem.release()
-        self._make_acct_obj()
-        self._acct.configure(rec)
+        if not isinstance(self._acct, next_class): self._acct = next_class(self.name, rec)
+        else: self._acct.configure(rec)
         if next_class != curr_class: self._acct.start()
 
     def ensure_destroyed(self): self._q.put((self._ensure_destroyed,))
     def _ensure_destroyed(self): self._destroy(self._get_class())
 
     def start(self): self._q.put((self._start,))
-    def _start(self):
-        self._make_acct_obj()
-        self._acct.start()
+    def _start(self): self._acct.start()
 
     def stop(self): self._q.put((self._stop,))
-    def _stop(self):
-        self._make_acct_obj()
-        self._acct.stop()
+    def _stop(self): self._acct.stop()
 
     def _destroy(self, curr_class):
         self._acct = None
@@ -124,10 +133,6 @@ class Worker:
         except KeyError: return None
         return shell_acct_class[shell]
 
-    def _make_acct_obj(self):
-        curr_class = self._get_class()
-        if not isinstance(self._acct, curr_class): self._acct = curr_class(self.name)
-
     def _run(self):
         while True:
             try:
diff --git a/api.py b/api.py
index 97d35d9..c4dec57 100644 (file)
--- a/api.py
+++ b/api.py
@@ -16,6 +16,7 @@ import tools
 
 
 API_SERVER_PORT = 812
+UNIX_ADDR = '/tmp/node_mgr.api'
 
 
 api_method_dict = {}
@@ -28,10 +29,6 @@ def export_to_api(nargs):
         return method
     return export
 
-@export_to_api(0)
-def DumpDatabase():
-    """DumpDatabase(): return the entire node manager DB, pickled"""
-    return cPickle.dumps(dict(database._db), 0)
 
 @export_to_api(0)
 def Help():
@@ -40,15 +37,13 @@ def Help():
 
 @export_to_api(1)
 def CreateSliver(rec):
-    """CreateSliver(sliver_name): set up a non-PLC-instantiated sliver"""
-    if not rec['plc_instantiated']:
-        accounts.get(rec['name']).ensure_created(rec)
+    """CreateSliver(sliver_name): create a non-PLC-instantiated sliver"""
+    if rec['instantiation'] == 'delegated': accounts.get(rec['name']).ensure_created(rec)
 
 @export_to_api(1)
-def DeleteSliver(rec):
-    """DeleteSliver(sliver_name): tear down a non-PLC-instantiated sliver"""
-    if not rec['plc_instantiated']:
-        accounts.get(rec['name']).ensure_destroyed()
+def DestroySliver(rec):
+    """DestroySliver(sliver_name): destroy a non-PLC-instantiated sliver"""
+    if rec['instantiation'] == 'delegated': accounts.get(rec['name']).ensure_destroyed()
 
 @export_to_api(1)
 def Start(rec):
@@ -63,98 +58,76 @@ def Stop(rec):
 @export_to_api(1)
 def GetEffectiveRSpec(rec):
     """GetEffectiveRSpec(sliver_name): return the RSpec allocated to the specified sliver, including loans"""
-    return tools.deepcopy(rec.get('eff_rspec', {}))
+    return rec.get('_rspec', {}).copy()
 
 @export_to_api(1)
 def GetRSpec(rec):
     """GetRSpec(sliver_name): return the RSpec allocated to the specified sliver, excluding loans"""
-    return tools.deepcopy(rec.get('rspec', {}))
+    return rec.get('rspec', {}).copy()
 
 @export_to_api(1)
 def GetLoans(rec):
     """GetLoans(sliver_name): return the list of loans made by the specified sliver"""
-    return tools.deepcopy(rec.get('loans', []))
+    return rec.get('_loans', []).copy()
 
 def validate_loans(obj):
     """Check that <obj> is a valid loan specification."""
-    def validate_loan(obj):
-        return (type(obj)==list or type(obj)==tuple) and len(obj)==3 and \
-               type(obj[0])==str and \
-               type(obj[1])==str and obj[1] in LOANABLE_RESOURCES and \
-               type(obj[2])==int and obj[2]>0
+    def validate_loan(obj): return (type(obj)==list or type(obj)==tuple) and len(obj)==3 and type(obj[0])==str and type(obj[1])==str and obj[1] in database.LOANABLE_RESOURCES and type(obj[2])==int and obj[2]>0
     return type(obj)==list and False not in map(validate_loan, obj)
 
 @export_to_api(2)
 def SetLoans(rec, loans):
     """SetLoans(sliver_name, loans): overwrite the list of loans made by the specified sliver"""
-    if not validate_loans(loans):
-        raise xmlrpclib.Fault(102, 'Invalid argument: the second argument must be a well-formed loan specification')
-    rec['loans'] = loans
-    database.deliver_records([rec])
-
-api_method_list = api_method_dict.keys()
-api_method_list.sort()
+    if not validate_loans(loans): raise xmlrpclib.Fault(102, 'Invalid argument: the second argument must be a well-formed loan specification')
+    rec['_loans'] = loans
+    database.db.sync()
 
 
 class APIRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler):
     # overriding _dispatch to achieve this effect is officially deprecated,
-    # but I can't figure out how to get access to .request
-    # without duplicating SimpleXMLRPCServer code here,
-    # which is more likely to change than the deprecated behavior
-    # is to be broken
+    # but I can't figure out how to get access to .request without
+    # duplicating SimpleXMLRPCServer code here, which is more likely to
+    # change than the deprecated behavior is to be broken
 
     @database.synchronized
-    def _dispatch(self, method_name, args):
-        method_name = str(method_name)
+    def _dispatch(self, method_name_unicode, args):
+        method_name = str(method_name_unicode)
         try: method = api_method_dict[method_name]
         except KeyError:
+            api_method_list = api_method_dict.keys()
+            api_method_list.sort()
             raise xmlrpclib.Fault(100, 'Invalid API method %s.  Valid choices are %s' % (method_name, ', '.join(api_method_list)))
-
         expected_nargs = nargs_dict[method_name]
-        if len(args) != nargs_dict[method_name]:
-            raise xmlrpclib.Fault(101, 'Invalid argument count: got %d, expecting %d.' % (len(args), expected_nargs))
+        if len(args) != expected_nargs: raise xmlrpclib.Fault(101, 'Invalid argument count: got %d, expecting %d.' % (len(args), expected_nargs))
         else:
             # Figure out who's calling.
             # XXX - these ought to be imported directly from some .h file
             SO_PEERCRED = 17
             sizeof_struct_ucred = 12
-            ucred = self.request.getsockopt(socket.SOL_SOCKET, SO_PEERCRED,
-                                            sizeof_struct_ucred)
+            ucred = self.request.getsockopt(socket.SOL_SOCKET, SO_PEERCRED, sizeof_struct_ucred)
             xid = struct.unpack('3i', ucred)[2]
             caller_name = pwd.getpwuid(xid)[0]
-
             if expected_nargs >= 1:
                 target_name = args[0]
-                target_rec = database.get_sliver(target_name)
-                if not target_rec: raise xmlrpclib.Fault(102, 'Invalid argument: the first argument must be a sliver name.')
-
-                if caller_name not in (args[0], 'root') and \
-                       (caller_name, method_name) not in target_rec['delegations']:
-                    raise xmlrpclib.Fault(108, 'Permission denied.')
+                target_rec = database.db.get(target_name)
+                if not (target_rec and target_rec['type'].startswith('sliver.')): raise xmlrpclib.Fault(102, 'Invalid argument: the first argument must be a sliver name.')
+                if not (caller_name in (args[0], 'root') or (caller_name, method_name) in target_rec['delegations']): raise xmlrpclib.Fault(108, 'Permission denied.')
                 result = method(target_rec, *args[1:])
-            else:
-                if method_name == 'DumpDatabase' and caller_name != 'root':
-                    raise xmlrpclib.Fault(108, 'Permission denied.')
-                result = method()
+            else: result = method()
             if result == None: result = 1
             return result
 
-class APIServer_INET(SocketServer.ThreadingMixIn,
-                     SimpleXMLRPCServer.SimpleXMLRPCServer):
-    allow_reuse_address = True
+class APIServer_INET(SocketServer.ThreadingMixIn, SimpleXMLRPCServer.SimpleXMLRPCServer): allow_reuse_address = True
 
 class APIServer_UNIX(APIServer_INET): address_family = socket.AF_UNIX
 
 def start():
     """Start two XMLRPC interfaces: one bound to localhost, the other bound to a Unix domain socket."""
-    serv1 = APIServer_INET(('127.0.0.1', API_SERVER_PORT),
-                           requestHandler=APIRequestHandler, logRequests=0)
+    serv1 = APIServer_INET(('127.0.0.1', API_SERVER_PORT), requestHandler=APIRequestHandler, logRequests=0)
     tools.as_daemon_thread(serv1.serve_forever)
-    unix_addr = '/tmp/node_mgr_api'
-    try: os.unlink(unix_addr)
+    try: os.unlink(UNIX_ADDR)
     except OSError, e:
         if e.errno != errno.ENOENT: raise
-    serv2 = APIServer_UNIX(unix_addr,
-                           requestHandler=APIRequestHandler, logRequests=0)
+    serv2 = APIServer_UNIX(UNIX_ADDR, requestHandler=APIRequestHandler, logRequests=0)
     tools.as_daemon_thread(serv2.serve_forever)
-    os.chmod(unix_addr, 0666)
+    os.chmod(UNIX_ADDR, 0666)
index 4555d48..ede28d9 100644 (file)
 import cPickle
-import sys
 import threading
 import time
 
+try: from bwlimit import bwmin, bwmax
+except ImportError: bwmin, bwmax = 8, 1000000000
 import accounts
-import bwcap
 import logger
 import tools
 
 
-DB_FILE = '/root/pl_node_mgr_db.pickle'
+DB_FILE = '/root/node_mgr_db.pickle'
 
+LOANABLE_RESOURCES = ['cpu_min', 'cpu_share', 'net_min', 'net_max', 'net2_min', 'net2_max', 'net_share', 'disk_max']
 
-class Database(dict):
-    def __init__(self): self.account_index = {}
-
-    def deliver_records(self, recs):
-        ts = self.get_timestamp()
-        for rec in recs:
-            old_rec = self.setdefault(rec['record_key'], {})
-            if rec['timestamp'] >= max(ts, old_rec.get('timestamp', 0)):
-                old_rec.update(rec, dirty=True)
-        self.compute_effective_rspecs()
-        if self.get_timestamp() > ts:
-            self.delete_old_records()
-            self.delete_old_accounts()
-            for rec in self.itervalues(): rec['dirty'] = True
-        self.create_new_accounts()
-        self.update_bwcap()
-
-    def compute_effective_rspecs(self):
-        """Apply loans to field 'rspec' to get field 'eff_rspec'."""
-        slivers = dict([(rec['name'], rec) for rec in self.itervalues() \
-                        if rec.get('account_type') == 'sliver.VServer'])
-
-        # Pass 1: copy 'rspec' to 'eff_rspec', saving the old value
-        for sliver in slivers.itervalues():
-            sliver['old_eff_rspec'] = sliver.get('eff_rspec')
-            sliver['eff_rspec'] = sliver['rspec'].copy()
-
-        # Pass 2: apply loans
-        for sliver in slivers.itervalues():
-            remaining_loanable_amount = sliver['rspec'].copy()
-            for other_name, resource, amount in sliver.get('loans', []):
-                if other_name in slivers and \
-                       0 < amount <= remaining_loanable_amount[resource]:
-                    sliver['eff_rspec'][resource] -= amount
-                    remaining_loanable_amount[resource] -= amount
-                    slivers[other_name]['eff_rspec'][resource] += amount
-
-        # Pass 3: mark changed rspecs dirty
-        for sliver in slivers.itervalues():
-            if sliver['eff_rspec'] != sliver['old_eff_rspec']:
-                sliver['needs_update'] = True
-            del sliver['old_eff_rspec']
-
-    def rebuild_account_index(self):
-        self.account_index.clear()
-        for rec in self.itervalues():
-            if 'account_type' in rec: self.account_index[rec['name']] = rec
-
-    def delete_stale_records(self, ts):
-        for key, rec in self.items():
-            if rec['timestamp'] < ts: del self[key]
-
-    def delete_expired_records(self):
-        for key, rec in self.items():
-            if rec.get('expires', sys.maxint) < time.time(): del self[key]
-
-    def destroy_old_accounts(self):
-        for name in accounts.all():
-            if name not in self.account_index: accounts.get(name).ensure_destroyed()
-
-    def create_new_accounts(self):
-        """Invoke the appropriate create() function for every dirty account."""
-        for rec in self.account_index.itervalues():
-            if rec['dirty'] and rec['plc_instantiated']: accounts.get(rec['name']).ensure_created(rec)
-            rec['dirty'] = False
-
-    def update_bwcap(self):
-        bwcap_rec = self.get('bwcap')
-        if bwcap_rec and bwcap_rec['dirty']:
-            bwcap.update(bwcap_rec)
-            bwcap_rec['dirty'] = False
+DEFAULT_ALLOCATIONS = {'enabled': 1, 'cpu_min': 0, 'cpu_share': 32, 'net_min': bwmin, 'net_max': bwmax, 'net2_min': bwmin, 'net2_max': bwmax, 'net_share': 1, 'disk_max': 5000000}
 
 
 # database object and associated lock
-_db_lock = threading.RLock()
-_db = Database()
-# these are used in tandem to request a database dump from the dumper daemon
-_db_cond = threading.Condition(_db_lock)
-_dump_requested = False
+db_lock = threading.RLock()
+db = None
 
+# these are used in tandem to request a database dump from the dumper daemon
+db_cond = threading.Condition(db_lock)
+dump_requested = False
 
 # decorator that acquires and releases the database lock before and after the decorated operation
-def synchronized(function):
-    def sync_fun(*args, **kw_args):
-        _db_lock.acquire()
-        try: return function(*args, **kw_args)
-        finally: _db_lock.release()
-    sync_fun.__doc__ = function.__doc__
-    sync_fun.__name__ = function.__name__
-    return sync_fun
+def synchronized(fn):
+    def sync_fn(*args, **kw_args):
+        db_lock.acquire()
+        try: return fn(*args, **kw_args)
+        finally: db_lock.release()
+    sync_fn.__doc__ = fn.__doc__
+    sync_fn.__name__ = fn.__name__
+    return sync_fn
+
+
+class Database(dict):
+    def __init__(self):
+        self._min_timestamp = 0
+
+    def _compute_effective_rspecs(self):
+        """Calculate the effects of loans and store the result in field _rspec.  At the moment, we allow slivers to loan only those resources that they have received directly from PLC.  In order to do the accounting, we store three different rspecs: field 'rspec', which is the resources given by PLC; field '_rspec', which is the actual amount of resources the sliver has after all loans; and variable resid_rspec, which is the amount of resources the sliver has after giving out loans but not receiving any."""
+        slivers = {}
+        for name, rec in self.iteritems():
+            if 'rspec' in rec:
+                rec['_rspec'] = rec['rspec'].copy()
+                slivers[name] = rec
+        for rec in slivers.itervalues():
+            eff_rspec = rec['_rspec']
+            resid_rspec = rec['rspec'].copy()
+            for target, resname, amt in rec.get('_loans', []):
+                if target in slivers and amt < resid_rspec[resname]:
+                    eff_rspec[resname] -= amt
+                    resid_rspec[resname] -= amt
+                    slivers[target]['_rspec'][resname] += amt
+
+    def deliver_record(self, rec):
+        """A record is simply a dictionary with 'name' and 'timestamp' keys.  We keep some persistent private data in the records under keys that start with '_'; thus record updates should not displace such keys."""
+        name = rec['name']
+        old_rec = self.get(name)
+        if old_rec != None and rec['timestamp'] > old_rec['timestamp']:
+            for key in old_rec.keys():
+                if not key.startswith('_'): del old_rec[key]
+            old_rec.update(rec)
+        elif rec['timestamp'] >= self._min_timestamp: self[name] = rec
+
+    def set_min_timestamp(self, ts):
+        self._min_timestamp = ts
+        for name, rec in self.items():
+            if rec['timestamp'] < ts: del self[name]
+
+    def sync(self):
+        # delete expired records
+        now = time.time()
+        for name, rec in self.items():
+            if rec.get('expires', now) < now: del self[name]
+
+        self._compute_effective_rspecs()
+
+        # create and destroy accounts as needed
+        existing_acct_names = accounts.all()
+        for name in existing_acct_names:
+            if name not in self: accounts.get(name).ensure_destroyed()
+        for name, rec in self.iteritems():
+            if rec['instantiation'] == 'plc-instantiated': accounts.get(name).ensure_created(rec)
+
+        # request a database dump
+        global dump_requested
+        dump_requested = True
+        db_cond.notify()
 
 
-# apply the given records to the database and request a dump
 @synchronized
-def deliver_records(recs):
-    global _dump_requested
-    _db.deliver_records(recs)
-    _dump_requested = True
-    _db_cond.notify()
+def GetSlivers_callback(data):
+    for d in data:
+        for sliver in d['slivers']:
+            rec = sliver.copy()
+            attr_dict = {}
+            for attr in rec.pop('attributes'): attr_dict[attr['name']] = attr_dict[attr['value']]
+            keys = rec.pop('keys')
+            rec['keys'] = '\n'.join([key_struct['key'] for key_struct in keys])
+            rspec = {}
+            rec['rspec'] = rspec
+            for resname, default_amt in DEFAULT_ALLOCATIONS.iteritems():
+                try: amt = int(attr_dict[resname])
+                except (KeyError, ValueError): amt = default_amt
+                rspec[resname] = amt
+        db.set_min_timestamp(d['timestamp'])
+    db.sync()
+
 
 def start():
     """The database dumper daemon.  When it starts up, it populates the database with the last dumped database.  It proceeds to handle dump requests forever."""
     def run():
-        global _dump_requested
-        _db_lock.acquire()
-        try:  # load the db
-            f = open(DB_FILE)
-            _db.update(cPickle.load(f))
-            f.close()
-        except: logger.log_exc()
-        while True:  # handle dump requests forever
-            while not _dump_requested: _db_cond.wait()
-            db_copy = tools.deepcopy(_db)
-            _dump_requested = False
-            _db_lock.release()
+        global dump_requested
+        while True:
+            db_lock.acquire()
+            while not dump_requested: db_cond.wait()
+            db_copy = tools.deepcopy(db)
+            dump_requested = False
+            db_lock.release()
             try: tools.write_file(DB_FILE, lambda f: cPickle.dump(db_copy, f, -1))
             except: logger.log_exc()
-            _db_lock.acquire()
+    global db
+    try:
+        f = open(DB_FILE)
+        try: db = cPickle.load(f)
+        finally: f.close()
+    except:
+        logger.log_exc()
+        db = Database()
     tools.as_daemon_thread(run)
index ebea6d6..17bba7a 100644 (file)
@@ -5,12 +5,10 @@ import logger
 import tools
 
 
-class Delegate:
+class Delegate(accounts.Account):
     SHELL = '/bin/forward_api_calls'  # tunneling shell
     TYPE = 'delegate'
 
-    def __init__(self, name): self.name = name
-
     @staticmethod
     def create(name):
         add_shell(Delegate.SHELL)
@@ -19,11 +17,6 @@ class Delegate:
     @staticmethod
     def destroy(name): logger.log_call('/usr/sbin/userdel', '-r', name)
 
-    def configure(self, rec): accounts.install_keys(rec)
-    def start(self): pass
-    def stop(self): pass
-
-
 def add_shell(shell):
     """Add <shell> to /etc/shells if it's not already there."""
     etc_shells = open('/etc/shells')
index 0d4536b..52d7024 100644 (file)
@@ -11,7 +11,7 @@
  */
 
 static const int TIMEOUT_SECS = 30;
-const char *API_addr = "/tmp/node_mgr_api";
+const char *API_addr = "/tmp/node_mgr.api";
 
 static const char *Header =
   "POST / HTTP/1.0\r\n"
index eb99caf..e59ca1c 100644 (file)
--- a/logger.py
+++ b/logger.py
@@ -6,8 +6,7 @@ import time
 import traceback
 
 
-LOG_FILE = '/var/log/pl_node_mgr.log'
-
+LOG_FILE = '/root/node_mgr.log'
 
 def log(msg):
     """Write <msg> to the log file."""
diff --git a/nm.py b/nm.py
index 37aa904..b1bce16 100644 (file)
--- a/nm.py
+++ b/nm.py
@@ -3,7 +3,6 @@
 import optparse
 import time
 
-from config import *
 import accounts
 import api
 import database
@@ -15,12 +14,8 @@ import tools
 
 
 parser = optparse.OptionParser()
-parser.add_option('-d', '--daemon',
-                  action='store_true', dest='daemon', default=False,
-                  help='run daemonized')
-parser.add_option('-s', '--startup',
-                  action='store_true', dest='startup', default=False,
-                  help='run all sliver startup scripts')
+parser.add_option('-d', '--daemon', action='store_true', dest='daemon', default=False, help='run daemonized')
+parser.add_option('-s', '--startup', action='store_true', dest='startup', default=False, help='run all sliver startup scripts')
 (options, args) = parser.parse_args()
 
 def run():
index 4292d1d..f6e9397 100644 (file)
@@ -20,7 +20,6 @@ import errno
 import os
 import vserver
 
-from bwlimit import bwmin, bwmax
 import accounts
 import logger
 import tools
@@ -38,19 +37,19 @@ DEFAULTS = {'disk_max': 5000000,
             'keys':          '',
             'initscript':    ''}
 
-class Sliver_VS(vserver.VServer):
+class Sliver_VS(accounts.Account, vserver.VServer):
     """This class wraps vserver.VServer to make its interface closer to what we need for the Node Manager."""
 
     SHELL = '/bin/vsh'
     TYPE = 'sliver.VServer'
 
-    def __init__(self, name):
-        vserver.VServer.__init__(self, name)
-        self.current_keys = ''
-        self.current_initscript = ''
+    def __init__(self, rec):
+        vserver.VServer.__init__(self, rec['name'])
+        self.keys = ''
+        self.rspec = {}
+        self.initscript = ''
         self.disk_usage_initialized = False
-        self.rec = DEFAULTS.copy()
-
+        self.configure(rec)
 
     @staticmethod
     def create(name): logger.log_call('/usr/sbin/vuseradd', name)
@@ -58,20 +57,15 @@ class Sliver_VS(vserver.VServer):
     @staticmethod
     def destroy(name): logger.log_call('/usr/sbin/vuserdel', name)
 
-
     def configure(self, rec):
-        self.rec = DEFAULTS.copy()
-        self.rec.update(rec)
-
-        self.set_resources()
-
-        new_keys = self.rec['keys']
-        if new_keys != self.current_keys:
-            accounts.install_keys(rec)
-            self.current_keys = new_keys
-
-        new_initscript = self.rec['initscript']
-        if new_initscript != self.current_initscript:
+        new_rspec = rec['_rspec']
+        if new_rspec != self.rspec:
+            self.rspec = new_rspec
+            self.set_resources()
+
+        new_initscript = rec['initscript']
+        if new_initscript != self.initscript:
+            self.initscript = new_initscript
             logger.log('%s: installing initscript' % self.name)
             def install_initscript():
                 flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
@@ -80,11 +74,11 @@ class Sliver_VS(vserver.VServer):
                 os.close(fd)
             try: self.chroot_call(install_initscript)
             except: logger.log_exc()
-            self.current_initscript = new_initscript
 
+        accounts.Account.configure(self, rec)  # install ssh keys
 
     def start(self):
-        if self.rec['enabled']:
+        if self.rspec['enabled']:
             logger.log('%s: starting' % self.name)
             child_pid = os.fork()
             if child_pid == 0:
@@ -99,9 +93,8 @@ class Sliver_VS(vserver.VServer):
         logger.log('%s: stopping' % self.name)
         vserver.VServer.stop(self)
 
-
     def set_resources(self):
-        disk_max = int(self.rec['disk_max'])
+        disk_max = self.rspec['disk_max']
         logger.log('%s: setting max disk usage to %d KiB' % (self.name, disk_max))
         try:  # if the sliver is over quota, .set_disk_limit will throw an exception
             if not self.disk_usage_initialized:
@@ -112,26 +105,21 @@ class Sliver_VS(vserver.VServer):
             vserver.VServer.set_disklimit(self, disk_max_KiB)
         except OSError: logger.log_exc()
 
-        net_limits = (int(self.rec['net_min']),
-                      int(self.rec['net_max']),
-                      int(self.rec['net2_min']),
-                      int(self.rec['net2_max']),
-                      int(self.rec['net_share']))
+        net_limits = (self.rspec['net_min'], self.rspec['net_max'], self.rspec['net2_min'], self.rspec['net2_max'], self.rspec['net_share'])
         logger.log('%s: setting net limits to %s bps' % (self.name, net_limits[:-1]))
         logger.log('%s: setting net share to %d' % (self.name, net_limits[-1]))
         self.set_bwlimit(*net_limits)
 
-        cpu_min = int(self.rec['cpu_min'])
-        cpu_share = int(self.rec['cpu_share'])
-        if bool(self.rec['enabled']):
+        cpu_min = self.rspec['cpu_min']
+        cpu_share = self.rspec['cpu_share']
+        if self.rspec['enabled']:
             if cpu_min > 0:
                 logger.log('%s: setting cpu share to %d%% guaranteed' % (self.name, cpu_min/10.0))
                 self.set_sched_config(cpu_min, vserver.SCHED_CPU_GUARANTEED)
             else:
                 logger.log('%s: setting cpu share to %d' % (self.name, cpu_share))
                 self.set_sched_config(cpu_share, 0)
-        else:
-            # tell vsh to disable remote login by setting CPULIMIT to 0
+        else:  # tell vsh to disable remote login by setting CPULIMIT to 0
             logger.log('%s: disabling remote login' % self.name)
             self.set_sched_config(0, 0)
             self.stop()
index 143e128..51527cc 100644 (file)
--- a/tools.py
+++ b/tools.py
@@ -8,8 +8,7 @@ import threading
 import logger
 
 
-PID_FILE = '/var/run/pl_node_mgr.pid'
-
+PID_FILE = '/var/run/node_mgr.pid'
 
 def as_daemon_thread(run):
     """Call function <run> with no arguments in its own thread."""
@@ -17,7 +16,6 @@ def as_daemon_thread(run):
     thr.setDaemon(True)
     thr.start()
 
-
 def close_nonstandard_fds():
     """Close all open file descriptors other than 0, 1, and 2."""
     _SC_OPEN_MAX = 4
@@ -25,7 +23,6 @@ def close_nonstandard_fds():
         try: os.close(fd)
         except OSError: pass  # most likely an fd that isn't open
 
-
 # after http://www.erlenstar.demon.co.uk/unix/faq_2.html
 def daemon():
     """Daemonize the current process."""
@@ -37,12 +34,10 @@ def daemon():
     devnull = os.open(os.devnull, os.O_RDWR)
     for fd in range(3): os.dup2(devnull, fd)
 
-
 def deepcopy(obj):
     """Return a deep copy of obj."""
     return cPickle.loads(cPickle.dumps(obj, -1))
 
-
 def fork_as(su, function, *args):
     """fork(), cd / to avoid keeping unused directories open, close all nonstandard file descriptors (to avoid capturing open sockets), fork() again (to avoid zombies) and call <function> with arguments <args> in the grandchild process.  If <su> is not None, set our group and user ids appropriately in the child process."""
     child_pid = os.fork()
@@ -62,7 +57,6 @@ def fork_as(su, function, *args):
         os._exit(0)
     else: os.waitpid(child_pid, 0)
 
-
 def pid_file():
     """We use a pid file to ensure that only one copy of NM is running at a given time.  If successful, this function will write a pid file containing the pid of the current process.  The return value is the pid of the other running process, or None otherwise."""
     other_pid = None
@@ -77,18 +71,16 @@ def pid_file():
             else: raise  # who knows
     if other_pid == None:
         # write a new pid file
-        write_file(PID_FILE, lambda thefile: thefile.write(str(os.getpid())))
+        write_file(PID_FILE, lambda f: f.write(str(os.getpid())))
     return other_pid
 
-
 def write_file(filename, do_write):
     """Write file <filename> atomically by opening a temporary file, using <do_write> to write that file, and then renaming the temporary file."""
     os.rename(write_temp_file(do_write), filename)
 
-
 def write_temp_file(do_write):
     fd, temporary_filename = tempfile.mkstemp()
-    thefile = os.fdopen(fd, 'w')
-    do_write(thefile)
-    thefile.close()
+    f = os.fdopen(fd, 'w')
+    try: do_write(f)
+    finally: f.close()
     return temporary_filename