- re-enable return_fields specification
[plcapi.git] / PLC / Sessions.py
1 import random
2 import base64
3 import time
4
5 from PLC.Faults import *
6 from PLC.Parameter import Parameter
7 from PLC.Table import Row, Table
8 from PLC.Persons import Person, Persons
9 from PLC.Nodes import Node, Nodes
10
11 class Session(Row):
12     """
13     Representation of a row in the sessions table. To use, instantiate
14     with a dict of values.
15     """
16
17     table_name = 'sessions'
18     primary_key = 'session_id'
19     join_tables = ['person_session', 'node_session']
20     fields = {
21         'session_id': Parameter(str, "Session key"),
22         'person_id': Parameter(int, "Account identifier, if applicable"),
23         'node_id': Parameter(int, "Node identifier, if applicable"),
24         'expires': Parameter(int, "Date and time when session expires, in seconds since UNIX epoch"),
25         }
26
27     def validate_expires(self, expires):
28         if expires < time.time():
29             raise PLCInvalidArgument, "Expiration date must be in the future"
30
31         return time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(expires))
32
33     def add_person(self, person, commit = True):
34         """
35         Associate person with session.
36         """
37
38         assert 'session_id' in self
39         assert isinstance(person, Person)
40         assert 'person_id' in person
41
42         session_id = self['session_id']
43         person_id = person['person_id']
44
45         self.api.db.do("INSERT INTO person_session (session_id, person_id)" \
46                        " VALUES(%(session_id)s, %(person_id)d)",
47                        locals())
48
49         if commit:
50             self.api.db.commit()
51
52         self['person_id'] = person_id
53
54     def add_node(self, node, commit = True):
55         """
56         Associate node with session.
57         """
58
59         assert 'session_id' in self
60         assert isinstance(node, Node)
61         assert 'node_id' in node
62
63         session_id = self['session_id']
64         node_id = node['node_id']
65
66         # Nodes can have only one session at a time
67         self.api.db.do("DELETE FROM node_session WHERE node_id = %(node_id)d",
68                        locals())
69
70         self.api.db.do("INSERT INTO node_session (session_id, node_id)" \
71                        " VALUES(%(session_id)s, %(node_id)d)",
72                        locals())
73
74         if commit:
75             self.api.db.commit()
76
77         self['node_id'] = node_id
78
79     def sync(self, commit = True, insert = None):
80         if not self.has_key('session_id'):
81             # Before a new session is added, delete expired sessions
82             expired = Sessions(self.api, expires = -int(time.time()))
83             for session in expired:
84                 session.delete(commit)
85
86             # Generate 32 random bytes
87             bytes = random.sample(xrange(0, 256), 32)
88             # Base64 encode their string representation
89             self['session_id'] = base64.b64encode("".join(map(chr, bytes)))
90             # Force insert
91             insert = True
92
93         Row.sync(self, commit, insert)
94
95 class Sessions(Table):
96     """
97     Representation of row(s) from the session table in the database.
98     """
99
100     def __init__(self, api, session_ids = None, expires = int(time.time())):
101         Table.__init__(self, api, Session)
102
103         sql = "SELECT %s FROM view_sessions WHERE True" % \
104               ", ".join(Session.fields)
105
106         if session_ids:
107             sql += " AND session_id IN (%s)" % ", ".join(map(api.db.quote, session_ids))
108
109         if expires is not None:
110             if expires >= 0:
111                 sql += " AND expires > %(expires)d"
112             else:
113                 expires = -expires
114                 sql += " AND expires < %(expires)d"
115
116         self.selectall(sql, locals())