Skip to content

Commit

Permalink
[XLS] Map next-state elements to next_value nodes
Browse files Browse the repository at this point in the history
IR verification now checks that next-state elements are no longer being used; they can only be used internally to a pass (currently used to continue supporting proc inlining).

PiperOrigin-RevId: 712551183
  • Loading branch information
ericastor authored and copybara-github committed Jan 6, 2025
1 parent 67bf106 commit cb4cf1f
Show file tree
Hide file tree
Showing 47 changed files with 518 additions and 994 deletions.
11 changes: 4 additions & 7 deletions xls/estimators/delay_model/analyze_critical_path_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,10 @@ TEST_F(AnalyzeCriticalPathTest, ProcWithState) {
AnalyzeCriticalPath(proc, /*clock_period_ps=*/std::nullopt,
*delay_estimator_));

ASSERT_EQ(cp.size(), 3);
EXPECT_EQ(cp[0].node, rev.node());
EXPECT_EQ(cp[0].path_delay_ps, 2);
EXPECT_EQ(cp[1].node, neg.node());
EXPECT_EQ(cp[1].path_delay_ps, 1);
EXPECT_EQ(cp[2].node, proc->GetStateRead(int64_t{0}));
EXPECT_EQ(cp[2].path_delay_ps, 0);
EXPECT_THAT(cp, ElementsAre(FieldsAre(m::Next(), _, 2, _),
FieldsAre(rev.node(), _, 2, _),
FieldsAre(neg.node(), _, 1, _),
FieldsAre(st.node(), _, 0, _)));
}

