diff --git a/README.md b/README.md index 87bcda114..7dbf56e50 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ fn main() -> hyperlight_host::Result<()> { let message = "Hello, World! I am executing inside of a VM :)\n".to_string(); // in order to call a function it first must be defined in the guest and exposed so that // the host can call it - multi_use_sandbox.call_guest_function_by_name::( + multi_use_sandbox.call::( "PrintOutput", message, )?; diff --git a/fuzz/fuzz_targets/host_print.rs b/fuzz/fuzz_targets/host_print.rs index c9c889e9e..59dc1ed13 100644 --- a/fuzz/fuzz_targets/host_print.rs +++ b/fuzz/fuzz_targets/host_print.rs @@ -27,7 +27,7 @@ fuzz_target!( |data: String| -> Corpus { let mut sandbox = SANDBOX.get().unwrap().lock().unwrap(); - let len: i32 = sandbox.call_guest_function_by_name::( + let len: i32 = sandbox.call::( "PrintOutput", data, ) diff --git a/src/hyperlight_host/benches/benchmarks.rs b/src/hyperlight_host/benches/benchmarks.rs index 96cc7ecf0..4896d9e14 100644 --- a/src/hyperlight_host/benches/benchmarks.rs +++ b/src/hyperlight_host/benches/benchmarks.rs @@ -16,9 +16,7 @@ limitations under the License. use criterion::{Criterion, criterion_group, criterion_main}; use hyperlight_host::GuestBinary; -use hyperlight_host::sandbox::{ - Callable, MultiUseSandbox, SandboxConfiguration, UninitializedSandbox, -}; +use hyperlight_host::sandbox::{MultiUseSandbox, SandboxConfiguration, UninitializedSandbox}; use hyperlight_testing::simple_guest_as_string; fn create_uninit_sandbox() -> UninitializedSandbox { @@ -99,10 +97,7 @@ fn guest_call_benchmark_large_param(c: &mut Criterion) { b.iter(|| { sandbox - .call_guest_function_by_name::<()>( - "LargeParameters", - (large_vec.clone(), large_string.clone()), - ) + .call::<()>("LargeParameters", (large_vec.clone(), large_string.clone())) .unwrap() }); }); diff --git a/src/hyperlight_host/examples/func_ctx/main.rs b/src/hyperlight_host/examples/func_ctx/main.rs index 2692f9487..8aedf0983 100644 --- a/src/hyperlight_host/examples/func_ctx/main.rs +++ b/src/hyperlight_host/examples/func_ctx/main.rs @@ -29,13 +29,9 @@ fn main() { // Do several calls against a sandbox running the `simpleguest.exe` binary, // and print their results - let res: String = sbox - .call_guest_function_by_name("Echo", "hello".to_string()) - .unwrap(); + let res: String = sbox.call("Echo", "hello".to_string()).unwrap(); println!("got Echo res: {res}"); - let res: i32 = sbox - .call_guest_function_by_name("CallMalloc", 200_i32) - .unwrap(); + let res: i32 = sbox.call("CallMalloc", 200_i32).unwrap(); println!("got CallMalloc res: {res}"); } diff --git a/src/hyperlight_host/examples/guest-debugging/main.rs b/src/hyperlight_host/examples/guest-debugging/main.rs index cc39b1705..e414b72d8 100644 --- a/src/hyperlight_host/examples/guest-debugging/main.rs +++ b/src/hyperlight_host/examples/guest-debugging/main.rs @@ -75,7 +75,7 @@ fn main() -> hyperlight_host::Result<()> { let message = "Hello, World! I am executing inside of a VM with debugger attached :)\n".to_string(); multi_use_sandbox_dbg - .call_guest_function_by_name::( + .call::( "PrintOutput", // function must be defined in the guest binary message.clone(), ) @@ -84,7 +84,7 @@ fn main() -> hyperlight_host::Result<()> { let message = "Hello, World! I am executing inside of a VM without debugger attached :)\n".to_string(); multi_use_sandbox - .call_guest_function_by_name::( + .call::( "PrintOutput", // function must be defined in the guest binary message.clone(), ) diff --git a/src/hyperlight_host/examples/hello-world/main.rs b/src/hyperlight_host/examples/hello-world/main.rs index e8461954a..4a403515b 100644 --- a/src/hyperlight_host/examples/hello-world/main.rs +++ b/src/hyperlight_host/examples/hello-world/main.rs @@ -40,7 +40,7 @@ fn main() -> hyperlight_host::Result<()> { // Call guest function let message = "Hello, World! I am executing inside of a VM :)\n".to_string(); multi_use_sandbox - .call_guest_function_by_name::( + .call::( "PrintOutput", // function must be defined in the guest binary message, ) diff --git a/src/hyperlight_host/examples/logging/main.rs b/src/hyperlight_host/examples/logging/main.rs index 470441e72..23a642886 100644 --- a/src/hyperlight_host/examples/logging/main.rs +++ b/src/hyperlight_host/examples/logging/main.rs @@ -50,7 +50,7 @@ fn main() -> Result<()> { // Call a guest function 5 times to generate some log entries. for _ in 0..5 { multiuse_sandbox - .call_guest_function_by_name::("Echo", "a".to_string()) + .call::("Echo", "a".to_string()) .unwrap(); } @@ -61,7 +61,7 @@ fn main() -> Result<()> { // Call a guest function that calls the HostPrint host function 5 times to generate some log entries. for _ in 0..5 { multiuse_sandbox - .call_guest_function_by_name::("PrintOutput", msg.clone()) + .call::("PrintOutput", msg.clone()) .unwrap(); } Ok(()) @@ -94,9 +94,7 @@ fn main() -> Result<()> { for _ in 0..NUM_CALLS { barrier.wait(); - multiuse_sandbox - .call_guest_function_by_name::<()>("Spin", ()) - .unwrap_err(); + multiuse_sandbox.call::<()>("Spin", ()).unwrap_err(); } thread.join().unwrap(); diff --git a/src/hyperlight_host/examples/metrics/main.rs b/src/hyperlight_host/examples/metrics/main.rs index 4328a3bfa..3401a3140 100644 --- a/src/hyperlight_host/examples/metrics/main.rs +++ b/src/hyperlight_host/examples/metrics/main.rs @@ -61,7 +61,7 @@ fn do_hyperlight_stuff() { // Call a guest function 5 times to generate some metrics. for _ in 0..5 { multiuse_sandbox - .call_guest_function_by_name::("Echo", "a".to_string()) + .call::("Echo", "a".to_string()) .unwrap(); } @@ -72,7 +72,7 @@ fn do_hyperlight_stuff() { // Call a guest function that calls the HostPrint host function 5 times to generate some metrics. for _ in 0..5 { multiuse_sandbox - .call_guest_function_by_name::("PrintOutput", msg.clone()) + .call::("PrintOutput", msg.clone()) .unwrap(); } Ok(()) @@ -108,9 +108,7 @@ fn do_hyperlight_stuff() { for _ in 0..NUM_CALLS { barrier.wait(); - multiuse_sandbox - .call_guest_function_by_name::<()>("Spin", ()) - .unwrap_err(); + multiuse_sandbox.call::<()>("Spin", ()).unwrap_err(); } for join_handle in join_handles { diff --git a/src/hyperlight_host/examples/tracing-chrome/main.rs b/src/hyperlight_host/examples/tracing-chrome/main.rs index 259a37057..7c4001106 100644 --- a/src/hyperlight_host/examples/tracing-chrome/main.rs +++ b/src/hyperlight_host/examples/tracing-chrome/main.rs @@ -36,7 +36,7 @@ fn main() -> Result<()> { // do the function call let current_time = std::time::Instant::now(); - let res: String = sbox.call_guest_function_by_name("Echo", "Hello, World!".to_string())?; + let res: String = sbox.call("Echo", "Hello, World!".to_string())?; let elapsed = current_time.elapsed(); println!("Function call finished in {:?}.", elapsed); assert_eq!(res, "Hello, World!"); diff --git a/src/hyperlight_host/examples/tracing-otlp/main.rs b/src/hyperlight_host/examples/tracing-otlp/main.rs index 69d318c1f..db5c9baf4 100644 --- a/src/hyperlight_host/examples/tracing-otlp/main.rs +++ b/src/hyperlight_host/examples/tracing-otlp/main.rs @@ -136,7 +136,7 @@ fn run_example(wait_input: bool) -> HyperlightResult<()> { // Call a guest function 5 times to generate some log entries. for _ in 0..5 { multiuse_sandbox - .call_guest_function_by_name::("Echo", "a".to_string()) + .call::("Echo", "a".to_string()) .unwrap(); } @@ -147,7 +147,7 @@ fn run_example(wait_input: bool) -> HyperlightResult<()> { // Call a guest function that calls the HostPrint host function 5 times to generate some log entries. for _ in 0..5 { multiuse_sandbox - .call_guest_function_by_name::("PrintOutput", msg.clone()) + .call::("PrintOutput", msg.clone()) .unwrap(); } @@ -179,9 +179,7 @@ fn run_example(wait_input: bool) -> HyperlightResult<()> { ); let _entered = span.enter(); barrier.wait(); - multiuse_sandbox - .call_guest_function_by_name::<()>("Spin", ()) - .unwrap_err(); + multiuse_sandbox.call::<()>("Spin", ()).unwrap_err(); } thread.join().expect("Thread panicked"); } diff --git a/src/hyperlight_host/examples/tracing-tracy/main.rs b/src/hyperlight_host/examples/tracing-tracy/main.rs index 03867f0d5..1c3912ebf 100644 --- a/src/hyperlight_host/examples/tracing-tracy/main.rs +++ b/src/hyperlight_host/examples/tracing-tracy/main.rs @@ -42,7 +42,7 @@ fn main() -> Result<()> { // do the function call let current_time = std::time::Instant::now(); - let res: String = sbox.call_guest_function_by_name("Echo", "Hello, World!".to_string())?; + let res: String = sbox.call("Echo", "Hello, World!".to_string())?; let elapsed = current_time.elapsed(); println!("Function call finished in {:?}.", elapsed); assert_eq!(res, "Hello, World!"); diff --git a/src/hyperlight_host/examples/tracing/main.rs b/src/hyperlight_host/examples/tracing/main.rs index 66020dcfa..fce5215e4 100644 --- a/src/hyperlight_host/examples/tracing/main.rs +++ b/src/hyperlight_host/examples/tracing/main.rs @@ -78,7 +78,7 @@ fn run_example() -> Result<()> { // Call a guest function 5 times to generate some log entries. for _ in 0..5 { multiuse_sandbox - .call_guest_function_by_name::("Echo", "a".to_string()) + .call::("Echo", "a".to_string()) .unwrap(); } @@ -89,7 +89,7 @@ fn run_example() -> Result<()> { // Call a guest function that calls the HostPrint host function 5 times to generate some log entries. for _ in 0..5 { multiuse_sandbox - .call_guest_function_by_name::("PrintOutput", msg.clone()) + .call::("PrintOutput", msg.clone()) .unwrap(); } Ok(()) @@ -131,9 +131,7 @@ fn run_example() -> Result<()> { ); let _entered = span.enter(); barrier.wait(); - multiuse_sandbox - .call_guest_function_by_name::<()>("Spin", ()) - .unwrap_err(); + multiuse_sandbox.call::<()>("Spin", ()).unwrap_err(); } for join_handle in join_handles { diff --git a/src/hyperlight_host/src/metrics/mod.rs b/src/hyperlight_host/src/metrics/mod.rs index a2f4026e5..76e9b93a3 100644 --- a/src/hyperlight_host/src/metrics/mod.rs +++ b/src/hyperlight_host/src/metrics/mod.rs @@ -116,12 +116,10 @@ mod tests { }); multi - .call_guest_function_by_name::("PrintOutput", "Hello".to_string()) + .call::("PrintOutput", "Hello".to_string()) .unwrap(); - multi - .call_guest_function_by_name::("Spin", ()) - .unwrap_err(); + multi.call::("Spin", ()).unwrap_err(); thread.join().unwrap(); snapshotter.snapshot() diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index 7692a4379..671dbe7db 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -63,6 +63,9 @@ pub struct MultiUseSandbox { dispatch_ptr: RawPtr, #[cfg(gdb)] dbg_mem_access_fn: DbgMemAccessHandlerWrapper, + /// If the current state of the sandbox has been captured in a snapshot, + /// that snapshot is stored here. + snapshot: Option, } impl MultiUseSandbox { @@ -87,6 +90,7 @@ impl MultiUseSandbox { dispatch_ptr, #[cfg(gdb)] dbg_mem_access_fn, + snapshot: None, } } @@ -115,15 +119,19 @@ impl MultiUseSandbox { /// ``` #[instrument(err(Debug), skip_all, parent = Span::current())] pub fn snapshot(&mut self) -> Result { + if let Some(snapshot) = &self.snapshot { + return Ok(snapshot.clone()); + } let mapped_regions_iter = self.vm.get_mapped_regions(); let mapped_regions_vec: Vec = mapped_regions_iter.cloned().collect(); let memory_snapshot = self .mem_mgr .unwrap_mgr_mut() .snapshot(self.id, mapped_regions_vec)?; - Ok(Snapshot { - inner: memory_snapshot, - }) + let inner = Arc::new(memory_snapshot); + let snapshot = Snapshot { inner }; + self.snapshot = Some(snapshot.clone()); + Ok(snapshot) } /// Restores the sandbox's memory to a previously captured snapshot state. @@ -159,6 +167,13 @@ impl MultiUseSandbox { /// ``` #[instrument(err(Debug), skip_all, parent = Span::current())] pub fn restore(&mut self, snapshot: &Snapshot) -> Result<()> { + if let Some(snap) = &self.snapshot { + if Arc::ptr_eq(&snap.inner, &snapshot.inner) { + // If the snapshot is already the current one, no need to restore + return Ok(()); + } + } + if self.id != snapshot.inner.sandbox_id() { return Err(SnapshotSandboxMismatch); } @@ -181,12 +196,15 @@ impl MultiUseSandbox { unsafe { self.vm.map_region(region)? }; } + // The restored snapshot is now our most current snapshot + self.snapshot = Some(snapshot.clone()); + Ok(()) } /// Calls a guest function by name with the specified arguments. /// - /// Changes made to the sandbox during execution are persisted. + /// Changes made to the sandbox during execution are *not* persisted. /// /// # Examples /// @@ -215,12 +233,62 @@ impl MultiUseSandbox { /// # Ok(()) /// # } /// ``` + #[doc(hidden)] + #[deprecated( + since = "0.8.0", + note = "Deprecated in favour of call and snapshot/restore." + )] #[instrument(err(Debug), skip(self, args), parent = Span::current())] pub fn call_guest_function_by_name( &mut self, func_name: &str, args: impl ParameterTuple, ) -> Result { + let snapshot = self.snapshot()?; + let res = self.call(func_name, args); + self.restore(&snapshot)?; + res + } + + /// Calls a guest function by name with the specified arguments. + /// + /// Changes made to the sandbox during execution are persisted. + /// + /// # Examples + /// + /// ```no_run + /// # use hyperlight_host::{MultiUseSandbox, UninitializedSandbox, GuestBinary}; + /// # fn example() -> Result<(), Box> { + /// let mut sandbox: MultiUseSandbox = UninitializedSandbox::new( + /// GuestBinary::FilePath("guest.bin".into()), + /// None + /// )?.evolve()?; + /// + /// // Call function with no arguments + /// let result: i32 = sandbox.call("GetCounter", ())?; + /// + /// // Call function with single argument + /// let doubled: i32 = sandbox.call("Double", 21)?; + /// assert_eq!(doubled, 42); + /// + /// // Call function with multiple arguments + /// let sum: i32 = sandbox.call("Add", (10, 32))?; + /// assert_eq!(sum, 42); + /// + /// // Call function returning string + /// let message: String = sandbox.call("Echo", "Hello, World!".to_string())?; + /// assert_eq!(message, "Hello, World!"); + /// # Ok(()) + /// # } + /// ``` + #[instrument(err(Debug), skip(self, args), parent = Span::current())] + pub fn call( + &mut self, + func_name: &str, + args: impl ParameterTuple, + ) -> Result { + // Reset snapshot since we are mutating the sandbox state + self.snapshot = None; maybe_time_and_emit_guest_call(func_name, || { let ret = self.call_guest_function_by_name_no_reset( func_name, @@ -254,6 +322,8 @@ impl MultiUseSandbox { // writes can be rolled back when necessary. log_then_return!("TODO: Writable mappings not yet supported"); } + // Reset snapshot since we are mutating the sandbox state + self.snapshot = None; unsafe { self.vm.map_region(rgn) }?; self.mem_mgr.unwrap_mgr_mut().mapped_rgns += 1; Ok(()) @@ -397,7 +467,7 @@ impl Callable for MultiUseSandbox { func_name: &str, args: impl ParameterTuple, ) -> Result { - self.call_guest_function_by_name(func_name, args) + self.call(func_name, args) } } @@ -429,9 +499,35 @@ mod tests { use crate::mem::memory_region::{MemoryRegion, MemoryRegionFlags, MemoryRegionType}; #[cfg(target_os = "linux")] use crate::mem::shared_mem::{ExclusiveSharedMemory, GuestSharedMemory, SharedMemory as _}; - use crate::sandbox::{Callable, SandboxConfiguration}; + use crate::sandbox::SandboxConfiguration; use crate::{GuestBinary, HyperlightError, MultiUseSandbox, Result, UninitializedSandbox}; + /// Tests that call_guest_function_by_name restores the state correctly + #[test] + fn test_call_guest_function_by_name() { + let mut sbox: MultiUseSandbox = { + let path = simple_guest_as_string().unwrap(); + let u_sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), None).unwrap(); + u_sbox.evolve() + } + .unwrap(); + + let snapshot = sbox.snapshot().unwrap(); + + let _ = sbox.call::("AddToStatic", 5i32).unwrap(); + let res: i32 = sbox.call("GetStatic", ()).unwrap(); + assert_eq!(res, 5); + + sbox.restore(&snapshot).unwrap(); + #[allow(deprecated)] + let _ = sbox + .call_guest_function_by_name::("AddToStatic", 5i32) + .unwrap(); + #[allow(deprecated)] + let res: i32 = sbox.call_guest_function_by_name("GetStatic", ()).unwrap(); + assert_eq!(res, 0); + } + // Tests to ensure that many (1000) function calls can be made in a call context with a small stack (1K) and heap(14K). // This test effectively ensures that the stack is being properly reset after each call and we are not leaking memory in the Guest. #[test] @@ -481,15 +577,13 @@ mod tests { let snapshot = sbox.snapshot().unwrap(); - let _ = sbox - .call_guest_function_by_name::("AddToStatic", 5i32) - .unwrap(); + let _ = sbox.call::("AddToStatic", 5i32).unwrap(); - let res: i32 = sbox.call_guest_function_by_name("GetStatic", ()).unwrap(); + let res: i32 = sbox.call("GetStatic", ()).unwrap(); assert_eq!(res, 5); sbox.restore(&snapshot).unwrap(); - let res: i32 = sbox.call_guest_function_by_name("GetStatic", ()).unwrap(); + let res: i32 = sbox.call("GetStatic", ()).unwrap(); assert_eq!(res, 0); } @@ -515,7 +609,7 @@ mod tests { let mut sbox: MultiUseSandbox = usbox.evolve()?; - let res: Result = sbox.call_guest_function_by_name("ViolateSeccompFilters", ()); + let res: Result = sbox.call("ViolateSeccompFilters", ()); #[cfg(feature = "seccomp")] match res { @@ -551,7 +645,7 @@ mod tests { let mut sbox: MultiUseSandbox = usbox.evolve()?; - let res: Result = sbox.call_guest_function_by_name("ViolateSeccompFilters", ()); + let res: Result = sbox.call("ViolateSeccompFilters", ()); match res { Ok(_) => {} @@ -605,7 +699,7 @@ mod tests { let mut sbox = ubox.evolve().unwrap(); let host_func_result = sbox - .call_guest_function_by_name::( + .call::( "CallGivenParamlessHostFuncThatReturnsI64", "Openat_Hostfunc".to_string(), ) @@ -634,8 +728,8 @@ mod tests { [libc::SYS_openat], )?; let mut sbox = ubox.evolve().unwrap(); - let host_func_result = sbox - .call_guest_function_by_name::( + let host_func_result: i64 = sbox + .call::( "CallGivenParamlessHostFuncThatReturnsI64", "Openat_Hostfunc".to_string(), ) @@ -658,7 +752,7 @@ mod tests { let mut multi_use_sandbox: MultiUseSandbox = usbox.evolve().unwrap(); - let res: Result<()> = multi_use_sandbox.call_guest_function_by_name("TriggerException", ()); + let res: Result<()> = multi_use_sandbox.call("TriggerException", ()); assert!(res.is_err()); @@ -698,9 +792,7 @@ mod tests { let mut multi_use_sandbox: MultiUseSandbox = usbox.evolve().unwrap(); - let res: i32 = multi_use_sandbox - .call_guest_function_by_name("GetStatic", ()) - .unwrap(); + let res: i32 = multi_use_sandbox.call("GetStatic", ()).unwrap(); assert_eq!(res, 0); } @@ -742,7 +834,7 @@ mod tests { let _guard = map_mem.lock.try_read().unwrap(); let actual: Vec = sbox - .call_guest_function_by_name( + .call( "ReadMappedBuffer", (guest_base as u64, expected.len() as u64), ) @@ -780,7 +872,7 @@ mod tests { // Execute should pass since memory is executable let succeed = sbox - .call_guest_function_by_name::( + .call::( "ExecMappedBuffer", (guest_base as u64, expected.len() as u64), ) @@ -789,7 +881,7 @@ mod tests { // write should fail because the memory is mapped as read-only let err = sbox - .call_guest_function_by_name::( + .call::( "WriteMappedBuffer", (guest_base as u64, expected.len() as u64), ) diff --git a/src/hyperlight_host/src/sandbox/snapshot.rs b/src/hyperlight_host/src/sandbox/snapshot.rs index e9e996e7a..c00aa4487 100644 --- a/src/hyperlight_host/src/sandbox/snapshot.rs +++ b/src/hyperlight_host/src/sandbox/snapshot.rs @@ -14,11 +14,12 @@ See the License for the specific language governing permissions and limitations under the License. */ +use std::sync::Arc; + use crate::mem::shared_mem_snapshot::SharedMemorySnapshot; /// A snapshot capturing the state of the memory in a `MultiUseSandbox`. #[derive(Clone)] pub struct Snapshot { - // TODO: Use Arc - pub(crate) inner: SharedMemorySnapshot, + pub(crate) inner: Arc, } diff --git a/src/hyperlight_host/src/sandbox/uninitialized.rs b/src/hyperlight_host/src/sandbox/uninitialized.rs index 65e9d9a80..6d905cf52 100644 --- a/src/hyperlight_host/src/sandbox/uninitialized.rs +++ b/src/hyperlight_host/src/sandbox/uninitialized.rs @@ -434,7 +434,7 @@ mod tests { let mut sandbox: MultiUseSandbox = uninitialized_sandbox.evolve().unwrap(); let res = sandbox - .call_guest_function_by_name::>("ReadFromUserMemory", (4u64, buffer.to_vec())) + .call::>("ReadFromUserMemory", (4u64, buffer.to_vec())) .expect("Failed to call ReadFromUserMemory"); assert_eq!(res, buffer.to_vec()); diff --git a/src/hyperlight_host/tests/integration_test.rs b/src/hyperlight_host/tests/integration_test.rs index 967d93c53..eceabd479 100644 --- a/src/hyperlight_host/tests/integration_test.rs +++ b/src/hyperlight_host/tests/integration_test.rs @@ -70,9 +70,7 @@ fn interrupt_host_call() { } }); - let result = sandbox - .call_guest_function_by_name::("CallHostSpin", ()) - .unwrap_err(); + let result = sandbox.call::("CallHostSpin", ()).unwrap_err(); assert!(matches!(result, HyperlightError::ExecutionCanceledByHost())); thread.join().unwrap(); @@ -96,16 +94,12 @@ fn interrupt_in_progress_guest_call() { assert!(interrupt_handle.dropped()); }); - let res = sbox1 - .call_guest_function_by_name::("Spin", ()) - .unwrap_err(); + let res = sbox1.call::("Spin", ()).unwrap_err(); assert!(matches!(res, HyperlightError::ExecutionCanceledByHost())); barrier.wait(); // Make sure we can still call guest functions after the VM was interrupted - sbox1 - .call_guest_function_by_name::("Echo", "hello".to_string()) - .unwrap(); + sbox1.call::("Echo", "hello".to_string()).unwrap(); // drop vm to make sure other thread can detect it drop(sbox1); @@ -131,15 +125,11 @@ fn interrupt_guest_call_in_advance() { }); barrier.wait(); // wait until `kill()` is called before starting the guest call - let res = sbox1 - .call_guest_function_by_name::("Spin", ()) - .unwrap_err(); + let res = sbox1.call::("Spin", ()).unwrap_err(); assert!(matches!(res, HyperlightError::ExecutionCanceledByHost())); // Make sure we can still call guest functions after the VM was interrupted - sbox1 - .call_guest_function_by_name::("Echo", "hello".to_string()) - .unwrap(); + sbox1.call::("Echo", "hello".to_string()).unwrap(); // drop vm to make sure other thread can detect it drop(sbox1); @@ -181,9 +171,9 @@ fn interrupt_same_thread() { for _ in 0..NUM_ITERS { barrier.wait(); sbox1 - .call_guest_function_by_name::("Echo", "hello".to_string()) + .call::("Echo", "hello".to_string()) .expect("Only sandbox 2 is allowed to be interrupted"); - match sbox2.call_guest_function_by_name::("Echo", "hello".to_string()) { + match sbox2.call::("Echo", "hello".to_string()) { Ok(_) | Err(HyperlightError::ExecutionCanceledByHost()) => { // Only allow successful calls or interrupted. // The call can be successful in case the call is finished before kill() is called. @@ -191,7 +181,7 @@ fn interrupt_same_thread() { _ => panic!("Unexpected return"), }; sbox3 - .call_guest_function_by_name::("Echo", "hello".to_string()) + .call::("Echo", "hello".to_string()) .expect("Only sandbox 2 is allowed to be interrupted"); } thread.join().expect("Thread should finish"); @@ -225,9 +215,9 @@ fn interrupt_same_thread_no_barrier() { barrier.wait(); for _ in 0..NUM_ITERS { sbox1 - .call_guest_function_by_name::("Echo", "hello".to_string()) + .call::("Echo", "hello".to_string()) .expect("Only sandbox 2 is allowed to be interrupted"); - match sbox2.call_guest_function_by_name::("Echo", "hello".to_string()) { + match sbox2.call::("Echo", "hello".to_string()) { Ok(_) | Err(HyperlightError::ExecutionCanceledByHost()) => { // Only allow successful calls or interrupted. // The call can be successful in case the call is finished before kill() is called. @@ -235,7 +225,7 @@ fn interrupt_same_thread_no_barrier() { _ => panic!("Unexpected return"), }; sbox3 - .call_guest_function_by_name::("Echo", "hello".to_string()) + .call::("Echo", "hello".to_string()) .expect("Only sandbox 2 is allowed to be interrupted"); } workload_done.store(true, Ordering::Relaxed); @@ -257,9 +247,7 @@ fn interrupt_moved_sandbox() { let thread = thread::spawn(move || { barrier2.wait(); - let res = sbox1 - .call_guest_function_by_name::("Spin", ()) - .unwrap_err(); + let res = sbox1.call::("Spin", ()).unwrap_err(); assert!(matches!(res, HyperlightError::ExecutionCanceledByHost())); }); @@ -272,9 +260,7 @@ fn interrupt_moved_sandbox() { assert!(interrupt_handle2.kill()); }); - let res = sbox2 - .call_guest_function_by_name::("Spin", ()) - .unwrap_err(); + let res = sbox2.call::("Spin", ()).unwrap_err(); assert!(matches!(res, HyperlightError::ExecutionCanceledByHost())); thread.join().expect("Thread should finish"); @@ -315,9 +301,7 @@ fn interrupt_custom_signal_no_and_retry_delay() { }); for _ in 0..NUM_ITERS { - let res = sbox1 - .call_guest_function_by_name::("Spin", ()) - .unwrap_err(); + let res = sbox1.call::("Spin", ()).unwrap_err(); assert!(matches!(res, HyperlightError::ExecutionCanceledByHost())); // immediately reenter another guest function call after having being cancelled, // so that the vcpu is running again before the interruptor-thread has a chance to see that the vcpu is not running @@ -354,7 +338,7 @@ fn interrupt_spamming_host_call() { barrier.wait(); // This guest call calls "HostFunc1" in a loop let res = sbox1 - .call_guest_function_by_name::("HostCallLoop", "HostFunc1".to_string()) + .call::("HostCallLoop", "HostFunc1".to_string()) .unwrap_err(); assert!(matches!(res, HyperlightError::ExecutionCanceledByHost())); @@ -369,7 +353,7 @@ fn print_four_args_c_guest() { let uninit = UninitializedSandbox::new(guest_path, None); let mut sbox1 = uninit.unwrap().evolve().unwrap(); - let res = sbox1.call_guest_function_by_name::( + let res = sbox1.call::( "PrintFourArgs", ("Test4".to_string(), 3_i32, 4_i64, "Tested".to_string()), ); @@ -383,7 +367,7 @@ fn guest_abort() { let mut sbox1 = new_uninit().unwrap().evolve().unwrap(); let error_code: u8 = 13; // this is arbitrary let res = sbox1 - .call_guest_function_by_name::<()>("GuestAbortWithCode", error_code as i32) + .call::<()>("GuestAbortWithCode", error_code as i32) .unwrap_err(); println!("{:?}", res); assert!( @@ -396,7 +380,7 @@ fn guest_abort_with_context1() { let mut sbox1 = new_uninit().unwrap().evolve().unwrap(); let res = sbox1 - .call_guest_function_by_name::<()>("GuestAbortWithMessage", (25_i32, "Oh no".to_string())) + .call::<()>("GuestAbortWithMessage", (25_i32, "Oh no".to_string())) .unwrap_err(); println!("{:?}", res); assert!( @@ -441,10 +425,7 @@ fn guest_abort_with_context2() { Proin sagittis nisl rhoncus mattis rhoncus urna. Magna eget est lorem ipsum."; let res = sbox1 - .call_guest_function_by_name::<()>( - "GuestAbortWithMessage", - (60_i32, abort_message.to_string()), - ) + .call::<()>("GuestAbortWithMessage", (60_i32, abort_message.to_string())) .unwrap_err(); println!("{:?}", res); assert!( @@ -463,7 +444,7 @@ fn guest_abort_c_guest() { let mut sbox1 = uninit.unwrap().evolve().unwrap(); let res = sbox1 - .call_guest_function_by_name::<()>( + .call::<()>( "GuestAbortWithMessage", (75_i32, "This is a test error message".to_string()), ) @@ -480,7 +461,7 @@ fn guest_panic() { let mut sbox1 = new_uninit_rust().unwrap().evolve().unwrap(); let res = sbox1 - .call_guest_function_by_name::<()>("guest_panic", "Error... error...".to_string()) + .call::<()>("guest_panic", "Error... error...".to_string()) .unwrap_err(); println!("{:?}", res); assert!( @@ -494,9 +475,7 @@ fn guest_malloc() { let mut sbox1 = new_uninit_rust().unwrap().evolve().unwrap(); let size_to_allocate = 2000_i32; - sbox1 - .call_guest_function_by_name::("TestMalloc", size_to_allocate) - .unwrap(); + sbox1.call::("TestMalloc", size_to_allocate).unwrap(); } #[test] @@ -506,7 +485,7 @@ fn guest_allocate_vec() { let size_to_allocate = 2000_i32; let res = sbox1 - .call_guest_function_by_name::( + .call::( "CallMalloc", // uses the rust allocator to allocate a vector on heap size_to_allocate, ) @@ -522,9 +501,7 @@ fn guest_malloc_abort() { let size = 20000000_i32; // some big number that should fail when allocated - let res = sbox1 - .call_guest_function_by_name::("TestMalloc", size) - .unwrap_err(); + let res = sbox1.call::("TestMalloc", size).unwrap_err(); println!("{:?}", res); assert!( matches!(res, HyperlightError::GuestAborted(code, _) if code == ErrorCode::MallocFailed as u8) @@ -544,7 +521,7 @@ fn guest_malloc_abort() { .unwrap(); let mut sbox2 = uninit.evolve().unwrap(); - let res = sbox2.call_guest_function_by_name::( + let res = sbox2.call::( "CallMalloc", // uses the rust allocator to allocate a vector on heap size_to_allocate as i32, ); @@ -564,13 +541,11 @@ fn dynamic_stack_allocate_c_guest() { let uninit = UninitializedSandbox::new(guest_path, None); let mut sbox1: MultiUseSandbox = uninit.unwrap().evolve().unwrap(); - let res: i32 = sbox1 - .call_guest_function_by_name("StackAllocate", 100_i32) - .unwrap(); + let res: i32 = sbox1.call("StackAllocate", 100_i32).unwrap(); assert_eq!(res, 100); let res = sbox1 - .call_guest_function_by_name::("StackAllocate", 0x800_0000_i32) + .call::("StackAllocate", 0x800_0000_i32) .unwrap_err(); assert!(matches!(res, HyperlightError::StackOverflow())); } @@ -580,7 +555,7 @@ fn dynamic_stack_allocate_c_guest() { fn static_stack_allocate() { let mut sbox1 = new_uninit().unwrap().evolve().unwrap(); - let res: i32 = sbox1.call_guest_function_by_name("SmallVar", ()).unwrap(); + let res: i32 = sbox1.call("SmallVar", ()).unwrap(); assert_eq!(res, 1024); } @@ -588,9 +563,7 @@ fn static_stack_allocate() { #[test] fn static_stack_allocate_overflow() { let mut sbox1 = new_uninit().unwrap().evolve().unwrap(); - let res = sbox1 - .call_guest_function_by_name::("LargeVar", ()) - .unwrap_err(); + let res = sbox1.call::("LargeVar", ()).unwrap_err(); assert!(matches!(res, HyperlightError::StackOverflow())); } @@ -601,9 +574,7 @@ fn recursive_stack_allocate() { let iterations = 1_i32; - sbox1 - .call_guest_function_by_name::("StackOverflow", iterations) - .unwrap(); + sbox1.call::("StackOverflow", iterations).unwrap(); } // checks stack guard page (between guest stack and heap) @@ -628,7 +599,7 @@ fn guard_page_check() { // we have to create a sandbox each iteration because can't reuse after MMIO error in release mode let mut sbox1 = new_uninit_rust().unwrap().evolve().unwrap(); - let result = sbox1.call_guest_function_by_name::("test_write_raw_ptr", offset); + let result = sbox1.call::("test_write_raw_ptr", offset); if guard_range.contains(&offset) { // should have failed assert!(matches!( @@ -646,9 +617,7 @@ fn guard_page_check_2() { // this test is rust-guest only let mut sbox1 = new_uninit_rust().unwrap().evolve().unwrap(); - let result = sbox1 - .call_guest_function_by_name::<()>("InfiniteRecursion", ()) - .unwrap_err(); + let result = sbox1.call::<()>("InfiniteRecursion", ()).unwrap_err(); assert!(matches!(result, HyperlightError::StackOverflow())); } @@ -656,9 +625,7 @@ fn guard_page_check_2() { fn execute_on_stack() { let mut sbox1 = new_uninit().unwrap().evolve().unwrap(); - let result = sbox1 - .call_guest_function_by_name::("ExecuteOnStack", ()) - .unwrap_err(); + let result = sbox1.call::("ExecuteOnStack", ()).unwrap_err(); let err = result.to_string(); assert!( @@ -671,7 +638,7 @@ fn execute_on_stack() { #[ignore] // ran from Justfile because requires feature "executable_heap" fn execute_on_heap() { let mut sbox1 = new_uninit_rust().unwrap().evolve().unwrap(); - let result = sbox1.call_guest_function_by_name::("ExecuteOnHeap", ()); + let result = sbox1.call::("ExecuteOnHeap", ()); println!("{:#?}", result); #[cfg(feature = "executable_heap")] @@ -693,9 +660,7 @@ fn recursive_stack_allocate_overflow() { let iterations = 10_i32; - let res = sbox1 - .call_guest_function_by_name::<()>("StackOverflow", iterations) - .unwrap_err(); + let res = sbox1.call::<()>("StackOverflow", iterations).unwrap_err(); println!("{:?}", res); assert!(matches!(res, HyperlightError::StackOverflow())); } @@ -768,7 +733,7 @@ fn log_test_messages(levelfilter: Option) { let message = format!("Hello from log_message level {}", level as i32); sbox1 - .call_guest_function_by_name::<()>("LogMessage", (message.to_string(), level as i32)) + .call::<()>("LogMessage", (message.to_string(), level as i32)) .unwrap(); } } diff --git a/src/hyperlight_host/tests/sandbox_host_tests.rs b/src/hyperlight_host/tests/sandbox_host_tests.rs index 7ca0ba5dc..32047653c 100644 --- a/src/hyperlight_host/tests/sandbox_host_tests.rs +++ b/src/hyperlight_host/tests/sandbox_host_tests.rs @@ -19,7 +19,7 @@ use std::sync::mpsc::channel; use std::sync::{Arc, Mutex}; use common::new_uninit; -use hyperlight_host::sandbox::{Callable, SandboxConfiguration}; +use hyperlight_host::sandbox::SandboxConfiguration; use hyperlight_host::{ GuestBinary, HyperlightError, MultiUseSandbox, Result, UninitializedSandbox, new_error, }; @@ -85,9 +85,7 @@ fn float_roundtrip() { ]; let mut sandbox: MultiUseSandbox = new_uninit().unwrap().evolve().unwrap(); for f in doubles.iter() { - let res: f64 = sandbox - .call_guest_function_by_name("EchoDouble", *f) - .unwrap(); + let res: f64 = sandbox.call("EchoDouble", *f).unwrap(); assert!( res.total_cmp(f).is_eq(), @@ -97,9 +95,7 @@ fn float_roundtrip() { ); } for f in floats.iter() { - let res: f32 = sandbox - .call_guest_function_by_name("EchoFloat", *f) - .unwrap(); + let res: f32 = sandbox.call("EchoFloat", *f).unwrap(); assert!( res.total_cmp(f).is_eq(), @@ -115,7 +111,7 @@ fn float_roundtrip() { fn invalid_guest_function_name() { for mut sandbox in get_simpleguest_sandboxes(None).into_iter() { let fn_name = "FunctionDoesntExist"; - let res = sandbox.call_guest_function_by_name::(fn_name, ()); + let res = sandbox.call::(fn_name, ()); println!("{:?}", res); assert!( matches!(res.unwrap_err(), HyperlightError::GuestError(hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode::GuestFunctionNotFound, error_name) if error_name == fn_name) @@ -128,7 +124,7 @@ fn invalid_guest_function_name() { fn set_static() { for mut sandbox in get_simpleguest_sandboxes(None).into_iter() { let fn_name = "SetStatic"; - let res = sandbox.call_guest_function_by_name::(fn_name, ()); + let res = sandbox.call::(fn_name, ()); println!("{:?}", res); assert!(res.is_ok()); // the result is the size of the static array in the guest @@ -162,7 +158,7 @@ fn multiple_parameters() { macro_rules! test_case { ($sandbox:ident, $rx:ident, $name:literal, ($($p:ident),+)) => {{ let ($($p),+, ..) = args.clone(); - let res: i32 = $sandbox.call_guest_function_by_name($name, ($($p.0,)+)).unwrap(); + let res: i32 = $sandbox.call($name, ($($p.0,)+)).unwrap(); println!("{res:?}"); let output = $rx.try_recv().unwrap(); println!("{output:?}"); @@ -188,7 +184,7 @@ fn multiple_parameters() { #[cfg_attr(target_os = "windows", serial)] // using LoadLibrary requires serial tests fn incorrect_parameter_type() { for mut sandbox in get_simpleguest_sandboxes(None) { - let res = sandbox.call_guest_function_by_name::( + let res = sandbox.call::( "Echo", 2_i32, // should be string ); @@ -206,7 +202,7 @@ fn incorrect_parameter_type() { #[cfg_attr(target_os = "windows", serial)] // using LoadLibrary requires serial tests fn incorrect_parameter_num() { for mut sandbox in get_simpleguest_sandboxes(None).into_iter() { - let res = sandbox.call_guest_function_by_name::("Echo", ("1".to_string(), 2_i32)); + let res = sandbox.call::("Echo", ("1".to_string(), 2_i32)); assert!(matches!( res.unwrap_err(), HyperlightError::GuestError( @@ -237,7 +233,7 @@ fn max_memory_sandbox() { fn iostack_is_working() { for mut sandbox in get_simpleguest_sandboxes(None).into_iter() { let res: i32 = sandbox - .call_guest_function_by_name::("ThisIsNotARealFunctionButTheNameIsImportant", ()) + .call::("ThisIsNotARealFunctionButTheNameIsImportant", ()) .unwrap(); assert_eq!(res, 99); } @@ -260,19 +256,15 @@ fn simple_test_helper() -> Result<()> { let message2 = "world"; for mut sandbox in get_simpleguest_sandboxes(Some(writer.into())).into_iter() { - let res: i32 = sandbox - .call_guest_function_by_name("PrintOutput", message.to_string()) - .unwrap(); + let res: i32 = sandbox.call("PrintOutput", message.to_string()).unwrap(); assert_eq!(res, 5); - let res: String = sandbox - .call_guest_function_by_name("Echo", message2.to_string()) - .unwrap(); + let res: String = sandbox.call("Echo", message2.to_string()).unwrap(); assert_eq!(res, "world"); let buffer = [1u8, 2, 3, 4, 5, 6]; let res: Vec = sandbox - .call_guest_function_by_name("GetSizePrefixedBuffer", buffer.to_vec()) + .call("GetSizePrefixedBuffer", buffer.to_vec()) .unwrap(); assert_eq!(res, buffer); } @@ -332,7 +324,7 @@ fn callback_test_helper() -> Result<()> { // call guest function that calls host function let mut init_sandbox: MultiUseSandbox = sandbox.evolve()?; let msg = "Hello world"; - init_sandbox.call_guest_function_by_name::("GuestMethod1", msg.to_string())?; + init_sandbox.call::("GuestMethod1", msg.to_string())?; let messages = rx.try_iter().collect::>(); assert_eq!(messages, [format!("Hello from GuestFunction1, {msg}")]); @@ -375,7 +367,7 @@ fn host_function_error() -> Result<()> { let mut init_sandbox: MultiUseSandbox = sandbox.evolve()?; let msg = "Hello world"; let res = init_sandbox - .call_guest_function_by_name::("GuestMethod1", msg.to_string()) + .call::("GuestMethod1", msg.to_string()) .unwrap_err(); assert!(matches!(res, HyperlightError::Error(msg) if msg == "Host function error!")); }