Skip to content

Commit

Permalink
new: try to recover structured payload when the model returns it as d…
Browse files Browse the repository at this point in the history
…istinct arguments
  • Loading branch information
evilsocket committed Nov 21, 2024
1 parent 42a5282 commit b8fa372
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions nerve-core/src/agent/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,24 @@ impl Agent {
}

#[allow(clippy::borrowed_box)]
pub fn validate(&self, invocation: &Invocation, action: &Box<dyn Action>) -> Result<()> {
pub fn validate(&self, invocation: &mut Invocation, action: &Box<dyn Action>) -> Result<()> {
// validate prerequisites
let payload_required = action.example_payload().is_some();
let attrs_required = action.example_attributes().is_some();
let has_payload = invocation.payload.is_some();
let has_attributes = invocation.attributes.is_some();
let mut has_payload = invocation.payload.is_some();
let mut has_attributes = invocation.attributes.is_some();

// sometimes when the tool expects a json payload, the model returns it as separate arguments
// in this case we need to convert it back to a single json string
if (payload_required && !has_payload) && (!attrs_required && has_attributes) {
log::warn!(
"model returned the payload as separate arguments, converting back to payload"
);
invocation.payload = Some(serde_json::to_string(&invocation.attributes).unwrap());
invocation.attributes = None;
has_payload = true;
has_attributes = false;
}

if payload_required && !has_payload {
// payload required and not specified
Expand Down Expand Up @@ -363,7 +375,6 @@ impl Agent {
self.serializer.try_parse(response.trim())?
} else {
// the model supports function calling natively

tool_calls
};

Expand All @@ -379,15 +390,15 @@ impl Agent {
}

// for each parsed invocation
for inv in invocations {
for mut inv in invocations {
// lookup action
let action = self.state.lock().await.get_action(&inv.action);
if action.is_none() {
self.on_invalid_action(inv.clone(), None).await;
} else {
// validate prerequisites
let action = action.unwrap();
if let Err(err) = self.validate(&inv, &action) {
if let Err(err) = self.validate(&mut inv, &action) {
self.on_invalid_action(inv.clone(), Some(err.to_string()))
.await;
} else {
Expand Down

0 comments on commit b8fa372

Please sign in to comment.