From 63208879da9adc458cf05504ef6f5fef9bb84441 Mon Sep 17 00:00:00 2001 From: Aram Drevekenin Date: Fri, 11 Oct 2024 15:26:05 +0200 Subject: [PATCH] fix(plugins): handle concurrent http downloads (#3664) --- zellij-server/src/plugins/wasm_bridge.rs | 7 +++++- zellij-utils/src/downloader.rs | 28 +++++++++++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/zellij-server/src/plugins/wasm_bridge.rs b/zellij-server/src/plugins/wasm_bridge.rs index 799642e2..e1529869 100644 --- a/zellij-server/src/plugins/wasm_bridge.rs +++ b/zellij-server/src/plugins/wasm_bridge.rs @@ -107,6 +107,7 @@ pub struct WasmBridge { default_keybinds: Keybinds, keybinds: HashMap, base_modes: HashMap, + downloader: Downloader, } impl WasmBridge { @@ -129,6 +130,7 @@ impl WasmBridge { let plugin_cache: Arc>> = Arc::new(Mutex::new(HashMap::new())); let watcher = None; + let downloader = Downloader::new(ZELLIJ_CACHE_DIR.to_path_buf()); WasmBridge { connected_clients, senders, @@ -157,6 +159,7 @@ impl WasmBridge { default_keybinds, keybinds: HashMap::new(), base_modes: HashMap::new(), + downloader, } } pub fn load_plugin( @@ -213,6 +216,7 @@ impl WasmBridge { let default_shell = self.default_shell.clone(); let default_layout = self.default_layout.clone(); let layout_dir = self.layout_dir.clone(); + let downloader = self.downloader.clone(); let default_mode = self .base_modes .get(&client_id) @@ -236,7 +240,8 @@ impl WasmBridge { .map(ToString::to_string) .collect(); - let downloader = Downloader::new(ZELLIJ_CACHE_DIR.to_path_buf()); + // if the url is already in cache, we'll use that version, otherwise + // we'll download it, place it in cache and then use it match downloader.download(url, Some(&file_name)).await { Ok(_) => plugin.path = ZELLIJ_CACHE_DIR.join(&file_name), Err(e) => handle_plugin_loading_failure( diff --git a/zellij-utils/src/downloader.rs b/zellij-utils/src/downloader.rs index e9f447b0..2fd1fd41 100644 --- a/zellij-utils/src/downloader.rs +++ b/zellij-utils/src/downloader.rs @@ -1,3 +1,4 @@ +use async_std::sync::Mutex; use async_std::{ fs, io::{ReadExt, WriteExt}, @@ -5,7 +6,9 @@ use async_std::{ }; use isahc::prelude::*; use isahc::{config::RedirectPolicy, HttpClient, Request}; +use std::collections::HashMap; use std::path::PathBuf; +use std::sync::Arc; use thiserror::Error; use url::Url; @@ -17,16 +20,22 @@ pub enum DownloaderError { HttpError(#[from] isahc::http::Error), #[error("IoError: {0}")] Io(#[source] std::io::Error), + #[error("StdIoError: {0}")] + StdIoError(#[from] std::io::Error), #[error("File name cannot be found in URL: {0}")] NotFoundFileName(String), #[error("Failed to parse URL body: {0}")] InvalidUrlBody(String), } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Downloader { client: Option, location: PathBuf, + // the whole thing is an Arc/Mutex so that Downloader is thread safe, and the individual values of + // the HashMap are Arc/Mutexes (Mutexi?) to represent that individual downloads should not + // happen concurrently + download_locks: Arc>>>>, } impl Default for Downloader { @@ -38,6 +47,7 @@ impl Default for Downloader { .build() .ok(), location: PathBuf::from(""), + download_locks: Default::default(), } } } @@ -51,6 +61,7 @@ impl Downloader { .build() .ok(), location, + download_locks: Default::default(), } } @@ -67,6 +78,14 @@ impl Downloader { Some(name) => name.to_string(), None => self.parse_name(url)?, }; + + // we do this to make sure only one download of a specific url is happening at a time + // otherwise the downloads corrupt each other (and we waste lots of system resources) + let download_lock = self.acquire_download_lock(&file_name).await; + // it's important that _lock remains in scope, otherwise it gets dropped and the lock is + // released before the download is complete + let _lock = download_lock.lock().await; + let file_path = self.location.join(file_name.as_str()); if file_path.exists() { log::debug!("File already exists: {:?}", file_path); @@ -157,6 +176,13 @@ impl Downloader { .ok_or_else(|| DownloaderError::NotFoundFileName(url.to_string())) .map(|s| s.to_string()) } + async fn acquire_download_lock(&self, file_name: &String) -> Arc> { + let mut lock_dict = self.download_locks.lock().await; + let download_lock = lock_dict + .entry(file_name.clone()) + .or_insert_with(|| Default::default()); + download_lock.clone() + } } #[cfg(test)]