fix bug in insert()
[sfa.git] / sfa / util / genitable.py
1 # genitable.py
2 #
3 # implements support for geni records stored in db tables
4 #
5 # TODO: Use existing PLC database methods? or keep this separate?
6
7 ### $Id$
8 ### $URL$
9
10 import report
11 import  pgdb
12 from pg import DB, ProgrammingError
13 from sfa.trust.gid import *
14 from sfa.util.record import *
15 from sfa.util.debug import *
16 from sfa.util.config import *
17 from sfa.util.filter import *
18
19 class GeniTable(list):
20
21     GENI_TABLE_PREFIX = "sfa"
22
23     def __init__(self, record_filter = None):
24
25         # pgsql doesn't like table names with "." in them, to replace it with "$"
26         self.tablename = GeniTable.GENI_TABLE_PREFIX
27
28         # establish a connection to the pgsql server
29         cninfo = Config().get_plc_dbinfo()     
30         self.cnx = DB(cninfo['dbname'], cninfo['address'], port=cninfo['port'], user=cninfo['user'], passwd=cninfo['password'])
31
32         if record_filter:
33             records = self.find(record_filter)
34             for record in reocrds:
35                 self.append(record)             
36
37     def exists(self):
38         tableList = self.cnx.get_tables()
39         if 'public.' + self.tablename in tableList:
40             return True
41         if 'public."' + self.tablename + '"' in tableList:
42             return True
43         return False
44
45     def create(self):
46         
47         querystr = "CREATE TABLE " + self.tablename + " ( \
48                 record_id serial PRIMARY KEY , \
49                 hrn text NOT NULL, \
50                 authority text NOT NULL, \
51                 gid text, \
52                 type text NOT NULL, \
53                 pointer integer, \
54                 date_created timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP, \
55                 last_updated timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP);"
56         template = "CREATE INDEX %s_%s_idx ON %s (%s);"
57         indexes = [template % ( self.tablename, field, self.tablename, field) \
58                    for field in ['hrn', 'type', 'authority', 'pointer']]
59         # IF EXISTS doenst exist in postgres < 8.2
60         try:
61             self.cnx.query('DROP TABLE IF EXISTS ' + self.tablename)
62         except ProgrammingError:
63             try:
64                 self.cnx.query('DROP TABLE ' + self.tablename)
65             except ProgrammingError:
66                 pass
67          
68         self.cnx.query(querystr)
69         for index in indexes:
70             self.cnx.query(index)
71
72     def remove(self, record):
73         query_str = "DELETE FROM %s WHERE record_id = %s" % (self.tablename, record['record_id']) 
74         self.cnx.query(query_str)
75
76     def insert(self, record):
77         dont_insert = ['date_created', 'last_updated', 'record_id']
78         fieldnames = [field for field in  record.all_fields.keys() if field not in dont_insert]  
79         fieldvals = record.get_field_value_strings(fieldnames)
80         query_str = "INSERT INTO " + self.tablename + \
81                        "(" + ",".join(fieldnames) + ") " + \
82                        "VALUES(" + ",".join(fieldvals) + ")"
83         #print query_str
84         self.cnx.query(query_str)
85         result = self.find({'hrn': record['hrn'], 'type': record['type']})
86         if not result:
87             record_id = None
88         elif isinstance(result, list):
89             record_id = result[0]['record_id']
90         else:
91             record_id = result['record_id']
92
93         return record_id
94
95     def update(self, record):
96         dont_update = ['date_created', 'last_updated', 'record_id']
97         fields = [field for field in  record.all_fields.keys() if field not in dont_update]  
98         fieldvals = record.get_field_value_strings(fields)
99         pairs = []
100         for field in fields:
101             val = record.get_field_value_string(field)
102             pairs.append(field + " = " + val)
103         update = ", ".join(pairs)
104
105         query_str = "UPDATE %s SET %s WHERE record_id = %s" % \
106                     (self.tablename, update, record['record_id'])
107         self.cnx.query(query_str)
108
109     def quote(self, value):
110         """
111         Returns quoted version of the specified value.
112         """
113
114         # The pgdb._quote function is good enough for general SQL
115         # quoting, except for array types.
116         if isinstance(value, (list, tuple, set)):
117             return "ARRAY[%s]" % ", ".join(map, self.quote, value)
118         else:
119             return pgdb._quote(value)
120
121     def find(self, record_filter = None):
122         sql = "SELECT * FROM %s WHERE True " % self.tablename
123         
124         if isinstance(record_filter, (list, tuple, set)):
125             ints = filter(lambda x: isinstance(x, (int, long)), record_filter)
126             strs = filter(lambda x: isinstance(x, StringTypes), record_filter)
127             record_filter = Filter(GeniRecord.all_fields, {'record_id': ints, 'hrn': strs})
128             sql += "AND (%s) %s " % record_filter.sql("OR") 
129         elif isinstance(record_filter, dict):
130             record_filter = Filter(GeniRecord.all_fields, record_filter)        
131             sql += " AND (%s) %s" % record_filter.sql("AND")
132         elif isinstance(record_filter, StringTypes):
133             record_filter = Filter(GeniRecord.all_fields, {'hrn':[record_filter]})    
134             sql += " AND (%s) %s" % record_filter.sql("AND")
135         elif isinstance(record_filter, int):
136             record_filter = Filter(GeniRecord.all_fields, {'record_id':[record_filter]})    
137             sql += " AND (%s) %s" % record_filter.sql("AND")
138         results = self.cnx.query(sql).dictresult()
139         return results
140
141     def findObjects(self, record_filter = None):
142         
143         results = self.find(record_filter) 
144         result_rec_list = []
145         for result in results:
146             if result['type'] in ['authority']:
147                 result_rec_list.append(AuthorityRecord(dict=result))
148             elif result['type'] in ['node']:
149                 result_rec_list.append(NodeRecord(dict=result))
150             elif result['type'] in ['slice']:
151                 result_rec_list.append(SliceRecord(dict=result))
152             elif result['type'] in ['user']:
153                 result_rec_list.append(UserRecord(dict=result))
154             else:
155                 result_rec_list.append(GeniRecord(dict=result))
156         return result_rec_list
157
158
159     def drop(self):
160         try:
161             self.cnx.query('DROP TABLE IF EXISTS ' + self.tablename)
162         except ProgrammingError:
163             try:
164                 self.cnx.query('DROP TABLE ' + self.tablename)
165             except ProgrammingError:
166                 pass
167     
168     @staticmethod
169     def geni_records_purge(cninfo):
170
171         cnx = DB(cninfo['dbname'], cninfo['address'], 
172                  port=cninfo['port'], user=cninfo['user'], passwd=cninfo['password'])
173         tableList = cnx.get_tables()
174         for table in tableList:
175             if table.startswith(GeniTable.GENI_TABLE_PREFIX) or \
176                     table.startswith('public.' + GeniTable.GENI_TABLE_PREFIX) or \
177                     table.startswith('public."' + GeniTable.GENI_TABLE_PREFIX):
178                 report.trace("dropping table " + table)
179                 cnx.query("DROP TABLE " + table)