#include "core.hlsl"
#include "scene.h"
#include "radiance_caching.h"
#include "catmullrom_cubic.hlsl"
#include "dodgson_quadratic.hlsl"

cbuffer g_lubo { radiance_caching_ubo_t g_lubo; };
RWTexture2DArray<uint> g_cache_mask;
RWTexture2DArray<float4> g_cache_fh;
RWTexture2DArray<float4> g_cache_diffuse;
RWTexture2DArray<float4> g_display_colour;
Texture2DArray<float4> g_hist_colour;
RWTexture2DArray<float> g_depth;
RWTexture2DArray<float2> g_mv;
RWTexture2DArray<float4> g_variance;
Texture2DArray<float4> g_hist_variance;
HASH_CACHE_SHADER_TEXTURE_TYPE<HASH_MASK_SHADER_TYPE> g_hash_mask;
HASH_CACHE_SHADER_TEXTURE_TYPE<float4> g_hash_cache;
RWTexture2DArray<float4> g_dbg;

#include "bindings.hlsl"
#include "gltf.hlsl"
#include "scene.hlsl"

[[vk::constant_id(0)]] const int highlight_cache_misses = 0;
[[vk::constant_id(3)]] const int mask_mode = 0;
[[vk::constant_id(4)]] const int generate_variance = 0;
[[vk::constant_id(5)]] const int hybrid_cache = 0;
[[vk::constant_id(6)]] const int sc_has_radiance = 1;
[[vk::constant_id(7)]] const int adjust_hash_lod_by_normal = 0;
[[vk::constant_id(8)]] const int cache_glossy = 0;

#include "common.hlsl"
#include "hash_cache.hlsl"
#include "sample_mask.hlsl"
#include "sample_cache.hlsl"

struct processed_samples_t
{
    uint num_depth_rejected;
    float rejection_depth;
    float3 closest_pos;
    uint closest_hit_instance;
    float closest_depth;
    float env_fraction;
    float num_trace_samples;
    float num_shade_samples;
    float3 background;
    float3 emission;
    float3 albedo;
    float3 diffuse;
    float3 diffuse2;
};

processed_samples_t init_processed_samples()
{
    processed_samples_t samples;
    samples.num_depth_rejected = 0;
    samples.closest_pos = HLSL_FLT_MAX;
    samples.closest_depth = 0;
    samples.closest_hit_instance = -1;
    samples.env_fraction = 0;
    samples.num_trace_samples = 0;
    samples.num_shade_samples = 0;
    samples.background = 0;
    samples.albedo = 0;
    samples.emission = 0;
    samples.diffuse = 0;
    samples.diffuse2 = 0;
    return samples;
}

float3 choose_closest_sample(float3 origin, float3 pos, float3 closest_pos)
{
    // TODO: fix this, simple_min should be false, but code chooses unstable samples
    const bool simple_min = true;

    if (simple_min)
    {
        return min(closest_pos, pos);
    }
    else
    {
        float len1 = length(origin - closest_pos);
        float len2 = length(origin - pos);
        if (len1 > len2)
            closest_pos = pos;
        else if (len1 == len2)
            closest_pos = min(closest_pos, pos);
        return closest_pos;
    }
}

// float sample_weight(
//     uint3 img_idx,
//     uint3 display_res,
//     float3 sample_uv3
// )
// {
//     const bool equal_weights = true;
//     if (equal_weights)
//         return 1.0f;

//     const float2 frac_uv = display_res.xy * sample_uv3.xy - (img_idx.xy + 0.5);
//     return gaussian2d(frac_uv);
// }

