patch for php-5.3 (the one in f12)
[plcapi.git] / PLC / Method.py
index 9b4a440..ed03974 100644 (file)
@@ -4,7 +4,8 @@
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
-# $Id: Method.py,v 1.9 2006/10/19 17:02:42 tmack Exp $
+# $Id$
+# $URL$
 #
 
 import xmlrpclib
 #
 
 import xmlrpclib
@@ -12,13 +13,20 @@ from types import *
 import textwrap
 import os
 import time
 import textwrap
 import os
 import time
+import pprint
+
+from types import StringTypes
 
 from PLC.Faults import *
 
 from PLC.Faults import *
-from PLC.Parameter import Parameter, Mixed
+from PLC.Parameter import Parameter, Mixed, python_type, xmlrpc_type
 from PLC.Auth import Auth
 from PLC.Debug import profile, log
 from PLC.Auth import Auth
 from PLC.Debug import profile, log
+from PLC.Events import Event, Events
+from PLC.Nodes import Node, Nodes
+from PLC.Persons import Person, Persons
 
 
-class Method:
+# we inherit object because we use new-style classes for legacy methods
+class Method (object):
     """
     Base class for all PLCAPI functions. At a minimum, all PLCAPI
     functions must define:
     """
     Base class for all PLCAPI functions. At a minimum, all PLCAPI
     functions must define:
@@ -65,7 +73,6 @@ class Method:
 
         # API may set this to a (addr, port) tuple if known
         self.source = None
 
         # API may set this to a (addr, port) tuple if known
         self.source = None
-
        
     def __call__(self, *args, **kwds):
         """
        
     def __call__(self, *args, **kwds):
         """
@@ -75,93 +82,111 @@ class Method:
 
         try:
            start = time.time()
 
         try:
            start = time.time()
-            (min_args, max_args, defaults) = self.args()
+
+            # legacy code cannot be type-checked, due to the way Method.args() works
+            if not hasattr(self,"skip_typecheck"):
+                (min_args, max_args, defaults) = self.args()
                                
                                
-           # Check that the right number of arguments were passed in
-            if len(args) < len(min_args) or len(args) > len(max_args):
-                raise PLCInvalidArgumentCount(len(args), len(min_args), len(max_args))
+                # Check that the right number of arguments were passed in
+                if len(args) < len(min_args) or len(args) > len(max_args):
+                    raise PLCInvalidArgumentCount(len(args), len(min_args), len(max_args))
 
 
-            for name, value, expected in zip(max_args, args, self.accepts):
-                self.type_check(name, value, expected)
+                for name, value, expected in zip(max_args, args, self.accepts):
+                    self.type_check(name, value, expected, args)
        
        
-            # The first argument to all methods that require
-            # authentication, should be an Auth structure. The rest of the
-            # arguments to the call may also be used in the authentication
-            # check. For example, calls made by the Boot Manager are
-            # verified by comparing a hash of the message parameters to
-            # the value in the authentication structure.        
-
-            if len(self.accepts):
-                auth = None
-                if isinstance(self.accepts[0], Auth):
-                    auth = self.accepts[0]
-                elif isinstance(self.accepts[0], Mixed):
-                    for auth in self.accepts[0]:
-                        if isinstance(auth, Auth):
-                            break
-                if isinstance(auth, Auth):
-                    auth.check(self, *args)
-          
            result = self.call(*args, **kwds)
            runtime = time.time() - start
            result = self.call(*args, **kwds)
            runtime = time.time() - start
-
-           if self.api.config.PLC_API_DEBUG:
-               self.log(0, runtime, *args)
+       
+            if self.api.config.PLC_API_DEBUG or hasattr(self, 'message'):
+               self.log(None, runtime, *args)
                
            return result
 
         except PLCFault, fault:
                
            return result
 
         except PLCFault, fault:
