- support joins and wheres in config file
authorTony Mack <tmack@cs.princeton.edu>
Tue, 21 Nov 2006 15:25:30 +0000 (15:25 +0000)
committerTony Mack <tmack@cs.princeton.edu>
Tue, 21 Nov 2006 15:25:30 +0000 (15:25 +0000)
tools/upgrade-db.py

index ad89fbb..8c66854 100755 (executable)
@@ -170,13 +170,14 @@ def generate_temp_table(table_name, db):
         try:
                 # get upgrade directions
                 table_def = upgrade_config[table_name].replace('(', '').replace(')', '').split(',')
-                table_fields, old_fields, required_joins = [], [], set()
+                table_fields, old_fields, joins, wheres = [], [], set(), set()
                 for field in table_def:
                         field_parts = field.strip().split(':')
                         table_fields.append(field_parts[0])
                         old_fields.append(field_parts[1])
-                        if field_parts[2:]:
-                               required_joins.update(set(field_parts[2:]))
+                        if field_parts[2:]:    
+                               joins.update(set(filter(lambda x: not x.find('=') > -1, field_parts[2:])))
+                               wheres.update(set(filter(lambda x: x.find('=') > -1, field_parts[2:])))
                
                # get indices of fields that cannot be null
                (type, body_list) = schema[table_name]
@@ -200,10 +201,13 @@ def generate_temp_table(table_name, db):
                 # get old data
                 get_old_data = "SELECT DISTINCT %s FROM %s" % \
                       (", ".join(old_fields), old_fields[0].split(".")[0])
-                for join in required_joins:
+                for join in joins:
                         get_old_data = get_old_data + " INNER JOIN %s USING (%s) " % \
                                        (join.split('.')[0], join.split('.')[1])
-
+               if wheres:      
+                       get_old_data = get_old_data + " WHERE " 
+               for where in wheres:
+                       get_old_data = get_old_data + " %s" % where 
                 cursor.execute(get_old_data)
                 rows = cursor.fetchall()