/*
  Copyright(C) 2015, tetraface Inc. All rights reserved.
  
  This is a modified implementation of Disney's BRDF shader.
  The original code is here:
    https://github.com/wdas/brdf
*/
/*
# Copyright Disney Enterprises, Inc.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License
# and the following modification to it: Section 6 Trademarks.
# deleted and replaced with:
#
# 6. Trademarks. This License does not grant permission to use the
# trade names, trademarks, service marks, or product names of the
# Licensor and its affiliates, except as required for reproducing
# the content of the NOTICE file.
#
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
*/

#include "world.inc"
//#include "material.inc"
#include "light.inc"
#if OIT
#include "oit.inc"
#endif
#define USE_TANGENT 1

cbuffer ConstantBufferMaterial : register( b2 )
{
	float presence; //Alpha
	float4 BaseColor;
	float4 emissive;
	float metallic; //reflection;
	float specular;
	float specularTint;
	float anisotropic;
	float roughness;
	float subsurface;
	float4 subsurfaceColor;
	float sheen;
	float sheenTint;
	float clearcoat;
	float clearcoatGloss;
}

#ifndef ALPHAMAP
#define ALPHAMAP TEXTURE
#endif

#if TEXTURE
Texture2D ColorMap : register( t0 );
SamplerState ColorMapSampler : register( s0 );
#endif
#if ALPHAMAP
Texture2D AlphaMap : register( t1 );
SamplerState AlphaMapSampler : register( s1 );
#endif
#if NORMALMAP
Texture2D NormalMap : register( t2 );
SamplerState NormalMapSampler : register( s2 );
#endif
#if SHADOW
Texture2D ShadowMap : register( t3 );
SamplerState ShadowMapSampler : register( s3 );
#endif

struct VS_INPUT
{
	float4 Pos : POSITION;
	float3 Normal : NORMAL;
#if 1 //always //TEXTURE || ALPHAMAP || NORMALMAP
	float2 TexCoord : TEXCOORD0;
#endif
#if VERTEXCOLOR
	float4 Col : COLOR0;
#endif
#if NORMALMAP || USE_TANGENT
	float4 Tangent : TANGENT;
#endif
};

struct VS_OUTPUT
{
	float4 Pos : SV_POSITION;
	float4 Col : COLOR;
	float3 Normal : NORMAL;
#if TEXTURE || ALPHAMAP || NORMALMAP
	float2 TexCoord : TEXCOORD0;
#endif
#if NORMALMAP || USE_TANGENT
	float3 Tangent : TANGENT;
	float3 Binormal : BINORMAL;
#endif
	float4 WorldPos : TEXCOORD3;
#if SHADOW
	float4 SMPos : TEXCOORD4;
#endif
};

struct PS_INPUT
{
	float4 Pos : SV_POSITION;
	float4 Col : COLOR;
	float3 Normal : NORMAL;
#if TEXTURE || ALPHAMAP || NORMALMAP
	float2 TexCoord : TEXCOORD0;
#endif
#if NORMALMAP || USE_TANGENT
	float3 Tangent : TANGENT;
	float3 Binormal : BINORMAL;
#endif
	float4 WorldPos : TEXCOORD3;
#if SHADOW
	float4 SMPos : TEXCOORD4;
#endif
	bool IsBack : SV_IsFrontFace; // Why inverted?
#if MSAA
    uint Coverage : SV_COVERAGE;
#endif
};

struct PS_OUTPUT
{
	float4 Color    : SV_Target0;
#if RENDER_NORMAL
	float4 Normal   : SV_Target1;
#endif
};

#include "oit_store.inc"


// Vertex shader
VS_OUTPUT VS(const VS_INPUT In)
{
	VS_OUTPUT Out;
	Out.Pos = mul(In.Pos, WorldViewProj);
#if VERTEXCOLOR
	Out.Col.xyz = In.Col.xyz;	//RGB
	Out.Col.w = BaseColor.w * In.Col.w;
#else
	Out.Col = BaseColor;
#endif
	Out.Normal = In.Normal;
#if TEXTURE || ALPHAMAP || NORMALMAP
	Out.TexCoord = In.TexCoord;
#endif
#if NORMALMAP || USE_TANGENT
	if(length(In.Tangent.xyz) > 0){
		Out.Tangent = In.Tangent.xyz;
		Out.Binormal = normalize(cross(In.Tangent.xyz, In.Normal)) * In.Tangent.w;
	}else{
		if(abs(dot(In.Normal,float3(0,1,0))) < abs(dot(In.Normal,float3(0,0,1))))
			Out.Tangent = normalize(cross(float3(0,1,0),In.Normal));
		else
			Out.Tangent = normalize(cross(float3(0,0,1),In.Normal));
		Out.Binormal = normalize(cross(Out.Tangent, In.Normal));
	}
#endif
	Out.WorldPos = In.Pos;
#if SHADOW
	Out.SMPos = mul(In.Pos, ShadowMapProj);
#endif
	return Out;
}


