feat: testbed (#39)

This commit is contained in:
Luc Georges 2023-11-06 21:26:37 +01:00 committed by GitHub
parent 4aacd7087b
commit c7affd0da9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
51 changed files with 3633 additions and 74 deletions

91
.github/workflows/test.yml vendored Normal file
View file

@ -0,0 +1,91 @@
name: test
on:
workflow_dispatch:
push:
branches: [main]
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
testbed:
runs-on: [self-hosted, intel-cpu, 8-cpu, ci]
container:
image: ubuntu:22.04
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install dependencies
run: |
apt update
DEBIAN_FRONTEND=noninteractive apt install -y pkg-config protobuf-compiler libssl-dev curl build-essential git-all gfortran
- name: Install Rust toolchain
uses: actions-rust-lang/setup-rust-toolchain@v1
with:
rustflags: ''
toolchain: nightly
- name: Install Python 3.10
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install node 18
uses: actions/setup-node@v3
with:
node-version: 18
- name: Install yarn
run: |
npm i -g yarn
- name: Set up cargo cache
uses: actions/cache@v3
continue-on-error: false
with:
path: |
~/.cargo/bin/
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
target/
key: ${{ runner.os }}-cargo-${{ hashFiles('Cargo.lock') }}
restore-keys: ${{ runner.os }}-cargo-${{ hashFiles('Cargo.lock') }}
- name: Build project
run: cargo build -r
- name: Run testbed
run: cargo run --bin testbed -r -- --api-token $API_TOKEN -r `pwd`/crates/testbed/repositories-ci.yaml
if: github.event_name == 'push' || github.event_name == 'pull_request'
env:
API_TOKEN: ${{ secrets.API_TOKEN }}
- name: Run testbed
run: cargo run --bin testbed -r -- --api-token $API_TOKEN
if: github.event_name == 'workflow_dispatch'
env:
API_TOKEN: ${{ secrets.API_TOKEN }}
- name: Find Comment
uses: peter-evans/find-comment@v2
id: fc
if: github.event_name == 'pull_request'
with:
issue-number: ${{ github.event.pull_request.number }}
comment-author: 'github-actions[bot]'
body-includes: '| Repository name | Source type | Average hole completion time (s) | Pass percentage |'
- name: Create or update comment
if: github.event_name == 'pull_request'
uses: peter-evans/create-or-update-comment@v3
with:
comment-id: ${{ steps.fc.outputs.comment-id }}
issue-number: ${{ github.event.pull_request.number }}
body-path: results.md
edit-mode: replace

5
.gitignore vendored
View file

@ -1,2 +1,7 @@
.vscode/
dist/ dist/
target/ target/
.DS_Store
__pycache__/
results.md
.pytest_cache/

728
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,16 +1,17 @@
[workspace] [workspace]
members = ["xtask/", "crates/*"] members = ["xtask/", "crates/*"]
resolver = "2" resolver = "2"
exclude = ["crates/testbed/repositories/*"]
[workspace.package] [workspace.package]
edition = "2021" edition = "2021"
license = "Apache-2.0" license = "Apache-2.0"
authors = ["Luc Georges <luc@huggingface.co>"] authors = ["Luc Georges <luc@huggingface.co>"]
[profile.dev] # [profile.dev]
# Disabling debug info speeds up builds a bunch, # Disabling debug info speeds up builds a bunch,
# and we don't rely on it for debugging that much. # and we don't rely on it for debugging that much.
debug = 0 # debug = 0
[profile.dev.package] [profile.dev.package]
# This speeds up `cargo xtask dist`. # This speeds up `cargo xtask dist`.

View file

@ -9,11 +9,20 @@ name = "llm-ls"
[dependencies] [dependencies]
home = "0.5" home = "0.5"
ropey = "1.6" ropey = "1.6"
reqwest = { version = "0.11", default-features = false, features = ["json", "rustls-tls"] } reqwest = { version = "0.11", default-features = false, features = [
"json",
"rustls-tls",
] }
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde_json = "1" serde_json = "1"
tokenizers = { version = "0.13", default-features = false, features = ["onig"] } tokenizers = { version = "0.14", default-features = false, features = ["onig"] }
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "rt-multi-thread"] } tokio = { version = "1", features = [
"fs",
"io-std",
"io-util",
"macros",
"rt-multi-thread",
] }
tower-lsp = "0.20" tower-lsp = "0.20"
tracing = "0.1" tracing = "0.1"
tracing-appender = "0.2" tracing-appender = "0.2"
@ -43,8 +52,4 @@ tree-sitter-typescript = "0.20"
[dependencies.uuid] [dependencies.uuid]
version = "1.4" version = "1.4"
features = [ features = ["v4", "fast-rng", "serde"]
"v4",
"fast-rng",
"serde",
]

View file

@ -663,7 +663,7 @@ impl LanguageServer for Backend {
async fn initialized(&self, _: InitializedParams) { async fn initialized(&self, _: InitializedParams) {
self.client self.client
.log_message(MessageType::INFO, "{llm-ls} initialized") .log_message(MessageType::INFO, "llm-ls initialized")
.await; .await;
info!("initialized language server"); info!("initialized language server");
} }
@ -686,15 +686,15 @@ impl LanguageServer for Backend {
Err(err) => error!("error opening {uri}: {err}"), Err(err) => error!("error opening {uri}: {err}"),
} }
self.client self.client
.log_message(MessageType::INFO, "{llm-ls} file opened") .log_message(MessageType::INFO, format!("{uri} opened"))
.await; .await;
} }
async fn did_change(&self, params: DidChangeTextDocumentParams) { async fn did_change(&self, params: DidChangeTextDocumentParams) {
self.client
.log_message(MessageType::INFO, "{llm-ls} file changed")
.await;
let uri = params.text_document.uri.to_string(); let uri = params.text_document.uri.to_string();
self.client
.log_message(MessageType::INFO, format!("{uri} changed"))
.await;
let mut document_map = self.document_map.write().await; let mut document_map = self.document_map.write().await;
let doc = document_map.get_mut(&uri); let doc = document_map.get_mut(&uri);
if let Some(doc) = doc { if let Some(doc) = doc {
@ -708,20 +708,20 @@ impl LanguageServer for Backend {
} }
async fn did_save(&self, params: DidSaveTextDocumentParams) { async fn did_save(&self, params: DidSaveTextDocumentParams) {
self.client
.log_message(MessageType::INFO, "{llm-ls} file saved")
.await;
let uri = params.text_document.uri.to_string(); let uri = params.text_document.uri.to_string();
self.client
.log_message(MessageType::INFO, format!("{uri} saved"))
.await;
info!("{uri} saved"); info!("{uri} saved");
} }
// TODO: // TODO:
// textDocument/didClose // textDocument/didClose
async fn did_close(&self, params: DidCloseTextDocumentParams) { async fn did_close(&self, params: DidCloseTextDocumentParams) {
self.client
.log_message(MessageType::INFO, "{llm-ls} file closed")
.await;
let uri = params.text_document.uri.to_string(); let uri = params.text_document.uri.to_string();
self.client
.log_message(MessageType::INFO, format!("{uri} closed"))
.await;
info!("{uri} closed"); info!("{uri} closed");
} }

View file

@ -0,0 +1,16 @@
[package]
name = "lsp-client"
version = "0.1.0"
edition.workspace = true
license.workspace = true
authors.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
lsp-types = "0.94"
serde = "1"
serde_json = "1"
tokio = { version = "1", features = ["io-util", "process"] }
tracing = "0.1"

View file

@ -0,0 +1,5 @@
# LSP Client
Rust LSP Client.
Heavily inspired by rust-analyzer's lsp-server implementation.

View file

@ -0,0 +1,130 @@
use std::sync::Arc;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::sync::{oneshot, Mutex};
use tokio::task::JoinHandle;
use tracing::{debug, error};
use crate::error::Result;
use crate::msg::{Message, Notification, Request, Response};
use crate::res_queue::ResQueue;
use crate::server::Server;
pub struct Connection {
pub(crate) sender: UnboundedSender<Message>,
pub(crate) receiver: UnboundedReceiver<Message>,
}
#[derive(Clone)]
pub struct LspClient {
reader_thread: Arc<JoinHandle<()>>,
res_queue: Arc<Mutex<ResQueue<oneshot::Sender<Response>>>>,
server: Arc<Server>,
server_sender: UnboundedSender<Message>,
}
impl LspClient {
pub async fn new(conn: Connection, server: Server) -> Self {
let res_queue = Arc::new(Mutex::new(ResQueue::default()));
let res_queue_clone = res_queue.clone();
let mut rx = conn.receiver;
let reader_thread = tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
match msg {
Message::Request(req) => Self::on_request(req),
Message::Notification(not) => Self::on_notification(not),
Message::Response(res) => {
Self::complete_request(res_queue_clone.clone(), res).await
}
}
}
});
Self {
reader_thread: Arc::new(reader_thread),
res_queue,
server: Arc::new(server),
server_sender: conn.sender,
}
}
fn on_request(_req: Request) {
todo!("requests are not handled by client");
}
fn on_notification(not: Notification) {
debug!("received notification: {not:?}");
}
pub async fn send_request<R: lsp_types::request::Request>(
&self,
params: R::Params,
) -> Result<Response> {
let (sender, receiver) = oneshot::channel::<Response>();
let request =
self.res_queue
.lock()
.await
.outgoing
.register(R::METHOD.to_string(), params, sender);
self.send(request.into());
Ok(receiver.await?)
}
async fn complete_request(
res_queue: Arc<Mutex<ResQueue<oneshot::Sender<Response>>>>,
response: Response,
) {
let sender = res_queue
.lock()
.await
.outgoing
.complete(response.id.clone())
.expect("received response for unknown request");
sender.send(response).unwrap();
}
pub fn send_notification<N: lsp_types::notification::Notification>(&self, params: N::Params) {
let not = Notification::new(N::METHOD.to_string(), params);
self.send(not.into());
}
pub async fn shutdown(&self) -> Result<()> {
self.send_request::<lsp_types::request::Shutdown>(())
.await?;
Ok(())
}
/// Exit will join on server threads waiting for exit.
///
/// This will fail if there are other strong references to the [`Server`] instance
pub async fn exit(self) {
self.send_notification::<lsp_types::notification::Exit>(());
match Arc::into_inner(self.reader_thread) {
Some(reader) => match reader.await {
Ok(r) => r,
Err(err) => {
error!("client reader panicked!");
std::panic::panic_any(err)
}
},
None => error!("error joining client thread, resources may have been leaked"),
};
match Arc::into_inner(self.server) {
Some(server) => {
match server.join().await {
Ok(_) => (),
Err(err) => error!("thread exited with error: {}", err),
};
}
None => error!("error joining server threads, resources may have been leaked"),
};
}
fn send(&self, message: Message) {
self.server_sender.send(message).unwrap();
}
}

View file

