#define GBUFFER_MAX_DEPTH 500.0f

struct VIn
{
    float4 p   : POSITION;
    float3 n   : NORMAL;
    float2 uv  : TEXCOORD0;
};

struct VOut
{
    float4 p   : POSITION;
    float2 uv  : TEXCOORD0;
    float3 ray : TEXCOORD1;
};

struct PIn
{
    float2 uv  : TEXCOORD0;
    float3 ray : TEXCOORD1;
};

VOut ssao_vs(VIn IN, uniform float4x4 wvp, uniform float3 farCorner)
{
    VOut OUT;
    OUT.p = mul(wvp, IN.p);
    // clean up inaccuracies for the UV coords
    float2 uv = sign(IN.p);
    // convert to image space
    uv = (float2(uv.x, -uv.y) + 1.0) * 0.5;
    OUT.uv = uv;
    // calculate the correct ray (modify XY parameters based on screen-space quad XY)
    OUT.ray = farCorner * float3(sign(IN.p.xy), 1);
    return OUT;
}

float3 computeZ(float2 xy)
{
    xy = 2.0 * xy.xy - 1.0;
    return float3(xy, sqrt(1.0 - dot(xy, xy)));
}

// for ps_3_0, we want to use tex2Dlod because it's faster
/*ps_3_0 float4 TEX2DLOD(sampler2D map, float2 uv)
{
    return tex2Dlod(map, float4(uv.xy, 0, 0));
}*/

float4 TEX2DLOD(sampler2D map, float2 uv)
{
    return tex2D(map, uv);
}

#define NUM_PAIRS 6
#define RANGE 60.0/1024.0

#define pi 3.14159

float2 GetRotatedSample(float i)
{
	return (i+1) / (NUM_PAIRS+2) * float2(cos(i / NUM_PAIRS * 2 * pi), sin (i / NUM_PAIRS * 2 * pi));
}

#define NUM_SAMPLES NUM_PAIRS*2+1

float4 ssao_ps(
    PIn IN,
    uniform float4x4 ptMat,
    uniform float far,
    uniform float4 texSize,
    uniform sampler2D geomMap : TEXUNIT0,
    uniform sampler2D randMap  : TEXUNIT1): COLOR0
{
	float2 bitShifts = float2(1.0/256.0, 1);
    
	half4 geomTex = TEX2DLOD(geomMap, IN.uv);
	half baseDepth = dot(geomTex.zw , bitShifts);
	
	half4 noiseTex = TEX2DLOD(randMap, frac(IN.uv*texSize.xy /4))*2-1;
	
	half2x2 rotation = 
	{
		{ noiseTex.y, noiseTex.x },
		{ -noiseTex.x, noiseTex.y }
	};
	
	const half2 OFFSETS1[NUM_PAIRS] =
	{
		GetRotatedSample(0),
		GetRotatedSample(1),
		GetRotatedSample(2),
		GetRotatedSample(3),
		GetRotatedSample(4),
		GetRotatedSample(5),
		/*GetRotatedSample(6),
		GetRotatedSample(7),
		GetRotatedSample(8),
		GetRotatedSample(9),
		GetRotatedSample(10),
		GetRotatedSample(11),*/
	};
	
	half occ = 1;
	
	half sphereRadiusZB = 2.0f / GBUFFER_MAX_DEPTH;
	
	#define MINPIXEL 0.0f
	#define MAXPIXEL 100.0f
	
	half radiusTex = clamp(sphereRadiusZB / baseDepth, MINPIXEL / texSize.x, MAXPIXEL / texSize.x);
	
	half numSamples = 2;
	
	for(int i = 0; i < NUM_PAIRS; i++)
	{
		half2 offset1 = mul(rotation, OFFSETS1[i]);
	
		half2 offseted1 = IN.uv + offset1 * radiusTex;
		half2 offseted2 = IN.uv - offset1 * radiusTex;
		
		half4 offsetSample1 = TEX2DLOD(geomMap, offseted1);
		half4 offsetSample2 = TEX2DLOD(geomMap, offseted2);

		half2 offsetDepth;
		offsetDepth.x = dot(offsetSample1.zw , bitShifts);
		offsetDepth.y = dot(offsetSample2.zw , bitShifts);
		
		half2 diff = offsetDepth - baseDepth.xx;
		
		half normalizedOffsetLen = (half)(i+1)/(NUM_PAIRS+2);
		
		half2 segmentDiff   = 1.5*sphereRadiusZB.xx*sqrt(1-normalizedOffsetLen*normalizedOffsetLen);
		
		half2 normalizedDiff = (diff / segmentDiff) + 0.5;
		
		
		
		half minDiff = min(normalizedDiff.x, normalizedDiff.y);
		
		// At 0, full sample
		// At -1, zero sample, zero weight
		
		half sampleadd = saturate(1+minDiff);
		
		occ += (saturate(normalizedDiff.x) +saturate(normalizedDiff.y))*sampleadd;
		numSamples += 2 * sampleadd;
		
		//occ = (1/sphereRadiusZB)*(offsetDepth.x + offsetDepth.y - 2*baseDepth) + 0.5;

 	}
	
	occ = occ / numSamples;
	

	half finalocc = saturate(occ*2);
	//finalocc = 0.5*(finalocc*finalocc+finalocc); // average between squared and bare to saturate it a bit
	
	if(baseDepth > (1.0f-1/256.0f))
		finalocc += 1;
	
	//return half4(0,ddy(baseDepth)*2000+0.5,0,occ);
	
	//return half4(1,1,1,occ);
	
	//return half4(geomTex.a,geomTex.b, 0, occ);
	
	

	//return half4(geomTex.r,geomTex.r,geomTex.r,finalocc);
	
	return half4(finalocc, finalocc, finalocc, 1);
}

