- fix db archive
[plcapi.git] / tools / upgrade-db.py
1 #!/usr/bin/python
2 #
3 # Tool for upgrading/converting a db
4 # Requirements:
5 # 1) Databse Schema - schema for the new database you what to upgrade to
6 # 2) Config File - the config file that describes how to convert the db
7 #
8 # Notes:
9 # 1) Script will attempt to conver the db defined in  /etc/planetlab/plc_config
10 import sys
11 import os
12 import getopt
13 import pgdb
14
15 config = {}
16 config_file = "/etc/planetlab/plc_config"
17 execfile(config_file, config)
18 upgrade_config_file = "plcdb.3-4.conf"
19 schema_file = "planetlab4.sql"
20 temp_dir = "/tmp"
21
22
23 def usage():
24         print "Usage: %s [OPTION] UPGRADE_CONFIG_FILE " % sys.argv[0]
25         print "Options:"
26         print "     -s, --schema=FILE       Upgraded Database Schema"
27         print "     -t, --temp-dir=DIR      Temp Directory"
28         print "     --help                  This message"
29         sys.exit(1)
30
31 try:
32         (opts, argv) = getopt.getopt(sys.argv[1:],
33                                      "s:d:",
34                                      ["schema=",
35                                       "temp-dir=",
36                                       "help"])
37 except getopt.GetoptError, err:
38         print "Error: ", err.msg
39         usage()
40
41 for (opt, optval) in opts:
42         if opt == "-s" or opt == "--schema":
43                 schema_file = optval
44         elif opt == "-d" or opt == "--temp-dir":
45                 temp_dir = optval
46         elif opt == "--help":
47                 usage()
48 try:
49         upgrade_config_file = argv[0]
50 except IndexError:
51         print "Error: too few arguments"
52         usage()
53
54 schema = {}
55 inserts = []
56 schema_items_ordered = []
57 sequences = {}
58 temp_tables = {}
59
60
61 # load conf file for this upgrade
62 try:
63         upgrade_config = {}
64         execfile(upgrade_config_file, upgrade_config)
65         upgrade_config.pop('__builtins__')
66         db_version_previous = upgrade_config['DB_VERSION_PREVIOUS']
67         db_version_new = upgrade_config['DB_VERSION_NEW']
68
69 except IOError, fault:
70         print "Error: upgrade config file (%s) not found. Exiting" % \
71                 (fault)
72         sys.exit(1) 
73 except KeyError, fault:
74         print "Error: %s not set in upgrade confing (%s). Exiting" % \
75                 (fault, upgrade_config_file)
76         sys.exit(1)
77
78
79
80
81 def connect():
82         db = pgdb.connect(user = config['PLC_DB_USER'],
83                   database = config['PLC_DB_NAME'])     
84         return db
85
86 def archive_db(database, archived_database):
87
88         archive_db = " dropdb -U postgres %s > /dev/null 2>&1;" \
89                      " psql template1 postgres -qc " \
90                      " 'ALTER DATABASE %s RENAME TO %s;';" % \
91                      (archived_database, database, archived_database)
92         exit_status = os.system(archive_db)
93         if exit_status:
94                 print "Error: unable to archive database. Upgrade failed"
95                 sys.exit(1)
96         #print "Status: %s has been archived. now named %s" % (database, archived_database)
97
98
99 def encode_utf8(inputfile_name, outputfile_name):
100         # rewrite a iso-8859-1 encoded file in utf8
101         try:
102                 inputfile = open(inputfile_name, 'r')
103                 outputfile = open(outputfile_name, 'w')
104                 for line in inputfile:
105                         if line.upper().find('SET CLIENT_ENCODING') > -1:
106                                 continue
107                         outputfile.write(unicode(line, 'iso-8859-1').encode('utf8'))
108                 inputfile.close()
109                 outputfile.close()              
110         except:
111                 print 'error encoding file'
112                 raise
113
114 def create_item_from_schema(item_name):
115
116         try:
117                 (type, body_list) = schema[item_name]
118                 exit_status = os.system('psql %s %s -qc "%s" > /dev/null 2>&1' % \
119                             (config['PLC_DB_NAME'], config['PLC_DB_USER'],"".join(body_list) ) )
120                 if exit_status:
121                         raise Exception
122         except Exception, fault:
123                 print 'Error: create %s failed. Check schema.' % item_name
124                 sys.exit(1)
125                 raise fault
126
127         except KeyError:
128                 print "Error: cannot create %s. definition not found in %s" % \
129                         (key, schema_file)
130                 return False
131
132 def fix_row(row, table_name, table_fields):
133
134         if table_name in ['nodenetworks']:
135                 # convert str bwlimit to bps int
136                 bwlimit_index = table_fields.index('bwlimit')
137                 if isinstance(row[bwlimit_index], int):
138                         pass
139                 elif row[bwlimit_index].find('mbit') > -1:
140                         row[bwlimit_index] = int(row[bwlimit_index].split('mbit')[0]) \
141                                             * 1000000
142                 elif row[bwlimit_index].find('kbit') > -1:
143                         row[bwlimit_index] = int(row[bwlimit_index].split('kbit')[0]) \
144                                              * 1000
145         elif table_name in ['slice_attribute']:
146                 # modify some invalid foreign keys
147                 attribute_type_index = table_fields.index('attribute_type_id')
148                 if row[attribute_type_index] == 10004:
149                         row[attribute_type_index] = 10016
150                 elif row[attribute_type_index] == 10006:
151                         row[attribute_type_index] = 10017
152         elif table_name in ['slice_attribute_types']:
153                 type_id_index = table_fields.index('attribute_type_id')
154                 if row[type_id_index] in [10004, 10006]:
155                         return None
156         return row
157         
158 def fix_table(table, table_name, table_fields):
159         if table_name in ['slice_attribute_types']:
160                 # remove duplicate/redundant primary keys
161                 type_id_index = table_fields.index('attribute_type_id')
162                 for row in table:
163                         if row[type_id_index] in [10004, 10006]:
164                                 table.remove(row)
165         return table
166
167 def remove_temp_tables():
168         # remove temp_tables
169         try:
170                 for temp_table in temp_tables:
171                         os.remove(temp_tables[temp_table])
172         except:
173                 raise
174
175 def generate_temp_table(table_name, db):
176         cursor = db.cursor()
177         try:
178                 # get upgrade directions
179                 table_def = upgrade_config[table_name].replace('(', '').replace(')', '').split(',')
180                 table_fields, old_fields, joins, wheres = [], [], set(), set()
181                 for field in table_def:
182                         field_parts = field.strip().split(':')
183                         table_fields.append(field_parts[0])
184                         old_fields.append(field_parts[1])
185                         if field_parts[2:]:     
186                                 joins.update(set(filter(lambda x: not x.find('=') > -1, field_parts[2:])))
187                                 wheres.update(set(filter(lambda x: x.find('=') > -1, field_parts[2:])))
188                 
189                 # get indices of fields that cannot be null
190                 (type, body_list) = schema[table_name]
191                 not_null_indices = []
192                 for field in table_fields:
193                         for body_line in body_list:
194                                 if body_line.find(field) > -1 and \
195                                    body_line.upper().find("NOT NULL") > -1:
196                                         not_null_indices.append(table_fields.index(field))
197
198                 # get index of primary key
199                 primary_key_indices = []
200                 for body_line in body_list:
201                         if body_line.find("PRIMARY KEY") > -1:
202                                 primary_key = body_line
203                                 for field in table_fields:
204                                         if primary_key.find(field) > -1:
205                                                 primary_key_indices.append(table_fields.index(field))
206                                 break
207
208                 # get old data
209                 get_old_data = "SELECT DISTINCT %s FROM %s" % \
210                       (", ".join(old_fields), old_fields[0].split(".")[0])
211                 for join in joins:
212                         get_old_data = get_old_data + " INNER JOIN %s USING (%s) " % \
213                                        (join.split('.')[0], join.split('.')[1])
214                 if wheres:      
215                         get_old_data = get_old_data + " WHERE " 
216                 for where in wheres:
217                         get_old_data = get_old_data + " %s" % where 
218                 cursor.execute(get_old_data)
219                 rows = cursor.fetchall()
220
221                 # write data to a temp file
222                 temp_file_name = '%s/%s.tmp' % (temp_dir, table_name)
223                 temp_file = open(temp_file_name, 'w')
224                 for row in rows:
225                         # attempt to make any necessary fixes to data
226                         row = fix_row(row, table_name, table_fields)
227                         # do not attempt to write null rows
228                         if row == None:
229                                 continue
230                         # do not attempt to write rows with null primary keys
231                         if filter(lambda x: row[x] == None, primary_key_indices):
232                                 continue 
233                         for i in range(len(row)):
234                                 # convert nulls into something pg can understand
235                                 if row[i] == None:
236                                         if i in not_null_indices:
237                                                 # XX doesnt work if column is int type
238                                                 row[i] = ""
239                                         else: 
240                                                 row[i] = "\N"
241                                 if isinstance(row[i], int) or isinstance(row[i], float):
242                                         row[i] = str(row[i])
243                                 # escape whatever can mess up the data format
244                                 if isinstance(row[i], str):
245                                         row[i] = row[i].replace('\t', '\\t')
246                                         row[i] = row[i].replace('\n', '\\n')
247                                         row[i] = row[i].replace('\r', '\\r')
248                         data_row = "\t".join(row)
249                         temp_file.write(data_row + "\n")
250                 temp_file.write("\.\n")
251                 temp_file.close()
252                 temp_tables[table_name] = temp_file_name
253
254         except KeyError:
255                 #print "WARNING: cannot upgrade %s. upgrade def not found. skipping" % \
256                 #       (table_name)
257                 return False
258         except IndexError, fault:
259                 print "Error: error found in upgrade config file. " \
260                       "check %s configuration. Aborting " % \
261                       (table_name)
262                 sys.exit(1)
263         except:
264                 print "Error: configuration for %s doesnt match db schema. " \
265                       " Aborting" % (table_name)
266                 try:
267                         db.rollback()
268                 except:
269                         pass
270                 raise
271
272
273 # Connect to current db
274 db = connect()
275 cursor = db.cursor()
276
277 # determin current db version
278 try:
279         cursor.execute("SELECT relname from pg_class where relname = 'plc_db_version'")
280         rows = cursor.fetchall()
281         if not rows:
282                 print "Warning: current db has no version. Unable to validate config file."
283         else:
284                 cursor.execute("SELECT version FROM plc_db_version")
285                 rows = cursor.fetchall()
286                 if not rows or not rows[0]:
287                         print "Warning: current db has no version. Unable to validate config file."
288                 elif rows[0][0] == db_version_new:
289                         print "Status: Versions are the same. No upgrade necessary."
290                         sys.exit()
291                 elif not rows[0][0] == db_version_previous:
292                         print "Stauts: DB_VERSION_PREVIOUS in config file (%s) does not" \
293                               " match current db version %d" % (upgrade_config_file, rows[0][0])
294                         sys.exit()
295                 else:
296                         print "STATUS: attempting upgrade from %d to %d" % \
297                                 (db_version_previous, db_version_new)   
298         
299         # check db encoding
300         sql = " SELECT pg_catalog.pg_encoding_to_char(d.encoding)" \
301               " FROM pg_catalog.pg_database d " \
302               " WHERE d.datname = '%s' " % config['PLC_DB_NAME']
303         cursor.execute(sql)
304         rows = cursor.fetchall()
305         if rows[0][0] not in ['UTF8', 'UNICODE']:
306                 print "WARNING: db encoding is not utf8. Attempting to encode"
307                 db.close()
308                 # generate db dump
309                 dump_file = '%s/dump.sql' % (temp_dir)
310                 dump_file_encoded = dump_file + ".utf8"
311                 dump_cmd = 'pg_dump -i %s -U postgres -f %s > /dev/null 2>&1' % \
312                            (config['PLC_DB_NAME'], dump_file)
313                 if os.system(dump_cmd):
314                         print "ERROR: during db dump. Exiting."
315                         sys.exit(1)
316                 # encode dump to utf8
317                 print "Status: encoding database dump"
318                 encode_utf8(dump_file, dump_file_encoded)
319                 # archive original db
320                 archive_db(config['PLC_DB_NAME'], config['PLC_DB_NAME']+'_sqlascii_archived')
321                 # create a utf8 database and upload encoded data
322                 recreate_cmd = 'createdb -U postgres -E UTF8 %s > /dev/null; ' \
323                                'psql -a -U  %s %s < %s > /dev/null 2>&1;'   % \
324                           (config['PLC_DB_NAME'], config['PLC_DB_USER'], \
325                            config['PLC_DB_NAME'], dump_file_encoded) 
326                 print "Status: recreating database as utf8"
327                 if os.system(recreate_cmd):
328                         print "Error: database encoding failed. Aborting"
329                         sys.exit(1)
330                 
331                 os.remove(dump_file_encoded)
332                 os.remove(dump_file)
333 except:
334         raise
335
336
337 db = connect()
338 cursor = db.cursor()
339
340 # parse the schema user wishes to upgrade to
341 try:
342         file = open(schema_file, 'r')
343         index = 0
344         lines = file.readlines()
345         while index < len(lines):
346                 line = lines[index] 
347                 # find all created objects
348                 if line.startswith("CREATE"):
349                         line_parts = line.split(" ")
350                         item_type = line_parts[1]
351                         item_name = line_parts[2]
352                         schema_items_ordered.append(item_name)
353                         if item_type in ['INDEX']:
354                                 schema[item_name] = (item_type, line)
355                         
356                         # functions, tables, views span over multiple lines
357                         # handle differently than indexes
358                         elif item_type in ['AGGREGATE', 'TABLE', 'VIEW']:
359                                 fields = [line]
360                                 while index < len(lines):
361                                         index = index + 1
362                                         nextline =lines[index]
363                                         # look for any sequences
364                                         if item_type in ['TABLE'] and nextline.find('serial') > -1:
365                                                 sequences[item_name] = nextline.strip().split()[0]
366                                         fields.append(nextline)
367                                         if nextline.find(";") >= 0:
368                                                 break
369                                 schema[item_name] = (item_type, fields)
370                         else:
371                                 print "Error: unknown type %s" % item_type
372                 elif line.startswith("INSERT"):
373                         inserts.append(line)
374                 index = index + 1
375                                 
376 except:
377         raise
378
379 print "Status: generating temp tables"
380 # generate all temp tables
381 for key in schema_items_ordered:
382         (type, body_list) = schema[key]
383         if type == 'TABLE':
384                 generate_temp_table(key, db)
385
386 # disconenct from current database and archive it
387 cursor.close()
388 db.close()
389
390 print "Status: archiving database"
391 archive_db(config['PLC_DB_NAME'], config['PLC_DB_NAME']+'_archived')
392 os.system('createdb -U postgres -E UTF8 %s > /dev/null; ' % config['PLC_DB_NAME'])
393
394 print "Status: upgrading database"
395 # attempt to create and load all items from schema into temp db
396 try:
397         for key in schema_items_ordered:
398                 (type, body_list) = schema[key]
399                 create_item_from_schema(key)
400                 if type == 'TABLE':
401                         if upgrade_config.has_key(key):                         
402                                 # attempt to populate with temp table data
403                                 table_def = upgrade_config[key].replace('(', '').replace(')', '').split(',')
404                                 table_fields = [field.strip().split(':')[0] for field in table_def]
405                                 insert_cmd = "psql %s %s -c " \
406                                              " 'COPY %s (%s) FROM stdin;' < %s " % \
407                                              (config['PLC_DB_NAME'], config['PLC_DB_USER'], key, 
408                                               ", ".join(table_fields), temp_tables[key] )
409                                 exit_status = os.system(insert_cmd)
410                                 if exit_status:
411                                         print "Error: upgrade %s failed" % key
412                                         sys.exit(1)
413                                 # update the primary key sequence
414                                 if sequences.has_key(key):
415                                         sequence = key +"_"+ sequences[key] +"_seq"
416                                         update_seq = "psql %s %s -c " \
417                                              " \"select setval('%s', max(%s)) FROM %s;\" > /dev/null" % \
418                                              (config['PLC_DB_NAME'], config['PLC_DB_USER'], sequence, 
419                                               sequences[key], key)
420                                         exit_status = os.system(update_seq)
421                                         if exit_status:
422                                                 print "Error: sequence %s update failed" % sequence
423                                                 sys.exit(1)
424                         else:
425                                 # check if there are any insert stmts in schema for this table
426                                 print "Warning: %s has no temp data file. Unable to populate with old data" % key
427                                 for insert_stmt in inserts:
428                                         if insert_stmt.find(key) > -1:
429                                                 insert_cmd = 'psql %s postgres -qc "%s;" > /dev/null 2>&1' % \
430                                                 (config['PLC_DB_NAME'], insert_stmt)
431                                                 os.system(insert_cmd) 
432 except:
433         print "Error: failed to populate db. Unarchiving original database and aborting"
434         undo_command = "dropdb -U postgres %s > /dev/null; psql template1 postgres -qc" \
435                        " 'ALTER DATABASE %s RENAME TO %s;';  > /dev/null" % \
436                        (config['PLC_DB_NAME'], config['PLC_DB_NAME']+'_archived', config['PLC_DB_NAME'])
437         os.system(undo_command) 
438         remove_temp_tables()
439         raise
440         
441 remove_temp_tables()
442
443 print "upgrade complete"