- changes to addresses, site_address, address_address_type table config (person_addre...
[plcapi.git] / tools / upgrade-db.py
index d2aff3c..ad89fbb 100755 (executable)
@@ -48,6 +48,7 @@ except IndexError:
 schema = {}
 inserts = []
 schema_items_ordered = []
+sequences = {}
 temp_tables = {}
 
 
@@ -78,8 +79,7 @@ def connect():
 
 def archive_db(database, archived_database):
 
-       print "Status: archiving old database"
-        archive_db = " dropdb -U postgres %s; > /dev/null 2>&1" \
+        archive_db = " dropdb -U postgres %s > /dev/null 2>&1;" \
                     " psql template1 postgres -qc " \
                      " 'ALTER DATABASE %s RENAME TO %s;';" \
                      " createdb -U postgres %s > /dev/null; " % \
@@ -88,11 +88,11 @@ def archive_db(database, archived_database):
         if exit_status:
                 print "Error: unable to archive database. Upgrade failed"
                 sys.exit(1)
-        print "Status: %s has been archived. now named %s" % (database, archived_database)
+        #print "Status: %s has been archived. now named %s" % (database, archived_database)
 
 
 def encode_utf8(inputfile_name, outputfile_name):
-       # rewrite a iso-8859-1 encoded file and in utf8
+       # rewrite a iso-8859-1 encoded file in utf8
        try:
                inputfile = open(inputfile_name, 'r')
                outputfile = open(outputfile_name, 'w')
@@ -268,12 +268,13 @@ try:
        cursor.execute("SELECT relname from pg_class where relname = 'plc_db_version'")
        rows = cursor.fetchall()
        if not rows:
-               print "WARNING: current db has no version. Unable to validate config file."
+               print "Warning: current db has no version. Unable to validate config file."
        else:
                cursor.execute("SELECT version FROM plc_db_version")
                rows = cursor.fetchall()
-
-               if rows[0][0] == db_version_new:
+               if not rows or not rows[0]:
+                       print "Warning: current db has no version. Unable to validate config file."
+               elif rows[0][0] == db_version_new:
                                print "Status: Versions are the same. No upgrade necessary."
                        sys.exit()
                elif not rows[0][0] == db_version_previous:
@@ -297,8 +298,8 @@ try:
                # generate db dump
                dump_file = '%s/dump.sql' % (temp_dir)
                dump_file_encoded = dump_file + ".utf8"
-               dump_cmd = 'pg_dump -i %s -U %s -f %s > /dev/null 2>&1' % \
-                          (config['PLC_DB_NAME'], config['PLC_DB_USER'], dump_file)
+               dump_cmd = 'pg_dump -i %s -U postgres -f %s > /dev/null 2>&1' % \
+                          (config['PLC_DB_NAME'], dump_file)
                if os.system(dump_cmd):
                        print "ERROR: during db dump. Exiting."
                        sys.exit(1)
@@ -308,10 +309,10 @@ try:
                # archive original db
                archive_db(config['PLC_DB_NAME'], config['PLC_DB_NAME']+'_sqlascii_archived')
                # create a utf8 database and upload encoded data
-               recreate_cmd = 'createdb -U %s -E UTF8 %s > /dev/null 2>&1; ' \
+               recreate_cmd = 'createdb -U postgres -E UTF8 %s > /dev/null 2>&1; ' \
                               'psql -a -U  %s %s < %s > /dev/null 2>&1;'   % \
-                         (config['PLC_DB_USER'], config['PLC_DB_NAME'], \
-                          config['PLC_DB_USER'], config['PLC_DB_NAME'], dump_file_encoded) 
+                         (config['PLC_DB_NAME'], config['PLC_DB_USER'], \
+                          config['PLC_DB_NAME'], dump_file_encoded) 
                print "Status: recreating database as utf8"
                if os.system(recreate_cmd):
                        print "Error: database encoding failed. Aborting"
@@ -349,6 +350,9 @@ try:
                                while index < len(lines):
                                        index = index + 1
                                        nextline =lines[index]
+                                       # look for any sequences
+                                       if item_type in ['TABLE'] and nextline.find('serial') > -1:
+                                               sequences[item_name] = nextline.strip().split()[0]
                                        fields.append(nextline)
                                        if nextline.find(";") >= 0:
                                                break
@@ -385,30 +389,44 @@ try:
                create_item_from_schema(key)
                if type == 'TABLE':
                        if upgrade_config.has_key(key):                         
+                               # attempt to populate with temp table data
                                table_def = upgrade_config[key].replace('(', '').replace(')', '').split(',')
                                table_fields = [field.strip().split(':')[0] for field in table_def]
                                insert_cmd = "psql %s %s -c " \
                                              " 'COPY %s (%s) FROM stdin;' < %s " % \
-                                             (database, config['PLC_DB_USER'], key, ", ".join(table_fields), temp_tables[key] )
+                                             (config['PLC_DB_NAME'], config['PLC_DB_USER'], key, 
+                                             ", ".join(table_fields), temp_tables[key] )
                                exit_status = os.system(insert_cmd)
                                if exit_status:
                                        print "Error: upgrade %s failed" % key
-                                       raise
+                                       sys.exit(1)
+                               # update the primary key sequence
+                               if sequences.has_key(key):
+                                       sequence = key +"_"+ sequences[key] +"_seq"
+                                       update_seq = "psql %s %s -c " \
+                                            " \"select setval('%s', max(%s)) FROM %s;\" > /dev/null" % \
+                                            (config['PLC_DB_NAME'], config['PLC_DB_USER'], sequence, 
+                                             sequences[key], key)
+                                       exit_status = os.system(update_seq)
+                                       if exit_status:
+                                               print "Error: sequence %s update failed" % sequence
+                                               sys.exit(1)
                        else:
                                # check if there are any insert stmts in schema for this table
                                print "Warning: %s has no temp data file. Unable to populate with old data" % key
                                for insert_stmt in inserts:
                                        if insert_stmt.find(key) > -1:
                                                insert_cmd = 'psql %s postgres -qc "%s;" > /dev/null 2>&1' % \
-                                               (database, insert_stmt)
+                                               (config['PLC_DB_NAME'], insert_stmt)
                                                os.system(insert_cmd) 
 except:
        print "Error: failed to populate db. Unarchiving original database and aborting"
-       undo_command = "dropdb -U postgres %s; psql template1 postgres -qc" \
+       undo_command = "dropdb -U postgres %s > /dev/null; psql template1 postgres -qc" \
                        " 'ALTER DATABASE %s RENAME TO %s;';  > /dev/null" % \
-                       (database, archived_database, database)
+                       (config['PLC_DB_NAME'], config['PLC_DB_NAME']+'_archived', config['PLC_DB_NAME'])
        os.system(undo_command) 
        remove_temp_tables()
+       raise
        
 remove_temp_tables()