- fix db encoding check
[plcapi.git] / tools / upgrade-db.py
1 #!/usr/bin/python
2 #
3 # Tool for upgrading a db based on db version #
4 import sys
5 import os
6 import getopt
7 import pgdb
8
9 config = {}
10 config_file = "/etc/planetlab/plc_config"
11 execfile(config_file, config)
12 upgrade_config_file = "plcdb.3-4.conf"
13 schema_file = "planetlab4.sql"
14 temp_dir = "/tmp"
15
16
17 def usage():
18         print "Usage: %s [OPTION] UPGRADE_CONFIG_FILE " % sys.argv[0]
19         print "Options:"
20         print "     -s, --schema=FILE       Upgraded Database Schema"
21         print "     -t, --temp-dir=DIR      Temp Directory"
22         print "     --help                  This message"
23         sys.exit(1)
24
25 try:
26         (opts, argv) = getopt.getopt(sys.argv[1:],
27                                      "s:d:",
28                                      ["schema=",
29                                       "temp-dir=",
30                                       "help"])
31 except getopt.GetoptError, err:
32         print "Error: ", err.msg
33         usage()
34
35 for (opt, optval) in opts:
36         if opt == "-s" or opt == "--schema":
37                 schema_file = optval
38         elif opt == "-d" or opt == "--temp-dir":
39                 temp_dir = optval
40         elif opt == "--help":
41                 usage()
42 try:
43         upgrade_config_file = argv[0]
44 except IndexError:
45         print "Error: too few arguments"
46         usage()
47
48 schema = {}
49 inserts = []
50 schema_items_ordered = []
51 sequences = {}
52 temp_tables = {}
53
54
55 # load conf file for this upgrade
56 try:
57         upgrade_config = {}
58         execfile(upgrade_config_file, upgrade_config)
59         upgrade_config.pop('__builtins__')
60         db_version_previous = upgrade_config['DB_VERSION_PREVIOUS']
61         db_version_new = upgrade_config['DB_VERSION_NEW']
62
63 except IOError, fault:
64         print "Error: upgrade config file (%s) not found. Exiting" % \
65                 (fault)
66         sys.exit(1) 
67 except KeyError, fault:
68         print "Error: %s not set in upgrade confing (%s). Exiting" % \
69                 (fault, upgrade_config_file)
70         sys.exit(1)
71
72
73
74
75 def connect():
76         db = pgdb.connect(user = config['PLC_DB_USER'],
77                   database = config['PLC_DB_NAME'])     
78         return db
79
80 def archive_db(database, archived_database):
81
82         archive_db = " dropdb -U postgres %s > /dev/null 2>&1;" \
83                      " psql template1 postgres -qc " \
84                      " 'ALTER DATABASE %s RENAME TO %s;';" \
85                      " createdb -U postgres %s > /dev/null; " % \
86                      (archived_database, database, archived_database, database)
87         exit_status = os.system(archive_db)
88         if exit_status:
89                 print "Error: unable to archive database. Upgrade failed"
90                 sys.exit(1)
91         #print "Status: %s has been archived. now named %s" % (database, archived_database)
92
93
94 def encode_utf8(inputfile_name, outputfile_name):
95         # rewrite a iso-8859-1 encoded file in utf8
96         try:
97                 inputfile = open(inputfile_name, 'r')
98                 outputfile = open(outputfile_name, 'w')
99                 for line in inputfile:
100                         outputfile.write(unicode(line, 'iso-8859-1').encode('utf8'))
101                 inputfile.close()
102                 outputfile.close()              
103         except:
104                 print 'error encoding file'
105                 raise
106
107 def create_item_from_schema(item_name):
108
109         try:
110                 (type, body_list) = schema[item_name]
111                 exit_status = os.system('psql %s %s -qc "%s" > /dev/null 2>&1' % \
112                             (config['PLC_DB_NAME'], config['PLC_DB_USER'],"".join(body_list) ) )
113                 if exit_status:
114                         raise Exception
115         except Exception, fault:
116                 print 'Error: create %s failed. Check schema.' % item_name
117                 sys.exit(1)
118                 raise fault
119
120         except KeyError:
121                 print "Error: cannot create %s. definition not found in %s" % \
122                         (key, schema_file)
123                 return False
124
125 def fix_row(row, table_name, table_fields):
126
127         if table_name in ['nodenetworks']:
128                 # convert str bwlimit to bps int
129                 bwlimit_index = table_fields.index('bwlimit')
130                 if isinstance(row[bwlimit_index], int):
131                         pass
132                 elif row[bwlimit_index].find('mbit') > -1:
133                         row[bwlimit_index] = int(row[bwlimit_index].split('mbit')[0]) \
134                                             * 1000000
135                 elif row[bwlimit_index].find('kbit') > -1:
136                         row[bwlimit_index] = int(row[bwlimit_index].split('kbit')[0]) \
137                                              * 1000
138         elif table_name in ['slice_attribute']:
139                 # modify some invalid foreign keys
140                 attribute_type_index = table_fields.index('attribute_type_id')
141                 if row[attribute_type_index] == 10004:
142                         row[attribute_type_index] = 10016
143                 elif row[attribute_type_index] == 10006:
144                         row[attribute_type_index] = 10017
145         elif table_name in ['slice_attribute_types']:
146                 type_id_index = table_fields.index('attribute_type_id')
147                 if row[type_id_index] in [10004, 10006]:
148                         return None
149         return row
150         
151 def fix_table(table, table_name, table_fields):
152         if table_name in ['slice_attribute_types']:
153                 # remove duplicate/redundant primary keys
154                 type_id_index = table_fields.index('attribute_type_id')
155                 for row in table:
156                         if row[type_id_index] in [10004, 10006]:
157                                 table.remove(row)
158         return table
159
160 def remove_temp_tables():
161         # remove temp_tables
162         try:
163                 for temp_table in temp_tables:
164                         os.remove(temp_tables[temp_table])
165         except:
166                 raise
167
168 def generate_temp_table(table_name, db):
169         cursor = db.cursor()
170         try:
171                 # get upgrade directions
172                 table_def = upgrade_config[table_name].replace('(', '').replace(')', '').split(',')
173                 table_fields, old_fields, joins, wheres = [], [], set(), set()
174                 for field in table_def:
175                         field_parts = field.strip().split(':')
176                         table_fields.append(field_parts[0])
177                         old_fields.append(field_parts[1])
178                         if field_parts[2:]:     
179                                 joins.update(set(filter(lambda x: not x.find('=') > -1, field_parts[2:])))
180                                 wheres.update(set(filter(lambda x: x.find('=') > -1, field_parts[2:])))
181                 
182                 # get indices of fields that cannot be null
183                 (type, body_list) = schema[table_name]
184                 not_null_indices = []
185                 for field in table_fields:
186                         for body_line in body_list:
187                                 if body_line.find(field) > -1 and \
188                                    body_line.upper().find("NOT NULL") > -1:
189                                         not_null_indices.append(table_fields.index(field))
190
191                 # get index of primary key
192                 primary_key_indices = []
193                 for body_line in body_list:
194                         if body_line.find("PRIMARY KEY") > -1:
195                                 primary_key = body_line
196                                 for field in table_fields:
197                                         if primary_key.find(field) > -1:
198                                                 primary_key_indices.append(table_fields.index(field))
199                                 break
200
201                 # get old data
202                 get_old_data = "SELECT DISTINCT %s FROM %s" % \
203                       (", ".join(old_fields), old_fields[0].split(".")[0])
204                 for join in joins:
205                         get_old_data = get_old_data + " INNER JOIN %s USING (%s) " % \
206                                        (join.split('.')[0], join.split('.')[1])
207                 if wheres:      
208                         get_old_data = get_old_data + " WHERE " 
209                 for where in wheres:
210                         get_old_data = get_old_data + " %s" % where 
211                 cursor.execute(get_old_data)
212                 rows = cursor.fetchall()
213
214                 # write data to a temp file
215                 temp_file_name = '%s/%s.tmp' % (temp_dir, table_name)
216                 temp_file = open(temp_file_name, 'w')
217                 for row in rows:
218                         # attempt to make any necessary fixes to data
219                         row = fix_row(row, table_name, table_fields)
220                         # do not attempt to write null rows
221                         if row == None:
222                                 continue
223                         # do not attempt to write rows with null primary keys
224                         if filter(lambda x: row[x] == None, primary_key_indices):
225                                 continue 
226                         for i in range(len(row)):
227                                 # convert nulls into something pg can understand
228                                 if row[i] == None:
229                                         if i in not_null_indices:
230                                                 # XX doesnt work if column is int type
231                                                 row[i] = ""
232                                         else: 
233                                                 row[i] = "\N"
234                                 if isinstance(row[i], int) or isinstance(row[i], float):
235                                         row[i] = str(row[i])
236                                 # escape whatever can mess up the data format
237                                 if isinstance(row[i], str):
238                                         row[i] = row[i].replace('\t', '\\t')
239                                         row[i] = row[i].replace('\n', '\\n')
240                                         row[i] = row[i].replace('\r', '\\r')
241                         data_row = "\t".join(row)
242                         temp_file.write(data_row + "\n")
243                 temp_file.write("\.\n")
244                 temp_file.close()
245                 temp_tables[table_name] = temp_file_name
246
247         except KeyError:
248                 #print "WARNING: cannot upgrade %s. upgrade def not found. skipping" % \
249                 #       (table_name)
250                 return False
251         except IndexError, fault:
252                 print "Error: error found in upgrade config file. " \
253                       "check %s configuration. Aborting " % \
254                       (table_name)
255                 sys.exit(1)
256         except:
257                 print "Error: configuration for %s doesnt match db schema. " \
258                       " Aborting" % (table_name)
259                 try:
260                         db.rollback()
261                 except:
262                         pass
263                 raise
264
265
266 # Connect to current db
267 db = connect()
268 cursor = db.cursor()
269
270 # determin current db version
271 try:
272         cursor.execute("SELECT relname from pg_class where relname = 'plc_db_version'")
273         rows = cursor.fetchall()
274         if not rows:
275                 print "Warning: current db has no version. Unable to validate config file."
276         else:
277                 cursor.execute("SELECT version FROM plc_db_version")
278                 rows = cursor.fetchall()
279                 if not rows or not rows[0]:
280                         print "Warning: current db has no version. Unable to validate config file."
281                 elif rows[0][0] == db_version_new:
282                         print "Status: Versions are the same. No upgrade necessary."
283                         sys.exit()
284                 elif not rows[0][0] == db_version_previous:
285                         print "Stauts: DB_VERSION_PREVIOUS in config file (%s) does not" \
286                               " match current db version %d" % (upgrade_config_file, rows[0][0])
287                         sys.exit()
288                 else:
289                         print "STATUS: attempting upgrade from %d to %d" % \
290                                 (db_version_previous, db_version_new)   
291         
292         # check db encoding
293         sql = " SELECT pg_catalog.pg_encoding_to_char(d.encoding)" \
294               " FROM pg_catalog.pg_database d " \
295               " WHERE d.datname = '%s' " % config['PLC_DB_NAME']
296         cursor.execute(sql)
297         rows = cursor.fetchall()
298         if rows[0][0] not in ['UTF8', 'UNICODE']:
299                 print "WARNING: db encoding is not utf8. Attempting to encode"
300                 db.close()
301                 # generate db dump
302                 dump_file = '%s/dump.sql' % (temp_dir)
303                 dump_file_encoded = dump_file + ".utf8"
304                 dump_cmd = 'pg_dump -i %s -U postgres -f %s > /dev/null 2>&1' % \
305                            (config['PLC_DB_NAME'], dump_file)
306                 if os.system(dump_cmd):
307                         print "ERROR: during db dump. Exiting."
308                         sys.exit(1)
309                 # encode dump to utf8
310                 print "Status: encoding database dump"
311                 encode_utf8(dump_file, dump_file_encoded)
312                 # archive original db
313                 archive_db(config['PLC_DB_NAME'], config['PLC_DB_NAME']+'_sqlascii_archived')
314                 # create a utf8 database and upload encoded data
315                 recreate_cmd = 'createdb -U postgres -E UTF8 %s > /dev/null 2>&1; ' \
316                                'psql -a -U  %s %s < %s > /dev/null 2>&1;'   % \
317                           (config['PLC_DB_NAME'], config['PLC_DB_USER'], \
318                            config['PLC_DB_NAME'], dump_file_encoded) 
319                 print "Status: recreating database as utf8"
320                 if os.system(recreate_cmd):
321                         print "Error: database encoding failed. Aborting"
322                         sys.exit(1)
323                 
324                 os.remove(dump_file_encoded)
325                 os.remove(dump_file)
326 except:
327         raise
328
329
330 db = connect()
331 cursor = db.cursor()
332
333 # parse the schema user wishes to upgrade to
334 try:
335         file = open(schema_file, 'r')
336         index = 0
337         lines = file.readlines()
338         while index < len(lines):
339                 line = lines[index] 
340                 # find all created objects
341                 if line.startswith("CREATE"):
342                         line_parts = line.split(" ")
343                         item_type = line_parts[1]
344                         item_name = line_parts[2]
345                         schema_items_ordered.append(item_name)
346                         if item_type in ['INDEX']:
347                                 schema[item_name] = (item_type, line)
348                         
349                         # functions, tables, views span over multiple lines
350                         # handle differently than indexes
351                         elif item_type in ['AGGREGATE', 'TABLE', 'VIEW']:
352                                 fields = [line]
353                                 while index < len(lines):
354                                         index = index + 1
355                                         nextline =lines[index]
356                                         # look for any sequences
357                                         if item_type in ['TABLE'] and nextline.find('serial') > -1:
358                                                 sequences[item_name] = nextline.strip().split()[0]
359                                         fields.append(nextline)
360                                         if nextline.find(";") >= 0:
361                                                 break
362                                 schema[item_name] = (item_type, fields)
363                         else:
364                                 print "Error: unknown type %s" % item_type
365                 elif line.startswith("INSERT"):
366                         inserts.append(line)
367                 index = index + 1
368                                 
369 except:
370         raise
371
372 print "Status: generating temp tables"
373 # generate all temp tables
374 for key in schema_items_ordered:
375         (type, body_list) = schema[key]
376         if type == 'TABLE':
377                 generate_temp_table(key, db)
378
379 # disconenct from current database and archive it
380 cursor.close()
381 db.close()
382
383 print "Status: archiving database"
384 archive_db(config['PLC_DB_NAME'], config['PLC_DB_NAME']+'_archived')
385
386
387 print "Status: upgrading database"
388 # attempt to create and load all items from schema into temp db
389 try:
390         for key in schema_items_ordered:
391                 (type, body_list) = schema[key]
392                 create_item_from_schema(key)
393                 if type == 'TABLE':
394                         if upgrade_config.has_key(key):                         
395                                 # attempt to populate with temp table data
396                                 table_def = upgrade_config[key].replace('(', '').replace(')', '').split(',')
397                                 table_fields = [field.strip().split(':')[0] for field in table_def]
398                                 insert_cmd = "psql %s %s -c " \
399                                              " 'COPY %s (%s) FROM stdin;' < %s " % \
400                                              (config['PLC_DB_NAME'], config['PLC_DB_USER'], key, 
401                                               ", ".join(table_fields), temp_tables[key] )
402                                 exit_status = os.system(insert_cmd)
403                                 if exit_status:
404                                         print "Error: upgrade %s failed" % key
405                                         sys.exit(1)
406                                 # update the primary key sequence
407                                 if sequences.has_key(key):
408                                         sequence = key +"_"+ sequences[key] +"_seq"
409                                         update_seq = "psql %s %s -c " \
410                                              " \"select setval('%s', max(%s)) FROM %s;\" > /dev/null" % \
411                                              (config['PLC_DB_NAME'], config['PLC_DB_USER'], sequence, 
412                                               sequences[key], key)
413                                         exit_status = os.system(update_seq)
414                                         if exit_status:
415                                                 print "Error: sequence %s update failed" % sequence
416                                                 sys.exit(1)
417                         else:
418                                 # check if there are any insert stmts in schema for this table
419                                 print "Warning: %s has no temp data file. Unable to populate with old data" % key
420                                 for insert_stmt in inserts:
421                                         if insert_stmt.find(key) > -1:
422                                                 insert_cmd = 'psql %s postgres -qc "%s;" > /dev/null 2>&1' % \
423                                                 (config['PLC_DB_NAME'], insert_stmt)
424                                                 os.system(insert_cmd) 
425 except:
426         print "Error: failed to populate db. Unarchiving original database and aborting"
427         undo_command = "dropdb -U postgres %s > /dev/null; psql template1 postgres -qc" \
428                        " 'ALTER DATABASE %s RENAME TO %s;';  > /dev/null" % \
429                        (config['PLC_DB_NAME'], config['PLC_DB_NAME']+'_archived', config['PLC_DB_NAME'])
430         os.system(undo_command) 
431         remove_temp_tables()
432         raise
433         
434 remove_temp_tables()
435
436 print "upgrade complete"