Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jianlingzhong committed Feb 12, 2025
1 parent da7ccce commit e64ebe6
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 58 deletions.
4 changes: 2 additions & 2 deletions compiler+runtime/include/cpp/jank/analyze/expr/case.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ namespace jank::analyze::expr
native_integer shift{};
native_integer mask{};
native_box<E> default_expr{};
std::vector<native_integer> transformed_keys{};
std::vector<native_box<E>> exprs{};
native_vector<native_integer> keys{};
native_vector<native_box<E>> exprs{};

void propagate_position(expression_position const pos)
{
Expand Down
7 changes: 0 additions & 7 deletions compiler+runtime/include/cpp/jank/analyze/processor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include <jank/runtime/var.hpp>
#include <jank/analyze/local_frame.hpp>
#include <jank/analyze/expression.hpp>
#include <jank/runtime/obj/persistent_sorted_map.hpp>
#include <jank/option.hpp>

namespace jank::runtime
Expand Down Expand Up @@ -153,12 +152,6 @@ namespace jank::analyze
/* Returns whether the form is a special symbol. */
native_bool is_special(runtime::object_ptr form);

struct keys_and_exprs
{
std::vector<native_integer> keys{};
std::vector<expression_ptr> exprs{};
};

using special_function_type
= std::function<expression_result(runtime::obj::persistent_list_ptr const &,
local_frame_ptr &,
Expand Down
2 changes: 1 addition & 1 deletion compiler+runtime/include/cpp/jank/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ extern "C"
jank_native_hash jank_to_hash(jank_object_ptr o);
jank_native_integer jank_to_integer(jank_object_ptr o);
jank_native_integer
shift_mask_case_integer(jank_object_ptr o, jank_native_integer shift, jank_native_integer mask);
jank_shift_mask_case_integer(jank_object_ptr o, jank_native_integer shift, jank_native_integer mask);

void jank_set_meta(jank_object_ptr o, jank_object_ptr meta);

Expand Down
56 changes: 35 additions & 21 deletions compiler+runtime/src/cpp/jank/analyze/processor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,61 +174,74 @@ namespace jank::analyze
return err(error{ "invalid case: incorrect number of elements in form" });
}

auto const first(o->data.first().unwrap());
if(first.data->type != object_type::symbol)
{
return err(error{ "invalid case: first element must be 'case*'" });
}
if(runtime::expect_object<obj::symbol>(first).data->name != "case*")
{
return err(error{ "invalid case: first element must be 'case*'" });
}

auto it{ o->data.rest() };
if (it.first().is_none()) {
return err(error{"missing value expression."});
}
auto const value_expr_obj{ it.first().unwrap() };
auto const value_expr{ analyze(value_expr_obj, f, expression_position::value, fc, needs_box) };
if (value_expr.is_err()) {
return err(error{value_expr.expect_err()});
}

it = it.rest();
if (it.first().is_none()) {
return err(error{"missing shift value."});
}
auto const shift_obj{ it.first().unwrap() };
if(shift_obj.data->type != object_type::integer)
{
return err(error{ "expected integer for shift" });
return err(error{ "expected integer for shift." });
}
auto const shift{ runtime::expect_object<runtime::obj::integer>(shift_obj) };

it = it.rest();
if(it.first().is_none()) {
return err(error{"missing mask value."});
}
auto const mask_obj{ it.first().unwrap() };
if(mask_obj.data->type != object_type::integer)
{
return err(error{ "expected integer for mask" });
return err(error{ "expected integer for mask." });
}
auto const mask{ runtime::expect_object<runtime::obj::integer>(mask_obj) };

it = it.rest();
if(it.first().is_none()) {
return err(error{"missing default expression."});
}
auto const default_expr_obj{ it.first().unwrap() };
auto const default_expr{ analyze(default_expr_obj, f, position, fc, needs_box) };

it = it.rest();
if(it.first().is_none()) {
return err(error{"missing keys and expressions for 'case'."});
}
auto const imap_obj{ it.first().unwrap() };
struct keys_and_exprs
{
native_vector<native_integer> keys{};
native_vector<expression_ptr> exprs{};
};
auto const keys_exprs{ visit_map_like(
[&](auto const &typed_map) -> string_result<keys_and_exprs> {
[&](auto const typed_imap_obj) -> string_result<keys_and_exprs> {
keys_and_exprs ret{};
for(auto auto_{ typed_map->seq() }; auto_ != nullptr; auto_ = auto_->next())
for(auto seq{ typed_imap_obj->seq() }; seq != nullptr; seq = seq->next())
{
auto const e{ auto_->first() };
auto const k{ runtime::nth(e, make_box(0)) };
auto const v{ runtime::nth(e, make_box(1)) };
if(k.data->type != object_type::integer)
auto const e{ seq->first() };
auto const k_obj{ runtime::nth(e, make_box(0)) };
auto const v_obj{ runtime::nth(e, make_box(1)) };
if(k_obj.data->type != object_type::integer)
{
return err("Map key for case* is expected to be an integer");
}
auto const transformed_key{ runtime::expect_object<obj::integer>(k) };
auto const expr{ analyze(v, f, position, fc, needs_box) };
auto const key{ runtime::expect_object<obj::integer>(k_obj) };
auto const expr{ analyze(v_obj, f, position, fc, needs_box) };
if(expr.is_err())
{
return err(expr.expect_err().message);
}
ret.keys.push_back(transformed_key->data);
ret.keys.push_back(key->data);
ret.exprs.push_back(expr.expect_ok());
}
return ret;
Expand All @@ -237,6 +250,7 @@ namespace jank::analyze
return err("Expect map-like for case keys and exprs");
},
imap_obj) };

if(keys_exprs.is_err())
{
return err(error{ keys_exprs.expect_err() });
Expand Down
18 changes: 10 additions & 8 deletions compiler+runtime/src/cpp/jank/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -817,11 +817,11 @@ extern "C"
return to_hash(o_obj);
}

native_integer to_integer(object const *o)
static native_integer to_integer_or_hash(object const *o)
{
if(o->type == object_type::integer)
{
return dyn_cast<obj::integer>(o)->data;
return expect_object<obj::integer>(o)->data;
}

return to_hash(o);
Expand All @@ -830,23 +830,25 @@ extern "C"
jank_native_integer jank_to_integer(jank_object_ptr const o)
{
auto const o_obj(reinterpret_cast<object *>(o));
return to_integer(o_obj);
return to_integer_or_hash(o_obj);
}

jank_native_integer shift_mask_case_integer(jank_object_ptr const o,
jank_native_integer jank_shift_mask_case_integer(jank_object_ptr const o,
jank_native_integer const shift,
jank_native_integer const mask)
{
auto const o_obj(reinterpret_cast<object *>(o));
auto integer{ to_integer(o_obj) };
auto integer{ to_integer_or_hash(o_obj) };
if(mask != 0)
{
if(o_obj->type == object_type::integer)
{
integer = integer >= std::numeric_limits<int32_t>::min()
&& integer <= std::numeric_limits<int32_t>::max()
/* We don't hash the integer if it's an int32 value. This is to be consistent with how keys are hashed in jank's
* case macro. */
integer = (integer >= std::numeric_limits<int32_t>::min()
&& integer <= std::numeric_limits<int32_t>::max())
? integer
: to_hash(make_box(integer));
: hash::integer(integer);
}
integer = (integer >> shift) & mask;
}
Expand Down
23 changes: 11 additions & 12 deletions compiler+runtime/src/cpp/jank/codegen/llvm_processor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -965,36 +965,36 @@ namespace jank::codegen
ctx->builder->getInt64Ty(),
{ ctx->builder->getPtrTy(), ctx->builder->getInt64Ty(), ctx->builder->getInt64Ty() },
false));
auto const fn(ctx->module->getOrInsertFunction("shift_mask_case_integer", integer_fn_type));
auto const fn(ctx->module->getOrInsertFunction("jank_shift_mask_case_integer", integer_fn_type));
llvm::SmallVector<llvm::Value *, 3> const args{
value,
llvm::ConstantInt::getSigned(ctx->builder->getInt64Ty(), expr.shift),
llvm::ConstantInt::getSigned(ctx->builder->getInt64Ty(), expr.mask)
};
auto const call(ctx->builder->CreateCall(fn, args));
auto const switch_val(ctx->builder->CreateIntCast(call, ctx->builder->getInt64Ty(), true));
auto const default_block = llvm::BasicBlock::Create(*ctx->llvm_ctx, "default", current_fn);
auto const default_block{llvm::BasicBlock::Create(*ctx->llvm_ctx, "default", current_fn)};
auto const switch_
= ctx->builder->CreateSwitch(switch_val, default_block, expr.transformed_keys.size());
{ctx->builder->CreateSwitch(switch_val, default_block, expr.keys.size())};
auto const merge_block
= is_return ? nullptr : llvm::BasicBlock::Create(*ctx->llvm_ctx, "merge", current_fn);
{is_return ? nullptr : llvm::BasicBlock::Create(*ctx->llvm_ctx, "merge", current_fn)};

ctx->builder->SetInsertPoint(default_block);
auto const default_val{ gen(expr.default_expr, arity) };
if(!is_return)
{
ctx->builder->CreateBr(merge_block);
}
auto const default_block_exit = ctx->builder->GetInsertBlock();
auto const default_block_exit{ctx->builder->GetInsertBlock()};

llvm::SmallVector<llvm::BasicBlock *> case_blocks;
llvm::SmallVector<llvm::Value *> case_values;
for(size_t block_counter = 0; block_counter < expr.transformed_keys.size(); ++block_counter)
for(size_t block_counter = 0; block_counter < expr.keys.size(); ++block_counter)
{
auto const block_name = fmt::format("case_{}", block_counter);
auto const block_name { fmt::format("case_{}", block_counter)};
auto const block{ llvm::BasicBlock::Create(*ctx->llvm_ctx, block_name, current_fn) };
switch_->addCase(llvm::ConstantInt::getSigned(ctx->builder->getInt64Ty(),
expr.transformed_keys[block_counter]),
expr.keys[block_counter]),
block);

ctx->builder->SetInsertPoint(block);
Expand All @@ -1007,13 +1007,12 @@ namespace jank::codegen
case_blocks.push_back(ctx->builder->GetInsertBlock());
}

// Handle value position with PHI node
if(!is_return)
{
ctx->builder->SetInsertPoint(merge_block);
auto const phi = ctx->builder->CreatePHI(ctx->builder->getPtrTy(),
expr.transformed_keys.size() + 1,
"switch_tmp");
auto const phi {ctx->builder->CreatePHI(ctx->builder->getPtrTy(),
expr.keys.size() + 1,
"switch_tmp")};
phi->addIncoming(default_val, default_block_exit);
for(size_t i = 0; i < case_blocks.size(); ++i)
{
Expand Down
13 changes: 6 additions & 7 deletions compiler+runtime/src/jank/clojure/core.jank
Original file line number Diff line number Diff line change
Expand Up @@ -3458,7 +3458,9 @@
(-> x (bit-shift-right shift) (bit-and mask)))

(defn- case-hash
"Returns the input if it is within the range of a 32-bit signed integer, otherwise returns the hash of the input."
"Returns the input if it is within the range of a 32-bit signed integer, otherwise returns the hash of the
input. The native hash returns a int32 int so this is to make sure that (case-hash (case-hash x)) == (case-hash x).
A key may be hashed more than once and we need to make sure its value does not change in later hashing."
[input]
(if (and (integer? input)
(>= input int32-min-value)
Expand Down Expand Up @@ -3584,18 +3586,15 @@
(if (== (count tests) (count hashes))
(if (fits-table? hashes)
; compact case ints, no shift-mask
[0
0
[0 0
(case-map expr-sym default case-hash identity tests thens skip-check)]
(let [[shift mask] (or (maybe-min-hash hashes) [0 0])]
(if (zero? mask)
; sparse case ints, no shift-mask
[0
0
[0 0
(case-map expr-sym default case-hash identity tests thens skip-check)]
; compact case ints, with shift-mask
[shift
mask
[shift mask
(case-map expr-sym default #(shift-mask shift mask (case-hash %)) identity tests thens skip-check)])))
; resolve hash collisions and try again
(let [[tests thens skip-check]
Expand Down

0 comments on commit e64ebe6

Please sign in to comment.