diff --git a/src/video_core/host_shaders/astc_decoder.comp b/src/video_core/host_shaders/astc_decoder.comp index 71a284e206..e62fc677ec 100644 --- a/src/video_core/host_shaders/astc_decoder.comp +++ b/src/video_core/host_shaders/astc_decoder.comp @@ -118,7 +118,7 @@ EncodingData CreateEncodingData(uint encoding, uint num_bits, uint bit_val, uint void ResultEmplaceBack(EncodingData val) { if (result_index >= result_vector_max_index) { - // Alert callers to avoid decoding more than needed by this phase + // Alert callers to avoid decoding more than needed by this phase result_limit_reached = true; return; } @@ -138,19 +138,18 @@ uint ReplicateBitTo9(uint value) { return value * 511; } +// Assumes num_bits < to_bit, num_bits and to_bit != 0 uint ReplicateBits(uint value, uint num_bits, uint to_bit) { - const uint v = value & uint((1 << num_bits) - 1); - uint res = v; - uint reslen = num_bits; - while (reslen < to_bit) { - const uint num_dst_bits_to_shift_up = min(num_bits, to_bit - reslen); - const uint num_src_bits_to_shift_down = num_bits - num_dst_bits_to_shift_up; - - res <<= num_dst_bits_to_shift_up; - res |= (v >> num_src_bits_to_shift_down); - reslen += num_bits; - } - return res; + const uint repl = value & ((1 << num_bits) - 1); + uint v = repl; + v |= v << (num_bits << 0); // [ xxxx xxxr ] + v |= v << (num_bits << 1); // [ xxxx xxrr ] + v |= v << (num_bits << 2); // [ xxxx rrrr ] + v |= v << (num_bits << 3); // [ rrrr rrrr ] + const uint shift = (to_bit % num_bits); + v <<= shift; + v |= repl >> (num_bits - shift); + return v & ((1 << to_bit) - 1); } uint FastReplicateTo8(uint value, uint num_bits) { @@ -191,72 +190,32 @@ uint Hash52(uint p) { return p; } -uint Select2DPartition(uint seed, uint x, uint y, uint partition_count) { - if ((block_dims.y * block_dims.x) < 32) { - x <<= 1; - y <<= 1; - } - +uint Select2DPartition(uint seed, uvec2 pos, uint partition_count) { + pos <<= uint(block_dims.y * block_dims.x < 32); seed += (partition_count - 1) * 1024; - const uint rnum = Hash52(uint(seed)); - uint seed1 = uint(rnum & 0xF); - uint seed2 = uint((rnum >> 4) & 0xF); - uint seed3 = uint((rnum >> 8) & 0xF); - uint seed4 = uint((rnum >> 12) & 0xF); - uint seed5 = uint((rnum >> 16) & 0xF); - uint seed6 = uint((rnum >> 20) & 0xF); - uint seed7 = uint((rnum >> 24) & 0xF); - uint seed8 = uint((rnum >> 28) & 0xF); - - seed1 = (seed1 * seed1); - seed2 = (seed2 * seed2); - seed3 = (seed3 * seed3); - seed4 = (seed4 * seed4); - seed5 = (seed5 * seed5); - seed6 = (seed6 * seed6); - seed7 = (seed7 * seed7); - seed8 = (seed8 * seed8); - - uint sh1, sh2; - if ((seed & 1) > 0) { - sh1 = (seed & 2) > 0 ? 4 : 5; - sh2 = (partition_count == 3) ? 6 : 5; - } else { - sh1 = (partition_count == 3) ? 6 : 5; - sh2 = (seed & 2) > 0 ? 4 : 5; - } - seed1 >>= sh1; - seed2 >>= sh2; - seed3 >>= sh1; - seed4 >>= sh2; - seed5 >>= sh1; - seed6 >>= sh2; - seed7 >>= sh1; - seed8 >>= sh2; - - uint a = seed1 * x + seed2 * y + (rnum >> 14); - uint b = seed3 * x + seed4 * y + (rnum >> 10); - uint c = seed5 * x + seed6 * y + (rnum >> 6); - uint d = seed7 * x + seed8 * y + (rnum >> 2); - - a &= 0x3F; - b &= 0x3F; - c &= 0x3F; - d &= 0x3F; - - if (partition_count < 4) { - d = 0; - } - if (partition_count < 3) { - c = 0; - } - - if (a >= b && a >= c && a >= d) { + uvec2 shift = uvec2( + (seed & 2) > 0 ? 4 : 5, + (partition_count == 3) ? 6 : 5 + ); + shift.xy = (seed & 1) > 0 ? shift.xy : shift.yx; + uvec4 rseed[2] = uvec4[]( + (uvec4(rnum) >> uvec4(0, 4, 8, 12)) & 0xf, + (uvec4(rnum) >> uvec4(16, 20, 24, 28)) & 0xf + ); + rseed[0] = (rseed[0] * rseed[0]) >> shift.xyxy; + rseed[1] = (rseed[1] * rseed[1]) >> shift.xyxy; + const uvec4 rnum_vec = uvec4(rnum) >> uvec4(14, 10, 6, 2); + const uvec4 result_mask = ((uvec4(0, 1, 2, 3) - partition_count) >> 8) & 0x3f; + uvec4 result = uvec4( + rseed[0].xz * pos.xx + rseed[0].yw * pos.yy + rnum_vec.xy, + rseed[1].xz * pos.xx + rseed[1].yw * pos.yy + rnum_vec.zw + ) & result_mask; + if (result.x >= result.y && result.x >= result.z && result.x >= result.w) { return 0; - } else if (b >= c && b >= d) { + } else if (result.y >= result.z && result.y >= result.w) { return 1; - } else if (c >= d) { + } else if (result.z >= result.w) { return 2; } else { return 3; @@ -459,11 +418,11 @@ void DecodeColorValues(uvec4 modes, uint num_partitions, uint color_data_bits, o } DecodeIntegerSequence(range - 1, num_values); uint out_index = 0; - for (int itr = 0; itr < result_index; ++itr) { + for (int i = 0; i < result_index; ++i) { if (out_index >= num_values) { break; } - const EncodingData val = GetEncodingFromVector(itr); + const EncodingData val = GetEncodingFromVector(i); const uint encoding = Encoding(val); const uint bitlen = NumBits(val); const uint bitval = BitValue(val); @@ -746,9 +705,8 @@ void UnquantizeTexelWeights(uvec2 size, bool is_dual_plane) { const uint num_planes = is_dual_plane ? 2 : 1; const uint area = size.x * size.y; const uint loop_count = min(result_index, area * num_planes); - for (uint itr = 0; itr < loop_count; ++itr) { - result_vector[itr] = - UnquantizeTexelWeight(GetEncodingFromVector(itr)); + for (uint i = 0; i < loop_count; ++i) { + result_vector[i] = UnquantizeTexelWeight(GetEncodingFromVector(i)); } } @@ -1055,7 +1013,7 @@ void DecompressBlock(ivec3 coord) { for (uint i = 0; i < block_dims.x; i++) { uint local_partition = 0; if (num_partitions > 1) { - local_partition = Select2DPartition(partition_index, i, j, num_partitions); + local_partition = Select2DPartition(partition_index, uvec2(i, j), num_partitions); } const uvec4 C0 = ReplicateByteTo16(endpoints0[local_partition]); const uvec4 C1 = ReplicateByteTo16(endpoints1[local_partition]);