From 620d5c569e965ac93ac5c58bca5d3d8938cb98bc Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Sat, 8 Jun 2024 12:32:02 -0700 Subject: [PATCH] fix parsing big endian gguf --- llm/ggml.go | 13 +++++-------- llm/gguf.go | 15 ++++++++++++++- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/llm/ggml.go b/llm/ggml.go index 645447d5..16da4c9d 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -231,8 +231,7 @@ const ( // Magic constant for `ggla` files (LoRA adapter). FILE_MAGIC_GGLA = 0x67676C61 // Magic constant for `gguf` files (versioned, gguf) - FILE_MAGIC_GGUF_LE = 0x46554747 - FILE_MAGIC_GGUF_BE = 0x47475546 + FILE_MAGIC_GGUF = 0x46554747 ) var ErrUnsupportedFormat = errors.New("unsupported model format") @@ -247,7 +246,7 @@ func DetectGGMLType(b []byte) string { return "ggjt" case FILE_MAGIC_GGLA: return "ggla" - case FILE_MAGIC_GGUF_LE, FILE_MAGIC_GGUF_BE: + case FILE_MAGIC_GGUF: return "gguf" default: return "" @@ -255,21 +254,19 @@ func DetectGGMLType(b []byte) string { } func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) { - var magic uint32 + var magic [4]byte if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil { return nil, 0, err } var c container - switch magic { + switch binary.LittleEndian.Uint32(magic[:]) { case FILE_MAGIC_GGML, FILE_MAGIC_GGMF, FILE_MAGIC_GGJT: return nil, 0, ErrUnsupportedFormat case FILE_MAGIC_GGLA: c = &containerGGLA{} - case FILE_MAGIC_GGUF_LE: + case FILE_MAGIC_GGUF: c = &containerGGUF{ByteOrder: binary.LittleEndian} - case FILE_MAGIC_GGUF_BE: - c = &containerGGUF{ByteOrder: binary.BigEndian} default: return nil, 0, errors.New("invalid file magic") } diff --git a/llm/gguf.go b/llm/gguf.go index 234efe57..8c64e166 100644 --- a/llm/gguf.go +++ b/llm/gguf.go @@ -36,10 +36,23 @@ func (c *containerGGUF) Name() string { } func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) { - if err := binary.Read(rs, c.ByteOrder, &c.Version); err != nil { + var version [4]byte + if err := binary.Read(rs, c.ByteOrder, &version); err != nil { return nil, err } + // if the lower 16 bits are 0, the byte order is probably wrong + if c.ByteOrder.Uint32(version[:])&1<<4 == 0 { + switch c.ByteOrder { + case binary.LittleEndian: + c.ByteOrder = binary.BigEndian + case binary.BigEndian: + c.ByteOrder = binary.LittleEndian + } + } + + c.Version = c.ByteOrder.Uint32(version[:]) + var err error switch c.Version { case 1: