Skip to content

Commit

Permalink
[NFC] Add HeapType::getKind returning a new HeapTypeKind enum
Browse files Browse the repository at this point in the history
The HeapType API has functions like `isBasic()`, `isStruct()`,
`isSignature()`, etc. to test the classification of a heap type. Many
users have to call these functions in sequence and handle all or most of
the possible classifications. When we add a new kind of heap type,
finding and updating all these sites is a manual and error-prone
process.

To make adding new heap type kinds easier, introduce a new API that
returns an enum classifying the heap type. The enum can be used in
switch statements and the compiler's exhaustiveness checker will flag
use sites that need to be updated when we add a new kind of heap type.

This commit uses the new enum internally in the type system, but
follow-on commits will add new uses and convert uses of the existing
APIs to use `getKind` instead.
  • Loading branch information
tlively committed Aug 5, 2024
1 parent beec630 commit 19384e6
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 110 deletions.
34 changes: 23 additions & 11 deletions src/wasm-type.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,14 @@ class Type {

enum Shareability { Shared, Unshared };

enum class HeapTypeKind {
Basic,
Func,
Struct,
Array,
Cont,
};

class HeapType {
// Unlike `Type`, which represents the types of values on the WebAssembly
// stack, `HeapType` is used to describe the structures that reference types
Expand Down Expand Up @@ -364,17 +372,21 @@ class HeapType {
HeapType(Struct&& struct_);
HeapType(Array array);

HeapTypeKind getKind() const;

constexpr bool isBasic() const { return id <= _last_basic_type; }
bool isFunction() const;
bool isData() const;
bool isSignature() const;
// Indicates whether the given type was defined to be of the form
// `(cont $ft)`. Returns false for `cont`, the top type of the continuation
// type hierarchy (and all other types). In other words, this is analogous to
// `isSignature`, but for continuation types.
bool isContinuation() const;
bool isStruct() const;
bool isArray() const;
bool isFunction() const {
return isMaybeShared(func) || getKind() == HeapTypeKind::Func;
}
bool isData() const {
auto kind = getKind();
return isMaybeShared(string) || kind == HeapTypeKind::Struct ||
kind == HeapTypeKind::Array;
}
bool isSignature() const { return getKind() == HeapTypeKind::Func; }
bool isContinuation() const { return getKind() == HeapTypeKind::Cont; }
bool isStruct() const { return getKind() == HeapTypeKind::Struct; }
bool isArray() const { return getKind() == HeapTypeKind::Array; }
bool isBottom() const;
bool isOpen() const;
bool isShared() const { return getShared() == Shared; }
Expand All @@ -383,7 +395,7 @@ class HeapType {

// Check if the type is a given basic heap type, while ignoring whether it is
// shared or not.
bool isMaybeShared(BasicHeapType type) {
bool isMaybeShared(BasicHeapType type) const {
return isBasic() && getBasic(Unshared) == type;
}

Expand Down
173 changes: 74 additions & 99 deletions src/wasm/wasm-type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,32 +93,28 @@ struct HeapTypeInfo {
// (i.e. contains only this type).
RecGroupInfo* recGroup = nullptr;
size_t recGroupIndex = 0;
enum Kind {
SignatureKind,
ContinuationKind,
StructKind,
ArrayKind,
} kind;
HeapTypeKind kind;
union {
Signature signature;
Continuation continuation;
Struct struct_;
Array array;
};

HeapTypeInfo(Signature sig) : kind(SignatureKind), signature(sig) {}
HeapTypeInfo(Signature sig) : kind(HeapTypeKind::Func), signature(sig) {}
HeapTypeInfo(Continuation continuation)
: kind(ContinuationKind), continuation(continuation) {}
HeapTypeInfo(const Struct& struct_) : kind(StructKind), struct_(struct_) {}
: kind(HeapTypeKind::Cont), continuation(continuation) {}
HeapTypeInfo(const Struct& struct_)
: kind(HeapTypeKind::Struct), struct_(struct_) {}
HeapTypeInfo(Struct&& struct_)
: kind(StructKind), struct_(std::move(struct_)) {}
HeapTypeInfo(Array array) : kind(ArrayKind), array(array) {}
: kind(HeapTypeKind::Struct), struct_(std::move(struct_)) {}
HeapTypeInfo(Array array) : kind(HeapTypeKind::Array), array(array) {}
~HeapTypeInfo();

constexpr bool isSignature() const { return kind == SignatureKind; }
constexpr bool isContinuation() const { return kind == ContinuationKind; }
constexpr bool isStruct() const { return kind == StructKind; }
constexpr bool isArray() const { return kind == ArrayKind; }
constexpr bool isSignature() const { return kind == HeapTypeKind::Func; }
constexpr bool isContinuation() const { return kind == HeapTypeKind::Cont; }
constexpr bool isStruct() const { return kind == HeapTypeKind::Struct; }
constexpr bool isArray() const { return kind == HeapTypeKind::Array; }
constexpr bool isData() const { return isStruct() || isArray(); }
};

Expand Down Expand Up @@ -440,14 +436,16 @@ HeapType::BasicHeapType getBasicHeapSupertype(HeapType type) {
}
auto* info = getHeapTypeInfo(type);
switch (info->kind) {
case HeapTypeInfo::SignatureKind:
case HeapTypeKind::Func:
return HeapTypes::func.getBasic(info->share);
case HeapTypeInfo::ContinuationKind:
case HeapTypeKind::Cont:
return HeapTypes::cont.getBasic(info->share);
case HeapTypeInfo::StructKind:
case HeapTypeKind::Struct:
return HeapTypes::struct_.getBasic(info->share);
case HeapTypeInfo::ArrayKind:
case HeapTypeKind::Array:
return HeapTypes::array.getBasic(info->share);
case HeapTypeKind::Basic:
break;
}
WASM_UNREACHABLE("unexpected kind");
};
Expand Down Expand Up @@ -572,18 +570,20 @@ bool TypeInfo::operator==(const TypeInfo& other) const {

HeapTypeInfo::~HeapTypeInfo() {
switch (kind) {
case SignatureKind:
case HeapTypeKind::Func:
signature.~Signature();
return;
case ContinuationKind:
case HeapTypeKind::Cont:
continuation.~Continuation();
return;
case StructKind:
case HeapTypeKind::Struct:
struct_.~Struct();
return;
case ArrayKind:
case HeapTypeKind::Array:
array.~Array();
return;
case HeapTypeKind::Basic:
break;
}
WASM_UNREACHABLE("unexpected kind");
}
Expand Down Expand Up @@ -1094,52 +1094,11 @@ HeapType::HeapType(Array array) {
HeapType(globalRecGroupStore.insert(std::make_unique<HeapTypeInfo>(array)));
}

bool HeapType::isFunction() const {
if (isBasic()) {
return id == func;
} else {
return getHeapTypeInfo(*this)->isSignature();
}
}

bool HeapType::isData() const {
if (isBasic()) {
return id == struct_ || id == array || id == string;
} else {
return getHeapTypeInfo(*this)->isData();
}
}

bool HeapType::isSignature() const {
if (isBasic()) {
return false;
} else {
return getHeapTypeInfo(*this)->isSignature();
}
}

bool HeapType::isContinuation() const {
HeapTypeKind HeapType::getKind() const {
if (isBasic()) {
return false;
} else {
return getHeapTypeInfo(*this)->isContinuation();
}
}

bool HeapType::isStruct() const {
if (isBasic()) {
return false;
} else {
return getHeapTypeInfo(*this)->isStruct();
}
}

bool HeapType::isArray() const {
if (isBasic()) {
return false;
} else {
return getHeapTypeInfo(*this)->isArray();
return HeapTypeKind::Basic;
}
return getHeapTypeInfo(*this)->kind;
}

bool HeapType::isBottom() const {
Expand Down Expand Up @@ -1248,14 +1207,16 @@ std::optional<HeapType> HeapType::getSuperType() const {

auto* info = getHeapTypeInfo(*this);
switch (info->kind) {
case HeapTypeInfo::SignatureKind:
case HeapTypeKind::Func:
return HeapType(func).getBasic(share);
case HeapTypeInfo::ContinuationKind:
case HeapTypeKind::Cont:
return HeapType(cont).getBasic(share);
case HeapTypeInfo::StructKind:
case HeapTypeKind::Struct:
return HeapType(struct_).getBasic(share);
case HeapTypeInfo::ArrayKind:
case HeapTypeKind::Array:
return HeapType(array).getBasic(share);
case HeapTypeKind::Basic:
break;
}
WASM_UNREACHABLE("unexpected kind");
}
Expand Down Expand Up @@ -1341,13 +1302,15 @@ HeapType::BasicHeapType HeapType::getUnsharedBottom() const {
}
auto* info = getHeapTypeInfo(*this);
switch (info->kind) {
case HeapTypeInfo::SignatureKind:
case HeapTypeKind::Func:
return nofunc;
case HeapTypeInfo::ContinuationKind:
case HeapTypeKind::Cont:
return nocont;
case HeapTypeInfo::StructKind:
case HeapTypeInfo::ArrayKind:
case HeapTypeKind::Struct:
case HeapTypeKind::Array:
return none;
case HeapTypeKind::Basic:
break;
}
WASM_UNREACHABLE("unexpected kind");
}
Expand Down Expand Up @@ -2177,18 +2140,20 @@ size_t RecGroupHasher::hash(const HeapTypeInfo& info) const {
wasm::rehash(digest, info.share);
wasm::rehash(digest, info.kind);
switch (info.kind) {
case HeapTypeInfo::SignatureKind:
case HeapTypeKind::Func:
hash_combine(digest, hash(info.signature));
return digest;
case HeapTypeInfo::ContinuationKind:
case HeapTypeKind::Cont:
hash_combine(digest, hash(info.continuation));
return digest;
case HeapTypeInfo::StructKind:
case HeapTypeKind::Struct:
hash_combine(digest, hash(info.struct_));
return digest;
case HeapTypeInfo::ArrayKind:
case HeapTypeKind::Array:
hash_combine(digest, hash(info.array));
return digest;
case HeapTypeKind::Basic:
break;
}
WASM_UNREACHABLE("unexpected kind");
}
Expand Down Expand Up @@ -2318,14 +2283,16 @@ bool RecGroupEquator::eq(const HeapTypeInfo& a, const HeapTypeInfo& b) const {
return false;
}
switch (a.kind) {
case HeapTypeInfo::SignatureKind:
case HeapTypeKind::Func:
return eq(a.signature, b.signature);
case HeapTypeInfo::ContinuationKind:
case HeapTypeKind::Cont:
return eq(a.continuation, b.continuation);
case HeapTypeInfo::StructKind:
case HeapTypeKind::Struct:
return eq(a.struct_, b.struct_);
case HeapTypeInfo::ArrayKind:
case HeapTypeKind::Array:
return eq(a.array, b.array);
case HeapTypeKind::Basic:
break;
}
WASM_UNREACHABLE("unexpected kind");
}
Expand Down Expand Up @@ -2432,23 +2399,25 @@ void TypeGraphWalkerBase<Self>::scanHeapType(HeapType* ht) {
}
auto* info = getHeapTypeInfo(*ht);
switch (info->kind) {
case HeapTypeInfo::SignatureKind:
case HeapTypeKind::Func:
taskList.push_back(Task::scan(&info->signature.results));
taskList.push_back(Task::scan(&info->signature.params));
break;
case HeapTypeInfo::ContinuationKind:
case HeapTypeKind::Cont:
taskList.push_back(Task::scan(&info->continuation.type));
break;
case HeapTypeInfo::StructKind: {
case HeapTypeKind::Struct: {
auto& fields = info->struct_.fields;
for (auto field = fields.rbegin(); field != fields.rend(); ++field) {
taskList.push_back(Task::scan(&field->type));
}
break;
}
case HeapTypeInfo::ArrayKind:
case HeapTypeKind::Array:
taskList.push_back(Task::scan(&info->array.element.type));
break;
case HeapTypeKind::Basic:
WASM_UNREACHABLE("unexpected kind");
}
}

Expand Down Expand Up @@ -2476,18 +2445,20 @@ struct TypeBuilder::Impl {
void set(HeapTypeInfo&& hti) {
info->kind = hti.kind;
switch (info->kind) {
case HeapTypeInfo::SignatureKind:
case HeapTypeKind::Func:
info->signature = hti.signature;
break;
case HeapTypeInfo::ContinuationKind:
case HeapTypeKind::Cont:
info->continuation = hti.continuation;
break;
case HeapTypeInfo::StructKind:
case HeapTypeKind::Struct:
info->struct_ = std::move(hti.struct_);
break;
case HeapTypeInfo::ArrayKind:
case HeapTypeKind::Array:
info->array = hti.array;
break;
case HeapTypeKind::Basic:
WASM_UNREACHABLE("unexpected kind");
}
initialized = true;
}
Expand Down Expand Up @@ -2608,14 +2579,16 @@ bool isValidSupertype(const HeapTypeInfo& sub, const HeapTypeInfo& super) {
}
SubTyper typer;
switch (sub.kind) {
case HeapTypeInfo::SignatureKind:
case HeapTypeKind::Func:
return typer.isSubType(sub.signature, super.signature);
case HeapTypeInfo::ContinuationKind:
case HeapTypeKind::Cont:
return typer.isSubType(sub.continuation, super.continuation);
case HeapTypeInfo::StructKind:
case HeapTypeKind::Struct:
return typer.isSubType(sub.struct_, super.struct_);
case HeapTypeInfo::ArrayKind:
case HeapTypeKind::Array:
return typer.isSubType(sub.array, super.array);
case HeapTypeKind::Basic:
break;
}
WASM_UNREACHABLE("unknown kind");
}
Expand All @@ -2640,28 +2613,30 @@ validateType(HeapTypeInfo& info, std::unordered_set<HeapType>& seenTypes) {
}
if (info.share == Shared) {
switch (info.kind) {
case HeapTypeInfo::SignatureKind:
case HeapTypeKind::Func:
// TODO: Figure out and enforce shared function rules.
break;
case HeapTypeInfo::ContinuationKind:
case HeapTypeKind::Cont:
if (!info.continuation.type.isShared()) {
return TypeBuilder::ErrorReason::InvalidFuncType;
}
break;
case HeapTypeInfo::StructKind:
case HeapTypeKind::Struct:
for (auto& field : info.struct_.fields) {
if (field.type.isRef() && !field.type.getHeapType().isShared()) {
return TypeBuilder::ErrorReason::InvalidUnsharedField;
}
}
break;
case HeapTypeInfo::ArrayKind: {
case HeapTypeKind::Array: {
auto elem = info.array.element.type;
if (elem.isRef() && !elem.getHeapType().isShared()) {
return TypeBuilder::ErrorReason::InvalidUnsharedField;
}
break;
}
case HeapTypeKind::Basic:
WASM_UNREACHABLE("unexpected kind");
}
}
return std::nullopt;
Expand Down

0 comments on commit 19384e6

Please sign in to comment.