Setting tag plcapi-5.4-2
[plcapi.git] / tools / upgrade-db.py
1 #!/usr/bin/python
2 #
3 # Tool for upgrading/converting a db
4 # Requirements:
5 # 1) Databse Schema - schema for the new database you what to upgrade to
6 # 2) Config File - the config file that describes how to convert the db
7 #
8 # Notes:
9 # 1) Will attempt to convert the db defined in  /etc/planetlab/plc_config
10 # 2) Does not automatically drop archived database. They must be removed
11 #    manually
12
13 import sys
14 import os
15 import getopt
16 import pgdb
17
18 config = {}
19 config_file = "/etc/planetlab/plc_config"
20 execfile(config_file, config)
21 upgrade_config_file = "plcdb.3-4.conf"
22 schema_file = "planetlab4.sql"
23 temp_dir = "/tmp"
24
25
26 def usage():
27         print "Usage: %s [OPTION] UPGRADE_CONFIG_FILE " % sys.argv[0]
28         print "Options:"
29         print "     -s, --schema=FILE       Upgraded Database Schema"
30         print "     -t, --temp-dir=DIR      Temp Directory"
31         print "     --help                  This message"
32         sys.exit(1)
33
34 try:
35         (opts, argv) = getopt.getopt(sys.argv[1:],
36                                      "s:d:",
37                                      ["schema=",
38                                       "temp-dir=",
39                                       "help"])
40 except getopt.GetoptError, err:
41         print "Error: ", err.msg
42         usage()
43
44 for (opt, optval) in opts:
45         if opt == "-s" or opt == "--schema":
46                 schema_file = optval
47         elif opt == "-d" or opt == "--temp-dir":
48                 temp_dir = optval
49         elif opt == "--help":
50                 usage()
51 try:
52         upgrade_config_file = argv[0]
53 except IndexError:
54         print "Error: too few arguments"
55         usage()
56
57 schema = {}
58 inserts = []
59 schema_items_ordered = []
60 sequences = {}
61 temp_tables = {}
62
63
64 # load conf file for this upgrade
65 try:
66         upgrade_config = {}
67         execfile(upgrade_config_file, upgrade_config)
68         upgrade_config.pop('__builtins__')
69         db_version_previous = upgrade_config['DB_VERSION_PREVIOUS']
70         db_version_new = upgrade_config['DB_VERSION_NEW']
71
72 except IOError, fault:
73         print "Error: upgrade config file (%s) not found. Exiting" % \
74                 (fault)
75         sys.exit(1) 
76 except KeyError, fault:
77         print "Error: %s not set in upgrade confing (%s). Exiting" % \
78                 (fault, upgrade_config_file)
79         sys.exit(1)
80
81
82
83
84 def connect():
85         db = pgdb.connect(user = config['PLC_DB_USER'],
86                   database = config['PLC_DB_NAME'])     
87         return db
88
89 def archive_db(database, archived_database):
90
91         archive_db = " dropdb -U postgres %s > /dev/null 2>&1;" \
92                      " psql template1 postgres -qc " \
93                      " 'ALTER DATABASE %s RENAME TO %s;';" % \
94                      (archived_database, database, archived_database)
95         exit_status = os.system(archive_db)
96         if exit_status:
97                 print "Error: unable to archive database. Upgrade failed"
98                 sys.exit(1)
99         #print "Status: %s has been archived. now named %s" % (database, archived_database)
100
101
102 def encode_utf8(inputfile_name, outputfile_name):
103         # rewrite a iso-8859-1 encoded file in utf8
104         try:
105                 inputfile = open(inputfile_name, 'r')
106                 outputfile = open(outputfile_name, 'w')
107                 for line in inputfile:
108                         if line.upper().find('SET CLIENT_ENCODING') > -1:
109                                 continue
110                         outputfile.write(unicode(line, 'iso-8859-1').encode('utf8'))
111                 inputfile.close()
112                 outputfile.close()              
113         except:
114                 print 'error encoding file'
115                 raise
116
117 def create_item_from_schema(item_name):
118
119         try:
120                 (type, body_list) = schema[item_name]
121                 exit_status = os.system('psql %s %s -qc "%s" > /dev/null 2>&1' % \
122                             (config['PLC_DB_NAME'], config['PLC_DB_USER'],"".join(body_list) ) )
123                 if exit_status:
124                         raise Exception
125         except Exception, fault:
126                 print 'Error: create %s failed. Check schema.' % item_name
127                 sys.exit(1)
128                 raise fault
129
130         except KeyError:
131                 print "Error: cannot create %s. definition not found in %s" % \
132                         (key, schema_file)
133                 return False
134
135 def fix_row(row, table_name, table_fields):
136
137         if table_name in ['interfaces']:
138                 # convert str bwlimit to bps int
139                 bwlimit_index = table_fields.index('bwlimit')
140                 if isinstance(row[bwlimit_index], int):
141                         pass
142                 elif row[bwlimit_index].find('mbit') > -1:
143                         row[bwlimit_index] = int(row[bwlimit_index].split('mbit')[0]) \
144                                             * 1000000
145                 elif row[bwlimit_index].find('kbit') > -1:
146                         row[bwlimit_index] = int(row[bwlimit_index].split('kbit')[0]) \
147                                              * 1000
148         elif table_name in ['slice_attribute']:
149                 # modify some invalid foreign keys
150                 attribute_type_index = table_fields.index('attribute_type_id')
151                 if row[attribute_type_index] == 10004:
152                         row[attribute_type_index] = 10016
153                 elif row[attribute_type_index] == 10006:
154                         row[attribute_type_index] = 10017
155                 elif row[attribute_type_index] in [10031, 10033]:
156                         row[attribute_type_index] = 10037
157                 elif row[attribute_type_index] in [10034, 10035]:
158                         row[attribute_type_index] = 10036
159         elif table_name in ['slice_attribute_types']:
160                 type_id_index = table_fields.index('attribute_type_id')
161                 if row[type_id_index] in [10004, 10006, 10031, 10033, 10034, 10035]:
162                         return None
163         return row
164         
165 def fix_table(table, table_name, table_fields):
166         if table_name in ['slice_attribute_types']:
167                 # remove duplicate/redundant primary keys
168                 type_id_index = table_fields.index('attribute_type_id')
169                 for row in table:
170                         if row[type_id_index] in [10004, 10006, 10031, 10033, 10034, 10035]:
171                                 table.remove(row)
172         return table
173
174 def remove_temp_tables():
175         # remove temp_tables
176         try:
177                 for temp_table in temp_tables:
178                         os.remove(temp_tables[temp_table])
179         except:
180                 raise
181
182 def generate_temp_table(table_name, db):
183         cursor = db.cursor()
184         try:
185                 # get upgrade directions
186                 table_def = upgrade_config[table_name].replace('(', '').replace(')', '').split(',')
187                 table_fields, old_fields, joins, wheres = [], [], set(), set()
188                 for field in table_def:
189                         field_parts = field.strip().split(':')
190                         table_fields.append(field_parts[0])
191                         old_fields.append(field_parts[1])
192                         if field_parts[2:]:     
193                                 joins.update(set(filter(lambda x: not x.find('=') > -1, field_parts[2:])))
194                                 wheres.update(set(filter(lambda x: x.find('=') > -1, field_parts[2:])))
195                 
196                 # get indices of fields that cannot be null
197                 (type, body_list) = schema[table_name]
198                 not_null_indices = []
199                 for field in table_fields:
200                         for body_line in body_list:
201                                 if body_line.find(field) > -1 and \
202                                    body_line.upper().find("NOT NULL") > -1:
203                                         not_null_indices.append(table_fields.index(field))
204                 # get index of primary key
205                 primary_key_indices = []
206                 for body_line in body_list:
207                         if body_line.find("PRIMARY KEY") > -1:
208                                 primary_key = body_line
209                                 for field in table_fields:
210                                         if primary_key.find(" "+field+" ") > -1:
211                                                 primary_key_indices.append(table_fields.index(field))
212                                 #break
213         
214                 # get old data
215                 get_old_data = "SELECT DISTINCT %s FROM %s" % \
216                       (", ".join(old_fields), old_fields[0].split(".")[0])
217                 for join in joins:
218                         get_old_data = get_old_data + " INNER JOIN %s USING (%s) " % \
219                                        (join.split('.')[0], join.split('.')[1])
220                 if wheres:      
221                         get_old_data = get_old_data + " WHERE " 
222                 for where in wheres:
223                         get_old_data = get_old_data + " %s" % where 
224                 cursor.execute(get_old_data)
225                 rows = cursor.fetchall()
226
227                 # write data to a temp file
228                 temp_file_name = '%s/%s.tmp' % (temp_dir, table_name)
229                 temp_file = open(temp_file_name, 'w')
230                 for row in rows:
231                         # attempt to make any necessary fixes to data
232                         row = fix_row(row, table_name, table_fields)
233                         # do not attempt to write null rows
234                         if row == None:
235                                 continue
236                         # do not attempt to write rows with null primary keys
237                         if filter(lambda x: row[x] == None, primary_key_indices):
238                                 continue 
239                         for i in range(len(row)):
240                                 # convert nulls into something pg can understand
241                                 if row[i] == None:
242                                         if i in not_null_indices:
243                                                 # XX doesnt work if column is int type
244                                                 row[i] = ""
245                                         else: 
246                                                 row[i] = "\N"
247                                 if isinstance(row[i], int) or isinstance(row[i], float):
248                                         row[i] = str(row[i])
249                                 # escape whatever can mess up the data format
250                                 if isinstance(row[i], str):
251                                         row[i] = row[i].replace('\t', '\\t')
252                                         row[i] = row[i].replace('\n', '\\n')
253                                         row[i] = row[i].replace('\r', '\\r')
254                         data_row = "\t".join(row)
255                         temp_file.write(data_row + "\n")
256                 temp_file.write("\.\n")
257                 temp_file.close()
258                 temp_tables[table_name] = temp_file_name
259
260         except KeyError:
261                 #print "WARNING: cannot upgrade %s. upgrade def not found. skipping" % \
262                 #       (table_name)
263                 return False
264         except IndexError, fault:
265                 print "Error: error found in upgrade config file. " \
266                       "check %s configuration. Aborting " % \
267                       (table_name)
268                 sys.exit(1)
269         except:
270                 print "Error: configuration for %s doesnt match db schema. " \
271                       " Aborting" % (table_name)
272                 try:
273                         db.rollback()
274                 except:
275                         pass
276                 raise
277
278
279 # Connect to current db
280 db = connect()
281 cursor = db.cursor()
282
283 # determin current db version
284 try:
285         cursor.execute("SELECT relname from pg_class where relname = 'plc_db_version'")
286         rows = cursor.fetchall()
287         if not rows:
288                 print "Warning: current db has no version. Unable to validate config file."
289         else:
290                 cursor.execute("SELECT version FROM plc_db_version")
291                 rows = cursor.fetchall()
292                 if not rows or not rows[0]:
293                         print "Warning: current db has no version. Unable to validate config file."
294                 elif rows[0][0] == db_version_new:
295                         print "Status: Versions are the same. No upgrade necessary."
296                         sys.exit()
297                 elif not rows[0][0] == db_version_previous:
298                         print "Stauts: DB_VERSION_PREVIOUS in config file (%s) does not" \
299                               " match current db version %d" % (upgrade_config_file, rows[0][0])
300                         sys.exit()
301                 else:
302                         print "STATUS: attempting upgrade from %d to %d" % \
303                                 (db_version_previous, db_version_new)   
304         
305         # check db encoding
306         sql = " SELECT pg_catalog.pg_encoding_to_char(d.encoding)" \
307               " FROM pg_catalog.pg_database d " \
308               " WHERE d.datname = '%s' " % config['PLC_DB_NAME']
309         cursor.execute(sql)
310         rows = cursor.fetchall()
311         if rows[0][0] not in ['UTF8', 'UNICODE']:
312                 print "WARNING: db encoding is not utf8. Attempting to encode"
313                 db.close()
314                 # generate db dump
315                 dump_file = '%s/dump.sql' % (temp_dir)
316                 dump_file_encoded = dump_file + ".utf8"
317                 dump_cmd = 'pg_dump -i %s -U postgres -f %s > /dev/null 2>&1' % \
318                            (config['PLC_DB_NAME'], dump_file)
319                 if os.system(dump_cmd):
320                         print "ERROR: during db dump. Exiting."
321                         sys.exit(1)
322                 # encode dump to utf8
323                 print "Status: encoding database dump"
324                 encode_utf8(dump_file, dump_file_encoded)
325                 # archive original db
326                 archive_db(config['PLC_DB_NAME'], config['PLC_DB_NAME']+'_sqlascii_archived')
327                 # create a utf8 database and upload encoded data
328                 recreate_cmd = 'createdb -U postgres -E UTF8 %s > /dev/null; ' \
329                                'psql -a -U  %s %s < %s > /dev/null 2>&1;'   % \
330                           (config['PLC_DB_NAME'], config['PLC_DB_USER'], \
331                            config['PLC_DB_NAME'], dump_file_encoded) 
332                 print "Status: recreating database as utf8"
333                 if os.system(recreate_cmd):
334                         print "Error: database encoding failed. Aborting"
335                         sys.exit(1)
336                 
337                 os.remove(dump_file_encoded)
338                 os.remove(dump_file)
339 except:
340         raise
341
342
343 db = connect()
344 cursor = db.cursor()
345
346 # parse the schema user wishes to upgrade to
347 try:
348         file = open(schema_file, 'r')
349         index = 0
350         lines = file.readlines()
351         while index < len(lines):
352                 line = lines[index] 
353                 if line.find("--") > -1:
354                         line_parts = line.split("--")
355                         line = line_parts[0]
356                 # find all created objects
357                 if line.startswith("CREATE"):
358                         line_parts = line.split(" ")
359                         if line_parts[1:3] == ['OR', 'REPLACE']:
360                                 line_parts = line_parts[2:]
361                         item_type = line_parts[1]
362                         item_name = line_parts[2]
363                         schema_items_ordered.append(item_name)
364                         if item_type in ['INDEX']:
365                                 schema[item_name] = (item_type, line)
366                         
367                         # functions, tables, views span over multiple lines
368                         # handle differently than indexes
369                         elif item_type in ['AGGREGATE', 'TABLE', 'VIEW']:
370                                 fields = [line]
371                                 while index < len(lines):
372                                         index = index + 1
373                                         nextline =lines[index]
374                                         if nextline.find("--") > -1:
375                                                 new_line_parts = nextline.split("--")
376                                                 nextline = new_line_parts[0]
377                                         # look for any sequences
378                                         if item_type in ['TABLE'] and nextline.find('serial') > -1:
379                                                 sequences[item_name] = nextline.strip().split()[0]
380                                         fields.append(nextline)
381                                         if nextline.find(";") >= 0:
382                                                 break
383                                 schema[item_name] = (item_type, fields)
384                         else:
385                                 print "Error: unknown type %s" % item_type
386                 elif line.startswith("INSERT"):
387                         inserts.append(line)
388                 index = index + 1
389                                 
390 except:
391         raise
392
393 print "Status: generating temp tables"
394 # generate all temp tables
395 for key in schema_items_ordered:
396         (type, body_list) = schema[key]
397         if type == 'TABLE':
398                 generate_temp_table(key, db)
399
400 # disconenct from current database and archive it
401 cursor.close()
402 db.close()
403
404 print "Status: archiving database"
405 archive_db(config['PLC_DB_NAME'], config['PLC_DB_NAME']+'_archived')
406 os.system('createdb -U postgres -E UTF8 %s > /dev/null; ' % config['PLC_DB_NAME'])
407
408 print "Status: upgrading database"
409 # attempt to create and load all items from schema into temp db
410 try:
411         for key in schema_items_ordered:
412                 (type, body_list) = schema[key]
413                 create_item_from_schema(key)
414                 if type == 'TABLE':
415                         if upgrade_config.has_key(key):                         
416                                 # attempt to populate with temp table data
417                                 table_def = upgrade_config[key].replace('(', '').replace(')', '').split(',')
418                                 table_fields = [field.strip().split(':')[0] for field in table_def]
419                                 insert_cmd = "psql %s %s -c " \
420                                              " 'COPY %s (%s) FROM stdin;' < %s " % \
421                                              (config['PLC_DB_NAME'], config['PLC_DB_USER'], key, 
422                                               ", ".join(table_fields), temp_tables[key] )
423                                 exit_status = os.system(insert_cmd)
424                                 if exit_status:
425                                         print "Error: upgrade %s failed" % key
426                                         sys.exit(1)
427                                 # update the primary key sequence
428                                 if sequences.has_key(key):
429                                         sequence = key +"_"+ sequences[key] +"_seq"
430                                         update_seq = "psql %s %s -c " \
431                                              " \"select setval('%s', max(%s)) FROM %s;\" > /dev/null" % \
432                                              (config['PLC_DB_NAME'], config['PLC_DB_USER'], sequence, 
433                                               sequences[key], key)
434                                         exit_status = os.system(update_seq)
435                                         if exit_status:
436                                                 print "Error: sequence %s update failed" % sequence
437                                                 sys.exit(1)
438                         else:
439                                 # check if there are any insert stmts in schema for this table
440                                 print "Warning: %s has no temp data file. Unable to populate with old data" % key
441                                 for insert_stmt in inserts:
442                                         if insert_stmt.find(key) > -1:
443                                                 insert_cmd = 'psql %s postgres -qc "%s;" > /dev/null 2>&1' % \
444                                                 (config['PLC_DB_NAME'], insert_stmt)
445                                                 os.system(insert_cmd) 
446 except:
447         print "Error: failed to populate db. Unarchiving original database and aborting"
448         undo_command = "dropdb -U postgres %s > /dev/null; psql template1 postgres -qc" \
449                        " 'ALTER DATABASE %s RENAME TO %s;';  > /dev/null" % \
450                        (config['PLC_DB_NAME'], config['PLC_DB_NAME']+'_archived', config['PLC_DB_NAME'])
451         os.system(undo_command) 
452         #remove_temp_tables()
453         raise
454         
455 #remove_temp_tables()
456
457 print "upgrade complete"