Skip to content

Commit

Permalink
outputparser: improve BooleanOutputParser (#978)
Browse files Browse the repository at this point in the history
The BooleanOutputParser requests the LLM to respond with a boolean,
and gives examples such as `true` or `false`. However, it only parsed
respones that include YES or NO. This commits adds more values for
parsing and changes the tests to fit them.
  • Loading branch information
amitaifrey authored Sep 13, 2024
1 parent 346f626 commit ddb8293
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 18 deletions.
26 changes: 15 additions & 11 deletions outputparser/boolean_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ import (

// BooleanParser is an output parser used to parse the output of an LLM as a boolean.
type BooleanParser struct {
TrueStr string
FalseStr string
TrueStrings []string
FalseStrings []string
}

// NewBooleanParser returns a new BooleanParser.
func NewBooleanParser() BooleanParser {
return BooleanParser{
TrueStr: "YES",
FalseStr: "NO",
TrueStrings: []string{"YES", "TRUE"},
FalseStrings: []string{"NO", "FALSE"},
}
}

Expand All @@ -33,20 +33,24 @@ func (p BooleanParser) GetFormatInstructions() string {

func (p BooleanParser) parse(text string) (bool, error) {
text = normalize(text)
booleanStrings := []string{p.TrueStr, p.FalseStr}

if !slices.Contains(booleanStrings, text) {
return false, ParseError{
Text: text,
Reason: fmt.Sprintf("Expected output to be either '%s' or '%s', received %s", p.TrueStr, p.FalseStr, text),
}
if slices.Contains(p.TrueStrings, text) {
return true, nil
}

return text == p.TrueStr, nil
if slices.Contains(p.FalseStrings, text) {
return false, nil
}

return false, ParseError{
Text: text,
Reason: fmt.Sprintf("Expected output to one of %v, received %s", append(p.TrueStrings, p.FalseStrings...), text),
}
}

func normalize(text string) string {
text = strings.TrimSpace(text)
text = strings.Trim(text, "'\"`")
text = strings.ToUpper(text)

return text
Expand Down
59 changes: 52 additions & 7 deletions outputparser/boolean_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,70 @@ func TestBooleanParser(t *testing.T) {
},
{
input: "YESNO",
err: outputparser.ParseError{},
expected: false,
},
{
input: "ok",
err: outputparser.ParseError{},
expected: false,
},
{
input: "true",
expected: true,
},
{
input: "false",
expected: false,
},
{
input: "True",
expected: true,
},
{
input: "False",
expected: false,
},
{
input: "TRUE",
expected: true,
},
{
input: "FALSE",
expected: false,
},
{
input: "'TRUE'",
expected: true,
},
{
input: "`TRUE`",
expected: true,
},
{
input: "'TRUE`",
expected: true,
},
}

for _, tc := range testCases {
parser := outputparser.NewBooleanParser()

actual, err := parser.Parse(tc.input)
if tc.err != nil && err == nil {
t.Errorf("Expected error %v, got nil", tc.err)
}
t.Run(tc.input, func(t *testing.T) {
t.Parallel()

result, err := parser.Parse(tc.input)
if err != nil && tc.err == nil {
t.Errorf("Unexpected error: %v", err)
}

if err == nil && tc.err != nil {
t.Errorf("Expected error %v, got nil", tc.err)
}

if actual != tc.expected {
t.Errorf("Expected %v, got %v", tc.expected, actual)
}
if result != tc.expected {
t.Errorf("Expected %v, but got %v", tc.expected, result)
}
})
}
}

0 comments on commit ddb8293

Please sign in to comment.