added Row class with validate methods
[sfa.git] / sfa / util / genitable.py
1 # genitable.py
2 #
3 # implements support for geni records stored in db tables
4 #
5 # TODO: Use existing PLC database methods? or keep this separate?
6
7 ### $Id$
8 ### $URL$
9
10 import report
11 import  pgdb
12 from pg import DB, ProgrammingError
13 from sfa.util.PostgreSQL import *
14 from sfa.trust.gid import *
15 from sfa.util.record import *
16 from sfa.util.debug import *
17 from sfa.util.config import *
18 from sfa.util.filter import *
19
20 class Row(dict):
21
22     # Set this to the name of the table that stores the row.
23     # e.g. table_name = "nodes"
24     table_name = None
25
26     # Set this to the name of the primary key of the table. It is
27     # assumed that the this key is a sequence if it is not set when
28     # sync() is called.
29     # e.g. primary_key="record_id"
30     primary_key = None
31
32     # Set this to the names of tables that reference this table's
33     # primary key.
34     join_tables = []
35     
36     def validate(self):
37         """
38         Validates values. Will validate a value with a custom function
39         if a function named 'validate_[key]' exists.
40         """
41         # Warn about mandatory fields
42         # XX TODO: Support checking for mandatory fields later
43         #mandatory_fields = self.db.fields(self.table_name, notnull = True, hasdef = False)
44         #for field in mandatory_fields:
45         #    if not self.has_key(field) or self[field] is None:
46         #        raise GeniInvalidArgument, field + " must be specified and cannot be unset in class %s"%self.__class__.__name__
47
48         # Validate values before committing
49         for key, value in self.iteritems():
50             if value is not None and hasattr(self, 'validate_' + key):
51                 validate = getattr(self, 'validate_' + key)
52                 self[key] = validate(value)
53
54
55     def validate_timestamp(self, timestamp, check_future = False):
56         """
57         Validates the specified GMT timestamp string (must be in
58         %Y-%m-%d %H:%M:%S format) or number (seconds since UNIX epoch,
59         i.e., 1970-01-01 00:00:00 GMT). If check_future is True,
60         raises an exception if timestamp is not in the future. Returns
61         a GMT timestamp string.
62         """
63
64         time_format = "%Y-%m-%d %H:%M:%S"
65
66         if isinstance(timestamp, StringTypes):
67             # calendar.timegm() is the inverse of time.gmtime()
68             timestamp = calendar.timegm(time.strptime(timestamp, time_format))
69
70         # Human readable timestamp string
71         human = time.strftime(time_format, time.gmtime(timestamp))
72
73         if check_future and timestamp < time.time():
74             raise GeniInvalidArgument, "'%s' not in the future" % human
75
76         return human
77
78 class GeniTable(list):
79
80     GENI_TABLE_PREFIX = "sfa"
81
82     def __init__(self, record_filter = None):
83
84         # pgsql doesn't like table names with "." in them, to replace it with "$"
85         self.tablename = GeniTable.GENI_TABLE_PREFIX
86         self.config = Config()
87         self.db = PostgreSQL(self.config)
88         # establish a connection to the pgsql server
89         cninfo = self.config.get_plc_dbinfo()     
90         self.cnx = DB(cninfo['dbname'], cninfo['address'], port=cninfo['port'], user=cninfo['user'], passwd=cninfo['password'])
91
92         if record_filter:
93             records = self.find(record_filter)
94             for record in reocrds:
95                 self.append(record)             
96
97     def exists(self):
98         tableList = self.cnx.get_tables()
99         if 'public.' + self.tablename in tableList:
100             return True
101         if 'public."' + self.tablename + '"' in tableList:
102             return True
103         return False
104
105     def db_fields(self, obj=None):
106         
107         db_fields = self.db.fields(self.GENI_TABLE_PREFIX)
108         return dict( [ (key,value) for (key, value) in obj.items() \
109                         if key in db_fields and
110                         self.is_writable(key, value, GeniRecord.fields)] )      
111
112     @staticmethod
113     def is_writable (key,value,dict):
114         # if not mentioned, assume it's writable (e.g. deleted ...)
115         if key not in dict: return True
116         # if mentioned but not linked to a Parameter object, idem
117         if not isinstance(dict[key], Parameter): return True
118         # if not marked ro, it's writable
119         if not dict[key].ro: return True
120
121         return False
122
123
124     def create(self):
125         
126         querystr = "CREATE TABLE " + self.tablename + " ( \
127                 record_id serial PRIMARY KEY , \
128                 hrn text NOT NULL, \
129                 authority text NOT NULL, \
130                 peer_authority text, \
131                 gid text, \
132                 type text NOT NULL, \
133                 pointer integer, \
134                 date_created timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP, \
135                 last_updated timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP);"
136         template = "CREATE INDEX %s_%s_idx ON %s (%s);"
137         indexes = [template % ( self.tablename, field, self.tablename, field) \
138                    for field in ['hrn', 'type', 'authority', 'peer_authority', 'pointer']]
139         # IF EXISTS doenst exist in postgres < 8.2
140         try:
141             self.cnx.query('DROP TABLE IF EXISTS ' + self.tablename)
142         except ProgrammingError:
143             try:
144                 self.cnx.query('DROP TABLE ' + self.tablename)
145             except ProgrammingError:
146                 pass
147          
148         self.cnx.query(querystr)
149         for index in indexes:
150             self.cnx.query(index)
151
152     def remove(self, record):
153         query_str = "DELETE FROM %s WHERE record_id = %s" % (self.tablename, record['record_id']) 
154         self.cnx.query(query_str)
155
156     def insert(self, record):
157         db_fields = self.db_fields(record)
158         keys = db_fields.keys()
159         values = [self.db.param(key, value) for (key, value) in db_fields.items()]
160         query_str = "INSERT INTO " + self.tablename + \
161                        "(" + ",".join(keys) + ") " + \
162                        "VALUES(" + ",".join(values) + ")"
163         self.db.do(query_str, db_fields)
164         self.db.commit()
165         result = self.find({'hrn': record['hrn'], 'type': record['type'], 'peer_authority': record['peer_authority']})
166         if not result:
167             record_id = None
168         elif isinstance(result, list):
169             record_id = result[0]['record_id']
170         else:
171             record_id = result['record_id']
172
173         return record_id
174
175     def update(self, record):
176         db_fields = self.db_fields(record)
177         keys = db_fields.keys()
178         values = [self.db.param(key, value) for (key, value) in db_fields.items()]
179         columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)]
180         query_str = "UPDATE %s SET %s WHERE record_id = %s" % \
181                     (self.tablename, ", ".join(columns), record['record_id'])
182         self.db.do(query_str, db_fields)
183         self.db.commit()
184
185     def quote(self, value):
186         """
187         Returns quoted version of the specified value.
188         """
189
190         # The pgdb._quote function is good enough for general SQL
191         # quoting, except for array types.
192         if isinstance(value, (list, tuple, set)):
193             return "ARRAY[%s]" % ", ".join(map, self.quote, value)
194         else:
195             return pgdb._quote(value)
196
197     def find(self, record_filter = None):
198         sql = "SELECT * FROM %s WHERE True " % self.tablename
199         
200         if isinstance(record_filter, (list, tuple, set)):
201             ints = filter(lambda x: isinstance(x, (int, long)), record_filter)
202             strs = filter(lambda x: isinstance(x, StringTypes), record_filter)
203             record_filter = Filter(GeniRecord.all_fields, {'record_id': ints, 'hrn': strs})
204             sql += "AND (%s) %s " % record_filter.sql("OR") 
205         elif isinstance(record_filter, dict):
206             record_filter = Filter(GeniRecord.all_fields, record_filter)        
207             sql += " AND (%s) %s" % record_filter.sql("AND")
208         elif isinstance(record_filter, StringTypes):
209             record_filter = Filter(GeniRecord.all_fields, {'hrn':[record_filter]})    
210             sql += " AND (%s) %s" % record_filter.sql("AND")
211         elif isinstance(record_filter, int):
212             record_filter = Filter(GeniRecord.all_fields, {'record_id':[record_filter]})    
213             sql += " AND (%s) %s" % record_filter.sql("AND")
214
215         results = self.cnx.query(sql).dictresult()
216         if isinstance(results, dict):
217             results = [results]
218         return results
219
220     def findObjects(self, record_filter = None):
221         
222         results = self.find(record_filter) 
223         result_rec_list = []
224         for result in results:
225             if result['type'] in ['authority']:
226                 result_rec_list.append(AuthorityRecord(dict=result))
227             elif result['type'] in ['node']:
228                 result_rec_list.append(NodeRecord(dict=result))
229             elif result['type'] in ['slice']:
230                 result_rec_list.append(SliceRecord(dict=result))
231             elif result['type'] in ['user']:
232                 result_rec_list.append(UserRecord(dict=result))
233             else:
234                 result_rec_list.append(GeniRecord(dict=result))
235         return result_rec_list
236
237
238     def drop(self):
239         try:
240             self.cnx.query('DROP TABLE IF EXISTS ' + self.tablename)
241         except ProgrammingError:
242             try:
243                 self.cnx.query('DROP TABLE ' + self.tablename)
244             except ProgrammingError:
245                 pass
246     
247     @staticmethod
248     def geni_records_purge(cninfo):
249
250         cnx = DB(cninfo['dbname'], cninfo['address'], 
251                  port=cninfo['port'], user=cninfo['user'], passwd=cninfo['password'])
252         tableList = cnx.get_tables()
253         for table in tableList:
254             if table.startswith(GeniTable.GENI_TABLE_PREFIX) or \
255                     table.startswith('public.' + GeniTable.GENI_TABLE_PREFIX) or \
256                     table.startswith('public."' + GeniTable.GENI_TABLE_PREFIX):
257                 report.trace("dropping table " + table)
258                 cnx.query("DROP TABLE " + table)