fix parser name

This commit is contained in:
Michael Yang 2024-04-26 17:11:47 -07:00
parent 9cf0f2e973
commit bd8eed57fc
2 changed files with 32 additions and 6 deletions

View file

@ -27,8 +27,9 @@ const (
) )
var ( var (
errMissingFrom = errors.New("no FROM line") errMissingFrom = errors.New("no FROM line")
errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"") errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"")
) )
func Format(cmds []Command) string { func Format(cmds []Command) string {
@ -82,7 +83,11 @@ func Parse(r io.Reader) (cmds []Command, err error) {
// process the state transition, some transitions need to be intercepted and redirected // process the state transition, some transitions need to be intercepted and redirected
if next != curr { if next != curr {
switch curr { switch curr {
case stateName, stateParameter: case stateName:
if !isValidCommand(b.String()) {
return nil, errInvalidCommand
}
// next state sometimes depends on the current buffer value // next state sometimes depends on the current buffer value
switch s := strings.ToLower(b.String()); s { switch s := strings.ToLower(b.String()); s {
case "from": case "from":
@ -97,9 +102,11 @@ func Parse(r io.Reader) (cmds []Command, err error) {
default: default:
cmd.Name = s cmd.Name = s
} }
case stateParameter:
cmd.Name = b.String()
case stateMessage: case stateMessage:
if !isValidMessageRole(b.String()) { if !isValidMessageRole(b.String()) {
return nil, errInvalidRole return nil, errInvalidMessageRole
} }
role = b.String() role = b.String()
@ -182,7 +189,7 @@ func parseRuneForState(r rune, cs state) (state, rune, error) {
case isSpace(r): case isSpace(r):
return stateValue, 0, nil return stateValue, 0, nil
default: default:
return stateNil, 0, errors.New("invalid") return stateNil, 0, errInvalidCommand
} }
case stateValue: case stateValue:
switch { switch {
@ -279,3 +286,12 @@ func isNewline(r rune) bool {
func isValidMessageRole(role string) bool { func isValidMessageRole(role string) bool {
return role == "system" || role == "user" || role == "assistant" return role == "system" || role == "user" || role == "assistant"
} }
func isValidCommand(cmd string) bool {
switch strings.ToLower(cmd) {
case "from", "license", "template", "system", "adapter", "parameter", "message":
return true
default:
return false
}
}

View file

@ -104,6 +104,16 @@ PARAMETER param1
assert.ErrorIs(t, err, io.ErrUnexpectedEOF) assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
} }
func TestParserBadCommand(t *testing.T) {
input := `
FROM foo
BADCOMMAND param1 value1
`
_, err := Parse(strings.NewReader(input))
assert.ErrorIs(t, err, errInvalidCommand)
}
func TestParserMessages(t *testing.T) { func TestParserMessages(t *testing.T) {
var cases = []struct { var cases = []struct {
input string input string
@ -165,7 +175,7 @@ FROM foo
MESSAGE badguy I'm a bad guy! MESSAGE badguy I'm a bad guy!
`, `,
nil, nil,
errInvalidRole, errInvalidMessageRole,
}, },
{ {
` `