#pragma once

#include "catmullrom_cubic.hlsl"
#include "dodgson_quadratic.hlsl"

static const float temp_alpha = 0.2;


float4 temp_cache_lookup(
    uint3 pixel,
    radiance_caching_ubo_t lubo,
    LocalGeometry lg,
    float3 normal,
    SamplerState linearSampler,
    Texture2DArray<float4> hist_diffuse,
    Texture2DArray<float> hist_depth,
    inout bool cache_miss
)
{
    const uint3 display_res = lubo.display_res;
    CameraState cam = lubo.cam[pixel.z];
    float3 origin = cam.makeRayOrigin();

    transform_buffer_entry_t transform = get_transform(lg.instance_id);
    float3 stable_hitpoint = origin + length(lg.pos - origin) * cam.makeRayDirection(pixel.xy, display_res.xy, 0.5);
    float3 local_stable_hitpoint = mul(transform.world_to_object, float4(stable_hitpoint, 1));
    float3 hitpoint = mul(transform.previous_object_to_world, float4(local_stable_hitpoint, 1.0f));
    float3 histUv = lubo.old_cam[pixel.z].worldToScreenUv(hitpoint);

    float dist = 1.0 / histUv.z;
    float dist_slope = 2 * get_dist_slope(lg, cam, display_res.xy);

    bool outside = any(0.0f > histUv) || any(histUv > 1.0f);

    if (outside)
        return float4(0,0,0,0);

    const uint3 kernel[4] =
    {
        { 0, 0, 0 }, { 1, 0, 0 },
        { 0, 1, 0 }, { 1, 1, 0 }
    };

    float3 hist_pixel = float3(histUv.xy * lubo.display_res.xy - 0.5, pixel.z);
    float3 ipos = float3(floor(hist_pixel.xy), pixel.z);
    float2 bilinear_w = hist_pixel.xy - ipos.xy;
    // const float weights[4] = { 1, 0, 0, 0 };
    const float weights[4] = 
    {
        (1 - bilinear_w.x) * (1 - bilinear_w.y),
        bilinear_w.x * (1 - bilinear_w.y),
        (1 - bilinear_w.x) * bilinear_w.y,
        bilinear_w.x * bilinear_w.y,
    };

    float4 hist_c = 0.0f;
    float sum_w = 0.0f;
    for (uint i = 0; i < 4; i++)
    {
        uint3 p = ipos + kernel[i];
        if (any(p > display_res))
            continue;

        float4 c = hist_diffuse[p];
        float d = 1.0 / hist_depth[p];

        float w = clamp(weights[i], 0, 1);

        float threshold0 = 0.95 * (dist - dist_slope);
        float threshold1 = 1.05 * (dist + dist_slope);
        if (threshold0 > d || d > threshold1)
            w = 0.0f;

        hist_c += w * c;
        sum_w += w;
    }

    const float max_spp = 5;
    if (sum_w > 0)
        hist_c = float4(hist_c.rgb / sum_w, min(max_spp, hist_c.a / sum_w));
    else
    {
        cache_miss = true;
        hist_c = float4(0,0,0,0);
    }

    return hist_c;
}

float3 update_temp_cache_entry(RWTexture2DArray<float4> cached_diffuse, uint3 cache_idx, float4 rec_diffuse, float4 new_diffuse, uint max_sample_history)
{
    float4 diffuse_final = float4(lerp(rec_diffuse, new_diffuse, temp_alpha).rgb, 1);
    if (rec_diffuse.a == 0)
        diffuse_final = new_diffuse;
    // float4 diffuse_final = temporal_filter(rec_diffuse, new_diffuse, max_sample_history);

    cached_diffuse[cache_idx] = diffuse_final;
    return diffuse_final.rgb;
}
