summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/video_core/renderer_vulkan/vk_shader_decompiler.cpp2271
-rw-r--r--src/video_core/renderer_vulkan/vk_shader_decompiler.h74
2 files changed, 1648 insertions, 697 deletions
diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
index 76894275b..8f517bdc1 100644
--- a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
+++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
@@ -3,8 +3,10 @@
// Refer to the license.txt file included.
#include <functional>
+#include <limits>
#include <map>
-#include <set>
+#include <type_traits>
+#include <utility>
#include <fmt/format.h>
@@ -23,7 +25,9 @@
#include "video_core/shader/node.h"
#include "video_core/shader/shader_ir.h"
-namespace Vulkan::VKShader {
+namespace Vulkan {
+
+namespace {
using Sirit::Id;
using Tegra::Engines::ShaderType;
@@ -35,22 +39,60 @@ using namespace VideoCommon::Shader;
using Maxwell = Tegra::Engines::Maxwell3D::Regs;
using Operation = const OperationNode&;
+class ASTDecompiler;
+class ExprDecompiler;
+
// TODO(Rodrigo): Use rasterizer's value
-constexpr u32 MAX_CONSTBUFFER_FLOATS = 0x4000;
-constexpr u32 MAX_CONSTBUFFER_ELEMENTS = MAX_CONSTBUFFER_FLOATS / 4;
-constexpr u32 STAGE_BINDING_STRIDE = 0x100;
+constexpr u32 MaxConstBufferFloats = 0x4000;
+constexpr u32 MaxConstBufferElements = MaxConstBufferFloats / 4;
+
+constexpr u32 NumInputPatches = 32; // This value seems to be the standard
+
+enum class Type { Void, Bool, Bool2, Float, Int, Uint, HalfFloat };
+
+class Expression final {
+public:
+ Expression(Id id, Type type) : id{id}, type{type} {
+ ASSERT(type != Type::Void);
+ }
+ Expression() : type{Type::Void} {}
-enum class Type { Bool, Bool2, Float, Int, Uint, HalfFloat };
+ Id id{};
+ Type type{};
+};
+static_assert(std::is_standard_layout_v<Expression>);
-struct SamplerImage {
- Id image_type;
- Id sampled_image_type;
- Id sampler;
+struct TexelBuffer {
+ Id image_type{};
+ Id image{};
};
-namespace {
+struct SampledImage {
+ Id image_type{};
+ Id sampled_image_type{};
+ Id sampler{};
+};
+
+struct StorageImage {
+ Id image_type{};
+ Id image{};
+};
+
+struct AttributeType {
+ Type type;
+ Id scalar;
+ Id vector;
+};
+
+struct VertexIndices {
+ std::optional<u32> position;
+ std::optional<u32> viewport;
+ std::optional<u32> point_size;
+ std::optional<u32> clip_distances;
+};
spv::Dim GetSamplerDim(const Sampler& sampler) {
+ ASSERT(!sampler.IsBuffer());
switch (sampler.GetType()) {
case Tegra::Shader::TextureType::Texture1D:
return spv::Dim::Dim1D;
@@ -66,6 +108,138 @@ spv::Dim GetSamplerDim(const Sampler& sampler) {
}
}
+std::pair<spv::Dim, bool> GetImageDim(const Image& image) {
+ switch (image.GetType()) {
+ case Tegra::Shader::ImageType::Texture1D:
+ return {spv::Dim::Dim1D, false};
+ case Tegra::Shader::ImageType::TextureBuffer:
+ return {spv::Dim::Buffer, false};
+ case Tegra::Shader::ImageType::Texture1DArray:
+ return {spv::Dim::Dim1D, true};
+ case Tegra::Shader::ImageType::Texture2D:
+ return {spv::Dim::Dim2D, false};
+ case Tegra::Shader::ImageType::Texture2DArray:
+ return {spv::Dim::Dim2D, true};
+ case Tegra::Shader::ImageType::Texture3D:
+ return {spv::Dim::Dim3D, false};
+ default:
+ UNIMPLEMENTED_MSG("Unimplemented image type={}", static_cast<u32>(image.GetType()));
+ return {spv::Dim::Dim2D, false};
+ }
+}
+
+/// Returns the number of vertices present in a primitive topology.
+u32 GetNumPrimitiveTopologyVertices(Maxwell::PrimitiveTopology primitive_topology) {
+ switch (primitive_topology) {
+ case Maxwell::PrimitiveTopology::Points:
+ return 1;
+ case Maxwell::PrimitiveTopology::Lines:
+ case Maxwell::PrimitiveTopology::LineLoop:
+ case Maxwell::PrimitiveTopology::LineStrip:
+ return 2;
+ case Maxwell::PrimitiveTopology::Triangles:
+ case Maxwell::PrimitiveTopology::TriangleStrip:
+ case Maxwell::PrimitiveTopology::TriangleFan:
+ return 3;
+ case Maxwell::PrimitiveTopology::LinesAdjacency:
+ case Maxwell::PrimitiveTopology::LineStripAdjacency:
+ return 4;
+ case Maxwell::PrimitiveTopology::TrianglesAdjacency:
+ case Maxwell::PrimitiveTopology::TriangleStripAdjacency:
+ return 6;
+ case Maxwell::PrimitiveTopology::Quads:
+ UNIMPLEMENTED_MSG("Quads");
+ return 3;
+ case Maxwell::PrimitiveTopology::QuadStrip:
+ UNIMPLEMENTED_MSG("QuadStrip");
+ return 3;
+ case Maxwell::PrimitiveTopology::Polygon:
+ UNIMPLEMENTED_MSG("Polygon");
+ return 3;
+ case Maxwell::PrimitiveTopology::Patches:
+ UNIMPLEMENTED_MSG("Patches");
+ return 3;
+ default:
+ UNREACHABLE();
+ return 3;
+ }
+}
+
+spv::ExecutionMode GetExecutionMode(Maxwell::TessellationPrimitive primitive) {
+ switch (primitive) {
+ case Maxwell::TessellationPrimitive::Isolines:
+ return spv::ExecutionMode::Isolines;
+ case Maxwell::TessellationPrimitive::Triangles:
+ return spv::ExecutionMode::Triangles;
+ case Maxwell::TessellationPrimitive::Quads:
+ return spv::ExecutionMode::Quads;
+ }
+ UNREACHABLE();
+ return spv::ExecutionMode::Triangles;
+}
+
+spv::ExecutionMode GetExecutionMode(Maxwell::TessellationSpacing spacing) {
+ switch (spacing) {
+ case Maxwell::TessellationSpacing::Equal:
+ return spv::ExecutionMode::SpacingEqual;
+ case Maxwell::TessellationSpacing::FractionalOdd:
+ return spv::ExecutionMode::SpacingFractionalOdd;
+ case Maxwell::TessellationSpacing::FractionalEven:
+ return spv::ExecutionMode::SpacingFractionalEven;
+ }
+ UNREACHABLE();
+ return spv::ExecutionMode::SpacingEqual;
+}
+
+spv::ExecutionMode GetExecutionMode(Maxwell::PrimitiveTopology input_topology) {
+ switch (input_topology) {
+ case Maxwell::PrimitiveTopology::Points:
+ return spv::ExecutionMode::InputPoints;
+ case Maxwell::PrimitiveTopology::Lines:
+ case Maxwell::PrimitiveTopology::LineLoop:
+ case Maxwell::PrimitiveTopology::LineStrip:
+ return spv::ExecutionMode::InputLines;
+ case Maxwell::PrimitiveTopology::Triangles:
+ case Maxwell::PrimitiveTopology::TriangleStrip:
+ case Maxwell::PrimitiveTopology::TriangleFan:
+ return spv::ExecutionMode::Triangles;
+ case Maxwell::PrimitiveTopology::LinesAdjacency:
+ case Maxwell::PrimitiveTopology::LineStripAdjacency:
+ return spv::ExecutionMode::InputLinesAdjacency;
+ case Maxwell::PrimitiveTopology::TrianglesAdjacency:
+ case Maxwell::PrimitiveTopology::TriangleStripAdjacency:
+ return spv::ExecutionMode::InputTrianglesAdjacency;
+ case Maxwell::PrimitiveTopology::Quads:
+ UNIMPLEMENTED_MSG("Quads");
+ return spv::ExecutionMode::Triangles;
+ case Maxwell::PrimitiveTopology::QuadStrip:
+ UNIMPLEMENTED_MSG("QuadStrip");
+ return spv::ExecutionMode::Triangles;
+ case Maxwell::PrimitiveTopology::Polygon:
+ UNIMPLEMENTED_MSG("Polygon");
+ return spv::ExecutionMode::Triangles;
+ case Maxwell::PrimitiveTopology::Patches:
+ UNIMPLEMENTED_MSG("Patches");
+ return spv::ExecutionMode::Triangles;
+ }
+ UNREACHABLE();
+ return spv::ExecutionMode::Triangles;
+}
+
+spv::ExecutionMode GetExecutionMode(Tegra::Shader::OutputTopology output_topology) {
+ switch (output_topology) {
+ case Tegra::Shader::OutputTopology::PointList:
+ return spv::ExecutionMode::OutputPoints;
+ case Tegra::Shader::OutputTopology::LineStrip:
+ return spv::ExecutionMode::OutputLineStrip;
+ case Tegra::Shader::OutputTopology::TriangleStrip:
+ return spv::ExecutionMode::OutputTriangleStrip;
+ default:
+ UNREACHABLE();
+ return spv::ExecutionMode::OutputPoints;
+ }
+}
+
/// Returns true if an attribute index is one of the 32 generic attributes
constexpr bool IsGenericAttribute(Attribute::Index attribute) {
return attribute >= Attribute::Index::Attribute_0 &&
@@ -73,7 +247,7 @@ constexpr bool IsGenericAttribute(Attribute::Index attribute) {
}
/// Returns the location of a generic attribute
-constexpr u32 GetGenericAttributeLocation(Attribute::Index attribute) {
+u32 GetGenericAttributeLocation(Attribute::Index attribute) {
ASSERT(IsGenericAttribute(attribute));
return static_cast<u32>(attribute) - static_cast<u32>(Attribute::Index::Attribute_0);
}
@@ -87,20 +261,146 @@ bool IsPrecise(Operation operand) {
return false;
}
-} // namespace
-
-class ASTDecompiler;
-class ExprDecompiler;
-
-class SPIRVDecompiler : public Sirit::Module {
+class SPIRVDecompiler final : public Sirit::Module {
public:
- explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderType stage)
- : Module(0x00010300), device{device}, ir{ir}, stage{stage}, header{ir.GetHeader()} {
+ explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderType stage,
+ const Specialization& specialization)
+ : Module(0x00010300), device{device}, ir{ir}, stage{stage}, header{ir.GetHeader()},
+ specialization{specialization} {
AddCapability(spv::Capability::Shader);
+ AddCapability(spv::Capability::UniformAndStorageBuffer16BitAccess);
+ AddCapability(spv::Capability::ImageQuery);
+ AddCapability(spv::Capability::Image1D);
+ AddCapability(spv::Capability::ImageBuffer);
+ AddCapability(spv::Capability::ImageGatherExtended);
+ AddCapability(spv::Capability::SampledBuffer);
+ AddCapability(spv::Capability::StorageImageWriteWithoutFormat);
+ AddCapability(spv::Capability::SubgroupBallotKHR);
+ AddCapability(spv::Capability::SubgroupVoteKHR);
+ AddExtension("SPV_KHR_shader_ballot");
+ AddExtension("SPV_KHR_subgroup_vote");
AddExtension("SPV_KHR_storage_buffer_storage_class");
AddExtension("SPV_KHR_variable_pointers");
+
+ if (ir.UsesViewportIndex()) {
+ AddCapability(spv::Capability::MultiViewport);
+ if (device.IsExtShaderViewportIndexLayerSupported()) {
+ AddExtension("SPV_EXT_shader_viewport_index_layer");
+ AddCapability(spv::Capability::ShaderViewportIndexLayerEXT);
+ }
+ }
+
+ if (device.IsFloat16Supported()) {
+ AddCapability(spv::Capability::Float16);
+ }
+ t_scalar_half = Name(TypeFloat(device.IsFloat16Supported() ? 16 : 32), "scalar_half");
+ t_half = Name(TypeVector(t_scalar_half, 2), "half");
+
+ const Id main = Decompile();
+
+ switch (stage) {
+ case ShaderType::Vertex:
+ AddEntryPoint(spv::ExecutionModel::Vertex, main, "main", interfaces);
+ break;
+ case ShaderType::TesselationControl:
+ AddCapability(spv::Capability::Tessellation);
+ AddEntryPoint(spv::ExecutionModel::TessellationControl, main, "main", interfaces);
+ AddExecutionMode(main, spv::ExecutionMode::OutputVertices,
+ header.common2.threads_per_input_primitive);
+ break;
+ case ShaderType::TesselationEval:
+ AddCapability(spv::Capability::Tessellation);
+ AddEntryPoint(spv::ExecutionModel::TessellationEvaluation, main, "main", interfaces);
+ AddExecutionMode(main, GetExecutionMode(specialization.tessellation.primitive));
+ AddExecutionMode(main, GetExecutionMode(specialization.tessellation.spacing));
+ AddExecutionMode(main, specialization.tessellation.clockwise
+ ? spv::ExecutionMode::VertexOrderCw
+ : spv::ExecutionMode::VertexOrderCcw);
+ break;
+ case ShaderType::Geometry:
+ AddCapability(spv::Capability::Geometry);
+ AddEntryPoint(spv::ExecutionModel::Geometry, main, "main", interfaces);
+ AddExecutionMode(main, GetExecutionMode(specialization.primitive_topology));
+ AddExecutionMode(main, GetExecutionMode(header.common3.output_topology));
+ AddExecutionMode(main, spv::ExecutionMode::OutputVertices,
+ header.common4.max_output_vertices);
+ // TODO(Rodrigo): Where can we get this info from?
+ AddExecutionMode(main, spv::ExecutionMode::Invocations, 1U);
+ break;
+ case ShaderType::Fragment:
+ AddEntryPoint(spv::ExecutionModel::Fragment, main, "main", interfaces);
+ AddExecutionMode(main, spv::ExecutionMode::OriginUpperLeft);
+ if (header.ps.omap.depth) {
+ AddExecutionMode(main, spv::ExecutionMode::DepthReplacing);
+ }
+ break;
+ case ShaderType::Compute:
+ const auto workgroup_size = specialization.workgroup_size;
+ AddExecutionMode(main, spv::ExecutionMode::LocalSize, workgroup_size[0],
+ workgroup_size[1], workgroup_size[2]);
+ AddEntryPoint(spv::ExecutionModel::GLCompute, main, "main", interfaces);
+ break;
+ }
}
+private:
+ Id Decompile() {
+ DeclareCommon();
+ DeclareVertex();
+ DeclareTessControl();
+ DeclareTessEval();
+ DeclareGeometry();
+ DeclareFragment();
+ DeclareCompute();
+ DeclareRegisters();
+ DeclarePredicates();
+ DeclareLocalMemory();
+ DeclareSharedMemory();
+ DeclareInternalFlags();
+ DeclareInputAttributes();
+ DeclareOutputAttributes();
+
+ u32 binding = specialization.base_binding;
+ binding = DeclareConstantBuffers(binding);
+ binding = DeclareGlobalBuffers(binding);
+ binding = DeclareTexelBuffers(binding);
+ binding = DeclareSamplers(binding);
+ binding = DeclareImages(binding);
+
+ const Id main = OpFunction(t_void, {}, TypeFunction(t_void));
+ AddLabel();
+
+ if (ir.IsDecompiled()) {
+ DeclareFlowVariables();
+ DecompileAST();
+ } else {
+ AllocateLabels();
+ DecompileBranchMode();
+ }
+
+ OpReturn();
+ OpFunctionEnd();
+
+ return main;
+ }
+
+ void DefinePrologue() {
+ if (stage == ShaderType::Vertex) {
+ // Clear Position to avoid reading trash on the Z conversion.
+ const auto position_index = out_indices.position.value();
+ const Id position = AccessElement(t_out_float4, out_vertex, position_index);
+ OpStore(position, v_varying_default);
+
+ if (specialization.point_size) {
+ const u32 point_size_index = out_indices.point_size.value();
+ const Id out_point_size = AccessElement(t_out_float, out_vertex, point_size_index);
+ OpStore(out_point_size, Constant(t_float, *specialization.point_size));
+ }
+ }
+ }
+
+ void DecompileAST();
+
void DecompileBranchMode() {
const u32 first_address = ir.GetBasicBlocks().begin()->first;
const Id loop_label = OpLabel("loop");
@@ -111,14 +411,15 @@ public:
std::vector<Sirit::Literal> literals;
std::vector<Id> branch_labels;
- for (const auto& pair : labels) {
- const auto [literal, label] = pair;
+ for (const auto& [literal, label] : labels) {
literals.push_back(literal);
branch_labels.push_back(label);
}
- jmp_to = Emit(OpVariable(TypePointer(spv::StorageClass::Function, t_uint),
- spv::StorageClass::Function, Constant(t_uint, first_address)));
+ jmp_to = OpVariable(TypePointer(spv::StorageClass::Function, t_uint),
+ spv::StorageClass::Function, Constant(t_uint, first_address));
+ AddLocalVariable(jmp_to);
+
std::tie(ssy_flow_stack, ssy_flow_stack_top) = CreateFlowStack();
std::tie(pbk_flow_stack, pbk_flow_stack_top) = CreateFlowStack();
@@ -128,151 +429,118 @@ public:
Name(pbk_flow_stack, "pbk_flow_stack");
Name(pbk_flow_stack_top, "pbk_flow_stack_top");
- Emit(OpBranch(loop_label));
- Emit(loop_label);
- Emit(OpLoopMerge(merge_label, continue_label, spv::LoopControlMask::Unroll));
- Emit(OpBranch(dummy_label));
+ DefinePrologue();
+
+ OpBranch(loop_label);
+ AddLabel(loop_label);
+ OpLoopMerge(merge_label, continue_label, spv::LoopControlMask::MaskNone);
+ OpBranch(dummy_label);
- Emit(dummy_label);
+ AddLabel(dummy_label);
const Id default_branch = OpLabel();
- const Id jmp_to_load = Emit(OpLoad(t_uint, jmp_to));
- Emit(OpSelectionMerge(jump_label, spv::SelectionControlMask::MaskNone));
- Emit(OpSwitch(jmp_to_load, default_branch, literals, branch_labels));
+ const Id jmp_to_load = OpLoad(t_uint, jmp_to);
+ OpSelectionMerge(jump_label, spv::SelectionControlMask::MaskNone);
+ OpSwitch(jmp_to_load, default_branch, literals, branch_labels);
- Emit(default_branch);
- Emit(OpReturn());
+ AddLabel(default_branch);
+ OpReturn();
- for (const auto& pair : ir.GetBasicBlocks()) {
- const auto& [address, bb] = pair;
- Emit(labels.at(address));
+ for (const auto& [address, bb] : ir.GetBasicBlocks()) {
+ AddLabel(labels.at(address));
VisitBasicBlock(bb);
const auto next_it = labels.lower_bound(address + 1);
const Id next_label = next_it != labels.end() ? next_it->second : default_branch;
- Emit(OpBranch(next_label));
+ OpBranch(next_label);
}
- Emit(jump_label);
- Emit(OpBranch(continue_label));
- Emit(continue_label);
- Emit(OpBranch(loop_label));
- Emit(merge_label);
+ AddLabel(jump_label);
+ OpBranch(continue_label);
+ AddLabel(continue_label);
+ OpBranch(loop_label);
+ AddLabel(merge_label);
}
- void DecompileAST();
+private:
+ friend class ASTDecompiler;
+ friend class ExprDecompiler;
- void Decompile() {
- const bool is_fully_decompiled = ir.IsDecompiled();
- AllocateBindings();
- if (!is_fully_decompiled) {
- AllocateLabels();
- }
+ static constexpr auto INTERNAL_FLAGS_COUNT = static_cast<std::size_t>(InternalFlag::Amount);
- DeclareVertex();
- DeclareGeometry();
- DeclareFragment();
- DeclareRegisters();
- DeclarePredicates();
- if (is_fully_decompiled) {
- DeclareFlowVariables();
+ void AllocateLabels() {
+ for (const auto& pair : ir.GetBasicBlocks()) {
+ const u32 address = pair.first;
+ labels.emplace(address, OpLabel(fmt::format("label_0x{:x}", address)));
}
- DeclareLocalMemory();
- DeclareInternalFlags();
- DeclareInputAttributes();
- DeclareOutputAttributes();
- DeclareConstantBuffers();
- DeclareGlobalBuffers();
- DeclareSamplers();
+ }
- execute_function =
- Emit(OpFunction(t_void, spv::FunctionControlMask::Inline, TypeFunction(t_void)));
- Emit(OpLabel());
+ void DeclareCommon() {
+ thread_id =
+ DeclareInputBuiltIn(spv::BuiltIn::SubgroupLocalInvocationId, t_in_uint, "thread_id");
+ }
- if (is_fully_decompiled) {
- DecompileAST();
- } else {
- DecompileBranchMode();
+ void DeclareVertex() {
+ if (stage != ShaderType::Vertex) {
+ return;
}
+ Id out_vertex_struct;
+ std::tie(out_vertex_struct, out_indices) = DeclareVertexStruct();
+ const Id vertex_ptr = TypePointer(spv::StorageClass::Output, out_vertex_struct);
+ out_vertex = OpVariable(vertex_ptr, spv::StorageClass::Output);
+ interfaces.push_back(AddGlobalVariable(Name(out_vertex, "out_vertex")));
- Emit(OpReturn());
- Emit(OpFunctionEnd());
+ // Declare input attributes
+ vertex_index = DeclareInputBuiltIn(spv::BuiltIn::VertexIndex, t_in_uint, "vertex_index");
+ instance_index =
+ DeclareInputBuiltIn(spv::BuiltIn::InstanceIndex, t_in_uint, "instance_index");
}
- ShaderEntries GetShaderEntries() const {
- ShaderEntries entries;
- entries.const_buffers_base_binding = const_buffers_base_binding;
- entries.global_buffers_base_binding = global_buffers_base_binding;
- entries.samplers_base_binding = samplers_base_binding;
- for (const auto& cbuf : ir.GetConstantBuffers()) {
- entries.const_buffers.emplace_back(cbuf.second, cbuf.first);
- }
- for (const auto& gmem_pair : ir.GetGlobalMemory()) {
- const auto& [base, usage] = gmem_pair;
- entries.global_buffers.emplace_back(base.cbuf_index, base.cbuf_offset);
- }
- for (const auto& sampler : ir.GetSamplers()) {
- entries.samplers.emplace_back(sampler);
- }
- for (const auto& attribute : ir.GetInputAttributes()) {
- if (IsGenericAttribute(attribute)) {
- entries.attributes.insert(GetGenericAttributeLocation(attribute));
- }
+ void DeclareTessControl() {
+ if (stage != ShaderType::TesselationControl) {
+ return;
}
- entries.clip_distances = ir.GetClipDistances();
- entries.shader_length = ir.GetLength();
- entries.entry_function = execute_function;
- entries.interfaces = interfaces;
- return entries;
- }
-
-private:
- friend class ASTDecompiler;
- friend class ExprDecompiler;
-
- static constexpr auto INTERNAL_FLAGS_COUNT = static_cast<std::size_t>(InternalFlag::Amount);
+ DeclareInputVertexArray(NumInputPatches);
+ DeclareOutputVertexArray(header.common2.threads_per_input_primitive);
- void AllocateBindings() {
- const u32 binding_base = static_cast<u32>(stage) * STAGE_BINDING_STRIDE;
- u32 binding_iterator = binding_base;
+ tess_level_outer = DeclareBuiltIn(
+ spv::BuiltIn::TessLevelOuter, spv::StorageClass::Output,
+ TypePointer(spv::StorageClass::Output, TypeArray(t_float, Constant(t_uint, 4U))),
+ "tess_level_outer");
+ Decorate(tess_level_outer, spv::Decoration::Patch);
- const auto Allocate = [&binding_iterator](std::size_t count) {
- const u32 current_binding = binding_iterator;
- binding_iterator += static_cast<u32>(count);
- return current_binding;
- };
- const_buffers_base_binding = Allocate(ir.GetConstantBuffers().size());
- global_buffers_base_binding = Allocate(ir.GetGlobalMemory().size());
- samplers_base_binding = Allocate(ir.GetSamplers().size());
+ tess_level_inner = DeclareBuiltIn(
+ spv::BuiltIn::TessLevelInner, spv::StorageClass::Output,
+ TypePointer(spv::StorageClass::Output, TypeArray(t_float, Constant(t_uint, 2U))),
+ "tess_level_inner");
+ Decorate(tess_level_inner, spv::Decoration::Patch);
- ASSERT_MSG(binding_iterator - binding_base < STAGE_BINDING_STRIDE,
- "Stage binding stride is too small");
+ invocation_id = DeclareInputBuiltIn(spv::BuiltIn::InvocationId, t_in_int, "invocation_id");
}
- void AllocateLabels() {
- for (const auto& pair : ir.GetBasicBlocks()) {
- const u32 address = pair.first;
- labels.emplace(address, OpLabel(fmt::format("label_0x{:x}", address)));
- }
- }
-
- void DeclareVertex() {
- if (stage != ShaderType::Vertex)
+ void DeclareTessEval() {
+ if (stage != ShaderType::TesselationEval) {
return;
+ }
+ DeclareInputVertexArray(NumInputPatches);
+ DeclareOutputVertex();
- DeclareVertexRedeclarations();
+ tess_coord = DeclareInputBuiltIn(spv::BuiltIn::TessCoord, t_in_float3, "tess_coord");
}
void DeclareGeometry() {
- if (stage != ShaderType::Geometry)
+ if (stage != ShaderType::Geometry) {
return;
-
- UNIMPLEMENTED();
+ }
+ const u32 num_input = GetNumPrimitiveTopologyVertices(specialization.primitive_topology);
+ DeclareInputVertexArray(num_input);
+ DeclareOutputVertex();
}
void DeclareFragment() {
- if (stage != ShaderType::Fragment)
+ if (stage != ShaderType::Fragment) {
return;
+ }
for (u32 rt = 0; rt < static_cast<u32>(frag_colors.size()); ++rt) {
if (!IsRenderTargetUsed(rt)) {
@@ -296,10 +564,19 @@ private:
interfaces.push_back(frag_depth);
}
- frag_coord = DeclareBuiltIn(spv::BuiltIn::FragCoord, spv::StorageClass::Input, t_in_float4,
- "frag_coord");
- front_facing = DeclareBuiltIn(spv::BuiltIn::FrontFacing, spv::StorageClass::Input,
- t_in_bool, "front_facing");
+ frag_coord = DeclareInputBuiltIn(spv::BuiltIn::FragCoord, t_in_float4, "frag_coord");
+ front_facing = DeclareInputBuiltIn(spv::BuiltIn::FrontFacing, t_in_bool, "front_facing");
+ point_coord = DeclareInputBuiltIn(spv::BuiltIn::PointCoord, t_in_float2, "point_coord");
+ }
+
+ void DeclareCompute() {
+ if (stage != ShaderType::Compute) {
+ return;
+ }
+
+ workgroup_id = DeclareInputBuiltIn(spv::BuiltIn::WorkgroupId, t_in_uint3, "workgroup_id");
+ local_invocation_id =
+ DeclareInputBuiltIn(spv::BuiltIn::LocalInvocationId, t_in_uint3, "local_invocation_id");
}
void DeclareRegisters() {
@@ -327,21 +604,44 @@ private:
}
void DeclareLocalMemory() {
- if (const u64 local_memory_size = header.GetLocalMemorySize(); local_memory_size > 0) {
- const auto element_count = static_cast<u32>(Common::AlignUp(local_memory_size, 4) / 4);
- const Id type_array = TypeArray(t_float, Constant(t_uint, element_count));
- const Id type_pointer = TypePointer(spv::StorageClass::Private, type_array);
- Name(type_pointer, "LocalMemory");
+ // TODO(Rodrigo): Unstub kernel local memory size and pass it from a register at
+ // specialization time.
+ const u64 lmem_size = stage == ShaderType::Compute ? 0x400 : header.GetLocalMemorySize();
+ if (lmem_size == 0) {
+ return;
+ }
+ const auto element_count = static_cast<u32>(Common::AlignUp(lmem_size, 4) / 4);
+ const Id type_array = TypeArray(t_float, Constant(t_uint, element_count));
+ const Id type_pointer = TypePointer(spv::StorageClass::Private, type_array);
+ Name(type_pointer, "LocalMemory");
+
+ local_memory =
+ OpVariable(type_pointer, spv::StorageClass::Private, ConstantNull(type_array));
+ AddGlobalVariable(Name(local_memory, "local_memory"));
+ }
+
+ void DeclareSharedMemory() {
+ if (stage != ShaderType::Compute) {
+ return;
+ }
+ t_smem_uint = TypePointer(spv::StorageClass::Workgroup, t_uint);
- local_memory =
- OpVariable(type_pointer, spv::StorageClass::Private, ConstantNull(type_array));
- AddGlobalVariable(Name(local_memory, "local_memory"));
+ const u32 smem_size = specialization.shared_memory_size;
+ if (smem_size == 0) {
+ // Avoid declaring an empty array.
+ return;
}
+ const auto element_count = static_cast<u32>(Common::AlignUp(smem_size, 4) / 4);
+ const Id type_array = TypeArray(t_uint, Constant(t_uint, element_count));
+ const Id type_pointer = TypePointer(spv::StorageClass::Workgroup, type_array);
+ Name(type_pointer, "SharedMemory");
+
+ shared_memory = OpVariable(type_pointer, spv::StorageClass::Workgroup);
+ AddGlobalVariable(Name(shared_memory, "shared_memory"));
}
void DeclareInternalFlags() {
- constexpr std::array<const char*, INTERNAL_FLAGS_COUNT> names = {"zero", "sign", "carry",
- "overflow"};
+ constexpr std::array names = {"zero", "sign", "carry", "overflow"};
for (std::size_t flag = 0; flag < INTERNAL_FLAGS_COUNT; ++flag) {
const auto flag_code = static_cast<InternalFlag>(flag);
const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false);
@@ -349,17 +649,53 @@ private:
}
}
+ void DeclareInputVertexArray(u32 length) {
+ constexpr auto storage = spv::StorageClass::Input;
+ std::tie(in_indices, in_vertex) = DeclareVertexArray(storage, "in_indices", length);
+ }
+
+ void DeclareOutputVertexArray(u32 length) {
+ constexpr auto storage = spv::StorageClass::Output;
+ std::tie(out_indices, out_vertex) = DeclareVertexArray(storage, "out_indices", length);
+ }
+
+ std::tuple<VertexIndices, Id> DeclareVertexArray(spv::StorageClass storage_class,
+ std::string name, u32 length) {
+ const auto [struct_id, indices] = DeclareVertexStruct();
+ const Id vertex_array = TypeArray(struct_id, Constant(t_uint, length));
+ const Id vertex_ptr = TypePointer(storage_class, vertex_array);
+ const Id vertex = OpVariable(vertex_ptr, storage_class);
+ AddGlobalVariable(Name(vertex, std::move(name)));
+ interfaces.push_back(vertex);
+ return {indices, vertex};
+ }
+
+ void DeclareOutputVertex() {
+ Id out_vertex_struct;
+ std::tie(out_vertex_struct, out_indices) = DeclareVertexStruct();
+ const Id out_vertex_ptr = TypePointer(spv::StorageClass::Output, out_vertex_struct);
+ out_vertex = OpVariable(out_vertex_ptr, spv::StorageClass::Output);
+ interfaces.push_back(AddGlobalVariable(Name(out_vertex, "out_vertex")));
+ }
+
void DeclareInputAttributes() {
for (const auto index : ir.GetInputAttributes()) {
if (!IsGenericAttribute(index)) {
continue;
}
- UNIMPLEMENTED_IF(stage == ShaderType::Geometry);
-
const u32 location = GetGenericAttributeLocation(index);
- const Id id = OpVariable(t_in_float4, spv::StorageClass::Input);
- Name(AddGlobalVariable(id), fmt::format("in_attr{}", location));
+ const auto type_descriptor = GetAttributeType(location);
+ Id type;
+ if (IsInputAttributeArray()) {
+ type = GetTypeVectorDefinitionLut(type_descriptor.type).at(3);
+ type = TypeArray(type, Constant(t_uint, GetNumInputVertices()));
+ type = TypePointer(spv::StorageClass::Input, type);
+ } else {
+ type = type_descriptor.vector;
+ }
+ const Id id = OpVariable(type, spv::StorageClass::Input);
+ AddGlobalVariable(Name(id, fmt::format("in_attr{}", location)));
input_attributes.emplace(index, id);
interfaces.push_back(id);
@@ -389,8 +725,21 @@ private:
if (!IsGenericAttribute(index)) {
continue;
}
- const auto location = GetGenericAttributeLocation(index);
- const Id id = OpVariable(t_out_float4, spv::StorageClass::Output);
+ const u32 location = GetGenericAttributeLocation(index);
+ Id type = t_float4;
+ Id varying_default = v_varying_default;
+ if (IsOutputAttributeArray()) {
+ const u32 num = GetNumOutputVertices();
+ type = TypeArray(type, Constant(t_uint, num));
+ if (device.GetDriverID() != vk::DriverIdKHR::eIntelProprietaryWindows) {
+ // Intel's proprietary driver fails to setup defaults for arrayed output
+ // attributes.
+ varying_default = ConstantComposite(type, std::vector(num, varying_default));
+ }
+ }
+ type = TypePointer(spv::StorageClass::Output, type);
+
+ const Id id = OpVariable(type, spv::StorageClass::Output, varying_default);
Name(AddGlobalVariable(id), fmt::format("out_attr{}", location));
output_attributes.emplace(index, id);
interfaces.push_back(id);
@@ -399,10 +748,8 @@ private:
}
}
- void DeclareConstantBuffers() {
- u32 binding = const_buffers_base_binding;
- for (const auto& entry : ir.GetConstantBuffers()) {
- const auto [index, size] = entry;
+ u32 DeclareConstantBuffers(u32 binding) {
+ for (const auto& [index, size] : ir.GetConstantBuffers()) {
const Id type = device.IsKhrUniformBufferStandardLayoutSupported() ? t_cbuf_scalar_ubo
: t_cbuf_std140_ubo;
const Id id = OpVariable(type, spv::StorageClass::Uniform);
@@ -412,12 +759,11 @@ private:
Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
constant_buffers.emplace(index, id);
}
+ return binding;
}
- void DeclareGlobalBuffers() {
- u32 binding = global_buffers_base_binding;
- for (const auto& entry : ir.GetGlobalMemory()) {
- const auto [base, usage] = entry;
+ u32 DeclareGlobalBuffers(u32 binding) {
+ for (const auto& [base, usage] : ir.GetGlobalMemory()) {
const Id id = OpVariable(t_gmem_ssbo, spv::StorageClass::StorageBuffer);
AddGlobalVariable(
Name(id, fmt::format("gmem_{}_{}", base.cbuf_index, base.cbuf_offset)));
@@ -426,89 +772,187 @@ private:
Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
global_buffers.emplace(base, id);
}
+ return binding;
+ }
+
+ u32 DeclareTexelBuffers(u32 binding) {
+ for (const auto& sampler : ir.GetSamplers()) {
+ if (!sampler.IsBuffer()) {
+ continue;
+ }
+ ASSERT(!sampler.IsArray());
+ ASSERT(!sampler.IsShadow());
+
+ constexpr auto dim = spv::Dim::Buffer;
+ constexpr int depth = 0;
+ constexpr int arrayed = 0;
+ constexpr bool ms = false;
+ constexpr int sampled = 1;
+ constexpr auto format = spv::ImageFormat::Unknown;
+ const Id image_type = TypeImage(t_float, dim, depth, arrayed, ms, sampled, format);
+ const Id pointer_type = TypePointer(spv::StorageClass::UniformConstant, image_type);
+ const Id id = OpVariable(pointer_type, spv::StorageClass::UniformConstant);
+ AddGlobalVariable(Name(id, fmt::format("sampler_{}", sampler.GetIndex())));
+ Decorate(id, spv::Decoration::Binding, binding++);
+ Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
+
+ texel_buffers.emplace(sampler.GetIndex(), TexelBuffer{image_type, id});
+ }
+ return binding;
}
- void DeclareSamplers() {
- u32 binding = samplers_base_binding;
+ u32 DeclareSamplers(u32 binding) {
for (const auto& sampler : ir.GetSamplers()) {
+ if (sampler.IsBuffer()) {
+ continue;
+ }
const auto dim = GetSamplerDim(sampler);
const int depth = sampler.IsShadow() ? 1 : 0;
const int arrayed = sampler.IsArray() ? 1 : 0;
- // TODO(Rodrigo): Sampled 1 indicates that the image will be used with a sampler. When
- // SULD and SUST instructions are implemented, replace this value.
- const int sampled = 1;
- const Id image_type =
- TypeImage(t_float, dim, depth, arrayed, false, sampled, spv::ImageFormat::Unknown);
+ constexpr bool ms = false;
+ constexpr int sampled = 1;
+ constexpr auto format = spv::ImageFormat::Unknown;
+ const Id image_type = TypeImage(t_float, dim, depth, arrayed, ms, sampled, format);
const Id sampled_image_type = TypeSampledImage(image_type);
const Id pointer_type =
TypePointer(spv::StorageClass::UniformConstant, sampled_image_type);
const Id id = OpVariable(pointer_type, spv::StorageClass::UniformConstant);
AddGlobalVariable(Name(id, fmt::format("sampler_{}", sampler.GetIndex())));
+ Decorate(id, spv::Decoration::Binding, binding++);
+ Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
- sampler_images.insert(
- {static_cast<u32>(sampler.GetIndex()), {image_type, sampled_image_type, id}});
+ sampled_images.emplace(sampler.GetIndex(),
+ SampledImage{image_type, sampled_image_type, id});
+ }
+ return binding;
+ }
+
+ u32 DeclareImages(u32 binding) {
+ for (const auto& image : ir.GetImages()) {
+ const auto [dim, arrayed] = GetImageDim(image);
+ constexpr int depth = 0;
+ constexpr bool ms = false;
+ constexpr int sampled = 2; // This won't be accessed with a sampler
+ constexpr auto format = spv::ImageFormat::Unknown;
+ const Id image_type = TypeImage(t_uint, dim, depth, arrayed, ms, sampled, format, {});
+ const Id pointer_type = TypePointer(spv::StorageClass::UniformConstant, image_type);
+ const Id id = OpVariable(pointer_type, spv::StorageClass::UniformConstant);
+ AddGlobalVariable(Name(id, fmt::format("image_{}", image.GetIndex())));
Decorate(id, spv::Decoration::Binding, binding++);
Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
+ if (image.IsRead() && !image.IsWritten()) {
+ Decorate(id, spv::Decoration::NonWritable);
+ } else if (image.IsWritten() && !image.IsRead()) {
+ Decorate(id, spv::Decoration::NonReadable);
+ }
+
+ images.emplace(static_cast<u32>(image.GetIndex()), StorageImage{image_type, id});
}
+ return binding;
+ }
+
+ bool IsInputAttributeArray() const {
+ return stage == ShaderType::TesselationControl || stage == ShaderType::TesselationEval ||
+ stage == ShaderType::Geometry;
}
- void DeclareVertexRedeclarations() {
- vertex_index = DeclareBuiltIn(spv::BuiltIn::VertexIndex, spv::StorageClass::Input,
- t_in_uint, "vertex_index");
- instance_index = DeclareBuiltIn(spv::BuiltIn::InstanceIndex, spv::StorageClass::Input,
- t_in_uint, "instance_index");
+ bool IsOutputAttributeArray() const {
+ return stage == ShaderType::TesselationControl;
+ }
- bool is_clip_distances_declared = false;
- for (const auto index : ir.GetOutputAttributes()) {
- if (index == Attribute::Index::ClipDistances0123 ||
- index == Attribute::Index::ClipDistances4567) {
- is_clip_distances_declared = true;
- }
+ u32 GetNumInputVertices() const {
+ switch (stage) {
+ case ShaderType::Geometry:
+ return GetNumPrimitiveTopologyVertices(specialization.primitive_topology);
+ case ShaderType::TesselationControl:
+ case ShaderType::TesselationEval:
+ return NumInputPatches;
+ default:
+ UNREACHABLE();
+ return 1;
}
+ }
- std::vector<Id> members;
- members.push_back(t_float4);
- if (ir.UsesPointSize()) {
- members.push_back(t_float);
- }
- if (is_clip_distances_declared) {
- members.push_back(TypeArray(t_float, Constant(t_uint, 8)));
- }
-
- const Id gl_per_vertex_struct = Name(TypeStruct(members), "PerVertex");
- Decorate(gl_per_vertex_struct, spv::Decoration::Block);
-
- u32 declaration_index = 0;
- const auto MemberDecorateBuiltIn = [&](spv::BuiltIn builtin, std::string name,
- bool condition) {
- if (!condition)
- return u32{};
- MemberName(gl_per_vertex_struct, declaration_index, name);
- MemberDecorate(gl_per_vertex_struct, declaration_index, spv::Decoration::BuiltIn,
- static_cast<u32>(builtin));
- return declaration_index++;
+ u32 GetNumOutputVertices() const {
+ switch (stage) {
+ case ShaderType::TesselationControl:
+ return header.common2.threads_per_input_primitive;
+ default:
+ UNREACHABLE();
+ return 1;
+ }
+ }
+
+ std::tuple<Id, VertexIndices> DeclareVertexStruct() {
+ struct BuiltIn {
+ Id type;
+ spv::BuiltIn builtin;
+ const char* name;
+ };
+ std::vector<BuiltIn> members;
+ members.reserve(4);
+
+ const auto AddBuiltIn = [&](Id type, spv::BuiltIn builtin, const char* name) {
+ const auto index = static_cast<u32>(members.size());
+ members.push_back(BuiltIn{type, builtin, name});
+ return index;
};
- position_index = MemberDecorateBuiltIn(spv::BuiltIn::Position, "position", true);
- point_size_index =
- MemberDecorateBuiltIn(spv::BuiltIn::PointSize, "point_size", ir.UsesPointSize());
- clip_distances_index = MemberDecorateBuiltIn(spv::BuiltIn::ClipDistance, "clip_distances",
- is_clip_distances_declared);
+ VertexIndices indices;
+ indices.position = AddBuiltIn(t_float4, spv::BuiltIn::Position, "position");
+
+ if (ir.UsesViewportIndex()) {
+ if (stage != ShaderType::Vertex || device.IsExtShaderViewportIndexLayerSupported()) {
+ indices.viewport = AddBuiltIn(t_int, spv::BuiltIn::ViewportIndex, "viewport_index");
+ } else {
+ LOG_ERROR(Render_Vulkan,
+ "Shader requires ViewportIndex but it's not supported on this "
+ "stage with this device.");
+ }
+ }
+
+ if (ir.UsesPointSize() || specialization.point_size) {
+ indices.point_size = AddBuiltIn(t_float, spv::BuiltIn::PointSize, "point_size");
+ }
+
+ const auto& output_attributes = ir.GetOutputAttributes();
+ const bool declare_clip_distances =
+ std::any_of(output_attributes.begin(), output_attributes.end(), [](const auto& index) {
+ return index == Attribute::Index::ClipDistances0123 ||
+ index == Attribute::Index::ClipDistances4567;
+ });
+ if (declare_clip_distances) {
+ indices.clip_distances = AddBuiltIn(TypeArray(t_float, Constant(t_uint, 8)),
+ spv::BuiltIn::ClipDistance, "clip_distances");
+ }
+
+ std::vector<Id> member_types;
+ member_types.reserve(members.size());
+ for (std::size_t i = 0; i < members.size(); ++i) {
+ member_types.push_back(members[i].type);
+ }
+ const Id per_vertex_struct = Name(TypeStruct(member_types), "PerVertex");
+ Decorate(per_vertex_struct, spv::Decoration::Block);
+
+ for (std::size_t index = 0; index < members.size(); ++index) {
+ const auto& member = members[index];
+ MemberName(per_vertex_struct, static_cast<u32>(index), member.name);
+ MemberDecorate(per_vertex_struct, static_cast<u32>(index), spv::Decoration::BuiltIn,
+ static_cast<u32>(member.builtin));
+ }
- const Id type_pointer = TypePointer(spv::StorageClass::Output, gl_per_vertex_struct);
- per_vertex = OpVariable(type_pointer, spv::StorageClass::Output);
- AddGlobalVariable(Name(per_vertex, "per_vertex"));
- interfaces.push_back(per_vertex);
+ return {per_vertex_struct, indices};
}
void VisitBasicBlock(const NodeBlock& bb) {
for (const auto& node : bb) {
- static_cast<void>(Visit(node));
+ [[maybe_unused]] const Type type = Visit(node).type;
+ ASSERT(type == Type::Void);
}
}
- Id Visit(const Node& node) {
+ Expression Visit(const Node& node) {
if (const auto operation = std::get_if<OperationNode>(&*node)) {
const auto operation_index = static_cast<std::size_t>(operation->GetCode());
const auto decompiler = operation_decompilers[operation_index];
@@ -516,18 +960,21 @@ private:
UNREACHABLE_MSG("Operation decompiler {} not defined", operation_index);
}
return (this->*decompiler)(*operation);
+ }
- } else if (const auto gpr = std::get_if<GprNode>(&*node)) {
+ if (const auto gpr = std::get_if<GprNode>(&*node)) {
const u32 index = gpr->GetIndex();
if (index == Register::ZeroIndex) {
- return Constant(t_float, 0.0f);
+ return {v_float_zero, Type::Float};
}
- return Emit(OpLoad(t_float, registers.at(index)));
+ return {OpLoad(t_float, registers.at(index)), Type::Float};
+ }
- } else if (const auto immediate = std::get_if<ImmediateNode>(&*node)) {
- return BitcastTo<Type::Float>(Constant(t_uint, immediate->GetValue()));
+ if (const auto immediate = std::get_if<ImmediateNode>(&*node)) {
+ return {Constant(t_uint, immediate->GetValue()), Type::Uint};
+ }
- } else if (const auto predicate = std::get_if<PredicateNode>(&*node)) {
+ if (const auto predicate = std::get_if<PredicateNode>(&*node)) {
const auto value = [&]() -> Id {
switch (const auto index = predicate->GetIndex(); index) {
case Tegra::Shader::Pred::UnusedIndex:
@@ -535,74 +982,107 @@ private:
case Tegra::Shader::Pred::NeverExecute:
return v_false;
default:
- return Emit(OpLoad(t_bool, predicates.at(index)));
+ return OpLoad(t_bool, predicates.at(index));
}
}();
if (predicate->IsNegated()) {
- return Emit(OpLogicalNot(t_bool, value));
+ return {OpLogicalNot(t_bool, value), Type::Bool};
}
- return value;
+ return {value, Type::Bool};
+ }
- } else if (const auto abuf = std::get_if<AbufNode>(&*node)) {
+ if (const auto abuf = std::get_if<AbufNode>(&*node)) {
const auto attribute = abuf->GetIndex();
- const auto element = abuf->GetElement();
+ const u32 element = abuf->GetElement();
+ const auto& buffer = abuf->GetBuffer();
+
+ const auto ArrayPass = [&](Id pointer_type, Id composite, std::vector<u32> indices) {
+ std::vector<Id> members;
+ members.reserve(std::size(indices) + 1);
+
+ if (buffer && IsInputAttributeArray()) {
+ members.push_back(AsUint(Visit(buffer)));
+ }
+ for (const u32 index : indices) {
+ members.push_back(Constant(t_uint, index));
+ }
+ return OpAccessChain(pointer_type, composite, members);
+ };
switch (attribute) {
- case Attribute::Index::Position:
- if (stage != ShaderType::Fragment) {
- UNIMPLEMENTED();
- break;
- } else {
+ case Attribute::Index::Position: {
+ if (stage == ShaderType::Fragment) {
if (element == 3) {
- return Constant(t_float, 1.0f);
+ return {Constant(t_float, 1.0f), Type::Float};
}
- return Emit(OpLoad(t_float, AccessElement(t_in_float, frag_coord, element)));
+ return {OpLoad(t_float, AccessElement(t_in_float, frag_coord, element)),
+ Type::Float};
+ }
+ const auto elements = {in_indices.position.value(), element};
+ return {OpLoad(t_float, ArrayPass(t_in_float, in_vertex, elements)), Type::Float};
+ }
+ case Attribute::Index::PointCoord: {
+ switch (element) {
+ case 0:
+ case 1:
+ return {OpCompositeExtract(t_float, OpLoad(t_float2, point_coord), element),
+ Type::Float};
}
+ UNIMPLEMENTED_MSG("Unimplemented point coord element={}", element);
+ return {v_float_zero, Type::Float};
+ }
case Attribute::Index::TessCoordInstanceIDVertexID:
// TODO(Subv): Find out what the values are for the first two elements when inside a
// vertex shader, and what's the value of the fourth element when inside a Tess Eval
// shader.
- ASSERT(stage == ShaderType::Vertex);
switch (element) {
+ case 0:
+ case 1:
+ return {OpLoad(t_float, AccessElement(t_in_float, tess_coord, element)),
+ Type::Float};
case 2:
- return BitcastFrom<Type::Uint>(Emit(OpLoad(t_uint, instance_index)));
+ return {OpLoad(t_uint, instance_index), Type::Uint};
case 3:
- return BitcastFrom<Type::Uint>(Emit(OpLoad(t_uint, vertex_index)));
+ return {OpLoad(t_uint, vertex_index), Type::Uint};
}
UNIMPLEMENTED_MSG("Unmanaged TessCoordInstanceIDVertexID element={}", element);
- return Constant(t_float, 0);
+ return {Constant(t_uint, 0U), Type::Uint};
case Attribute::Index::FrontFacing:
// TODO(Subv): Find out what the values are for the other elements.
ASSERT(stage == ShaderType::Fragment);
if (element == 3) {
- const Id is_front_facing = Emit(OpLoad(t_bool, front_facing));
- const Id true_value =
- BitcastTo<Type::Float>(Constant(t_int, static_cast<s32>(-1)));
- const Id false_value = BitcastTo<Type::Float>(Constant(t_int, 0));
- return Emit(OpSelect(t_float, is_front_facing, true_value, false_value));
+ const Id is_front_facing = OpLoad(t_bool, front_facing);
+ const Id true_value = Constant(t_int, static_cast<s32>(-1));
+ const Id false_value = Constant(t_int, 0);
+ return {OpSelect(t_int, is_front_facing, true_value, false_value), Type::Int};
}
UNIMPLEMENTED_MSG("Unmanaged FrontFacing element={}", element);
- return Constant(t_float, 0.0f);
+ return {v_float_zero, Type::Float};
default:
if (IsGenericAttribute(attribute)) {
- const Id pointer =
- AccessElement(t_in_float, input_attributes.at(attribute), element);
- return Emit(OpLoad(t_float, pointer));
+ const u32 location = GetGenericAttributeLocation(attribute);
+ const auto type_descriptor = GetAttributeType(location);
+ const Type type = type_descriptor.type;
+ const Id attribute_id = input_attributes.at(attribute);
+ const Id pointer = ArrayPass(type_descriptor.scalar, attribute_id, {element});
+ return {OpLoad(GetTypeDefinition(type), pointer), type};
}
break;
}
UNIMPLEMENTED_MSG("Unhandled input attribute: {}", static_cast<u32>(attribute));
+ return {v_float_zero, Type::Float};
+ }
- } else if (const auto cbuf = std::get_if<CbufNode>(&*node)) {
+ if (const auto cbuf = std::get_if<CbufNode>(&*node)) {
const Node& offset = cbuf->GetOffset();
const Id buffer_id = constant_buffers.at(cbuf->GetIndex());
Id pointer{};
if (device.IsKhrUniformBufferStandardLayoutSupported()) {
- const Id buffer_offset = Emit(OpShiftRightLogical(
- t_uint, BitcastTo<Type::Uint>(Visit(offset)), Constant(t_uint, 2u)));
- pointer = Emit(
- OpAccessChain(t_cbuf_float, buffer_id, Constant(t_uint, 0u), buffer_offset));
+ const Id buffer_offset =
+ OpShiftRightLogical(t_uint, AsUint(Visit(offset)), Constant(t_uint, 2U));
+ pointer =
+ OpAccessChain(t_cbuf_float, buffer_id, Constant(t_uint, 0U), buffer_offset);
} else {
Id buffer_index{};
Id buffer_element{};
@@ -614,53 +1094,76 @@ private:
buffer_element = Constant(t_uint, (offset_imm / 4) % 4);
} else if (std::holds_alternative<OperationNode>(*offset)) {
// Indirect access
- const Id offset_id = BitcastTo<Type::Uint>(Visit(offset));
- const Id unsafe_offset = Emit(OpUDiv(t_uint, offset_id, Constant(t_uint, 4)));
- const Id final_offset = Emit(OpUMod(
- t_uint, unsafe_offset, Constant(t_uint, MAX_CONSTBUFFER_ELEMENTS - 1)));
- buffer_index = Emit(OpUDiv(t_uint, final_offset, Constant(t_uint, 4)));
- buffer_element = Emit(OpUMod(t_uint, final_offset, Constant(t_uint, 4)));
+ const Id offset_id = AsUint(Visit(offset));
+ const Id unsafe_offset = OpUDiv(t_uint, offset_id, Constant(t_uint, 4));
+ const Id final_offset =
+ OpUMod(t_uint, unsafe_offset, Constant(t_uint, MaxConstBufferElements - 1));
+ buffer_index = OpUDiv(t_uint, final_offset, Constant(t_uint, 4));
+ buffer_element = OpUMod(t_uint, final_offset, Constant(t_uint, 4));
} else {
UNREACHABLE_MSG("Unmanaged offset node type");
}
- pointer = Emit(OpAccessChain(t_cbuf_float, buffer_id, Constant(t_uint, 0),
- buffer_index, buffer_element));
+ pointer = OpAccessChain(t_cbuf_float, buffer_id, Constant(t_uint, 0), buffer_index,
+ buffer_element);
}
- return Emit(OpLoad(t_float, pointer));
+ return {OpLoad(t_float, pointer), Type::Float};
+ }
- } else if (const auto gmem = std::get_if<GmemNode>(&*node)) {
+ if (const auto gmem = std::get_if<GmemNode>(&*node)) {
const Id gmem_buffer = global_buffers.at(gmem->GetDescriptor());
- const Id real = BitcastTo<Type::Uint>(Visit(gmem->GetRealAddress()));
- const Id base = BitcastTo<Type::Uint>(Visit(gmem->GetBaseAddress()));
+ const Id real = AsUint(Visit(gmem->GetRealAddress()));
+ const Id base = AsUint(Visit(gmem->GetBaseAddress()));
+
+ Id offset = OpISub(t_uint, real, base);
+ offset = OpUDiv(t_uint, offset, Constant(t_uint, 4U));
+ return {OpLoad(t_float,
+ OpAccessChain(t_gmem_float, gmem_buffer, Constant(t_uint, 0U), offset)),
+ Type::Float};
+ }
- Id offset = Emit(OpISub(t_uint, real, base));
- offset = Emit(OpUDiv(t_uint, offset, Constant(t_uint, 4u)));
- return Emit(OpLoad(t_float, Emit(OpAccessChain(t_gmem_float, gmem_buffer,
- Constant(t_uint, 0u), offset))));
+ if (const auto lmem = std::get_if<LmemNode>(&*node)) {
+ Id address = AsUint(Visit(lmem->GetAddress()));
+ address = OpShiftRightLogical(t_uint, address, Constant(t_uint, 2U));
+ const Id pointer = OpAccessChain(t_prv_float, local_memory, address);
+ return {OpLoad(t_float, pointer), Type::Float};
+ }
+
+ if (const auto smem = std::get_if<SmemNode>(&*node)) {
+ Id address = AsUint(Visit(smem->GetAddress()));
+ address = OpShiftRightLogical(t_uint, address, Constant(t_uint, 2U));
+ const Id pointer = OpAccessChain(t_smem_uint, shared_memory, address);
+ return {OpLoad(t_uint, pointer), Type::Uint};
+ }
- } else if (const auto conditional = std::get_if<ConditionalNode>(&*node)) {
+ if (const auto internal_flag = std::get_if<InternalFlagNode>(&*node)) {
+ const Id flag = internal_flags.at(static_cast<std::size_t>(internal_flag->GetFlag()));
+ return {OpLoad(t_bool, flag), Type::Bool};
+ }
+
+ if (const auto conditional = std::get_if<ConditionalNode>(&*node)) {
// It's invalid to call conditional on nested nodes, use an operation instead
const Id true_label = OpLabel();
const Id skip_label = OpLabel();
- const Id condition = Visit(conditional->GetCondition());
- Emit(OpSelectionMerge(skip_label, spv::SelectionControlMask::MaskNone));
- Emit(OpBranchConditional(condition, true_label, skip_label));
- Emit(true_label);
+ const Id condition = AsBool(Visit(conditional->GetCondition()));
+ OpSelectionMerge(skip_label, spv::SelectionControlMask::MaskNone);
+ OpBranchConditional(condition, true_label, skip_label);
+ AddLabel(true_label);
- ++conditional_nest_count;
+ conditional_branch_set = true;
+ inside_branch = false;
VisitBasicBlock(conditional->GetCode());
- --conditional_nest_count;
-
- if (inside_branch == 0) {
- Emit(OpBranch(skip_label));
+ conditional_branch_set = false;
+ if (!inside_branch) {
+ OpBranch(skip_label);
} else {
- inside_branch--;
+ inside_branch = false;
}
- Emit(skip_label);
+ AddLabel(skip_label);
return {};
+ }
- } else if (const auto comment = std::get_if<CommentNode>(&*node)) {
- Name(Emit(OpUndef(t_void)), comment->GetText());
+ if (const auto comment = std::get_if<CommentNode>(&*node)) {
+ Name(OpUndef(t_void), comment->GetText());
return {};
}
@@ -669,94 +1172,126 @@ private:
}
template <Id (Module::*func)(Id, Id), Type result_type, Type type_a = result_type>
- Id Unary(Operation operation) {
+ Expression Unary(Operation operation) {
const Id type_def = GetTypeDefinition(result_type);
- const Id op_a = VisitOperand<type_a>(operation, 0);
+ const Id op_a = As(Visit(operation[0]), type_a);
- const Id value = BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a)));
+ const Id value = (this->*func)(type_def, op_a);
if (IsPrecise(operation)) {
Decorate(value, spv::Decoration::NoContraction);
}
- return value;
+ return {value, result_type};
}
template <Id (Module::*func)(Id, Id, Id), Type result_type, Type type_a = result_type,
Type type_b = type_a>
- Id Binary(Operation operation) {
+ Expression Binary(Operation operation) {
const Id type_def = GetTypeDefinition(result_type);
- const Id op_a = VisitOperand<type_a>(operation, 0);
- const Id op_b = VisitOperand<type_b>(operation, 1);
+ const Id op_a = As(Visit(operation[0]), type_a);
+ const Id op_b = As(Visit(operation[1]), type_b);
- const Id value = BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a, op_b)));
+ const Id value = (this->*func)(type_def, op_a, op_b);
if (IsPrecise(operation)) {
Decorate(value, spv::Decoration::NoContraction);
}
- return value;
+ return {value, result_type};
}
template <Id (Module::*func)(Id, Id, Id, Id), Type result_type, Type type_a = result_type,
Type type_b = type_a, Type type_c = type_b>
- Id Ternary(Operation operation) {
+ Expression Ternary(Operation operation) {
const Id type_def = GetTypeDefinition(result_type);
- const Id op_a = VisitOperand<type_a>(operation, 0);
- const Id op_b = VisitOperand<type_b>(operation, 1);
- const Id op_c = VisitOperand<type_c>(operation, 2);
+ const Id op_a = As(Visit(operation[0]), type_a);
+ const Id op_b = As(Visit(operation[1]), type_b);
+ const Id op_c = As(Visit(operation[2]), type_c);
- const Id value = BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a, op_b, op_c)));
+ const Id value = (this->*func)(type_def, op_a, op_b, op_c);
if (IsPrecise(operation)) {
Decorate(value, spv::Decoration::NoContraction);
}
- return value;
+ return {value, result_type};
}
template <Id (Module::*func)(Id, Id, Id, Id, Id), Type result_type, Type type_a = result_type,
Type type_b = type_a, Type type_c = type_b, Type type_d = type_c>
- Id Quaternary(Operation operation) {
+ Expression Quaternary(Operation operation) {
const Id type_def = GetTypeDefinition(result_type);
- const Id op_a = VisitOperand<type_a>(operation, 0);
- const Id op_b = VisitOperand<type_b>(operation, 1);
- const Id op_c = VisitOperand<type_c>(operation, 2);
- const Id op_d = VisitOperand<type_d>(operation, 3);
+ const Id op_a = As(Visit(operation[0]), type_a);
+ const Id op_b = As(Visit(operation[1]), type_b);
+ const Id op_c = As(Visit(operation[2]), type_c);
+ const Id op_d = As(Visit(operation[3]), type_d);
- const Id value =
- BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a, op_b, op_c, op_d)));
+ const Id value = (this->*func)(type_def, op_a, op_b, op_c, op_d);
if (IsPrecise(operation)) {
Decorate(value, spv::Decoration::NoContraction);
}
- return value;
+ return {value, result_type};
}
- Id Assign(Operation operation) {
+ Expression Assign(Operation operation) {
const Node& dest = operation[0];
const Node& src = operation[1];
- Id target{};
+ Expression target{};
if (const auto gpr = std::get_if<GprNode>(&*dest)) {
if (gpr->GetIndex() == Register::ZeroIndex) {
// Writing to Register::ZeroIndex is a no op
return {};
}
- target = registers.at(gpr->GetIndex());
+ target = {registers.at(gpr->GetIndex()), Type::Float};
} else if (const auto abuf = std::get_if<AbufNode>(&*dest)) {
- target = [&]() -> Id {
+ const auto& buffer = abuf->GetBuffer();
+ const auto ArrayPass = [&](Id pointer_type, Id composite, std::vector<u32> indices) {
+ std::vector<Id> members;
+ members.reserve(std::size(indices) + 1);
+
+ if (buffer && IsOutputAttributeArray()) {
+ members.push_back(AsUint(Visit(buffer)));
+ }
+ for (const u32 index : indices) {
+ members.push_back(Constant(t_uint, index));
+ }
+ return OpAccessChain(pointer_type, composite, members);
+ };
+
+ target = [&]() -> Expression {
+ const u32 element = abuf->GetElement();
switch (const auto attribute = abuf->GetIndex(); attribute) {
- case Attribute::Index::Position:
- return AccessElement(t_out_float, per_vertex, position_index,
- abuf->GetElement());
+ case Attribute::Index::Position: {
+ const u32 index = out_indices.position.value();
+ return {ArrayPass(t_out_float, out_vertex, {index, element}), Type::Float};
+ }
case Attribute::Index::LayerViewportPointSize:
- UNIMPLEMENTED_IF(abuf->GetElement() != 3);
- return AccessElement(t_out_float, per_vertex, point_size_index);
- case Attribute::Index::ClipDistances0123:
- return AccessElement(t_out_float, per_vertex, clip_distances_index,
- abuf->GetElement());
- case Attribute::Index::ClipDistances4567:
- return AccessElement(t_out_float, per_vertex, clip_distances_index,
- abuf->GetElement() + 4);
+ switch (element) {
+ case 2: {
+ if (!out_indices.viewport) {
+ return {};
+ }
+ const u32 index = out_indices.viewport.value();
+ return {AccessElement(t_out_int, out_vertex, index), Type::Int};
+ }
+ case 3: {
+ const auto index = out_indices.point_size.value();
+ return {AccessElement(t_out_float, out_vertex, index), Type::Float};
+ }
+ default:
+ UNIMPLEMENTED_MSG("LayerViewportPoint element={}", abuf->GetElement());
+ return {};
+ }
+ case Attribute::Index::ClipDistances0123: {
+ const u32 index = out_indices.clip_distances.value();
+ return {AccessElement(t_out_float, out_vertex, index, element), Type::Float};
+ }
+ case Attribute::Index::ClipDistances4567: {
+ const u32 index = out_indices.clip_distances.value();
+ return {AccessElement(t_out_float, out_vertex, index, element + 4),
+ Type::Float};
+ }
default:
if (IsGenericAttribute(attribute)) {
- return AccessElement(t_out_float, output_attributes.at(attribute),
- abuf->GetElement());
+ const Id composite = output_attributes.at(attribute);
+ return {ArrayPass(t_out_float, composite, {element}), Type::Float};
}
UNIMPLEMENTED_MSG("Unhandled output attribute: {}",
static_cast<u32>(attribute));
@@ -764,72 +1299,154 @@ private:
}
}();
+ } else if (const auto patch = std::get_if<PatchNode>(&*dest)) {
+ target = [&]() -> Expression {
+ const u32 offset = patch->GetOffset();
+ switch (offset) {
+ case 0:
+ case 1:
+ case 2:
+ case 3:
+ return {AccessElement(t_out_float, tess_level_outer, offset % 4), Type::Float};
+ case 4:
+ case 5:
+ return {AccessElement(t_out_float, tess_level_inner, offset % 4), Type::Float};
+ }
+ UNIMPLEMENTED_MSG("Unhandled patch output offset: {}", offset);
+ return {};
+ }();
+
} else if (const auto lmem = std::get_if<LmemNode>(&*dest)) {
- Id address = BitcastTo<Type::Uint>(Visit(lmem->GetAddress()));
- address = Emit(OpUDiv(t_uint, address, Constant(t_uint, 4)));
- target = Emit(OpAccessChain(t_prv_float, local_memory, {address}));
+ Id address = AsUint(Visit(lmem->GetAddress()));
+ address = OpUDiv(t_uint, address, Constant(t_uint, 4));
+ target = {OpAccessChain(t_prv_float, local_memory, address), Type::Float};
+
+ } else if (const auto smem = std::get_if<SmemNode>(&*dest)) {
+ ASSERT(stage == ShaderType::Compute);
+ Id address = AsUint(Visit(smem->GetAddress()));
+ address = OpShiftRightLogical(t_uint, address, Constant(t_uint, 2U));
+ target = {OpAccessChain(t_smem_uint, shared_memory, address), Type::Uint};
+
+ } else if (const auto gmem = std::get_if<GmemNode>(&*dest)) {
+ const Id real = AsUint(Visit(gmem->GetRealAddress()));
+ const Id base = AsUint(Visit(gmem->GetBaseAddress()));
+ const Id diff = OpISub(t_uint, real, base);
+ const Id offset = OpShiftRightLogical(t_uint, diff, Constant(t_uint, 2));
+
+ const Id gmem_buffer = global_buffers.at(gmem->GetDescriptor());
+ target = {OpAccessChain(t_gmem_float, gmem_buffer, Constant(t_uint, 0), offset),
+ Type::Float};
+
+ } else {
+ UNIMPLEMENTED();
}
- Emit(OpStore(target, Visit(src)));
+ OpStore(target.id, As(Visit(src), target.type));
return {};
}
- Id FCastHalf0(Operation operation) {
- UNIMPLEMENTED();
- return {};
+ template <u32 offset>
+ Expression FCastHalf(Operation operation) {
+ const Id value = AsHalfFloat(Visit(operation[0]));
+ return {GetFloatFromHalfScalar(OpCompositeExtract(t_scalar_half, value, offset)),
+ Type::Float};
}
- Id FCastHalf1(Operation operation) {
- UNIMPLEMENTED();
- return {};
- }
+ Expression FSwizzleAdd(Operation operation) {
+ const Id minus = Constant(t_float, -1.0f);
+ const Id plus = v_float_one;
+ const Id zero = v_float_zero;
+ const Id lut_a = ConstantComposite(t_float4, minus, plus, minus, zero);
+ const Id lut_b = ConstantComposite(t_float4, minus, minus, plus, minus);
- Id FSwizzleAdd(Operation operation) {
- UNIMPLEMENTED();
- return {};
- }
+ Id mask = OpLoad(t_uint, thread_id);
+ mask = OpBitwiseAnd(t_uint, mask, Constant(t_uint, 3));
+ mask = OpShiftLeftLogical(t_uint, mask, Constant(t_uint, 1));
+ mask = OpShiftRightLogical(t_uint, AsUint(Visit(operation[2])), mask);
+ mask = OpBitwiseAnd(t_uint, mask, Constant(t_uint, 3));
- Id HNegate(Operation operation) {
- UNIMPLEMENTED();
- return {};
+ const Id modifier_a = OpVectorExtractDynamic(t_float, lut_a, mask);
+ const Id modifier_b = OpVectorExtractDynamic(t_float, lut_b, mask);
+
+ const Id op_a = OpFMul(t_float, AsFloat(Visit(operation[0])), modifier_a);
+ const Id op_b = OpFMul(t_float, AsFloat(Visit(operation[1])), modifier_b);
+ return {OpFAdd(t_float, op_a, op_b), Type::Float};
}
- Id HClamp(Operation operation) {
- UNIMPLEMENTED();
- return {};
+ Expression HNegate(Operation operation) {
+ const bool is_f16 = device.IsFloat16Supported();
+ const Id minus_one = Constant(t_scalar_half, is_f16 ? 0xbc00 : 0xbf800000);
+ const Id one = Constant(t_scalar_half, is_f16 ? 0x3c00 : 0x3f800000);
+ const auto GetNegate = [&](std::size_t index) {
+ return OpSelect(t_scalar_half, AsBool(Visit(operation[index])), minus_one, one);
+ };
+ const Id negation = OpCompositeConstruct(t_half, GetNegate(1), GetNegate(2));
+ return {OpFMul(t_half, AsHalfFloat(Visit(operation[0])), negation), Type::HalfFloat};
}
- Id HCastFloat(Operation operation) {
- UNIMPLEMENTED();
- return {};
+ Expression HClamp(Operation operation) {
+ const auto Pack = [&](std::size_t index) {
+ const Id scalar = GetHalfScalarFromFloat(AsFloat(Visit(operation[index])));
+ return OpCompositeConstruct(t_half, scalar, scalar);
+ };
+ const Id value = AsHalfFloat(Visit(operation[0]));
+ const Id min = Pack(1);
+ const Id max = Pack(2);
+
+ const Id clamped = OpFClamp(t_half, value, min, max);
+ if (IsPrecise(operation)) {
+ Decorate(clamped, spv::Decoration::NoContraction);
+ }
+ return {clamped, Type::HalfFloat};
}
- Id HUnpack(Operation operation) {
- UNIMPLEMENTED();
- return {};
+ Expression HCastFloat(Operation operation) {
+ const Id value = GetHalfScalarFromFloat(AsFloat(Visit(operation[0])));
+ return {OpCompositeConstruct(t_half, value, Constant(t_scalar_half, 0)), Type::HalfFloat};
}
- Id HMergeF32(Operation operation) {
- UNIMPLEMENTED();
- return {};
+ Expression HUnpack(Operation operation) {
+ Expression operand = Visit(operation[0]);
+ const auto type = std::get<Tegra::Shader::HalfType>(operation.GetMeta());
+ if (type == Tegra::Shader::HalfType::H0_H1) {
+ return operand;
+ }
+ const auto value = [&] {
+ switch (std::get<Tegra::Shader::HalfType>(operation.GetMeta())) {
+ case Tegra::Shader::HalfType::F32:
+ return GetHalfScalarFromFloat(AsFloat(operand));
+ case Tegra::Shader::HalfType::H0_H0:
+ return OpCompositeExtract(t_scalar_half, AsHalfFloat(operand), 0);
+ case Tegra::Shader::HalfType::H1_H1:
+ return OpCompositeExtract(t_scalar_half, AsHalfFloat(operand), 1);
+ default:
+ UNREACHABLE();
+ return ConstantNull(t_half);
+ }
+ }();
+ return {OpCompositeConstruct(t_half, value, value), Type::HalfFloat};
}
- Id HMergeH0(Operation operation) {
- UNIMPLEMENTED();
- return {};
+ Expression HMergeF32(Operation operation) {
+ const Id value = AsHalfFloat(Visit(operation[0]));
+ return {GetFloatFromHalfScalar(OpCompositeExtract(t_scalar_half, value, 0)), Type::Float};
}
- Id HMergeH1(Operation operation) {
- UNIMPLEMENTED();
- return {};
+ template <u32 offset>
+ Expression HMergeHN(Operation operation) {
+ const Id target = AsHalfFloat(Visit(operation[0]));
+ const Id source = AsHalfFloat(Visit(operation[1]));
+ const Id object = OpCompositeExtract(t_scalar_half, source, offset);
+ return {OpCompositeInsert(t_half, object, target, offset), Type::HalfFloat};
}
- Id HPack2(Operation operation) {
- UNIMPLEMENTED();
- return {};
+ Expression HPack2(Operation operation) {
+ const Id low = GetHalfScalarFromFloat(AsFloat(Visit(operation[0])));
+ const Id high = GetHalfScalarFromFloat(AsFloat(Visit(operation[1])));
+ return {OpCompositeConstruct(t_half, low, high), Type::HalfFloat};
}
- Id LogicalAssign(Operation operation) {
+ Expression LogicalAssign(Operation operation) {
const Node& dest = operation[0];
const Node& src = operation[1];
@@ -850,106 +1467,190 @@ private:
target = internal_flags.at(static_cast<u32>(flag->GetFlag()));
}
- Emit(OpStore(target, Visit(src)));
+ OpStore(target, AsBool(Visit(src)));
return {};
}
- Id LogicalPick2(Operation operation) {
- UNIMPLEMENTED();
- return {};
+ Id GetTextureSampler(Operation operation) {
+ const auto& meta = std::get<MetaTexture>(operation.GetMeta());
+ ASSERT(!meta.sampler.IsBuffer());
+
+ const auto& entry = sampled_images.at(meta.sampler.GetIndex());
+ return OpLoad(entry.sampled_image_type, entry.sampler);
}
- Id LogicalAnd2(Operation operation) {
- UNIMPLEMENTED();
- return {};
+ Id GetTextureImage(Operation operation) {
+ const auto& meta = std::get<MetaTexture>(operation.GetMeta());
+ const u32 index = meta.sampler.GetIndex();
+ if (meta.sampler.IsBuffer()) {
+ const auto& entry = texel_buffers.at(index);
+ return OpLoad(entry.image_type, entry.image);
+ } else {
+ const auto& entry = sampled_images.at(index);
+ return OpImage(entry.image_type, GetTextureSampler(operation));
+ }
}
- Id GetTextureSampler(Operation operation) {
- const auto meta = std::get_if<MetaTexture>(&operation.GetMeta());
- const auto entry = sampler_images.at(static_cast<u32>(meta->sampler.GetIndex()));
- return Emit(OpLoad(entry.sampled_image_type, entry.sampler));
+ Id GetImage(Operation operation) {
+ const auto& meta = std::get<MetaImage>(operation.GetMeta());
+ const auto entry = images.at(meta.image.GetIndex());
+ return OpLoad(entry.image_type, entry.image);
}
- Id GetTextureImage(Operation operation) {
- const auto meta = std::get_if<MetaTexture>(&operation.GetMeta());
- const auto entry = sampler_images.at(static_cast<u32>(meta->sampler.GetIndex()));
- return Emit(OpImage(entry.image_type, GetTextureSampler(operation)));
+ Id AssembleVector(const std::vector<Id>& coords, Type type) {
+ const Id coords_type = GetTypeVectorDefinitionLut(type).at(coords.size() - 1);
+ return coords.size() == 1 ? coords[0] : OpCompositeConstruct(coords_type, coords);
}
- Id GetTextureCoordinates(Operation operation) {
- const auto meta = std::get_if<MetaTexture>(&operation.GetMeta());
+ Id GetCoordinates(Operation operation, Type type) {
std::vector<Id> coords;
for (std::size_t i = 0; i < operation.GetOperandsCount(); ++i) {
- coords.push_back(Visit(operation[i]));
+ coords.push_back(As(Visit(operation[i]), type));
}
- if (meta->sampler.IsArray()) {
- const Id array_integer = BitcastTo<Type::Int>(Visit(meta->array));
- coords.push_back(Emit(OpConvertSToF(t_float, array_integer)));
+ if (const auto meta = std::get_if<MetaTexture>(&operation.GetMeta())) {
+ // Add array coordinate for textures
+ if (meta->sampler.IsArray()) {
+ Id array = AsInt(Visit(meta->array));
+ if (type == Type::Float) {
+ array = OpConvertSToF(t_float, array);
+ }
+ coords.push_back(array);
+ }
}
- if (meta->sampler.IsShadow()) {
- coords.push_back(Visit(meta->depth_compare));
+ return AssembleVector(coords, type);
+ }
+
+ Id GetOffsetCoordinates(Operation operation) {
+ const auto& meta = std::get<MetaTexture>(operation.GetMeta());
+ std::vector<Id> coords;
+ coords.reserve(meta.aoffi.size());
+ for (const auto& coord : meta.aoffi) {
+ coords.push_back(AsInt(Visit(coord)));
}
+ return AssembleVector(coords, Type::Int);
+ }
- const std::array<Id, 4> t_float_lut = {nullptr, t_float2, t_float3, t_float4};
- return coords.size() == 1
- ? coords[0]
- : Emit(OpCompositeConstruct(t_float_lut.at(coords.size() - 1), coords));
+ std::pair<Id, Id> GetDerivatives(Operation operation) {
+ const auto& meta = std::get<MetaTexture>(operation.GetMeta());
+ const auto& derivatives = meta.derivates;
+ ASSERT(derivatives.size() % 2 == 0);
+
+ const std::size_t components = derivatives.size() / 2;
+ std::vector<Id> dx, dy;
+ dx.reserve(components);
+ dy.reserve(components);
+ for (std::size_t index = 0; index < components; ++index) {
+ dx.push_back(AsFloat(Visit(derivatives.at(index * 2 + 0))));
+ dy.push_back(AsFloat(Visit(derivatives.at(index * 2 + 1))));
+ }
+ return {AssembleVector(dx, Type::Float), AssembleVector(dy, Type::Float)};
}
- Id GetTextureElement(Operation operation, Id sample_value) {
- const auto meta = std::get_if<MetaTexture>(&operation.GetMeta());
- ASSERT(meta);
- return Emit(OpCompositeExtract(t_float, sample_value, meta->element));
+ Expression GetTextureElement(Operation operation, Id sample_value, Type type) {
+ const auto& meta = std::get<MetaTexture>(operation.GetMeta());
+ const auto type_def = GetTypeDefinition(type);
+ return {OpCompositeExtract(type_def, sample_value, meta.element), type};
}
- Id Texture(Operation operation) {
- const Id texture = Emit(OpImageSampleImplicitLod(t_float4, GetTextureSampler(operation),
- GetTextureCoordinates(operation)));
- return GetTextureElement(operation, texture);
+ Expression Texture(Operation operation) {
+ const auto& meta = std::get<MetaTexture>(operation.GetMeta());
+ UNIMPLEMENTED_IF(!meta.aoffi.empty());
+
+ const bool can_implicit = stage == ShaderType::Fragment;
+ const Id sampler = GetTextureSampler(operation);
+ const Id coords = GetCoordinates(operation, Type::Float);
+
+ if (meta.depth_compare) {
+ // Depth sampling
+ UNIMPLEMENTED_IF(meta.bias);
+ const Id dref = AsFloat(Visit(meta.depth_compare));
+ if (can_implicit) {
+ return {OpImageSampleDrefImplicitLod(t_float, sampler, coords, dref, {}),
+ Type::Float};
+ } else {
+ return {OpImageSampleDrefExplicitLod(t_float, sampler, coords, dref,
+ spv::ImageOperandsMask::Lod, v_float_zero),
+ Type::Float};
+ }
+ }
+
+ std::vector<Id> operands;
+ spv::ImageOperandsMask mask{};
+ if (meta.bias) {
+ mask = mask | spv::ImageOperandsMask::Bias;
+ operands.push_back(AsFloat(Visit(meta.bias)));
+ }
+
+ Id texture;
+ if (can_implicit) {
+ texture = OpImageSampleImplicitLod(t_float4, sampler, coords, mask, operands);
+ } else {
+ texture = OpImageSampleExplicitLod(t_float4, sampler, coords,
+ mask | spv::ImageOperandsMask::Lod, v_float_zero,
+ operands);
+ }
+ return GetTextureElement(operation, texture, Type::Float);
}
- Id TextureLod(Operation operation) {
- const auto meta = std::get_if<MetaTexture>(&operation.GetMeta());
- const Id texture = Emit(OpImageSampleExplicitLod(
- t_float4, GetTextureSampler(operation), GetTextureCoordinates(operation),
- spv::ImageOperandsMask::Lod, Visit(meta->lod)));
- return GetTextureElement(operation, texture);
+ Expression TextureLod(Operation operation) {
+ const auto& meta = std::get<MetaTexture>(operation.GetMeta());
+
+ const Id sampler = GetTextureSampler(operation);
+ const Id coords = GetCoordinates(operation, Type::Float);
+ const Id lod = AsFloat(Visit(meta.lod));
+
+ spv::ImageOperandsMask mask = spv::ImageOperandsMask::Lod;
+ std::vector<Id> operands;
+ if (!meta.aoffi.empty()) {
+ mask = mask | spv::ImageOperandsMask::Offset;
+ operands.push_back(GetOffsetCoordinates(operation));
+ }
+
+ if (meta.sampler.IsShadow()) {
+ const Id dref = AsFloat(Visit(meta.depth_compare));
+ return {
+ OpImageSampleDrefExplicitLod(t_float, sampler, coords, dref, mask, lod, operands),
+ Type::Float};
+ }
+ const Id texture = OpImageSampleExplicitLod(t_float4, sampler, coords, mask, lod, operands);
+ return GetTextureElement(operation, texture, Type::Float);
}
- Id TextureGather(Operation operation) {
- const auto meta = std::get_if<MetaTexture>(&operation.GetMeta());
- const auto coords = GetTextureCoordinates(operation);
+ Expression TextureGather(Operation operation) {
+ const auto& meta = std::get<MetaTexture>(operation.GetMeta());
+ UNIMPLEMENTED_IF(!meta.aoffi.empty());
- Id texture;
- if (meta->sampler.IsShadow()) {
- texture = Emit(OpImageDrefGather(t_float4, GetTextureSampler(operation), coords,
- Visit(meta->component)));
+ const Id coords = GetCoordinates(operation, Type::Float);
+ Id texture{};
+ if (meta.sampler.IsShadow()) {
+ texture = OpImageDrefGather(t_float4, GetTextureSampler(operation), coords,
+ AsFloat(Visit(meta.depth_compare)));
} else {
u32 component_value = 0;
- if (meta->component) {
- const auto component = std::get_if<ImmediateNode>(&*meta->component);
+ if (meta.component) {
+ const auto component = std::get_if<ImmediateNode>(&*meta.component);
ASSERT_MSG(component, "Component is not an immediate value");
component_value = component->GetValue();
}
- texture = Emit(OpImageGather(t_float4, GetTextureSampler(operation), coords,
- Constant(t_uint, component_value)));
+ texture = OpImageGather(t_float4, GetTextureSampler(operation), coords,
+ Constant(t_uint, component_value));
}
-
- return GetTextureElement(operation, texture);
+ return GetTextureElement(operation, texture, Type::Float);
}
- Id TextureQueryDimensions(Operation operation) {
- const auto meta = std::get_if<MetaTexture>(&operation.GetMeta());
- const auto image_id = GetTextureImage(operation);
- AddCapability(spv::Capability::ImageQuery);
+ Expression TextureQueryDimensions(Operation operation) {
+ const auto& meta = std::get<MetaTexture>(operation.GetMeta());
+ UNIMPLEMENTED_IF(!meta.aoffi.empty());
+ UNIMPLEMENTED_IF(meta.depth_compare);
- if (meta->element == 3) {
- return BitcastTo<Type::Float>(Emit(OpImageQueryLevels(t_int, image_id)));
+ const auto image_id = GetTextureImage(operation);
+ if (meta.element == 3) {
+ return {OpImageQueryLevels(t_int, image_id), Type::Int};
}
- const Id lod = VisitOperand<Type::Uint>(operation, 0);
+ const Id lod = AsUint(Visit(operation[0]));
const std::size_t coords_count = [&]() {
- switch (const auto type = meta->sampler.GetType(); type) {
+ switch (const auto type = meta.sampler.GetType(); type) {
case Tegra::Shader::TextureType::Texture1D:
return 1;
case Tegra::Shader::TextureType::Texture2D:
@@ -963,141 +1664,190 @@ private:
}
}();
- if (meta->element >= coords_count) {
- return Constant(t_float, 0.0f);
+ if (meta.element >= coords_count) {
+ return {v_float_zero, Type::Float};
}
const std::array<Id, 3> types = {t_int, t_int2, t_int3};
- const Id sizes = Emit(OpImageQuerySizeLod(types.at(coords_count - 1), image_id, lod));
- const Id size = Emit(OpCompositeExtract(t_int, sizes, meta->element));
- return BitcastTo<Type::Float>(size);
+ const Id sizes = OpImageQuerySizeLod(types.at(coords_count - 1), image_id, lod);
+ const Id size = OpCompositeExtract(t_int, sizes, meta.element);
+ return {size, Type::Int};
}
- Id TextureQueryLod(Operation operation) {
- UNIMPLEMENTED();
- return {};
+ Expression TextureQueryLod(Operation operation) {
+ const auto& meta = std::get<MetaTexture>(operation.GetMeta());
+ UNIMPLEMENTED_IF(!meta.aoffi.empty());
+ UNIMPLEMENTED_IF(meta.depth_compare);
+
+ if (meta.element >= 2) {
+ UNREACHABLE_MSG("Invalid element");
+ return {v_float_zero, Type::Float};
+ }
+ const auto sampler_id = GetTextureSampler(operation);
+
+ const Id multiplier = Constant(t_float, 256.0f);
+ const Id multipliers = ConstantComposite(t_float2, multiplier, multiplier);
+
+ const Id coords = GetCoordinates(operation, Type::Float);
+ Id size = OpImageQueryLod(t_float2, sampler_id, coords);
+ size = OpFMul(t_float2, size, multipliers);
+ size = OpConvertFToS(t_int2, size);
+ return GetTextureElement(operation, size, Type::Int);
}
- Id TexelFetch(Operation operation) {
+ Expression TexelFetch(Operation operation) {
+ const auto& meta = std::get<MetaTexture>(operation.GetMeta());
+ UNIMPLEMENTED_IF(meta.depth_compare);
+
+ const Id image = GetTextureImage(operation);
+ const Id coords = GetCoordinates(operation, Type::Int);
+ Id fetch;
+ if (meta.lod && !meta.sampler.IsBuffer()) {
+ fetch = OpImageFetch(t_float4, image, coords, spv::ImageOperandsMask::Lod,
+ AsInt(Visit(meta.lod)));
+ } else {
+ fetch = OpImageFetch(t_float4, image, coords);
+ }
+ return GetTextureElement(operation, fetch, Type::Float);
+ }
+
+ Expression TextureGradient(Operation operation) {
+ const auto& meta = std::get<MetaTexture>(operation.GetMeta());
+ UNIMPLEMENTED_IF(!meta.aoffi.empty());
+
+ const Id sampler = GetTextureSampler(operation);
+ const Id coords = GetCoordinates(operation, Type::Float);
+ const auto [dx, dy] = GetDerivatives(operation);
+ const std::vector grad = {dx, dy};
+
+ static constexpr auto mask = spv::ImageOperandsMask::Grad;
+ const Id texture = OpImageSampleImplicitLod(t_float4, sampler, coords, mask, grad);
+ return GetTextureElement(operation, texture, Type::Float);
+ }
+
+ Expression ImageLoad(Operation operation) {
UNIMPLEMENTED();
return {};
}
- Id TextureGradient(Operation operation) {
- UNIMPLEMENTED();
+ Expression ImageStore(Operation operation) {
+ const auto meta{std::get<MetaImage>(operation.GetMeta())};
+ std::vector<Id> colors;
+ for (const auto& value : meta.values) {
+ colors.push_back(AsUint(Visit(value)));
+ }
+
+ const Id coords = GetCoordinates(operation, Type::Int);
+ const Id texel = OpCompositeConstruct(t_uint4, colors);
+
+ OpImageWrite(GetImage(operation), coords, texel, {});
return {};
}
- Id ImageLoad(Operation operation) {
+ Expression AtomicImageAdd(Operation operation) {
UNIMPLEMENTED();
return {};
}
- Id ImageStore(Operation operation) {
+ Expression AtomicImageMin(Operation operation) {
UNIMPLEMENTED();
return {};
}
- Id AtomicImageAdd(Operation operation) {
+ Expression AtomicImageMax(Operation operation) {
UNIMPLEMENTED();
return {};
}
- Id AtomicImageAnd(Operation operation) {
+ Expression AtomicImageAnd(Operation operation) {
UNIMPLEMENTED();
return {};
}
- Id AtomicImageOr(Operation operation) {
+ Expression AtomicImageOr(Operation operation) {
UNIMPLEMENTED();
return {};
}
- Id AtomicImageXor(Operation operation) {
+ Expression AtomicImageXor(Operation operation) {
UNIMPLEMENTED();
return {};
}
- Id AtomicImageExchange(Operation operation) {
+ Expression AtomicImageExchange(Operation operation) {
UNIMPLEMENTED();
return {};
}
- Id Branch(Operation operation) {
- const auto target = std::get_if<ImmediateNode>(&*operation[0]);
- UNIMPLEMENTED_IF(!target);
-
- Emit(OpStore(jmp_to, Constant(t_uint, target->GetValue())));
- Emit(OpBranch(continue_label));
- inside_branch = conditional_nest_count;
- if (conditional_nest_count == 0) {
- Emit(OpLabel());
+ Expression Branch(Operation operation) {
+ const auto& target = std::get<ImmediateNode>(*operation[0]);
+ OpStore(jmp_to, Constant(t_uint, target.GetValue()));
+ OpBranch(continue_label);
+ inside_branch = true;
+ if (!conditional_branch_set) {
+ AddLabel();
}
return {};
}
- Id BranchIndirect(Operation operation) {
- const Id op_a = VisitOperand<Type::Uint>(operation, 0);
+ Expression BranchIndirect(Operation operation) {
+ const Id op_a = AsUint(Visit(operation[0]));
- Emit(OpStore(jmp_to, op_a));
- Emit(OpBranch(continue_label));
- inside_branch = conditional_nest_count;
- if (conditional_nest_count == 0) {
- Emit(OpLabel());
+ OpStore(jmp_to, op_a);
+ OpBranch(continue_label);
+ inside_branch = true;
+ if (!conditional_branch_set) {
+ AddLabel();
}
return {};
}
- Id PushFlowStack(Operation operation) {
- const auto target = std::get_if<ImmediateNode>(&*operation[0]);
- ASSERT(target);
-
+ Expression PushFlowStack(Operation operation) {
+ const auto& target = std::get<ImmediateNode>(*operation[0]);
const auto [flow_stack, flow_stack_top] = GetFlowStack(operation);
- const Id current = Emit(OpLoad(t_uint, flow_stack_top));
- const Id next = Emit(OpIAdd(t_uint, current, Constant(t_uint, 1)));
- const Id access = Emit(OpAccessChain(t_func_uint, flow_stack, current));
+ const Id current = OpLoad(t_uint, flow_stack_top);
+ const Id next = OpIAdd(t_uint, current, Constant(t_uint, 1));
+ const Id access = OpAccessChain(t_func_uint, flow_stack, current);
- Emit(OpStore(access, Constant(t_uint, target->GetValue())));
- Emit(OpStore(flow_stack_top, next));
+ OpStore(access, Constant(t_uint, target.GetValue()));
+ OpStore(flow_stack_top, next);
return {};
}
- Id PopFlowStack(Operation operation) {
+ Expression PopFlowStack(Operation operation) {
const auto [flow_stack, flow_stack_top] = GetFlowStack(operation);
- const Id current = Emit(OpLoad(t_uint, flow_stack_top));
- const Id previous = Emit(OpISub(t_uint, current, Constant(t_uint, 1)));
- const Id access = Emit(OpAccessChain(t_func_uint, flow_stack, previous));
- const Id target = Emit(OpLoad(t_uint, access));
-
- Emit(OpStore(flow_stack_top, previous));
- Emit(OpStore(jmp_to, target));
- Emit(OpBranch(continue_label));
- inside_branch = conditional_nest_count;
- if (conditional_nest_count == 0) {
- Emit(OpLabel());
+ const Id current = OpLoad(t_uint, flow_stack_top);
+ const Id previous = OpISub(t_uint, current, Constant(t_uint, 1));
+ const Id access = OpAccessChain(t_func_uint, flow_stack, previous);
+ const Id target = OpLoad(t_uint, access);
+
+ OpStore(flow_stack_top, previous);
+ OpStore(jmp_to, target);
+ OpBranch(continue_label);
+ inside_branch = true;
+ if (!conditional_branch_set) {
+ AddLabel();
}
return {};
}
- Id PreExit() {
- switch (stage) {
- case ShaderType::Vertex: {
- // TODO(Rodrigo): We should use VK_EXT_depth_range_unrestricted instead, but it doesn't
- // seem to be working on Nvidia's drivers and Intel (mesa and blob) doesn't support it.
- const Id z_pointer = AccessElement(t_out_float, per_vertex, position_index, 2u);
- Id depth = Emit(OpLoad(t_float, z_pointer));
- depth = Emit(OpFAdd(t_float, depth, Constant(t_float, 1.0f)));
- depth = Emit(OpFMul(t_float, depth, Constant(t_float, 0.5f)));
- Emit(OpStore(z_pointer, depth));
- break;
+ void PreExit() {
+ if (stage == ShaderType::Vertex) {
+ const u32 position_index = out_indices.position.value();
+ const Id z_pointer = AccessElement(t_out_float, out_vertex, position_index, 2U);
+ const Id w_pointer = AccessElement(t_out_float, out_vertex, position_index, 3U);
+ Id depth = OpLoad(t_float, z_pointer);
+ depth = OpFAdd(t_float, depth, OpLoad(t_float, w_pointer));
+ depth = OpFMul(t_float, depth, Constant(t_float, 0.5f));
+ OpStore(z_pointer, depth);
}
- case ShaderType::Fragment: {
+ if (stage == ShaderType::Fragment) {
const auto SafeGetRegister = [&](u32 reg) {
// TODO(Rodrigo): Replace with contains once C++20 releases
if (const auto it = registers.find(reg); it != registers.end()) {
- return Emit(OpLoad(t_float, it->second));
+ return OpLoad(t_float, it->second);
}
- return Constant(t_float, 0.0f);
+ return v_float_zero;
};
UNIMPLEMENTED_IF_MSG(header.ps.omap.sample_mask != 0,
@@ -1112,8 +1862,8 @@ private:
// TODO(Subv): Figure out how dual-source blending is configured in the Switch.
for (u32 component = 0; component < 4; ++component) {
if (header.ps.IsColorComponentOutputEnabled(rt, component)) {
- Emit(OpStore(AccessElement(t_out_float, frag_colors.at(rt), component),
- SafeGetRegister(current_reg)));
+ OpStore(AccessElement(t_out_float, frag_colors.at(rt), component),
+ SafeGetRegister(current_reg));
++current_reg;
}
}
@@ -1121,110 +1871,117 @@ private:
if (header.ps.omap.depth) {
// The depth output is always 2 registers after the last color output, and
// current_reg already contains one past the last color register.
- Emit(OpStore(frag_depth, SafeGetRegister(current_reg + 1)));
+ OpStore(frag_depth, SafeGetRegister(current_reg + 1));
}
- break;
- }
}
-
- return {};
}
- Id Exit(Operation operation) {
+ Expression Exit(Operation operation) {
PreExit();
- inside_branch = conditional_nest_count;
- if (conditional_nest_count > 0) {
- Emit(OpReturn());
+ inside_branch = true;
+ if (conditional_branch_set) {
+ OpReturn();
} else {
const Id dummy = OpLabel();
- Emit(OpBranch(dummy));
- Emit(dummy);
- Emit(OpReturn());
- Emit(OpLabel());
+ OpBranch(dummy);
+ AddLabel(dummy);
+ OpReturn();
+ AddLabel();
}
return {};
}
- Id Discard(Operation operation) {
- inside_branch = conditional_nest_count;
- if (conditional_nest_count > 0) {
- Emit(OpKill());
+ Expression Discard(Operation operation) {
+ inside_branch = true;
+ if (conditional_branch_set) {
+ OpKill();
} else {
const Id dummy = OpLabel();
- Emit(OpBranch(dummy));
- Emit(dummy);
- Emit(OpKill());
- Emit(OpLabel());
+ OpBranch(dummy);
+ AddLabel(dummy);
+ OpKill();
+ AddLabel();
}
return {};
}
- Id EmitVertex(Operation operation) {
- UNIMPLEMENTED();
+ Expression EmitVertex(Operation) {
+ OpEmitVertex();
return {};
}
- Id EndPrimitive(Operation operation) {
- UNIMPLEMENTED();
+ Expression EndPrimitive(Operation operation) {
+ OpEndPrimitive();
return {};
}
- Id YNegate(Operation operation) {
- UNIMPLEMENTED();
- return {};
+ Expression InvocationId(Operation) {
+ return {OpLoad(t_int, invocation_id), Type::Int};
}
- template <u32 element>
- Id LocalInvocationId(Operation) {
+ Expression YNegate(Operation) {
UNIMPLEMENTED();
- return {};
+ return {Constant(t_float, 1.0f), Type::Float};
}
template <u32 element>
- Id WorkGroupId(Operation) {
- UNIMPLEMENTED();
- return {};
+ Expression LocalInvocationId(Operation) {
+ const Id id = OpLoad(t_uint3, local_invocation_id);
+ return {OpCompositeExtract(t_uint, id, element), Type::Uint};
}
- Id BallotThread(Operation) {
- UNIMPLEMENTED();
- return {};
+ template <u32 element>
+ Expression WorkGroupId(Operation operation) {
+ const Id id = OpLoad(t_uint3, workgroup_id);
+ return {OpCompositeExtract(t_uint, id, element), Type::Uint};
}
- Id VoteAll(Operation) {
- UNIMPLEMENTED();
- return {};
- }
+ Expression BallotThread(Operation operation) {
+ const Id predicate = AsBool(Visit(operation[0]));
+ const Id ballot = OpSubgroupBallotKHR(t_uint4, predicate);
- Id VoteAny(Operation) {
- UNIMPLEMENTED();
- return {};
+ if (!device.IsWarpSizePotentiallyBiggerThanGuest()) {
+ // Guest-like devices can just return the first index.
+ return {OpCompositeExtract(t_uint, ballot, 0U), Type::Uint};
+ }
+
+ // The others will have to return what is local to the current thread.
+ // For instance a device with a warp size of 64 will return the upper uint when the current
+ // thread is 38.
+ const Id tid = OpLoad(t_uint, thread_id);
+ const Id thread_index = OpShiftRightLogical(t_uint, tid, Constant(t_uint, 5));
+ return {OpVectorExtractDynamic(t_uint, ballot, thread_index), Type::Uint};
}
- Id VoteEqual(Operation) {
- UNIMPLEMENTED();
- return {};
+ template <Id (Module::*func)(Id, Id)>
+ Expression Vote(Operation operation) {
+ // TODO(Rodrigo): Handle devices with different warp sizes
+ const Id predicate = AsBool(Visit(operation[0]));
+ return {(this->*func)(t_bool, predicate), Type::Bool};
}
- Id ThreadId(Operation) {
- UNIMPLEMENTED();
- return {};
+ Expression ThreadId(Operation) {
+ return {OpLoad(t_uint, thread_id), Type::Uint};
}
- Id ShuffleIndexed(Operation) {
- UNIMPLEMENTED();
- return {};
+ Expression ShuffleIndexed(Operation operation) {
+ const Id value = AsFloat(Visit(operation[0]));
+ const Id index = AsUint(Visit(operation[1]));
+ return {OpSubgroupReadInvocationKHR(t_float, value, index), Type::Float};
}
- Id DeclareBuiltIn(spv::BuiltIn builtin, spv::StorageClass storage, Id type,
- const std::string& name) {
+ Id DeclareBuiltIn(spv::BuiltIn builtin, spv::StorageClass storage, Id type, std::string name) {
const Id id = OpVariable(type, storage);
Decorate(id, spv::Decoration::BuiltIn, static_cast<u32>(builtin));
- AddGlobalVariable(Name(id, name));
+ AddGlobalVariable(Name(id, std::move(name)));
interfaces.push_back(id);
return id;
}
+ Id DeclareInputBuiltIn(spv::BuiltIn builtin, Id type, std::string name) {
+ return DeclareBuiltIn(builtin, spv::StorageClass::Input, type, std::move(name));
+ }
+
bool IsRenderTargetUsed(u32 rt) const {
for (u32 component = 0; component < 4; ++component) {
if (header.ps.IsColorComponentOutputEnabled(rt, component)) {
@@ -1242,66 +1999,148 @@ private:
members.push_back(Constant(t_uint, element));
}
- return Emit(OpAccessChain(pointer_type, composite, members));
+ return OpAccessChain(pointer_type, composite, members);
}
- template <Type type>
- Id VisitOperand(Operation operation, std::size_t operand_index) {
- const Id value = Visit(operation[operand_index]);
-
- switch (type) {
+ Id As(Expression expr, Type wanted_type) {
+ switch (wanted_type) {
case Type::Bool:
+ return AsBool(expr);
case Type::Bool2:
+ return AsBool2(expr);
case Type::Float:
- return value;
+ return AsFloat(expr);
case Type::Int:
- return Emit(OpBitcast(t_int, value));
+ return AsInt(expr);
case Type::Uint:
- return Emit(OpBitcast(t_uint, value));
+ return AsUint(expr);
case Type::HalfFloat:
- UNIMPLEMENTED();
+ return AsHalfFloat(expr);
+ default:
+ UNREACHABLE();
+ return expr.id;
}
- UNREACHABLE();
- return value;
}
- template <Type type>
- Id BitcastFrom(Id value) {
- switch (type) {
- case Type::Bool:
- case Type::Bool2:
+ Id AsBool(Expression expr) {
+ ASSERT(expr.type == Type::Bool);
+ return expr.id;
+ }
+
+ Id AsBool2(Expression expr) {
+ ASSERT(expr.type == Type::Bool2);
+ return expr.id;
+ }
+
+ Id AsFloat(Expression expr) {
+ switch (expr.type) {
case Type::Float:
- return value;
+ return expr.id;
case Type::Int:
case Type::Uint:
- return Emit(OpBitcast(t_float, value));
+ return OpBitcast(t_float, expr.id);
case Type::HalfFloat:
- UNIMPLEMENTED();
+ if (device.IsFloat16Supported()) {
+ return OpBitcast(t_float, expr.id);
+ }
+ return OpBitcast(t_float, OpPackHalf2x16(t_uint, expr.id));
+ default:
+ UNREACHABLE();
+ return expr.id;
}
- UNREACHABLE();
- return value;
}
- template <Type type>
- Id BitcastTo(Id value) {
- switch (type) {
- case Type::Bool:
- case Type::Bool2:
+ Id AsInt(Expression expr) {
+ switch (expr.type) {
+ case Type::Int:
+ return expr.id;
+ case Type::Float:
+ case Type::Uint:
+ return OpBitcast(t_int, expr.id);
+ case Type::HalfFloat:
+ if (device.IsFloat16Supported()) {
+ return OpBitcast(t_int, expr.id);
+ }
+ return OpPackHalf2x16(t_int, expr.id);
+ default:
UNREACHABLE();
+ return expr.id;
+ }
+ }
+
+ Id AsUint(Expression expr) {
+ switch (expr.type) {
+ case Type::Uint:
+ return expr.id;
case Type::Float:
- return Emit(OpBitcast(t_float, value));
case Type::Int:
- return Emit(OpBitcast(t_int, value));
- case Type::Uint:
- return Emit(OpBitcast(t_uint, value));
+ return OpBitcast(t_uint, expr.id);
case Type::HalfFloat:
- UNIMPLEMENTED();
+ if (device.IsFloat16Supported()) {
+ return OpBitcast(t_uint, expr.id);
+ }
+ return OpPackHalf2x16(t_uint, expr.id);
+ default:
+ UNREACHABLE();
+ return expr.id;
+ }
+ }
+
+ Id AsHalfFloat(Expression expr) {
+ switch (expr.type) {
+ case Type::HalfFloat:
+ return expr.id;
+ case Type::Float:
+ case Type::Int:
+ case Type::Uint:
+ if (device.IsFloat16Supported()) {
+ return OpBitcast(t_half, expr.id);
+ }
+ return OpUnpackHalf2x16(t_half, AsUint(expr));
+ default:
+ UNREACHABLE();
+ return expr.id;
+ }
+ }
+
+ Id GetHalfScalarFromFloat(Id value) {
+ if (device.IsFloat16Supported()) {
+ return OpFConvert(t_scalar_half, value);
}
- UNREACHABLE();
return value;
}
- Id GetTypeDefinition(Type type) {
+ Id GetFloatFromHalfScalar(Id value) {
+ if (device.IsFloat16Supported()) {
+ return OpFConvert(t_float, value);
+ }
+ return value;
+ }
+
+ AttributeType GetAttributeType(u32 location) const {
+ if (stage != ShaderType::Vertex) {
+ return {Type::Float, t_in_float, t_in_float4};
+ }
+ switch (specialization.attribute_types.at(location)) {
+ case Maxwell::VertexAttribute::Type::SignedNorm:
+ case Maxwell::VertexAttribute::Type::UnsignedNorm:
+ case Maxwell::VertexAttribute::Type::Float:
+ return {Type::Float, t_in_float, t_in_float4};
+ case Maxwell::VertexAttribute::Type::SignedInt:
+ return {Type::Int, t_in_int, t_in_int4};
+ case Maxwell::VertexAttribute::Type::UnsignedInt:
+ return {Type::Uint, t_in_uint, t_in_uint4};
+ case Maxwell::VertexAttribute::Type::UnsignedScaled:
+ case Maxwell::VertexAttribute::Type::SignedScaled:
+ UNIMPLEMENTED();
+ return {Type::Float, t_in_float, t_in_float4};
+ default:
+ UNREACHABLE();
+ return {Type::Float, t_in_float, t_in_float4};
+ }
+ }
+
+ Id GetTypeDefinition(Type type) const {
switch (type) {
case Type::Bool:
return t_bool;
@@ -1314,10 +2153,25 @@ private:
case Type::Uint:
return t_uint;
case Type::HalfFloat:
+ return t_half;
+ default:
+ UNREACHABLE();
+ return {};
+ }
+ }
+
+ std::array<Id, 4> GetTypeVectorDefinitionLut(Type type) const {
+ switch (type) {
+ case Type::Float:
+ return {nullptr, t_float2, t_float3, t_float4};
+ case Type::Int:
+ return {nullptr, t_int2, t_int3, t_int4};
+ case Type::Uint:
+ return {nullptr, t_uint2, t_uint3, t_uint4};
+ default:
UNIMPLEMENTED();
+ return {};
}
- UNREACHABLE();
- return {};
}
std::tuple<Id, Id> CreateFlowStack() {
@@ -1327,9 +2181,11 @@ private:
constexpr auto storage_class = spv::StorageClass::Function;
const Id flow_stack_type = TypeArray(t_uint, Constant(t_uint, FLOW_STACK_SIZE));
- const Id stack = Emit(OpVariable(TypePointer(storage_class, flow_stack_type), storage_class,
- ConstantNull(flow_stack_type)));
- const Id top = Emit(OpVariable(t_func_uint, storage_class, Constant(t_uint, 0)));
+ const Id stack = OpVariable(TypePointer(storage_class, flow_stack_type), storage_class,
+ ConstantNull(flow_stack_type));
+ const Id top = OpVariable(t_func_uint, storage_class, Constant(t_uint, 0));
+ AddLocalVariable(stack);
+ AddLocalVariable(top);
return std::tie(stack, top);
}
@@ -1358,8 +2214,8 @@ private:
&SPIRVDecompiler::Unary<&Module::OpFNegate, Type::Float>,
&SPIRVDecompiler::Unary<&Module::OpFAbs, Type::Float>,
&SPIRVDecompiler::Ternary<&Module::OpFClamp, Type::Float>,
- &SPIRVDecompiler::FCastHalf0,
- &SPIRVDecompiler::FCastHalf1,
+ &SPIRVDecompiler::FCastHalf<0>,
+ &SPIRVDecompiler::FCastHalf<1>,
&SPIRVDecompiler::Binary<&Module::OpFMin, Type::Float>,
&SPIRVDecompiler::Binary<&Module::OpFMax, Type::Float>,
&SPIRVDecompiler::Unary<&Module::OpCos, Type::Float>,
@@ -1407,7 +2263,7 @@ private:
&SPIRVDecompiler::Unary<&Module::OpBitcast, Type::Uint, Type::Int>,
&SPIRVDecompiler::Binary<&Module::OpShiftLeftLogical, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpShiftRightLogical, Type::Uint>,
- &SPIRVDecompiler::Binary<&Module::OpShiftRightArithmetic, Type::Uint>,
+ &SPIRVDecompiler::Binary<&Module::OpShiftRightLogical, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpBitwiseAnd, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpBitwiseOr, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpBitwiseXor, Type::Uint>,
@@ -1426,8 +2282,8 @@ private:
&SPIRVDecompiler::HCastFloat,
&SPIRVDecompiler::HUnpack,
&SPIRVDecompiler::HMergeF32,
- &SPIRVDecompiler::HMergeH0,
- &SPIRVDecompiler::HMergeH1,
+ &SPIRVDecompiler::HMergeHN<0>,
+ &SPIRVDecompiler::HMergeHN<1>,
&SPIRVDecompiler::HPack2,
&SPIRVDecompiler::LogicalAssign,
@@ -1435,8 +2291,9 @@ private:
&SPIRVDecompiler::Binary<&Module::OpLogicalOr, Type::Bool>,
&SPIRVDecompiler::Binary<&Module::OpLogicalNotEqual, Type::Bool>,
&SPIRVDecompiler::Unary<&Module::OpLogicalNot, Type::Bool>,
- &SPIRVDecompiler::LogicalPick2,
- &SPIRVDecompiler::LogicalAnd2,
+ &SPIRVDecompiler::Binary<&Module::OpVectorExtractDynamic, Type::Bool, Type::Bool2,
+ Type::Uint>,
+ &SPIRVDecompiler::Unary<&Module::OpAll, Type::Bool, Type::Bool2>,
&SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::Float>,
&SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::Float>,
@@ -1444,7 +2301,7 @@ private:
&SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::Float>,
&SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::Float>,
&SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::Float>,
- &SPIRVDecompiler::Unary<&Module::OpIsNan, Type::Bool>,
+ &SPIRVDecompiler::Unary<&Module::OpIsNan, Type::Bool, Type::Float>,
&SPIRVDecompiler::Binary<&Module::OpSLessThan, Type::Bool, Type::Int>,
&SPIRVDecompiler::Binary<&Module::OpIEqual, Type::Bool, Type::Int>,
@@ -1460,19 +2317,19 @@ private:
&SPIRVDecompiler::Binary<&Module::OpINotEqual, Type::Bool, Type::Uint>,
&SPIRVDecompiler::Binary<&Module::OpUGreaterThanEqual, Type::Bool, Type::Uint>,
- &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::HalfFloat>,
- &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::HalfFloat>,
- &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool, Type::HalfFloat>,
- &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::HalfFloat>,
- &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::HalfFloat>,
- &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::HalfFloat>,
+ &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool2, Type::HalfFloat>,
+ &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool2, Type::HalfFloat>,
+ &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool2, Type::HalfFloat>,
+ &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool2, Type::HalfFloat>,
+ &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool2, Type::HalfFloat>,
+ &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool2, Type::HalfFloat>,
// TODO(Rodrigo): Should these use the OpFUnord* variants?
- &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::HalfFloat>,
- &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::HalfFloat>,
- &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool, Type::HalfFloat>,
- &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::HalfFloat>,
- &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::HalfFloat>,
- &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::HalfFloat>,
+ &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool2, Type::HalfFloat>,
+ &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool2, Type::HalfFloat>,
+ &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool2, Type::HalfFloat>,
+ &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool2, Type::HalfFloat>,
+ &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool2, Type::HalfFloat>,
+ &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool2, Type::HalfFloat>,
&SPIRVDecompiler::Texture,
&SPIRVDecompiler::TextureLod,
@@ -1509,9 +2366,9 @@ private:
&SPIRVDecompiler::WorkGroupId<2>,
&SPIRVDecompiler::BallotThread,
- &SPIRVDecompiler::VoteAll,
- &SPIRVDecompiler::VoteAny,
- &SPIRVDecompiler::VoteEqual,
+ &SPIRVDecompiler::Vote<&Module::OpSubgroupAllKHR>,
+ &SPIRVDecompiler::Vote<&Module::OpSubgroupAnyKHR>,
+ &SPIRVDecompiler::Vote<&Module::OpSubgroupAllEqualKHR>,
&SPIRVDecompiler::ThreadId,
&SPIRVDecompiler::ShuffleIndexed,
@@ -1522,8 +2379,7 @@ private:
const ShaderIR& ir;
const ShaderType stage;
const Tegra::Shader::Header header;
- u64 conditional_nest_count{};
- u64 inside_branch{};
+ const Specialization& specialization;
const Id t_void = Name(TypeVoid(), "void");
@@ -1551,20 +2407,28 @@ private:
const Id t_func_uint = Name(TypePointer(spv::StorageClass::Function, t_uint), "func_uint");
const Id t_in_bool = Name(TypePointer(spv::StorageClass::Input, t_bool), "in_bool");
+ const Id t_in_int = Name(TypePointer(spv::StorageClass::Input, t_int), "in_int");
+ const Id t_in_int4 = Name(TypePointer(spv::StorageClass::Input, t_int4), "in_int4");
const Id t_in_uint = Name(TypePointer(spv::StorageClass::Input, t_uint), "in_uint");
+ const Id t_in_uint3 = Name(TypePointer(spv::StorageClass::Input, t_uint3), "in_uint3");
+ const Id t_in_uint4 = Name(TypePointer(spv::StorageClass::Input, t_uint4), "in_uint4");
const Id t_in_float = Name(TypePointer(spv::StorageClass::Input, t_float), "in_float");
+ const Id t_in_float2 = Name(TypePointer(spv::StorageClass::Input, t_float2), "in_float2");
+ const Id t_in_float3 = Name(TypePointer(spv::StorageClass::Input, t_float3), "in_float3");
const Id t_in_float4 = Name(TypePointer(spv::StorageClass::Input, t_float4), "in_float4");
+ const Id t_out_int = Name(TypePointer(spv::StorageClass::Output, t_int), "out_int");
+
const Id t_out_float = Name(TypePointer(spv::StorageClass::Output, t_float), "out_float");
const Id t_out_float4 = Name(TypePointer(spv::StorageClass::Output, t_float4), "out_float4");
const Id t_cbuf_float = TypePointer(spv::StorageClass::Uniform, t_float);
const Id t_cbuf_std140 = Decorate(
- Name(TypeArray(t_float4, Constant(t_uint, MAX_CONSTBUFFER_ELEMENTS)), "CbufStd140Array"),
- spv::Decoration::ArrayStride, 16u);
+ Name(TypeArray(t_float4, Constant(t_uint, MaxConstBufferElements)), "CbufStd140Array"),
+ spv::Decoration::ArrayStride, 16U);
const Id t_cbuf_scalar = Decorate(
- Name(TypeArray(t_float, Constant(t_uint, MAX_CONSTBUFFER_FLOATS)), "CbufScalarArray"),
- spv::Decoration::ArrayStride, 4u);
+ Name(TypeArray(t_float, Constant(t_uint, MaxConstBufferFloats)), "CbufScalarArray"),
+ spv::Decoration::ArrayStride, 4U);
const Id t_cbuf_std140_struct = MemberDecorate(
Decorate(TypeStruct(t_cbuf_std140), spv::Decoration::Block), 0, spv::Decoration::Offset, 0);
const Id t_cbuf_scalar_struct = MemberDecorate(
@@ -1572,28 +2436,43 @@ private:
const Id t_cbuf_std140_ubo = TypePointer(spv::StorageClass::Uniform, t_cbuf_std140_struct);
const Id t_cbuf_scalar_ubo = TypePointer(spv::StorageClass::Uniform, t_cbuf_scalar_struct);
+ Id t_smem_uint{};
+
const Id t_gmem_float = TypePointer(spv::StorageClass::StorageBuffer, t_float);
const Id t_gmem_array =
- Name(Decorate(TypeRuntimeArray(t_float), spv::Decoration::ArrayStride, 4u), "GmemArray");
+ Name(Decorate(TypeRuntimeArray(t_float), spv::Decoration::ArrayStride, 4U), "GmemArray");
const Id t_gmem_struct = MemberDecorate(
Decorate(TypeStruct(t_gmem_array), spv::Decoration::Block), 0, spv::Decoration::Offset, 0);
const Id t_gmem_ssbo = TypePointer(spv::StorageClass::StorageBuffer, t_gmem_struct);
const Id v_float_zero = Constant(t_float, 0.0f);
+ const Id v_float_one = Constant(t_float, 1.0f);
+
+ // Nvidia uses these defaults for varyings (e.g. position and generic attributes)
+ const Id v_varying_default =
+ ConstantComposite(t_float4, v_float_zero, v_float_zero, v_float_zero, v_float_one);
+
const Id v_true = ConstantTrue(t_bool);
const Id v_false = ConstantFalse(t_bool);
- Id per_vertex{};
+ Id t_scalar_half{};
+ Id t_half{};
+
+ Id out_vertex{};
+ Id in_vertex{};
std::map<u32, Id> registers;
std::map<Tegra::Shader::Pred, Id> predicates;
std::map<u32, Id> flow_variables;
Id local_memory{};
+ Id shared_memory{};
std::array<Id, INTERNAL_FLAGS_COUNT> internal_flags{};
std::map<Attribute::Index, Id> input_attributes;
std::map<Attribute::Index, Id> output_attributes;
std::map<u32, Id> constant_buffers;
std::map<GlobalMemoryBase, Id> global_buffers;
- std::map<u32, SamplerImage> sampler_images;
+ std::map<u32, TexelBuffer> texel_buffers;
+ std::map<u32, SampledImage> sampled_images;
+ std::map<u32, StorageImage> images;
Id instance_index{};
Id vertex_index{};
@@ -1601,18 +2480,20 @@ private:
Id frag_depth{};
Id frag_coord{};
Id front_facing{};
-
- u32 position_index{};
- u32 point_size_index{};
- u32 clip_distances_index{};
+ Id point_coord{};
+ Id tess_level_outer{};
+ Id tess_level_inner{};
+ Id tess_coord{};
+ Id invocation_id{};
+ Id workgroup_id{};
+ Id local_invocation_id{};
+ Id thread_id{};
+
+ VertexIndices in_indices;
+ VertexIndices out_indices;
std::vector<Id> interfaces;
- u32 const_buffers_base_binding{};
- u32 global_buffers_base_binding{};
- u32 samplers_base_binding{};
-
- Id execute_function{};
Id jmp_to{};
Id ssy_flow_stack_top{};
Id pbk_flow_stack_top{};
@@ -1620,6 +2501,9 @@ private:
Id pbk_flow_stack{};
Id continue_label{};
std::map<u32, Id> labels;
+
+ bool conditional_branch_set{};
+ bool inside_branch{};
};
class ExprDecompiler {
@@ -1630,25 +2514,25 @@ public:
const Id type_def = decomp.GetTypeDefinition(Type::Bool);
const Id op1 = Visit(expr.operand1);
const Id op2 = Visit(expr.operand2);
- return decomp.Emit(decomp.OpLogicalAnd(type_def, op1, op2));
+ return decomp.OpLogicalAnd(type_def, op1, op2);
}
Id operator()(const ExprOr& expr) {
const Id type_def = decomp.GetTypeDefinition(Type::Bool);
const Id op1 = Visit(expr.operand1);
const Id op2 = Visit(expr.operand2);
- return decomp.Emit(decomp.OpLogicalOr(type_def, op1, op2));
+ return decomp.OpLogicalOr(type_def, op1, op2);
}
Id operator()(const ExprNot& expr) {
const Id type_def = decomp.GetTypeDefinition(Type::Bool);
const Id op1 = Visit(expr.operand1);
- return decomp.Emit(decomp.OpLogicalNot(type_def, op1));
+ return decomp.OpLogicalNot(type_def, op1);
}
Id operator()(const ExprPredicate& expr) {
const auto pred = static_cast<Tegra::Shader::Pred>(expr.predicate);
- return decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.predicates.at(pred)));
+ return decomp.OpLoad(decomp.t_bool, decomp.predicates.at(pred));
}
Id operator()(const ExprCondCode& expr) {
@@ -1670,12 +2554,15 @@ public:
}
} else if (const auto flag = std::get_if<InternalFlagNode>(&*cc)) {
target = decomp.internal_flags.at(static_cast<u32>(flag->GetFlag()));
+ } else {
+ UNREACHABLE();
}
- return decomp.Emit(decomp.OpLoad(decomp.t_bool, target));
+
+ return decomp.OpLoad(decomp.t_bool, target);
}
Id operator()(const ExprVar& expr) {
- return decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.flow_variables.at(expr.var_index)));
+ return decomp.OpLoad(decomp.t_bool, decomp.flow_variables.at(expr.var_index));
}
Id operator()(const ExprBoolean& expr) {
@@ -1684,9 +2571,9 @@ public:
Id operator()(const ExprGprEqual& expr) {
const Id target = decomp.Constant(decomp.t_uint, expr.value);
- const Id gpr = decomp.BitcastTo<Type::Uint>(
- decomp.Emit(decomp.OpLoad(decomp.t_float, decomp.registers.at(expr.gpr))));
- return decomp.Emit(decomp.OpLogicalEqual(decomp.t_uint, gpr, target));
+ Id gpr = decomp.OpLoad(decomp.t_float, decomp.registers.at(expr.gpr));
+ gpr = decomp.OpBitcast(decomp.t_uint, gpr);
+ return decomp.OpLogicalEqual(decomp.t_uint, gpr, target);
}
Id Visit(const Expr& node) {
@@ -1714,16 +2601,16 @@ public:
const Id condition = expr_parser.Visit(ast.condition);
const Id then_label = decomp.OpLabel();
const Id endif_label = decomp.OpLabel();
- decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone));
- decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label));
- decomp.Emit(then_label);
+ decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone);
+ decomp.OpBranchConditional(condition, then_label, endif_label);
+ decomp.AddLabel(then_label);
ASTNode current = ast.nodes.GetFirst();
while (current) {
Visit(current);
current = current->GetNext();
}
- decomp.Emit(decomp.OpBranch(endif_label));
- decomp.Emit(endif_label);
+ decomp.OpBranch(endif_label);
+ decomp.AddLabel(endif_label);
}
void operator()([[maybe_unused]] const ASTIfElse& ast) {
@@ -1741,7 +2628,7 @@ public:
void operator()(const ASTVarSet& ast) {
ExprDecompiler expr_parser{decomp};
const Id condition = expr_parser.Visit(ast.condition);
- decomp.Emit(decomp.OpStore(decomp.flow_variables.at(ast.index), condition));
+ decomp.OpStore(decomp.flow_variables.at(ast.index), condition);
}
void operator()([[maybe_unused]] const ASTLabel& ast) {
@@ -1758,12 +2645,11 @@ public:
const Id loop_start_block = decomp.OpLabel();
const Id loop_end_block = decomp.OpLabel();
current_loop_exit = endloop_label;
- decomp.Emit(decomp.OpBranch(loop_label));
- decomp.Emit(loop_label);
- decomp.Emit(
- decomp.OpLoopMerge(endloop_label, loop_end_block, spv::LoopControlMask::MaskNone));
- decomp.Emit(decomp.OpBranch(loop_start_block));
- decomp.Emit(loop_start_block);
+ decomp.OpBranch(loop_label);
+ decomp.AddLabel(loop_label);
+ decomp.OpLoopMerge(endloop_label, loop_end_block, spv::LoopControlMask::MaskNone);
+ decomp.OpBranch(loop_start_block);
+ decomp.AddLabel(loop_start_block);
ASTNode current = ast.nodes.GetFirst();
while (current) {
Visit(current);
@@ -1771,8 +2657,8 @@ public:
}
ExprDecompiler expr_parser{decomp};
const Id condition = expr_parser.Visit(ast.condition);
- decomp.Emit(decomp.OpBranchConditional(condition, loop_label, endloop_label));
- decomp.Emit(endloop_label);
+ decomp.OpBranchConditional(condition, loop_label, endloop_label);
+ decomp.AddLabel(endloop_label);
}
void operator()(const ASTReturn& ast) {
@@ -1781,27 +2667,27 @@ public:
const Id condition = expr_parser.Visit(ast.condition);
const Id then_label = decomp.OpLabel();
const Id endif_label = decomp.OpLabel();
- decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone));
- decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label));
- decomp.Emit(then_label);
+ decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone);
+ decomp.OpBranchConditional(condition, then_label, endif_label);
+ decomp.AddLabel(then_label);
if (ast.kills) {
- decomp.Emit(decomp.OpKill());
+ decomp.OpKill();
} else {
decomp.PreExit();
- decomp.Emit(decomp.OpReturn());
+ decomp.OpReturn();
}
- decomp.Emit(endif_label);
+ decomp.AddLabel(endif_label);
} else {
const Id next_block = decomp.OpLabel();
- decomp.Emit(decomp.OpBranch(next_block));
- decomp.Emit(next_block);
+ decomp.OpBranch(next_block);
+ decomp.AddLabel(next_block);
if (ast.kills) {
- decomp.Emit(decomp.OpKill());
+ decomp.OpKill();
} else {
decomp.PreExit();
- decomp.Emit(decomp.OpReturn());
+ decomp.OpReturn();
}
- decomp.Emit(decomp.OpLabel());
+ decomp.AddLabel(decomp.OpLabel());
}
}
@@ -1811,17 +2697,17 @@ public:
const Id condition = expr_parser.Visit(ast.condition);
const Id then_label = decomp.OpLabel();
const Id endif_label = decomp.OpLabel();
- decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone));
- decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label));
- decomp.Emit(then_label);
- decomp.Emit(decomp.OpBranch(current_loop_exit));
- decomp.Emit(endif_label);
+ decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone);
+ decomp.OpBranchConditional(condition, then_label, endif_label);
+ decomp.AddLabel(then_label);
+ decomp.OpBranch(current_loop_exit);
+ decomp.AddLabel(endif_label);
} else {
const Id next_block = decomp.OpLabel();
- decomp.Emit(decomp.OpBranch(next_block));
- decomp.Emit(next_block);
- decomp.Emit(decomp.OpBranch(current_loop_exit));
- decomp.Emit(decomp.OpLabel());
+ decomp.OpBranch(next_block);
+ decomp.AddLabel(next_block);
+ decomp.OpBranch(current_loop_exit);
+ decomp.AddLabel(decomp.OpLabel());
}
}
@@ -1842,20 +2728,51 @@ void SPIRVDecompiler::DecompileAST() {
flow_variables.emplace(i, AddGlobalVariable(id));
}
+ DefinePrologue();
+
const ASTNode program = ir.GetASTProgram();
ASTDecompiler decompiler{*this};
decompiler.Visit(program);
const Id next_block = OpLabel();
- Emit(OpBranch(next_block));
- Emit(next_block);
+ OpBranch(next_block);
+ AddLabel(next_block);
+}
+
+} // Anonymous namespace
+
+ShaderEntries GenerateShaderEntries(const VideoCommon::Shader::ShaderIR& ir) {
+ ShaderEntries entries;
+ for (const auto& cbuf : ir.GetConstantBuffers()) {
+ entries.const_buffers.emplace_back(cbuf.second, cbuf.first);
+ }
+ for (const auto& [base, usage] : ir.GetGlobalMemory()) {
+ entries.global_buffers.emplace_back(base.cbuf_index, base.cbuf_offset, usage.is_written);
+ }
+ for (const auto& sampler : ir.GetSamplers()) {
+ if (sampler.IsBuffer()) {
+ entries.texel_buffers.emplace_back(sampler);
+ } else {
+ entries.samplers.emplace_back(sampler);
+ }
+ }
+ for (const auto& image : ir.GetImages()) {
+ entries.images.emplace_back(image);
+ }
+ for (const auto& attribute : ir.GetInputAttributes()) {
+ if (IsGenericAttribute(attribute)) {
+ entries.attributes.insert(GetGenericAttributeLocation(attribute));
+ }
+ }
+ entries.clip_distances = ir.GetClipDistances();
+ entries.shader_length = ir.GetLength();
+ entries.uses_warps = ir.UsesWarps();
+ return entries;
}
-DecompilerResult Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir,
- ShaderType stage) {
- auto decompiler = std::make_unique<SPIRVDecompiler>(device, ir, stage);
- decompiler->Decompile();
- return {std::move(decompiler), decompiler->GetShaderEntries()};
+std::vector<u32> Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir,
+ ShaderType stage, const Specialization& specialization) {
+ return SPIRVDecompiler(device, ir, stage, specialization).Assemble();
}
-} // namespace Vulkan::VKShader
+} // namespace Vulkan
diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.h b/src/video_core/renderer_vulkan/vk_shader_decompiler.h
index 203fc00d0..2b01321b6 100644
--- a/src/video_core/renderer_vulkan/vk_shader_decompiler.h
+++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.h
@@ -5,29 +5,28 @@
#pragma once
#include <array>
+#include <bitset>
#include <memory>
#include <set>
+#include <type_traits>
#include <utility>
#include <vector>
-#include <sirit/sirit.h>
-
#include "common/common_types.h"
#include "video_core/engines/maxwell_3d.h"
+#include "video_core/engines/shader_type.h"
#include "video_core/shader/shader_ir.h"
-namespace VideoCommon::Shader {
-class ShaderIR;
-}
-
namespace Vulkan {
class VKDevice;
}
-namespace Vulkan::VKShader {
+namespace Vulkan {
using Maxwell = Tegra::Engines::Maxwell3D::Regs;
+using TexelBufferEntry = VideoCommon::Shader::Sampler;
using SamplerEntry = VideoCommon::Shader::Sampler;
+using ImageEntry = VideoCommon::Shader::Image;
constexpr u32 DESCRIPTOR_SET = 0;
@@ -46,39 +45,74 @@ private:
class GlobalBufferEntry {
public:
- explicit GlobalBufferEntry(u32 cbuf_index, u32 cbuf_offset)
- : cbuf_index{cbuf_index}, cbuf_offset{cbuf_offset} {}
+ constexpr explicit GlobalBufferEntry(u32 cbuf_index, u32 cbuf_offset, bool is_written)
+ : cbuf_index{cbuf_index}, cbuf_offset{cbuf_offset}, is_written{is_written} {}
- u32 GetCbufIndex() const {
+ constexpr u32 GetCbufIndex() const {
return cbuf_index;
}
- u32 GetCbufOffset() const {
+ constexpr u32 GetCbufOffset() const {
return cbuf_offset;
}
+ constexpr bool IsWritten() const {
+ return is_written;
+ }
+
private:
u32 cbuf_index{};
u32 cbuf_offset{};
+ bool is_written{};
};
struct ShaderEntries {
- u32 const_buffers_base_binding{};
- u32 global_buffers_base_binding{};
- u32 samplers_base_binding{};
+ u32 NumBindings() const {
+ return static_cast<u32>(const_buffers.size() + global_buffers.size() +
+ texel_buffers.size() + samplers.size() + images.size());
+ }
+
std::vector<ConstBufferEntry> const_buffers;
std::vector<GlobalBufferEntry> global_buffers;
+ std::vector<TexelBufferEntry> texel_buffers;
std::vector<SamplerEntry> samplers;
+ std::vector<ImageEntry> images;
std::set<u32> attributes;
std::array<bool, Maxwell::NumClipDistances> clip_distances{};
std::size_t shader_length{};
- Sirit::Id entry_function{};
- std::vector<Sirit::Id> interfaces;
+ bool uses_warps{};
+};
+
+struct Specialization final {
+ u32 base_binding{};
+
+ // Compute specific
+ std::array<u32, 3> workgroup_size{};
+ u32 shared_memory_size{};
+
+ // Graphics specific
+ Maxwell::PrimitiveTopology primitive_topology{};
+ std::optional<float> point_size{};
+ std::array<Maxwell::VertexAttribute::Type, Maxwell::NumVertexAttributes> attribute_types{};
+
+ // Tessellation specific
+ struct {
+ Maxwell::TessellationPrimitive primitive{};
+ Maxwell::TessellationSpacing spacing{};
+ bool clockwise{};
+ } tessellation;
+};
+// Old gcc versions don't consider this trivially copyable.
+// static_assert(std::is_trivially_copyable_v<Specialization>);
+
+struct SPIRVShader {
+ std::vector<u32> code;
+ ShaderEntries entries;
};
-using DecompilerResult = std::pair<std::unique_ptr<Sirit::Module>, ShaderEntries>;
+ShaderEntries GenerateShaderEntries(const VideoCommon::Shader::ShaderIR& ir);
-DecompilerResult Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir,
- Tegra::Engines::ShaderType stage);
+std::vector<u32> Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir,
+ Tegra::Engines::ShaderType stage, const Specialization& specialization);
-} // namespace Vulkan::VKShader
+} // namespace Vulkan