diff --git a/src/utils/shader_decompiler_dxc.hpp b/src/utils/shader_decompiler_dxc.hpp index 405cc695..e54f3131 100644 --- a/src/utils/shader_decompiler_dxc.hpp +++ b/src/utils/shader_decompiler_dxc.hpp @@ -50,6 +50,7 @@ enum class TokenizerState : uint32_t { DESCRIPTION_OUTPUT_SIG_TABLE_END, DESCRIPTION_SHADER_DEBUG_NAME, DESCRIPTION_SHADER_HASH, + DESCRIPTION_FUNCTIONALITY_NOTE, DESCRIPTION_PIPELINE_RUNTIME_TITLE, DESCRIPTION_PIPELINE_RUNTIME_WHITESPACE, DESCRIPTION_PIPELINE_RUNTIME_INFO, @@ -238,7 +239,9 @@ class Decompiler { FromStringView(input, value); } output = std::format("{}", value); - if (output.find('.') == std::string::npos) { + bool has_dot = output.find('.') != std::string::npos; + bool has_plus = output.find('+') != std::string::npos; + if (!has_dot && !has_plus) { output += ".0f"; } else { output += "f"; @@ -310,8 +313,8 @@ class Decompiler { {"27", "floor"}, {"28", "ceil"}, {"29", "trunc"}, - {"83", "ddx_course"}, - {"84", "ddy_course"}, + {"83", "ddx_coarse"}, + {"84", "ddy_coarse"}, {"85", "ddx_fine"}, {"86", "ddy_fine"}, }; @@ -454,6 +457,7 @@ class Decompiler { NOINTERPOLATION, NOPERSPECTIVE, LINEAR, + CENTROID, } interp_mode; int32_t dyn_index = -1; @@ -472,6 +476,7 @@ class Decompiler { if (input == "nointerpolation") return InterpMode::NOINTERPOLATION; if (input == "noperspective") return InterpMode::NOPERSPECTIVE; if (input == "linear") return InterpMode::LINEAR; + if (input == "centroid") return InterpMode::CENTROID; throw std::invalid_argument("Unknown InterpMode"); } @@ -886,7 +891,9 @@ class Decompiler { : Resource(metadata, resource_descriptions, "t") { static auto pointer_regex = std::regex{R"(^(?:\[(\S+) x )?%"class\.([^<]+)<(?:vector<)?([^,>]+)(?:, ([^>]+)>)? ?>"\]?\*)"}; const auto [array_size, class_name, base_type, type_count] = StringViewMatch<4>(this->pointer, pointer_regex); - this->data_type = std::format("{}{}", base_type, type_count); + + auto base_type_fixed = DataType::FixBaseType(base_type); + this->data_type = std::format("{}{}", base_type_fixed, type_count); // https://github.com/microsoft/DirectXShaderCompiler/blob/b766b432678cf5f7a93567d253bb5f7fd8a0b2c7/docs/DXIL.rst#L1047 uint32_t shape; @@ -972,6 +979,12 @@ class Decompiler { size_t vector_size; std::string data_type; + static std::string FixBaseType(std::string_view data_type) { + if (data_type == "i32") return "int"; + if (data_type == "unsigned int") return "uint"; + return std::string(data_type); + } + explicit DataType(std::string_view line) { static auto regex = std::regex{R"(^(?:\[(\S+) x )?(?:<(\S+) x )?(\w+)(\*)?>?\]?$)"}; const auto [array_size, vector_size, data_type, is_pointer] = StringViewMatch<4>(line, regex); @@ -985,11 +998,7 @@ class Decompiler { } else { FromStringView(vector_size, this->vector_size); } - if (data_type == "i32") { - this->data_type = "int"; - } else { - this->data_type = data_type; - } + this->data_type = FixBaseType(data_type); } }; @@ -1609,7 +1618,55 @@ class Decompiler { auto sampler_resource = preprocess_state.sampler_resources[sampler_range_index]; decompiled = std::format("{} _{} = {}.SampleCmpLevelZero({}, {}, {});", srv_resource.data_type, variable, srv_name, sampler_name, coords, ParseFloat(compareValue)); + } else if (functionName == "@dx.op.textureGather.i32" || functionName == "@dx.op.textureGather.f32") { + // %67 = call %dx.types.ResRet.i32 @dx.op.textureGather.i32(i32 73, %dx.types.Handle %3, %dx.types.Handle %11, float %53, float %54, float undef, float undef, i32 0, i32 0, i32 0) ; TextureGather(srv,sampler,coord0,coord1,coord2,coord3,offset0,offset1,channel) + auto [op, srv, sampler, coord0, coord1, coord2, coord3, offset0, offset1, channel] = StringViewSplit<10>(functionParamsString, param_regex, 2); + auto ref_resource = std::string{srv.substr(1)}; + auto ref_sampler = std::string{sampler.substr(1)}; + const bool has_coord_z = coord2 != "undef"; + const bool has_coord_w = coord3 != "undef"; + const bool has_offset_y = offset1 != "undef"; + std::string coords; + if (has_coord_w) { + coords = std::format("float4({}, {}, {}, {})", ParseFloat(coord0), ParseFloat(coord1), ParseFloat(coord2), ParseFloat(coord3)); + } else if (has_coord_z) { + coords = std::format("float3({}, {}, {})", ParseFloat(coord0), ParseFloat(coord1), ParseFloat(coord2)); + } else { + coords = std::format("float2({}, {})", ParseFloat(coord0), ParseFloat(coord1)); + } + std::string offset; + if (has_offset_y) { + offset = std::format("int2({}, {})", ParseInt(offset0), ParseInt(offset1)); + } else { + offset = std::format("{}", ParseInt(offset0)); + } + + auto [srv_name, srv_range_index] = preprocess_state.resource_binding_variables.at(ref_resource); + auto srv_resource = preprocess_state.srv_resources[srv_range_index]; + auto [sampler_name, sampler_range_index] = preprocess_state.resource_binding_variables.at(ref_sampler); + auto sampler_resource = preprocess_state.sampler_resources[sampler_range_index]; + std::string channel_string; + if (channel == "0") { + channel_string = "Red"; + } else if (channel == "1") { + channel_string = "Green"; + } else if (channel == "2") { + channel_string = "Blue"; + } else if (channel == "3") { + channel_string = "Alpha"; + } else { + throw std::exception("Unknown Gather channel."); + } + if (functionName == "@dx.op.textureGather.i32") { + decompiled = std::format("int4 _{} = {}.Gather{}({}, {});", variable, srv_name, channel_string, sampler_name, coords); + } else { + decompiled = std::format("float4 _{} = {}.Gather{}({}, {});", variable, srv_name, channel_string, sampler_name, coords); + } + } else if (functionName == "@dx.op.waveReadLaneFirst.f32") { + // %504 = call float @dx.op.waveReadLaneFirst.f32(i32 118, float %503) ; WaveReadLaneFirst(value) + auto [op, value] = StringViewSplit<2>(functionParamsString, param_regex, 2); + decompiled = std::format("float _{} = WaveReadLaneFirst({});", variable, ParseFloat(value)); } else { std::cerr << line << "\n"; std::cerr << "Function name: " << functionName << "\n"; @@ -1957,7 +2014,7 @@ class Decompiler { } else if (functionName == "@dx.op.discard") { // call void @dx.op.discard(i32 82, i1 true) ; Discard(condition) auto [opNumber, condition] = StringViewSplit<2>(functionParamsString, param_regex, 2); - decompiled = std::format("discard({})", ParseBool(condition)); + decompiled = std::format("if ({}) discard;", ParseBool(condition)); } else if (functionName == "@dx.op.barrier") { // @dx.op.barrier(i32 80, i32 9) ; Barrier(barrierMode) auto [opNumber, barrierMode] = StringViewSplit<2>(functionParamsString, param_regex, 2); @@ -2250,6 +2307,8 @@ class Decompiler { state = TokenizerState::DESCRIPTION_RESOURCE_BINDINGS_TITLE; } else if (line == "; ViewId state:") { state = TokenizerState::DESCRIPTION_VIEW_ID_STATE_TITLE; + } else if (line == "; Note: shader requires additional functionality:") { + state = TokenizerState::DESCRIPTION_FUNCTIONALITY_NOTE; } else { throw std::invalid_argument("Unexpected description entry"); } @@ -2261,6 +2320,12 @@ class Decompiler { state++; line_number++; break; + case TokenizerState::DESCRIPTION_FUNCTIONALITY_NOTE: + if (line == ";") { + state = TokenizerState::DESCRIPTION_WHITESPACE; + } + line_number++; + break; case TokenizerState::DESCRIPTION_INPUT_SIG_TABLE_ROW: if (line == ";") { state = TokenizerState::DESCRIPTION_WHITESPACE;