Bugfixes for EC serialization and plotting
authorAlina Quereilhac <alina.quereilhac@inria.fr>
Sun, 3 Aug 2014 19:00:45 +0000 (21:00 +0200)
committerAlina Quereilhac <alina.quereilhac@inria.fr>
Sun, 3 Aug 2014 19:00:45 +0000 (21:00 +0200)
src/nepi/execution/ec.py
src/nepi/util/plot.py [deleted file]
src/nepi/util/plotter.py
src/nepi/util/serializer.py
test/util/plotter.py

index 31e11ef..2c6557f 100644 (file)
@@ -154,9 +154,9 @@ class ExperimentController(object):
     """
 
     @classmethod
-    def load(cls, path, format = SFormats.XML):
+    def load(cls, filepath, format = SFormats.XML):
         serializer = ECSerializer()
-        ec = serializer.load(path)
+        ec = serializer.load(filepath)
         return ec
 
     def __init__(self, exp_id = None): 
@@ -382,10 +382,10 @@ class ExperimentController(object):
 
                 time.sleep(0.5)
 
-    def plot(self, fpath = None, format= PFormats.FIGURE, persist = False):
+    def plot(self, dirpath = None, format= PFormats.FIGURE, show = False):
         plotter = ECPlotter()
-        fpath = plotter.plot(self, fpath = fpath, format= format, 
-                persist = persist)
+        fpath = plotter.plot(self, dirpath = dirpath, format= format, 
+                show = show)
         return fpath
 
     def serialize(self, format = SFormats.XML):
@@ -393,9 +393,9 @@ class ExperimentController(object):
         sec = serializer.load(self, format = format)
         return sec
 
-    def save(self, path, format = SFormats.XML):
+    def save(self, dirpath = None, format = SFormats.XML):
         serializer = ECSerializer()
-        path = serializer.save(self, path, format = format)
+        path = serializer.save(self, dirpath  = None, format = format)
         return path
 
     def get_task(self, tid):
diff --git a/src/nepi/util/plot.py b/src/nepi/util/plot.py
deleted file mode 100644 (file)
index 97b146b..0000000
+++ /dev/null
@@ -1,49 +0,0 @@
-#
-#    NEPI, a framework to manage network experiments
-#    Copyright (C) 2013 INRIA
-#
-#    This program is free software: you can redistribute it and/or modify
-#    it under the terms of the GNU General Public License as published by
-#    the Free Software Foundation, either version 3 of the License, or
-#    (at your option) any later version.
-#
-#    This program is distributed in the hope that it will be useful,
-#    but WITHOUT ANY WARRANTY; without even the implied warranty of
-#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-#    GNU General Public License for more details.
-#
-#    You should have received a copy of the GNU General Public License
-#    along with this program.  If not, see <http://www.gnu.org/licenses/>.
-#
-# Author: Alina Quereilhac <alina.quereilhac@inria.fr>
-
-import networkx
-import tempfile
-
-class Plotter(object):
-    def __init__(self, box):
-        self._graph = networkx.Graph(graph = dict(overlap = "false"))
-
-        traversed = set()
-        self._traverse_boxes(traversed, box)
-
-    def _traverse_boxes(self, traversed, box):
-        traversed.add(box.guid)
-
-        self._graph.add_node(box.label, 
-                width = 50/72.0, # 1 inch = 72 points
-                height = 50/72.0, 
-                shape = "circle")
-
-        for b in box.connections:
-            self._graph.add_edge(box.label, b.label)
-            if b.guid not in traversed:
-                self._traverse_boxes(traversed, b)
-
-    def plot(self):
-        f = tempfile.NamedTemporaryFile(delete=False)
-        networkx.draw_graphviz(self._graph)
-        networkx.write_dot(self._graph, f.name)
-        f.close()
-        return f.name
-
index c21f598..f1c1748 100644 (file)
@@ -25,16 +25,17 @@ class PFormats:
     FIGURE = "figure"
 
 class ECPlotter(object):
-    def plot(self, ec, fpath = None, format= PFormats.FIGURE, persist = False):
+    def plot(self, ec, dirpath = None, format= PFormats.FIGURE, 
+            show = False):
         graph, labels = self._ec2graph(ec)
 
         add_extension = False
 
-        if persist and not fpath:
+        if not dirpath:
             import tempfile
             dirpath = tempfile.mkdtemp()
-            fpath = os.path.join(dirpath, "%s_%s" % (ec.exp_id, ec.run_id)) 
-            add_extension = True
+        
+        fpath = os.path.join(dirpath, "%s_%s" % (ec.exp_id, ec.run_id)) 
 
         if format == PFormats.FIGURE:
             import matplotlib.pyplot as plt
@@ -43,23 +44,21 @@ class ECPlotter(object):
                     node_size = 500, with_labels=True)
            
             label = "\n".join(map(lambda v: "%s: %s" % (v[0], v[1]), labels.iteritems()))
-            plt.annotate(label, xy=(0.0, 0.95), xycoords='axes fraction')
-            
-            if persist:
-                if add_extension:
-                    fpath += ".png"
+            plt.annotate(label, xy=(0.05, 0.95), xycoords='axes fraction')
+           
+            fpath += ".png"
 
-                plt.savefig(fpath, bbox_inches="tight")
-            else:
+            plt.savefig(fpath, bbox_inches="tight")
+            
+            if show:
                 plt.show()
 
         elif format == PFormats.DOT:
-            if persist:
-                if add_extension:
-                    fpath += ".dot"
+            fpath += ".dot"
 
-                networkx.write_dot(graph, fpath)
-            else:
+            networkx.write_dot(graph, fpath)
+            
+            if show:
                 import subprocess
                 subprocess.call(["dot", "-Tps", fpath, "-o", "%s.ps" % fpath])
                 subprocess.call(["evince","%s.ps" % fpath])
index 9ed216f..aaf1aff 100644 (file)
@@ -24,12 +24,12 @@ class SFormats:
     XML = "xml"
     
 class ECSerializer(object):
-    def load(self, path, format = SFormats.XML):
+    def load(self, filepath, format = SFormats.XML):
         if format == SFormats.XML:
             from nepi.util.parsers.xml_parser import ECXMLParser
             
             parser = ECXMLParser()
-            f = open(path, "r")
+            f = open(filepath, "r")
             xml = f.read()
             f.close()
 
@@ -46,16 +46,20 @@ class ECSerializer(object):
 
         return sec
 
-    def save(self, ec, path, format = SFormats.XML):
+    def save(self, ec, dirpath = None, format = SFormats.XML):
+        if not dirpath:
+            import tempfile
+            dirpath = tempfile.mkdtemp()
         date = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
         filename = "%s_%s" % (ec.exp_id, date)
 
         if format == SFormats.XML:
-            path = os.path.join(path, "%s.xml" % filename)
+            filepath = os.path.join(dirpath, "%s.xml" % filename)
             sec = self.serialize(ec, format = format)
-            f = open(path, "w")
+            f = open(filepath, "w")
             f.write(sec)
             f.close()
 
-        return path
+        return filepath
 
index 4e1a553..9ad5ce6 100755 (executable)
@@ -118,20 +118,21 @@ class PlotterTestCase(unittest.TestCase):
         for iface in ifaces:
             ec.register_connection(link, iface)
        
-        fpath = ec.plot(persist = True)
+        fpath = ec.plot()
         statinfo = os.stat(fpath)
         size = statinfo.st_size
         self.assertTrue(size > 0)
         self.assertTrue(fpath.endswith(".png"))
 
-        fpath = ec.plot(persist = True, format = PFormats.DOT)
+        os.remove(fpath)
+
+        fpath = ec.plot(format = PFormats.DOT)
         statinfo = os.stat(fpath)
         size = statinfo.st_size
         self.assertTrue(size > 0)
         self.assertTrue(fpath.endswith(".dot"))
 
-        print fpath
-
+        os.remove(fpath)
 
 if __name__ == '__main__':
     unittest.main()