Skip to content

Commit ab34663

Browse files
committed
move shaders to separate mod
1 parent 0e026d4 commit ab34663

File tree

7 files changed

+215
-186
lines changed

7 files changed

+215
-186
lines changed

tests/difftests/lib/src/scaffold/compute/backend.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::config::Config;
2+
use crate::scaffold::shader::SpirvShader;
23
use anyhow::Result;
34

45
/// Configuration for a GPU buffer
@@ -9,6 +10,25 @@ pub struct BufferConfig {
910
pub initial_data: Option<Vec<u8>>,
1011
}
1112

13+
impl BufferConfig {
14+
pub fn writeback(size: usize) -> Self {
15+
Self {
16+
size: size as u64,
17+
usage: BufferUsage::Storage,
18+
initial_data: None,
19+
}
20+
}
21+
22+
pub fn read_only<A: bytemuck::NoUninit>(slice: &[A]) -> Self {
23+
let vec = bytemuck::cast_slice(slice).to_vec();
24+
Self {
25+
size: vec.len() as u64,
26+
usage: BufferUsage::StorageReadOnly,
27+
initial_data: Some(vec),
28+
}
29+
}
30+
}
31+
1232
/// Buffer usage type
1333
#[derive(Clone, Copy, PartialEq)]
1434
pub enum BufferUsage {
@@ -17,8 +37,6 @@ pub enum BufferUsage {
1737
Uniform,
1838
}
1939

20-
use super::SpirvShader;
21-
2240
/// A generic trait for compute backends
2341
pub trait ComputeBackend: Sized {
2442
/// Initialize the backend

tests/difftests/lib/src/scaffold/compute/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ mod ash;
22
mod backend;
33
mod wgpu;
44

5+
pub use crate::scaffold::shader::*;
56
pub use ash::AshBackend;
67
pub use backend::{BufferConfig, BufferUsage, ComputeBackend, ComputeShaderTest, ComputeTest};
78
pub use wgpu::{
8-
RustComputeShader, SpirvShader, WgpuBackend, WgpuComputeTest, WgpuComputeTestMultiBuffer,
9-
WgpuComputeTestPushConstants, WgpuShader, WgslComputeShader,
9+
WgpuBackend, WgpuComputeTest, WgpuComputeTestMultiBuffer, WgpuComputeTestPushConstants,
1010
};

tests/difftests/lib/src/scaffold/compute/wgpu.rs

Lines changed: 5 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -1,152 +1,17 @@
1+
use super::backend::{self, ComputeBackend};
12
use crate::config::Config;
3+
use crate::scaffold::shader::RustComputeShader;
4+
use crate::scaffold::shader::WgpuShader;
5+
use crate::scaffold::shader::WgslComputeShader;
26
use anyhow::Context;
37
use bytemuck::Pod;
48
use futures::executor::block_on;
5-
use spirv_builder::{ModuleResult, SpirvBuilder};
6-
use std::{
7-
borrow::Cow,
8-
env,
9-
fs::{self, File},
10-
io::Write,
11-
path::PathBuf,
12-
sync::Arc,
13-
};
9+
use std::{borrow::Cow, fs::File, io::Write, sync::Arc};
1410
use wgpu::{PipelineCompilationOptions, util::DeviceExt};
1511

16-
use super::backend::{self, ComputeBackend};
17-
1812
pub type BufferConfig = backend::BufferConfig;
1913
pub type BufferUsage = backend::BufferUsage;
2014

21-
/// Trait for shaders that can provide SPIRV bytes.
22-
pub trait SpirvShader {
23-
/// Returns the SPIRV bytes and entry point name.
24-
fn spirv_bytes(&self) -> anyhow::Result<(Vec<u8>, String)>;
25-
}
26-
27-
/// Trait for shaders that can create wgpu modules.
28-
pub trait WgpuShader {
29-
/// Creates a wgpu shader module.
30-
fn create_module(
31-
&self,
32-
device: &wgpu::Device,
33-
) -> anyhow::Result<(wgpu::ShaderModule, Option<String>)>;
34-
}
35-
36-
/// A compute shader written in Rust compiled with spirv-builder.
37-
pub struct RustComputeShader {
38-
pub path: PathBuf,
39-
pub target: String,
40-
pub capabilities: Vec<spirv_builder::Capability>,
41-
}
42-
43-
impl RustComputeShader {
44-
pub fn new<P: Into<PathBuf>>(path: P) -> Self {
45-
Self {
46-
path: path.into(),
47-
target: "spirv-unknown-vulkan1.1".to_string(),
48-
capabilities: Vec::new(),
49-
}
50-
}
51-
52-
pub fn with_target<P: Into<PathBuf>>(path: P, target: impl Into<String>) -> Self {
53-
Self {
54-
path: path.into(),
55-
target: target.into(),
56-
capabilities: Vec::new(),
57-
}
58-
}
59-
60-
pub fn with_capability(mut self, capability: spirv_builder::Capability) -> Self {
61-
self.capabilities.push(capability);
62-
self
63-
}
64-
}
65-
66-
impl SpirvShader for RustComputeShader {
67-
fn spirv_bytes(&self) -> anyhow::Result<(Vec<u8>, String)> {
68-
let mut builder = SpirvBuilder::new(&self.path, &self.target)
69-
.print_metadata(spirv_builder::MetadataPrintout::None)
70-
.release(true)
71-
.multimodule(false)
72-
.shader_panic_strategy(spirv_builder::ShaderPanicStrategy::SilentExit)
73-
.preserve_bindings(true);
74-
75-
for capability in &self.capabilities {
76-
builder = builder.capability(*capability);
77-
}
78-
79-
let artifact = builder.build().context("SpirvBuilder::build() failed")?;
80-
81-
if artifact.entry_points.len() != 1 {
82-
anyhow::bail!(
83-
"Expected exactly one entry point, found {}",
84-
artifact.entry_points.len()
85-
);
86-
}
87-
let entry_point = artifact.entry_points.into_iter().next().unwrap();
88-
89-
let shader_bytes = match artifact.module {
90-
ModuleResult::SingleModule(path) => fs::read(&path)
91-
.with_context(|| format!("reading spv file '{}' failed", path.display()))?,
92-
ModuleResult::MultiModule(_modules) => {
93-
anyhow::bail!("MultiModule modules produced");
94-
}
95-
};
96-
97-
Ok((shader_bytes, entry_point))
98-
}
99-
}
100-
101-
impl WgpuShader for RustComputeShader {
102-
fn create_module(
103-
&self,
104-
device: &wgpu::Device,
105-
) -> anyhow::Result<(wgpu::ShaderModule, Option<String>)> {
106-
let (shader_bytes, entry_point) = self.spirv_bytes()?;
107-
108-
if shader_bytes.len() % 4 != 0 {
109-
anyhow::bail!("SPIR-V binary length is not a multiple of 4");
110-
}
111-
let shader_words: Vec<u32> = bytemuck::cast_slice(&shader_bytes).to_vec();
112-
let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
113-
label: Some("Compute Shader"),
114-
source: wgpu::ShaderSource::SpirV(Cow::Owned(shader_words)),
115-
});
116-
Ok((module, Some(entry_point)))
117-
}
118-
}
119-
120-
/// A WGSL compute shader.
121-
pub struct WgslComputeShader {
122-
pub path: PathBuf,
123-
pub entry_point: Option<String>,
124-
}
125-
126-
impl WgslComputeShader {
127-
pub fn new<P: Into<PathBuf>>(path: P, entry_point: Option<String>) -> Self {
128-
Self {
129-
path: path.into(),
130-
entry_point,
131-
}
132-
}
133-
}
134-
135-
impl WgpuShader for WgslComputeShader {
136-
fn create_module(
137-
&self,
138-
device: &wgpu::Device,
139-
) -> anyhow::Result<(wgpu::ShaderModule, Option<String>)> {
140-
let shader_source = fs::read_to_string(&self.path)
141-
.with_context(|| format!("reading wgsl source file '{}'", &self.path.display()))?;
142-
let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
143-
label: Some("Compute Shader"),
144-
source: wgpu::ShaderSource::Wgsl(Cow::Owned(shader_source)),
145-
});
146-
Ok((module, self.entry_point.clone()))
147-
}
148-
}
149-
15015
/// Compute test that is generic over the shader type.
15116
pub struct WgpuComputeTest<S> {
15217
shader: S,
@@ -539,48 +404,6 @@ impl ComputeBackend for WgpuBackend {
539404
}
540405
}
541406

