fix(plugins): handle concurrent http downloads (#3664)
This commit is contained in:
parent
ec1eea3ba1
commit
63208879da
2 changed files with 33 additions and 2 deletions
|
|
@ -107,6 +107,7 @@ pub struct WasmBridge {
|
|||
default_keybinds: Keybinds,
|
||||
keybinds: HashMap<ClientId, Keybinds>,
|
||||
base_modes: HashMap<ClientId, InputMode>,
|
||||
downloader: Downloader,
|
||||
}
|
||||
|
||||
impl WasmBridge {
|
||||
|
|
@ -129,6 +130,7 @@ impl WasmBridge {
|
|||
let plugin_cache: Arc<Mutex<HashMap<PathBuf, Module>>> =
|
||||
Arc::new(Mutex::new(HashMap::new()));
|
||||
let watcher = None;
|
||||
let downloader = Downloader::new(ZELLIJ_CACHE_DIR.to_path_buf());
|
||||
WasmBridge {
|
||||
connected_clients,
|
||||
senders,
|
||||
|
|
@ -157,6 +159,7 @@ impl WasmBridge {
|
|||
default_keybinds,
|
||||
keybinds: HashMap::new(),
|
||||
base_modes: HashMap::new(),
|
||||
downloader,
|
||||
}
|
||||
}
|
||||
pub fn load_plugin(
|
||||
|
|
@ -213,6 +216,7 @@ impl WasmBridge {
|
|||
let default_shell = self.default_shell.clone();
|
||||
let default_layout = self.default_layout.clone();
|
||||
let layout_dir = self.layout_dir.clone();
|
||||
let downloader = self.downloader.clone();
|
||||
let default_mode = self
|
||||
.base_modes
|
||||
.get(&client_id)
|
||||
|
|
@ -236,7 +240,8 @@ impl WasmBridge {
|
|||
.map(ToString::to_string)
|
||||
.collect();
|
||||
|
||||
let downloader = Downloader::new(ZELLIJ_CACHE_DIR.to_path_buf());
|
||||
// if the url is already in cache, we'll use that version, otherwise
|
||||
// we'll download it, place it in cache and then use it
|
||||
match downloader.download(url, Some(&file_name)).await {
|
||||
Ok(_) => plugin.path = ZELLIJ_CACHE_DIR.join(&file_name),
|
||||
Err(e) => handle_plugin_loading_failure(
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
use async_std::sync::Mutex;
|
||||
use async_std::{
|
||||
fs,
|
||||
io::{ReadExt, WriteExt},
|
||||
|
|
@ -5,7 +6,9 @@ use async_std::{
|
|||
};
|
||||
use isahc::prelude::*;
|
||||
use isahc::{config::RedirectPolicy, HttpClient, Request};
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use url::Url;
|
||||
|
||||
|
|
@ -17,16 +20,22 @@ pub enum DownloaderError {
|
|||
HttpError(#[from] isahc::http::Error),
|
||||
#[error("IoError: {0}")]
|
||||
Io(#[source] std::io::Error),
|
||||
#[error("StdIoError: {0}")]
|
||||
StdIoError(#[from] std::io::Error),
|
||||
#[error("File name cannot be found in URL: {0}")]
|
||||
NotFoundFileName(String),
|
||||
#[error("Failed to parse URL body: {0}")]
|
||||
InvalidUrlBody(String),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Downloader {
|
||||
client: Option<HttpClient>,
|
||||
location: PathBuf,
|
||||
// the whole thing is an Arc/Mutex so that Downloader is thread safe, and the individual values of
|
||||
// the HashMap are Arc/Mutexes (Mutexi?) to represent that individual downloads should not
|
||||
// happen concurrently
|
||||
download_locks: Arc<Mutex<HashMap<String, Arc<Mutex<()>>>>>,
|
||||
}
|
||||
|
||||
impl Default for Downloader {
|
||||
|
|
@ -38,6 +47,7 @@ impl Default for Downloader {
|
|||
.build()
|
||||
.ok(),
|
||||
location: PathBuf::from(""),
|
||||
download_locks: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -51,6 +61,7 @@ impl Downloader {
|
|||
.build()
|
||||
.ok(),
|
||||
location,
|
||||
download_locks: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -67,6 +78,14 @@ impl Downloader {
|
|||
Some(name) => name.to_string(),
|
||||
None => self.parse_name(url)?,
|
||||
};
|
||||
|
||||
// we do this to make sure only one download of a specific url is happening at a time
|
||||
// otherwise the downloads corrupt each other (and we waste lots of system resources)
|
||||
let download_lock = self.acquire_download_lock(&file_name).await;
|
||||
// it's important that _lock remains in scope, otherwise it gets dropped and the lock is
|
||||
// released before the download is complete
|
||||
let _lock = download_lock.lock().await;
|
||||
|
||||
let file_path = self.location.join(file_name.as_str());
|
||||
if file_path.exists() {
|
||||
log::debug!("File already exists: {:?}", file_path);
|
||||
|
|
@ -157,6 +176,13 @@ impl Downloader {
|
|||
.ok_or_else(|| DownloaderError::NotFoundFileName(url.to_string()))
|
||||
.map(|s| s.to_string())
|
||||
}
|
||||
async fn acquire_download_lock(&self, file_name: &String) -> Arc<Mutex<()>> {
|
||||
let mut lock_dict = self.download_locks.lock().await;
|
||||
let download_lock = lock_dict
|
||||
.entry(file_name.clone())
|
||||
.or_insert_with(|| Default::default());
|
||||
download_lock.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue