#pragma once

float4 compute_ray_cone(uint3 img_res, LocalGeometry lg, CameraState cam)
{
    const float3 origin = cam.makeRayOrigin();
    const float cone_diff = cam.getConeDiff(img_res.xy);
    const float3 V = lg.pos - origin;
    const float cone_width = cone_diff * length(V);
    const float3 out_dir = normalize(V);
    return float4(out_dir, cone_width);
}

struct cache_query_t
{
    uint sample_idx;
    uint hash_map_cell_lifetime;
    uint hash_map_run_length;
    uint hash_map_block_exp;

    bool use_dir;
    bool use_normal;
    uint instance_id;
    float3 local_pos;
    float lod;
    float3 dir;
    float roughness;
    float3 normal;
};

struct cache_address_t
{
    uint collision_hash;
    uint block_idx;
    uint3 cell_idx3;
};

struct cache_query_result_t
{
    bool cache_miss;
    uint3 cache_idx;
    uint3 cache_hit_idx;
    uint nb_collisions;

    cached_radiance_t radiance;
    cached_radiance_t rec_radiance;
};

cache_query_result_t init_cache_query_result()
{
    cache_query_result_t v;
    v.cache_miss = true;
    v.cache_idx = UINT_MAX;
    v.cache_hit_idx = UINT_MAX;
    v.nb_collisions = 0u;
    v.radiance = init_cached_radiance();
    v.rec_radiance = init_cached_radiance();
    return v;
}

uint3 entry_dims()
{
    if (cache_glossy)
        return hash_cell_glossy_dims;
    else
        return hash_cell_diffuse_dims;
}

uint3 cache_dims(HASH_CACHE_SHADER_TEXTURE_TYPE<float4> cache)
{
    uint3 dims = dim3(cache);
    uint3 struct_dims = entry_dims();
    return uint3(dims.x, dims.y, dims.z) / struct_dims;
}

uint3 hash_block_dims(uint hash_block_exp)
{
    uint hash_block_width = 1u << hash_block_exp;
    uint3 dims = uint3(hash_block_width, hash_block_width, hash_block_width);
    if (HASH_CACHE_ARRAYED)
        dims = uint3(hash_block_width, hash_block_width, 1);
    return dims;
}

uint3 cache_idx3(uint cache_idx, uint3 cdims)
{;
    uint idx = cache_idx % (cdims.x * cdims.y * cdims.z);
    return uint3(
        idx % cdims.x,
        (idx % (cdims.x * cdims.y)) / cdims.x,
        idx / (cdims.x * cdims.y)
    );
}

uint3 hash_collision(cache_address_t hashes, uint i, uint3 cdims)
{
    uint3 hdims = hash_block_dims(g_lubo.hash_block_size_exp);
    uint3 bdims = cdims / hdims;
    uint3 block_idx3 = cache_idx3(hashes.block_idx + i, bdims);
    return block_idx3 * hdims + hashes.cell_idx3;
}

uint2 decode_hash_mask(uint64_t mask_val)
{
    return uint2(uint(mask_val >> 32), uint(mask_val & 0xFFFFFFFF));
}

uint64_t encode_hash_mask(uint2 mask_vals)
{
    return (uint64_t(mask_vals.x) << 32) | uint64_t(mask_vals.y);
}

uint2 read_info_entry(HASH_CACHE_SHADER_TEXTURE_TYPE<HASH_MASK_SHADER_TYPE> hash_mask, uint3 cache_idx, uint3 cache_dims)
{
    uint64_t val = hash_mask[cache_idx];
    return uint2(uint(val >> 32), uint(val & 0xFFFFFFFF));
}

cached_radiance_t read_hashcache_entry(HASH_CACHE_SHADER_TEXTURE_TYPE<float4> cache, uint3 cache_idx, uint3 cache_dims)
{
    cached_radiance_t entry = init_cached_radiance();
    entry.diffuse = cache[cache_idx * entry_dims() + uint3(0, 0, 0)];
    if (cache_glossy)
    {
        entry.gloss[0] = cache[cache_idx * entry_dims() + uint3(0, 1, 0)];
        entry.gloss[1] = cache[cache_idx * entry_dims() + uint3(0, 2, 0)];
        entry.out_dir[0] = cache[cache_idx * entry_dims() + uint3(0, 3, 0)];
        entry.out_dir[1] = cache[cache_idx * entry_dims() + uint3(0, 4, 0)];
    }
    return entry;
}

void write_hashcache_entry(HASH_CACHE_SHADER_TEXTURE_TYPE<float4> cache, uint3 cache_idx, uint3 cache_dims, cached_radiance_t entry)
{
    cache[cache_idx * entry_dims() + uint3(0, 0, 0)] = entry.diffuse;
    if (cache_glossy)
    {
        cache[cache_idx * entry_dims() + uint3(0, 1, 0)] = entry.gloss[0];
        cache[cache_idx * entry_dims() + uint3(0, 2, 0)] = entry.gloss[1];
        cache[cache_idx * entry_dims() + uint3(0, 3, 0)] = entry.out_dir[0];
        cache[cache_idx * entry_dims() + uint3(0, 4, 0)] = entry.out_dir[1];
    }
    return;
}

uint hash(int4 aabb_val, int p[4])
{
    // prime hash https://doi.org/10.1109/RT.2007.4342602
    uint idx = 0;
    for (int i = 0; i < 4; i++)
    {
        // better hash_combination from: https://github.com/g-truc/glm/blob/master/glm/gtx/hash.inl
        uint hash = aabb_val[i] * p[i] + 0x9e3779b9 + (idx << 6) + (idx >> 2);
        idx ^= hash;
    }
    return uint(idx);
}

uint hash1(int4 aabb_val)
{
    const int p[4] = { 1404211, 2043997, 6868819, 8445629 };
    return hash(aabb_val, p);
}

uint hash2(int4 aabb_val)
{
    const int p[4] = { 1476047, 2312899, 4450909, 9627113 };
    return hash(aabb_val, p);
}

// Jarzynski2020 https://jcgt.org/published/0009/03/02/
uint4 pcg4d(uint4 v)
{
    v = v * 1664525u + 1013904223u;
    // v += v.yzxy * v.wxyz;
    v.x += v.y * v.w;
    v.y += v.z * v.x;
    v.z += v.x * v.y;
    v.w += v.y * v.z;
    v = v ^ (v >> 16u);
    // v += v.yzxy * v.wxyz;
    v.x += v.y * v.w;
    v.y += v.z * v.x;
    v.z += v.x * v.y;
    v.w += v.y * v.z;
    return v;
}

uint3 hash_to_idx(uint seed, uint2 dims) { return uint3(seed % dims.x, (seed / dims.x) % dims.y, 0); }

// https://en.wikipedia.org/wiki/User:Microwerx in https://en.wikipedia.org/wiki/Cube_mapping
float3 convert_xyz_to_cube_uv(float3 p)
{
    float absX = abs(p.x);
    float absY = abs(p.y);
    float absZ = abs(p.z);

    int isXPositive = p.x > 0 ? 1 : 0;
    int isYPositive = p.y > 0 ? 1 : 0;
    int isZPositive = p.z > 0 ? 1 : 0;

    float3 ret = 0;

    float maxAxis, uc, vc;

    // POSITIVE X
    if (isXPositive && absX >= absY && absX >= absZ)
    {
        // u (0 to 1) goes from +z to -z
        // v (0 to 1) goes from -y to +y
        maxAxis = absX;
        uc = -p.z;
        vc = p.y;
        ret.z = 0;
    }
    // NEGATIVE X
    if (!isXPositive && absX >= absY && absX >= absZ)
    {
        // u (0 to 1) goes from -z to +z
        // v (0 to 1) goes from -y to +y
        maxAxis = absX;
        uc = p.z;
        vc = p.y;
        ret.z = 1;
    }
    // POSITIVE Y
    if (isYPositive && absY >= absX && absY >= absZ)
    {
        // u (0 to 1) goes from -x to +x
        // v (0 to 1) goes from +z to -z
        maxAxis = absY;
        uc = p.x;
        vc = -p.z;
        ret.z = 2;
    }
    // NEGATIVE Y
    if (!isYPositive && absY >= absX && absY >= absZ)
    {
        // u (0 to 1) goes from -x to +x
        // v (0 to 1) goes from -z to +z
        maxAxis = absY;
        uc = p.x;
        vc = p.z;
        ret.z = 3;
    }
    // POSITIVE Z
    if (isZPositive && absZ >= absX && absZ >= absY)
    {
        // u (0 to 1) goes from -x to +x
        // v (0 to 1) goes from -y to +y
        maxAxis = absZ;
        uc = p.x;
        vc = p.y;
        ret.z = 4;
    }
    // NEGATIVE Z
    if (!isZPositive && absZ >= absX && absZ >= absY)
    {
        // u (0 to 1) goes from +x to -x
        // v (0 to 1) goes from -y to +y
        maxAxis = absZ;
        uc = -p.x;
        vc = p.y;
        ret.z = 5;
    }

    // Convert range from -1 to 1 to 0 to 1
    ret.x = 0.5f * (uc / maxAxis + 1.0f);
    ret.y = 0.5f * (vc / maxAxis + 1.0f);
    return ret;
}

cache_query_t jitter_cache_query(uint jitter_idx, cache_query_t q)
{
    // generate random vector based on periodic index (jitter_idx)
    float3 jitter_uv = float3(Halton(2, jitter_idx), Halton(3, jitter_idx), Halton(5, jitter_idx));
    float2 jitter_uv2 = float2(Halton(7, jitter_idx), Halton(9, jitter_idx));
    const bool use_box_muller = false;
    if (use_box_muller)
    {
        jitter_uv = boxmuller3D(jitter_uv);
        jitter_uv2.xy = boxmuller(jitter_uv2);
    }

    // position is jittered in the plane of the hitpoint normal
    float3x3 onb = onb_frisvad2(q.normal);
    float2 jitter_vec = pow(2, q.lod) * (2 * jitter_uv.xy - 1);
    q.local_pos += (jitter_vec.x * onb[0] + jitter_vec.y * onb[1]);
    q.lod += jitter_uv.z;

    // direction is jittered within phong lobe based on roughness
    const float max_k = 128;
    const float k = clamp((1.0f - sqrt(clamp(q.roughness, min_roughness, 1.0f))) * max_k, 0, max_k);
    q.dir = mul(transpose(onb_frisvad2(q.dir)), sample_cosine_power_hemisphere2(k, jitter_uv2).xyz).xyz;

    return q;
}

cache_address_t compute_hashes(cache_query_t q, uint3 cdims)
{
    // quantize object-space position to retrieve voxel index
    int ilod = floor(q.lod);
    int3 ipos = int3(floor(q.local_pos / pow(2.0f, ilod)));
    uint3 hdims = hash_block_dims(q.hash_map_block_exp);
    int3 iblock = ipos / (1u << q.hash_map_block_exp);
    int3 icell = ipos % (1u << q.hash_map_block_exp);
    uint4 hashes = pcg4d(int4(iblock, ilod));
    uint4 hashes2 = pcg4d(int4(ipos, ilod));
    hashes.w = hashes2.w;

    // combine position hashes with instance hash
    hashes = hash1(q.instance_id) + 0x9e3779b9 + (hashes << 6) + (hashes >> 2);

    // project entire volume of 3d hashblock onto a 2d square because most geometry can be approximated by a surface
    // where surfaces are more complex or overlap, the collision hash will keep correctness!
    if (HASH_CACHE_ARRAYED)
    {
        if (abs(q.normal[0]) >= abs(q.normal[1]) && abs(q.normal[0]) >= abs(q.normal[2]))
            icell = uint3(icell.yz, 0);
        else if (abs(q.normal[1]) >= abs(q.normal[0]) && abs(q.normal[1]) >= abs(q.normal[2]))
            icell = uint3(icell.zx, 0);
        else
            icell = uint3(icell.xy, 0);
    }

    // if needed, hashes can be distinguished by the hit normal to prevent most radiance leaks through thin surfaces
    if (q.use_normal)
    {
        uint cube_plane = uint(convert_xyz_to_cube_uv(q.normal).z);
        uint dir_hash = hash2(cube_plane);
        hashes ^= dir_hash + 0x9e3779b9 + (hashes << 6) + (hashes >> 2);
    }

    // when storing non-diffuse radiance without lobe filter, the cache needs to be 5d
    // and include a quantized outgoing direction.
    // TODO: use a better sphere projection and maybe transform into onb first
    if (q.use_dir)
    {
        const float max_k = 128;
        float3 uvc = convert_xyz_to_cube_uv(q.dir);
        uint4 int_dir = int4(floor(float3(max_k, max_k, 1) * uvc), 0);
        int4 dir_hashes = pcg4d(int_dir);
        hashes ^= dir_hashes +  0x9e3779b9 + (hashes << 6) + (hashes >> 2);
    }

    uint3 bdims = cdims / hdims;
    uint block_size = hdims.x * hdims.y * hdims.z;

    cache_address_t ret;
    ret.collision_hash = hashes.w;
    ret.block_idx = (hashes.x / block_size) % (bdims.x * bdims.y * bdims.z);
    ret.cell_idx3 = icell;
    return ret;
}

cache_query_result_t hash_cache_lookup(
    HASH_CACHE_SHADER_TEXTURE_TYPE<HASH_MASK_SHADER_TYPE> hash_mask,
    HASH_CACHE_SHADER_TEXTURE_TYPE<float4> cache,
    bool reserve,
    cache_query_t query
)
{
    const uint3 cdims = cache_dims(cache);
    cache_address_t hashes = compute_hashes(query, cdims);
    cache_query_result_t result = init_cache_query_result();

    for (uint i = 0; i < query.hash_map_run_length; i++)
    {
        const uint3 cache_idx = hash_collision(hashes, i, cdims);
        const uint64_t mask_val = hash_mask[cache_idx];

        bool is_cache_hit = hashes.collision_hash == decode_hash_mask(mask_val).x;
        bool is_unreserved = decode_hash_mask(mask_val).y < query.sample_idx + 1;
        bool is_uninitialized = decode_hash_mask(mask_val).x == 0;
        bool is_expired = decode_hash_mask(mask_val).y + query.hash_map_cell_lifetime < query.sample_idx + 1;

        if ((is_cache_hit && is_unreserved) || (reserve && (is_uninitialized || is_expired)))
        {
            uint64_t new_val = encode_hash_mask(uint2(hashes.collision_hash, query.sample_idx + 1));
            uint64_t orig = 0;
            InterlockedCompareExchange(hash_mask[cache_idx], mask_val, new_val, orig);

            // remember cache cell position and set hash value if we won the race, retest hash if other thread won
            if (orig == mask_val)
                result.cache_idx = cache_idx;
            else
                is_cache_hit = (decode_hash_mask(orig).x == hashes.collision_hash);
        }

        if (is_cache_hit)
        {
            // fetch the radiance value if we have a cache hit
            result.radiance = read_hashcache_entry(cache, cache_idx, cdims);
            result.cache_miss = result.radiance.diffuse.a == 0.0f;
            result.cache_hit_idx = cache_idx;
        }
        else if (!reserve && is_uninitialized)
        {
            // for pure lookups, we may terminate early if the hashchain ends here
            break;
        }

        if (!result.cache_miss || result.cache_idx.x != UINT_MAX)
            break;
        else
            result.nb_collisions++;
    }
    return result;
}

cached_radiance_t filtered_hash_lookup(
    HASH_CACHE_SHADER_TEXTURE_TYPE<HASH_MASK_SHADER_TYPE> hash_mask,
    HASH_CACHE_SHADER_TEXTURE_TYPE<float4> cache,
    const cache_query_t q
)
{
    // TODO return a result here, essentially randomly reserve one of these
    const bool lod_filter = true;
    const bool spatial_filter = false;

    float t = q.lod - floor(q.lod);
    cache_query_result_t result;

    cached_radiance_t rad0;
    if (spatial_filter)
    {
        float ilod = floor(q.lod);
        float voxel_size = pow(2.0f, ilod);
        float3 ipos = floor(q.local_pos / voxel_size);
        float3 tuv3 = (q.local_pos - (ipos * voxel_size)) / voxel_size;

        int majorAxis = (abs(q.normal[0]) >= abs(q.normal[1]) && abs(q.normal[0]) >= abs(q.normal[2])) ? 0 :
                        (abs(q.normal[1]) >= abs(q.normal[2])) ? 1 : 2;

        float2 tuv = majorAxis == 0 ? tuv3.yz : (majorAxis == 1 ? tuv3.zx : tuv3.xy);
    
        float3 axes[3] = { float3(1,0,0), float3(0,1,0), float3(0,0,1) };
        float3 pos[4];
        pos[0] = q.local_pos;
        pos[1] = q.local_pos + axes[(majorAxis + 1) % 3] * voxel_size;
        pos[2] = q.local_pos + axes[(majorAxis + 2) % 3] * voxel_size;
        pos[3] = q.local_pos + (axes[(majorAxis + 1) % 3] + axes[(majorAxis + 2) % 3]) * voxel_size;

        cache_query_t q0 = q;
        q0.local_pos = pos[0];
        cached_radiance_t rad000 = hash_cache_lookup(hash_mask, cache, false, q0).radiance;
        q0.local_pos = pos[1];
        cached_radiance_t rad001 = hash_cache_lookup(hash_mask, cache, false, q0).radiance;
        q0.local_pos = pos[2];
        cached_radiance_t rad010 = hash_cache_lookup(hash_mask, cache, false, q0).radiance;
        q0.local_pos = pos[3];
        cached_radiance_t rad011 = hash_cache_lookup(hash_mask, cache, false, q0).radiance;

        rad0 = rad000;
        rad0.diffuse = lerp(lerp(rad000.diffuse, rad001.diffuse, tuv.x), lerp(rad010.diffuse, rad011.diffuse, tuv.x), tuv.y);
    }
    else
    {
        rad0 = hash_cache_lookup(hash_mask, cache, false, q).radiance;
    }

    cached_radiance_t rad = rad0;

    if (lod_filter)
    {
        cache_query_t q1 = q;
        q1.lod += 1;

        cached_radiance_t rad1;
        if (spatial_filter)
        {
            float ilod = floor(q1.lod);
            float voxel_size = pow(2.0f, ilod);
            float3 ipos = floor(q.local_pos / voxel_size);
            float3 tuv3 = (q.local_pos - (ipos * voxel_size)) / voxel_size;

            int majorAxis = (abs(q.normal[0]) >= abs(q.normal[1]) && abs(q.normal[0]) >= abs(q.normal[2])) ? 0 :
                            (abs(q.normal[1]) >= abs(q.normal[2])) ? 1 : 2;

            float2 tuv = majorAxis == 0 ? tuv3.yz : (majorAxis == 1 ? tuv3.zx : tuv3.xy);
        
            float3 axes[3] = { float3(1,0,0), float3(0,1,0), float3(0,0,1) };
            float3 pos[4];
            pos[0] = q.local_pos;
            pos[1] = q.local_pos + axes[(majorAxis + 1) % 3] * voxel_size;
            pos[2] = q.local_pos + axes[(majorAxis + 2) % 3] * voxel_size;
            pos[3] = q.local_pos + (axes[(majorAxis + 1) % 3] + axes[(majorAxis + 2) % 3]) * voxel_size;

            q1.local_pos = pos[0];
            cached_radiance_t rad100 = hash_cache_lookup(hash_mask, cache, false, q1).radiance;
            q1.local_pos = pos[1];
            cached_radiance_t rad101 = hash_cache_lookup(hash_mask, cache, false, q1).radiance;
            q1.local_pos = pos[2];
            cached_radiance_t rad110 = hash_cache_lookup(hash_mask, cache, false, q1).radiance;
            q1.local_pos = pos[3];
            cached_radiance_t rad111 = hash_cache_lookup(hash_mask, cache, false, q1).radiance;

            rad1 = rad100;
            rad1.diffuse = lerp(lerp(rad100.diffuse, rad101.diffuse, tuv.x), lerp(rad110.diffuse, rad111.diffuse, tuv.x), tuv.y);
        }
        else
        {
           rad1 = hash_cache_lookup(hash_mask, cache, false, q1).radiance;
        }

        t = t * rad1.diffuse.a / (rad1.diffuse.a + rad0.diffuse.a);
        rad.diffuse = lerp(rad0.diffuse, rad1.diffuse, t);
    }

    return rad;
}

cached_radiance_t update_hash_cache_entry(
    uint z,
    float max_sample_history,
    HASH_CACHE_SHADER_TEXTURE_TYPE<float4> cache,
    cache_query_result_t qresult,
    float roughness,
    float4 new_diffuse,
    float4 new_glossy,
    float3 new_out_dir,
    float3 old_out_dir
)
{
    cached_radiance_t entry = qresult.radiance;
    uint3 cdims = cache_dims(cache);
    entry.diffuse = temporal_filter(entry.diffuse, new_diffuse, max_sample_history);
    if (cache_glossy)
    {
        float old_spp = entry.gloss[z].a;
        bool filter_direction = true;
        if (filter_direction)
        {
            // const float k = clamp((1.0f - sqrt(clamp(roughness, 0.001, 1.0f))) * 128, 0, 128);
            const float k = -2.0f * (roughness - 1.0f) / clamp(roughness, min_roughness, 1.0f);
            old_spp *= pow(max(0, dot(entry.out_dir[z].rgb, new_out_dir.xyz)), k);

        }
        else
        {
            float cosTheta = clamp(dot(old_out_dir.xyz, new_out_dir.xyz), 0.0f, 1.0f);
            float angle = acos(cosTheta);
            float weight = exp(-angle / max(roughness, 0.001f));
            old_spp *= weight;
        }

        entry.out_dir[z] = temporal_filter(float4(entry.out_dir[z].rgb, old_spp), float4(new_out_dir, 1), max_sample_history);
        entry.gloss[z] = temporal_filter(float4(entry.gloss[z].rgb, old_spp), new_glossy, max_sample_history);
    }

    if (qresult.cache_idx.x != UINT_MAX)
    {
        write_hashcache_entry(cache, qresult.cache_idx, cdims, entry);
    }

    return entry;
}