542-
/// For WGSL, the code checks for "shader.wgsl" then "compute.wgsl".
543-
impl Default for WgslComputeShader {
544-
fn default() -> Self {
545-
let manifest_dir = env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set");
546-
let manifest_path = PathBuf::from(manifest_dir);
547-
let shader_path = manifest_path.join("shader.wgsl");
548-
let compute_path = manifest_path.join("compute.wgsl");
549-
550-
let (file, source) = if shader_path.exists() {
551-
(
552-
shader_path.clone(),
553-
fs::read_to_string(&shader_path).unwrap_or_default(),
554-
)
555-
} else if compute_path.exists() {
556-
(
557-
compute_path.clone(),
558-
fs::read_to_string(&compute_path).unwrap_or_default(),
559-
)
560-
} else {
561-
panic!("No default WGSL shader found in manifest directory");
562-
};
563-
564-
let entry_point = if source.contains("fn main_cs(") {
565-
Some("main_cs".to_string())
566-
} else if source.contains("fn main(") {
567-
Some("main".to_string())
568-
} else {
569-
None
570-
};
571-
572-
Self::new(file, entry_point)
573-
}
574-
}
575-
576-
/// For the SPIR-V shader, the manifest directory is used as the build path.
577-
impl Default for RustComputeShader {
578-
fn default() -> Self {
579-
let manifest_dir = env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set");
580-
Self::new(PathBuf::from(manifest_dir))
581-
}
582-
}
583-
584407
impl<S> WgpuComputeTestMultiBuffer<S>
585408
where
586409
S: WgpuShader,
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
pub mod compute;
2+
pub mod shader;
23
pub mod skip;
34

