#include "core.hlsl"
#include "scene.h"
#include "radiance_caching.h"

cbuffer g_lubo { radiance_caching_ubo_t g_lubo; };
RWTexture2DArray<uint> g_src_cache_mask;
RWTexture2DArray<float4> g_src_cache_fh;
RWTexture2DArray<float4> g_src_cache_diffuse;
RWTexture2DArray<uint> g_dst_cache_mask;
RWTexture2DArray<float4> g_dst_cache_fh;
RWTexture2DArray<float4> g_dst_cache_diffuse;
RWTexture2DArray<uint> g_repr_indices;
RWTexture2DArray<int> g_depthmin;
RWTexture2DArray<int64_t> g_depthmax;
RWStructuredBuffer<shading_statistics_t> g_shading_statistics;

#include "bindings.hlsl"
#include "gltf.hlsl"
#include "scene.hlsl"

[[vk::constant_id(0)]] const int mask_mode = 0;
[[vk::constant_id(1)]] const int hybrid_cache = 0;
[[vk::constant_id(2)]] const int sc_has_radiance = 1;
[[vk::constant_id(3)]] const int two_pass_reprojection = 0;
[[vk::constant_id(4)]] const int sampling_mode = 0;
[[vk::constant_id(5)]] const int shading_mode = 0;

#include "common.hlsl"
#include "sample_mask.hlsl"
#include "sample_cache.hlsl"

// clang-format off
[numthreads(8, 8, 1)]
void main(uint3 gId : SV_DispatchThreadID)
{
    // clang-format on
    const uint3 img_res = g_lubo.display_res;
    const uint3 cache_res = g_lubo.screen_cache_res;
    if (gId.z >= img_res.z)
        return;
    dbg = all(gId == img_res / 2);

    const uint frame_idx = g_lubo.sample_idx;
    const CameraState cam = g_lubo.cam[gId.z];
    const float3 cam_origin = cam.makeRayOrigin();
    const float pixel_size = cam.get_pixel_diagonal(img_res.xy);
    const uint src_mask_val = g_src_cache_mask[gId];

    const uint num_slots = mask_mode == 0 ? g_lubo.screen_space_num_layers : g_lubo.screen_M * g_lubo.screen_M;
    const uint occupancy_bits = (1u << num_slots) - 1u;
    const uint num_samples = countbits(src_mask_val & ) count_samples(src_mask_val, gId, num_slots);

    // adaptive shading
    float importance = 0;
    float importance_w = 0;

    uint outstanding_samples_mask = get_occupancy_mask(src_mask_val);
    for (uint i = 0; i < num_samples; i++)
    {
        const uint z = firstbitlow(outstanding_samples_mask);
        outstanding_samples_mask &= ~(1u << z);

        const bool src_occupied = is_sample_occupied(src_mask_val, z);
        const bool src_visible = is_sample_visible(src_mask_val, z, num_slots);
        if (!src_occupied)
            continue;

        uint3 src_cache_idx = decode_cache_idx(gId, z);
        float4 fh_val = g_src_cache_fh[src_cache_idx];
        const first_hit_t fh = decode_first_hit(fh_val);

        // reproject the sample into the new camera space
        LocalGeometry lg = get_local_geometry(fh.instance_id, fh.prim_id, fh.bary);
        const float3 dst_uv = cam.worldToScreen(lg.pos);

        // reserve slot in cache and write entry if there is still space
        uint target_slot = 0;
        bool replace_sample = try_reserve_slot(cache_res, gId, src_visible, dst_uv, g_dst_cache_mask, target_slot);

        if (replace_sample)
        {
            uint3 dst_cache_idx = compute_cache_idx(cache_res, dst_uv.xy, gId.z, target_slot);
            if (two_pass_reprojection)
                g_repr_indices[dst_cache_idx] = pack_cache_idx(src_cache_idx);
            else
            {
                g_dst_cache_fh[dst_cache_idx] = fh_val;
                if (sc_has_radiance)
                    g_dst_cache_diffuse[dst_cache_idx] = g_src_cache_diffuse[src_cache_idx];
            }

            if (sampling_mode == ADAPTIVE_SAMPLING)
            {
                uint3 dst_mask_idx = compute_cache_idx(img_res, dst_uv.xy, gId.z, 0);
                float hyp_depth = cam.worldToScreenUv(lg.pos).z;
                float3 view_vec = lg.pos - cam_origin;
                float view_dist = length(view_vec);
                float3 view_dir = view_vec / view_dist;
                float cos_n = clamp(abs(dot(lg.surface_normal, view_dir)), 0, 1);
                float angle_n = acos(cos_n);
                float angle_p = atan(pixel_size);
                float depth_slope = abs(view_dist - (view_dist * cos_n / cos(angle_n - angle_p)));
                if (fh.instance_id == UINT_MAX)
                    depth_slope = 0.01 * TMAX * pixel_size; // environment sample
                depth_slope = depth_slope;
                int64_t dslope_val = (int64_t(asint(hyp_depth)) << 32) | int64_t(asint(depth_slope));
                int new_value = asint(hyp_depth);
                InterlockedMax(g_depthmax[dst_mask_idx], dslope_val);
                InterlockedMin(g_depthmin[dst_mask_idx], new_value);
            }

            if (shading_mode == ADAPTIVE_SHADING && src_visible)
            {
                float sample_importance = fh.instance_id != UINT_MAX ? 1.0f / (1 + g_src_cache_diffuse[src_cache_idx].a) : importance_min_val;
                importance += sample_importance;
                importance_w += 1.0f;
            }
        }
    }

    if ((shading_mode == ADAPTIVE_SHADING) && (importance > importance_min_val))
    {
       importance /= importance_w;

        uint stats_idx = frame_idx % 3;
        uint next_stats_idx = (frame_idx + 1) % 3;
        shading_statistics_t stats = g_shading_statistics[stats_idx];
        float lowest_bin = asfloat(stats.lowest_bin);
        float highest_bin = asfloat(stats.highest_bin);
        float log_lowest_bin = log2(lowest_bin);
        float log_highest_bin = log2(highest_bin);

        uint i = 0;
        if (importance <= lowest_bin)
            i = num_bins-1;
        else if (importance >= highest_bin)
            i = 0;
        else
            i = clamp(remapf<uint>(log_highest_bin, log_lowest_bin, 0, num_bins, log2(importance)), 0, num_bins-1);
        
        InterlockedAdd(g_shading_statistics[stats_idx].bins[i], uint(importance_w));
        float min_importance = importance > importance_min_val ? importance : importance_max_val;
        float max_importance = importance < importance_max_val ? importance : importance_min_val;
        int wave_min = WaveActiveMin(asint(min_importance));
        int wave_max = WaveActiveMax(asint(max_importance));
        if (WaveIsFirstLane())
        {
            InterlockedMin(g_shading_statistics[next_stats_idx].lowest_bin, wave_min);
            InterlockedMax(g_shading_statistics[next_stats_idx].highest_bin, wave_max);
        }
    }
}
