/**
 * @license
 * Copyright 2021 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
/* eslint guard-for-in: 0 */
/* eslint no-restricted-syntax: 0 */
/* eslint no-else-return: 0 */
/* eslint camelcase: 0 */
/* eslint no-param-reassign: 0 */
/* eslint no-underscore-dangle: 0 */
import * as tfcore from '@tensorflow/tfjs-core';
import * as tf from '@tensorflow/tfjs';
import { TUNABLE_FLAG_VALUE_RANGE_MAP } from './params';

export function isiOS() {
  return /iPhone|iPad|iPod/i.test(navigator.userAgent);
}

export function isAndroid() {
  return /Android/i.test(navigator.userAgent);
}

export function isMobile() {
  return isAndroid() || isiOS();
}

/**
 * Reset the target backend.
 *
 * @param backendName The name of the backend to be reset.
 */
async function resetBackend(backendName) {
  const ENGINE = tfcore.engine();
  if (!(backendName in ENGINE.registryFactory)) {
    throw new Error(`${backendName} backend is not registered.`);
  }

  if (backendName in ENGINE.registry) {
    const backendFactory = tfcore.findBackendFactory(backendName);
    tfcore.removeBackend(backendName);
    tfcore.registerBackend(backendName, backendFactory);
  }

  await tfcore.setBackend(backendName);
}

/**
 * Set environment flags.
 *
 * This is a wrapper function of `tfcore.env().setFlags()` to constrain users to
 * only set tunable flags (the keys of `TUNABLE_FLAG_TYPE_MAP`).
 *
 * ```js
 * const flagConfig = {
 *        WEBGL_PACK: false,
 *      };
 * await setEnvFlags(flagConfig);
 *
 * console.log(tfcore.env().getBool('WEBGL_PACK')); // false
 * console.log(tfcore.env().getBool('WEBGL_PACK_BINARY_OPERATIONS')); // false
 * ```
 *
 * @param flagConfig An object to store flag-value pairs.
 */
export async function setBackendAndEnvFlags(flagConfig, backend) {
  if (flagConfig == null) {
    return;
  } else if (typeof flagConfig !== 'object') {
    throw new Error(`An object is expected, while a(n) ${typeof flagConfig} is found.`);
  }

  // Check the validation of flags and values.
  for (const flag in flagConfig) {
    // TODO: check whether flag can be set as flagConfig[flag].
    if (!(flag in TUNABLE_FLAG_VALUE_RANGE_MAP)) {
      throw new Error(`${flag} is not a tunable or valid environment flag.`);
    }
    if (TUNABLE_FLAG_VALUE_RANGE_MAP[flag].indexOf(flagConfig[flag]) === -1) {
      throw new Error(
        `${flag} value is expected to be in the range [${TUNABLE_FLAG_VALUE_RANGE_MAP[flag]}], while ${flagConfig[flag]}`
        + ' is found.',
      );
    }
  }

  tfcore.env().setFlags(flagConfig);

  const [runtime, $backend] = backend.split('-');

  if (runtime === 'tfjs') {
    await resetBackend($backend);
  }
}


export function cropTensor(img, grayscaleMode, grayscaleInput) {
  const size = Math.min(img.shape[0], img.shape[1]);
  const centerHeight = img.shape[0] / 2;
  const beginHeight = centerHeight - (size / 2);
  const centerWidth = img.shape[1] / 2;
  const beginWidth = centerWidth - (size / 2);

  if (grayscaleMode && !grayscaleInput) {
    // cropped rgb data
    let grayscale_cropped = img.slice([beginHeight, beginWidth, 0], [size, size, 3]);

    grayscale_cropped = grayscale_cropped.reshape([size * size, 1, 3])
    const rgb_weights = [0.2989, 0.5870, 0.1140]
    grayscale_cropped = tfcore.mul(grayscale_cropped, rgb_weights)
    grayscale_cropped = grayscale_cropped.reshape([size, size, 3]);

    grayscale_cropped = tfcore.sum(grayscale_cropped, -1)
    grayscale_cropped = tfcore.expandDims(grayscale_cropped, -1)

    return grayscale_cropped;
  }
  return img.slice([beginHeight, beginWidth, 0], [size, size, 3]);
}

