do not depend on types.StringTypes anymore
[sfa.git] / sfa / dummy / dummyslices.py
1 import time
2 from collections import defaultdict
3
4 from sfa.util.sfatime import utcparse, datetime_to_epoch
5 from sfa.util.sfalogging import logger
6 from sfa.util.xrn import Xrn, get_leaf, get_authority, urn_to_hrn
7
8 from sfa.rspecs.rspec import RSpec
9 from sfa.storage.model import SliverAllocation
10
11 from sfa.dummy.dummyxrn import DummyXrn, hrn_to_dummy_slicename
12
13 MAXINT =  2L**31-1
14
15 class DummySlices:
16
17     def __init__(self, driver):
18         self.driver = driver
19
20     def get_slivers(self, xrn, node=None):
21         hrn, type = urn_to_hrn(xrn)
22          
23         slice_name = hrn_to_dummy_slicename(hrn)
24         
25         slices = self.driver.shell.GetSlices({'slice_name': slice_name})
26         slice = slices[0]
27         # Build up list of users and slice attributes
28         user_ids = slice['user_ids']
29         # Get user information
30         all_users_list = self.driver.shell.GetUsers({'user_id':user_ids})
31         all_users = {}
32         for user in all_users_list:
33             all_users[user['user_id']] = user        
34
35         # Build up list of keys
36         all_keys = set()
37         for user in all_users_list:
38             all_keys.extend(user['keys'])
39
40         slivers = []
41         for slice in slices:
42             keys = all_keys
43             # XXX Sanity check; though technically this should be a system invariant
44             # checked with an assertion
45             if slice['expires'] > MAXINT:  slice['expires']= MAXINT
46             
47             slivers.append({
48                 'hrn': hrn,
49                 'name': slice['name'],
50                 'slice_id': slice['slice_id'],
51                 'expires': slice['expires'],
52                 'keys': keys,
53             })
54
55         return slivers
56  
57
58     def verify_slice_nodes(self, slice_urn, slice, rspec_nodes):
59
60         slivers = {}
61         for node in rspec_nodes:
62             hostname = node.get('component_name')
63             client_id = node.get('client_id')
64             component_id = node.get('component_id').strip()
65             if hostname:
66                 hostname = hostname.strip()
67             elif component_id:
68                 hostname = xrn_to_hostname(component_id)
69             if hostname:
70                 slivers[hostname] = {'client_id': client_id, 'component_id': component_id}
71         all_nodes = self.driver.shell.GetNodes()
72         requested_slivers = []
73         for node in all_nodes:
74             if node['hostname'] in slivers.keys():
75                 requested_slivers.append(node['node_id'])
76
77         if 'node_ids' not in slice.keys():
78             slice['node_ids']=[] 
79         nodes = self.driver.shell.GetNodes({'node_ids': slice['node_ids']})
80         current_slivers = [node['node_id'] for node in nodes]
81
82         # remove nodes not in rspec
83         deleted_nodes = list(set(current_slivers).difference(requested_slivers))
84
85         # add nodes from rspec
86         added_nodes = list(set(requested_slivers).difference(current_slivers))        
87
88         try:
89             self.driver.shell.AddSliceToNodes({'slice_id': slice['slice_id'], 'node_ids': added_nodes})
90             self.driver.shell.DeleteSliceFromNodes({'slice_id': slice['slice_id'], 'node_ids': deleted_nodes})
91
92         except: 
93             logger.log_exc('Failed to add/remove slice from nodes')
94
95         slices = self.driver.shell.GetSlices({'slice_name': slice['slice_name']})
96         resulting_nodes = self.driver.shell.GetNodes({'node_ids': slices[0]['node_ids']})
97
98         # update sliver allocations
99         for node in resulting_nodes:
100             client_id = slivers[node['hostname']]['client_id']
101             component_id = slivers[node['hostname']]['component_id']
102             sliver_hrn = '%s.%s-%s' % (self.driver.hrn, slice['slice_id'], node['node_id'])
103             sliver_id = Xrn(sliver_hrn, type='sliver').urn
104             record = SliverAllocation(sliver_id=sliver_id, client_id=client_id,
105                                       component_id=component_id,
106                                       slice_urn = slice_urn,
107                                       allocation_state='geni_allocated')
108             record.sync(self.driver.api.dbsession())
109         return resulting_nodes
110         
111
112     def verify_slice(self, slice_hrn, slice_record, expiration, options=None):
113         if options is None: options={}
114         slicename = hrn_to_dummy_slicename(slice_hrn)
115         parts = slicename.split("_")
116         login_base = parts[0]
117         slices = self.driver.shell.GetSlices({'slice_name': slicename}) 
118         if not slices:
119             slice = {'slice_name': slicename}
120             # add the slice                          
121             slice['slice_id'] = self.driver.shell.AddSlice(slice)
122             slice['node_ids'] = []
123             slice['user_ids'] = []
124         else:
125             slice = slices[0]
126             if slice_record and slice_record.get('expires'):
127                 requested_expires = int(datetime_to_epoch(utcparse(slice_record['expires'])))
128                 if requested_expires and slice['expires'] != requested_expires:
129                     self.driver.shell.UpdateSlice( {'slice_id': slice['slice_id'], 'fields':{'expires' : expiration}})
130        
131         return slice
132
133     def verify_users(self, slice_hrn, slice_record, users, options=None):
134         if options is None: options={}
135         slice_name = hrn_to_dummy_slicename(slice_hrn)
136         users_by_email = {}
137         for user in users:
138             user['urn'] = user['urn'].lower()
139             hrn, type = urn_to_hrn(user['urn'])
140             username = get_leaf(hrn)
141             user['username'] = username
142
143             if 'email' in user:
144                 user['email'] = user['email'].lower() 
145                 users_by_email[user['email']] = user
146         
147         # start building a list of existing users
148         existing_users_by_email = {}
149         existing_slice_users_by_email = {}
150         existing_users = self.driver.shell.GetUsers()
151         existing_slice_users_ids = self.driver.shell.GetSlices({'slice_name': slice_name})[0]['user_ids']
152         for user in existing_users:
153             existing_users_by_email[user['email']] = user  
154             if user['user_id'] in existing_slice_users_ids:
155                 existing_slice_users_by_email[user['email']] = user
156                 
157         add_users_by_email = set(users_by_email).difference(existing_slice_user_by_email)
158         delete_users_by_email = set(existing_slice_user_by_email).difference(users_by_email)
159         try:
160             for user in add_users_by_email: 
161                 self.driver.shell.AddUser()
162         except: 
163             pass
164             
165
166     def verify_keys(self, old_users, new_users, options=None):
167         if options is None: options={}
168         # existing keys 
169         existing_keys = []
170         for user in old_users:
171              existing_keys.append(user['keys'])
172         userdict = {}
173         for user in old_users:
174             userdict[user['email']] = user    
175     
176         # add new keys
177         requested_keys = []
178         updated_users = []
179         for user in new_users:
180             user_keys = user.get('keys', [])
181             updated_users.append(user)
182             for key_string in user_keys:
183                 requested_keys.append(key_string)
184                 if key_string not in existing_keys:
185                     key = key_string
186                     try:
187                         self.driver.shell.AddUserKey({'user_id': user['user_id'], 'key':key})
188                             
189                     except:
190                         pass        
191         # remove old keys (only if we are not appending)
192         append = options.get('append', True)
193         if append == False: 
194             removed_keys = set(existing_keys).difference(requested_keys)
195             for key in removed_keys:
196                  try:
197                      self.driver.shell.DeleteKey({'key': key})
198                  except:
199                      pass   
200