diff --git a/convert/reader.go b/convert/reader.go index 56a8ae89..ce95208e 100644 --- a/convert/reader.go +++ b/convert/reader.go @@ -56,22 +56,25 @@ func (t *tensorBase) SetRepacker(fn repacker) { type repacker func(string, []float32, []uint64) ([]float32, error) func parseTensors(fsys fs.FS) ([]Tensor, error) { - patterns := map[string]func(fs.FS, ...string) ([]Tensor, error){ - "model-*-of-*.safetensors": parseSafetensors, - "model.safetensors": parseSafetensors, - "pytorch_model-*-of-*.bin": parseTorch, - "pytorch_model.bin": parseTorch, - "consolidated.*.pth": parseTorch, + patterns := []struct { + Pattern string + Func func(fs.FS, ...string) ([]Tensor, error) + }{ + {"model-*-of-*.safetensors", parseSafetensors}, + {"model.safetensors", parseSafetensors}, + {"pytorch_model-*-of-*.bin", parseTorch}, + {"pytorch_model.bin", parseTorch}, + {"consolidated.*.pth", parseTorch}, } - for pattern, parseFn := range patterns { - matches, err := fs.Glob(fsys, pattern) + for _, pattern := range patterns { + matches, err := fs.Glob(fsys, pattern.Pattern) if err != nil { return nil, err } if len(matches) > 0 { - return parseFn(fsys, matches...) + return pattern.Func(fsys, matches...) } } diff --git a/convert/tokenizer.go b/convert/tokenizer.go index cca40eb0..0d42a6d8 100644 --- a/convert/tokenizer.go +++ b/convert/tokenizer.go @@ -220,19 +220,22 @@ func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) { } func parseVocabulary(fsys fs.FS) (*Vocabulary, error) { - patterns := map[string]func(fs.FS) (*Vocabulary, error){ - "tokenizer.model": parseSentencePiece, - "tokenizer.json": parseVocabularyFromTokenizer, + patterns := []struct { + Pattern string + Func func(fs.FS) (*Vocabulary, error) + }{ + {"tokenizer.model", parseSentencePiece}, + {"tokenizer.json", parseVocabularyFromTokenizer}, } - for pattern, parseFn := range patterns { - if _, err := fs.Stat(fsys, pattern); errors.Is(err, os.ErrNotExist) { + for _, pattern := range patterns { + if _, err := fs.Stat(fsys, pattern.Pattern); errors.Is(err, os.ErrNotExist) { continue } else if err != nil { return nil, err } - return parseFn(fsys) + return pattern.Func(fsys) } return nil, errors.New("unknown tensor format")