Merge branch 'master' into senslab2
[sfa.git] / sfa / senslab / slabpostgres.py
1 import psycopg2
2 import psycopg2.extensions
3 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
4 # UNICODEARRAY not exported yet
5 psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
6 from sfa.util.config import Config
7 from sfa.storage.table import SfaTable
8 from sfa.util.sfalogging import logger
9 # allow to run sfa2wsdl if this is missing (for mac)
10 import sys
11 try: import pgdb
12 except: print >> sys.stderr, "WARNING, could not import pgdb"
13
14 #Dict holding the columns names of the table as keys
15 #and their type, used for creation of the table
16 slice_table = {'record_id_user':'integer PRIMARY KEY references X ON DELETE CASCADE ON UPDATE CASCADE','oar_job_id':'integer DEFAULT -1',  'record_id_slice':'integer', 'slice_hrn':'text NOT NULL'}
17
18 #Dict with all the specific senslab tables
19 tablenames_dict = {'slice': slice_table}
20
21 class SlabDB:
22     def __init__(self):
23         self.config = Config()
24         self.connection = None
25         self.init_create_query()
26         
27     def init_create_query(self):
28         sfatable = SfaTable()
29         slice_table['record_id_user'] =  slice_table['record_id_user'].replace("X",sfatable.tablename)
30         print sys.stderr, " \r\n \r\n slice_table %s ",slice_table 
31         
32     def cursor(self):
33         if self.connection is None:
34             # (Re)initialize database connection
35             if psycopg2:
36                 try:
37                     # Try UNIX socket first                    
38                     self.connection = psycopg2.connect(user = 'sfa',
39                                                        password = 'sfa',
40                                                        database = 'sfa')
41                     #self.connection = psycopg2.connect(user = self.config.SFA_PLC_DB_USER,
42                                                        #password = self.config.SFA_PLC_DB_PASSWORD,
43                                                        #database = self.config.SFA_PLC_DB_NAME)
44                 except psycopg2.OperationalError:
45                     # Fall back on TCP
46                     self.connection = psycopg2.connect(user = self.config.SFA_PLC_DB_USER,
47                                                        password = self.config.SFA_PLC_DB_PASSWORD,
48                                                        database = self.config.SFA_PLC_DB_NAME,
49                                                        host = self.config.SFA_PLC_DB_HOST,
50                                                        port = self.config.SFA_PLC_DB_PORT)
51                 self.connection.set_client_encoding("UNICODE")
52             else:
53                 self.connection = pgdb.connect(user = self.config.SFA_PLC_DB_USER,
54                                                password = self.config.SFA_PLC_DB_PASSWORD,
55                                                host = "%s:%d" % (self.config.SFA_PLC_DB_HOST, self.config.SFA_PLC_DB_PORT),
56                                                database = self.config.SFA_PLC_DB_NAME)
57
58         return self.connection.cursor()
59         
60     #Close connection to database
61     def close(self):
62         if self.connection is not None:
63             self.connection.close()
64             self.connection = None
65             
66     def selectall(self, query,  hashref = True, key_field = None):
67         """
68         Return each row as a dictionary keyed on field name (like DBI
69         selectrow_hashref()). If key_field is specified, return rows
70         as a dictionary keyed on the specified field (like DBI
71         selectall_hashref()).
72
73         """
74         cursor = self.cursor()
75         cursor.execute(query)
76         rows = cursor.fetchall()
77         cursor.close()
78         self.connection.commit()
79
80         if hashref or key_field is not None:
81             # Return each row as a dictionary keyed on field name
82             # (like DBI selectrow_hashref()).
83             labels = [column[0] for column in cursor.description]
84             rows = [dict(zip(labels, row)) for row in rows]
85
86         if key_field is not None and key_field in labels:
87             # Return rows as a dictionary keyed on the specified field
88             # (like DBI selectall_hashref()).
89             return dict([(row[key_field], row) for row in rows])
90         else:
91             return rows
92         
93         
94     def exists(self, tablename):
95         """
96         Checks if the table specified as tablename exists.
97     
98         """
99         #mark = self.cursor()
100         sql = "SELECT * from pg_tables"
101         #mark.execute(sql)
102         #rows = mark.fetchall()
103         #mark.close()
104         #labels = [column[0] for column in mark.description]
105         #rows = [dict(zip(labels, row)) for row in rows]
106         rows = self.selectall(sql)
107         rows = filter(lambda row: row['tablename'].startswith(tablename), rows)
108         if rows:
109             return True
110         return False
111     
112     def createtable(self, tablename ):
113         """
114         Creates the specifed table. Uses the global dictionnary holding the tablenames and
115         the table schema.
116     
117         """
118         mark = self.cursor()
119         tablelist =[]
120         if tablename not in tablenames_dict:
121             logger.error("Tablename unknown - creation failed")
122             return
123             
124         T  = tablenames_dict[tablename]
125         
126         for k in T.keys(): 
127             tmp = str(k) +' ' + T[k]
128             tablelist.append(tmp)
129             
130         end_of_statement = ",".join(tablelist)
131         
132         statement = "CREATE TABLE " + tablename + " ("+ end_of_statement +");"
133      
134         #template = "CREATE INDEX %s_%s_idx ON %s (%s);"
135         #indexes = [template % ( self.tablename, field, self.tablename, field) \
136                     #for field in ['hrn', 'type', 'authority', 'peer_authority', 'pointer']]
137         # IF EXISTS doenst exist in postgres < 8.2
138         try:
139             mark.execute('DROP TABLE IF EXISTS ' + tablename +';')
140         except:
141             try:
142                 mark.execute('DROP TABLE' + tablename +';')
143             except:
144                 pass
145             
146         mark.execute(statement)
147         #for index in indexes:
148             #self.db.do(index)
149         self.connection.commit()
150         mark.close()
151         self.close()
152         return
153     
154
155
156
157     def insert(self, table, columns,values):
158         """
159         Inserts data (values) into the columns of the specified table. 
160     
161         """
162         mark = self.cursor()
163         statement = "INSERT INTO " + table + \
164                     "(" + ",".join(columns) + ") " + \
165                     "VALUES(" + ", ".join(values) + ");"
166
167         mark.execute(statement) 
168         self.connection.commit()
169         mark.close()
170         self.close()
171         return
172     
173     def insert_slab_slice(self, person_rec):
174         """
175         Inserts information about a user and his slice into the slice table. 
176     
177         """
178         sfatable = SfaTable()
179         keys = slice_table.keys()
180         
181         #returns a list of records from the sfa table (dicts)
182         #the filters specified will return only one matching record, into a list of dicts
183         #Finds the slice associated with the user (Senslabs slices  hrns contains the user hrn)
184
185         userrecord = sfatable.find({'hrn': person_rec['hrn'], 'type':'user'})
186         slicerec =  sfatable.find({'hrn': person_rec['hrn']+'_slice', 'type':'slice'})
187         if slicerec :
188             if (isinstance (userrecord, list)):
189                 userrecord = userrecord[0]
190             if (isinstance (slicerec, list)):
191                 slicerec = slicerec[0]
192                 
193             oar_dflt_jobid = -1
194             values = [ str(oar_dflt_jobid), ' \''+ str(slicerec['hrn']) + '\'', str(userrecord['record_id']), str( slicerec['record_id'])]
195     
196             self.insert('slice', keys, values)
197         else :
198             logger.error("Trying to import a not senslab slice")
199         return
200         
201         
202     def update(self, table, column_names, values, whereclause, valueclause):
203         """
204         Updates a record in a given table. 
205     
206         """
207         #Creates the values string for the update SQL command
208         vclause = valueclause
209         if len(column_names) is not len(values):
210             return
211         else:
212             valueslist = []
213             valuesdict = dict(zip(column_names,values))
214             for k in valuesdict.keys():
215                 valuesdict[k] = str(valuesdict[k])
216                 #v = ' \''+ str(k) + '\''+ '='+' \''+ valuesdict[k]+'\''
217                 v = str(k) + '=' + valuesdict[k]
218                 valueslist.append(v)
219         if isinstance(vclause,str):
220             vclause = '\''+ vclause + '\''
221         statement = "UPDATE %s SET %s WHERE %s = %s" % \
222                     (table, ", ".join(valueslist), whereclause, vclause)
223         print>>sys.stderr,"\r\n \r\n SLABPOSTGRES.PY update statement %s valuesdict %s valueslist %s" %(statement,valuesdict,valueslist)
224         mark = self.cursor()
225         mark.execute(statement) 
226         self.connection.commit()
227         mark.close()
228         self.close()
229
230         return
231
232     def update_senslab_slice(self, slice_rec):
233         sfatable = SfaTable()
234         hrn = str(slice_rec['hrn']) 
235         userhrn = hrn.rstrip('_slice')
236         userrecord = sfatable.find({'hrn': userhrn, 'type':'user'})
237         print>>sys.stderr, " \r\n \r\n \t SLABPOSTGRES.PY  update_senslab_slice : userrecord  %s slice_rec %s userhrn %s" %( userrecord, slice_rec, userhrn)
238         if (isinstance (userrecord, list)):
239                 userrecord = userrecord[0]
240         columns = [ 'record_id_user', 'oar_job_id']
241         values = [slice_rec['record_id_user'],slice_rec['oar_job_id']]
242         self.update('slice',columns, values,'record_id_slice', slice_rec['record_id_slice'])
243         return 
244         
245        
246     def find(self, tablename,record_filter = None, columns=None):  
247         print>>sys.stderr, " \r\n \r\n \t SLABPOSTGRES.PY find :  record_filter %s %s columns %s %s" %( record_filter , type(record_filter),columns , type(columns))
248         if not columns:
249             columns = "*"
250         else:
251             columns = ",".join(columns)
252         sql = "SELECT %s FROM %s WHERE True " % (columns, tablename)
253         
254         #if isinstance(record_filter, (list, tuple, set)):
255             #ints = filter(lambda x: isinstance(x, (int, long)), record_filter)
256             #strs = filter(lambda x: isinstance(x, StringTypes), record_filter)
257             #record_filter = Filter(SfaRecord.all_fields, {'record_id': ints, 'hrn': strs})
258             #sql += "AND (%s) %s " % record_filter.sql("OR") 
259         #elif isinstance(record_filter, dict):
260             #record_filter = Filter(SfaRecord.all_fields, record_filter)        
261             #sql += " AND (%s) %s" % record_filter.sql("AND")
262         #elif isinstance(record_filter, StringTypes):
263             #record_filter = Filter(SfaRecord.all_fields, {'hrn':[record_filter]})    
264             #sql += " AND (%s) %s" % record_filter.sql("AND")
265         #elif isinstance(record_filter, int):
266             #record_filter = Filter(SfaRecord.all_fields, {'record_id':[record_filter]})    
267             #sql += " AND (%s) %s" % record_filter.sql("AND")
268        
269         if isinstance(record_filter, dict):
270             for k in record_filter.keys():
271                 #sql += "AND "+' \''+ str(k) + '\''+ '='+' \''+ str(record_filter[k])+'\''
272                 #sql += "AND "+ str(k) + '=' + str(record_filter[k])
273                 sql += "AND "+ str(k) +'='+' \''+ str(record_filter[k])+'\''
274         elif isinstance(record_filter, str):
275             sql += "AND slice_hrn ="+ ' \''+record_filter+'\''
276
277         #elif isinstance(record_filter, int):
278             #record_filter = Filter(SfaRecord.all_fields, {'record_id':[record_filter]})    
279             #sql += " AND (%s) %s" % record_filter.sql("AND")
280         sql +=  ";"
281         print>>sys.stderr, " \r\n \r\n \t SLABPOSTGRES.PY find : sql %s record_filter  %s %s" %(sql, record_filter , type(record_filter))
282         results = self.selectall(sql)
283         if isinstance(results, dict):
284             results = [results]
285         return results
286