feat: add llama.cpp backend (#94)

* feat: add `llama.cpp` backend

* fix(ci): install stable toolchain instead of nightly

* fix(ci): use different model

---------

Co-authored-by: flopes <FredericoPerimLopes@users.noreply.github.com>
This commit is contained in:
Luc Georges 2024-05-23 16:56:13 +02:00 committed by Luc Georges
parent 078d4c7af2
commit 0e95bb3589
No known key found for this signature in database
GPG key ID: 22924A120A2C2CE0
8 changed files with 87 additions and 28 deletions

View file

@ -12,7 +12,6 @@ env:
RUSTFLAGS: "-D warnings -W unreachable-pub" RUSTFLAGS: "-D warnings -W unreachable-pub"
RUSTUP_MAX_RETRIES: 10 RUSTUP_MAX_RETRIES: 10
FETCH_DEPTH: 0 # pull in the tags for the version string FETCH_DEPTH: 0 # pull in the tags for the version string
MACOSX_DEPLOYMENT_TARGET: 10.15
CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_LINKER: aarch64-linux-gnu-gcc CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_LINKER: aarch64-linux-gnu-gcc
CARGO_TARGET_ARM_UNKNOWN_LINUX_GNUEABIHF_LINKER: arm-linux-gnueabihf-gcc CARGO_TARGET_ARM_UNKNOWN_LINUX_GNUEABIHF_LINKER: arm-linux-gnueabihf-gcc

View file

@ -40,10 +40,7 @@ jobs:
DEBIAN_FRONTEND=noninteractive apt install -y pkg-config protobuf-compiler libssl-dev curl build-essential git-all gfortran DEBIAN_FRONTEND=noninteractive apt install -y pkg-config protobuf-compiler libssl-dev curl build-essential git-all gfortran
- name: Install Rust toolchain - name: Install Rust toolchain
uses: actions-rust-lang/setup-rust-toolchain@v1 uses: dtolnay/rust-toolchain@stable
with:
rustflags: ''
toolchain: nightly
- name: Install Python 3.10 - name: Install Python 3.10
uses: actions/setup-python@v5 uses: actions/setup-python@v5

View file

@ -67,10 +67,9 @@ pub enum Backend {
#[serde(default = "hf_default_url", deserialize_with = "parse_url")] #[serde(default = "hf_default_url", deserialize_with = "parse_url")]
url: String, url: String,
}, },
// TODO: LlamaCpp {
// LlamaCpp { url: String,
// url: String, },
// },
Ollama { Ollama {
url: String, url: String,
}, },

View file

