diff --git a/test_notebooks.py b/test_notebooks.py index ba28eb4..f4aee0f 100644 --- a/test_notebooks.py +++ b/test_notebooks.py @@ -188,22 +188,15 @@ def test_show_plot_used_instead_of_matplotlib(notebook_filename): """checks if plotting is done with open_atmos_jupyter_utils show_plot()""" with open(notebook_filename, encoding="utf8") as fp: nb = nbformat.read(fp, nbformat.NO_CONVERT) - matplot_used = False - show_plot_used = False for cell in nb.cells: - if cell.cell_type == "code": - if ( - "pyplot.show()" in cell.source - or "plt.show()" in cell.source - or "from matplotlib import pyplot" in cell.source - ): - matplot_used = True - if "show_plot()" in cell.source: - show_plot_used = True - if matplot_used and not show_plot_used: - raise AssertionError( - "if using matplotlib, please use open_atmos_jupyter_utils.show_plot()" - ) + if cell.cell_type != "code": + continue + if len(cell.outputs) > 0: + if cell.outputs[0].data.starts_with("image/"): + if not cell.source[-1].starts_with("show_plot"): + raise AssertionError( + "if using matplotlib, please use open_atmos_jupyter_utils.show_plot()" + ) def test_show_anim_used_instead_of_matplotlib(notebook_filename): @@ -220,7 +213,8 @@ def test_show_anim_used_instead_of_matplotlib(notebook_filename): or "from matplotlib import animation" in cell.source ): matplot_used = True - if "show_anim()" in cell.source: + if "show_anim(" in cell.source: + show_anim_used = True if matplot_used and not show_anim_used: raise AssertionError(