Skip to content

Commit

Permalink
Define iterate for RemoteChannel (#48515)
Browse files Browse the repository at this point in the history
  • Loading branch information
quinnj authored Feb 7, 2023
1 parent 1ff04d8 commit 9639c42
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/Distributed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ module Distributed

# imports for extension
import Base: getindex, wait, put!, take!, fetch, isready, push!, length,
hash, ==, kill, close, isopen, showerror
hash, ==, kill, close, isopen, showerror, iterate, IteratorSize

# imports for use
using Base: Process, Semaphore, JLOptions, buffer_writes, @async_unwrap,
Expand Down
20 changes: 20 additions & 0 deletions src/remotecall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -778,3 +778,23 @@ function getindex(r::RemoteChannel, args...)
end
return remotecall_fetch(getindex, r.where, r, args...)
end

function iterate(c::RemoteChannel, state=nothing)
if isopen(c) || isready(c)
try
return (take!(c), nothing)
catch e
if isa(e, InvalidStateException) ||
(isa(e, RemoteException) &&
isa(e.captured.ex, InvalidStateException) &&
e.captured.ex.state === :closed)
return nothing
end
rethrow()
end
else
return nothing
end
end

IteratorSize(::Type{<:RemoteChannel}) = SizeUnknown()
26 changes: 26 additions & 0 deletions test/distributed_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,32 @@ function test_iteration(in_c, out_c)
end

test_iteration(Channel(10), Channel(10))
test_iteration(RemoteChannel(() -> Channel(10)), RemoteChannel(() -> Channel(10)))

@everywhere function test_iteration_take(ch)
count = 0
for x in ch
count += 1
end
return count
end

@everywhere function test_iteration_put(ch, total)
for i in 1:total
put!(ch, i)
end
close(ch)
end

let ch = RemoteChannel(() -> Channel(1))
@async test_iteration_put(ch, 10)
@test 10 == @fetchfrom id_other test_iteration_take(ch)
# now reverse
ch = RemoteChannel(() -> Channel(1))
@spawnat id_other test_iteration_put(ch, 10)
@test 10 == test_iteration_take(ch)
end

# make sure exceptions propagate when waiting on Tasks
@test_throws CompositeException (@sync (@async error("oops")))
try
Expand Down

0 comments on commit 9639c42

Please sign in to comment.