feat: improve suggestions based on AST (#30)

* feat: improve suggestions based on AST

* feat: bump version to `0.3.0`
This commit is contained in:
Luc Georges 2023-10-11 19:33:57 +02:00 committed by GitHub
parent b6d6c6cccd
commit cdbf76fd43
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 682 additions and 41 deletions

244
Cargo.lock generated
View file

@ -659,7 +659,7 @@ checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3"
[[package]]
name = "llm-ls"
version = "0.2.2"
version = "0.3.0"
dependencies = [
"home",
"reqwest",
@ -671,6 +671,28 @@ dependencies = [
"tracing",
"tracing-appender",
"tracing-subscriber",
"tree-sitter",
"tree-sitter-bash",
"tree-sitter-c",
"tree-sitter-c-sharp",
"tree-sitter-cpp",
"tree-sitter-elixir",
"tree-sitter-erlang",
"tree-sitter-go",
"tree-sitter-html",
"tree-sitter-java",
"tree-sitter-javascript",
"tree-sitter-json",
"tree-sitter-lua",
"tree-sitter-md",
"tree-sitter-objc",
"tree-sitter-python",
"tree-sitter-r",
"tree-sitter-ruby",
"tree-sitter-rust",
"tree-sitter-scala",
"tree-sitter-swift",
"tree-sitter-typescript",
]
[[package]]
@ -1731,6 +1753,226 @@ dependencies = [
"tracing-serde",
]
[[package]]
name = "tree-sitter"
version = "0.20.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e747b1f9b7b931ed39a548c1fae149101497de3c1fc8d9e18c62c1a66c683d3d"
dependencies = [
"cc",
"regex",
]
[[package]]
name = "tree-sitter-bash"
version = "0.20.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "096f57b3b44c04bfc7b21a4da44bfa16adf1f88aba18993b8478a091076d0968"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-c"
version = "0.20.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30b03bdf218020057abee831581a74bff8c298323d6c6cd1a70556430ded9f4b"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-c-sharp"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9ab3dc608f34924fa9e10533a95f62dbc14b6de0ddd7107722eba66fe19ae31"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-cpp"
version = "0.20.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23b4b625f46a7370544b9cf0545532c26712ae49bfc02eb09825db358b9f79e1"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-elixir"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a9916f3e1c80b3c8aab8582604e97e8720cb9b893489b347cf999f80f9d469e"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-erlang"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d110d62a7ae35b985d8cfbc4de6e9281c7cbf268c466e30ebb31c2d3f861141"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-go"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ad6d11f19441b961af2fda7f12f5d0dac325f6d6de83836a1d3750018cc5114"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-html"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "184e6b77953a354303dc87bf5fe36558c83569ce92606e7b382a0dc1b7443443"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-java"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2adc5696bf5abf761081d7457d2bb82d0e3b28964f4214f63fd7e720ef462653"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-javascript"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edbc663376bdd294bd1f0a6daf859aedb9aa5bdb72217d7ad8ba2d5314102cf7"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-json"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50d82d2e33ee675dc71289e2ace4f8f9cf96d36d81400e9dae5ea61edaf5dea6"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-lua"
version = "0.0.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0968cf4962ead1d26da28921dde1fd97407e7bbcf2f959cd20cf04ba2daa9421"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-md"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a237fa10f6b466b76c783c79b08cc172581e547ef1dbb6ddf1f8b4e230157e1"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-objc"
version = "3.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f465c1a24f400b1e4837c97ef350954dea05ff72030f6808fb3945e04fe0b27"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-python"
version = "0.20.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c93b1b1fbd0d399db3445f51fd3058e43d0b4dcff62ddbdb46e66550978aa5"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-r"
version = "0.19.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "522c13f4cc46213148b19d4ad40a988ffabd51fd90eb7de759844fbde49bda0c"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-ruby"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ac30cbb1560363ae76e1ccde543d6d99087421e228cc47afcec004b86bb711a"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-rust"
version = "0.20.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0832309b0b2b6d33760ce5c0e818cb47e1d72b468516bfe4134408926fa7594"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-scala"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93df43ab4f2b3299fe97e73eb9b946bbca453b402bea8debf1fa69ab4e28412b"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-swift"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eee2dbeb101a88a1d9e4883e3fbda6c799cf676f6a1cf59e4fc3862e67e70118"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-typescript"
version = "0.20.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a75049f0aafabb2aac205d7bb24da162b53dcd0cfb326785f25a2f32efa8071a"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "try-lock"
version = "0.2.4"

View file

@ -1,6 +1,6 @@
[package]
name = "llm-ls"
version = "0.2.2"
version = "0.3.0"
edition = "2021"
[[bin]]
@ -17,4 +17,26 @@ tower-lsp = "0.20"
tracing = "0.1"
tracing-appender = "0.2"
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
tree-sitter = "0.20"
tree-sitter-bash = "0.20"
tree-sitter-c = "0.20"
tree-sitter-cpp = "0.20"
tree-sitter-c-sharp = "0.20"
tree-sitter-elixir = "0.1"
tree-sitter-erlang = "0.2"
tree-sitter-go = "0.20"
tree-sitter-html = "0.19"
tree-sitter-java = "0.20"
tree-sitter-javascript = "0.20"
tree-sitter-json = "0.20"
tree-sitter-lua = "0.0.19"
tree-sitter-md = "0.1"
tree-sitter-objc = "3"
tree-sitter-python = "0.20"
tree-sitter-r = "0.19"
tree-sitter-ruby = "0.20"
tree-sitter-rust = "0.20"
tree-sitter-scala = "0.20"
tree-sitter-swift = "0.3"
tree-sitter-typescript = "0.20"

View file

@ -0,0 +1,196 @@
use ropey::Rope;
use tower_lsp::jsonrpc::Result;
use tree_sitter::{Parser, Tree};
use crate::internal_error;
use crate::language_id::LanguageId;
fn get_parser(language_id: LanguageId) -> Result<Parser> {
match language_id {
LanguageId::Bash => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_bash::language())
.map_err(internal_error)?;
Ok(parser)
}
LanguageId::C => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_c::language())
.map_err(internal_error)?;
Ok(parser)
}
LanguageId::Cpp => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_cpp::language())
.map_err(internal_error)?;
Ok(parser)
}
LanguageId::CSharp => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_c_sharp::language())
.map_err(internal_error)?;
Ok(parser)
}
LanguageId::Elixir => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_elixir::language())
.map_err(internal_error)?;
Ok(parser)
}
LanguageId::Erlang => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_erlang::language())
.map_err(internal_error)?;
Ok(parser)
}
LanguageId::Go => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_go::language())
.map_err(internal_error)?;
Ok(parser)
}
LanguageId::Html => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_html::language())
.map_err(internal_error)?;
Ok(parser)
}
LanguageId::Java => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_java::language())
.map_err(internal_error)?;
Ok(parser)
}
LanguageId::JavaScript | LanguageId::JavaScriptReact => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_javascript::language())
.map_err(internal_error)?;
Ok(parser)
}
LanguageId::Json => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_json::language())
.map_err(internal_error)?;
Ok(parser)
}
LanguageId::Lua => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_lua::language())
.map_err(internal_error)?;
Ok(parser)
}
LanguageId::Markdown => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_md::language())
.map_err(internal_error)?;
Ok(parser)
}
LanguageId::ObjectiveC => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_objc::language())
.map_err(internal_error)?;
Ok(parser)
}
LanguageId::Python => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_python::language())
.map_err(internal_error)?;
Ok(parser)
}
LanguageId::R => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_r::language())
.map_err(internal_error)?;
Ok(parser)
}
LanguageId::Ruby => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_ruby::language())
.map_err(internal_error)?;
Ok(parser)
}
LanguageId::Rust => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_rust::language())
.map_err(internal_error)?;
Ok(parser)
}
LanguageId::Scala => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_scala::language())
.map_err(internal_error)?;
Ok(parser)
}
LanguageId::Swift => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_swift::language())
.map_err(internal_error)?;
Ok(parser)
}
LanguageId::TypeScript => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_typescript::language_typescript())
.map_err(internal_error)?;
Ok(parser)
}
LanguageId::TypeScriptReact => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_typescript::language_tsx())
.map_err(internal_error)?;
Ok(parser)
}
LanguageId::Unknown => Ok(Parser::new()),
}
}
pub struct Document {
#[allow(dead_code)]
language_id: LanguageId,
pub text: Rope,
parser: Parser,
pub tree: Option<Tree>,
}
impl Document {
pub async fn open(language_id: &str, text: &str) -> Result<Self> {
let language_id = language_id.into();
let rope = Rope::from_str(text);
let mut parser = get_parser(language_id)?;
let tree = parser.parse(text, None);
Ok(Document {
language_id,
text: rope,
parser,
tree,
})
}
pub async fn change(&mut self, text: &str) -> Result<()> {
let rope = Rope::from_str(text);
self.tree = self.parser.parse(text, None);
self.text = rope;
Ok(())
}
}

