Setting tag plcapi-5.4-2
[plcapi.git] / PLC / Sessions.py
index b91ca12..6a03068 100644 (file)
@@ -1,9 +1,12 @@
+from types import StringTypes
 import random
 import base64
 import time
 
 from PLC.Faults import *
 from PLC.Parameter import Parameter
+from PLC.Filter import Filter
+from PLC.Debug import profile
 from PLC.Table import Row, Table
 from PLC.Persons import Person, Persons
 from PLC.Nodes import Node, Nodes
@@ -30,56 +33,20 @@ class Session(Row):
 
         return time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(expires))
 
-    def add_person(self, person, commit = True):
-        """
-        Associate person with session.
-        """
-
-        assert 'session_id' in self
-        assert isinstance(person, Person)
-        assert 'person_id' in person
-
-        session_id = self['session_id']
-        person_id = person['person_id']
-
-        self.api.db.do("INSERT INTO person_session (session_id, person_id)" \
-                       " VALUES(%(session_id)s, %(person_id)d)",
-                       locals())
-
-        if commit:
-            self.api.db.commit()
-
-        self['person_id'] = person_id
+    add_person = Row.add_object(Person, 'person_session')
 
     def add_node(self, node, commit = True):
-        """
-        Associate node with session.
-        """
-
-        assert 'session_id' in self
-        assert isinstance(node, Node)
-        assert 'node_id' in node
-
-        session_id = self['session_id']
-        node_id = node['node_id']
-
         # Nodes can have only one session at a time
-        self.api.db.do("DELETE FROM node_session WHERE node_id = %(node_id)d",
-                       locals())
-
-        self.api.db.do("INSERT INTO node_session (session_id, node_id)" \
-                       " VALUES(%(session_id)s, %(node_id)d)",
-                       locals())
+        self.api.db.do("DELETE FROM node_session WHERE node_id = %d" % \
+                       node['node_id'])
 
-        if commit:
-            self.api.db.commit()
-
-        self['node_id'] = node_id
+        add = Row.add_object(Node, 'node_session')
+        add(self, node, commit = commit)
 
     def sync(self, commit = True, insert = None):
         if not self.has_key('session_id'):
             # Before a new session is added, delete expired sessions
-            expired = Sessions(self.api, expires = -int(time.time())).values()
+            expired = Sessions(self.api, expires = -int(time.time()))
             for session in expired:
                 session.delete(commit)
 
@@ -97,14 +64,30 @@ class Sessions(Table):
     Representation of row(s) from the session table in the database.
     """
 
-    def __init__(self, api, session_ids = None, expires = int(time.time())):
-       self.api = api
+    def __init__(self, api, session_filter = None, expires = int(time.time())):
+        Table.__init__(self, api, Session)
 
         sql = "SELECT %s FROM view_sessions WHERE True" % \
               ", ".join(Session.fields)
 
-        if session_ids:
-            sql += " AND session_id IN (%s)" % ", ".join(api.db.quote(session_ids))
+        if session_filter is not None:
+            if isinstance(session_filter, (list, tuple, set)):
+                # Separate the list into integers and strings
+                ints = filter(lambda x: isinstance(x, (int, long)), session_filter)
+                strs = filter(lambda x: isinstance(x, StringTypes), session_filter)
+                session_filter = Filter(Session.fields, {'person_id': ints, 'session_id': strs})
+                sql += " AND (%s) %s" % session_filter.sql(api, "OR")
+            elif isinstance(session_filter, dict):
+                session_filter = Filter(Session.fields, session_filter)
+                sql += " AND (%s) %s" % session_filter.sql(api, "AND")
+            elif isinstance(session_filter, (int, long)):
+                session_filter = Filter(Session.fields, {'person_id': session_filter})
+                sql += " AND (%s) %s" % session_filter.sql(api, "AND")
+            elif isinstance(session_filter, StringTypes):
+                session_filter = Filter(Session.fields, {'session_id': session_filter})
+                sql += " AND (%s) %s" % session_filter.sql(api, "AND")
+            else:
+                raise PLCInvalidArgument, "Wrong session filter"%session_filter
 
         if expires is not None:
             if expires >= 0:
@@ -113,7 +96,4 @@ class Sessions(Table):
                 expires = -expires
                 sql += " AND expires < %(expires)d"
 
-        rows = self.api.db.selectall(sql, locals())
-        for row in rows:
-            self[row['session_id']] = Session(api, row)
+        self.selectall(sql, locals())