Setting tag plcapi-5.4-2
[plcapi.git] / PLC / Sessions.py
1 from types import StringTypes
2 import random
3 import base64
4 import time
5
6 from PLC.Faults import *
7 from PLC.Parameter import Parameter
8 from PLC.Filter import Filter
9 from PLC.Debug import profile
10 from PLC.Table import Row, Table
11 from PLC.Persons import Person, Persons
12 from PLC.Nodes import Node, Nodes
13
14 class Session(Row):
15     """
16     Representation of a row in the sessions table. To use, instantiate
17     with a dict of values.
18     """
19
20     table_name = 'sessions'
21     primary_key = 'session_id'
22     join_tables = ['person_session', 'node_session']
23     fields = {
24         'session_id': Parameter(str, "Session key"),
25         'person_id': Parameter(int, "Account identifier, if applicable"),
26         'node_id': Parameter(int, "Node identifier, if applicable"),
27         'expires': Parameter(int, "Date and time when session expires, in seconds since UNIX epoch"),
28         }
29
30     def validate_expires(self, expires):
31         if expires < time.time():
32             raise PLCInvalidArgument, "Expiration date must be in the future"
33
34         return time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(expires))
35
36     add_person = Row.add_object(Person, 'person_session')
37
38     def add_node(self, node, commit = True):
39         # Nodes can have only one session at a time
40         self.api.db.do("DELETE FROM node_session WHERE node_id = %d" % \
41                        node['node_id'])
42
43         add = Row.add_object(Node, 'node_session')
44         add(self, node, commit = commit)
45
46     def sync(self, commit = True, insert = None):
47         if not self.has_key('session_id'):
48             # Before a new session is added, delete expired sessions
49             expired = Sessions(self.api, expires = -int(time.time()))
50             for session in expired:
51                 session.delete(commit)
52
53             # Generate 32 random bytes
54             bytes = random.sample(xrange(0, 256), 32)
55             # Base64 encode their string representation
56             self['session_id'] = base64.b64encode("".join(map(chr, bytes)))
57             # Force insert
58             insert = True
59
60         Row.sync(self, commit, insert)
61
62 class Sessions(Table):
63     """
64     Representation of row(s) from the session table in the database.
65     """
66
67     def __init__(self, api, session_filter = None, expires = int(time.time())):
68         Table.__init__(self, api, Session)
69
70         sql = "SELECT %s FROM view_sessions WHERE True" % \
71               ", ".join(Session.fields)
72
73         if session_filter is not None:
74             if isinstance(session_filter, (list, tuple, set)):
75                 # Separate the list into integers and strings
76                 ints = filter(lambda x: isinstance(x, (int, long)), session_filter)
77                 strs = filter(lambda x: isinstance(x, StringTypes), session_filter)
78                 session_filter = Filter(Session.fields, {'person_id': ints, 'session_id': strs})
79                 sql += " AND (%s) %s" % session_filter.sql(api, "OR")
80             elif isinstance(session_filter, dict):
81                 session_filter = Filter(Session.fields, session_filter)
82                 sql += " AND (%s) %s" % session_filter.sql(api, "AND")
83             elif isinstance(session_filter, (int, long)):
84                 session_filter = Filter(Session.fields, {'person_id': session_filter})
85                 sql += " AND (%s) %s" % session_filter.sql(api, "AND")
86             elif isinstance(session_filter, StringTypes):
87                 session_filter = Filter(Session.fields, {'session_id': session_filter})
88                 sql += " AND (%s) %s" % session_filter.sql(api, "AND")
89             else:
90                 raise PLCInvalidArgument, "Wrong session filter"%session_filter
91
92         if expires is not None:
93             if expires >= 0:
94                 sql += " AND expires > %(expires)d"
95             else:
96                 expires = -expires
97                 sql += " AND expires < %(expires)d"
98
99         self.selectall(sql, locals())