remove some PL-specific details
[util-vserver.git] / python / bwlimit.py
index 27f2344..b8ce780 100644 (file)
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
-# $Id: bwlimit.py,v 1.6 2006/03/01 16:28:51 mlhuang Exp $
+# $Id: bwlimit.py,v 1.9 2006/03/01 22:37:24 mlhuang Exp $
 #
 
 import sys, os, re, getopt
+from sets import Set
+import pwd
 
 
 # Where the tc binary lives
@@ -58,10 +60,6 @@ TC = "/sbin/tc"
 # Default interface
 dev = "eth0"
 
-# For backward compatibility, if bwcap is not specified, attempt to
-# get it from here.
-bwcap_file = "/etc/planetlab/bwcap"
-
 # Verbosity level
 verbose = 0
 
@@ -201,9 +199,8 @@ def format_tc_rate(rate):
         return "%.0fbit" % rate
 
 
-# Parse /etc/planetlab/bwcap. XXX Should get this from the API
-# instead.
-def get_bwcap():
+# Parse /etc/planetlab/bwcap (or equivalent)
+def read_bwcap(bwcap_file):
     bwcap = bwmax
     try:
         fp = open(bwcap_file, "r")
@@ -217,42 +214,57 @@ def get_bwcap():
     return bwcap
 
 
+# Get current (live) value of bwcap
+def get_bwcap(dev = dev):
+
+    state = tc("-d class show dev %s" % dev)
+    base_re = re.compile(r"class htb 1:10 parent 1:1 .*ceil ([^ ]+) .*")
+    base_classes = filter(None, map(base_re.match, state))
+    if not base_classes:
+        return -1
+    if len(base_classes) > 1:
+        raise Exception, "unable to get current bwcap"
+    return get_tc_rate(base_classes[0].group(1))
+
+
 # Get slice xid (500) from slice name ("500" or "princeton_mlh") or
 # slice name ("princeton_mlh") from slice xid (500).
 def get_slice(xid_or_name):
-    labels = ['account', 'password', 'uid', 'gid', 'gecos', 'directory', 'shell']
 
-    for line in file("/etc/passwd"):
-        # Comment
-        if line.strip() == '' or line[0] in '#':
-            continue
-        # princeton_mlh:x:...
-        fields = line.strip().split(':')
-        if len(fields) < len(labels):
-            continue
-        # {'account': 'princeton_mlh', 'password': 'x', ...}
-        pw = dict(zip(labels, fields))
-        if xid_or_name == root_xid:
-            return "root"
-        if xid_or_name == default_xid:
-            return "default"
-        elif xid_or_name == int(pw['uid']):
-            # Convert xid into name
-            return pw['account']
-        elif pw['uid'] == xid_or_name or pw['account'] == xid_or_name:
-            # Convert name into xid
-            return int(pw['uid'])
+    if xid_or_name == root_xid:
+        return "root"
+    if xid_or_name == default_xid:
+        return "default"
+    if isinstance(xid_or_name, (int, long)):
+        try:
+            return pwd.getpwuid(xid_or_name).pw_name
+        except KeyError:
+            pass
+    else:
+        try:
+            try:
+                return int(xid_or_name)
+            except ValueError:
+                pass
+            return pwd.getpwnam(xid_or_name).pw_uid
+        except KeyError:
+            pass
 
     return None
 
 
-# Shortcut for running a tc command
-def tc(cmd):
+# Shortcut for running a command
+def run(cmd, input = None):
     try:
         if verbose:
-            sys.stderr.write("Executing: " + TC + " " + cmd + "\n")
-        fileobj = os.popen(TC + " " + cmd, "r")
-        output = fileobj.readlines()
+            sys.stderr.write("Executing: " + cmd + "\n")
+        if input is None:
+            fileobj = os.popen(cmd, "r")
+            output = fileobj.readlines()
+        else:
+            fileobj = os.popen(cmd, "w")
+            fileobj.write(input)
+            output = None
         if fileobj.close() is None:
             return output
     except Exception, e:
@@ -260,15 +272,13 @@ def tc(cmd):
     return None
 
 
+# Shortcut for running a tc command
+def tc(cmd):
+    return run(TC + " " + cmd)
+
+
 # (Re)initialize the bandwidth limits on this node
-def init(dev = dev, bwcap = None):
-    if bwcap is None:
-        # For backward compatibility, if bwcap is not specified,
-        # attempt to get it from /etc/planetlab/bwcap.
-        bwcap = get_bwcap()
-    else:
-        # Allow bwcap to be specified as a tc rate string
-        bwcap = get_tc_rate(bwcap)
+def init(dev, bwcap):
 
     # Delete root qdisc 1: if it exists. This will also automatically
     # delete any child classes.
@@ -427,8 +437,29 @@ def off(xid, dev = dev):
         tc("class del dev %s classid 1:%x" % (dev, exempt_minor | xid))
 
 
+def exempt_init(group_name, node_ips):
+
+    # Clean up
+    iptables = "/sbin/iptables -t vnet %s POSTROUTING" 
+    run(iptables % "-F")
+    run("/sbin/ipset -X " + group_name)
+
+    # Create a hashed IP set of all of these destinations
+    run("/sbin/modprobe ip_set_iphash")
+    lines = ["-N %s iphash" % group_name]
+    add_cmd = "-A %s " % group_name
+    lines += [(add_cmd + ip) for ip in node_ips]
+    lines += ["COMMIT"]
+    restore = "\n".join(lines) + "\n"
+    run("/sbin/ipset -R", restore)
+
+    # Add rule to match on destination IP set
+    run((iptables + " -m set --set %s dst -j CLASSIFY --set-class 1:%x") %
+        ("-A", group_name, exempt_minor))
+
+
 def usage():
-    bwcap_description = format_tc_rate(bwmax)
+    bwcap_description = format_tc_rate(get_bwcap())
         
     print """
 Usage:
@@ -482,7 +513,7 @@ def main():
     if len(argv):
         if argv[0] == "init" or (argv[0] == "on" and len(argv) == 1):
             # (Re)initialize
-            init(dev, bwcap)
+            init(dev, get_tc_rate(bwcap))
 
         elif argv[0] == "get" or argv[0] == "show":
             # Show
@@ -502,7 +533,7 @@ def main():
                 if slice is None:
                     # Orphaned (not associated with a slice) class
                     slice = "%d?" % xid
-                print "%s: share %d minrate %s maxrate %s" % \
+                print "%s %d %s %s" % \
                       (slice, share, format_tc_rate(minrate), format_tc_rate(maxrate))
 
         elif len(argv) >= 2: