Fix version output when missing.
[plcapi.git] / PLC / Sessions.py
1 # $Id$
2 # $URL$
3 from types import StringTypes
4 import random
5 import base64
6 import time
7
8 from PLC.Faults import *
9 from PLC.Parameter import Parameter
10 from PLC.Filter import Filter
11 from PLC.Debug import profile
12 from PLC.Table import Row, Table
13 from PLC.Persons import Person, Persons
14 from PLC.Nodes import Node, Nodes
15
16 class Session(Row):
17     """
18     Representation of a row in the sessions table. To use, instantiate
19     with a dict of values.
20     """
21
22     table_name = 'sessions'
23     primary_key = 'session_id'
24     join_tables = ['person_session', 'node_session']
25     fields = {
26         'session_id': Parameter(str, "Session key"),
27         'person_id': Parameter(int, "Account identifier, if applicable"),
28         'node_id': Parameter(int, "Node identifier, if applicable"),
29         'expires': Parameter(int, "Date and time when session expires, in seconds since UNIX epoch"),
30         }
31
32     def validate_expires(self, expires):
33         if expires < time.time():
34             raise PLCInvalidArgument, "Expiration date must be in the future"
35
36         return time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(expires))
37
38     add_person = Row.add_object(Person, 'person_session')
39
40     def add_node(self, node, commit = True):
41         # Nodes can have only one session at a time
42         self.api.db.do("DELETE FROM node_session WHERE node_id = %d" % \
43                        node['node_id'])
44
45         add = Row.add_object(Node, 'node_session')
46         add(self, node, commit = commit)
47
48     def sync(self, commit = True, insert = None):
49         if not self.has_key('session_id'):
50             # Before a new session is added, delete expired sessions
51             expired = Sessions(self.api, expires = -int(time.time()))
52             for session in expired:
53                 session.delete(commit)
54
55             # Generate 32 random bytes
56             bytes = random.sample(xrange(0, 256), 32)
57             # Base64 encode their string representation
58             self['session_id'] = base64.b64encode("".join(map(chr, bytes)))
59             # Force insert
60             insert = True
61
62         Row.sync(self, commit, insert)
63
64 class Sessions(Table):
65     """
66     Representation of row(s) from the session table in the database.
67     """
68
69     def __init__(self, api, session_filter = None, expires = int(time.time())):
70         Table.__init__(self, api, Session)
71
72         sql = "SELECT %s FROM view_sessions WHERE True" % \
73               ", ".join(Session.fields)
74
75         if session_filter is not None:
76             if isinstance(session_filter, (list, tuple, set)):
77                 # Separate the list into integers and strings
78                 ints = filter(lambda x: isinstance(x, (int, long)), session_filter)
79                 strs = filter(lambda x: isinstance(x, StringTypes), session_filter)
80                 session_filter = Filter(Session.fields, {'person_id': ints, 'session_id': strs})
81                 sql += " AND (%s) %s" % session_filter.sql(api, "OR")
82             elif isinstance(session_filter, dict):
83                 session_filter = Filter(Session.fields, session_filter)
84                 sql += " AND (%s) %s" % session_filter.sql(api, "AND")
85             elif isinstance(session_filter, (int, long)):
86                 session_filter = Filter(Session.fields, {'person_id': session_filter})
87                 sql += " AND (%s) %s" % session_filter.sql(api, "AND")
88             elif isinstance(session_filter, StringTypes):
89                 session_filter = Filter(Session.fields, {'session_id': session_filter})
90                 sql += " AND (%s) %s" % session_filter.sql(api, "AND")
91             else:
92                 raise PLCInvalidArgument, "Wrong session filter"%session_filter
93
94         if expires is not None:
95             if expires >= 0:
96                 sql += " AND expires > %(expires)d"
97             else:
98                 expires = -expires
99                 sql += " AND expires < %(expires)d"
100
101         self.selectall(sql, locals())