Skip to content

Commit

Permalink
Merge pull request #31 from google/select-reject-attr
Browse files Browse the repository at this point in the history
Support rejectattr + select
  • Loading branch information
ochafik authored Jan 26, 2025
2 parents 0f5f7f2 + 97c5fed commit 6c51cdb
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 56 deletions.
9 changes: 7 additions & 2 deletions include/minja/chat-template.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class chat_template {
/* .keep_trailing_newline = */ false,
});
supports_tools_ = source.find("tools") != std::string::npos;

auto renders_string_arguments =
try_raw_render({
{
Expand Down Expand Up @@ -173,7 +173,12 @@ class chat_template {
if (tool_call["type"] == "function") {
auto & function = tool_call.at("function");
std::string arguments = function.at("arguments");
function["arguments"] = json::parse(arguments);
try {
function["arguments"] = json::parse(arguments);
} catch (const std::exception & ecvt) {
fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what());
function["arguments"] = arguments;
}
}
}
}
Expand Down
113 changes: 60 additions & 53 deletions include/minja/minja.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2648,31 +2648,34 @@ inline std::shared_ptr<Context> Context::builtins() {
return filter.call(context, actual_args);
});
};
// https://jinja.palletsprojects.com/en/3.0.x/templates/#jinja-filters.reject
globals.set("reject", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
args.expectArgs("reject", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
auto & items = args.args[0];
auto filter_fn = context->get(args.args[1]);
if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump());
auto select_or_reject = [make_filter](bool is_select) {
return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
auto & items = args.args[0];
auto filter_fn = context->get(args.args[1]);
if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump());

auto filter_args = Value::array();
for (size_t i = 2, n = args.args.size(); i < n; i++) {
filter_args.push_back(args.args[i]);
}
auto filter = make_filter(filter_fn, filter_args);
auto filter_args = Value::array();
for (size_t i = 2, n = args.args.size(); i < n; i++) {
filter_args.push_back(args.args[i]);
}
auto filter = make_filter(filter_fn, filter_args);

auto res = Value::array();
for (size_t i = 0, n = items.size(); i < n; i++) {
auto & item = items.at(i);
ArgumentsValue filter_args;
filter_args.args.emplace_back(item);
auto pred_res = filter.call(context, filter_args);
if (!pred_res.to_bool()) {
res.push_back(item);
auto res = Value::array();
for (size_t i = 0, n = items.size(); i < n; i++) {
auto & item = items.at(i);
ArgumentsValue filter_args;
filter_args.args.emplace_back(item);
auto pred_res = filter.call(context, filter_args);
if (pred_res.to_bool() == (is_select ? true : false)) {
res.push_back(item);
}
}
}
return res;
}));
return res;
});
};
globals.set("select", select_or_reject(/* is_select= */ true));
globals.set("reject", select_or_reject(/* is_select= */ false));
globals.set("map", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
auto res = Value::array();
if (args.args.size() == 1 &&
Expand Down Expand Up @@ -2720,41 +2723,45 @@ inline std::shared_ptr<Context> Context::builtins() {
if (!text.empty() && text.back() == '\n') out += "\n";
return out;
}));
globals.set("selectattr", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
args.expectArgs("selectattr", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
auto & items = args.args[0];
if (items.is_null())
return Value::array();
auto attr_name = args.args[1].get<std::string>();

bool has_test = false;
Value test_fn;
ArgumentsValue test_args {{Value()}, {}};
if (args.args.size() >= 3) {
has_test = true;
test_fn = context->get(args.args[2]);
if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump());
for (size_t i = 3, n = args.args.size(); i < n; i++) {
test_args.args.emplace_back(args.args[i]);
auto select_or_reject_attr = [](bool is_select) {
return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
args.expectArgs(is_select ? "selectattr" : "rejectattr", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
auto & items = args.args[0];
if (items.is_null())
return Value::array();
auto attr_name = args.args[1].get<std::string>();

bool has_test = false;
Value test_fn;
ArgumentsValue test_args {{Value()}, {}};
if (args.args.size() >= 3) {
has_test = true;
test_fn = context->get(args.args[2]);
if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump());
for (size_t i = 3, n = args.args.size(); i < n; i++) {
test_args.args.emplace_back(args.args[i]);
}
test_args.kwargs = args.kwargs;
}
test_args.kwargs = args.kwargs;
}

auto res = Value::array();
for (size_t i = 0, n = items.size(); i < n; i++) {
auto & item = items.at(i);
auto attr = item.get(attr_name);
if (has_test) {
test_args.args[0] = attr;
if (test_fn.call(context, test_args).to_bool()) {
res.push_back(item);
auto res = Value::array();
for (size_t i = 0, n = items.size(); i < n; i++) {
auto & item = items.at(i);
auto attr = item.get(attr_name);
if (has_test) {
test_args.args[0] = attr;
if (test_fn.call(context, test_args).to_bool() == (is_select ? true : false)) {
res.push_back(item);
}
} else {
res.push_back(attr);
}
} else {
res.push_back(attr);
}
}
return res;
}));
return res;
});
};
globals.set("selectattr", select_or_reject_attr(/* is_select= */ true));
globals.set("rejectattr", select_or_reject_attr(/* is_select= */ false));
globals.set("range", Value::callable([=](const std::shared_ptr<Context> &, ArgumentsValue & args) {
std::vector<int64_t> startEndStep(3);
std::vector<bool> param_set(3);
Expand Down
12 changes: 11 additions & 1 deletion scripts/fetch_templates_and_goldens.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def handle_chat_template(output_folder, model_id, variant, template_src, context
env.globals['strftime_now'] = strftime_now

template_handles_tools = 'tools' in template_src
supports_code_interpreter = 'code_interpreter' in template_src


def renders(messages, *, tools=[], add_generation_prompt=False, extra_context={}, expect_strings=[]):
Expand Down Expand Up @@ -142,6 +143,11 @@ def renders(messages, *, tools=[], add_generation_prompt=False, extra_context={}
context = json.load(f)

if not template_handles_tools and 'tools' in context:
print(f'Skipping {context_name} test as tools seem unsupported by template {template_file}', file=sys.stderr)
continue

if not supports_code_interpreter and 'tools' in context and any(t['type'] == 'code_interpreter' for t in context['tools']):
print(f'Skipping {context_name} test as code_interpreter seems unsupported by template {template_file}', file=sys.stderr)
continue

if not supports_system_role and any(m['role'] == 'system' for m in context['messages']):
Expand All @@ -155,7 +161,11 @@ def renders(messages, *, tools=[], add_generation_prompt=False, extra_context={}
for tool_call in message['tool_calls']:
if tool_call.get('type') == 'function':
arguments = tool_call['function']['arguments']
tool_call['function']['arguments'] = json.loads(arguments)
try:
arguments = json.loads(arguments)
except:
pass
tool_call['function']['arguments'] = arguments

if requires_typed_content:
for message in context['messages']:
Expand Down
43 changes: 43 additions & 0 deletions tests/contexts/tool_use_code_interpreter.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
{
"messages": [
{
"role": "user",
"content": "Print a hello world message with python."
},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_1___",
"type": "function",
"function": {
"arguments": "print('Hello, World!')",
"name": "python"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_1___",
"name": "python",
"content": "{\"stdout\": \"Hello, World!\"}"
}
],
"add_generation_prompt": true,
"bos_token": "<|startoftext|>",
"eos_token": "<|endoftext|>",
"builtin_tools": [
"wolfram_alpha",
"brave_search",
"code_interpreter"
],
"cutting_knowledge_date": "2023-04-01",
"todays_date": "2024-09-03",
"tools": [
{
"type": "code_interpreter"
}
]
}
6 changes: 6 additions & 0 deletions tests/test-syntax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ TEST(SyntaxTest, SimpleCases) {
EXPECT_EQ(
R"([{'a': 1}])",
render(R"({{ [{"a": 1}, {"a": 2}, {}] | selectattr("a", "equalto", 1) | list }})", {}, {}));
EXPECT_EQ(
R"([{'a': 2}, {}])",
render(R"({{ [{"a": 1}, {"a": 2}, {}] | rejectattr("a", "equalto", 1) | list }})", {}, {}));
EXPECT_EQ(
"[1, 2]",
render(R"({{ [{"a": 1}, {"a": 2}] | map(attribute="a") | list }})", {}, {}));
Expand Down Expand Up @@ -251,6 +254,9 @@ TEST(SyntaxTest, SimpleCases) {
EXPECT_EQ(
"Tools: 1, 3...",
render("{{ 'Tools: ' + [1, 2, 3] | reject('equalto', 2) | join(', ') + '...' }}", {}, {}));
EXPECT_EQ(
"Tools: 2...",
render("{{ 'Tools: ' + [1, 2, 3] | select('equalto', 2) | join(', ') + '...' }}", {}, {}));
EXPECT_EQ(
"1, 2, 3",
render("{{ [1, 2, 3] | join(', ') }}", {}, {}));
Expand Down

0 comments on commit 6c51cdb

Please sign in to comment.