support loadingg module files with multiple classes
authorTony Mack <tmack@cs.princeton.edu>
Tue, 8 Jan 2008 17:53:31 +0000 (17:53 +0000)
committerTony Mack <tmack@cs.princeton.edu>
Tue, 8 Jan 2008 17:53:31 +0000 (17:53 +0000)
qaapi/qa/QAAPI.py

index 5310895..532da93 100644 (file)
@@ -10,36 +10,26 @@ class QAAPI:
     methods = []
                
        
-    def __init__(self, globals = globals(), config = None, logging=None):
+    def __init__(self, globals, config = None, logging=None, verbose=None):
        if config is None: self.config = Config()
        else: self.config = Config(config)
 
-       # Load methods
-       real_files = lambda name: not name.startswith('__init__') \
-                                 and  name.endswith('.py')
-       remove_ext = lambda name: name.split(".py")[0]  
-       iterator = os.walk(self.modules_path)
-       (root, basenames, files) = iterator.next()
-       method_base = ""
-       self.methods.extend([method_base+file for file in map(remove_ext, filter(real_files, files))])  
-       for (root, dirs, files) in iterator:
-           parts = root.split(os.sep)  
-           for basename in basenames:
-               if basename in parts:
-                   method_base = ".".join(parts[parts.index(basename):])+"."
-           files = filter(real_files, files)
-           files = map(remove_ext, files)      
-           self.methods.extend([method_base+file for file in  files]) 
-
-       # Add methods to self and global environment 
-        for method in self.methods:
-           callable = self.callable(method)(self.config)
-           if logging: callable = log(callable, method)
+       module_files = self.module_files(self.modules_path)
+       callables = set()
+        # determine what is callable
+       for file in module_files:
+           callables.update(self.callables(file))
+               
+       # Add methods to self and global environemt     
+        for method in callables:           
+           if logging: method = log(method, method.mod_name)
            elif hasattr(self.config, 'log') and self.config.log:
-                callable = log(callable, method)
-           
+               method = log(method, method.mod_name)
+          
            class Dummy: pass
-            paths = method.split(".")
+            paths = method.mod_name.split(".")
+           print dir(method)
+           
             if len(paths) > 1:
                 first = paths.pop(0)
 
@@ -55,35 +45,63 @@ class QAAPI:
                 for path in paths:
                     if not hasattr(obj, path):
                         if path == paths[-1]:
-                            setattr(obj, path, callable)
-                           globals[method]=obj  
+                            setattr(obj, path, method)
+                           globals[method.mod_name]=obj  
                         else:
                             setattr(obj, path, Dummy())
                     obj = getattr(obj, path)
            else:
-               if not hasattr(self, method):
-                   setattr(self, method, callable)
                if globals is not None:
-                   globals[method] = callable          
-               
+                   globals[method.mod_name] = method           
+
+    def module_files(self, module_dir):
+       """
+       Build a list of files   
+       """     
+       
+       # Load files from modules direcotry
+        real_files = lambda name: not name.startswith('__init__') \
+                                  and  name.endswith('.py')
+        remove_ext = lambda name: name.split(".py")[0]
+        iterator = os.walk(module_dir)
+        (root, basenames, files) = iterator.next()
+        module_base = ""
+        module_files = []
+        module_files.extend([method_base+file for file in map(remove_ext, filter(real_files, files))])
+
+        # recurse through directory             
+        for (root, dirs, files) in iterator:
+            parts = root.split(os.sep)
+            for basename in basenames:
+                if basename in parts:
+                    module_base = ".".join(parts[parts.index(basename):])+"."
+            files = filter(real_files, files)
+            files = map(remove_ext, files)
+            module_files.extend([module_base+file for file in  files])
+
+       return module_files 
 
-    def callable(self, method):
+    def callables(self, module_file):
        """
        Return a new instance of the specified method. 
        """      
         
-       # Look up test  
-       if method not in self.methods:
-           raise Exception, "Invalid method: %s" % method
-
        # Get new instance of method
+       parts = module_file.split(".")
+       # add every part except for the last to name (filename)
+       module_dir =  "qa.modules."
+       module_basename = ".".join(parts[:-1])
+       module_path = module_dir + module_file
        try:
-           #classname = method.split(".")[-1]
-           module_name = "qa.modules."+method
-           module = __import__(module_name, globals(), locals(), module_name)
-           components = module_name.split('.')
-           module = getattr(module, components[-1:][0])        
-           return module
+           module = __import__(module_path, globals(), locals(), module_path)
+           callables = []
+
+           for attribute in dir(module):
+               attr = getattr(module, attribute)
+               if callable(attr):
+                   setattr(attr, 'mod_name', module_basename+"."+attribute)
+                   callables.append(attr(self.config))
+           return callables 
        except ImportError, AttributeError:
            raise