remove some PL-specific details
authorSteve Muir <smuir@cs.princeton.edu>
Tue, 14 Mar 2006 22:57:50 +0000 (22:57 +0000)
committerSteve Muir <smuir@cs.princeton.edu>
Tue, 14 Mar 2006 22:57:50 +0000 (22:57 +0000)
python/bwlimit.py

index c906b2a..b8ce780 100644 (file)
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
-# $Id: bwlimit.py,v 1.8 2006/03/01 22:02:52 mlhuang Exp $
+# $Id: bwlimit.py,v 1.9 2006/03/01 22:37:24 mlhuang Exp $
 #
 
 import sys, os, re, getopt
 from sets import Set
-import plcapi
+import pwd
 
 
 # Where the tc binary lives
@@ -60,10 +60,6 @@ TC = "/sbin/tc"
 # Default interface
 dev = "eth0"
 
-# For backward compatibility, if bwcap is not specified, attempt to
-# get it from here.
-bwcap_file = "/etc/planetlab/bwcap"
-
 # Verbosity level
 verbose = 0
 
@@ -203,9 +199,8 @@ def format_tc_rate(rate):
         return "%.0fbit" % rate
 
 
-# Parse /etc/planetlab/bwcap. XXX Should get this from the API
-# instead.
-def get_bwcap():
+# Parse /etc/planetlab/bwcap (or equivalent)
+def read_bwcap(bwcap_file):
     bwcap = bwmax
     try:
         fp = open(bwcap_file, "r")
@@ -219,31 +214,41 @@ def get_bwcap():
     return bwcap
 
 
+# Get current (live) value of bwcap
+def get_bwcap(dev = dev):
+
+    state = tc("-d class show dev %s" % dev)
+    base_re = re.compile(r"class htb 1:10 parent 1:1 .*ceil ([^ ]+) .*")
+    base_classes = filter(None, map(base_re.match, state))
+    if not base_classes:
+        return -1
+    if len(base_classes) > 1:
+        raise Exception, "unable to get current bwcap"
+    return get_tc_rate(base_classes[0].group(1))
+
+
 # Get slice xid (500) from slice name ("500" or "princeton_mlh") or
 # slice name ("princeton_mlh") from slice xid (500).
 def get_slice(xid_or_name):
-    labels = ['account', 'password', 'uid', 'gid', 'gecos', 'directory', 'shell']
 
-    for line in file("/etc/passwd"):
-        # Comment
-        if line.strip() == '' or line[0] in '#':
-            continue
-        # princeton_mlh:x:...
-        fields = line.strip().split(':')
-        if len(fields) < len(labels):
-            continue
-        # {'account': 'princeton_mlh', 'password': 'x', ...}
-        pw = dict(zip(labels, fields))
-        if xid_or_name == root_xid:
-            return "root"
-        if xid_or_name == default_xid:
-            return "default"
-        elif xid_or_name == int(pw['uid']):
-            # Convert xid into name
-            return pw['account']
-        elif pw['uid'] == xid_or_name or pw['account'] == xid_or_name:
-            # Convert name into xid
-            return int(pw['uid'])
+    if xid_or_name == root_xid:
+        return "root"
+    if xid_or_name == default_xid:
+        return "default"
+    if isinstance(xid_or_name, (int, long)):
+        try:
+            return pwd.getpwuid(xid_or_name).pw_name
+        except KeyError:
+            pass
+    else:
+        try:
+            try:
+                return int(xid_or_name)
+            except ValueError:
+                pass
+            return pwd.getpwnam(xid_or_name).pw_uid
+        except KeyError:
+            pass
 
     return None
 
@@ -273,14 +278,7 @@ def tc(cmd):
 
 
 # (Re)initialize the bandwidth limits on this node
-def init(dev = dev, bwcap = None):
-    if bwcap is None:
-        # For backward compatibility, if bwcap is not specified,
-        # attempt to get it from /etc/planetlab/bwcap.
-        bwcap = get_bwcap()
-    else:
-        # Allow bwcap to be specified as a tc rate string
-        bwcap = get_tc_rate(bwcap)
+def init(dev, bwcap):
 
     # Delete root qdisc 1: if it exists. This will also automatically
     # delete any child classes.
@@ -323,9 +321,6 @@ def init(dev = dev, bwcap = None):
     # up here.
     on(default_xid, dev, share = default_share)
 
-    # Set up exemptions.
-    exempt_init()
-
 
 # Get the bandwidth limits for a particular slice xid as a tuple (xid,
 # share, minrate, maxrate), or all classes as a list of tuples.
@@ -442,46 +437,25 @@ def off(xid, dev = dev):
         tc("class del dev %s classid 1:%x" % (dev, exempt_minor | xid))
 
 
-def exempt_init():
-    # Who are we?
-    try:
-        node_id = int(file('/etc/planetlab/node_id').readline().strip())
-    except:
-        return False
-
-    api = plcapi.PLCAPI()
-
-    # All nodes that have access to Internet2
-    node_ids = []
-    for node_group in api.AnonAdmGetNodeGroups(api.auth):
-        if node_group['name'] == "Internet2":
-            node_ids += api.AnonAdmGetNodeGroupNodes(api.auth, node_group['nodegroup_id'])
-
-    # Remove duplicates
-    node_ids = list(Set(node_ids))
-
-    # Continue only if we ourselves have access to Internet2
-    if node_id not in node_ids:
-        return True
-
-    # Exempt the following destinations from the node bandwidth cap
-    node_ips = [node['ip'] for node in api.AnonAdmGetNodes(api.auth, node_ids, ['ip'])]
+def exempt_init(group_name, node_ips):
 
     # Clean up
-    run("/sbin/iptables -t vnet -F POSTROUTING")
-    run("/sbin/ipset -X Internet2")
+    iptables = "/sbin/iptables -t vnet %s POSTROUTING" 
+    run(iptables % "-F")
+    run("/sbin/ipset -X " + group_name)
 
     # Create a hashed IP set of all of these destinations
     run("/sbin/modprobe ip_set_iphash")
-    lines = ["-N Internet2 iphash"]
-    lines += ["-A Internet2 " + ip for ip in node_ips]
+    lines = ["-N %s iphash" % group_name]
+    add_cmd = "-A %s " % group_name
+    lines += [(add_cmd + ip) for ip in node_ips]
     lines += ["COMMIT"]
     restore = "\n".join(lines) + "\n"
     run("/sbin/ipset -R", restore)
 
     # Add rule to match on destination IP set
-    run("/sbin/iptables -t vnet -A POSTROUTING -m set --set Internet2 dst -j CLASSIFY --set-class 1:%x" %
-        exempt_minor)
+    run((iptables + " -m set --set %s dst -j CLASSIFY --set-class 1:%x") %
+        ("-A", group_name, exempt_minor))
 
 
 def usage():
@@ -539,7 +513,7 @@ def main():
     if len(argv):
         if argv[0] == "init" or (argv[0] == "on" and len(argv) == 1):
             # (Re)initialize
-            init(dev, bwcap)
+            init(dev, get_tc_rate(bwcap))
 
         elif argv[0] == "get" or argv[0] == "show":
             # Show