trying out the hint from github issue
[plcapi.git] / tools / upgrade-db.py
index d307045..4c9d1d5 100755 (executable)
@@ -1,6 +1,15 @@
 #!/usr/bin/python
 #
-# Tool for upgrading a db based on db version #
+# Tool for upgrading/converting a db
+# Requirements:
+# 1) Databse Schema - schema for the new database you what to upgrade to
+# 2) Config File - the config file that describes how to convert the db
+#
+# Notes:
+# 1) Will attempt to convert the db defined in  /etc/planetlab/plc_config
+# 2) Does not automatically drop archived database. They must be removed
+#    manually
+
 import sys
 import os
 import getopt
@@ -45,11 +54,10 @@ except IndexError:
        print "Error: too few arguments"
         usage()
 
-database = config['PLC_DB_NAME']
-archived_database = database + "_archived"
 schema = {}
 inserts = []
 schema_items_ordered = []
+sequences = {}
 temp_tables = {}
 
 
@@ -62,11 +70,11 @@ try:
         db_version_new = upgrade_config['DB_VERSION_NEW']
 
 except IOError, fault:
-        print "ERROR: upgrade config file (%s) not found. Exiting" % \
+        print "Error: upgrade config file (%s) not found. Exiting" % \
                (fault)
         sys.exit(1) 
 except KeyError, fault:
-       print "ERROR: %s not set in upgrade confing (%s). Exiting" % \
+       print "Error: %s not set in upgrade confing (%s). Exiting" % \
                (fault, upgrade_config_file)
        sys.exit(1)
 
@@ -78,27 +86,28 @@ def connect():
                  database = config['PLC_DB_NAME'])     
        return db
 
-def archive_db():
+def archive_db(database, archived_database):
 
-       print "STATUS: archiving old database"
-        archive_db = "psql template1 postgres -qc " \
-                    " 'ALTER DATABASE %s RENAME TO %s;';" \
-                    " createdb -U postgres %s > /dev/null; " % \
-                     (database, archived_database, database)
-        exit_status = os.system(archive_db)
+        archive_db = " dropdb -U postgres %s > /dev/null 2>&1;" \
+                    " psql template1 postgres -qc " \
+                     " 'ALTER DATABASE %s RENAME TO %s;';" % \
+                     (archived_database, database, archived_database)
+       exit_status = os.system(archive_db)
         if exit_status:
-                print "ERROR: unable to archive database. Upgrade failed"
+                print "Error: unable to archive database. Upgrade failed"
                 sys.exit(1)
-        print "STATUS: %s has been archived. now named %s" % (database, archived_database)
+        #print "Status: %s has been archived. now named %s" % (database, archived_database)
 
 
 def encode_utf8(inputfile_name, outputfile_name):
-       # rewrite a iso-8859-1 encoded file and in utf8
+       # rewrite a iso-8859-1 encoded file in utf8
        try:
                inputfile = open(inputfile_name, 'r')
                outputfile = open(outputfile_name, 'w')
                for line in inputfile:
-                       outputfile.write(unicode(line, 'iso-8859-1').encode('utf8'))
+                       if line.upper().find('SET CLIENT_ENCODING') > -1:
+                               continue
+                       outputfile.write(unicode(line, 'iso-8859-1').encode('utf8'))
                inputfile.close()
                outputfile.close()              
        except:
@@ -114,18 +123,18 @@ def create_item_from_schema(item_name):
                if exit_status:
                        raise Exception
         except Exception, fault:
-                print 'ERROR: create %s failed. Check schema.' % item_name
+                print 'Error: create %s failed. Check schema.' % item_name
                sys.exit(1)
                raise fault
 
         except KeyError:
-                print "ERROR: cannot create %s. definition not found in %s" % \
+                print "Error: cannot create %s. definition not found in %s" % \
                         (key, schema_file)
                 return False
 
 def fix_row(row, table_name, table_fields):
 
-       if table_name in ['nodenetworks']:
+       if table_name in ['interfaces']:
                # convert str bwlimit to bps int
                bwlimit_index = table_fields.index('bwlimit')
                if isinstance(row[bwlimit_index], int):
@@ -143,9 +152,13 @@ def fix_row(row, table_name, table_fields):
                        row[attribute_type_index] = 10016
                elif row[attribute_type_index] == 10006:
                        row[attribute_type_index] = 10017
+               elif row[attribute_type_index] in [10031, 10033]:
+                       row[attribute_type_index] = 10037
+               elif row[attribute_type_index] in [10034, 10035]:
+                       row[attribute_type_index] = 10036
        elif table_name in ['slice_attribute_types']:
                type_id_index = table_fields.index('attribute_type_id')
-               if row[type_id_index] in [10004, 10006]:
+               if row[type_id_index] in [10004, 10006, 10031, 10033, 10034, 10035]:
                        return None
        return row
        
@@ -154,7 +167,7 @@ def fix_table(table, table_name, table_fields):
                # remove duplicate/redundant primary keys
                type_id_index = table_fields.index('attribute_type_id')
                for row in table:
-                       if row[type_id_index] in [10004, 10006]:
+                       if row[type_id_index] in [10004, 10006, 10031, 10033, 10034, 10035]:
                                table.remove(row)
        return table
 
@@ -171,13 +184,14 @@ def generate_temp_table(table_name, db):
         try:
                 # get upgrade directions
                 table_def = upgrade_config[table_name].replace('(', '').replace(')', '').split(',')
-                table_fields, old_fields, required_joins = [], [], set()
+                table_fields, old_fields, joins, wheres = [], [], set(), set()
                 for field in table_def:
                         field_parts = field.strip().split(':')
                         table_fields.append(field_parts[0])
                         old_fields.append(field_parts[1])
-                        if field_parts[2:]:
-                               required_joins.update(set(field_parts[2:]))
+                        if field_parts[2:]:    
+                               joins.update(set(filter(lambda x: not x.find('=') > -1, field_parts[2:])))
+                               wheres.update(set(filter(lambda x: x.find('=') > -1, field_parts[2:])))
                
                # get indices of fields that cannot be null
                (type, body_list) = schema[table_name]
@@ -187,24 +201,26 @@ def generate_temp_table(table_name, db):
                                if body_line.find(field) > -1 and \
                                   body_line.upper().find("NOT NULL") > -1:
                                        not_null_indices.append(table_fields.index(field))
-
                # get index of primary key
                primary_key_indices = []
                for body_line in body_list:
                        if body_line.find("PRIMARY KEY") > -1:
                                primary_key = body_line
                                for field in table_fields:
-                                       if primary_key.find(field) > -1:
+                                       if primary_key.find(" "+field+" ") > -1:
                                                primary_key_indices.append(table_fields.index(field))
-                               break
-
+                               #break
+       
                 # get old data
                 get_old_data = "SELECT DISTINCT %s FROM %s" % \
                       (", ".join(old_fields), old_fields[0].split(".")[0])
-                for join in required_joins:
+                for join in joins:
                         get_old_data = get_old_data + " INNER JOIN %s USING (%s) " % \
                                        (join.split('.')[0], join.split('.')[1])
-
+               if wheres:      
+                       get_old_data = get_old_data + " WHERE " 
+               for where in wheres:
+                       get_old_data = get_old_data + " %s" % where 
                 cursor.execute(get_old_data)
                 rows = cursor.fetchall()
 
@@ -246,12 +262,12 @@ def generate_temp_table(table_name, db):
                 #       (table_name)
                 return False
         except IndexError, fault:
-                print "ERROR: error found in upgrade config file. " \
+                print "Error: error found in upgrade config file. " \
                       "check %s configuration. Aborting " % \
                       (table_name)
                 sys.exit(1)
         except:
-                print "ERROR: configuration for %s doesnt match db schema. " \
+                print "Error: configuration for %s doesnt match db schema. " \
                      " Aborting" % (table_name)
                 try:
                         db.rollback()
@@ -269,16 +285,17 @@ try:
        cursor.execute("SELECT relname from pg_class where relname = 'plc_db_version'")
        rows = cursor.fetchall()
        if not rows:
-               print "WARNING: current db has no version. Unable to validate config file."
+               print "Warning: current db has no version. Unable to validate config file."
        else:
                cursor.execute("SELECT version FROM plc_db_version")
                rows = cursor.fetchall()
-
-               if rows[0][0] == db_version_new:
-                               print "STATUS: Versions are the same. No upgrade necessary."
+               if not rows or not rows[0]:
+                       print "Warning: current db has no version. Unable to validate config file."
+               elif rows[0][0] == db_version_new:
+                               print "Status: Versions are the same. No upgrade necessary."
                        sys.exit()
                elif not rows[0][0] == db_version_previous:
-                       print "STATUS: DB_VERSION_PREVIOUS in config file (%s) does not" \
+                       print "Stauts: DB_VERSION_PREVIOUS in config file (%s) does not" \
                              " match current db version %d" % (upgrade_config_file, rows[0][0])
                        sys.exit()
                else:
@@ -288,35 +305,33 @@ try:
        # check db encoding
        sql = " SELECT pg_catalog.pg_encoding_to_char(d.encoding)" \
              " FROM pg_catalog.pg_database d " \
-             " LEFT JOIN pg_catalog.pg_user u ON d.datdba = u.usesysid " \
              " WHERE d.datname = '%s' " % config['PLC_DB_NAME']
        cursor.execute(sql)
        rows = cursor.fetchall()
-       if rows[0][0] not in ['UTF8']:
-               print "WARNING: db encoding is not utf8. Must convert"
+       if rows[0][0] not in ['UTF8', 'UNICODE']:
+               print "WARNING: db encoding is not utf8. Attempting to encode"
                db.close()
                # generate db dump
                dump_file = '%s/dump.sql' % (temp_dir)
                dump_file_encoded = dump_file + ".utf8"
-               dump_cmd = 'pg_dump -i %s -U %s -f %s > /dev/null 2>&1' % \
-                          (config['PLC_DB_NAME'], config['PLC_DB_USER'], dump_file)
-               print dump_cmd
+               dump_cmd = 'pg_dump -i %s -U postgres -f %s > /dev/null 2>&1' % \
+                          (config['PLC_DB_NAME'], dump_file)
                if os.system(dump_cmd):
                        print "ERROR: during db dump. Exiting."
                        sys.exit(1)
                # encode dump to utf8
-               print "STATUS: encoding database dump"
+               print "Status: encoding database dump"
                encode_utf8(dump_file, dump_file_encoded)
                # archive original db
-               archive_db()
+               archive_db(config['PLC_DB_NAME'], config['PLC_DB_NAME']+'_sqlascii_archived')
                # create a utf8 database and upload encoded data
-               recreate_cmd = 'createdb -U %s -E UTF8 %s > /dev/null 2>&1; ' \
+               recreate_cmd = 'createdb -U postgres -E UTF8 %s > /dev/null; ' \
                               'psql -a -U  %s %s < %s > /dev/null 2>&1;'   % \
-                         (config['PLC_DB_USER'], config['PLC_DB_NAME'], \
-                          config['PLC_DB_USER'], config['PLC_DB_NAME'], dump_file_encoded) 
-               print "STATUS: recreating database as utf8"
+                         (config['PLC_DB_NAME'], config['PLC_DB_USER'], \
+                          config['PLC_DB_NAME'], dump_file_encoded) 
+               print "Status: recreating database as utf8"
                if os.system(recreate_cmd):
-                       print "ERROR: database encoding failed. Aborting"
+                       print "Error: database encoding failed. Aborting"
                        sys.exit(1)
                
                os.remove(dump_file_encoded)
@@ -335,9 +350,14 @@ try:
        lines = file.readlines()
        while index < len(lines):
                line = lines[index] 
+               if line.find("--") > -1:
+                       line_parts = line.split("--")
+                        line = line_parts[0]
                # find all created objects
                if line.startswith("CREATE"):
                        line_parts = line.split(" ")
+                       if line_parts[1:3] == ['OR', 'REPLACE']:
+                               line_parts = line_parts[2:]
                        item_type = line_parts[1]
                        item_name = line_parts[2]
                        schema_items_ordered.append(item_name)
@@ -351,12 +371,18 @@ try:
                                while index < len(lines):
                                        index = index + 1
                                        nextline =lines[index]
+                                       if nextline.find("--") > -1:
+                                                new_line_parts = nextline.split("--")
+                                                nextline = new_line_parts[0]
+                                       # look for any sequences
+                                       if item_type in ['TABLE'] and nextline.find('serial') > -1:
+                                               sequences[item_name] = nextline.strip().split()[0]
                                        fields.append(nextline)
                                        if nextline.find(";") >= 0:
                                                break
                                schema[item_name] = (item_type, fields)
                        else:
-                               print "ERROR: unknown type %s" % item_type
+                               print "Error: unknown type %s" % item_type
                elif line.startswith("INSERT"):
                        inserts.append(line)
                index = index + 1
@@ -364,7 +390,7 @@ try:
 except:
        raise
 
-print "STATUS: generating temp tables"
+print "Status: generating temp tables"
 # generate all temp tables
 for key in schema_items_ordered:
        (type, body_list) = schema[key]
@@ -375,11 +401,11 @@ for key in schema_items_ordered:
 cursor.close()
 db.close()
 
-print "STATUS: archiving database"
-archive_db()
+print "Status: archiving database"
+archive_db(config['PLC_DB_NAME'], config['PLC_DB_NAME']+'_archived')
+os.system('createdb -U postgres -E UTF8 %s > /dev/null; ' % config['PLC_DB_NAME'])
 
-
-print "STATUS: upgrading database"
+print "Status: upgrading database"
 # attempt to create and load all items from schema into temp db
 try:
        for key in schema_items_ordered:
@@ -387,31 +413,45 @@ try:
                create_item_from_schema(key)
                if type == 'TABLE':
                        if upgrade_config.has_key(key):                         
+                               # attempt to populate with temp table data
                                table_def = upgrade_config[key].replace('(', '').replace(')', '').split(',')
                                table_fields = [field.strip().split(':')[0] for field in table_def]
                                insert_cmd = "psql %s %s -c " \
                                              " 'COPY %s (%s) FROM stdin;' < %s " % \
-                                             (database, config['PLC_DB_USER'], key, ", ".join(table_fields), temp_tables[key] )
+                                             (config['PLC_DB_NAME'], config['PLC_DB_USER'], key, 
+                                             ", ".join(table_fields), temp_tables[key] )
                                exit_status = os.system(insert_cmd)
                                if exit_status:
-                                       print "ERROR: upgrade %s failed" % key
-                                       raise
+                                       print "Error: upgrade %s failed" % key
+                                       sys.exit(1)
+                               # update the primary key sequence
+                               if sequences.has_key(key):
+                                       sequence = key +"_"+ sequences[key] +"_seq"
+                                       update_seq = "psql %s %s -c " \
+                                            " \"select setval('%s', max(%s)) FROM %s;\" > /dev/null" % \
+                                            (config['PLC_DB_NAME'], config['PLC_DB_USER'], sequence, 
+                                             sequences[key], key)
+                                       exit_status = os.system(update_seq)
+                                       if exit_status:
+                                               print "Error: sequence %s update failed" % sequence
+                                               sys.exit(1)
                        else:
                                # check if there are any insert stmts in schema for this table
-                               print "WARNING: %s has no temp data file. Unable to populate with old data" % key
+                               print "Warning: %s has no temp data file. Unable to populate with old data" % key
                                for insert_stmt in inserts:
                                        if insert_stmt.find(key) > -1:
                                                insert_cmd = 'psql %s postgres -qc "%s;" > /dev/null 2>&1' % \
-                                               (database, insert_stmt)
+                                               (config['PLC_DB_NAME'], insert_stmt)
                                                os.system(insert_cmd) 
 except:
-       print "ERROR: failed to populate db. Unarchiving original database and aborting"
-       undo_command = "dropdb -U postgres %s; psql template1 postgres -qc" \
+       print "Error: failed to populate db. Unarchiving original database and aborting"
+       undo_command = "dropdb -U postgres %s > /dev/null; psql template1 postgres -qc" \
                        " 'ALTER DATABASE %s RENAME TO %s;';  > /dev/null" % \
-                       (database, archived_database, database)
+                       (config['PLC_DB_NAME'], config['PLC_DB_NAME']+'_archived', config['PLC_DB_NAME'])
        os.system(undo_command) 
-       remove_temp_tables()
+       #remove_temp_tables()
+       raise
        
-remove_temp_tables()
+#remove_temp_tables()
 
 print "upgrade complete"