// Geometry shader
//   Convert LineAdjacent (with 4 vertices) to TriangleStrip(with 3 or 4 vertices).
[maxvertexcount(4)]
void GSpatch(lineadj VS_OUTPUT input[4], inout TriangleStream<VS_OUTPUT> stream)
{
	VS_OUTPUT output;

	output.Pos = input[1].Pos;
	output.Col = input[1].Col;
	output.Normal = input[1].Normal;
#if TEXTURE || ALPHAMAP || NORMALMAP
	output.TexCoord = input[1].TexCoord;
#endif
#if NORMALMAP || USE_TANGENT
	output.Tangent = input[1].Tangent;
	output.Binormal = input[1].Binormal;
#endif
	output.WorldPos = input[1].WorldPos;
#if SHADOW
	output.SMPos = input[1].SMPos;
#endif
	stream.Append(output);

	output.Pos = input[0].Pos;
	output.Col = input[0].Col;
	output.Normal = input[0].Normal;
#if TEXTURE || ALPHAMAP || NORMALMAP
	output.TexCoord = input[0].TexCoord;
#endif
#if NORMALMAP || USE_TANGENT
	output.Tangent = input[0].Tangent;
	output.Binormal = input[0].Binormal;
#endif
	output.WorldPos = input[0].WorldPos;
#if SHADOW
	output.SMPos = input[0].SMPos;
#endif
	stream.Append(output);

	output.Pos = input[2].Pos;
	output.Col = input[2].Col;
	output.Normal = input[2].Normal;
#if TEXTURE || ALPHAMAP || NORMALMAP
	output.TexCoord = input[2].TexCoord;
#endif
#if NORMALMAP || USE_TANGENT
	output.Tangent = input[2].Tangent;
	output.Binormal = input[2].Binormal;
#endif
	output.WorldPos = input[2].WorldPos;
#if SHADOW
	output.SMPos = input[2].SMPos;
#endif
	stream.Append(output);

	if(length(input[3].Normal) > 0.0){
		output.Pos = input[3].Pos;
		output.Col = input[3].Col;
		output.Normal = input[3].Normal;
#if TEXTURE || ALPHAMAP || NORMALMAP
		output.TexCoord = input[3].TexCoord;
#endif
#if NORMALMAP || USE_TANGENT
		output.Tangent = input[3].Tangent;
		output.Binormal = input[3].Binormal;
#endif
		output.WorldPos = input[3].WorldPos;
#if SHADOW
		output.SMPos = input[3].SMPos;
#endif
		stream.Append(output);
	}

	stream.RestartStrip();
}


float GetShadow(PS_INPUT In)
{
#if SHADOW
	float2 smt = float2((In.SMPos.x+1)*0.5, 1.0-(In.SMPos.y+1)*0.5);
	float smz = ShadowMap.Sample(ShadowMapSampler, smt.xy).r;
	float sm_coef = (In.SMPos.z < smz+0.005) ? 1.0 : 0.5;
	return sm_coef;
#else
	return 1.0;
#endif
}

float4 GetBaseColor(PS_INPUT In)
{
#if TEXTURE
	// Replace the base color
	float4 col = ColorMap.Sample(ColorMapSampler, In.TexCoord);
	col.a *= In.Col.w;
#else
	float4 col = In.Col;
#endif
#if ALPHAMAP
	float4 alpha_col = AlphaMap.Sample(AlphaMapSampler, In.TexCoord);
	col.a *= alpha_col.a;
#endif

	// Cancel to write if the color is almost transparent.
	clip(col.a - 1.0/255.0);
	return col;
}


static const float PI = 3.14159265358979323846;

float sqr(float x) { return x*x; }

float SchlickFresnel(float u)
{
    float m = saturate(1-u);
    float m2 = m*m;
    return m2*m2*m; // pow(m,5)
}

float GTR1(float NdotH, float a)
{
    if (a >= 1) return 1/PI;
	else{
		float a2 = a*a;
		float t = 1 + (a2 - 1)*NdotH*NdotH;
		return (a2 - 1) / (PI*log(a2)*t);
	}
}

float GTR2(float NdotH, float a)
{
    float a2 = a*a;
    float t = 1 + (a2-1)*NdotH*NdotH;
    return a2 / (PI * t*t);
}

float GTR2_aniso(float NdotH, float HdotX, float HdotY, float ax, float ay)
{
	float d = PI * ax*ay * sqr( sqr(HdotX/ax) + sqr(HdotY/ay) + NdotH*NdotH );
    return (d > 0) ? 1 / d : 0;
}

float smithG_GGX(float Ndotv, float alphaG)
{
	float a = alphaG*alphaG;
	float b = Ndotv*Ndotv;
	return 1 / (saturate(Ndotv) + sqrt(a + b - a*b));
}

