feat: add llama.cpp backend (#94)

* feat: add `llama.cpp` backend

* fix(ci): install stable toolchain instead of nightly

* fix(ci): use different model

---------

Co-authored-by: flopes <FredericoPerimLopes@users.noreply.github.com>
This commit is contained in:
Luc Georges 2024-05-23 16:56:13 +02:00 committed by Luc Georges
parent 078d4c7af2
commit 0e95bb3589
No known key found for this signature in database
GPG key ID: 22924A120A2C2CE0
8 changed files with 87 additions and 28 deletions

View file

@ -12,7 +12,6 @@ env:
RUSTFLAGS: "-D warnings -W unreachable-pub"
RUSTUP_MAX_RETRIES: 10
FETCH_DEPTH: 0 # pull in the tags for the version string
MACOSX_DEPLOYMENT_TARGET: 10.15
CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_LINKER: aarch64-linux-gnu-gcc
CARGO_TARGET_ARM_UNKNOWN_LINUX_GNUEABIHF_LINKER: arm-linux-gnueabihf-gcc

View file

@ -40,10 +40,7 @@ jobs:
DEBIAN_FRONTEND=noninteractive apt install -y pkg-config protobuf-compiler libssl-dev curl build-essential git-all gfortran
- name: Install Rust toolchain
uses: actions-rust-lang/setup-rust-toolchain@v1
with:
rustflags: ''
toolchain: nightly
uses: dtolnay/rust-toolchain@stable
- name: Install Python 3.10
uses: actions/setup-python@v5

View file

@ -67,10 +67,9 @@ pub enum Backend {
#[serde(default = "hf_default_url", deserialize_with = "parse_url")]
url: String,
},
// TODO:
// LlamaCpp {
// url: String,
// },
LlamaCpp {
url: String,
},
Ollama {
url: String,
},

View file

@ -67,6 +67,37 @@ fn parse_api_text(text: &str) -> Result<Vec<Generation>> {
}
}
#[derive(Debug, Serialize, Deserialize)]
struct LlamaCppGeneration {
content: String,
}
impl From<LlamaCppGeneration> for Generation {
fn from(value: LlamaCppGeneration) -> Self {
Generation {
generated_text: value.content,
}
}
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum LlamaCppAPIResponse {
Generation(LlamaCppGeneration),
Error(APIError),
}
fn build_llamacpp_headers() -> HeaderMap {
HeaderMap::new()
}
fn parse_llamacpp_text(text: &str) -> Result<Vec<Generation>> {
match serde_json::from_str(text)? {
LlamaCppAPIResponse::Generation(gen) => Ok(vec![gen.into()]),
LlamaCppAPIResponse::Error(err) => Err(Error::LlamaCpp(err)),
}
}
#[derive(Debug, Serialize, Deserialize)]
struct OllamaGeneration {
response: String,
@ -192,6 +223,9 @@ pub(crate) fn build_body(
request_body.insert("parameters".to_owned(), params);
}
}
Backend::LlamaCpp { .. } => {
request_body.insert("prompt".to_owned(), Value::String(prompt));
}
Backend::Ollama { .. } | Backend::OpenAi { .. } => {
request_body.insert("prompt".to_owned(), Value::String(prompt));
request_body.insert("model".to_owned(), Value::String(model));
@ -208,6 +242,7 @@ pub(crate) fn build_headers(
) -> Result<HeaderMap> {
match backend {
Backend::HuggingFace { .. } => build_api_headers(api_token, ide),
Backend::LlamaCpp { .. } => Ok(build_llamacpp_headers()),
Backend::Ollama { .. } => Ok(build_ollama_headers()),
Backend::OpenAi { .. } => build_openai_headers(api_token, ide),
Backend::Tgi { .. } => build_tgi_headers(api_token, ide),
@ -217,6 +252,7 @@ pub(crate) fn build_headers(
pub(crate) fn parse_generations(backend: &Backend, text: &str) -> Result<Vec<Generation>> {
match backend {
Backend::HuggingFace { .. } => parse_api_text(text),
Backend::LlamaCpp { .. } => parse_llamacpp_text(text),
Backend::Ollama { .. } => parse_ollama_text(text),
Backend::OpenAi { .. } => parse_openai_text(text),
Backend::Tgi { .. } => parse_tgi_text(text),

View file

@ -168,7 +168,7 @@ impl TryFrom<Vec<tower_lsp::lsp_types::PositionEncodingKind>> for PositionEncodi
}
impl PositionEncodingKind {
pub fn to_lsp_type(&self) -> tower_lsp::lsp_types::PositionEncodingKind {
pub fn to_lsp_type(self) -> tower_lsp::lsp_types::PositionEncodingKind {
match self {
PositionEncodingKind::Utf8 => tower_lsp::lsp_types::PositionEncodingKind::UTF8,
PositionEncodingKind::Utf16 => tower_lsp::lsp_types::PositionEncodingKind::UTF16,
@ -205,9 +205,10 @@ impl Document {
) -> Result<()> {
match change.range {
Some(range) => {
if range.start.line < range.end.line
if range.start.line > range.end.line
|| (range.start.line == range.end.line
&& range.start.character <= range.end.character) {
&& range.start.character > range.end.character)
{
return Err(Error::InvalidRange(range));
}
@ -219,7 +220,10 @@ impl Document {
// 1. Get the line at which the change starts.
let change_start_line_idx = range.start.line as usize;
let change_start_line = self.text.get_line(change_start_line_idx).ok_or_else(|| Error::OutOfBoundLine(change_start_line_idx, self.text.len_lines()))?;
let change_start_line =
self.text.get_line(change_start_line_idx).ok_or_else(|| {
Error::OutOfBoundLine(change_start_line_idx, self.text.len_lines())
})?;
// 2. Get the line at which the change ends. (Small optimization
// where we first check whether start and end line are the
@ -228,7 +232,9 @@ impl Document {
let change_end_line_idx = range.end.line as usize;
let change_end_line = match same_line {
true => change_start_line,
false => self.text.get_line(change_end_line_idx).ok_or_else(|| Error::OutOfBoundLine(change_end_line_idx, self.text.len_lines()))?,
false => self.text.get_line(change_end_line_idx).ok_or_else(|| {
Error::OutOfBoundLine(change_end_line_idx, self.text.len_lines())
})?,
};
fn compute_char_idx(
@ -330,7 +336,7 @@ impl Document {
self.tree = Some(new_tree);
}
None => {
return Err(Error::TreeSitterParseError);
return Err(Error::TreeSitterParsing);
}
}
}
@ -416,7 +422,9 @@ mod test {
let mut rope = Rope::from_str(
"let a = '🥸 你好';\rfunction helloWorld() { return '🤲🏿'; }\nlet b = 'Hi, 😊';",
);
let mut doc = Document::open(&LanguageId::JavaScript.to_string(), &rope.to_string()).await.unwrap();
let mut doc = Document::open(&LanguageId::JavaScript.to_string(), &rope.to_string())
.await
.unwrap();
let mut parser = Parser::new();
parser
@ -464,7 +472,9 @@ mod test {
#[tokio::test]
async fn test_text_document_apply_content_change_bounds() {
let rope = Rope::from_str("");
let mut doc = Document::open(&LanguageId::Unknown.to_string(), &rope.to_string()).await.unwrap();
let mut doc = Document::open(&LanguageId::Unknown.to_string(), &rope.to_string())
.await
.unwrap();
assert!(doc
.apply_content_change(new_change!(0, 0, 0, 1, ""), PositionEncodingKind::Utf16)
@ -513,7 +523,9 @@ mod test {
async fn test_document_update_tree_consistency_easy() {
let a = "let a = '你好';\rlet b = 'Hi, 😊';";
let mut document = Document::open(&LanguageId::JavaScript.to_string(), a).await.unwrap();
let mut document = Document::open(&LanguageId::JavaScript.to_string(), a)
.await
.unwrap();
document
.apply_content_change(new_change!(0, 9, 0, 11, "𐐀"), PositionEncodingKind::Utf16)
@ -541,7 +553,9 @@ mod test {
async fn test_document_update_tree_consistency_medium() {
let a = "let a = '🥸 你好';\rfunction helloWorld() { return '🤲🏿'; }\nlet b = 'Hi, 😊';";
let mut document = Document::open(&LanguageId::JavaScript.to_string(), a).await.unwrap();
let mut document = Document::open(&LanguageId::JavaScript.to_string(), a)
.await
.unwrap();
document
.apply_content_change(new_change!(0, 14, 2, 13, ""), PositionEncodingKind::Utf16)

View file

@ -33,6 +33,8 @@ pub enum Error {
InvalidRepositoryId,
#[error("invalid tokenizer path")]
InvalidTokenizerPath,
#[error("llama.cpp error: {0}")]
LlamaCpp(crate::backend::APIError),
#[error("ollama error: {0}")]
Ollama(crate::backend::APIError),
#[error("openai error: {0}")]
@ -50,7 +52,7 @@ pub enum Error {
#[error("tgi error: {0}")]
Tgi(crate::backend::APIError),
#[error("tree-sitter parse error: timeout possibly exceeded")]
TreeSitterParseError,
TreeSitterParsing,
#[error("tree-sitter language error: {0}")]
TreeSitterLanguage(#[from] tree_sitter::LanguageError),
#[error("tokenizer error: {0}")]
@ -60,7 +62,7 @@ pub enum Error {
#[error("unknown backend: {0}")]
UnknownBackend(String),
#[error("unknown encoding kind: {0}")]
UnknownEncodingKind(String)
UnknownEncodingKind(String),
}
pub(crate) type Result<T> = std::result::Result<T, Error>;

View file

@ -417,6 +417,17 @@ async fn get_tokenizer(
fn build_url(backend: Backend, model: &str) -> String {
match backend {
Backend::HuggingFace { url } => format!("{url}/models/{model}"),
Backend::LlamaCpp { mut url } => {
if url.ends_with("/completions") {
url
} else if url.ends_with('/') {
url.push_str("completions");
url
} else {
url.push_str("/completions");
url
}
}
Backend::Ollama { url } => url,
Backend::OpenAi { url } => url,
Backend::Tgi { url } => url,
@ -540,7 +551,8 @@ impl LanguageServer for LlmService {
general_capabilities
.position_encodings
.map(TryFrom::try_from)
}).unwrap_or(Ok(document::PositionEncodingKind::Utf16))?;
})
.unwrap_or(Ok(document::PositionEncodingKind::Utf16))?;
*self.position_encoding.write().await = position_encoding;

View file

@ -2,10 +2,10 @@
context_window: 2000
fim:
enabled: true
prefix: <fim_prefix>
middle: <fim_middle>
suffix: <fim_suffix>
model: bigcode/starcoder
prefix: "<PRE> "
middle: " <MID>"
suffix: " <SUF>"
model: codellama/CodeLlama-13b-hf
backend: huggingface
request_body:
max_new_tokens: 150
@ -14,8 +14,8 @@ request_body:
top_p: 0.95
tls_skip_verify_insecure: false
tokenizer_config:
repository: bigcode/starcoder
tokens_to_clear: ["<|endoftext|>"]
repository: codellama/CodeLlama-13b-hf
tokens_to_clear: ["<EOT>"]
repositories:
- source:
type: local