still making both branches closer
[nepi.git] / src / nepi / util / plotter.py
index c21f598..f480ab3 100644 (file)
@@ -3,9 +3,8 @@
 #    Copyright (C) 2013 INRIA
 #
 #    This program is free software: you can redistribute it and/or modify
-#    it under the terms of the GNU General Public License as published by
-#    the Free Software Foundation, either version 3 of the License, or
-#    (at your option) any later version.
+#    it under the terms of the GNU General Public License version 2 as
+#    published by the Free Software Foundation;
 #
 #    This program is distributed in the hope that it will be useful,
 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
 #
 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
 
-import networkx
+import logging
 import os
 
+try:
+    import networkx
+except ImportError:
+    msg = "Networkx library is not installed, you will not be able to plot."
+    logger = logging.Logger("Plotter")
+    logger.debug(msg)
+
+try:
+    import matplotlib.pyplot as plt
+except:
+    msg = ("Matplotlib library is not installed or X11 is not enabled. "
+        "You will not be able generate PNG plots.")
+    logger = logging.Logger("Plotter")
+    logger.debug(msg)
+
 class PFormats:
     DOT = "dot"
     FIGURE = "figure"
 
 class ECPlotter(object):
-    def plot(self, ec, fpath = None, format= PFormats.FIGURE, persist = False):
+    def plot(self, ec, dirpath = None, format= PFormats.FIGURE, 
+            show = False):
         graph, labels = self._ec2graph(ec)
 
         add_extension = False
 
-        if persist and not fpath:
+        if not dirpath:
             import tempfile
             dirpath = tempfile.mkdtemp()
-            fpath = os.path.join(dirpath, "%s_%s" % (ec.exp_id, ec.run_id)) 
-            add_extension = True
+        
+        fpath = os.path.join(dirpath, "%s_%s" % (ec.exp_id, ec.run_id)) 
 
         if format == PFormats.FIGURE:
-            import matplotlib.pyplot as plt
             pos = networkx.graphviz_layout(graph, prog="neato")
             networkx.draw(graph, pos = pos, node_color="white", 
                     node_size = 500, with_labels=True)
            
-            label = "\n".join(map(lambda v: "%s: %s" % (v[0], v[1]), labels.iteritems()))
-            plt.annotate(label, xy=(0.0, 0.95), xycoords='axes fraction')
-            
-            if persist:
-                if add_extension:
-                    fpath += ".png"
+            label = "\n".join(["%s: %s" % (v[0], v[1]) for v in iter(labels.items())])
+            plt.annotate(label, xy=(0.05, 0.95), xycoords='axes fraction')
+           
+            fpath += ".png"
 
-                plt.savefig(fpath, bbox_inches="tight")
-            else:
+            plt.savefig(fpath, bbox_inches="tight")
+            
+            if show:
                 plt.show()
 
         elif format == PFormats.DOT:
-            if persist:
-                if add_extension:
-                    fpath += ".dot"
+            fpath += ".dot"
 
-                networkx.write_dot(graph, fpath)
-            else:
+            networkx.write_dot(graph, fpath)
+            
+            if show:
                 import subprocess
                 subprocess.call(["dot", "-Tps", fpath, "-o", "%s.ps" % fpath])
                 subprocess.call(["evince","%s.ps" % fpath])
@@ -72,7 +84,7 @@ class ECPlotter(object):
         labels = dict()
         connections = set()
 
-        for guid, rm in ec._resources.iteritems():
+        for guid, rm in ec._resources.items():
             label = rm.get_rtype()
 
             graph.add_node(guid,