Merge commit 'origin/master'
[plcapi.git] / aspects / ratelimitaspects.py
index 8be22e5..ee2178b 100644 (file)
@@ -25,6 +25,12 @@ class BaseRateLimit(object):
 
         self.whitelist = []
 
+    def log(self, line)
+        log = open("/var/log/plc_ratelimit.log", "a")
+        date = datetime.now().strftime("%d/%m/%y %H:%M")
+        log.write("%s - %s\n" % (date, line))
+        log.flush()
+
     def before(self, wobj, data, *args, **kwargs):
         # ratelimit_128.112.139.115_201011091532 = 1
         # ratelimit_128.112.139.115_201011091533 = 14
@@ -34,19 +40,27 @@ class BaseRateLimit(object):
 
         api_method_name = wobj.name
         api_method_source = wobj.source
+        api_method_caller = args[0]["Username"]
 
         if api_method_source == None or api_method_source[0] == self.config.PLC_API_IP or api_method_source[0] in self.whitelist:
             return
 
+        if api_method_caller == None:
+            self.log("%s called with Username = None" % api_method_source[0])
+            return
+
         mc = memcache.Client(["%s:11211" % self.config.PLC_API_HOST])
         now = datetime.now()
-        current_key = "%s_%s_%s" % (self.prefix, api_method_source[0], now.strftime("%Y%m%d%H%M"))
+        current_key = "%s_%s_%s_%s" % (self.prefix, api_method_caller, api_method_source[0], now.strftime("%Y%m%d%H%M"))
 
-        keys_to_check = ["%s_%s_%s" % (self.prefix, api_method_source[0], (now - timedelta(minutes = minute)).strftime("%Y%m%d%H%M")) for minute in range(self.minutes + 1)]
+        keys_to_check = ["%s_%s_%s_%s" % (self.prefix, api_method_caller, api_method_source[0], (now - timedelta(minutes = minute)).strftime("%Y%m%d%H%M")) for minute in range(self.minutes + 1)]
 
         try:
-            mc.incr(current_key)
+            value = mc.incr(current_key)
         except ValueError:
+            value = None
+
+        if value == None:
             mc.set(current_key, 1, time=self.expire_after)
 
         result = mc.get_multi(keys_to_check)
@@ -55,10 +69,7 @@ class BaseRateLimit(object):
             total_requests += result[i]
 
         if total_requests > self.requests:
-            log = open("/var/log/plc_api_ratelimit.log", "a")
-            date = datetime.now().strftime("%d/%m/%y %H:%M")
-            log.write("%s - %s\n" % (date, api_method_source[0]))
-            log.flush()
+            self.log("%s - %s" % (api_method_source[0], api_method_caller))
             raise PLCPermissionDenied, "Maximum allowed number of API calls exceeded"
 
     def after(self, wobj, data, *args, **kwargs):