-            # Prepend method name to expected faults
-            fault.faultString = self.name + ": " + fault.faultString
+       
+           caller = ""
+           if isinstance(self.caller, Person):
+               caller = 'person_id %s'  % self.caller['person_id']
+            elif isinstance(self.caller, Node):
+                caller = 'node_id %s'  % self.caller['node_id']
+
+            # Prepend caller and method name to expected faults
+            fault.faultString = caller + ": " +  self.name + ": " + fault.faultString
            runtime = time.time() - start
            runtime = time.time() - start
-           self.log(fault.faultCode, runtime, *args)
-            raise fault
-
+           
+           if self.api.config.PLC_API_DEBUG:
+               self.log(fault, runtime, *args)
+            
+           raise fault
 
 
-    def log(self, fault_code, runtime, *args):
+    def log(self, fault, runtime, *args):
         """
         Log the transaction 
         """    
         """
         Log the transaction 
         """    
-       # Gather necessary logging variables
-       event_type = 'Unknown'
-       object_type = 'Unknown'
-       person_id = None
-       object_ids = []
-       call_name = self.name
-       call_args = ", ".join([unicode(arg) for arg in list(args)[1:]])
-       call = "%s(%s)" % (call_name, call_args)
-               
-       if hasattr(self, 'event_type'):
-               event_type = self.event_type
-        if hasattr(self, 'object_type'):
-               object_type = self.object_type
-       if self.caller:
-               person_id = self.caller['person_id']
-       if hasattr(self, 'object_ids'):
-               object_ids = self.object_ids 
-
-       # do not log system calls
-        if call_name.startswith('system'):
-               return False
-       # do not log get calls
-       if call_name.startswith('Get'):
-               return False
+
+       # Do not log system or Get calls
+        #if self.name.startswith('system') or self.name.startswith('Get'):
+        #    return False
+        # Do not log ReportRunlevel 
+        if self.name.startswith('system'):
+            return False
+        if self.name.startswith('ReportRunlevel'):
+            return False
+
+        # Create a new event
+        event = Event(self.api)
+       event['fault_code'] = 0
+       if fault:
+            event['fault_code'] = fault.faultCode
+        event['runtime'] = runtime
+
+        # Redact passwords and sessions
+        newargs = args
+        if args:
+            newargs = []
+            for arg in args:
+                if not isinstance(arg, dict):
+                    newargs.append(arg)
+                    continue
+                # what type of auth this is
+                if arg.has_key('AuthMethod'):
+                    auth_methods = ['session', 'password', 'capability', 'gpg', 'hmac','anonymous']
+                    auth_method = arg['AuthMethod']
+                    if auth_method in auth_methods:
+                        event['auth_type'] = auth_method
+                for password in 'AuthString', 'session', 'password':
+                    if arg.has_key(password):
+                        arg = arg.copy()
+                        arg[password] = "Removed by API"
+                newargs.append(arg)
+
+        # Log call representation
+        # XXX Truncate to avoid DoS
+        event['call'] = self.name + pprint.saferepr(newargs)
+       event['call_name'] = self.name
+
+        # Both users and nodes can call some methods
+        if isinstance(self.caller, Person):
+            event['person_id'] = self.caller['person_id']
+        elif isinstance(self.caller, Node):
+            event['node_id'] = self.caller['node_id']
+
+        event.sync(commit = False)
+
+        if hasattr(self, 'event_objects') and isinstance(self.event_objects, dict):
+            for key in self.event_objects.keys():
+               for object_id in self.event_objects[key]:
+                    event.add_object(key, object_id, commit = False)
        
        
-       sql_event = "INSERT INTO events " \
-              " (person_id, event_type, object_type, fault_code, call, runtime) VALUES" \
-              " (%(person_id)s, %(event_type)s, %(object_type)s," \
-             "  %(fault_code)d, %(call)s, %(runtime)f)" 
-       self.api.db.do(sql_event, locals())     
-
-       # log objects affected
-       for object_id in object_ids:
-               event_id =  self.api.db.last_insert_id('events', 'event_id')
-               sql_objects = "INSERT INTO event_object (event_id, object_id) VALUES" \
-                        " (%(event_id)d, %(object_id)d) "  % (locals()) 
-               self.api.db.do(sql_objects)
-                       
-        self.api.db.commit()           
+
+       # Set the message for this event
+       if fault:
+           event['message'] = fault.faultString
+       elif hasattr(self, 'message'):
+            event['message'] = self.message    
        
        
+        # Commit
+        event.sync()
 
     def help(self, indent = "  "):
         """
 
     def help(self, indent = "  "):
         """
