push trunk version to branch
authorMarc Fiuczynski <mef@cs.princeton.edu>
Wed, 15 Apr 2009 18:17:53 +0000 (18:17 +0000)
committerMarc Fiuczynski <mef@cs.princeton.edu>
Wed, 15 Apr 2009 18:17:53 +0000 (18:17 +0000)
modprobe.py

index b56f5d5..5e4053d 100644 (file)
@@ -10,42 +10,80 @@ import tempfile
 class Modprobe:
     def __init__(self,filename="/etc/modprobe.conf"):
         self.conffile = {}
-        self.origconffile = {}
-        for keyword in ("alias","options","install","remove","blacklist","MODULES","#"):
+        for keyword in ("include","alias","options","install","remove","blacklist","MODULES"):
             self.conffile[keyword]={}
         self.filename = filename
 
     def input(self,filename=None):
         if filename==None: filename=self.filename
-        fb = file(filename,"r")
-        for line in fb.readlines():
-            parts = line.split()
-            command = parts[0].lower()
-
-            table = self.conffile.get(command,None)
-            if table == None:
-                print "WARNING: command %s not recognized. Ignoring!" % command
+
+        # list of file names; loop itself might add filenames
+        filenames = [filename]
+
+        for filename in filenames:
+            if not os.path.exists(filename):
                 continue
 
-            if command == "alias":
-                wildcard=parts[1]
-                modulename=parts[2]
-                self.aliasset(wildcard,modulename)
-                options=''
-                if len(parts)>3:
-                    options=" ".join(parts[3:])
-                    self.optionsset(modulename,options)
-                    self.conffile['MODULES']={}
-                self.conffile['MODULES'][modulename]=options
-            else:
-                modulename=parts[1]
-                rest=" ".join(parts[2:])
-                self._set(command,modulename,rest)
-                if command == "options":
+            fb = file(filename,"r")
+            for line in fb.readlines():
+                def __default():
+                    modulename=parts[1]
+                    rest=" ".join(parts[2:])
+                    self._set(command,modulename,rest)
+
+                def __alias():
+                    wildcard=parts[1]
+                    modulename=parts[2]
+                    self.aliasset(wildcard,modulename)
+                    options=''
+                    if len(parts)>3:
+                        options=" ".join(parts[3:])
+                        self.optionsset(modulename,options)
+                    self.conffile['MODULES'][modulename]=options
+
+                def __options():
+                    modulename=parts[1]
+                    rest=" ".join(parts[2:])
                     self.conffile['MODULES'][modulename]=rest
+                    __default()
+
+                def __blacklist():
+                    modulename=parts[1]
+                    self.blacklistset(modulename,'')
+
+                def __include():
+                    newfilename = parts[1]
+                    if os.path.exists(newfilename):
+                        if os.path.isdir(newfilename):
+                            for e in os.listdir(newfilename):
+                                filenames.append("%s/%s"%(newfilename,e))
+                        else:
+                            filenames.append(newfilename)
+
+                funcs = {"alias":__alias,
+                         "options":__options,
+                         "blacklis":__blacklist,
+                         "include":__include}
+
+                parts = line.split()
+
+                # skip empty lines or those that are comments
+                if len(parts) == 0 or parts[0] == "#":
+                    continue
+
+                # lower case first word
+                command = parts[0].lower()
+
+                # check if its a command we support
+                if not self.conffile.has_key(command):
+                    print "WARNING: command %s not recognized." % command
+                    continue
+
+                func = funcs.get(command,__default)
+                func()
+            
+            fb.close()
 
-        self.origconffile = self.conffile.copy()
-                
     def _get(self,command,key):
         return self.conffile[command].get(key,None)
 
@@ -58,11 +96,17 @@ class Modprobe:
     def optionsget(self,key):
         return self._get('options',key)
 
+    def blacklistget(self,key):
+        return self._get('blacklist',key)
+
     def aliasset(self,key,value):
         self._set("alias",key,value)
 
     def optionsset(self,key,value):
         self._set("options",key,value)
+
+    def blacklistset(self,key,value):
+        self._set("blacklist",key,value)
         
     def _comparefiles(self,a,b):
         try:
@@ -121,12 +165,19 @@ class Modprobe:
          
 if __name__ == '__main__':
     import sys
+    m = Modprobe()
     if len(sys.argv)>1:
-        m = Modprobe(sys.argv[1])
+        fn = sys.argv[1]
     else:
-        m = Modprobe()
+        fn = "/etc/modprobe.conf"
 
     m.input()
-    m.aliasset("bond0","bonding")
-    m.optionsset("bond0","miimon=100")
-    m.output("/tmp/x")
+
+    blacklist = Modprobe()
+    blacklistfiles = ("blacklist","blacklist-compat","blacklist-firewire")
+    for blf in blacklistfiles:
+        if os.path.exists("/etc/modprobe.d/%s"%blf):
+            blacklist.input("/etc/modprobe.d/%s"%blf)
+
+    m.output("/tmp/%s-tmp"%os.path.basename(fn),"TEST")
+    blacklist.output("/tmp/blacklist-tmp.txt","TEST")