From bd8eed57fc7d5f9b9b9d333b9c395864fb2378d8 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 26 Apr 2024 17:11:47 -0700 Subject: [PATCH] fix parser name --- parser/parser.go | 26 +++++++++++++++++++++----- parser/parser_test.go | 12 +++++++++++- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/parser/parser.go b/parser/parser.go index 6c451e99..9d1f3388 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -27,8 +27,9 @@ const ( ) var ( - errMissingFrom = errors.New("no FROM line") - errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"") + errMissingFrom = errors.New("no FROM line") + errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"") + errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"") ) func Format(cmds []Command) string { @@ -82,7 +83,11 @@ func Parse(r io.Reader) (cmds []Command, err error) { // process the state transition, some transitions need to be intercepted and redirected if next != curr { switch curr { - case stateName, stateParameter: + case stateName: + if !isValidCommand(b.String()) { + return nil, errInvalidCommand + } + // next state sometimes depends on the current buffer value switch s := strings.ToLower(b.String()); s { case "from": @@ -97,9 +102,11 @@ func Parse(r io.Reader) (cmds []Command, err error) { default: cmd.Name = s } + case stateParameter: + cmd.Name = b.String() case stateMessage: if !isValidMessageRole(b.String()) { - return nil, errInvalidRole + return nil, errInvalidMessageRole } role = b.String() @@ -182,7 +189,7 @@ func parseRuneForState(r rune, cs state) (state, rune, error) { case isSpace(r): return stateValue, 0, nil default: - return stateNil, 0, errors.New("invalid") + return stateNil, 0, errInvalidCommand } case stateValue: switch { @@ -279,3 +286,12 @@ func isNewline(r rune) bool { func isValidMessageRole(role string) bool { return role == "system" || role == "user" || role == "assistant" } + +func isValidCommand(cmd string) bool { + switch strings.ToLower(cmd) { + case "from", "license", "template", "system", "adapter", "parameter", "message": + return true + default: + return false + } +} diff --git a/parser/parser_test.go b/parser/parser_test.go index 0b08f1ab..a28205aa 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -104,6 +104,16 @@ PARAMETER param1 assert.ErrorIs(t, err, io.ErrUnexpectedEOF) } +func TestParserBadCommand(t *testing.T) { + input := ` +FROM foo +BADCOMMAND param1 value1 +` + _, err := Parse(strings.NewReader(input)) + assert.ErrorIs(t, err, errInvalidCommand) + +} + func TestParserMessages(t *testing.T) { var cases = []struct { input string @@ -165,7 +175,7 @@ FROM foo MESSAGE badguy I'm a bad guy! `, nil, - errInvalidRole, + errInvalidMessageRole, }, { `