fix bug in update method
[sfa.git] / sfa / util / genitable.py
1 # genitable.py
2 #
3 # implements support for geni records stored in db tables
4 #
5 # TODO: Use existing PLC database methods? or keep this separate?
6
7 ### $Id$
8 ### $URL$
9
10 import report
11
12 from pg import DB, ProgrammingError
13
14 from sfa.trust.gid import *
15 from sfa.util.record import *
16 from sfa.util.debug import *
17
18 class GeniTable:
19
20     GENI_TABLE_PREFIX = "sfa$"
21
22     def __init__(self, create=False, hrn="unspecified.default.registry", cninfo=None):
23
24         self.hrn = hrn
25
26         # pgsql doesn't like table names with "." in them, to replace it with "$"
27         self.tablename = GeniTable.GENI_TABLE_PREFIX + self.hrn.replace(".", "$")
28
29         # establish a connection to the pgsql server
30         self.cnx = DB(cninfo['dbname'], cninfo['address'], port=cninfo['port'], user=cninfo['user'], passwd=cninfo['password'])
31
32         # if asked to create the table, then create it
33         if create:
34             self.create()
35
36     def exists(self):
37         tableList = self.cnx.get_tables()
38         if 'public.' + self.tablename in tableList:
39             return True
40         if 'public."' + self.tablename + '"' in tableList:
41             return True
42         return False
43
44     def create(self):
45         
46         querystr = "CREATE TABLE " + self.tablename + " ( \
47                 key text, \
48                 hrn text, \
49                 gid text, \
50                 type text, \
51                 pointer integer, \
52                 date_created timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP, \
53                 last_updated timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP);"
54         template = "CREATE INDEX %s_%s_idx ON %s (%s);"
55         indexes = [template % ( self.tablename, field, self.tablename, field) \
56                    for field in ['key', 'hrn', 'type','pointer']]
57         # IF EXISTS doenst exist in postgres < 8.2
58         try:
59             self.cnx.query('DROP TABLE IF EXISTS ' + self.tablename)
60         except ProgrammingError:
61             try:
62                 self.cnx.query('DROP TABLE ' + self.tablename)
63             except ProgrammingError:
64                 pass
65         
66         self.cnx.query(querystr)
67         for index in indexes:
68             self.cnx.query(index)
69
70     def remove(self, record):
71         query_str = "DELETE FROM " + self.tablename + " WHERE key = '" + record.get_key() + "'"
72         self.cnx.query(query_str)
73
74     def insert(self, record):
75         dont_insert = ['date_created', 'last_updated']
76         fields = [field for field in  record.fields.keys() if field not in dont_insert]  
77         fieldnames = ["key", "pointer"] + fields
78         fieldvals = record.get_field_value_strings(fieldnames)
79         query_str = "INSERT INTO " + self.tablename + \
80                        "(" + ",".join(fieldnames) + ") " + \
81                        "VALUES(" + ",".join(fieldvals) + ")"
82         #print query_str
83         self.cnx.query(query_str)
84
85     def update(self, record):
86         dont_update = ['date_created', 'last_updated']
87         fields = [field for field in  record.fields.keys() if field not in dont_update]  
88         fieldvals = record.get_field_value_strings(fields)
89         pairs = []
90         for field in fields:
91             val = record.get_field_value_string(field)
92             pairs.append(field + " = " + val)
93         update = ", ".join(pairs)
94
95         query_str = "UPDATE " + self.tablename+ " SET " + update + " WHERE key = '" + record.get_key() + "'"
96         #print query_str
97         self.cnx.query(query_str)
98
99     def find_dict(self, type, value, searchfield):
100         query_str = "SELECT * FROM " + self.tablename + " WHERE " + searchfield + " = '" + str(value) + "'"
101         dict_list = self.cnx.query(query_str).dictresult()
102         result_dict_list = []
103         for dict in dict_list:
104            if (type=="*") or (dict['type'] == type):
105                result_dict_list.append(dict)
106         return result_dict_list
107
108     def find(self, type, value, searchfield):
109         result_dict_list = self.find_dict(type, value, searchfield)
110         result_rec_list = []
111         for result in result_dict_list:
112             if result['type'] in ['authority']:
113                 result_rec_list.append(AuthorityRecord(dict=result))
114             elif result['type'] in ['node']:
115                 result_rec_list.append(NodeRecord(dict=result))
116             elif result['type'] in ['slice']:
117                 result_rec_list.append(SliceRecord(dict=result))
118             elif result['type'] in ['user']:
119                 result_rec_list.append(UserRecord(dict=result))
120             else:
121                 result_rec_list.append(GeniRecord(dict=result))
122         return result_rec_list
123
124     def resolve_dict(self, type, hrn):
125         return self.find_dict(type, hrn, "hrn")
126
127     def resolve(self, type, hrn):
128         return self.find(type, hrn, "hrn")
129
130     def list_dict(self):
131         query_str = "SELECT * FROM " + self.tablename
132         result_dict_list = self.cnx.query(query_str).dictresult()
133         return result_dict_list
134
135     def list(self):
136         result_dict_list = self.list_dict()
137         result_rec_list = []
138         for dict in result_dict_list:
139             result_rec_list.append(GeniRecord(dict=dict).as_dict())
140         return result_rec_list
141
142     def drop(self):
143         try:
144             self.cnx.query('DROP TABLE IF EXISTS ' + self.tablename)
145         except ProgrammingError:
146             try:
147                 self.cnx.query('DROP TABLE ' + self.tablename)
148             except ProgrammingError:
149                 pass
150     
151     @staticmethod
152     def geni_records_purge(cninfo):
153
154         cnx = DB(cninfo['dbname'], cninfo['address'], 
155                  port=cninfo['port'], user=cninfo['user'], passwd=cninfo['password'])
156         tableList = cnx.get_tables()
157         for table in tableList:
158             if table.startswith(GeniTable.GENI_TABLE_PREFIX) or \
159                     table.startswith('public.' + GeniTable.GENI_TABLE_PREFIX) or \
160                     table.startswith('public."' + GeniTable.GENI_TABLE_PREFIX):
161                 report.trace("dropping table " + table)
162                 cnx.query("DROP TABLE " + table)