@@ -218,7 +243,7 @@ class Method:
             elif isinstance(param, Mixed):
                 for subparam in param:
                     text += param_text(name, subparam, indent + step, step)
             elif isinstance(param, Mixed):
                 for subparam in param:
                     text += param_text(name, subparam, indent + step, step)
-            elif isinstance(param, (list, tuple)):
+            elif isinstance(param, (list, tuple, set)):
                 for subparam in param:
                     text += param_text("", subparam, indent + step, step)
 
                 for subparam in param:
                     text += param_text("", subparam, indent + step, step)
 
@@ -244,7 +269,7 @@ class Method:
         That represents the minimum and maximum sets of arguments that
         this function accepts and the defaults for the optional arguments.
         """
         That represents the minimum and maximum sets of arguments that
         this function accepts and the defaults for the optional arguments.
         """
-
+        
         # Inspect call. Remove self from the argument list.
         max_args = self.call.func_code.co_varnames[1:self.call.func_code.co_argcount]
         defaults = self.call.func_defaults
         # Inspect call. Remove self from the argument list.
         max_args = self.call.func_code.co_varnames[1:self.call.func_code.co_argcount]
         defaults = self.call.func_defaults
@@ -256,7 +281,7 @@ class Method:
         
         return (min_args, max_args, defaults)
 
         
         return (min_args, max_args, defaults)
 
-    def type_check(self, name, value, expected, min = None, max = None):
+    def type_check(self, name, value, expected, args):
         """
         Checks the type of the named value against the expected type,
         which may be a Python type, a typed value, a Parameter, a
         """
         Checks the type of the named value against the expected type,
         which may be a Python type, a typed value, a Parameter, a
@@ -276,26 +301,36 @@ class Method:
         if isinstance(expected, Mixed):
             for item in expected:
                 try:
         if isinstance(expected, Mixed):
             for item in expected:
                 try:
-                    self.type_check(name, value, item)
-                    expected = item
-                    break
+                    self.type_check(name, value, item, args)
+                    return
                 except PLCInvalidArgument, fault:
                     pass
                 except PLCInvalidArgument, fault:
                     pass
-            if expected != item:
-                xmlrpc_types = [xmlrpc_type(item) for item in expected]
-                raise PLCInvalidArgument("expected %s, got %s" % \
-                                         (" or ".join(xmlrpc_types),
-                                          xmlrpc_type(type(value))),
-                                         name)
+            raise fault
+
+        # If an authentication structure is expected, save it and
+        # authenticate after basic type checking is done.
+        if isinstance(expected, Auth):
+            auth = expected
+        else:
+            auth = None
 
         # Get actual expected type from within the Parameter structure
 
         # Get actual expected type from within the Parameter structure
-        elif isinstance(expected, Parameter):
+        if isinstance(expected, Parameter):
             min = expected.min
             max = expected.max
             min = expected.min
             max = expected.max
+            nullok = expected.nullok
             expected = expected.type
             expected = expected.type
+        else:
+            min = None
+            max = None
+            nullok = False
 
         expected_type = python_type(expected)
 
 
         expected_type = python_type(expected)
 
+        # If value can be NULL
+        if value is None and nullok:
+            return
+
         # Strings are a special case. Accept either unicode or str
         # types if a string is expected.
         if expected_type in StringTypes and isinstance(value, StringTypes):
         # Strings are a special case. Accept either unicode or str
         # types if a string is expected.
         if expected_type in StringTypes and isinstance(value, StringTypes):
