From 0e1f98f40bc93aa2ab7d1d66399d14cbddc04fe0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gerg=C5=91=20S=C3=A1lyi?= Date: Sat, 12 Apr 2025 19:51:53 +0200 Subject: [PATCH] Add unix signal handling Implement unix signal handling using a basic self pipe Shutdown gracefully on termination signals TERM, INT or HUP Reserve signals USR1 and USR2 for future use --- Cargo.lock | 2 + Cargo.toml | 2 + src/main.rs | 34 +++++- src/signal.rs | 278 ++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 315 insertions(+), 1 deletion(-) create mode 100644 src/signal.rs diff --git a/Cargo.lock b/Cargo.lock index 8a2038b..5a21766 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -697,9 +697,11 @@ dependencies = [ "env_logger", "fast_image_resize", "image", + "libc", "log", "mio", "niri-ipc", + "rustix", "serde", "serde_json", "smithay-client-toolkit", diff --git a/Cargo.toml b/Cargo.toml index e6a9304..6d0911a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,8 +19,10 @@ clap = { version = "4.5.3", features = ["derive"] } env_logger = "0.11.3" fast_image_resize = "5.0.0" image = "0.25.0" +libc = "0.2.171" log = "0.4.21" mio = { version = "1.0.2", features = ["os-ext", "os-poll"] } +rustix = "0.38.44" serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0.140" swayipc = "3.0.2" diff --git a/src/main.rs b/src/main.rs index 8715c0d..3ecdf54 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ +mod compositors; mod cli; mod image; -mod compositors; +mod signal; mod wayland; use std::{ @@ -37,6 +38,7 @@ use smithay_client_toolkit::reexports::protocols use crate::{ cli::{Cli, PixelFormat}, compositors::{Compositor, ConnectionTask, WorkspaceVisible}, + signal::SignalPipe, wayland::State, }; @@ -128,6 +130,22 @@ fn run() -> anyhow::Result<()> { const SWAY: Token = Token(1); ConnectionTask::spawn_subscribe_event_loop(compositor, tx, waker); + const SIGNAL: Token = Token(2); + let signal_pipe = match SignalPipe::new() { + Ok(signal_pipe) => { + poll.registry().register( + &mut SourceFd(&signal_pipe.as_raw_fd()), + SIGNAL, + Interest::READABLE + ).unwrap(); + Some(signal_pipe) + }, + Err(e) => { + error!("Failed to set up signal handling: {e}"); + None + } + }; + loop { event_queue.flush().unwrap(); event_queue.dispatch_pending(&mut state).unwrap(); @@ -150,6 +168,20 @@ fn run() -> anyhow::Result<()> { &mut event_queue ), SWAY => handle_sway_event(&mut state, &rx), + SIGNAL => match signal_pipe.as_ref().unwrap().read() { + Err(e) => error!("Failed to read the signal pipe: {e}"), + Ok(signal_flags) => { + if let Some(signal) = signal_flags.any_termination() { + info!("Received signal {signal}, exiting"); + return Ok(()); + } else if signal_flags.has_usr1() + || signal_flags.has_usr2() + { + error!("Received signal USR1 or USR2 is \ + reserved for future functionality"); + } + }, + }, _ => unreachable!() } } diff --git a/src/signal.rs b/src/signal.rs new file mode 100644 index 0000000..8daff26 --- /dev/null +++ b/src/signal.rs @@ -0,0 +1,278 @@ +use std::{ + ffi::c_int, + io, + mem::{ManuallyDrop, MaybeUninit}, + os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd}, + ptr, + sync::atomic::{AtomicI32, Ordering::Relaxed}, +}; + +use libc::{ + raise, SA_RESETHAND, SA_RESTART, SIG_DFL, SIG_ERR, + SIGHUP, SIGINT, SIGUSR1, SIGUSR2, SIGTERM, + sigaction, sigemptyset, signal, sigset_t, write, +}; +use rustix::{ + fs::{fcntl_setfl, OFlags}, + io::{fcntl_setfd, FdFlags, read_uninit}, + pipe::pipe, +}; + +const TERM_SIGNALS: [c_int; 3] = [SIGHUP, SIGINT, SIGTERM]; +const OTHER_SIGNALS: [c_int; 2] = [SIGUSR1, SIGUSR2]; + +const TERM: u8 = 1 << 0; +const INT: u8 = 1 << 1; +const HUP: u8 = 1 << 2; +const USR1: u8 = 1 << 3; +const USR2: u8 = 1 << 4; + +static PIPE_FD: AtomicI32 = AtomicI32::new(-1); + +pub struct SignalPipe { + read_half: OwnedFd, +} + +impl SignalPipe { + pub fn new() -> io::Result { + unsafe { + let (read_half, write_half) = pipe()?; + fcntl_setfd(&read_half, FdFlags::CLOEXEC)?; + fcntl_setfd(&write_half, FdFlags::CLOEXEC)?; + fcntl_setfl(&read_half, OFlags::NONBLOCK)?; + fcntl_setfl(&write_half, OFlags::NONBLOCK)?; + PIPE_FD.compare_exchange( + -1, + write_half.as_raw_fd(), + Relaxed, + Relaxed, + ).unwrap(); + let _ = ManuallyDrop::new(write_half); + let ret = SignalPipe { read_half }; + let sigset_empty = sigset_empty()?; + for signum in TERM_SIGNALS { + sigaction_set_handler( + signum, + handle_termination_signals, + sigset_empty, + SA_RESTART | SA_RESETHAND, + )?; + } + for signum in OTHER_SIGNALS { + sigaction_set_handler( + signum, + handle_other_signals, + sigset_empty, + SA_RESTART, + )?; + } + Ok(ret) + } + } + + pub fn read(&self) -> io::Result { + let mut buf = [MaybeUninit::::uninit(); 64]; + let mut flags = 0; + for byte in read_uninit(&self.read_half, &mut buf)?.0 { + assert_ne!(*byte, 0); + flags |= *byte; + } + Ok(SignalFlags(flags)) + } +} + +impl Drop for SignalPipe { + fn drop(&mut self) { + for signum in OTHER_SIGNALS { + sigaction_reset_default(signum).unwrap(); + } + for signum in TERM_SIGNALS { + sigaction_reset_default(signum).unwrap(); + } + let write_half_fd = PIPE_FD.swap(-1, Relaxed); + assert_ne!(write_half_fd, -1); + drop(unsafe { OwnedFd::from_raw_fd(write_half_fd) }); + } +} + +impl AsRawFd for SignalPipe { + fn as_raw_fd(&self) -> RawFd { + self.read_half.as_raw_fd() + } +} + +#[derive(Clone, Copy)] +pub struct SignalFlags(u8); + +impl SignalFlags { + pub fn any_termination(self) -> Option<&'static str> { + if self.0 & TERM != 0 { + Some("TERM") + } else if self.0 & INT != 0 { + Some("INT") + } else if self.0 & HUP != 0 { + Some("HUP") + } else { + None + } + } + pub fn has_usr1(self) -> bool { + self.0 & USR1 != 0 + } + pub fn has_usr2(self) -> bool { + self.0 & USR2 != 0 + } +} + +fn sigset_empty() -> io::Result { + unsafe { + let mut sigset = MaybeUninit::uninit(); + if sigemptyset(sigset.as_mut_ptr()) < 0 { + return Err(io::Error::last_os_error()) + } + Ok(sigset.assume_init()) + } +} + +unsafe fn sigaction_set_handler( + signum: c_int, + handler: extern "C" fn(c_int), + mask: sigset_t, + flags: c_int +) -> io::Result<()> { + unsafe { + if sigaction( + signum, + &sigaction { + sa_sigaction: handler as _, + sa_mask: mask, + sa_flags: flags, + sa_restorer: None, + }, + ptr::null_mut(), + ) < 0 { + return Err(io::Error::last_os_error()) + } + Ok(()) + } +} + +fn sigaction_reset_default(signum: c_int) -> io::Result<()> { + unsafe { + if signal(signum, SIG_DFL) == SIG_ERR { + return Err(io::Error::last_os_error()) + } + Ok(()) + } +} + +extern "C" fn handle_termination_signals(signum: c_int) { + unsafe { + let _errno_guard = ErrnoGuard::new(); + let byte: u8 = match signum { + SIGTERM => TERM, + SIGINT => INT, + SIGHUP => HUP, + _ => 0, + }; + // In case of an error termination signals will have SA_RESETHAND set + // so re-raise the signal to invoke the default handler + if write( + PIPE_FD.load(Relaxed), + ptr::from_ref(&byte).cast(), + 1, + ) != 1 { + raise(signum); + } + } +} + +extern "C" fn handle_other_signals(signum: c_int) { + unsafe { + let _errno_guard = ErrnoGuard::new(); + let byte: u8 = match signum { + SIGUSR1 => USR1, + SIGUSR2 => USR2, + _ => 0, + }; + // In case of an error ignore non-termination signals + let _ = write( + PIPE_FD.load(Relaxed), + ptr::from_ref(&byte).cast(), + 1, + ); + } +} + +struct ErrnoGuard(i32); + +impl ErrnoGuard { + unsafe fn new() -> ErrnoGuard { + ErrnoGuard(errno()) + } +} + +impl Drop for ErrnoGuard { + fn drop(&mut self) { + set_errno(self.0) + } +} + +// Based on the Rust Standard Library: +// .rustup/toolchains/nightly-x86_64-unknown-linux-gnu/lib/rustlib/src/rust/library/std/src/sys/pal/unix/os.rs +// with changes for stable rust from the errno crate: +// https://github.com/lambda-fairy/rust-errno/blob/main/src/unix.rs +// under licence MIT OR Apache-2.0 + +unsafe extern "C" { + #[cfg_attr( + any( + target_os = "linux", + target_os = "emscripten", + target_os = "fuchsia", + target_os = "l4re", + target_os = "hurd", + target_os = "dragonfly", + ), + link_name = "__errno_location" + )] + #[cfg_attr( + any( + target_os = "netbsd", + target_os = "openbsd", + target_os = "cygwin", + target_os = "android", + target_os = "redox", + target_os = "nuttx", + target_env = "newlib", + target_os = "vxworks", + ), + link_name = "__errno" + )] + #[cfg_attr( + any(target_os = "solaris", target_os = "illumos"), + link_name = "___errno" + )] + #[cfg_attr(target_os = "nto", link_name = "__get_errno_ptr")] + #[cfg_attr( + any(target_os = "freebsd", target_vendor = "apple"), + link_name = "__error" + )] + #[cfg_attr(target_os = "haiku", link_name = "_errnop")] + #[cfg_attr(target_os = "aix", link_name = "_Errno")] + // SAFETY: this will always return the same pointer on a given thread. + fn errno_location() -> *mut c_int; +} + +/// Returns the platform-specific value of errno +#[inline] +fn errno() -> i32 { + unsafe { (*errno_location()) as i32 } +} + +/// Sets the platform-specific value of errno +// needed for readdir and syscall! +#[inline] +fn set_errno(e: i32) { + unsafe { *errno_location() = e as c_int } +}