Merge remote-tracking branch 'origin/pycurl' into planetlab-4_0-branch
[plcapi.git] / PLC / Method.py
index 14b73ae..5e7d09a 100644 (file)
@@ -4,7 +4,7 @@
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
-# $Id: Method.py,v 1.2 2006/09/08 19:44:31 mlhuang Exp $
+# $Id: Method.py 5574 2007-10-25 20:33:17Z thierry $
 #
 
 import xmlrpclib
@@ -12,11 +12,17 @@ from types import *
 import textwrap
 import os
 import time
+import pprint
+
+from types import StringTypes
 
 from PLC.Faults import *
-from PLC.Parameter import Parameter, Mixed
+from PLC.Parameter import Parameter, Mixed, python_type, xmlrpc_type
 from PLC.Auth import Auth
 from PLC.Debug import profile, log
+from PLC.Events import Event, Events
+from PLC.Nodes import Node, Nodes
+from PLC.Persons import Person, Persons
 
 class Method:
     """
@@ -65,16 +71,15 @@ class Method:
 
         # API may set this to a (addr, port) tuple if known
         self.source = None
-       self.__call__ = self.log(self.__call__) 
-
        
-    def __call__(self, *args):
+    def __call__(self, *args, **kwds):
         """
         Main entry point for all PLCAPI functions. Type checks
         arguments, authenticates, and executes call().
         """
 
         try:
+           start = time.time()
             (min_args, max_args, defaults) = self.args()
                                
            # Check that the right number of arguments were passed in
@@ -82,91 +87,88 @@ class Method:
                 raise PLCInvalidArgumentCount(len(args), len(min_args), len(max_args))
 
             for name, value, expected in zip(max_args, args, self.accepts):
-                self.type_check(name, value, expected)
+                self.type_check(name, value, expected, args)
        
-            # The first argument to all methods that require
-            # authentication, should be an Auth structure. The rest of the
-            # arguments to the call may also be used in the authentication
-            # check. For example, calls made by the Boot Manager are
-            # verified by comparing a hash of the message parameters to
-            # the value in the authentication structure.        
-
-            if len(self.accepts):
-                auth = None
-                if isinstance(self.accepts[0], Auth):
-                    auth = self.accepts[0]
-                elif isinstance(self.accepts[0], Mixed):
-                    for auth in self.accepts[0]:
-                        if isinstance(auth, Auth):
-                            break
-                if isinstance(auth, Auth):
-                    auth.check(self, *args)
+           result = self.call(*args, **kwds)
+           runtime = time.time() - start
 
+            if self.api.config.PLC_API_DEBUG or hasattr(self, 'message'):
+               self.log(None, runtime, *args)
                
-           return self.call(*args)
+           return result
 
         except PLCFault, fault:
-            # Prepend method name to expected faults
-            fault.faultString = self.name + ": " + fault.faultString
+       
+           caller = ""
+           if isinstance(self.caller, Person):
+               caller = 'person_id %s'  % self.caller['person_id']
+            elif isinstance(self.caller, Node):
+                caller = 'node_id %s'  % self.caller['node_id']
+
+            # Prepend caller and method name to expected faults
+            fault.faultString = caller + ": " +  self.name + ": " + fault.faultString
+           runtime = time.time() - start
+           self.log(fault, runtime, *args)
             raise fault
 
-
-    def log(self, callable):
+    def log(self, fault, runtime, *args):
         """
         Log the transaction 
-        """
-       def __log__(vars):
-               """
-               Commit the transaction 
-               """
-
-               # Do not log listMethods call
-               if vars['call_name'] in ['listMethods']:
-                       return False
-        
-               sql = "INSERT INTO events " \
-                        " (person_id, event_type, object_type, fault_code, call, runtime)" \
-                        " VALUES (%d, '%s', '%s', %d, '%s', %f)" %  \
-                        (vars['person_id'], vars['event_type'], vars['object_type'], 
-                       vars['fault_code'], vars['call'], vars['runtime'])
-                self.api.db.do(sql)
-                self.api.db.commit()
-                       
-
-        def wrapper(*args, **kwds):
-               
-               fault_code = 0
-               # XX Get real person_id
-               person_id = 1
-               event_type = 'Unknown'
-               object_type = 'Unknown'
-               call_name = callable.im_class.__module__.split('.')[-1:][0]
-               call_args = ", ".join([str(arg) for arg in list(args)[1:]]).replace('\'', '\\\'')
-               call = "%s(%s)" % (call_name, call_args)
-               
-               if hasattr(self, 'event_type'):
-                       event_type = self.event_type
-               
-                if hasattr(self, 'object_type'):
-                       object_type = self.object_type
-               
-               start = time.time()
-                       
-               try:
-                       result =  callable(*args, **kwds)
-                       runtime =  time.time() - start
-                       __log__(locals())                       
-                       return result
-                                                                      
-                except PLCFault, fault:
-                       fault_code = fault.faultCode
-                       runtime =  time.time() - start
-                       __log__(locals())
-                       print fault
-                       
-       return wrapper
+        """    
+
+       # Do not log system or Get calls
+        #if self.name.startswith('system') or self.name.startswith('Get'):
+        #    return False
+
+        # Create a new event
+        event = Event(self.api)
+       event['fault_code'] = 0
+       if fault:
+            event['fault_code'] = fault.faultCode
+        event['runtime'] = runtime
+
+        # Redact passwords and sessions
+        if args and isinstance(args[0], dict):
+           # what type of auth this is
+           if args[0].has_key('AuthMethod'):
+               auth_methods = ['session', 'password', 'capability', 'gpg', 'hmac','anonymous']
+               auth_method = args[0]['AuthMethod']
+               if auth_method in auth_methods:
+                   event['auth_type'] = auth_method
+            for password in 'AuthString', 'session':
+                if args[0].has_key(password):
+                    auth = args[0].copy()
+                    auth[password] = "Removed by API"
+                    args = (auth,) + args[1:]
+
+        # Log call representation
+        # XXX Truncate to avoid DoS
+        event['call'] = self.name + pprint.saferepr(args)
+       event['call_name'] = self.name
+
+        # Both users and nodes can call some methods
+        if isinstance(self.caller, Person):
+            event['person_id'] = self.caller['person_id']
+        elif isinstance(self.caller, Node):
+            event['node_id'] = self.caller['node_id']
+
+        event.sync(commit = False)
+
+        if hasattr(self, 'event_objects') and isinstance(self.event_objects, dict):
+            for key in self.event_objects.keys():
+               for object_id in self.event_objects[key]:
+                    event.add_object(key, object_id, commit = False)
        
 
+       # Set the message for this event
+       if fault:
+           event['message'] = fault.faultString
+       elif hasattr(self, 'message'):
+            event['message'] = self.message    
+       
+        # Commit
+        event.sync()
+
     def help(self, indent = "  "):
         """
         Text documentation for the method.
@@ -222,7 +224,7 @@ class Method:
             elif isinstance(param, Mixed):
                 for subparam in param:
                     text += param_text(name, subparam, indent + step, step)
-            elif isinstance(param, (list, tuple)):
+            elif isinstance(param, (list, tuple, set)):
                 for subparam in param:
                     text += param_text("", subparam, indent + step, step)
 
@@ -260,7 +262,7 @@ class Method:
         
         return (min_args, max_args, defaults)
 
-    def type_check(self, name, value, expected, min = None, max = None):
+    def type_check(self, name, value, expected, args):
         """
         Checks the type of the named value against the expected type,
         which may be a Python type, a typed value, a Parameter, a
@@ -280,26 +282,36 @@ class Method:
         if isinstance(expected, Mixed):
             for item in expected:
                 try:
-                    self.type_check(name, value, item)
-                    expected = item
-                    break
+                    self.type_check(name, value, item, args)
+                    return
                 except PLCInvalidArgument, fault:
                     pass