processed_samples_t lookup_samplecache(
    uint3 gId,
    uint3 img_res,
    CameraState cam,
    bool reject_depth,
    float traced_depth
)
{
    processed_samples_t samples = init_processed_samples();
    if (any(gId >= img_res))
        return samples;

    const uint mask_val = g_cache_mask[gId];
    const uint num_visible_samples = count_visible_samples(mask_val, gId, num_sample_slots());
    samples.rejection_depth = traced_depth;

    // first deteremine average albedo and how many samples are background samples
    // so that we can mix the background in correctly.
    samples.num_trace_samples = num_visible_samples;
    samples.albedo = 0.0f;
    samples.env_fraction = 0.0f;
    samples.emission = 0;

    uint outstanding_samples_mask = get_visible_samples_mask(mask_val);
    for (uint i = 0; i < num_visible_samples; i++)
    {
        uint z = firstbitlow(outstanding_samples_mask);
        outstanding_samples_mask &= ~(1u << z);
        const uint3 cache_idx = decode_cache_idx(gId, z);
        sample_entry_t entry;
        entry.fh = decode_first_hit(g_cache_fh[cache_idx]);
        entry.radiance.diffuse = 0;
        if (sc_has_radiance)
            entry.radiance.diffuse = g_cache_diffuse[cache_idx];

        const LocalGeometry lg = get_local_geometry(entry.fh.instance_id, entry.fh.prim_id, entry.fh.bary);
        
        float sample_depth = cam.worldToScreenUv(lg.pos).z;
        if (sample_depth > samples.closest_depth)
        {
            samples.closest_hit_instance = lg.instance_id;
            samples.closest_depth = sample_depth;
        }
        samples.closest_pos = choose_closest_sample(cam.makeRayOrigin(), lg.pos, samples.closest_pos);

        float4 ray_cone = compute_ray_cone(img_res, lg, cam);
        const float cone_width = ray_cone.w;
        const float3 out_dir = ray_cone.xyz;
        const GltfShadeParams sp = get_shading_params(entry.fh.mat_id, out_dir, cone_width / sqrt(32), lg);
        samples.albedo += sp.base_color.rgb;
        samples.emission += sp.emission + sp.base_color.rgb * g_ubo.env.ambient;

        if (lg.instance_id == UINT_MAX)
        {
            samples.env_fraction += 1;
            samples.background += getEnvColor(g_ubo.env, out_dir, true, 0);
        }
    }
    if (samples.num_trace_samples > 0)
    {
        samples.albedo.rgb /= samples.num_trace_samples;
        samples.emission.rgb /= samples.num_trace_samples;
        samples.env_fraction /= samples.num_trace_samples;
        samples.background /= samples.num_trace_samples;
    }

    outstanding_samples_mask = get_visible_samples_mask(mask_val);
    for (uint i = 0; i < num_visible_samples; i++)
    {
        uint z = firstbitlow(outstanding_samples_mask);
        outstanding_samples_mask &= ~(1u << z);
        const uint3 cache_idx = decode_cache_idx(gId, z);
        float4 diffuse = 0;
        if (sc_has_radiance && samples.env_fraction < 1)
            diffuse = g_cache_diffuse[cache_idx];
        else if (hybrid_cache)
        {
            first_hit_t fh = decode_first_hit(g_cache_fh[cache_idx]);
            const LocalGeometry lg = get_local_geometry(fh.instance_id, fh.prim_id, fh.bary);
            float4 ray_cone = compute_ray_cone(img_res, lg, cam);
            const float cone_width = ray_cone.w;
            const float3 out_dir = ray_cone.xyz;
            float lod = log2(cone_width);
            const GltfShadeParams sp = get_shading_params(fh.mat_id, out_dir, cone_width / sqrt(32), lg);

            if (adjust_hash_lod_by_normal)
                lod -= log2(abs(dot(out_dir, lg.surface_normal)));
            cache_query_t query;
            query.sample_idx = g_lubo.sample_idx;
            query.hash_map_run_length = g_lubo.hash_map_run_length;
            query.hash_map_cell_lifetime = g_lubo.hash_map_cell_lifetime;
            query.hash_map_block_exp = g_lubo.hash_block_size_exp;
            query.use_dir = false;
            query.use_normal = true;
            query.instance_id = fh.instance_id;
            query.local_pos = lg.local_pos;
            query.lod = lod + g_lubo.spatial_lod_bias;
            query.dir = out_dir;
            query.roughness = sp.roughness;
            query.normal = dot(-out_dir, sp.normal) > 0 ? -sp.normal : sp.normal;

            diffuse = filtered_hash_lookup(g_hash_mask, g_hash_cache, query).diffuse;
        }

        // float3 sample_uv3 = cam.worldToScreenUv(lg.pos);
        float weight = 1.0; // sample_weight(gId, img_res, sample_uv3);
        if (diffuse.a)
        {
            float3 colour = diffuse.rgb;
            samples.diffuse += weight * diffuse.a * colour;
            samples.diffuse2 += weight * diffuse.a * sqr(colour);
            samples.num_shade_samples += weight * diffuse.a;
        }
    }

    if (samples.num_shade_samples > 0)
    {
        samples.diffuse.rgb /= samples.num_shade_samples;
        samples.diffuse2.rgb /= samples.num_shade_samples;
    }

    return samples;
}


