#include "core.hlsl"
#include "scene.h"
#include "radiance_caching.h"

cbuffer g_lubo { radiance_caching_ubo_t g_lubo; };
Texture2DArray<uint2> g_prim_id;
Texture2DArray<uint> g_mat_id;
Texture2DArray<float2> g_bary;
Texture2DArray<float4> g_hist_diffuse;
Texture2DArray<float> g_hist_depth;
SamplerState g_hist_sampler;
RWTexture2DArray<float4> g_diffuse;
RWTexture2DArray<float4> g_display_colour;

#include "bindings.hlsl"
#include "gltf.hlsl"
#include "scene.hlsl"
#include "pathtracer.hlsl"
#include "common.hlsl"
#include "radiance_caching.hlsl"
#include "temp_repr_cache.hlsl"

[[vk::constant_id(0)]] const int highlight_cache_misses = 0;

#ifdef RGEN

static const bool cache_glossy = false;

// clang-format off
[shader("raygeneration")]
void main()
{
    // clang-format on
    uint3 LaunchID = DispatchRaysIndex();
    uint3 display_res = dim3(g_display_colour);
    uint3 LaunchSize = DispatchRaysDimensions();
    dbg = all(LaunchID == LaunchSize / 2); // enable printouts for debug pixel

    seed = tea(
        g_lubo.sample_idx * (display_res.x * display_res.y * display_res.z) + LaunchID.z * (display_res.x * display_res.y)
            + LaunchID.y * display_res.x + LaunchID.x,
        g_ubo.generation
    );

    /*****************************************************************************/
    /* look up diffuse radiance in cache and mark the cache cell as in use */
    uint2 id_val = g_prim_id[LaunchID];
    uint instance_id = id_val.x;
    uint prim_id = id_val.y;
    if (instance_id == UINT_MAX)
        return;
    CameraState cam = g_lubo.cam[LaunchID.z];
    float3 origin = cam.makeRayOrigin();
    const float coneDiff = cam.getConeDiff(LaunchSize.xy);

    uint mat_id = g_mat_id[LaunchID];
    float2 bary = g_bary[LaunchID];
    LocalGeometry lg = get_local_geometry(instance_id, prim_id, bary);
    float3 out_dir = normalize(lg.pos - origin);
    GltfShadeParams sp = get_shading_params(mat_id, out_dir, 0, lg);

    float3 stable_hitpoint = origin + length(lg.pos - origin) * cam.makeRayDirection(LaunchID.xy, LaunchSize.xy, 0.5);
    const float lod = log2(coneDiff * length(origin - stable_hitpoint.xyz));

    bool cache_miss = true;
    const float4 rec_diffuse = temp_cache_lookup(
        LaunchID,
        g_lubo,
        lg,
        sp.normal,
        g_hist_sampler,
        g_hist_diffuse,
        g_hist_depth,
        cache_miss
    );

    if (highlight_cache_misses && cache_miss)
        g_display_colour[LaunchID] = float4(0, HLSL_FLT_MAX, HLSL_FLT_MAX, 1);

    float4 new_diffuse = float4(0, 0, 0, 1);
    float4 new_glossy = float4(0, 0, 0, 1);
    const float3 diffuse_albedo = sp.lambertian_weight(0.0) * sp.base_color.rgb;
    float3 glossy_albedo = 0.0f;
    if (cache_glossy)
        glossy_albedo = max(0.1, (sp.metallic_weight() + sp.specular_reflection_weight(1)) * sp.base_color.rgb); 

    float2x4 direct_light = compute_direct_light(g_ubo.max_bounces, g_ubo.env, lg, out_dir, sp, cache_glossy, true);
    new_diffuse.rgb += direct_light[0].rgb / max(diffuse_albedo, albedo_eps);
    new_glossy.rgb += direct_light[1].rgb / max(glossy_albedo, albedo_eps);

    float2x4 indirect_light = compute_indirect_light(g_ubo.max_bounces, g_ubo.env, lg, out_dir, coneDiff, lod, sp, cache_glossy);
    new_diffuse.rgb += indirect_light[0].rgb / max(diffuse_albedo, albedo_eps);
    new_glossy.rgb += indirect_light[1].rgb / max(glossy_albedo, albedo_eps);

    new_diffuse.rgb = clip_fireflies(new_diffuse.rgb);
    new_glossy.rgb = clip_fireflies(new_glossy.rgb);

    g_display_colour[LaunchID].rgb += diffuse_albedo * update_temp_cache_entry(g_diffuse, LaunchID, rec_diffuse, new_diffuse, g_lubo.max_sample_history);
}
#endif // RGEN
