diff --git a/convert/convert.go b/convert/convert.go index 8c7b0943..44783b6e 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -208,14 +208,18 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error { return err } - if vocabSize := int(p.VocabSize); vocabSize > len(t.Vocabulary.Tokens) { - slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", p.VocabSize, "actual", len(t.Vocabulary.Tokens)) + vocabSize := int(p.VocabSize) + switch { + case vocabSize > len(t.Vocabulary.Tokens): + slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens)) for i := range vocabSize - len(t.Vocabulary.Tokens) { t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i)) t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1) t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined) } - } else { + case vocabSize < len(t.Vocabulary.Tokens): + return fmt.Errorf("vocabulary is larger than expected '%d' instead of '%d'", len(t.Vocabulary.Tokens), vocabSize) + default: slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens)) }