Merge branch 'master' into senslab2
[sfa.git] / sfa / senslab / slabpostgres.py
1 import psycopg2
2 import psycopg2.extensions
3 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
4 # UNICODEARRAY not exported yet
5 psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
6 from sfa.util.config import Config
7 from sfa.util.table import SfaTable
8 from sfa.util.sfalogging import logger
9 # allow to run sfa2wsdl if this is missing (for mac)
10 import sys
11 try: import pgdb
12 except: print >> sys.stderr, "WARNING, could not import pgdb"
13
14 #Dict holding the columns names of the table as keys
15 #and their type, used for creation of the table
16 slice_table = {'record_id_user':'integer PRIMARY KEY references sfa ON DELETE CASCADE ON UPDATE CASCADE','oar_job_id':'integer DEFAULT -1',  'record_id_slice':'integer', 'slice_hrn':'text NOT NULL'}
17
18 #Dict with all the specific senslab tables
19 tablenames_dict = {'slice': slice_table}
20
21 class SlabDB:
22     def __init__(self):
23         self.config = Config()
24         self.connection = None
25
26     def cursor(self):
27         if self.connection is None:
28             # (Re)initialize database connection
29             if psycopg2:
30                 try:
31                     # Try UNIX socket first                    
32                     self.connection = psycopg2.connect(user = 'sfa',
33                                                        password = 'sfa',
34                                                        database = 'sfa')
35                     #self.connection = psycopg2.connect(user = self.config.SFA_PLC_DB_USER,
36                                                        #password = self.config.SFA_PLC_DB_PASSWORD,
37                                                        #database = self.config.SFA_PLC_DB_NAME)
38                 except psycopg2.OperationalError:
39                     # Fall back on TCP
40                     self.connection = psycopg2.connect(user = self.config.SFA_PLC_DB_USER,
41                                                        password = self.config.SFA_PLC_DB_PASSWORD,
42                                                        database = self.config.SFA_PLC_DB_NAME,
43                                                        host = self.config.SFA_PLC_DB_HOST,
44                                                        port = self.config.SFA_PLC_DB_PORT)
45                 self.connection.set_client_encoding("UNICODE")
46             else:
47                 self.connection = pgdb.connect(user = self.config.SFA_PLC_DB_USER,
48                                                password = self.config.SFA_PLC_DB_PASSWORD,
49                                                host = "%s:%d" % (self.config.SFA_PLC_DB_HOST, self.config.SFA_PLC_DB_PORT),
50                                                database = self.config.SFA_PLC_DB_NAME)
51
52         return self.connection.cursor()
53         
54     #Close connection to database
55     def close(self):
56         if self.connection is not None:
57             self.connection.close()
58             self.connection = None
59             
60     def selectall(self, query,  hashref = True, key_field = None):
61         """
62         Return each row as a dictionary keyed on field name (like DBI
63         selectrow_hashref()). If key_field is specified, return rows
64         as a dictionary keyed on the specified field (like DBI
65         selectall_hashref()).
66
67         """
68         cursor = self.cursor()
69         cursor.execute(query)
70         rows = cursor.fetchall()
71         cursor.close()
72         self.connection.commit()
73
74         if hashref or key_field is not None:
75             # Return each row as a dictionary keyed on field name
76             # (like DBI selectrow_hashref()).
77             labels = [column[0] for column in cursor.description]
78             rows = [dict(zip(labels, row)) for row in rows]
79
80         if key_field is not None and key_field in labels:
81             # Return rows as a dictionary keyed on the specified field
82             # (like DBI selectall_hashref()).
83             return dict([(row[key_field], row) for row in rows])
84         else:
85             return rows
86         
87         
88     def exists(self, tablename):
89         """
90         Checks if the table specified as tablename exists.
91     
92         """
93         #mark = self.cursor()
94         sql = "SELECT * from pg_tables"
95         #mark.execute(sql)
96         #rows = mark.fetchall()
97         #mark.close()
98         #labels = [column[0] for column in mark.description]
99         #rows = [dict(zip(labels, row)) for row in rows]
100         rows = self.selectall(sql)
101         rows = filter(lambda row: row['tablename'].startswith(tablename), rows)
102         if rows:
103             return True
104         return False
105     
106     def createtable(self, tablename ):
107         """
108         Creates the specifed table. Uses the global dictionnary holding the tablenames and
109         the table schema.
110     
111         """
112         mark = self.cursor()
113         tablelist =[]
114         if tablename not in tablenames_dict:
115             logger.error("Tablename unknown - creation failed")
116             return
117             
118         T  = tablenames_dict[tablename]
119         
120         for k in T.keys(): 
121             tmp = str(k) +' ' + T[k]
122             tablelist.append(tmp)
123             
124         end_of_statement = ",".join(tablelist)
125         
126         statement = "CREATE TABLE " + tablename + " ("+ end_of_statement +");"
127      
128         #template = "CREATE INDEX %s_%s_idx ON %s (%s);"
129         #indexes = [template % ( self.tablename, field, self.tablename, field) \
130                     #for field in ['hrn', 'type', 'authority', 'peer_authority', 'pointer']]
131         # IF EXISTS doenst exist in postgres < 8.2
132         try:
133             mark.execute('DROP TABLE IF EXISTS ' + tablename +';')
134         except:
135             try:
136                 mark.execute('DROP TABLE' + tablename +';')
137             except:
138                 pass
139             
140         mark.execute(statement)
141         #for index in indexes:
142             #self.db.do(index)
143         self.connection.commit()
144         mark.close()
145         self.close()
146         return
147     
148
149
150
151     def insert(self, table, columns,values):
152         """
153         Inserts data (values) into the columns of the specified table. 
154     
155         """
156         mark = self.cursor()
157         statement = "INSERT INTO " + table + \
158                     "(" + ",".join(columns) + ") " + \
159                     "VALUES(" + ", ".join(values) + ");"
160
161         mark.execute(statement) 
162         self.connection.commit()
163         mark.close()
164         self.close()
165         return
166     
167     def insert_slab_slice(self, person_rec):
168         """
169         Inserts information about a user and his slice into the slice table. 
170     
171         """
172         sfatable = SfaTable()
173         keys = slice_table.keys()
174         
175         #returns a list of records from the sfa table (dicts)
176         #the filters specified will return only one matching record, into a list of dicts
177         #Finds the slice associated with the user (Senslabs slices  hrns contains the user hrn)
178
179         userrecord = sfatable.find({'hrn': person_rec['hrn'], 'type':'user'})
180         slicerec =  sfatable.find({'hrn': person_rec['hrn']+'_slice', 'type':'slice'})
181         if slicerec :
182             if (isinstance (userrecord, list)):
183                 userrecord = userrecord[0]
184             if (isinstance (slicerec, list)):
185                 slicerec = slicerec[0]
186                 
187             oar_dflt_jobid = -1
188             values = [ str(oar_dflt_jobid), ' \''+ str(slicerec['hrn']) + '\'', str(userrecord['record_id']), str( slicerec['record_id'])]
189     
190             self.insert('slice', keys, values)
191         else :
192             logger.error("Trying to import a not senslab slice")
193         return
194         
195         
196     def update(self, table, column_names, values, whereclause, valueclause):
197         """
198         Updates a record in a given table. 
199     
200         """
201         #Creates the values string for the update SQL command
202         vclause = valueclause
203         if len(column_names) is not len(values):
204             return
205         else:
206             valueslist = []
207             valuesdict = dict(zip(column_names,values))
208             for k in valuesdict.keys():
209                 valuesdict[k] = str(valuesdict[k])
210                 #v = ' \''+ str(k) + '\''+ '='+' \''+ valuesdict[k]+'\''
211                 v = str(k) + '=' + valuesdict[k]
212                 valueslist.append(v)
213         if isinstance(vclause,str):
214             vclause = '\''+ vclause + '\''
215         statement = "UPDATE %s SET %s WHERE %s = %s" % \
216                     (table, ", ".join(valueslist), whereclause, vclause)
217         print>>sys.stderr,"\r\n \r\n SLABPOSTGRES.PY update statement %s valuesdict %s valueslist %s" %(statement,valuesdict,valueslist)
218         mark = self.cursor()
219         mark.execute(statement) 
220         self.connection.commit()
221         mark.close()
222         self.close()
223
224         return
225
226     def update_senslab_slice(self, slice_rec):
227         sfatable = SfaTable()
228         userhrn = slice_rec['hrn'].strip('_slice')
229         userrecord = sfatable.find({'hrn': userhrn, 'type':'user'})
230         if (isinstance (userrecord, list)):
231                 userrecord = userrecord[0]
232         columns = [ 'record_user_id', 'oar_job_id']
233         values = [slice_rec['record_user_id'],slice_rec['oar_job_id']]
234         self.update('slice',columns, values,'record_slice_id', slice_rec['record_slice_id'])
235         return 
236         
237        
238     def find(self, tablename,record_filter = None, columns=None):
239         if not columns:
240             columns = "*"
241         else:
242             columns = ",".join(columns)
243         sql = "SELECT %s FROM %s WHERE True " % (columns, tablename)
244         
245         #if isinstance(record_filter, (list, tuple, set)):
246             #ints = filter(lambda x: isinstance(x, (int, long)), record_filter)
247             #strs = filter(lambda x: isinstance(x, StringTypes), record_filter)
248             #record_filter = Filter(SfaRecord.all_fields, {'record_id': ints, 'hrn': strs})
249             #sql += "AND (%s) %s " % record_filter.sql("OR") 
250         #elif isinstance(record_filter, dict):
251             #record_filter = Filter(SfaRecord.all_fields, record_filter)        
252             #sql += " AND (%s) %s" % record_filter.sql("AND")
253         #elif isinstance(record_filter, StringTypes):
254             #record_filter = Filter(SfaRecord.all_fields, {'hrn':[record_filter]})    
255             #sql += " AND (%s) %s" % record_filter.sql("AND")
256         #elif isinstance(record_filter, int):
257             #record_filter = Filter(SfaRecord.all_fields, {'record_id':[record_filter]})    
258             #sql += " AND (%s) %s" % record_filter.sql("AND")
259        
260         if isinstance(record_filter, dict):
261             for k in record_filter.keys():
262                 sql += "AND "+' \''+ str(k) + '\''+ '='+' \''+ str(record_filter[k])+'\''
263             
264         elif isinstance(record_filter, str):
265             sql += "AND slice_hrn ="+ ' \''+record_filter+'\''
266
267         #elif isinstance(record_filter, int):
268             #record_filter = Filter(SfaRecord.all_fields, {'record_id':[record_filter]})    
269             #sql += " AND (%s) %s" % record_filter.sql("AND")
270         sql +=  ";"
271         print>>sys.stderr, " \r\n \r\n \t SLABPOSTGRES.PY find : sql %s record_filter  %s %s" %(sql, record_filter , type(record_filter))
272         results = self.selectall(sql)
273         if isinstance(results, dict):
274             results = [results]
275         return results
276