diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000..0b9a4028 --- /dev/null +++ b/.clang-format @@ -0,0 +1,134 @@ +BasedOnStyle: Mozilla +Language: Cpp +AccessModifierOffset: -2 +AlignAfterOpenBracket: Align +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: true +AlignEscapedNewlines: Left +AlignOperands: true +AlignTrailingComments: true +# clang 9.0 AllowAllArgumentsOnNextLine: true +# clang 9.0 AllowAllConstructorInitializersOnNextLine: true +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Inline +# clang 9.0 AllowShortLambdasOnASingleLine: All +# clang 9.0 features AllowShortIfStatementsOnASingleLine: Never +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: All +AlwaysBreakBeforeMultilineStrings: false +AlwaysBreakTemplateDeclarations: Yes +BinPackArguments: false +BinPackParameters: false +BreakBeforeBraces: Custom +BraceWrapping: + # clang 9.0 feature AfterCaseLabel: false + AfterClass: true + AfterControlStatement: true + AfterEnum: true + AfterFunction: true + AfterNamespace: true + AfterObjCDeclaration: true + AfterStruct: true + AfterUnion: true + AfterExternBlock: true + BeforeCatch: true + BeforeElse: true +## This is the big change from historical ITK formatting! +# Historically ITK used a style similar to https://en.wikipedia.org/wiki/Indentation_style#Whitesmiths_style +# with indented braces, and not indented code. This style is very difficult to automatically +# maintain with code beautification tools. Not indenting braces is more common among +# formatting tools. + IndentBraces: false + SplitEmptyFunction: false + SplitEmptyRecord: false + SplitEmptyNamespace: false +BreakBeforeBinaryOperators: None +#clang 6.0 BreakBeforeInheritanceComma: true +BreakInheritanceList: BeforeComma +BreakBeforeTernaryOperators: true +#clang 6.0 BreakConstructorInitializersBeforeComma: true +BreakConstructorInitializers: BeforeComma +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: true +## The following line allows larger lines in non-documentation code +ColumnLimit: 120 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: false +ConstructorInitializerIndentWidth: 2 +ContinuationIndentWidth: 2 +Cpp11BracedListStyle: false +DerivePointerAlignment: false +DisableFormat: false +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: true +ForEachMacros: + - foreach + - Q_FOREACH + - BOOST_FOREACH +IncludeBlocks: Preserve +IncludeCategories: + - Regex: '^"(llvm|llvm-c|clang|clang-c)/' + Priority: 2 + - Regex: '^(<|"(gtest|gmock|isl|json)/)' + Priority: 3 + - Regex: '.*' + Priority: 1 +IncludeIsMainRegex: '(Test)?$' +IndentCaseLabels: true +IndentPPDirectives: AfterHash +IndentWidth: 2 +IndentWrappedFunctionNames: false +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: true +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 2 +NamespaceIndentation: None +ObjCBinPackProtocolList: Auto +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: true +ObjCSpaceBeforeProtocolList: false +PenaltyBreakAssignment: 2 +PenaltyBreakBeforeFirstCallParameter: 19 +PenaltyBreakComment: 300 +## The following line allows larger lines in non-documentation code +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyBreakTemplateDeclaration: 10 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 200 +PointerAlignment: Middle +ReflowComments: true +# We may want to sort the includes as a separate pass +SortIncludes: false +# We may want to revisit this later +SortUsingDeclarations: false +SpaceAfterCStyleCast: false +# SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeCpp11BracedList: false +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: false +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +StatementMacros: + - Q_UNUSED + - QT_REQUIRE_VERSION +TabWidth: 2 +UseTab: Never +... diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..4fb66285 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,8 @@ +*.png filter=lfs diff=lfs merge=lfs -text +*.jp2 filter=lfs diff=lfs merge=lfs -text +*.tif filter=lfs diff=lfs merge=lfs -text +*.data-* filter=lfs diff=lfs merge=lfs -text +*.gpkg filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text +*.jpeg filter=lfs diff=lfs merge=lfs -text diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml new file mode 100644 index 00000000..9a9bd290 --- /dev/null +++ b/.gitlab-ci.yml @@ -0,0 +1,119 @@ +image: gitlab-registry.irstea.fr/remi.cresson/otbtf:2.4-cpu-basic-testing + +variables: + OTB_BUILD: /src/otb/build/OTB/build # Local OTB build directory + OTBTF_SRC: /src/otbtf # Local OTBTF source directory + OTB_TEST_DIR: $OTB_BUILD/Testing/Temporary # OTB testing directory + ARTIFACT_TEST_DIR: $CI_PROJECT_DIR/testing + CRC_BOOK_TMP: /tmp/crc_book_tests_tmp + +workflow: + rules: + - if: $CI_MERGE_REQUEST_ID # Execute jobs in merge request context + - if: $CI_COMMIT_BRANCH == 'develop' # Execute jobs when a new commit is pushed to develop branch + +stages: + - Build + - Static Analysis + - Test + - Applications Test + +.update_otbtf_src: &update_otbtf_src + - sudo rm -rf $OTBTF_SRC && sudo ln -s $PWD $OTBTF_SRC # Replace local OTBTF source directory + +.compile_otbtf: &compile_otbtf + - cd $OTB_BUILD && sudo make install -j$(nproc --all) # Rebuild OTB with new OTBTF sources + +.install_pytest: &install_pytest + - pip3 install pytest pytest-cov pytest-order # Install pytest stuff + +before_script: + - *update_otbtf_src + +build: + stage: Build + allow_failure: false + script: + - *compile_otbtf + +flake8: + stage: Static Analysis + allow_failure: true + script: + - sudo apt update && sudo apt install flake8 -y + - python -m flake8 --max-line-length=120 $OTBTF_SRC/python + +pylint: + stage: Static Analysis + allow_failure: true + script: + - sudo apt update && sudo apt install pylint -y + - pylint --disable=too-many-nested-blocks,too-many-locals,too-many-statements,too-few-public-methods,too-many-instance-attributes,too-many-arguments --ignored-modules=tensorflow --max-line-length=120 --logging-format-style=new $OTBTF_SRC/python + +codespell: + stage: Static Analysis + allow_failure: true + script: + - sudo pip install codespell && codespell + +cppcheck: + stage: Static Analysis + allow_failure: true + script: + - sudo apt update && sudo apt install cppcheck -y + - cd $OTBTF_SRC/ && cppcheck --enable=all --error-exitcode=1 -I include/ --suppress=missingInclude --suppress=unusedFunction . + +ctest: + stage: Test + script: + - *compile_otbtf + - sudo rm -rf $OTB_TEST_DIR/* # Empty testing temporary folder (old files here) + - cd $OTB_BUILD/ && sudo ctest -L OTBTensorflow # Run ctest + after_script: + - cp -r $OTB_TEST_DIR $ARTIFACT_TEST_DIR + artifacts: + paths: + - $ARTIFACT_TEST_DIR/*.* + expire_in: 1 week + when: on_failure + +.applications_test_base: + stage: Applications Test + rules: + # Only for MR targeting 'develop' branch because applications tests are slow + - if: $CI_MERGE_REQUEST_ID && $CI_MERGE_REQUEST_TARGET_BRANCH_NAME == 'develop' + artifacts: + when: on_failure + paths: + - $CI_PROJECT_DIR/report_*.xml + - $ARTIFACT_TEST_DIR/*.* + expire_in: 1 week + +crc_book: + extends: .applications_test_base + script: + - *compile_otbtf + - *install_pytest + - cd $CI_PROJECT_DIR + - mkdir -p $CRC_BOOK_TMP + - TMPDIR=$CRC_BOOK_TMP DATADIR=$CI_PROJECT_DIR/test/data python -m pytest --junitxml=$CI_PROJECT_DIR/report_tutorial.xml $OTBTF_SRC/test/tutorial_unittest.py + after_script: + - mkdir -p $ARTIFACT_TEST_DIR + - cp $CRC_BOOK_TMP/*.* $ARTIFACT_TEST_DIR/ + +sr4rs: + extends: .applications_test_base + script: + - *compile_otbtf + - *install_pytest + - cd $CI_PROJECT_DIR + - wget -O sr4rs_sentinel2_bands4328_france2020_savedmodel.zip + https://nextcloud.inrae.fr/s/boabW9yCjdpLPGX/download/sr4rs_sentinel2_bands4328_france2020_savedmodel.zip + - unzip -o sr4rs_sentinel2_bands4328_france2020_savedmodel.zip + - wget -O sr4rs_data.zip https://nextcloud.inrae.fr/s/qMLLyKCDieqmgWz/download + - unzip -o sr4rs_data.zip + - rm -rf sr4rs + - git clone https://github.com/remicres/sr4rs.git + - export PYTHONPATH=$PYTHONPATH:$PWD/sr4rs + - python -m pytest --junitxml=$CI_PROJECT_DIR/report_sr4rs.xml $OTBTF_SRC/test/sr4rs_unittest.py + diff --git a/CMakeLists.txt b/CMakeLists.txt index f0bc92a6..0a4d646a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,7 +15,7 @@ if(OTB_USE_TENSORFLOW) find_library(TENSORFLOW_FRAMEWORK_LIB NAMES libtensorflow_framework) set(TENSORFLOW_LIBS "${TENSORFLOW_CC_LIB}" "${TENSORFLOW_FRAMEWORK_LIB}") - + set(OTBTensorflow_THIRD_PARTY "this is a hack to skip header_tests") else() message("Tensorflow support disabled") endif() diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md new file mode 100644 index 00000000..cd39540e --- /dev/null +++ b/CONTRIBUTORS.md @@ -0,0 +1,8 @@ +- Remi Cresson +- Nicolas Narcon +- Benjamin Commandre +- Vincent Delbar +- Loic Lozac'h +- Pratyush Das +- Doctor Who +- Jordi Inglada diff --git a/Dockerfile b/Dockerfile index 53e3f777..c9aa6e9f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -47,7 +47,7 @@ RUN wget -qO /opt/otbtf/bin/bazelisk https://github.com/bazelbuild/bazelisk/rele && ln -s /opt/otbtf/bin/bazelisk /opt/otbtf/bin/bazel ARG BZL_TARGETS="//tensorflow:libtensorflow_cc.so //tensorflow/tools/pip_package:build_pip_package" -# "--config=opt" will enable 'march=native' (otherwise read comments about CPU compatibilty and edit CC_OPT_FLAGS in build-env-tf.sh) +# "--config=opt" will enable 'march=native' (otherwise read comments about CPU compatibility and edit CC_OPT_FLAGS in build-env-tf.sh) ARG BZL_CONFIGS="--config=nogcp --config=noaws --config=nohdfs --config=opt" # "--compilation_mode opt" is already enabled by default (see tf repo .bazelrc and configure.py) ARG BZL_OPTIONS="--verbose_failures --remote_cache=http://localhost:9090" @@ -77,7 +77,7 @@ RUN git clone --single-branch -b $TF https://github.com/tensorflow/tensorflow.gi # Symlink external libs (required for MKL - libiomp5) && for f in $(find -L /opt/otbtf/include/tf -wholename "*/external/*/*.so"); do ln -s $f /opt/otbtf/lib/; done \ # Compress and save TF binaries - && ( ! $ZIP_TF_BIN || zip -9 -j --symlinks /opt/otbtf/tf-$TF.zip tensorflow/cc/saved_model/tag_constants.h bazel-bin/tensorflow/libtensorflow_cc.so* /tmp/tensorflow_pkg/tensorflow*.whl ) \ + && ( ! $ZIP_TF_BIN || zip -9 -j --symlinks /opt/otbtf/tf-$TF.zip tensorflow/cc/saved_model/tag_constants.h tensorflow/cc/saved_model/signature_constants.h bazel-bin/tensorflow/libtensorflow_cc.so* /tmp/tensorflow_pkg/tensorflow*.whl ) \ # Cleaning && rm -rf bazel-* /src/tf /root/.cache/ /tmp/* diff --git a/LICENSE b/LICENSE index a1a86c14..236ca028 100644 --- a/LICENSE +++ b/LICENSE @@ -188,7 +188,7 @@ identification within third-party archives. Copyright 2018-2019 Rémi Cresson (IRSTEA) - Copyright 2020 Rémi Cresson (INRAE) + Copyright 2020-2021 Rémi Cresson (INRAE) Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/README.md b/README.md index a18961b3..d29b7a64 100644 --- a/README.md +++ b/README.md @@ -18,21 +18,23 @@ Applications can be used to build OTB pipelines from Python or C++ APIs. ### Python -This is a work in progress. For now, `tricks.py` provides a set of helpers to build deep nets, and `otbtf.py` provides datasets which can be used in Tensorflow pipelines to train networks from python. +`otbtf.py` targets python developers that want to train their own model from python with TensorFlow or Keras. +It provides various classes for datasets and iterators to handle the _patches images_ generated from the `PatchesExtraction` OTB application. +For instance, the `otbtf.Dataset` class provides a method `get_tf_dataset()` which returns a `tf.dataset` that can be used in your favorite TensorFlow pipelines, or convert your patches into TFRecords. -## Portfolio +`tricks.py` is here for backward compatibility with codes based on OTBTF 1.x and 2.x. -Below are some screen captures of deep learning applications performed at large scale with OTBTF. - - Image to image translation (Spot-7 image --> Wikimedia Map using CGAN) - +## Examples +Below are some screen captures of deep learning applications performed at large scale with OTBTF. - Landcover mapping (Spot-7 images --> Building map using semantic segmentation) - - Image enhancement (Enhancement of Sentinel-2 images at 1.5m using SRGAN) + - Super resolution (Sentinel-2 images upsampled with the [SR4RS software](https://github.com/remicres/sr4rs), which is based on OTBTF) -You can read more details about these applications on [this blog](https://mdl4eo.irstea.fr/2019/) + - Image to image translation (Spot-7 image --> Wikimedia Map using CGAN. So unnecessary but fun!) + ## How to install @@ -42,8 +44,8 @@ For now you have two options: either use the existing **docker image**, or build Use the latest image from dockerhub: ``` -docker pull mdl4eo/otbtf2.5:cpu -docker run -u otbuser -v $(pwd):/home/otbuser mdl4eo/otbtf2.5:cpu otbcli_PatchesExtraction -help +docker pull mdl4eo/otbtf3.0:cpu +docker run -u otbuser -v $(pwd):/home/otbuser mdl4eo/otbtf3.0:cpu otbcli_PatchesExtraction -help ``` Read more in the [docker use documentation](doc/DOCKERUSE.md). @@ -59,10 +61,11 @@ Read more in the [build from sources documentation](doc/HOWTOBUILD.md). - in the `python` folder are provided some [ready-to-use deep networks, with documentation and scientific references](doc/EXAMPLES.md). - A book: *Cresson, R. (2020). Deep Learning for Remote Sensing Images with Open Source Software. CRC Press.* Use QGIS, OTB and Tensorflow to perform various kind of deep learning sorcery on remote sensing images (patch-based classification for landcover mapping, semantic segmentation of buildings, optical image restoration from joint SAR/Optical time series). - Check [our repository](https://github.com/remicres/otbtf_tutorials_resources) containing stuff (data and models) to begin with with! +- Finally, take a look in the `test` folder. You will find plenty of command lines for applications tests! ## Contribute -Every one can **contribute** to OTBTF! Don't be shy. +Every one can **contribute** to OTBTF. Just open a PR :) ## Cite diff --git a/RELEASE_NOTES.txt b/RELEASE_NOTES.txt new file mode 100644 index 00000000..76403305 --- /dev/null +++ b/RELEASE_NOTES.txt @@ -0,0 +1,99 @@ +Version 3.0.0-beta (20 nov 2021) +---------------------------------------------------------------- +* Use Tensorflow 2 API everywhere. Everything is backward compatible (old models can still be used). +* Support models with no-named inputs and outputs. OTBTF now can resolve the names! :) Just in the same order as they are defined in the computational graph. +* Support user placeholders of type vector (int, float or bool) +* More unit tests, spell check, better static analysis of C++ and python code +* Improve the handling of 3-dimensional output tensors, + more explanation in error messages about output tensors dimensions. +* Improve `PatchesSelection` to locate patches centers with corners or pixels centers depending if the patch size is odd or even. + +Version 2.5 (20 oct 2021) +---------------------------------------------------------------- +* Fix a bug in otbtf.py. The `PatchesImagesReader` wasn't working properly when the streaming was disabled. +* Improve the documentation on docker build and docker use (Thanks to Vincent@LaTelescop and Doctor-Who). + +Version 2.4 (11 apr 2021) +---------------------------------------------------------------- +* Fix a bug: The output image origin was sometimes shifted from a fraction of pixel. This issue happened only with multi-inputs models that have inputs of different spacing. +* Improvement: The output image largest possible region is now computed on the maximum possible area within the expression field. Before that, the largest possible region was too much cropped when an expression field > 1 was used. Now output images are a bit larger when a non unitary expression field is used. + +Version 2.3 (30 mar 2021) +---------------------------------------------------------------- +* More supported numeric types for tensors: + * `tensorflow::DT_FLOAT` + * `tensorflow::DT_DOUBLE` + * `tensorflow::DT_UINT64` + * `tensorflow::DT_INT64` + * `tensorflow::DT_UINT32` + * `tensorflow::DT_INT32` + * `tensorflow::DT_UINT16` + * `tensorflow::DT_INT16` + * `tensorflow::DT_UINT8` +* Update instructions to use docker + +Version 2.2 (29 jan 2021) +---------------------------------------------------------------- +* Huge enhancement of the docker image build (from Vincent@LaTeleScop) + +Version 2.1 (17 nov 2020) +---------------------------------------------------------------- +* New OTBTF python classes to train the models: + * `PatchesReaderBase`: base abstract class for patches readers. Users/developers can implement their own from it! + * `PatchesImagesReader`: a class implementing `PatchesReaderBase` to access the patches images, as they are produced by the OTBTF PatchesExtraction application. + * `IteratorBase`: base class to iterate on `PatchesReaderBase`-derived readers. + * `RandomIterator`: an iterator implementing `IteratorBase` designed to randomly access elements. + * `Dataset`: generic class to build datasets, consisting essentially of the assembly of a `PatchesReaderBase`-derived reader, and a `IteratorBase`-derived iterator. The `Dataset` handles the gathering of the data using a thread. It can be used as a `tf.dataset` to feed computational graphs. + * `DatasetFromPatchesImages`: a `Dataset` that uses a `PatchesImagesReader` to allow users/developers to stream their patches generated using the OTBTF PatchesExtraction through a `tf.dataset` which implements a streaming mechanism, enabling low memory footprint and high performance I/O thank to a threaded reading mechanism. +* Fix in dockerfile (from Pratyush Das) to work with WSL2 + +Version 2.0 (29 may 2020) +---------------------------------------------------------------- +* Now using TensorFlow 2.0! Some minor migration of python models, because we stick with `tf.compat.v1`. +* Python functions to read patches now use GDAL +* Lighter docker images (thanks to Vincent@LaTeleScop) + +Version 1.8.0 (14 jan 2020) +---------------------------------------------------------------- +* PatchesExtraction supports no-data (a different value for each source can be set) +* New sampling strategy available in PatchesSelection (balanced strategy) + +Version 1.7.0 (15 oct 2019) +---------------------------------------------------------------- +* Add a new application for patches selection (experimental) +* New docker images that are GPU-enabled using NVIDIA runtime + +Version 1.6.0 (18 jul 2019) +---------------------------------------------------------------- +* Fix a bug related to coordinates tolerance (TensorflowModelTrain can now use patches that do not occupy physically the same space) +* Fix dockerfile (add important environment variables, add a non-root user, add an example how to run the docker image) +* Document the provided Gaetano et al. two-branch CNN + +Version 1.5.1 (18 jun 2019) +---------------------------------------------------------------- +* Ubuntu bionic dockerfile + instructions +* Doc tags for QGIS3 integration +* Add cmake tests (3 models are tested in various configuration on Pan/XS images) +* PatchesExtraction writes patches images with physical spacing + +Version 1.3.0 (18 nov 2018) +---------------------------------------------------------------- +* Add 3 models that can be directly trained with TensorflowModelTrain (one CNN net, one FCN net, one 2-branch CNN net performing separately on PAN and XS images) +* Fix a bug occurring when using a scale factor <> 1 with a non-unit expression field +* Fix incorrect batch size in learning filters when batch size was not a multiple of number of batches +* Add some documentation + +Version 1.2.0 (29 sep 2018) +---------------------------------------------------------------- +* Fix typos in documentation +* Add a new application for dense polygon classes statistics +* Fix a bug in validation step +* Add streaming option of training/validation +* Document filters classes +* Change applications parameters names and roles +* Add a python application that converts a graph into a savedmodel +* Adjust tiling to expression field +* Update license + +Version 1.0.0 (16 may 2018) +---------------------------------------------------------------- +* First release of OTBTF! diff --git a/app/otbDensePolygonClassStatistics.cxx b/app/otbDensePolygonClassStatistics.cxx index 8ee2283e..fa7c2701 100644 --- a/app/otbDensePolygonClassStatistics.cxx +++ b/app/otbDensePolygonClassStatistics.cxx @@ -1,6 +1,7 @@ /*========================================================================= - Copyright (c) Remi Cresson (IRSTEA). All rights reserved. + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -8,18 +9,24 @@ PURPOSE. See the above copyright notices for more information. =========================================================================*/ -#include "otbWrapperApplication.h" +#include "itkFixedArray.h" +#include "itkObjectFactory.h" #include "otbWrapperApplicationFactory.h" +// Application engine +#include "otbStandardFilterWatcher.h" +#include "itkFixedArray.h" + +// Filters #include "otbStatisticsXMLFileWriter.h" #include "otbWrapperElevationParametersHandler.h" - #include "otbVectorDataToLabelImageFilter.h" #include "otbImageToNoDataMaskFilter.h" #include "otbStreamingStatisticsMapFromLabelImageFilter.h" #include "otbVectorDataIntoImageProjectionFilter.h" #include "otbImageToVectorImageCastFilter.h" +// OGR #include "otbOGR.h" namespace otb @@ -36,15 +43,14 @@ class DensePolygonClassStatistics : public Application { public: /** Standard class typedefs. */ - typedef DensePolygonClassStatistics Self; + typedef DensePolygonClassStatistics Self; typedef Application Superclass; typedef itk::SmartPointer Pointer; typedef itk::SmartPointer ConstPointer; /** Standard macro */ itkNewMacro(Self); - - itkTypeMacro(DensePolygonClassStatistics, otb::Application); + itkTypeMacro(DensePolygonClassStatistics, Application); /** DataObjects typedef */ typedef UInt32ImageType LabelImageType; @@ -66,14 +72,7 @@ class DensePolygonClassStatistics : public Application typedef otb::StatisticsXMLFileWriter StatWriterType; - -private: - DensePolygonClassStatistics() - { - - } - - void DoInit() override + void DoInit() { SetName("DensePolygonClassStatistics"); SetDescription("Computes statistics on a training polygon set."); @@ -87,7 +86,6 @@ class DensePolygonClassStatistics : public Application " - number of samples per geometry\n"); SetDocLimitations("None"); SetDocAuthors("Remi Cresson"); - SetDocSeeAlso(" "); AddDocTag(Tags::Learning); @@ -114,67 +112,11 @@ class DensePolygonClassStatistics : public Application SetDocExampleParameterValue("field", "label"); SetDocExampleParameterValue("out","polygonStat.xml"); - SetOfficialDocLink(); - } - - void DoUpdateParameters() override - { - if ( HasValue("vec") ) - { - std::string vectorFile = GetParameterString("vec"); - ogr::DataSource::Pointer ogrDS = - ogr::DataSource::New(vectorFile, ogr::DataSource::Modes::Read); - ogr::Layer layer = ogrDS->GetLayer(0); - ogr::Feature feature = layer.ogr().GetNextFeature(); - - ClearChoices("field"); - - for(int iField=0; iFieldGetNameRef(); - key = item; - std::string::iterator end = std::remove_if(key.begin(),key.end(),IsNotAlphaNum); - std::transform(key.begin(), end, key.begin(), tolower); - - OGRFieldType fieldType = feature.ogr().GetFieldDefnRef(iField)->GetType(); - - if(fieldType == OFTString || fieldType == OFTInteger || fieldType == OFTInteger64) - { - std::string tmpKey="field."+key.substr(0, end - key.begin()); - AddChoice(tmpKey,item); - } - } - } - - // Check that the extension of the output parameter is XML (mandatory for - // StatisticsXMLFileWriter) - // Check it here to trigger the error before polygons analysis - - if ( HasValue("out") ) - { - // Store filename extension - // Check that the right extension is given : expected .xml - const std::string extension = itksys::SystemTools::GetFilenameLastExtension(this->GetParameterString("out")); - - if (itksys::SystemTools::LowerCase(extension) != ".xml") - { - otbAppLogFATAL( << extension << " is a wrong extension for parameter \"out\": Expected .xml" ); - } - } } - void DoExecute() override + void DoExecute() { - // Filters - VectorDataReprojFilterType::Pointer m_VectorDataReprojectionFilter; - RasterizeFilterType::Pointer m_RasterizeFIDFilter; - RasterizeFilterType::Pointer m_RasterizeClassFilter; - NoDataMaskFilterType::Pointer m_NoDataFilter; - CastFilterType::Pointer m_NoDataCastFilter; - StatsFilterType::Pointer m_FIDStatsFilter; - StatsFilterType::Pointer m_ClassStatsFilter; - // Retrieve the field name std::vector selectedCFieldIdx = GetSelectedItems("field"); @@ -253,14 +195,72 @@ class DensePolygonClassStatistics : public Application fidMap.erase(intNoData); classMap.erase(intNoData); - StatWriterType::Pointer statWriter = StatWriterType::New(); - statWriter->SetFileName(this->GetParameterString("out")); - statWriter->AddInputMap("samplesPerClass",classMap); - statWriter->AddInputMap("samplesPerVector",fidMap); - statWriter->Update(); + m_StatWriter = StatWriterType::New(); + m_StatWriter->SetFileName(this->GetParameterString("out")); + m_StatWriter->AddInputMap("samplesPerClass", classMap); + m_StatWriter->AddInputMap("samplesPerVector", fidMap); + m_StatWriter->Update(); } - + + void DoUpdateParameters() + { + if (HasValue("vec")) + { + std::string vectorFile = GetParameterString("vec"); + ogr::DataSource::Pointer ogrDS = + ogr::DataSource::New(vectorFile, ogr::DataSource::Modes::Read); + ogr::Layer layer = ogrDS->GetLayer(0); + ogr::Feature feature = layer.ogr().GetNextFeature(); + + ClearChoices("field"); + + for(int iField=0; iFieldGetNameRef(); + key = item; + std::string::iterator end = std::remove_if(key.begin(),key.end(),IsNotAlphaNum); + std::transform(key.begin(), end, key.begin(), tolower); + + OGRFieldType fieldType = feature.ogr().GetFieldDefnRef(iField)->GetType(); + + if(fieldType == OFTString || fieldType == OFTInteger || fieldType == OFTInteger64) + { + std::string tmpKey="field."+key.substr(0, end - key.begin()); + AddChoice(tmpKey,item); + } + } + } + + // Check that the extension of the output parameter is XML (mandatory for + // StatisticsXMLFileWriter) + // Check it here to trigger the error before polygons analysis + + if (HasValue("out")) + { + // Store filename extension + // Check that the right extension is given : expected .xml + const std::string extension = itksys::SystemTools::GetFilenameLastExtension(this->GetParameterString("out")); + + if (itksys::SystemTools::LowerCase(extension) != ".xml") + { + otbAppLogFATAL( << extension << " is a wrong extension for parameter \"out\": Expected .xml" ); + } + } + } + + + +private: + // Filters + VectorDataReprojFilterType::Pointer m_VectorDataReprojectionFilter; + RasterizeFilterType::Pointer m_RasterizeFIDFilter; + RasterizeFilterType::Pointer m_RasterizeClassFilter; + NoDataMaskFilterType::Pointer m_NoDataFilter; + CastFilterType::Pointer m_NoDataCastFilter; + StatsFilterType::Pointer m_FIDStatsFilter; + StatsFilterType::Pointer m_ClassStatsFilter; + StatWriterType::Pointer m_StatWriter; }; diff --git a/app/otbImageClassifierFromDeepFeatures.cxx b/app/otbImageClassifierFromDeepFeatures.cxx index f861da41..f3ffd273 100644 --- a/app/otbImageClassifierFromDeepFeatures.cxx +++ b/app/otbImageClassifierFromDeepFeatures.cxx @@ -1,6 +1,7 @@ /*========================================================================= - Copyright (c) Remi Cresson (IRSTEA). All rights reserved. + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -60,7 +61,6 @@ class ImageClassifierFromDeepFeatures : public CompositeApplication // Populate group ShareParameter(ss_key_group.str(), "tfmodel." + ss_key_group.str(), ss_desc_group.str()); - } @@ -106,10 +106,8 @@ class ImageClassifierFromDeepFeatures : public CompositeApplication ShareParameter("out" , "classif.out" , "Output image" , "Output image" ); ShareParameter("confmap" , "classif.confmap" , "Confidence map image", "Confidence map image"); ShareParameter("ram" , "classif.ram" , "Ram" , "Ram" ); - } - void DoUpdateParameters() { UpdateInternalParameters("classif"); @@ -121,12 +119,8 @@ class ImageClassifierFromDeepFeatures : public CompositeApplication GetInternalApplication("classif")->SetParameterInputImage("in", GetInternalApplication("tfmodel")->GetParameterOutputImage("out")); UpdateInternalParameters("classif"); ExecuteInternal("classif"); - } // DOExecute() - - void AfterExecuteAndWriteOutputs() - { - // Nothing to do } + }; } // namespace Wrapper } // namespace otb diff --git a/app/otbLabelImageSampleSelection.cxx b/app/otbLabelImageSampleSelection.cxx index 5364453d..50396fa0 100644 --- a/app/otbLabelImageSampleSelection.cxx +++ b/app/otbLabelImageSampleSelection.cxx @@ -1,6 +1,7 @@ /*========================================================================= - Copyright (c) Remi Cresson (IRSTEA). All rights reserved. + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -34,7 +35,7 @@ class LabelImageSampleSelection : public Application { public: /** Standard class typedefs. */ - typedef LabelImageSampleSelection Self; + typedef LabelImageSampleSelection Self; typedef Application Superclass; typedef itk::SmartPointer Pointer; typedef itk::SmartPointer ConstPointer; @@ -384,4 +385,4 @@ class LabelImageSampleSelection : public Application } // end namespace wrapper } // end namespace otb -OTB_APPLICATION_EXPORT( otb::Wrapper::LabelImageSampleSelection ) +OTB_APPLICATION_EXPORT(otb::Wrapper::LabelImageSampleSelection) diff --git a/app/otbPatchesExtraction.cxx b/app/otbPatchesExtraction.cxx index df13cbc3..7b0ce456 100644 --- a/app/otbPatchesExtraction.cxx +++ b/app/otbPatchesExtraction.cxx @@ -1,6 +1,7 @@ /*========================================================================= - Copyright (c) Remi Cresson (IRSTEA). All rights reserved. + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -55,8 +56,6 @@ class PatchesExtraction : public Application TFSourceType m_ImageSource; // Image source FloatVectorImageType::SizeType m_PatchSize; // Patch size - unsigned int m_NumberOfElements; // Number of output samples - std::string m_KeyIn; // Key of input image list std::string m_KeyOut; // Key of output samples image std::string m_KeyPszX; // Key for samples sizes X @@ -137,10 +136,6 @@ class PatchesExtraction : public Application } } - void DoUpdateParameters() - { - } - void DoInit() { @@ -238,6 +233,12 @@ class PatchesExtraction : public Application } } + + + void DoUpdateParameters() + { + } + private: std::vector m_Bundles; @@ -246,4 +247,4 @@ class PatchesExtraction : public Application } // end namespace wrapper } // end namespace otb -OTB_APPLICATION_EXPORT( otb::Wrapper::PatchesExtraction ) +OTB_APPLICATION_EXPORT(otb::Wrapper::PatchesExtraction) diff --git a/app/otbPatchesSelection.cxx b/app/otbPatchesSelection.cxx index 170fc9b6..5d8165a0 100644 --- a/app/otbPatchesSelection.cxx +++ b/app/otbPatchesSelection.cxx @@ -1,6 +1,7 @@ /*========================================================================= - Copyright (c) Remi Cresson (IRSTEA). All rights reserved. + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -67,7 +68,7 @@ class PatchesSelection : public Application { public: /** Standard class typedefs. */ - typedef PatchesSelection Self; + typedef PatchesSelection Self; typedef Application Superclass; typedef itk::SmartPointer Pointer; typedef itk::SmartPointer ConstPointer; @@ -99,11 +100,6 @@ class PatchesSelection : public Application typedef itk::MaskImageFilter MaskImageFilterType; - void DoUpdateParameters() - { - } - - void DoInit() { @@ -166,22 +162,15 @@ class PatchesSelection : public Application { public: SampleBundle(){} - SampleBundle(unsigned int nClasses){ - dist = DistributionType(nClasses); - id = 0; + explicit SampleBundle(unsigned int nClasses): dist(DistributionType(nClasses)), id(0), black(true){ (void) point; - black = true; (void) index; } ~SampleBundle(){} - SampleBundle(const SampleBundle & other){ - dist = other.GetDistribution(); - id = other.GetSampleID(); - point = other.GetPosition(); - black = other.GetBlack(); - index = other.GetIndex(); - } + SampleBundle(const SampleBundle & other): dist(other.GetDistribution()), id(other.GetSampleID()), + point(other.GetPosition()), black(other.GetBlack()), index(other.GetIndex()) + {} DistributionType GetDistribution() const { @@ -252,6 +241,10 @@ class PatchesSelection : public Application int userOffX = GetParameterInt("grid.offsetx"); int userOffY = GetParameterInt("grid.offsety"); + // Tell if the patch size is odd or even + const bool isEven = GetParameterInt("grid.psize") % 2 == 0; + otbAppLogINFO("Patch size is even: " << isEven); + // Explicit streaming over the morphed mask, based on the RAM parameter typedef otb::RAMDrivenStrippedStreamingManager StreamingManagerType; StreamingManagerType::Pointer m_StreamingManager = StreamingManagerType::New(); @@ -327,9 +320,13 @@ class PatchesSelection : public Application // Compute coordinates UInt8ImageType::PointType geo; inputImage->TransformIndexToPhysicalPoint(inIt.GetIndex(), geo); - DataNodeType::PointType point; - point[0] = geo[0]; - point[1] = geo[1]; + + // Update geo if we want the corner or the center + if (isEven) + { + geo[0] -= 0.5 * std::abs(inputImage->GetSpacing()[0]); + geo[1] -= 0.5 * std::abs(inputImage->GetSpacing()[1]); + } // Lambda call lambda(pos, geo); @@ -538,7 +535,7 @@ class PatchesSelection : public Application PopulateVectorData(seed); } - void PopulateVectorData(std::vector & samples) + void PopulateVectorData(const std::vector & samples) { // Get data tree DataTreeType::Pointer treeTrain = m_OutVectorDataTrain->GetDataTree(); @@ -656,6 +653,11 @@ class PatchesSelection : public Application } + + void DoUpdateParameters() + { + } + private: RadiusType m_Radius; IsNoDataFilterType::Pointer m_NoDataFilter; diff --git a/app/otbTensorflowModelServe.cxx b/app/otbTensorflowModelServe.cxx index 0aa9d71e..47a8c957 100644 --- a/app/otbTensorflowModelServe.cxx +++ b/app/otbTensorflowModelServe.cxx @@ -1,6 +1,7 @@ /*========================================================================= - Copyright (c) Remi Cresson (IRSTEA). All rights reserved. + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -16,9 +17,8 @@ #include "otbStandardFilterWatcher.h" #include "itkFixedArray.h" -// Tensorflow stuff -#include "tensorflow/core/public/session.h" -#include "tensorflow/core/platform/env.h" +// Tensorflow SavedModel +#include "tensorflow/cc/saved_model/loader.h" // Tensorflow model filter #include "otbTensorflowMultisourceModelFilter.h" @@ -62,10 +62,6 @@ class TensorflowModelServe : public Application /** Typedefs for images */ typedef FloatVectorImageType::SizeType SizeType; - void DoUpdateParameters() - { - } - // // Store stuff related to one source // @@ -120,9 +116,12 @@ class TensorflowModelServe : public Application AddParameter(ParameterType_InputImageList, ss_key_in.str(), ss_desc_in.str() ); AddParameter(ParameterType_Int, ss_key_dims_x.str(), ss_desc_dims_x.str()); SetMinimumParameterIntValue (ss_key_dims_x.str(), 1); + SetDefaultParameterInt (ss_key_dims_x.str(), 1); AddParameter(ParameterType_Int, ss_key_dims_y.str(), ss_desc_dims_y.str()); SetMinimumParameterIntValue (ss_key_dims_y.str(), 1); + SetDefaultParameterInt (ss_key_dims_y.str(), 1); AddParameter(ParameterType_String, ss_key_ph.str(), ss_desc_ph.str()); + MandatoryOff (ss_key_ph.str()); // Add a new bundle ProcessObjectsBundle bundle; @@ -166,7 +165,7 @@ class TensorflowModelServe : public Application // Input model AddParameter(ParameterType_Group, "model", "model parameters"); - AddParameter(ParameterType_Directory, "model.dir", "TensorFlow model_save directory"); + AddParameter(ParameterType_Directory, "model.dir", "TensorFlow SavedModel directory"); MandatoryOn ("model.dir"); SetParameterDescription ("model.dir", "The model directory should contains the model Google Protobuf (.pb) and variables"); @@ -175,6 +174,8 @@ class TensorflowModelServe : public Application SetParameterDescription ("model.userplaceholders", "Syntax to use is \"placeholder_1=value_1 ... placeholder_N=value_N\""); AddParameter(ParameterType_Bool, "model.fullyconv", "Fully convolutional"); MandatoryOff ("model.fullyconv"); + AddParameter(ParameterType_StringList, "model.tagsets", "Which tags (i.e. v1.MetaGraphDefs) to load from the saved model. Currently, only one tag is supported. Can be retrieved by running `saved_model_cli show --dir your_model_dir --all`"); + MandatoryOff ("model.tagsets"); // Output tensors parameters AddParameter(ParameterType_Group, "output", "Output tensors parameters"); @@ -182,7 +183,7 @@ class TensorflowModelServe : public Application SetDefaultParameterFloat ("output.spcscale", 1.0); SetParameterDescription ("output.spcscale", "The output image size/scale and spacing*scale where size and spacing corresponds to the first input"); AddParameter(ParameterType_StringList, "output.names", "Names of the output tensors"); - MandatoryOn ("output.names"); + MandatoryOff ("output.names"); // Output Field of Expression AddParameter(ParameterType_Int, "output.efieldx", "The output expression field (width)"); @@ -246,15 +247,14 @@ class TensorflowModelServe : public Application { // Load the Tensorflow bundle - tf::LoadModel(GetParameterAsString("model.dir"), m_SavedModel); + tf::LoadModel(GetParameterAsString("model.dir"), m_SavedModel, GetParameterStringList("model.tagsets")); // Prepare inputs PrepareInputs(); // Setup filter m_TFFilter = TFModelFilterType::New(); - m_TFFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def()); - m_TFFilter->SetSession(m_SavedModel.session.get()); + m_TFFilter->SetSavedModel(&m_SavedModel); m_TFFilter->SetOutputTensors(GetParameterStringList("output.names")); m_TFFilter->SetOutputSpacingScale(GetParameterFloat("output.spcscale")); otbAppLogINFO("Output spacing ratio: " << m_TFFilter->GetOutputSpacingScale()); @@ -328,6 +328,11 @@ class TensorflowModelServe : public Application SetParameterOutputImage("out", m_TFFilter->GetOutput()); } } + + + void DoUpdateParameters() + { + } private: @@ -335,7 +340,7 @@ class TensorflowModelServe : public Application StreamingFilterType::Pointer m_StreamFilter; tensorflow::SavedModelBundle m_SavedModel; // must be alive during all the execution of the application ! - std::vector m_Bundles; + std::vector m_Bundles; }; // end of class diff --git a/app/otbTensorflowModelTrain.cxx b/app/otbTensorflowModelTrain.cxx index b37c72c3..e7901998 100644 --- a/app/otbTensorflowModelTrain.cxx +++ b/app/otbTensorflowModelTrain.cxx @@ -1,6 +1,7 @@ /*========================================================================= - Copyright (c) Remi Cresson (IRSTEA). All rights reserved. + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -16,9 +17,8 @@ #include "otbStandardFilterWatcher.h" #include "itkFixedArray.h" -// Tensorflow stuff -#include "tensorflow/core/public/session.h" -#include "tensorflow/core/platform/env.h" +// Tensorflow SavedModel +#include "tensorflow/cc/saved_model/loader.h" // Tensorflow model train #include "otbTensorflowMultisourceModelTrain.h" @@ -185,6 +185,8 @@ class TensorflowModelTrain : public Application MandatoryOff ("model.restorefrom"); AddParameter(ParameterType_String, "model.saveto", "Save model to path"); MandatoryOff ("model.saveto"); + AddParameter(ParameterType_StringList, "model.tagsets", "Which tags (i.e. v1.MetaGraphDefs) to load from the saved model. Currently, only one tag is supported. Can be retrieved by running `saved_model_cli show --dir your_model_dir --all`"); + MandatoryOff ("model.tagsets"); // Training parameters group AddParameter(ParameterType_Group, "training", "Training parameters"); @@ -360,7 +362,7 @@ class TensorflowModelTrain : public Application // // Get user placeholders // - TrainModelFilterType::DictType GetUserPlaceholders(const std::string key) + TrainModelFilterType::DictType GetUserPlaceholders(const std::string & key) { TrainModelFilterType::DictType dict; TrainModelFilterType::StringList expressions = GetParameterStringList(key); @@ -407,13 +409,15 @@ class TensorflowModelTrain : public Application { // Load the Tensorflow bundle - tf::LoadModel(GetParameterAsString("model.dir"), m_SavedModel); + tf::LoadModel(GetParameterAsString("model.dir"), m_SavedModel, GetParameterStringList("model.tagsets")); - // Check if we have to restore variables from somewhere + // Check if we have to restore variables from somewhere else if (HasValue("model.restorefrom")) { const std::string path = GetParameterAsString("model.restorefrom"); otbAppLogINFO("Restoring model from " + path); + + // Load SavedModel variables tf::RestoreModel(path, m_SavedModel); } @@ -422,8 +426,7 @@ class TensorflowModelTrain : public Application // Setup training filter m_TrainModelFilter = TrainModelFilterType::New(); - m_TrainModelFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def()); - m_TrainModelFilter->SetSession(m_SavedModel.session.get()); + m_TrainModelFilter->SetSavedModel(&m_SavedModel); m_TrainModelFilter->SetOutputTensors(GetParameterStringList("training.outputtensors")); m_TrainModelFilter->SetTargetNodesNames(GetParameterStringList("training.targetnodes")); m_TrainModelFilter->SetBatchSize(GetParameterInt("training.batchsize")); @@ -446,8 +449,7 @@ class TensorflowModelTrain : public Application otbAppLogINFO("Set validation mode to classification validation"); m_ValidateModelFilter = ValidateModelFilterType::New(); - m_ValidateModelFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def()); - m_ValidateModelFilter->SetSession(m_SavedModel.session.get()); + m_ValidateModelFilter->SetSavedModel(&m_SavedModel); m_ValidateModelFilter->SetBatchSize(GetParameterInt("training.batchsize")); m_ValidateModelFilter->SetUserPlaceholders(GetUserPlaceholders("validation.userplaceholders")); m_ValidateModelFilter->SetInputPlaceholders(m_InputPlaceholdersForValidation); diff --git a/app/otbTrainClassifierFromDeepFeatures.cxx b/app/otbTrainClassifierFromDeepFeatures.cxx index ae1e5d94..39ac4189 100644 --- a/app/otbTrainClassifierFromDeepFeatures.cxx +++ b/app/otbTrainClassifierFromDeepFeatures.cxx @@ -1,6 +1,7 @@ /*========================================================================= - Copyright (c) Remi Cresson (IRSTEA). All rights reserved. + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -120,11 +121,6 @@ class TrainClassifierFromDeepFeatures : public CompositeApplication GetInternalApplication("train")->AddImageToParameterInputImageList("io.il", GetInternalApplication("tfmodel")->GetParameterOutputImage("out")); UpdateInternalParameters("train"); ExecuteInternal("train"); - } // DOExecute() - - void AfterExecuteAndWriteOutputs() - { - // Nothing to do } }; diff --git a/doc/APPLICATIONS.md b/doc/APPLICATIONS.md index 69282cd5..25d739fb 100644 --- a/doc/APPLICATIONS.md +++ b/doc/APPLICATIONS.md @@ -59,7 +59,7 @@ When using a model in OTBTF, the important thing is to know the following parame ![Schema](doc/images/schema.png) -The **scale factor** descibes the physical change of spacing of the outputs, typically introduced in the model by non unitary strides in pooling or convolution operators. +The **scale factor** describes the physical change of spacing of the outputs, typically introduced in the model by non unitary strides in pooling or convolution operators. For each output, it is expressed relatively to one single input of the model called the _reference input source_. Additionally, the names of the _target nodes_ must be known (e.g. "optimizer"). Also, the names of _user placeholders_, typically scalars placeholders that are used to control some parameters of the model, must be know (e.g. "dropout_rate"). @@ -356,7 +356,7 @@ But here, we will just perform some fine tuning of our model. The **SavedModel** is located in the `outmodel` directory. Our model is quite basic: it has two input placeholders, **x1** and **y1** respectively for input patches (with size 16x16) and input reference labels (with size 1x1). We named **prediction** the tensor that predict the labels and the optimizer that perform the stochastic gradient descent is an operator named **optimizer**. -We perform the fine tuning and we export the new model variables directly in the _outmodel/variables_ folder, overwritting the existing variables of the model. +We perform the fine tuning and we export the new model variables directly in the _outmodel/variables_ folder, overwriting the existing variables of the model. We use the **TensorflowModelTrain** application to perform the training of this existing model. ``` otbcli_TensorflowModelTrain -model.dir /path/to/oursavedmodel -training.targetnodesnames optimizer -training.source1.il samp_patches.tif -training.source1.patchsizex 16 -training.source1.patchsizey 16 -training.source1.placeholder x1 -training.source2.il samp_labels.tif -training.source2.patchsizex 1 -training.source2.patchsizey 1 -training.source2.placeholder y1 -model.saveto /path/to/oursavedmodel/variables/variables diff --git a/doc/DOCKERUSE.md b/doc/DOCKERUSE.md index 8d326c27..34e58cc3 100644 --- a/doc/DOCKERUSE.md +++ b/doc/DOCKERUSE.md @@ -20,6 +20,10 @@ Here is the list of OTBTF docker images hosted on [dockerhub](https://hub.docker | **mdl4eo/otbtf2.5:cpu** | Ubuntu Focal | r2.5 | 7.4.0 | CPU, few optimization | no | 5.2,6.1,7.0,7.5,8.6| | **mdl4eo/otbtf2.5:gpu** | Ubuntu Focal | r2.5 | 7.4.0 | GPU | no | 5.2,6.1,7.0,7.5,8.6| | **mdl4eo/otbtf2.5:gpu-dev** | Ubuntu Focal | r2.5 | 7.4.0 | GPU (dev) | yes | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf3.0:cpu-basic** | Ubuntu Focal | r2.5 | 7.4.0 | CPU, no optimization | no | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf3.0:cpu-basic-dev** | Ubuntu Focal | r2.5 | 7.4.0 | CPU, no optimization (dev) | yes | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf3.0:gpu** | Ubuntu Focal | r2.5 | 7.4.0 | GPU | yes | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf3.0:gpu-dev** | Ubuntu Focal | r2.5 | 7.4.0 | GPU (dev) | yes | 5.2,6.1,7.0,7.5,8.6| - `cpu` tagged docker images are compiled without optimization. - `gpu` tagged docker images are suited for **NVIDIA GPUs**. They use CUDA/CUDNN support. @@ -107,7 +111,7 @@ This section is largely inspired from the [moringa docker help](https://gitlab.i ## Useful diagnostic commands -Here are some usefull commands. +Here are some useful commands. ```bash docker info # System info diff --git a/doc/images/classif_map.png b/doc/images/classif_map.png index bc2a1079..017118c7 100644 Binary files a/doc/images/classif_map.png and b/doc/images/classif_map.png differ diff --git a/doc/images/docker_desktop_1.jpeg b/doc/images/docker_desktop_1.jpeg index 9a03bd58..21902e77 100644 Binary files a/doc/images/docker_desktop_1.jpeg and b/doc/images/docker_desktop_1.jpeg differ diff --git a/doc/images/docker_desktop_2.jpeg b/doc/images/docker_desktop_2.jpeg index e393cdb4..ec9f5632 100644 Binary files a/doc/images/docker_desktop_2.jpeg and b/doc/images/docker_desktop_2.jpeg differ diff --git a/doc/images/landcover.png b/doc/images/landcover.png index 913f3600..0fecc763 100644 Binary files a/doc/images/landcover.png and b/doc/images/landcover.png differ diff --git a/doc/images/logo.png b/doc/images/logo.png index 13aeaffd..13f91d5a 100644 Binary files a/doc/images/logo.png and b/doc/images/logo.png differ diff --git a/doc/images/model_training.png b/doc/images/model_training.png index d385dbb6..3e7d1689 100644 Binary files a/doc/images/model_training.png and b/doc/images/model_training.png differ diff --git a/doc/images/patches_extraction.png b/doc/images/patches_extraction.png index 024beb69..bd04ff59 100644 Binary files a/doc/images/patches_extraction.png and b/doc/images/patches_extraction.png differ diff --git a/doc/images/pix2pix.png b/doc/images/pix2pix.png index 1c79652f..5df6b349 100644 Binary files a/doc/images/pix2pix.png and b/doc/images/pix2pix.png differ diff --git a/doc/images/savedmodel_simple_cnn.png b/doc/images/savedmodel_simple_cnn.png index d494e99e..7ebe28ee 100644 Binary files a/doc/images/savedmodel_simple_cnn.png and b/doc/images/savedmodel_simple_cnn.png differ diff --git a/doc/images/savedmodel_simple_fcnn.png b/doc/images/savedmodel_simple_fcnn.png index bc9e7aab..d2e13fc3 100644 Binary files a/doc/images/savedmodel_simple_fcnn.png and b/doc/images/savedmodel_simple_fcnn.png differ diff --git a/doc/images/savedmodel_simple_pxs_fcn.png b/doc/images/savedmodel_simple_pxs_fcn.png index 169e9703..bfb72eab 100644 Binary files a/doc/images/savedmodel_simple_pxs_fcn.png and b/doc/images/savedmodel_simple_pxs_fcn.png differ diff --git a/doc/images/schema.png b/doc/images/schema.png index df80fcc1..f5788e33 100644 Binary files a/doc/images/schema.png and b/doc/images/schema.png differ diff --git a/doc/images/supresol.png b/doc/images/supresol.png index efff2fd2..310d36ab 100644 Binary files a/doc/images/supresol.png and b/doc/images/supresol.png differ diff --git a/include/otbTensorflowCommon.cxx b/include/otbTensorflowCommon.cxx index b93717ed..662c9d3e 100644 --- a/include/otbTensorflowCommon.cxx +++ b/include/otbTensorflowCommon.cxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even diff --git a/include/otbTensorflowCommon.h b/include/otbTensorflowCommon.h index e8db4b8b..fbd72810 100644 --- a/include/otbTensorflowCommon.h +++ b/include/otbTensorflowCommon.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -18,6 +18,9 @@ #include #include #include +#include "itkMacro.h" +#include "itkImageRegionConstIterator.h" +#include "itkImageRegionIterator.h" namespace otb { namespace tf { diff --git a/include/otbTensorflowCopyUtils.cxx b/include/otbTensorflowCopyUtils.cxx index 116aafb0..e9690511 100644 --- a/include/otbTensorflowCopyUtils.cxx +++ b/include/otbTensorflowCopyUtils.cxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -11,27 +11,31 @@ =========================================================================*/ #include "otbTensorflowCopyUtils.h" -namespace otb { -namespace tf { +namespace otb +{ +namespace tf +{ // // Display a TensorShape // -std::string PrintTensorShape(const tensorflow::TensorShape & shp) +std::string +PrintTensorShape(const tensorflow::TensorShape & shp) { std::stringstream s; - unsigned int nDims = shp.dims(); + unsigned int nDims = shp.dims(); s << "{" << shp.dim_size(0); - for (unsigned int d = 1 ; d < nDims ; d++) + for (unsigned int d = 1; d < nDims; d++) s << ", " << shp.dim_size(d); - s << "}" ; + s << "}"; return s.str(); } // // Display infos about a tensor // -std::string PrintTensorInfos(const tensorflow::Tensor & tensor) +std::string +PrintTensorInfos(const tensorflow::Tensor & tensor) { std::stringstream s; s << "Tensor "; @@ -39,17 +43,19 @@ std::string PrintTensorInfos(const tensorflow::Tensor & tensor) s << "shape is " << PrintTensorShape(tensor.shape()); // Data type s << " data type is " << tensor.dtype(); + s << " (" << tf::GetDataTypeAsString(tensor.dtype()) << ")"; return s.str(); } // // Create a tensor with the good datatype // -template -tensorflow::Tensor CreateTensor(tensorflow::TensorShape & shape) +template +tensorflow::Tensor +CreateTensor(tensorflow::TensorShape & shape) { tensorflow::DataType ts_dt = GetTensorflowDataType(); - tensorflow::Tensor out_tensor(ts_dt, shape); + tensorflow::Tensor out_tensor(ts_dt, shape); return out_tensor; } @@ -58,32 +64,35 @@ tensorflow::Tensor CreateTensor(tensorflow::TensorShape & shape) // Populate a tensor with the buffered region of a vector image using std::copy // Warning: tensor datatype must be consistent with the image value type // -template -void PopulateTensorFromBufferedVectorImage(const typename TImage::Pointer bufferedimagePtr, tensorflow::Tensor & out_tensor) +template +void +PopulateTensorFromBufferedVectorImage(const typename TImage::Pointer bufferedimagePtr, tensorflow::Tensor & out_tensor) { - size_t n_elem = bufferedimagePtr->GetNumberOfComponentsPerPixel() * - bufferedimagePtr->GetBufferedRegion().GetNumberOfPixels(); - std::copy_n(bufferedimagePtr->GetBufferPointer(), - n_elem, - out_tensor.flat().data()); + size_t n_elem = + bufferedimagePtr->GetNumberOfComponentsPerPixel() * bufferedimagePtr->GetBufferedRegion().GetNumberOfPixels(); + std::copy_n( + bufferedimagePtr->GetBufferPointer(), n_elem, out_tensor.flat().data()); } // // Recopy an VectorImage region into a 4D-shaped tensorflow::Tensor ({-1, sz_y, sz_x, sz_bands}) // -template -void RecopyImageRegionToTensor(const typename TImage::Pointer inputPtr, const typename TImage::RegionType & region, - tensorflow::Tensor & tensor, unsigned int elemIdx) // element position along the 1st dimension +template +void +RecopyImageRegionToTensor(const typename TImage::Pointer inputPtr, + const typename TImage::RegionType & region, + tensorflow::Tensor & tensor, + unsigned int elemIdx) // element position along the 1st dimension { typename itk::ImageRegionConstIterator inIt(inputPtr, region); - unsigned int nBands = inputPtr->GetNumberOfComponentsPerPixel(); - auto tMap = tensor.tensor(); + unsigned int nBands = inputPtr->GetNumberOfComponentsPerPixel(); + auto tMap = tensor.tensor(); for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt) { const int y = inIt.GetIndex()[1] - region.GetIndex()[1]; const int x = inIt.GetIndex()[0] - region.GetIndex()[0]; - for (unsigned int band = 0 ; band < nBands ; band++) + for (unsigned int band = 0; band < nBands; band++) tMap(elemIdx, y, x, band) = inIt.Get()[band]; } } @@ -92,9 +101,12 @@ void RecopyImageRegionToTensor(const typename TImage::Pointer inputPtr, const ty // Type-agnostic version of the 'RecopyImageRegionToTensor' function // TODO: add some numeric types // -template -void RecopyImageRegionToTensorWithCast(const typename TImage::Pointer inputPtr, const typename TImage::RegionType & region, - tensorflow::Tensor & tensor, unsigned int elemIdx) // element position along the 1st dimension +template +void +RecopyImageRegionToTensorWithCast(const typename TImage::Pointer inputPtr, + const typename TImage::RegionType & region, + tensorflow::Tensor & tensor, + unsigned int elemIdx) // element position along the 1st dimension { tensorflow::DataType dt = tensor.dtype(); if (dt == tensorflow::DT_FLOAT) @@ -110,21 +122,25 @@ void RecopyImageRegionToTensorWithCast(const typename TImage::Pointer inputPtr, else if (dt == tensorflow::DT_INT32) RecopyImageRegionToTensor(inputPtr, region, tensor, elemIdx); else if (dt == tensorflow::DT_UINT16) - RecopyImageRegionToTensor (inputPtr, region, tensor, elemIdx); + RecopyImageRegionToTensor(inputPtr, region, tensor, elemIdx); else if (dt == tensorflow::DT_INT16) RecopyImageRegionToTensor(inputPtr, region, tensor, elemIdx); else if (dt == tensorflow::DT_UINT8) - RecopyImageRegionToTensor (inputPtr, region, tensor, elemIdx); + RecopyImageRegionToTensor(inputPtr, region, tensor, elemIdx); else - itkGenericExceptionMacro("TF DataType "<< dt << " not currently implemented !"); + itkGenericExceptionMacro("TF DataType " << dt << " not currently implemented !"); } // // Sample a centered patch (from index) // -template -void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename TImage::IndexType & centerIndex, const typename TImage::SizeType & patchSize, - tensorflow::Tensor & tensor, unsigned int elemIdx) +template +void +SampleCenteredPatch(const typename TImage::Pointer inputPtr, + const typename TImage::IndexType & centerIndex, + const typename TImage::SizeType & patchSize, + tensorflow::Tensor & tensor, + unsigned int elemIdx) { typename TImage::IndexType regionStart; regionStart[0] = centerIndex[0] - patchSize[0] / 2; @@ -136,9 +152,13 @@ void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename // // Sample a centered patch (from coordinates) // -template -void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename TImage::PointType & centerCoord, const typename TImage::SizeType & patchSize, - tensorflow::Tensor & tensor, unsigned int elemIdx) +template +void +SampleCenteredPatch(const typename TImage::Pointer inputPtr, + const typename TImage::PointType & centerCoord, + const typename TImage::SizeType & patchSize, + tensorflow::Tensor & tensor, + unsigned int elemIdx) { // Assuming tensor is of shape {-1, sz_y, sz_x, sz_bands} // Get the index of the center @@ -147,41 +167,48 @@ void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename SampleCenteredPatch(inputPtr, centerIndex, patchSize, tensor, elemIdx); } -// Return the number of channels that the output tensor will occupy in the output image // +// Return the number of channels from the TensorShapeProto // shape {n} --> 1 (e.g. a label) -// shape {n, c} --> c (e.g. a vector) -// shape {x, y, c} --> c (e.g. a patch) -// shape {n, x, y, c} --> c (e.g. some patches) +// shape {n, c} --> c (e.g. a pixel) +// shape {n, x, y} --> 1 (e.g. a mono-channel patch) +// shape {n, x, y, c} --> c (e.g. a multi-channel patch) // -tensorflow::int64 GetNumberOfChannelsForOutputTensor(const tensorflow::Tensor & tensor) +tensorflow::int64 +GetNumberOfChannelsFromShapeProto(const tensorflow::TensorShapeProto & proto) { - const tensorflow::TensorShape shape = tensor.shape(); - const int nDims = shape.dims(); + const int nDims = proto.dim_size(); if (nDims == 1) + // e.g. a batch prediction, as flat tensor + return 1; + if (nDims == 3) + // typically when the last dimension in squeezed following a + // computation that does not keep dimensions (e.g. reduce_sum, etc.) return 1; - return shape.dim_size(nDims - 1); + // any other dimension: we assume that the last dimension represent the + // number of channels in the output image. + return proto.dim(nDims - 1).size(); } // // Copy a tensor into the image region -// TODO: Enable to change mapping from source tensor to image to make it more generic -// -// Right now, only the following output tensor shapes can be processed: -// shape {n} --> 1 (e.g. a label) -// shape {n, c} --> c (e.g. a vector) -// shape {x, y, c} --> c (e.g. a multichannel image) // -template -void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename TImage::RegionType & bufferRegion, - typename TImage::Pointer outputPtr, const typename TImage::RegionType & outputRegion, int & channelOffset) +template +void +CopyTensorToImageRegion(const tensorflow::Tensor & tensor, + const typename TImage::RegionType & bufferRegion, + typename TImage::Pointer outputPtr, + const typename TImage::RegionType & outputRegion, + int & channelOffset) { // Flatten the tensor auto tFlat = tensor.flat(); - // Get the size of the last component of the tensor (see 'GetNumberOfChannelsForOutputTensor(...)') - const tensorflow::int64 outputDimSize_C = GetNumberOfChannelsForOutputTensor(tensor); + // Get the number of component of the output image + tensorflow::TensorShapeProto proto; + tensor.shape().AsProto(&proto); + const tensorflow::int64 outputDimSize_C = GetNumberOfChannelsFromShapeProto(proto); // Number of columns (size x of the buffer) const tensorflow::int64 nCols = bufferRegion.GetSize(0); @@ -191,15 +218,15 @@ void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename T const tensorflow::int64 nElmI = bufferRegion.GetNumberOfPixels() * outputDimSize_C; if (nElmI != nElmT) { - itkGenericExceptionMacro("Number of elements in the tensor is " << nElmT << - " but image outputRegion has " << nElmI << - " values to fill.\nBuffer region:\n" << bufferRegion << - "\nNumber of components: " << outputDimSize_C << - "\nTensor shape:\n " << PrintTensorShape(tensor.shape()) << - "\nPlease check the input(s) field of view (FOV), " << - "the output field of expression (FOE), and the " << - "output spacing scale if you run the model in fully " << - "convolutional mode (how many strides in your model?)"); + itkGenericExceptionMacro("Number of elements in the tensor is " << nElmT + << " but image outputRegion has " << nElmI << " values to fill.\n" + << "Buffer region is: \n" << bufferRegion << "\n" + << "Number of components in the output image: " << outputDimSize_C << "\n" + << "Tensor shape: " << PrintTensorShape(tensor.shape()) << "\n" + << "Please check the input(s) field of view (FOV), " + << "the output field of expression (FOE), and the " + << "output spacing scale if you run the model in fully " + << "convolutional mode (how many strides in your model?)"); } // Iterate over the image @@ -212,145 +239,217 @@ void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename T // TODO: it could be useful to change the tensor-->image mapping here. // e.g use a lambda for "pos" calculation const int pos = outputDimSize_C * (y * nCols + x); - for (unsigned int c = 0 ; c < outputDimSize_C ; c++) - outIt.Get()[channelOffset + c] = tFlat( pos + c); + for (unsigned int c = 0; c < outputDimSize_C; c++) + outIt.Get()[channelOffset + c] = tFlat(pos + c); } // Update the offset channelOffset += outputDimSize_C; - } // // Type-agnostic version of the 'CopyTensorToImageRegion' function -// TODO: add some numeric types // -template -void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename TImage::RegionType & bufferRegion, - typename TImage::Pointer outputPtr, const typename TImage::RegionType & region, int & channelOffset) +template +void +CopyTensorToImageRegion(const tensorflow::Tensor & tensor, + const typename TImage::RegionType & bufferRegion, + typename TImage::Pointer outputPtr, + const typename TImage::RegionType & region, + int & channelOffset) { tensorflow::DataType dt = tensor.dtype(); if (dt == tensorflow::DT_FLOAT) - CopyTensorToImageRegion (tensor, bufferRegion, outputPtr, region, channelOffset); + CopyTensorToImageRegion(tensor, bufferRegion, outputPtr, region, channelOffset); else if (dt == tensorflow::DT_DOUBLE) - CopyTensorToImageRegion (tensor, bufferRegion, outputPtr, region, channelOffset); + CopyTensorToImageRegion(tensor, bufferRegion, outputPtr, region, channelOffset); + else if (dt == tensorflow::DT_UINT64) + CopyTensorToImageRegion(tensor, bufferRegion, outputPtr, region, channelOffset); else if (dt == tensorflow::DT_INT64) CopyTensorToImageRegion(tensor, bufferRegion, outputPtr, region, channelOffset); + else if (dt == tensorflow::DT_UINT32) + CopyTensorToImageRegion(tensor, bufferRegion, outputPtr, region, channelOffset); else if (dt == tensorflow::DT_INT32) - CopyTensorToImageRegion (tensor, bufferRegion, outputPtr, region, channelOffset); + CopyTensorToImageRegion(tensor, bufferRegion, outputPtr, region, channelOffset); + else if (dt == tensorflow::DT_UINT16) + CopyTensorToImageRegion(tensor, bufferRegion, outputPtr, region, channelOffset); + else if (dt == tensorflow::DT_INT16) + CopyTensorToImageRegion(tensor, bufferRegion, outputPtr, region, channelOffset); + else if (dt == tensorflow::DT_UINT8) + CopyTensorToImageRegion(tensor, bufferRegion, outputPtr, region, channelOffset); else - itkGenericExceptionMacro("TF DataType "<< dt << " not currently implemented !"); - + itkGenericExceptionMacro("TF DataType " << dt << " not currently implemented !"); } // // Compare two string lowercase // -bool iequals(const std::string& a, const std::string& b) +bool +iequals(const std::string & a, const std::string & b) { - return std::equal(a.begin(), a.end(), - b.begin(), b.end(), - [](char cha, char chb) { - return tolower(cha) == tolower(chb); - }); + return std::equal( + a.begin(), a.end(), b.begin(), b.end(), [](char cha, char chb) { return tolower(cha) == tolower(chb); }); } -// Convert an expression into a dict -// +// Convert a value into a tensor // Following types are supported: // -bool // -int // -float +// -vector of float +// +// e.g. "true", "0.2", "14", "(1.2, 4.2, 4)" // -// e.g. is_training=true, droptout=0.2, nfeat=14 -std::pair ExpressionToTensor(std::string expression) +// TODO: we could add some other types (e.g. string) +tensorflow::Tensor +ValueToTensor(std::string value) { - std::pair dict; + std::vector values; - std::size_t found = expression.find("="); - if (found != std::string::npos) - { - // Find name and value - std::string name = expression.substr(0, found); - std::string value = expression.substr(found+1); + // Check if value is a vector or a scalar + const bool has_left = (value[0] == '('); + const bool has_right = value[value.size() - 1] == ')'; - dict.first = name; + // Check consistency + bool is_vec = false; + if (has_left || has_right) + { + is_vec = true; + if (!has_left || !has_right) + itkGenericExceptionMacro("Error parsing vector expression (missing parentheses ?)" << value); + } + + // Scalar --> Vector for generic processing + if (!is_vec) + { + values.push_back(value); + } + else + { + // Remove "(" and ")" chars + std::string trimmed_value = value.substr(1, value.size() - 2); + + // Split string into vector using "," delimiter + std::regex rgx("\\s*,\\s*"); + std::sregex_token_iterator iter{ trimmed_value.begin(), trimmed_value.end(), rgx, -1 }; + std::sregex_token_iterator end; + values = std::vector({ iter, end }); + } + + // Find type + bool has_dot = false; + bool is_digit = true; + for (auto & val : values) + { + has_dot = has_dot || val.find(".") != std::string::npos; + is_digit = is_digit && val.find_first_not_of("-0123456789.") == std::string::npos; + } + + // Create tensor + tensorflow::TensorShape shape({values.size()}); + tensorflow::Tensor out(tensorflow::DT_BOOL, shape); + if (is_digit) + { + if (has_dot) + out = tensorflow::Tensor(tensorflow::DT_FLOAT, shape); + else + out = tensorflow::Tensor(tensorflow::DT_INT32, shape); + } + + // Fill tensor + unsigned int idx = 0; + for (auto & val : values) + { - // Find type - std::size_t found_dot = value.find(".") != std::string::npos; - std::size_t is_digit = value.find_first_not_of("0123456789.") == std::string::npos; - if (is_digit) + if (is_digit) + { + if (has_dot) { - if (found_dot) + // FLOAT + try { - // FLOAT - try - { - float val = std::stof(value); - tensorflow::Tensor out(tensorflow::DT_FLOAT, tensorflow::TensorShape()); - out.scalar()() = val; - dict.second = out; - - } - catch(...) - { - itkGenericExceptionMacro("Error parsing name=" - << name << " with value=" << value << " as float"); - } - + out.flat()(idx) = std::stof(val); } - else + catch (...) { - // INT - try - { - int val = std::stoi(value); - tensorflow::Tensor out(tensorflow::DT_INT32, tensorflow::TensorShape()); - out.scalar()() = val; - dict.second = out; - - } - catch(...) - { - itkGenericExceptionMacro("Error parsing name=" - << name << " with value=" << value << " as int"); - } - + itkGenericExceptionMacro("Error parsing value \"" << val << "\" as float"); } } else { - // BOOL - bool val = true; - if (iequals(value, "true")) - { - val = true; - } - else if (iequals(value, "false")) + // INT + try { - val = false; + out.flat()(idx) = std::stoi(val); } - else + catch (...) { - itkGenericExceptionMacro("Error parsing name=" - << name << " with value=" << value << " as bool"); + itkGenericExceptionMacro("Error parsing value \"" << val << "\" as int"); } - tensorflow::Tensor out(tensorflow::DT_BOOL, tensorflow::TensorShape()); - out.scalar()() = val; - dict.second = out; } - } else { - itkGenericExceptionMacro("The following expression is not valid: " - << "\n\t" << expression - << ".\nExpression must be in the form int_value=1 or float_value=1.0 or bool_value=true."); + // BOOL + bool ret = true; + if (iequals(val, "true")) + { + ret = true; + } + else if (iequals(val, "false")) + { + ret = false; + } + else + { + itkGenericExceptionMacro("Error parsing value \"" << val << "\" as bool"); + } + out.flat()(idx) = ret; } + idx++; + } + otbLogMacro(Debug, << "Returning tensor: "<< out.DebugString()); + + return out; +} + +// Convert an expression into a dict +// +// Following types are supported: +// -bool +// -int +// -float +// -vector of float +// +// e.g. is_training=true, droptout=0.2, nfeat=14, x=(1.2, 4.2, 4) +std::pair +ExpressionToTensor(std::string expression) +{ + std::pair dict; - return dict; + std::size_t found = expression.find("="); + if (found != std::string::npos) + { + // Find name and value + std::string name = expression.substr(0, found); + std::string value = expression.substr(found + 1); + + dict.first = name; + + // Transform value into tensorflow::Tensor + dict.second = ValueToTensor(value); + } + else + { + itkGenericExceptionMacro("The following expression is not valid: " + << "\n\t" << expression << ".\nExpression must be in one of the following form:" + << "\n- int32_value=1 \n- float_value=1.0 \n- bool_value=true." + << "\n- float_vec=(1.0, 5.253, 2)"); + } + + return dict; } } // end namespace tf diff --git a/include/otbTensorflowCopyUtils.h b/include/otbTensorflowCopyUtils.h index 47ad6cf2..17458791 100644 --- a/include/otbTensorflowCopyUtils.h +++ b/include/otbTensorflowCopyUtils.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -15,18 +15,24 @@ // ITK exception #include "itkMacro.h" +// OTB log +#include "otbMacro.h" + // ITK image iterators #include "itkImageRegionIterator.h" #include "itkImageRegionConstIterator.h" // tensorflow::tensor #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" // tensorflow::datatype <--> ImageType::InternalPixelType #include "otbTensorflowDataTypeBridge.h" // STD #include +#include namespace otb { namespace tf { @@ -63,8 +69,8 @@ void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename template void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename TImage::PointType & centerCoord, const typename TImage::SizeType & patchSize, tensorflow::Tensor & tensor, unsigned int elemIdx); -// Return the number of channels that the output tensor will occupy in the output image -tensorflow::int64 GetNumberOfChannelsForOutputTensor(const tensorflow::Tensor & tensor); +// Return the number of channels from the TensorflowShapeProto +tensorflow::int64 GetNumberOfChannelsFromShapeProto(const tensorflow::TensorShapeProto & proto); // Copy a tensor into the image region template @@ -74,6 +80,9 @@ void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, typename TImage: template void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename TImage::RegionType & bufferRegion, typename TImage::Pointer outputPtr, const typename TImage::RegionType & outputRegion, int & channelOffset); +// Convert a value into a tensor +tensorflow::Tensor ValueToTensor(std::string value); + // Convert an expression into a dict std::pair ExpressionToTensor(std::string expression); diff --git a/include/otbTensorflowDataTypeBridge.cxx b/include/otbTensorflowDataTypeBridge.cxx index 5a421c92..a510cb4e 100644 --- a/include/otbTensorflowDataTypeBridge.cxx +++ b/include/otbTensorflowDataTypeBridge.cxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -80,5 +80,13 @@ bool HasSameDataType(const tensorflow::Tensor & tensor) return GetTensorflowDataType() == tensor.dtype(); } +// +// Return the datatype as string +// +tensorflow::string GetDataTypeAsString(tensorflow::DataType dt) +{ + return tensorflow::DataTypeString(dt); +} + } // end namespace tf } // end namespace otb diff --git a/include/otbTensorflowDataTypeBridge.h b/include/otbTensorflowDataTypeBridge.h index 16e9dd23..af6be18d 100644 --- a/include/otbTensorflowDataTypeBridge.h +++ b/include/otbTensorflowDataTypeBridge.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -27,6 +27,9 @@ tensorflow::DataType GetTensorflowDataType(); template bool HasSameDataType(const tensorflow::Tensor & tensor); +// Return datatype as string +tensorflow::string GetDataTypeAsString(tensorflow::DataType dt); + } // end namespace tf } // end namespace otb diff --git a/include/otbTensorflowGraphOperations.cxx b/include/otbTensorflowGraphOperations.cxx index 16300f6c..d40c4da6 100644 --- a/include/otbTensorflowGraphOperations.cxx +++ b/include/otbTensorflowGraphOperations.cxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -11,161 +11,180 @@ =========================================================================*/ #include "otbTensorflowGraphOperations.h" -namespace otb { -namespace tf { +namespace otb +{ +namespace tf +{ + // -// Restore a model from a path +// Load SavedModel variables // -void RestoreModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle) +void +RestoreModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle) { tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape()); checkpointPathTensor.scalar()() = path; - std::vector> feed_dict = - {{bundle.meta_graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}}; - auto status = bundle.session->Run(feed_dict, {}, {bundle.meta_graph_def.saver_def().restore_op_name()}, nullptr); + std::vector> feed_dict = { + { bundle.meta_graph_def.saver_def().filename_tensor_name(), checkpointPathTensor } + }; + auto status = bundle.session->Run(feed_dict, {}, { bundle.meta_graph_def.saver_def().restore_op_name() }, nullptr); if (!status.ok()) - { - itkGenericExceptionMacro("Can't restore the input model: " << status.ToString() ); - } + { + itkGenericExceptionMacro("Can't restore the input model: " << status.ToString()); + } } // -// Restore a model from a path +// Save SavedModel variables // -void SaveModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle) +void +SaveModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle) { tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape()); checkpointPathTensor.scalar()() = path; - std::vector> feed_dict = - {{bundle.meta_graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}}; - auto status = bundle.session->Run(feed_dict, {}, {bundle.meta_graph_def.saver_def().save_tensor_name()}, nullptr); + std::vector> feed_dict = { + { bundle.meta_graph_def.saver_def().filename_tensor_name(), checkpointPathTensor } + }; + auto status = bundle.session->Run(feed_dict, {}, { bundle.meta_graph_def.saver_def().save_tensor_name() }, nullptr); if (!status.ok()) - { - itkGenericExceptionMacro("Can't restore the input model: " << status.ToString() ); - } + { + itkGenericExceptionMacro("Can't restore the input model: " << status.ToString()); + } } // -// Load a session and a graph from a folder +// Load a SavedModel // -void LoadModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle) +void +LoadModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle, std::vector tagList) { + // If the tag list is empty, we push back the default tag for model serving + if (tagList.size() == 0) + tagList.push_back(tensorflow::kSavedModelTagServe); + // std::vector --> std::unordered_list + std::unordered_set tagSets; + std::copy(tagList.begin(), tagList.end(), std::inserter(tagSets, tagSets.end())); // copy in unordered_set + + // Call to tensorflow::LoadSavedModel tensorflow::RunOptions runoptions; runoptions.set_trace_level(tensorflow::RunOptions_TraceLevel_FULL_TRACE); - auto status = tensorflow::LoadSavedModel(tensorflow::SessionOptions(), runoptions, - path, {tensorflow::kSavedModelTagServe}, &bundle); + auto status = tensorflow::LoadSavedModel(tensorflow::SessionOptions(), runoptions, path, tagSets, &bundle); if (!status.ok()) - { - itkGenericExceptionMacro("Can't load the input model: " << status.ToString() ); - } - + { + itkGenericExceptionMacro("Can't load the input model: " << status.ToString()); + } } -// -// Load a graph from a .meta file -// -tensorflow::GraphDef LoadGraph(std::string filename) -{ - tensorflow::MetaGraphDef meta_graph_def; - auto status = tensorflow::ReadBinaryProto(tensorflow::Env::Default(), filename, &meta_graph_def); - if (!status.ok()) - { - itkGenericExceptionMacro("Can't load the input model: " << status.ToString() ); - } - - return meta_graph_def.graph_def(); -} -// // Get the following attributes of the specified tensors (by name) of a graph: +// - layer name, as specified in the model // - shape // - datatype -// Here we assume that the node's output is a tensor -// -void GetTensorAttributes(const tensorflow::GraphDef & graph, std::vector & tensorsNames, - std::vector & shapes, std::vector & dataTypes) +void +GetTensorAttributes(const tensorflow::protobuf::Map layers, + std::vector & tensorsNames, + std::vector & layerNames, + std::vector & shapes, + std::vector & dataTypes) { // Allocation shapes.clear(); - shapes.reserve(tensorsNames.size()); dataTypes.clear(); - dataTypes.reserve(tensorsNames.size()); + layerNames.clear(); - // Get infos - for (std::vector::iterator nameIt = tensorsNames.begin(); - nameIt != tensorsNames.end(); ++nameIt) - { - bool found = false; - for (int i = 0 ; i < graph.node_size() ; i++) - { - tensorflow::NodeDef node = graph.node(i); + // Debug infos + otbLogMacro(Debug, << "Nodes contained in the model: "); + for (auto const & layer : layers) + otbLogMacro(Debug, << "\t" << layer.first); + // When the user doesn't specify output.names, m_OutputTensors defaults to an empty list that we can not iterate over. + // We change it to a list containing an empty string [""] + if (tensorsNames.size() == 0) + { + otbLogMacro(Debug, << "No output.name specified. Using a default list with one empty string."); + tensorsNames.push_back(""); + } - if (node.name().compare((*nameIt)) == 0) - { - found = true; + // Next, we fill layerNames + int k = 0; // counter used for tensorsNames + for (auto const & name: tensorsNames) + { + bool found = false; + tensorflow::TensorInfo tensor_info; - // Set default to DT_FLOAT - tensorflow::DataType ts_dt = tensorflow::DT_FLOAT; + // If the user didn't specify the placeholdername, choose the kth layer inside the model + if (name.size() == 0) + { + found = true; + // select the k-th element of `layers` + auto it = layers.begin(); + std::advance(it, k); + layerNames.push_back(it->second.name()); + tensor_info = it->second; + otbLogMacro(Debug, << "Input " << k << " corresponds to " << it->first << " in the model"); + } - // Default (input?) tensor type - auto test_is_output = node.attr().find("T"); - if (test_is_output != node.attr().end()) - { - ts_dt = node.attr().at("T").type(); - } - auto test_has_dtype = node.attr().find("dtype"); - if (test_has_dtype != node.attr().end()) - { - ts_dt = node.attr().at("dtype").type(); - } - auto test_output_type = node.attr().find("output_type"); - if (test_output_type != node.attr().end()) + // Else, if the user specified the placeholdername, find the corresponding layer inside the model + else + { + otbLogMacro(Debug, << "Searching for corresponding node of: " << name << "... "); + for (auto const & layer : layers) + { + // layer is a pair (name, tensor_info) + // cf https://stackoverflow.com/questions/63181951/how-to-get-graph-or-graphdef-from-a-given-model + std::string layername = layer.first; + if (layername.substr(0, layername.find(":")).compare(name) == 0) { - // if there is an output type, we take it instead of the - // datatype of the input tensor - ts_dt = node.attr().at("output_type").type(); + found = true; + layerNames.push_back(layer.second.name()); + tensor_info = layer.second; + otbLogMacro(Debug, << "Found: " << layer.second.name() << " in the model"); } - dataTypes.push_back(ts_dt); + } // next layer + } // end else - // Get the tensor's shape - // Here we assure it's a tensor, with 1 shape - tensorflow::TensorShapeProto ts_shp = node.attr().at("_output_shapes").list().shape(0); - shapes.push_back(ts_shp); - } - } + k += 1; if (!found) { - itkGenericExceptionMacro("Tensor name \"" << (*nameIt) << "\" not found" ); + itkGenericExceptionMacro("Tensor name \"" << name << "\" not found. \n" + << "You can list all inputs/outputs of your SavedModel by " + << "running: \n\t `saved_model_cli show --dir your_model_dir --all`"); } - } + // Default tensor type + tensorflow::DataType ts_dt = tensor_info.dtype(); + dataTypes.push_back(ts_dt); + // Get the tensor's shape + // Here we assure it's a tensor, with 1 shape + tensorflow::TensorShapeProto ts_shp = tensor_info.tensor_shape(); + shapes.push_back(ts_shp); + } // next tensor name } // // Print a lot of stuff about the specified nodes of the graph // -void PrintNodeAttributes(const tensorflow::GraphDef & graph, std::vector & nodesNames) +void +PrintNodeAttributes(const tensorflow::GraphDef & graph, const std::vector & nodesNames) { std::cout << "Go through graph:" << std::endl; std::cout << "#\tname" << std::endl; - for (int i = 0 ; i < graph.node_size() ; i++) + for (int i = 0; i < graph.node_size(); i++) { tensorflow::NodeDef node = graph.node(i); std::cout << i << "\t" << node.name() << std::endl; - for (std::vector::iterator nameIt = nodesNames.begin(); - nameIt != nodesNames.end(); ++nameIt) + for (auto const & name: nodesNames) { - if (node.name().compare((*nameIt)) == 0) + if (node.name().compare(name) == 0) { std::cout << "Node " << i << " : " << std::endl; std::cout << "\tName: " << node.name() << std::endl; - std::cout << "\tinput_size() : " << node.input_size() << std::endl; + std::cout << "\tinput_size(): " << node.input_size() << std::endl; std::cout << "\tPrintDebugString --------------------------------"; std::cout << std::endl; node.PrintDebugString(); @@ -173,20 +192,19 @@ void PrintNodeAttributes(const tensorflow::GraphDef & graph, std::vectorfirst << std::endl; - std::cout << "\t\tValue.value_case() :" << attr->second.value_case() << std::endl; + std::cout << "\t\tKey: " << attr->first << std::endl; + std::cout << "\t\tValue.value_case(): " << attr->second.value_case() << std::endl; std::cout << "\t\tPrintDebugString --------------------------------"; std::cout << std::endl; attr->second.PrintDebugString(); std::cout << "\t\t-------------------------------------------------" << std::endl; std::cout << std::endl; } // next attribute - } // node name match - } // next node name - } // next node of the graph - + } // node name match + } // next node name + } // next node of the graph } } // end namespace tf diff --git a/include/otbTensorflowGraphOperations.h b/include/otbTensorflowGraphOperations.h index 4b4e93c0..6ad4a4e2 100644 --- a/include/otbTensorflowGraphOperations.h +++ b/include/otbTensorflowGraphOperations.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -24,30 +24,30 @@ // ITK exception #include "itkMacro.h" +// OTB log +#include "otbMacro.h" + namespace otb { namespace tf { -// Restore a model from a path +// Load SavedModel variables void RestoreModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle); -// Restore a model from a path +// Save SavedModel variables void SaveModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle); -// Load a session and a graph from a folder -void LoadModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle); - -// Load a graph from a .meta file -tensorflow::GraphDef LoadGraph(std::string filename); +// Load SavedModel +void LoadModel(const tensorflow::tstring path, tensorflow::SavedModelBundle & bundle, std::vector tagList); // Get the following attributes of the specified tensors (by name) of a graph: // - shape // - datatype // Here we assume that the node's output is a tensor -void GetTensorAttributes(const tensorflow::GraphDef & graph, std::vector & tensorsNames, +void GetTensorAttributes(const tensorflow::protobuf::Map layers, std::vector & tensorsNames, std::vector & shapes, std::vector & dataTypes); // Print a lot of stuff about the specified nodes of the graph -void PrintNodeAttributes(const tensorflow::GraphDef & graph, std::vector & nodesNames); +void PrintNodeAttributes(const tensorflow::GraphDef & graph, const std::vector & nodesNames); } // end namespace tf } // end namespace otb diff --git a/include/otbTensorflowMultisourceModelBase.h b/include/otbTensorflowMultisourceModelBase.h index 9fb0a79c..d10648ea 100644 --- a/include/otbTensorflowMultisourceModelBase.h +++ b/include/otbTensorflowMultisourceModelBase.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -15,10 +15,12 @@ #include "itkProcessObject.h" #include "itkNumericTraits.h" #include "itkSimpleDataObjectDecorator.h" +#include "itkImageToImageFilter.h" // Tensorflow #include "tensorflow/core/public/session.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/cc/saved_model/signature_constants.h" // Tensorflow helpers #include "otbTensorflowGraphOperations.h" @@ -45,8 +47,7 @@ namespace otb * be the same. If not, an exception will be thrown during the method * GenerateOutputInformation(). * - * The TensorFlow graph and session must be set using the SetGraph() and - * SetSession() methods. + * The TensorFlow SavedModel pointer must be set using the SetSavedModel() method. * * Target nodes names of the TensorFlow graph that must be triggered can be set * with the SetTargetNodesNames. @@ -103,10 +104,11 @@ public itk::ImageToImageFilter typedef std::vector TensorListType; /** Set and Get the Tensorflow session and graph */ - void SetGraph(tensorflow::GraphDef graph) { m_Graph = graph; } - tensorflow::GraphDef GetGraph() { return m_Graph ; } - void SetSession(tensorflow::Session * session) { m_Session = session; } - tensorflow::Session * GetSession() { return m_Session; } + void SetSavedModel(tensorflow::SavedModelBundle * saved_model) {m_SavedModel = saved_model;} + tensorflow::SavedModelBundle * GetSavedModel() {return m_SavedModel;} + + /** Get the SignatureDef */ + tensorflow::SignatureDef GetSignatureDef(); /** Model parameters */ void PushBackInputTensorBundle(std::string name, SizeType receptiveField, ImagePointerType image); @@ -129,8 +131,8 @@ public itk::ImageToImageFilter itkGetMacro(OutputExpressionFields, SizeListType); /** User placeholders */ - void SetUserPlaceholders(DictType dict) { m_UserPlaceholders = dict; } - DictType GetUserPlaceholders() { return m_UserPlaceholders; } + void SetUserPlaceholders(const DictType & dict) {m_UserPlaceholders = dict;} + DictType GetUserPlaceholders() {return m_UserPlaceholders;} /** Target nodes names */ itkSetMacro(TargetNodesNames, StringList); @@ -157,8 +159,7 @@ public itk::ImageToImageFilter void operator=(const Self&); //purposely not implemented // Tensorflow graph and session - tensorflow::GraphDef m_Graph; // The TensorFlow graph - tensorflow::Session * m_Session; // The TensorFlow session + tensorflow::SavedModelBundle * m_SavedModel; // The TensorFlow model // Model parameters StringList m_InputPlaceholders; // Input placeholders names @@ -174,6 +175,10 @@ public itk::ImageToImageFilter TensorShapeProtoList m_InputTensorsShapes; // Input tensors shapes TensorShapeProtoList m_OutputTensorsShapes; // Output tensors shapes + // Layer names inside the model corresponding to inputs and outputs + StringList m_InputLayers; // List of input names, as contained in the model + StringList m_OutputLayers; // List of output names, as contained in the model + }; // end class diff --git a/include/otbTensorflowMultisourceModelBase.hxx b/include/otbTensorflowMultisourceModelBase.hxx index 02c98bae..752b7c9d 100644 --- a/include/otbTensorflowMultisourceModelBase.hxx +++ b/include/otbTensorflowMultisourceModelBase.hxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -20,27 +20,56 @@ namespace otb template TensorflowMultisourceModelBase ::TensorflowMultisourceModelBase() - { - m_Session = nullptr; +{ Superclass::SetCoordinateTolerance(itk::NumericTraits::max() ); Superclass::SetDirectionTolerance(itk::NumericTraits::max() ); - } + + m_SavedModel = NULL; +} + +template +tensorflow::SignatureDef +TensorflowMultisourceModelBase +::GetSignatureDef() +{ + auto signatures = this->GetSavedModel()->GetSignatures(); + tensorflow::SignatureDef signature_def; + + if (signatures.size() == 0) + { + itkExceptionMacro("There are no available signatures for this tag-set. \n" << + "Please check which tag-set to use by running "<< + "`saved_model_cli show --dir your_model_dir --all`"); + } + + // If serving_default key exists (which is the default for TF saved model), choose it as signature + // Else, choose the first one + if (signatures.contains(tensorflow::kDefaultServingSignatureDefKey)) + { + signature_def = signatures.at(tensorflow::kDefaultServingSignatureDefKey); + } + else + { + signature_def = signatures.begin()->second; + } + return signature_def; +} template void TensorflowMultisourceModelBase ::PushBackInputTensorBundle(std::string placeholder, SizeType receptiveField, ImagePointerType image) - { +{ Superclass::PushBackInput(image); m_InputReceptiveFields.push_back(receptiveField); m_InputPlaceholders.push_back(placeholder); - } +} template std::stringstream TensorflowMultisourceModelBase ::GenerateDebugReport(DictType & inputs) - { +{ // Create a debug report std::stringstream debugReport; @@ -51,43 +80,53 @@ TensorflowMultisourceModelBase // Describe inputs for (unsigned int i = 0 ; i < this->GetNumberOfInputs() ; i++) - { + { const ImagePointerType inputPtr = const_cast(this->GetInput(i)); const RegionType reqRegion = inputPtr->GetRequestedRegion(); debugReport << "Input #" << i << ":\n"; debugReport << "Requested region: " << reqRegion << "\n"; - debugReport << "Tensor shape (\"" << inputs[i].first << "\": " << tf::PrintTensorShape(inputs[i].second.shape()) << "\n"; - } + debugReport << "Tensor \"" << inputs[i].first << "\": " << tf::PrintTensorInfos(inputs[i].second) << "\n"; + } // Show user placeholders debugReport << "User placeholders:\n" ; for (auto& dict: this->GetUserPlaceholders()) - { - debugReport << dict.first << " " << tf::PrintTensorInfos(dict.second) << "\n" << std::endl; - } + { + debugReport << "Tensor \"" << dict.first << "\": " << tf::PrintTensorInfos(dict.second) << "\n" << std::endl; + } return debugReport; - } +} template void TensorflowMultisourceModelBase ::RunSession(DictType & inputs, TensorListType & outputs) - { +{ // Add the user's placeholders - for (auto& dict: this->GetUserPlaceholders()) - { - inputs.push_back(dict); - } + std::copy(this->GetUserPlaceholders().begin(), this->GetUserPlaceholders().end(), std::back_inserter(inputs)); // Run the TF session here // The session will initialize the outputs + // `inputs` corresponds to a mapping {name, tensor}, with the name being specified by the user when calling TensorFlowModelServe + // we must adapt it to `inputs_new`, that corresponds to a mapping {layerName, tensor}, with the layerName being from the model + DictType inputs_new; + int k = 0; + for (auto& dict: inputs) + { + DictElementType element = {m_InputLayers[k], dict.second}; + inputs_new.push_back(element); + k+=1; + } + // Run the session, evaluating our output tensors from the graph - auto status = this->GetSession()->Run(inputs, m_OutputTensors, m_TargetNodesNames, &outputs); - if (!status.ok()) { + auto status = this->GetSavedModel()->session.get()->Run(inputs_new, m_OutputLayers, m_TargetNodesNames, &outputs); + + if (!status.ok()) + { // Create a debug report std::stringstream debugReport = GenerateDebugReport(inputs); @@ -96,16 +135,14 @@ TensorflowMultisourceModelBase itkExceptionMacro("Can't run the tensorflow session !\n" << "Tensorflow error message:\n" << status.ToString() << "\n" "OTB Filter debug message:\n" << debugReport.str() ); - } - - } +} template void TensorflowMultisourceModelBase ::GenerateOutputInformation() - { +{ // Check that the number of the following is the same // - input placeholders names @@ -113,30 +150,27 @@ TensorflowMultisourceModelBase // - input images const unsigned int nbInputs = this->GetNumberOfInputs(); if (nbInputs != m_InputReceptiveFields.size() || nbInputs != m_InputPlaceholders.size()) - { + { itkExceptionMacro("Number of input images is " << nbInputs << " but the number of input patches size is " << m_InputReceptiveFields.size() << " and the number of input tensors names is " << m_InputPlaceholders.size()); - } - - // Check that the number of the following is the same - // - output tensors names - // - output expression fields - if (m_OutputExpressionFields.size() != m_OutputTensors.size()) - { - itkExceptionMacro("Number of output tensors names is " << m_OutputTensors.size() << - " but the number of output fields of expression is " << m_OutputExpressionFields.size()); - } + } ////////////////////////////////////////////////////////////////////////////////////////// // Get tensors information ////////////////////////////////////////////////////////////////////////////////////////// - - // Get input and output tensors datatypes and shapes - tf::GetTensorAttributes(m_Graph, m_InputPlaceholders, m_InputTensorsShapes, m_InputTensorsDataTypes); - tf::GetTensorAttributes(m_Graph, m_OutputTensors, m_OutputTensorsShapes, m_OutputTensorsDataTypes); - - } + // Set all subelement of the model + auto signaturedef = this->GetSignatureDef(); + + // Given the inputs/outputs names that the user specified, get the names of the inputs/outputs contained in the model + // and other infos (shapes, dtypes) + // For example, for output names specified by the user m_OutputTensors = ['s2t', 's2t_pad'], + // this will return m_OutputLayers = ['PartitionedCall:0', 'PartitionedCall:1'] + // In case the user hasn't named the output, e.g. m_OutputTensors = [''], + // this will return the first output m_OutputLayers = ['PartitionedCall:0'] + tf::GetTensorAttributes(signaturedef.inputs(), m_InputPlaceholders, m_InputLayers, m_InputTensorsShapes, m_InputTensorsDataTypes); + tf::GetTensorAttributes(signaturedef.outputs(), m_OutputTensors, m_OutputLayers, m_OutputTensorsShapes, m_OutputTensorsDataTypes); +} } // end namespace otb diff --git a/include/otbTensorflowMultisourceModelFilter.h b/include/otbTensorflowMultisourceModelFilter.h index 46a273af..36d781dd 100644 --- a/include/otbTensorflowMultisourceModelFilter.h +++ b/include/otbTensorflowMultisourceModelFilter.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -17,15 +17,13 @@ // Iterator #include "itkImageRegionConstIteratorWithOnlyIndex.h" -// Tensorflow helpers -#include "otbTensorflowGraphOperations.h" -#include "otbTensorflowDataTypeBridge.h" -#include "otbTensorflowCopyUtils.h" - // Tile hint #include "itkMetaDataObject.h" #include "otbMetaDataKey.h" +// OTB log +#include "otbMacro.h" + namespace otb { diff --git a/include/otbTensorflowMultisourceModelFilter.hxx b/include/otbTensorflowMultisourceModelFilter.hxx index 91c5384e..d208f01a 100644 --- a/include/otbTensorflowMultisourceModelFilter.hxx +++ b/include/otbTensorflowMultisourceModelFilter.hxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -201,15 +201,19 @@ TensorflowMultisourceModelFilter ////////////////////////////////////////////////////////////////////////////////////////// // If the output spacing is not specified, we use the first input image as grid reference - m_OutputSpacing = this->GetInput(0)->GetSignedSpacing(); + // OTBTF assumes that the output image has the following geometric properties: + // (1) Image origin is the top-left pixel + // (2) Image pixel spacing has positive x-spacing and negative y-spacing + m_OutputSpacing = this->GetInput(0)->GetSpacing(); // GetSpacing() returns abs. spacing + m_OutputSpacing[1] *= -1.0; // Force negative y-spacing m_OutputSpacing[0] *= m_OutputSpacingScale; m_OutputSpacing[1] *= m_OutputSpacingScale; - PointType extentInf, extentSup; - extentSup.Fill(itk::NumericTraits::max()); - extentInf.Fill(itk::NumericTraits::NonpositiveMin()); // Compute the extent of each input images and update the extent or the output image. // The extent of the output image is the intersection of all input images extents. + PointType extentInf, extentSup; + extentSup.Fill(itk::NumericTraits::max()); + extentInf.Fill(itk::NumericTraits::NonpositiveMin()); for (unsigned int imageIndex = 0 ; imageIndex < this->GetNumberOfInputs() ; imageIndex++) { ImageType * currentImage = static_cast( @@ -267,17 +271,22 @@ TensorflowMultisourceModelFilter unsigned int outputPixelSize = 0; for (auto& protoShape: this->GetOutputTensorsShapes()) { - // The number of components per pixel is the last dimension of the tensor - int dim_size = protoShape.dim_size(); - unsigned int nComponents = 1; - if (1 < dim_size && dim_size <= 4) - { - nComponents = protoShape.dim(dim_size-1).size(); - } - else if (dim_size > 4) + // Find the number of components + if (protoShape.dim_size() > 4) { - itkExceptionMacro("Dim_size=" << dim_size << " currently not supported."); + itkExceptionMacro("dim_size=" << protoShape.dim_size() << " currently not supported. " + "Keep in mind that output tensors must have 1, 2, 3 or 4 dimensions. " + "In the case of 1-dimensional tensor, the first dimension is for the batch, " + "and we assume that the output tensor has 1 channel. " + "In the case of 2-dimensional tensor, the first dimension is for the batch, " + "and the second is the number of components. " + "In the case of 3-dimensional tensor, the first dimension is for the batch, " + "and other dims are for (x, y). " + "In the case of 4-dimensional tensor, the first dimension is for the batch, " + "and the second and the third are for (x, y). The last is for the number of " + "channels. "); } + unsigned int nComponents = tf::GetNumberOfChannelsFromShapeProto(protoShape); outputPixelSize += nComponents; } @@ -329,7 +338,7 @@ TensorflowMultisourceModelFilter if (!OutputRegionToInputRegion(requestedRegion, inRegion, inputImage) ) { // Image does not overlap requested region: set requested region to null - itkDebugMacro( << "Image #" << i << " :\n" << inRegion << " is outside the requested region"); + otbLogMacro(Debug, << "Image #" << i << " :\n" << inRegion << " is outside the requested region"); inRegion.GetModifiableIndex().Fill(0); inRegion.GetModifiableSize().Fill(0); } @@ -393,6 +402,7 @@ TensorflowMultisourceModelFilter // Create input tensors list DictType inputs; + // Populate input tensors for (unsigned int i = 0 ; i < nInputs ; i++) { @@ -462,6 +472,7 @@ TensorflowMultisourceModelFilter } // next input tensor // Run session + // TODO: see if we print some info about inputs/outputs of the model e.g. m_OutputTensors TensorListType outputs; this->RunSession(inputs, outputs); @@ -485,7 +496,7 @@ TensorflowMultisourceModelFilter catch( itk::ExceptionObject & err ) { std::stringstream debugMsg = this->GenerateDebugReport(inputs); - itkExceptionMacro("Error occured during tensor to image conversion.\n" + itkExceptionMacro("Error occurred during tensor to image conversion.\n" << "Context: " << debugMsg.str() << "Error:" << err); } diff --git a/include/otbTensorflowMultisourceModelLearningBase.h b/include/otbTensorflowMultisourceModelLearningBase.h index ba130453..0663f17a 100644 --- a/include/otbTensorflowMultisourceModelLearningBase.h +++ b/include/otbTensorflowMultisourceModelLearningBase.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -100,7 +100,7 @@ public TensorflowMultisourceModelBase TensorflowMultisourceModelLearningBase(); virtual ~TensorflowMultisourceModelLearningBase() {}; - virtual void GenerateOutputInformation(void); + virtual void GenerateOutputInformation(void) override; virtual void GenerateInputRequestedRegion(); diff --git a/include/otbTensorflowMultisourceModelLearningBase.hxx b/include/otbTensorflowMultisourceModelLearningBase.hxx index 35347829..28b2328b 100644 --- a/include/otbTensorflowMultisourceModelLearningBase.hxx +++ b/include/otbTensorflowMultisourceModelLearningBase.hxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even diff --git a/include/otbTensorflowMultisourceModelTrain.h b/include/otbTensorflowMultisourceModelTrain.h index 8f8983ca..8ec4c38c 100644 --- a/include/otbTensorflowMultisourceModelTrain.h +++ b/include/otbTensorflowMultisourceModelTrain.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even diff --git a/include/otbTensorflowMultisourceModelTrain.hxx b/include/otbTensorflowMultisourceModelTrain.hxx index e7b68dac..272dd639 100644 --- a/include/otbTensorflowMultisourceModelTrain.hxx +++ b/include/otbTensorflowMultisourceModelTrain.hxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even diff --git a/include/otbTensorflowMultisourceModelValidate.h b/include/otbTensorflowMultisourceModelValidate.h index f4a95406..322f6a24 100644 --- a/include/otbTensorflowMultisourceModelValidate.h +++ b/include/otbTensorflowMultisourceModelValidate.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even diff --git a/include/otbTensorflowMultisourceModelValidate.hxx b/include/otbTensorflowMultisourceModelValidate.hxx index c7673c8f..8ec685ba 100644 --- a/include/otbTensorflowMultisourceModelValidate.hxx +++ b/include/otbTensorflowMultisourceModelValidate.hxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -98,7 +98,7 @@ TensorflowMultisourceModelValidate /** * Perform the validation * The session is ran over the entire set of batches. - * Output is then validated agains the references images, + * Output is then validated against the references images, * and a confusion matrix is built. */ template diff --git a/include/otbTensorflowSampler.h b/include/otbTensorflowSampler.h index d71eba7a..bd363bc8 100644 --- a/include/otbTensorflowSampler.h +++ b/include/otbTensorflowSampler.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even diff --git a/include/otbTensorflowSampler.hxx b/include/otbTensorflowSampler.hxx index 9611d934..8c0ea745 100644 --- a/include/otbTensorflowSampler.hxx +++ b/include/otbTensorflowSampler.hxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -154,7 +154,7 @@ TensorflowSampler m_OutputPatchImages.push_back(newImage); } - itk::ProgressReporter progess(this, 0, nTotal); + itk::ProgressReporter progress(this, 0, nTotal); // Iterate on the vector data itVector.GoToBegin(); @@ -218,8 +218,8 @@ TensorflowSampler rejected++; } - // Update progres - progess.CompletedPixel(); + // Update progress + progress.CompletedPixel(); } diff --git a/include/otbTensorflowSamplingUtils.cxx b/include/otbTensorflowSamplingUtils.cxx index 5a8b8e3c..5cf88f6b 100644 --- a/include/otbTensorflowSamplingUtils.cxx +++ b/include/otbTensorflowSamplingUtils.cxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even diff --git a/include/otbTensorflowSamplingUtils.h b/include/otbTensorflowSamplingUtils.h index 93879301..585f9013 100644 --- a/include/otbTensorflowSamplingUtils.h +++ b/include/otbTensorflowSamplingUtils.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -27,23 +27,17 @@ class Distribution typedef typename TImage::PixelType ValueType; typedef vnl_vector CountsType; - Distribution(unsigned int nClasses){ - m_NbOfClasses = nClasses; - m_Dist = CountsType(nClasses, 0); - + explicit Distribution(unsigned int nClasses): m_NbOfClasses(nClasses), m_Dist(CountsType(nClasses, 0)) + { } - Distribution(unsigned int nClasses, float fillValue){ - m_NbOfClasses = nClasses; - m_Dist = CountsType(nClasses, fillValue); - + Distribution(unsigned int nClasses, float fillValue): m_NbOfClasses(nClasses), m_Dist(CountsType(nClasses, fillValue)) + { } - Distribution(){ - m_NbOfClasses = 2; - m_Dist = CountsType(m_NbOfClasses, 0); + Distribution(): m_NbOfClasses(2), m_Dist(CountsType(m_NbOfClasses, 0)) + { } - Distribution(const Distribution & other){ - m_Dist = other.Get(); - m_NbOfClasses = m_Dist.size(); + Distribution(const Distribution & other): m_Dist(other.Get()), m_NbOfClasses(m_Dist.size()) + { } ~Distribution(){} diff --git a/include/otbTensorflowSource.h b/include/otbTensorflowSource.h index f569c720..1556997f 100644 --- a/include/otbTensorflowSource.h +++ b/include/otbTensorflowSource.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -25,9 +25,9 @@ namespace otb { /* - * This is a simple helper to create images concatenation. + * This is a helper for images concatenation. * Images must have the same size. - * This is basically the common input type used in every OTB-TF applications. + * This is the common input type used in every OTB-TF applications. */ template class TensorflowSource @@ -60,7 +60,7 @@ class TensorflowSource // Get the source output FloatVectorImagePointerType Get(); - TensorflowSource(){}; + TensorflowSource(); virtual ~TensorflowSource (){}; private: diff --git a/include/otbTensorflowSource.hxx b/include/otbTensorflowSource.hxx index bb3de775..2ad57586 100644 --- a/include/otbTensorflowSource.hxx +++ b/include/otbTensorflowSource.hxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even @@ -17,12 +17,21 @@ namespace otb { +// +// Constructor +// +template +TensorflowSource +::TensorflowSource() +{} + // // Prepare the big stack of images // template void -TensorflowSource::Set(FloatVectorImageListType * inputList) +TensorflowSource +::Set(FloatVectorImageListType * inputList) { // Create one stack for input images list m_Concatener = ListConcatenerFilterType::New(); diff --git a/include/otbTensorflowStreamerFilter.h b/include/otbTensorflowStreamerFilter.h index eee73f87..4730d369 100644 --- a/include/otbTensorflowStreamerFilter.h +++ b/include/otbTensorflowStreamerFilter.h @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even diff --git a/include/otbTensorflowStreamerFilter.hxx b/include/otbTensorflowStreamerFilter.hxx index 5e5622be..59904a54 100644 --- a/include/otbTensorflowStreamerFilter.hxx +++ b/include/otbTensorflowStreamerFilter.hxx @@ -1,7 +1,7 @@ /*========================================================================= - Copyright (c) 2018-2019 Remi Cresson (IRSTEA) - Copyright (c) 2020-2021 Remi Cresson (INRAE) + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE This software is distributed WITHOUT ANY WARRANTY; without even diff --git a/python/ckpt2savedmodel.py b/python/ckpt2savedmodel.py index cbb72bb9..117203ba 100755 --- a/python/ckpt2savedmodel.py +++ b/python/ckpt2savedmodel.py @@ -2,8 +2,8 @@ # -*- coding: utf-8 -*- # ========================================================================== # -# Copyright 2018-2019 Remi Cresson (IRSTEA) -# Copyright 2020 Remi Cresson (INRAE) +# Copyright 2018-2019 IRSTEA +# Copyright 2020-2021 INRAE # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,23 +18,37 @@ # limitations under the License. # # ==========================================================================*/ +""" +This application converts a checkpoint into a SavedModel, that can be used in +TensorflowModelTrain or TensorflowModelServe OTB applications. +This is intended to work mostly with tf.v1 models, since the models in tf.v2 +can be more conveniently exported as SavedModel (see how to build a model with +keras in Tensorflow 2). +""" import argparse from tricks import ckpt_to_savedmodel -# Parser -parser = argparse.ArgumentParser() -parser.add_argument("--ckpt", help="Checkpoint file (without the \".meta\" extension)", required=True) -parser.add_argument("--inputs", help="Inputs names (e.g. [\"x_cnn_1:0\", \"x_cnn_2:0\"])", required=True, nargs='+') -parser.add_argument("--outputs", help="Outputs names (e.g. [\"prediction:0\", \"features:0\"])", required=True, - nargs='+') -parser.add_argument("--model", help="Output directory for SavedModel", required=True) -parser.add_argument('--clear_devices', dest='clear_devices', action='store_true') -parser.set_defaults(clear_devices=False) -params = parser.parse_args() -if __name__ == "__main__": +def main(): + """ + Main function + """ + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt", help="Checkpoint file (without the \".meta\" extension)", required=True) + parser.add_argument("--inputs", help="Inputs names (e.g. [\"x_cnn_1:0\", \"x_cnn_2:0\"])", required=True, nargs='+') + parser.add_argument("--outputs", help="Outputs names (e.g. [\"prediction:0\", \"features:0\"])", required=True, + nargs='+') + parser.add_argument("--model", help="Output directory for SavedModel", required=True) + parser.add_argument('--clear_devices', dest='clear_devices', action='store_true') + parser.set_defaults(clear_devices=False) + params = parser.parse_args() + ckpt_to_savedmodel(ckpt_path=params.ckpt, inputs=params.inputs, outputs=params.outputs, savedmodel_path=params.model, clear_devices=params.clear_devices) + + +if __name__ == "__main__": + main() diff --git a/python/create_savedmodel_ienco-m3_patchbased.py b/python/examples/tensorflow_v1x/create_savedmodel_ienco-m3_patchbased.py similarity index 99% rename from python/create_savedmodel_ienco-m3_patchbased.py rename to python/examples/tensorflow_v1x/create_savedmodel_ienco-m3_patchbased.py index fdb77227..2a3ad56f 100755 --- a/python/create_savedmodel_ienco-m3_patchbased.py +++ b/python/examples/tensorflow_v1x/create_savedmodel_ienco-m3_patchbased.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# ========================================================================== +# ========================================================================= # # Copyright 2018-2019 Remi Cresson, Dino Ienco (IRSTEA) # Copyright 2020-2021 Remi Cresson, Dino Ienco (INRAE) @@ -17,7 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# ==========================================================================*/ +# ========================================================================= # Reference: # @@ -26,12 +26,12 @@ # Satellite Data Fusion. IEEE Journal of Selected Topics in Applied Earth # Observations and Remote Sensing, 11(12), 4939-4949. +import argparse from tricks import create_savedmodel import tensorflow.compat.v1 as tf import tensorflow.compat.v1.nn.rnn_cell as rnn -tf.disable_v2_behavior() -import argparse +tf.disable_v2_behavior() parser = argparse.ArgumentParser() parser.add_argument("--nunits", type=int, default=1024, help="number of units") @@ -63,7 +63,7 @@ def RnnAttention(x, nunits, nlayer, n_dims, n_timetamps, is_training_ph): cell = rnn.GRUCell(nunits) cells.append(cell) cell = tf.compat.v1.contrib.rnn.MultiRNNCell(cells) - # SIGNLE LAYER: single GRUCell, nunits hidden units each + # SINGLE LAYER: single GRUCell, nunits hidden units each else: cell = rnn.GRUCell(nunits) outputs, _ = tf.compat.v1.nn.static_rnn(cell, x, dtype="float32") diff --git a/python/create_savedmodel_maggiori17_fullyconv.py b/python/examples/tensorflow_v1x/create_savedmodel_maggiori17_fullyconv.py similarity index 95% rename from python/create_savedmodel_maggiori17_fullyconv.py rename to python/examples/tensorflow_v1x/create_savedmodel_maggiori17_fullyconv.py index 32843e76..7c2bed5c 100755 --- a/python/create_savedmodel_maggiori17_fullyconv.py +++ b/python/examples/tensorflow_v1x/create_savedmodel_maggiori17_fullyconv.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -#========================================================================== +# ========================================================================= # # Copyright 2018-2019 Remi Cresson (IRSTEA) # Copyright 2020-2021 Remi Cresson (INRAE) @@ -17,7 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -#==========================================================================*/ +# ========================================================================= # Reference: # @@ -79,7 +79,7 @@ activation=tf.nn.crelu) # Deconv = conv on the padded/strided input, that is an (5+1)*4 - deconv1 = tf.compat.v1.layers.conv2d_transpose(inputs=conv4, filters=1, strides=(4,4), kernel_size=[8, 8], + deconv1 = tf.compat.v1.layers.conv2d_transpose(inputs=conv4, filters=1, strides=(4, 4), kernel_size=[8, 8], padding="valid", activation=tf.nn.sigmoid) n = tf.shape(deconv1)[0] diff --git a/python/create_savedmodel_pxs_fcn.py b/python/examples/tensorflow_v1x/create_savedmodel_pxs_fcn.py similarity index 100% rename from python/create_savedmodel_pxs_fcn.py rename to python/examples/tensorflow_v1x/create_savedmodel_pxs_fcn.py diff --git a/python/create_savedmodel_simple_cnn.py b/python/examples/tensorflow_v1x/create_savedmodel_simple_cnn.py similarity index 100% rename from python/create_savedmodel_simple_cnn.py rename to python/examples/tensorflow_v1x/create_savedmodel_simple_cnn.py diff --git a/python/create_savedmodel_simple_fcn.py b/python/examples/tensorflow_v1x/create_savedmodel_simple_fcn.py similarity index 100% rename from python/create_savedmodel_simple_fcn.py rename to python/examples/tensorflow_v1x/create_savedmodel_simple_fcn.py diff --git a/python/otbtf.py b/python/otbtf.py index 6a5e35d9..a23d5237 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- # ========================================================================== # -# Copyright 2018-2019 Remi Cresson (IRSTEA) -# Copyright 2020 Remi Cresson (INRAE) +# Copyright 2018-2019 IRSTEA +# Copyright 2020-2021 INRAE # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,57 +17,57 @@ # limitations under the License. # # ==========================================================================*/ +""" +Contains stuff to help working with TensorFlow and geospatial data in the +OTBTF framework. +""" import threading import multiprocessing import time +import logging +from abc import ABC, abstractmethod import numpy as np import tensorflow as tf import gdal -import logging -from abc import ABC, abstractmethod -""" -------------------------------------------------------- Helpers -------------------------------------------------------- -""" +# ----------------------------------------------------- Helpers -------------------------------------------------------- def gdal_open(filename): """ Open a GDAL raster :param filename: raster file - :return: a GDAL ds instance + :return: a GDAL dataset instance """ - ds = gdal.Open(filename) - if ds is None: + gdal_ds = gdal.Open(filename) + if gdal_ds is None: raise Exception("Unable to open file {}".format(filename)) - return ds + return gdal_ds -def read_as_np_arr(ds, as_patches=True): +def read_as_np_arr(gdal_ds, as_patches=True): """ Read a GDAL raster as numpy array - :param ds: GDAL ds instance + :param gdal_ds: a GDAL dataset instance :param as_patches: if True, the returned numpy array has the following shape (n, psz_x, psz_x, nb_channels). If False, the shape is (1, psz_y, psz_x, nb_channels) :return: Numpy array of dim 4 """ - buffer = ds.ReadAsArray() - szx = ds.RasterXSize + buffer = gdal_ds.ReadAsArray() + size_x = gdal_ds.RasterXSize if len(buffer.shape) == 3: buffer = np.transpose(buffer, axes=(1, 2, 0)) if not as_patches: - n = 1 - szy = ds.RasterYSize + n_elems = 1 + size_y = gdal_ds.RasterYSize else: - n = int(ds.RasterYSize / szx) - szy = szx - return np.float32(buffer.reshape((n, szy, szx, ds.RasterCount))) + n_elems = int(gdal_ds.RasterYSize / size_x) + size_y = size_x + return np.float32(buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount))) -""" ----------------------------------------------------- Buffer class ------------------------------------------------------ -""" +# -------------------------------------------------- Buffer class ------------------------------------------------------ class Buffer: @@ -80,19 +80,27 @@ def __init__(self, max_length): self.container = [] def size(self): + """ + Returns the buffer size + """ return len(self.container) - def add(self, x): - self.container.append(x) - assert (self.size() <= self.max_length) + def add(self, new_element): + """ + Add an element in the buffer + :param new_element: new element to add + """ + self.container.append(new_element) + assert self.size() <= self.max_length def is_complete(self): + """ + Return True if the buffer is at full capacity + """ return self.size() == self.max_length -""" ------------------------------------------------- PatchesReaderBase class ----------------------------------------------- -""" +# ---------------------------------------------- PatchesReaderBase class ----------------------------------------------- class PatchesReaderBase(ABC): @@ -106,7 +114,6 @@ def get_sample(self, index): Return one sample. :return One sample instance, whatever the sample structure is (dict, numpy array, ...) """ - pass @abstractmethod def get_stats(self) -> dict: @@ -129,7 +136,6 @@ def get_stats(self) -> dict: "std": np.array([...])}, } """ - pass @abstractmethod def get_size(self): @@ -137,12 +143,9 @@ def get_size(self): Returns the total number of samples :return: number of samples (int) """ - pass -""" ------------------------------------------------ PatchesImagesReader class ---------------------------------------------- -""" +# --------------------------------------------- PatchesImagesReader class ---------------------------------------------- class PatchesImagesReader(PatchesReaderBase): @@ -174,66 +177,64 @@ def __init__(self, filenames_dict: dict, use_streaming=False): :param use_streaming: if True, the patches are read on the fly from the disc, nothing is kept in memory. """ - assert (len(filenames_dict.values()) > 0) + assert len(filenames_dict.values()) > 0 - # ds dict - self.ds = dict() - for src_key, src_filenames in filenames_dict.items(): - self.ds[src_key] = [] - for src_filename in src_filenames: - self.ds[src_key].append(gdal_open(src_filename)) + # gdal_ds dict + self.gdal_ds = {key: [gdal_open(src_fn) for src_fn in src_fns] for key, src_fns in filenames_dict.items()} - if len(set([len(ds_list) for ds_list in self.ds.values()])) != 1: + # check number of patches in each sources + if len({len(ds_list) for ds_list in self.gdal_ds.values()}) != 1: raise Exception("Each source must have the same number of patches images") # streaming on/off self.use_streaming = use_streaming - # ds check - nb_of_patches = {key: 0 for key in self.ds} + # gdal_ds check + nb_of_patches = {key: 0 for key in self.gdal_ds} self.nb_of_channels = dict() - for src_key, ds_list in self.ds.items(): - for ds in ds_list: - nb_of_patches[src_key] += self._get_nb_of_patches(ds) + for src_key, ds_list in self.gdal_ds.items(): + for gdal_ds in ds_list: + nb_of_patches[src_key] += self._get_nb_of_patches(gdal_ds) if src_key not in self.nb_of_channels: - self.nb_of_channels[src_key] = ds.RasterCount + self.nb_of_channels[src_key] = gdal_ds.RasterCount else: - if self.nb_of_channels[src_key] != ds.RasterCount: + if self.nb_of_channels[src_key] != gdal_ds.RasterCount: raise Exception("All patches images from one source must have the same number of channels!" "Error happened for source: {}".format(src_key)) if len(set(nb_of_patches.values())) != 1: raise Exception("Sources must have the same number of patches! Number of patches: {}".format(nb_of_patches)) - # ds sizes - src_key_0 = list(self.ds)[0] # first key - self.ds_sizes = [self._get_nb_of_patches(ds) for ds in self.ds[src_key_0]] + # gdal_ds sizes + src_key_0 = list(self.gdal_ds)[0] # first key + self.ds_sizes = [self._get_nb_of_patches(ds) for ds in self.gdal_ds[src_key_0]] self.size = sum(self.ds_sizes) # if use_streaming is False, we store in memory all patches images if not self.use_streaming: - patches_list = {src_key: [read_as_np_arr(ds) for ds in self.ds[src_key]] for src_key in self.ds} - self.patches_buffer = {src_key: np.concatenate(patches_list[src_key], axis=0) for src_key in self.ds} + patches_list = {src_key: [read_as_np_arr(ds) for ds in self.gdal_ds[src_key]] for src_key in self.gdal_ds} + self.patches_buffer = {src_key: np.concatenate(patches_list[src_key], axis=0) for src_key in self.gdal_ds} def _get_ds_and_offset_from_index(self, index): offset = index - for i, ds_size in enumerate(self.ds_sizes): + idx = None + for idx, ds_size in enumerate(self.ds_sizes): if offset < ds_size: break offset -= ds_size - return i, offset + return idx, offset @staticmethod - def _get_nb_of_patches(ds): - return int(ds.RasterYSize / ds.RasterXSize) + def _get_nb_of_patches(gdal_ds): + return int(gdal_ds.RasterYSize / gdal_ds.RasterXSize) @staticmethod - def _read_extract_as_np_arr(ds, offset): - assert (ds is not None) - psz = ds.RasterXSize + def _read_extract_as_np_arr(gdal_ds, offset): + assert gdal_ds is not None + psz = gdal_ds.RasterXSize yoff = int(offset * psz) - assert (yoff + psz <= ds.RasterYSize) - buffer = ds.ReadAsArray(0, yoff, psz, psz) + assert yoff + psz <= gdal_ds.RasterYSize + buffer = gdal_ds.ReadAsArray(0, yoff, psz, psz) if len(buffer.shape) == 3: buffer = np.transpose(buffer, axes=(1, 2, 0)) return np.float32(buffer) @@ -248,14 +249,14 @@ def get_sample(self, index): ... "src_key_M": np.array((psz_y_M, psz_x_M, nb_ch_M))} """ - assert (0 <= index) - assert (index < self.size) + assert index >= 0 + assert index < self.size if not self.use_streaming: - res = {src_key: self.patches_buffer[src_key][index, :, :, :] for src_key in self.ds} + res = {src_key: self.patches_buffer[src_key][index, :, :, :] for src_key in self.gdal_ds} else: i, offset = self._get_ds_and_offset_from_index(index) - res = {src_key: self._read_extract_as_np_arr(self.ds[src_key][i], offset) for src_key in self.ds} + res = {src_key: self._read_extract_as_np_arr(self.gdal_ds[src_key][i], offset) for src_key in self.gdal_ds} return res @@ -278,7 +279,7 @@ def get_stats(self): axis = (0, 1) # (row, col) def _filled(value): - return {src_key: value * np.ones((self.nb_of_channels[src_key])) for src_key in self.ds} + return {src_key: value * np.ones((self.nb_of_channels[src_key])) for src_key in self.gdal_ds} _maxs = _filled(0.0) _mins = _filled(float("inf")) @@ -298,17 +299,15 @@ def _filled(value): "max": _maxs[src_key], "mean": rsize * _sums[src_key], "std": np.sqrt(rsize * _sqsums[src_key] - np.square(rsize * _sums[src_key])) - } for src_key in self.ds} - logging.info("Stats: {}".format(stats)) + } for src_key in self.gdal_ds} + logging.info("Stats: {}", stats) return stats def get_size(self): return self.size -""" -------------------------------------------------- IteratorBase class --------------------------------------------------- -""" +# ----------------------------------------------- IteratorBase class --------------------------------------------------- class IteratorBase(ABC): @@ -320,9 +319,7 @@ def __init__(self, patches_reader: PatchesReaderBase): pass -""" ------------------------------------------------- RandomIterator class -------------------------------------------------- -""" +# ---------------------------------------------- RandomIterator class -------------------------------------------------- class RandomIterator(IteratorBase): @@ -352,9 +349,7 @@ def _shuffle(self): np.random.shuffle(self.indices) -""" ---------------------------------------------------- Dataset class ------------------------------------------------------ -""" +# ------------------------------------------------- Dataset class ------------------------------------------------------ class Dataset: @@ -389,10 +384,12 @@ def __init__(self, patches_reader: PatchesReaderBase, buffer_length: int = 128, self.output_shapes[src_key] = np_arr.shape self.output_types[src_key] = tf.dtypes.as_dtype(np_arr.dtype) - logging.info("output_types: {}".format(self.output_types)) - logging.info("output_shapes: {}".format(self.output_shapes)) + logging.info("output_types: {}", self.output_types) + logging.info("output_shapes: {}", self.output_shapes) # buffers + if self.size <= buffer_length: + buffer_length = self.size self.miner_buffer = Buffer(buffer_length) self.mining_lock = multiprocessing.Lock() self.consumer_buffer = Buffer(buffer_length) @@ -434,12 +431,12 @@ def _dump(self): This function dumps the miner_buffer into the consumer_buffer, and restart the miner_thread """ # Wait for miner to finish his job - t = time.time() + date_t = time.time() self.miner_thread.join() - self.tot_wait += time.time() - t + self.tot_wait += time.time() - date_t # Copy miner_buffer.container --> consumer_buffer.container - self.consumer_buffer.container = [elem for elem in self.miner_buffer.container] + self.consumer_buffer.container = self.miner_buffer.container.copy() # Clear miner_buffer.container self.miner_buffer.container.clear() @@ -454,27 +451,24 @@ def _collect(self): """ # Fill the miner_container until it's full while not self.miner_buffer.is_complete(): - try: - index = next(self.iterator) - with self.mining_lock: - new_sample = self.patches_reader.get_sample(index=index) - self.miner_buffer.add(new_sample) - except Exception as e: - logging.warning("Error during collecting samples: {}".format(e)) + index = next(self.iterator) + with self.mining_lock: + new_sample = self.patches_reader.get_sample(index=index) + self.miner_buffer.add(new_sample) def _summon_miner_thread(self): """ Create and starts the thread for the data collect """ - t = threading.Thread(target=self._collect) - t.start() - return t + new_thread = threading.Thread(target=self._collect) + new_thread.start() + return new_thread def _generator(self): """ Generator function, used for the tf dataset """ - for elem in range(self.size): + for _ in range(self.size): yield self.read_one_sample() def get_tf_dataset(self, batch_size, drop_remainder=True): @@ -486,7 +480,7 @@ def get_tf_dataset(self, batch_size, drop_remainder=True): """ if batch_size <= 2 * self.miner_buffer.max_length: logging.warning("Batch size is {} but dataset buffer has {} elements. Consider using a larger dataset " - "buffer to avoid I/O bottleneck".format(batch_size, self.miner_buffer.max_length)) + "buffer to avoid I/O bottleneck", batch_size, self.miner_buffer.max_length) return self.tf_dataset.batch(batch_size, drop_remainder=drop_remainder) def get_total_wait_in_seconds(self): @@ -497,9 +491,7 @@ def get_total_wait_in_seconds(self): return self.tot_wait -""" -------------------------------------------- DatasetFromPatchesImages class --------------------------------------------- -""" +# ----------------------------------------- DatasetFromPatchesImages class --------------------------------------------- class DatasetFromPatchesImages(Dataset): diff --git a/python/tricks.py b/python/tricks.py index fe4c5dea..b31b14c3 100644 --- a/python/tricks.py +++ b/python/tricks.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- # ========================================================================== # -# Copyright 2018-2019 Remi Cresson (IRSTEA) -# Copyright 2020 Remi Cresson (INRAE) +# Copyright 2018-2019 IRSTEA +# Copyright 2020-2021 INRAE # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,54 +17,42 @@ # limitations under the License. # # ==========================================================================*/ -import gdal -import numpy as np +""" +This module contains a set of python functions to interact with geospatial data +and TensorFlow models. +Starting from OTBTF >= 3.0.0, tricks is only used as a backward compatible stub +for TF 1.X versions. +""" import tensorflow.compat.v1 as tf from deprecated import deprecated - +from otbtf import gdal_open, read_as_np_arr as read_as_np_arr_from_gdal_ds tf.disable_v2_behavior() +@deprecated(version="3.0.0", reason="Please use otbtf.read_image_as_np() instead") def read_image_as_np(filename, as_patches=False): """ - Read an image as numpy array. - @param filename File name of patches image - @param as_patches True if the image must be read as patches - @return 4D numpy array [batch, h, w, c] + Read a patches-image as numpy array. + :param filename: File name of the patches-image + :param as_patches: True if the image must be read as patches + :return 4D numpy array [batch, h, w, c] (batch = 1 when as_patches is False) """ # Open a GDAL dataset - ds = gdal.Open(filename) - if ds is None: - raise Exception("Unable to open file {}".format(filename)) - - # Raster infos - n_bands = ds.RasterCount - szx = ds.RasterXSize - szy = ds.RasterYSize - - # Raster array - myarray = ds.ReadAsArray() - - # Re-order bands (when there is > 1 band) - if (len(myarray.shape) == 3): - axes = (1, 2, 0) - myarray = np.transpose(myarray, axes=axes) + gdal_ds = gdal_open(filename) - if (as_patches): - n = int(szy / szx) - return myarray.reshape((n, szx, szx, n_bands)) - - return myarray.reshape((1, szy, szx, n_bands)) + # Return patches + return read_as_np_arr_from_gdal_ds(gdal_ds=gdal_ds, as_patches=as_patches) +@deprecated(version="3.0.0", reason="Please consider using TensorFlow >= 2 to build your nets") def create_savedmodel(sess, inputs, outputs, directory): """ - Create a SavedModel - @param sess TF session - @param inputs List of inputs names (e.g. ["x_cnn_1:0", "x_cnn_2:0"]) - @param outputs List of outputs names (e.g. ["prediction:0", "features:0"]) - @param directory Path for the generated SavedModel + Create a SavedModel from TF 1.X graphs + :param sess: The Tensorflow V1 session + :param inputs: List of inputs names (e.g. ["x_cnn_1:0", "x_cnn_2:0"]) + :param outputs: List of outputs names (e.g. ["prediction:0", "features:0"]) + :param directory: Path for the generated SavedModel """ print("Create a SavedModel in " + directory) graph = tf.compat.v1.get_default_graph() @@ -72,14 +60,16 @@ def create_savedmodel(sess, inputs, outputs, directory): outputs_names = {o: graph.get_tensor_by_name(o) for o in outputs} tf.compat.v1.saved_model.simple_save(sess, directory, inputs=inputs_names, outputs=outputs_names) + +@deprecated(version="3.0.0", reason="Please consider using TensorFlow >= 2 to build and save your nets") def ckpt_to_savedmodel(ckpt_path, inputs, outputs, savedmodel_path, clear_devices=False): """ - Read a Checkpoint and build a SavedModel - @param ckpt_path Path to the checkpoint file (without the ".meta" extension) - @param inputs List of inputs names (e.g. ["x_cnn_1:0", "x_cnn_2:0"]) - @param outputs List of outputs names (e.g. ["prediction:0", "features:0"]) - @param savedmodel_path Path for the generated SavedModel - @param clear_devices Clear TF devices positionning (True/False) + Read a Checkpoint and build a SavedModel for some TF 1.X graph + :param ckpt_path: Path to the checkpoint file (without the ".meta" extension) + :param inputs: List of inputs names (e.g. ["x_cnn_1:0", "x_cnn_2:0"]) + :param outputs: List of outputs names (e.g. ["prediction:0", "features:0"]) + :param savedmodel_path: Path for the generated SavedModel + :param clear_devices: Clear TensorFlow devices positioning (True/False) """ tf.compat.v1.reset_default_graph() with tf.compat.v1.Session() as sess: @@ -90,33 +80,17 @@ def ckpt_to_savedmodel(ckpt_path, inputs, outputs, savedmodel_path, clear_device # Create a SavedModel create_savedmodel(sess, inputs=inputs, outputs=outputs, directory=savedmodel_path) -@deprecated + +@deprecated(version="3.0.0", reason="Please use otbtf.read_image_as_np() instead") def read_samples(filename): - """ + """ Read a patches image. @param filename: raster file name """ - return read_image_as_np(filename, as_patches=True) + return read_image_as_np(filename, as_patches=True) -@deprecated -def CreateSavedModel(sess, inputs, outputs, directory): - """ - Create a SavedModel - @param sess TF session - @param inputs List of inputs names (e.g. ["x_cnn_1:0", "x_cnn_2:0"]) - @param outputs List of outputs names (e.g. ["prediction:0", "features:0"]) - @param directory Path for the generated SavedModel - """ - create_savedmodel(sess, inputs, outputs, directory) -@deprecated -def CheckpointToSavedModel(ckpt_path, inputs, outputs, savedmodel_path, clear_devices=False): - """ - Read a Checkpoint and build a SavedModel - @param ckpt_path Path to the checkpoint file (without the ".meta" extension) - @param inputs List of inputs names (e.g. ["x_cnn_1:0", "x_cnn_2:0"]) - @param outputs List of outputs names (e.g. ["prediction:0", "features:0"]) - @param savedmodel_path Path for the generated SavedModel - @param clear_devices Clear TF devices positionning (True/False) - """ - ckpt_to_savedmodel(ckpt_path, inputs, outputs, savedmodel_path, clear_devices) +# Aliases for backward compatibility +# pylint: disable=invalid-name +CreateSavedModel = create_savedmodel +CheckpointToSavedModel = ckpt_to_savedmodel diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index f2f176f4..a4260dfe 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,5 +1,24 @@ otb_module_test() +# Unit tests +set(${otb-module}Tests + otbTensorflowTests.cxx + otbTensorflowCopyUtilsTests.cxx) + +add_executable(otbTensorflowTests ${${otb-module}Tests}) + +target_include_directories(otbTensorflowTests PRIVATE ${tensorflow_include_dir}) +target_link_libraries(otbTensorflowTests ${${otb-module}-Test_LIBRARIES} ${TENSORFLOW_CC_LIB} ${TENSORFLOW_FRAMEWORK_LIB}) +otb_module_target_label(otbTensorflowTests) + +# CopyUtilsTests +otb_add_test(NAME floatValueToTensorTest COMMAND otbTensorflowTests floatValueToTensorTest) +otb_add_test(NAME intValueToTensorTest COMMAND otbTensorflowTests intValueToTensorTest) +otb_add_test(NAME boolValueToTensorTest COMMAND otbTensorflowTests boolValueToTensorTest) +otb_add_test(NAME floatVecValueToTensorTest COMMAND otbTensorflowTests floatVecValueToTensorTest) +otb_add_test(NAME intVecValueToTensorTest COMMAND otbTensorflowTests intVecValueToTensorTest) +otb_add_test(NAME boolVecValueToTensorTest COMMAND otbTensorflowTests boolVecValueToTensorTest) + # Directories set(DATADIR ${CMAKE_CURRENT_SOURCE_DIR}/data) set(MODELSDIR ${CMAKE_CURRENT_SOURCE_DIR}/models) @@ -9,12 +28,19 @@ set(IMAGEXS ${DATADIR}/xs_subset.tif) set(IMAGEPAN ${DATADIR}/pan_subset.tif) set(IMAGEPXS ${DATADIR}/pxs_subset.tif) set(IMAGEPXS2 ${DATADIR}/pxs_subset2.tif) +set(PATCHESA ${DATADIR}/Sentinel-2_B4328_10m_patches_A.jp2) +set(PATCHESB ${DATADIR}/Sentinel-2_B4328_10m_patches_B.jp2) +set(LABELSA ${DATADIR}/Sentinel-2_B4328_10m_labels_A.tif) +set(LABELSB ${DATADIR}/Sentinel-2_B4328_10m_labels_B.tif) +set(PATCHES01 ${DATADIR}/Sentinel-2_B4328_10m_patches_A.jp2) +set(PATCHES11 ${DATADIR}/Sentinel-2_B4328_10m_patches_B.jp2) # Input models set(MODEL1 ${MODELSDIR}/model1) set(MODEL2 ${MODELSDIR}/model2) set(MODEL3 ${MODELSDIR}/model3) set(MODEL4 ${MODELSDIR}/model4) +set(MODEL5 ${MODELSDIR}/model5) # Output images and baselines set(MODEL1_PB_OUT apTvClTensorflowModelServeCNN16x16PB.tif) @@ -23,6 +49,91 @@ set(MODEL2_FC_OUT apTvClTensorflowModelServeCNN8x8_32x32FC.tif) set(MODEL3_PB_OUT apTvClTensorflowModelServeFCNN16x16PB.tif) set(MODEL3_FC_OUT apTvClTensorflowModelServeFCNN16x16FC.tif) set(MODEL4_FC_OUT apTvClTensorflowModelServeFCNN64x64to32x32.tif) +set(MODEL1_SAVED model1_updated) +set(PATCHESIMG_01 patchimg_01.tif) +set(PATCHESIMG_11 patchimg_11.tif) +set(MODEL5_OUT reduce_sum.tif) + +#----------- Patches selection ---------------- +set(PATCHESPOS_01 ${TEMP}/out_train_32.gpkg) +set(PATCHESPOS_02 ${TEMP}/out_valid_32.gpkg) +set(PATCHESPOS_11 ${TEMP}/out_train_33.gpkg) +set(PATCHESPOS_12 ${TEMP}/out_valid_33.gpkg) +# Even patches +otb_test_application(NAME PatchesSelectionEven + APP PatchesSelection + OPTIONS + -in ${IMAGEPXS2} + -grid.step 32 + -grid.psize 32 + -outtrain ${PATCHESPOS_01} + -outvalid ${PATCHESPOS_02} + ) + +# Odd patches +otb_test_application(NAME PatchesSelectionOdd + APP PatchesSelection + OPTIONS + -in ${IMAGEPXS2} + -grid.step 32 + -grid.psize 33 + -outtrain ${PATCHESPOS_11} + -outvalid ${PATCHESPOS_12} + ) + +#----------- Patches extraction ---------------- +# Even patches +otb_test_application(NAME PatchesExtractionEven + APP PatchesExtraction + OPTIONS + -source1.il ${IMAGEPXS2} + -source1.patchsizex 32 + -source1.patchsizey 32 + -source1.out ${TEMP}/${PATCHESIMG_01} + -vec ${PATCHESPOS_01} + -field id + VALID --compare-image ${EPSILON_6} + ${DATADIR}/${PATCHESIMG_01} + ${TEMP}/${PATCHESIMG_01} + ) + +# Odd patches +otb_test_application(NAME PatchesExtractionOdd + APP PatchesExtraction + OPTIONS + -source1.il ${IMAGEPXS2} + -source1.patchsizex 33 + -source1.patchsizey 33 + -source1.out ${TEMP}/${PATCHESIMG_11} + -vec ${PATCHESPOS_11} + -field id + VALID --compare-image ${EPSILON_6} + ${DATADIR}/${PATCHESIMG_11} + ${TEMP}/${PATCHESIMG_11} + ) + +#----------- Model training : 1-branch CNN (16x16) Patch-Based ---------------- +set(ENV{OTB_LOGGER_LEVEL} DEBUG) +otb_test_application(NAME TensorflowModelTrainCNN16x16PB + APP TensorflowModelTrain + OPTIONS + -training.epochs 10 + -training.source1.il ${PATCHESA} + -training.source1.placeholder "x" + -training.source2.il ${LABELSA} + -training.source2.placeholder "y" + -validation.source1.il ${PATCHESB} + -validation.source2.il ${LABELSB} + -validation.source1.name "x" + -validation.source2.name "prediction" + -training.source1.patchsizex 16 -training.source1.patchsizey 16 + -training.source2.patchsizex 1 -training.source2.patchsizey 1 + -model.dir ${MODEL1} + -model.saveto ${MODEL1_SAVED} + -training.targetnodes "optimizer" + -validation.mode "class" + ) +set_tests_properties(TensorflowModelTrainCNN16x16PB PROPERTIES ENVIRONMENT "OTB_LOGGER_LEVEL=DEBUG;$ENV{OTB_LOGGER_LEVEL}") #----------- Model serving : 1-branch CNN (16x16) Patch-Based ---------------- otb_test_application(NAME TensorflowModelServeCNN16x16PB @@ -34,6 +145,7 @@ otb_test_application(NAME TensorflowModelServeCNN16x16PB VALID --compare-image ${EPSILON_6} ${DATADIR}/${MODEL1_PB_OUT} ${TEMP}/${MODEL1_PB_OUT}) +set_tests_properties(TensorflowModelServeCNN16x16PB PROPERTIES ENVIRONMENT "OTB_LOGGER_LEVEL=DEBUG}") #----------- Model serving : 2-branch CNN (8x8, 32x32) Patch-Based ---------------- otb_test_application(NAME apTvClTensorflowModelServeCNN8x8_32x32PB @@ -47,7 +159,8 @@ otb_test_application(NAME apTvClTensorflowModelServeCNN8x8_32x32PB VALID --compare-image ${EPSILON_6} ${DATADIR}/${MODEL2_PB_OUT} ${TEMP}/${MODEL2_PB_OUT}) -set_tests_properties(apTvClTensorflowModelServeCNN8x8_32x32PB PROPERTIES ENVIRONMENT "OTB_TF_NSOURCES=2;$ENV{OTB_TF_NSOURCES}") +set_tests_properties(apTvClTensorflowModelServeCNN8x8_32x32PB PROPERTIES ENVIRONMENT "OTB_LOGGER_LEVEL=DEBUG;OTB_TF_NSOURCES=2;$ENV{OTB_TF_NSOURCES}") + #----------- Model serving : 2-branch CNN (8x8, 32x32) Fully-Conv ---------------- set(ENV{OTB_TF_NSOURCES} 2) @@ -62,7 +175,7 @@ otb_test_application(NAME apTvClTensorflowModelServeCNN8x8_32x32FC VALID --compare-image ${EPSILON_6} ${DATADIR}/${MODEL2_FC_OUT} ${TEMP}/${MODEL2_FC_OUT}) -set_tests_properties(apTvClTensorflowModelServeCNN8x8_32x32FC PROPERTIES ENVIRONMENT "OTB_TF_NSOURCES=2;$ENV{OTB_TF_NSOURCES}") +set_tests_properties(apTvClTensorflowModelServeCNN8x8_32x32FC PROPERTIES ENVIRONMENT "OTB_LOGGER_LEVEL=DEBUG;OTB_TF_NSOURCES=2;$ENV{OTB_TF_NSOURCES}") #----------- Model serving : 1-branch FCNN (16x16) Patch-Based ---------------- set(ENV{OTB_TF_NSOURCES} 1) @@ -75,6 +188,8 @@ otb_test_application(NAME apTvClTensorflowModelServeFCNN16x16PB VALID --compare-image ${EPSILON_6} ${DATADIR}/${MODEL3_PB_OUT} ${TEMP}/${MODEL3_PB_OUT}) +set_tests_properties(apTvClTensorflowModelServeFCNN16x16PB PROPERTIES ENVIRONMENT "OTB_LOGGER_LEVEL=DEBUG}") + #----------- Model serving : 1-branch FCNN (16x16) Fully-conv ---------------- set(ENV{OTB_TF_NSOURCES} 1) @@ -87,10 +202,11 @@ otb_test_application(NAME apTvClTensorflowModelServeFCNN16x16FC VALID --compare-image ${EPSILON_6} ${DATADIR}/${MODEL3_FC_OUT} ${TEMP}/${MODEL3_FC_OUT}) +set_tests_properties(apTvClTensorflowModelServeFCNN16x16FC PROPERTIES ENVIRONMENT "OTB_LOGGER_LEVEL=DEBUG}") #----------- Model serving : 1-branch FCNN (64x64)-->(32x32), Fully-conv ---------------- set(ENV{OTB_TF_NSOURCES} 1) -otb_test_application(NAME apTvClTensorflowModelServeFCNN64x64to32x32.tif +otb_test_application(NAME apTvClTensorflowModelServeFCNN64x64to32x32 APP TensorflowModelServe OPTIONS -source1.il ${IMAGEPXS2} -source1.rfieldx 64 -source1.rfieldy 64 -source1.placeholder x @@ -100,5 +216,46 @@ otb_test_application(NAME apTvClTensorflowModelServeFCNN64x64to32x32.tif VALID --compare-image ${EPSILON_6} ${DATADIR}/${MODEL4_FC_OUT} ${TEMP}/${MODEL4_FC_OUT}) +set_tests_properties(apTvClTensorflowModelServeFCNN64x64to32x32 PROPERTIES ENVIRONMENT "OTB_LOGGER_LEVEL=DEBUG}") + +#----------- Test various output tensor shapes ---------------- +# We test the following output shapes on one monochannel image: +# [None] +# [None, 1] +# [None, None, None] +# [None, None, None, 1] +set(ENV{OTB_TF_NSOURCES} 1) +otb_test_application(NAME outputTensorShapesTest1pb + APP TensorflowModelServe + OPTIONS -source1.il ${IMAGEPAN} -model.dir ${MODEL5} -out ${TEMP}/${MODEL5_OUT} -output.names "tf.reshape" + VALID --compare-image ${EPSILON_6} ${IMAGEPAN} ${TEMP}/${MODEL5_OUT}) +otb_test_application(NAME outputTensorShapesTest1fc + APP TensorflowModelServe + OPTIONS -source1.il ${IMAGEPAN} -model.dir ${MODEL5} -model.fullyconv on -out ${TEMP}/${MODEL5_OUT} -output.names "tf.reshape" + VALID --compare-image ${EPSILON_6} ${IMAGEPAN} ${TEMP}/${MODEL5_OUT}) +otb_test_application(NAME outputTensorShapesTest2pb + APP TensorflowModelServe + OPTIONS -source1.il ${IMAGEPAN} -model.dir ${MODEL5} -out ${TEMP}/${MODEL5_OUT} -output.names "tf.reshape_1" + VALID --compare-image ${EPSILON_6} ${IMAGEPAN} ${TEMP}/${MODEL5_OUT}) +otb_test_application(NAME outputTensorShapesTest2fc + APP TensorflowModelServe + OPTIONS -source1.il ${IMAGEPAN} -model.dir ${MODEL5} -model.fullyconv on -out ${TEMP}/${MODEL5_OUT} -output.names "tf.reshape_1" + VALID --compare-image ${EPSILON_6} ${IMAGEPAN} ${TEMP}/${MODEL5_OUT}) +otb_test_application(NAME outputTensorShapesTest3pb + APP TensorflowModelServe + OPTIONS -source1.il ${IMAGEPAN} -model.dir ${MODEL5} -out ${TEMP}/${MODEL5_OUT} -output.names "tf.reshape_2" + VALID --compare-image ${EPSILON_6} ${IMAGEPAN} ${TEMP}/${MODEL5_OUT}) +otb_test_application(NAME outputTensorShapesTest3fc + APP TensorflowModelServe + OPTIONS -source1.il ${IMAGEPAN} -model.dir ${MODEL5} -model.fullyconv on -out ${TEMP}/${MODEL5_OUT} -output.names "tf.reshape_2" + VALID --compare-image ${EPSILON_6} ${IMAGEPAN} ${TEMP}/${MODEL5_OUT}) +otb_test_application(NAME outputTensorShapesTest4pb + APP TensorflowModelServe + OPTIONS -source1.il ${IMAGEPAN} -model.dir ${MODEL5} -out ${TEMP}/${MODEL5_OUT} -output.names "tf.reshape_3" + VALID --compare-image ${EPSILON_6} ${IMAGEPAN} ${TEMP}/${MODEL5_OUT}) +otb_test_application(NAME outputTensorShapesTest4fc + APP TensorflowModelServe + OPTIONS -source1.il ${IMAGEPAN} -model.dir ${MODEL5} -model.fullyconv on -out ${TEMP}/${MODEL5_OUT} -output.names "tf.reshape_3" + VALID --compare-image ${EPSILON_6} ${IMAGEPAN} ${TEMP}/${MODEL5_OUT}) diff --git a/test/data/RF_model_from_deep_features_map.tif b/test/data/RF_model_from_deep_features_map.tif new file mode 100644 index 00000000..8d01ec19 --- /dev/null +++ b/test/data/RF_model_from_deep_features_map.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7c603c85a80e6d59d18fe686a827e77066dbedb76c53fb6c35c7ed291d20cf34 +size 1316 diff --git a/test/data/Sentinel-2_B4328_10m_labels_A.tif b/test/data/Sentinel-2_B4328_10m_labels_A.tif new file mode 100644 index 00000000..63fa4e2e --- /dev/null +++ b/test/data/Sentinel-2_B4328_10m_labels_A.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0501bdcb5d8580ba9202aa2fb5fd057234387a0ffff12fbccc71d9899dd67463 +size 3286 diff --git a/test/data/Sentinel-2_B4328_10m_labels_B.tif b/test/data/Sentinel-2_B4328_10m_labels_B.tif new file mode 100644 index 00000000..0b0edfa0 --- /dev/null +++ b/test/data/Sentinel-2_B4328_10m_labels_B.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d430731c2560e40094cf9fe2d8d7c55ead665229f7ffef30b564258c1944beaf +size 3286 diff --git a/test/data/Sentinel-2_B4328_10m_patches_A.jp2 b/test/data/Sentinel-2_B4328_10m_patches_A.jp2 new file mode 100644 index 00000000..fb686f59 --- /dev/null +++ b/test/data/Sentinel-2_B4328_10m_patches_A.jp2 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d14d17d4a90d0d989623a5798056ba0b2074e6f595a408af45a9c35afbfdd84d +size 756503 diff --git a/test/data/Sentinel-2_B4328_10m_patches_B.jp2 b/test/data/Sentinel-2_B4328_10m_patches_B.jp2 new file mode 100644 index 00000000..e34e865f --- /dev/null +++ b/test/data/Sentinel-2_B4328_10m_patches_B.jp2 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:50d7fcfb6de506824143f998cdac3bde1921e1805cb2cc741d8145006722d4e3 +size 758859 diff --git a/test/data/amsterdam_labels_A.tif b/test/data/amsterdam_labels_A.tif new file mode 100644 index 00000000..a7b0dbcf --- /dev/null +++ b/test/data/amsterdam_labels_A.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bdd9ecd2e54992d712ee44fc05a91bc38f25bfa6c6cbcdc11874fe9070a4fc0f +size 11411 diff --git a/test/data/amsterdam_labels_B.tif b/test/data/amsterdam_labels_B.tif new file mode 100644 index 00000000..fff31bcf --- /dev/null +++ b/test/data/amsterdam_labels_B.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5ce48d4d57f46ca39d940ead160d7968f9b6c2a62b4f6b49132372d1a8916622 +size 11362 diff --git a/test/data/amsterdam_patches_A.tif b/test/data/amsterdam_patches_A.tif new file mode 100644 index 00000000..41323cf2 --- /dev/null +++ b/test/data/amsterdam_patches_A.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b8d606712ba9b1493ffd02c08135c5d1072d2bc1e64228d00425f17cc62280c9 +size 3840365 diff --git a/test/data/amsterdam_patches_B.tif b/test/data/amsterdam_patches_B.tif new file mode 100644 index 00000000..ace9beeb --- /dev/null +++ b/test/data/amsterdam_patches_B.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df49a0a8135854191006b0828d99fa494b059cfa15ef801b263d8520bbf8566e +size 3817177 diff --git a/test/data/apTvClTensorflowModelServeCNN16x16PB.tif b/test/data/apTvClTensorflowModelServeCNN16x16PB.tif index 8d936fcb..586d3609 100644 Binary files a/test/data/apTvClTensorflowModelServeCNN16x16PB.tif and b/test/data/apTvClTensorflowModelServeCNN16x16PB.tif differ diff --git a/test/data/apTvClTensorflowModelServeCNN8x8_32x32FC.tif b/test/data/apTvClTensorflowModelServeCNN8x8_32x32FC.tif index f6155aa6..4cad559f 100644 Binary files a/test/data/apTvClTensorflowModelServeCNN8x8_32x32FC.tif and b/test/data/apTvClTensorflowModelServeCNN8x8_32x32FC.tif differ diff --git a/test/data/apTvClTensorflowModelServeCNN8x8_32x32PB.tif b/test/data/apTvClTensorflowModelServeCNN8x8_32x32PB.tif index ec9c9d9f..459f5ad9 100644 Binary files a/test/data/apTvClTensorflowModelServeCNN8x8_32x32PB.tif and b/test/data/apTvClTensorflowModelServeCNN8x8_32x32PB.tif differ diff --git a/test/data/apTvClTensorflowModelServeFCNN16x16FC.tif b/test/data/apTvClTensorflowModelServeFCNN16x16FC.tif index 14a5d11a..5da8fc25 100644 Binary files a/test/data/apTvClTensorflowModelServeFCNN16x16FC.tif and b/test/data/apTvClTensorflowModelServeFCNN16x16FC.tif differ diff --git a/test/data/apTvClTensorflowModelServeFCNN16x16PB.tif b/test/data/apTvClTensorflowModelServeFCNN16x16PB.tif index 14a5d11a..5da8fc25 100644 Binary files a/test/data/apTvClTensorflowModelServeFCNN16x16PB.tif and b/test/data/apTvClTensorflowModelServeFCNN16x16PB.tif differ diff --git a/test/data/apTvClTensorflowModelServeFCNN64x64to32x32.tif b/test/data/apTvClTensorflowModelServeFCNN64x64to32x32.tif index 1d22a3b9..08b8b6d7 100644 Binary files a/test/data/apTvClTensorflowModelServeFCNN64x64to32x32.tif and b/test/data/apTvClTensorflowModelServeFCNN64x64to32x32.tif differ diff --git a/test/data/classif_model1.tif b/test/data/classif_model1.tif new file mode 100644 index 00000000..467f0133 --- /dev/null +++ b/test/data/classif_model1.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ac1e1b76123ee614062e445b173fd27f2e590782f7f653a32440d2e626ae09d9 +size 1082 diff --git a/test/data/classif_model2.tif b/test/data/classif_model2.tif new file mode 100644 index 00000000..d78613fd --- /dev/null +++ b/test/data/classif_model2.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f498681e24a61d796a307f5f2696468888995249ced64a875e06a406c3540128 +size 1510 diff --git a/test/data/classif_model3_fcn.tif b/test/data/classif_model3_fcn.tif new file mode 100644 index 00000000..9b6f406c --- /dev/null +++ b/test/data/classif_model3_fcn.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:97b84882b61de5d8a965503a30872a171f6548dabb302701d76a12459cdd2148 +size 852 diff --git a/test/data/classif_model3_pb.tif b/test/data/classif_model3_pb.tif new file mode 100644 index 00000000..8abc3c9d --- /dev/null +++ b/test/data/classif_model3_pb.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c067f03b3b0b7edfde73ccf53fd07f7e104082dea6754733917051c3a54937c6 +size 857 diff --git a/test/data/classif_model4.tif b/test/data/classif_model4.tif new file mode 100644 index 00000000..76269c5b --- /dev/null +++ b/test/data/classif_model4.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2496ffad2c40ac61d8ca5903aa8fe15af65994ae1479513a199ddbb6104978af +size 7951 diff --git a/test/data/fake_spot6.jp2 b/test/data/fake_spot6.jp2 new file mode 100644 index 00000000..0fc2b9d4 --- /dev/null +++ b/test/data/fake_spot6.jp2 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aece496d23eada3e3dac20c7285230c7ab9e49d96bb58b23b1f2dad5054d5002 +size 2099513 diff --git a/test/data/outvec_A.gpkg b/test/data/outvec_A.gpkg new file mode 100644 index 00000000..9f042b76 --- /dev/null +++ b/test/data/outvec_A.gpkg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c6c4ef5c6cabfffc74fbdd716fbba17627c0916ae1dd9dab556e9d7589ef66ee +size 663552 diff --git a/test/data/outvec_B.gpkg b/test/data/outvec_B.gpkg new file mode 100644 index 00000000..10e824d6 --- /dev/null +++ b/test/data/outvec_B.gpkg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5b8b9f30ed4acdf05e90c35f1dc12d44815dfa871527d4ebb9b5cb6614e3cbdd +size 663552 diff --git a/test/data/pan_subset.tif b/test/data/pan_subset.tif index ba380b4d..2e4a2f17 100644 Binary files a/test/data/pan_subset.tif and b/test/data/pan_subset.tif differ diff --git a/test/data/patchimg_01.tif b/test/data/patchimg_01.tif new file mode 100644 index 00000000..209aa93a --- /dev/null +++ b/test/data/patchimg_01.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e651c7fcc960ad900493161164a6429e3e2d66816c33b8e94f871a01756f6d34 +size 98776 diff --git a/test/data/patchimg_11.tif b/test/data/patchimg_11.tif new file mode 100644 index 00000000..95413dc6 --- /dev/null +++ b/test/data/patchimg_11.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5d2a41d26a211570a8ad45f79e2b87d533838cf296062d5317aa2bb2cb1b464d +size 105028 diff --git a/test/data/pxs_subset.tif b/test/data/pxs_subset.tif index 89ee7990..8909cb41 100644 Binary files a/test/data/pxs_subset.tif and b/test/data/pxs_subset.tif differ diff --git a/test/data/pxs_subset2.tif b/test/data/pxs_subset2.tif index 64991c05..c74443e6 100644 Binary files a/test/data/pxs_subset2.tif and b/test/data/pxs_subset2.tif differ diff --git a/test/data/s2_10m_labels_A.tif b/test/data/s2_10m_labels_A.tif new file mode 100644 index 00000000..4ed438fa --- /dev/null +++ b/test/data/s2_10m_labels_A.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ef0d65837c47f9d7904299c0e2e24077a3325483665e20bfeab0d90fc2cd6fec +size 2574 diff --git a/test/data/s2_10m_labels_B.tif b/test/data/s2_10m_labels_B.tif new file mode 100644 index 00000000..1b1e8a03 --- /dev/null +++ b/test/data/s2_10m_labels_B.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2e2785ae30e7a7788ad57007c48e791d54b1a7548bf08ad43d33acc25725c50 +size 2585 diff --git a/test/data/s2_10m_patches_A.tif b/test/data/s2_10m_patches_A.tif new file mode 100644 index 00000000..d8158f8a --- /dev/null +++ b/test/data/s2_10m_patches_A.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1ab610d8a0e08216e16f7a30e5cb55c3d5ddc60415f2a2475adcf33336c22aa9 +size 9679811 diff --git a/test/data/s2_10m_patches_B.tif b/test/data/s2_10m_patches_B.tif new file mode 100644 index 00000000..3e913acc --- /dev/null +++ b/test/data/s2_10m_patches_B.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6a3ba3be2e55f9b21a6d98f545e6f9b3d7bcc29bb1b93d5a421baf67adf99638 +size 9729377 diff --git a/test/data/s2_20m_patches_A.tif b/test/data/s2_20m_patches_A.tif new file mode 100644 index 00000000..a1fbc0f8 --- /dev/null +++ b/test/data/s2_20m_patches_A.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c12c791b081c86c14656d939be06e21a4b5498b7a15032e08bcfb9d579c3f877 +size 3505547 diff --git a/test/data/s2_20m_patches_B.tif b/test/data/s2_20m_patches_B.tif new file mode 100644 index 00000000..b2cf78ab --- /dev/null +++ b/test/data/s2_20m_patches_B.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4923867f94fbb3d3cd2708161bb1f0c19348b75b4a632d2d67a3f152238a3fca +size 3522046 diff --git a/test/data/s2_20m_stack.jp2 b/test/data/s2_20m_stack.jp2 new file mode 100644 index 00000000..15c2267b --- /dev/null +++ b/test/data/s2_20m_stack.jp2 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c62408e325e5208be7c3ff5c5ff541f349e599121aae1f93c46b46a041111fe +size 12584832 diff --git a/test/data/s2_labels_A.tif b/test/data/s2_labels_A.tif new file mode 100644 index 00000000..ebe395d8 --- /dev/null +++ b/test/data/s2_labels_A.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b01c44ca2bfe1f2710cc70bbaaca67321aa935a8cc112f28b501dcd4261a0fe7 +size 2474 diff --git a/test/data/s2_labels_B.tif b/test/data/s2_labels_B.tif new file mode 100644 index 00000000..aa2a2872 --- /dev/null +++ b/test/data/s2_labels_B.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3e8bc2b9acc30f2c44e4fbd3702cfb8dccef1845d3d6c5d84b39328f6170d0e3 +size 2485 diff --git a/test/data/s2_patches_A.tif b/test/data/s2_patches_A.tif new file mode 100644 index 00000000..6ae3bb06 --- /dev/null +++ b/test/data/s2_patches_A.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:473e6a13ff57123af49bc4b8437db96907dc33c7cc2a76d99bed0d2392c8098d +size 9679709 diff --git a/test/data/s2_patches_B.tif b/test/data/s2_patches_B.tif new file mode 100644 index 00000000..05820924 --- /dev/null +++ b/test/data/s2_patches_B.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d240e18c87a9f015dfda0600c8359dd20fc20eae66336905230c6848c482400e +size 9729275 diff --git a/test/data/s2_stack.jp2 b/test/data/s2_stack.jp2 new file mode 100644 index 00000000..630659fd --- /dev/null +++ b/test/data/s2_stack.jp2 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:091a491118dbee603f8d9cac66c130cc31758a4d5745e7877ad1ee760ad5224e +size 33554488 diff --git a/test/data/terrain_truth_epsg32654_A.tif b/test/data/terrain_truth_epsg32654_A.tif new file mode 100644 index 00000000..e738b06e --- /dev/null +++ b/test/data/terrain_truth_epsg32654_A.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3c26a60d41c0d4fff3edc79e3c5fefef88354bd7b3e845073ab7d087da67d3ff +size 517796 diff --git a/test/data/terrain_truth_epsg32654_B.tif b/test/data/terrain_truth_epsg32654_B.tif new file mode 100644 index 00000000..31e013f1 --- /dev/null +++ b/test/data/terrain_truth_epsg32654_B.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:60f1a69609fbc09c8b1007c40adebd9e01aee96a9d3bf4c32aade7e16c29e784 +size 507062 diff --git a/test/data/xs_subset.tif b/test/data/xs_subset.tif index 36305231..153a2bbd 100644 Binary files a/test/data/xs_subset.tif and b/test/data/xs_subset.tif differ diff --git a/test/models/model1/SavedModel_cnn/saved_model.pb b/test/models/model1/SavedModel_cnn/saved_model.pb deleted file mode 100644 index 7674fb36..00000000 Binary files a/test/models/model1/SavedModel_cnn/saved_model.pb and /dev/null differ diff --git a/test/models/model1/SavedModel_cnn/variables/variables.data-00000-of-00001 b/test/models/model1/SavedModel_cnn/variables/variables.data-00000-of-00001 deleted file mode 100644 index 437b9924..00000000 Binary files a/test/models/model1/SavedModel_cnn/variables/variables.data-00000-of-00001 and /dev/null differ diff --git a/test/models/model1/SavedModel_cnn/variables/variables.index b/test/models/model1/SavedModel_cnn/variables/variables.index deleted file mode 100644 index 759c02f3..00000000 Binary files a/test/models/model1/SavedModel_cnn/variables/variables.index and /dev/null differ diff --git a/test/models/model1/saved_model.pb b/test/models/model1/saved_model.pb index 7674fb36..48fd1b8c 100644 Binary files a/test/models/model1/saved_model.pb and b/test/models/model1/saved_model.pb differ diff --git a/test/models/model1/variables/variables.data-00000-of-00001 b/test/models/model1/variables/variables.data-00000-of-00001 index 2aba6b57..16e70843 100644 Binary files a/test/models/model1/variables/variables.data-00000-of-00001 and b/test/models/model1/variables/variables.data-00000-of-00001 differ diff --git a/test/models/model2/saved_model.pb b/test/models/model2/saved_model.pb index 7269539b..345bf207 100644 Binary files a/test/models/model2/saved_model.pb and b/test/models/model2/saved_model.pb differ diff --git a/test/models/model2/variables/variables.data-00000-of-00001 b/test/models/model2/variables/variables.data-00000-of-00001 index 60cb472a..18add85a 100644 Binary files a/test/models/model2/variables/variables.data-00000-of-00001 and b/test/models/model2/variables/variables.data-00000-of-00001 differ diff --git a/test/models/model3/saved_model.pb b/test/models/model3/saved_model.pb index 099d0ce6..a48e85ad 100644 Binary files a/test/models/model3/saved_model.pb and b/test/models/model3/saved_model.pb differ diff --git a/test/models/model3/variables/variables.data-00000-of-00001 b/test/models/model3/variables/variables.data-00000-of-00001 index 6d507b96..aa215b47 100644 Binary files a/test/models/model3/variables/variables.data-00000-of-00001 and b/test/models/model3/variables/variables.data-00000-of-00001 differ diff --git a/test/models/model4/saved_model.pb b/test/models/model4/saved_model.pb index 77c215d9..b9e1fa1f 100644 Binary files a/test/models/model4/saved_model.pb and b/test/models/model4/saved_model.pb differ diff --git a/test/models/model4/variables/variables.data-00000-of-00001 b/test/models/model4/variables/variables.data-00000-of-00001 index 2837c8c9..bc1f29ba 100644 Binary files a/test/models/model4/variables/variables.data-00000-of-00001 and b/test/models/model4/variables/variables.data-00000-of-00001 differ diff --git a/test/models/model5.py b/test/models/model5.py new file mode 100644 index 00000000..cc17d52e --- /dev/null +++ b/test/models/model5.py @@ -0,0 +1,25 @@ +""" +This test checks that the output tensor shapes are supported. +The input of this model must be a mono channel image. +All 4 different output shapes supported in OTBTF are tested. + +""" +import tensorflow as tf + +# Input +x = tf.keras.Input(shape=[None, None, None], name="x") # [b, h, w, c=1] + +# Create reshaped outputs +shape = tf.shape(x) +b = shape[0] +h = shape[1] +w = shape[2] +y1 = tf.reshape(x, shape=(b*h*w,)) # [b*h*w] +y2 = tf.reshape(x, shape=(b*h*w, 1)) # [b*h*w, 1] +y3 = tf.reshape(x, shape=(b, h, w)) # [b, h, w] +y4 = tf.reshape(x, shape=(b, h, w, 1)) # [b, h, w, 1] + +# Create model +model = tf.keras.Model(inputs={"x": x}, outputs={"y1": y1, "y2": y2, "y3": y3, "y4": y4}) +model.save("model5") + diff --git a/test/models/model5/saved_model.pb b/test/models/model5/saved_model.pb new file mode 100644 index 00000000..b173da41 --- /dev/null +++ b/test/models/model5/saved_model.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7577f7a8810d9b9250e38f6e43f91751a81b9e350849df9a9e092a72bf43b4d7 +size 78685 diff --git a/test/models/model5/variables/variables.data-00000-of-00001 b/test/models/model5/variables/variables.data-00000-of-00001 new file mode 100644 index 00000000..a29d56e6 --- /dev/null +++ b/test/models/model5/variables/variables.data-00000-of-00001 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8300c7141e98e177cfb572ae2635516cadf56b54d7ef2a98112d14a59b87e06b +size 797 diff --git a/test/models/model5/variables/variables.index b/test/models/model5/variables/variables.index new file mode 100644 index 00000000..b5dcd195 Binary files /dev/null and b/test/models/model5/variables/variables.index differ diff --git a/test/otbTensorflowCopyUtilsTests.cxx b/test/otbTensorflowCopyUtilsTests.cxx new file mode 100644 index 00000000..5b958646 --- /dev/null +++ b/test/otbTensorflowCopyUtilsTests.cxx @@ -0,0 +1,116 @@ +/*========================================================================= + + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#include "otbTensorflowCopyUtils.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "itkMacro.h" + +template +int compare(tensorflow::Tensor & t1, tensorflow::Tensor & t2) +{ + std::cout << "Compare " << t1.DebugString() << " and " << t2.DebugString() << std::endl; + if (t1.dims() != t2.dims()) + { + std::cout << "dims() differ!" << std::endl; + return EXIT_FAILURE; + } + if (t1.dtype() != t2.dtype()) + { + std::cout << "dtype() differ!" << std::endl; + return EXIT_FAILURE; + } + if (t1.NumElements() != t2.NumElements()) + { + std::cout << "NumElements() differ!" << std::endl; + return EXIT_FAILURE; + } + for (unsigned int i = 0; i < t1.NumElements(); i++) + if (t1.flat()(i) != t2.flat()(i)) + { + std::cout << "scalar " << i << " differ!" << std::endl; + return EXIT_FAILURE; + } + // Else + std::cout << "Tensors are equals :)" << std::endl; + return EXIT_SUCCESS; +} + +template +int genericValueToTensorTest(tensorflow::DataType dt, std::string expr, T value) +{ + tensorflow::Tensor t = otb::tf::ValueToTensor(expr); + tensorflow::Tensor t_ref(dt, tensorflow::TensorShape({1})); + t_ref.scalar()() = value; + + return compare(t, t_ref); +} + +int floatValueToTensorTest(int itkNotUsed(argc), char * itkNotUsed(argv)[]) +{ + return genericValueToTensorTest(tensorflow::DT_FLOAT, "0.1234", 0.1234) + && genericValueToTensorTest(tensorflow::DT_FLOAT, "-0.1234", -0.1234) ; +} + +int intValueToTensorTest(int itkNotUsed(argc), char * itkNotUsed(argv)[]) +{ + return genericValueToTensorTest(tensorflow::DT_INT32, "1234", 1234) + && genericValueToTensorTest(tensorflow::DT_INT32, "-1234", -1234); +} + +int boolValueToTensorTest(int itkNotUsed(argc), char * itkNotUsed(argv)[]) +{ + return genericValueToTensorTest(tensorflow::DT_BOOL, "true", true) + && genericValueToTensorTest(tensorflow::DT_BOOL, "True", true) + && genericValueToTensorTest(tensorflow::DT_BOOL, "False", false) + && genericValueToTensorTest(tensorflow::DT_BOOL, "false", false); +} + +template +int genericVecValueToTensorTest(tensorflow::DataType dt, std::string expr, std::vector values, std::size_t size) +{ + tensorflow::Tensor t = otb::tf::ValueToTensor(expr); + tensorflow::Tensor t_ref(dt, tensorflow::TensorShape({size})); + unsigned int i = 0; + for (auto value: values) + { + t_ref.flat()(i) = value; + i++; + } + + return compare(t, t_ref); +} + +int floatVecValueToTensorTest(int itkNotUsed(argc), char * itkNotUsed(argv)[]) +{ + return genericVecValueToTensorTest(tensorflow::DT_FLOAT, + "(0.1234, -1,-20,2.56 ,3.5)", + std::vector({0.1234, -1, -20, 2.56 ,3.5}), + 5); +} + +int intVecValueToTensorTest(int itkNotUsed(argc), char * itkNotUsed(argv)[]) +{ + return genericVecValueToTensorTest(tensorflow::DT_INT32, + "(1234, -1,-20,256 ,35)", + std::vector({1234, -1, -20, 256 ,35}), + 5); +} + +int boolVecValueToTensorTest(int itkNotUsed(argc), char * itkNotUsed(argv)[]) +{ + return genericVecValueToTensorTest(tensorflow::DT_BOOL, + "(true, false,True, False)", + std::vector({true, false, true, false}), + 4); +} + + diff --git a/test/otbTensorflowTests.cxx b/test/otbTensorflowTests.cxx new file mode 100644 index 00000000..50e9a91a --- /dev/null +++ b/test/otbTensorflowTests.cxx @@ -0,0 +1,23 @@ +/*========================================================================= + + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2021 INRAE + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#include "otbTestMain.h" + +void RegisterTests() +{ + REGISTER_TEST(floatValueToTensorTest); + REGISTER_TEST(intValueToTensorTest); + REGISTER_TEST(boolValueToTensorTest); + REGISTER_TEST(floatVecValueToTensorTest); + REGISTER_TEST(intVecValueToTensorTest); + REGISTER_TEST(boolVecValueToTensorTest); +} + diff --git a/test/sr4rs_unittest.py b/test/sr4rs_unittest.py new file mode 100644 index 00000000..89c3945f --- /dev/null +++ b/test/sr4rs_unittest.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import unittest +import os +from pathlib import Path +import test_utils + + +def command_train_succeed(extra_opts=""): + root_dir = os.environ["CI_PROJECT_DIR"] + ckpt_dir = "/tmp/" + + def _input(file_name): + return "{}/sr4rs_data/input/{}".format(root_dir, file_name) + + command = "python {}/sr4rs/code/train.py ".format(root_dir) + command += "--lr_patches " + command += _input("DIM_SPOT6_MS_202007290959110_ORT_ORTHO-MS-193_posA_s2.jp2 ") + command += _input("DIM_SPOT7_MS_202004111036186_ORT_ORTHO-MS-081_posA_s2.jp2 ") + command += _input("DIM_SPOT7_MS_202006201000507_ORT_ORTHO-MS-054_posA_s2.jp2 ") + command += "--hr_patches " + command += _input("DIM_SPOT6_MS_202007290959110_ORT_ORTHO-MS-193_posA_s6_cal.jp2 ") + command += _input("DIM_SPOT7_MS_202004111036186_ORT_ORTHO-MS-081_posA_s6_cal.jp2 ") + command += _input("DIM_SPOT7_MS_202006201000507_ORT_ORTHO-MS-054_posA_s6_cal.jp2 ") + command += "--save_ckpt {} ".format(ckpt_dir) + command += "--depth 4 " + command += "--nresblocks 1 " + command += "--epochs 1 " + command += extra_opts + os.system(command) + file = Path("{}/checkpoint".format(ckpt_dir)) + return file.is_file() + + +class SR4RSv1Test(unittest.TestCase): + + def test_train_nostream(self): + self.assertTrue(command_train_succeed()) + + def test_train_stream(self): + self.assertTrue(command_train_succeed(extra_opts="--streaming")) + + def test_inference(self): + root_dir = os.environ["CI_PROJECT_DIR"] + out_img = "/tmp/sr4rs.tif" + baseline = "{}/sr4rs_data/baseline/sr4rs.tif".format(root_dir) + + command = "python {}/sr4rs/code/sr.py ".format(root_dir) + command += "--input {}/sr4rs_data/input/".format(root_dir) + command += "SENTINEL2B_20200929-104857-489_L2A_T31TEJ_C_V2-2_FRE_10m.tif " + command += "--savedmodel {}/sr4rs_sentinel2_bands4328_france2020_savedmodel/ ".format(root_dir) + command += "--output '{}?&box=256:256:512:512'".format(out_img) + os.system(command) + + self.assertTrue(test_utils.compare(baseline, out_img)) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 00000000..c07301e9 --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,56 @@ +import otbApplication +import os + + +def get_nb_of_channels(raster): + """ + Return the number of channels in the input raster + :param raster: raster filename (str) + :return the number of channels in the image (int) + """ + info = otbApplication.Registry.CreateApplication("ReadImageInfo") + info.SetParameterString("in", raster) + info.ExecuteAndWriteOutput() + return info.GetParameterInt('numberbands') + + +def compare(raster1, raster2, tol=0.01): + """ + Return True if the two rasters have the same contents in each bands + :param raster1: raster 1 filename (str) + :param raster2: raster 2 filename (str) + :param tol: tolerance (float) + """ + n_bands1 = get_nb_of_channels(raster1) + n_bands2 = get_nb_of_channels(raster2) + if n_bands1 != n_bands2: + print("The images have not the same number of channels") + return False + + for i in range(1, 1 + n_bands1): + comp = otbApplication.Registry.CreateApplication('CompareImages') + comp.SetParameterString('ref.in', raster1) + comp.SetParameterInt('ref.channel', i) + comp.SetParameterString('meas.in', raster2) + comp.SetParameterInt('meas.channel', i) + comp.Execute() + mae = comp.GetParameterFloat('mae') + if mae > tol: + print("The images have not the same content in channel {} " + "(Mean average error: {})".format(i, mae)) + return False + return True + + +def resolve_paths(filename, var_list): + """ + Retrieve environment variables in paths + :param filename: file name + :params var_list: variable list + :return filename with retrieved environment variables + """ + new_filename = filename + for var in var_list: + new_filename = new_filename.replace("${}".format(var), os.environ[var]) + print("Resolve filename...\n\tfilename: {}, \n\tnew filename: {}".format(filename, new_filename)) + return new_filename diff --git a/test/tutorial_unittest.py b/test/tutorial_unittest.py new file mode 100644 index 00000000..7934862f --- /dev/null +++ b/test/tutorial_unittest.py @@ -0,0 +1,513 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import pytest +import unittest +import os +from pathlib import Path +import test_utils + +INFERENCE_MAE_TOL = 10.0 # Dummy value: we don't really care of the mae value but rather the image size etc + + +def resolve_paths(path): + """ + Resolve a path with the environment variables + """ + return test_utils.resolve_paths(path, var_list=["TMPDIR", "DATADIR"]) + + +def run_command(command): + """ + Run a command + :param command: the command to run + """ + full_command = resolve_paths(command) + print("Running command: \n\t {}".format(full_command)) + os.system(full_command) + + +def run_command_and_test_exist(command, file_list): + """ + :param command: the command to run (str) + :param file_list: list of files to check + :return True or False + """ + run_command(command) + print("Checking if files exist...") + for file in file_list: + print("\t{}".format(file)) + path = Path(resolve_paths(file)) + if not path.is_file(): + print("File {} does not exist!".format(file)) + return False + print("\tOk") + return True + + +def run_command_and_compare(command, to_compare_dict, tol=0.01): + """ + :param command: the command to run (str) + :param to_compare_dict: a dict of {baseline1: output1, ..., baselineN: outputN} + :param tol: tolerance (float) + :return True or False + """ + + run_command(command) + for baseline, output in to_compare_dict.items(): + if not test_utils.compare(resolve_paths(baseline), resolve_paths(output), tol): + print("Baseline {} and output {} differ.".format(baseline, output)) + return False + return True + + +class TutorialTest(unittest.TestCase): + + @pytest.mark.order(1) + def test_sample_selection(self): + self.assertTrue( + run_command_and_test_exist( + command="otbcli_LabelImageSampleSelection " + "-inref $DATADIR/terrain_truth_epsg32654_A.tif " + "-nodata 255 " + "-outvec $TMPDIR/outvec_A.gpkg", + file_list=["$TMPDIR/outvec_A.gpkg"])) + self.assertTrue( + run_command_and_test_exist( + command="otbcli_LabelImageSampleSelection " + "-inref $DATADIR/terrain_truth_epsg32654_B.tif " + "-nodata 255 " + "-outvec $TMPDIR/outvec_B.gpkg", + file_list=["$TMPDIR/outvec_B.gpkg"])) + + @pytest.mark.order(2) + def test_patches_extraction(self): + self.assertTrue( + run_command_and_compare( + command="otbcli_PatchesExtraction " + "-source1.il $DATADIR/s2_stack.jp2 " + "-source1.out $TMPDIR/s2_patches_A.tif " + "-source1.patchsizex 16 " + "-source1.patchsizey 16 " + "-vec $TMPDIR/outvec_A.gpkg " + "-field class " + "-outlabels $TMPDIR/s2_labels_A.tif", + to_compare_dict={"$DATADIR/s2_patches_A.tif": "$TMPDIR/s2_patches_A.tif", + "$DATADIR/s2_labels_A.tif": "$TMPDIR/s2_labels_A.tif"})) + self.assertTrue( + run_command_and_compare( + command="otbcli_PatchesExtraction " + "-source1.il $DATADIR/s2_stack.jp2 " + "-source1.out $TMPDIR/s2_patches_B.tif " + "-source1.patchsizex 16 " + "-source1.patchsizey 16 " + "-vec $TMPDIR/outvec_B.gpkg " + "-field class " + "-outlabels $TMPDIR/s2_labels_B.tif", + to_compare_dict={"$DATADIR/s2_patches_B.tif": "$TMPDIR/s2_patches_B.tif", + "$DATADIR/s2_labels_B.tif": "$TMPDIR/s2_labels_B.tif"})) + + @pytest.mark.order(3) + def test_generate_model1(self): + run_command("git clone https://github.com/remicres/otbtf_tutorials_resources.git " + "$TMPDIR/otbtf_tuto_repo") + self.assertTrue( + run_command_and_test_exist( + command="python $TMPDIR/otbtf_tuto_repo/01_patch_based_classification/models/create_model1.py " + "$TMPDIR/model1", + file_list=["$TMPDIR/model1/saved_model.pb"])) + + @pytest.mark.order(4) + def test_model1_train(self): + self.assertTrue( + run_command_and_test_exist( + command="otbcli_TensorflowModelTrain " + "-training.source1.il $DATADIR/s2_patches_A.tif " + "-training.source1.patchsizex 16 " + "-training.source1.patchsizey 16 " + "-training.source1.placeholder x " + "-training.source2.il $DATADIR/s2_labels_A.tif " + "-training.source2.patchsizex 1 " + "-training.source2.patchsizey 1 " + "-training.source2.placeholder y " + "-model.dir $TMPDIR/model1 " + "-training.targetnodes optimizer " + "-training.epochs 10 " + "-validation.mode class " + "-validation.source1.il $DATADIR/s2_patches_B.tif " + "-validation.source1.name x " + "-validation.source2.il $DATADIR/s2_labels_B.tif " + "-validation.source2.name prediction " + "-model.saveto $TMPDIR/model1/variables/variables", + file_list=["$TMPDIR/model1/variables/variables.index"] + ) + ) + + @pytest.mark.order(5) + def test_model1_inference_pb(self): + self.assertTrue( + run_command_and_compare( + command="otbcli_TensorflowModelServe " + "-source1.il $DATADIR/s2_stack.jp2 " + "-source1.rfieldx 16 " + "-source1.rfieldy 16 " + "-source1.placeholder x " + "-model.dir $TMPDIR/model1 " + "-output.names prediction " + "-out \"$TMPDIR/classif_model1.tif?&box=4000:4000:1000:1000\" uint8", + to_compare_dict={"$DATADIR/classif_model1.tif": "$TMPDIR/classif_model1.tif"}, + tol=INFERENCE_MAE_TOL)) + + @pytest.mark.order(6) + def test_model1_inference_fcn(self): + self.assertTrue( + run_command_and_compare( + command="otbcli_TensorflowModelServe " + "-source1.il $DATADIR/s2_stack.jp2 " + "-source1.rfieldx 16 " + "-source1.rfieldy 16 " + "-source1.placeholder x " + "-model.dir $TMPDIR/model1 " + "-output.names prediction " + "-model.fullyconv on " + "-output.spcscale 4 " + "-out \"$TMPDIR/classif_model1.tif?&box=1000:1000:256:256\" uint8", + to_compare_dict={"$DATADIR/classif_model1.tif": "$TMPDIR/classif_model1.tif"}, + tol=INFERENCE_MAE_TOL)) + + @pytest.mark.order(7) + def test_rf_sampling(self): + self.assertTrue( + run_command_and_test_exist( + command="otbcli_SampleExtraction " + "-in $DATADIR/s2_stack.jp2 " + "-vec $TMPDIR/outvec_A.gpkg " + "-field class " + "-out $TMPDIR/pixelvalues_A.gpkg", + file_list=["$TMPDIR/pixelvalues_A.gpkg"])) + self.assertTrue( + run_command_and_test_exist( + command="otbcli_SampleExtraction " + "-in $DATADIR/s2_stack.jp2 " + "-vec $TMPDIR/outvec_B.gpkg " + "-field class " + "-out $TMPDIR/pixelvalues_B.gpkg", + file_list=["$TMPDIR/pixelvalues_B.gpkg"])) + + @pytest.mark.order(8) + def test_rf_training(self): + self.assertTrue( + run_command_and_test_exist( + command="otbcli_TrainVectorClassifier " + "-io.vd $TMPDIR/pixelvalues_A.gpkg " + "-valid.vd $TMPDIR/pixelvalues_B.gpkg " + "-feat value_0 value_1 value_2 value_3 " + "-cfield class " + "-classifier rf " + "-io.out $TMPDIR/randomforest_model.yaml ", + file_list=["$TMPDIR/randomforest_model.yaml"])) + + @pytest.mark.order(9) + def test_generate_model2(self): + self.assertTrue( + run_command_and_test_exist( + command="python $TMPDIR/otbtf_tuto_repo/01_patch_based_classification/models/create_model2.py " + "$TMPDIR/model2", + file_list=["$TMPDIR/model2/saved_model.pb"])) + + @pytest.mark.order(10) + def test_model2_train(self): + self.assertTrue( + run_command_and_test_exist( + command="otbcli_TensorflowModelTrain " + "-training.source1.il $DATADIR/s2_patches_A.tif " + "-training.source1.patchsizex 16 " + "-training.source1.patchsizey 16 " + "-training.source1.placeholder x " + "-training.source2.il $DATADIR/s2_labels_A.tif " + "-training.source2.patchsizex 1 " + "-training.source2.patchsizey 1 " + "-training.source2.placeholder y " + "-model.dir $TMPDIR/model2 " + "-training.targetnodes optimizer " + "-training.epochs 10 " + "-validation.mode class " + "-validation.source1.il $DATADIR/s2_patches_B.tif " + "-validation.source1.name x " + "-validation.source2.il $DATADIR/s2_labels_B.tif " + "-validation.source2.name prediction " + "-model.saveto $TMPDIR/model2/variables/variables", + file_list=["$TMPDIR/model2/variables/variables.index"])) + + @pytest.mark.order(11) + def test_model2_inference_fcn(self): + self.assertTrue( + run_command_and_compare(command="otbcli_TensorflowModelServe " + "-source1.il $DATADIR/s2_stack.jp2 " + "-source1.rfieldx 16 " + "-source1.rfieldy 16 " + "-source1.placeholder x " + "-model.dir $TMPDIR/model2 " + "-model.fullyconv on " + "-output.names prediction " + "-out \"$TMPDIR/classif_model2.tif?&box=4000:4000:1000:1000\" uint8", + to_compare_dict={"$DATADIR/classif_model2.tif": "$TMPDIR/classif_model2.tif"}, + tol=INFERENCE_MAE_TOL)) + + @pytest.mark.order(12) + def test_model2rf_train(self): + self.assertTrue( + run_command_and_test_exist( + command="otbcli_TrainClassifierFromDeepFeatures " + "-source1.il $DATADIR/s2_stack.jp2 " + "-source1.rfieldx 16 " + "-source1.rfieldy 16 " + "-source1.placeholder x " + "-model.dir $TMPDIR/model2 " + "-model.fullyconv on " + "-optim.tilesizex 999999 " + "-optim.tilesizey 128 " + "-output.names features " + "-vd $TMPDIR/outvec_A.gpkg " + "-valid $TMPDIR/outvec_B.gpkg " + "-sample.vfn class " + "-sample.bm 0 " + "-classifier rf " + "-out $TMPDIR/RF_model_from_deep_features.yaml", + file_list=["$TMPDIR/RF_model_from_deep_features.yaml"])) + + @pytest.mark.order(13) + def test_model2rf_inference(self): + self.assertTrue( + run_command_and_compare( + command="otbcli_ImageClassifierFromDeepFeatures " + "-source1.il $DATADIR/s2_stack.jp2 " + "-source1.rfieldx 16 " + "-source1.rfieldy 16 " + "-source1.placeholder x " + "-deepmodel.dir $TMPDIR/model2 " + "-deepmodel.fullyconv on " + "-output.names features " + "-model $TMPDIR/RF_model_from_deep_features.yaml " + "-out \"$TMPDIR/RF_model_from_deep_features_map.tif?&box=4000:4000:1000:1000\" uint8", + to_compare_dict={ + "$DATADIR/RF_model_from_deep_features_map.tif": "$TMPDIR/RF_model_from_deep_features_map.tif"}, + tol=INFERENCE_MAE_TOL)) + + @pytest.mark.order(14) + def test_patch_extraction_20m(self): + self.assertTrue( + run_command_and_compare( + command="OTB_TF_NSOURCES=2 otbcli_PatchesExtraction " + "-source1.il $DATADIR/s2_20m_stack.jp2 " + "-source1.patchsizex 8 " + "-source1.patchsizey 8 " + "-source1.out $TMPDIR/s2_20m_patches_A.tif " + "-source2.il $DATADIR/s2_stack.jp2 " + "-source2.patchsizex 16 " + "-source2.patchsizey 16 " + "-source2.out $TMPDIR/s2_10m_patches_A.tif " + "-vec $TMPDIR/outvec_A.gpkg " + "-field class " + "-outlabels $TMPDIR/s2_10m_labels_A.tif uint8", + to_compare_dict={"$DATADIR/s2_10m_labels_A.tif": "$TMPDIR/s2_10m_labels_A.tif", + "$DATADIR/s2_10m_patches_A.tif": "$TMPDIR/s2_10m_patches_A.tif", + "$DATADIR/s2_20m_patches_A.tif": "$TMPDIR/s2_20m_patches_A.tif"})) + self.assertTrue( + run_command_and_compare( + command="OTB_TF_NSOURCES=2 otbcli_PatchesExtraction " + "-source1.il $DATADIR/s2_20m_stack.jp2 " + "-source1.patchsizex 8 " + "-source1.patchsizey 8 " + "-source1.out $TMPDIR/s2_20m_patches_B.tif " + "-source2.il $DATADIR/s2_stack.jp2 " + "-source2.patchsizex 16 " + "-source2.patchsizey 16 " + "-source2.out $TMPDIR/s2_10m_patches_B.tif " + "-vec $TMPDIR/outvec_B.gpkg " + "-field class " + "-outlabels $TMPDIR/s2_10m_labels_B.tif uint8", + to_compare_dict={"$DATADIR/s2_10m_labels_B.tif": "$TMPDIR/s2_10m_labels_B.tif", + "$DATADIR/s2_10m_patches_B.tif": "$TMPDIR/s2_10m_patches_B.tif", + "$DATADIR/s2_20m_patches_B.tif": "$TMPDIR/s2_20m_patches_B.tif"})) + + @pytest.mark.order(15) + def test_generate_model3(self): + self.assertTrue( + run_command_and_test_exist( + command="python $TMPDIR/otbtf_tuto_repo/01_patch_based_classification/models/create_model3.py " + "$TMPDIR/model3", + file_list=["$TMPDIR/model3/saved_model.pb"])) + + @pytest.mark.order(16) + def test_model3_train(self): + self.assertTrue( + run_command_and_test_exist( + command="OTB_TF_NSOURCES=2 otbcli_TensorflowModelTrain " + "-training.source1.il $DATADIR/s2_20m_patches_A.tif " + "-training.source1.patchsizex 8 " + "-training.source1.patchsizey 8 " + "-training.source1.placeholder x1 " + "-training.source2.il $DATADIR/s2_10m_patches_A.tif " + "-training.source2.patchsizex 16 " + "-training.source2.patchsizey 16 " + "-training.source2.placeholder x2 " + "-training.source3.il $DATADIR/s2_10m_labels_A.tif " + "-training.source3.patchsizex 1 " + "-training.source3.patchsizey 1 " + "-training.source3.placeholder y " + "-model.dir $TMPDIR/model3 " + "-training.targetnodes optimizer " + "-training.epochs 10 " + "-validation.mode class " + "-validation.source1.il $DATADIR/s2_20m_patches_B.tif " + "-validation.source1.name x1 " + "-validation.source2.il $DATADIR/s2_10m_patches_B.tif " + "-validation.source2.name x2 " + "-validation.source3.il $DATADIR/s2_10m_labels_B.tif " + "-validation.source3.name prediction " + "-model.saveto $TMPDIR/model3/variables/variables", + file_list=["$TMPDIR/model3/variables/variables.index"])) + + @pytest.mark.order(17) + def test_model3_inference_pb(self): + self.assertTrue( + run_command_and_compare( + command= + "OTB_TF_NSOURCES=2 otbcli_TensorflowModelServe " + "-source1.il $DATADIR/s2_20m_stack.jp2 " + "-source1.rfieldx 8 " + "-source1.rfieldy 8 " + "-source1.placeholder x1 " + "-source2.il $DATADIR/s2_stack.jp2 " + "-source2.rfieldx 16 " + "-source2.rfieldy 16 " + "-source2.placeholder x2 " + "-model.dir $TMPDIR/model3 " + "-output.names prediction " + "-out \"$TMPDIR/classif_model3_pb.tif?&box=2000:2000:500:500&gdal:co:compress=deflate\"", + to_compare_dict={"$DATADIR/classif_model3_pb.tif": "$TMPDIR/classif_model3_pb.tif"}, + tol=INFERENCE_MAE_TOL)) + + @pytest.mark.order(18) + def test_model3_inference_fcn(self): + self.assertTrue( + run_command_and_compare( + command= + "OTB_TF_NSOURCES=2 otbcli_TensorflowModelServe " + "-source1.il $DATADIR/s2_20m_stack.jp2 " + "-source1.rfieldx 8 " + "-source1.rfieldy 8 " + "-source1.placeholder x1 " + "-source2.il $DATADIR/s2_stack.jp2 " + "-source2.rfieldx 16 " + "-source2.rfieldy 16 " + "-source2.placeholder x2 " + "-model.dir $TMPDIR/model3 " + "-model.fullyconv on " + "-output.names prediction " + "-out \"$TMPDIR/classif_model3_fcn.tif?&box=2000:2000:500:500&gdal:co:compress=deflate\"", + to_compare_dict={"$DATADIR/classif_model3_fcn.tif": "$TMPDIR/classif_model3_fcn.tif"}, + tol=INFERENCE_MAE_TOL)) + + @pytest.mark.order(19) + def test_generate_model4(self): + self.assertTrue( + run_command_and_test_exist( + command="python $TMPDIR/otbtf_tuto_repo/02_semantic_segmentation/models/create_model4.py " + "$TMPDIR/model4", + file_list=["$TMPDIR/model4/saved_model.pb"])) + + @pytest.mark.order(20) + def test_patches_selection_semseg(self): + self.assertTrue( + run_command_and_test_exist( + command="otbcli_PatchesSelection " + "-in $DATADIR/fake_spot6.jp2 " + "-grid.step 64 " + "-grid.psize 64 " + "-outtrain $TMPDIR/outvec_A_semseg.gpkg " + "-outvalid $TMPDIR/outvec_B_semseg.gpkg", + file_list=["$TMPDIR/outvec_A_semseg.gpkg", + "$TMPDIR/outvec_B_semseg.gpkg"])) + + @pytest.mark.order(21) + def test_patch_extraction_semseg(self): + self.assertTrue( + run_command_and_compare( + command="OTB_TF_NSOURCES=2 otbcli_PatchesExtraction " + "-source1.il $DATADIR/fake_spot6.jp2 " + "-source1.patchsizex 64 " + "-source1.patchsizey 64 " + "-source1.out \"$TMPDIR/amsterdam_patches_A.tif?&gdal:co:compress=deflate\" " + "-source2.il $TMPDIR/otbtf_tuto_repo/02_semantic_segmentation/" + "amsterdam_dataset/terrain_truth/amsterdam_labelimage.tif " + "-source2.patchsizex 64 " + "-source2.patchsizey 64 " + "-source2.out \"$TMPDIR/amsterdam_labels_A.tif?&gdal:co:compress=deflate\" " + "-vec $TMPDIR/outvec_A_semseg.gpkg " + "-field id ", + to_compare_dict={"$DATADIR/amsterdam_labels_A.tif": "$TMPDIR/amsterdam_labels_A.tif", + "$DATADIR/amsterdam_patches_A.tif": "$TMPDIR/amsterdam_patches_A.tif"})) + self.assertTrue( + run_command_and_compare( + command="OTB_TF_NSOURCES=2 otbcli_PatchesExtraction " + "-source1.il $DATADIR/fake_spot6.jp2 " + "-source1.patchsizex 64 " + "-source1.patchsizey 64 " + "-source1.out \"$TMPDIR/amsterdam_patches_B.tif?&gdal:co:compress=deflate\" " + "-source2.il $TMPDIR/otbtf_tuto_repo/02_semantic_segmentation/" + "amsterdam_dataset/terrain_truth/amsterdam_labelimage.tif " + "-source2.patchsizex 64 " + "-source2.patchsizey 64 " + "-source2.out \"$TMPDIR/amsterdam_labels_B.tif?&gdal:co:compress=deflate\" " + "-vec $TMPDIR/outvec_B_semseg.gpkg " + "-field id ", + to_compare_dict={"$DATADIR/amsterdam_labels_B.tif": "$TMPDIR/amsterdam_labels_B.tif", + "$DATADIR/amsterdam_patches_B.tif": "$TMPDIR/amsterdam_patches_B.tif"})) + + @pytest.mark.order(22) + def test_model4_train(self): + self.assertTrue( + run_command_and_test_exist( + command="OTB_TF_NSOURCES=1 otbcli_TensorflowModelTrain " + "-training.source1.il $DATADIR/amsterdam_patches_A.tif " + "-training.source1.patchsizex 64 " + "-training.source1.patchsizey 64 " + "-training.source1.placeholder x " + "-training.source2.il $DATADIR/amsterdam_labels_A.tif " + "-training.source2.patchsizex 64 " + "-training.source2.patchsizey 64 " + "-training.source2.placeholder y " + "-model.dir $TMPDIR/model4 " + "-training.targetnodes optimizer " + "-training.epochs 10 " + "-validation.mode class " + "-validation.source1.il $DATADIR/amsterdam_patches_B.tif " + "-validation.source1.name x " + "-validation.source2.il $DATADIR/amsterdam_labels_B.tif " + "-validation.source2.name prediction " + "-model.saveto $TMPDIR/model4/variables/variables", + file_list=["$TMPDIR/model4/variables/variables.index"])) + + @pytest.mark.order(23) + def test_model4_inference(self): + self.assertTrue( + run_command_and_compare( + command= + "otbcli_TensorflowModelServe " + "-source1.il $DATADIR/fake_spot6.jp2 " + "-source1.rfieldx 64 " + "-source1.rfieldy 64 " + "-source1.placeholder x " + "-model.dir $TMPDIR/model4 " + "-model.fullyconv on " + "-output.names prediction_fcn " + "-output.efieldx 32 " + "-output.efieldy 32 " + "-out \"$TMPDIR/classif_model4.tif?&gdal:co:compress=deflate\" uint8", + to_compare_dict={"$DATADIR/classif_model4.tif": "$TMPDIR/classif_model4.tif"}, + tol=INFERENCE_MAE_TOL)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tools/docker/README.md b/tools/docker/README.md index 8722b52e..4246b74d 100644 --- a/tools/docker/README.md +++ b/tools/docker/README.md @@ -74,7 +74,7 @@ docker build --network='host' -t otbtf:oldstable-gpu --build-arg BASE_IMG=nvidia ### Build for another machine and save TF compiled files ```bash -# Use same ubuntu and CUDA version than your target machine, beware of CC optimization and CPU compatibilty +# Use same ubuntu and CUDA version than your target machine, beware of CC optimization and CPU compatibility # (set env variable CC_OPT_FLAGS and avoid "-march=native" if your Docker's CPU is optimized with AVX2/AVX512 but your target CPU isn't) docker build --network='host' -t otbtf:custom --build-arg BASE_IMG=nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 \ --build-arg TF=v2.5.0 --build-arg ZIP_TF_BIN=true . @@ -146,7 +146,7 @@ $ mapla ``` ## Common errors -Buid : +Build : `Error response from daemon: manifest for nvidia/cuda:11.0-cudnn8-devel-ubuntu20.04 not found: manifest unknown: manifest unknown` => Image is missing from dockerhub