ollama/parser/parser.go

283 lines
5.1 KiB
Go
Raw Normal View History

package parser
import (
"bufio"
2023-07-25 13:22:23 -04:00
"bytes"
"errors"
2023-07-27 12:55:48 -04:00
"fmt"
"io"
2024-04-22 18:37:14 -04:00
"strconv"
"strings"
)
type Command struct {
Name string
Args string
}
2024-04-22 18:37:14 -04:00
type state int
2024-04-22 18:37:14 -04:00
const (
stateNil state = iota
stateName
stateValue
stateParameter
stateMessage
stateComment
)
2024-04-24 19:12:56 -04:00
var (
errMissingFrom = errors.New("no FROM line")
errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"")
)
2024-04-22 18:37:14 -04:00
2024-04-24 21:49:14 -04:00
func Format(cmds []Command) string {
var b bytes.Buffer
for _, cmd := range cmds {
name := cmd.Name
args := cmd.Args
switch cmd.Name {
case "model":
name = "from"
args = cmd.Args
case "license", "template", "system", "adapter":
args = quote(args)
// pass
case "message":
role, message, _ := strings.Cut(cmd.Args, ": ")
args = role + " " + quote(message)
default:
name = "parameter"
args = cmd.Name + " " + cmd.Args
}
fmt.Fprintln(&b, strings.ToUpper(name), args)
}
return b.String()
}
2024-04-22 18:37:14 -04:00
func Parse(r io.Reader) (cmds []Command, err error) {
var cmd Command
var curr state
var b bytes.Buffer
var role string
br := bufio.NewReader(r)
for {
r, _, err := br.ReadRune()
if errors.Is(err, io.EOF) {
break
} else if err != nil {
return nil, err
}
2024-04-22 18:37:14 -04:00
next, r, err := parseRuneForState(r, curr)
if errors.Is(err, io.ErrUnexpectedEOF) {
return nil, fmt.Errorf("%w: %s", err, b.String())
} else if err != nil {
return nil, err
}
2024-04-26 18:13:27 -04:00
// process the state transition, some transitions need to be intercepted and redirected
2024-04-22 18:37:14 -04:00
if next != curr {
switch curr {
case stateName, stateParameter:
2024-04-26 18:13:27 -04:00
// next state sometimes depends on the current buffer value
2024-04-22 18:37:14 -04:00
switch s := strings.ToLower(b.String()); s {
case "from":
cmd.Name = "model"
case "parameter":
2024-04-26 18:13:27 -04:00
// transition to stateParameter which sets command name
2024-04-22 18:37:14 -04:00
next = stateParameter
case "message":
2024-04-26 18:13:27 -04:00
// transition to stateMessage which validates the message role
2024-04-22 18:37:14 -04:00
next = stateMessage
fallthrough
default:
cmd.Name = s
}
case stateMessage:
2024-04-26 18:13:27 -04:00
if !isValidMessageRole(b.String()) {
2024-04-22 18:37:14 -04:00
return nil, errInvalidRole
}
role = b.String()
case stateComment, stateNil:
// pass
case stateValue:
s, ok := unquote(b.String())
if !ok || isSpace(r) {
if _, err := b.WriteRune(r); err != nil {
return nil, err
}
continue
}
if role != "" {
s = role + ": " + s
role = ""
}
cmd.Args = s
cmds = append(cmds, cmd)
2023-08-10 19:09:02 -04:00
}
2024-04-22 18:37:14 -04:00
b.Reset()
curr = next
}
if strconv.IsPrint(r) {
if _, err := b.WriteRune(r); err != nil {
return nil, err
2023-08-10 19:22:08 -04:00
}
2024-04-22 18:37:14 -04:00
}
}
// flush the buffer
switch curr {
case stateComment, stateNil:
// pass; nothing to flush
case stateValue:
2024-04-24 22:17:26 -04:00
s, ok := unquote(b.String())
if !ok {
2024-04-22 18:37:14 -04:00
return nil, io.ErrUnexpectedEOF
}
2024-04-24 22:17:26 -04:00
if role != "" {
s = role + ": " + s
}
cmd.Args = s
2024-04-22 18:37:14 -04:00
cmds = append(cmds, cmd)
default:
return nil, io.ErrUnexpectedEOF
}
2024-04-22 18:37:14 -04:00
for _, cmd := range cmds {
if cmd.Name == "model" {
return cmds, nil
}
}
2024-04-24 19:12:56 -04:00
return nil, errMissingFrom
}
2024-04-22 18:37:14 -04:00
func parseRuneForState(r rune, cs state) (state, rune, error) {
switch cs {
case stateNil:
switch {
case r == '#':
return stateComment, 0, nil
case isSpace(r), isNewline(r):
return stateNil, 0, nil
default:
return stateName, r, nil
}
case stateName:
switch {
case isAlpha(r):
return stateName, r, nil
case isSpace(r):
return stateValue, 0, nil
default:
return stateNil, 0, errors.New("invalid")
}
case stateValue:
switch {
case isNewline(r):
return stateNil, r, nil
case isSpace(r):
return stateNil, r, nil
default:
return stateValue, r, nil
}
case stateParameter:
switch {
case isAlpha(r), isNumber(r), r == '_':
return stateParameter, r, nil
case isSpace(r):
return stateValue, 0, nil
default:
return stateNil, 0, io.ErrUnexpectedEOF
}
case stateMessage:
switch {
case isAlpha(r):
return stateMessage, r, nil
case isSpace(r):
return stateValue, 0, nil
default:
return stateNil, 0, io.ErrUnexpectedEOF
}
case stateComment:
switch {
case isNewline(r):
return stateNil, 0, nil
default:
return stateComment, 0, nil
}
default:
return stateNil, 0, errors.New("")
2023-07-27 12:55:48 -04:00
}
2024-04-22 18:37:14 -04:00
}
2023-07-27 12:55:48 -04:00
2024-04-24 21:49:14 -04:00
func quote(s string) string {
if strings.Contains(s, "\n") || strings.HasSuffix(s, " ") {
if strings.Contains(s, "\"") {
return `"""` + s + `"""`
}
return strconv.Quote(s)
}
return s
}
2024-04-22 18:37:14 -04:00
func unquote(s string) (string, bool) {
if len(s) == 0 {
return "", false
2023-07-27 12:55:48 -04:00
}
2023-07-25 13:22:23 -04:00
2024-04-22 18:37:14 -04:00
// TODO: single quotes
if len(s) >= 3 && s[:3] == `"""` {
if len(s) >= 6 && s[len(s)-3:] == `"""` {
return s[3 : len(s)-3], true
}
return "", false
2023-07-27 12:55:48 -04:00
}
2024-04-22 18:37:14 -04:00
if len(s) >= 1 && s[0] == '"' {
if len(s) >= 2 && s[len(s)-1] == '"' {
return s[1 : len(s)-1], true
}
return "", false
2023-07-27 12:55:48 -04:00
}
2024-04-22 18:37:14 -04:00
return s, true
2023-07-27 12:55:48 -04:00
}
2024-04-22 18:37:14 -04:00
func isAlpha(r rune) bool {
return r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z'
}
2024-04-22 18:37:14 -04:00
func isNumber(r rune) bool {
return r >= '0' && r <= '9'
}
2024-04-22 18:37:14 -04:00
func isSpace(r rune) bool {
return r == ' ' || r == '\t'
}
2024-04-22 18:37:14 -04:00
func isNewline(r rune) bool {
return r == '\r' || r == '\n'
}
2024-04-26 18:13:27 -04:00
func isValidMessageRole(role string) bool {
2024-04-22 18:37:14 -04:00
return role == "system" || role == "user" || role == "assistant"
}