From c0a00f68aec6f8f7481be723e56015a8911513dd Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 22 Apr 2024 15:37:14 -0700 Subject: [PATCH 1/9] refactor modelfile parser --- parser/parser.go | 298 ++++++++++++++++++++++++++------------- parser/parser_test.go | 316 +++++++++++++++++++++++++++++++++++++----- server/routes_test.go | 1 - 3 files changed, 485 insertions(+), 130 deletions(-) diff --git a/parser/parser.go b/parser/parser.go index 947848b2..edb81615 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -6,8 +6,9 @@ import ( "errors" "fmt" "io" - "log/slog" "slices" + "strconv" + "strings" ) type Command struct { @@ -15,118 +16,219 @@ type Command struct { Args string } -func (c *Command) Reset() { - c.Name = "" - c.Args = "" -} +type state int -func Parse(reader io.Reader) ([]Command, error) { - var commands []Command - var command, modelCommand Command +const ( + stateNil state = iota + stateName + stateValue + stateParameter + stateMessage + stateComment +) - scanner := bufio.NewScanner(reader) - scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), bufio.MaxScanTokenSize) - scanner.Split(scanModelfile) - for scanner.Scan() { - line := scanner.Bytes() +var errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"") - fields := bytes.SplitN(line, []byte(" "), 2) - if len(fields) == 0 || len(fields[0]) == 0 { - continue +func Parse(r io.Reader) (cmds []Command, err error) { + var cmd Command + var curr state + var b bytes.Buffer + var role string + + br := bufio.NewReader(r) + for { + r, _, err := br.ReadRune() + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return nil, err } - switch string(bytes.ToUpper(fields[0])) { - case "FROM": - command.Name = "model" - command.Args = string(bytes.TrimSpace(fields[1])) - // copy command for validation - modelCommand = command - case "ADAPTER": - command.Name = string(bytes.ToLower(fields[0])) - command.Args = string(bytes.TrimSpace(fields[1])) - case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT": - command.Name = string(bytes.ToLower(fields[0])) - command.Args = string(fields[1]) - case "PARAMETER": - fields = bytes.SplitN(fields[1], []byte(" "), 2) - if len(fields) < 2 { - return nil, fmt.Errorf("missing value for %s", fields) + next, r, err := parseRuneForState(r, curr) + if errors.Is(err, io.ErrUnexpectedEOF) { + return nil, fmt.Errorf("%w: %s", err, b.String()) + } else if err != nil { + return nil, err + } + + if next != curr { + switch curr { + case stateName, stateParameter: + switch s := strings.ToLower(b.String()); s { + case "from": + cmd.Name = "model" + case "parameter": + next = stateParameter + case "message": + next = stateMessage + fallthrough + default: + cmd.Name = s + } + case stateMessage: + if !slices.Contains([]string{"system", "user", "assistant"}, b.String()) { + return nil, errInvalidRole + } + + role = b.String() + case stateComment, stateNil: + // pass + case stateValue: + s := b.String() + + s, ok := unquote(b.String()) + if !ok || isSpace(r) { + if _, err := b.WriteRune(r); err != nil { + return nil, err + } + + continue + } + + if role != "" { + s = role + ": " + s + role = "" + } + + cmd.Args = s + cmds = append(cmds, cmd) } - command.Name = string(fields[0]) - command.Args = string(bytes.TrimSpace(fields[1])) - case "EMBED": - return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead") - case "MESSAGE": - command.Name = string(bytes.ToLower(fields[0])) - fields = bytes.SplitN(fields[1], []byte(" "), 2) - if len(fields) < 2 { - return nil, fmt.Errorf("should be in the format ") + b.Reset() + curr = next + } + + if strconv.IsPrint(r) { + if _, err := b.WriteRune(r); err != nil { + return nil, err } - if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) { - return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"") - } - command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1])) + } + } + + // flush the buffer + switch curr { + case stateComment, stateNil: + // pass; nothing to flush + case stateValue: + if _, ok := unquote(b.String()); !ok { + return nil, io.ErrUnexpectedEOF + } + + cmd.Args = b.String() + cmds = append(cmds, cmd) + default: + return nil, io.ErrUnexpectedEOF + } + + for _, cmd := range cmds { + if cmd.Name == "model" { + return cmds, nil + } + } + + return nil, errors.New("no FROM line") +} + +func parseRuneForState(r rune, cs state) (state, rune, error) { + switch cs { + case stateNil: + switch { + case r == '#': + return stateComment, 0, nil + case isSpace(r), isNewline(r): + return stateNil, 0, nil default: - if !bytes.HasPrefix(fields[0], []byte("#")) { - // log a warning for unknown commands - slog.Warn(fmt.Sprintf("Unknown command: %s", fields[0])) - } - continue + return stateName, r, nil + } + case stateName: + switch { + case isAlpha(r): + return stateName, r, nil + case isSpace(r): + return stateValue, 0, nil + default: + return stateNil, 0, errors.New("invalid") + } + case stateValue: + switch { + case isNewline(r): + return stateNil, r, nil + case isSpace(r): + return stateNil, r, nil + default: + return stateValue, r, nil + } + case stateParameter: + switch { + case isAlpha(r), isNumber(r), r == '_': + return stateParameter, r, nil + case isSpace(r): + return stateValue, 0, nil + default: + return stateNil, 0, io.ErrUnexpectedEOF + } + case stateMessage: + switch { + case isAlpha(r): + return stateMessage, r, nil + case isSpace(r): + return stateValue, 0, nil + default: + return stateNil, 0, io.ErrUnexpectedEOF + } + case stateComment: + switch { + case isNewline(r): + return stateNil, 0, nil + default: + return stateComment, 0, nil + } + default: + return stateNil, 0, errors.New("") + } +} + +func unquote(s string) (string, bool) { + if len(s) == 0 { + return "", false + } + + // TODO: single quotes + if len(s) >= 3 && s[:3] == `"""` { + if len(s) >= 6 && s[len(s)-3:] == `"""` { + return s[3 : len(s)-3], true } - commands = append(commands, command) - command.Reset() + return "", false } - if modelCommand.Args == "" { - return nil, errors.New("no FROM line for the model was specified") - } - - return commands, scanner.Err() -} - -func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) { - advance, token, err = scan([]byte(`"""`), []byte(`"""`), data, atEOF) - if err != nil { - return 0, nil, err - } - - if advance > 0 && token != nil { - return advance, token, nil - } - - advance, token, err = scan([]byte(`"`), []byte(`"`), data, atEOF) - if err != nil { - return 0, nil, err - } - - if advance > 0 && token != nil { - return advance, token, nil - } - - return bufio.ScanLines(data, atEOF) -} - -func scan(openBytes, closeBytes, data []byte, atEOF bool) (advance int, token []byte, err error) { - newline := bytes.IndexByte(data, '\n') - - if start := bytes.Index(data, openBytes); start >= 0 && start < newline { - end := bytes.Index(data[start+len(openBytes):], closeBytes) - if end < 0 { - if atEOF { - return 0, nil, fmt.Errorf("unterminated %s: expecting %s", openBytes, closeBytes) - } else { - return 0, nil, nil - } + if len(s) >= 1 && s[0] == '"' { + if len(s) >= 2 && s[len(s)-1] == '"' { + return s[1 : len(s)-1], true } - n := start + len(openBytes) + end + len(closeBytes) - - newData := data[:start] - newData = append(newData, data[start+len(openBytes):n-len(closeBytes)]...) - return n, newData, nil + return "", false } - return 0, nil, nil + return s, true +} + +func isAlpha(r rune) bool { + return r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' +} + +func isNumber(r rune) bool { + return r >= '0' && r <= '9' +} + +func isSpace(r rune) bool { + return r == ' ' || r == '\t' +} + +func isNewline(r rune) bool { + return r == '\r' || r == '\n' +} + +func isValidRole(role string) bool { + return role == "system" || role == "user" || role == "assistant" } diff --git a/parser/parser_test.go b/parser/parser_test.go index 25e849b5..09ed2b92 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -1,13 +1,16 @@ package parser import ( + "bytes" + "fmt" + "io" "strings" "testing" "github.com/stretchr/testify/assert" ) -func Test_Parser(t *testing.T) { +func TestParser(t *testing.T) { input := ` FROM model1 @@ -35,7 +38,7 @@ TEMPLATE template1 assert.Equal(t, expectedCommands, commands) } -func Test_Parser_NoFromLine(t *testing.T) { +func TestParserNoFromLine(t *testing.T) { input := ` PARAMETER param1 value1 @@ -48,7 +51,7 @@ PARAMETER param2 value2 assert.ErrorContains(t, err, "no FROM line") } -func Test_Parser_MissingValue(t *testing.T) { +func TestParserParametersMissingValue(t *testing.T) { input := ` FROM foo @@ -58,41 +61,292 @@ PARAMETER param1 reader := strings.NewReader(input) _, err := Parse(reader) - assert.ErrorContains(t, err, "missing value for [param1]") - + assert.ErrorIs(t, err, io.ErrUnexpectedEOF) } -func Test_Parser_Messages(t *testing.T) { - - input := ` +func TestParserMessages(t *testing.T) { + var cases = []struct { + input string + expected []Command + err error + }{ + { + ` +FROM foo +MESSAGE system You are a Parser. Always Parse things. +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "message", Args: "system: You are a Parser. Always Parse things."}, + }, + nil, + }, + { + ` FROM foo MESSAGE system You are a Parser. Always Parse things. MESSAGE user Hey there! MESSAGE assistant Hello, I want to parse all the things! -` - - reader := strings.NewReader(input) - commands, err := Parse(reader) - assert.Nil(t, err) - - expectedCommands := []Command{ - {Name: "model", Args: "foo"}, - {Name: "message", Args: "system: You are a Parser. Always Parse things."}, - {Name: "message", Args: "user: Hey there!"}, - {Name: "message", Args: "assistant: Hello, I want to parse all the things!"}, - } - - assert.Equal(t, expectedCommands, commands) -} - -func Test_Parser_Messages_BadRole(t *testing.T) { - - input := ` +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "message", Args: "system: You are a Parser. Always Parse things."}, + {Name: "message", Args: "user: Hey there!"}, + {Name: "message", Args: "assistant: Hello, I want to parse all the things!"}, + }, + nil, + }, + { + ` +FROM foo +MESSAGE system """ +You are a multiline Parser. Always Parse things. +""" + `, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "message", Args: "system: \nYou are a multiline Parser. Always Parse things.\n"}, + }, + nil, + }, + { + ` FROM foo MESSAGE badguy I'm a bad guy! -` +`, + nil, + errInvalidRole, + }, + { + ` +FROM foo +MESSAGE system +`, + nil, + io.ErrUnexpectedEOF, + }, + { + ` +FROM foo +MESSAGE system`, + nil, + io.ErrUnexpectedEOF, + }, + } - reader := strings.NewReader(input) - _, err := Parse(reader) - assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"") + for _, c := range cases { + t.Run("", func(t *testing.T) { + commands, err := Parse(strings.NewReader(c.input)) + assert.ErrorIs(t, err, c.err) + assert.Equal(t, c.expected, commands) + }) + } +} + +func TestParserQuoted(t *testing.T) { + var cases = []struct { + multiline string + expected []Command + err error + }{ + { + ` +FROM foo +TEMPLATE """ +This is a +multiline template. +""" + `, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "template", Args: "\nThis is a\nmultiline template.\n"}, + }, + nil, + }, + { + ` +FROM foo +TEMPLATE """ +This is a +multiline template.""" + `, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "template", Args: "\nThis is a\nmultiline template."}, + }, + nil, + }, + { + ` +FROM foo +TEMPLATE """This is a +multiline template.""" + `, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "template", Args: "This is a\nmultiline template."}, + }, + nil, + }, + { + ` +FROM foo +TEMPLATE """This is a multiline template.""" + `, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "template", Args: "This is a multiline template."}, + }, + nil, + }, + { + ` +FROM foo +TEMPLATE """This is a multiline template."" + `, + nil, + io.ErrUnexpectedEOF, + }, + { + ` +FROM foo +TEMPLATE " + `, + nil, + io.ErrUnexpectedEOF, + }, + { + ` +FROM foo +TEMPLATE """ +This is a multiline template with "quotes". +""" +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "template", Args: "\nThis is a multiline template with \"quotes\".\n"}, + }, + nil, + }, + { + ` +FROM foo +TEMPLATE """""" +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "template", Args: ""}, + }, + nil, + }, + { + ` +FROM foo +TEMPLATE "" +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "template", Args: ""}, + }, + nil, + }, + { + ` +FROM foo +TEMPLATE "'" +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "template", Args: "'"}, + }, + nil, + }, + } + + for _, c := range cases { + t.Run("", func(t *testing.T) { + commands, err := Parse(strings.NewReader(c.multiline)) + assert.ErrorIs(t, err, c.err) + assert.Equal(t, c.expected, commands) + }) + } +} + +func TestParserParameters(t *testing.T) { + var cases = []string{ + "numa true", + "num_ctx 1", + "num_batch 1", + "num_gqa 1", + "num_gpu 1", + "main_gpu 1", + "low_vram true", + "f16_kv true", + "logits_all true", + "vocab_only true", + "use_mmap true", + "use_mlock true", + "num_thread 1", + "num_keep 1", + "seed 1", + "num_predict 1", + "top_k 1", + "top_p 1.0", + "tfs_z 1.0", + "typical_p 1.0", + "repeat_last_n 1", + "temperature 1.0", + "repeat_penalty 1.0", + "presence_penalty 1.0", + "frequency_penalty 1.0", + "mirostat 1", + "mirostat_tau 1.0", + "mirostat_eta 1.0", + "penalize_newline true", + "stop foo", + } + + for _, c := range cases { + t.Run(c, func(t *testing.T) { + var b bytes.Buffer + fmt.Fprintln(&b, "FROM foo") + fmt.Fprintln(&b, "PARAMETER", c) + t.Logf("input: %s", b.String()) + _, err := Parse(&b) + assert.Nil(t, err) + }) + } +} + +func TestParserOnlyFrom(t *testing.T) { + commands, err := Parse(strings.NewReader("FROM foo")) + assert.Nil(t, err) + + expected := []Command{{Name: "model", Args: "foo"}} + assert.Equal(t, expected, commands) +} + +func TestParserComments(t *testing.T) { + var cases = []struct { + input string + expected []Command + }{ + { + ` +# comment +FROM foo + `, + []Command{ + {Name: "model", Args: "foo"}, + }, + }, + } + + for _, c := range cases { + t.Run("", func(t *testing.T) { + commands, err := Parse(strings.NewReader(c.input)) + assert.Nil(t, err) + assert.Equal(t, c.expected, commands) + }) + } } diff --git a/server/routes_test.go b/server/routes_test.go index 4f907702..6ac98367 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -238,6 +238,5 @@ func Test_Routes(t *testing.T) { if tc.Expected != nil { tc.Expected(t, resp) } - } } From 238715037dec4f1afc49b2a008ae26d110133922 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 24 Apr 2024 16:08:51 -0700 Subject: [PATCH 2/9] linting --- parser/parser.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/parser/parser.go b/parser/parser.go index edb81615..1b80ebec 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "slices" "strconv" "strings" ) @@ -66,7 +65,7 @@ func Parse(r io.Reader) (cmds []Command, err error) { cmd.Name = s } case stateMessage: - if !slices.Contains([]string{"system", "user", "assistant"}, b.String()) { + if !isValidRole(b.String()) { return nil, errInvalidRole } @@ -74,8 +73,6 @@ func Parse(r io.Reader) (cmds []Command, err error) { case stateComment, stateNil: // pass case stateValue: - s := b.String() - s, ok := unquote(b.String()) if !ok || isSpace(r) { if _, err := b.WriteRune(r); err != nil { From abe614c705736eed06440d657ab75ca094fc78f3 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 24 Apr 2024 16:12:56 -0700 Subject: [PATCH 3/9] tests --- parser/parser.go | 7 +- parser/parser_test.go | 169 ++++++++++++++++++++++++++++-------------- 2 files changed, 118 insertions(+), 58 deletions(-) diff --git a/parser/parser.go b/parser/parser.go index 1b80ebec..a8133d78 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -26,7 +26,10 @@ const ( stateComment ) -var errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"") +var ( + errMissingFrom = errors.New("no FROM line") + errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"") +) func Parse(r io.Reader) (cmds []Command, err error) { var cmd Command @@ -123,7 +126,7 @@ func Parse(r io.Reader) (cmds []Command, err error) { } } - return nil, errors.New("no FROM line") + return nil, errMissingFrom } func parseRuneForState(r rune, cs state) (state, rune, error) { diff --git a/parser/parser_test.go b/parser/parser_test.go index 09ed2b92..94b4e8ad 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -11,7 +11,6 @@ import ( ) func TestParser(t *testing.T) { - input := ` FROM model1 ADAPTER adapter1 @@ -38,21 +37,62 @@ TEMPLATE template1 assert.Equal(t, expectedCommands, commands) } -func TestParserNoFromLine(t *testing.T) { +func TestParserFrom(t *testing.T) { + var cases = []struct { + input string + expected []Command + err error + }{ + { + "FROM foo", + []Command{{Name: "model", Args: "foo"}}, + nil, + }, + { + "FROM /path/to/model", + []Command{{Name: "model", Args: "/path/to/model"}}, + nil, + }, + { + "FROM /path/to/model/fp16.bin", + []Command{{Name: "model", Args: "/path/to/model/fp16.bin"}}, + nil, + }, + { + "FROM llama3:latest", + []Command{{Name: "model", Args: "llama3:latest"}}, + nil, + }, + { + "FROM llama3:7b-instruct-q4_K_M", + []Command{{Name: "model", Args: "llama3:7b-instruct-q4_K_M"}}, + nil, + }, + { + "", nil, errMissingFrom, + }, + { + "PARAMETER param1 value1", + nil, + errMissingFrom, + }, + { + "PARAMETER param1 value1\nFROM foo", + []Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}}, + nil, + }, + } - input := ` -PARAMETER param1 value1 -PARAMETER param2 value2 -` - - reader := strings.NewReader(input) - - _, err := Parse(reader) - assert.ErrorContains(t, err, "no FROM line") + for _, c := range cases { + t.Run("", func(t *testing.T) { + commands, err := Parse(strings.NewReader(c.input)) + assert.ErrorIs(t, err, c.err) + assert.Equal(t, c.expected, commands) + }) + } } func TestParserParametersMissingValue(t *testing.T) { - input := ` FROM foo PARAMETER param1 @@ -261,6 +301,17 @@ TEMPLATE "'" }, nil, }, + { + ` +FROM foo +TEMPLATE """''"'""'""'"'''''""'""'""" +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "template", Args: `''"'""'""'"'''''""'""'`}, + }, + nil, + }, } for _, c := range cases { @@ -273,59 +324,65 @@ TEMPLATE "'" } func TestParserParameters(t *testing.T) { - var cases = []string{ - "numa true", - "num_ctx 1", - "num_batch 1", - "num_gqa 1", - "num_gpu 1", - "main_gpu 1", - "low_vram true", - "f16_kv true", - "logits_all true", - "vocab_only true", - "use_mmap true", - "use_mlock true", - "num_thread 1", - "num_keep 1", - "seed 1", - "num_predict 1", - "top_k 1", - "top_p 1.0", - "tfs_z 1.0", - "typical_p 1.0", - "repeat_last_n 1", - "temperature 1.0", - "repeat_penalty 1.0", - "presence_penalty 1.0", - "frequency_penalty 1.0", - "mirostat 1", - "mirostat_tau 1.0", - "mirostat_eta 1.0", - "penalize_newline true", - "stop foo", + var cases = map[string]struct { + name, value string + }{ + "numa true": {"numa", "true"}, + "num_ctx 1": {"num_ctx", "1"}, + "num_batch 1": {"num_batch", "1"}, + "num_gqa 1": {"num_gqa", "1"}, + "num_gpu 1": {"num_gpu", "1"}, + "main_gpu 1": {"main_gpu", "1"}, + "low_vram true": {"low_vram", "true"}, + "f16_kv true": {"f16_kv", "true"}, + "logits_all true": {"logits_all", "true"}, + "vocab_only true": {"vocab_only", "true"}, + "use_mmap true": {"use_mmap", "true"}, + "use_mlock true": {"use_mlock", "true"}, + "num_thread 1": {"num_thread", "1"}, + "num_keep 1": {"num_keep", "1"}, + "seed 1": {"seed", "1"}, + "num_predict 1": {"num_predict", "1"}, + "top_k 1": {"top_k", "1"}, + "top_p 1.0": {"top_p", "1.0"}, + "tfs_z 1.0": {"tfs_z", "1.0"}, + "typical_p 1.0": {"typical_p", "1.0"}, + "repeat_last_n 1": {"repeat_last_n", "1"}, + "temperature 1.0": {"temperature", "1.0"}, + "repeat_penalty 1.0": {"repeat_penalty", "1.0"}, + "presence_penalty 1.0": {"presence_penalty", "1.0"}, + "frequency_penalty 1.0": {"frequency_penalty", "1.0"}, + "mirostat 1": {"mirostat", "1"}, + "mirostat_tau 1.0": {"mirostat_tau", "1.0"}, + "mirostat_eta 1.0": {"mirostat_eta", "1.0"}, + "penalize_newline true": {"penalize_newline", "true"}, + "stop ### User:": {"stop", "### User:"}, + "stop ### User: ": {"stop", "### User: "}, + "stop \"### User:\"": {"stop", "### User:"}, + "stop \"### User: \"": {"stop", "### User: "}, + "stop \"\"\"### User:\"\"\"": {"stop", "### User:"}, + "stop \"\"\"### User:\n\"\"\"": {"stop", "### User:\n"}, + "stop <|endoftext|>": {"stop", "<|endoftext|>"}, + "stop <|eot_id|>": {"stop", "<|eot_id|>"}, + "stop ": {"stop", ""}, } - for _, c := range cases { - t.Run(c, func(t *testing.T) { + for k, v := range cases { + t.Run(k, func(t *testing.T) { var b bytes.Buffer fmt.Fprintln(&b, "FROM foo") - fmt.Fprintln(&b, "PARAMETER", c) - t.Logf("input: %s", b.String()) - _, err := Parse(&b) + fmt.Fprintln(&b, "PARAMETER", k) + commands, err := Parse(&b) assert.Nil(t, err) + + assert.Equal(t, []Command{ + {Name: "model", Args: "foo"}, + {Name: v.name, Args: v.value}, + }, commands) }) } } -func TestParserOnlyFrom(t *testing.T) { - commands, err := Parse(strings.NewReader("FROM foo")) - assert.Nil(t, err) - - expected := []Command{{Name: "model", Args: "foo"}} - assert.Equal(t, expected, commands) -} - func TestParserComments(t *testing.T) { var cases = []struct { input string From 8907bf51d235e86854c762870f203380866f1ae3 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 24 Apr 2024 19:17:26 -0700 Subject: [PATCH 4/9] fix multiline --- parser/parser.go | 9 ++++-- parser/parser_test.go | 70 ++++++++++++++++++++++++++++--------------- 2 files changed, 53 insertions(+), 26 deletions(-) diff --git a/parser/parser.go b/parser/parser.go index a8133d78..a42c0275 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -110,11 +110,16 @@ func Parse(r io.Reader) (cmds []Command, err error) { case stateComment, stateNil: // pass; nothing to flush case stateValue: - if _, ok := unquote(b.String()); !ok { + s, ok := unquote(b.String()) + if !ok { return nil, io.ErrUnexpectedEOF } - cmd.Args = b.String() + if role != "" { + s = role + ": " + s + } + + cmd.Args = s cmds = append(cmds, cmd) default: return nil, io.ErrUnexpectedEOF diff --git a/parser/parser_test.go b/parser/parser_test.go index 94b4e8ad..1eb10157 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -124,6 +124,16 @@ MESSAGE system You are a Parser. Always Parse things. { ` FROM foo +MESSAGE system You are a Parser. Always Parse things.`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "message", Args: "system: You are a Parser. Always Parse things."}, + }, + nil, + }, + { + ` +FROM foo MESSAGE system You are a Parser. Always Parse things. MESSAGE user Hey there! MESSAGE assistant Hello, I want to parse all the things! @@ -192,57 +202,57 @@ func TestParserQuoted(t *testing.T) { { ` FROM foo -TEMPLATE """ +SYSTEM """ This is a -multiline template. +multiline system. """ `, []Command{ {Name: "model", Args: "foo"}, - {Name: "template", Args: "\nThis is a\nmultiline template.\n"}, + {Name: "system", Args: "\nThis is a\nmultiline system.\n"}, }, nil, }, { ` FROM foo -TEMPLATE """ +SYSTEM """ This is a -multiline template.""" +multiline system.""" `, []Command{ {Name: "model", Args: "foo"}, - {Name: "template", Args: "\nThis is a\nmultiline template."}, + {Name: "system", Args: "\nThis is a\nmultiline system."}, }, nil, }, { ` FROM foo -TEMPLATE """This is a -multiline template.""" +SYSTEM """This is a +multiline system.""" `, []Command{ {Name: "model", Args: "foo"}, - {Name: "template", Args: "This is a\nmultiline template."}, + {Name: "system", Args: "This is a\nmultiline system."}, }, nil, }, { ` FROM foo -TEMPLATE """This is a multiline template.""" +SYSTEM """This is a multiline system.""" `, []Command{ {Name: "model", Args: "foo"}, - {Name: "template", Args: "This is a multiline template."}, + {Name: "system", Args: "This is a multiline system."}, }, nil, }, { ` FROM foo -TEMPLATE """This is a multiline template."" +SYSTEM """This is a multiline system."" `, nil, io.ErrUnexpectedEOF, @@ -250,7 +260,7 @@ TEMPLATE """This is a multiline template."" { ` FROM foo -TEMPLATE " +SYSTEM " `, nil, io.ErrUnexpectedEOF, @@ -258,57 +268,69 @@ TEMPLATE " { ` FROM foo -TEMPLATE """ -This is a multiline template with "quotes". +SYSTEM """ +This is a multiline system with "quotes". """ `, []Command{ {Name: "model", Args: "foo"}, - {Name: "template", Args: "\nThis is a multiline template with \"quotes\".\n"}, + {Name: "system", Args: "\nThis is a multiline system with \"quotes\".\n"}, }, nil, }, { ` FROM foo -TEMPLATE """""" +SYSTEM """""" `, []Command{ {Name: "model", Args: "foo"}, - {Name: "template", Args: ""}, + {Name: "system", Args: ""}, }, nil, }, { ` FROM foo -TEMPLATE "" +SYSTEM "" `, []Command{ {Name: "model", Args: "foo"}, - {Name: "template", Args: ""}, + {Name: "system", Args: ""}, }, nil, }, { ` FROM foo -TEMPLATE "'" +SYSTEM "'" `, []Command{ {Name: "model", Args: "foo"}, - {Name: "template", Args: "'"}, + {Name: "system", Args: "'"}, }, nil, }, { ` FROM foo -TEMPLATE """''"'""'""'"'''''""'""'""" +SYSTEM """''"'""'""'"'''''""'""'""" `, []Command{ {Name: "model", Args: "foo"}, - {Name: "template", Args: `''"'""'""'"'''''""'""'`}, + {Name: "system", Args: `''"'""'""'"'''''""'""'`}, + }, + nil, + }, + { + ` +FROM foo +TEMPLATE """ +{{ .Prompt }} +"""`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "template", Args: "\n{{ .Prompt }}\n"}, }, nil, }, From 4d083635803fa9072c50a37a3cb9e2c5e8f97867 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 26 Apr 2024 15:13:27 -0700 Subject: [PATCH 5/9] comments --- parser/parser.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/parser/parser.go b/parser/parser.go index a42c0275..c6667d66 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -53,22 +53,26 @@ func Parse(r io.Reader) (cmds []Command, err error) { return nil, err } + // process the state transition, some transitions need to be intercepted and redirected if next != curr { switch curr { case stateName, stateParameter: + // next state sometimes depends on the current buffer value switch s := strings.ToLower(b.String()); s { case "from": cmd.Name = "model" case "parameter": + // transition to stateParameter which sets command name next = stateParameter case "message": + // transition to stateMessage which validates the message role next = stateMessage fallthrough default: cmd.Name = s } case stateMessage: - if !isValidRole(b.String()) { + if !isValidMessageRole(b.String()) { return nil, errInvalidRole } @@ -234,6 +238,6 @@ func isNewline(r rune) bool { return r == '\r' || r == '\n' } -func isValidRole(role string) bool { +func isValidMessageRole(role string) bool { return role == "system" || role == "user" || role == "assistant" } From 176ad3aa6edef56e00edb67ffec720a49a835060 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 24 Apr 2024 18:49:14 -0700 Subject: [PATCH 6/9] parser: add commands format --- cmd/cmd.go | 24 +++++++---------- parser/parser.go | 39 ++++++++++++++++++++++++++++ parser/parser_test.go | 60 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 108 insertions(+), 15 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index 2315ad1a..e3c1d873 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -17,7 +17,6 @@ import ( "os" "os/signal" "path/filepath" - "regexp" "runtime" "strings" "syscall" @@ -57,12 +56,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error { p := progress.NewProgress(os.Stderr) defer p.Stop() - modelfile, err := os.ReadFile(filename) + modelfile, err := os.Open(filename) if err != nil { return err } + defer modelfile.Close() - commands, err := parser.Parse(bytes.NewReader(modelfile)) + commands, err := parser.Parse(modelfile) if err != nil { return err } @@ -76,10 +76,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error { spinner := progress.NewSpinner(status) p.Add(status, spinner) - for _, c := range commands { - switch c.Name { + for i := range commands { + switch commands[i].Name { case "model", "adapter": - path := c.Args + path := commands[i].Args if path == "~" { path = home } else if strings.HasPrefix(path, "~/") { @@ -91,7 +91,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { } fi, err := os.Stat(path) - if errors.Is(err, os.ErrNotExist) && c.Name == "model" { + if errors.Is(err, os.ErrNotExist) && commands[i].Name == "model" { continue } else if err != nil { return err @@ -114,13 +114,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return err } - name := c.Name - if c.Name == "model" { - name = "from" - } - - re := regexp.MustCompile(fmt.Sprintf(`(?im)^(%s)\s+%s\s*$`, name, c.Args)) - modelfile = re.ReplaceAll(modelfile, []byte("$1 @"+digest)) + commands[i].Args = "@"+digest } } @@ -150,7 +144,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { quantization, _ := cmd.Flags().GetString("quantization") - request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile), Quantization: quantization} + request := api.CreateRequest{Name: args[0], Modelfile: parser.Format(commands), Quantization: quantization} if err := client.Create(cmd.Context(), &request, fn); err != nil { return err } diff --git a/parser/parser.go b/parser/parser.go index c6667d66..22e07235 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -31,6 +31,33 @@ var ( errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"") ) +func Format(cmds []Command) string { + var b bytes.Buffer + for _, cmd := range cmds { + name := cmd.Name + args := cmd.Args + + switch cmd.Name { + case "model": + name = "from" + args = cmd.Args + case "license", "template", "system", "adapter": + args = quote(args) + // pass + case "message": + role, message, _ := strings.Cut(cmd.Args, ": ") + args = role + " " + quote(message) + default: + name = "parameter" + args = cmd.Name + " " + cmd.Args + } + + fmt.Fprintln(&b, strings.ToUpper(name), args) + } + + return b.String() +} + func Parse(r io.Reader) (cmds []Command, err error) { var cmd Command var curr state @@ -197,6 +224,18 @@ func parseRuneForState(r rune, cs state) (state, rune, error) { } } +func quote(s string) string { + if strings.Contains(s, "\n") || strings.HasSuffix(s, " ") { + if strings.Contains(s, "\"") { + return `"""` + s + `"""` + } + + return strconv.Quote(s) + } + + return s +} + func unquote(s string) (string, bool) { if len(s) == 0 { return "", false diff --git a/parser/parser_test.go b/parser/parser_test.go index 1eb10157..0b08f1ab 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -429,3 +429,63 @@ FROM foo }) } } + +func TestParseFormatParse(t *testing.T) { + var cases = []string{ + ` +FROM foo +ADAPTER adapter1 +LICENSE MIT +PARAMETER param1 value1 +PARAMETER param2 value2 +TEMPLATE template1 +MESSAGE system You are a Parser. Always Parse things. +MESSAGE user Hey there! +MESSAGE assistant Hello, I want to parse all the things! +`, + ` +FROM foo +ADAPTER adapter1 +LICENSE MIT +PARAMETER param1 value1 +PARAMETER param2 value2 +TEMPLATE template1 +MESSAGE system """ +You are a store greeter. Always responsed with "Hello!". +""" +MESSAGE user Hey there! +MESSAGE assistant Hello, I want to parse all the things! +`, + ` +FROM foo +ADAPTER adapter1 +LICENSE """ +Very long and boring legal text. +Blah blah blah. +"Oh look, a quote!" +""" + +PARAMETER param1 value1 +PARAMETER param2 value2 +TEMPLATE template1 +MESSAGE system """ +You are a store greeter. Always responsed with "Hello!". +""" +MESSAGE user Hey there! +MESSAGE assistant Hello, I want to parse all the things! +`, + } + + for _, c := range cases { + t.Run("", func(t *testing.T) { + commands, err := Parse(strings.NewReader(c)) + assert.NoError(t, err) + + commands2, err := Parse(strings.NewReader(Format(commands))) + assert.NoError(t, err) + + assert.Equal(t, commands, commands2) + }) + } + +} From 9cf0f2e9736c31bbd00f4671613d0e31ecb3c4ea Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 26 Apr 2024 16:59:31 -0700 Subject: [PATCH 7/9] use parser.Format instead of templating modelfile --- parser/parser.go | 13 +++--- server/images.go | 104 +++++++++++++++++++---------------------------- server/routes.go | 12 +++--- 3 files changed, 54 insertions(+), 75 deletions(-) diff --git a/parser/parser.go b/parser/parser.go index 22e07235..6c451e99 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -32,7 +32,7 @@ var ( ) func Format(cmds []Command) string { - var b bytes.Buffer + var sb strings.Builder for _, cmd := range cmds { name := cmd.Name args := cmd.Args @@ -43,19 +43,18 @@ func Format(cmds []Command) string { args = cmd.Args case "license", "template", "system", "adapter": args = quote(args) - // pass case "message": role, message, _ := strings.Cut(cmd.Args, ": ") args = role + " " + quote(message) default: name = "parameter" - args = cmd.Name + " " + cmd.Args + args = cmd.Name + " " + quote(cmd.Args) } - fmt.Fprintln(&b, strings.ToUpper(name), args) + fmt.Fprintln(&sb, strings.ToUpper(name), args) } - return b.String() + return sb.String() } func Parse(r io.Reader) (cmds []Command, err error) { @@ -225,12 +224,12 @@ func parseRuneForState(r rune, cs state) (state, rune, error) { } func quote(s string) string { - if strings.Contains(s, "\n") || strings.HasSuffix(s, " ") { + if strings.Contains(s, "\n") || strings.HasPrefix(s, " ") || strings.HasSuffix(s, " ") { if strings.Contains(s, "\"") { return `"""` + s + `"""` } - return strconv.Quote(s) + return `"` + s + `"` } return s diff --git a/server/images.go b/server/images.go index 4e4107f7..68840c1a 100644 --- a/server/images.go +++ b/server/images.go @@ -21,7 +21,6 @@ import ( "runtime" "strconv" "strings" - "text/template" "golang.org/x/exp/slices" @@ -64,6 +63,48 @@ func (m *Model) IsEmbedding() bool { return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert") } +func (m *Model) Commands() (cmds []parser.Command) { + cmds = append(cmds, parser.Command{Name: "model", Args: m.ModelPath}) + + if m.Template != "" { + cmds = append(cmds, parser.Command{Name: "template", Args: m.Template}) + } + + if m.System != "" { + cmds = append(cmds, parser.Command{Name: "system", Args: m.System}) + } + + for _, adapter := range m.AdapterPaths { + cmds = append(cmds, parser.Command{Name: "adapter", Args: adapter}) + } + + for _, projector := range m.ProjectorPaths { + cmds = append(cmds, parser.Command{Name: "projector", Args: projector}) + } + + for k, v := range m.Options { + switch v := v.(type) { + case []any: + for _, s := range v { + cmds = append(cmds, parser.Command{Name: k, Args: fmt.Sprintf("%v", s)}) + } + default: + cmds = append(cmds, parser.Command{Name: k, Args: fmt.Sprintf("%v", v)}) + } + } + + for _, license := range m.License { + cmds = append(cmds, parser.Command{Name: "license", Args: license}) + } + + for _, msg := range m.Messages { + cmds = append(cmds, parser.Command{Name: "message", Args: fmt.Sprintf("%s %s", msg.Role, msg.Content)}) + } + + return cmds + +} + type Message struct { Role string `json:"role"` Content string `json:"content"` @@ -901,67 +942,6 @@ func DeleteModel(name string) error { return nil } -func ShowModelfile(model *Model) (string, error) { - var mt struct { - *Model - From string - Parameters map[string][]any - } - - mt.Parameters = make(map[string][]any) - for k, v := range model.Options { - if s, ok := v.([]any); ok { - mt.Parameters[k] = s - continue - } - - mt.Parameters[k] = []any{v} - } - - mt.Model = model - mt.From = model.ModelPath - - if model.ParentModel != "" { - mt.From = model.ParentModel - } - - modelFile := `# Modelfile generated by "ollama show" -# To build a new Modelfile based on this one, replace the FROM line with: -# FROM {{ .ShortName }} - -FROM {{ .From }} -TEMPLATE """{{ .Template }}""" - -{{- if .System }} -SYSTEM """{{ .System }}""" -{{- end }} - -{{- range $adapter := .AdapterPaths }} -ADAPTER {{ $adapter }} -{{- end }} - -{{- range $k, $v := .Parameters }} -{{- range $parameter := $v }} -PARAMETER {{ $k }} {{ printf "%#v" $parameter }} -{{- end }} -{{- end }}` - - tmpl, err := template.New("").Parse(modelFile) - if err != nil { - slog.Info(fmt.Sprintf("error parsing template: %q", err)) - return "", err - } - - var buf bytes.Buffer - - if err = tmpl.Execute(&buf, mt); err != nil { - slog.Info(fmt.Sprintf("error executing template: %q", err)) - return "", err - } - - return buf.String(), nil -} - func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error { mp := ParseModelPath(name) fn(api.ProgressResponse{Status: "retrieving manifest"}) diff --git a/server/routes.go b/server/routes.go index b1962d23..35b20f56 100644 --- a/server/routes.go +++ b/server/routes.go @@ -728,12 +728,12 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { } } - mf, err := ShowModelfile(model) - if err != nil { - return nil, err - } - - resp.Modelfile = mf + var sb strings.Builder + fmt.Fprintln(&sb, "# Modelfile generate by \"ollama show\"") + fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:") + fmt.Fprintf(&sb, "# FROM %s\n\n", model.ShortName) + fmt.Fprint(&sb, parser.Format(model.Commands())) + resp.Modelfile = sb.String() return resp, nil } From bd8eed57fc7d5f9b9b9d333b9c395864fb2378d8 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 26 Apr 2024 17:11:47 -0700 Subject: [PATCH 8/9] fix parser name --- parser/parser.go | 26 +++++++++++++++++++++----- parser/parser_test.go | 12 +++++++++++- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/parser/parser.go b/parser/parser.go index 6c451e99..9d1f3388 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -27,8 +27,9 @@ const ( ) var ( - errMissingFrom = errors.New("no FROM line") - errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"") + errMissingFrom = errors.New("no FROM line") + errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"") + errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"") ) func Format(cmds []Command) string { @@ -82,7 +83,11 @@ func Parse(r io.Reader) (cmds []Command, err error) { // process the state transition, some transitions need to be intercepted and redirected if next != curr { switch curr { - case stateName, stateParameter: + case stateName: + if !isValidCommand(b.String()) { + return nil, errInvalidCommand + } + // next state sometimes depends on the current buffer value switch s := strings.ToLower(b.String()); s { case "from": @@ -97,9 +102,11 @@ func Parse(r io.Reader) (cmds []Command, err error) { default: cmd.Name = s } + case stateParameter: + cmd.Name = b.String() case stateMessage: if !isValidMessageRole(b.String()) { - return nil, errInvalidRole + return nil, errInvalidMessageRole } role = b.String() @@ -182,7 +189,7 @@ func parseRuneForState(r rune, cs state) (state, rune, error) { case isSpace(r): return stateValue, 0, nil default: - return stateNil, 0, errors.New("invalid") + return stateNil, 0, errInvalidCommand } case stateValue: switch { @@ -279,3 +286,12 @@ func isNewline(r rune) bool { func isValidMessageRole(role string) bool { return role == "system" || role == "user" || role == "assistant" } + +func isValidCommand(cmd string) bool { + switch strings.ToLower(cmd) { + case "from", "license", "template", "system", "adapter", "parameter", "message": + return true + default: + return false + } +} diff --git a/parser/parser_test.go b/parser/parser_test.go index 0b08f1ab..a28205aa 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -104,6 +104,16 @@ PARAMETER param1 assert.ErrorIs(t, err, io.ErrUnexpectedEOF) } +func TestParserBadCommand(t *testing.T) { + input := ` +FROM foo +BADCOMMAND param1 value1 +` + _, err := Parse(strings.NewReader(input)) + assert.ErrorIs(t, err, errInvalidCommand) + +} + func TestParserMessages(t *testing.T) { var cases = []struct { input string @@ -165,7 +175,7 @@ FROM foo MESSAGE badguy I'm a bad guy! `, nil, - errInvalidRole, + errInvalidMessageRole, }, { ` From 5ea844964e4a59eef8a3162dc0f2e54c16bbb6e2 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 1 May 2024 09:53:36 -0700 Subject: [PATCH 9/9] cmd: import regexp --- cmd/cmd.go | 1 + 1 file changed, 1 insertion(+) diff --git a/cmd/cmd.go b/cmd/cmd.go index e3c1d873..fa3172ca 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -17,6 +17,7 @@ import ( "os" "os/signal" "path/filepath" + "regexp" "runtime" "strings" "syscall"