import { useEffect, useState, useRef, useContext } from 'react';
import {
  Group,
  Mesh,
  MeshStandardMaterial,
  DoubleSide,
  RepeatWrapping,
  WebGLProgramParametersWithUniforms,
  SRGBColorSpace,
  Texture,
} from 'three';
import { useFrame } from '@react-three/fiber';
import { GLTF } from 'three/examples/jsm/loaders/GLTFLoader';

import { Triplet } from '../../../../types';
import { basicNoise3, powCurve } from '../shaders';
import { StoreContext } from '../../../../store';
import { BackgroundStore } from '../../../../store/BackgroundStore';

export type GrlTreeType =
  | 'Birch L'
  | 'Birch S'
  | 'Fir L - Variant 1'
  | 'Fir L - Variant 2'
  | 'Pine L'
  | 'Pine M'
  | 'Pine XL';
type GrlTreeFamily = 'birch' | 'fir' | 'pine';

const getTextures = async (type: GrlTreeFamily, store: BackgroundStore) => {
  return Promise.all([
    store.loadAsset<Texture>(`/grasslands/common/${type}-albedo.png`),
    store.loadAsset<Texture>(`/grasslands/common/${type}-normal.jpg`),
  ]);
};

const typeToFamily = (type: GrlTreeType): GrlTreeFamily => {
  switch (type) {
    case 'Birch L':
    case 'Birch S':
      return 'birch';
    case 'Fir L - Variant 1':
    case 'Fir L - Variant 2':
      return 'fir';
    case 'Pine L':
    case 'Pine M':
    case 'Pine XL':
      return 'pine';
  }
};

export const GrlTree = ({
  type,
  position = [0, 0, 0],
  scale = [1, 1, 1],
  rotation = [0, 0, 0],
}: {
  type: GrlTreeType;
  position?: Triplet;
  scale?: Triplet;
  rotation?: Triplet;
}) => {
  const { background: store } = useContext(StoreContext);
  const [group, setGroup] = useState<Group>();
  const leafShader = useRef<any>();

  useEffect(() => {
    getTextures(typeToFamily(type), store).then(([albedo, normal]) => {
      [albedo, normal].forEach(texture => {
        texture.colorSpace = SRGBColorSpace;
        texture.wrapS = RepeatWrapping;
        texture.wrapT = RepeatWrapping;
        texture.premultiplyAlpha = true;
        texture.flipY = true;
      });
      store.loadAsset<GLTF>(`/grasslands/${type}/${type}.gltf`).then(gltf => {
        const lod0Group = gltf.scene.children[0].children[0] as Group;
        const trunkMesh = lod0Group.children[0] as Mesh;
        const leavesMesh = lod0Group.children[1] as Mesh;
        // console.log({ type, lod0Group, leavesMesh, trunkMesh });
        // const colors = leavesMesh.geometry.getAttribute('color');
        // for (let i = 0; i < colors.count; i++) {
        //   colors.setXYZ(i, 1, 1, 1);
        // }
        const uvs = leavesMesh.geometry.getAttribute('uv');
        for (let i = 0; i < uvs.count; i++) {
          uvs.setY(i, uvs.getY(i) * -1);
        }
        const leavesMaterial = new MeshStandardMaterial({
          map: albedo,
          normalMap: normal,
          // depthWrite: false,
          // transparent: true,
          alphaTest: 0.45,
          roughness: 1,
          metalness: 0,
          premultipliedAlpha: true,
          side: DoubleSide,
        });
        const trunkMaterial = new MeshStandardMaterial({
          map: albedo,
          normalMap: normal,
          roughness: 0.84299999475479126,
          metalness: 0,
        });
        trunkMaterial.onBeforeCompile = (shader: WebGLProgramParametersWithUniforms) => {
          if (typeToFamily(type) === 'birch') {
            shader.fragmentShader = `${powCurve}` + shader.fragmentShader;
            shader.fragmentShader = shader.fragmentShader.replace(
              '#include <opaque_fragment>',
              `
                #include <opaque_fragment>
              gl_FragColor.rgb = powCurve(gl_FragColor.rgb, 0.6);
              `,
            );
          }
        };
        trunkMaterial.customProgramCacheKey = () => typeToFamily(type);
        leavesMaterial.onBeforeCompile = (shader: WebGLProgramParametersWithUniforms) => {
          shader.uniforms.time = { value: 0 };
          shader.uniforms.windIntensity = { value: 0 };
          shader.vertexShader =
            `
              uniform float time;
              uniform float windIntensity;
              varying vec3 vPosition;
              ` + shader.vertexShader;
          shader.vertexShader = shader.vertexShader.replace(
            '#include <begin_vertex>',
            `
            #include <begin_vertex>
            vPosition = transformed;
            float windEffect = sin((time * 3.0) + transformed.x * 2.0 + transformed.z * 2.0) * windIntensity * 0.04;
            windEffect += sin(time * 10.0 + transformed.x + transformed.z) * windIntensity * 0.01;
            transformed.y += windEffect;
            `,
          );

          shader.fragmentShader =
            `
            varying vec3 vPosition;
            ${basicNoise3}
            ` + shader.fragmentShader;
          if (typeToFamily(type) === 'birch') {
            // TODO: this is a bit wonky and doesn't really match the grove tree
            shader.fragmentShader = shader.fragmentShader.replace(
              '#include <opaque_fragment>',
              `
              #include <opaque_fragment>
              gl_FragColor.r = gl_FragColor.r > 0.01 ? pow(abs(gl_FragColor.r), 1.0 - noise(vPosition / 1.0) * 0.9) * 0.8 : gl_FragColor.r;
              gl_FragColor.g = gl_FragColor.g > 0.02 ? pow(abs(gl_FragColor.g), 1.0 - noise(vPosition / 4.0) * 0.8) * 0.5 : gl_FragColor.g;
              gl_FragColor.b = gl_FragColor.b > 0.015 ? pow(abs(gl_FragColor.b), 1.0 - noise(vPosition / 3.0) * 0.4) * 0.8 : gl_FragColor.b;
              `,
            );
          }
          leafShader.current = shader;
        };
        leavesMaterial.customProgramCacheKey = () => typeToFamily(type);
        trunkMesh.material = trunkMaterial;
        leavesMesh.material = leavesMaterial;
        trunkMesh.castShadow = true;
        trunkMesh.receiveShadow = true;
        leavesMesh.castShadow = true;
        leavesMesh.receiveShadow = true;
        leavesMesh.material.needsUpdate = true;
        trunkMesh.material.needsUpdate = true;
        setGroup(lod0Group);
      });
    });
  }, []);

  useFrame(state => {
    const time = state.clock.elapsedTime;
    if (leafShader.current) {
      leafShader.current.uniforms.time = { value: time };
      leafShader.current.uniforms.windIntensity = { value: store.wind.intensity };
    }
  });

  return group ? <primitive object={group} position={position} scale={scale} rotation={rotation} /> : null;
};
