Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve generics #1994

Merged
merged 4 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions ext/rbs_extension/parser.c
Original file line number Diff line number Diff line change
Expand Up @@ -1119,13 +1119,15 @@ VALUE parse_type(parserstate *state) {
type_params ::= {} `[` type_param `,` ... <`]`>
| {<>}

type_param ::= kUNCHECKED? (kIN|kOUT|) tUIDENT (module_type_params == true)
type_param ::= kUNCHECKED? (kIN|kOUT|) tUIDENT upper_bound? default_type? (module_type_params == true)

type_param ::= tUIDENT (module_type_params == false)
type_param ::= tUIDENT upper_bound? default_type? (module_type_params == false)
*/
VALUE parse_type_params(parserstate *state, range *rg, bool module_type_params) {
VALUE params = EMPTY_ARRAY;

bool required_param_allowed = true;

if (state->next_token.type == pLBRACKET) {
parser_advance(state);

Expand All @@ -1136,12 +1138,14 @@ VALUE parse_type_params(parserstate *state, range *rg, bool module_type_params)
bool unchecked = false;
VALUE variance = ID2SYM(rb_intern("invariant"));
VALUE upper_bound = Qnil;
VALUE default_type = Qnil;

range param_range = NULL_RANGE;
range name_range;
range variance_range = NULL_RANGE;
range unchecked_range = NULL_RANGE;
range upper_bound_range = NULL_RANGE;
range default_type_range = NULL_RANGE;

param_range.start = state->next_token.range.start;

Expand Down Expand Up @@ -1179,13 +1183,28 @@ VALUE parse_type_params(parserstate *state, range *rg, bool module_type_params)

if (state->next_token.type == pLT) {
parser_advance(state);
upper_bound_range.start = state->current_token.range.start;
upper_bound = parse_type(state);
upper_bound_range.end = state->current_token.range.end;
}

if (state->next_token.type == kSINGLETON) {
if (module_type_params) {
if (state->next_token.type == pEQ) {
parser_advance(state);
upper_bound = parse_singleton_type(state);

default_type_range.start = state->current_token.range.start;
default_type = parse_type(state);
default_type_range.start = state->current_token.range.end;

required_param_allowed = false;
} else {
parser_advance(state);
upper_bound = parse_instance_type(state, false);
if (!required_param_allowed) {
raise_syntax_error(
state,
state->current_token,
"required type parameter is not allowed after optional type parameter"
);
}
}
}

Expand All @@ -1198,8 +1217,9 @@ VALUE parse_type_params(parserstate *state, range *rg, bool module_type_params)
rbs_loc_add_optional_child(loc, rb_intern("variance"), variance_range);
rbs_loc_add_optional_child(loc, rb_intern("unchecked"), unchecked_range);
rbs_loc_add_optional_child(loc, rb_intern("upper_bound"), upper_bound_range);
rbs_loc_add_optional_child(loc, rb_intern("default"), default_type_range);

VALUE param = rbs_ast_type_param(name, variance, unchecked, upper_bound, location);
VALUE param = rbs_ast_type_param(name, variance, unchecked, upper_bound, default_type, location);
melt_array(&params);
rb_ary_push(params, param);

Expand Down
3 changes: 2 additions & 1 deletion ext/rbs_extension/ruby_objs.c
Original file line number Diff line number Diff line change
Expand Up @@ -307,11 +307,12 @@ VALUE rbs_ast_annotation(VALUE string, VALUE location) {
);
}

VALUE rbs_ast_type_param(VALUE name, VALUE variance, bool unchecked, VALUE upper_bound, VALUE location) {
VALUE rbs_ast_type_param(VALUE name, VALUE variance, bool unchecked, VALUE upper_bound, VALUE default_type, VALUE location) {
VALUE args = rb_hash_new();
rb_hash_aset(args, ID2SYM(rb_intern("name")), name);
rb_hash_aset(args, ID2SYM(rb_intern("variance")), variance);
rb_hash_aset(args, ID2SYM(rb_intern("upper_bound")), upper_bound);
rb_hash_aset(args, ID2SYM(rb_intern("default_type")), default_type);
rb_hash_aset(args, ID2SYM(rb_intern("location")), location);

VALUE type_param = CLASS_NEW_INSTANCE(RBS_AST_TypeParam, 1, &args);
Expand Down
2 changes: 1 addition & 1 deletion ext/rbs_extension/ruby_objs.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
VALUE rbs_alias(VALUE typename, VALUE args, VALUE location);
VALUE rbs_ast_annotation(VALUE string, VALUE location);
VALUE rbs_ast_comment(VALUE string, VALUE location);
VALUE rbs_ast_type_param(VALUE name, VALUE variance, bool unchecked, VALUE upper_bound, VALUE location);
VALUE rbs_ast_type_param(VALUE name, VALUE variance, bool unchecked, VALUE upper_bound, VALUE default_type, VALUE location);
VALUE rbs_ast_decl_type_alias(VALUE name, VALUE type_params, VALUE type, VALUE annotations, VALUE location, VALUE comment);
VALUE rbs_ast_decl_class_super(VALUE name, VALUE args, VALUE location);
VALUE rbs_ast_decl_class(VALUE name, VALUE type_params, VALUE super_class, VALUE members, VALUE annotations, VALUE location, VALUE comment);
Expand Down
86 changes: 71 additions & 15 deletions lib/rbs/ast/type_param.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,22 @@
module RBS
module AST
class TypeParam
attr_reader :name, :variance, :location, :upper_bound
attr_reader :name, :variance, :location, :upper_bound_type, :default_type

def initialize(name:, variance:, upper_bound:, location:)
def initialize(name:, variance:, upper_bound:, location:, default_type: nil)
@name = name
@variance = variance
@upper_bound = upper_bound
@upper_bound_type = upper_bound
@location = location
@unchecked = false
@default_type = default_type
end

def upper_bound
case upper_bound_type
when Types::ClassInstance, Types::ClassSingleton, Types::Interface
upper_bound_type
end
end

def unchecked!(value = true)
Expand All @@ -26,14 +34,15 @@ def ==(other)
other.is_a?(TypeParam) &&
other.name == name &&
other.variance == variance &&
other.upper_bound == upper_bound &&
other.upper_bound_type == upper_bound_type &&
other.default_type == default_type &&
other.unchecked? == unchecked?
end

alias eql? ==

def hash
self.class.hash ^ name.hash ^ variance.hash ^ upper_bound.hash ^ unchecked?.hash
self.class.hash ^ name.hash ^ variance.hash ^ upper_bound_type.hash ^ unchecked?.hash ^ default_type.hash
end

def to_json(state = JSON::State.new)
Expand All @@ -42,29 +51,36 @@ def to_json(state = JSON::State.new)
variance: variance,
unchecked: unchecked?,
location: location,
upper_bound: upper_bound
upper_bound: upper_bound_type,
default_type: default_type
}.to_json(state)
end

def rename(name)
TypeParam.new(
name: name,
variance: variance,
upper_bound: upper_bound,
location: location
upper_bound: upper_bound_type,
location: location,
default_type: default_type
).unchecked!(unchecked?)
end

def map_type(&block)
if b = upper_bound
_upper_bound = yield(b)
if b = upper_bound_type
_upper_bound_type = yield(b)
end

if dt = default_type
_default_type = yield(dt)
end

TypeParam.new(
name: name,
variance: variance,
upper_bound: _upper_bound,
location: location
upper_bound: _upper_bound_type,
location: location,
default_type: _default_type
).unchecked!(unchecked?)
end

Expand Down Expand Up @@ -101,8 +117,9 @@ def self.rename(params, new_names:)
TypeParam.new(
name: new_name,
variance: param.variance,
upper_bound: param.upper_bound&.map_type {|type| type.sub(subst) },
location: param.location
upper_bound: param.upper_bound_type&.map_type {|type| type.sub(subst) },
location: param.location,
default_type: param.default_type&.map_type {|type| type.sub(subst) }
).unchecked!(param.unchecked?)
end
end
Expand All @@ -125,12 +142,51 @@ def to_s

s << name.to_s

if type = upper_bound
if type = upper_bound_type
s << " < #{type}"
end

if dt = default_type
s << " = #{dt}"
end

s
end

def self.application(params, args)
subst = Substitution.new()

if params.empty?
return nil
end

min_count = params.count { _1.default_type.nil? }
max_count = params.size

unless min_count <= args.size && args.size <= max_count
raise "Invalid type application: required type params=#{min_count}, optional type params=#{max_count - min_count}, given args=#{args.size}"
end

params.zip(args).each do |param, arg|
if arg
subst.add(from: param.name, to: arg)
else
subst.add(from: param.name, to: param.default_type || raise)
end
end

subst
end

def self.normalize_args(params, args)
params.zip(args).filter_map do |param, arg|
if arg
arg
else
param.default_type
end
end
end
end
end
end
48 changes: 41 additions & 7 deletions lib/rbs/cli/validate.rb
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def run
private

def validate_class_module_definition
@env.class_decls.each do |name, decl|
@env.class_decls.each do |name, entry|
RBS.logger.info "Validating class/module definition: `#{name}`..."
@builder.build_instance(name).each_type do |type|
@validator.validate_type type, context: nil
Expand All @@ -107,30 +107,47 @@ def validate_class_module_definition
@errors.add(error)
end

case decl
case entry
when Environment::ClassEntry
decl.decls.each do |decl|
entry.decls.each do |decl|
if super_class = decl.decl.super_class
super_class.args.each do |arg|
void_type_context_validator(arg, true)
no_self_type_validator(arg)
no_classish_type_validator(arg)
@validator.validate_type(arg, context: nil)
end

if super_entry = @env.normalized_class_entry(super_class.name)
InvalidTypeApplicationError.check!(type_name: super_class.name, args: super_class.args, params: super_entry.type_params, location: super_class.location)
end
end
end
when Environment::ModuleEntry
decl.decls.each do |decl|
entry.decls.each do |decl|
decl.decl.self_types.each do |self_type|
self_type.args.each do |arg|
void_type_context_validator(arg, true)
no_self_type_validator(arg)
no_classish_type_validator(arg)
@validator.validate_type(arg, context: nil)
end

self_params =
if self_type.name.class?
@env.normalized_module_entry(self_type.name)&.type_params
else
@env.interface_decls[self_type.name]&.decl&.type_params
end

if self_params
InvalidTypeApplicationError.check!(type_name: self_type.name, params: self_params, args: self_type.args, location: self_type.location)
end
end
end
end

d = decl.primary.decl
d = entry.primary.decl

@validator.validate_type_params(
d.type_params,
Expand All @@ -139,14 +156,22 @@ def validate_class_module_definition
)

d.type_params.each do |param|
if ub = param.upper_bound
if ub = param.upper_bound_type
void_type_context_validator(ub)
no_self_type_validator(ub)
no_classish_type_validator(ub)
@validator.validate_type(ub, context: nil)
end

if dt = param.default_type
void_type_context_validator(dt)
no_self_type_validator(dt)
no_classish_type_validator(dt)
@validator.validate_type(dt, context: nil)
end
end

decl.decls.each do |d|
entry.decls.each do |d|
d.decl.each_member do |member|
case member
when AST::Members::MethodDefinition
Expand All @@ -163,6 +188,15 @@ def validate_class_module_definition
void_type_context_validator(arg, true)
end
end
params =
if member.name.class?
module_decl = @env.normalized_module_entry(member.name) or raise
module_decl.type_params
else
interface_decl = @env.interface_decls.fetch(member.name)
interface_decl.decl.type_params
end
InvalidTypeApplicationError.check!(type_name: member.name, params: params, args: member.args, location: member.location)
when AST::Members::Var
void_type_context_validator(member.type)
if member.is_a?(AST::Members::ClassVariable)
Expand Down
3 changes: 2 additions & 1 deletion lib/rbs/definition.rb
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,11 @@ def initialize(type_name:, params:, ancestors:)
end

def apply(args, location:)
# Assume default types of type parameters are already added to `args`
InvalidTypeApplicationError.check!(
type_name: type_name,
args: args,
params: params,
params: params.map { AST::TypeParam.new(name: _1, variance: :invariant, upper_bound: nil, location: nil, default_type: nil) },
location: location
)

Expand Down
Loading