Skip to content

Commit

Permalink
Merge pull request #937 from ksss/const-lit
Browse files Browse the repository at this point in the history
Const types as literal types
  • Loading branch information
soutaro authored Mar 28, 2022
2 parents c37fcf3 + 03cbebb commit c6325af
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 29 deletions.
31 changes: 17 additions & 14 deletions lib/rbs/prototype/rb.rb
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def process(node, decls:, comments:, context:)
# Give up type prediction when node is MASGN.
Types::Bases::Any.new(location: nil)
else
node_type(value_node)
literal_to_type(value_node)
end
decls << AST::Declarations::Constant.new(
name: const_name,
Expand Down Expand Up @@ -408,7 +408,7 @@ def function_type_from_body(node, def_name)
name = lvasgn.children[0]
fun.optional_positionals << Types::Function::Param.new(
name: name,
type: node_type(lvasgn.children[1])
type: param_type(lvasgn.children[1])
)
end

Expand All @@ -429,7 +429,7 @@ def function_type_from_body(node, def_name)
when nil, :NODE_SPECIAL_REQUIRED_KEYWORD
fun.required_keywords[name] = Types::Function::Param.new(name: name, type: untyped)
when RubyVM::AbstractSyntaxTree::Node
fun.optional_keywords[name] = Types::Function::Param.new(name: name, type: node_type(value))
fun.optional_keywords[name] = Types::Function::Param.new(name: name, type: param_type(value))
else
raise "Unexpected keyword arg value: #{value}"
end
Expand Down Expand Up @@ -477,9 +477,9 @@ def literal_to_type(node)
when :DREGX
BuiltinNames::Regexp.instance_type
when :TRUE
BuiltinNames::TrueClass.instance_type
Types::Literal.new(literal: true, location: nil)
when :FALSE
BuiltinNames::FalseClass.instance_type
Types::Literal.new(literal: false, location: nil)
when :NIL
Types::Bases::Nil.new(location: nil)
when :LIT
Expand Down Expand Up @@ -536,6 +536,14 @@ def literal_to_type(node)
value_type = types_to_union_type(value_types)
BuiltinNames::Hash.instance_type([key_type, value_type])
end
when :CALL
receiver, method_name, * = node.children
case method_name
when :freeze, :tap, :itself, :dup, :clone, :taint, :untaint, :extend
literal_to_type(receiver)
else
default
end
else
untyped
end
Expand Down Expand Up @@ -570,7 +578,7 @@ def range_element_type(types)
end
end

def node_type(node, default: Types::Bases::Any.new(location: nil))
def param_type(node, default: Types::Bases::Any.new(location: nil))
case node.type
when :LIT
case node.children[0]
Expand All @@ -597,19 +605,14 @@ def node_type(node, default: Types::Bases::Any.new(location: nil))
BuiltinNames::Array.instance_type(default)
when :HASH
BuiltinNames::Hash.instance_type(default, default)
when :CALL
receiver, method_name, * = node.children
case method_name
when :freeze, :tap, :itself, :dup, :clone, :taint, :untaint, :extend
node_type(receiver)
else
default
end
else
default
end
end

# backward compatible
alias node_type param_type

def private
@private ||= AST::Members::Private.new(location: nil)
end
Expand Down
26 changes: 11 additions & 15 deletions test/rbs/rb_prototype_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def self.world
end
def kw_req(a:) end
def opt_with_method_call(a = 'a'.freeze, b: 'b'.dup) end
end
EOR

Expand All @@ -63,8 +61,6 @@ def hello: (untyped a, ?::Integer b, *untyped c, untyped d, e: untyped e, ?f: ::
def self.world: () { (untyped, untyped, untyped, x: untyped, y: untyped) -> untyped } -> untyped
def kw_req: (a: untyped a) -> nil
def opt_with_method_call: (?::String a, ?b: ::String b) -> nil
end
EOF
end
Expand Down Expand Up @@ -134,9 +130,9 @@ def regx: () -> ::Regexp
def dregx: () -> ::Regexp
def t: () -> ::TrueClass
def t: () -> true
def f: () -> ::FalseClass
def f: () -> false
def n: () -> nil
Expand Down Expand Up @@ -574,11 +570,11 @@ module Foo

assert_write parser.decls, <<-EOF
module Foo
VERSION: ::String
VERSION: "0.1.1"
FROZEN: ::String
FROZEN: "str"
::Hello::World: ::Symbol
::Hello::World: :foo
end
EOF
end
Expand Down Expand Up @@ -622,21 +618,21 @@ def test_literal_types
parser.parse(rb)

assert_write parser.decls, <<-EOF
A: ::Integer
A: 1
B: ::Float
C: ::String
D: ::Symbol
D: :hello
E: untyped?
E: nil
F: bool
F: false
G: ::Array[untyped]
G: ::Array[1 | 2 | 3]
H: ::Hash[untyped, untyped]
H: { id: 123 }
EOF
end

Expand Down

0 comments on commit c6325af

Please sign in to comment.