View file

@ -0,0 +1,107 @@
use std::fmt;
#[derive(Clone, Copy)]
pub enum LanguageId {
Bash,
C,
Cpp,
CSharp,
Elixir,
Erlang,
Go,
Html,
Java,
JavaScript,
JavaScriptReact,
Json,
Lua,
Markdown,
ObjectiveC,
Python,
R,
Ruby,
Rust,
Scala,
Swift,
TypeScript,
TypeScriptReact,
Unknown,
}
impl fmt::Display for LanguageId {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Bash => write!(f, "shellscript"),
Self::C => write!(f, "c"),
Self::Cpp => write!(f, "cpp"),
Self::CSharp => write!(f, "csharp"),
Self::Elixir => write!(f, "elixir"),
Self::Erlang => write!(f, "erlang"),
Self::Go => write!(f, "go"),
Self::Html => write!(f, "html"),
Self::Java => write!(f, "java"),
Self::JavaScript => write!(f, "javascript"),
Self::JavaScriptReact => write!(f, "javascriptreact"),
Self::Json => write!(f, "json"),
Self::Lua => write!(f, "lua"),
Self::Markdown => write!(f, "markdown"),
Self::ObjectiveC => write!(f, "objective-c"),
Self::Python => write!(f, "python"),
Self::R => write!(f, "r"),
Self::Ruby => write!(f, "ruby"),
Self::Rust => write!(f, "rust"),
Self::Scala => write!(f, "scala"),
Self::Swift => write!(f, "swift"),
Self::TypeScript => write!(f, "typescript"),
Self::TypeScriptReact => write!(f, "typescriptreact"),
Self::Unknown => write!(f, "unknown"),
}
}
}
pub struct LanguageIdError {
language_id: String,
}
impl fmt::Display for LanguageIdError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Invalid language id: {}", self.language_id)
}
}
impl From<&str> for LanguageId {
fn from(value: &str) -> Self {
match value {
"c" => Self::C,
"cpp" => Self::Cpp,
"csharp" => Self::CSharp,
"elixir" => Self::Elixir,
"erlang" => Self::Erlang,
"go" => Self::Go,
"html" => Self::Html,
"java" => Self::Java,
"javascript" => Self::JavaScript,
"javascriptreact" => Self::JavaScriptReact,
"json" => Self::Json,
"lua" => Self::Lua,
"markdown" => Self::Markdown,
"objective-c" => Self::ObjectiveC,
"python" => Self::Python,
"r" => Self::R,
"ruby" => Self::Ruby,
"rust" => Self::Rust,
"scala" => Self::Scala,
"shellscript" => Self::Bash,
"swift" => Self::Swift,
"typescript" => Self::TypeScript,
"typescriptreact" => Self::TypeScriptReact,
_ => Self::Unknown,
}
}
}
impl From<String> for LanguageId {
fn from(value: String) -> Self {
Self::from(value.as_str())
}
}