float3 BRDF(float3 L, float3 V, float3 N, float3 X, float3 Y, float3 base_color, float shadow)
{
    float NdotL = dot(N,L);
    float NdotV = dot(N,V);
    if (NdotL <= 0){
		return float3(0,0,0);
	}else{
		float3 H = normalize(L+V);
		float NdotH = dot(N,H);
		float LdotH = dot(L,H);

		float3 Cdlin = base_color;
		float Cdlum = .3*Cdlin[0] + .6*Cdlin[1]  + .1*Cdlin[2]; // luminance approx.

		float3 Ctint = Cdlum > 0 ? Cdlin/Cdlum : float3(1,1,1); // normalize lum. to isolate hue+sat
		float3 Cspec0 = lerp(specular*0.08*lerp(float3(1,1,1), Ctint, specularTint), Cdlin, metallic);
		float3 Csheen = lerp(float3(1,1,1), Ctint, sheenTint);

		// Diffuse fresnel - go from 1 at normal incidence to .5 at grazing
		// and lerp in diffuse retro-reflection based on roughness
		float FL = SchlickFresnel(NdotL);
		float FV = SchlickFresnel(NdotV);
		float Fd90 = 0.5 + 2 * LdotH*LdotH * roughness;
		float Fd = lerp(1, Fd90, FL) * lerp(1, Fd90, FV);

		// Based on Hanrahan-Krueger brdf approximation of isotropic bssrdf
		// 1.25 scale is used to (roughly) preserve albedo
		// Fss90 used to "flatten" retroreflection based on roughness
		float Fss90 = LdotH*LdotH*roughness;
		float Fss = lerp(1, Fss90, FL) * lerp(1, Fss90, FV);
		float ss = (NdotL + NdotV > 0) ? 1.25 * (Fss * (1 / (NdotL + NdotV) - .5) + .5) : 0;

		// specular
		float aspect = sqrt(1-anisotropic*.9);
		float ax = max(.001, sqr(roughness)/aspect);
		float ay = max(.001, sqr(roughness)*aspect);
		float Ds = GTR2_aniso(NdotH, dot(H, X), dot(H, Y), ax, ay);
		float FH = SchlickFresnel(LdotH);
		float3 Fs = lerp(Cspec0, float3(1,1,1), FH);
		float roughg = sqr(roughness*.5+.5);
		float Gs = smithG_GGX(NdotL, roughg) * smithG_GGX(NdotV, roughg);

		// sheen
		float3 Fsheen = FH * sheen * Csheen;

		// clearcoat (ior = 1.5 -> F0 = 0.04)
		float Dr = GTR1(NdotH, lerp(.1,.001,clearcoatGloss));
		float Fr = lerp(.04, 1, FH);
		float Gr = smithG_GGX(NdotL, .25) * smithG_GGX(NdotV, .25);

		float3 diffuse = (1/PI) * lerp(Fd*Cdlin, ss*subsurfaceColor.xyz, subsurface) * shadow;
		float3 result = (diffuse+Fsheen)*(1-metallic) + Gs*Fs*Ds + .25*clearcoat*Gr*Fr*Dr;
		return result * NdotL;
	}
}


// Pixel shader
#if OIT
[earlydepthstencil]
#endif
PS_OUTPUT PS(PS_INPUT In)
{
	PS_OUTPUT output;
	
	if(IsArbClipped(In.WorldPos.xyz))
		discard;

	float4 base_color = GetBaseColor(In);
	float shadow = GetShadow(In);
	
#if NORMALMAP
	float3 normal_col = NormalMap.Sample(NormalMapSampler, In.TexCoord).xyz * 2.0f - 1.0f;
	normal_col *= NormalMapFlip;
	normalize(normal_col);

	float3x3 mtxn = {In.Tangent, In.Binormal, In.Normal};
	float3 nrm_v = mul(normal_col, mtxn);
	float3x3 mtxt = {In.Binormal, In.Normal, In.Tangent};
	float3 tan_v = mul(normal_col, mtxt);
	float3x3 mtxb = {In.Normal, In.Tangent, In.Binormal};
	float3 bin_v = mul(normal_col, mtxb);
#else
	float3 nrm_v = In.Normal;
	float3 tan_v = In.Tangent;
	float3 bin_v = In.Binormal;
#endif
	if(In.IsBack){
		nrm_v = -nrm_v;
		tan_v = -tan_v;
		bin_v = -bin_v;
		base_color.xyz *= BackFaceColor.xyz;
	}

#if MULTILIGHT
	float3 b = float3(0,0,0);
	for(int i=0; i<LIGHT_MAX && i<LightNum; i++){
		float3 light_dir;
		if(LightPos[i].w == 0){
			light_dir = normalize(LightPos[i].xyz - In.WorldPos.xyz); // point light
		}else{
			light_dir = LightPos[i].xyz; // directional light
		}
		b += BRDF(light_dir, ViewDir, nrm_v, tan_v, bin_v, base_color.xyz, shadow);
	}
#else
	float3 b = BRDF(LightDir, ViewDir, nrm_v, tan_v, bin_v, base_color.xyz, shadow);
#endif

	output.Color.xyz = pow(saturate(b * PI + emissive.xyz), 1.0/2.2);
	output.Color.w = base_color.w;

#if OIT
    StoreOIT(In, output.Color);
    output.Color = float4(0,0,0,0);	// This does not affect anything because RenderTargetWriteMask is 0.
#endif

#if RENDER_NORMAL
	output.Normal = float4(mul(In.Normal.xyz, (float3x3)WorldView), 1);
#endif

	return output;
}