TEST_F(AnalyzeCriticalPathTest, ProcWithSendReceive) {
Expand Down
50 changes: 28 additions & 22 deletions xls/ir/function_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1165,43 +1165,49 @@ absl::Status ProcBuilder::InstantiateProc(
.status();
}

absl::StatusOr<Proc*> ProcBuilder::Build(absl::Span<const BValue> next_state) {
absl::StatusOr<Proc*> ProcBuilder::Build() {
if (ErrorPending()) {
return GetError();
}

// TODO: Remove this once fully transitioned over to `next_value` nodes.
if (!next_state.empty() && next_state.size() != state_params_.size()) {
return absl::InvalidArgumentError(
absl::StrFormat("Number of recurrent state elements given (%d) does "
"not equal the number of state elements in the proc "
"(%d)",
next_state.size(), state_params_.size()));
}
for (int64_t i = 0; i < next_state.size(); ++i) {
if (GetType(next_state[i]) != GetType(GetStateParam(i))) {
return absl::InvalidArgumentError(
absl::StrFormat("Recurrent state type %s does not match proc "
"parameter state type %s for element %d.",
GetType(GetStateParam(i))->ToString(),
GetType(next_state[i])->ToString(), i));
}
}

// down_cast the FunctionBase* to Proc*. We know this is safe because
// ProcBuilder constructs and passes a Proc to BuilderBase constructor so
// function_ is always a Proc.
Proc* proc = package()->AddProc(
absl::WrapUnique(down_cast<Proc*>(function_.release())));
for (int64_t i = 0; i < next_state.size(); ++i) {
XLS_RETURN_IF_ERROR(proc->SetNextStateElement(i, next_state[i].node()));
}
if (should_verify_) {
XLS_RETURN_IF_ERROR(VerifyProc(proc));
}
return proc;
}

absl::StatusOr<Proc*> ProcBuilder::Build(absl::Span<const BValue> next_state) {
if (!next_state.empty()) {
if (next_state.size() != state_params_.size()) {
return absl::InvalidArgumentError(
absl::StrFormat("Number of recurrent state elements given (%d) does "
"not equal the number of state elements in the proc "
"(%d)",
next_state.size(), state_params_.size()));
}
if (!proc()->next_values().empty()) {
return absl::InvalidArgumentError(
"Cannot use Build(next_state) when also using next_value nodes.");
}
for (int64_t index = 0; index < next_state.size(); ++index) {
if (GetType(next_state[index]) != GetType(GetStateParam(index))) {
return absl::InvalidArgumentError(
absl::StrFormat("Recurrent state type %s does not match provided "
"state type %s for element %d.",
GetType(GetStateParam(index))->ToString(),
GetType(next_state[index])->ToString(), index));
}
Next(GetStateParam(index), next_state[index]);
}
}
return Build();
}

BValue ProcBuilder::StateElement(std::string_view name,
const Value& initial_value,
std::optional<BValue> read_predicate,
Expand Down
7 changes: 6 additions & 1 deletion xls/ir/function_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
#include "xls/common/casts.h"
#include "xls/common/status/ret_check.h"
Expand Down Expand Up @@ -801,10 +802,14 @@ class ProcBuilder : public BuilderBase {
return state_params_.at(index);
}

absl::StatusOr<Proc*> Build();

// Build the proc using the given BValues as the next state values. If
// `next_state` is not empty, the number of recurrent state elements in
// `next_state` must match the number of state parameters.
absl::StatusOr<Proc*> Build(absl::Span<const BValue> next_state = {});
// Provided as a convenience for the common case where we can treat the next
// state as approximately a return value.
absl::StatusOr<Proc*> Build(absl::Span<const BValue> next_state);

// Adds a state element to the proc with the given initial value (and read
// predicate if provided). Returns the newly added state read.
Expand Down
42 changes: 27 additions & 15 deletions xls/ir/function_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "xls/ir/function_builder.h"

#include <cstdint>
#include <string>
#include <vector>

Expand Down Expand Up @@ -518,12 +519,16 @@ TEST(FunctionBuilderTest, SendAndReceive) {
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, b.Build({after_all, next_state}));

EXPECT_THAT(
proc->GetNextStateElement(0),
m::AfterAll(m::Send(), m::TupleIndex(m::Receive()),
m::Send(m::StateRead(), m::StateRead(), m::Literal(1)),
m::TupleIndex(m::Receive())));
EXPECT_THAT(proc->GetNextStateElement(1),
m::Add(m::TupleIndex(m::Receive()), m::TupleIndex(m::Receive())));
proc->next_values(proc->GetStateRead(int64_t{0})),
ElementsAre(m::Next(
m::StateRead("my_token"),
m::AfterAll(m::Send(), m::TupleIndex(m::Receive()),
m::Send(m::StateRead(), m::StateRead(), m::Literal(1)),
m::TupleIndex(m::Receive())))));
EXPECT_THAT(proc->next_values(proc->GetStateRead(1)),
ElementsAre(m::Next(m::StateRead("my_state"),
m::Add(m::TupleIndex(m::Receive()),
m::TupleIndex(m::Receive())))));

EXPECT_THAT(proc->StateElements(),
ElementsAre(m::StateElement("my_token", Value::Token()),
Expand Down Expand Up @@ -863,9 +868,11 @@ TEST(FunctionBuilderTest, TokenlessProcBuilder) {
m::MinDelay(m::TupleIndex(m::Receive(m::TupleIndex(m::Receive())))),
m::Add())));

EXPECT_THAT(proc->GetNextStateElement(0),
m::Add(m::TupleIndex(m::Receive(m::Channel("a")), 1),
m::TupleIndex(m::Receive(m::Channel("b")), 1)));
EXPECT_THAT(proc->next_values(proc->GetStateRead(int64_t{0})),
ElementsAre(m::Next(
m::StateRead("st"),
m::Add(m::TupleIndex(m::Receive(m::Channel("a")), 1),
m::TupleIndex(m::Receive(m::Channel("b")), 1)))));
}

TEST(FunctionBuilderTest, StatelessProcBuilder) {
Expand Down Expand Up @@ -897,10 +904,14 @@ TEST(FunctionBuilderTest, ProcWithMultipleStateElements) {
EXPECT_THAT(proc->DumpIr(),
HasSubstr("proc the_proc(tkn: token, x: bits[32], y: bits[32], "
"z: bits[32], init={token, 1, 2, 3})"));
EXPECT_EQ(proc->GetNextStateElement(0)->GetName(), "tkn");
EXPECT_EQ(proc->GetNextStateElement(1)->GetName(), "x");
EXPECT_EQ(proc->GetNextStateElement(2)->GetName(), "x_plus_y");
EXPECT_EQ(proc->GetNextStateElement(3)->GetName(), "z");
EXPECT_THAT(proc->next_values(proc->GetStateRead(int64_t{0})),
ElementsAre(m::Next(m::StateRead("tkn"), m::StateRead("tkn"))));
EXPECT_THAT(proc->next_values(proc->GetStateRead(1)),
ElementsAre(m::Next(m::StateRead("x"), m::StateRead("x"))));
EXPECT_THAT(proc->next_values(proc->GetStateRead(2)),
ElementsAre(m::Next(m::StateRead("y"), m::Name("x_plus_y"))));
EXPECT_THAT(proc->next_values(proc->GetStateRead(3)),
ElementsAre(m::Next(m::StateRead("z"), m::StateRead("z"))));
}

TEST(FunctionBuilderTest, ProcWithNextStateElement) {
Expand All @@ -911,7 +922,7 @@ TEST(FunctionBuilderTest, ProcWithNextStateElement) {
BValue z = pb.StateElement("z", Value(UBits(3, 32)));
BValue next = pb.Next(/*state_read=*/y, /*value=*/z, /*pred=*/x);

XLS_ASSERT_OK(pb.Build(/*next_state=*/{x, y, z}));
XLS_ASSERT_OK(pb.Build());
EXPECT_THAT(next.node(),
m::Next(m::StateRead("y"), /*value=*/m::StateRead("z"),
/*predicate=*/m::StateRead("x")));
Expand All @@ -938,7 +949,8 @@ TEST(FunctionBuilderTest, TokenlessProcBuilderNoChannelOps) {
BValue state = pb.StateElement("st", Value(UBits(42, 16)));
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build({state}));

EXPECT_THAT(proc->GetNextStateElement(0), m::StateRead("st"));
EXPECT_THAT(proc->next_values(proc->GetStateRead(int64_t{0})),
ElementsAre(m::Next(m::StateRead("st"), m::StateRead("st"))));
}

TEST(FunctionBuilderTest, Assert) {
Expand Down
6 changes: 3 additions & 3 deletions xls/ir/ir_parser_error_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1239,7 +1239,7 @@ proc foo(my_token: token, my_state: bits[32], init={token, 42}) {
EXPECT_EQ(package->functions().size(), 0);
EXPECT_EQ(package->procs().size(), 1);
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, package->GetProc("foo"));
EXPECT_EQ(proc->node_count(), 2);
EXPECT_EQ(proc->node_count(), 4);
EXPECT_EQ(proc->StateElements().size(), 2);
EXPECT_EQ(proc->GetStateElement(int64_t{0})->initial_value().ToString(),
"token");
Expand Down Expand Up @@ -1467,8 +1467,8 @@ proc foo(my_token: token, my_state: bits[32], init={token, 42}) {
EXPECT_THAT(
Parser::ParsePackage(program).status(),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("Recurrent state type token does not match proc "
"parameter state type bits[32] for element 0.")));
HasSubstr("Recurrent state type token does not match provided "
"state type bits[32] for element 0.")));
}

TEST(IrParserErrorTest, ProcWithRet) {
Expand Down
2 changes: 1 addition & 1 deletion xls/ir/ir_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ proc foo(my_token: token, my_state: bits[32], init={token, 42}) {
EXPECT_EQ(package->functions().size(), 0);
EXPECT_EQ(package->procs().size(), 1);
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, package->GetProc("foo"));
EXPECT_EQ(proc->node_count(), 2);
EXPECT_EQ(proc->node_count(), 4);
EXPECT_EQ(proc->StateElements().size(), 2);
EXPECT_EQ(proc->GetStateElement(0)->initial_value().ToString(), "token");
EXPECT_EQ(proc->GetStateElement(0)->type()->ToString(), "token");
Expand Down
4 changes: 4 additions & 0 deletions xls/ir/ir_test_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
namespace xls {

VerifiedPackage::~VerifiedPackage() {
if (!verify_) {
return;
}

absl::Status status = VerifyPackage(this);
if (!status.ok()) {
ADD_FAILURE() << absl::StrFormat(
Expand Down
5 changes: 5 additions & 0 deletions xls/ir/ir_test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ class VerifiedPackage : public Package {
public:
explicit VerifiedPackage(std::string_view name) : Package(name) {}
~VerifiedPackage() override;

void AcceptInvalid() { verify_ = false; }

private:
bool verify_ = true;
};

// A test base class with convenience functions for IR tests.
Expand Down
7 changes: 4 additions & 3 deletions xls/ir/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -903,12 +903,13 @@ absl::Status Node::ReplaceUsesWith(Node* replacement,
}

// If the replacement does not have an assigned name but this node does, move
// the name over to preserve the name. If this is a parameter node then don't
// move the name because we cannot clear the name of a parameter node.
// the name over to preserve the name. If this is a parameter or state-read
// node then don't move the name because we cannot clear the name of a
// parameter node or state element.
//
// We also don't replace the name if some use was filtered out and not
// updated.
if (all_replaced && !Is<Param>() && HasAssignedName() &&
if (all_replaced && !Is<Param>() && !Is<StateRead>() && HasAssignedName() &&
!replacement->HasAssignedName()) {
// Do not use SetName because we do not want the name to be uniqued which
// would add a suffix because (clearly) the name already exists.
Expand Down
Loading

0 comments on commit cb4cf1f

Please sign in to comment.