Setting tag plcapi-7.0-0
[plcapi.git] / PLC / Sessions.py
1 import random
2 import base64
3 import time
4
5 from PLC.Faults import *
6 from PLC.Parameter import Parameter
7 from PLC.Filter import Filter
8 from PLC.Debug import profile
9 from PLC.Table import Row, Table
10 from PLC.Persons import Person, Persons
11 from PLC.Nodes import Node, Nodes
12
13 class Session(Row):
14     """
15     Representation of a row in the sessions table. To use, instantiate
16     with a dict of values.
17     """
18
19     table_name = 'sessions'
20     primary_key = 'session_id'
21     join_tables = ['person_session', 'node_session']
22     fields = {
23         'session_id': Parameter(str, "Session key"),
24         'person_id': Parameter(int, "Account identifier, if applicable"),
25         'node_id': Parameter(int, "Node identifier, if applicable"),
26         'expires': Parameter(int, "Date and time when session expires, in seconds since UNIX epoch"),
27         }
28
29     def validate_expires(self, expires):
30         if expires < time.time():
31             raise PLCInvalidArgument("Expiration date must be in the future")
32
33         return time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(expires))
34
35     add_person = Row.add_object(Person, 'person_session')
36
37     def add_node(self, node, commit = True):
38         # Nodes can have only one session at a time
39         self.api.db.do("DELETE FROM node_session WHERE node_id = %d" % \
40                        node['node_id'])
41
42         add = Row.add_object(Node, 'node_session')
43         add(self, node, commit = commit)
44
45     def sync(self, commit = True, insert = None):
46         if 'session_id' not in self:
47             # Before a new session is added, delete expired sessions
48             expired = Sessions(self.api, expires = -int(time.time()))
49             for session in expired:
50                 session.delete(commit)
51
52             # Generate 32 random bytes
53             int8s = random.sample(range(0, 256), 32)
54             # Base64 encode their string representation
55             self['session_id'] = base64.b64encode(bytes(int8s)).decode()
56             # Force insert
57             insert = True
58
59         Row.sync(self, commit, insert)
60
61 class Sessions(Table):
62     """
63     Representation of row(s) from the session table in the database.
64     """
65
66     def __init__(self, api, session_filter = None, expires = int(time.time())):
67         Table.__init__(self, api, Session)
68
69         sql = "SELECT %s FROM view_sessions WHERE True" % \
70               ", ".join(Session.fields)
71
72         if session_filter is not None:
73             if isinstance(session_filter, (list, tuple, set)):
74                 # Separate the list into integers and strings
75                 ints = [x for x in session_filter if isinstance(x, int)]
76                 strs = [x for x in session_filter if isinstance(x, str)]
77                 session_filter = Filter(Session.fields, {'person_id': ints, 'session_id': strs})
78                 sql += " AND (%s) %s" % session_filter.sql(api, "OR")
79             elif isinstance(session_filter, dict):
80                 session_filter = Filter(Session.fields, session_filter)
81                 sql += " AND (%s) %s" % session_filter.sql(api, "AND")
82             elif isinstance(session_filter, int):
83                 session_filter = Filter(Session.fields, {'person_id': session_filter})
84                 sql += " AND (%s) %s" % session_filter.sql(api, "AND")
85             elif isinstance(session_filter, str):
86                 session_filter = Filter(Session.fields, {'session_id': session_filter})
87                 sql += " AND (%s) %s" % session_filter.sql(api, "AND")
88             else:
89                 raise PLCInvalidArgument("Wrong session filter"%session_filter)
90
91         if expires is not None:
92             if expires >= 0:
93                 sql += " AND expires > %(expires)d"
94             else:
95                 expires = -expires
96                 sql += " AND expires < %(expires)d"
97
98         self.selectall(sql, locals())