View file

@ -1,3 +1,4 @@
use document::Document;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
use ropey::Rope;
use serde::{Deserialize, Deserializer, Serialize};
@ -12,14 +13,66 @@ use tokio::sync::RwLock;
use tower_lsp::jsonrpc::{Error, Result};
use tower_lsp::lsp_types::*;
use tower_lsp::{Client, LanguageServer, LspService, Server};
use tracing::{debug, error, info};
use tracing::{debug, error, info, warn};
use tracing_appender::rolling;
use tracing_subscriber::EnvFilter;
mod document;
mod language_id;
const MAX_WARNING_REPEAT: Duration = Duration::from_secs(3_600);
const NAME: &str = "llm-ls";
const VERSION: &str = env!("CARGO_PKG_VERSION");
#[derive(Debug, PartialEq, Eq)]
enum CompletionType {
Empty,
SingleLine,
MultiLine,
}
fn should_complete(document: &Document, position: Position) -> CompletionType {
let row = position.line as usize;
let column = position.character as usize;
if let Some(tree) = &document.tree {
let current_node = tree.root_node().descendant_for_point_range(
tree_sitter::Point { row, column },
tree_sitter::Point { row, column },
);
if let Some(node) = current_node {
if node == tree.root_node() {
return CompletionType::MultiLine;
}
let start = node.start_position();
let end = node.end_position();
let mut start_offset = document.text.line_to_char(start.row) + start.column;
let mut end_offset = document.text.line_to_char(end.row) + end.column - 1;
let start_char = document.text.char(start_offset);
if !start_char.is_whitespace() {
start_offset += 1;
}
let end_char = document.text.char(end_offset);
if !end_char.is_whitespace() {
end_offset -= 1;
}
if start_offset >= end_offset {
return CompletionType::SingleLine;
}
let slice = document.text.slice(start_offset..end_offset);
if slice.to_string().trim().is_empty() {
return CompletionType::MultiLine;
}
}
}
let start_idx = document.text.line_to_char(row);
let next_char = document.text.char(start_idx + column);
if next_char.is_whitespace() {
CompletionType::SingleLine
} else {
CompletionType::Empty
}
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(untagged)]
enum TokenizerConfig {
@ -100,20 +153,6 @@ enum APIResponse {
Error(APIError),
}
#[derive(Debug)]
struct Document {
#[allow(dead_code)]
language_id: String,
text: Rope,
}
impl Document {
fn new(language_id: String, text: Rope) -> Self {
Self { language_id, text }
}
}
#[derive(Debug)]
struct Backend {
cache_dir: PathBuf,
client: Client,
@ -313,7 +352,11 @@ async fn request_completion(
}
}
fn parse_generations(generations: Vec<Generation>, tokens_to_clear: &[String]) -> Vec<Completion> {
fn parse_generations(
generations: Vec<Generation>,
tokens_to_clear: &[String],
completion_type: CompletionType,
) -> Vec<Completion> {
generations
.into_iter()
.map(|g| {
@ -321,7 +364,20 @@ fn parse_generations(generations: Vec<Generation>, tokens_to_clear: &[String]) -
for token in tokens_to_clear {
generated_text = generated_text.replace(token, "")
}
Completion { generated_text }
match completion_type {
CompletionType::Empty => {
warn!("completion type should not be empty when post processing completions");
Completion { generated_text }
}
CompletionType::SingleLine => Completion {
generated_text: generated_text
.split_once('\n')
.unwrap_or((&generated_text, ""))
.0
.to_owned(),
},
CompletionType::MultiLine => Completion { generated_text },
}
})
.collect()
}
@ -466,6 +522,12 @@ impl Backend {
*unauthenticated_warn_at = Instant::now();
}
}
let completion_type = should_complete(document, params.text_document_position.position);
info!("completion type: {completion_type:?}");
if completion_type == CompletionType::Empty {
return Ok(vec![]);
}
let tokenizer = get_tokenizer(
&params.model,
&mut *self.tokenizer_map.write().await,
@ -500,7 +562,11 @@ impl Backend {
)
.await?;
Ok(parse_generations(result, &params.tokens_to_clear))
Ok(parse_generations(
result,
&params.tokens_to_clear,
completion_type,
))
}
}
@ -526,40 +592,46 @@ impl LanguageServer for Backend {
self.client
.log_message(MessageType::INFO, "{llm-ls} initialized")
.await;
let _ = info!("initialized language server");
info!("initialized language server");
}
// TODO:
// textDocument/didClose
async fn did_open(&self, params: DidOpenTextDocumentParams) {
let uri = params.text_document.uri.to_string();
match Document::open(
&params.text_document.language_id,
&params.text_document.text,
)
.await
{
Ok(document) => {
self.document_map
.write()
.await
.insert(uri.clone(), document);
info!("{uri} opened");
}
Err(err) => error!("error opening {uri}: {err}"),
}
self.client
.log_message(MessageType::INFO, "{llm-ls} file opened")
.await;
let rope = ropey::Rope::from_str(&params.text_document.text);
let uri = params.text_document.uri.to_string();
*self
.document_map
.write()
.await
.entry(uri.clone())
.or_insert(Document::new("unknown".to_owned(), Rope::new())) =
Document::new(params.text_document.language_id, rope);
info!("{uri} opened");
}
async fn did_change(&self, params: DidChangeTextDocumentParams) {
self.client
.log_message(MessageType::INFO, "{llm-ls} file changed")
.await;
let rope = ropey::Rope::from_str(&params.content_changes[0].text);
let uri = params.text_document.uri.to_string();
let mut document_map = self.document_map.write().await;
let doc = document_map
.entry(uri.clone())
.or_insert(Document::new("unknown".to_owned(), Rope::new()));
doc.text = rope;
info!("{uri} changed");
let doc = document_map.get_mut(&uri);
if let Some(doc) = doc {
match doc.change(&params.content_changes[0].text).await {
Ok(()) => info!("{uri} changed"),
Err(err) => error!("error when changing {uri}: {err}"),
}
} else {
warn!("textDocument/didChange {uri}: document not found");
}
}
async fn did_save(&self, params: DidSaveTextDocumentParams) {
@ -570,6 +642,8 @@ impl LanguageServer for Backend {
info!("{uri} saved");
}
// TODO:
// textDocument/didClose
async fn did_close(&self, params: DidCloseTextDocumentParams) {
self.client
.log_message(MessageType::INFO, "{llm-ls} file closed")
@ -579,7 +653,7 @@ impl LanguageServer for Backend {
}
async fn shutdown(&self) -> Result<()> {
let _ = debug!("shutdown");
debug!("shutdown");
Ok(())
}
}