diff --git a/apps/CMakeLists.txt b/apps/CMakeLists.txt index 530bd8eecf07..fefb8d67791d 100644 --- a/apps/CMakeLists.txt +++ b/apps/CMakeLists.txt @@ -12,6 +12,7 @@ add_library( gdalalg_pipeline.cpp gdalalg_raster.cpp gdalalg_raster_info.cpp + gdalalg_raster_calc.cpp gdalalg_raster_clip.cpp gdalalg_raster_convert.cpp gdalalg_raster_edit.cpp diff --git a/apps/gdalalg_raster.cpp b/apps/gdalalg_raster.cpp index ba68d40d55ca..769c9a9b4850 100644 --- a/apps/gdalalg_raster.cpp +++ b/apps/gdalalg_raster.cpp @@ -13,6 +13,7 @@ #include "gdalalgorithm.h" #include "gdalalg_raster_info.h" +#include "gdalalg_raster_calc.h" #include "gdalalg_raster_clip.h" #include "gdalalg_raster_convert.h" #include "gdalalg_raster_edit.h" @@ -41,6 +42,7 @@ class GDALRasterAlgorithm final : public GDALAlgorithm GDALRasterAlgorithm() : GDALAlgorithm(NAME, DESCRIPTION, HELP_URL) { RegisterSubAlgorithm(); + RegisterSubAlgorithm(); RegisterSubAlgorithm(); RegisterSubAlgorithm(); RegisterSubAlgorithm(); diff --git a/apps/gdalalg_raster_calc.cpp b/apps/gdalalg_raster_calc.cpp new file mode 100644 index 000000000000..8b7650eb8442 --- /dev/null +++ b/apps/gdalalg_raster_calc.cpp @@ -0,0 +1,620 @@ +/****************************************************************************** + * + * Project: GDAL + * Purpose: "gdal raster calc" subcommand + * Author: Daniel Baston + * + ****************************************************************************** + * Copyright (c) 2025, ISciences LLC + * + * SPDX-License-Identifier: MIT + ****************************************************************************/ + +#include "gdalalg_raster_calc.h" + +#include "../frmts/vrt/gdal_vrt.h" +#include "../frmts/vrt/vrtdataset.h" + +#include "gdal_priv.h" +#include "gdal_utils.h" + +#include +#include +#include + +//! @cond Doxygen_Suppress + +#ifndef _ +#define _(x) (x) +#endif + +struct GDALCalcOptions +{ + bool checkSRS{true}; + bool checkExtent{true}; +}; + +static bool MatchIsCompleteVariableNameWithNoIndex(const std::string &str, + size_t from, size_t to) +{ + if (to < str.size()) + { + // If the character after the end of the match is: + // * alphanumeric or _ : we've matched only part of a variable name + // * [ : we've matched a variable that already has an index + // * ( : we've matched a function name + if (std::isalnum(str[to]) || str[to] == '_' || str[to] == '[' || + str[to] == '(') + { + return false; + } + } + if (from > 0) + { + // If the character before the start of the match is alphanumeric or _, + // we've matched only part of a variable name. + if (std::isalnum(str[from - 1]) || str[from - 1] == '_') + { + return false; + } + } + + return true; +} + +/** + * Add a band subscript to all instances of a specified variable that + * do not already have such a subscript. For example, "X" would be + * replaced with "X[3]" but "X[1]" would be left untouched. + */ +static std::string SetBandIndices(const std::string &origExpression, + const std::string &variable, int band, + bool &expressionChanged) +{ + std::string expression = origExpression; + expressionChanged = false; + + std::string::size_type seekPos = 0; + auto pos = expression.find(variable, seekPos); + while (pos != std::string::npos) + { + auto end = pos + variable.size(); + + if (MatchIsCompleteVariableNameWithNoIndex(expression, pos, end)) + { + // No index specified for variable + expression = expression.substr(0, pos + variable.size()) + '[' + + std::to_string(band) + ']' + expression.substr(end); + expressionChanged = true; + } + + seekPos = end; + pos = expression.find(variable, seekPos); + } + + return expression; +} + +struct SourceProperties +{ + int nBands{0}; + int nX{0}; + int nY{0}; + std::array gt{}; + std::unique_ptr srs{ + nullptr}; +}; + +static std::optional +UpdateSourceProperties(SourceProperties &out, const std::string &dsn, + const GDALCalcOptions &options) +{ + SourceProperties source; + bool srsMismatch = false; + bool extentMismatch = false; + bool dimensionMismatch = false; + + { + std::unique_ptr ds( + GDALDataset::Open(dsn.c_str(), GDAL_OF_RASTER)); + + if (!ds) + { + CPLError(CE_Failure, CPLE_AppDefined, "Failed to open %s", + dsn.c_str()); + return std::nullopt; + } + + source.nBands = ds->GetRasterCount(); + source.nX = ds->GetRasterXSize(); + source.nY = ds->GetRasterYSize(); + + if (options.checkExtent) + { + ds->GetGeoTransform(source.gt.data()); + } + + if (options.checkSRS && out.srs) + { + const OGRSpatialReference *srs = ds->GetSpatialRef(); + srsMismatch = srs && !srs->IsSame(out.srs.get()); + } + } + + if (source.nX != out.nX || source.nY != out.nY) + { + dimensionMismatch = true; + } + + if (source.gt[0] != out.gt[0] || source.gt[2] != out.gt[2] || + source.gt[3] != out.gt[3] || source.gt[4] != out.gt[4]) + { + extentMismatch = true; + } + if (source.gt[1] != out.gt[1] || source.gt[5] != out.gt[5]) + { + // Resolutions are different. Are the extents the same? + double xmaxOut = out.gt[0] + out.nX * out.gt[1] + out.nY * out.gt[2]; + double yminOut = out.gt[3] + out.nX * out.gt[4] + out.nY * out.gt[5]; + + double xmax = + source.gt[0] + source.nX * source.gt[1] + source.nY * source.gt[2]; + double ymin = + source.gt[3] + source.nX * source.gt[4] + source.nY * source.gt[5]; + + // TODO use a tolerance here + if (xmax != xmaxOut || ymin != yminOut) + { + extentMismatch = true; + } + } + + if (options.checkExtent && extentMismatch) + { + CPLError(CE_Failure, CPLE_AppDefined, + "Input extents are inconsistent."); + return std::nullopt; + } + + if (!options.checkExtent && dimensionMismatch) + { + CPLError(CE_Failure, CPLE_AppDefined, + "Inputs do not have the same dimensions."); + return std::nullopt; + } + + // Choose the finest resolution + // TODO modify to use common resolution (https://github.com/OSGeo/gdal/issues/11497) + if (source.nX > out.nX) + { + out.nX = source.nX; + out.gt[1] = source.gt[1]; + } + if (source.nY > out.nY) + { + out.nY = source.nY; + out.gt[5] = source.gt[5]; + } + + if (srsMismatch) + { + CPLError(CE_Failure, CPLE_AppDefined, + "Input spatial reference systems are inconsistent."); + return std::nullopt; + } + + return source; +} + +/** Create XML nodes for one or more derived bands resulting from the evaluation + * of a single expression + * + * @param root VRTDataset node to which the band nodes should be added + * @param nXOut Number of columns in VRT dataset + * @param nYOut Number of rows in VRT dataset + * @param expression Expression for which band(s) should be added + * @param sources Mapping of source names to DSNs + * @param sourceProps Mapping of source names to properties + * @return true if the band(s) were added, false otherwise + */ +static bool +CreateDerivedBandXML(CPLXMLNode *root, int nXOut, int nYOut, + const std::string &expression, + const std::map &sources, + const std::map &sourceProps) +{ + int nOutBands = 1; // By default, each expression produces a single output + // band. When processing the expression below, we may + // discover that the expression produces multiple bands, + // in which case this will be updated. + for (int nOutBand = 1; nOutBand <= nOutBands; nOutBand++) + { + // Copy the expression for each output band, because we may modify it + // when adding band indices (e.g., X -> X[1]) to the variables in the + // expression. + std::string bandExpression = expression; + + CPLXMLNode *band = CPLCreateXMLNode(root, CXT_Element, "VRTRasterBand"); + CPLAddXMLAttributeAndValue(band, "subClass", "VRTDerivedRasterBand"); + // TODO: Allow user specification of output data type? + CPLAddXMLAttributeAndValue(band, "dataType", "Float64"); + + CPLXMLNode *pixelFunctionType = + CPLCreateXMLNode(band, CXT_Element, "PixelFunctionType"); + CPLCreateXMLNode(pixelFunctionType, CXT_Text, "expression"); + CPLXMLNode *arguments = + CPLCreateXMLNode(band, CXT_Element, "PixelFunctionArguments"); + + for (const auto &[source_name, dsn] : sources) + { + auto it = sourceProps.find(source_name); + CPLAssert(it != sourceProps.end()); + const auto &props = it->second; + + { + const int nDefaultInBand = std::min(props.nBands, nOutBand); + + CPLString expressionBandVariable; + expressionBandVariable.Printf("%s[%d]", source_name.c_str(), + nDefaultInBand); + + bool expressionUsesAllBands = false; + bandExpression = + SetBandIndices(bandExpression, source_name, nDefaultInBand, + expressionUsesAllBands); + + if (expressionUsesAllBands) + { + if (nOutBands <= 1) + { + nOutBands = props.nBands; + } + else if (props.nBands != 1 && props.nBands != nOutBands) + { + CPLError(CE_Failure, CPLE_AppDefined, + "Expression cannot operate on all bands of " + "rasters with incompatible numbers of bands " + "(source %s has %d bands but expected to have " + "1 or %d bands).", + source_name.c_str(), props.nBands, nOutBands); + return false; + } + } + } + + // Create a for each input band that is used in + // the expression. + for (int nInBand = 1; nInBand <= props.nBands; nInBand++) + { + CPLString inBandVariable; + inBandVariable.Printf("%s[%d]", source_name.c_str(), nInBand); + if (bandExpression.find(inBandVariable) == std::string::npos) + { + continue; + } + + CPLXMLNode *source = + CPLCreateXMLNode(band, CXT_Element, "SimpleSource"); + CPLAddXMLAttributeAndValue(source, "name", + inBandVariable.c_str()); + + CPLXMLNode *sourceFilename = + CPLCreateXMLNode(source, CXT_Element, "SourceFilename"); + CPLAddXMLAttributeAndValue(sourceFilename, "relativeToVRT", + "0"); + CPLCreateXMLNode(sourceFilename, CXT_Text, dsn.c_str()); + + CPLXMLNode *sourceBand = + CPLCreateXMLNode(source, CXT_Element, "SourceBand"); + CPLCreateXMLNode(sourceBand, CXT_Text, + std::to_string(nInBand).c_str()); + + // TODO add ? + + CPLXMLNode *srcRect = + CPLCreateXMLNode(source, CXT_Element, "SrcRect"); + CPLAddXMLAttributeAndValue(srcRect, "xOff", "0"); + CPLAddXMLAttributeAndValue(srcRect, "yOff", "0"); + CPLAddXMLAttributeAndValue(srcRect, "xSize", + std::to_string(props.nX).c_str()); + CPLAddXMLAttributeAndValue(srcRect, "ySize", + std::to_string(props.nY).c_str()); + + CPLXMLNode *dstRect = + CPLCreateXMLNode(source, CXT_Element, "DstRect"); + CPLAddXMLAttributeAndValue(dstRect, "xOff", "0"); + CPLAddXMLAttributeAndValue(dstRect, "yOff", "0"); + CPLAddXMLAttributeAndValue(dstRect, "xSize", + std::to_string(nXOut).c_str()); + CPLAddXMLAttributeAndValue(dstRect, "ySize", + std::to_string(nYOut).c_str()); + } + } + + // Add the expression as a last step, because we may modify the + // expression as we iterate through the bands. + CPLAddXMLAttributeAndValue(arguments, "expression", + bandExpression.c_str()); + CPLAddXMLAttributeAndValue(arguments, "dialect", "muparser"); + } + + return true; +} + +static bool ParseSourceDescriptors(const std::vector &inputs, + std::map &datasets, + std::string &firstSourceName) +{ + bool isFirst = true; + + for (const auto &input : inputs) + { + std::string name = ""; + + auto pos = input.find('='); + if (pos == std::string::npos) + { + if (inputs.size() > 1) + { + CPLError(CE_Failure, CPLE_AppDefined, + "Inputs must be named when more than one input is " + "provided."); + return false; + } + name = "X"; + } + else + { + name = input.substr(0, pos); + } + + std::string dsn = + (pos == std::string::npos) ? input : input.substr(pos + 1); + datasets[name] = std::move(dsn); + + if (isFirst) + { + firstSourceName = name; + isFirst = false; + } + } + + return true; +} + +static bool ReadFileLists(std::vector &inputs) +{ + for (std::size_t i = 0; i < inputs.size(); i++) + { + const auto &input = inputs[i]; + if (input[0] == '@') + { + auto f = + VSIVirtualHandleUniquePtr(VSIFOpenL(input.c_str() + 1, "r")); + if (!f) + { + CPLError(CE_Failure, CPLE_FileIO, "Cannot open %s", + input.c_str() + 1); + return false; + } + std::vector sources; + while (const char *filename = CPLReadLineL(f.get())) + { + sources.push_back(filename); + } + inputs.erase(inputs.begin() + static_cast(i)); + inputs.insert(inputs.end(), sources.begin(), sources.end()); + } + } + + return true; +} + +/** Creates a VRT datasource with one or more derived raster bands containing + * results of an expression. + * + * To make this work with muparser (which does not support vector types), we + * do a simple parsing of the expression internally, transforming it into + * multiple expressions with explicit band indices. For example, for a two-band + * raster "X", the expression "X + 3" will be transformed into "X[1] + 3" and + * "X[2] + 3". The use of brackets is for readability only; as far as the + * expression engine is concerned, the variables "X[1]" and "X[2]" have nothing + * to do with each other. + * + * @param inputs A list of sources, expressed as NAME=DSN + * @param expressions A list of expressions to be evaluated + * @param options flags controlling which checks should be performed on the inputs + * + * @return a newly created VRTDataset, or nullptr on error + */ +static std::unique_ptr +GDALCalcCreateVRTDerived(const std::vector &inputs, + const std::vector &expressions, + const GDALCalcOptions &options) +{ + if (inputs.empty()) + { + return nullptr; + } + + std::map sources; + std::string firstSource; + if (!ParseSourceDescriptors(inputs, sources, firstSource)) + { + return nullptr; + } + + // Use the first source provided to determine properties of the output + const char *firstDSN = sources[firstSource].c_str(); + + // Read properties from the first source + SourceProperties out; + { + std::unique_ptr ds( + GDALDataset::Open(firstDSN, GDAL_OF_RASTER)); + + if (!ds) + { + CPLError(CE_Failure, CPLE_AppDefined, "Failed to open %s", + firstDSN); + return nullptr; + } + + out.nX = ds->GetRasterXSize(); + out.nY = ds->GetRasterYSize(); + out.nBands = 1; + out.srs.reset(ds->GetSpatialRef() ? ds->GetSpatialRef()->Clone() + : nullptr); + ds->GetGeoTransform(out.gt.data()); + } + + CPLXMLNode *root = CPLCreateXMLNode(nullptr, CXT_Element, "VRTDataset"); + + // Collect properties of the different sources, and verity them for + // consistency. + std::map sourceProps; + for (const auto &[source_name, dsn] : sources) + { + // TODO avoid opening the first source twice. + auto props = UpdateSourceProperties(out, dsn, options); + if (props.has_value()) + { + sourceProps[source_name] = std::move(props.value()); + } + else + { + return nullptr; + } + } + + for (const auto &origExpression : expressions) + { + if (!CreateDerivedBandXML(root, out.nX, out.nY, origExpression, sources, + sourceProps)) + { + return nullptr; + } + } + + //CPLDebug("VRT", "%s", CPLSerializeXMLTree(root)); + + auto ds = std::make_unique(out.nX, out.nY); + if (ds->XMLInit(root, "") != CE_None) + { + return nullptr; + }; + ds->SetGeoTransform(out.gt.data()); + if (out.srs) + { + ds->SetSpatialRef(OGRSpatialReference::FromHandle(out.srs.get())); + } + + return ds; +} + +/************************************************************************/ +/* GDALRasterEditAlgorithm::GDALRasterEditAlgorithm() */ +/************************************************************************/ + +GDALRasterCalcAlgorithm::GDALRasterCalcAlgorithm() noexcept + : GDALAlgorithm(NAME, DESCRIPTION, HELP_URL) +{ + AddProgressArg(); + + AddArg(GDAL_ARG_NAME_INPUT, 'i', _("Input raster datasets"), &m_inputs) + .SetMinCount(1) + .SetAutoOpenDataset(false) + .SetMetaVar("INPUTS"); + + AddOutputFormatArg(&m_format); + AddOutputDatasetArg(&m_outputDataset, GDAL_OF_RASTER); + AddCreationOptionsArg(&m_creationOptions); + AddOverwriteArg(&m_overwrite); + + AddArg("no-check-srs", 0, + _("Do not check consistency of input spatial reference systems"), + &m_NoCheckSRS); + AddArg("no-check-extent", 0, _("Do not check consistency of input extents"), + &m_NoCheckExtent); + + AddArg("calc", 0, _("Expression(s) to evaluate"), &m_expr).SetMinCount(1); +} + +/************************************************************************/ +/* GDALRasterCalcAlgorithm::RunImpl() */ +/************************************************************************/ + +bool GDALRasterCalcAlgorithm::RunImpl(GDALProgressFunc pfnProgress, + void *pProgressData) +{ + if (m_outputDataset.GetDatasetRef()) + { + CPLError(CE_Failure, CPLE_NotSupported, + "gdal raster calc does not support outputting to an " + "already opened output dataset"); + return false; + } + + VSIStatBufL sStat; + if (!m_overwrite && !m_outputDataset.GetName().empty() && + (VSIStatL(m_outputDataset.GetName().c_str(), &sStat) == 0 || + std::unique_ptr( + GDALDataset::Open(m_outputDataset.GetName().c_str())))) + { + ReportError(CE_Failure, CPLE_AppDefined, + "File '%s' already exists. Specify the --overwrite " + "option to overwrite it.", + m_outputDataset.GetName().c_str()); + return false; + } + + GDALCalcOptions options; + options.checkExtent = !m_NoCheckExtent; + options.checkSRS = !m_NoCheckSRS; + + if (!ReadFileLists(m_inputs)) + { + return false; + } + + auto vrt = GDALCalcCreateVRTDerived(m_inputs, m_expr, options); + + if (vrt == nullptr) + { + return false; + } + + CPLStringList translateArgs; + if (!m_format.empty()) + { + translateArgs.AddString("-of"); + translateArgs.AddString(m_format.c_str()); + } + for (const auto &co : m_creationOptions) + { + translateArgs.AddString("-co"); + translateArgs.AddString(co.c_str()); + } + + GDALTranslateOptions *translateOptions = + GDALTranslateOptionsNew(translateArgs.List(), nullptr); + GDALTranslateOptionsSetProgress(translateOptions, pfnProgress, + pProgressData); + + auto poOutDS = + std::unique_ptr(GDALDataset::FromHandle(GDALTranslate( + m_outputDataset.GetName().c_str(), GDALDataset::ToHandle(vrt.get()), + translateOptions, nullptr))); + GDALTranslateOptionsFree(translateOptions); + + if (!poOutDS) + { + return false; + } + + m_outputDataset.Set(std::move(poOutDS)); + + return true; +} + +//! @endcond diff --git a/apps/gdalalg_raster_calc.h b/apps/gdalalg_raster_calc.h new file mode 100644 index 000000000000..d99a2b20feed --- /dev/null +++ b/apps/gdalalg_raster_calc.h @@ -0,0 +1,54 @@ +/****************************************************************************** + * + * Project: GDAL + * Purpose: "calc" step of "raster pipeline" + * Author: Daniel Baston + * + ****************************************************************************** + * Copyright (c) 2025, ISciences LLC + * + * SPDX-License-Identifier: MIT + ****************************************************************************/ + +#ifndef GDALALG_RASTER_CALC_INCLUDED +#define GDALALG_RASTER_CALC_INCLUDED + +#include "gdalalg_raster_pipeline.h" + +//! @cond Doxygen_Suppress + +/************************************************************************/ +/* GDALRasterCalcAlgorithm */ +/************************************************************************/ + +class GDALRasterCalcAlgorithm : public GDALAlgorithm +{ + public: + explicit GDALRasterCalcAlgorithm() noexcept; + + static constexpr const char *NAME = "calc"; + static constexpr const char *DESCRIPTION = "Perform raster algebra"; + static constexpr const char *HELP_URL = "/programs/gdal_raster_calc.html"; + + static std::vector GetAliases() + { + return {}; + } + + private: + bool RunImpl(GDALProgressFunc pfnProgress, void *pProgressData) override; + + std::vector m_inputs{}; + GDALArgDatasetValue m_dataset{}; + std::vector m_expr{}; + GDALArgDatasetValue m_outputDataset{}; + std::string m_format{}; + std::vector m_creationOptions{}; + bool m_overwrite{false}; + bool m_NoCheckSRS{false}; + bool m_NoCheckExtent{false}; +}; + +//! @endcond + +#endif /* GDALALG_RASTER_CALC_INCLUDED */ diff --git a/autotest/gdrivers/vrtderived.py b/autotest/gdrivers/vrtderived.py index 90e955a8f806..52220fc4a09d 100755 --- a/autotest/gdrivers/vrtderived.py +++ b/autotest/gdrivers/vrtderived.py @@ -1222,6 +1222,13 @@ def vrt_expression_xml(tmpdir, expression, dialect, sources): ["exprtk"], id="expression returns nodata", ), + pytest.param( + "ZB[1] + B[1]", + [("ZB[1]", 7), ("B[1]", 3)], + 10, + ["muparser"], + id="index substitution works correctly", + ), ], ) @pytest.mark.parametrize("dialect", ("exprtk", "muparser")) @@ -1274,8 +1281,8 @@ def test_vrt_pixelfn_expression( id="expression is too long", ), pytest.param( - "B[1]", - [("B[1]", 3)], + "B@1", + [("B@1", 3)], "muparser", "Invalid variable name", id="invalid variable name", diff --git a/autotest/utilities/test_gdalalg_raster_calc.py b/autotest/utilities/test_gdalalg_raster_calc.py new file mode 100755 index 000000000000..00ae5b37332b --- /dev/null +++ b/autotest/utilities/test_gdalalg_raster_calc.py @@ -0,0 +1,431 @@ +#!/usr/bin/env pytest +# -*- coding: utf-8 -*- +############################################################################### +# Project: GDAL/OGR Test Suite +# Purpose: 'gdal raster calc' testing +# Author: Daniel Baston +# +############################################################################### +# Copyright (c) 2025, ISciences LLC +# +# SPDX-License-Identifier: MIT +############################################################################### + +import re + +import gdaltest +import pytest + +from osgeo import gdal + +gdal.UseExceptions() + + +@pytest.fixture(scope="module", autouse=True) +def require_muparser(): + if not gdaltest.gdal_has_vrt_expression_dialect("muparser"): + pytest.skip("muparser not available") + + +@pytest.fixture() +def calc(): + reg = gdal.GetGlobalAlgorithmRegistry() + raster = reg.InstantiateAlg("raster") + return raster.InstantiateSubAlgorithm("calc") + + +@pytest.mark.parametrize("output_format", ("tif", "vrt")) +def test_gdalalg_raster_calc_basic_1(calc, tmp_vsimem, output_format): + + np = pytest.importorskip("numpy") + + infile = "../gcore/data/rgbsmall.tif" + outfile = tmp_vsimem / "out.tif" + + calc["input"] = [infile] + calc["output"] = outfile + calc["calc"] = ["2 + X / (1 + sum(X[1], X[2], X[3]))"] + + assert calc.Run() + + with gdal.Open(infile) as src, gdal.Open(outfile) as dst: + srcval = src.ReadAsArray().astype("float64") + expected = np.apply_along_axis(lambda x: 2 + x / (1 + x.sum()), 0, srcval) + + np.testing.assert_array_equal(expected, dst.ReadAsArray()) + assert src.GetGeoTransform() == dst.GetGeoTransform() + assert src.GetSpatialRef().IsSame(dst.GetSpatialRef()) + + +@pytest.mark.parametrize("output_format", ("tif", "vrt")) +def test_gdalalg_raster_calc_basic_2(calc, tmp_vsimem, output_format): + + np = pytest.importorskip("numpy") + + infile = "../gcore/data/byte.tif" + outfile = tmp_vsimem / "out.tif" + + calc["input"] = [infile] + calc["output"] = outfile + calc["calc"] = ["X > 128 ? X + 3 : nan"] + + assert calc.Run() + + with gdal.Open(infile) as src, gdal.Open(outfile) as dst: + srcval = src.ReadAsArray().astype("float64") + expected = np.where(srcval > 128, srcval + 3, float("nan")) + + np.testing.assert_array_equal(expected, dst.ReadAsArray()) + assert src.GetGeoTransform() == dst.GetGeoTransform() + assert src.GetSpatialRef().IsSame(dst.GetSpatialRef()) + + +def test_gdalalg_raster_calc_creation_options(calc, tmp_vsimem): + + infile = "../gcore/data/byte.tif" + outfile = tmp_vsimem / "out.tif" + + calc["input"] = [infile] + calc["output"] = outfile + calc["creation-option"] = ["COMPRESS=LZW"] + calc["calc"] = ["X[1] + 3"] + + assert calc.Run() + + with gdal.Open(outfile) as dst: + assert dst.GetMetadata("IMAGE_STRUCTURE")["COMPRESSION"] == "LZW" + + +def test_gdalalg_raster_calc_output_format(calc, tmp_vsimem): + + infile = "../gcore/data/byte.tif" + outfile = tmp_vsimem / "out.unknown" + + calc["input"] = [infile] + calc["output"] = outfile + calc["output-format"] = "GTiff" + calc["calc"] = ["X + 3"] + + assert calc.Run() + + with gdal.Open(outfile) as dst: + assert dst.GetDriver().GetName() == "GTiff" + + +def test_gdalalg_raster_calc_overwrite(calc, tmp_vsimem): + + infile = "../gcore/data/byte.tif" + outfile = tmp_vsimem / "out.tif" + + gdal.CopyFile(infile, outfile) + + calc["input"] = [infile] + calc["output"] = outfile + calc["calc"] = ["X + 3"] + + with pytest.raises(Exception, match="already exists"): + assert not calc.Run() + + calc["overwrite"] = True + + assert calc.Run() + + +@pytest.mark.parametrize("expr", ("X + 3", "X[1] + 3")) +def test_gdalalg_raster_calc_basic_named_source(calc, tmp_vsimem, expr): + + np = pytest.importorskip("numpy") + + infile = "../gcore/data/byte.tif" + outfile = tmp_vsimem / "out.tif" + + calc["input"] = [f"X={infile}"] + calc["output"] = outfile + calc["calc"] = [expr] + + assert calc.Run() + + with gdal.Open(infile) as src, gdal.Open(outfile) as dst: + np.testing.assert_array_equal(src.ReadAsArray() + 3.0, dst.ReadAsArray()) + + +def test_gdalalg_raster_calc_multiple_calcs(calc, tmp_vsimem): + + np = pytest.importorskip("numpy") + + infile = "../gcore/data/byte.tif" + outfile = tmp_vsimem / "out.tif" + + calc["input"] = [infile] + calc["output"] = outfile + calc["calc"] = ["X + 3", "sqrt(X)"] + + assert calc.Run() + + with gdal.Open(infile) as src, gdal.Open(outfile) as dst: + src_dat = src.ReadAsArray() + dst_dat = dst.ReadAsArray() + + np.testing.assert_array_equal(src_dat + 3.0, dst_dat[0]) + np.testing.assert_array_equal(np.sqrt(src_dat.astype(np.float64)), dst_dat[1]) + + +@pytest.mark.parametrize( + "expr", + ( + "(A+B) / (A - B + 3)", + "A[2] + B", + ), +) +def test_gdalalg_raster_calc_multiple_inputs(calc, tmp_vsimem, expr): + + np = pytest.importorskip("numpy") + + # convert 1-based indices to 0-based indices to evaluate expression + # with numpy + numpy_expr = expr + for match in re.findall(r"(?<=\[)\d+(?=])", expr): + numpy_expr = re.sub(match, str(int(match) - 1), expr, count=1) + + nx = 3 + ny = 5 + nz = 2 + + input_1 = tmp_vsimem / "in1.tif" + input_2 = tmp_vsimem / "in2.tif" + outfile = tmp_vsimem / "out.tif" + + A = np.arange(nx * ny * nz, dtype=np.float32).reshape(nz, ny, nx) + B = np.sqrt(A) + + with gdal.GetDriverByName("GTiff").Create( + input_1, nx, ny, nz, eType=gdal.GDT_Float32 + ) as ds: + ds.WriteArray(A) + + with gdal.GetDriverByName("GTiff").Create( + input_2, nx, ny, nz, eType=gdal.GDT_Float32 + ) as ds: + ds.WriteArray(B) + + calc["input"] = [f"A={input_1}", f"B={input_2}"] + calc["output"] = outfile + calc["calc"] = [expr] + + assert calc.Run() + + with gdal.Open(outfile) as dst: + dat = dst.ReadAsArray() + np.testing.assert_allclose(dat, eval(numpy_expr), rtol=1e-6) + + +def test_gdalalg_raster_calc_inputs_from_file(calc, tmp_vsimem, tmp_path): + + np = pytest.importorskip("numpy") + + input_1 = tmp_vsimem / "in1.tif" + input_2 = tmp_vsimem / "in2.tif" + input_txt = tmp_path / "inputs.txt" + outfile = tmp_vsimem / "out.tif" + + with gdal.GetDriverByName("GTiff").Create(input_1, 2, 2) as ds: + ds.GetRasterBand(1).Fill(1) + + with gdal.GetDriverByName("GTIff").Create(input_2, 2, 2) as ds: + ds.GetRasterBand(1).Fill(2) + + with gdal.VSIFile(input_txt, "w") as txtfile: + txtfile.write(f"A={input_1}\n") + txtfile.write(f"B={input_2}\n") + + calc["input"] = [f"@{input_txt}"] + calc["output"] = outfile + calc["calc"] = ["A + B"] + + assert calc.Run() + + with gdal.Open(outfile) as dst: + assert np.all(dst.ReadAsArray() == 3) + + +def test_gdalalg_raster_calc_different_band_counts(calc, tmp_vsimem): + + np = pytest.importorskip("numpy") + + input_1 = tmp_vsimem / "in1.tif" + input_2 = tmp_vsimem / "in2.tif" + outfile = tmp_vsimem / "out.tif" + + with gdal.GetDriverByName("GTiff").Create(input_1, 2, 2, 2) as ds: + ds.GetRasterBand(1).Fill(1) + ds.GetRasterBand(2).Fill(2) + + with gdal.GetDriverByName("GTIff").Create(input_2, 2, 2, 3) as ds: + ds.GetRasterBand(1).Fill(3) + ds.GetRasterBand(2).Fill(4) + ds.GetRasterBand(3).Fill(5) + + calc["input"] = [f"A={input_1}", f"B={input_2}"] + calc["output"] = outfile + calc["calc"] = ["A[1] + A[2] + B[1] + B[2] + B[3]"] + + assert calc.Run() + + with gdal.Open(outfile) as dst: + assert np.all(dst.ReadAsArray() == (1 + 2 + 3 + 4 + 5)) + + +def test_gdalalg_calc_different_resolutions(calc, tmp_vsimem): + + np = pytest.importorskip("numpy") + + xmax = 60 + ymax = 60 + resolutions = [10, 20, 60] + + inputs = [tmp_vsimem / f"in_{i}.tif" for i in range(len(resolutions))] + outfile = tmp_vsimem / "out.tif" + + for res, fname in zip(resolutions, inputs): + with gdal.GetDriverByName("GTiff").Create( + fname, int(xmax / res), int(ymax / res), 1 + ) as ds: + ds.GetRasterBand(1).Fill(res) + ds.SetGeoTransform((0, res, 0, ymax, 0, -res)) + + calc["input"] = [f"A={inputs[0]}", f"B={inputs[1]}", f"C={inputs[2]}"] + calc["calc"] = ["A + B + C"] + calc["output"] = outfile + + calc["no-check-extent"] = True + with pytest.raises(Exception, match="Inputs do not have the same dimensions"): + calc.Run() + calc["no-check-extent"] = False + + assert calc.Run() + + with gdal.Open(outfile) as ds: + assert ds.RasterXSize == xmax / min(resolutions) + assert ds.RasterYSize == ymax / min(resolutions) + + assert np.all(ds.ReadAsArray() == sum(resolutions)) + + +def test_gdalalg_raster_calc_error_extent_mismatch(calc, tmp_vsimem): + + input_1 = tmp_vsimem / "in1.tif" + input_2 = tmp_vsimem / "in2.tif" + outfile = tmp_vsimem / "out.tif" + + with gdal.GetDriverByName("GTiff").Create(input_1, 2, 2) as ds: + ds.SetGeoTransform((0, 1, 0, 2, 0, -1)) + with gdal.GetDriverByName("GTIff").Create(input_2, 2, 2) as ds: + ds.SetGeoTransform((0, 2, 0, 4, 0, -2)) + + calc["input"] = [f"A={input_1}", f"B={input_2}"] + calc["output"] = outfile + calc["calc"] = ["A+B"] + + with pytest.raises(Exception, match="extents are inconsistent"): + calc.Run() + + calc["no-check-extent"] = True + assert calc.Run() + + with gdal.Open(input_1) as src, gdal.Open(outfile) as dst: + assert src.GetGeoTransform() == dst.GetGeoTransform() + + +def test_gdalalg_raster_calc_error_crs_mismatch(calc, tmp_vsimem): + + input_1 = tmp_vsimem / "in1.tif" + input_2 = tmp_vsimem / "in2.tif" + outfile = tmp_vsimem / "out.tif" + + with gdal.GetDriverByName("GTiff").Create(input_1, 2, 2) as ds: + ds.SetProjection("EPSG:4326") + with gdal.GetDriverByName("GTIff").Create(input_2, 2, 2) as ds: + ds.SetProjection("EPSG:4269") + + calc["input"] = [f"B={input_1}", f"A={input_2}"] + calc["output"] = outfile + calc["calc"] = ["A+B"] + + with pytest.raises(Exception, match="spatial reference systems are inconsistent"): + calc.Run() + + calc["no-check-srs"] = True + assert calc.Run() + + with gdal.Open(input_1) as src, gdal.Open(outfile) as dst: + assert src.GetSpatialRef().IsSame(dst.GetSpatialRef()) + + +@pytest.mark.parametrize("bands", [(2, 3), (2, 4)]) +def test_gdalalg_raster_calc_error_band_count_mismatch(calc, tmp_vsimem, bands): + + input_1 = tmp_vsimem / "in1.tif" + input_2 = tmp_vsimem / "in2.tif" + outfile = tmp_vsimem / "out.tif" + + gdal.GetDriverByName("GTiff").Create(input_1, 2, 2, bands[0]) + gdal.GetDriverByName("GTIff").Create(input_2, 2, 2, bands[1]) + + calc["input"] = [f"A={input_1}", f"B={input_2}"] + calc["output"] = outfile + calc["calc"] = ["A+B"] + + with pytest.raises(Exception, match="incompatible numbers of bands"): + calc.Run() + + calc["calc"] = ["A+B[1]"] + assert calc.Run() + + +@pytest.mark.parametrize( + "expr,source,bands,expected", + [ + ("aX + 2", "aX", 1, ["aX[1] + 2"]), + ("aX + 2", "aX", 2, ["aX[1] + 2", "aX[2] + 2"]), + ("aX + 2", "X", 1, ["aX + 2"]), + ("aX + 2", "a", 1, ["aX + 2"]), + ("2 + aX", "X", 1, ["2 + aX"]), + ("2 + aX", "aX", 1, ["2 + aX[1]"]), + ("B1 + B10", "B1", 1, ["B1[1] + B10"]), + ("B1[1] + B10", "B1", 2, ["B1[1] + B10"]), + ("B1[1] + B1", "B1", 2, ["B1[1] + B1[1]", "B1[1] + B1[2]"]), + ("SIN(N) + N", "N", 1, ["SIN(N[1]) + N[1]"]), + ("SUM(N,N2) + N", "N", 1, ["SUM(N[1],N2) + N[1]"]), + ("SUM(N,N2) + N", "N2", 1, ["SUM(N,N2[1]) + N"]), + ("A_X + X", "X", 1, ["A_X + X[1]"]), + ], +) +def test_gdalalg_raster_calc_expression_rewriting( + calc, tmp_vsimem, expr, source, bands, expected +): + # The expression rewriting isn't exposed to Python, so we + # create an VRT with an expression and a single source, and + # inspect the transformed expression in the VRT XML. + # The transformed expression need not be valid, because we + # don't actually read the VRT in GDAL. + + import xml.etree.ElementTree as ET + + outfile = tmp_vsimem / "out.vrt" + infile = tmp_vsimem / "input.tif" + + gdal.GetDriverByName("GTiff").Create(infile, 2, 2, bands) + + calc["input"] = [f"{source}={infile}"] + calc["output"] = outfile + calc["calc"] = [expr] + + assert calc.Run() + + with gdal.VSIFile(outfile, "r") as f: + root = ET.fromstring(f.read()) + + expr = [ + node.attrib["expression"] for node in root.findall(".//PixelFunctionArguments") + ] + assert expr == expected diff --git a/doc/source/programs/gdal_raster.rst b/doc/source/programs/gdal_raster.rst index 4f6190896d98..bcc331dce415 100644 --- a/doc/source/programs/gdal_raster.rst +++ b/doc/source/programs/gdal_raster.rst @@ -19,6 +19,7 @@ Synopsis Usage: gdal raster where is one of: + - calc: Perform raster algebra. - clip: Clip a raster dataset. - convert: Convert a raster dataset. - edit: Edit a raster dataset. @@ -34,6 +35,7 @@ Available sub-commands ---------------------- - :ref:`gdal_raster_info_subcommand` +- :ref:`gdal_raster_calc_subcommand` - :ref:`gdal_raster_clip_subcommand` - :ref:`gdal_raster_convert_subcommand` - :ref:`gdal_raster_mosaic_subcommand` diff --git a/doc/source/programs/gdal_raster_calc.rst b/doc/source/programs/gdal_raster_calc.rst new file mode 100644 index 000000000000..0d1b8c6092cb --- /dev/null +++ b/doc/source/programs/gdal_raster_calc.rst @@ -0,0 +1,104 @@ +.. _gdal_raster_calc_subcommand: + +================================================================================ +"gdal raster calc" sub-command +================================================================================ + +.. versionadded:: 3.11 + +.. only:: html + + Perform raster algebra + +.. Index:: gdal raster calc + +Synopsis +-------- + +.. program-output:: gdal raster calc --help + :ellipsis: -5 + + +Description +----------- + +:program:`gdal raster calc` performs pixel-wise calculations on one or more input GDAL datasets. Calculations +can be performed eagerly, writing results to a conventional raster format, +or lazily, written as a set of derived bands in a :ref:`VRT (Virtual Dataset) `. + +The list of input GDAL datasets can be specified at the end +of the command line or put in a text file (one input per line) for very long lists. +If more than one input dataset is used, it should be prefixed with a name by which it +will be referenced in the calculation, e.g. ``A=my_data.tif``. (If a single dataset is +used, it will be referred to with the variable ``X`` in formulas.) + +The inputs should have the same spatial reference system and should cover the same spatial extent but are not required to have the same +spatial resolution. The spatial extent check can be disabled with :option:`--no-check-extent`, +in which case the inputs must have the same dimensions. The spatial reference system check can be +disabled with :option:`--no-check-srs`. + +The following options are available: + +.. include:: gdal_options/of_raster_create_copy.rst + +.. include:: gdal_options/co.rst + +.. include:: gdal_options/overwrite.rst + +.. option:: -i [=] + + Select an input dataset to be processed. If more than one input dataset is provided, + each dataset must be prefixed with a name to which it will will be referenced in :option:`--calc`. + +.. option:: --calc + + An expression to be evaluated using the `muparser `__ math parser library. + The expression may refer to individual bands of each input (e.g., ``X[1] + 3``) or it may be applied to all bands + of an input (``X + 3``). If the expression contains a reference to all bands of multiple inputs, those inputs + must either have the same the number of bands, or a single band. + + For example, if inputs ``A`` and ``B`` each have three bands, and input ``C`` has a single band, then the argument + ``--calc "A + B + C"`` is equivalent to ``--calc "A[1] + B[1] + C[1]" --calc "A[2] + B[2] + C[1]" --calc "A[3] + B[3] + C[1]"``. + + Multiple calculations may be specified; output band(s) will be produced for each expression in the order they + are provided. + + Input rasters will be converted to 64-bit floating point numbers before performing calculations. + +.. option:: --no-check-extent + + Do not verify that the input rasters have the same spatial extent. The input rasters will instead be required to + have the same dimensions. The geotransform of the first input will be assigned to the output. + +.. option:: --no-check-srs + + Do not check the spatial reference systems of the inputs for consistency. All inputs will be assumed to have the + spatial reference system of the first input, and this spatial reference system will be used for the output. + +Examples +-------- + +.. example:: + :title: Per-band sum of three files + :id: simple-sum + + .. code-block:: bash + + gdal raster calc -i "A=file1.tif" -i "B=file2.tif" -i "C=file3.tif" --calc "A+B+C" -o out.tif + + +.. example:: + :title: Per-band maximum of three files + :id: simple-max + + .. code-block:: bash + + gdal raster calc -i "A=file1.tif" -i "B=file2.tif" -i "C=file3.tif" --calc "max(A,B,C)" -o out.tif + + +.. example:: + :title: Setting values of zero and below to NaN + + .. code-block:: bash + + gdal_calc -i "A=input.tif" -o=result.tif --calc="A > 0 ? A : NaN" diff --git a/doc/source/programs/index.rst b/doc/source/programs/index.rst index 23afc8cb9aab..7a7927d3a6cb 100644 --- a/doc/source/programs/index.rst +++ b/doc/source/programs/index.rst @@ -32,6 +32,7 @@ single :program:`gdal` program that accepts commands and subcommands. gdal_convert gdal_raster gdal_raster_info + gdal_raster_calc gdal_raster_clip gdal_raster_convert gdal_raster_edit @@ -60,6 +61,7 @@ single :program:`gdal` program that accepts commands and subcommands. - :ref:`gdal_convert_command`: Convert a dataset - :ref:`gdal_raster_command`: Entry point for raster commands - :ref:`gdal_raster_info_subcommand`: Get information on a raster dataset + - :ref:`gdal_raster_calc_subcommand`: Perform raster algebra - :ref:`gdal_raster_clip_subcommand`: Clip a raster dataset - :ref:`gdal_raster_convert_subcommand`: Convert a raster dataset - :ref:`gdal_raster_edit_subcommand`: Edit in place a raster dataset diff --git a/frmts/vrt/vrtexpression_muparser.cpp b/frmts/vrt/vrtexpression_muparser.cpp index d9e747302c21..3910db4573dd 100644 --- a/frmts/vrt/vrtexpression_muparser.cpp +++ b/frmts/vrt/vrtexpression_muparser.cpp @@ -13,19 +13,65 @@ #include "vrtexpression.h" #include "cpl_string.h" +#include #include +#include #include namespace gdal { /*! @cond Doxygen_Suppress */ + +static std::optional Sanitize(const std::string &osVariable) +{ + // muparser does not allow characters '[' or ']' which we use to emulate + // vectors. Replace these with a combination of underscores + auto from = osVariable.find('['); + if (from != std::string::npos) + { + auto to = osVariable.find(']'); + if (to != std::string::npos) + { + auto sanitized = std::string("__") + osVariable.substr(0, from) + + +"__" + + osVariable.substr(from + 1, to - from - 1) + "__"; + return sanitized; + } + } + + return std::nullopt; +} + +static void ReplaceVariable(std::string &expression, + const std::string &variable, + const std::string &sanitized) +{ + std::string::size_type seekPos = 0; + auto pos = expression.find(variable, seekPos); + while (pos != std::string::npos) + { + auto end = pos + variable.size(); + + if (pos == 0 || + (!std::isalnum(expression[pos - 1]) && expression[pos - 1] != '_')) + { + expression = + expression.substr(0, pos) + sanitized + expression.substr(end); + } + + seekPos = end; + pos = expression.find(variable, seekPos); + } +} + class MuParserExpression::Impl { public: explicit Impl(std::string_view osExpression) - : m_osExpression(std::string(osExpression)), m_oVectors{}, m_oParser{}, - m_adfResults{1}, m_bIsCompiled{false}, m_bCompileFailed{false} + : m_osExpression(std::string(osExpression)), m_oSubstitutions{}, + m_oParser{}, m_adfResults{1}, m_bIsCompiled{false}, m_bCompileFailed{ + false} { } @@ -50,13 +96,26 @@ class MuParserExpression::Impl return CE_Failure; } + // On some platforms muparser does not seem to parse "nan" as a floating + // point literal. try { - CPLString tmpExpression(m_osExpression); + m_oParser.DefineConst("nan", + std::numeric_limits::quiet_NaN()); + m_oParser.DefineConst("NaN", + std::numeric_limits::quiet_NaN()); + } + catch (const mu::Parser::exception_type &) + { + } + + try + { + std::string tmpExpression(m_osExpression); - for (const auto &[osVec, osElems] : m_oVectors) + for (const auto &[osFrom, osTo] : m_oSubstitutions) { - tmpExpression.replaceAll(osVec, osElems); + ReplaceVariable(tmpExpression, osFrom, osTo); } m_oParser.SetExpr(tmpExpression); @@ -108,8 +167,8 @@ class MuParserExpression::Impl return CE_None; } - CPLString m_osExpression; - std::map m_oVectors; + const CPLString m_osExpression; + std::map m_oSubstitutions; mu::Parser m_oParser; std::vector m_adfResults; bool m_bIsCompiled; @@ -134,7 +193,12 @@ CPLErr MuParserExpression::Compile() void MuParserExpression::RegisterVariable(std::string_view osVariable, double *pdfValue) { - m_pImpl->Register(osVariable, pdfValue); + auto sanitized = Sanitize(std::string(osVariable)); + if (sanitized.has_value()) + { + m_pImpl->m_oSubstitutions[std::string(osVariable)] = sanitized.value(); + } + m_pImpl->Register(sanitized.value_or(std::string(osVariable)), pdfValue); } void MuParserExpression::RegisterVector(std::string_view osVariable, @@ -155,8 +219,9 @@ void MuParserExpression::RegisterVector(std::string_view osVariable, for (std::size_t i = 0; i < padfValues->size(); i++) { - osElementVarName.Printf("__%s_%d", osVectorVarName.c_str(), + osElementVarName.Printf("%s[%d]", osVectorVarName.c_str(), static_cast(i)); + osElementVarName = Sanitize(osElementVarName).value(); RegisterVariable(osElementVarName, padfValues->data() + i); if (i > 0) @@ -166,7 +231,7 @@ void MuParserExpression::RegisterVector(std::string_view osVariable, osElementsList += osElementVarName; } - m_pImpl->m_oVectors[std::string(osVariable)] = osElementsList; + m_pImpl->m_oSubstitutions[std::string(osVariable)] = osElementsList; } CPLErr MuParserExpression::Evaluate() diff --git a/frmts/vrt/vrtsources.cpp b/frmts/vrt/vrtsources.cpp index c16b085b950e..64bd38932d5f 100644 --- a/frmts/vrt/vrtsources.cpp +++ b/frmts/vrt/vrtsources.cpp @@ -489,6 +489,11 @@ CPLXMLNode *VRTSimpleSource::SerializeToXML(const char *pszVRTPath) CXT_Text, m_osResampling.c_str()); } + if (!m_osName.empty()) + { + CPLAddXMLAttributeAndValue(psSrc, "name", m_osName); + } + if (m_bSrcDSNameFromVRT) { CPLAddXMLChild(psSrc, CPLParseXMLString(m_osSrcDSName.c_str())); diff --git a/gcore/gdalalgorithm.h b/gcore/gdalalgorithm.h index 2853d23bb350..9a0788c9822d 100644 --- a/gcore/gdalalgorithm.h +++ b/gcore/gdalalgorithm.h @@ -423,7 +423,7 @@ class CPL_DLL GDALArgDatasetValue final /** Set dataset name */ void Set(const std::string &name); - /** Transfer dataset to this instance (does not affect is reference + /** Transfer dataset to this instance (does not affect its reference * counter). */ void Set(std::unique_ptr poDS); diff --git a/swig/include/python/gdal_python.i b/swig/include/python/gdal_python.i index db9a0ac37dcb..60d628c62723 100644 --- a/swig/include/python/gdal_python.i +++ b/swig/include/python/gdal_python.i @@ -5365,6 +5365,10 @@ class VSIFile(BytesIO): raise Exception("Unhandled algorithm argument data type") def Set(self, value): + import os + if isinstance(value, os.PathLike): + value = str(value) + type = self.GetType() if type == GAAT_BOOLEAN: return self.SetAsBoolean(value)