#include "core.hlsl"
#include "scene.h"
#include "radiance_caching.h"
#include "common.hlsl"

[[vk::constant_id(0)]] const int mask_mode = 0;
[[vk::constant_id(1)]] const int hybrid_cache = 0;
[[vk::constant_id(2)]] const int sc_has_radiance = 1;

cbuffer g_lubo { radiance_caching_ubo_t g_lubo; };
Texture2DArray<uint> g_repr_source;
Texture2DArray<uint> g_src_cache_mask;
RWTexture2DArray<uint> g_dst_cache_mask;
Texture2DArray<float4> g_src_cache_fh;
RWTexture2DArray<float4> g_dst_cache_fh;
Texture2DArray<float4> g_src_cache_diffuse;
RWTexture2DArray<float4> g_dst_cache_diffuse;

#include "bindings.hlsl"
#include "gltf.hlsl"
#include "scene.hlsl"
#include "sample_mask.hlsl"
#include "sample_cache.hlsl"

// clang-format off
[numthreads(8, 8, 1)]
void main(
    uint3 gId : SV_DispatchThreadID
)
{
    // clang-format on
    const uint3 img_res = g_lubo.display_res;
    const uint3 cache_res = g_lubo.screen_cache_res;
    const uint num_slots = g_lubo.screen_space_num_layers;
    dbg = all(gId == ((img_res / 2)));

    if (any(gId >= img_res))
        return;

    const uint dst_mask_val = g_dst_cache_mask[gId];
    uint occupant_samples = get_occupancy_mask(dst_mask_val);
    uint num_samples = countbits(occupant_samples);
    for (uint i = 0; i < num_samples; i++)
    {
        uint z = firstbitlow(occupant_samples);
        const uint3 cache_src_idx = unpack_cache_idx(g_repr_source[uint3(gId.xy, z)]);
        g_dst_cache_fh[uint3(gId.xy, z)] = g_src_cache_fh[cache_src_idx];

        if (sc_has_radiance)
            g_dst_cache_diffuse[uint3(gId.xy, z)] = g_src_cache_diffuse[cache_src_idx];

        occupant_samples &= ~(1u << z);
    }

    // re-read visibility bits in case they got falsified during reprojection
    occupant_samples = get_occupancy_mask(dst_mask_val);
    uint new_mask_val = occupant_samples;
    for (uint i = 0; i < num_samples; i++)
    {
        uint z = firstbitlow(occupant_samples);
        const uint3 cache_src_idx = unpack_cache_idx(g_repr_source[uint3(gId.xy, z)]);
        const uint3 img_src_idx = uint3(cache_src_idx.xy, cache_src_idx.z / num_slots);
        uint src_mask_val = g_src_cache_mask[img_src_idx];
        if (is_sample_visible(src_mask_val, z, num_slots))
            new_mask_val |= (1u << (z + num_slots));

        occupant_samples &= ~(1u << z);
    }
    g_dst_cache_mask[gId] = new_mask_val;
}