@ -0,0 +1,95 @@
use std::fmt;
use std::io;
use tokio::sync::oneshot::error::RecvError;
use crate::msg::ResponseError;
#[derive(Debug, Clone, PartialEq)]
pub struct ProtocolError(pub(crate) String);
impl std::error::Error for ProtocolError {}
impl fmt::Display for ProtocolError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self.0, f)
}
}
#[derive(Debug)]
pub enum ExtractError {
/// The extracted message was of a different method than expected.
MethodMismatch(String, String),
/// Failed to deserialize the message.
JsonError {
msg_type: String,
method: Option<String>,
error: serde_json::Error,
},
/// Server responded with an Error
ResponseError { error: ResponseError },
}
impl std::error::Error for ExtractError {}
impl fmt::Display for ExtractError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ExtractError::MethodMismatch(asked, req_method) => {
write!(
f,
"Method mismatch for request, extract for '{}' != request method '{}'",
asked, req_method
)
}
ExtractError::JsonError {
msg_type,
method,
error,
} => {
let method = if let Some(method) = method {
method.clone()
} else {
"None".to_owned()
};
write!(f, "Invalid message body\nMessage type: {msg_type}\nMethod: {method}\n error: {error}",)
}
ExtractError::ResponseError { error } => {
write!(f, "Server answered with an error message\n error: {error}",)
}
}
}
}
#[derive(Debug)]
pub enum Error {
ChannelClosed(RecvError),
Io(io::Error),
MissingBinaryPath,
}
impl std::error::Error for Error {}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::ChannelClosed(e) => write!(f, "Channel closed: {}", e),
Error::Io(e) => write!(f, "IO error: {}", e),
Error::MissingBinaryPath => write!(f, "Missing binary path"),
}
}
}
impl From<RecvError> for Error {
fn from(value: RecvError) -> Self {
Self::ChannelClosed(value)
}
}
impl From<io::Error> for Error {
fn from(value: io::Error) -> Self {
Self::Io(value)
}
}
pub type Result<T> = std::result::Result<T, Error>;

View file

@ -0,0 +1,5 @@
pub mod client;
pub mod error;
pub mod msg;
pub mod res_queue;
pub mod server;

View file

@ -0,0 +1,458 @@
use std::{
fmt::{self, Display},
io,
marker::Unpin,
};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tracing::debug;
use crate::error::ExtractError;
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum Message {
Request(Request),
Response(Response),
Notification(Notification),
}
impl From<Request> for Message {
fn from(request: Request) -> Message {
Message::Request(request)
}
}
impl From<Response> for Message {
fn from(response: Response) -> Message {
Message::Response(response)
}
}
impl From<Notification> for Message {
fn from(notification: Notification) -> Message {
Message::Notification(notification)
}
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[serde(transparent)]
pub struct RequestId(IdRepr);
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[serde(untagged)]
enum IdRepr {
I32(i32),
String(String),
}
impl From<i32> for RequestId {
fn from(id: i32) -> RequestId {
RequestId(IdRepr::I32(id))
}
}
impl From<String> for RequestId {
fn from(id: String) -> RequestId {
RequestId(IdRepr::String(id))
}
}
impl fmt::Display for RequestId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.0 {
IdRepr::I32(it) => fmt::Display::fmt(it, f),
// Use debug here, to make it clear that `92` and `"92"` are
// different, and to reduce WTF factor if the sever uses `" "` as an
// ID.
IdRepr::String(it) => fmt::Debug::fmt(it, f),
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Request {
pub id: RequestId,
pub method: String,
#[serde(default = "serde_json::Value::default")]
#[serde(skip_serializing_if = "serde_json::Value::is_null")]
pub params: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "lowercase")]
pub enum ResponseContent {
Result(serde_json::Value),
Error(ResponseError),
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Response {
// JSON RPC allows this to be null if it was impossible
// to decode the request's id. Ignore this special case
// and just die horribly.
pub id: RequestId,
#[serde(flatten)]
pub content: ResponseContent,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ResponseError {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
}
impl Display for ResponseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match ErrorCode::try_from_i32(self.code) {
Ok(code) => write!(f, "{:?} [{}]: {}", code, self.code, self.message),
Err(_) => write!(f, "Unknown error code [{}]: {}", self.code, self.message),
}
}
}
#[derive(Clone, Copy, Debug)]
#[non_exhaustive]
pub enum ErrorCode {
// Defined by JSON RPC:
ParseError = -32700,
InvalidRequest = -32600,
MethodNotFound = -32601,
InvalidParams = -32602,
InternalError = -32603,
ServerErrorStart = -32099,
ServerErrorEnd = -32000,
/// Error code indicating that a server received a notification or
/// request before the server has received the `initialize` request.
ServerNotInitialized = -32002,
UnknownErrorCode = -32001,
// Defined by the protocol:
/// The client has canceled a request and a server has detected
/// the cancel.
RequestCanceled = -32800,
/// The server detected that the content of a document got
/// modified outside normal conditions. A server should
/// NOT send this error code if it detects a content change
/// in it unprocessed messages. The result even computed
/// on an older state might still be useful for the client.
///
/// If a client decides that a result is not of any use anymore
/// the client should cancel the request.
ContentModified = -32801,
/// The server cancelled the request. This error code should
/// only be used for requests that explicitly support being
/// server cancellable.
///
/// @since 3.17.0
ServerCancelled = -32802,
/// A request failed but it was syntactically correct, e.g the
/// method name was known and the parameters were valid. The error
/// message should contain human readable information about why
/// the request failed.
///
/// @since 3.17.0
RequestFailed = -32803,
}
impl ErrorCode {
fn try_from_i32(code: i32) -> Result<ErrorCode, ()> {
match code {
-32700 => Ok(ErrorCode::ParseError),
-32600 => Ok(ErrorCode::InvalidRequest),
-32601 => Ok(ErrorCode::MethodNotFound),
-32602 => Ok(ErrorCode::InvalidParams),
-32603 => Ok(ErrorCode::InternalError),
-32099 => Ok(ErrorCode::ServerErrorStart),
-32000 => Ok(ErrorCode::ServerErrorEnd),
-32002 => Ok(ErrorCode::ServerNotInitialized),
-32001 => Ok(ErrorCode::UnknownErrorCode),
-32800 => Ok(ErrorCode::RequestCanceled),
-32801 => Ok(ErrorCode::ContentModified),
-32802 => Ok(ErrorCode::ServerCancelled),
-32803 => Ok(ErrorCode::RequestFailed),
_ => Err(()),
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Notification {
pub method: String,
#[serde(default = "serde_json::Value::default")]
#[serde(skip_serializing_if = "serde_json::Value::is_null")]
pub params: serde_json::Value,
}
impl Message {
pub async fn read<R: AsyncBufRead + Unpin + ?Sized>(r: &mut R) -> io::Result<Option<Message>> {
Message::_read(r).await
}
async fn _read<R: AsyncBufRead + Unpin + ?Sized>(r: &mut R) -> io::Result<Option<Message>> {
let text = match read_msg_text(r).await? {
None => return Ok(None),
Some(text) => text,
};
let msg = serde_json::from_str(&text)?;
Ok(Some(msg))
}
pub async fn write<W: AsyncWrite + Unpin>(self, w: &mut W) -> io::Result<()> {
self._write(w).await
}
async fn _write<W: AsyncWrite + Unpin>(self, w: &mut W) -> io::Result<()> {
#[derive(Serialize)]
struct JsonRpc {
jsonrpc: &'static str,
#[serde(flatten)]
msg: Message,
}
let text = serde_json::to_string(&JsonRpc {
jsonrpc: "2.0",
msg: self,
})?;
write_msg_text(w, &text).await
}
}
impl Response {
pub fn new_ok<R: Serialize>(id: RequestId, result: R) -> Response {
Response {
id,
content: ResponseContent::Result(serde_json::to_value(result).unwrap()),
}
}
pub fn new_err(id: RequestId, code: i32, message: String) -> Response {
let error = ResponseError {
code,
message,
data: None,
};
Response {
id,
content: ResponseContent::Error(error),
}
}
pub fn extract<P: DeserializeOwned>(self) -> Result<(RequestId, P), ExtractError> {
match self.content {
ResponseContent::Result(result) => match serde_json::from_value(result) {
Ok(params) => Ok((self.id, params)),
Err(error) => Err(ExtractError::JsonError {
msg_type: "response".to_owned(),
method: None,
error,
}),
},
ResponseContent::Error(error) => Err(ExtractError::ResponseError { error }),
}
}
}
impl Request {
pub fn new<P: Serialize>(id: RequestId, method: String, params: P) -> Request {
Request {
id,
method,
params: serde_json::to_value(params).unwrap(),
}
}
pub fn extract<P: DeserializeOwned>(
self,
method: &str,
) -> Result<(RequestId, P), ExtractError> {
if self.method != method {
return Err(ExtractError::MethodMismatch(self.method, method.to_owned()));
}
match serde_json::from_value(self.params) {
Ok(params) => Ok((self.id, params)),
Err(error) => Err(ExtractError::JsonError {
msg_type: "request".to_owned(),
method: Some(self.method),
error,
}),
}
}
}
impl Notification {
pub fn new(method: String, params: impl Serialize) -> Notification {
Notification {
method,
params: serde_json::to_value(params).unwrap(),
}
}
pub fn extract<P: DeserializeOwned>(self, method: &str) -> Result<P, ExtractError> {
if self.method != method {
return Err(ExtractError::MethodMismatch(self.method, method.to_owned()));
}
match serde_json::from_value(self.params) {
Ok(params) => Ok(params),
Err(error) => Err(ExtractError::JsonError {
msg_type: "notification".to_owned(),
method: Some(self.method),
error,
}),
}
}
pub(crate) fn is_exit(&self) -> bool {
self.method == "exit"
}
}
async fn read_msg_text<R: AsyncBufRead + Unpin + ?Sized>(
inp: &mut R,
) -> io::Result<Option<String>> {
fn invalid_data(error: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error {
io::Error::new(io::ErrorKind::InvalidData, error)
}
macro_rules! invalid_data {
($($tt:tt)*) => (invalid_data(format!($($tt)*)))
}
let mut size = None;
let mut buf = String::new();
loop {
buf.clear();
if inp.read_line(&mut buf).await? == 0 {
return Ok(None);
}
if !buf.ends_with("\r\n") {
return Err(invalid_data!("malformed header: {:?}", buf));
}
let buf = &buf[..buf.len() - 2];
if buf.is_empty() {
break;
}
let mut parts = buf.splitn(2, ": ");
let header_name = parts.next().unwrap();
let header_value = parts
.next()
.ok_or_else(|| invalid_data!("malformed header: {:?}", buf))?;
if header_name.eq_ignore_ascii_case("Content-Length") {
size = Some(header_value.parse::<usize>().map_err(invalid_data)?);
}
}
let size: usize = size.ok_or_else(|| invalid_data!("no Content-Length"))?;
let mut buf = buf.into_bytes();
buf.resize(size, 0);
inp.read_exact(&mut buf).await?;
let buf = String::from_utf8(buf).map_err(invalid_data)?;
debug!("< {}", buf);
Ok(Some(buf))
}
async fn write_msg_text<W: AsyncWrite + Unpin>(out: &mut W, msg: &str) -> io::Result<()> {
debug!("> {}", msg);
out.write_all(format!("Content-Length: {}\r\n\r\n", msg.len()).as_bytes())
.await?;
out.write_all(msg.as_bytes()).await?;
out.flush().await?;
Ok(())
}
#[cfg(test)]
mod tests {
use crate::msg::{ResponseContent, ResponseError};
use super::{Message, Notification, Request, RequestId, Response};
#[test]
fn shutdown_with_explicit_null() {
let text = "{\"jsonrpc\": \"2.0\",\"id\": 3,\"method\": \"shutdown\", \"params\": null }";
let msg: Message = serde_json::from_str(text).unwrap();
assert!(
matches!(msg, Message::Request(req) if req.id == 3.into() && req.method == "shutdown")
);
}
#[test]
fn shutdown_with_no_params() {
let text = "{\"jsonrpc\": \"2.0\",\"id\": 3,\"method\": \"shutdown\"}";
let msg: Message = serde_json::from_str(text).unwrap();
assert!(
matches!(msg, Message::Request(req) if req.id == 3.into() && req.method == "shutdown")
);
}
#[test]
fn notification_with_explicit_null() {
let text = "{\"jsonrpc\": \"2.0\",\"method\": \"exit\", \"params\": null }";
let msg: Message = serde_json::from_str(text).unwrap();
assert!(matches!(msg, Message::Notification(not) if not.method == "exit"));
}
#[test]
fn notification_with_no_params() {
let text = "{\"jsonrpc\": \"2.0\",\"method\": \"exit\"}";
let msg: Message = serde_json::from_str(text).unwrap();
assert!(matches!(msg, Message::Notification(not) if not.method == "exit"));
}
#[test]
fn serialize_request_with_null_params() {
let msg = Message::Request(Request {
id: RequestId::from(3),
method: "shutdown".into(),
params: serde_json::Value::Null,
});
let serialized = serde_json::to_string(&msg).unwrap();
assert_eq!("{\"id\":3,\"method\":\"shutdown\"}", serialized);
}
#[test]
fn serialize_notification_with_null_params() {
let msg = Message::Notification(Notification {
method: "exit".into(),
params: serde_json::Value::Null,
});
let serialized = serde_json::to_string(&msg).unwrap();
assert_eq!("{\"method\":\"exit\"}", serialized);
}
#[test]
fn serialize_response_with_null_result() {
let text = "{\"id\":1,\"result\":null}";
let msg = Message::Response(Response {
id: RequestId::from(1),
content: ResponseContent::Result(serde_json::Value::Null),
});
let serialized = serde_json::to_string(&msg).unwrap();
assert_eq!(text, serialized);
}
#[test]
fn serialize_response_with_error() {
let text = "{\"id\":1,\"error\":{\"code\":-32603,\"message\":\"some error message\"}}";
let msg = Message::Response(Response {
id: RequestId::from(1),
content: ResponseContent::Error(ResponseError {
code: -32603,
message: "some error message".to_owned(),
data: None,
}),
});
let serialized = serde_json::to_string(&msg).unwrap();
assert_eq!(text, serialized);
}
}

View file

@ -0,0 +1,41 @@
use std::collections::HashMap;
use serde::Serialize;
use crate::msg::{Request, RequestId};
/// Manages the set of pending responses
#[derive(Debug)]
pub struct ResQueue<O> {
pub outgoing: Outgoing<O>,
}
impl<O> Default for ResQueue<O> {
fn default() -> ResQueue<O> {
ResQueue {
outgoing: Outgoing {
next_id: 0,
pending: HashMap::default(),
},
}
}
}
#[derive(Debug)]
pub struct Outgoing<O> {
next_id: i32,
pending: HashMap<RequestId, O>,
}
impl<O> Outgoing<O> {
pub fn register<P: Serialize>(&mut self, method: String, params: P, data: O) -> Request {
let id = RequestId::from(self.next_id);
self.pending.insert(id.clone(), data);
self.next_id += 1;
Request::new(id, method, params)
}
pub fn complete(&mut self, id: RequestId) -> Option<O> {
self.pending.remove(&id)
}
}

View file

@ -0,0 +1,148 @@
use std::{io, path::PathBuf, process::Stdio};
use tokio::{
io::{BufReader, BufWriter},
process::{Child, Command},
sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
task::JoinHandle,
};
use tracing::{debug, error};
use crate::msg::Message;
use crate::{
client::Connection,
error::{Error, Result},
};
pub struct Server {
threads: IoThreads,
}
impl Server {
pub fn build() -> ServerBuilder {
ServerBuilder {
binary_path: None,
command: None,
transport: Transport::default(),
}
}
/// join server's threads to the current thread
pub async fn join(self) -> Result<()> {
self.threads.join().await?;
Ok(())
}
}
#[derive(Default)]
pub enum Transport {
#[default]
Stdio,
Socket,
}
pub struct ServerBuilder {
binary_path: Option<PathBuf>,
command: Option<Command>,
transport: Transport,
}
impl ServerBuilder {
pub fn binary_path(mut self, binary_path: PathBuf) -> Self {
self.binary_path = Some(binary_path);
self
}
pub fn command(mut self, command: Command) -> Self {
self.command = Some(command);
self
}
pub fn transport(mut self, transport: Transport) -> Self {
self.transport = transport;
self
}
pub async fn start(self) -> Result<(Connection, Server)> {
let mut command = if let Some(command) = self.command {
command
} else if let Some(path) = self.binary_path {
Command::new(path)
} else {
return Err(Error::MissingBinaryPath);
};
match self.transport {
Transport::Stdio => {
let child = command
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()?;
let (sender, receiver, threads) = stdio(child);
Ok((Connection { sender, receiver }, Server { threads }))
}
Transport::Socket => {
todo!("socket transport not implemented");
}
}
}
}
fn stdio(
mut child: Child,
) -> (
UnboundedSender<Message>,
UnboundedReceiver<Message>,
IoThreads,
) {
let (writer_sender, mut writer_receiver) = unbounded_channel::<Message>();
let writer = tokio::spawn(async move {
let stdin = child.stdin.take().unwrap();
let mut bufr = BufWriter::new(stdin);
while let Some(it) = writer_receiver.recv().await {
let is_exit = matches!(&it, Message::Notification(n) if n.is_exit());
debug!("sending message {:#?}", it);
it.write(&mut bufr).await?;
if is_exit {
break;
}
}
Ok(())
});
let (reader_sender, reader_receiver) = unbounded_channel::<Message>();
let reader = tokio::spawn(async move {
let stdout = child.stdout.take().unwrap();
let mut reader = BufReader::new(stdout);
while let Some(msg) = Message::read(&mut reader).await? {
debug!("received message {:#?}", msg);
reader_sender
.send(msg)
.expect("receiver was dropped, failed to send a message");
}
Ok(())
});
let threads = IoThreads { reader, writer };
(writer_sender, reader_receiver, threads)
}
pub struct IoThreads {
reader: JoinHandle<io::Result<()>>,
writer: JoinHandle<io::Result<()>>,
}
impl IoThreads {
pub async fn join(self) -> io::Result<()> {
match self.reader.await? {
Ok(_) => (),
Err(err) => {
error!("reader err: {err}");
}
}
match self.writer.await? {
Ok(_) => (),
Err(err) => {
error!("writer err: {err}");
}
}
Ok(())
}
}

35
crates/testbed/Cargo.toml Normal file
View file

@ -0,0 +1,35 @@
[package]
name = "testbed"
version = "0.1.0"
resolver = "2"
edition.workspace = true
license.workspace = true
authors.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1"
clap = { version = "4", features = ["derive"] }
futures = "0.3"
futures-util = "0.3"
home = "0.5"
lsp-client = { path = "../lsp-client" }
lsp-types = "0.94"
rand = "0.8"
reqwest = { version = "0.11", features = ["stream"] }
ropey = "1.6"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
serde_yaml = "0.9"
tempfile = "3"
tokio = "1"
tokio-util = { version = "0.7", features = ["compat"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
url = "2"
zip = "0.6"
[dependencies.uuid]
version = "1.5"
features = ["v4", "fast-rng", "serde"]

80
crates/testbed/README.md Normal file
View file

@ -0,0 +1,80 @@
# testbed
testbed is a framework to evaluate the efficiency of the completions generated by llm-ls and the underlying model.
It works by first making holes in files, then generates completions for a given list of repositories and finally runs the associated unit tests.
The result is a table containing a line for each repository and the total with the average percentage of successful unit tests.
Here is a simplified pseudo code algorithm for testbed:
```
read the repositories file
read the holes file(s)
for each repository
spawn a thread
setup the repository
for each hole
make the hole as specified by the file
generate completions
build the code
run the tests
print results
```
## Running testbed
Before running testbed you will need to create a repositories file. It is a YAML file containing a list of repositories to test.
It also contains the parameters to the `llm-ls/getCompletions` request.
Repositories can either be sourced from your local storage or Github.
You can check the repositories files at the root of the crate to see the full structure.
### Generating holes
Before running testbed, you will need to generate a holes file for each repository. To generate a holes file run testbed with the `-g` option. You can specify the number of holes to make with `-n <number>`. It will take the list of repositories in your YAML file and create the associated files at the defined path.
### Setup
testbed runs completions for each repository in parallel. It will first create a temporary directory, then copy or download the repository's source files to that location and finally run the setup commands.
Setup commands are useful to install dependencies.
```yaml
setup_commands:
- ["python3", ["-m", "venv", "huggingface_hub-venv"]]
- ["huggingface_hub-venv/bin/python3", ["-m", "pip", "install", ".[dev]"]]
```
### Build
Before running the tests, testbed will run a build command to check if the code is valid.
To configure the commands, you can do the following:
```yaml
build_command: huggingface_hub-venv/bin/python3
build_args: ["-m", "compileall", "-q", "."]
```
### Runners
testbed supports two test runners at the moment:
- cargo
- pytest
To configure your runner, you have the following options:
```yaml
runner: pytest
runner_command: huggingface_hub-venv/bin/python3
runner_extra_args:
- "-k"
- "_utils_ and not _utils_cache and not _utils_http and not paginate and not git"
```
You can override the runners command with `runner_command`, which is useful when setting up dependencies in a venv.
## References
testbed was inspired by [human-eval](https://github.com/openai/human-eval) and [RepoEval](https://arxiv.org/abs/2303.12570).

View file

@ -0,0 +1 @@
[{"cursor":{"line":38,"character":6},"file":"src/lib.rs"},{"cursor":{"line":568,"character":8},"file":"src/lib.rs"},{"cursor":{"line":271,"character":11},"file":"src/lib.rs"},{"cursor":{"line":887,"character":14},"file":"src/lib.rs"},{"cursor":{"line":697,"character":12},"file":"src/lib.rs"},{"cursor":{"line":707,"character":6},"file":"src/lib.rs"},{"cursor":{"line":240,"character":10},"file":"src/lib.rs"},{"cursor":{"line":387,"character":13},"file":"src/lib.rs"},{"cursor":{"line":501,"character":8},"file":"src/lib.rs"},{"cursor":{"line":178,"character":6},"file":"src/lib.rs"}]

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1 @@
[{"cursor":{"line":115,"character":0},"file":"src/stores/mod.rs"},{"cursor":{"line":136,"character":1},"file":"src/stores/expiring_value_cache.rs"},{"cursor":{"line":92,"character":4},"file":"src/lru_list.rs"},{"cursor":{"line":280,"character":14},"file":"src/stores/timed.rs"},{"cursor":{"line":294,"character":14},"file":"src/stores/unbound.rs"},{"cursor":{"line":267,"character":1},"file":"src/stores/unbound.rs"},{"cursor":{"line":142,"character":0},"file":"src/stores/expiring_value_cache.rs"},{"cursor":{"line":370,"character":12},"file":"src/stores/sized.rs"},{"cursor":{"line":280,"character":10},"file":"src/proc_macro.rs"},{"cursor":{"line":177,"character":5},"file":"src/stores/timed_sized.rs"}]

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1 @@
[{"cursor":{"line":12,"character":3},"file":"src/lib.rs"},{"cursor":{"line":12,"character":11},"file":"src/lib.rs"},{"cursor":{"line":12,"character":5},"file":"src/lib.rs"},{"cursor":{"line":0,"character":4},"file":"src/lib.rs"},{"cursor":{"line":12,"character":3},"file":"src/lib.rs"},{"cursor":{"line":12,"character":6},"file":"src/lib.rs"},{"cursor":{"line":0,"character":10},"file":"src/lib.rs"},{"cursor":{"line":0,"character":3},"file":"src/lib.rs"},{"cursor":{"line":12,"character":4},"file":"src/lib.rs"},{"cursor":{"line":0,"character":10},"file":"src/lib.rs"}]

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1 @@
[{"cursor":{"line":1300,"character":8},"file":"fastapi/param_functions.py"},{"cursor":{"line":0,"character":12},"file":"fastapi/middleware/trustedhost.py"},{"cursor":{"line":1,"character":5},"file":"fastapi/middleware/httpsredirect.py"},{"cursor":{"line":177,"character":14},"file":"fastapi/openapi/docs.py"},{"cursor":{"line":37,"character":5},"file":"fastapi/security/http.py"},{"cursor":{"line":0,"character":7},"file":"fastapi/testclient.py"},{"cursor":{"line":647,"character":6},"file":"fastapi/param_functions.py"},{"cursor":{"line":419,"character":0},"file":"fastapi/_compat.py"},{"cursor":{"line":100,"character":1},"file":"fastapi/exceptions.py"},{"cursor":{"line":107,"character":14},"file":"fastapi/security/oauth2.py"}]

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1 @@
[{"cursor":{"line":73,"character":10},"file":"helix-core/src/chars.rs"},{"cursor":{"line":257,"character":11},"file":"helix-dap/src/types.rs"},{"cursor":{"line":39,"character":14},"file":"helix-view/src/info.rs"},{"cursor":{"line":116,"character":12},"file":"helix-term/src/ui/mod.rs"},{"cursor":{"line":1,"character":14},"file":"helix-term/src/ui/text.rs"},{"cursor":{"line":2,"character":5},"file":"helix-core/src/config.rs"},{"cursor":{"line":151,"character":14},"file":"helix-view/src/gutter.rs"},{"cursor":{"line":11,"character":10},"file":"helix-term/src/ui/lsp.rs"},{"cursor":{"line":18,"character":0},"file":"helix-term/src/ui/text.rs"},{"cursor":{"line":230,"character":3},"file":"helix-term/src/ui/markdown.rs"}]

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1 @@
[{"cursor":{"line":485,"character":14},"file":"src/huggingface_hub/__init__.py"},{"cursor":{"line":134,"character":9},"file":"src/huggingface_hub/_space_api.py"},{"cursor":{"line":122,"character":3},"file":"src/huggingface_hub/inference_api.py"},{"cursor":{"line":1485,"character":0},"file":"src/huggingface_hub/file_download.py"},{"cursor":{"line":184,"character":4},"file":"src/huggingface_hub/inference/_common.py"},{"cursor":{"line":157,"character":13},"file":"src/huggingface_hub/utils/_validators.py"},{"cursor":{"line":32,"character":10},"file":"src/huggingface_hub/utils/_paths.py"},{"cursor":{"line":1596,"character":11},"file":"src/huggingface_hub/file_download.py"},{"cursor":{"line":20,"character":4},"file":"src/huggingface_hub/commands/__init__.py"},{"cursor":{"line":47,"character":8},"file":"src/huggingface_hub/commands/user.py"}]

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1 @@
[{"cursor":{"line":54,"character":12},"file":"src/Type.ts"},{"cursor":{"line":108,"character":1},"file":"src/Eq.ts"},{"cursor":{"line":6,"character":4},"file":"src/Reporter.ts"},{"cursor":{"line":98,"character":7},"file":"src/Type.ts"},{"cursor":{"line":88,"character":1},"file":"src/Type.ts"},{"cursor":{"line":140,"character":11},"file":"src/Eq.ts"},{"cursor":{"line":52,"character":2},"file":"src/Kleisli.ts"},{"cursor":{"line":47,"character":2},"file":"src/PathReporter.ts"},{"cursor":{"line":109,"character":11},"file":"src/Schemable.ts"},{"cursor":{"line":484,"character":1},"file":"src/TaskDecoder.ts"}]

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1 @@
[{"cursor":{"line":136,"character":0},"file":"rust/lance/src/utils/sql.rs"},{"cursor":{"line":121,"character":7},"file":"rust/lance-index/src/vector/utils.rs"},{"cursor":{"line":182,"character":1},"file":"rust/lance-linalg/src/simd/i32.rs"},{"cursor":{"line":527,"character":2},"file":"rust/lance-core/src/io/writer.rs"},{"cursor":{"line":123,"character":0},"file":"rust/lance/benches/scan.rs"},{"cursor":{"line":14,"character":1},"file":"rust/lance-core/src/utils.rs"},{"cursor":{"line":211,"character":2},"file":"rust/lance-core/src/encodings/dictionary.rs"},{"cursor":{"line":508,"character":8},"file":"rust/lance-core/src/datatypes/field.rs"},{"cursor":{"line":624,"character":7},"file":"rust/lance/src/dataset/cleanup.rs"},{"cursor":{"line":396,"character":5},"file":"rust/lance/src/index/vector/opq.rs"}]

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1 @@
[{"cursor":{"line":0,"character":6},"file":"rust/vectordb/src/io.rs"},{"cursor":{"line":65,"character":4},"file":"rust/vectordb/src/utils.rs"},{"cursor":{"line":138,"character":7},"file":"rust/vectordb/src/data/inspect.rs"},{"cursor":{"line":25,"character":3},"file":"rust/vectordb/src/error.rs"},{"cursor":{"line":479,"character":3},"file":"rust/vectordb/src/table.rs"},{"cursor":{"line":133,"character":10},"file":"rust/vectordb/src/data/sanitize.rs"},{"cursor":{"line":54,"character":1},"file":"rust/vectordb/src/utils.rs"},{"cursor":{"line":138,"character":9},"file":"rust/vectordb/src/table.rs"},{"cursor":{"line":34,"character":10},"file":"rust/vectordb/src/database.rs"},{"cursor":{"line":159,"character":1},"file":"rust/vectordb/src/data/inspect.rs"}]

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1 @@
[{"cursor":{"line":374,"character":3},"file":"src/picklescan/scanner.py"},{"cursor":{"line":7,"character":4},"file":"src/picklescan/torch.py"},{"cursor":{"line":0,"character":5},"file":"src/picklescan/__main__.py"},{"cursor":{"line":0,"character":3},"file":"src/picklescan/torch.py"},{"cursor":{"line":351,"character":10},"file":"src/picklescan/scanner.py"},{"cursor":{"line":2,"character":6},"file":"src/picklescan/__main__.py"},{"cursor":{"line":192,"character":10},"file":"src/picklescan/scanner.py"},{"cursor":{"line":0,"character":9},"file":"src/picklescan/__init__.py"},{"cursor":{"line":2,"character":2},"file":"src/picklescan/__main__.py"},{"cursor":{"line":33,"character":3},"file":"src/picklescan/cli.py"}]

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1 @@
[{"cursor":{"line":3,"character":9},"file":"starlette/_utils.py"},{"cursor":{"line":0,"character":10},"file":"starlette/__init__.py"},{"cursor":{"line":38,"character":2},"file":"starlette/config.py"},{"cursor":{"line":23,"character":14},"file":"starlette/_utils.py"},{"cursor":{"line":393,"character":5},"file":"starlette/datastructures.py"},{"cursor":{"line":114,"character":2},"file":"starlette/testclient.py"},{"cursor":{"line":187,"character":13},"file":"starlette/templating.py"},{"cursor":{"line":79,"character":3},"file":"starlette/status.py"},{"cursor":{"line":129,"character":3},"file":"starlette/middleware/cors.py"},{"cursor":{"line":22,"character":6},"file":"starlette/middleware/sessions.py"}]

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1 @@
[{"cursor":{"line":165,"character":10},"file":"src/helpers/util.ts"},{"cursor":{"line":147,"character":7},"file":"src/ZodError.ts"},{"cursor":{"line":0,"character":14},"file":"src/ZodError.ts"},{"cursor":{"line":166,"character":9},"file":"src/helpers/util.ts"},{"cursor":{"line":5,"character":5},"file":"src/helpers/enumUtil.ts"},{"cursor":{"line":10,"character":2},"file":"src/errors.ts"},{"cursor":{"line":134,"character":12},"file":"src/benchmarks/primitives.ts"},{"cursor":{"line":2,"character":10},"file":"src/benchmarks/index.ts"},{"cursor":{"line":33,"character":5},"file":"src/benchmarks/realworld.ts"},{"cursor":{"line":12,"character":13},"file":"src/benchmarks/index.ts"}]

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,247 @@
---
context_window: 2000
fim:
enabled: true
prefix: <fim_prefix>
middle: <fim_middle>
suffix: <fim_suffix>
model: bigcode/starcoder
request_params:
max_new_tokens: 150
temperature: 0.2
do_sample: true
top_p: 0.95
tls_skip_verify_insecure: false
tokenizer_config:
repository: bigcode/starcoder
tokens_to_clear: ["<|endoftext|>"]
repositories:
- source:
type: local
path: simple
src_path: src
build_command: cargo
build_args: ["build"]
language: rust
runner: cargo
holes_file: simple.json
- source:
type: github
owner: mmaitre314
name: picklescan
revision: 40001cd1caa9e041b1bce1b80f3707056cd8be52
src_path: src/picklescan
build_command: picklescan-venv/bin/python3
build_args: ["-m", "compileall", "-q", "."]
language: python
runner: pytest
runner_command: picklescan-venv/bin/python3
setup_commands:
- ["python3", ["-m", "venv", "picklescan-venv"]]
- ["picklescan-venv/bin/python3", ["-m", "pip", "install", "."]]
- ["picklescan-venv/bin/python3", ["-m", "pip", "install", "-r", "requirements.txt"]]
holes_file: picklescan-smol.json
- source:
type: github
owner: huggingface
name: huggingface_hub
revision: a48eb89d4186bc84bca67b117cf29a0ee0b69774
src_path: src/huggingface_hub
build_command: huggingface_hub-venv/bin/python3
build_args: ["-m", "compileall", "-q", "."]
language: python
runner: pytest
runner_command: huggingface_hub-venv/bin/python3
runner_extra_args:
- "-k"
- "_utils_ and not _utils_cache and not _utils_http and not paginate and not git"
setup_commands:
- ["python3", ["-m", "venv", "huggingface_hub-venv"]]
- ["huggingface_hub-venv/bin/python3", ["-m", "pip", "install", ".[dev]"]]
holes_file: huggingface_hub-smol.json
- source:
type: github
owner: tiangolo
name: fastapi
revision: e4b21c6eab7cd58caf3c6c492ea1ce7945425dd1
src_path: fastapi
build_command: fastapi-venv/bin/python3
build_args: ["-m", "compileall", "-q", "."]
language: python
runner: pytest
runner_command: fastapi-venv/bin/python3
setup_commands:
- ["python3", ["-m", "venv", "fastapi-venv"]]
- ["fastapi-venv/bin/python3", ["-m", "pip", "install", "--upgrade", "pip"]]
- ["fastapi-venv/bin/python3", ["-m", "pip", "install", "-r", "requirements-tests.txt"]]
- ["fastapi-venv/bin/python3", ["-m", "pip", "install", "pydantic"]]
holes_file: fastapi-smol.json
- source:
type: github
owner: encode
name: starlette
revision: 657e7e7b728e13dc66cc3f77dffd00a42545e171
src_path: starlette
build_command: starlette-venv/bin/python3
build_args: ["-m", "compileall", "-q", "."]
language: python
runner: pytest
runner_command: starlette-venv/bin/python3
setup_commands:
- ["python3", ["-m", "venv", "starlette-venv"]]
- ["starlette-venv/bin/python3", ["-m", "pip", "install", "--upgrade", "pip"]]
- ["starlette-venv/bin/python3", ["-m", "pip", "install", "-r", "requirements.txt"]]
holes_file: starlette-smol.json
- source:
type: github
owner: lancedb
name: lancedb
revision: 682e95fa8388d5839c8a782063beb307c4fca4bc
src_path: rust/vectordb/src
build_command: cargo
build_args: ["build"]
language: rust
runner: cargo
# this is to avoid skewing the average hole completion time
setup_commands:
- ["cargo", ["build"]]
holes_file: lancedb-smol.json
- source:
type: github
owner: lancedb
name: lance
revision: c8ee16ec31eeca884c78bd4a400404aaa994ed46
src_path: rust
exclude_paths:
- .cargo
- .vscode
- .gitignore
- README.md
- img.png
build_command: cargo
build_args: ["build", "--all-features", "--manifest-path", "rust/Cargo.toml"]
language: rust
runner: cargo
runner_extra_args: ["--all-features", "--manifest-path", "rust/Cargo.toml"]
setup_commands:
- ["rm", ["rust/lance-core/protos", "rust/lance-index/protos"]]
- ["ln", ["-s", "../../protos", "rust/lance-core/protos"]]
- ["ln", ["-s", "../../protos", "rust/lance-index/protos"]]
- ["cargo", ["build", "--all-features", "--manifest-path", "rust/Cargo.toml"]]
holes_file: lance-smol.json
- source:
type: github
owner: tkaitchuck
name: constrandom
revision: e9f560ba14e09ff9db9caca3f2dfa3ff52cc96de
src_path: src
build_command: cargo
build_args: ["build"]
language: rust
runner: cargo
setup_commands:
- ["cargo", ["build"]]
holes_file: constrandom-smol.json
- source:
type: github
owner: jaemk
name: cached
revision: b1015561fb121c3e5698183c39df202e5a83994a
src_path: src
build_command: cargo
build_args: ["build"]
language: rust
runner: cargo
setup_commands:
- ["cargo", ["build"]]
holes_file: cached-smol.json
- source:
type: github
owner: smol-rs
name: async-executor
revision: b91875e73bd9aec582e099d8c792514381fc8d0f
src_path: src
build_command: cargo
build_args: ["build"]
language: rust
runner: cargo
setup_commands:
- ["cargo", ["build"]]
holes_file: async-executor-smol.json
- source:
type: github
owner: gcanti
name: io-ts
revision: 616583de0198632cad7820ed8701b15f654c7fd2
src_path: src
build_command: npm
build_args: ["run", "build"]
language: typescript
runner: vitest
setup_commands:
- ["npm", ["install"]]
holes_file: io-ts-smol.json
- source:
type: github
owner: colinhacks
name: zod
revision: 481c9ba1932203777f6fe9497bb2a8a1d33c620e
src_path: src
exclude_paths: ["src/__tests__"]
build_command: yarn
build_args: ["build"]
language: typescript
runner: jest
runner_command: yarn
runner_args: ["test", "--", "--no-colors"]
setup_commands:
- ["yarn", ["install"]]
holes_file: zod-smol.json
- source:
type: github
owner: helix-editor
name: helix
revision: ae6a0a9cfd377fbfa494760282498cf2ca322782
exclude_paths:
- .cargo
- .github
- book
- contrib
- docs
- helix-core/tests
- helix-term/tests
- helix-tui/tests
- helix-view/tests
- runtime
- xtask
- .envrc
- .gitattributes
- .gitignore
- .ignore
- CHANGELOG.md
- Cargo.lock
- LICENSE
- README.md
- VERSION
- base16_theme.toml
- default.nix
- flake.lock
- flake.nix
- grammars.nix
- languages.toml
- logo.svg
- logo_dark.svg
- logo_light.svg
- rust-toolchain.toml
- rustfmt.toml
- screenshot.png
- shell.nix
- theme.toml
build_command: cargo
build_args: ["build"]
language: rust
runner: cargo
runner_extra_args: ["--workspace"]
setup_commands:
- ["cargo", ["build"]]
holes_file: helix-smol.json

View file

@ -0,0 +1,247 @@
---
context_window: 2000
fim:
enabled: true
prefix: <fim_prefix>
middle: <fim_middle>
suffix: <fim_suffix>
model: bigcode/starcoder
request_params:
max_new_tokens: 150
temperature: 0.2
do_sample: true
top_p: 0.95
tls_skip_verify_insecure: false
tokenizer_config:
repository: bigcode/starcoder
tokens_to_clear: ["<|endoftext|>"]
repositories:
- source:
type: local
path: simple
src_path: src
build_command: cargo
build_args: ["build"]
language: rust
runner: cargo
holes_file: simple.json
- source:
type: github
owner: mmaitre314
name: picklescan
revision: 40001cd1caa9e041b1bce1b80f3707056cd8be52
src_path: src/picklescan
build_command: picklescan-venv/bin/python3
build_args: ["-m", "compileall", "-q", "."]
language: python
runner: pytest
runner_command: picklescan-venv/bin/python3
setup_commands:
- ["python3", ["-m", "venv", "picklescan-venv"]]
- ["picklescan-venv/bin/python3", ["-m", "pip", "install", "."]]
- ["picklescan-venv/bin/python3", ["-m", "pip", "install", "-r", "requirements.txt"]]
holes_file: picklescan.json
- source:
type: github
owner: huggingface
name: huggingface_hub
revision: a48eb89d4186bc84bca67b117cf29a0ee0b69774
src_path: src/huggingface_hub
build_command: huggingface_hub-venv/bin/python3
build_args: ["-m", "compileall", "-q", "."]
language: python
runner: pytest
runner_command: huggingface_hub-venv/bin/python3
runner_extra_args:
- "-k"
- "_utils_ and not _utils_cache and not _utils_http and not paginate and not git"
setup_commands:
- ["python3", ["-m", "venv", "huggingface_hub-venv"]]
- ["huggingface_hub-venv/bin/python3", ["-m", "pip", "install", ".[dev]"]]
holes_file: huggingface_hub.json
- source:
type: github
owner: tiangolo
name: fastapi
revision: e4b21c6eab7cd58caf3c6c492ea1ce7945425dd1
src_path: fastapi
build_command: fastapi-venv/bin/python3
build_args: ["-m", "compileall", "-q", "."]
language: python
runner: pytest
runner_command: fastapi-venv/bin/python3
setup_commands:
- ["python3", ["-m", "venv", "fastapi-venv"]]
- ["fastapi-venv/bin/python3", ["-m", "pip", "install", "--upgrade", "pip"]]
- ["fastapi-venv/bin/python3", ["-m", "pip", "install", "-r", "requirements-tests.txt"]]
- ["fastapi-venv/bin/python3", ["-m", "pip", "install", "pydantic"]]
holes_file: fastapi.json
- source:
type: github
owner: encode
name: starlette
revision: 657e7e7b728e13dc66cc3f77dffd00a42545e171
src_path: starlette
build_command: starlette-venv/bin/python3
build_args: ["-m", "compileall", "-q", "."]
language: python
runner: pytest
runner_command: starlette-venv/bin/python3
setup_commands:
- ["python3", ["-m", "venv", "starlette-venv"]]
- ["starlette-venv/bin/python3", ["-m", "pip", "install", "--upgrade", "pip"]]
- ["starlette-venv/bin/python3", ["-m", "pip", "install", "-r", "requirements.txt"]]
holes_file: starlette.json
- source:
type: github
owner: lancedb
name: lancedb
revision: 682e95fa8388d5839c8a782063beb307c4fca4bc
src_path: rust/vectordb/src
build_command: cargo
build_args: ["build"]
language: rust
runner: cargo
# this is to avoid skewing the average hole completion time
setup_commands:
- ["cargo", ["build"]]
holes_file: lancedb.json
- source:
type: github
owner: lancedb
name: lance
revision: 4f7fb490c6b36a659b2c40b3945b4d533356acf2
src_path: rust
exclude_paths:
- .cargo
- .vscode
- .gitignore
- README.md
- img.png
build_command: cargo
build_args: ["build", "--all-features", "--manifest-path", "rust/Cargo.toml"]
language: rust
runner: cargo
runner_extra_args: ["--all-features", "--manifest-path", "rust/Cargo.toml"]
setup_commands:
- ["rm", ["rust/lance-core/protos", "rust/lance-index/protos"]]
- ["ln", ["-s", "../../protos", "rust/lance-core/protos"]]
- ["ln", ["-s", "../../protos", "rust/lance-index/protos"]]
- ["cargo", ["build", "--all-features", "--manifest-path", "rust/Cargo.toml"]]
holes_file: lance.json
- source:
type: github
owner: tkaitchuck
name: constrandom
revision: e9f560ba14e09ff9db9caca3f2dfa3ff52cc96de
src_path: src
build_command: cargo
build_args: ["build"]
language: rust
runner: cargo
setup_commands:
- ["cargo", ["build"]]
holes_file: constrandom.json
- source:
type: github
owner: jaemk
name: cached
revision: b1015561fb121c3e5698183c39df202e5a83994a
src_path: src
build_command: cargo
build_args: ["build"]
language: rust
runner: cargo
setup_commands:
- ["cargo", ["build"]]
holes_file: cached.json
- source:
type: github
owner: smol-rs
name: async-executor
revision: b91875e73bd9aec582e099d8c792514381fc8d0f
src_path: src
build_command: cargo
build_args: ["build"]
language: rust
runner: cargo
setup_commands:
- ["cargo", ["build"]]
holes_file: async-executor.json
- source:
type: github
owner: gcanti
name: io-ts
revision: 616583de0198632cad7820ed8701b15f654c7fd2
src_path: src
build_command: npm
build_args: ["run", "build"]
language: typescript
runner: vitest
setup_commands:
- ["npm", ["install"]]
holes_file: io-ts.json
- source:
type: github
owner: colinhacks
name: zod
revision: 481c9ba1932203777f6fe9497bb2a8a1d33c620e
src_path: src
exclude_paths: ["src/__tests__"]
build_command: yarn
build_args: ["build"]
language: typescript
runner: jest
runner_command: yarn
runner_args: ["test", "--", "--no-colors"]
setup_commands:
- ["yarn", ["install"]]
holes_file: zod.json
- source:
type: github
owner: helix-editor
name: helix
revision: ae6a0a9cfd377fbfa494760282498cf2ca322782
exclude_paths:
- .cargo
- .github
- book
- contrib
- docs
- helix-core/tests
- helix-term/tests
- helix-tui/tests
- helix-view/tests
- runtime
- xtask
- .envrc
- .gitattributes
- .gitignore
- .ignore
- CHANGELOG.md
- Cargo.lock
- LICENSE
- README.md
- VERSION
- base16_theme.toml
- default.nix
- flake.lock
- flake.nix
- grammars.nix
- languages.toml
- logo.svg
- logo_dark.svg
- logo_light.svg
- rust-toolchain.toml
- rustfmt.toml
- screenshot.png
- shell.nix
- theme.toml
build_command: cargo
build_args: ["build"]
language: rust
runner: cargo
runner_extra_args: ["--workspace"]
setup_commands:
- ["cargo", ["build"]]
holes_file: helix.json

View file

@ -0,0 +1,7 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "simple"
version = "0.1.0"

View file

@ -0,0 +1,10 @@
[workspace]
[package]
name = "simple"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]

View file

@ -0,0 +1,46 @@
fn sum(lhs: i32, rhs: i32) -> i32 {
lhs + rhs
}
fn sub(lhs: i32, rhs: i32) -> i32 {
lhs - rhs
}
fn mul(lhs: i32, rhs: i32) -> i32 {
lhs * rhs
}
fn div(lhs: i32, rhs: i32) -> i32 {
lhs / rhs
}
fn main() {
println!("42 + 42 = {}", sum(42, 42));
println!("41 - 42 = {}", sub(41, 42));
println!("42 * 42 = {}", mul(42, 42));
println!("42 / 42 = {}", div(42, 42));
}
#[cfg(test)]
mod tests {
#[test]
fn test_sum() {
assert_eq!(42 + 42, super::sum(42, 42));
}
#[test]
fn test_sub() {
assert_eq!(42 - 42, super::sub(42, 42));
assert_eq!(41 - 42, super::sub(41, 42));
}
#[test]
fn test_mul() {
assert_eq!(42 * 42, super::mul(42, 42));
}
#[test]
fn test_div() {
assert_eq!(42 / 42, super::div(42, 42));
}
}

View file

@ -0,0 +1,113 @@
use std::{
collections::VecDeque,
path::{Path, PathBuf},
};
use anyhow::anyhow;
use rand::{seq::SliceRandom, Rng};
use ropey::Rope;
use tokio::{
fs::{self, OpenOptions},
io::{AsyncReadExt, AsyncWriteExt},
};
use tracing::info;
use crate::{setup_repo_dir, Hole, RepositoriesConfig};
async fn file_is_empty(file_path: impl AsRef<Path>) -> anyhow::Result<bool> {
let mut content = String::new();
fs::File::open(&file_path)
.await?
.read_to_string(&mut content)
.await?;
Ok(content.trim().is_empty())
}
pub(crate) async fn generate_holes(
repositories_config: RepositoriesConfig,
repos_dir_path: &Path,
holes_dir_path: &Path,
holes_per_repo: usize,
filter_repos: bool,
filter_list: Vec<String>,
) -> anyhow::Result<()> {
let mut rng = rand::thread_rng();
for repo in repositories_config.repositories {
if filter_repos && !filter_list.contains(&repo.name()) {
continue;
}
let repo_name = repo.name();
info!("creating {} holes for {}", holes_per_repo, repo_name);
let (_tmp_dir, path) = setup_repo_dir(repos_dir_path, &repo.source).await?;
let mut files = vec![];
let mut stack = VecDeque::new();
let exclude_paths = repo
.source
.exclude_paths()
.iter()
.map(|p| path.join(p))
.collect::<Vec<PathBuf>>();
stack.push_back(path.join(repo.source.src_path()));
while let Some(src) = stack.pop_back() {
let mut entries = fs::read_dir(&src).await?;
while let Some(entry) = entries.next_entry().await? {
let entry_type = entry.file_type().await?;
let src_path = entry.path();
if exclude_paths.iter().any(|p| src_path.starts_with(p)) {
continue;
}
if entry_type.is_dir() {
stack.push_back(src_path);
} else if entry_type.is_file()
&& repo
.language
.is_code_file(src_path.file_name().unwrap().to_str().unwrap())
&& !file_is_empty(&src_path).await?
{
files.push(src_path);
}
}
}
let mut holes = vec![];
let mut i = 0;
while i < holes_per_repo {
let file_path = files
.choose(&mut rng)
.ok_or(anyhow!("files vec is empty"))?;
let mut content = String::new();
fs::File::open(&file_path)
.await?
.read_to_string(&mut content)
.await?;
let rope = Rope::from_str(&content);
let line_nb = rng.gen_range(0..rope.len_lines());
let line = rope.line(line_nb);
let line_string = line.to_string();
let trimmed = line_string.trim();
if trimmed.starts_with(repo.language.comment_token()) || trimmed.is_empty() {
continue;
}
let column_nb = rng.gen_range(0..15.min(line.len_chars()));
holes.push(Hole::new(
line_nb as u32,
column_nb as u32,
file_path.strip_prefix(&path)?.to_str().unwrap().to_owned(),
));
i += 1;
}
let mut file = OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(&holes_dir_path.join(repo.holes_file))
.await?;
file.write_all(serde_json::to_string(&holes)?.as_bytes())
.await?;
}
Ok(())
}

View file

@ -0,0 +1,44 @@
use std::fmt;
use serde::{Deserialize, Serialize};
// const JS_EXT: [&str; 2] = [".js", ".jsx"];
const PY_EXT: [&str; 1] = [".py"];
const RS_EXT: [&str; 1] = [".rs"];
const TS_EXT: [&str; 3] = [".ts", ".tsx", ".d.ts"];
#[derive(Clone, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub(crate) enum Language {
Python,
Rust,
Typescript,
}
impl fmt::Display for Language {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Python => write!(f, "python"),
Self::Rust => write!(f, "rust"),
Self::Typescript => write!(f, "typescript"),
}
}
}
impl Language {
pub(crate) fn is_code_file(&self, file_name: &str) -> bool {
match self {
Self::Python => PY_EXT.iter().any(|ext| file_name.ends_with(ext)),
Self::Rust => RS_EXT.iter().any(|ext| file_name.ends_with(ext)),
Self::Typescript => TS_EXT.iter().any(|ext| file_name.ends_with(ext)),
}
}
pub(crate) fn comment_token(&self) -> &str {
match self {
Self::Python => "#",
Self::Rust => "//",
Self::Typescript => "//",
}
}
}

717
crates/testbed/src/main.rs Normal file
View file

@ -0,0 +1,717 @@
use std::{
collections::{HashMap, VecDeque},
fmt::Display,
io::BufReader,
path::{Path, PathBuf},
process::Stdio,
sync::Arc,
time::Instant,
};
use anyhow::anyhow;
use clap::Parser;
use futures_util::{stream::FuturesUnordered, StreamExt, TryStreamExt};
use lang::Language;
use lsp_client::{client::LspClient, msg::RequestId, server::Server};
use lsp_types::{
DidOpenTextDocumentParams, InitializeParams, TextDocumentIdentifier, TextDocumentItem,
TextDocumentPositionParams,
};
use ropey::Rope;
use runner::Runner;
use serde::{Deserialize, Serialize};
use tempfile::TempDir;
use tokio::{
fs::{self, read_to_string, File, OpenOptions},
io::{self, AsyncReadExt, AsyncWriteExt},
process::Command,
sync::{RwLock, Semaphore},
};
use tokio_util::compat::FuturesAsyncReadCompatExt;
use tracing::{debug, error, info, info_span, warn, Instrument};
use tracing_subscriber::EnvFilter;
use url::Url;
use crate::{
holes_generator::generate_holes,
runner::run_test,
types::{
FimParams, GetCompletions, GetCompletionsParams, GetCompletionsResult, Ide, RequestParams,
TokenizerConfig,
},
};
mod holes_generator;
mod lang;
mod runner;
mod types;
/// Testbed runs llm-ls' code completion to measure its performance
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Hugging Face Inference API Token
#[arg(short, long)]
api_token: Option<String>,
/// Comma separated list of repos in the repositories file to run completions or holes generation for;
/// matches on path for local repos and `owner/name` for github repos
#[arg(short, long)]
filter: Option<String>,
/// When this is specified, holes files will be generated based on the repositories.yaml file
#[arg(short, long, action)]
generate_holes: bool,
/// Path to the directory containing the holes files
#[arg(short = 'H', long)]
holes_dir_path: Option<String>,
/// Number of holes to create per repository
#[arg(short = 'n', long, default_value_t = 100)]
holes_per_repo: usize,
/// Path to llm-ls' binary
#[arg(short, long)]
llm_ls_bin_path: Option<String>,
/// Path to the local repositories/ directory
#[arg(short = 'R', long)]
repos_dir_path: Option<String>,
/// Path to the repositories.yaml file
#[arg(short, long)]
repos_file_path: Option<String>,
}
#[derive(Clone, Deserialize, Serialize)]
struct LocalRepo {
path: PathBuf,
src_path: String,
#[serde(default)]
exclude_paths: Vec<String>,
}
#[derive(Clone, Deserialize, Serialize)]
struct GithubRepo {
owner: String,
name: String,
revision: String,
#[serde(default)]
src_path: String,
#[serde(default)]
exclude_paths: Vec<String>,
}
#[derive(Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "lowercase")]
enum RepoSource {
Local(LocalRepo),
Github(GithubRepo),
}
impl RepoSource {
fn source_type(&self) -> String {
match self {
Self::Local { .. } => "local".to_owned(),
Self::Github { .. } => "github".to_owned(),
}
}
fn src_path(&self) -> String {
match self {
Self::Local(local) => local.src_path.clone(),
Self::Github(github) => github.src_path.clone(),
}
}
fn exclude_paths(&self) -> Vec<String> {
match self {
Self::Local(local) => local.exclude_paths.clone(),
Self::Github(github) => github.exclude_paths.clone(),
}
}
}
#[derive(Clone, Deserialize, Serialize)]
struct Repository {
build_command: String,
build_args: Vec<String>,
env: Option<Vec<String>>,
holes_file: String,
language: Language,
runner: Runner,
runner_command: Option<String>,
runner_args: Option<Vec<String>>,
#[serde(default)]
runner_extra_args: Vec<String>,
setup_commands: Option<Vec<(String, Vec<String>)>>,
source: RepoSource,
}
impl Repository {
/// can panic if local path is not utf8
fn name(&self) -> String {
match &self.source {
RepoSource::Local(local) => local.path.to_str().unwrap().to_owned(),
RepoSource::Github(github) => format!("{}/{}", github.owner, github.name),
}
}
}
#[derive(Clone, Deserialize, Serialize)]
struct Hole {
cursor: lsp_types::Position,
/// relative path of a file in the repository
file: String,
}
impl Hole {
fn new(line: u32, character: u32, file: String) -> Self {
Self {
cursor: lsp_types::Position::new(line, character),
file,
}
}
}
impl Display for Hole {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{} [{}, {}]",
self.file, self.cursor.line, self.cursor.character
)
}
}
// unused for now, consider all holes as lines
// enum HoleType {
// Line,
// Multiline
// }
#[derive(Clone, Deserialize, Serialize)]
struct RepositoriesConfig {
context_window: usize,
fim: FimParams,
model: String,
request_params: RequestParams,
repositories: Vec<Repository>,
tls_skip_verify_insecure: bool,
tokenizer_config: Option<TokenizerConfig>,
tokens_to_clear: Vec<String>,
}
struct HoleCompletionResult {
repo_name: String,
repo_source_type: String,
pass_percentage: f32,
completion_time_ms: u128,
}
impl HoleCompletionResult {
fn new(
repo_name: String,
repo_source_type: String,
pass_percentage: f32,
completion_time_ms: u128,
) -> Self {
Self {
repo_name,
repo_source_type,
pass_percentage,
completion_time_ms,
}
}
}
async fn get_api_token(args_token: Option<String>) -> anyhow::Result<Option<String>> {
if args_token.is_some() {
Ok(args_token)
} else {
let home_dir = home::home_dir().ok_or(anyhow!("failed to find home dir"))?;
let cached_token = home_dir.join(".cache/huggingface/token");
if cached_token.try_exists()? {
let mut token = String::new();
File::open(cached_token)
.await?
.read_to_string(&mut token)
.await?;
Ok(Some(token.trim().to_owned()))
} else {
Ok(None)
}
}
}
async fn download_repo_from_github(
temp_dir: &TempDir,
repo: &GithubRepo,
) -> anyhow::Result<PathBuf> {
let repo_dir_name = format!("{}-{}", repo.name, repo.revision);
let archive_path = temp_dir.path().join(format!("{}.zip", repo_dir_name));
let mut archive = File::create(&archive_path).await?;
let stream = reqwest::get(&format!(
"https://github.com/{}/{}/archive/{}.zip",
repo.owner, repo.name, repo.revision,
))
.await?
.error_for_status()?
.bytes_stream();
let stream = stream
.map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
.into_async_read();
let mut stream = stream.compat();
io::copy(&mut stream, &mut archive).await?;
let archive = BufReader::new(std::fs::File::open(archive_path)?);
zip::ZipArchive::new(archive)?.extract(temp_dir.path())?;
Ok(temp_dir.path().join(repo_dir_name))
}
async fn copy_dir_contents(source: &Path, dest: &Path) -> anyhow::Result<()> {
let mut stack = VecDeque::new();
stack.push_back((source.to_path_buf(), dest.to_path_buf()));
while let Some((src, dst)) = stack.pop_back() {
let mut entries = fs::read_dir(&src).await?;
while let Some(entry) = entries.next_entry().await? {
let entry_type = entry.file_type().await?;
let src_path = entry.path();
let dst_path = fs::canonicalize(&dst).await?.join(entry.file_name());
if entry_type.is_dir() {
fs::create_dir(&dst_path).await?;
stack.push_back((src_path, dst_path));
} else if entry_type.is_file() {
fs::copy(&src_path, &dst_path).await?;
}
}
}
Ok(())
}
async fn setup_repo_dir(
repos_dir_path: &Path,
source: &RepoSource,
) -> anyhow::Result<(TempDir, PathBuf)> {
match source {
RepoSource::Local(local) => {
debug!("setting up local repo: {}", local.path.to_str().unwrap());
let temp_dir = TempDir::new()?;
copy_dir_contents(&repos_dir_path.join(&local.path), temp_dir.path()).await?;
let repo_path = temp_dir.path().to_path_buf();
Ok((temp_dir, repo_path))
}
RepoSource::Github(github) => {
debug!("setting repo from github: {}/{}", github.owner, github.name);
let temp_dir = TempDir::new()?;
let repo_path = download_repo_from_github(&temp_dir, github).await?;
Ok((temp_dir, repo_path))
}
}
}
fn parse_env(env: &Option<Vec<String>>) -> anyhow::Result<Vec<(String, String)>> {
let mut env_vars = vec![];
if let Some(env) = env {
for var in env {
env_vars.push(
var.split_once('=')
.map(|(n, v)| (n.to_owned(), v.to_owned()))
.ok_or(anyhow!("failed to split env var {var}"))?,
);
}
}
Ok(env_vars)
}
async fn run_setup(
commands: &Vec<(String, Vec<String>)>,
env: &Option<Vec<String>>,
repo_path: impl AsRef<Path>,
) -> anyhow::Result<()> {
let parsed_env = parse_env(env)?;
for command in commands {
let mut status_cmd = Command::new(&command.0);
for (name, value) in &parsed_env {
status_cmd.env(name, value);
}
debug!("running setup command: {} {:?}", command.0, command.1);
let status = status_cmd
.args(&command.1)
.current_dir(&repo_path)
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()?
.wait()
.await?;
if !status.success() {
return Err(anyhow!(
"error running: \"{} {}\"",
command.0,
command.1.join(" ")
));
}
}
Ok(())
}
async fn build(
command: &str,
args: &Vec<String>,
env: &Option<Vec<String>>,
repo_path: impl AsRef<Path>,
) -> anyhow::Result<bool> {
let parsed_env = parse_env(env)?;
let mut status_cmd = Command::new(command);
for (name, value) in parsed_env {
status_cmd.env(name, value);
}
debug!("building repo: {command} {args:?}");
let status = status_cmd
.args(args)
.current_dir(repo_path)
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()?
.wait()
.await?;
Ok(status.success())
}
#[allow(clippy::too_many_arguments)]
async fn complete_holes(
repo: Repository,
client: Arc<LspClient>,
file_cache: Arc<RwLock<HashMap<PathBuf, Rope>>>,
holes_dir_path: PathBuf,
repos_dir_path: PathBuf,
repos_config: RepositoriesConfig,
api_token: Option<String>,
semaphore: Arc<Semaphore>,
) -> anyhow::Result<Vec<HoleCompletionResult>> {
let permit = semaphore.acquire_owned().await?;
let span = info_span!("complete_hole", repo_name = repo.name());
async move {
let holes_file_path = holes_dir_path.join(&repo.holes_file);
let mut holes = String::new();
File::open(holes_file_path)
.await?
.read_to_string(&mut holes)
.await?;
let holes: Vec<Hole> = serde_json::from_str(&holes)?;
let ten_percent = if holes.len() >= 10 {
holes.len() / 10
} else {
1
};
info!("running {} hole completions", holes.len());
let RepositoriesConfig {
context_window,
fim,
model,
request_params,
tls_skip_verify_insecure,
tokenizer_config,
tokens_to_clear,
..
} = repos_config;
let (_temp_dir, repo_path) = setup_repo_dir(&repos_dir_path, &repo.source).await?;
if let Some(commands) = &repo.setup_commands {
run_setup(commands, &repo.env, &repo_path).await?;
}
let mut hole_completions_result = Vec::with_capacity(holes.len());
for (idx, hole) in holes.iter().enumerate() {
let hole_instant = Instant::now();
let file_path = repo_path.join(&hole.file);
let file_path_str = file_path
.to_str()
.ok_or(anyhow!("failed to convert file to str"))?;
let mut file_content = if file_cache.read().await.contains_key(&file_path) {
file_cache
.read()
.await
.get(&file_path)
.ok_or(anyhow!("failed to find {} in file cache", file_path_str))?
.to_owned()
} else {
let file_content = Rope::from_str(&read_to_string(&file_path).await?);
file_cache
.write()
.await
.insert(file_path.clone(), file_content.clone());
file_content
};
let original_content = file_content.clone();
let hole_start = file_content.line_to_char(hole.cursor.line as usize)
+ hole.cursor.character as usize;
let hole_end = hole_start
+ file_content
.line(hole.cursor.line as usize)
.slice(hole.cursor.character as usize..)
.len_chars()
- 1;
file_content.remove(hole_start..hole_end);
let uri = Url::parse(&format!("file:/{file_path_str}"))?;
client.send_notification::<lsp_types::notification::DidOpenTextDocument>(
DidOpenTextDocumentParams {
text_document: TextDocumentItem {
uri: uri.clone(),
language_id: repo.language.to_string(),
version: 0,
text: file_content.to_string(),
},
},
);
let response = client
.send_request::<GetCompletions>(GetCompletionsParams {
api_token: api_token.clone(),
context_window,
fim: fim.clone(),
ide: Ide::default(),
model: model.clone(),
request_params: request_params.clone(),
text_document_position: TextDocumentPositionParams {
position: hole.cursor,
text_document: TextDocumentIdentifier { uri },
},
tls_skip_verify_insecure,
tokens_to_clear: tokens_to_clear.clone(),
tokenizer_config: tokenizer_config.clone(),
})
.await?;
let (_, result): (RequestId, GetCompletionsResult) = match response.extract() {
Ok(res) => res,
Err(err) => {
error!("llm-ls response error: {err}");
continue;
}
};
file_content.insert(hole_start, &result.completions[0].generated_text);
let mut file = OpenOptions::new()
.write(true)
.truncate(true)
.open(&file_path)
.await?;
file.write_all(file_content.to_string().as_bytes()).await?;
let test_percentage =
if build(&repo.build_command, &repo.build_args, &repo.env, &repo_path).await? {
run_test(
repo.runner,
&repo.runner_command,
&repo.runner_args,
&mut repo.runner_extra_args.clone(),
&repo.env,
&repo_path,
)
.await?
} else {
0f32
};
debug!("{} passed {}%", hole.to_string(), test_percentage * 100f32);
hole_completions_result.push(HoleCompletionResult::new(
repo.name(),
repo.source.source_type(),
test_percentage,
hole_instant.elapsed().as_millis(),
));
let mut file = OpenOptions::new()
.write(true)
.truncate(true)
.open(&file_path)
.await?;
file.write_all(original_content.to_string().as_bytes())
.await?;
if (idx + 1) % ten_percent == 0 {
info!("completed {}%", (idx + 1) / ten_percent * 10);
}
}
drop(permit);
info!("finished running hole completions");
Ok(hole_completions_result)
}
.instrument(span)
.await
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_target(true)
.with_line_number(true)
.with_env_filter(
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info")),
)
.init();
let args = Args::parse();
let api_token = get_api_token(args.api_token).await?;
let current_dir = std::env::current_dir()?;
let llm_ls_path = if let Some(bin_path) = args.llm_ls_bin_path {
bin_path.into()
} else {
current_dir.join("target/release/llm-ls")
};
let repos_dir_path = if let Some(path) = args.repos_dir_path {
path.into()
} else {
current_dir.join("crates/testbed/repositories")
};
let repos_file_path = if let Some(path) = args.repos_file_path {
path.into()
} else {
current_dir.join("crates/testbed/repositories.yaml")
};
let holes_dir_path = if let Some(path) = args.holes_dir_path {
path.into()
} else {
current_dir.join("crates/testbed/holes")
};
let (filter_repos, filter_list) = if let Some(filter) = args.filter {
(true, filter.split(',').map(|s| s.to_owned()).collect())
} else {
(false, vec![])
};
let mut repos_file = String::new();
File::open(&repos_file_path)
.await?
.read_to_string(&mut repos_file)
.await?;
let repos_config: RepositoriesConfig = serde_yaml::from_str(&repos_file)?;
if args.generate_holes {
return generate_holes(
repos_config,
&repos_dir_path,
&holes_dir_path,
args.holes_per_repo,
filter_repos,
filter_list,
)
.await;
}
debug!(
"initializing language server at path: {}",
llm_ls_path.to_str().unwrap()
);
let (conn, server) = Server::build().binary_path(llm_ls_path).start().await?;
let client = Arc::new(LspClient::new(conn, server).await);
client
.send_request::<lsp_types::request::Initialize>(InitializeParams::default())
.await?;
let file_cache = Arc::new(RwLock::new(HashMap::new()));
let mut passing_tests_percentage = vec![];
let repositories = repos_config.repositories.clone();
let mut handles = FuturesUnordered::new();
// Query the model by batches of 64
let semaphore = Arc::new(Semaphore::new(8));
for repo in repositories {
if filter_repos && !filter_list.contains(&repo.name()) {
continue;
}
let client = client.clone();
let file_cache = file_cache.clone();
let holes_dir_path = holes_dir_path.clone();
let repos_dir_path = repos_dir_path.clone();
let repos_config = repos_config.clone();
let api_token = api_token.clone();
let semaphore = semaphore.clone();
handles.push(tokio::spawn(async move {
complete_holes(
repo,
client,
file_cache,
holes_dir_path,
repos_dir_path,
repos_config,
api_token,
semaphore,
)
.await
}));
}
while let Some(res) = handles.next().await {
match res {
Ok(Ok(res)) => passing_tests_percentage.extend(res),
Ok(Err(err)) => return Err(err),
Err(err) => return Err(err.into()),
}
}
let mut results_map: HashMap<(String, String), (u128, f32, f32)> = HashMap::new();
for res in passing_tests_percentage {
results_map
.entry((res.repo_name, res.repo_source_type))
.and_modify(|p| {
p.0 += res.completion_time_ms;
p.1 += res.pass_percentage;
p.2 += 1f32;
})
.or_insert((res.completion_time_ms, res.pass_percentage, 1f32));
}
let mut results_table =
"| Repository name | Source type | Average hole completion time (s) | Pass percentage |\n| :-------------- | :---------- | -------------------------------: | --------------: |\n".to_owned();
let mut total_time = 0;
let mut total_percentage = 0f32;
let mut total_count = 0f32;
for (k, v) in results_map.iter() {
let avg = v.1 / v.2;
let avg_time = v.0 as f32 / v.2;
results_table.push_str(&format!(
"| {} | {} | {} | {}% |\n",
k.0,
k.1,
avg_time / 1_000f32,
avg * 100f32
));
total_percentage += v.1;
total_count += v.2;
total_time += v.0;
}
let total_avg = total_percentage / total_count;
let total_time_avg = total_time as f32 / total_count;
results_table.push_str(&format!(
"| **Total** | -- | {} | {}% |\n\n",
total_time_avg / 1_000f32,
total_avg * 100f32
));
results_table.push_str(
&[
"**Note:** The \"hole completion time\" represents the full process of:",
" - replacing the code from the file with a completion from the model",
" - building the project",
" - running the tests",
]
.join("\n"),
);
info!("llm-ls results:\n{}", results_table);
OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open("results.md")
.await?
.write_all(results_table.as_bytes())
.await?;
client.shutdown().await?;
match Arc::into_inner(client) {
Some(client) => client.exit().await,
None => warn!("could not send exit notification because client is referenced elsewhere"),
}
Ok(())
}

View file

@ -0,0 +1,278 @@
use std::{path::Path, process::Stdio};
use anyhow::anyhow;
use serde::{Deserialize, Serialize};
use tokio::{io::AsyncReadExt, process::Command};
use tracing::debug;
use crate::parse_env;
#[derive(Deserialize, Serialize)]
struct TestSuiteResult {
r#type: String,
event: String,
passed: u32,
failed: u32,
ignored: u32,
measured: u32,
filtered_out: u32,
exec_time: f64,
}
async fn pytest_runner(
override_cmd: &Option<String>,
extra_args: &mut Vec<String>,
repo_path: &Path,
) -> anyhow::Result<f32> {
let cmd = if let Some(cmd) = override_cmd {
cmd
} else {
"python3"
};
let mut args = vec![
"-m".to_owned(),
"pytest".to_owned(),
"tests".to_owned(),
"-q".to_owned(),
"--disable-warnings".to_owned(),
"--no-header".to_owned(),
];
args.append(extra_args);
debug!("running pytest tests: {cmd} {args:?}");
let mut child = Command::new(cmd)
.args(args)
.current_dir(repo_path)
.stdout(Stdio::piped())
.stderr(Stdio::null())
.spawn()?;
let mut stdout = String::new();
child
.stdout
.take()
.ok_or(anyhow!("failed to take stdout"))?
.read_to_string(&mut stdout)
.await?;
// XXX: the pytest command can still fail even after the compilation check
// the above check should prevent an error, but better safe than sorry
let lines = stdout.split_terminator('\n');
let result = match lines.last() {
Some(line) => line.replace('=', "").trim().to_owned(),
None => return Ok(0f32),
};
let mut passed = 0f32;
let mut failed = 0f32;
let without_time = &result[0..result.find("in").unwrap_or(result.len())].trim();
for res in without_time.split(", ") {
if res.contains("passed") {
let passed_str = res.replace(" passed", "");
passed = passed_str.parse::<u32>()? as f32;
} else if res.contains("failed") && !res.contains("xfailed") {
let failed_str = res.replace(" failed", "");
failed = failed_str.parse::<u32>()? as f32;
} else if res.contains("error") {
return Ok(0f32);
}
}
if passed == 0f32 && failed == 0f32 {
return Ok(0f32);
}
Ok(passed / (passed + failed))
}
async fn cargo_runner(
override_cmd: &Option<String>,
extra_args: &mut Vec<String>,
env: &Option<Vec<String>>,
repo_path: &Path,
) -> anyhow::Result<f32> {
let cmd = if let Some(cmd) = override_cmd {
cmd
} else {
"cargo"
};
let mut args = vec![];
args.append(extra_args);
if !args.contains(&"--".to_owned()) {
args.push("--".to_owned());
}
args.extend([
"-Z".to_owned(),
"unstable-options".to_owned(),
"--format".to_owned(),
"json".to_owned(),
]);
debug!("running cargo tests: {cmd} test {args:?}");
let parsed_env = parse_env(env)?;
let mut cmd = Command::new(cmd);
for (name, value) in parsed_env {
cmd.env(name, value);
}
let mut child = cmd
.arg("test")
.args(args)
.current_dir(repo_path)
.stdout(Stdio::piped())
.stderr(Stdio::null())
.spawn()?;
let mut stdout = String::new();
child
.stdout
.take()
.ok_or(anyhow!("failed to take stdout"))?
.read_to_string(&mut stdout)
.await?;
let lines = stdout.split_terminator('\n');
let mut passed = 0;
let mut failed = 0;
for line in lines {
let test_suite_result = match serde_json::from_str::<TestSuiteResult>(line) {
Ok(res) => res,
Err(_) => continue,
};
passed += test_suite_result.passed;
failed += test_suite_result.failed;
}
if passed == 0 && failed == 0 {
return Ok(0f32);
}
Ok(passed as f32 / (passed as f32 + failed as f32))
}
async fn jest_runner(
override_cmd: &Option<String>,
override_args: &Option<Vec<String>>,
repo_path: &Path,
) -> anyhow::Result<f32> {
let cmd = if let Some(cmd) = override_cmd {
cmd
} else {
"npm"
};
let default_args = vec!["run".to_owned(), "test".to_owned()];
let args = if let Some(args) = override_args {
args
} else {
&default_args
};
debug!("running jest tests: {cmd} {args:?}");
let mut child = Command::new(cmd)
.args(args)
.current_dir(repo_path)
.stdout(Stdio::null())
.stderr(Stdio::piped())
.spawn()?;
let mut stderr = String::new();
child
.stderr
.take()
.ok_or(anyhow!("failed to take stderr"))?
.read_to_string(&mut stderr)
.await?;
let lines = stderr.split_terminator('\n');
let mut passed = 0f32;
let mut failed = 0f32;
for line in lines {
if line.contains("Tests:") {
let words = line.trim().split(' ').collect::<Vec<&str>>();
let mut prev = words[0];
for word in words {
if word.contains("passed") {
passed = prev.parse::<u32>()? as f32;
} else if line.contains("failed") {
failed = prev.parse::<u32>()? as f32;
}
prev = word;
}
}
}
if passed == 0f32 && failed == 0f32 {
return Ok(0f32);
}
Ok(passed / (passed + failed))
}
async fn vitest_runner(
override_cmd: &Option<String>,
override_args: &Option<Vec<String>>,
repo_path: &Path,
) -> anyhow::Result<f32> {
let cmd = if let Some(cmd) = override_cmd {
cmd
} else {
"npm"
};
let default_args = vec!["run".to_owned(), "test".to_owned()];
let args = if let Some(args) = override_args {
args
} else {
&default_args
};
debug!("running vitest tests: {cmd} {args:?}");
let mut child = Command::new(cmd)
.args(args)
.current_dir(repo_path)
.stdout(Stdio::piped())
.stderr(Stdio::null())
.spawn()?;
let mut stdout = String::new();
child
.stdout
.take()
.ok_or(anyhow!("failed to take stdout"))?
.read_to_string(&mut stdout)
.await?;
let lines = stdout.split_terminator('\n');
let mut passed = 0f32;
let mut failed = 0f32;
for line in lines {
if line.contains(" Tests ") {
let words = line.trim().split(' ').collect::<Vec<&str>>();
let mut prev = words[0];
for word in words {
if word.contains("passed") {
passed = prev.parse::<u32>()? as f32;
} else if line.contains("failed") {
failed = prev.parse::<u32>()? as f32;
}
prev = word;
}
}
}
if passed == 0f32 && failed == 0f32 {
return Ok(0f32);
}
Ok(passed / (passed + failed))
}
#[derive(Clone, Copy, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum Runner {
Cargo,
Jest,
Pytest,
Vitest,
}
pub async fn run_test(
runner: Runner,
override_cmd: &Option<String>,
override_args: &Option<Vec<String>>,
extra_args: &mut Vec<String>,
env: &Option<Vec<String>>,
repo_path: &Path,
) -> anyhow::Result<f32> {
match runner {
Runner::Cargo => cargo_runner(override_cmd, extra_args, env, repo_path).await,
Runner::Jest => jest_runner(override_cmd, override_args, repo_path).await,
Runner::Pytest => pytest_runner(override_cmd, extra_args, repo_path).await,
Runner::Vitest => vitest_runner(override_cmd, override_args, repo_path).await,
}
}

View file

@ -0,0 +1,88 @@
use std::path::PathBuf;
use lsp_types::{request::Request, TextDocumentPositionParams};
use serde::{Deserialize, Deserializer, Serialize};
use uuid::Uuid;
#[derive(Debug)]
pub(crate) enum GetCompletions {}
impl Request for GetCompletions {
type Params = GetCompletionsParams;
type Result = GetCompletionsResult;
const METHOD: &'static str = "llm-ls/getCompletions";
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub(crate) struct RequestParams {
pub(crate) max_new_tokens: u32,
pub(crate) temperature: f32,
pub(crate) do_sample: bool,
pub(crate) top_p: f32,
pub(crate) stop_tokens: Option<Vec<String>>,
}
#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub(crate) enum Ide {
Neovim,
VSCode,
JetBrains,
Emacs,
Jupyter,
Sublime,
VisualStudio,
#[default]
Unknown,
}
fn parse_ide<'de, D>(d: D) -> std::result::Result<Ide, D::Error>
where
D: Deserializer<'de>,
{
Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(Ide::Unknown))
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub(crate) struct FimParams {
pub(crate) enabled: bool,
pub(crate) prefix: String,
pub(crate) middle: String,
pub(crate) suffix: String,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(untagged)]
pub(crate) enum TokenizerConfig {
Local { path: PathBuf },
HuggingFace { repository: String },
Download { url: String, to: PathBuf },
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub(crate) struct GetCompletionsParams {
#[serde(flatten)]
pub(crate) text_document_position: TextDocumentPositionParams,
pub(crate) request_params: RequestParams,
#[serde(default)]
#[serde(deserialize_with = "parse_ide")]
pub(crate) ide: Ide,
pub(crate) fim: FimParams,
pub(crate) api_token: Option<String>,
pub(crate) model: String,
pub(crate) tokens_to_clear: Vec<String>,
pub(crate) tokenizer_config: Option<TokenizerConfig>,
pub(crate) context_window: usize,
pub(crate) tls_skip_verify_insecure: bool,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub(crate) struct Completion {
pub(crate) generated_text: String,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub(crate) struct GetCompletionsResult {
request_id: Uuid,
pub(crate) completions: Vec<Completion>,
}