Added support for new slice table for Senslab.
[sfa.git] / sfa / senslab / slabpostgres.py
1 ###########################################################################
2 #    Copyright (C) 2011 by                                       
3 #    <savakian@sfa2.grenoble.senslab.info>                                                             
4 #
5 # Copyright: See COPYING file that comes with this distribution
6 #
7 ###########################################################################
8 import psycopg2
9 import psycopg2.extensions
10 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
11 # UNICODEARRAY not exported yet
12 psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY)
13 from sfa.util.config import Config
14 from sfa.util.table import SfaTable
15 # allow to run sfa2wsdl if this is missing (for mac)
16 import sys
17 try: import pgdb
18 except: print >> sys.stderr, "WARNING, could not import pgdb"
19
20 slice_table = {'oar_job_id':'integer DEFAULT -1', 'record_id_user':'integer PRIMARY KEY references sfa ON DELETE CASCADE ON UPDATE CASCADE', 'record_id_slice':'integer', 'slice_hrn':'text NOT NULL'}
21 tablenames_dict = {'slice': slice_table}
22
23 class SlabDB:
24     def __init__(self):
25         self.config = Config()
26         self.debug = False
27
28         self.connection = None
29
30     #@handle_exception
31     def cursor(self):
32         if self.connection is None:
33             # (Re)initialize database connection
34             if psycopg2:
35                 try:
36                     # Try UNIX socket first                    
37                     self.connection = psycopg2.connect(user = 'sfa',
38                                                        password = 'sfa',
39                                                        database = 'sfa')
40                     #self.connection = psycopg2.connect(user = self.config.SFA_PLC_DB_USER,
41                                                        #password = self.config.SFA_PLC_DB_PASSWORD,
42                                                        #database = self.config.SFA_PLC_DB_NAME)
43                 except psycopg2.OperationalError:
44                     # Fall back on TCP
45                     self.connection = psycopg2.connect(user = self.config.SFA_PLC_DB_USER,
46                                                        password = self.config.SFA_PLC_DB_PASSWORD,
47                                                        database = self.config.SFA_PLC_DB_NAME,
48                                                        host = self.config.SFA_PLC_DB_HOST,
49                                                        port = self.config.SFA_PLC_DB_PORT)
50                 self.connection.set_client_encoding("UNICODE")
51             else:
52                 self.connection = pgdb.connect(user = self.config.SFA_PLC_DB_USER,
53                                                password = self.config.SFA_PLC_DB_PASSWORD,
54                                                host = "%s:%d" % (self.config.SFA_PLC_DB_HOST, self.config.SFA_PLC_DB_PORT),
55                                                database = self.config.SFA_PLC_DB_NAME)
56
57         (self.rowcount, self.description, self.lastrowid) = \
58                         (None, None, None)
59
60         return self.connection.cursor()
61         
62     def close(self):
63         if self.connection is not None:
64             self.connection.close()
65             self.connection = None
66             
67     def exists(self, tablename):
68         mark = self.cursor()
69         sql = "SELECT * from pg_tables"
70         mark.execute(sql)
71         rows = mark.fetchall()
72         mark.close()
73         labels = [column[0] for column in mark.description]
74         rows = [dict(zip(labels, row)) for row in rows]
75
76         rows = filter(lambda row: row['tablename'].startswith(tablename), rows)
77         if rows:
78             return True
79         return False
80     
81     def createtable(self, tablename ):
82         mark = self.cursor()
83         tablelist =[]
84         T  = tablenames_dict[tablename]
85         for k in T.keys(): 
86             tmp = str(k) +' ' + T[k]
87             tablelist.append(tmp)
88         end = ",".join(tablelist)
89         
90         statement = "CREATE TABLE " + tablename + " ("+ end +");"
91      
92         #template = "CREATE INDEX %s_%s_idx ON %s (%s);"
93         #indexes = [template % ( self.tablename, field, self.tablename, field) \
94                     #for field in ['hrn', 'type', 'authority', 'peer_authority', 'pointer']]
95         # IF EXISTS doenst exist in postgres < 8.2
96         try:
97             mark.execute('DROP TABLE IF EXISTS ' + tablename +';')
98         except:
99             try:
100                 mark.execute('DROP TABLE' + tablename +';')
101             except:
102                 pass
103             
104         mark.execute(statement)
105         #for index in indexes:
106             #self.db.do(index)
107         self.connection.commit()
108         mark.close()
109         #self.connection.close()
110         self.close()
111         return
112     
113         
114     def findRecords(self,table, column, operator, string):
115         mark = self.cursor()
116     
117         statement =  'SELECT * FROM ' + table + ' WHERE ' + column + ' ' + operator + ' ' + ' \'' + string +'\''
118         mark.execute(statement) 
119         record = mark.fetchall() 
120         mark.close()
121         self.connection.close()
122         return record
123
124
125     def insert(self, table, columns,values):
126          mark = self.cursor()
127          statement = "INSERT INTO " + table + \
128                        "(" + ",".join(columns) + ") " + \
129                        "VALUES(" + ", ".join(values) + ");"
130
131          #statement = 'INSERT INTO ' + table + ' (' + columns + ') VALUES (' + values + ')' 
132          print>>sys.stderr, " \r\n insert statement", statement
133          mark.execute(statement) 
134          self.connection.commit()
135          mark.close()
136          #self.connection.close()
137          self.close()
138          return
139     
140     def insert_slice(self, person_rec):
141         sfatable = SfaTable()
142         keys = slice_table.keys()
143         
144         #returns a list of records (dicts)
145         #the filters specified will return only one matching record, into a list of dicts
146
147         userrecord = sfatable.find({'hrn': person_rec['hrn'], 'type':'user'})
148
149         slicerec =  sfatable.find({'hrn': person_rec['hrn']+'_slice', 'type':'slice'})
150         if (isinstance (userrecord, list)):
151             userrecord = userrecord[0]
152         if (isinstance (slicerec, list)):
153             slicerec = slicerec[0]
154         
155         values = [ '-1', ' \''+ str(slicerec['hrn']) + '\'', str(userrecord['record_id']), str( slicerec['record_id'])]
156
157         self.insert('slice', keys, values)
158         return
159         
160     def update(self, table, column_names, values, whereclause, valueclause):
161
162         #Creates the values string for the update SQL command
163         if len(column_names) is not len(values):
164             return
165         else:
166             valueslist = []
167             valuesdict = dict(zip(column_names,values))
168             for k in valuesdict.keys():
169                 valuesdict[k] = str(valuesdict[k])
170                 v = ' \''+ str(k) + '\''+ '='+' \''+ valuesdict[k]+'\''
171                 valueslist.append(v)
172                 
173         statement = "UPDATE %s SET %s WHERE %s = %s" % \
174                     (table, ", ".join(valueslist), whereclause, valueclause)
175         print >>sys.stderr, "\r\n \r\n \t SLABPOSTGRES.PY UPDATE statement    ", statement
176         mark = self.cursor()
177         mark.execute(statement) 
178         self.connection.commit()
179         mark.close()
180         self.close()
181         #self.connection.close()
182         return
183
184     def update_slice(self, slice_rec):
185         sfatable = SfaTable()
186         userhrn = slice_rec['hrn'].strip('_slice')
187         userrecords = sfatable.find({'hrn': userhrn, 'type':'user'})
188         columns = [ 'record_user_id', 'oar_job_id']
189         values = [slice_rec['record_user_id'],slice_rec['oar_job_id']]
190         self.update('slice',columns, values,'record_slice_id', slice_rec['record_slice_id'])
191         return 
192         
193        
194