fix(plugins): handle concurrent http downloads (#3664)

This commit is contained in:
Aram Drevekenin 2024-10-11 15:26:05 +02:00 committed by GitHub
parent ec1eea3ba1
commit 63208879da
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 33 additions and 2 deletions

View file

@ -107,6 +107,7 @@ pub struct WasmBridge {
default_keybinds: Keybinds,
keybinds: HashMap<ClientId, Keybinds>,
base_modes: HashMap<ClientId, InputMode>,
downloader: Downloader,
}
impl WasmBridge {
@ -129,6 +130,7 @@ impl WasmBridge {
let plugin_cache: Arc<Mutex<HashMap<PathBuf, Module>>> =
Arc::new(Mutex::new(HashMap::new()));
let watcher = None;
let downloader = Downloader::new(ZELLIJ_CACHE_DIR.to_path_buf());
WasmBridge {
connected_clients,
senders,
@ -157,6 +159,7 @@ impl WasmBridge {
default_keybinds,
keybinds: HashMap::new(),
base_modes: HashMap::new(),
downloader,
}
}
pub fn load_plugin(
@ -213,6 +216,7 @@ impl WasmBridge {
let default_shell = self.default_shell.clone();
let default_layout = self.default_layout.clone();
let layout_dir = self.layout_dir.clone();
let downloader = self.downloader.clone();
let default_mode = self
.base_modes
.get(&client_id)
@ -236,7 +240,8 @@ impl WasmBridge {
.map(ToString::to_string)
.collect();
let downloader = Downloader::new(ZELLIJ_CACHE_DIR.to_path_buf());
// if the url is already in cache, we'll use that version, otherwise
// we'll download it, place it in cache and then use it
match downloader.download(url, Some(&file_name)).await {
Ok(_) => plugin.path = ZELLIJ_CACHE_DIR.join(&file_name),
Err(e) => handle_plugin_loading_failure(

View file

@ -1,3 +1,4 @@
use async_std::sync::Mutex;
use async_std::{
fs,
io::{ReadExt, WriteExt},
@ -5,7 +6,9 @@ use async_std::{
};
use isahc::prelude::*;
use isahc::{config::RedirectPolicy, HttpClient, Request};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use thiserror::Error;
use url::Url;
@ -17,16 +20,22 @@ pub enum DownloaderError {
HttpError(#[from] isahc::http::Error),
#[error("IoError: {0}")]
Io(#[source] std::io::Error),
#[error("StdIoError: {0}")]
StdIoError(#[from] std::io::Error),
#[error("File name cannot be found in URL: {0}")]
NotFoundFileName(String),
#[error("Failed to parse URL body: {0}")]
InvalidUrlBody(String),
}
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Downloader {
client: Option<HttpClient>,
location: PathBuf,
// the whole thing is an Arc/Mutex so that Downloader is thread safe, and the individual values of
// the HashMap are Arc/Mutexes (Mutexi?) to represent that individual downloads should not
// happen concurrently
download_locks: Arc<Mutex<HashMap<String, Arc<Mutex<()>>>>>,
}
impl Default for Downloader {
@ -38,6 +47,7 @@ impl Default for Downloader {
.build()
.ok(),
location: PathBuf::from(""),
download_locks: Default::default(),
}
}
}
@ -51,6 +61,7 @@ impl Downloader {
.build()
.ok(),
location,
download_locks: Default::default(),
}
}
@ -67,6 +78,14 @@ impl Downloader {
Some(name) => name.to_string(),
None => self.parse_name(url)?,
};
// we do this to make sure only one download of a specific url is happening at a time
// otherwise the downloads corrupt each other (and we waste lots of system resources)
let download_lock = self.acquire_download_lock(&file_name).await;
// it's important that _lock remains in scope, otherwise it gets dropped and the lock is
// released before the download is complete
let _lock = download_lock.lock().await;
let file_path = self.location.join(file_name.as_str());
if file_path.exists() {
log::debug!("File already exists: {:?}", file_path);
@ -157,6 +176,13 @@ impl Downloader {
.ok_or_else(|| DownloaderError::NotFoundFileName(url.to_string()))
.map(|s| s.to_string())
}
async fn acquire_download_lock(&self, file_name: &String) -> Arc<Mutex<()>> {
let mut lock_dict = self.download_locks.lock().await;
let download_lock = lock_dict
.entry(file_name.clone())
.or_insert_with(|| Default::default());
download_lock.clone()
}
}
#[cfg(test)]