====
[plcapi.git] / PLC / Sessions.py
index 8c9701b..6a03068 100644 (file)
@@ -1,9 +1,12 @@
+from types import StringTypes
 import random
 import base64
 import time
 
 from PLC.Faults import *
 from PLC.Parameter import Parameter
 import random
 import base64
 import time
 
 from PLC.Faults import *
 from PLC.Parameter import Parameter
+from PLC.Filter import Filter
+from PLC.Debug import profile
 from PLC.Table import Row, Table
 from PLC.Persons import Person, Persons
 from PLC.Nodes import Node, Nodes
 from PLC.Table import Row, Table
 from PLC.Persons import Person, Persons
 from PLC.Nodes import Node, Nodes
@@ -37,8 +40,8 @@ class Session(Row):
         self.api.db.do("DELETE FROM node_session WHERE node_id = %d" % \
                        node['node_id'])
 
         self.api.db.do("DELETE FROM node_session WHERE node_id = %d" % \
                        node['node_id'])
 
-        add = Row.add_object(Node, 'person_session')
-        add(self, node, commit)
+        add = Row.add_object(Node, 'node_session')
+        add(self, node, commit = commit)
 
     def sync(self, commit = True, insert = None):
         if not self.has_key('session_id'):
 
     def sync(self, commit = True, insert = None):
         if not self.has_key('session_id'):
@@ -61,14 +64,30 @@ class Sessions(Table):
     Representation of row(s) from the session table in the database.
     """
 
     Representation of row(s) from the session table in the database.
     """
 
-    def __init__(self, api, session_ids = None, expires = int(time.time())):
-       Table.__init__(self, api, Session)
+    def __init__(self, api, session_filter = None, expires = int(time.time())):
+        Table.__init__(self, api, Session)
 
         sql = "SELECT %s FROM view_sessions WHERE True" % \
               ", ".join(Session.fields)
 
 
         sql = "SELECT %s FROM view_sessions WHERE True" % \
               ", ".join(Session.fields)
 
-        if session_ids:
-            sql += " AND session_id IN (%s)" % ", ".join(map(api.db.quote, session_ids))
+        if session_filter is not None:
+            if isinstance(session_filter, (list, tuple, set)):
+                # Separate the list into integers and strings
+                ints = filter(lambda x: isinstance(x, (int, long)), session_filter)
+                strs = filter(lambda x: isinstance(x, StringTypes), session_filter)
+                session_filter = Filter(Session.fields, {'person_id': ints, 'session_id': strs})
+                sql += " AND (%s) %s" % session_filter.sql(api, "OR")
+            elif isinstance(session_filter, dict):
+                session_filter = Filter(Session.fields, session_filter)
+                sql += " AND (%s) %s" % session_filter.sql(api, "AND")
+            elif isinstance(session_filter, (int, long)):
+                session_filter = Filter(Session.fields, {'person_id': session_filter})
+                sql += " AND (%s) %s" % session_filter.sql(api, "AND")
+            elif isinstance(session_filter, StringTypes):
+                session_filter = Filter(Session.fields, {'session_id': session_filter})
+                sql += " AND (%s) %s" % session_filter.sql(api, "AND")
+            else:
+                raise PLCInvalidArgument, "Wrong session filter"%session_filter
 
         if expires is not None:
             if expires >= 0:
 
         if expires is not None:
             if expires >= 0: