diff --git a/compiler+runtime/include/cpp/jank/analyze/expr/case.hpp b/compiler+runtime/include/cpp/jank/analyze/expr/case.hpp index c9f3fcedb..9b2f10518 100644 --- a/compiler+runtime/include/cpp/jank/analyze/expr/case.hpp +++ b/compiler+runtime/include/cpp/jank/analyze/expr/case.hpp @@ -14,8 +14,8 @@ namespace jank::analyze::expr native_integer shift{}; native_integer mask{}; native_box default_expr{}; - std::vector transformed_keys{}; - std::vector> exprs{}; + native_vector keys{}; + native_vector> exprs{}; void propagate_position(expression_position const pos) { diff --git a/compiler+runtime/include/cpp/jank/analyze/processor.hpp b/compiler+runtime/include/cpp/jank/analyze/processor.hpp index c747d190c..ad343163e 100644 --- a/compiler+runtime/include/cpp/jank/analyze/processor.hpp +++ b/compiler+runtime/include/cpp/jank/analyze/processor.hpp @@ -6,7 +6,6 @@ #include #include #include -#include #include namespace jank::runtime @@ -153,12 +152,6 @@ namespace jank::analyze /* Returns whether the form is a special symbol. */ native_bool is_special(runtime::object_ptr form); - struct keys_and_exprs - { - std::vector keys{}; - std::vector exprs{}; - }; - using special_function_type = std::functiondata.first().unwrap()); - if(first.data->type != object_type::symbol) - { - return err(error{ "invalid case: first element must be 'case*'" }); - } - if(runtime::expect_object(first).data->name != "case*") - { - return err(error{ "invalid case: first element must be 'case*'" }); - } - auto it{ o->data.rest() }; + if (it.first().is_none()) { + return err(error{"missing value expression."}); + } auto const value_expr_obj{ it.first().unwrap() }; auto const value_expr{ analyze(value_expr_obj, f, expression_position::value, fc, needs_box) }; + if (value_expr.is_err()) { + return err(error{value_expr.expect_err()}); + } it = it.rest(); + if (it.first().is_none()) { + return err(error{"missing shift value."}); + } auto const shift_obj{ it.first().unwrap() }; if(shift_obj.data->type != object_type::integer) { - return err(error{ "expected integer for shift" }); + return err(error{ "expected integer for shift." }); } auto const shift{ runtime::expect_object(shift_obj) }; it = it.rest(); + if(it.first().is_none()) { + return err(error{"missing mask value."}); + } auto const mask_obj{ it.first().unwrap() }; if(mask_obj.data->type != object_type::integer) { - return err(error{ "expected integer for mask" }); + return err(error{ "expected integer for mask." }); } auto const mask{ runtime::expect_object(mask_obj) }; it = it.rest(); + if(it.first().is_none()) { + return err(error{"missing default expression."}); + } auto const default_expr_obj{ it.first().unwrap() }; auto const default_expr{ analyze(default_expr_obj, f, position, fc, needs_box) }; it = it.rest(); + if(it.first().is_none()) { + return err(error{"missing keys and expressions for 'case'."}); + } auto const imap_obj{ it.first().unwrap() }; + struct keys_and_exprs + { + native_vector keys{}; + native_vector exprs{}; + }; auto const keys_exprs{ visit_map_like( - [&](auto const &typed_map) -> string_result { + [&](auto const typed_imap_obj) -> string_result { keys_and_exprs ret{}; - for(auto auto_{ typed_map->seq() }; auto_ != nullptr; auto_ = auto_->next()) + for(auto seq{ typed_imap_obj->seq() }; seq != nullptr; seq = seq->next()) { - auto const e{ auto_->first() }; - auto const k{ runtime::nth(e, make_box(0)) }; - auto const v{ runtime::nth(e, make_box(1)) }; - if(k.data->type != object_type::integer) + auto const e{ seq->first() }; + auto const k_obj{ runtime::nth(e, make_box(0)) }; + auto const v_obj{ runtime::nth(e, make_box(1)) }; + if(k_obj.data->type != object_type::integer) { return err("Map key for case* is expected to be an integer"); } - auto const transformed_key{ runtime::expect_object(k) }; - auto const expr{ analyze(v, f, position, fc, needs_box) }; + auto const key{ runtime::expect_object(k_obj) }; + auto const expr{ analyze(v_obj, f, position, fc, needs_box) }; if(expr.is_err()) { return err(expr.expect_err().message); } - ret.keys.push_back(transformed_key->data); + ret.keys.push_back(key->data); ret.exprs.push_back(expr.expect_ok()); } return ret; @@ -237,6 +250,7 @@ namespace jank::analyze return err("Expect map-like for case keys and exprs"); }, imap_obj) }; + if(keys_exprs.is_err()) { return err(error{ keys_exprs.expect_err() }); diff --git a/compiler+runtime/src/cpp/jank/c_api.cpp b/compiler+runtime/src/cpp/jank/c_api.cpp index e2bd07069..0538383c8 100644 --- a/compiler+runtime/src/cpp/jank/c_api.cpp +++ b/compiler+runtime/src/cpp/jank/c_api.cpp @@ -817,11 +817,11 @@ extern "C" return to_hash(o_obj); } - native_integer to_integer(object const *o) + static native_integer to_integer_or_hash(object const *o) { if(o->type == object_type::integer) { - return dyn_cast(o)->data; + return expect_object(o)->data; } return to_hash(o); @@ -830,23 +830,25 @@ extern "C" jank_native_integer jank_to_integer(jank_object_ptr const o) { auto const o_obj(reinterpret_cast(o)); - return to_integer(o_obj); + return to_integer_or_hash(o_obj); } - jank_native_integer shift_mask_case_integer(jank_object_ptr const o, + jank_native_integer jank_shift_mask_case_integer(jank_object_ptr const o, jank_native_integer const shift, jank_native_integer const mask) { auto const o_obj(reinterpret_cast(o)); - auto integer{ to_integer(o_obj) }; + auto integer{ to_integer_or_hash(o_obj) }; if(mask != 0) { if(o_obj->type == object_type::integer) { - integer = integer >= std::numeric_limits::min() - && integer <= std::numeric_limits::max() + /* We don't hash the integer if it's an int32 value. This is to be consistent with how keys are hashed in jank's + * case macro. */ + integer = (integer >= std::numeric_limits::min() + && integer <= std::numeric_limits::max()) ? integer - : to_hash(make_box(integer)); + : hash::integer(integer); } integer = (integer >> shift) & mask; } diff --git a/compiler+runtime/src/cpp/jank/codegen/llvm_processor.cpp b/compiler+runtime/src/cpp/jank/codegen/llvm_processor.cpp index 781b247d7..343c9a2a1 100644 --- a/compiler+runtime/src/cpp/jank/codegen/llvm_processor.cpp +++ b/compiler+runtime/src/cpp/jank/codegen/llvm_processor.cpp @@ -965,7 +965,7 @@ namespace jank::codegen ctx->builder->getInt64Ty(), { ctx->builder->getPtrTy(), ctx->builder->getInt64Ty(), ctx->builder->getInt64Ty() }, false)); - auto const fn(ctx->module->getOrInsertFunction("shift_mask_case_integer", integer_fn_type)); + auto const fn(ctx->module->getOrInsertFunction("jank_shift_mask_case_integer", integer_fn_type)); llvm::SmallVector const args{ value, llvm::ConstantInt::getSigned(ctx->builder->getInt64Ty(), expr.shift), @@ -973,11 +973,11 @@ namespace jank::codegen }; auto const call(ctx->builder->CreateCall(fn, args)); auto const switch_val(ctx->builder->CreateIntCast(call, ctx->builder->getInt64Ty(), true)); - auto const default_block = llvm::BasicBlock::Create(*ctx->llvm_ctx, "default", current_fn); + auto const default_block{llvm::BasicBlock::Create(*ctx->llvm_ctx, "default", current_fn)}; auto const switch_ - = ctx->builder->CreateSwitch(switch_val, default_block, expr.transformed_keys.size()); + {ctx->builder->CreateSwitch(switch_val, default_block, expr.keys.size())}; auto const merge_block - = is_return ? nullptr : llvm::BasicBlock::Create(*ctx->llvm_ctx, "merge", current_fn); + {is_return ? nullptr : llvm::BasicBlock::Create(*ctx->llvm_ctx, "merge", current_fn)}; ctx->builder->SetInsertPoint(default_block); auto const default_val{ gen(expr.default_expr, arity) }; @@ -985,16 +985,16 @@ namespace jank::codegen { ctx->builder->CreateBr(merge_block); } - auto const default_block_exit = ctx->builder->GetInsertBlock(); + auto const default_block_exit{ctx->builder->GetInsertBlock()}; llvm::SmallVector case_blocks; llvm::SmallVector case_values; - for(size_t block_counter = 0; block_counter < expr.transformed_keys.size(); ++block_counter) + for(size_t block_counter = 0; block_counter < expr.keys.size(); ++block_counter) { - auto const block_name = fmt::format("case_{}", block_counter); + auto const block_name { fmt::format("case_{}", block_counter)}; auto const block{ llvm::BasicBlock::Create(*ctx->llvm_ctx, block_name, current_fn) }; switch_->addCase(llvm::ConstantInt::getSigned(ctx->builder->getInt64Ty(), - expr.transformed_keys[block_counter]), + expr.keys[block_counter]), block); ctx->builder->SetInsertPoint(block); @@ -1007,13 +1007,12 @@ namespace jank::codegen case_blocks.push_back(ctx->builder->GetInsertBlock()); } - // Handle value position with PHI node if(!is_return) { ctx->builder->SetInsertPoint(merge_block); - auto const phi = ctx->builder->CreatePHI(ctx->builder->getPtrTy(), - expr.transformed_keys.size() + 1, - "switch_tmp"); + auto const phi {ctx->builder->CreatePHI(ctx->builder->getPtrTy(), + expr.keys.size() + 1, + "switch_tmp")}; phi->addIncoming(default_val, default_block_exit); for(size_t i = 0; i < case_blocks.size(); ++i) { diff --git a/compiler+runtime/src/jank/clojure/core.jank b/compiler+runtime/src/jank/clojure/core.jank index 945507752..980e3c3eb 100644 --- a/compiler+runtime/src/jank/clojure/core.jank +++ b/compiler+runtime/src/jank/clojure/core.jank @@ -3458,7 +3458,9 @@ (-> x (bit-shift-right shift) (bit-and mask))) (defn- case-hash - "Returns the input if it is within the range of a 32-bit signed integer, otherwise returns the hash of the input." + "Returns the input if it is within the range of a 32-bit signed integer, otherwise returns the hash of the + input. The native hash returns a int32 int so this is to make sure that (case-hash (case-hash x)) == (case-hash x). + A key may be hashed more than once and we need to make sure its value does not change in later hashing." [input] (if (and (integer? input) (>= input int32-min-value) @@ -3584,18 +3586,15 @@ (if (== (count tests) (count hashes)) (if (fits-table? hashes) ; compact case ints, no shift-mask - [0 - 0 + [0 0 (case-map expr-sym default case-hash identity tests thens skip-check)] (let [[shift mask] (or (maybe-min-hash hashes) [0 0])] (if (zero? mask) ; sparse case ints, no shift-mask - [0 - 0 + [0 0 (case-map expr-sym default case-hash identity tests thens skip-check)] ; compact case ints, with shift-mask - [shift - mask + [shift mask (case-map expr-sym default #(shift-mask shift mask (case-hash %)) identity tests thens skip-check)]))) ; resolve hash collisions and try again (let [[tests thens skip-check]