//
// This is a modified version of the BlurCS compute shader from Microsoft's MiniEngine
// library. The copyright notice from the original version is included below.
//
// The original source code of MiniEngine is available on GitHub.
// https://github.com/Microsoft/DirectX-Graphics-Samples
//

//
// Copyright (c) Microsoft. All rights reserved.
// This code is licensed under the MIT License (MIT).
// THIS CODE IS PROVIDED *AS IS* WITHOUT WARRANTY OF
// ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY
// IMPLIED WARRANTIES OF FITNESS FOR A PARTICULAR
// PURPOSE, MERCHANTABILITY, OR NON-INFRINGEMENT.
//
// Developed by Minigraph
//
// Author:  Bob Brown
//

#pragma warning(disable : 3568)
#pragma exclude_renderers gles gles3 d3d11_9x

#include "../StdLib.hlsl"

Texture2D<float4> _Source;
RWTexture2D<float4> _Result;

SamplerState sampler_LinearClamp;

CBUFFER_START(cb)
    float4 _Size;
CBUFFER_END

// 16x16 pixels with an 8x8 center that we will be blurring writing out. Each uint is two color
// channels packed together.
// The reason for separating channels is to reduce bank conflicts in the local data memory
// controller. A large stride will cause more threads to collide on the same memory bank.
groupshared uint gs_cacheR[128];
groupshared uint gs_cacheG[128];
groupshared uint gs_cacheB[128];
groupshared uint gs_cacheA[128];

float4 BlurPixels(float4 a, float4 b, float4 c, float4 d, float4 e, float4 f, float4 g, float4 h, float4 i)
{
    return 0.27343750 * (e    )
         + 0.21875000 * (d + f)
         + 0.10937500 * (c + g)
         + 0.03125000 * (b + h)
         + 0.00390625 * (a + i);
}

void Store2Pixels(uint index, float4 pixel1, float4 pixel2)
{
    gs_cacheR[index] = f32tof16(pixel1.r) | f32tof16(pixel2.r) << 16;
    gs_cacheG[index] = f32tof16(pixel1.g) | f32tof16(pixel2.g) << 16;
    gs_cacheB[index] = f32tof16(pixel1.b) | f32tof16(pixel2.b) << 16;
    gs_cacheA[index] = f32tof16(pixel1.a) | f32tof16(pixel2.a) << 16;
}

void Load2Pixels(uint index, out float4 pixel1, out float4 pixel2)
{
    uint rr = gs_cacheR[index];
    uint gg = gs_cacheG[index];
    uint bb = gs_cacheB[index];
    uint aa = gs_cacheA[index];
    pixel1 = float4(f16tof32(rr      ), f16tof32(gg      ), f16tof32(bb      ), f16tof32(aa      ));
    pixel2 = float4(f16tof32(rr >> 16), f16tof32(gg >> 16), f16tof32(bb >> 16), f16tof32(aa >> 16));
}

void Store1Pixel(uint index, float4 pixel)
{
    gs_cacheR[index] = asuint(pixel.r);
    gs_cacheG[index] = asuint(pixel.g);
    gs_cacheB[index] = asuint(pixel.b);
    gs_cacheA[index] = asuint(pixel.a);
}

void Load1Pixel(uint index, out float4 pixel)
{
    pixel = asfloat(uint4(gs_cacheR[index], gs_cacheG[index], gs_cacheB[index], gs_cacheA[index]));
}

// Blur two pixels horizontally.  This reduces LDS reads and pixel unpacking.
void BlurHorizontally(uint outIndex, uint leftMostIndex)
{
    float4 s0, s1, s2, s3, s4, s5, s6, s7, s8, s9;
    Load2Pixels(leftMostIndex + 0, s0, s1);
    Load2Pixels(leftMostIndex + 1, s2, s3);
    Load2Pixels(leftMostIndex + 2, s4, s5);
    Load2Pixels(leftMostIndex + 3, s6, s7);
    Load2Pixels(leftMostIndex + 4, s8, s9);

    Store1Pixel(outIndex    , BlurPixels(s0, s1, s2, s3, s4, s5, s6, s7, s8));
    Store1Pixel(outIndex + 1, BlurPixels(s1, s2, s3, s4, s5, s6, s7, s8, s9));
}

void BlurVertically(uint2 pixelCoord, uint topMostIndex)
{
    float4 s0, s1, s2, s3, s4, s5, s6, s7, s8;
    Load1Pixel(topMostIndex     , s0);
    Load1Pixel(topMostIndex +  8, s1);
    Load1Pixel(topMostIndex + 16, s2);
    Load1Pixel(topMostIndex + 24, s3);
    Load1Pixel(topMostIndex + 32, s4);
    Load1Pixel(topMostIndex + 40, s5);
    Load1Pixel(topMostIndex + 48, s6);
    Load1Pixel(topMostIndex + 56, s7);
    Load1Pixel(topMostIndex + 64, s8);

    float4 blurred = BlurPixels(s0, s1, s2, s3, s4, s5, s6, s7, s8);

    // Write to the final target
    _Result[pixelCoord] = blurred;
}

#pragma kernel KMain

#ifdef DISABLE_COMPUTE_SHADERS

TRIVIAL_COMPUTE_KERNEL(KMain)

#else

[numthreads(8, 8, 1)]
void KMain(uint2 groupId : SV_GroupID, uint2 groupThreadId : SV_GroupThreadID, uint2 dispatchThreadId : SV_DispatchThreadID)
{
    // Upper-left pixel coordinate of quad that this thread will read
    int2 threadUL = (groupThreadId << 1) + (groupId << 3) - 4;

    // Downsample the block
    float2 offset = float2(threadUL);
    float4 p00 = _Source.SampleLevel(sampler_LinearClamp, (offset                    + 0.5) * _Size.zw, 0.0);
    float4 p10 = _Source.SampleLevel(sampler_LinearClamp, (offset + float2(1.0, 0.0) + 0.5) * _Size.zw, 0.0);
    float4 p01 = _Source.SampleLevel(sampler_LinearClamp, (offset + float2(0.0, 1.0) + 0.5) * _Size.zw, 0.0);
    float4 p11 = _Source.SampleLevel(sampler_LinearClamp, (offset + float2(1.0, 1.0) + 0.5) * _Size.zw, 0.0);

    // Store the 4 downsampled pixels in LDS
    uint destIdx = groupThreadId.x + (groupThreadId.y << 4u);
    Store2Pixels(destIdx     , p00, p10);
    Store2Pixels(destIdx + 8u, p01, p11);

    GroupMemoryBarrierWithGroupSync();

    // Horizontally blur the pixels in LDS
    uint row = groupThreadId.y << 4u;
    BlurHorizontally(row + (groupThreadId.x << 1u), row + groupThreadId.x + (groupThreadId.x & 4u));

    GroupMemoryBarrierWithGroupSync();

    // Vertically blur the pixels in LDS and write the result to memory
    BlurVertically(dispatchThreadId, (groupThreadId.y << 3u) + groupThreadId.x);
}

#endif // DISABLE_COMPUTE_SHADERS