bug fixes
[plcapi.git] / PLC / Sessions.py
1 from types import StringTypes
2 import random
3 import base64
4 import time
5
6 from PLC.Faults import *
7 from PLC.Parameter import Parameter
8 from PLC.Debug import profile
9 from PLC.Storage.AlchemyObject import AlchemyObj
10 from PLC.Persons import Person, Persons
11 from PLC.Nodes import Node, Nodes
12
13 class Session(AlchemyObj):
14     """
15     Representation of a row in the sessions table. To use, instantiate
16     with a dict of values.
17     """
18
19     tablename = 'sessions'
20     join_tables = ['person_session', 'node_session']
21     fields = {
22         'session_id': Parameter(str, "Session key", primary_key=True),
23         'person_id': Parameter(int, "Account identifier, if applicable"),
24         'node_id': Parameter(int, "Node identifier, if applicable"),
25         'expires': Parameter(int, "Date and time when session expires, in seconds since UNIX epoch"),
26         }
27
28     def validate_expires(self, expires):
29         if expires < time.time():
30             raise PLCInvalidArgument, "Expiration date must be in the future"
31
32         return time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(expires))
33
34     #add_person = Row.add_object(Person, 'person_session')
35
36     def sync(self, commit = True, insert = None, validate=True):
37         AlchemyObj.sync(commit=commit, validate=validate)
38         if not self.has_key('session_id'):
39             # Before a new session is added, delete expired sessions
40             expired = Sessions(self.api, expires = -int(time.time()))
41             for session in expired:
42                 session.delete(commit)
43
44             # Generate 32 random bytes
45             bytes = random.sample(xrange(0, 256), 32)
46             # Base64 encode their string representation
47             self['session_id'] = base64.b64encode("".join(map(chr, bytes)))
48             # Force insert
49             AlchemyObj.insert(self, dict(self))
50
51 class Sessions(list):
52     """
53     Representation of row(s) from the session table in the database.
54     """
55
56     def __init__(self, api, session_filter = None, expires = int(time.time())):
57         if session_filter is not None:
58             sessions = Session().select()
59         if isinstance(session_filter, (list, tuple, set)):
60             # Separate the list into integers and strings
61             ints = filter(lambda x: isinstance(x, (int, long)), session_filter)
62             strs = filter(lambda x: isinstance(x, StringTypes), session_filter)
63             session_filter = {'person_id': ints, 'session_id': strs}
64             sessions = Session().select(filter=session_filter)
65         elif isinstance(session_filter, dict):
66             sessions = Session().select(filter=session_filter)
67         elif isinstance(session_filter, (int, long)):
68             sessions = Session().select(filter={'person_id': session_filter})
69         elif isinstance(session_filter, StringTypes):
70             sessions = Session().select(filter={'session_id': session_filter})
71         else:
72             raise PLCInvalidArgument, "Wrong session filter"%session_filter
73
74         if expires is not None:
75             if expires >= 0:
76                 sql += " AND expires > %(expires)d"
77             else:
78                 expires = -expires
79                 sql += " AND expires < %(expires)d"
80
81         self.extend(sessions)