From d50fa7470f978302f3d2cc9162269d8f9306b473 Mon Sep 17 00:00:00 2001 From: Roman Kisel Date: Tue, 8 Jul 2025 11:28:44 -0700 Subject: [PATCH] host file access --- Cargo.lock | 17 + Cargo.toml | 2 + openhcl/diag_client/Cargo.toml | 1 + openhcl/diag_client/src/lib.rs | 21 + openhcl/diag_proto/src/diag.proto | 6 + openhcl/diag_server/src/diag_service.rs | 37 ++ openhcl/ohcldiag-dev/Cargo.toml | 2 + openhcl/ohcldiag-dev/src/main.rs | 53 ++ openhcl/underhill_core/Cargo.toml | 2 + openhcl/underhill_core/src/dispatch/mod.rs | 35 ++ openhcl/underhill_core/src/lib.rs | 20 + support/host_file_access/Cargo.toml | 18 + support/host_file_access/src/lib.rs | 539 +++++++++++++++++++++ 13 files changed, 753 insertions(+) create mode 100644 support/host_file_access/Cargo.toml create mode 100644 support/host_file_access/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 2a59a0d297..0e12eb9187 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1041,6 +1041,7 @@ dependencies = [ "fs-err", "futures", "guid", + "host_file_access", "inspect", "inspect_proto", "mesh_rpc", @@ -2882,6 +2883,18 @@ dependencies = [ "tracing", ] +[[package]] +name = "host_file_access" +version = "0.0.0" +dependencies = [ + "bitfield-struct 0.10.1", + "futures", + "open_enum", + "thiserror 2.0.12", + "tracing", + "zerocopy 0.8.24", +] + [[package]] name = "http" version = "1.3.1" @@ -4757,6 +4770,7 @@ dependencies = [ "fs-err", "futures", "futures-concurrency", + "host_file_access", "inspect", "kmsg", "mesh", @@ -4767,6 +4781,7 @@ dependencies = [ "thiserror 2.0.12", "tracing-subscriber", "unicycle", + "zerocopy 0.8.24", ] [[package]] @@ -7521,6 +7536,7 @@ dependencies = [ "guid", "hcl", "hcl_compat_uefi_nvram_storage", + "host_file_access", "hvdef", "hyperv_ic_guest", "hyperv_ic_resources", @@ -7568,6 +7584,7 @@ dependencies = [ "serde_helpers", "serde_json", "serial_16550_resources", + "sha2", "socket2", "sparse_mmap", "state_unit", diff --git a/Cargo.toml b/Cargo.toml index 7b3397be98..8cf706bbd4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ members = [ # fuzzing "support/inspect/fuzz", "support/mesh/mesh_rpc/fuzz", + "support/host_file_access", "support/sparse_mmap/fuzz", "support/ucs2/fuzz", "vm/devices/chipset/fuzz", @@ -140,6 +141,7 @@ safeatomic = { path = "support/safeatomic" } serde_helpers = { path = "support/serde_helpers" } sev_guest_device = { path = "support/sev_guest_device" } sparse_mmap = { path = "support/sparse_mmap" } +host_file_access = { path = "support/host_file_access" } task_control = { path = "support/task_control" } tdx_guest_device = { path = "support/tdx_guest_device" } tee_call = { path = "support/tee_call" } diff --git a/openhcl/diag_client/Cargo.toml b/openhcl/diag_client/Cargo.toml index c0c375ee73..15f885d13a 100644 --- a/openhcl/diag_client/Cargo.toml +++ b/openhcl/diag_client/Cargo.toml @@ -16,6 +16,7 @@ inspect.workspace = true mesh_rpc.workspace = true unix_socket.workspace = true pal_async.workspace = true +host_file_access.workspace = true vmsocket.workspace = true anyhow.workspace = true diff --git a/openhcl/diag_client/src/lib.rs b/openhcl/diag_client/src/lib.rs index 6e49089de4..5e8fa56350 100644 --- a/openhcl/diag_client/src/lib.rs +++ b/openhcl/diag_client/src/lib.rs @@ -772,6 +772,27 @@ impl DiagClient { Ok(state.data) } + + /// Connect to the synthetic data endpoint. + pub async fn synthetic_data( + &self, + id: &String, + ) -> anyhow::Result> { + let (conn, socket) = self.connect_data().await?; + self.ttrpc + .call() + .start( + diag_proto::UnderhillDiag::HostFile, + diag_proto::HostFileRequest { + id: id.clone(), + conn, + }, + ) + .await + .map_err(grpc_status)?; + + Ok(socket) + } } fn grpc_status(status: Status) -> anyhow::Error { diff --git a/openhcl/diag_proto/src/diag.proto b/openhcl/diag_proto/src/diag.proto index 08cc2df11b..bf6678bb44 100644 --- a/openhcl/diag_proto/src/diag.proto +++ b/openhcl/diag_proto/src/diag.proto @@ -29,6 +29,7 @@ service UnderhillDiag { rpc ReadFile(FileRequest) returns (google.protobuf.Empty); rpc DumpSavedState(google.protobuf.Empty) returns (DumpSavedStateResponse); rpc PacketCapture(NetworkPacketCaptureRequest) returns (NetworkPacketCaptureResponse); + rpc HostFile(HostFileRequest) returns (google.protobuf.Empty); } message ExecRequest { @@ -106,3 +107,8 @@ message NetworkPacketCaptureResponse { message CrashRequest { int32 pid = 1; } + +message HostFileRequest { + string id = 1; + uint64 conn = 2; +} diff --git a/openhcl/diag_server/src/diag_service.rs b/openhcl/diag_server/src/diag_service.rs index 2d96f41455..3db089122e 100644 --- a/openhcl/diag_server/src/diag_service.rs +++ b/openhcl/diag_server/src/diag_service.rs @@ -12,6 +12,7 @@ use diag_proto::ExecRequest; use diag_proto::ExecResponse; use diag_proto::FILE_LINE_MAX; use diag_proto::FileRequest; +use diag_proto::HostFileRequest; use diag_proto::KmsgRequest; use diag_proto::NetworkPacketCaptureRequest; use diag_proto::NetworkPacketCaptureResponse; @@ -84,6 +85,11 @@ pub enum DiagRequest { Resume(FailableRpc<(), ()>), /// Save VTL2 state Save(FailableRpc<(), Vec>), + /// Get the `vmlinux` image from the host. + /// Just experimental, not used in production -- must be signed, even better come inside an IGVM. + VmLinux(FailableRpc), + /// Write the log from the host. + SomeLog(FailableRpc), /// Setup network trace PacketCapture(FailableRpc, PacketCaptureParams>), /// Profile VTL2 @@ -240,6 +246,10 @@ impl DiagServiceHandler { UnderhillDiag::DumpSavedState((), response) => response.send(grpc_result( ctx.until_cancelled(self.handle_dump_saved_state()).await, )), + UnderhillDiag::HostFile(request, response) => response.send(grpc_result( + ctx.until_cancelled(self.handle_host_file_access(driver, &request)) + .await, + )), } } @@ -581,6 +591,33 @@ impl DiagServiceHandler { .await } + async fn handle_host_file_access( + &self, + _driver: &(impl Driver + Spawn + Clone), + request: &HostFileRequest, + ) -> anyhow::Result<()> { + let params = self.take_connection(request.conn).await?; + match request.id.as_str() { + "vmlinux" => { + tracing::info!("Reading vmlinux from the host"); + self.request_send + .call_failable(DiagRequest::VmLinux, params.into_inner()) + .await?; + } + "log" => { + tracing::info!("Wring the log from the host"); + self.request_send + .call_failable(DiagRequest::SomeLog, params.into_inner()) + .await?; + } + _ => { + tracing::warn!("Unsupported host file access request: {}", request.id); + } + } + + Ok(()) + } + async fn handle_packet_capture( &self, request: &NetworkPacketCaptureRequest, diff --git a/openhcl/ohcldiag-dev/Cargo.toml b/openhcl/ohcldiag-dev/Cargo.toml index c4070fee02..093bf83a93 100644 --- a/openhcl/ohcldiag-dev/Cargo.toml +++ b/openhcl/ohcldiag-dev/Cargo.toml @@ -14,6 +14,7 @@ inspect.workspace = true mesh.workspace = true pal_async.workspace = true pal.workspace = true +host_file_access.workspace = true term.workspace = true anyhow.workspace = true @@ -27,6 +28,7 @@ socket2.workspace = true thiserror.workspace = true tracing-subscriber = { workspace = true, features = ["env-filter"] } unicycle.workspace = true +zerocopy.workspace = true [lints] workspace = true diff --git a/openhcl/ohcldiag-dev/src/main.rs b/openhcl/ohcldiag-dev/src/main.rs index 6276895914..9a5b01dc2d 100644 --- a/openhcl/ohcldiag-dev/src/main.rs +++ b/openhcl/ohcldiag-dev/src/main.rs @@ -157,6 +157,24 @@ enum Command { #[clap(long, requires = "serial")] pipe_path: Option, }, + /// Requests data by the string ID, providing the ability to download + /// synthetized data such that requires seeks within the file. + HostFile { + /// The ID of the data to synthesize. + #[clap(short, long)] + id: String, + /// The output file path. + dst: PathBuf, + /// Maximum data size the host allows. + #[clap(short, long)] + size_limit: Option, + /// Allow existing file. + #[clap(short, long, default_value = "false")] + existing: bool, + /// Allow writing to the file. + #[clap(short, long, default_value = "false")] + write: bool, + }, /// Writes the contents of the file. File { /// Keep waiting for and writing new data as its logged. @@ -873,6 +891,41 @@ pub fn main() -> anyhow::Result<()> { let mut file = create_or_stderr(&output)?; file.write_all(&client.dump_saved_state().await?)?; } + Command::HostFile { + id, + dst, + size_limit, + existing, + write, + } => { + let medium = fs_err::OpenOptions::new() + .read(true) + .write(write) + .create(!existing) + .open(dst) + .context("failed to open file")?; + let client = new_client(driver.clone(), &vm)?; + let transport = client.synthetic_data(&id).await?; + + let mut data_stor = + host_file_access::HostFileStorage::new(medium, size_limit.into()); + match data_stor.run_async(transport).await { + Ok(_) => {} + Err(host_file_access::HostFileError::EndOfFile) => { + eprintln!("Synthetic data {id} transfer complete, end of file reached."); + } + Err(e) => { + eprintln!("Failed to run data transfer: {e}"); + return Err(anyhow::Error::from(e)); + } + } + + let bytes_written = data_stor.bytes_written(); + println!( + "Synthetic data {id} transfer complete, {} bytes written.", + bytes_written + ); + } } Ok(()) }) diff --git a/openhcl/underhill_core/Cargo.toml b/openhcl/underhill_core/Cargo.toml index da4ed2d7bf..1851fe6ab6 100644 --- a/openhcl/underhill_core/Cargo.toml +++ b/openhcl/underhill_core/Cargo.toml @@ -59,6 +59,7 @@ hcl_compat_uefi_nvram_storage = { workspace = true, features = ["inspect", "save get_helpers.workspace = true get_protocol.workspace = true guest_emulation_transport.workspace = true +host_file_access.workspace = true ide.workspace = true ide_resources.workspace = true input_core.workspace = true @@ -162,6 +163,7 @@ parking_lot.workspace = true serde = { workspace = true, features = ["derive"] } serde_helpers.workspace = true serde_json.workspace = true +sha2 = { workspace = true, features = ["std"] } socket2.workspace = true thiserror = { workspace = true, features = ["std"] } time = { workspace = true, features = ["macros"] } diff --git a/openhcl/underhill_core/src/dispatch/mod.rs b/openhcl/underhill_core/src/dispatch/mod.rs index 83519832e1..9ca778c5dd 100644 --- a/openhcl/underhill_core/src/dispatch/mod.rs +++ b/openhcl/underhill_core/src/dispatch/mod.rs @@ -47,10 +47,13 @@ use openhcl_dma_manager::OpenhclDmaManager; use pal_async::task::Spawn; use pal_async::task::Task; use parking_lot::Mutex; +use sha2::Digest; use socket2::Socket; use state_unit::SavedStateUnit; use state_unit::SpawnedUnit; use state_unit::StateUnits; +use std::io::Read; +use std::io::Write; use std::sync::Arc; use std::time::Duration; use tracing::Instrument; @@ -78,6 +81,8 @@ pub enum UhVmRpc { Pause(Rpc<(), bool>), Resume(Rpc<(), bool>), Save(FailableRpc<(), Vec>), + VmLinux(FailableRpc), + SomeLog(FailableRpc), ClearHalt(Rpc<(), bool>), // TODO: remove this, and use DebugRequest::Resume PacketCapture(FailableRpc, PacketCaptureParams>), } @@ -355,6 +360,36 @@ impl LoadedVm { }) .await } + UhVmRpc::VmLinux(rpc) => { + tracing::info!(CVM_ALLOWED, "reading vmlinux from the host"); + + pal_async::local::block_with_io(async |_| { + rpc.handle_failable::<_, anyhow::Error>(async |socket| { + let mut hfa = host_file_access::HostFileAccess::new(socket); + let mut vmlinux = vec![]; + hfa.read_to_end(&mut vmlinux)?; + + let mut hasher = sha2::Sha256::new(); + hasher.update(&vmlinux); + let result = hasher.finalize(); + tracing::info!(CVM_ALLOWED, "vmlinux sha256: {:x?}", result); + Ok(()) + }) + .await + }); + } + UhVmRpc::SomeLog(rpc) => { + tracing::info!(CVM_ALLOWED, "writing log to the host"); + + pal_async::local::block_with_io(async |_| { + rpc.handle_failable::<_, anyhow::Error>(async |socket| { + let mut hfa = host_file_access::HostFileAccess::new(socket); + hfa.write_all(b"Hello from underhill!")?; + Ok(()) + }) + .await + }); + } }, Event::ServicingRequest(message) => { // Explicitly destructure the message for easier tracking of its changes. diff --git a/openhcl/underhill_core/src/lib.rs b/openhcl/underhill_core/src/lib.rs index e03205b2e8..7ce6995d64 100644 --- a/openhcl/underhill_core/src/lib.rs +++ b/openhcl/underhill_core/src/lib.rs @@ -604,6 +604,26 @@ async fn run_control( }) .detach(); } + diag_server::DiagRequest::VmLinux(rpc) => { + tracing::info!(CVM_ALLOWED, "reading vmlinux from the host"); + let Some(workers) = &mut workers else { + rpc.complete(Err(RemoteError::new(anyhow::anyhow!( + "worker has not been started yet" + )))); + continue; + }; + workers.vm_rpc.send(UhVmRpc::VmLinux(rpc)); + } + diag_server::DiagRequest::SomeLog(rpc) => { + tracing::info!(CVM_ALLOWED, "writing log to the host"); + let Some(workers) = &mut workers else { + rpc.complete(Err(RemoteError::new(anyhow::anyhow!( + "worker has not been started yet" + )))); + continue; + }; + workers.vm_rpc.send(UhVmRpc::SomeLog(rpc)); + } diag_server::DiagRequest::PacketCapture(rpc) => { let Some(workers) = &mut workers else { rpc.complete(Err(RemoteError::new(anyhow::anyhow!( diff --git a/support/host_file_access/Cargo.toml b/support/host_file_access/Cargo.toml new file mode 100644 index 0000000000..7e9054dde7 --- /dev/null +++ b/support/host_file_access/Cargo.toml @@ -0,0 +1,18 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +[package] +name = "host_file_access" +edition.workspace = true +rust-version.workspace = true + +[dependencies] +bitfield-struct.workspace = true +futures.workspace = true +open_enum.workspace = true +thiserror.workspace = true +tracing.workspace = true +zerocopy.workspace = true + +[lints] +workspace = true diff --git a/support/host_file_access/src/lib.rs b/support/host_file_access/src/lib.rs new file mode 100644 index 0000000000..b2d226581d --- /dev/null +++ b/support/host_file_access/src/lib.rs @@ -0,0 +1,539 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Simple transfer protocol with support for seeks. +//! +//! The forward seeks beyond the size of the file might be platform dependendant. +//! Normally, the OS extends the file thus providing a way to truncate the file +//! to the desired size fast. +//! +//! Each transfer starts with the header packetthat provides the seek position, +//! size and the dierction. It might be followed by the data bytes sent in the +//! either direction. The sentinel value of `-1i128` for the header designates the +//! end of the file. Any transfer after that will return an error. + +use bitfield_struct::bitfield; +use futures::AsyncReadExt; +use futures::AsyncWriteExt; +use open_enum::open_enum; +use std::io::Read; +use std::io::Seek; +use std::io::Write; +use zerocopy::FromBytes; +use zerocopy::Immutable; +use zerocopy::IntoBytes; + +open_enum! { + ///! Represents the position from which to seek in the medium. + #[derive(IntoBytes, FromBytes, Immutable)] + pub(crate) enum SeekPosition: u8 { + ///! Seek relative to the current position. + CURRENT = 0b00, + ///! Seek relative to the start of the medium. + START = 0b01, + ///! Seek relative to the end of the medium. + END = 0b10, + } +} + +impl SeekPosition { + const fn into_bits(self) -> u8 { + self.0 + } + + const fn from_bits(bits: u8) -> Self { + Self(bits) + } +} + +open_enum! { + ///! Represents the direction of the data transfer. + #[derive(IntoBytes, FromBytes, Immutable)] + pub(crate) enum HostFileOperation: u8 { + ///! Write data to the medium. + WRITE = 0, + ///! Read data from the medium. + READ = 1, + } +} + +impl HostFileOperation { + const fn into_bits(self) -> u8 { + self.0 + } + + const fn from_bits(bits: u8) -> Self { + Self(bits) + } +} + +///! Represents the transport header. +#[derive(IntoBytes, FromBytes, Immutable)] +#[bitfield(u128)] +pub(crate) struct TransportHeader { + #[bits(62)] + pub seek_amount: i64, + #[bits(2)] + pub seek_pos: SeekPosition, + pub data_size: u32, + #[bits(1)] + pub direction: HostFileOperation, + #[bits(31)] + _reserved1: u32, +} + +impl TransportHeader { + ///! Checks if the header represents the end of the file. + pub fn is_eof(&self) -> bool { + self.into_bits() == Self::eof().into_bits() + } + + ///! Creates the end of file header. + pub fn eof() -> Self { + Self::from_bits(-1i128 as u128) + } + + ///! Checks if the header requests flushing the medium. + pub fn is_flush(&self) -> bool { + self.into_bits() == Self::flush().into_bits() + } + + ///! Creates a new transport header to request flushing the medium. + pub fn flush() -> Self { + Self::from_bits(0) + } +} + +///! Errors that can occur during host data operations. +#[derive(Debug, thiserror::Error)] +pub enum HostFileError { + ///! End of file reached. + #[error("end of file reached")] + EndOfFile, + ///! Invalid seek position specified in the header. + #[error("invalid seek position specified in the header")] + InvalidSeekPosition, + ///! Invalid data size specified in the header. + #[error("invalid data size specified in the header")] + InvalidDataSize, + ///! Write limit exceeded. + #[error("write limit exceeded")] + WriteLimitExceeded, + ///! Invalid direction specified in the header. + #[error("invalid direction specified in the header")] + InvalidDirection, + ///! An I/O error occurred during the operation. + #[error("I/O error occurred during the operation")] + IoError(#[source] std::io::Error), +} + +///! Provides input for the host data operations. +pub(crate) enum HostData<'a> { + ///! Data to be written to the medium. + Write(&'a [u8]), + ///! Buffer to read data into from the medium. + Read(&'a mut [u8]), +} + +///! Represents the write limit for the data storage. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum WriteLimit { + ///! No write limit. + None, + ///! A specific write limit in bytes. + Limit(usize), +} + +impl Into for WriteLimit { + fn into(self) -> usize { + match self { + WriteLimit::None => usize::MAX, + WriteLimit::Limit(limit) => limit, + } + } +} + +impl Into> for WriteLimit { + fn into(self) -> Option { + match self { + WriteLimit::None => None, + WriteLimit::Limit(limit) => Some(limit), + } + } +} + +impl From> for WriteLimit { + fn from(limit: Option) -> Self { + match limit { + Some(size) => WriteLimit::Limit(size), + None => WriteLimit::None, + } + } +} + +///! A storage medium for the data operations. +pub struct HostFileStorage { + medium: M, + write_limit: WriteLimit, + bytes_read: usize, + bytes_written: usize, + eof: bool, +} + +impl HostFileStorage { + ///! Creates a new `HostFileStorage` with the given medium. + pub fn new(medium: M, write_limit: WriteLimit) -> Self { + Self { + medium, + write_limit, + bytes_read: 0, + bytes_written: 0, + eof: false, + } + } + + ///! Returns the number of bytes read so far. + pub fn bytes_read(&self) -> usize { + self.bytes_read + } + + ///! Returns the number of bytes written so far. + pub fn bytes_written(&self) -> usize { + self.bytes_written + } + + ///! Performs a data transfer with the given header and data. + ///! Returns the number of bytes transferred or an error if the operation failed. + pub(crate) fn transfer( + &mut self, + header: TransportHeader, + data: HostData<'_>, + ) -> Result { + if self.eof { + tracing::debug!("Transfer requested after EOF reached"); + return Err(HostFileError::EndOfFile); + } + + if header.is_eof() { + tracing::debug!("End of file header received"); + self.eof = true; + return Err(HostFileError::EndOfFile); + } else if header.is_flush() { + tracing::debug!("Flush header received"); + self.medium.flush().map_err(HostFileError::IoError)?; + return Ok(0); + } + + tracing::debug!("Transfer header {header:?}"); + + if let WriteLimit::Limit(limit) = self.write_limit { + if self.bytes_written + header.data_size() as usize > limit { + tracing::debug!( + "Write limit exceeded: {} bytes written, limit is {} bytes", + self.bytes_written, + limit + ); + return Err(HostFileError::WriteLimitExceeded); + } + } + + let seek_amount = header.seek_amount(); + if seek_amount != 0 { + let seek_from = match header.seek_pos() { + SeekPosition::CURRENT => std::io::SeekFrom::Current(seek_amount as i64), + SeekPosition::START => std::io::SeekFrom::Start(seek_amount as u64), + SeekPosition::END => std::io::SeekFrom::End(seek_amount as i64), + _ => return Err(HostFileError::InvalidSeekPosition), + }; + + self.medium + .seek(seek_from) + .map_err(HostFileError::IoError)?; + } + + let bytes_transferred = match header.direction() { + HostFileOperation::WRITE => { + if let HostData::Write(bytes_to_write) = data { + if bytes_to_write.len() != header.data_size() as usize { + return Err(HostFileError::InvalidDataSize); + } + + self.medium + .write_all(bytes_to_write) + .map_err(HostFileError::IoError)?; + self.bytes_written += bytes_to_write.len(); + + bytes_to_write.len() + } else { + return Err(HostFileError::InvalidDirection); + } + } + HostFileOperation::READ => { + if let HostData::Read(buffer) = data { + if buffer.len() != header.data_size() as usize { + return Err(HostFileError::InvalidDataSize); + } + + let bytes_read = self.medium.read(buffer).map_err(HostFileError::IoError)?; + self.bytes_read += bytes_read; + + bytes_read + } else { + return Err(HostFileError::InvalidDirection); + } + } + _ => return Err(HostFileError::InvalidDirection), + }; + + tracing::debug!( + "Transfer completed: {} bytes transferred, totals: {} bytes written, {} bytes read", + bytes_transferred, + self.bytes_written, + self.bytes_read + ); + + Ok(bytes_transferred) + } + + ///! Runs the data operations on the provided transport. + pub fn run(&mut self, mut transport: T) -> Result<(), HostFileError> { + if self.eof { + return Err(HostFileError::EndOfFile); + } + + loop { + let mut header = TransportHeader::eof(); + transport + .read_exact(header.as_mut_bytes()) + .map_err(HostFileError::IoError)?; + + if header.is_eof() { + self.eof = true; + return Ok(()); + } + + let mut buf = vec![0; header.data_size() as usize]; + match header.direction() { + HostFileOperation::READ => { + let bytes_read = self.transfer(header, HostData::Read(&mut buf))?; + transport + .write_all((bytes_read as u128).as_bytes()) + .map_err(HostFileError::IoError)?; + transport + .write_all(&buf[..bytes_read]) + .map_err(HostFileError::IoError)?; + } + HostFileOperation::WRITE => { + transport + .read_exact(&mut buf) + .map_err(HostFileError::IoError)?; + self.transfer(header, HostData::Write(&buf))?; + } + _ => return Err(HostFileError::InvalidDirection), + } + } + } + + ///! Runs the data operations asynchronously on the provided transport. + pub async fn run_async( + &mut self, + mut transport: T, + ) -> Result<(), HostFileError> { + if self.eof { + return Err(HostFileError::EndOfFile); + } + + loop { + let mut header = TransportHeader::eof(); + transport + .read_exact(header.as_mut_bytes()) + .await + .map_err(HostFileError::IoError)?; + + if header.is_eof() { + self.eof = true; + return Ok(()); + } + + let mut buf = vec![0; header.data_size() as usize]; + match header.direction() { + HostFileOperation::READ => { + let bytes_read = self.transfer(header, HostData::Read(&mut buf))?; + transport + .write_all((bytes_read as u128).as_bytes()) + .await + .map_err(HostFileError::IoError)?; + transport + .write_all(&buf[..bytes_read]) + .await + .map_err(HostFileError::IoError)?; + } + HostFileOperation::WRITE => { + transport + .read_exact(&mut buf) + .await + .map_err(HostFileError::IoError)?; + self.transfer(header, HostData::Write(&buf))?; + } + _ => return Err(HostFileError::InvalidDirection), + } + } + } +} + +///! A wrapper around a transport that provides file-like access. +pub struct HostFileAccess { + transport: T, +} + +impl HostFileAccess { + ///! Creates a new `HostFileAccess` with the given transport. + pub fn new(transport: T) -> Self { + Self { transport } + } +} + +impl Drop for HostFileAccess { + fn drop(&mut self) { + let header = TransportHeader::eof(); + if let Err(e) = self.transport.write_all(header.as_bytes()) { + tracing::error!("Failed to write EOF header: {}", e); + } + } +} + +impl Write for HostFileAccess { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + tracing::debug!("Writing {} bytes to host file access", buf.len()); + + let header = TransportHeader::new() + .with_seek_amount(0) + .with_seek_pos(SeekPosition::START) + .with_data_size(buf.len() as u32) + .with_direction(HostFileOperation::WRITE); + + self.transport + .write_all(header.as_bytes()) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + self.transport + .write_all(buf) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + tracing::debug!("Flushing host file access"); + + let header = TransportHeader::flush(); + self.transport + .write_all(header.as_bytes()) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + Ok(()) + } +} + +impl Read for HostFileAccess { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + tracing::debug!("Reading {} bytes from host file access", buf.len()); + + let header = TransportHeader::new() + .with_seek_amount(0) + .with_seek_pos(SeekPosition::START) + .with_data_size(buf.len() as u32) + .with_direction(HostFileOperation::READ); + + self.transport + .write_all(header.as_bytes()) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + + let mut size_header = [0; 16]; + self.transport + .read_exact(&mut size_header) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + let size = u128::from_le_bytes(size_header) as usize; + + tracing::debug!("{} bytes available", size); + + if size == 0 { + tracing::debug!("End of file reached"); + return Ok(0); + } + if size > buf.len() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Buffer too small for read operation", + )); + } + + self.transport + .read_exact(buf[..size].as_mut()) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + Ok(size) + } +} + +impl Seek for HostFileAccess { + fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result { + tracing::debug!("Seeking in host file access to {:?}", pos); + + let seek_amount = match pos { + std::io::SeekFrom::Start(offset) => offset as i64, + std::io::SeekFrom::End(offset) => offset as i64, + std::io::SeekFrom::Current(offset) => offset as i64, + }; + + let header = TransportHeader::new() + .with_seek_amount(seek_amount) + .with_seek_pos(SeekPosition::START) + .with_data_size(0) + .with_direction(HostFileOperation::READ); + + self.transport + .write_all(header.as_bytes()) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + Ok(seek_amount as u64) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Cursor; + + #[test] + fn test_host_data_storage() { + // Write data to the storage + let data = b"Hello, world!"; + let cursor = Cursor::new(Vec::new()); + let mut storage = HostFileStorage::new(cursor, WriteLimit::None); + let header = TransportHeader::new() + .with_seek_amount(0) + .with_seek_pos(SeekPosition::START) + .with_data_size(data.len() as u32) + .with_direction(HostFileOperation::WRITE); + let result = storage.transfer(header, HostData::Write(data)); + assert!(result.is_ok()); + assert_eq!(storage.bytes_written(), data.len()); + assert_eq!(storage.bytes_read(), 0); + assert!(!storage.eof); + + // Read data from the storage + let mut buffer = vec![0; data.len()]; + let header = TransportHeader::new() + .with_seek_amount(0) + .with_seek_pos(SeekPosition::START) + .with_data_size(data.len() as u32) + .with_direction(HostFileOperation::READ); + let result = storage.transfer(header, HostData::Read(&mut buffer)); + assert!(result.is_ok()); + assert_eq!(buffer, data); + assert_eq!(storage.bytes_written(), data.len()); + assert_eq!(storage.bytes_read(), data.len()); + assert!(!storage.eof); + + // Test end of file + let eof_header = TransportHeader::eof(); + let result = storage.transfer(eof_header, HostData::Write(&[])); + assert!(matches!(result, Err(HostFileError::EndOfFile))); + } +}