@@ -320,6 +355,11 @@ class Method:
             if max is not None and \
                len(value.encode(self.api.encoding)) > max:
                 raise PLCInvalidArgument, "%s must be at most %d bytes long" % (name, max)
             if max is not None and \
                len(value.encode(self.api.encoding)) > max:
                 raise PLCInvalidArgument, "%s must be at most %d bytes long" % (name, max)
+        elif expected_type in (list, tuple, set):
+            if min is not None and len(value) < min:
+                raise PLCInvalidArgument, "%s must contain at least %d items" % (name, min)
+            if max is not None and len(value) > max:
+                raise PLCInvalidArgument, "%s must contain at most %d items" % (name, max)
         else:
             if min is not None and value < min:
                 raise PLCInvalidArgument, "%s must be > %s" % (name, str(min))
         else:
             if min is not None and value < min:
                 raise PLCInvalidArgument, "%s must be > %s" % (name, str(min))
@@ -327,62 +367,25 @@ class Method:
                 raise PLCInvalidArgument, "%s must be < %s" % (name, str(max))
 
         # If a list with particular types of items is expected
                 raise PLCInvalidArgument, "%s must be < %s" % (name, str(max))
 
         # If a list with particular types of items is expected
-        if isinstance(expected, (list, tuple)):
+        if isinstance(expected, (list, tuple, set)):
             for i in range(len(value)):
                 if i >= len(expected):
             for i in range(len(value)):
                 if i >= len(expected):
-                    i = len(expected) - 1
-                self.type_check(name + "[]", value[i], expected[i])
+                    j = len(expected) - 1
+                else:
+                    j = i
+                self.type_check(name + "[]", value[i], expected[j], args)
 
         # If a struct with particular (or required) types of items is
         # expected.
         elif isinstance(expected, dict):
             for key in value.keys():
                 if key in expected:
 
         # If a struct with particular (or required) types of items is
         # expected.
         elif isinstance(expected, dict):
             for key in value.keys():
                 if key in expected:
-                    self.type_check(name + "['%s']" % key, value[key], expected[key])
+                    self.type_check(name + "['%s']" % key, value[key], expected[key], args)
             for key, subparam in expected.iteritems():
                 if isinstance(subparam, Parameter) and \
             for key, subparam in expected.iteritems():
                 if isinstance(subparam, Parameter) and \
+                   subparam.optional is not None and \
                    not subparam.optional and key not in value.keys():
                     raise PLCInvalidArgument("'%s' not specified" % key, name)
 
                    not subparam.optional and key not in value.keys():
                     raise PLCInvalidArgument("'%s' not specified" % key, name)
 
-def python_type(arg):
-    """
-    Returns the Python type of the specified argument, which may be a
-    Python type, a typed value, or a Parameter.
-    """
-
-    if isinstance(arg, Parameter):
-        arg = arg.type
-
-    if isinstance(arg, type):
-        return arg
-    else:
-        return type(arg)
-
-def xmlrpc_type(arg):
-    """
-    Returns the XML-RPC type of the specified argument, which may be a
-    Python type, a typed value, or a Parameter.
-    """
-
-    arg_type = python_type(arg)
-
-    if arg_type == NoneType:
-        return "nil"
-    elif arg_type == IntType or arg_type == LongType:
-        return "int"
-    elif arg_type == bool:
-        return "boolean"
-    elif arg_type == FloatType:
-        return "double"
-    elif arg_type in StringTypes:
-        return "string"
-    elif arg_type == ListType or arg_type == TupleType:
-        return "array"
-    elif arg_type == DictType:
-        return "struct"
-    elif arg_type == Mixed:
-        # Not really an XML-RPC type but return "mixed" for
-        # documentation purposes.
-        return "mixed"
-    else:
-        raise PLCAPIError, "XML-RPC cannot marshal %s objects" % arg_type
+        if auth is not None:
+            auth.check(self, *args)