Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/natcap/invest/reports/raster_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,14 +419,14 @@ def _configure_special_values(
text_specs = []

if lower_threshold is not None:
cmap.set_under(lower_color)
cmap = cmap.with_extremes(under=lower_color)
extend = 'min'
thresholds.append(lower_threshold)
labels.append(lower_label)
text_specs.append((0, -0.05, 'top'))

if upper_threshold is not None:
cmap.set_over(upper_color)
cmap = cmap.with_extremes(over=upper_color)
extend = 'max' if extend == 'neither' else 'both'
thresholds.append(upper_threshold)
labels.append(upper_label)
Expand Down
140 changes: 139 additions & 1 deletion tests/reports/test_raster_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def save_figure(fig, filepath):
def make_simple_raster(target_filepath, shape):
array = numpy.linspace(0, 1, num=numpy.multiply(*shape)).reshape(*shape)
pygeoprocessing.numpy_array_to_raster(
array, target_nodata=None, pixel_size=(1, 1), origin=(0, 0),
array, target_nodata=None, pixel_size=(1, -1), origin=(0, 0),
projection_wkt=PROJ_WKT, target_path=target_filepath)


Expand Down Expand Up @@ -325,6 +325,18 @@ def test_plot_raster_list_different_transforms(self):
save_figure(fig, actual_png)
compare_snapshots(reference, actual_png)


class RasterSpecialValueConfigTests(unittest.TestCase):
"""Snapshot tests for special values in RasterConfig."""

def setUp(self):
"""Override setUp function to create temp workspace directory."""
self.workspace_dir = tempfile.mkdtemp()

def tearDown(self):
"""Override tearDown function to remove temporary directory."""
shutil.rmtree(self.workspace_dir)

def test_special_value_config(self):
"""Should pass when only index 0 (lower bound) is fully populated."""
config = SpecialValueConfig(
Expand Down Expand Up @@ -361,6 +373,132 @@ def test_special_value_config(self):
"If index 1 is `None` in any of the special config tuples" in
str(context.exception))

def test_special_values_rejected_for_nominal(self):
"""RasterPlotConfig does not allow special values for nominal raster"""
with self.assertRaisesRegex(
ValueError, '`special_values` may only be defined'):
raster_utils.RasterPlotConfig(
raster_path=os.path.join(self.workspace_dir, 'foo.tif'),
datatype=raster_utils.RasterDatatype.nominal,
spec=spec.Output(id='foo', about='foo output'),
special_values=raster_utils.SpecialValueConfig(
thresholds=(-1, 1),
labels=('low', 'high'),
colors=('red', 'blue')))

def test_special_values_rejected_for_binary(self):
"""RasterPlotConfig does not allow special values for binary raster"""
with self.assertRaisesRegex(
ValueError, '`special_values` may only be defined'):
raster_utils.RasterPlotConfig(
raster_path=os.path.join(self.workspace_dir, 'foo.tif'),
datatype=raster_utils.RasterDatatype.binary,
spec=spec.Output(id='foo', about='foo output'),
special_values=raster_utils.SpecialValueConfig(
thresholds=(-1, 1),
labels=('low', 'high'),
colors=('red', 'blue')))

def test_configure_special_values_both_bounds(self):
"""_configure_special_values configures both colorbar extensions."""
cmap = matplotlib.colormaps['viridis'].copy()
special_values = raster_utils.SpecialValueConfig(
thresholds=(-1, 1),
labels=('low', 'high'),
colors=('red', 'blue'))

extend, thresholds, labels, text_specs = (
raster_utils._configure_special_values(cmap, special_values))

self.assertEqual(extend, 'both')
self.assertEqual(thresholds, [-1, 1])
self.assertEqual(labels, ['low', 'high'])
self.assertEqual(
text_specs, [(0, -0.05, 'top'), (0, 1.05, 'bottom')])

def test_plot_divergent_log_raster_requires_symmetric_thresholds(
self):
"""Divergent log special values must be symmetric around 0."""
shape = (4, 4)
raster_config = raster_utils.RasterPlotConfig(
raster_path=os.path.join(self.workspace_dir, 'foo.tif'),
datatype=raster_utils.RasterDatatype.divergent,
transform="log",
spec=spec.Output(id='foo', about='foo output'),
special_values=raster_utils.SpecialValueConfig(
thresholds=(0.4, 1),
labels=('low', 'high'),
colors=('black', 'orange')))
make_simple_raster(raster_config.raster_path, shape)

with self.assertRaisesRegex(
UserWarning, 'To ensure that 0 falls at the logical break'):
raster_utils.plot_raster_list([raster_config])

def test_plot_continuous_raster_special_values(self):
"""Test correct plot for continuous raster with special values"""
figname = 'plot_raster_list_special_values.png'
reference = os.path.join(REFS_DIR, figname)
shape = (4, 4)
raster_config = raster_utils.RasterPlotConfig(
raster_path=os.path.join(self.workspace_dir, 'foo.tif'),
datatype=raster_utils.RasterDatatype.continuous,
spec=spec.Output(id='foo', about='foo output'),
special_values=raster_utils.SpecialValueConfig(
thresholds=(0.4, 1),
labels=('low', 'high'),
colors=('black', 'orange')))
make_simple_raster(raster_config.raster_path, shape)

config_list = [raster_config]
fig = raster_utils.plot_raster_list(config_list)
actual_png = os.path.join(self.workspace_dir, figname)
save_figure(fig, actual_png)
compare_snapshots(reference, actual_png)

def test_plot_divergent_raster_max_special_value(self):
"""Test divergent raster plot w special value has a correct colorbar"""
figname = 'plot_raster_list_special_max_value.png'
reference = os.path.join(REFS_DIR, figname)
shape = (4, 4)
raster_config = raster_utils.RasterPlotConfig(
raster_path=os.path.join(self.workspace_dir, 'foo.tif'),
datatype=raster_utils.RasterDatatype.divergent,
spec=spec.Output(id='foo', about='foo output'),
special_values=raster_utils.SpecialValueConfig(
thresholds=(None, 0.8),
labels=(None, 'high'),
colors=(None, 'darkblue')))
# Note this raster doesn't actually have negative values
make_simple_raster(raster_config.raster_path, shape)

config_list = [raster_config]
fig = raster_utils.plot_raster_list(config_list)
actual_png = os.path.join(self.workspace_dir, figname)
save_figure(fig, actual_png)
compare_snapshots(reference, actual_png)

def test_plot_raster_list_special_values_adds_threshold_ticks(self):
"""Test plot_raster_list adds special values as colorbar ticks."""
thresholds = (-.8, .9)
shape = (4, 4)
raster_config = raster_utils.RasterPlotConfig(
raster_path=os.path.join(self.workspace_dir, 'foo.tif'),
datatype=raster_utils.RasterDatatype.continuous,
spec=spec.Output(id='foo', about='foo output'),
special_values=raster_utils.SpecialValueConfig(
thresholds=thresholds,
labels=('low', 'high'),
colors=('red', 'blue')))
make_simple_raster(raster_config.raster_path, shape)

fig = raster_utils.plot_raster_list([raster_config])
colorbar_ax = fig.axes[1]
ticks = list(colorbar_ax.get_yticks())

self.assertIn(thresholds[0], ticks)
self.assertIn(thresholds[1], ticks)


class RasterPlotLegendTests(unittest.TestCase):
"""Snapshot tests for legend placement on nominal rasters."""
Expand Down
Loading