diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/video_core/renderer_vulkan/vk_shader_decompiler.cpp | 2271 | ||||
-rw-r--r-- | src/video_core/renderer_vulkan/vk_shader_decompiler.h | 74 |
2 files changed, 1648 insertions, 697 deletions
diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp index 76894275b..8f517bdc1 100644 --- a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp +++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp @@ -3,8 +3,10 @@ // Refer to the license.txt file included. #include <functional> +#include <limits> #include <map> -#include <set> +#include <type_traits> +#include <utility> #include <fmt/format.h> @@ -23,7 +25,9 @@ #include "video_core/shader/node.h" #include "video_core/shader/shader_ir.h" -namespace Vulkan::VKShader { +namespace Vulkan { + +namespace { using Sirit::Id; using Tegra::Engines::ShaderType; @@ -35,22 +39,60 @@ using namespace VideoCommon::Shader; using Maxwell = Tegra::Engines::Maxwell3D::Regs; using Operation = const OperationNode&; +class ASTDecompiler; +class ExprDecompiler; + // TODO(Rodrigo): Use rasterizer's value -constexpr u32 MAX_CONSTBUFFER_FLOATS = 0x4000; -constexpr u32 MAX_CONSTBUFFER_ELEMENTS = MAX_CONSTBUFFER_FLOATS / 4; -constexpr u32 STAGE_BINDING_STRIDE = 0x100; +constexpr u32 MaxConstBufferFloats = 0x4000; +constexpr u32 MaxConstBufferElements = MaxConstBufferFloats / 4; + +constexpr u32 NumInputPatches = 32; // This value seems to be the standard + +enum class Type { Void, Bool, Bool2, Float, Int, Uint, HalfFloat }; + +class Expression final { +public: + Expression(Id id, Type type) : id{id}, type{type} { + ASSERT(type != Type::Void); + } + Expression() : type{Type::Void} {} -enum class Type { Bool, Bool2, Float, Int, Uint, HalfFloat }; + Id id{}; + Type type{}; +}; +static_assert(std::is_standard_layout_v<Expression>); -struct SamplerImage { - Id image_type; - Id sampled_image_type; - Id sampler; +struct TexelBuffer { + Id image_type{}; + Id image{}; }; -namespace { +struct SampledImage { + Id image_type{}; + Id sampled_image_type{}; + Id sampler{}; +}; + +struct StorageImage { + Id image_type{}; + Id image{}; +}; + +struct AttributeType { + Type type; + Id scalar; + Id vector; +}; + +struct VertexIndices { + std::optional<u32> position; + std::optional<u32> viewport; + std::optional<u32> point_size; + std::optional<u32> clip_distances; +}; spv::Dim GetSamplerDim(const Sampler& sampler) { + ASSERT(!sampler.IsBuffer()); switch (sampler.GetType()) { case Tegra::Shader::TextureType::Texture1D: return spv::Dim::Dim1D; @@ -66,6 +108,138 @@ spv::Dim GetSamplerDim(const Sampler& sampler) { } } +std::pair<spv::Dim, bool> GetImageDim(const Image& image) { + switch (image.GetType()) { + case Tegra::Shader::ImageType::Texture1D: + return {spv::Dim::Dim1D, false}; + case Tegra::Shader::ImageType::TextureBuffer: + return {spv::Dim::Buffer, false}; + case Tegra::Shader::ImageType::Texture1DArray: + return {spv::Dim::Dim1D, true}; + case Tegra::Shader::ImageType::Texture2D: + return {spv::Dim::Dim2D, false}; + case Tegra::Shader::ImageType::Texture2DArray: + return {spv::Dim::Dim2D, true}; + case Tegra::Shader::ImageType::Texture3D: + return {spv::Dim::Dim3D, false}; + default: + UNIMPLEMENTED_MSG("Unimplemented image type={}", static_cast<u32>(image.GetType())); + return {spv::Dim::Dim2D, false}; + } +} + +/// Returns the number of vertices present in a primitive topology. +u32 GetNumPrimitiveTopologyVertices(Maxwell::PrimitiveTopology primitive_topology) { + switch (primitive_topology) { + case Maxwell::PrimitiveTopology::Points: + return 1; + case Maxwell::PrimitiveTopology::Lines: + case Maxwell::PrimitiveTopology::LineLoop: + case Maxwell::PrimitiveTopology::LineStrip: + return 2; + case Maxwell::PrimitiveTopology::Triangles: + case Maxwell::PrimitiveTopology::TriangleStrip: + case Maxwell::PrimitiveTopology::TriangleFan: + return 3; + case Maxwell::PrimitiveTopology::LinesAdjacency: + case Maxwell::PrimitiveTopology::LineStripAdjacency: + return 4; + case Maxwell::PrimitiveTopology::TrianglesAdjacency: + case Maxwell::PrimitiveTopology::TriangleStripAdjacency: + return 6; + case Maxwell::PrimitiveTopology::Quads: + UNIMPLEMENTED_MSG("Quads"); + return 3; + case Maxwell::PrimitiveTopology::QuadStrip: + UNIMPLEMENTED_MSG("QuadStrip"); + return 3; + case Maxwell::PrimitiveTopology::Polygon: + UNIMPLEMENTED_MSG("Polygon"); + return 3; + case Maxwell::PrimitiveTopology::Patches: + UNIMPLEMENTED_MSG("Patches"); + return 3; + default: + UNREACHABLE(); + return 3; + } +} + +spv::ExecutionMode GetExecutionMode(Maxwell::TessellationPrimitive primitive) { + switch (primitive) { + case Maxwell::TessellationPrimitive::Isolines: + return spv::ExecutionMode::Isolines; + case Maxwell::TessellationPrimitive::Triangles: + return spv::ExecutionMode::Triangles; + case Maxwell::TessellationPrimitive::Quads: + return spv::ExecutionMode::Quads; + } + UNREACHABLE(); + return spv::ExecutionMode::Triangles; +} + +spv::ExecutionMode GetExecutionMode(Maxwell::TessellationSpacing spacing) { + switch (spacing) { + case Maxwell::TessellationSpacing::Equal: + return spv::ExecutionMode::SpacingEqual; + case Maxwell::TessellationSpacing::FractionalOdd: + return spv::ExecutionMode::SpacingFractionalOdd; + case Maxwell::TessellationSpacing::FractionalEven: + return spv::ExecutionMode::SpacingFractionalEven; + } + UNREACHABLE(); + return spv::ExecutionMode::SpacingEqual; +} + +spv::ExecutionMode GetExecutionMode(Maxwell::PrimitiveTopology input_topology) { + switch (input_topology) { + case Maxwell::PrimitiveTopology::Points: + return spv::ExecutionMode::InputPoints; + case Maxwell::PrimitiveTopology::Lines: + case Maxwell::PrimitiveTopology::LineLoop: + case Maxwell::PrimitiveTopology::LineStrip: + return spv::ExecutionMode::InputLines; + case Maxwell::PrimitiveTopology::Triangles: + case Maxwell::PrimitiveTopology::TriangleStrip: + case Maxwell::PrimitiveTopology::TriangleFan: + return spv::ExecutionMode::Triangles; + case Maxwell::PrimitiveTopology::LinesAdjacency: + case Maxwell::PrimitiveTopology::LineStripAdjacency: + return spv::ExecutionMode::InputLinesAdjacency; + case Maxwell::PrimitiveTopology::TrianglesAdjacency: + case Maxwell::PrimitiveTopology::TriangleStripAdjacency: + return spv::ExecutionMode::InputTrianglesAdjacency; + case Maxwell::PrimitiveTopology::Quads: + UNIMPLEMENTED_MSG("Quads"); + return spv::ExecutionMode::Triangles; + case Maxwell::PrimitiveTopology::QuadStrip: + UNIMPLEMENTED_MSG("QuadStrip"); + return spv::ExecutionMode::Triangles; + case Maxwell::PrimitiveTopology::Polygon: + UNIMPLEMENTED_MSG("Polygon"); + return spv::ExecutionMode::Triangles; + case Maxwell::PrimitiveTopology::Patches: + UNIMPLEMENTED_MSG("Patches"); + return spv::ExecutionMode::Triangles; + } + UNREACHABLE(); + return spv::ExecutionMode::Triangles; +} + +spv::ExecutionMode GetExecutionMode(Tegra::Shader::OutputTopology output_topology) { + switch (output_topology) { + case Tegra::Shader::OutputTopology::PointList: + return spv::ExecutionMode::OutputPoints; + case Tegra::Shader::OutputTopology::LineStrip: + return spv::ExecutionMode::OutputLineStrip; + case Tegra::Shader::OutputTopology::TriangleStrip: + return spv::ExecutionMode::OutputTriangleStrip; + default: + UNREACHABLE(); + return spv::ExecutionMode::OutputPoints; + } +} + /// Returns true if an attribute index is one of the 32 generic attributes constexpr bool IsGenericAttribute(Attribute::Index attribute) { return attribute >= Attribute::Index::Attribute_0 && @@ -73,7 +247,7 @@ constexpr bool IsGenericAttribute(Attribute::Index attribute) { } /// Returns the location of a generic attribute -constexpr u32 GetGenericAttributeLocation(Attribute::Index attribute) { +u32 GetGenericAttributeLocation(Attribute::Index attribute) { ASSERT(IsGenericAttribute(attribute)); return static_cast<u32>(attribute) - static_cast<u32>(Attribute::Index::Attribute_0); } @@ -87,20 +261,146 @@ bool IsPrecise(Operation operand) { return false; } -} // namespace - -class ASTDecompiler; -class ExprDecompiler; - -class SPIRVDecompiler : public Sirit::Module { +class SPIRVDecompiler final : public Sirit::Module { public: - explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderType stage) - : Module(0x00010300), device{device}, ir{ir}, stage{stage}, header{ir.GetHeader()} { + explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderType stage, + const Specialization& specialization) + : Module(0x00010300), device{device}, ir{ir}, stage{stage}, header{ir.GetHeader()}, + specialization{specialization} { AddCapability(spv::Capability::Shader); + AddCapability(spv::Capability::UniformAndStorageBuffer16BitAccess); + AddCapability(spv::Capability::ImageQuery); + AddCapability(spv::Capability::Image1D); + AddCapability(spv::Capability::ImageBuffer); + AddCapability(spv::Capability::ImageGatherExtended); + AddCapability(spv::Capability::SampledBuffer); + AddCapability(spv::Capability::StorageImageWriteWithoutFormat); + AddCapability(spv::Capability::SubgroupBallotKHR); + AddCapability(spv::Capability::SubgroupVoteKHR); + AddExtension("SPV_KHR_shader_ballot"); + AddExtension("SPV_KHR_subgroup_vote"); AddExtension("SPV_KHR_storage_buffer_storage_class"); AddExtension("SPV_KHR_variable_pointers"); + + if (ir.UsesViewportIndex()) { + AddCapability(spv::Capability::MultiViewport); + if (device.IsExtShaderViewportIndexLayerSupported()) { + AddExtension("SPV_EXT_shader_viewport_index_layer"); + AddCapability(spv::Capability::ShaderViewportIndexLayerEXT); + } + } + + if (device.IsFloat16Supported()) { + AddCapability(spv::Capability::Float16); + } + t_scalar_half = Name(TypeFloat(device.IsFloat16Supported() ? 16 : 32), "scalar_half"); + t_half = Name(TypeVector(t_scalar_half, 2), "half"); + + const Id main = Decompile(); + + switch (stage) { + case ShaderType::Vertex: + AddEntryPoint(spv::ExecutionModel::Vertex, main, "main", interfaces); + break; + case ShaderType::TesselationControl: + AddCapability(spv::Capability::Tessellation); + AddEntryPoint(spv::ExecutionModel::TessellationControl, main, "main", interfaces); + AddExecutionMode(main, spv::ExecutionMode::OutputVertices, + header.common2.threads_per_input_primitive); + break; + case ShaderType::TesselationEval: + AddCapability(spv::Capability::Tessellation); + AddEntryPoint(spv::ExecutionModel::TessellationEvaluation, main, "main", interfaces); + AddExecutionMode(main, GetExecutionMode(specialization.tessellation.primitive)); + AddExecutionMode(main, GetExecutionMode(specialization.tessellation.spacing)); + AddExecutionMode(main, specialization.tessellation.clockwise + ? spv::ExecutionMode::VertexOrderCw + : spv::ExecutionMode::VertexOrderCcw); + break; + case ShaderType::Geometry: + AddCapability(spv::Capability::Geometry); + AddEntryPoint(spv::ExecutionModel::Geometry, main, "main", interfaces); + AddExecutionMode(main, GetExecutionMode(specialization.primitive_topology)); + AddExecutionMode(main, GetExecutionMode(header.common3.output_topology)); + AddExecutionMode(main, spv::ExecutionMode::OutputVertices, + header.common4.max_output_vertices); + // TODO(Rodrigo): Where can we get this info from? + AddExecutionMode(main, spv::ExecutionMode::Invocations, 1U); + break; + case ShaderType::Fragment: + AddEntryPoint(spv::ExecutionModel::Fragment, main, "main", interfaces); + AddExecutionMode(main, spv::ExecutionMode::OriginUpperLeft); + if (header.ps.omap.depth) { + AddExecutionMode(main, spv::ExecutionMode::DepthReplacing); + } + break; + case ShaderType::Compute: + const auto workgroup_size = specialization.workgroup_size; + AddExecutionMode(main, spv::ExecutionMode::LocalSize, workgroup_size[0], + workgroup_size[1], workgroup_size[2]); + AddEntryPoint(spv::ExecutionModel::GLCompute, main, "main", interfaces); + break; + } } +private: + Id Decompile() { + DeclareCommon(); + DeclareVertex(); + DeclareTessControl(); + DeclareTessEval(); + DeclareGeometry(); + DeclareFragment(); + DeclareCompute(); + DeclareRegisters(); + DeclarePredicates(); + DeclareLocalMemory(); + DeclareSharedMemory(); + DeclareInternalFlags(); + DeclareInputAttributes(); + DeclareOutputAttributes(); + + u32 binding = specialization.base_binding; + binding = DeclareConstantBuffers(binding); + binding = DeclareGlobalBuffers(binding); + binding = DeclareTexelBuffers(binding); + binding = DeclareSamplers(binding); + binding = DeclareImages(binding); + + const Id main = OpFunction(t_void, {}, TypeFunction(t_void)); + AddLabel(); + + if (ir.IsDecompiled()) { + DeclareFlowVariables(); + DecompileAST(); + } else { + AllocateLabels(); + DecompileBranchMode(); + } + + OpReturn(); + OpFunctionEnd(); + + return main; + } + + void DefinePrologue() { + if (stage == ShaderType::Vertex) { + // Clear Position to avoid reading trash on the Z conversion. + const auto position_index = out_indices.position.value(); + const Id position = AccessElement(t_out_float4, out_vertex, position_index); + OpStore(position, v_varying_default); + + if (specialization.point_size) { + const u32 point_size_index = out_indices.point_size.value(); + const Id out_point_size = AccessElement(t_out_float, out_vertex, point_size_index); + OpStore(out_point_size, Constant(t_float, *specialization.point_size)); + } + } + } + + void DecompileAST(); + void DecompileBranchMode() { const u32 first_address = ir.GetBasicBlocks().begin()->first; const Id loop_label = OpLabel("loop"); @@ -111,14 +411,15 @@ public: std::vector<Sirit::Literal> literals; std::vector<Id> branch_labels; - for (const auto& pair : labels) { - const auto [literal, label] = pair; + for (const auto& [literal, label] : labels) { literals.push_back(literal); branch_labels.push_back(label); } - jmp_to = Emit(OpVariable(TypePointer(spv::StorageClass::Function, t_uint), - spv::StorageClass::Function, Constant(t_uint, first_address))); + jmp_to = OpVariable(TypePointer(spv::StorageClass::Function, t_uint), + spv::StorageClass::Function, Constant(t_uint, first_address)); + AddLocalVariable(jmp_to); + std::tie(ssy_flow_stack, ssy_flow_stack_top) = CreateFlowStack(); std::tie(pbk_flow_stack, pbk_flow_stack_top) = CreateFlowStack(); @@ -128,151 +429,118 @@ public: Name(pbk_flow_stack, "pbk_flow_stack"); Name(pbk_flow_stack_top, "pbk_flow_stack_top"); - Emit(OpBranch(loop_label)); - Emit(loop_label); - Emit(OpLoopMerge(merge_label, continue_label, spv::LoopControlMask::Unroll)); - Emit(OpBranch(dummy_label)); + DefinePrologue(); + + OpBranch(loop_label); + AddLabel(loop_label); + OpLoopMerge(merge_label, continue_label, spv::LoopControlMask::MaskNone); + OpBranch(dummy_label); - Emit(dummy_label); + AddLabel(dummy_label); const Id default_branch = OpLabel(); - const Id jmp_to_load = Emit(OpLoad(t_uint, jmp_to)); - Emit(OpSelectionMerge(jump_label, spv::SelectionControlMask::MaskNone)); - Emit(OpSwitch(jmp_to_load, default_branch, literals, branch_labels)); + const Id jmp_to_load = OpLoad(t_uint, jmp_to); + OpSelectionMerge(jump_label, spv::SelectionControlMask::MaskNone); + OpSwitch(jmp_to_load, default_branch, literals, branch_labels); - Emit(default_branch); - Emit(OpReturn()); + AddLabel(default_branch); + OpReturn(); - for (const auto& pair : ir.GetBasicBlocks()) { - const auto& [address, bb] = pair; - Emit(labels.at(address)); + for (const auto& [address, bb] : ir.GetBasicBlocks()) { + AddLabel(labels.at(address)); VisitBasicBlock(bb); const auto next_it = labels.lower_bound(address + 1); const Id next_label = next_it != labels.end() ? next_it->second : default_branch; - Emit(OpBranch(next_label)); + OpBranch(next_label); } - Emit(jump_label); - Emit(OpBranch(continue_label)); - Emit(continue_label); - Emit(OpBranch(loop_label)); - Emit(merge_label); + AddLabel(jump_label); + OpBranch(continue_label); + AddLabel(continue_label); + OpBranch(loop_label); + AddLabel(merge_label); } - void DecompileAST(); +private: + friend class ASTDecompiler; + friend class ExprDecompiler; - void Decompile() { - const bool is_fully_decompiled = ir.IsDecompiled(); - AllocateBindings(); - if (!is_fully_decompiled) { - AllocateLabels(); - } + static constexpr auto INTERNAL_FLAGS_COUNT = static_cast<std::size_t>(InternalFlag::Amount); - DeclareVertex(); - DeclareGeometry(); - DeclareFragment(); - DeclareRegisters(); - DeclarePredicates(); - if (is_fully_decompiled) { - DeclareFlowVariables(); + void AllocateLabels() { + for (const auto& pair : ir.GetBasicBlocks()) { + const u32 address = pair.first; + labels.emplace(address, OpLabel(fmt::format("label_0x{:x}", address))); } - DeclareLocalMemory(); - DeclareInternalFlags(); - DeclareInputAttributes(); - DeclareOutputAttributes(); - DeclareConstantBuffers(); - DeclareGlobalBuffers(); - DeclareSamplers(); + } - execute_function = - Emit(OpFunction(t_void, spv::FunctionControlMask::Inline, TypeFunction(t_void))); - Emit(OpLabel()); + void DeclareCommon() { + thread_id = + DeclareInputBuiltIn(spv::BuiltIn::SubgroupLocalInvocationId, t_in_uint, "thread_id"); + } - if (is_fully_decompiled) { - DecompileAST(); - } else { - DecompileBranchMode(); + void DeclareVertex() { + if (stage != ShaderType::Vertex) { + return; } + Id out_vertex_struct; + std::tie(out_vertex_struct, out_indices) = DeclareVertexStruct(); + const Id vertex_ptr = TypePointer(spv::StorageClass::Output, out_vertex_struct); + out_vertex = OpVariable(vertex_ptr, spv::StorageClass::Output); + interfaces.push_back(AddGlobalVariable(Name(out_vertex, "out_vertex"))); - Emit(OpReturn()); - Emit(OpFunctionEnd()); + // Declare input attributes + vertex_index = DeclareInputBuiltIn(spv::BuiltIn::VertexIndex, t_in_uint, "vertex_index"); + instance_index = + DeclareInputBuiltIn(spv::BuiltIn::InstanceIndex, t_in_uint, "instance_index"); } - ShaderEntries GetShaderEntries() const { - ShaderEntries entries; - entries.const_buffers_base_binding = const_buffers_base_binding; - entries.global_buffers_base_binding = global_buffers_base_binding; - entries.samplers_base_binding = samplers_base_binding; - for (const auto& cbuf : ir.GetConstantBuffers()) { - entries.const_buffers.emplace_back(cbuf.second, cbuf.first); - } - for (const auto& gmem_pair : ir.GetGlobalMemory()) { - const auto& [base, usage] = gmem_pair; - entries.global_buffers.emplace_back(base.cbuf_index, base.cbuf_offset); - } - for (const auto& sampler : ir.GetSamplers()) { - entries.samplers.emplace_back(sampler); - } - for (const auto& attribute : ir.GetInputAttributes()) { - if (IsGenericAttribute(attribute)) { - entries.attributes.insert(GetGenericAttributeLocation(attribute)); - } + void DeclareTessControl() { + if (stage != ShaderType::TesselationControl) { + return; } - entries.clip_distances = ir.GetClipDistances(); - entries.shader_length = ir.GetLength(); - entries.entry_function = execute_function; - entries.interfaces = interfaces; - return entries; - } - -private: - friend class ASTDecompiler; - friend class ExprDecompiler; - - static constexpr auto INTERNAL_FLAGS_COUNT = static_cast<std::size_t>(InternalFlag::Amount); + DeclareInputVertexArray(NumInputPatches); + DeclareOutputVertexArray(header.common2.threads_per_input_primitive); - void AllocateBindings() { - const u32 binding_base = static_cast<u32>(stage) * STAGE_BINDING_STRIDE; - u32 binding_iterator = binding_base; + tess_level_outer = DeclareBuiltIn( + spv::BuiltIn::TessLevelOuter, spv::StorageClass::Output, + TypePointer(spv::StorageClass::Output, TypeArray(t_float, Constant(t_uint, 4U))), + "tess_level_outer"); + Decorate(tess_level_outer, spv::Decoration::Patch); - const auto Allocate = [&binding_iterator](std::size_t count) { - const u32 current_binding = binding_iterator; - binding_iterator += static_cast<u32>(count); - return current_binding; - }; - const_buffers_base_binding = Allocate(ir.GetConstantBuffers().size()); - global_buffers_base_binding = Allocate(ir.GetGlobalMemory().size()); - samplers_base_binding = Allocate(ir.GetSamplers().size()); + tess_level_inner = DeclareBuiltIn( + spv::BuiltIn::TessLevelInner, spv::StorageClass::Output, + TypePointer(spv::StorageClass::Output, TypeArray(t_float, Constant(t_uint, 2U))), + "tess_level_inner"); + Decorate(tess_level_inner, spv::Decoration::Patch); - ASSERT_MSG(binding_iterator - binding_base < STAGE_BINDING_STRIDE, - "Stage binding stride is too small"); + invocation_id = DeclareInputBuiltIn(spv::BuiltIn::InvocationId, t_in_int, "invocation_id"); } - void AllocateLabels() { - for (const auto& pair : ir.GetBasicBlocks()) { - const u32 address = pair.first; - labels.emplace(address, OpLabel(fmt::format("label_0x{:x}", address))); - } - } - - void DeclareVertex() { - if (stage != ShaderType::Vertex) + void DeclareTessEval() { + if (stage != ShaderType::TesselationEval) { return; + } + DeclareInputVertexArray(NumInputPatches); + DeclareOutputVertex(); - DeclareVertexRedeclarations(); + tess_coord = DeclareInputBuiltIn(spv::BuiltIn::TessCoord, t_in_float3, "tess_coord"); } void DeclareGeometry() { - if (stage != ShaderType::Geometry) + if (stage != ShaderType::Geometry) { return; - - UNIMPLEMENTED(); + } + const u32 num_input = GetNumPrimitiveTopologyVertices(specialization.primitive_topology); + DeclareInputVertexArray(num_input); + DeclareOutputVertex(); } void DeclareFragment() { - if (stage != ShaderType::Fragment) + if (stage != ShaderType::Fragment) { return; + } for (u32 rt = 0; rt < static_cast<u32>(frag_colors.size()); ++rt) { if (!IsRenderTargetUsed(rt)) { @@ -296,10 +564,19 @@ private: interfaces.push_back(frag_depth); } - frag_coord = DeclareBuiltIn(spv::BuiltIn::FragCoord, spv::StorageClass::Input, t_in_float4, - "frag_coord"); - front_facing = DeclareBuiltIn(spv::BuiltIn::FrontFacing, spv::StorageClass::Input, - t_in_bool, "front_facing"); + frag_coord = DeclareInputBuiltIn(spv::BuiltIn::FragCoord, t_in_float4, "frag_coord"); + front_facing = DeclareInputBuiltIn(spv::BuiltIn::FrontFacing, t_in_bool, "front_facing"); + point_coord = DeclareInputBuiltIn(spv::BuiltIn::PointCoord, t_in_float2, "point_coord"); + } + + void DeclareCompute() { + if (stage != ShaderType::Compute) { + return; + } + + workgroup_id = DeclareInputBuiltIn(spv::BuiltIn::WorkgroupId, t_in_uint3, "workgroup_id"); + local_invocation_id = + DeclareInputBuiltIn(spv::BuiltIn::LocalInvocationId, t_in_uint3, "local_invocation_id"); } void DeclareRegisters() { @@ -327,21 +604,44 @@ private: } void DeclareLocalMemory() { - if (const u64 local_memory_size = header.GetLocalMemorySize(); local_memory_size > 0) { - const auto element_count = static_cast<u32>(Common::AlignUp(local_memory_size, 4) / 4); - const Id type_array = TypeArray(t_float, Constant(t_uint, element_count)); - const Id type_pointer = TypePointer(spv::StorageClass::Private, type_array); - Name(type_pointer, "LocalMemory"); + // TODO(Rodrigo): Unstub kernel local memory size and pass it from a register at + // specialization time. + const u64 lmem_size = stage == ShaderType::Compute ? 0x400 : header.GetLocalMemorySize(); + if (lmem_size == 0) { + return; + } + const auto element_count = static_cast<u32>(Common::AlignUp(lmem_size, 4) / 4); + const Id type_array = TypeArray(t_float, Constant(t_uint, element_count)); + const Id type_pointer = TypePointer(spv::StorageClass::Private, type_array); + Name(type_pointer, "LocalMemory"); + + local_memory = + OpVariable(type_pointer, spv::StorageClass::Private, ConstantNull(type_array)); + AddGlobalVariable(Name(local_memory, "local_memory")); + } + + void DeclareSharedMemory() { + if (stage != ShaderType::Compute) { + return; + } + t_smem_uint = TypePointer(spv::StorageClass::Workgroup, t_uint); - local_memory = - OpVariable(type_pointer, spv::StorageClass::Private, ConstantNull(type_array)); - AddGlobalVariable(Name(local_memory, "local_memory")); + const u32 smem_size = specialization.shared_memory_size; + if (smem_size == 0) { + // Avoid declaring an empty array. + return; } + const auto element_count = static_cast<u32>(Common::AlignUp(smem_size, 4) / 4); + const Id type_array = TypeArray(t_uint, Constant(t_uint, element_count)); + const Id type_pointer = TypePointer(spv::StorageClass::Workgroup, type_array); + Name(type_pointer, "SharedMemory"); + + shared_memory = OpVariable(type_pointer, spv::StorageClass::Workgroup); + AddGlobalVariable(Name(shared_memory, "shared_memory")); } void DeclareInternalFlags() { - constexpr std::array<const char*, INTERNAL_FLAGS_COUNT> names = {"zero", "sign", "carry", - "overflow"}; + constexpr std::array names = {"zero", "sign", "carry", "overflow"}; for (std::size_t flag = 0; flag < INTERNAL_FLAGS_COUNT; ++flag) { const auto flag_code = static_cast<InternalFlag>(flag); const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false); @@ -349,17 +649,53 @@ private: } } + void DeclareInputVertexArray(u32 length) { + constexpr auto storage = spv::StorageClass::Input; + std::tie(in_indices, in_vertex) = DeclareVertexArray(storage, "in_indices", length); + } + + void DeclareOutputVertexArray(u32 length) { + constexpr auto storage = spv::StorageClass::Output; + std::tie(out_indices, out_vertex) = DeclareVertexArray(storage, "out_indices", length); + } + + std::tuple<VertexIndices, Id> DeclareVertexArray(spv::StorageClass storage_class, + std::string name, u32 length) { + const auto [struct_id, indices] = DeclareVertexStruct(); + const Id vertex_array = TypeArray(struct_id, Constant(t_uint, length)); + const Id vertex_ptr = TypePointer(storage_class, vertex_array); + const Id vertex = OpVariable(vertex_ptr, storage_class); + AddGlobalVariable(Name(vertex, std::move(name))); + interfaces.push_back(vertex); + return {indices, vertex}; + } + + void DeclareOutputVertex() { + Id out_vertex_struct; + std::tie(out_vertex_struct, out_indices) = DeclareVertexStruct(); + const Id out_vertex_ptr = TypePointer(spv::StorageClass::Output, out_vertex_struct); + out_vertex = OpVariable(out_vertex_ptr, spv::StorageClass::Output); + interfaces.push_back(AddGlobalVariable(Name(out_vertex, "out_vertex"))); + } + void DeclareInputAttributes() { for (const auto index : ir.GetInputAttributes()) { if (!IsGenericAttribute(index)) { continue; } - UNIMPLEMENTED_IF(stage == ShaderType::Geometry); - const u32 location = GetGenericAttributeLocation(index); - const Id id = OpVariable(t_in_float4, spv::StorageClass::Input); - Name(AddGlobalVariable(id), fmt::format("in_attr{}", location)); + const auto type_descriptor = GetAttributeType(location); + Id type; + if (IsInputAttributeArray()) { + type = GetTypeVectorDefinitionLut(type_descriptor.type).at(3); + type = TypeArray(type, Constant(t_uint, GetNumInputVertices())); + type = TypePointer(spv::StorageClass::Input, type); + } else { + type = type_descriptor.vector; + } + const Id id = OpVariable(type, spv::StorageClass::Input); + AddGlobalVariable(Name(id, fmt::format("in_attr{}", location))); input_attributes.emplace(index, id); interfaces.push_back(id); @@ -389,8 +725,21 @@ private: if (!IsGenericAttribute(index)) { continue; } - const auto location = GetGenericAttributeLocation(index); - const Id id = OpVariable(t_out_float4, spv::StorageClass::Output); + const u32 location = GetGenericAttributeLocation(index); + Id type = t_float4; + Id varying_default = v_varying_default; + if (IsOutputAttributeArray()) { + const u32 num = GetNumOutputVertices(); + type = TypeArray(type, Constant(t_uint, num)); + if (device.GetDriverID() != vk::DriverIdKHR::eIntelProprietaryWindows) { + // Intel's proprietary driver fails to setup defaults for arrayed output + // attributes. + varying_default = ConstantComposite(type, std::vector(num, varying_default)); + } + } + type = TypePointer(spv::StorageClass::Output, type); + + const Id id = OpVariable(type, spv::StorageClass::Output, varying_default); Name(AddGlobalVariable(id), fmt::format("out_attr{}", location)); output_attributes.emplace(index, id); interfaces.push_back(id); @@ -399,10 +748,8 @@ private: } } - void DeclareConstantBuffers() { - u32 binding = const_buffers_base_binding; - for (const auto& entry : ir.GetConstantBuffers()) { - const auto [index, size] = entry; + u32 DeclareConstantBuffers(u32 binding) { + for (const auto& [index, size] : ir.GetConstantBuffers()) { const Id type = device.IsKhrUniformBufferStandardLayoutSupported() ? t_cbuf_scalar_ubo : t_cbuf_std140_ubo; const Id id = OpVariable(type, spv::StorageClass::Uniform); @@ -412,12 +759,11 @@ private: Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET); constant_buffers.emplace(index, id); } + return binding; } - void DeclareGlobalBuffers() { - u32 binding = global_buffers_base_binding; - for (const auto& entry : ir.GetGlobalMemory()) { - const auto [base, usage] = entry; + u32 DeclareGlobalBuffers(u32 binding) { + for (const auto& [base, usage] : ir.GetGlobalMemory()) { const Id id = OpVariable(t_gmem_ssbo, spv::StorageClass::StorageBuffer); AddGlobalVariable( Name(id, fmt::format("gmem_{}_{}", base.cbuf_index, base.cbuf_offset))); @@ -426,89 +772,187 @@ private: Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET); global_buffers.emplace(base, id); } + return binding; + } + + u32 DeclareTexelBuffers(u32 binding) { + for (const auto& sampler : ir.GetSamplers()) { + if (!sampler.IsBuffer()) { + continue; + } + ASSERT(!sampler.IsArray()); + ASSERT(!sampler.IsShadow()); + + constexpr auto dim = spv::Dim::Buffer; + constexpr int depth = 0; + constexpr int arrayed = 0; + constexpr bool ms = false; + constexpr int sampled = 1; + constexpr auto format = spv::ImageFormat::Unknown; + const Id image_type = TypeImage(t_float, dim, depth, arrayed, ms, sampled, format); + const Id pointer_type = TypePointer(spv::StorageClass::UniformConstant, image_type); + const Id id = OpVariable(pointer_type, spv::StorageClass::UniformConstant); + AddGlobalVariable(Name(id, fmt::format("sampler_{}", sampler.GetIndex()))); + Decorate(id, spv::Decoration::Binding, binding++); + Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET); + + texel_buffers.emplace(sampler.GetIndex(), TexelBuffer{image_type, id}); + } + return binding; } - void DeclareSamplers() { - u32 binding = samplers_base_binding; + u32 DeclareSamplers(u32 binding) { for (const auto& sampler : ir.GetSamplers()) { + if (sampler.IsBuffer()) { + continue; + } const auto dim = GetSamplerDim(sampler); const int depth = sampler.IsShadow() ? 1 : 0; const int arrayed = sampler.IsArray() ? 1 : 0; - // TODO(Rodrigo): Sampled 1 indicates that the image will be used with a sampler. When - // SULD and SUST instructions are implemented, replace this value. - const int sampled = 1; - const Id image_type = - TypeImage(t_float, dim, depth, arrayed, false, sampled, spv::ImageFormat::Unknown); + constexpr bool ms = false; + constexpr int sampled = 1; + constexpr auto format = spv::ImageFormat::Unknown; + const Id image_type = TypeImage(t_float, dim, depth, arrayed, ms, sampled, format); const Id sampled_image_type = TypeSampledImage(image_type); const Id pointer_type = TypePointer(spv::StorageClass::UniformConstant, sampled_image_type); const Id id = OpVariable(pointer_type, spv::StorageClass::UniformConstant); AddGlobalVariable(Name(id, fmt::format("sampler_{}", sampler.GetIndex()))); + Decorate(id, spv::Decoration::Binding, binding++); + Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET); - sampler_images.insert( - {static_cast<u32>(sampler.GetIndex()), {image_type, sampled_image_type, id}}); + sampled_images.emplace(sampler.GetIndex(), + SampledImage{image_type, sampled_image_type, id}); + } + return binding; + } + + u32 DeclareImages(u32 binding) { + for (const auto& image : ir.GetImages()) { + const auto [dim, arrayed] = GetImageDim(image); + constexpr int depth = 0; + constexpr bool ms = false; + constexpr int sampled = 2; // This won't be accessed with a sampler + constexpr auto format = spv::ImageFormat::Unknown; + const Id image_type = TypeImage(t_uint, dim, depth, arrayed, ms, sampled, format, {}); + const Id pointer_type = TypePointer(spv::StorageClass::UniformConstant, image_type); + const Id id = OpVariable(pointer_type, spv::StorageClass::UniformConstant); + AddGlobalVariable(Name(id, fmt::format("image_{}", image.GetIndex()))); Decorate(id, spv::Decoration::Binding, binding++); Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET); + if (image.IsRead() && !image.IsWritten()) { + Decorate(id, spv::Decoration::NonWritable); + } else if (image.IsWritten() && !image.IsRead()) { + Decorate(id, spv::Decoration::NonReadable); + } + + images.emplace(static_cast<u32>(image.GetIndex()), StorageImage{image_type, id}); } + return binding; + } + + bool IsInputAttributeArray() const { + return stage == ShaderType::TesselationControl || stage == ShaderType::TesselationEval || + stage == ShaderType::Geometry; } - void DeclareVertexRedeclarations() { - vertex_index = DeclareBuiltIn(spv::BuiltIn::VertexIndex, spv::StorageClass::Input, - t_in_uint, "vertex_index"); - instance_index = DeclareBuiltIn(spv::BuiltIn::InstanceIndex, spv::StorageClass::Input, - t_in_uint, "instance_index"); + bool IsOutputAttributeArray() const { + return stage == ShaderType::TesselationControl; + } - bool is_clip_distances_declared = false; - for (const auto index : ir.GetOutputAttributes()) { - if (index == Attribute::Index::ClipDistances0123 || - index == Attribute::Index::ClipDistances4567) { - is_clip_distances_declared = true; - } + u32 GetNumInputVertices() const { + switch (stage) { + case ShaderType::Geometry: + return GetNumPrimitiveTopologyVertices(specialization.primitive_topology); + case ShaderType::TesselationControl: + case ShaderType::TesselationEval: + return NumInputPatches; + default: + UNREACHABLE(); + return 1; } + } - std::vector<Id> members; - members.push_back(t_float4); - if (ir.UsesPointSize()) { - members.push_back(t_float); - } - if (is_clip_distances_declared) { - members.push_back(TypeArray(t_float, Constant(t_uint, 8))); - } - - const Id gl_per_vertex_struct = Name(TypeStruct(members), "PerVertex"); - Decorate(gl_per_vertex_struct, spv::Decoration::Block); - - u32 declaration_index = 0; - const auto MemberDecorateBuiltIn = [&](spv::BuiltIn builtin, std::string name, - bool condition) { - if (!condition) - return u32{}; - MemberName(gl_per_vertex_struct, declaration_index, name); - MemberDecorate(gl_per_vertex_struct, declaration_index, spv::Decoration::BuiltIn, - static_cast<u32>(builtin)); - return declaration_index++; + u32 GetNumOutputVertices() const { + switch (stage) { + case ShaderType::TesselationControl: + return header.common2.threads_per_input_primitive; + default: + UNREACHABLE(); + return 1; + } + } + + std::tuple<Id, VertexIndices> DeclareVertexStruct() { + struct BuiltIn { + Id type; + spv::BuiltIn builtin; + const char* name; + }; + std::vector<BuiltIn> members; + members.reserve(4); + + const auto AddBuiltIn = [&](Id type, spv::BuiltIn builtin, const char* name) { + const auto index = static_cast<u32>(members.size()); + members.push_back(BuiltIn{type, builtin, name}); + return index; }; - position_index = MemberDecorateBuiltIn(spv::BuiltIn::Position, "position", true); - point_size_index = - MemberDecorateBuiltIn(spv::BuiltIn::PointSize, "point_size", ir.UsesPointSize()); - clip_distances_index = MemberDecorateBuiltIn(spv::BuiltIn::ClipDistance, "clip_distances", - is_clip_distances_declared); + VertexIndices indices; + indices.position = AddBuiltIn(t_float4, spv::BuiltIn::Position, "position"); + + if (ir.UsesViewportIndex()) { + if (stage != ShaderType::Vertex || device.IsExtShaderViewportIndexLayerSupported()) { + indices.viewport = AddBuiltIn(t_int, spv::BuiltIn::ViewportIndex, "viewport_index"); + } else { + LOG_ERROR(Render_Vulkan, + "Shader requires ViewportIndex but it's not supported on this " + "stage with this device."); + } + } + + if (ir.UsesPointSize() || specialization.point_size) { + indices.point_size = AddBuiltIn(t_float, spv::BuiltIn::PointSize, "point_size"); + } + + const auto& output_attributes = ir.GetOutputAttributes(); + const bool declare_clip_distances = + std::any_of(output_attributes.begin(), output_attributes.end(), [](const auto& index) { + return index == Attribute::Index::ClipDistances0123 || + index == Attribute::Index::ClipDistances4567; + }); + if (declare_clip_distances) { + indices.clip_distances = AddBuiltIn(TypeArray(t_float, Constant(t_uint, 8)), + spv::BuiltIn::ClipDistance, "clip_distances"); + } + + std::vector<Id> member_types; + member_types.reserve(members.size()); + for (std::size_t i = 0; i < members.size(); ++i) { + member_types.push_back(members[i].type); + } + const Id per_vertex_struct = Name(TypeStruct(member_types), "PerVertex"); + Decorate(per_vertex_struct, spv::Decoration::Block); + + for (std::size_t index = 0; index < members.size(); ++index) { + const auto& member = members[index]; + MemberName(per_vertex_struct, static_cast<u32>(index), member.name); + MemberDecorate(per_vertex_struct, static_cast<u32>(index), spv::Decoration::BuiltIn, + static_cast<u32>(member.builtin)); + } - const Id type_pointer = TypePointer(spv::StorageClass::Output, gl_per_vertex_struct); - per_vertex = OpVariable(type_pointer, spv::StorageClass::Output); - AddGlobalVariable(Name(per_vertex, "per_vertex")); - interfaces.push_back(per_vertex); + return {per_vertex_struct, indices}; } void VisitBasicBlock(const NodeBlock& bb) { for (const auto& node : bb) { - static_cast<void>(Visit(node)); + [[maybe_unused]] const Type type = Visit(node).type; + ASSERT(type == Type::Void); } } - Id Visit(const Node& node) { + Expression Visit(const Node& node) { if (const auto operation = std::get_if<OperationNode>(&*node)) { const auto operation_index = static_cast<std::size_t>(operation->GetCode()); const auto decompiler = operation_decompilers[operation_index]; @@ -516,18 +960,21 @@ private: UNREACHABLE_MSG("Operation decompiler {} not defined", operation_index); } return (this->*decompiler)(*operation); + } - } else if (const auto gpr = std::get_if<GprNode>(&*node)) { + if (const auto gpr = std::get_if<GprNode>(&*node)) { const u32 index = gpr->GetIndex(); if (index == Register::ZeroIndex) { - return Constant(t_float, 0.0f); + return {v_float_zero, Type::Float}; } - return Emit(OpLoad(t_float, registers.at(index))); + return {OpLoad(t_float, registers.at(index)), Type::Float}; + } - } else if (const auto immediate = std::get_if<ImmediateNode>(&*node)) { - return BitcastTo<Type::Float>(Constant(t_uint, immediate->GetValue())); + if (const auto immediate = std::get_if<ImmediateNode>(&*node)) { + return {Constant(t_uint, immediate->GetValue()), Type::Uint}; + } - } else if (const auto predicate = std::get_if<PredicateNode>(&*node)) { + if (const auto predicate = std::get_if<PredicateNode>(&*node)) { const auto value = [&]() -> Id { switch (const auto index = predicate->GetIndex(); index) { case Tegra::Shader::Pred::UnusedIndex: @@ -535,74 +982,107 @@ private: case Tegra::Shader::Pred::NeverExecute: return v_false; default: - return Emit(OpLoad(t_bool, predicates.at(index))); + return OpLoad(t_bool, predicates.at(index)); } }(); if (predicate->IsNegated()) { - return Emit(OpLogicalNot(t_bool, value)); + return {OpLogicalNot(t_bool, value), Type::Bool}; } - return value; + return {value, Type::Bool}; + } - } else if (const auto abuf = std::get_if<AbufNode>(&*node)) { + if (const auto abuf = std::get_if<AbufNode>(&*node)) { const auto attribute = abuf->GetIndex(); - const auto element = abuf->GetElement(); + const u32 element = abuf->GetElement(); + const auto& buffer = abuf->GetBuffer(); + + const auto ArrayPass = [&](Id pointer_type, Id composite, std::vector<u32> indices) { + std::vector<Id> members; + members.reserve(std::size(indices) + 1); + + if (buffer && IsInputAttributeArray()) { + members.push_back(AsUint(Visit(buffer))); + } + for (const u32 index : indices) { + members.push_back(Constant(t_uint, index)); + } + return OpAccessChain(pointer_type, composite, members); + }; switch (attribute) { - case Attribute::Index::Position: - if (stage != ShaderType::Fragment) { - UNIMPLEMENTED(); - break; - } else { + case Attribute::Index::Position: { + if (stage == ShaderType::Fragment) { if (element == 3) { - return Constant(t_float, 1.0f); + return {Constant(t_float, 1.0f), Type::Float}; } - return Emit(OpLoad(t_float, AccessElement(t_in_float, frag_coord, element))); + return {OpLoad(t_float, AccessElement(t_in_float, frag_coord, element)), + Type::Float}; + } + const auto elements = {in_indices.position.value(), element}; + return {OpLoad(t_float, ArrayPass(t_in_float, in_vertex, elements)), Type::Float}; + } + case Attribute::Index::PointCoord: { + switch (element) { + case 0: + case 1: + return {OpCompositeExtract(t_float, OpLoad(t_float2, point_coord), element), + Type::Float}; } + UNIMPLEMENTED_MSG("Unimplemented point coord element={}", element); + return {v_float_zero, Type::Float}; + } case Attribute::Index::TessCoordInstanceIDVertexID: // TODO(Subv): Find out what the values are for the first two elements when inside a // vertex shader, and what's the value of the fourth element when inside a Tess Eval // shader. - ASSERT(stage == ShaderType::Vertex); switch (element) { + case 0: + case 1: + return {OpLoad(t_float, AccessElement(t_in_float, tess_coord, element)), + Type::Float}; case 2: - return BitcastFrom<Type::Uint>(Emit(OpLoad(t_uint, instance_index))); + return {OpLoad(t_uint, instance_index), Type::Uint}; case 3: - return BitcastFrom<Type::Uint>(Emit(OpLoad(t_uint, vertex_index))); + return {OpLoad(t_uint, vertex_index), Type::Uint}; } UNIMPLEMENTED_MSG("Unmanaged TessCoordInstanceIDVertexID element={}", element); - return Constant(t_float, 0); + return {Constant(t_uint, 0U), Type::Uint}; case Attribute::Index::FrontFacing: // TODO(Subv): Find out what the values are for the other elements. ASSERT(stage == ShaderType::Fragment); if (element == 3) { - const Id is_front_facing = Emit(OpLoad(t_bool, front_facing)); - const Id true_value = - BitcastTo<Type::Float>(Constant(t_int, static_cast<s32>(-1))); - const Id false_value = BitcastTo<Type::Float>(Constant(t_int, 0)); - return Emit(OpSelect(t_float, is_front_facing, true_value, false_value)); + const Id is_front_facing = OpLoad(t_bool, front_facing); + const Id true_value = Constant(t_int, static_cast<s32>(-1)); + const Id false_value = Constant(t_int, 0); + return {OpSelect(t_int, is_front_facing, true_value, false_value), Type::Int}; } UNIMPLEMENTED_MSG("Unmanaged FrontFacing element={}", element); - return Constant(t_float, 0.0f); + return {v_float_zero, Type::Float}; default: if (IsGenericAttribute(attribute)) { - const Id pointer = - AccessElement(t_in_float, input_attributes.at(attribute), element); - return Emit(OpLoad(t_float, pointer)); + const u32 location = GetGenericAttributeLocation(attribute); + const auto type_descriptor = GetAttributeType(location); + const Type type = type_descriptor.type; + const Id attribute_id = input_attributes.at(attribute); + const Id pointer = ArrayPass(type_descriptor.scalar, attribute_id, {element}); + return {OpLoad(GetTypeDefinition(type), pointer), type}; } break; } UNIMPLEMENTED_MSG("Unhandled input attribute: {}", static_cast<u32>(attribute)); + return {v_float_zero, Type::Float}; + } - } else if (const auto cbuf = std::get_if<CbufNode>(&*node)) { + if (const auto cbuf = std::get_if<CbufNode>(&*node)) { const Node& offset = cbuf->GetOffset(); const Id buffer_id = constant_buffers.at(cbuf->GetIndex()); Id pointer{}; if (device.IsKhrUniformBufferStandardLayoutSupported()) { - const Id buffer_offset = Emit(OpShiftRightLogical( - t_uint, BitcastTo<Type::Uint>(Visit(offset)), Constant(t_uint, 2u))); - pointer = Emit( - OpAccessChain(t_cbuf_float, buffer_id, Constant(t_uint, 0u), buffer_offset)); + const Id buffer_offset = + OpShiftRightLogical(t_uint, AsUint(Visit(offset)), Constant(t_uint, 2U)); + pointer = + OpAccessChain(t_cbuf_float, buffer_id, Constant(t_uint, 0U), buffer_offset); } else { Id buffer_index{}; Id buffer_element{}; @@ -614,53 +1094,76 @@ private: buffer_element = Constant(t_uint, (offset_imm / 4) % 4); } else if (std::holds_alternative<OperationNode>(*offset)) { // Indirect access - const Id offset_id = BitcastTo<Type::Uint>(Visit(offset)); - const Id unsafe_offset = Emit(OpUDiv(t_uint, offset_id, Constant(t_uint, 4))); - const Id final_offset = Emit(OpUMod( - t_uint, unsafe_offset, Constant(t_uint, MAX_CONSTBUFFER_ELEMENTS - 1))); - buffer_index = Emit(OpUDiv(t_uint, final_offset, Constant(t_uint, 4))); - buffer_element = Emit(OpUMod(t_uint, final_offset, Constant(t_uint, 4))); + const Id offset_id = AsUint(Visit(offset)); + const Id unsafe_offset = OpUDiv(t_uint, offset_id, Constant(t_uint, 4)); + const Id final_offset = + OpUMod(t_uint, unsafe_offset, Constant(t_uint, MaxConstBufferElements - 1)); + buffer_index = OpUDiv(t_uint, final_offset, Constant(t_uint, 4)); + buffer_element = OpUMod(t_uint, final_offset, Constant(t_uint, 4)); } else { UNREACHABLE_MSG("Unmanaged offset node type"); } - pointer = Emit(OpAccessChain(t_cbuf_float, buffer_id, Constant(t_uint, 0), - buffer_index, buffer_element)); + pointer = OpAccessChain(t_cbuf_float, buffer_id, Constant(t_uint, 0), buffer_index, + buffer_element); } - return Emit(OpLoad(t_float, pointer)); + return {OpLoad(t_float, pointer), Type::Float}; + } - } else if (const auto gmem = std::get_if<GmemNode>(&*node)) { + if (const auto gmem = std::get_if<GmemNode>(&*node)) { const Id gmem_buffer = global_buffers.at(gmem->GetDescriptor()); - const Id real = BitcastTo<Type::Uint>(Visit(gmem->GetRealAddress())); - const Id base = BitcastTo<Type::Uint>(Visit(gmem->GetBaseAddress())); + const Id real = AsUint(Visit(gmem->GetRealAddress())); + const Id base = AsUint(Visit(gmem->GetBaseAddress())); + + Id offset = OpISub(t_uint, real, base); + offset = OpUDiv(t_uint, offset, Constant(t_uint, 4U)); + return {OpLoad(t_float, + OpAccessChain(t_gmem_float, gmem_buffer, Constant(t_uint, 0U), offset)), + Type::Float}; + } - Id offset = Emit(OpISub(t_uint, real, base)); - offset = Emit(OpUDiv(t_uint, offset, Constant(t_uint, 4u))); - return Emit(OpLoad(t_float, Emit(OpAccessChain(t_gmem_float, gmem_buffer, - Constant(t_uint, 0u), offset)))); + if (const auto lmem = std::get_if<LmemNode>(&*node)) { + Id address = AsUint(Visit(lmem->GetAddress())); + address = OpShiftRightLogical(t_uint, address, Constant(t_uint, 2U)); + const Id pointer = OpAccessChain(t_prv_float, local_memory, address); + return {OpLoad(t_float, pointer), Type::Float}; + } + + if (const auto smem = std::get_if<SmemNode>(&*node)) { + Id address = AsUint(Visit(smem->GetAddress())); + address = OpShiftRightLogical(t_uint, address, Constant(t_uint, 2U)); + const Id pointer = OpAccessChain(t_smem_uint, shared_memory, address); + return {OpLoad(t_uint, pointer), Type::Uint}; + } - } else if (const auto conditional = std::get_if<ConditionalNode>(&*node)) { + if (const auto internal_flag = std::get_if<InternalFlagNode>(&*node)) { + const Id flag = internal_flags.at(static_cast<std::size_t>(internal_flag->GetFlag())); + return {OpLoad(t_bool, flag), Type::Bool}; + } + + if (const auto conditional = std::get_if<ConditionalNode>(&*node)) { // It's invalid to call conditional on nested nodes, use an operation instead const Id true_label = OpLabel(); const Id skip_label = OpLabel(); - const Id condition = Visit(conditional->GetCondition()); - Emit(OpSelectionMerge(skip_label, spv::SelectionControlMask::MaskNone)); - Emit(OpBranchConditional(condition, true_label, skip_label)); - Emit(true_label); + const Id condition = AsBool(Visit(conditional->GetCondition())); + OpSelectionMerge(skip_label, spv::SelectionControlMask::MaskNone); + OpBranchConditional(condition, true_label, skip_label); + AddLabel(true_label); - ++conditional_nest_count; + conditional_branch_set = true; + inside_branch = false; VisitBasicBlock(conditional->GetCode()); - --conditional_nest_count; - - if (inside_branch == 0) { - Emit(OpBranch(skip_label)); + conditional_branch_set = false; + if (!inside_branch) { + OpBranch(skip_label); } else { - inside_branch--; + inside_branch = false; } - Emit(skip_label); + AddLabel(skip_label); return {}; + } - } else if (const auto comment = std::get_if<CommentNode>(&*node)) { - Name(Emit(OpUndef(t_void)), comment->GetText()); + if (const auto comment = std::get_if<CommentNode>(&*node)) { + Name(OpUndef(t_void), comment->GetText()); return {}; } @@ -669,94 +1172,126 @@ private: } template <Id (Module::*func)(Id, Id), Type result_type, Type type_a = result_type> - Id Unary(Operation operation) { + Expression Unary(Operation operation) { const Id type_def = GetTypeDefinition(result_type); - const Id op_a = VisitOperand<type_a>(operation, 0); + const Id op_a = As(Visit(operation[0]), type_a); - const Id value = BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a))); + const Id value = (this->*func)(type_def, op_a); if (IsPrecise(operation)) { Decorate(value, spv::Decoration::NoContraction); } - return value; + return {value, result_type}; } template <Id (Module::*func)(Id, Id, Id), Type result_type, Type type_a = result_type, Type type_b = type_a> - Id Binary(Operation operation) { + Expression Binary(Operation operation) { const Id type_def = GetTypeDefinition(result_type); - const Id op_a = VisitOperand<type_a>(operation, 0); - const Id op_b = VisitOperand<type_b>(operation, 1); + const Id op_a = As(Visit(operation[0]), type_a); + const Id op_b = As(Visit(operation[1]), type_b); - const Id value = BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a, op_b))); + const Id value = (this->*func)(type_def, op_a, op_b); if (IsPrecise(operation)) { Decorate(value, spv::Decoration::NoContraction); } - return value; + return {value, result_type}; } template <Id (Module::*func)(Id, Id, Id, Id), Type result_type, Type type_a = result_type, Type type_b = type_a, Type type_c = type_b> - Id Ternary(Operation operation) { + Expression Ternary(Operation operation) { const Id type_def = GetTypeDefinition(result_type); - const Id op_a = VisitOperand<type_a>(operation, 0); - const Id op_b = VisitOperand<type_b>(operation, 1); - const Id op_c = VisitOperand<type_c>(operation, 2); + const Id op_a = As(Visit(operation[0]), type_a); + const Id op_b = As(Visit(operation[1]), type_b); + const Id op_c = As(Visit(operation[2]), type_c); - const Id value = BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a, op_b, op_c))); + const Id value = (this->*func)(type_def, op_a, op_b, op_c); if (IsPrecise(operation)) { Decorate(value, spv::Decoration::NoContraction); } - return value; + return {value, result_type}; } template <Id (Module::*func)(Id, Id, Id, Id, Id), Type result_type, Type type_a = result_type, Type type_b = type_a, Type type_c = type_b, Type type_d = type_c> - Id Quaternary(Operation operation) { + Expression Quaternary(Operation operation) { const Id type_def = GetTypeDefinition(result_type); - const Id op_a = VisitOperand<type_a>(operation, 0); - const Id op_b = VisitOperand<type_b>(operation, 1); - const Id op_c = VisitOperand<type_c>(operation, 2); - const Id op_d = VisitOperand<type_d>(operation, 3); + const Id op_a = As(Visit(operation[0]), type_a); + const Id op_b = As(Visit(operation[1]), type_b); + const Id op_c = As(Visit(operation[2]), type_c); + const Id op_d = As(Visit(operation[3]), type_d); - const Id value = - BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a, op_b, op_c, op_d))); + const Id value = (this->*func)(type_def, op_a, op_b, op_c, op_d); if (IsPrecise(operation)) { Decorate(value, spv::Decoration::NoContraction); } - return value; + return {value, result_type}; } - Id Assign(Operation operation) { + Expression Assign(Operation operation) { const Node& dest = operation[0]; const Node& src = operation[1]; - Id target{}; + Expression target{}; if (const auto gpr = std::get_if<GprNode>(&*dest)) { if (gpr->GetIndex() == Register::ZeroIndex) { // Writing to Register::ZeroIndex is a no op return {}; } - target = registers.at(gpr->GetIndex()); + target = {registers.at(gpr->GetIndex()), Type::Float}; } else if (const auto abuf = std::get_if<AbufNode>(&*dest)) { - target = [&]() -> Id { + const auto& buffer = abuf->GetBuffer(); + const auto ArrayPass = [&](Id pointer_type, Id composite, std::vector<u32> indices) { + std::vector<Id> members; + members.reserve(std::size(indices) + 1); + + if (buffer && IsOutputAttributeArray()) { + members.push_back(AsUint(Visit(buffer))); + } + for (const u32 index : indices) { + members.push_back(Constant(t_uint, index)); + } + return OpAccessChain(pointer_type, composite, members); + }; + + target = [&]() -> Expression { + const u32 element = abuf->GetElement(); switch (const auto attribute = abuf->GetIndex(); attribute) { - case Attribute::Index::Position: - return AccessElement(t_out_float, per_vertex, position_index, - abuf->GetElement()); + case Attribute::Index::Position: { + const u32 index = out_indices.position.value(); + return {ArrayPass(t_out_float, out_vertex, {index, element}), Type::Float}; + } case Attribute::Index::LayerViewportPointSize: - UNIMPLEMENTED_IF(abuf->GetElement() != 3); - return AccessElement(t_out_float, per_vertex, point_size_index); - case Attribute::Index::ClipDistances0123: - return AccessElement(t_out_float, per_vertex, clip_distances_index, - abuf->GetElement()); - case Attribute::Index::ClipDistances4567: - return AccessElement(t_out_float, per_vertex, clip_distances_index, - abuf->GetElement() + 4); + switch (element) { + case 2: { + if (!out_indices.viewport) { + return {}; + } + const u32 index = out_indices.viewport.value(); + return {AccessElement(t_out_int, out_vertex, index), Type::Int}; + } + case 3: { + const auto index = out_indices.point_size.value(); + return {AccessElement(t_out_float, out_vertex, index), Type::Float}; + } + default: + UNIMPLEMENTED_MSG("LayerViewportPoint element={}", abuf->GetElement()); + return {}; + } + case Attribute::Index::ClipDistances0123: { + const u32 index = out_indices.clip_distances.value(); + return {AccessElement(t_out_float, out_vertex, index, element), Type::Float}; + } + case Attribute::Index::ClipDistances4567: { + const u32 index = out_indices.clip_distances.value(); + return {AccessElement(t_out_float, out_vertex, index, element + 4), + Type::Float}; + } default: if (IsGenericAttribute(attribute)) { - return AccessElement(t_out_float, output_attributes.at(attribute), - abuf->GetElement()); + const Id composite = output_attributes.at(attribute); + return {ArrayPass(t_out_float, composite, {element}), Type::Float}; } UNIMPLEMENTED_MSG("Unhandled output attribute: {}", static_cast<u32>(attribute)); @@ -764,72 +1299,154 @@ private: } }(); + } else if (const auto patch = std::get_if<PatchNode>(&*dest)) { + target = [&]() -> Expression { + const u32 offset = patch->GetOffset(); + switch (offset) { + case 0: + case 1: + case 2: + case 3: + return {AccessElement(t_out_float, tess_level_outer, offset % 4), Type::Float}; + case 4: + case 5: + return {AccessElement(t_out_float, tess_level_inner, offset % 4), Type::Float}; + } + UNIMPLEMENTED_MSG("Unhandled patch output offset: {}", offset); + return {}; + }(); + } else if (const auto lmem = std::get_if<LmemNode>(&*dest)) { - Id address = BitcastTo<Type::Uint>(Visit(lmem->GetAddress())); - address = Emit(OpUDiv(t_uint, address, Constant(t_uint, 4))); - target = Emit(OpAccessChain(t_prv_float, local_memory, {address})); + Id address = AsUint(Visit(lmem->GetAddress())); + address = OpUDiv(t_uint, address, Constant(t_uint, 4)); + target = {OpAccessChain(t_prv_float, local_memory, address), Type::Float}; + + } else if (const auto smem = std::get_if<SmemNode>(&*dest)) { + ASSERT(stage == ShaderType::Compute); + Id address = AsUint(Visit(smem->GetAddress())); + address = OpShiftRightLogical(t_uint, address, Constant(t_uint, 2U)); + target = {OpAccessChain(t_smem_uint, shared_memory, address), Type::Uint}; + + } else if (const auto gmem = std::get_if<GmemNode>(&*dest)) { + const Id real = AsUint(Visit(gmem->GetRealAddress())); + const Id base = AsUint(Visit(gmem->GetBaseAddress())); + const Id diff = OpISub(t_uint, real, base); + const Id offset = OpShiftRightLogical(t_uint, diff, Constant(t_uint, 2)); + + const Id gmem_buffer = global_buffers.at(gmem->GetDescriptor()); + target = {OpAccessChain(t_gmem_float, gmem_buffer, Constant(t_uint, 0), offset), + Type::Float}; + + } else { + UNIMPLEMENTED(); } - Emit(OpStore(target, Visit(src))); + OpStore(target.id, As(Visit(src), target.type)); return {}; } - Id FCastHalf0(Operation operation) { - UNIMPLEMENTED(); - return {}; + template <u32 offset> + Expression FCastHalf(Operation operation) { + const Id value = AsHalfFloat(Visit(operation[0])); + return {GetFloatFromHalfScalar(OpCompositeExtract(t_scalar_half, value, offset)), + Type::Float}; } - Id FCastHalf1(Operation operation) { - UNIMPLEMENTED(); - return {}; - } + Expression FSwizzleAdd(Operation operation) { + const Id minus = Constant(t_float, -1.0f); + const Id plus = v_float_one; + const Id zero = v_float_zero; + const Id lut_a = ConstantComposite(t_float4, minus, plus, minus, zero); + const Id lut_b = ConstantComposite(t_float4, minus, minus, plus, minus); - Id FSwizzleAdd(Operation operation) { - UNIMPLEMENTED(); - return {}; - } + Id mask = OpLoad(t_uint, thread_id); + mask = OpBitwiseAnd(t_uint, mask, Constant(t_uint, 3)); + mask = OpShiftLeftLogical(t_uint, mask, Constant(t_uint, 1)); + mask = OpShiftRightLogical(t_uint, AsUint(Visit(operation[2])), mask); + mask = OpBitwiseAnd(t_uint, mask, Constant(t_uint, 3)); - Id HNegate(Operation operation) { - UNIMPLEMENTED(); - return {}; + const Id modifier_a = OpVectorExtractDynamic(t_float, lut_a, mask); + const Id modifier_b = OpVectorExtractDynamic(t_float, lut_b, mask); + + const Id op_a = OpFMul(t_float, AsFloat(Visit(operation[0])), modifier_a); + const Id op_b = OpFMul(t_float, AsFloat(Visit(operation[1])), modifier_b); + return {OpFAdd(t_float, op_a, op_b), Type::Float}; } - Id HClamp(Operation operation) { - UNIMPLEMENTED(); - return {}; + Expression HNegate(Operation operation) { + const bool is_f16 = device.IsFloat16Supported(); + const Id minus_one = Constant(t_scalar_half, is_f16 ? 0xbc00 : 0xbf800000); + const Id one = Constant(t_scalar_half, is_f16 ? 0x3c00 : 0x3f800000); + const auto GetNegate = [&](std::size_t index) { + return OpSelect(t_scalar_half, AsBool(Visit(operation[index])), minus_one, one); + }; + const Id negation = OpCompositeConstruct(t_half, GetNegate(1), GetNegate(2)); + return {OpFMul(t_half, AsHalfFloat(Visit(operation[0])), negation), Type::HalfFloat}; } - Id HCastFloat(Operation operation) { - UNIMPLEMENTED(); - return {}; + Expression HClamp(Operation operation) { + const auto Pack = [&](std::size_t index) { + const Id scalar = GetHalfScalarFromFloat(AsFloat(Visit(operation[index]))); + return OpCompositeConstruct(t_half, scalar, scalar); + }; + const Id value = AsHalfFloat(Visit(operation[0])); + const Id min = Pack(1); + const Id max = Pack(2); + + const Id clamped = OpFClamp(t_half, value, min, max); + if (IsPrecise(operation)) { + Decorate(clamped, spv::Decoration::NoContraction); + } + return {clamped, Type::HalfFloat}; } - Id HUnpack(Operation operation) { - UNIMPLEMENTED(); - return {}; + Expression HCastFloat(Operation operation) { + const Id value = GetHalfScalarFromFloat(AsFloat(Visit(operation[0]))); + return {OpCompositeConstruct(t_half, value, Constant(t_scalar_half, 0)), Type::HalfFloat}; } - Id HMergeF32(Operation operation) { - UNIMPLEMENTED(); - return {}; + Expression HUnpack(Operation operation) { + Expression operand = Visit(operation[0]); + const auto type = std::get<Tegra::Shader::HalfType>(operation.GetMeta()); + if (type == Tegra::Shader::HalfType::H0_H1) { + return operand; + } + const auto value = [&] { + switch (std::get<Tegra::Shader::HalfType>(operation.GetMeta())) { + case Tegra::Shader::HalfType::F32: + return GetHalfScalarFromFloat(AsFloat(operand)); + case Tegra::Shader::HalfType::H0_H0: + return OpCompositeExtract(t_scalar_half, AsHalfFloat(operand), 0); + case Tegra::Shader::HalfType::H1_H1: + return OpCompositeExtract(t_scalar_half, AsHalfFloat(operand), 1); + default: + UNREACHABLE(); + return ConstantNull(t_half); + } + }(); + return {OpCompositeConstruct(t_half, value, value), Type::HalfFloat}; } - Id HMergeH0(Operation operation) { - UNIMPLEMENTED(); - return {}; + Expression HMergeF32(Operation operation) { + const Id value = AsHalfFloat(Visit(operation[0])); + return {GetFloatFromHalfScalar(OpCompositeExtract(t_scalar_half, value, 0)), Type::Float}; } - Id HMergeH1(Operation operation) { - UNIMPLEMENTED(); - return {}; + template <u32 offset> + Expression HMergeHN(Operation operation) { + const Id target = AsHalfFloat(Visit(operation[0])); + const Id source = AsHalfFloat(Visit(operation[1])); + const Id object = OpCompositeExtract(t_scalar_half, source, offset); + return {OpCompositeInsert(t_half, object, target, offset), Type::HalfFloat}; } - Id HPack2(Operation operation) { - UNIMPLEMENTED(); - return {}; + Expression HPack2(Operation operation) { + const Id low = GetHalfScalarFromFloat(AsFloat(Visit(operation[0]))); + const Id high = GetHalfScalarFromFloat(AsFloat(Visit(operation[1]))); + return {OpCompositeConstruct(t_half, low, high), Type::HalfFloat}; } - Id LogicalAssign(Operation operation) { + Expression LogicalAssign(Operation operation) { const Node& dest = operation[0]; const Node& src = operation[1]; @@ -850,106 +1467,190 @@ private: target = internal_flags.at(static_cast<u32>(flag->GetFlag())); } - Emit(OpStore(target, Visit(src))); + OpStore(target, AsBool(Visit(src))); return {}; } - Id LogicalPick2(Operation operation) { - UNIMPLEMENTED(); - return {}; + Id GetTextureSampler(Operation operation) { + const auto& meta = std::get<MetaTexture>(operation.GetMeta()); + ASSERT(!meta.sampler.IsBuffer()); + + const auto& entry = sampled_images.at(meta.sampler.GetIndex()); + return OpLoad(entry.sampled_image_type, entry.sampler); } - Id LogicalAnd2(Operation operation) { - UNIMPLEMENTED(); - return {}; + Id GetTextureImage(Operation operation) { + const auto& meta = std::get<MetaTexture>(operation.GetMeta()); + const u32 index = meta.sampler.GetIndex(); + if (meta.sampler.IsBuffer()) { + const auto& entry = texel_buffers.at(index); + return OpLoad(entry.image_type, entry.image); + } else { + const auto& entry = sampled_images.at(index); + return OpImage(entry.image_type, GetTextureSampler(operation)); + } } - Id GetTextureSampler(Operation operation) { - const auto meta = std::get_if<MetaTexture>(&operation.GetMeta()); - const auto entry = sampler_images.at(static_cast<u32>(meta->sampler.GetIndex())); - return Emit(OpLoad(entry.sampled_image_type, entry.sampler)); + Id GetImage(Operation operation) { + const auto& meta = std::get<MetaImage>(operation.GetMeta()); + const auto entry = images.at(meta.image.GetIndex()); + return OpLoad(entry.image_type, entry.image); } - Id GetTextureImage(Operation operation) { - const auto meta = std::get_if<MetaTexture>(&operation.GetMeta()); - const auto entry = sampler_images.at(static_cast<u32>(meta->sampler.GetIndex())); - return Emit(OpImage(entry.image_type, GetTextureSampler(operation))); + Id AssembleVector(const std::vector<Id>& coords, Type type) { + const Id coords_type = GetTypeVectorDefinitionLut(type).at(coords.size() - 1); + return coords.size() == 1 ? coords[0] : OpCompositeConstruct(coords_type, coords); } - Id GetTextureCoordinates(Operation operation) { - const auto meta = std::get_if<MetaTexture>(&operation.GetMeta()); + Id GetCoordinates(Operation operation, Type type) { std::vector<Id> coords; for (std::size_t i = 0; i < operation.GetOperandsCount(); ++i) { - coords.push_back(Visit(operation[i])); + coords.push_back(As(Visit(operation[i]), type)); } - if (meta->sampler.IsArray()) { - const Id array_integer = BitcastTo<Type::Int>(Visit(meta->array)); - coords.push_back(Emit(OpConvertSToF(t_float, array_integer))); + if (const auto meta = std::get_if<MetaTexture>(&operation.GetMeta())) { + // Add array coordinate for textures + if (meta->sampler.IsArray()) { + Id array = AsInt(Visit(meta->array)); + if (type == Type::Float) { + array = OpConvertSToF(t_float, array); + } + coords.push_back(array); + } } - if (meta->sampler.IsShadow()) { - coords.push_back(Visit(meta->depth_compare)); + return AssembleVector(coords, type); + } + + Id GetOffsetCoordinates(Operation operation) { + const auto& meta = std::get<MetaTexture>(operation.GetMeta()); + std::vector<Id> coords; + coords.reserve(meta.aoffi.size()); + for (const auto& coord : meta.aoffi) { + coords.push_back(AsInt(Visit(coord))); } + return AssembleVector(coords, Type::Int); + } - const std::array<Id, 4> t_float_lut = {nullptr, t_float2, t_float3, t_float4}; - return coords.size() == 1 - ? coords[0] - : Emit(OpCompositeConstruct(t_float_lut.at(coords.size() - 1), coords)); + std::pair<Id, Id> GetDerivatives(Operation operation) { + const auto& meta = std::get<MetaTexture>(operation.GetMeta()); + const auto& derivatives = meta.derivates; + ASSERT(derivatives.size() % 2 == 0); + + const std::size_t components = derivatives.size() / 2; + std::vector<Id> dx, dy; + dx.reserve(components); + dy.reserve(components); + for (std::size_t index = 0; index < components; ++index) { + dx.push_back(AsFloat(Visit(derivatives.at(index * 2 + 0)))); + dy.push_back(AsFloat(Visit(derivatives.at(index * 2 + 1)))); + } + return {AssembleVector(dx, Type::Float), AssembleVector(dy, Type::Float)}; } - Id GetTextureElement(Operation operation, Id sample_value) { - const auto meta = std::get_if<MetaTexture>(&operation.GetMeta()); - ASSERT(meta); - return Emit(OpCompositeExtract(t_float, sample_value, meta->element)); + Expression GetTextureElement(Operation operation, Id sample_value, Type type) { + const auto& meta = std::get<MetaTexture>(operation.GetMeta()); + const auto type_def = GetTypeDefinition(type); + return {OpCompositeExtract(type_def, sample_value, meta.element), type}; } - Id Texture(Operation operation) { - const Id texture = Emit(OpImageSampleImplicitLod(t_float4, GetTextureSampler(operation), - GetTextureCoordinates(operation))); - return GetTextureElement(operation, texture); + Expression Texture(Operation operation) { + const auto& meta = std::get<MetaTexture>(operation.GetMeta()); + UNIMPLEMENTED_IF(!meta.aoffi.empty()); + + const bool can_implicit = stage == ShaderType::Fragment; + const Id sampler = GetTextureSampler(operation); + const Id coords = GetCoordinates(operation, Type::Float); + + if (meta.depth_compare) { + // Depth sampling + UNIMPLEMENTED_IF(meta.bias); + const Id dref = AsFloat(Visit(meta.depth_compare)); + if (can_implicit) { + return {OpImageSampleDrefImplicitLod(t_float, sampler, coords, dref, {}), + Type::Float}; + } else { + return {OpImageSampleDrefExplicitLod(t_float, sampler, coords, dref, + spv::ImageOperandsMask::Lod, v_float_zero), + Type::Float}; + } + } + + std::vector<Id> operands; + spv::ImageOperandsMask mask{}; + if (meta.bias) { + mask = mask | spv::ImageOperandsMask::Bias; + operands.push_back(AsFloat(Visit(meta.bias))); + } + + Id texture; + if (can_implicit) { + texture = OpImageSampleImplicitLod(t_float4, sampler, coords, mask, operands); + } else { + texture = OpImageSampleExplicitLod(t_float4, sampler, coords, + mask | spv::ImageOperandsMask::Lod, v_float_zero, + operands); + } + return GetTextureElement(operation, texture, Type::Float); } - Id TextureLod(Operation operation) { - const auto meta = std::get_if<MetaTexture>(&operation.GetMeta()); - const Id texture = Emit(OpImageSampleExplicitLod( - t_float4, GetTextureSampler(operation), GetTextureCoordinates(operation), - spv::ImageOperandsMask::Lod, Visit(meta->lod))); - return GetTextureElement(operation, texture); + Expression TextureLod(Operation operation) { + const auto& meta = std::get<MetaTexture>(operation.GetMeta()); + + const Id sampler = GetTextureSampler(operation); + const Id coords = GetCoordinates(operation, Type::Float); + const Id lod = AsFloat(Visit(meta.lod)); + + spv::ImageOperandsMask mask = spv::ImageOperandsMask::Lod; + std::vector<Id> operands; + if (!meta.aoffi.empty()) { + mask = mask | spv::ImageOperandsMask::Offset; + operands.push_back(GetOffsetCoordinates(operation)); + } + + if (meta.sampler.IsShadow()) { + const Id dref = AsFloat(Visit(meta.depth_compare)); + return { + OpImageSampleDrefExplicitLod(t_float, sampler, coords, dref, mask, lod, operands), + Type::Float}; + } + const Id texture = OpImageSampleExplicitLod(t_float4, sampler, coords, mask, lod, operands); + return GetTextureElement(operation, texture, Type::Float); } - Id TextureGather(Operation operation) { - const auto meta = std::get_if<MetaTexture>(&operation.GetMeta()); - const auto coords = GetTextureCoordinates(operation); + Expression TextureGather(Operation operation) { + const auto& meta = std::get<MetaTexture>(operation.GetMeta()); + UNIMPLEMENTED_IF(!meta.aoffi.empty()); - Id texture; - if (meta->sampler.IsShadow()) { - texture = Emit(OpImageDrefGather(t_float4, GetTextureSampler(operation), coords, - Visit(meta->component))); + const Id coords = GetCoordinates(operation, Type::Float); + Id texture{}; + if (meta.sampler.IsShadow()) { + texture = OpImageDrefGather(t_float4, GetTextureSampler(operation), coords, + AsFloat(Visit(meta.depth_compare))); } else { u32 component_value = 0; - if (meta->component) { - const auto component = std::get_if<ImmediateNode>(&*meta->component); + if (meta.component) { + const auto component = std::get_if<ImmediateNode>(&*meta.component); ASSERT_MSG(component, "Component is not an immediate value"); component_value = component->GetValue(); } - texture = Emit(OpImageGather(t_float4, GetTextureSampler(operation), coords, - Constant(t_uint, component_value))); + texture = OpImageGather(t_float4, GetTextureSampler(operation), coords, + Constant(t_uint, component_value)); } - - return GetTextureElement(operation, texture); + return GetTextureElement(operation, texture, Type::Float); } - Id TextureQueryDimensions(Operation operation) { - const auto meta = std::get_if<MetaTexture>(&operation.GetMeta()); - const auto image_id = GetTextureImage(operation); - AddCapability(spv::Capability::ImageQuery); + Expression TextureQueryDimensions(Operation operation) { + const auto& meta = std::get<MetaTexture>(operation.GetMeta()); + UNIMPLEMENTED_IF(!meta.aoffi.empty()); + UNIMPLEMENTED_IF(meta.depth_compare); - if (meta->element == 3) { - return BitcastTo<Type::Float>(Emit(OpImageQueryLevels(t_int, image_id))); + const auto image_id = GetTextureImage(operation); + if (meta.element == 3) { + return {OpImageQueryLevels(t_int, image_id), Type::Int}; } - const Id lod = VisitOperand<Type::Uint>(operation, 0); + const Id lod = AsUint(Visit(operation[0])); const std::size_t coords_count = [&]() { - switch (const auto type = meta->sampler.GetType(); type) { + switch (const auto type = meta.sampler.GetType(); type) { case Tegra::Shader::TextureType::Texture1D: return 1; case Tegra::Shader::TextureType::Texture2D: @@ -963,141 +1664,190 @@ private: } }(); - if (meta->element >= coords_count) { - return Constant(t_float, 0.0f); + if (meta.element >= coords_count) { + return {v_float_zero, Type::Float}; } const std::array<Id, 3> types = {t_int, t_int2, t_int3}; - const Id sizes = Emit(OpImageQuerySizeLod(types.at(coords_count - 1), image_id, lod)); - const Id size = Emit(OpCompositeExtract(t_int, sizes, meta->element)); - return BitcastTo<Type::Float>(size); + const Id sizes = OpImageQuerySizeLod(types.at(coords_count - 1), image_id, lod); + const Id size = OpCompositeExtract(t_int, sizes, meta.element); + return {size, Type::Int}; } - Id TextureQueryLod(Operation operation) { - UNIMPLEMENTED(); - return {}; + Expression TextureQueryLod(Operation operation) { + const auto& meta = std::get<MetaTexture>(operation.GetMeta()); + UNIMPLEMENTED_IF(!meta.aoffi.empty()); + UNIMPLEMENTED_IF(meta.depth_compare); + + if (meta.element >= 2) { + UNREACHABLE_MSG("Invalid element"); + return {v_float_zero, Type::Float}; + } + const auto sampler_id = GetTextureSampler(operation); + + const Id multiplier = Constant(t_float, 256.0f); + const Id multipliers = ConstantComposite(t_float2, multiplier, multiplier); + + const Id coords = GetCoordinates(operation, Type::Float); + Id size = OpImageQueryLod(t_float2, sampler_id, coords); + size = OpFMul(t_float2, size, multipliers); + size = OpConvertFToS(t_int2, size); + return GetTextureElement(operation, size, Type::Int); } - Id TexelFetch(Operation operation) { + Expression TexelFetch(Operation operation) { + const auto& meta = std::get<MetaTexture>(operation.GetMeta()); + UNIMPLEMENTED_IF(meta.depth_compare); + + const Id image = GetTextureImage(operation); + const Id coords = GetCoordinates(operation, Type::Int); + Id fetch; + if (meta.lod && !meta.sampler.IsBuffer()) { + fetch = OpImageFetch(t_float4, image, coords, spv::ImageOperandsMask::Lod, + AsInt(Visit(meta.lod))); + } else { + fetch = OpImageFetch(t_float4, image, coords); + } + return GetTextureElement(operation, fetch, Type::Float); + } + + Expression TextureGradient(Operation operation) { + const auto& meta = std::get<MetaTexture>(operation.GetMeta()); + UNIMPLEMENTED_IF(!meta.aoffi.empty()); + + const Id sampler = GetTextureSampler(operation); + const Id coords = GetCoordinates(operation, Type::Float); + const auto [dx, dy] = GetDerivatives(operation); + const std::vector grad = {dx, dy}; + + static constexpr auto mask = spv::ImageOperandsMask::Grad; + const Id texture = OpImageSampleImplicitLod(t_float4, sampler, coords, mask, grad); + return GetTextureElement(operation, texture, Type::Float); + } + + Expression ImageLoad(Operation operation) { UNIMPLEMENTED(); return {}; } - Id TextureGradient(Operation operation) { - UNIMPLEMENTED(); + Expression ImageStore(Operation operation) { + const auto meta{std::get<MetaImage>(operation.GetMeta())}; + std::vector<Id> colors; + for (const auto& value : meta.values) { + colors.push_back(AsUint(Visit(value))); + } + + const Id coords = GetCoordinates(operation, Type::Int); + const Id texel = OpCompositeConstruct(t_uint4, colors); + + OpImageWrite(GetImage(operation), coords, texel, {}); return {}; } - Id ImageLoad(Operation operation) { + Expression AtomicImageAdd(Operation operation) { UNIMPLEMENTED(); return {}; } - Id ImageStore(Operation operation) { + Expression AtomicImageMin(Operation operation) { UNIMPLEMENTED(); return {}; } - Id AtomicImageAdd(Operation operation) { + Expression AtomicImageMax(Operation operation) { UNIMPLEMENTED(); return {}; } - Id AtomicImageAnd(Operation operation) { + Expression AtomicImageAnd(Operation operation) { UNIMPLEMENTED(); return {}; } - Id AtomicImageOr(Operation operation) { + Expression AtomicImageOr(Operation operation) { UNIMPLEMENTED(); return {}; } - Id AtomicImageXor(Operation operation) { + Expression AtomicImageXor(Operation operation) { UNIMPLEMENTED(); return {}; } - Id AtomicImageExchange(Operation operation) { + Expression AtomicImageExchange(Operation operation) { UNIMPLEMENTED(); return {}; } - Id Branch(Operation operation) { - const auto target = std::get_if<ImmediateNode>(&*operation[0]); - UNIMPLEMENTED_IF(!target); - - Emit(OpStore(jmp_to, Constant(t_uint, target->GetValue()))); - Emit(OpBranch(continue_label)); - inside_branch = conditional_nest_count; - if (conditional_nest_count == 0) { - Emit(OpLabel()); + Expression Branch(Operation operation) { + const auto& target = std::get<ImmediateNode>(*operation[0]); + OpStore(jmp_to, Constant(t_uint, target.GetValue())); + OpBranch(continue_label); + inside_branch = true; + if (!conditional_branch_set) { + AddLabel(); } return {}; } - Id BranchIndirect(Operation operation) { - const Id op_a = VisitOperand<Type::Uint>(operation, 0); + Expression BranchIndirect(Operation operation) { + const Id op_a = AsUint(Visit(operation[0])); - Emit(OpStore(jmp_to, op_a)); - Emit(OpBranch(continue_label)); - inside_branch = conditional_nest_count; - if (conditional_nest_count == 0) { - Emit(OpLabel()); + OpStore(jmp_to, op_a); + OpBranch(continue_label); + inside_branch = true; + if (!conditional_branch_set) { + AddLabel(); } return {}; } - Id PushFlowStack(Operation operation) { - const auto target = std::get_if<ImmediateNode>(&*operation[0]); - ASSERT(target); - + Expression PushFlowStack(Operation operation) { + const auto& target = std::get<ImmediateNode>(*operation[0]); const auto [flow_stack, flow_stack_top] = GetFlowStack(operation); - const Id current = Emit(OpLoad(t_uint, flow_stack_top)); - const Id next = Emit(OpIAdd(t_uint, current, Constant(t_uint, 1))); - const Id access = Emit(OpAccessChain(t_func_uint, flow_stack, current)); + const Id current = OpLoad(t_uint, flow_stack_top); + const Id next = OpIAdd(t_uint, current, Constant(t_uint, 1)); + const Id access = OpAccessChain(t_func_uint, flow_stack, current); - Emit(OpStore(access, Constant(t_uint, target->GetValue()))); - Emit(OpStore(flow_stack_top, next)); + OpStore(access, Constant(t_uint, target.GetValue())); + OpStore(flow_stack_top, next); return {}; } - Id PopFlowStack(Operation operation) { + Expression PopFlowStack(Operation operation) { const auto [flow_stack, flow_stack_top] = GetFlowStack(operation); - const Id current = Emit(OpLoad(t_uint, flow_stack_top)); - const Id previous = Emit(OpISub(t_uint, current, Constant(t_uint, 1))); - const Id access = Emit(OpAccessChain(t_func_uint, flow_stack, previous)); - const Id target = Emit(OpLoad(t_uint, access)); - - Emit(OpStore(flow_stack_top, previous)); - Emit(OpStore(jmp_to, target)); - Emit(OpBranch(continue_label)); - inside_branch = conditional_nest_count; - if (conditional_nest_count == 0) { - Emit(OpLabel()); + const Id current = OpLoad(t_uint, flow_stack_top); + const Id previous = OpISub(t_uint, current, Constant(t_uint, 1)); + const Id access = OpAccessChain(t_func_uint, flow_stack, previous); + const Id target = OpLoad(t_uint, access); + + OpStore(flow_stack_top, previous); + OpStore(jmp_to, target); + OpBranch(continue_label); + inside_branch = true; + if (!conditional_branch_set) { + AddLabel(); } return {}; } - Id PreExit() { - switch (stage) { - case ShaderType::Vertex: { - // TODO(Rodrigo): We should use VK_EXT_depth_range_unrestricted instead, but it doesn't - // seem to be working on Nvidia's drivers and Intel (mesa and blob) doesn't support it. - const Id z_pointer = AccessElement(t_out_float, per_vertex, position_index, 2u); - Id depth = Emit(OpLoad(t_float, z_pointer)); - depth = Emit(OpFAdd(t_float, depth, Constant(t_float, 1.0f))); - depth = Emit(OpFMul(t_float, depth, Constant(t_float, 0.5f))); - Emit(OpStore(z_pointer, depth)); - break; + void PreExit() { + if (stage == ShaderType::Vertex) { + const u32 position_index = out_indices.position.value(); + const Id z_pointer = AccessElement(t_out_float, out_vertex, position_index, 2U); + const Id w_pointer = AccessElement(t_out_float, out_vertex, position_index, 3U); + Id depth = OpLoad(t_float, z_pointer); + depth = OpFAdd(t_float, depth, OpLoad(t_float, w_pointer)); + depth = OpFMul(t_float, depth, Constant(t_float, 0.5f)); + OpStore(z_pointer, depth); } - case ShaderType::Fragment: { + if (stage == ShaderType::Fragment) { const auto SafeGetRegister = [&](u32 reg) { // TODO(Rodrigo): Replace with contains once C++20 releases if (const auto it = registers.find(reg); it != registers.end()) { - return Emit(OpLoad(t_float, it->second)); + return OpLoad(t_float, it->second); } - return Constant(t_float, 0.0f); + return v_float_zero; }; UNIMPLEMENTED_IF_MSG(header.ps.omap.sample_mask != 0, @@ -1112,8 +1862,8 @@ private: // TODO(Subv): Figure out how dual-source blending is configured in the Switch. for (u32 component = 0; component < 4; ++component) { if (header.ps.IsColorComponentOutputEnabled(rt, component)) { - Emit(OpStore(AccessElement(t_out_float, frag_colors.at(rt), component), - SafeGetRegister(current_reg))); + OpStore(AccessElement(t_out_float, frag_colors.at(rt), component), + SafeGetRegister(current_reg)); ++current_reg; } } @@ -1121,110 +1871,117 @@ private: if (header.ps.omap.depth) { // The depth output is always 2 registers after the last color output, and // current_reg already contains one past the last color register. - Emit(OpStore(frag_depth, SafeGetRegister(current_reg + 1))); + OpStore(frag_depth, SafeGetRegister(current_reg + 1)); } - break; - } } - - return {}; } - Id Exit(Operation operation) { + Expression Exit(Operation operation) { PreExit(); - inside_branch = conditional_nest_count; - if (conditional_nest_count > 0) { - Emit(OpReturn()); + inside_branch = true; + if (conditional_branch_set) { + OpReturn(); } else { const Id dummy = OpLabel(); - Emit(OpBranch(dummy)); - Emit(dummy); - Emit(OpReturn()); - Emit(OpLabel()); + OpBranch(dummy); + AddLabel(dummy); + OpReturn(); + AddLabel(); } return {}; } - Id Discard(Operation operation) { - inside_branch = conditional_nest_count; - if (conditional_nest_count > 0) { - Emit(OpKill()); + Expression Discard(Operation operation) { + inside_branch = true; + if (conditional_branch_set) { + OpKill(); } else { const Id dummy = OpLabel(); - Emit(OpBranch(dummy)); - Emit(dummy); - Emit(OpKill()); - Emit(OpLabel()); + OpBranch(dummy); + AddLabel(dummy); + OpKill(); + AddLabel(); } return {}; } - Id EmitVertex(Operation operation) { - UNIMPLEMENTED(); + Expression EmitVertex(Operation) { + OpEmitVertex(); return {}; } - Id EndPrimitive(Operation operation) { - UNIMPLEMENTED(); + Expression EndPrimitive(Operation operation) { + OpEndPrimitive(); return {}; } - Id YNegate(Operation operation) { - UNIMPLEMENTED(); - return {}; + Expression InvocationId(Operation) { + return {OpLoad(t_int, invocation_id), Type::Int}; } - template <u32 element> - Id LocalInvocationId(Operation) { + Expression YNegate(Operation) { UNIMPLEMENTED(); - return {}; + return {Constant(t_float, 1.0f), Type::Float}; } template <u32 element> - Id WorkGroupId(Operation) { - UNIMPLEMENTED(); - return {}; + Expression LocalInvocationId(Operation) { + const Id id = OpLoad(t_uint3, local_invocation_id); + return {OpCompositeExtract(t_uint, id, element), Type::Uint}; } - Id BallotThread(Operation) { - UNIMPLEMENTED(); - return {}; + template <u32 element> + Expression WorkGroupId(Operation operation) { + const Id id = OpLoad(t_uint3, workgroup_id); + return {OpCompositeExtract(t_uint, id, element), Type::Uint}; } - Id VoteAll(Operation) { - UNIMPLEMENTED(); - return {}; - } + Expression BallotThread(Operation operation) { + const Id predicate = AsBool(Visit(operation[0])); + const Id ballot = OpSubgroupBallotKHR(t_uint4, predicate); - Id VoteAny(Operation) { - UNIMPLEMENTED(); - return {}; + if (!device.IsWarpSizePotentiallyBiggerThanGuest()) { + // Guest-like devices can just return the first index. + return {OpCompositeExtract(t_uint, ballot, 0U), Type::Uint}; + } + + // The others will have to return what is local to the current thread. + // For instance a device with a warp size of 64 will return the upper uint when the current + // thread is 38. + const Id tid = OpLoad(t_uint, thread_id); + const Id thread_index = OpShiftRightLogical(t_uint, tid, Constant(t_uint, 5)); + return {OpVectorExtractDynamic(t_uint, ballot, thread_index), Type::Uint}; } - Id VoteEqual(Operation) { - UNIMPLEMENTED(); - return {}; + template <Id (Module::*func)(Id, Id)> + Expression Vote(Operation operation) { + // TODO(Rodrigo): Handle devices with different warp sizes + const Id predicate = AsBool(Visit(operation[0])); + return {(this->*func)(t_bool, predicate), Type::Bool}; } - Id ThreadId(Operation) { - UNIMPLEMENTED(); - return {}; + Expression ThreadId(Operation) { + return {OpLoad(t_uint, thread_id), Type::Uint}; } - Id ShuffleIndexed(Operation) { - UNIMPLEMENTED(); - return {}; + Expression ShuffleIndexed(Operation operation) { + const Id value = AsFloat(Visit(operation[0])); + const Id index = AsUint(Visit(operation[1])); + return {OpSubgroupReadInvocationKHR(t_float, value, index), Type::Float}; } - Id DeclareBuiltIn(spv::BuiltIn builtin, spv::StorageClass storage, Id type, - const std::string& name) { + Id DeclareBuiltIn(spv::BuiltIn builtin, spv::StorageClass storage, Id type, std::string name) { const Id id = OpVariable(type, storage); Decorate(id, spv::Decoration::BuiltIn, static_cast<u32>(builtin)); - AddGlobalVariable(Name(id, name)); + AddGlobalVariable(Name(id, std::move(name))); interfaces.push_back(id); return id; } + Id DeclareInputBuiltIn(spv::BuiltIn builtin, Id type, std::string name) { + return DeclareBuiltIn(builtin, spv::StorageClass::Input, type, std::move(name)); + } + bool IsRenderTargetUsed(u32 rt) const { for (u32 component = 0; component < 4; ++component) { if (header.ps.IsColorComponentOutputEnabled(rt, component)) { @@ -1242,66 +1999,148 @@ private: members.push_back(Constant(t_uint, element)); } - return Emit(OpAccessChain(pointer_type, composite, members)); + return OpAccessChain(pointer_type, composite, members); } - template <Type type> - Id VisitOperand(Operation operation, std::size_t operand_index) { - const Id value = Visit(operation[operand_index]); - - switch (type) { + Id As(Expression expr, Type wanted_type) { + switch (wanted_type) { case Type::Bool: + return AsBool(expr); case Type::Bool2: + return AsBool2(expr); case Type::Float: - return value; + return AsFloat(expr); case Type::Int: - return Emit(OpBitcast(t_int, value)); + return AsInt(expr); case Type::Uint: - return Emit(OpBitcast(t_uint, value)); + return AsUint(expr); case Type::HalfFloat: - UNIMPLEMENTED(); + return AsHalfFloat(expr); + default: + UNREACHABLE(); + return expr.id; } - UNREACHABLE(); - return value; } - template <Type type> - Id BitcastFrom(Id value) { - switch (type) { - case Type::Bool: - case Type::Bool2: + Id AsBool(Expression expr) { + ASSERT(expr.type == Type::Bool); + return expr.id; + } + + Id AsBool2(Expression expr) { + ASSERT(expr.type == Type::Bool2); + return expr.id; + } + + Id AsFloat(Expression expr) { + switch (expr.type) { case Type::Float: - return value; + return expr.id; case Type::Int: case Type::Uint: - return Emit(OpBitcast(t_float, value)); + return OpBitcast(t_float, expr.id); case Type::HalfFloat: - UNIMPLEMENTED(); + if (device.IsFloat16Supported()) { + return OpBitcast(t_float, expr.id); + } + return OpBitcast(t_float, OpPackHalf2x16(t_uint, expr.id)); + default: + UNREACHABLE(); + return expr.id; } - UNREACHABLE(); - return value; } - template <Type type> - Id BitcastTo(Id value) { - switch (type) { - case Type::Bool: - case Type::Bool2: + Id AsInt(Expression expr) { + switch (expr.type) { + case Type::Int: + return expr.id; + case Type::Float: + case Type::Uint: + return OpBitcast(t_int, expr.id); + case Type::HalfFloat: + if (device.IsFloat16Supported()) { + return OpBitcast(t_int, expr.id); + } + return OpPackHalf2x16(t_int, expr.id); + default: UNREACHABLE(); + return expr.id; + } + } + + Id AsUint(Expression expr) { + switch (expr.type) { + case Type::Uint: + return expr.id; case Type::Float: - return Emit(OpBitcast(t_float, value)); case Type::Int: - return Emit(OpBitcast(t_int, value)); - case Type::Uint: - return Emit(OpBitcast(t_uint, value)); + return OpBitcast(t_uint, expr.id); case Type::HalfFloat: - UNIMPLEMENTED(); + if (device.IsFloat16Supported()) { + return OpBitcast(t_uint, expr.id); + } + return OpPackHalf2x16(t_uint, expr.id); + default: + UNREACHABLE(); + return expr.id; + } + } + + Id AsHalfFloat(Expression expr) { + switch (expr.type) { + case Type::HalfFloat: + return expr.id; + case Type::Float: + case Type::Int: + case Type::Uint: + if (device.IsFloat16Supported()) { + return OpBitcast(t_half, expr.id); + } + return OpUnpackHalf2x16(t_half, AsUint(expr)); + default: + UNREACHABLE(); + return expr.id; + } + } + + Id GetHalfScalarFromFloat(Id value) { + if (device.IsFloat16Supported()) { + return OpFConvert(t_scalar_half, value); } - UNREACHABLE(); return value; } - Id GetTypeDefinition(Type type) { + Id GetFloatFromHalfScalar(Id value) { + if (device.IsFloat16Supported()) { + return OpFConvert(t_float, value); + } + return value; + } + + AttributeType GetAttributeType(u32 location) const { + if (stage != ShaderType::Vertex) { + return {Type::Float, t_in_float, t_in_float4}; + } + switch (specialization.attribute_types.at(location)) { + case Maxwell::VertexAttribute::Type::SignedNorm: + case Maxwell::VertexAttribute::Type::UnsignedNorm: + case Maxwell::VertexAttribute::Type::Float: + return {Type::Float, t_in_float, t_in_float4}; + case Maxwell::VertexAttribute::Type::SignedInt: + return {Type::Int, t_in_int, t_in_int4}; + case Maxwell::VertexAttribute::Type::UnsignedInt: + return {Type::Uint, t_in_uint, t_in_uint4}; + case Maxwell::VertexAttribute::Type::UnsignedScaled: + case Maxwell::VertexAttribute::Type::SignedScaled: + UNIMPLEMENTED(); + return {Type::Float, t_in_float, t_in_float4}; + default: + UNREACHABLE(); + return {Type::Float, t_in_float, t_in_float4}; + } + } + + Id GetTypeDefinition(Type type) const { switch (type) { case Type::Bool: return t_bool; @@ -1314,10 +2153,25 @@ private: case Type::Uint: return t_uint; case Type::HalfFloat: + return t_half; + default: + UNREACHABLE(); + return {}; + } + } + + std::array<Id, 4> GetTypeVectorDefinitionLut(Type type) const { + switch (type) { + case Type::Float: + return {nullptr, t_float2, t_float3, t_float4}; + case Type::Int: + return {nullptr, t_int2, t_int3, t_int4}; + case Type::Uint: + return {nullptr, t_uint2, t_uint3, t_uint4}; + default: UNIMPLEMENTED(); + return {}; } - UNREACHABLE(); - return {}; } std::tuple<Id, Id> CreateFlowStack() { @@ -1327,9 +2181,11 @@ private: constexpr auto storage_class = spv::StorageClass::Function; const Id flow_stack_type = TypeArray(t_uint, Constant(t_uint, FLOW_STACK_SIZE)); - const Id stack = Emit(OpVariable(TypePointer(storage_class, flow_stack_type), storage_class, - ConstantNull(flow_stack_type))); - const Id top = Emit(OpVariable(t_func_uint, storage_class, Constant(t_uint, 0))); + const Id stack = OpVariable(TypePointer(storage_class, flow_stack_type), storage_class, + ConstantNull(flow_stack_type)); + const Id top = OpVariable(t_func_uint, storage_class, Constant(t_uint, 0)); + AddLocalVariable(stack); + AddLocalVariable(top); return std::tie(stack, top); } @@ -1358,8 +2214,8 @@ private: &SPIRVDecompiler::Unary<&Module::OpFNegate, Type::Float>, &SPIRVDecompiler::Unary<&Module::OpFAbs, Type::Float>, &SPIRVDecompiler::Ternary<&Module::OpFClamp, Type::Float>, - &SPIRVDecompiler::FCastHalf0, - &SPIRVDecompiler::FCastHalf1, + &SPIRVDecompiler::FCastHalf<0>, + &SPIRVDecompiler::FCastHalf<1>, &SPIRVDecompiler::Binary<&Module::OpFMin, Type::Float>, &SPIRVDecompiler::Binary<&Module::OpFMax, Type::Float>, &SPIRVDecompiler::Unary<&Module::OpCos, Type::Float>, @@ -1407,7 +2263,7 @@ private: &SPIRVDecompiler::Unary<&Module::OpBitcast, Type::Uint, Type::Int>, &SPIRVDecompiler::Binary<&Module::OpShiftLeftLogical, Type::Uint>, &SPIRVDecompiler::Binary<&Module::OpShiftRightLogical, Type::Uint>, - &SPIRVDecompiler::Binary<&Module::OpShiftRightArithmetic, Type::Uint>, + &SPIRVDecompiler::Binary<&Module::OpShiftRightLogical, Type::Uint>, &SPIRVDecompiler::Binary<&Module::OpBitwiseAnd, Type::Uint>, &SPIRVDecompiler::Binary<&Module::OpBitwiseOr, Type::Uint>, &SPIRVDecompiler::Binary<&Module::OpBitwiseXor, Type::Uint>, @@ -1426,8 +2282,8 @@ private: &SPIRVDecompiler::HCastFloat, &SPIRVDecompiler::HUnpack, &SPIRVDecompiler::HMergeF32, - &SPIRVDecompiler::HMergeH0, - &SPIRVDecompiler::HMergeH1, + &SPIRVDecompiler::HMergeHN<0>, + &SPIRVDecompiler::HMergeHN<1>, &SPIRVDecompiler::HPack2, &SPIRVDecompiler::LogicalAssign, @@ -1435,8 +2291,9 @@ private: &SPIRVDecompiler::Binary<&Module::OpLogicalOr, Type::Bool>, &SPIRVDecompiler::Binary<&Module::OpLogicalNotEqual, Type::Bool>, &SPIRVDecompiler::Unary<&Module::OpLogicalNot, Type::Bool>, - &SPIRVDecompiler::LogicalPick2, - &SPIRVDecompiler::LogicalAnd2, + &SPIRVDecompiler::Binary<&Module::OpVectorExtractDynamic, Type::Bool, Type::Bool2, + Type::Uint>, + &SPIRVDecompiler::Unary<&Module::OpAll, Type::Bool, Type::Bool2>, &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::Float>, &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::Float>, @@ -1444,7 +2301,7 @@ private: &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::Float>, &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::Float>, &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::Float>, - &SPIRVDecompiler::Unary<&Module::OpIsNan, Type::Bool>, + &SPIRVDecompiler::Unary<&Module::OpIsNan, Type::Bool, Type::Float>, &SPIRVDecompiler::Binary<&Module::OpSLessThan, Type::Bool, Type::Int>, &SPIRVDecompiler::Binary<&Module::OpIEqual, Type::Bool, Type::Int>, @@ -1460,19 +2317,19 @@ private: &SPIRVDecompiler::Binary<&Module::OpINotEqual, Type::Bool, Type::Uint>, &SPIRVDecompiler::Binary<&Module::OpUGreaterThanEqual, Type::Bool, Type::Uint>, - &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::HalfFloat>, - &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::HalfFloat>, - &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool, Type::HalfFloat>, - &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::HalfFloat>, - &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::HalfFloat>, - &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::HalfFloat>, + &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool2, Type::HalfFloat>, + &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool2, Type::HalfFloat>, + &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool2, Type::HalfFloat>, + &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool2, Type::HalfFloat>, + &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool2, Type::HalfFloat>, + &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool2, Type::HalfFloat>, // TODO(Rodrigo): Should these use the OpFUnord* variants? - &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::HalfFloat>, - &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::HalfFloat>, - &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool, Type::HalfFloat>, - &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::HalfFloat>, - &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::HalfFloat>, - &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::HalfFloat>, + &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool2, Type::HalfFloat>, + &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool2, Type::HalfFloat>, + &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool2, Type::HalfFloat>, + &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool2, Type::HalfFloat>, + &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool2, Type::HalfFloat>, + &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool2, Type::HalfFloat>, &SPIRVDecompiler::Texture, &SPIRVDecompiler::TextureLod, @@ -1509,9 +2366,9 @@ private: &SPIRVDecompiler::WorkGroupId<2>, &SPIRVDecompiler::BallotThread, - &SPIRVDecompiler::VoteAll, - &SPIRVDecompiler::VoteAny, - &SPIRVDecompiler::VoteEqual, + &SPIRVDecompiler::Vote<&Module::OpSubgroupAllKHR>, + &SPIRVDecompiler::Vote<&Module::OpSubgroupAnyKHR>, + &SPIRVDecompiler::Vote<&Module::OpSubgroupAllEqualKHR>, &SPIRVDecompiler::ThreadId, &SPIRVDecompiler::ShuffleIndexed, @@ -1522,8 +2379,7 @@ private: const ShaderIR& ir; const ShaderType stage; const Tegra::Shader::Header header; - u64 conditional_nest_count{}; - u64 inside_branch{}; + const Specialization& specialization; const Id t_void = Name(TypeVoid(), "void"); @@ -1551,20 +2407,28 @@ private: const Id t_func_uint = Name(TypePointer(spv::StorageClass::Function, t_uint), "func_uint"); const Id t_in_bool = Name(TypePointer(spv::StorageClass::Input, t_bool), "in_bool"); + const Id t_in_int = Name(TypePointer(spv::StorageClass::Input, t_int), "in_int"); + const Id t_in_int4 = Name(TypePointer(spv::StorageClass::Input, t_int4), "in_int4"); const Id t_in_uint = Name(TypePointer(spv::StorageClass::Input, t_uint), "in_uint"); + const Id t_in_uint3 = Name(TypePointer(spv::StorageClass::Input, t_uint3), "in_uint3"); + const Id t_in_uint4 = Name(TypePointer(spv::StorageClass::Input, t_uint4), "in_uint4"); const Id t_in_float = Name(TypePointer(spv::StorageClass::Input, t_float), "in_float"); + const Id t_in_float2 = Name(TypePointer(spv::StorageClass::Input, t_float2), "in_float2"); + const Id t_in_float3 = Name(TypePointer(spv::StorageClass::Input, t_float3), "in_float3"); const Id t_in_float4 = Name(TypePointer(spv::StorageClass::Input, t_float4), "in_float4"); + const Id t_out_int = Name(TypePointer(spv::StorageClass::Output, t_int), "out_int"); + const Id t_out_float = Name(TypePointer(spv::StorageClass::Output, t_float), "out_float"); const Id t_out_float4 = Name(TypePointer(spv::StorageClass::Output, t_float4), "out_float4"); const Id t_cbuf_float = TypePointer(spv::StorageClass::Uniform, t_float); const Id t_cbuf_std140 = Decorate( - Name(TypeArray(t_float4, Constant(t_uint, MAX_CONSTBUFFER_ELEMENTS)), "CbufStd140Array"), - spv::Decoration::ArrayStride, 16u); + Name(TypeArray(t_float4, Constant(t_uint, MaxConstBufferElements)), "CbufStd140Array"), + spv::Decoration::ArrayStride, 16U); const Id t_cbuf_scalar = Decorate( - Name(TypeArray(t_float, Constant(t_uint, MAX_CONSTBUFFER_FLOATS)), "CbufScalarArray"), - spv::Decoration::ArrayStride, 4u); + Name(TypeArray(t_float, Constant(t_uint, MaxConstBufferFloats)), "CbufScalarArray"), + spv::Decoration::ArrayStride, 4U); const Id t_cbuf_std140_struct = MemberDecorate( Decorate(TypeStruct(t_cbuf_std140), spv::Decoration::Block), 0, spv::Decoration::Offset, 0); const Id t_cbuf_scalar_struct = MemberDecorate( @@ -1572,28 +2436,43 @@ private: const Id t_cbuf_std140_ubo = TypePointer(spv::StorageClass::Uniform, t_cbuf_std140_struct); const Id t_cbuf_scalar_ubo = TypePointer(spv::StorageClass::Uniform, t_cbuf_scalar_struct); + Id t_smem_uint{}; + const Id t_gmem_float = TypePointer(spv::StorageClass::StorageBuffer, t_float); const Id t_gmem_array = - Name(Decorate(TypeRuntimeArray(t_float), spv::Decoration::ArrayStride, 4u), "GmemArray"); + Name(Decorate(TypeRuntimeArray(t_float), spv::Decoration::ArrayStride, 4U), "GmemArray"); const Id t_gmem_struct = MemberDecorate( Decorate(TypeStruct(t_gmem_array), spv::Decoration::Block), 0, spv::Decoration::Offset, 0); const Id t_gmem_ssbo = TypePointer(spv::StorageClass::StorageBuffer, t_gmem_struct); const Id v_float_zero = Constant(t_float, 0.0f); + const Id v_float_one = Constant(t_float, 1.0f); + + // Nvidia uses these defaults for varyings (e.g. position and generic attributes) + const Id v_varying_default = + ConstantComposite(t_float4, v_float_zero, v_float_zero, v_float_zero, v_float_one); + const Id v_true = ConstantTrue(t_bool); const Id v_false = ConstantFalse(t_bool); - Id per_vertex{}; + Id t_scalar_half{}; + Id t_half{}; + + Id out_vertex{}; + Id in_vertex{}; std::map<u32, Id> registers; std::map<Tegra::Shader::Pred, Id> predicates; std::map<u32, Id> flow_variables; Id local_memory{}; + Id shared_memory{}; std::array<Id, INTERNAL_FLAGS_COUNT> internal_flags{}; std::map<Attribute::Index, Id> input_attributes; std::map<Attribute::Index, Id> output_attributes; std::map<u32, Id> constant_buffers; std::map<GlobalMemoryBase, Id> global_buffers; - std::map<u32, SamplerImage> sampler_images; + std::map<u32, TexelBuffer> texel_buffers; + std::map<u32, SampledImage> sampled_images; + std::map<u32, StorageImage> images; Id instance_index{}; Id vertex_index{}; @@ -1601,18 +2480,20 @@ private: Id frag_depth{}; Id frag_coord{}; Id front_facing{}; - - u32 position_index{}; - u32 point_size_index{}; - u32 clip_distances_index{}; + Id point_coord{}; + Id tess_level_outer{}; + Id tess_level_inner{}; + Id tess_coord{}; + Id invocation_id{}; + Id workgroup_id{}; + Id local_invocation_id{}; + Id thread_id{}; + + VertexIndices in_indices; + VertexIndices out_indices; std::vector<Id> interfaces; - u32 const_buffers_base_binding{}; - u32 global_buffers_base_binding{}; - u32 samplers_base_binding{}; - - Id execute_function{}; Id jmp_to{}; Id ssy_flow_stack_top{}; Id pbk_flow_stack_top{}; @@ -1620,6 +2501,9 @@ private: Id pbk_flow_stack{}; Id continue_label{}; std::map<u32, Id> labels; + + bool conditional_branch_set{}; + bool inside_branch{}; }; class ExprDecompiler { @@ -1630,25 +2514,25 @@ public: const Id type_def = decomp.GetTypeDefinition(Type::Bool); const Id op1 = Visit(expr.operand1); const Id op2 = Visit(expr.operand2); - return decomp.Emit(decomp.OpLogicalAnd(type_def, op1, op2)); + return decomp.OpLogicalAnd(type_def, op1, op2); } Id operator()(const ExprOr& expr) { const Id type_def = decomp.GetTypeDefinition(Type::Bool); const Id op1 = Visit(expr.operand1); const Id op2 = Visit(expr.operand2); - return decomp.Emit(decomp.OpLogicalOr(type_def, op1, op2)); + return decomp.OpLogicalOr(type_def, op1, op2); } Id operator()(const ExprNot& expr) { const Id type_def = decomp.GetTypeDefinition(Type::Bool); const Id op1 = Visit(expr.operand1); - return decomp.Emit(decomp.OpLogicalNot(type_def, op1)); + return decomp.OpLogicalNot(type_def, op1); } Id operator()(const ExprPredicate& expr) { const auto pred = static_cast<Tegra::Shader::Pred>(expr.predicate); - return decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.predicates.at(pred))); + return decomp.OpLoad(decomp.t_bool, decomp.predicates.at(pred)); } Id operator()(const ExprCondCode& expr) { @@ -1670,12 +2554,15 @@ public: } } else if (const auto flag = std::get_if<InternalFlagNode>(&*cc)) { target = decomp.internal_flags.at(static_cast<u32>(flag->GetFlag())); + } else { + UNREACHABLE(); } - return decomp.Emit(decomp.OpLoad(decomp.t_bool, target)); + + return decomp.OpLoad(decomp.t_bool, target); } Id operator()(const ExprVar& expr) { - return decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.flow_variables.at(expr.var_index))); + return decomp.OpLoad(decomp.t_bool, decomp.flow_variables.at(expr.var_index)); } Id operator()(const ExprBoolean& expr) { @@ -1684,9 +2571,9 @@ public: Id operator()(const ExprGprEqual& expr) { const Id target = decomp.Constant(decomp.t_uint, expr.value); - const Id gpr = decomp.BitcastTo<Type::Uint>( - decomp.Emit(decomp.OpLoad(decomp.t_float, decomp.registers.at(expr.gpr)))); - return decomp.Emit(decomp.OpLogicalEqual(decomp.t_uint, gpr, target)); + Id gpr = decomp.OpLoad(decomp.t_float, decomp.registers.at(expr.gpr)); + gpr = decomp.OpBitcast(decomp.t_uint, gpr); + return decomp.OpLogicalEqual(decomp.t_uint, gpr, target); } Id Visit(const Expr& node) { @@ -1714,16 +2601,16 @@ public: const Id condition = expr_parser.Visit(ast.condition); const Id then_label = decomp.OpLabel(); const Id endif_label = decomp.OpLabel(); - decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone)); - decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label)); - decomp.Emit(then_label); + decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone); + decomp.OpBranchConditional(condition, then_label, endif_label); + decomp.AddLabel(then_label); ASTNode current = ast.nodes.GetFirst(); while (current) { Visit(current); current = current->GetNext(); } - decomp.Emit(decomp.OpBranch(endif_label)); - decomp.Emit(endif_label); + decomp.OpBranch(endif_label); + decomp.AddLabel(endif_label); } void operator()([[maybe_unused]] const ASTIfElse& ast) { @@ -1741,7 +2628,7 @@ public: void operator()(const ASTVarSet& ast) { ExprDecompiler expr_parser{decomp}; const Id condition = expr_parser.Visit(ast.condition); - decomp.Emit(decomp.OpStore(decomp.flow_variables.at(ast.index), condition)); + decomp.OpStore(decomp.flow_variables.at(ast.index), condition); } void operator()([[maybe_unused]] const ASTLabel& ast) { @@ -1758,12 +2645,11 @@ public: const Id loop_start_block = decomp.OpLabel(); const Id loop_end_block = decomp.OpLabel(); current_loop_exit = endloop_label; - decomp.Emit(decomp.OpBranch(loop_label)); - decomp.Emit(loop_label); - decomp.Emit( - decomp.OpLoopMerge(endloop_label, loop_end_block, spv::LoopControlMask::MaskNone)); - decomp.Emit(decomp.OpBranch(loop_start_block)); - decomp.Emit(loop_start_block); + decomp.OpBranch(loop_label); + decomp.AddLabel(loop_label); + decomp.OpLoopMerge(endloop_label, loop_end_block, spv::LoopControlMask::MaskNone); + decomp.OpBranch(loop_start_block); + decomp.AddLabel(loop_start_block); ASTNode current = ast.nodes.GetFirst(); while (current) { Visit(current); @@ -1771,8 +2657,8 @@ public: } ExprDecompiler expr_parser{decomp}; const Id condition = expr_parser.Visit(ast.condition); - decomp.Emit(decomp.OpBranchConditional(condition, loop_label, endloop_label)); - decomp.Emit(endloop_label); + decomp.OpBranchConditional(condition, loop_label, endloop_label); + decomp.AddLabel(endloop_label); } void operator()(const ASTReturn& ast) { @@ -1781,27 +2667,27 @@ public: const Id condition = expr_parser.Visit(ast.condition); const Id then_label = decomp.OpLabel(); const Id endif_label = decomp.OpLabel(); - decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone)); - decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label)); - decomp.Emit(then_label); + decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone); + decomp.OpBranchConditional(condition, then_label, endif_label); + decomp.AddLabel(then_label); if (ast.kills) { - decomp.Emit(decomp.OpKill()); + decomp.OpKill(); } else { decomp.PreExit(); - decomp.Emit(decomp.OpReturn()); + decomp.OpReturn(); } - decomp.Emit(endif_label); + decomp.AddLabel(endif_label); } else { const Id next_block = decomp.OpLabel(); - decomp.Emit(decomp.OpBranch(next_block)); - decomp.Emit(next_block); + decomp.OpBranch(next_block); + decomp.AddLabel(next_block); if (ast.kills) { - decomp.Emit(decomp.OpKill()); + decomp.OpKill(); } else { decomp.PreExit(); - decomp.Emit(decomp.OpReturn()); + decomp.OpReturn(); } - decomp.Emit(decomp.OpLabel()); + decomp.AddLabel(decomp.OpLabel()); } } @@ -1811,17 +2697,17 @@ public: const Id condition = expr_parser.Visit(ast.condition); const Id then_label = decomp.OpLabel(); const Id endif_label = decomp.OpLabel(); - decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone)); - decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label)); - decomp.Emit(then_label); - decomp.Emit(decomp.OpBranch(current_loop_exit)); - decomp.Emit(endif_label); + decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone); + decomp.OpBranchConditional(condition, then_label, endif_label); + decomp.AddLabel(then_label); + decomp.OpBranch(current_loop_exit); + decomp.AddLabel(endif_label); } else { const Id next_block = decomp.OpLabel(); - decomp.Emit(decomp.OpBranch(next_block)); - decomp.Emit(next_block); - decomp.Emit(decomp.OpBranch(current_loop_exit)); - decomp.Emit(decomp.OpLabel()); + decomp.OpBranch(next_block); + decomp.AddLabel(next_block); + decomp.OpBranch(current_loop_exit); + decomp.AddLabel(decomp.OpLabel()); } } @@ -1842,20 +2728,51 @@ void SPIRVDecompiler::DecompileAST() { flow_variables.emplace(i, AddGlobalVariable(id)); } + DefinePrologue(); + const ASTNode program = ir.GetASTProgram(); ASTDecompiler decompiler{*this}; decompiler.Visit(program); const Id next_block = OpLabel(); - Emit(OpBranch(next_block)); - Emit(next_block); + OpBranch(next_block); + AddLabel(next_block); +} + +} // Anonymous namespace + +ShaderEntries GenerateShaderEntries(const VideoCommon::Shader::ShaderIR& ir) { + ShaderEntries entries; + for (const auto& cbuf : ir.GetConstantBuffers()) { + entries.const_buffers.emplace_back(cbuf.second, cbuf.first); + } + for (const auto& [base, usage] : ir.GetGlobalMemory()) { + entries.global_buffers.emplace_back(base.cbuf_index, base.cbuf_offset, usage.is_written); + } + for (const auto& sampler : ir.GetSamplers()) { + if (sampler.IsBuffer()) { + entries.texel_buffers.emplace_back(sampler); + } else { + entries.samplers.emplace_back(sampler); + } + } + for (const auto& image : ir.GetImages()) { + entries.images.emplace_back(image); + } + for (const auto& attribute : ir.GetInputAttributes()) { + if (IsGenericAttribute(attribute)) { + entries.attributes.insert(GetGenericAttributeLocation(attribute)); + } + } + entries.clip_distances = ir.GetClipDistances(); + entries.shader_length = ir.GetLength(); + entries.uses_warps = ir.UsesWarps(); + return entries; } -DecompilerResult Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir, - ShaderType stage) { - auto decompiler = std::make_unique<SPIRVDecompiler>(device, ir, stage); - decompiler->Decompile(); - return {std::move(decompiler), decompiler->GetShaderEntries()}; +std::vector<u32> Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir, + ShaderType stage, const Specialization& specialization) { + return SPIRVDecompiler(device, ir, stage, specialization).Assemble(); } -} // namespace Vulkan::VKShader +} // namespace Vulkan diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.h b/src/video_core/renderer_vulkan/vk_shader_decompiler.h index 203fc00d0..2b01321b6 100644 --- a/src/video_core/renderer_vulkan/vk_shader_decompiler.h +++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.h @@ -5,29 +5,28 @@ #pragma once #include <array> +#include <bitset> #include <memory> #include <set> +#include <type_traits> #include <utility> #include <vector> -#include <sirit/sirit.h> - #include "common/common_types.h" #include "video_core/engines/maxwell_3d.h" +#include "video_core/engines/shader_type.h" #include "video_core/shader/shader_ir.h" -namespace VideoCommon::Shader { -class ShaderIR; -} - namespace Vulkan { class VKDevice; } -namespace Vulkan::VKShader { +namespace Vulkan { using Maxwell = Tegra::Engines::Maxwell3D::Regs; +using TexelBufferEntry = VideoCommon::Shader::Sampler; using SamplerEntry = VideoCommon::Shader::Sampler; +using ImageEntry = VideoCommon::Shader::Image; constexpr u32 DESCRIPTOR_SET = 0; @@ -46,39 +45,74 @@ private: class GlobalBufferEntry { public: - explicit GlobalBufferEntry(u32 cbuf_index, u32 cbuf_offset) - : cbuf_index{cbuf_index}, cbuf_offset{cbuf_offset} {} + constexpr explicit GlobalBufferEntry(u32 cbuf_index, u32 cbuf_offset, bool is_written) + : cbuf_index{cbuf_index}, cbuf_offset{cbuf_offset}, is_written{is_written} {} - u32 GetCbufIndex() const { + constexpr u32 GetCbufIndex() const { return cbuf_index; } - u32 GetCbufOffset() const { + constexpr u32 GetCbufOffset() const { return cbuf_offset; } + constexpr bool IsWritten() const { + return is_written; + } + private: u32 cbuf_index{}; u32 cbuf_offset{}; + bool is_written{}; }; struct ShaderEntries { - u32 const_buffers_base_binding{}; - u32 global_buffers_base_binding{}; - u32 samplers_base_binding{}; + u32 NumBindings() const { + return static_cast<u32>(const_buffers.size() + global_buffers.size() + + texel_buffers.size() + samplers.size() + images.size()); + } + std::vector<ConstBufferEntry> const_buffers; std::vector<GlobalBufferEntry> global_buffers; + std::vector<TexelBufferEntry> texel_buffers; std::vector<SamplerEntry> samplers; + std::vector<ImageEntry> images; std::set<u32> attributes; std::array<bool, Maxwell::NumClipDistances> clip_distances{}; std::size_t shader_length{}; - Sirit::Id entry_function{}; - std::vector<Sirit::Id> interfaces; + bool uses_warps{}; +}; + +struct Specialization final { + u32 base_binding{}; + + // Compute specific + std::array<u32, 3> workgroup_size{}; + u32 shared_memory_size{}; + + // Graphics specific + Maxwell::PrimitiveTopology primitive_topology{}; + std::optional<float> point_size{}; + std::array<Maxwell::VertexAttribute::Type, Maxwell::NumVertexAttributes> attribute_types{}; + + // Tessellation specific + struct { + Maxwell::TessellationPrimitive primitive{}; + Maxwell::TessellationSpacing spacing{}; + bool clockwise{}; + } tessellation; +}; +// Old gcc versions don't consider this trivially copyable. +// static_assert(std::is_trivially_copyable_v<Specialization>); + +struct SPIRVShader { + std::vector<u32> code; + ShaderEntries entries; }; -using DecompilerResult = std::pair<std::unique_ptr<Sirit::Module>, ShaderEntries>; +ShaderEntries GenerateShaderEntries(const VideoCommon::Shader::ShaderIR& ir); -DecompilerResult Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir, - Tegra::Engines::ShaderType stage); +std::vector<u32> Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir, + Tegra::Engines::ShaderType stage, const Specialization& specialization); -} // namespace Vulkan::VKShader +} // namespace Vulkan |