diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 4e9a6c01..9d852df7 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -216,7 +216,7 @@ def next_module(self): return self._next_module[self._ref_count-1] def backward_release(self, flag): - if self._ref_count == 1: + if self._ref_count == 1 and self._backward_block_ctx is not None: self._backward_block_ctx.exit(flag, True) config['load_stream'].record_event(config['load_event']) self._ref_count -= 1