diff --git a/convert/convert_test.go b/convert/convert_test.go index f71ff8cd..9eb1632f 100644 --- a/convert/convert_test.go +++ b/convert/convert_test.go @@ -89,7 +89,7 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -func TestConvertFull(t *testing.T) { +func TestConvertModel(t *testing.T) { cases := []string{ "Meta-Llama-3-8B-Instruct", "Meta-Llama-3.1-8B-Instruct", diff --git a/convert/tokenizer.go b/convert/tokenizer.go index 653df6d2..14d6ba66 100644 --- a/convert/tokenizer.go +++ b/convert/tokenizer.go @@ -100,8 +100,21 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) } if template, ok := p["chat_template"]; ok { - if err := json.Unmarshal(template, &t.Template); err != nil { - return nil, err + var s []struct { + Name string `json:"name"` + Template string `json:"template"` + } + if err := json.Unmarshal(template, &t.Template); err == nil { + // noop + } else if err := json.Unmarshal(template, &s); err == nil { + for _, e := range s { + if e.Name == "default" { + t.Template = e.Template + break + } + } + } else { + return nil, fmt.Errorf("invalid chat_template: %w", err) } } @@ -141,7 +154,6 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) } type tokenizer struct { - Version string `json:"version"` AddedTokens []token `json:"added_tokens"` Model struct { Type string `json:"type"` @@ -239,7 +251,7 @@ func parseVocabulary(fsys fs.FS) (*Vocabulary, error) { return pattern.Func(fsys) } - return nil, errors.New("unknown tensor format") + return nil, errors.New("unknown tokenizer format") } type SpecialVocabulary struct { diff --git a/convert/tokenizer_test.go b/convert/tokenizer_test.go new file mode 100644 index 00000000..d9550e09 --- /dev/null +++ b/convert/tokenizer_test.go @@ -0,0 +1,208 @@ +package convert + +import ( + "io" + "io/fs" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func createTokenizerFS(t *testing.T, dir string, files map[string]io.Reader) fs.FS { + t.Helper() + + for k, v := range files { + if err := func() error { + f, err := os.Create(filepath.Join(dir, k)) + if err != nil { + return err + } + defer f.Close() + + if _, err := io.Copy(f, v); err != nil { + return err + } + + return nil + }(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + } + + return os.DirFS(dir) +} + +func TestParseTokenizer(t *testing.T) { + cases := []struct { + name string + fsys fs.FS + specialTokenTypes []string + want *Tokenizer + }{ + { + name: "string chat template", + fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{ + "tokenizer.json": strings.NewReader(`{}`), + "tokenizer_config.json": strings.NewReader(`{ + "chat_template": "" + }`), + }), + want: &Tokenizer{ + Vocabulary: &Vocabulary{Model: "gpt2"}, + Pre: "default", + Template: "", + }, + }, + { + name: "list chat template", + fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{ + "tokenizer.json": strings.NewReader(`{}`), + "tokenizer_config.json": strings.NewReader(`{ + "chat_template": [ + { + "name": "default", + "template": "" + }, + { + "name": "tools", + "template": "" + } + ] + }`), + }), + want: &Tokenizer{ + Vocabulary: &Vocabulary{Model: "gpt2"}, + Pre: "default", + Template: "", + }, + }, + { + name: "added tokens", + fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{ + "tokenizer.json": strings.NewReader(`{ + "added_tokens": [ + { + "id": 999, + "content": "", + "special": false + } + ] + }`), + }), + want: &Tokenizer{ + Vocabulary: &Vocabulary{ + Model: "gpt2", + Tokens: []string{""}, + Scores: []float32{999}, + Types: []int32{4}, + }, + Pre: "default", + }, + }, + { + name: "added tokens overlap vocab", + fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{ + "tokenizer.json": strings.NewReader(`{ + "added_tokens": [ + { + "id": 0, + "content": "", + "special": true + } + ], + "model": { + "vocab": { + "": 0 + } + } + }`), + }), + want: &Tokenizer{ + Vocabulary: &Vocabulary{ + Model: "gpt2", + Tokens: []string{""}, + Scores: []float32{0}, + Types: []int32{3}, + }, + Pre: "default", + }, + }, + { + name: "special token types", + fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{ + "tokenizer.json": strings.NewReader(`{ + "added_tokens": [ + { + "id": 0, + "content": "", + "special": true + }, + { + "id": 1, + "content": "", + "special": true + }, + { + "id": 2, + "content": "", + "special": true + }, + { + "id": 3, + "content": "", + "special": true + } + ], + "model": { + "vocab": { + "": 0, + "": 1, + "": 2, + "": 3 + } + } + }`), + "tokenizer_config.json": strings.NewReader(`{ + "add_bos_token": true, + "add_eos_token": false, + "bos_token": "", + "eos_token": "", + "pad_token": "", + "unk_token": "" + }`), + }), + specialTokenTypes: []string{"pad", "eos", "bos", "unk"}, + want: &Tokenizer{ + Vocabulary: &Vocabulary{ + Model: "gpt2", + Tokens: []string{"", "", "", ""}, + Scores: []float32{0, 1, 2, 3}, + Types: []int32{3, 3, 3, 3}, + }, + SpecialVocabulary: []*SpecialVocabulary{ + {Type: "pad", Content: "", ID: 0, AddToken: false}, + {Type: "eos", Content: "", ID: 1, AddToken: false}, + {Type: "bos", Content: "", ID: 2, AddToken: true}, + {Type: "unk", Content: "", ID: 3, AddToken: false}, + }, + Pre: "default", + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tokenizer, err := parseTokenizer(tt.fsys, tt.specialTokenTypes) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if diff := cmp.Diff(tt.want, tokenizer); diff != "" { + t.Errorf("unexpected tokenizer (-want +got):\n%s", diff) + } + }) + } +}