Skip to content

Commit

Permalink
SROA: generalize unswitchtupleunion optimization
Browse files Browse the repository at this point in the history
This commit improves SROA pass by extending the `unswitchtupleunion`
optimization to handle the general parametric types, e.g.:
```julia
julia> struct A{T}
           x::T
       end;

julia> function foo(a1, a2, c)
           t = c ? A(a1) : A(a2)
           return getfield(t, :x)
       end;

julia> only(Base.code_ircode(foo, (Int,Float64,Bool); optimize_until="SROA"))
```

> Before
```
2 1 ─      goto #3 if not _4                                          │
  2 ─ %2 = %new(A{Int64}, _2)::A{Int64}                               │╻ A
  └──      goto #4                                                    │
  3 ─ %4 = %new(A{Float64}, _3)::A{Float64}                           │╻ A
  4 ┄ %5 = φ (#2 => %2, #3 => %4)::Union{A{Float64}, A{Int64}}        │
3 │   %6 = Main.getfield(%5, :x)::Union{Float64, Int64}               │
  └──      return %6                                                  │
   => Union{Float64, Int64}
```

> After
```
julia> only(Base.code_ircode(foo, (Int,Float64,Bool); optimize_until="SROA"))
2 1 ─      goto #3 if not _4                                           │
  2 ─      nothing::A{Int64}                                           │╻ A
  └──      goto #4                                                     │
  3 ─      nothing::A{Float64}                                         │╻ A
  4 ┄ %8 = φ (#2 => _2, #3 => _3)::Union{Float64, Int64}               │
  │        nothing::Union{A{Float64}, A{Int64}}
3 │   %6 = %8::Union{Float64, Int64}                                   │
  └──      return %6                                                   │
   => Union{Float64, Int64}
```
  • Loading branch information
aviatesk committed Jul 11, 2023
1 parent 680e3b3 commit 59aad42
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 5 deletions.
4 changes: 2 additions & 2 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1107,8 +1107,8 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
end
struct_typ = widenconst(argextype(val, compact))
struct_typ_unwrapped = unwrap_unionall(struct_typ)
if isa(struct_typ, Union) && struct_typ <: Tuple
struct_typ_unwrapped = unswitchtupleunion(struct_typ_unwrapped)
if isa(struct_typ, Union)
struct_typ_unwrapped = unswitchtypeunion(struct_typ_unwrapped)
end
if isa(struct_typ_unwrapped, Union) && is_isdefined
lift_comparison!(isdefined, compact, idx, stmt, lifting_cache, 𝕃ₒ)
Expand Down
4 changes: 1 addition & 3 deletions base/compiler/typeutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,7 @@ function unionall_depth(@nospecialize ua) # aka subtype_env_size
return depth
end

# convert a Union of Tuple types to a Tuple of Unions
unswitchtupleunion(u::Union) = unswitchtypeunion(u, Tuple.name)

# convert a Union of same `UnionAll` types to the `UnionAll` type whose parameter is the Unions
function unswitchtypeunion(u::Union, typename::Union{Nothing,Core.TypeName}=nothing)
ts = uniontypes(u)
n = -1
Expand Down
21 changes: 21 additions & 0 deletions test/compiler/irpasses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1390,3 +1390,24 @@ function wrap1_wrap1_wrapper(b, x, y)
end
@test wrap1_wrap1_wrapper(true, 1, 1.0) === 1.0
@test wrap1_wrap1_wrapper(false, 1, 1.0) === 1

# Test unswitching-union optimization within SRO Apass
function sroaunswitchtuple(c, x1, x2)
t = c ? (x1,) : (x2,)
return getfield(t, 1)
end
struct UnswitchUnion{T}
x::T
end
function sroaunswitchstruct(c, x1, x2)
x = c ? UnswitchUnion(x1) : UnswitchUnion(x2)
return getfield(x, :x)
end
let src = code_typed1(sroaunswitchtuple, Tuple{Bool, Int, Float64})
@test count(isnew, src.code) == 0
@test count(iscall((src, getfield)), src.code) == 0
end
let src = code_typed1(sroaunswitchstruct, Tuple{Bool, Int, Float64})
@test count(isnew, src.code) == 0
@test count(iscall((src, getfield)), src.code) == 0
end

0 comments on commit 59aad42

Please sign in to comment.