Skip to content

Commit

Permalink
Rework initialization of constants & class variables (#15333)
Browse files Browse the repository at this point in the history
Changes the `flag` that keeps track of initialization status of constants and class variables to have 3 states instead of 2. The third one indicates that the value is currently being initialized and allows detecting recursion.
Previously we used an array to keep track of values being currently initialized. This array is now unnecessary.

The signature of `__crystal_once` changes: it now takes an Int8 (i8) instead of Bool (i1) and drops the once `state` pointer which isn't needed anymore. So `__crystal_once_init` no longer initializes a state pointer and returns nil.

Also introduces a fast path for the (very likely) scenario that the variable is already initialized which doesn't need a mutex.
Also introduces an LLVM optimization that instructs LLVM to optimize away repeated calls to `__crystal_once` for the same initializer.

Requires a new compiler build to benefit from the improvement. The legacy versions of `__crystal_once` and `__crystal_once_init` are still supported by both the stdlib and the compiler to keep both forward & backward compatibility (1.15 and below can build 1.16+ and 1.16+ can build 1.15 and below).

A follow-up could leverage `ReferenceStorage` and `.unsafe_construct` to inline the `Mutex` instead of allocating in the GC heap. Along with #15330 then `__crystal_once_init` could become allocation free, which could prove useful for such a core/low level feature.

Co-authored-by: David Keller <[email protected]>
  • Loading branch information
ysbaddaden and BlobCodes authored Jan 20, 2025
1 parent 39aaae5 commit 8d02c8b
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 65 deletions.
6 changes: 3 additions & 3 deletions src/compiler/crystal/codegen/class_var.cr
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class Crystal::CodeGenVisitor
initialized_flag_name = class_var_global_initialized_name(class_var)
initialized_flag = @main_mod.globals[initialized_flag_name]?
unless initialized_flag
initialized_flag = @main_mod.globals.add(@main_llvm_context.int1, initialized_flag_name)
initialized_flag.initializer = @main_llvm_context.int1.const_int(0)
initialized_flag = @main_mod.globals.add(@main_llvm_context.int8, initialized_flag_name)
initialized_flag.initializer = @main_llvm_context.int8.const_int(0)
initialized_flag.linkage = LLVM::Linkage::Internal if @single_module
initialized_flag.thread_local = true if class_var.thread_local?
end
Expand Down Expand Up @@ -61,7 +61,7 @@ class Crystal::CodeGenVisitor
initialized_flag_name = class_var_global_initialized_name(class_var)
initialized_flag = @llvm_mod.globals[initialized_flag_name]?
unless initialized_flag
initialized_flag = @llvm_mod.globals.add(llvm_context.int1, initialized_flag_name)
initialized_flag = @llvm_mod.globals.add(llvm_context.int8, initialized_flag_name)
initialized_flag.thread_local = true if class_var.thread_local?
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/compiler/crystal/codegen/const.cr
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class Crystal::CodeGenVisitor
initialized_flag_name = const.initialized_llvm_name
initialized_flag = @main_mod.globals[initialized_flag_name]?
unless initialized_flag
initialized_flag = @main_mod.globals.add(@main_llvm_context.int1, initialized_flag_name)
initialized_flag.initializer = @main_llvm_context.int1.const_int(0)
initialized_flag = @main_mod.globals.add(@main_llvm_context.int8, initialized_flag_name)
initialized_flag.initializer = @main_llvm_context.int8.const_int(0)
initialized_flag.linkage = LLVM::Linkage::Internal if @single_module
end
initialized_flag
Expand Down
54 changes: 36 additions & 18 deletions src/compiler/crystal/codegen/once.cr
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,49 @@ class Crystal::CodeGenVisitor
if once_init_fun = typed_fun?(@main_mod, ONCE_INIT)
once_init_fun = check_main_fun ONCE_INIT, once_init_fun

once_state_global = @main_mod.globals.add(once_init_fun.type.return_type, ONCE_STATE)
once_state_global.linkage = LLVM::Linkage::Internal if @single_module
once_state_global.initializer = once_init_fun.type.return_type.null

state = call once_init_fun
store state, once_state_global
if once_init_fun.type.return_type.void?
call once_init_fun
else
# legacy (kept for backward compatibility): the compiler must save the
# state returned by __crystal_once_init
once_state_global = @main_mod.globals.add(once_init_fun.type.return_type, ONCE_STATE)
once_state_global.linkage = LLVM::Linkage::Internal if @single_module
once_state_global.initializer = once_init_fun.type.return_type.null

state = call once_init_fun
store state, once_state_global
end
end
end

def run_once(flag, func : LLVMTypedFunction)
once_fun = main_fun(ONCE)
once_init_fun = main_fun(ONCE_INIT)

# both of these should be Void*
once_state_type = once_init_fun.type.return_type
once_initializer_type = once_fun.func.params.last.type
once_fun_params = once_fun.func.params
once_initializer_type = once_fun_params.last.type # must be Void*
initializer = pointer_cast(func.func.to_value, once_initializer_type)

once_state_global = @llvm_mod.globals[ONCE_STATE]? || begin
global = @llvm_mod.globals.add(once_state_type, ONCE_STATE)
global.linkage = LLVM::Linkage::External
global
if once_fun_params.size == 2
args = [flag, initializer]
else
# legacy (kept for backward compatibility): the compiler must pass the
# state returned by __crystal_once_init to __crystal_once as the first
# argument
once_init_fun = main_fun(ONCE_INIT)
once_state_type = once_init_fun.type.return_type # must be Void*

once_state_global = @llvm_mod.globals[ONCE_STATE]? || begin
global = @llvm_mod.globals.add(once_state_type, ONCE_STATE)
global.linkage = LLVM::Linkage::External
global
end

state = load(once_state_type, once_state_global)
{% if LibLLVM::IS_LT_150 %}
flag = bit_cast(flag, @llvm_context.int1.pointer) # cast Int8* to Bool*
{% end %}
args = [state, flag, initializer]
end

state = load(once_state_type, once_state_global)
initializer = pointer_cast(func.func.to_value, once_initializer_type)
call once_fun, [state, flag, initializer]
call once_fun, args
end
end
172 changes: 130 additions & 42 deletions src/crystal/once.cr
Original file line number Diff line number Diff line change
@@ -1,54 +1,142 @@
# This file defines the functions `__crystal_once_init` and `__crystal_once` expected
# by the compiler. `__crystal_once` is called each time a constant or class variable
# has to be initialized and is its responsibility to verify the initializer is executed
# only once. `__crystal_once_init` is executed only once at the beginning of the program
# and the result is passed on each call to `__crystal_once`.

# This implementation uses an array to store the initialization flag pointers for each value
# to find infinite loops and raise an error. In multithread mode a mutex is used to
# avoid race conditions between threads.

# :nodoc:
class Crystal::OnceState
@rec = [] of Bool*

def once(flag : Bool*, initializer : Void*)
unless flag.value
if @rec.includes?(flag)
raise "Recursion while initializing class variables and/or constants"
# This file defines two functions expected by the compiler:
#
# - `__crystal_once_init`: executed only once at the beginning of the program
# and, for the legacy implementation, the result is passed on each call to
# `__crystal_once`.
#
# - `__crystal_once`: called each time a constant or class variable has to be
# initialized and is its responsibility to verify the initializer is executed
# only once and to fail on recursion.

# In multithread mode a mutex is used to avoid race conditions between threads.
#
# On Win32, `Crystal::System::FileDescriptor#@@reader_thread` spawns a new
# thread even without the `preview_mt` flag, and the thread can also reference
# Crystal constants, leading to race conditions, so we always enable the mutex.

{% if compare_versions(Crystal::VERSION, "1.16.0-dev") >= 0 %}
# This implementation uses an enum over the initialization flag pointer for
# each value to find infinite loops and raise an error.

module Crystal
# :nodoc:
enum OnceState : Int8
Processing = -1
Uninitialized = 0
Initialized = 1
end

{% if flag?(:preview_mt) || flag?(:win32) %}
@@once_mutex = uninitialized Mutex

# :nodoc:
def self.once_mutex=(@@once_mutex : Mutex)
end
@rec << flag
{% end %}

Proc(Nil).new(initializer, Pointer(Void).null).call
flag.value = true
# :nodoc:
# Using @[NoInline] so LLVM optimizes for the hot path (var already
# initialized).
@[NoInline]
def self.once(flag : OnceState*, initializer : Void*) : Nil
{% if flag?(:preview_mt) || flag?(:win32) %}
@@once_mutex.synchronize { once_exec(flag, initializer) }
{% else %}
once_exec(flag, initializer)
{% end %}

@rec.pop
# safety check, and allows to safely call `Intrinsics.unreachable` in
# `__crystal_once`
unless flag.value.initialized?
System.print_error "BUG: failed to initialize constant or class variable\n"
LibC._exit(1)
end
end

private def self.once_exec(flag : OnceState*, initializer : Void*) : Nil
case flag.value
in .initialized?
return
in .uninitialized?
flag.value = :processing
Proc(Nil).new(initializer, Pointer(Void).null).call
flag.value = :initialized
in .processing?
raise "Recursion while initializing class variables and/or constants"
end
end
end

# on Win32, `Crystal::System::FileDescriptor#@@reader_thread` spawns a new
# thread even without the `preview_mt` flag, and the thread can also reference
# Crystal constants, leading to race conditions, so we always enable the mutex
# TODO: can this be improved?
{% if flag?(:preview_mt) || flag?(:win32) %}
@mutex = Mutex.new(:reentrant)
# :nodoc:
fun __crystal_once_init : Nil
{% if flag?(:preview_mt) || flag?(:win32) %}
Crystal.once_mutex = Mutex.new(:reentrant)
{% end %}
end

# :nodoc:
#
# Using `@[AlwaysInline]` allows LLVM to optimize const accesses. Since this
# is a `fun` the function will still appear in the symbol table, though it
# will never be called.
@[AlwaysInline]
fun __crystal_once(flag : Crystal::OnceState*, initializer : Void*) : Nil
return if flag.value.initialized?

Crystal.once(flag, initializer)

# tell LLVM that it can optimize away repeated `__crystal_once` calls for
# this global (e.g. repeated access to constant in a single funtion);
# this is truly unreachable otherwise `Crystal.once` would have panicked
Intrinsics.unreachable unless flag.value.initialized?
end
{% else %}
# This implementation uses a global array to store the initialization flag
# pointers for each value to find infinite loops and raise an error.

# :nodoc:
class Crystal::OnceState
@rec = [] of Bool*

@[NoInline]
def once(flag : Bool*, initializer : Void*)
unless flag.value
@mutex.synchronize do
previous_def
if @rec.includes?(flag)
raise "Recursion while initializing class variables and/or constants"
end
@rec << flag

Proc(Nil).new(initializer, Pointer(Void).null).call
flag.value = true

@rec.pop
end
end
{% end %}
end

# :nodoc:
fun __crystal_once_init : Void*
Crystal::OnceState.new.as(Void*)
end

# :nodoc:
fun __crystal_once(state : Void*, flag : Bool*, initializer : Void*)
state.as(Crystal::OnceState).once(flag, initializer)
end

{% if flag?(:preview_mt) || flag?(:win32) %}
@mutex = Mutex.new(:reentrant)

@[NoInline]
def once(flag : Bool*, initializer : Void*)
unless flag.value
@mutex.synchronize do
previous_def
end
end
end
{% end %}
end

# :nodoc:
fun __crystal_once_init : Void*
Crystal::OnceState.new.as(Void*)
end

# :nodoc:
@[AlwaysInline]
fun __crystal_once(state : Void*, flag : Bool*, initializer : Void*)
return if flag.value
state.as(Crystal::OnceState).once(flag, initializer)
Intrinsics.unreachable unless flag.value
end
{% end %}
17 changes: 17 additions & 0 deletions src/intrinsics.cr
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,23 @@ module Intrinsics
macro va_end(ap)
::LibIntrinsics.va_end({{ap}})
end

# Should codegen to the following LLVM IR (before being inlined):
# ```
# define internal void @"*Intrinsics::unreachable:NoReturn"() #12 {
# entry:
# unreachable
# }
# ```
#
# Can be used like `@llvm.assume(i1 cond)` as `unreachable unless (assumption)`.
#
# WARNING: the behaviour of the program is undefined if the assumption is broken!
@[AlwaysInline]
def self.unreachable : NoReturn
x = uninitialized NoReturn
x
end
end

macro debugger
Expand Down

0 comments on commit 8d02c8b

Please sign in to comment.