required_fields = ['hrn', 'addr', 'port']
- def __init__(self, api):
+ def __init__(self, api, file = "/etc/geni/aggregates.xml"):
dict.__init__(self, {})
self.api = api
- aggregates_file = self.api.server_basedir + os.sep + 'aggregates.xml'
+
+ # create default connection dict
connection_dict = {}
for field in self.required_fields:
connection_dict[field] = ''
- self.aggregate_info = XmlStorage(aggregates_file, {'aggregates': {'aggregate': [connection_dict]}})
+
+ # get possible config file locations
+ loaded = False
+ path = os.path.dirname(os.path.abspath(__file__))
+ filename = file.split(os.sep)[-1]
+ alt_file = path + os.sep + filename
+ files = [file, alt_file]
+
+ for f in files:
+ try:
+ if os.path.isfile(f):
+ self.aggregate_info = XmlStorage(f, {'aggregates': {'aggregate': [connection_dict]}})
+ loaded = True
+ except: pass
+
+ # if file is missing, just recreate it in the right place
+ if not loaded:
+ self.aggregate_info = XmlStorage(file, {'aggregates': {'aggregate': [connection_dict]}})
self.aggregate_info.load()
self.connectAggregates()