-            if expected != item:
-                xmlrpc_types = [xmlrpc_type(item) for item in expected]
-                raise PLCInvalidArgument("expected %s, got %s" % \
-                                         (" or ".join(xmlrpc_types),
-                                          xmlrpc_type(type(value))),
-                                         name)
+            raise fault
+
+        # If an authentication structure is expected, save it and
+        # authenticate after basic type checking is done.
+        if isinstance(expected, Auth):
+            auth = expected
+        else:
+            auth = None
 
         # Get actual expected type from within the Parameter structure
-        elif isinstance(expected, Parameter):
+        if isinstance(expected, Parameter):
             min = expected.min
             max = expected.max
+            nullok = expected.nullok
             expected = expected.type
+        else:
+            min = None
+            max = None
+            nullok = False
 
         expected_type = python_type(expected)
 
+        # If value can be NULL
+        if value is None and nullok:
+            return
+
         # Strings are a special case. Accept either unicode or str
         # types if a string is expected.
         if expected_type in StringTypes and isinstance(value, StringTypes):
@@ -324,6 +336,11 @@ class Method:
             if max is not None and \
                len(value.encode(self.api.encoding)) > max:
                 raise PLCInvalidArgument, "%s must be at most %d bytes long" % (name, max)
+        elif expected_type in (list, tuple, set):
+            if min is not None and len(value) < min:
+                raise PLCInvalidArgument, "%s must contain at least %d items" % (name, min)
+            if max is not None and len(value) > max:
+                raise PLCInvalidArgument, "%s must contain at most %d items" % (name, max)
         else:
             if min is not None and value < min:
                 raise PLCInvalidArgument, "%s must be > %s" % (name, str(min))
@@ -331,62 +348,25 @@ class Method:
                 raise PLCInvalidArgument, "%s must be < %s" % (name, str(max))
 
         # If a list with particular types of items is expected
-        if isinstance(expected, (list, tuple)):
+        if isinstance(expected, (list, tuple, set)):
             for i in range(len(value)):
                 if i >= len(expected):
-                    i = len(expected) - 1
-                self.type_check(name + "[]", value[i], expected[i])
+                    j = len(expected) - 1
+                else:
+                    j = i
+                self.type_check(name + "[]", value[i], expected[j], args)
 
         # If a struct with particular (or required) types of items is
         # expected.
         elif isinstance(expected, dict):
             for key in value.keys():
                 if key in expected:
-                    self.type_check(name + "['%s']" % key, value[key], expected[key])
+                    self.type_check(name + "['%s']" % key, value[key], expected[key], args)
             for key, subparam in expected.iteritems():
                 if isinstance(subparam, Parameter) and \
+                   subparam.optional is not None and \
                    not subparam.optional and key not in value.keys():
                     raise PLCInvalidArgument("'%s' not specified" % key, name)
 
-def python_type(arg):
-    """
-    Returns the Python type of the specified argument, which may be a
-    Python type, a typed value, or a Parameter.
-    """
-
-    if isinstance(arg, Parameter):
-        arg = arg.type
-
-    if isinstance(arg, type):
-        return arg
-    else:
-        return type(arg)
-
-def xmlrpc_type(arg):
-    """
-    Returns the XML-RPC type of the specified argument, which may be a
-    Python type, a typed value, or a Parameter.
-    """
-
-    arg_type = python_type(arg)
-
-    if arg_type == NoneType:
-        return "nil"
-    elif arg_type == IntType or arg_type == LongType:
-        return "int"
-    elif arg_type == bool:
-        return "boolean"
-    elif arg_type == FloatType:
-        return "double"
-    elif arg_type in StringTypes:
-        return "string"
-    elif arg_type == ListType or arg_type == TupleType:
-        return "array"
-    elif arg_type == DictType:
-        return "struct"
-    elif arg_type == Mixed:
-        # Not really an XML-RPC type but return "mixed" for
-        # documentation purposes.
-        return "mixed"
-    else:
-        raise PLCAPIError, "XML-RPC cannot marshal %s objects" % arg_type
+        if auth is not None:
+            auth.check(self, *args)