a4b82cc2ba9efc0d003b369098cf5788480cb0fb
[plcapi.git] / PLC / Sessions.py
1 # $Id$
2 from types import StringTypes
3 import random
4 import base64
5 import time
6
7 from PLC.Faults import *
8 from PLC.Parameter import Parameter
9 from PLC.Filter import Filter
10 from PLC.Debug import profile
11 from PLC.Table import Row, Table
12 from PLC.Persons import Person, Persons
13 from PLC.Nodes import Node, Nodes
14
15 class Session(Row):
16     """
17     Representation of a row in the sessions table. To use, instantiate
18     with a dict of values.
19     """
20
21     table_name = 'sessions'
22     primary_key = 'session_id'
23     join_tables = ['person_session', 'node_session']
24     fields = {
25         'session_id': Parameter(str, "Session key"),
26         'person_id': Parameter(int, "Account identifier, if applicable"),
27         'node_id': Parameter(int, "Node identifier, if applicable"),
28         'expires': Parameter(int, "Date and time when session expires, in seconds since UNIX epoch"),
29         }
30
31     def validate_expires(self, expires):
32         if expires < time.time():
33             raise PLCInvalidArgument, "Expiration date must be in the future"
34
35         return time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(expires))
36
37     add_person = Row.add_object(Person, 'person_session')
38
39     def add_node(self, node, commit = True):
40         # Nodes can have only one session at a time
41         self.api.db.do("DELETE FROM node_session WHERE node_id = %d" % \
42                        node['node_id'])
43
44         add = Row.add_object(Node, 'node_session')
45         add(self, node, commit = commit)
46
47     def sync(self, commit = True, insert = None):
48         if not self.has_key('session_id'):
49             # Before a new session is added, delete expired sessions
50             expired = Sessions(self.api, expires = -int(time.time()))
51             for session in expired:
52                 session.delete(commit)
53
54             # Generate 32 random bytes
55             bytes = random.sample(xrange(0, 256), 32)
56             # Base64 encode their string representation
57             self['session_id'] = base64.b64encode("".join(map(chr, bytes)))
58             # Force insert
59             insert = True
60
61         Row.sync(self, commit, insert)
62
63 class Sessions(Table):
64     """
65     Representation of row(s) from the session table in the database.
66     """
67
68     def __init__(self, api, session_filter = None, expires = int(time.time())):
69         Table.__init__(self, api, Session)
70
71         sql = "SELECT %s FROM view_sessions WHERE True" % \
72               ", ".join(Session.fields)
73
74         if session_filter is not None:
75             if isinstance(session_filter, (list, tuple, set)):
76                 # Separate the list into integers and strings
77                 ints = filter(lambda x: isinstance(x, (int, long)), session_filter)
78                 strs = filter(lambda x: isinstance(x, StringTypes), session_filter)
79                 session_filter = Filter(Session.fields, {'person_id': ints, 'session_id': strs})
80                 sql += " AND (%s) %s" % session_filter.sql(api, "OR")
81             elif isinstance(session_filter, dict):
82                 session_filter = Filter(Session.fields, session_filter)
83                 sql += " AND (%s) %s" % session_filter.sql(api, "AND")
84
85         if expires is not None:
86             if expires >= 0:
87                 sql += " AND expires > %(expires)d"
88             else:
89                 expires = -expires
90                 sql += " AND expires < %(expires)d"
91
92         self.selectall(sql, locals())