Setting tag plcapi-5.4-2
[plcapi.git] / PLC / Sessions.py
index b91ca12..6a03068 100644 (file)
@@ -1,9 +1,12 @@
+from types import StringTypes
 import random
 import base64
 import time
 
 from PLC.Faults import *
 from PLC.Parameter import Parameter
 import random
 import base64
 import time
 
 from PLC.Faults import *
 from PLC.Parameter import Parameter
+from PLC.Filter import Filter
+from PLC.Debug import profile
 from PLC.Table import Row, Table
 from PLC.Persons import Person, Persons
 from PLC.Nodes import Node, Nodes
 from PLC.Table import Row, Table
 from PLC.Persons import Person, Persons
 from PLC.Nodes import Node, Nodes
@@ -30,56 +33,20 @@ class Session(Row):
 
         return time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(expires))
 
 
         return time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(expires))
 
-    def add_person(self, person, commit = True):
-        """
-        Associate person with session.
-        """
-
-        assert 'session_id' in self
-        assert isinstance(person, Person)
-        assert 'person_id' in person
-
-        session_id = self['session_id']
-        person_id = person['person_id']
-
-        self.api.db.do("INSERT INTO person_session (session_id, person_id)" \
-                       " VALUES(%(session_id)s, %(person_id)d)",
-                       locals())
-
-        if commit:
-            self.api.db.commit()
-
-        self['person_id'] = person_id
+    add_person = Row.add_object(Person, 'person_session')
 
     def add_node(self, node, commit = True):
 
     def add_node(self, node, commit = True):
-        """
-        Associate node with session.
-        """
-
-        assert 'session_id' in self
-        assert isinstance(node, Node)
-        assert 'node_id' in node
-
-        session_id = self['session_id']
-        node_id = node['node_id']
-
         # Nodes can have only one session at a time
         # Nodes can have only one session at a time
-        self.api.db.do("DELETE FROM node_session WHERE node_id = %(node_id)d",
-                       locals())
-
-        self.api.db.do("INSERT INTO node_session (session_id, node_id)" \
-                       " VALUES(%(session_id)s, %(node_id)d)",
-                       locals())
+        self.api.db.do("DELETE FROM node_session WHERE node_id = %d" % \
+                       node['node_id'])
 
 
-        if commit:
-            self.api.db.commit()
-
-        self['node_id'] = node_id
+        add = Row.add_object(Node, 'node_session')
+        add(self, node, commit = commit)
 
     def sync(self, commit = True, insert = None):
         if not self.has_key('session_id'):
             # Before a new session is added, delete expired sessions
 
     def sync(self, commit = True, insert = None):
         if not self.has_key('session_id'):
             # Before a new session is added, delete expired sessions
-            expired = Sessions(self.api, expires = -int(time.time())).values()
+            expired = Sessions(self.api, expires = -int(time.time()))
             for session in expired:
                 session.delete(commit)
 
             for session in expired:
                 session.delete(commit)
 
@@ -97,14 +64,30 @@ class Sessions(Table):
     Representation of row(s) from the session table in the database.
     """
 
     Representation of row(s) from the session table in the database.
     """
 
-    def __init__(self, api, session_ids = None, expires = int(time.time())):
-       self.api = api
+    def __init__(self, api, session_filter = None, expires = int(time.time())):
+        Table.__init__(self, api, Session)
 
         sql = "SELECT %s FROM view_sessions WHERE True" % \
               ", ".join(Session.fields)
 
 
         sql = "SELECT %s FROM view_sessions WHERE True" % \
               ", ".join(Session.fields)
 
-        if session_ids:
-            sql += " AND session_id IN (%s)" % ", ".join(api.db.quote(session_ids))
+        if session_filter is not None:
+            if isinstance(session_filter, (list, tuple, set)):
+                # Separate the list into integers and strings
+                ints = filter(lambda x: isinstance(x, (int, long)), session_filter)
+                strs = filter(lambda x: isinstance(x, StringTypes), session_filter)
+                session_filter = Filter(Session.fields, {'person_id': ints, 'session_id': strs})
+                sql += " AND (%s) %s" % session_filter.sql(api, "OR")
+            elif isinstance(session_filter, dict):
+                session_filter = Filter(Session.fields, session_filter)
+                sql += " AND (%s) %s" % session_filter.sql(api, "AND")
+            elif isinstance(session_filter, (int, long)):
+                session_filter = Filter(Session.fields, {'person_id': session_filter})
+                sql += " AND (%s) %s" % session_filter.sql(api, "AND")
+            elif isinstance(session_filter, StringTypes):
+                session_filter = Filter(Session.fields, {'session_id': session_filter})
+                sql += " AND (%s) %s" % session_filter.sql(api, "AND")
+            else:
+                raise PLCInvalidArgument, "Wrong session filter"%session_filter
 
         if expires is not None:
             if expires >= 0:
 
         if expires is not None:
             if expires >= 0:
@@ -113,7 +96,4 @@ class Sessions(Table):
                 expires = -expires
                 sql += " AND expires < %(expires)d"
 
                 expires = -expires
                 sql += " AND expires < %(expires)d"
 
-        rows = self.api.db.selectall(sql, locals())
-        for row in rows:
-            self[row['session_id']] = Session(api, row)
+        self.selectall(sql, locals())