remove simplejson dependency
[plcapi.git] / tools / upgrade-db.py
index e38a8b4..4c9d1d5 100755 (executable)
@@ -1,6 +1,15 @@
 #!/usr/bin/python
 #
-# Tool for upgrading a db based on db version #
+# Tool for upgrading/converting a db
+# Requirements:
+# 1) Databse Schema - schema for the new database you what to upgrade to
+# 2) Config File - the config file that describes how to convert the db
+#
+# Notes:
+# 1) Will attempt to convert the db defined in  /etc/planetlab/plc_config
+# 2) Does not automatically drop archived database. They must be removed
+#    manually
+
 import sys
 import os
 import getopt
@@ -81,10 +90,9 @@ def archive_db(database, archived_database):
 
         archive_db = " dropdb -U postgres %s > /dev/null 2>&1;" \
                     " psql template1 postgres -qc " \
-                     " 'ALTER DATABASE %s RENAME TO %s;';" \
-                     " createdb -U postgres %s > /dev/null; " % \
-                     (archived_database, database, archived_database, database)
-        exit_status = os.system(archive_db)
+                     " 'ALTER DATABASE %s RENAME TO %s;';" % \
+                     (archived_database, database, archived_database)
+       exit_status = os.system(archive_db)
         if exit_status:
                 print "Error: unable to archive database. Upgrade failed"
                 sys.exit(1)
@@ -97,7 +105,9 @@ def encode_utf8(inputfile_name, outputfile_name):
                inputfile = open(inputfile_name, 'r')
                outputfile = open(outputfile_name, 'w')
                for line in inputfile:
-                       outputfile.write(unicode(line, 'iso-8859-1').encode('utf8'))
+                       if line.upper().find('SET CLIENT_ENCODING') > -1:
+                               continue
+                       outputfile.write(unicode(line, 'iso-8859-1').encode('utf8'))
                inputfile.close()
                outputfile.close()              
        except:
@@ -124,7 +134,7 @@ def create_item_from_schema(item_name):
 
 def fix_row(row, table_name, table_fields):
 
-       if table_name in ['nodenetworks']:
+       if table_name in ['interfaces']:
                # convert str bwlimit to bps int
                bwlimit_index = table_fields.index('bwlimit')
                if isinstance(row[bwlimit_index], int):
@@ -142,9 +152,13 @@ def fix_row(row, table_name, table_fields):
                        row[attribute_type_index] = 10016
                elif row[attribute_type_index] == 10006:
                        row[attribute_type_index] = 10017
+               elif row[attribute_type_index] in [10031, 10033]:
+                       row[attribute_type_index] = 10037
+               elif row[attribute_type_index] in [10034, 10035]:
+                       row[attribute_type_index] = 10036
        elif table_name in ['slice_attribute_types']:
                type_id_index = table_fields.index('attribute_type_id')
-               if row[type_id_index] in [10004, 10006]:
+               if row[type_id_index] in [10004, 10006, 10031, 10033, 10034, 10035]:
                        return None
        return row
        
@@ -153,7 +167,7 @@ def fix_table(table, table_name, table_fields):
                # remove duplicate/redundant primary keys
                type_id_index = table_fields.index('attribute_type_id')
                for row in table:
-                       if row[type_id_index] in [10004, 10006]:
+                       if row[type_id_index] in [10004, 10006, 10031, 10033, 10034, 10035]:
                                table.remove(row)
        return table
 
@@ -187,17 +201,16 @@ def generate_temp_table(table_name, db):
                                if body_line.find(field) > -1 and \
                                   body_line.upper().find("NOT NULL") > -1:
                                        not_null_indices.append(table_fields.index(field))
-
                # get index of primary key
                primary_key_indices = []
                for body_line in body_list:
                        if body_line.find("PRIMARY KEY") > -1:
                                primary_key = body_line
                                for field in table_fields:
-                                       if primary_key.find(field) > -1:
+                                       if primary_key.find(" "+field+" ") > -1:
                                                primary_key_indices.append(table_fields.index(field))
-                               break
-
+                               #break
+       
                 # get old data
                 get_old_data = "SELECT DISTINCT %s FROM %s" % \
                       (", ".join(old_fields), old_fields[0].split(".")[0])
@@ -312,7 +325,7 @@ try:
                # archive original db
                archive_db(config['PLC_DB_NAME'], config['PLC_DB_NAME']+'_sqlascii_archived')
                # create a utf8 database and upload encoded data
-               recreate_cmd = 'createdb -U postgres -E UTF8 %s > /dev/null 2>&1; ' \
+               recreate_cmd = 'createdb -U postgres -E UTF8 %s > /dev/null; ' \
                               'psql -a -U  %s %s < %s > /dev/null 2>&1;'   % \
                          (config['PLC_DB_NAME'], config['PLC_DB_USER'], \
                           config['PLC_DB_NAME'], dump_file_encoded) 
@@ -337,9 +350,14 @@ try:
        lines = file.readlines()
        while index < len(lines):
                line = lines[index] 
+               if line.find("--") > -1:
+                       line_parts = line.split("--")
+                        line = line_parts[0]
                # find all created objects
                if line.startswith("CREATE"):
                        line_parts = line.split(" ")
+                       if line_parts[1:3] == ['OR', 'REPLACE']:
+                               line_parts = line_parts[2:]
                        item_type = line_parts[1]
                        item_name = line_parts[2]
                        schema_items_ordered.append(item_name)
@@ -353,6 +371,9 @@ try:
                                while index < len(lines):
                                        index = index + 1
                                        nextline =lines[index]
+                                       if nextline.find("--") > -1:
+                                                new_line_parts = nextline.split("--")
+                                                nextline = new_line_parts[0]
                                        # look for any sequences
                                        if item_type in ['TABLE'] and nextline.find('serial') > -1:
                                                sequences[item_name] = nextline.strip().split()[0]
@@ -382,7 +403,7 @@ db.close()
 
 print "Status: archiving database"
 archive_db(config['PLC_DB_NAME'], config['PLC_DB_NAME']+'_archived')
-
+os.system('createdb -U postgres -E UTF8 %s > /dev/null; ' % config['PLC_DB_NAME'])
 
 print "Status: upgrading database"
 # attempt to create and load all items from schema into temp db
@@ -428,9 +449,9 @@ except:
                        " 'ALTER DATABASE %s RENAME TO %s;';  > /dev/null" % \
                        (config['PLC_DB_NAME'], config['PLC_DB_NAME']+'_archived', config['PLC_DB_NAME'])
        os.system(undo_command) 
-       remove_temp_tables()
+       #remove_temp_tables()
        raise
        
-remove_temp_tables()
+#remove_temp_tables()
 
 print "upgrade complete"