// clang-format off
[numthreads(16, 8, 1)]
void main(uint3 gId : SV_DispatchThreadID, uint3 lId : SV_GroupThreadID)
{
    // clang-format on
    const uint3 img_res = g_lubo.display_res;
    if (any(gId >= img_res))
        return;

    dbg = all(gId == img_res / 2); // enable printouts for debug pixel

    processed_samples_t samples = lookup_samplecache(gId, g_lubo.display_res, g_lubo.cam[gId.z], false, 0.0f);

    float3 colour = lerp(samples.albedo.rgb * samples.diffuse.rgb + samples.emission, samples.background, samples.env_fraction);
    float3 colour2 = lerp(sqr(samples.albedo.rgb) * samples.diffuse2.rgb + sqr(samples.emission), sqr(samples.background), samples.env_fraction);
    float4 m1 = float4(colour, 1);
    float4 m2 = float4(colour2, samples.num_trace_samples);

    if (samples.num_shade_samples <= 1 && highlight_cache_misses)
    {
        g_display_colour[gId] = float4(0, HLSL_FLT_MAX, HLSL_FLT_MAX, 1);
    }
    else
    {
        g_display_colour[gId] = m1;
    }

    if (generate_variance)
        g_variance[gId] = m2;

    CameraState cam = g_lubo.cam[gId.z];
    CameraState old_cam = g_lubo.old_cam[gId.z];

    float3 origin = cam.makeRayOrigin();
    float dist = length(origin - samples.closest_pos);
    float3 stable_hitpoint = origin + dist * cam.makeRayDirection(gId.xy, img_res.xy, 0.5);
    g_depth[gId] = cam.worldToScreenUv(stable_hitpoint).z;

    transform_buffer_entry_t transform = get_transform(samples.closest_hit_instance);
    float3 local_stable_hitpoint = mul(transform.world_to_object, float4(stable_hitpoint, 1));
    float3 previous_hitpoint = mul(transform.previous_object_to_world, float4(local_stable_hitpoint, 1.0f));
    float2 screenUV = old_cam.worldToScreenUv(previous_hitpoint).xy;
    float2 mv = screenUVToMV(gId.xy, img_res.xy, screenUV);
    g_mv[gId] = mv;

    // float3 hist_uv3 = float3(screenUV, gId.z);
    // float4 hist_m1 = max(0.0, SampleTextureCatmullRom(g_hist_colour, g_linear_sampler, hist_uv3, dim3(g_hist_colour).xy));
    // float4 hist_m2 = max(0.0, SampleTextureCatmullRom(g_hist_variance, g_linear_sampler, hist_uv3, dim3(g_hist_variance).xy));
    // float varX = dot(m2.rgb - sqr(m1.rgb), 1);
    // float sqr_diff = dot(m2.rgb - 2 * m1.rgb * hist_m1.rgb + hist_m2.rgb, 1);
    // g_dbg[gId] = clamp((varX + 5e-7) / (sqr_diff + 1e-6), 0, 1.0);
    // g_dbg[gId] = hist_m2 - sqr(hist_m1);
}

