|
| 1 | +use super::backend::{self, ComputeBackend}; |
1 | 2 | use crate::config::Config;
|
| 3 | +use crate::scaffold::shader::RustComputeShader; |
| 4 | +use crate::scaffold::shader::WgpuShader; |
| 5 | +use crate::scaffold::shader::WgslComputeShader; |
2 | 6 | use anyhow::Context;
|
3 | 7 | use bytemuck::Pod;
|
4 | 8 | use futures::executor::block_on;
|
5 |
| -use spirv_builder::{ModuleResult, SpirvBuilder}; |
6 |
| -use std::{ |
7 |
| - borrow::Cow, |
8 |
| - env, |
9 |
| - fs::{self, File}, |
10 |
| - io::Write, |
11 |
| - path::PathBuf, |
12 |
| - sync::Arc, |
13 |
| -}; |
| 9 | +use std::{borrow::Cow, fs::File, io::Write, sync::Arc}; |
14 | 10 | use wgpu::{PipelineCompilationOptions, util::DeviceExt};
|
15 | 11 |
|
16 |
| -use super::backend::{self, ComputeBackend}; |
17 |
| - |
18 | 12 | pub type BufferConfig = backend::BufferConfig;
|
19 | 13 | pub type BufferUsage = backend::BufferUsage;
|
20 | 14 |
|
21 |
| -/// Trait for shaders that can provide SPIRV bytes. |
22 |
| -pub trait SpirvShader { |
23 |
| - /// Returns the SPIRV bytes and entry point name. |
24 |
| - fn spirv_bytes(&self) -> anyhow::Result<(Vec<u8>, String)>; |
25 |
| -} |
26 |
| - |
27 |
| -/// Trait for shaders that can create wgpu modules. |
28 |
| -pub trait WgpuShader { |
29 |
| - /// Creates a wgpu shader module. |
30 |
| - fn create_module( |
31 |
| - &self, |
32 |
| - device: &wgpu::Device, |
33 |
| - ) -> anyhow::Result<(wgpu::ShaderModule, Option<String>)>; |
34 |
| -} |
35 |
| - |
36 |
| -/// A compute shader written in Rust compiled with spirv-builder. |
37 |
| -pub struct RustComputeShader { |
38 |
| - pub path: PathBuf, |
39 |
| - pub target: String, |
40 |
| - pub capabilities: Vec<spirv_builder::Capability>, |
41 |
| -} |
42 |
| - |
43 |
| -impl RustComputeShader { |
44 |
| - pub fn new<P: Into<PathBuf>>(path: P) -> Self { |
45 |
| - Self { |
46 |
| - path: path.into(), |
47 |
| - target: "spirv-unknown-vulkan1.1".to_string(), |
48 |
| - capabilities: Vec::new(), |
49 |
| - } |
50 |
| - } |
51 |
| - |
52 |
| - pub fn with_target<P: Into<PathBuf>>(path: P, target: impl Into<String>) -> Self { |
53 |
| - Self { |
54 |
| - path: path.into(), |
55 |
| - target: target.into(), |
56 |
| - capabilities: Vec::new(), |
57 |
| - } |
58 |
| - } |
59 |
| - |
60 |
| - pub fn with_capability(mut self, capability: spirv_builder::Capability) -> Self { |
61 |
| - self.capabilities.push(capability); |
62 |
| - self |
63 |
| - } |
64 |
| -} |
65 |
| - |
66 |
| -impl SpirvShader for RustComputeShader { |
67 |
| - fn spirv_bytes(&self) -> anyhow::Result<(Vec<u8>, String)> { |
68 |
| - let mut builder = SpirvBuilder::new(&self.path, &self.target) |
69 |
| - .print_metadata(spirv_builder::MetadataPrintout::None) |
70 |
| - .release(true) |
71 |
| - .multimodule(false) |
72 |
| - .shader_panic_strategy(spirv_builder::ShaderPanicStrategy::SilentExit) |
73 |
| - .preserve_bindings(true); |
74 |
| - |
75 |
| - for capability in &self.capabilities { |
76 |
| - builder = builder.capability(*capability); |
77 |
| - } |
78 |
| - |
79 |
| - let artifact = builder.build().context("SpirvBuilder::build() failed")?; |
80 |
| - |
81 |
| - if artifact.entry_points.len() != 1 { |
82 |
| - anyhow::bail!( |
83 |
| - "Expected exactly one entry point, found {}", |
84 |
| - artifact.entry_points.len() |
85 |
| - ); |
86 |
| - } |
87 |
| - let entry_point = artifact.entry_points.into_iter().next().unwrap(); |
88 |
| - |
89 |
| - let shader_bytes = match artifact.module { |
90 |
| - ModuleResult::SingleModule(path) => fs::read(&path) |
91 |
| - .with_context(|| format!("reading spv file '{}' failed", path.display()))?, |
92 |
| - ModuleResult::MultiModule(_modules) => { |
93 |
| - anyhow::bail!("MultiModule modules produced"); |
94 |
| - } |
95 |
| - }; |
96 |
| - |
97 |
| - Ok((shader_bytes, entry_point)) |
98 |
| - } |
99 |
| -} |
100 |
| - |
101 |
| -impl WgpuShader for RustComputeShader { |
102 |
| - fn create_module( |
103 |
| - &self, |
104 |
| - device: &wgpu::Device, |
105 |
| - ) -> anyhow::Result<(wgpu::ShaderModule, Option<String>)> { |
106 |
| - let (shader_bytes, entry_point) = self.spirv_bytes()?; |
107 |
| - |
108 |
| - if shader_bytes.len() % 4 != 0 { |
109 |
| - anyhow::bail!("SPIR-V binary length is not a multiple of 4"); |
110 |
| - } |
111 |
| - let shader_words: Vec<u32> = bytemuck::cast_slice(&shader_bytes).to_vec(); |
112 |
| - let module = device.create_shader_module(wgpu::ShaderModuleDescriptor { |
113 |
| - label: Some("Compute Shader"), |
114 |
| - source: wgpu::ShaderSource::SpirV(Cow::Owned(shader_words)), |
115 |
| - }); |
116 |
| - Ok((module, Some(entry_point))) |
117 |
| - } |
118 |
| -} |
119 |
| - |
120 |
| -/// A WGSL compute shader. |
121 |
| -pub struct WgslComputeShader { |
122 |
| - pub path: PathBuf, |
123 |
| - pub entry_point: Option<String>, |
124 |
| -} |
125 |
| - |
126 |
| -impl WgslComputeShader { |
127 |
| - pub fn new<P: Into<PathBuf>>(path: P, entry_point: Option<String>) -> Self { |
128 |
| - Self { |
129 |
| - path: path.into(), |
130 |
| - entry_point, |
131 |
| - } |
132 |
| - } |
133 |
| -} |
134 |
| - |
135 |
| -impl WgpuShader for WgslComputeShader { |
136 |
| - fn create_module( |
137 |
| - &self, |
138 |
| - device: &wgpu::Device, |
139 |
| - ) -> anyhow::Result<(wgpu::ShaderModule, Option<String>)> { |
140 |
| - let shader_source = fs::read_to_string(&self.path) |
141 |
| - .with_context(|| format!("reading wgsl source file '{}'", &self.path.display()))?; |
142 |
| - let module = device.create_shader_module(wgpu::ShaderModuleDescriptor { |
143 |
| - label: Some("Compute Shader"), |
144 |
| - source: wgpu::ShaderSource::Wgsl(Cow::Owned(shader_source)), |
145 |
| - }); |
146 |
| - Ok((module, self.entry_point.clone())) |
147 |
| - } |
148 |
| -} |
149 |
| - |
150 | 15 | /// Compute test that is generic over the shader type.
|
151 | 16 | pub struct WgpuComputeTest<S> {
|
152 | 17 | shader: S,
|
@@ -539,48 +404,6 @@ impl ComputeBackend for WgpuBackend {
|
539 | 404 | }
|
540 | 405 | }
|
541 | 406 |
|
542 |
| -/// For WGSL, the code checks for "shader.wgsl" then "compute.wgsl". |
543 |
| -impl Default for WgslComputeShader { |
544 |
| - fn default() -> Self { |
545 |
| - let manifest_dir = env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"); |
546 |
| - let manifest_path = PathBuf::from(manifest_dir); |
547 |
| - let shader_path = manifest_path.join("shader.wgsl"); |
548 |
| - let compute_path = manifest_path.join("compute.wgsl"); |
549 |
| - |
550 |
| - let (file, source) = if shader_path.exists() { |
551 |
| - ( |
552 |
| - shader_path.clone(), |
553 |
| - fs::read_to_string(&shader_path).unwrap_or_default(), |
554 |
| - ) |
555 |
| - } else if compute_path.exists() { |
556 |
| - ( |
557 |
| - compute_path.clone(), |
558 |
| - fs::read_to_string(&compute_path).unwrap_or_default(), |
559 |
| - ) |
560 |
| - } else { |
561 |
| - panic!("No default WGSL shader found in manifest directory"); |
562 |
| - }; |
563 |
| - |
564 |
| - let entry_point = if source.contains("fn main_cs(") { |
565 |
| - Some("main_cs".to_string()) |
566 |
| - } else if source.contains("fn main(") { |
567 |
| - Some("main".to_string()) |
568 |
| - } else { |
569 |
| - None |
570 |
| - }; |
571 |
| - |
572 |
| - Self::new(file, entry_point) |
573 |
| - } |
574 |
| -} |
575 |
| - |
576 |
| -/// For the SPIR-V shader, the manifest directory is used as the build path. |
577 |
| -impl Default for RustComputeShader { |
578 |
| - fn default() -> Self { |
579 |
| - let manifest_dir = env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"); |
580 |
| - Self::new(PathBuf::from(manifest_dir)) |
581 |
| - } |
582 |
| -} |
583 |
| - |
584 | 407 | impl<S> WgpuComputeTestMultiBuffer<S>
|
585 | 408 | where
|
586 | 409 | S: WgpuShader,
|
|
0 commit comments