45
pub use skip::Skip;
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
mod rust_gpu_shader;
2+
mod wgsl_shader;
3+
4+
pub use rust_gpu_shader::RustComputeShader;
5+
pub use wgsl_shader::WgslComputeShader;
6+
7+
/// Trait for shaders that can provide SPIRV bytes.
8+
pub trait SpirvShader {
9+
/// Returns the SPIRV bytes and entry point name.
10+
fn spirv_bytes(&self) -> anyhow::Result<(Vec<u8>, String)>;
11+
}
12+
13+
/// Trait for shaders that can create wgpu modules.
14+
pub trait WgpuShader {
15+
/// Creates a wgpu shader module.
16+
fn create_module(
17+
&self,
18+
device: &wgpu::Device,
19+
) -> anyhow::Result<(wgpu::ShaderModule, Option<String>)>;
20+
}
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
use crate::scaffold::shader::{SpirvShader, WgpuShader};
2+
use anyhow::Context;
3+
use spirv_builder::{ModuleResult, SpirvBuilder};
4+
use std::borrow::Cow;
5+
use std::path::PathBuf;
6+
use std::{env, fs};
7+
8+
/// A compute shader written in Rust compiled with spirv-builder.
9+
pub struct RustComputeShader {
10+
pub path: PathBuf,
11+
pub target: String,
12+
pub capabilities: Vec<spirv_builder::Capability>,
13+
}
14+
15+
impl RustComputeShader {
16+
pub fn new<P: Into<PathBuf>>(path: P) -> Self {
17+
Self {
18+
path: path.into(),
19+
target: "spirv-unknown-vulkan1.1".to_string(),
20+
capabilities: Vec::new(),
21+
}
22+
}
23+
24+
pub fn with_target<P: Into<PathBuf>>(path: P, target: impl Into<String>) -> Self {
25+
Self {
26+
path: path.into(),
27+
target: target.into(),
28+
capabilities: Vec::new(),
29+
}
30+
}
31+
32+
pub fn with_capability(mut self, capability: spirv_builder::Capability) -> Self {
33+
self.capabilities.push(capability);
34+
self
35+
}
36+
}
37+
38+
impl SpirvShader for RustComputeShader {
39+
fn spirv_bytes(&self) -> anyhow::Result<(Vec<u8>, String)> {
40+
let mut builder = SpirvBuilder::new(&self.path, &self.target)
41+
.print_metadata(spirv_builder::MetadataPrintout::None)
42+
.release(true)
43+
.multimodule(false)
44+
.shader_panic_strategy(spirv_builder::ShaderPanicStrategy::SilentExit)
45+
.preserve_bindings(true);
46+
47+
for capability in &self.capabilities {
48+
builder = builder.capability(*capability);
49+
}
50+
51+
let artifact = builder.build().context("SpirvBuilder::build() failed")?;
52+
53+
if artifact.entry_points.len() != 1 {
54+
anyhow::bail!(
55+
"Expected exactly one entry point, found {}",
56+
artifact.entry_points.len()
57+
);
58+
}
59+
let entry_point = artifact.entry_points.into_iter().next().unwrap();
60+
61+
let shader_bytes = match artifact.module {
62+
ModuleResult::SingleModule(path) => fs::read(&path)
63+
.with_context(|| format!("reading spv file '{}' failed", path.display()))?,
64+
ModuleResult::MultiModule(_modules) => {
65+
anyhow::bail!("MultiModule modules produced");
66+
}
67+
};
68+
69+
Ok((shader_bytes, entry_point))
70+
}
71+
}
72+
73+
impl WgpuShader for RustComputeShader {
74+
fn create_module(
75+
&self,
76+
device: &wgpu::Device,
77+
) -> anyhow::Result<(wgpu::ShaderModule, Option<String>)> {
78+
let (shader_bytes, entry_point) = self.spirv_bytes()?;
79+
80+
if shader_bytes.len() % 4 != 0 {
81+
anyhow::bail!("SPIR-V binary length is not a multiple of 4");
82+
}
83+
let shader_words: Vec<u32> = bytemuck::cast_slice(&shader_bytes).to_vec();
84+
let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
85+
label: Some("Compute Shader"),
86+
source: wgpu::ShaderSource::SpirV(Cow::Owned(shader_words)),
87+
});
88+
Ok((module, Some(entry_point)))
89+
}
90+
}
91+
92+
/// For the SPIR-V shader, the manifest directory is used as the build path.
93+
impl Default for RustComputeShader {
94+
fn default() -> Self {
95+
let manifest_dir = env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set");
96+
Self::new(PathBuf::from(manifest_dir))
97+
}
98+
}

0 commit comments

Comments
 (0)