X-Git-Url: http://git.onelab.eu/?a=blobdiff_plain;f=tools%2Fupgrade-db.py;h=4c9d1d51511e5e9de7fea3f5cc126751a9401769;hb=refs%2Fheads%2Fremove-xmlrpc;hp=d307045ddcc1aebf694be561468cf7b537c735cb;hpb=6637dca8e74436d68989b9d055ac43a08c0907c6;p=plcapi.git diff --git a/tools/upgrade-db.py b/tools/upgrade-db.py index d307045..4c9d1d5 100755 --- a/tools/upgrade-db.py +++ b/tools/upgrade-db.py @@ -1,6 +1,15 @@ #!/usr/bin/python # -# Tool for upgrading a db based on db version # +# Tool for upgrading/converting a db +# Requirements: +# 1) Databse Schema - schema for the new database you what to upgrade to +# 2) Config File - the config file that describes how to convert the db +# +# Notes: +# 1) Will attempt to convert the db defined in /etc/planetlab/plc_config +# 2) Does not automatically drop archived database. They must be removed +# manually + import sys import os import getopt @@ -45,11 +54,10 @@ except IndexError: print "Error: too few arguments" usage() -database = config['PLC_DB_NAME'] -archived_database = database + "_archived" schema = {} inserts = [] schema_items_ordered = [] +sequences = {} temp_tables = {} @@ -62,11 +70,11 @@ try: db_version_new = upgrade_config['DB_VERSION_NEW'] except IOError, fault: - print "ERROR: upgrade config file (%s) not found. Exiting" % \ + print "Error: upgrade config file (%s) not found. Exiting" % \ (fault) sys.exit(1) except KeyError, fault: - print "ERROR: %s not set in upgrade confing (%s). Exiting" % \ + print "Error: %s not set in upgrade confing (%s). Exiting" % \ (fault, upgrade_config_file) sys.exit(1) @@ -78,27 +86,28 @@ def connect(): database = config['PLC_DB_NAME']) return db -def archive_db(): +def archive_db(database, archived_database): - print "STATUS: archiving old database" - archive_db = "psql template1 postgres -qc " \ - " 'ALTER DATABASE %s RENAME TO %s;';" \ - " createdb -U postgres %s > /dev/null; " % \ - (database, archived_database, database) - exit_status = os.system(archive_db) + archive_db = " dropdb -U postgres %s > /dev/null 2>&1;" \ + " psql template1 postgres -qc " \ + " 'ALTER DATABASE %s RENAME TO %s;';" % \ + (archived_database, database, archived_database) + exit_status = os.system(archive_db) if exit_status: - print "ERROR: unable to archive database. Upgrade failed" + print "Error: unable to archive database. Upgrade failed" sys.exit(1) - print "STATUS: %s has been archived. now named %s" % (database, archived_database) + #print "Status: %s has been archived. now named %s" % (database, archived_database) def encode_utf8(inputfile_name, outputfile_name): - # rewrite a iso-8859-1 encoded file and in utf8 + # rewrite a iso-8859-1 encoded file in utf8 try: inputfile = open(inputfile_name, 'r') outputfile = open(outputfile_name, 'w') for line in inputfile: - outputfile.write(unicode(line, 'iso-8859-1').encode('utf8')) + if line.upper().find('SET CLIENT_ENCODING') > -1: + continue + outputfile.write(unicode(line, 'iso-8859-1').encode('utf8')) inputfile.close() outputfile.close() except: @@ -114,18 +123,18 @@ def create_item_from_schema(item_name): if exit_status: raise Exception except Exception, fault: - print 'ERROR: create %s failed. Check schema.' % item_name + print 'Error: create %s failed. Check schema.' % item_name sys.exit(1) raise fault except KeyError: - print "ERROR: cannot create %s. definition not found in %s" % \ + print "Error: cannot create %s. definition not found in %s" % \ (key, schema_file) return False def fix_row(row, table_name, table_fields): - if table_name in ['nodenetworks']: + if table_name in ['interfaces']: # convert str bwlimit to bps int bwlimit_index = table_fields.index('bwlimit') if isinstance(row[bwlimit_index], int): @@ -143,9 +152,13 @@ def fix_row(row, table_name, table_fields): row[attribute_type_index] = 10016 elif row[attribute_type_index] == 10006: row[attribute_type_index] = 10017 + elif row[attribute_type_index] in [10031, 10033]: + row[attribute_type_index] = 10037 + elif row[attribute_type_index] in [10034, 10035]: + row[attribute_type_index] = 10036 elif table_name in ['slice_attribute_types']: type_id_index = table_fields.index('attribute_type_id') - if row[type_id_index] in [10004, 10006]: + if row[type_id_index] in [10004, 10006, 10031, 10033, 10034, 10035]: return None return row @@ -154,7 +167,7 @@ def fix_table(table, table_name, table_fields): # remove duplicate/redundant primary keys type_id_index = table_fields.index('attribute_type_id') for row in table: - if row[type_id_index] in [10004, 10006]: + if row[type_id_index] in [10004, 10006, 10031, 10033, 10034, 10035]: table.remove(row) return table @@ -171,13 +184,14 @@ def generate_temp_table(table_name, db): try: # get upgrade directions table_def = upgrade_config[table_name].replace('(', '').replace(')', '').split(',') - table_fields, old_fields, required_joins = [], [], set() + table_fields, old_fields, joins, wheres = [], [], set(), set() for field in table_def: field_parts = field.strip().split(':') table_fields.append(field_parts[0]) old_fields.append(field_parts[1]) - if field_parts[2:]: - required_joins.update(set(field_parts[2:])) + if field_parts[2:]: + joins.update(set(filter(lambda x: not x.find('=') > -1, field_parts[2:]))) + wheres.update(set(filter(lambda x: x.find('=') > -1, field_parts[2:]))) # get indices of fields that cannot be null (type, body_list) = schema[table_name] @@ -187,24 +201,26 @@ def generate_temp_table(table_name, db): if body_line.find(field) > -1 and \ body_line.upper().find("NOT NULL") > -1: not_null_indices.append(table_fields.index(field)) - # get index of primary key primary_key_indices = [] for body_line in body_list: if body_line.find("PRIMARY KEY") > -1: primary_key = body_line for field in table_fields: - if primary_key.find(field) > -1: + if primary_key.find(" "+field+" ") > -1: primary_key_indices.append(table_fields.index(field)) - break - + #break + # get old data get_old_data = "SELECT DISTINCT %s FROM %s" % \ (", ".join(old_fields), old_fields[0].split(".")[0]) - for join in required_joins: + for join in joins: get_old_data = get_old_data + " INNER JOIN %s USING (%s) " % \ (join.split('.')[0], join.split('.')[1]) - + if wheres: + get_old_data = get_old_data + " WHERE " + for where in wheres: + get_old_data = get_old_data + " %s" % where cursor.execute(get_old_data) rows = cursor.fetchall() @@ -246,12 +262,12 @@ def generate_temp_table(table_name, db): # (table_name) return False except IndexError, fault: - print "ERROR: error found in upgrade config file. " \ + print "Error: error found in upgrade config file. " \ "check %s configuration. Aborting " % \ (table_name) sys.exit(1) except: - print "ERROR: configuration for %s doesnt match db schema. " \ + print "Error: configuration for %s doesnt match db schema. " \ " Aborting" % (table_name) try: db.rollback() @@ -269,16 +285,17 @@ try: cursor.execute("SELECT relname from pg_class where relname = 'plc_db_version'") rows = cursor.fetchall() if not rows: - print "WARNING: current db has no version. Unable to validate config file." + print "Warning: current db has no version. Unable to validate config file." else: cursor.execute("SELECT version FROM plc_db_version") rows = cursor.fetchall() - - if rows[0][0] == db_version_new: - print "STATUS: Versions are the same. No upgrade necessary." + if not rows or not rows[0]: + print "Warning: current db has no version. Unable to validate config file." + elif rows[0][0] == db_version_new: + print "Status: Versions are the same. No upgrade necessary." sys.exit() elif not rows[0][0] == db_version_previous: - print "STATUS: DB_VERSION_PREVIOUS in config file (%s) does not" \ + print "Stauts: DB_VERSION_PREVIOUS in config file (%s) does not" \ " match current db version %d" % (upgrade_config_file, rows[0][0]) sys.exit() else: @@ -288,35 +305,33 @@ try: # check db encoding sql = " SELECT pg_catalog.pg_encoding_to_char(d.encoding)" \ " FROM pg_catalog.pg_database d " \ - " LEFT JOIN pg_catalog.pg_user u ON d.datdba = u.usesysid " \ " WHERE d.datname = '%s' " % config['PLC_DB_NAME'] cursor.execute(sql) rows = cursor.fetchall() - if rows[0][0] not in ['UTF8']: - print "WARNING: db encoding is not utf8. Must convert" + if rows[0][0] not in ['UTF8', 'UNICODE']: + print "WARNING: db encoding is not utf8. Attempting to encode" db.close() # generate db dump dump_file = '%s/dump.sql' % (temp_dir) dump_file_encoded = dump_file + ".utf8" - dump_cmd = 'pg_dump -i %s -U %s -f %s > /dev/null 2>&1' % \ - (config['PLC_DB_NAME'], config['PLC_DB_USER'], dump_file) - print dump_cmd + dump_cmd = 'pg_dump -i %s -U postgres -f %s > /dev/null 2>&1' % \ + (config['PLC_DB_NAME'], dump_file) if os.system(dump_cmd): print "ERROR: during db dump. Exiting." sys.exit(1) # encode dump to utf8 - print "STATUS: encoding database dump" + print "Status: encoding database dump" encode_utf8(dump_file, dump_file_encoded) # archive original db - archive_db() + archive_db(config['PLC_DB_NAME'], config['PLC_DB_NAME']+'_sqlascii_archived') # create a utf8 database and upload encoded data - recreate_cmd = 'createdb -U %s -E UTF8 %s > /dev/null 2>&1; ' \ + recreate_cmd = 'createdb -U postgres -E UTF8 %s > /dev/null; ' \ 'psql -a -U %s %s < %s > /dev/null 2>&1;' % \ - (config['PLC_DB_USER'], config['PLC_DB_NAME'], \ - config['PLC_DB_USER'], config['PLC_DB_NAME'], dump_file_encoded) - print "STATUS: recreating database as utf8" + (config['PLC_DB_NAME'], config['PLC_DB_USER'], \ + config['PLC_DB_NAME'], dump_file_encoded) + print "Status: recreating database as utf8" if os.system(recreate_cmd): - print "ERROR: database encoding failed. Aborting" + print "Error: database encoding failed. Aborting" sys.exit(1) os.remove(dump_file_encoded) @@ -335,9 +350,14 @@ try: lines = file.readlines() while index < len(lines): line = lines[index] + if line.find("--") > -1: + line_parts = line.split("--") + line = line_parts[0] # find all created objects if line.startswith("CREATE"): line_parts = line.split(" ") + if line_parts[1:3] == ['OR', 'REPLACE']: + line_parts = line_parts[2:] item_type = line_parts[1] item_name = line_parts[2] schema_items_ordered.append(item_name) @@ -351,12 +371,18 @@ try: while index < len(lines): index = index + 1 nextline =lines[index] + if nextline.find("--") > -1: + new_line_parts = nextline.split("--") + nextline = new_line_parts[0] + # look for any sequences + if item_type in ['TABLE'] and nextline.find('serial') > -1: + sequences[item_name] = nextline.strip().split()[0] fields.append(nextline) if nextline.find(";") >= 0: break schema[item_name] = (item_type, fields) else: - print "ERROR: unknown type %s" % item_type + print "Error: unknown type %s" % item_type elif line.startswith("INSERT"): inserts.append(line) index = index + 1 @@ -364,7 +390,7 @@ try: except: raise -print "STATUS: generating temp tables" +print "Status: generating temp tables" # generate all temp tables for key in schema_items_ordered: (type, body_list) = schema[key] @@ -375,11 +401,11 @@ for key in schema_items_ordered: cursor.close() db.close() -print "STATUS: archiving database" -archive_db() +print "Status: archiving database" +archive_db(config['PLC_DB_NAME'], config['PLC_DB_NAME']+'_archived') +os.system('createdb -U postgres -E UTF8 %s > /dev/null; ' % config['PLC_DB_NAME']) - -print "STATUS: upgrading database" +print "Status: upgrading database" # attempt to create and load all items from schema into temp db try: for key in schema_items_ordered: @@ -387,31 +413,45 @@ try: create_item_from_schema(key) if type == 'TABLE': if upgrade_config.has_key(key): + # attempt to populate with temp table data table_def = upgrade_config[key].replace('(', '').replace(')', '').split(',') table_fields = [field.strip().split(':')[0] for field in table_def] insert_cmd = "psql %s %s -c " \ " 'COPY %s (%s) FROM stdin;' < %s " % \ - (database, config['PLC_DB_USER'], key, ", ".join(table_fields), temp_tables[key] ) + (config['PLC_DB_NAME'], config['PLC_DB_USER'], key, + ", ".join(table_fields), temp_tables[key] ) exit_status = os.system(insert_cmd) if exit_status: - print "ERROR: upgrade %s failed" % key - raise + print "Error: upgrade %s failed" % key + sys.exit(1) + # update the primary key sequence + if sequences.has_key(key): + sequence = key +"_"+ sequences[key] +"_seq" + update_seq = "psql %s %s -c " \ + " \"select setval('%s', max(%s)) FROM %s;\" > /dev/null" % \ + (config['PLC_DB_NAME'], config['PLC_DB_USER'], sequence, + sequences[key], key) + exit_status = os.system(update_seq) + if exit_status: + print "Error: sequence %s update failed" % sequence + sys.exit(1) else: # check if there are any insert stmts in schema for this table - print "WARNING: %s has no temp data file. Unable to populate with old data" % key + print "Warning: %s has no temp data file. Unable to populate with old data" % key for insert_stmt in inserts: if insert_stmt.find(key) > -1: insert_cmd = 'psql %s postgres -qc "%s;" > /dev/null 2>&1' % \ - (database, insert_stmt) + (config['PLC_DB_NAME'], insert_stmt) os.system(insert_cmd) except: - print "ERROR: failed to populate db. Unarchiving original database and aborting" - undo_command = "dropdb -U postgres %s; psql template1 postgres -qc" \ + print "Error: failed to populate db. Unarchiving original database and aborting" + undo_command = "dropdb -U postgres %s > /dev/null; psql template1 postgres -qc" \ " 'ALTER DATABASE %s RENAME TO %s;'; > /dev/null" % \ - (database, archived_database, database) + (config['PLC_DB_NAME'], config['PLC_DB_NAME']+'_archived', config['PLC_DB_NAME']) os.system(undo_command) - remove_temp_tables() + #remove_temp_tables() + raise -remove_temp_tables() +#remove_temp_tables() print "upgrade complete"