Skip to content

Commit

Permalink
Merge pull request #18 from luxonis/sanitize-nn-output
Browse files Browse the repository at this point in the history
Sanitize NN output (NaN, Infinite); handle corrupted frames
  • Loading branch information
SzabolcsGergely authored May 12, 2020
2 parents cb09505 + dbb6996 commit 1022dbb
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 10 deletions.
4 changes: 2 additions & 2 deletions host/core/host_data_packet.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ struct HostDataPacket
} catch (const std::exception& e)
{
std::cerr << e.what() << std::endl;
result = new py::array(py::dtype("f"), {1}, {});
result = nullptr;
}

//py::gil_scoped_release release; // REUIRED ???

// std::cout << "===> c++ getPythonNumpyArray " << t.ellapsed_us() << " us\n";
Expand Down
14 changes: 14 additions & 0 deletions host/core/nnet/tensor_entry.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,18 @@ struct TensorEntry
return getFloatByIndex(arr_index);
}

bool checkValidTensorEntry() const
{
for(int idx = 0; idx < getPropertiesNumber(); idx++)
{
float tensorValue = getFloatByIndex(idx);
if(isnan(tensorValue) || isinf(tensorValue))
{
printf("invalid tensor packet, discarding \n");
return false;
}
}
return true;
}

};
22 changes: 16 additions & 6 deletions host/core/nnet/tensor_entry_container.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ class TensorEntryContainer
// &ti.output_property_value_string_to_index[ti.output_properties_dimensions[0]];
// }
}

entry.push_back(te);
if(te.checkValidTensorEntry() == true)
{
entry.push_back(te);
}
}

return entry;
Expand Down Expand Up @@ -112,12 +114,20 @@ struct PyTensorEntryContainerIterator

std::vector<TensorEntry> next()
{
if (index == seq.size())
while(true)
{
throw py::stop_iteration();
if (index == seq.size())
{
throw py::stop_iteration();
}
std::vector<TensorEntry> next_entry = seq.getByIndex(index++);
if(next_entry.empty())
{
continue;
}

return next_entry;
}

return seq.getByIndex(index++);
}

TensorEntryContainer &seq;
Expand Down
1 change: 1 addition & 0 deletions host/core/nnet/tensor_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ struct TensorInfo
{
assert(output_entry_iteration_index < output_dimensions.size());
//assert(false); // TODO: correct ?
assert(output_dimensions[output_entry_iteration_index] != 0);
return getTensorSize() / output_dimensions[output_entry_iteration_index];
}
}
Expand Down
10 changes: 9 additions & 1 deletion host/core/pipeline/host_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@ void HostPipeline::onNewData(

if (!_data_queue_lf.push(host_data))
{
std::cout << "Data queue is full " << info.name << ":\n";
std::unique_lock<std::mutex> guard(q_lock);
_data_queue_lf.pop();
guard.unlock();
if (!_data_queue_lf.push(host_data))
{
std::cerr << "Data queue is full " << info.name << ":\n";
}
}

// std::cout << "===> onNewData " << t.ellapsed_us() << " us\n";
Expand Down Expand Up @@ -80,6 +86,7 @@ void HostPipeline::consumePackets()
Timer consume_dur;
_consumed_packets.clear();

std::unique_lock<std::mutex> guard(q_lock);
_data_queue_lf.consume_all(
[this]
(std::shared_ptr<HostDataPacket>& data)
Expand All @@ -88,6 +95,7 @@ void HostPipeline::consumePackets()
this->_consumed_packets.push_back(data);
}
);
guard.unlock();

if (!this->_consumed_packets.empty())
{
Expand Down
4 changes: 3 additions & 1 deletion host/core/pipeline/host_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <set>
#include <tuple>
#include <vector>
#include <mutex>

#include <boost/lockfree/spsc_queue.hpp>
#include <boost/lockfree/queue.hpp>
Expand All @@ -20,7 +21,7 @@ class HostPipeline
: public DataObserver<StreamInfo, StreamData>
{
protected:
const unsigned c_data_queue_size = 100;
const unsigned c_data_queue_size = 30;

boost::lockfree::spsc_queue<std::shared_ptr<HostDataPacket>> _data_queue_lf;
std::list<std::shared_ptr<HostDataPacket>> _consumed_packets; // TODO: temporary solution
Expand All @@ -44,6 +45,7 @@ class HostPipeline
std::list<std::shared_ptr<HostDataPacket>> getConsumedDataPackets();

private:
std::mutex q_lock;
// from DataObserver<StreamInfo, StreamData>
virtual void onNewData(const StreamInfo& info, const StreamData& data) final;
// from DataObserver<StreamInfo, StreamData>
Expand Down

0 comments on commit 1022dbb

Please sign in to comment.