#define SPECULAR_WEIGHT 3

#define NUM_BLUR_SAMPLES 3

float4 ssaoBlurX_ps(float2 uv : TEXCOORD0,
    uniform float4 invTexSize,
    uniform sampler2D map : TEXUNIT0, uniform sampler2D geomMap : TEXUNIT1): COLOR0
{
//    return TEX2DLOD(ssaoMap, uv);
    float2 o = float2(invTexSize.x, 0);
    float4 sum = TEX2DLOD(map, uv) * (NUM_BLUR_SAMPLES + 1);
    float denom = NUM_BLUR_SAMPLES + 1;
    float4 geom = TEX2DLOD(geomMap, uv);

	float2 bitShifts = float2(1.0/256.0, 1);
    
	float baseDepth = dot(geom.zw , bitShifts);
	
	float sphereRadiusZB = 1.0f / GBUFFER_MAX_DEPTH;
    
    //float3 viewNorm1 = computeZ(geom.xy);
    for (int i = 1; i <= NUM_BLUR_SAMPLES; ++i)
    {
        float2 nuv = uv + o * i;
        float4 nGeom = TEX2DLOD(geomMap, nuv);
	    //float3 viewNorm2 = computeZ(nGeom.xy);

		float4 geomNewTex = TEX2DLOD(geomMap, nuv);
		float newDepth = dot(geomNewTex.zw , bitShifts);
		float delta = newDepth - baseDepth;
	    float coef = (NUM_BLUR_SAMPLES + 1 - i) * (abs(delta) < sphereRadiusZB);
	    
        sum += TEX2DLOD(map, nuv) * coef;
        denom += coef;
    }
    for (int i = 1; i <= NUM_BLUR_SAMPLES; ++i)
    {
        float2 nuv = uv + o * -i;
        float4 nGeom = TEX2DLOD(geomMap, nuv);
	    float3 viewNorm2 = computeZ(nGeom.xy);
	   
		float4 geomNewTex = TEX2DLOD(geomMap, nuv);
		float newDepth = dot(geomNewTex.zw , bitShifts);
		float delta = newDepth - baseDepth;
	    float coef = (NUM_BLUR_SAMPLES + 1 - i) * (abs(delta) < sphereRadiusZB);

        sum += TEX2DLOD(map, nuv) * coef;
        denom += coef;
    }
    return sum / denom;// - float4(1,0,0,0);
}

float4 ssaoBlurY_ps(float2 uv : TEXCOORD0,
    uniform float4 invTexSize,
    uniform sampler2D map : TEXUNIT0, uniform sampler2D geomMap : TEXUNIT1): COLOR0
{
//    return TEX2DLOD(map, uv);
    float2 o = float2(0, invTexSize.y);
    float sum = TEX2DLOD(map, uv).r * (NUM_BLUR_SAMPLES + 1);
    float denom = NUM_BLUR_SAMPLES + 1;
    float4 geom = TEX2DLOD(geomMap, uv);

    float2 bitShifts = float2(1.0/256.0, 1);
    
   	float baseDepth = dot(geom.zw , bitShifts);
	float sphereRadiusZB = 1.0f / GBUFFER_MAX_DEPTH;


    for (int i = 1; i <= NUM_BLUR_SAMPLES; ++i)
    {
        float2 nuv = uv + o * i;
        float4 nGeom = TEX2DLOD(geomMap, nuv);
        float3 viewNorm2 = computeZ(nGeom.xy);

		float4 geomNewTex = TEX2DLOD(geomMap, nuv);
		float newDepth = dot(geomNewTex.zw , bitShifts);
		float delta = newDepth - baseDepth;
	    float coef = (NUM_BLUR_SAMPLES + 1 - i) * (abs(delta) < sphereRadiusZB);

        sum += TEX2DLOD(map, nuv).r * coef;
        denom += coef;
    }
    for (int i = 1; i <= NUM_BLUR_SAMPLES; ++i)
    {
        float2 nuv = uv + o * -i;
        float4 nGeom = TEX2DLOD(geomMap, nuv);
        float3 viewNorm2 = computeZ(nGeom.xy);
        
		float4 geomNewTex = TEX2DLOD(geomMap, nuv);
		float newDepth = dot(geomNewTex.zw , bitShifts);
		float delta = newDepth - baseDepth;
	    float coef = (NUM_BLUR_SAMPLES + 1 - i) * (abs(delta) < sphereRadiusZB);
	    
        sum += TEX2DLOD(map, nuv).r * coef;
        denom += coef;
    }
    
    float specular = geom.x;
    float diffuse = geom.y;
    
    float ssaoTerm = sum / denom;
    
    // Making specular kill SSAO faster, so it doesn't get capped by 1
    return (SPECULAR_WEIGHT*specular + diffuse * ssaoTerm) / (SPECULAR_WEIGHT*specular + diffuse + 0.001);
}



