Clean-up of slabpostgres.py
[sfa.git] / sfa / senslab / slabpostgres.py
1 import psycopg2
2 import psycopg2.extensions
3 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
4 # UNICODEARRAY not exported yet
5 psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
6 from sfa.util.config import Config
7 from sfa.util.table import SfaTable
8 from sfa.util.sfalogging import logger
9 # allow to run sfa2wsdl if this is missing (for mac)
10 import sys
11 try: import pgdb
12 except: print >> sys.stderr, "WARNING, could not import pgdb"
13
14 #Dict holding the columns names of the table as keys
15 #and their type, used for creation of the table
16 slice_table = {'record_id_user':'integer PRIMARY KEY references sfa ON DELETE CASCADE ON UPDATE CASCADE','oar_job_id':'integer DEFAULT -1',  'record_id_slice':'integer', 'slice_hrn':'text NOT NULL'}
17
18 #Dict with all the specific senslab tables
19 tablenames_dict = {'slice': slice_table}
20
21 class SlabDB:
22     def __init__(self):
23         self.config = Config()
24         self.connection = None
25
26     def cursor(self):
27         if self.connection is None:
28             # (Re)initialize database connection
29             if psycopg2:
30                 try:
31                     # Try UNIX socket first                    
32                     self.connection = psycopg2.connect(user = 'sfa',
33                                                        password = 'sfa',
34                                                        database = 'sfa')
35                     #self.connection = psycopg2.connect(user = self.config.SFA_PLC_DB_USER,
36                                                        #password = self.config.SFA_PLC_DB_PASSWORD,
37                                                        #database = self.config.SFA_PLC_DB_NAME)
38                 except psycopg2.OperationalError:
39                     # Fall back on TCP
40                     self.connection = psycopg2.connect(user = self.config.SFA_PLC_DB_USER,
41                                                        password = self.config.SFA_PLC_DB_PASSWORD,
42                                                        database = self.config.SFA_PLC_DB_NAME,
43                                                        host = self.config.SFA_PLC_DB_HOST,
44                                                        port = self.config.SFA_PLC_DB_PORT)
45                 self.connection.set_client_encoding("UNICODE")
46             else:
47                 self.connection = pgdb.connect(user = self.config.SFA_PLC_DB_USER,
48                                                password = self.config.SFA_PLC_DB_PASSWORD,
49                                                host = "%s:%d" % (self.config.SFA_PLC_DB_HOST, self.config.SFA_PLC_DB_PORT),
50                                                database = self.config.SFA_PLC_DB_NAME)
51
52         (self.rowcount, self.description, self.lastrowid) = \
53                         (None, None, None)
54
55         return self.connection.cursor()
56         
57     #Close connection to database
58     def close(self):
59         if self.connection is not None:
60             self.connection.close()
61             self.connection = None
62             
63     def exists(self, tablename):
64         """
65         Checks if the table specified as tablename exists.
66     
67         """
68         mark = self.cursor()
69         sql = "SELECT * from pg_tables"
70         mark.execute(sql)
71         rows = mark.fetchall()
72         mark.close()
73         labels = [column[0] for column in mark.description]
74         rows = [dict(zip(labels, row)) for row in rows]
75
76         rows = filter(lambda row: row['tablename'].startswith(tablename), rows)
77         if rows:
78             return True
79         return False
80     
81     def createtable(self, tablename ):
82         """
83         Creates the specifed table. Uses the global dictionnary holding the tablenames and
84         the table schema.
85     
86         """
87         mark = self.cursor()
88         tablelist =[]
89         if tablename not in tablenames_dict:
90             logger.error("Tablename unknown - creation failed")
91             return
92             
93         T  = tablenames_dict[tablename]
94         
95         for k in T.keys(): 
96             tmp = str(k) +' ' + T[k]
97             tablelist.append(tmp)
98             
99         end_of_statement = ",".join(tablelist)
100         
101         statement = "CREATE TABLE " + tablename + " ("+ end_of_statement +");"
102      
103         #template = "CREATE INDEX %s_%s_idx ON %s (%s);"
104         #indexes = [template % ( self.tablename, field, self.tablename, field) \
105                     #for field in ['hrn', 'type', 'authority', 'peer_authority', 'pointer']]
106         # IF EXISTS doenst exist in postgres < 8.2
107         try:
108             mark.execute('DROP TABLE IF EXISTS ' + tablename +';')
109         except:
110             try:
111                 mark.execute('DROP TABLE' + tablename +';')
112             except:
113                 pass
114             
115         mark.execute(statement)
116         #for index in indexes:
117             #self.db.do(index)
118         self.connection.commit()
119         mark.close()
120         self.close()
121         return
122     
123
124
125
126     def insert(self, table, columns,values):
127         """
128         Inserts data (values) into the columns of the specified table. 
129     
130         """
131         mark = self.cursor()
132         statement = "INSERT INTO " + table + \
133                     "(" + ",".join(columns) + ") " + \
134                     "VALUES(" + ", ".join(values) + ");"
135
136         mark.execute(statement) 
137         self.connection.commit()
138         mark.close()
139         self.close()
140         return
141     
142     def insert_slab_slice(self, person_rec):
143         """
144         Inserts information about a user and his slice into the slice table. 
145     
146         """
147         sfatable = SfaTable()
148         keys = slice_table.keys()
149         
150         #returns a list of records from the sfa table (dicts)
151         #the filters specified will return only one matching record, into a list of dicts
152         #Finds the slice associated with the user (Senslabs slices  hrns contains the user hrn)
153
154         userrecord = sfatable.find({'hrn': person_rec['hrn'], 'type':'user'})
155         slicerec =  sfatable.find({'hrn': person_rec['hrn']+'_slice', 'type':'slice'})
156         if slicerec :
157             if (isinstance (userrecord, list)):
158                 userrecord = userrecord[0]
159             if (isinstance (slicerec, list)):
160                 slicerec = slicerec[0]
161                 
162             oar_dflt_jobid = -1
163             values = [ str(oar_dflt_jobid), ' \''+ str(slicerec['hrn']) + '\'', str(userrecord['record_id']), str( slicerec['record_id'])]
164     
165             self.insert('slice', keys, values)
166         else :
167             logger.error("Trying to import a not senslab slice")
168         return
169         
170         
171     def update(self, table, column_names, values, whereclause, valueclause):
172         """
173         Updates a record in a given table. 
174     
175         """
176         #Creates the values string for the update SQL command
177         if len(column_names) is not len(values):
178             return
179         else:
180             valueslist = []
181             valuesdict = dict(zip(column_names,values))
182             for k in valuesdict.keys():
183                 valuesdict[k] = str(valuesdict[k])
184                 v = ' \''+ str(k) + '\''+ '='+' \''+ valuesdict[k]+'\''
185                 valueslist.append(v)
186                 
187         statement = "UPDATE %s SET %s WHERE %s = %s" % \
188                     (table, ", ".join(valueslist), whereclause, valueclause)
189
190         mark = self.cursor()
191         mark.execute(statement) 
192         self.connection.commit()
193         mark.close()
194         self.close()
195
196         return
197
198     def update_senslab_slice(self, slice_rec):
199         sfatable = SfaTable()
200         userhrn = slice_rec['hrn'].strip('_slice')
201         userrecord = sfatable.find({'hrn': userhrn, 'type':'user'})
202         if (isinstance (userrecord, list)):
203                 userrecord = userrecord[0]
204         columns = [ 'record_user_id', 'oar_job_id']
205         values = [slice_rec['record_user_id'],slice_rec['oar_job_id']]
206         self.update('slice',columns, values,'record_slice_id', slice_rec['record_slice_id'])
207         return 
208         
209        
210