2 from collections import defaultdict
4 from sfa.util.sfatime import utcparse, datetime_to_epoch
5 from sfa.util.sfalogging import logger
6 from sfa.util.xrn import Xrn, get_leaf, get_authority, urn_to_hrn
8 from sfa.rspecs.rspec import RSpec
9 from sfa.storage.model import SliverAllocation
11 from sfa.dummy.dummyxrn import DummyXrn, hrn_to_dummy_slicename
17 def __init__(self, driver):
20 def get_slivers(self, xrn, node=None):
21 hrn, type = urn_to_hrn(xrn)
23 slice_name = hrn_to_dummy_slicename(hrn)
25 slices = self.driver.shell.GetSlices({'slice_name': slice_name})
27 # Build up list of users and slice attributes
28 user_ids = slice['user_ids']
29 # Get user information
30 all_users_list = self.driver.shell.GetUsers({'user_id':user_ids})
32 for user in all_users_list:
33 all_users[user['user_id']] = user
35 # Build up list of keys
37 for user in all_users_list:
38 all_keys.extend(user['keys'])
43 # XXX Sanity check; though technically this should be a system invariant
44 # checked with an assertion
45 if slice['expires'] > MAXINT: slice['expires']= MAXINT
49 'name': slice['name'],
50 'slice_id': slice['slice_id'],
51 'expires': slice['expires'],
58 def verify_slice_nodes(self, slice_urn, slice, rspec_nodes):
61 for node in rspec_nodes:
62 hostname = node.get('component_name')
63 client_id = node.get('client_id')
64 component_id = node.get('component_id').strip()
66 hostname = hostname.strip()
68 hostname = xrn_to_hostname(component_id)
70 slivers[hostname] = {'client_id': client_id, 'component_id': component_id}
71 all_nodes = self.driver.shell.GetNodes()
72 requested_slivers = []
73 for node in all_nodes:
74 if node['hostname'] in slivers.keys():
75 requested_slivers.append(node['node_id'])
77 if 'node_ids' not in slice.keys():
79 nodes = self.driver.shell.GetNodes({'node_ids': slice['node_ids']})
80 current_slivers = [node['node_id'] for node in nodes]
82 # remove nodes not in rspec
83 deleted_nodes = list(set(current_slivers).difference(requested_slivers))
85 # add nodes from rspec
86 added_nodes = list(set(requested_slivers).difference(current_slivers))
89 self.driver.shell.AddSliceToNodes({'slice_id': slice['slice_id'], 'node_ids': added_nodes})
90 self.driver.shell.DeleteSliceFromNodes({'slice_id': slice['slice_id'], 'node_ids': deleted_nodes})
93 logger.log_exc('Failed to add/remove slice from nodes')
95 slices = self.driver.shell.GetSlices({'slice_name': slice['slice_name']})
96 resulting_nodes = self.driver.shell.GetNodes({'node_ids': slices[0]['node_ids']})
98 # update sliver allocations
99 for node in resulting_nodes:
100 client_id = slivers[node['hostname']]['client_id']
101 component_id = slivers[node['hostname']]['component_id']
102 sliver_hrn = '%s.%s-%s' % (self.driver.hrn, slice['slice_id'], node['node_id'])
103 sliver_id = Xrn(sliver_hrn, type='sliver').urn
104 record = SliverAllocation(sliver_id=sliver_id, client_id=client_id,
105 component_id=component_id,
106 slice_urn = slice_urn,
107 allocation_state='geni_allocated')
108 record.sync(self.driver.api.dbsession())
109 return resulting_nodes
112 def verify_slice(self, slice_hrn, slice_record, expiration, options=None):
113 if options is None: options={}
114 slicename = hrn_to_dummy_slicename(slice_hrn)
115 parts = slicename.split("_")
116 login_base = parts[0]
117 slices = self.driver.shell.GetSlices({'slice_name': slicename})
119 slice = {'slice_name': slicename}
121 slice['slice_id'] = self.driver.shell.AddSlice(slice)
122 slice['node_ids'] = []
123 slice['user_ids'] = []
126 if slice_record and slice_record.get('expires'):
127 requested_expires = int(datetime_to_epoch(utcparse(slice_record['expires'])))
128 if requested_expires and slice['expires'] != requested_expires:
129 self.driver.shell.UpdateSlice( {'slice_id': slice['slice_id'], 'fields':{'expires' : expiration}})
133 def verify_users(self, slice_hrn, slice_record, users, options=None):
134 if options is None: options={}
135 slice_name = hrn_to_dummy_slicename(slice_hrn)
138 user['urn'] = user['urn'].lower()
139 hrn, type = urn_to_hrn(user['urn'])
140 username = get_leaf(hrn)
141 user['username'] = username
144 user['email'] = user['email'].lower()
145 users_by_email[user['email']] = user
147 # start building a list of existing users
148 existing_users_by_email = {}
149 existing_slice_users_by_email = {}
150 existing_users = self.driver.shell.GetUsers()
151 existing_slice_users_ids = self.driver.shell.GetSlices({'slice_name': slice_name})[0]['user_ids']
152 for user in existing_users:
153 existing_users_by_email[user['email']] = user
154 if user['user_id'] in existing_slice_users_ids:
155 existing_slice_users_by_email[user['email']] = user
157 add_users_by_email = set(users_by_email).difference(existing_slice_user_by_email)
158 delete_users_by_email = set(existing_slice_user_by_email).difference(users_by_email)
160 for user in add_users_by_email:
161 self.driver.shell.AddUser()
166 def verify_keys(self, old_users, new_users, options=None):
167 if options is None: options={}
170 for user in old_users:
171 existing_keys.append(user['keys'])
173 for user in old_users:
174 userdict[user['email']] = user
179 for user in new_users:
180 user_keys = user.get('keys', [])
181 updated_users.append(user)
182 for key_string in user_keys:
183 requested_keys.append(key_string)
184 if key_string not in existing_keys:
187 self.driver.shell.AddUserKey({'user_id': user['user_id'], 'key':key})
191 # remove old keys (only if we are not appending)
192 append = options.get('append', True)
194 removed_keys = set(existing_keys).difference(requested_keys)
195 for key in removed_keys:
197 self.driver.shell.DeleteKey({'key': key})