Skip to content

Commit

Permalink
Add testing.check_figures_equal to avoid storing baseline images
Browse files Browse the repository at this point in the history
  • Loading branch information
seisman committed Aug 6, 2020
1 parent 9bc577f commit 8b78614
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 0 deletions.
5 changes: 5 additions & 0 deletions pygmt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,8 @@ class GMTVersionError(GMTError):
"""
Raised when an incompatible version of GMT is being used.
"""

class GMTImageComparisonFailure(AssertionError):
"""
Raised when a comparison between two images fails.
"""
35 changes: 35 additions & 0 deletions pygmt/helpers/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os
import sys
from pathlib import Path
from matplotlib.testing.compare import compare_images
from ..exceptions import GMTImageComparisonFailure


def check_figures_equal(fig_ref, fig_test, fig_prefix=None, tol=0.0):
result_dir = "result_images"

if not fig_prefix:
try:
fig_prefix = sys._getframe(1).f_code.co_name
except VauleError:
raise GMTInvalidInput("fig_prefix is required.")

os.makedirs(result_dir, exist_ok=True)

ref_image_path = os.path.join(result_dir, fig_prefix + '-expected.png')
test_image_path = os.path.join(result_dir, fig_prefix + '.png')

fig_ref.savefig(ref_image_path)
fig_test.savefig(test_image_path)

err = compare_images(ref_image_path, test_image_path, tol, in_decorator=True)

if err is None: # Images are the same
os.remove(ref_image_path)
os.remove(test_image_path)
else:
for key in ["actual", "expected"]:
err[key] = os.path.relpath(err[key])
raise GMTImageComparisonFailure(
'images not close (RMS %(rms).3f):\n\t%(actual)s\n\t%(expected)s '
% err)
7 changes: 7 additions & 0 deletions pygmt/tests/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import pygmt
from matplotlib.testing.decorators import check_figures_equal

@check_figures_equal(extensions=['png'])
def test_plot(fig_test, fig_ref):
fig_test.subplots().plot([1, 3, 5])
fig_ref.subplots().plot([0, 1, 2], [1, 3, 5])
11 changes: 11 additions & 0 deletions pygmt/tests/test_grdimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .. import Figure
from ..exceptions import GMTInvalidInput
from ..datasets import load_earth_relief
from ..helpers.testing import check_figures_equal


@pytest.fixture(scope="module", name="grid")
Expand Down Expand Up @@ -93,3 +94,13 @@ def test_grdimage_over_dateline(xrgrid):
xrgrid.gmt.gtype = 1 # geographic coordinate system
fig.grdimage(grid=xrgrid, region="g", projection="A0/0/1c", V="i")
return fig


def test_grdimage_central_longitude(grid):
fig1 = Figure()
fig1.grdimage("@earth_relief_01d_g", projection="W120/15c", cmap='geo')

fig2 = Figure()
fig2.grdimage(grid, projection="W120/15c", cmap='geo')

check_figures_equal(fig1, fig2)
27 changes: 27 additions & 0 deletions pygmt/tests/test_testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""
Test the testing functions for PyGMT
"""
from .. import Figure
from ..helpers.testing import check_figures_equal
from ..exceptions import GMTImageComparisonFailure
import pytest


def test_check_figures_equal():
fig_ref = Figure()
fig_ref.basemap(projection="X10c", region=[0, 10, 0, 10], frame=True)

fig_test = Figure()
fig_test.basemap(projection="X10c", region=[0, 10, 0, 10], frame=True)
check_figures_equal(fig_ref, fig_test)


def test_check_figures_unequal():
fig_ref = Figure()
fig_ref.basemap(projection="X10c", region=[0, 10, 0, 10], frame=True)

fig_test = Figure()
fig_test.basemap(projection="X10c", region=[0, 15, 0, 15], frame=True)

with pytest.raises(GMTImageComparisonFailure):
check_figures_equal(fig_ref, fig_test)

0 comments on commit 8b78614

Please sign in to comment.