From: Alina Quereilhac Date: Sun, 3 Aug 2014 19:00:45 +0000 (+0200) Subject: Bugfixes for EC serialization and plotting X-Git-Tag: nepi-3.2.0~104 X-Git-Url: http://git.onelab.eu/?p=nepi.git;a=commitdiff_plain;h=d5b781271af50ba526332809bc632fe17ef6d5e5 Bugfixes for EC serialization and plotting --- diff --git a/src/nepi/execution/ec.py b/src/nepi/execution/ec.py index 31e11ef5..2c6557ff 100644 --- a/src/nepi/execution/ec.py +++ b/src/nepi/execution/ec.py @@ -154,9 +154,9 @@ class ExperimentController(object): """ @classmethod - def load(cls, path, format = SFormats.XML): + def load(cls, filepath, format = SFormats.XML): serializer = ECSerializer() - ec = serializer.load(path) + ec = serializer.load(filepath) return ec def __init__(self, exp_id = None): @@ -382,10 +382,10 @@ class ExperimentController(object): time.sleep(0.5) - def plot(self, fpath = None, format= PFormats.FIGURE, persist = False): + def plot(self, dirpath = None, format= PFormats.FIGURE, show = False): plotter = ECPlotter() - fpath = plotter.plot(self, fpath = fpath, format= format, - persist = persist) + fpath = plotter.plot(self, dirpath = dirpath, format= format, + show = show) return fpath def serialize(self, format = SFormats.XML): @@ -393,9 +393,9 @@ class ExperimentController(object): sec = serializer.load(self, format = format) return sec - def save(self, path, format = SFormats.XML): + def save(self, dirpath = None, format = SFormats.XML): serializer = ECSerializer() - path = serializer.save(self, path, format = format) + path = serializer.save(self, dirpath = None, format = format) return path def get_task(self, tid): diff --git a/src/nepi/util/plot.py b/src/nepi/util/plot.py deleted file mode 100644 index 97b146be..00000000 --- a/src/nepi/util/plot.py +++ /dev/null @@ -1,49 +0,0 @@ -# -# NEPI, a framework to manage network experiments -# Copyright (C) 2013 INRIA -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . -# -# Author: Alina Quereilhac - -import networkx -import tempfile - -class Plotter(object): - def __init__(self, box): - self._graph = networkx.Graph(graph = dict(overlap = "false")) - - traversed = set() - self._traverse_boxes(traversed, box) - - def _traverse_boxes(self, traversed, box): - traversed.add(box.guid) - - self._graph.add_node(box.label, - width = 50/72.0, # 1 inch = 72 points - height = 50/72.0, - shape = "circle") - - for b in box.connections: - self._graph.add_edge(box.label, b.label) - if b.guid not in traversed: - self._traverse_boxes(traversed, b) - - def plot(self): - f = tempfile.NamedTemporaryFile(delete=False) - networkx.draw_graphviz(self._graph) - networkx.write_dot(self._graph, f.name) - f.close() - return f.name - diff --git a/src/nepi/util/plotter.py b/src/nepi/util/plotter.py index c21f5988..f1c1748e 100644 --- a/src/nepi/util/plotter.py +++ b/src/nepi/util/plotter.py @@ -25,16 +25,17 @@ class PFormats: FIGURE = "figure" class ECPlotter(object): - def plot(self, ec, fpath = None, format= PFormats.FIGURE, persist = False): + def plot(self, ec, dirpath = None, format= PFormats.FIGURE, + show = False): graph, labels = self._ec2graph(ec) add_extension = False - if persist and not fpath: + if not dirpath: import tempfile dirpath = tempfile.mkdtemp() - fpath = os.path.join(dirpath, "%s_%s" % (ec.exp_id, ec.run_id)) - add_extension = True + + fpath = os.path.join(dirpath, "%s_%s" % (ec.exp_id, ec.run_id)) if format == PFormats.FIGURE: import matplotlib.pyplot as plt @@ -43,23 +44,21 @@ class ECPlotter(object): node_size = 500, with_labels=True) label = "\n".join(map(lambda v: "%s: %s" % (v[0], v[1]), labels.iteritems())) - plt.annotate(label, xy=(0.0, 0.95), xycoords='axes fraction') - - if persist: - if add_extension: - fpath += ".png" + plt.annotate(label, xy=(0.05, 0.95), xycoords='axes fraction') + + fpath += ".png" - plt.savefig(fpath, bbox_inches="tight") - else: + plt.savefig(fpath, bbox_inches="tight") + + if show: plt.show() elif format == PFormats.DOT: - if persist: - if add_extension: - fpath += ".dot" + fpath += ".dot" - networkx.write_dot(graph, fpath) - else: + networkx.write_dot(graph, fpath) + + if show: import subprocess subprocess.call(["dot", "-Tps", fpath, "-o", "%s.ps" % fpath]) subprocess.call(["evince","%s.ps" % fpath]) diff --git a/src/nepi/util/serializer.py b/src/nepi/util/serializer.py index 9ed216f2..aaf1aff4 100644 --- a/src/nepi/util/serializer.py +++ b/src/nepi/util/serializer.py @@ -24,12 +24,12 @@ class SFormats: XML = "xml" class ECSerializer(object): - def load(self, path, format = SFormats.XML): + def load(self, filepath, format = SFormats.XML): if format == SFormats.XML: from nepi.util.parsers.xml_parser import ECXMLParser parser = ECXMLParser() - f = open(path, "r") + f = open(filepath, "r") xml = f.read() f.close() @@ -46,16 +46,20 @@ class ECSerializer(object): return sec - def save(self, ec, path, format = SFormats.XML): + def save(self, ec, dirpath = None, format = SFormats.XML): + if not dirpath: + import tempfile + dirpath = tempfile.mkdtemp() + date = datetime.datetime.now().strftime('%Y%m%d%H%M%S') filename = "%s_%s" % (ec.exp_id, date) if format == SFormats.XML: - path = os.path.join(path, "%s.xml" % filename) + filepath = os.path.join(dirpath, "%s.xml" % filename) sec = self.serialize(ec, format = format) - f = open(path, "w") + f = open(filepath, "w") f.write(sec) f.close() - return path + return filepath diff --git a/test/util/plotter.py b/test/util/plotter.py index 4e1a553e..9ad5ce6e 100755 --- a/test/util/plotter.py +++ b/test/util/plotter.py @@ -118,20 +118,21 @@ class PlotterTestCase(unittest.TestCase): for iface in ifaces: ec.register_connection(link, iface) - fpath = ec.plot(persist = True) + fpath = ec.plot() statinfo = os.stat(fpath) size = statinfo.st_size self.assertTrue(size > 0) self.assertTrue(fpath.endswith(".png")) - fpath = ec.plot(persist = True, format = PFormats.DOT) + os.remove(fpath) + + fpath = ec.plot(format = PFormats.DOT) statinfo = os.stat(fpath) size = statinfo.st_size self.assertTrue(size > 0) self.assertTrue(fpath.endswith(".dot")) - print fpath - + os.remove(fpath) if __name__ == '__main__': unittest.main()