X-Git-Url: http://git.onelab.eu/?a=blobdiff_plain;f=PLC%2FTable.py;h=236596588e644c7db1a1b7005d320aef09fe50b1;hb=19d4a01ccf66af9e00914351b3eacd5fc880f988;hp=eeac6c1cfe0471c0f7c11c394bfc71a0d008cf67;hpb=a12987e598653f7fb2b8000d1826dab0f8e154a1;p=plcapi.git diff --git a/PLC/Table.py b/PLC/Table.py index eeac6c1..2365965 100644 --- a/PLC/Table.py +++ b/PLC/Table.py @@ -1,11 +1,12 @@ -# $Id$ from types import StringTypes, IntType, LongType import time import calendar +from PLC.Timestamp import Timestamp from PLC.Faults import * from PLC.Parameter import Parameter + class Row(dict): """ Representation of a row in a database table. To use, optionally @@ -32,10 +33,6 @@ class Row(dict): # sync(). fields = {} - # Set this to the name of the view that gathers the row and its relations - # e.g. view_name = "view_nodes" - view_name = None - # The name of the view that extends objects with tags # e.g. view_tags_name = "view_node_tags" view_tags_name = None @@ -46,6 +43,11 @@ class Row(dict): def __init__(self, api, fields = {}): dict.__init__(self, fields) self.api = api + # run the class_init initializer once + cls=self.__class__ + if not hasattr(cls,'class_inited'): + cls.class_init (api) + cls.class_inited=True # actual value does not matter def validate(self): """ @@ -64,58 +66,38 @@ class Row(dict): if value is not None and hasattr(self, 'validate_' + key): validate = getattr(self, 'validate_' + key) self[key] = validate(value) - - def separate_types(self, items): - """ - Separate a list of different typed objects. - Return a list for each type (ints, strs and dicts) - """ - - if isinstance(items, (list, tuple, set)): - ints = filter(lambda x: isinstance(x, (int, long)), items) - strs = filter(lambda x: isinstance(x, StringTypes), items) - dicts = filter(lambda x: isinstance(x, dict), items) - return (ints, strs, dicts) - else: - raise PLCInvalidArgument, "Can only separate list types" - - - def associate(self, *args): - """ - Provides a means for high level api calls to associate objects - using low level calls. - """ - - if len(args) < 3: - raise PLCInvalidArgumentCount, "auth, field, value must be specified" - elif hasattr(self, 'associate_' + args[1]): - associate = getattr(self, 'associate_'+args[1]) - associate(*args) - else: - raise PLCInvalidArguemnt, "No such associate function associate_%s" % args[1] - def validate_timestamp(self, timestamp, check_future = False): + def separate_types(self, items): """ - Validates the specified GMT timestamp string (must be in - %Y-%m-%d %H:%M:%S format) or number (seconds since UNIX epoch, - i.e., 1970-01-01 00:00:00 GMT). If check_future is True, - raises an exception if timestamp is not in the future. Returns - a GMT timestamp string. + Separate a list of different typed objects. + Return a list for each type (ints, strs and dicts) """ - time_format = "%Y-%m-%d %H:%M:%S" + if isinstance(items, (list, tuple, set)): + ints = filter(lambda x: isinstance(x, (int, long)), items) + strs = filter(lambda x: isinstance(x, StringTypes), items) + dicts = filter(lambda x: isinstance(x, dict), items) + return (ints, strs, dicts) + else: + raise PLCInvalidArgument, "Can only separate list types" - if isinstance(timestamp, StringTypes): - # calendar.timegm() is the inverse of time.gmtime() - timestamp = calendar.timegm(time.strptime(timestamp, time_format)) - # Human readable timestamp string - human = time.strftime(time_format, time.gmtime(timestamp)) + def associate(self, *args): + """ + Provides a means for high level api calls to associate objects + using low level calls. + """ - if check_future and timestamp < time.time(): - raise PLCInvalidArgument, "'%s' not in the future" % human + if len(args) < 3: + raise PLCInvalidArgumentCount, "auth, field, value must be specified" + elif hasattr(self, 'associate_' + args[1]): + associate = getattr(self, 'associate_'+args[1]) + associate(*args) + else: + raise PLCInvalidArguemnt, "No such associate function associate_%s" % args[1] - return human + def validate_timestamp (self, timestamp): + return Timestamp.sql_validate(timestamp) def add_object(self, classobj, join_table, columns = None): """ @@ -135,7 +117,7 @@ class Row(dict): assert isinstance(obj, classobj) assert isinstance(obj, Row) assert obj.primary_key in obj - assert join_table in obj.join_tables + assert join_table in obj.join_tables # By default, just insert the primary keys of each object # into the join table. @@ -153,7 +135,7 @@ class Row(dict): if commit: self.api.db.commit() - + return add add_object = classmethod(add_object) @@ -168,7 +150,7 @@ class Row(dict): """ Disassociate from the specified object. """ - + assert isinstance(self, Row) assert self.primary_key in self assert join_table in self.join_tables @@ -176,10 +158,10 @@ class Row(dict): assert isinstance(obj, Row) assert obj.primary_key in obj assert join_table in obj.join_tables - + self_id = self[self.primary_key] obj_id = obj[obj.primary_key] - + self.api.db.do("DELETE FROM %s WHERE %s = %s AND %s = %s" % \ (join_table, self.primary_key, self.api.db.param('self_id', self_id), @@ -211,65 +193,114 @@ class Row(dict): for this object, and are not marked as a read-only Parameter. """ - if obj is None: obj = self + if obj is None: + obj = self db_fields = self.api.db.fields(self.table_name) - return dict ( [ (key,value) for (key,value) in obj.items() + return dict ( [ (key, value) for (key, value) in obj.items() if key in db_fields and - Row.is_writable(key,value,self.fields) ] ) + Row.is_writable(key, value, self.fields) ] ) def tag_fields (self, obj=None): """ Return the fields of obj that are mentioned in tags """ if obj is None: obj=self - - return dict ( [ (key,value) for (key,value) in obj.iteritems() + + return dict ( [ (key,value) for (key,value) in obj.iteritems() if key in self.tags and Row.is_writable(key,value,self.tags) ] ) - - # takes in input a list of columns, returns three lists - # fields, tags, rejected + + # takes as input a list of columns, sort native fields from tags + # returns 2 dicts and one list : fields, tags, rejected @classmethod def parse_columns (cls, columns): - (fields,tags,rejected)=({},{},{}) + (fields,tags,rejected)=({},{},[]) for column in columns: if column in cls.fields: fields[column]=cls.fields[column] elif column in cls.tags: tags[column]=cls.tags[column] else: rejected.append(column) return (fields,tags,rejected) + # compute the 'accepts' part of a method, from a list of column names, and a fields dict + # use exclude=True to exclude the column names instead + # typically accepted_fields (Node.fields,['hostname','model',...]) + @staticmethod + def accepted_fields (update_columns, fields_dict, exclude=False): + result={} + for (k,v) in fields_dict.iteritems(): + if (not exclude and k in update_columns) or (exclude and k not in update_columns): + result[k]=v + return result + + # filter out user-provided fields that are not part of the declared acceptance list + # keep it separate from split_fields for simplicity + # typically check_fields (,{'hostname':Parameter(str,...),'model':Parameter(..)...}) + @staticmethod + def check_fields (user_dict, accepted_fields): +# avoid the simple, but silent, version +# return dict ([ (k,v) for (k,v) in user_dict.items() if k in accepted_fields ]) + result={} + for (k,v) in user_dict.items(): + if k in accepted_fields: result[k]=v + else: raise PLCInvalidArgument ('Trying to set/change unaccepted key %s'%k) + return result + + # given a dict (typically passed to an Update method), we check and sort + # them against a list of dicts, e.g. [Node.fields, Node.related_fields] + # return is a list that contains n+1 dicts, last one has the rejected fields + @staticmethod + def split_fields (fields, dicts): + result=[] + for x in dicts: result.append({}) + rejected={} + for (field,value) in fields.iteritems(): + found=False + for i in range(len(dicts)): + candidate_dict=dicts[i] + if field in candidate_dict.keys(): + result[i][field]=value + found=True + break + if not found: rejected[field]=value + result.append(rejected) + return result + + ### class initialization : create tag-dependent cross view if needed @classmethod def tagvalue_view_name (cls, tagname): return "tagvalue_view_%s_%s"%(cls.primary_key,tagname) @classmethod - def tagvalue_view_create (cls,tagname): + def tagvalue_view_create_sql (cls,tagname): """ - returns an SQL sentence that creates a view named after the primary_key and tagname, + returns a SQL sentence that creates a view named after the primary_key and tagname, with 2 columns - (*) column 1: name=self.primary_key - (*) column 2: name=tagname value=tagvalue + (*) column 1: primary_key + (*) column 2: actual tag value, renamed into tagname """ - if not cls.view_tags_name: return "" + if not cls.view_tags_name: + raise Exception, 'WARNING: class %s needs to set view_tags_name'%cls.__name__ table_name=cls.table_name primary_key=cls.primary_key view_tags_name=cls.view_tags_name tagvalue_view_name=cls.tagvalue_view_name(tagname) return 'CREATE OR REPLACE VIEW %(tagvalue_view_name)s ' \ - 'as SELECT %(table_name)s.%(primary_key)s,%(view_tags_name)s.tagvalue as "%(tagname)s" ' \ + 'as SELECT %(table_name)s.%(primary_key)s,%(view_tags_name)s.value as "%(tagname)s" ' \ 'from %(table_name)s right join %(view_tags_name)s using (%(primary_key)s) ' \ 'WHERE tagname = \'%(tagname)s\';'%locals() @classmethod - def tagvalue_views_create (cls): + def class_init (cls,api): + cls.tagvalue_views_create (api) + + @classmethod + def tagvalue_views_create (cls,api): if not cls.tags: return - sql = [] - for (type,type_dict) in cls.tags.iteritems(): - for (tagname,details) in type_dict.iteritems(): - sql.append(cls.tagvalue_view_create (tagname)) - return sql + for tagname in cls.tags.keys(): + api.db.do(cls.tagvalue_view_create_sql (tagname)) + api.db.commit() def __eq__(self, y): """ @@ -283,13 +314,15 @@ class Row(dict): y = self.db_fields(y) return dict.__eq__(x, y) - def sync(self, commit = True, insert = None): + # validate becomes optional on sept. 2010 + # we find it useful to use DeletePerson on duplicated entries + def sync(self, commit = True, insert = None, validate=True): """ Flush changes back to the database. """ # Validate all specified fields - self.validate() + if validate: self.validate() # Filter out fields that cannot be set or updated directly db_fields = self.db_fields() @@ -304,26 +337,26 @@ class Row(dict): if not self.has_key(self.primary_key) or \ keys == [self.primary_key] or \ insert is True: - - # If primary key id is a serial int and it isnt included, get next id - if self.fields[self.primary_key].type in (IntType, LongType) and \ - self.primary_key not in self: - pk_id = self.api.db.next_id(self.table_name, self.primary_key) - self[self.primary_key] = pk_id - db_fields[self.primary_key] = pk_id - keys = db_fields.keys() - values = [self.api.db.param(key, value) for (key, value) in db_fields.items()] + + # If primary key id is a serial int and it isnt included, get next id + if self.fields[self.primary_key].type in (IntType, LongType) and \ + self.primary_key not in self: + pk_id = self.api.db.next_id(self.table_name, self.primary_key) + self[self.primary_key] = pk_id + db_fields[self.primary_key] = pk_id + keys = db_fields.keys() + values = [self.api.db.param(key, value) for (key, value) in db_fields.items()] # Insert new row sql = "INSERT INTO %s (%s) VALUES (%s)" % \ (self.table_name, ", ".join(keys), ", ".join(values)) else: # Update existing row columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)] - sql = "UPDATE %s SET " % self.table_name + \ - ", ".join(columns) + \ - " WHERE %s = %s" % \ - (self.primary_key, - self.api.db.param(self.primary_key, self[self.primary_key])) + sql = "UPDATE {} SET {} WHERE {} = {}"\ + .format(self.table_name, + ", ".join(columns), + self.primary_key, + self.api.db.param(self.primary_key, self[self.primary_key])) self.api.db.do(sql, db_fields) @@ -369,7 +402,7 @@ class Table(list): tag_columns={} else: (columns,tag_columns,rejected) = classobj.parse_columns(columns) - if not columns: + if not columns and not tag_columns: raise PLCInvalidArgument, "No valid return fields specified for class %s"%classobj.__name__ if rejected: raise PLCInvalidArgument, "unknown column(s) specified %r in %s"%(rejected,classobj.__name__)