#include "core.hlsl"
#include "scene.h"
#include "radiance_caching.h"

cbuffer g_lubo { radiance_caching_ubo_t g_lubo; };
RWTexture2DArray<uint> g_cache_mask;
RWTexture2DArray<float4> g_cache_fh;
RWTexture2DArray<float4> g_cache_diffuse;
RWTexture2DArray<int> g_depthmin;
RWTexture2DArray<int64_t> g_depthmax;
RWStructuredBuffer<shading_statistics_t> g_shading_statistics;
RWTexture2DArray<float4> g_dbg;

[[vk::constant_id(0)]] const int mask_mode = 0;
[[vk::constant_id(1)]] const int sampling_mode = 0;
[[vk::constant_id(2)]] const int sc_has_radiance = 1;
[[vk::constant_id(3)]] const int hybrid_cache = 0;
[[vk::constant_id(4)]] const int shading_mode = 0;

#include "bindings.hlsl"
#include "gltf.hlsl"
#include "scene.hlsl"
#include "pathtracer.hlsl"
#include "common.hlsl"
#include "sample_mask.hlsl"
#include "sample_cache.hlsl"

#ifdef RGEN
// clang-format off
[shader("raygeneration")]
void main()
{
    // clang-format on
    const uint3 LaunchID = DispatchRaysIndex();
    const uint3 LaunchSize = DispatchRaysDimensions();
    const uint3 cache_res = g_lubo.screen_cache_res;
    const uint3 img_res = g_lubo.display_res;
    dbg = all(LaunchID == LaunchSize / 2); // enable printouts for debug pixel
    seed = tea(LaunchID.z * (LaunchSize.x * LaunchSize.y) + LaunchID.y * LaunchSize.x + LaunchID.x, g_lubo.random_seed);

    /*****************************************************************************/
    /* generate primary ray */
    // float2 sample_uv = g_lubo.jitter;
    CameraState cam = g_lubo.cam[LaunchID.z];
    float3 origin = cam.makeRayOrigin();

    /*****************************************************************************/
    /* For sample cache, aim the primary ray at the sample position */
    const uint num_slots = num_sample_slots();
    uint mask_val = g_cache_mask[LaunchID];

    uint active_samples = density_analysis(
        LaunchID,
        g_lubo.sample_idx,
        img_res,
        cam,
        num_slots,
        g_lubo.screen_space_target_density,
        g_lubo.screen_space_density_tolerance,
        sampling_mode,
        shading_mode,
        g_cache_mask,
        g_cache_fh,
        g_depthmin,
        g_depthmax,
        g_shading_statistics,
        g_dbg,
        mask_val
    );

    uint num_primary_rays = countbits(active_samples);
    for (uint sample_id = 0; sample_id < num_primary_rays; sample_id++)
    {
        uint target_sample = firstbithigh(active_samples);
        active_samples &= ~(1u << target_sample);

        const uint3 cache_idx = decode_cache_idx(LaunchID, target_sample);
        bool aim = is_sample_occupied(mask_val, target_sample);
        float3 target_point = 0;
        float3 dir = 0;
        float tmax = TMAX;
        if (aim)
        {
            const first_hit_t fh = decode_first_hit(g_cache_fh[cache_idx]);
            LocalGeometry lg = get_local_geometry(fh.instance_id, fh.prim_id, fh.bary);
            target_point = lg.pos;
            dir = target_point - origin;
            tmax = 1.001;
        }
        else if (mask_mode == 1)
        {
            float M = g_lubo.screen_M;
            dir = cam.makeRayDirection(LaunchID.xy * M + decode_subpixel(target_sample), LaunchSize.xy * M, 0.5f);
        }
        else
        {
            float2 sample_uv = clamp(Halton23((g_lubo.sample_idx + sample_id) % 32), 0.01, 0.99);
            // float2 sample_uv = float2(rnd(seed), rnd(seed));
            dir = cam.makeRayDirection(LaunchID.xy, LaunchSize.xy, sample_uv);
        }

        RayDesc rayDesc;
        rayDesc.Origin = origin;
        rayDesc.TMin = 0;
        rayDesc.TMax = tmax;
        rayDesc.Direction = dir;
        Payload payload = make_payload(rayDesc.Origin, rayDesc.Direction);
        TraceRay(g_topLevel, path_ray_flags, 0xff, 0, 1, 0, rayDesc, payload);

        const LocalGeometry lg = get_local_geometry(payload.instance_id, payload.prim_id, payload.bary);
        float3 hitpoint = lg.pos;
        bool is_visible = false;
        if (aim)
        {
            // update visibility bit
            is_visible = 0.999 <= payload.t && payload.t <= 1.001;

            if (is_visible)
                mask_val = set_visible_sample(mask_val, target_sample, num_slots);
            else
                mask_val = set_invisible_sample(mask_val, target_sample, num_slots);

            // if the last sample fails validation, simply overwrite it.
            if ((count_visible_samples(mask_val, LaunchID, num_sample_slots()) == 0) && (sample_id == num_primary_rays - 1))
            {
                aim = false;
            }
        }

        if (!aim)
        {
            // new hitpoints are always visible, occupancy bit is set in GI pass
            target_point = hitpoint;
            mask_val = enable_sample(mask_val, target_sample, num_slots);
            first_hit_t fh;
            fh.instance_id = payload.instance_id;
            fh.prim_id = payload.prim_id;
            fh.mat_id = payload.mat_id;
            fh.bary = payload.bary;
            g_cache_fh[cache_idx] = encode_first_hit(fh);
            if (sc_has_radiance)
                g_cache_diffuse[cache_idx] = 0;
        }
    }

    g_cache_mask[LaunchID] = mask_val;
}
#endif // RGEN