@ -67,6 +67,37 @@ fn parse_api_text(text: &str) -> Result<Vec<Generation>> {
} }
} }
#[derive(Debug, Serialize, Deserialize)]
struct LlamaCppGeneration {
content: String,
}
impl From<LlamaCppGeneration> for Generation {
fn from(value: LlamaCppGeneration) -> Self {
Generation {
generated_text: value.content,
}
}
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum LlamaCppAPIResponse {
Generation(LlamaCppGeneration),
Error(APIError),
}
fn build_llamacpp_headers() -> HeaderMap {
HeaderMap::new()
}
fn parse_llamacpp_text(text: &str) -> Result<Vec<Generation>> {
match serde_json::from_str(text)? {
LlamaCppAPIResponse::Generation(gen) => Ok(vec![gen.into()]),
LlamaCppAPIResponse::Error(err) => Err(Error::LlamaCpp(err)),
}
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct OllamaGeneration { struct OllamaGeneration {
response: String, response: String,
@ -192,6 +223,9 @@ pub(crate) fn build_body(
request_body.insert("parameters".to_owned(), params); request_body.insert("parameters".to_owned(), params);
} }
} }
Backend::LlamaCpp { .. } => {
request_body.insert("prompt".to_owned(), Value::String(prompt));
}
Backend::Ollama { .. } | Backend::OpenAi { .. } => { Backend::Ollama { .. } | Backend::OpenAi { .. } => {
request_body.insert("prompt".to_owned(), Value::String(prompt)); request_body.insert("prompt".to_owned(), Value::String(prompt));
request_body.insert("model".to_owned(), Value::String(model)); request_body.insert("model".to_owned(), Value::String(model));
@ -208,6 +242,7 @@ pub(crate) fn build_headers(
) -> Result<HeaderMap> { ) -> Result<HeaderMap> {
match backend { match backend {
Backend::HuggingFace { .. } => build_api_headers(api_token, ide), Backend::HuggingFace { .. } => build_api_headers(api_token, ide),
Backend::LlamaCpp { .. } => Ok(build_llamacpp_headers()),
Backend::Ollama { .. } => Ok(build_ollama_headers()), Backend::Ollama { .. } => Ok(build_ollama_headers()),
Backend::OpenAi { .. } => build_openai_headers(api_token, ide), Backend::OpenAi { .. } => build_openai_headers(api_token, ide),
Backend::Tgi { .. } => build_tgi_headers(api_token, ide), Backend::Tgi { .. } => build_tgi_headers(api_token, ide),
@ -217,6 +252,7 @@ pub(crate) fn build_headers(
pub(crate) fn parse_generations(backend: &Backend, text: &str) -> Result<Vec<Generation>> { pub(crate) fn parse_generations(backend: &Backend, text: &str) -> Result<Vec<Generation>> {
match backend { match backend {
Backend::HuggingFace { .. } => parse_api_text(text), Backend::HuggingFace { .. } => parse_api_text(text),
Backend::LlamaCpp { .. } => parse_llamacpp_text(text),
Backend::Ollama { .. } => parse_ollama_text(text), Backend::Ollama { .. } => parse_ollama_text(text),
Backend::OpenAi { .. } => parse_openai_text(text), Backend::OpenAi { .. } => parse_openai_text(text),
Backend::Tgi { .. } => parse_tgi_text(text), Backend::Tgi { .. } => parse_tgi_text(text),

View file

@ -168,7 +168,7 @@ impl TryFrom<Vec<tower_lsp::lsp_types::PositionEncodingKind>> for PositionEncodi
} }
impl PositionEncodingKind { impl PositionEncodingKind {
pub fn to_lsp_type(&self) -> tower_lsp::lsp_types::PositionEncodingKind { pub fn to_lsp_type(self) -> tower_lsp::lsp_types::PositionEncodingKind {
match self { match self {
PositionEncodingKind::Utf8 => tower_lsp::lsp_types::PositionEncodingKind::UTF8, PositionEncodingKind::Utf8 => tower_lsp::lsp_types::PositionEncodingKind::UTF8,
PositionEncodingKind::Utf16 => tower_lsp::lsp_types::PositionEncodingKind::UTF16, PositionEncodingKind::Utf16 => tower_lsp::lsp_types::PositionEncodingKind::UTF16,
@ -205,9 +205,10 @@ impl Document {
) -> Result<()> { ) -> Result<()> {
match change.range { match change.range {
Some(range) => { Some(range) => {
if range.start.line < range.end.line if range.start.line > range.end.line
|| (range.start.line == range.end.line || (range.start.line == range.end.line
&& range.start.character <= range.end.character) { && range.start.character > range.end.character)
{
return Err(Error::InvalidRange(range)); return Err(Error::InvalidRange(range));
} }
@ -219,7 +220,10 @@ impl Document {
// 1. Get the line at which the change starts. // 1. Get the line at which the change starts.
let change_start_line_idx = range.start.line as usize; let change_start_line_idx = range.start.line as usize;
let change_start_line = self.text.get_line(change_start_line_idx).ok_or_else(|| Error::OutOfBoundLine(change_start_line_idx, self.text.len_lines()))?; let change_start_line =
self.text.get_line(change_start_line_idx).ok_or_else(|| {
Error::OutOfBoundLine(change_start_line_idx, self.text.len_lines())
})?;
// 2. Get the line at which the change ends. (Small optimization // 2. Get the line at which the change ends. (Small optimization
// where we first check whether start and end line are the // where we first check whether start and end line are the
@ -228,7 +232,9 @@ impl Document {
let change_end_line_idx = range.end.line as usize; let change_end_line_idx = range.end.line as usize;
let change_end_line = match same_line { let change_end_line = match same_line {
true => change_start_line, true => change_start_line,
false => self.text.get_line(change_end_line_idx).ok_or_else(|| Error::OutOfBoundLine(change_end_line_idx, self.text.len_lines()))?, false => self.text.get_line(change_end_line_idx).ok_or_else(|| {
Error::OutOfBoundLine(change_end_line_idx, self.text.len_lines())
})?,
}; };
fn compute_char_idx( fn compute_char_idx(
@ -330,7 +336,7 @@ impl Document {
self.tree = Some(new_tree); self.tree = Some(new_tree);
} }
None => { None => {
return Err(Error::TreeSitterParseError); return Err(Error::TreeSitterParsing);
} }
} }
} }
@ -416,7 +422,9 @@ mod test {
let mut rope = Rope::from_str( let mut rope = Rope::from_str(
"let a = '🥸 你好';\rfunction helloWorld() { return '🤲🏿'; }\nlet b = 'Hi, 😊';", "let a = '🥸 你好';\rfunction helloWorld() { return '🤲🏿'; }\nlet b = 'Hi, 😊';",
); );
let mut doc = Document::open(&LanguageId::JavaScript.to_string(), &rope.to_string()).await.unwrap(); let mut doc = Document::open(&LanguageId::JavaScript.to_string(), &rope.to_string())
.await
.unwrap();
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser
@ -464,7 +472,9 @@ mod test {
#[tokio::test] #[tokio::test]
async fn test_text_document_apply_content_change_bounds() { async fn test_text_document_apply_content_change_bounds() {
let rope = Rope::from_str(""); let rope = Rope::from_str("");
let mut doc = Document::open(&LanguageId::Unknown.to_string(), &rope.to_string()).await.unwrap(); let mut doc = Document::open(&LanguageId::Unknown.to_string(), &rope.to_string())
.await
.unwrap();
assert!(doc assert!(doc
.apply_content_change(new_change!(0, 0, 0, 1, ""), PositionEncodingKind::Utf16) .apply_content_change(new_change!(0, 0, 0, 1, ""), PositionEncodingKind::Utf16)
@ -513,7 +523,9 @@ mod test {
async fn test_document_update_tree_consistency_easy() { async fn test_document_update_tree_consistency_easy() {
let a = "let a = '你好';\rlet b = 'Hi, 😊';"; let a = "let a = '你好';\rlet b = 'Hi, 😊';";
let mut document = Document::open(&LanguageId::JavaScript.to_string(), a).await.unwrap(); let mut document = Document::open(&LanguageId::JavaScript.to_string(), a)
.await
.unwrap();
document document
.apply_content_change(new_change!(0, 9, 0, 11, "𐐀"), PositionEncodingKind::Utf16) .apply_content_change(new_change!(0, 9, 0, 11, "𐐀"), PositionEncodingKind::Utf16)
@ -541,7 +553,9 @@ mod test {
async fn test_document_update_tree_consistency_medium() { async fn test_document_update_tree_consistency_medium() {
let a = "let a = '🥸 你好';\rfunction helloWorld() { return '🤲🏿'; }\nlet b = 'Hi, 😊';"; let a = "let a = '🥸 你好';\rfunction helloWorld() { return '🤲🏿'; }\nlet b = 'Hi, 😊';";
let mut document = Document::open(&LanguageId::JavaScript.to_string(), a).await.unwrap(); let mut document = Document::open(&LanguageId::JavaScript.to_string(), a)
.await
.unwrap();
document document
.apply_content_change(new_change!(0, 14, 2, 13, ""), PositionEncodingKind::Utf16) .apply_content_change(new_change!(0, 14, 2, 13, ""), PositionEncodingKind::Utf16)

View file

@ -33,6 +33,8 @@ pub enum Error {
InvalidRepositoryId, InvalidRepositoryId,
#[error("invalid tokenizer path")] #[error("invalid tokenizer path")]
InvalidTokenizerPath, InvalidTokenizerPath,
#[error("llama.cpp error: {0}")]
LlamaCpp(crate::backend::APIError),
#[error("ollama error: {0}")] #[error("ollama error: {0}")]
Ollama(crate::backend::APIError), Ollama(crate::backend::APIError),
#[error("openai error: {0}")] #[error("openai error: {0}")]
@ -50,7 +52,7 @@ pub enum Error {
#[error("tgi error: {0}")] #[error("tgi error: {0}")]
Tgi(crate::backend::APIError), Tgi(crate::backend::APIError),
#[error("tree-sitter parse error: timeout possibly exceeded")] #[error("tree-sitter parse error: timeout possibly exceeded")]
TreeSitterParseError, TreeSitterParsing,
#[error("tree-sitter language error: {0}")] #[error("tree-sitter language error: {0}")]
TreeSitterLanguage(#[from] tree_sitter::LanguageError), TreeSitterLanguage(#[from] tree_sitter::LanguageError),
#[error("tokenizer error: {0}")] #[error("tokenizer error: {0}")]
@ -60,7 +62,7 @@ pub enum Error {
#[error("unknown backend: {0}")] #[error("unknown backend: {0}")]
UnknownBackend(String), UnknownBackend(String),
#[error("unknown encoding kind: {0}")] #[error("unknown encoding kind: {0}")]
UnknownEncodingKind(String) UnknownEncodingKind(String),
} }
pub(crate) type Result<T> = std::result::Result<T, Error>; pub(crate) type Result<T> = std::result::Result<T, Error>;

View file

@ -417,6 +417,17 @@ async fn get_tokenizer(
fn build_url(backend: Backend, model: &str) -> String { fn build_url(backend: Backend, model: &str) -> String {
match backend { match backend {
Backend::HuggingFace { url } => format!("{url}/models/{model}"), Backend::HuggingFace { url } => format!("{url}/models/{model}"),
Backend::LlamaCpp { mut url } => {
if url.ends_with("/completions") {
url
} else if url.ends_with('/') {
url.push_str("completions");
url
} else {
url.push_str("/completions");
url
}
}
Backend::Ollama { url } => url, Backend::Ollama { url } => url,
Backend::OpenAi { url } => url, Backend::OpenAi { url } => url,
Backend::Tgi { url } => url, Backend::Tgi { url } => url,
@ -540,7 +551,8 @@ impl LanguageServer for LlmService {
general_capabilities general_capabilities
.position_encodings .position_encodings
.map(TryFrom::try_from) .map(TryFrom::try_from)
}).unwrap_or(Ok(document::PositionEncodingKind::Utf16))?; })
.unwrap_or(Ok(document::PositionEncodingKind::Utf16))?;
*self.position_encoding.write().await = position_encoding; *self.position_encoding.write().await = position_encoding;

View file

@ -2,10 +2,10 @@
context_window: 2000 context_window: 2000
fim: fim:
enabled: true enabled: true
prefix: <fim_prefix> prefix: "<PRE> "
middle: <fim_middle> middle: " <MID>"
suffix: <fim_suffix> suffix: " <SUF>"
model: bigcode/starcoder model: codellama/CodeLlama-13b-hf
backend: huggingface backend: huggingface
request_body: request_body:
max_new_tokens: 150 max_new_tokens: 150
@ -14,8 +14,8 @@ request_body:
top_p: 0.95 top_p: 0.95
tls_skip_verify_insecure: false tls_skip_verify_insecure: false
tokenizer_config: tokenizer_config:
repository: bigcode/starcoder repository: codellama/CodeLlama-13b-hf
tokens_to_clear: ["<|endoftext|>"] tokens_to_clear: ["<EOT>"]
repositories: repositories:
- source: - source:
type: local type: local