export function capture(rasterElement, grayscale = false) {
  return tf.tidy(() => {
    const pixels = tf.browser.fromPixels(rasterElement);

    // crop the image so we're using the center square
    const cropped = cropTensor(pixels, grayscale);

    // Expand the outer most dimension so we have a batch size of 1
    const batchedImage = cropped.expandDims(0);

    // Normalize the image between -1 and a1. The image comes in between 0-255
    // so we divide by 127 and subtract 1.
    return batchedImage.toFloat().div(tf.scalar(127)).sub(tf.scalar(1));
  });
}


export function normalizePlain(hand) {
  const plainArray = [];
  hand.map(obj => {
    plainArray.push(obj.x);
    plainArray.push(obj.y);
    plainArray.push(obj.z);
    return true;
  });
  return plainArray;
}

export function normalize(hand) {
  return hand.map(obj => [obj.x, obj.y, obj.z]);
}

export function flatOneHot(label, numClasses) {
  const labelOneHot = new Array(numClasses).fill(0);
  labelOneHot[label] = 1;

  return labelOneHot;
}

export function fisherYates(array, seed) {
  if (!array) {
    return null;
  }
  const length = array.length;

  // need to clone array or we'd be editing original as we goo
  const shuffled = array.slice();

  for (let i = (length - 1); i > 0; i -= 1) {
    let randomIndex;
    if (seed) {
      randomIndex = Math.floor(seed() * (i + 1));
    } else {
      randomIndex = Math.floor(Math.random() * (i + 1));
    }

    [shuffled[i], shuffled[randomIndex]] = [shuffled[randomIndex], shuffled[i]];
  }
  return shuffled;
}

export function fisherYatesWithTarget(array, target) {
  if (!array) {
    return null;
  }
  const length = array.length;

  // need to clone array or we'd be editing original as we goo
  const shuffled = array.slice();
  const shuffledTarget = target.slice();

  for (let i = (length - 1); i > 0; i -= 1) {
    const randomIndex = Math.floor(Math.random() * (i + 1));

    [shuffled[i], shuffled[randomIndex]] = [shuffled[randomIndex], shuffled[i]];
    [shuffledTarget[i], shuffledTarget[randomIndex]] = [shuffledTarget[randomIndex], shuffledTarget[i]];
  }
  return { shuffled, shuffledTarget };
}

export function convertToTfDataset(samples = [], seed = null) {
  // first shuffle each class individually
  // TODO: we could basically replicate this by insterting randomly
  for (let i = 0; i < samples.length; i += 1) {
    samples[i] = fisherYates(samples[i], seed);
  }

  // then break into validation and test datasets

  let trainDataset = [];
  let validationDataset = [];

  // for each class, add samples to train and validation dataset
  for (let i = 0; i < samples.length; i += 1) {
    if (samples[i]) {
      const y = flatOneHot(i, samples.length);

      const classLength = samples[i].length;
      const numValidation = Math.ceil(0.15 * classLength);
      const numTrain = classLength - numValidation;

      const sample = [...samples[i]];
      const classTrain = sample.slice(0, numTrain).map((dataArray) => {
        return { data: dataArray, label: y };
      });

      const classValidation = sample.slice(numTrain).map((dataArray) => {
        return { data: dataArray, label: y };
      });

      trainDataset = trainDataset.concat(classTrain);
      validationDataset = validationDataset.concat(classValidation);
    }
  }

  // finally shuffle both train and validation datasets
  trainDataset = fisherYates(trainDataset, seed);
  validationDataset = fisherYates(validationDataset, seed);

  const trainX = tf.data.array(trainDataset.map(sample => sample.data));
  const validationX = tf.data.array(validationDataset.map(sample => sample.data));
  const trainY = tf.data.array(trainDataset.map(sample => sample.label));
  const validationY = tf.data.array(validationDataset.map(sample => sample.label));

  // return tfcore.data dataset objects
  return {
    trainDataset: tf.data.zip({ xs: trainX, ys: trainY }),
    validationDataset: tf.data.zip({ xs: validationX, ys: validationY }),
  };
}

export function convertToTfDatasetMultiTarget(samples = [], samplesTargets = [], seed = null) {
  // first shuffle each class individually
  // TODO: we could basically replicate this by insterting randomly
  for (let i = 0; i < samples.length; i += 1) {
    samples[i] = fisherYates(samples[i], seed);
    samplesTargets[i] = fisherYates(samplesTargets[i], seed);
  }

  // then break into validation and test datasets

  let trainDataset = [];
  let validationDataset = [];

  // for each class, add samples to train and validation dataset
  for (let i = 0; i < samples.length; i += 1) {
    const classLength = samples[i].length;
    const numValidation = Math.ceil(0.15 * classLength);
    const numTrain = classLength - numValidation;

    const sample = [...samples[i]];
    const classTrain = sample.slice(0, numTrain).map((dataArray, index) => {
      return { data: tf.tensor(dataArray).expandDims(0), label: tf.tensor(samplesTargets[i][index]) };
    });

    const classValidation = sample.slice(numTrain).map((dataArray, index) => {
      return { data: tf.tensor(dataArray).expandDims(0), label: tf.tensor(samplesTargets[i][index]) };
    });

    trainDataset = trainDataset.concat(classTrain);
    validationDataset = validationDataset.concat(classValidation);
  }

  // finally shuffle both train and validation datasets
  trainDataset = fisherYates(trainDataset, seed);
  validationDataset = fisherYates(validationDataset, seed);

  const trainX = tf.data.array(trainDataset.map(sample => sample.data));
  const validationX = tf.data.array(validationDataset.map(sample => sample.data));
  const trainY = tf.data.array(trainDataset.map(sample => sample.label));
  const validationY = tf.data.array(validationDataset.map(sample => sample.label));

  console.log('==============trainX-Y', trainDataset);
  return {
    trainDataset: tf.data.zip({ xs: trainX, ys: trainY }),
    validationDataset: tf.data.zip({ xs: validationX, ys: validationY }),
  };
}

// Convert Float32Array to base64 string
export function float32ArrayToBase64(float32Array) {
  const uint8Array = new Uint8Array(float32Array.buffer);
  const binaryString = String.fromCharCode.apply(null, uint8Array);
  return btoa(binaryString);
}

// Convert base64 string to Float32Array
export function base64ToFloat32Array(base64) {
  const binaryString = atob(base64);
  const uint8Array = new Uint8Array(binaryString.length);
  for (let i = 0; i < binaryString.length; i += 1) {
    uint8Array[i] = binaryString.charCodeAt(i);
  }
  return new Float32Array(uint8Array.buffer);
}

// Recursively convert Float32Array properties to base64 strings
export function convertObjectToBase64(obj) {
  if (obj instanceof Float32Array) {
    return { __type: 'Float32Array', data: float32ArrayToBase64(obj) };
  } else if (Array.isArray(obj)) {
    return obj.map(convertObjectToBase64);
  } else if (typeof obj === 'object' && obj !== null) {
    const result = {};
    Object.keys(obj).forEach((key) => {
      result[key] = convertObjectToBase64(obj[key]);
    });
    return result;
  }
  return obj;
}

// Recursively convert base64 strings back to Float32Array properties
export function convertBase64ToObject(obj) {
  if (obj && obj.__type === 'Float32Array') {
    return base64ToFloat32Array(obj.data);
  } else if (Array.isArray(obj)) {
    return obj.map(convertBase64ToObject);
  } else if (typeof obj === 'object' && obj !== null) {
    const result = {};
    Object.keys(obj).forEach((key) => {
      result[key] = convertBase64ToObject(obj[key]);
    });
    return result;
  }
  return obj;
}