/* eslint no-restricted-syntax: 0 */
/* eslint no-await-in-loop: 0 */
/* eslint no-else-return: 0 */
/* eslint no-unreachable: 0 */
/* eslint no-loop-func: 0 */
import * as tfcore from '@tensorflow/tfjs-core';
import * as tf from '@tensorflow/tfjs';
import * as mobilenetModule from '@tensorflow-models/mobilenet';
import * as seedrandom from 'seedrandom';
import { Camera } from '../camera';
import { STATE, MODEL_TYPES } from '../params';
import { fisherYatesWithTarget } from '../util';


export default class ObjectsHandler {
  modelType = MODEL_TYPES.OBJECTS;

  state = STATE;

  seed = seedrandom('testSuite');

  trainingParams = {
    denseUnits: 200,
    epochs: 20,
    learningRate: 0.001,
    batchSize: 48,
  };

  samplesRectangleMinSize = 70;

  removeImageBackground = false;

  imageAugmentation = [
    'unchanged', // required
    'flipLeftRight',
    'rotate45',
    'rotate90',
    'rotate135',
    'rotate180',
    // 'zoomIn',
  ];

  samplesRectangleMaxSize = 1; // videoHeight / rectangleMaxSize

  samplesDisplacementFactor = 3; // rectangleWidth / displacementFactor

  samplesPreviewTime = 1; // ms. 0 to disable

  rectangleDrawingDebouncer = 0; // Average samples. 0 to disable

  rectangleDrawingPosition = { x0: 0, y0: 0, x1: 0, y1: 0 };

  rectangleDrawingPositionAvg = [];

  rectangleRandomPosition = true;

  CANVAS_WIDTH = 360;

  CANVAS_HEIGHT = 270;

  classColors = [
    '#60aaff',
    '#60ff63',
    '#fffa60',
    '#fc60ff',
    '#ff8d3b',
    '#3bffff',
    '#a03bff',
    '#3bff93',
  ];

  wrapper = null;

  camera = null;

  rafId = null;

  classifying = false;

  predictionCallback = null;

  model = null;

  truncatedModel = null;

  samples = [];

  sampleClippedImage = null;

  samplesTargets = [];

  trainDataset = null;

  validationDataset = null;

  classNamesToIndex = {};

  outputShape = [1, 224, 224, 3];

  rectangleSize = 100;

  videoCenterX = 0;

  videoCenterY = 0;

  constructor(wrapper, predictionCallback) {
    this.wrapper = wrapper;
    this.predictionCallback = predictionCallback;
  }

  init = async () => {
    console.log('Handling model Objects');
    await this.initDefaultValueMap();
    this.camera = await Camera.setupCamera(this.state.camera, this.wrapper);
    this.renderPrediction();
    await this.loadPreTrainedModel();
  }

  initDefaultValueMap = async () => {
    const newState = { ...STATE };
    // newState.camera = { ...STATE.camera, sizeOption: '224 X 224' };

    this.state = newState;
  }

  loadPreTrainedModel = async () => {
    this.state.model = await mobilenetModule.load({
      version: 2,
      alpha: 1.0,
    });
    const embeddings = this.state.model.infer(tfcore.browser.fromPixels(this.camera.video), true);
    this.outputShape = embeddings.shape;
    // this.outputShape = embeddings.shape[1];
    // console.log('===============model ', embeddings, embeddings.dataSync());
    embeddings.dispose();
  }

  renderPrediction = async () => {
    await this.renderResult();
    this.rafId = requestAnimationFrame(this.renderPrediction);
    return null;
  }

  renderResult = async () => {
    const camera = this.camera;
    if (camera.video.readyState < 2) {
      await new Promise((resolve) => {
        camera.video.onloadeddata = () => {
          resolve(true);
        };
      });
    }

    if (this.classifying) {
      const prediction = await this.classify();
      camera.drawCtx();
      if (prediction?.data && prediction?.confidence > 0) {
        if (this.rectangleDrawingDebouncer > 0) {
          if (this.rectangleDrawingPositionAvg.length > this.rectangleDrawingDebouncer) {
            this.rectangleDrawingPosition = {
              x0: this.rectangleDrawingPositionAvg.reduce((acc, val) => acc + val[0], 0) / this.rectangleDrawingPositionAvg.length,
              y0: this.rectangleDrawingPositionAvg.reduce((acc, val) => acc + val[1], 0) / this.rectangleDrawingPositionAvg.length,
              x1: this.rectangleDrawingPositionAvg.reduce((acc, val) => acc + val[2], 0) / this.rectangleDrawingPositionAvg.length,
              y1: this.rectangleDrawingPositionAvg.reduce((acc, val) => acc + val[3], 0) / this.rectangleDrawingPositionAvg.length,
            };
            this.rectangleDrawingPositionAvg = [];
          } else {
            this.rectangleDrawingPositionAvg.push([
              prediction.data[1],
              prediction.data[3],
              prediction.data[2],
              prediction.data[4],
            ]);
          }
        } else {
          this.rectangleDrawingPosition = {
            x0: prediction.data[1],
            y0: prediction.data[3],
            x1: prediction.data[2],
            y1: prediction.data[4],
          };
        }

        camera.drawRectangle(
          this.rectangleDrawingPosition.x0,
          this.rectangleDrawingPosition.y0,
          this.rectangleDrawingPosition.x1,
          this.rectangleDrawingPosition.y1,
          prediction.className,
          this.classColors[this.classNamesToIndex[prediction.className]],
          '4',
          false,
        );
      }
    } else {
      camera.drawCtx();
      if (!this.camera.preview) {
        const size = this.calculateReactangleOnCurrentFrame();
        if (size) {
          camera.drawRectangle(size.x0, size.y0, size.x1, size.y1);
        }
      }
    }
  }

  mapValue = (value, inMin, inMax, outMin, outMax) => {
    return ((value - inMin) * (outMax - outMin)) / (inMax - inMin) + outMin;
  }

  classify = async () => {
    if (this.model && this.predictionCallback) {
      const results = {};
      let resultsData = null;

      const imageTensor = tfcore.browser.fromPixels(this.camera.video);
      const embeddings = this.state.model.infer(imageTensor, true);
      const logits = this.model.predict(embeddings.expandDims(0));
      const data = await logits.data();
      const mult = this.camera.video.videoWidth / Object.keys(this.classNamesToIndex).length;
      const predictedClass = data[0] / mult;
      for (let i = 0; i < Object.keys(this.classNamesToIndex).length; i += 1) {
        const className = Object.keys(this.classNamesToIndex)[i];
        const classNameValue = this.classNamesToIndex[className];
        const confidence = Math.abs(classNameValue - predictedClass);
        if (confidence <= 0.3) {
          const confidencePercentage = this.mapValue(confidence, 0.5, 0, 0, 1); // 0: class 1 100%; 1: class 2 100%; 0.5: in doubt 50%/50%
          // console.log('===============classify', confidencePercentage, predictedClass, data[0]);
          results[className] = confidencePercentage;
          resultsData = {
            data,
            className,
            confidence: confidencePercentage,
          };
        }
      }
      logits.dispose();
      imageTensor.dispose();
      embeddings.dispose();
      this.predictionCallback(results);
      return resultsData;
    }
    return null;
  }

  startVideo = async () => {
    this.camera = await Camera.setupCamera(this.state.camera, this.wrapper);
    this.renderPrediction();
  }

  stopVideo = () => {
    window.cancelAnimationFrame(this.rafId);
    this.rafId = null;

    if (this.camera != null) {
      this.camera.video.srcObject.getVideoTracks().forEach((track) => track.stop());
      this.camera.clearCtx();
    }

    if (this.predictionCallback && this.classifying) {
      this.predictionCallback({});
    }
  }

  clearAllSamples = () => {
    this.classifying = false;
    this.samples = [];
    this.samplesTargets = [];
    this.classNamesToIndex = {};
    // this.model = null;
  }

  startClassifying = () => {
    this.classifying = true;
  }

  stopClassifying = () => {
    this.classifying = false;
  }

  setTrainingParams = (params) => {
    this.trainingParams = {
      ...this.trainingParams,
      ...params,
    };
  }

  getTrainingParams = () => {
    return this.trainingParams;
  }

  calculateReactangleOnCurrentFrame = (resizeTo = 0) => {
    if (this.camera) {
      if (!this.videoCenterX || !this.videoCenterY) {
        this.videoCenterX = this.camera.video.videoWidth / 2;
        this.videoCenterY = this.camera.video.videoHeight / 2;
        this.rectangleSize = 100;
      }
      if (this.rectangleRandomPosition) {
        this.rectangleSize = Math.floor(Math.random() * ((this.camera.video.videoHeight / this.samplesRectangleMaxSize) - this.samplesRectangleMinSize - 1) + this.samplesRectangleMinSize);
        const minX = this.rectangleSize / 2;
        const minY = this.rectangleSize / 2;
        const maxX = this.camera.video.videoWidth - minX;
        const maxY = this.camera.video.videoHeight - minY;
        this.videoCenterX = Math.floor(Math.random() * (maxX - minX + 1) + minX);
        this.videoCenterY = Math.floor(Math.random() * (maxY - minY + 1) + minY);
        this.rectangleRandomPosition = false;
      }

      const centerY = this.videoCenterY;
      const halfSize = this.rectangleSize / 2;
      const x0 = this.videoCenterX - halfSize;
      const y0 = this.videoCenterY - halfSize;
      const x1 = this.videoCenterX + halfSize;
      const y1 = this.videoCenterY + halfSize;
      if (resizeTo > 0) {
        const newX0 = (x0 * resizeTo) / this.camera.video.videoWidth;
        const newY0 = (y0 * resizeTo) / this.camera.video.videoHeight;
        const newX1 = (x1 * resizeTo) / this.camera.video.videoWidth;
        const newY1 = (y1 * resizeTo) / this.camera.video.videoHeight;
        return { x0: newX0, x1: newX1, y0: newY0, y1: newY1 };
      } else {
        return { x0, x1, y0, y1 };
      }
    } else {
      return null;
    }
  }

  calculateFeaturesOnCurrentFrame = async (imageTensor, className = 0) => {
    if (this.camera) {
      const rotatedRectangle = this.calculateReactangleOnCurrentFrame();
      const boundingBox = [rotatedRectangle.x0, rotatedRectangle.x1, rotatedRectangle.y0, rotatedRectangle.y1];

      if (this.samplesPreviewTime) {
        const normalizedTensorFrame = imageTensor.div(255);
        await new Promise((resolve) => {
          tfcore.browser.toPixels(tf.squeeze(normalizedTensorFrame), this.camera.canvas).then(() => {
            this.camera.drawRectangle(rotatedRectangle.x0, rotatedRectangle.y0, rotatedRectangle.x1, rotatedRectangle.y1);
            setTimeout(() => {
              resolve();
            }, this.samplesPreviewTime);
          });
        });
        this.camera.ctx.translate(this.camera.video.videoWidth, 0);
        this.camera.ctx.scale(-1, 1);
        normalizedTensorFrame.dispose();
      }

      return tf.tidy(() => {
        // const targetTensor = tf.tensor1d([className].concat(boundingBox));
        const targetTensor = [className].concat(boundingBox);
        return { image: imageTensor, target: targetTensor };
      });
    } else {
      return null;
    }
  }

  calculateFeaturesOnCurrentFrameWithOffset = async (imageTensor, offsetX, offsetY, rectangle, className = 0) => {
    if (this.camera) {
      const width = rectangle.x1 - rectangle.x0;
      const height = rectangle.y1 - rectangle.y0;
      const newX1 = width + offsetX;
      const newY1 = height + offsetY;
      const boundingBox = [offsetX, newX1, offsetY, newY1];
      // console.log('===============addSample', offsetX, offsetY, newX1, newY1, rectangle);

      if (this.samplesPreviewTime) {
        const normalizedTensorFrame = imageTensor.div(255);
        await new Promise((resolve) => {
          tfcore.browser.toPixels(tf.squeeze(normalizedTensorFrame), this.camera.canvas).then(() => {
            this.camera.drawRectangle(offsetX, offsetY, newX1, newY1);
            setTimeout(() => {
              resolve();
            }, this.samplesPreviewTime);
          });
        });
        this.camera.ctx.translate(this.camera.video.videoWidth, 0);
        this.camera.ctx.scale(-1, 1);
        normalizedTensorFrame.dispose();
      }

      return tf.tidy(() => {
        const targetTensor = [className].concat(boundingBox);
        // const targetTensor = tf.tensor1d([className].concat(boundingBox));
        return { image: imageTensor, target: targetTensor, edgeX: offsetX + (width / this.samplesDisplacementFactor), edgeY: offsetY + (height / this.samplesDisplacementFactor) };
      });
    } else {
      return null;
    }
  }

  blurImage(imageTensor, level) {
    const blurredOutsideRectangle = tf.conv2d(imageTensor.cast('float32'), tf.ones([level, level, imageTensor.shape[2], imageTensor.shape[2]]).div(30), 1, 'same');

    return blurredOutsideRectangle;
  }

  displaceRectangleImage(imageTensor, dx, dy) {
    const translationMatrix = tf.tensor2d([1, 0, dx, 0, 1, dy, 0, 0], [1, 8]);
    const displacedTensor = tf.image.transform(imageTensor.expandDims(0).cast('float32'), translationMatrix, 'nearest', 'nearest', 0);
    return displacedTensor;
  }

  async loadImageAsTensor(imagePath) {
    const img = new Image();
    img.src = imagePath;
    await new Promise((resolve) => {
      img.onload = resolve;
    });
    const imageTensor = tf.browser.fromPixels(img).resizeNearestNeighbor([this.CANVAS_HEIGHT, this.CANVAS_WIDTH]);
    return imageTensor;
  }

  displaceRectangleImageWithFill(imageTensor, offsetX, offsetY, rectangle, augmentation = 'unchanged') {
    const [x0, y0] = [rectangle.x0, rectangle.y0];
    const [x1, y1] = [rectangle.x1, rectangle.y1];
    const width = (x1 < imageTensor.shape[1]) ? x1 - x0 : imageTensor.shape[1] - x0;
    const height = (y1 < imageTensor.shape[0]) ? y1 - y0 : imageTensor.shape[0] - y0;
    const averagePixelValueTensor = imageTensor.mean([0, 1]);
    const averagePixelValue = averagePixelValueTensor.dataSync()[0];

    // cut out the rectangle from the background
    if (!this.sampleClippedImage) {
      if (this.removeImageBackground) {
        // [Opt. 1]: fill the whole background with black
        // this.sampleClippedImage = tf.zerosLike(imageTensor);
        // [Opt. 2]: fill the whole background with the average pixel value
        this.sampleClippedImage = tf.tidy(() => {
          const averagePixelValueExpanded = averagePixelValueTensor.expandDims(0).expandDims(0);
          return averagePixelValueExpanded.tile([imageTensor.shape[0], imageTensor.shape[1], 1]);
        });
      } else {
        // fill the rectangle with the average pixel value
        const topPad = y0;
        const bottomPad = imageTensor.shape[0] - y1;
        const leftPad = x0;
        const rightPad = imageTensor.shape[1] - x1;
        const blackRectangle = imageTensor.slice([y0, x0, 0], [height, width, imageTensor.shape[2]]);
        const mask = tf.onesLike(blackRectangle);
        const paddedMask = mask.pad([[topPad, bottomPad], [leftPad, rightPad], [0, 0]]);
        const invertedMask = tf.onesLike(paddedMask).sub(paddedMask);
        const invertedMaskGrey = invertedMask.clipByValue(0.2, 1);
        const outsideRectangle = imageTensor.mul(invertedMaskGrey);
        const filledRectangle = tf.tidy(() => {
          const averagePixelValueExpanded = averagePixelValueTensor.expandDims(0).expandDims(0);
          return averagePixelValueExpanded.tile([height, width, 1]);
        });
        const paddedFilledRectangle = filledRectangle.pad([[topPad, bottomPad], [leftPad, rightPad], [0, 0]]).clipByValue(0, 155);
        const outsideRectangleFilled = outsideRectangle.add(paddedFilledRectangle);
        this.sampleClippedImage = outsideRectangleFilled.clipByValue(0, 255);

        blackRectangle.dispose();
        mask.dispose();
        paddedMask.dispose();
        invertedMask.dispose();
        invertedMaskGrey.dispose();
        outsideRectangle.dispose();
        filledRectangle.dispose();
        paddedFilledRectangle.dispose();
        outsideRectangleFilled.dispose();
      }
    }

    // extract the rectangle and displaced it
    const displacementBottomPad = imageTensor.shape[0] - offsetY - height;
    const displacementRightPadPad = imageTensor.shape[1] - offsetX - width;

    // transform image to get variance
    const rectangleImage = tf.slice(imageTensor, [y0, x0, 0], [height, width, imageTensor.shape[2]]);
    let rectangleImageMod = rectangleImage;
    if (augmentation === 'flipLeftRight') {
      const preRectangleImage = tf.image.flipLeftRight(rectangleImage.cast('float32').expandDims(0));
      rectangleImageMod = tf.slice(tf.squeeze(preRectangleImage), [0, 0, 0], [height, width, imageTensor.shape[2]]);
    } else if (augmentation === 'rotate45') {
      const preRectangleImage = tf.image.rotateWithOffset(rectangleImage.cast('float32').expandDims(0), Math.PI / 4, averagePixelValue, 0.5);
      rectangleImageMod = tf.slice(tf.squeeze(preRectangleImage), [0, 0, 0], [height, width, imageTensor.shape[2]]);
    } else if (augmentation === 'rotate90') {
      const preRectangleImage = tf.image.rotateWithOffset(rectangleImage.cast('float32').expandDims(0), Math.PI / 2, averagePixelValue, 0.5);
      rectangleImageMod = tf.slice(tf.squeeze(preRectangleImage), [0, 0, 0], [height, width, imageTensor.shape[2]]);
    } else if (augmentation === 'rotate135') {
      const preRectangleImage = tf.image.rotateWithOffset(rectangleImage.cast('float32').expandDims(0), 3 * (Math.PI / 4), averagePixelValue, 0.5);
      rectangleImageMod = tf.slice(tf.squeeze(preRectangleImage), [0, 0, 0], [height, width, imageTensor.shape[2]]);
    } else if (augmentation === 'rotate180') {
      const preRectangleImage = tf.image.rotateWithOffset(rectangleImage.cast('float32').expandDims(0), Math.PI / 1, averagePixelValue, 0.5);
      rectangleImageMod = tf.slice(tf.squeeze(preRectangleImage), [0, 0, 0], [height, width, imageTensor.shape[2]]);
    } else if (augmentation === 'zoomIn') {
      const preRectangleImage = tf.image.cropAndResize(rectangleImage.cast('float32').expandDims(0), [[0.2, 0.2, 0.8, 0.8]], [0], [rectangleImage.shape[0], rectangleImage.shape[1]]);
      rectangleImageMod = tf.slice(tf.squeeze(preRectangleImage), [0, 0, 0], [height, width, imageTensor.shape[2]]);
    }


    const rectangleImagePadded = rectangleImageMod.pad([[offsetY, displacementBottomPad], [offsetX, displacementRightPadPad], [0, 0]]);

    // add the displaced rectangle to the background
    const mask2 = tf.onesLike(rectangleImage).pad([[offsetY, displacementBottomPad], [offsetX, displacementRightPadPad], [0, 0]]);
    const invertedMask2 = tf.onesLike(mask2).sub(mask2);
    const backgroundWithHole = this.sampleClippedImage.mul(invertedMask2);
    const combinedImage = backgroundWithHole.add(rectangleImagePadded);

    rectangleImageMod.dispose();
    rectangleImage.dispose();
    rectangleImagePadded.dispose();
    mask2.dispose();
    invertedMask2.dispose();
    backgroundWithHole.dispose();
    averagePixelValueTensor.dispose();

    return combinedImage;
  }

  addSample = async (className) => {
    if (this.camera && className && this.state.model) {
      if (this.classNamesToIndex[className] === undefined) {
        this.classNamesToIndex[className] = Object.keys(this.classNamesToIndex).length || 0;
      }
      const index = this.classNamesToIndex[className];
      if (!this.samples[index]) {
        this.samples[index] = [];
        this.samplesTargets[index] = [];
      }
      this.camera.preview = true;
      const video = tfcore.browser.fromPixels(this.camera.video);
      const rotatedRectangle = this.calculateReactangleOnCurrentFrame();

      let offsetX = 0;
      let offsetY = 0;
      const width = rotatedRectangle.x1 - rotatedRectangle.x0;
      const height = rotatedRectangle.y1 - rotatedRectangle.y0;
      let sample = null;
      if (this.sampleClippedImage) {
        this.sampleClippedImage.dispose();
        this.sampleClippedImage = null;
      }
      // this.sampleClippedImage = await this.loadImageAsTensor('/images/profiles/maqueta_bsas_semaforo.png');
      while (offsetY + height < this.CANVAS_HEIGHT) {
        while (offsetX + width < this.CANVAS_WIDTH) {
          for (let i = 0; i < this.imageAugmentation.length; i += 1) {
            const displacedTensor = this.displaceRectangleImageWithFill(video, offsetX, offsetY, rotatedRectangle, this.imageAugmentation[i]);

            sample = await this.calculateFeaturesOnCurrentFrameWithOffset(displacedTensor, offsetX, offsetY, rotatedRectangle, index);
            const embeddings = this.state.model.infer(sample.image, true);
            // this.samples[index].push(embeddings);
            this.samples[index].push(embeddings.dataSync());
            this.samplesTargets[index].push(sample.target);
            displacedTensor.dispose();
            embeddings.dispose();
          }
          offsetX = sample.edgeX;
        }
        offsetY = sample.edgeY;
        offsetX = 0;
      }
      // take original sample too
      const oriSample = await this.calculateFeaturesOnCurrentFrame(video, index);
      const oriEmbeddings = this.state.model.infer(oriSample.image, true);
      // this.samples[index].push(oriEmbeddings);
      this.samples[index].push(oriEmbeddings.dataSync());
      this.samplesTargets[index].push(oriSample.target);

      this.camera.preview = false;
      oriEmbeddings.dispose();
      // console.log('===============addSample', oriSampleEmb);
      this.rectangleRandomPosition = true;
      return this.samples[index].length || 0;
    }
    return 0;
  }

  getSamples = () => {
    return {
      modelType: this.modelType,
      samples: this.samples,
      samplesTargets: this.samplesTargets,
    };
  }

  setSamples = (samples, classesArray) => {
    if (samples.modelType === this.modelType) {
      // this.samples = samples.samples;
      // Bug: cut shape of saved sample ([1296]) to [1280]
      this.samples = samples.samples.map((sampleArr) => sampleArr.map((sample) => sample.slice(0, this.outputShape[1])));
      // this.samples = samples.samples.map((sampleArr) => sampleArr.map((sample) => sample.slice(0, this.outputShape)));
      this.samplesTargets = samples.samplesTargets;
      const nSamples = {};
      this.samples.map((sampleArr, classIndex) => {
        if (sampleArr && classesArray[classIndex] !== undefined) {
          const className = classesArray[classIndex][0]
          if (this.classNamesToIndex[className] === undefined) {
            this.classNamesToIndex[className] = classIndex;
          }
          nSamples[className] = sampleArr.length;
        }
        return null
      });
      return Object.keys(nSamples).length > 0 ? nSamples : false;
    } else {
      return false;
    }
  }

  customLossFunction = (yTrue, yPred) => {
    return tf.tidy(() => {
      // const LABEL_MULTIPLIER = [this.camera.video.videoWidth / (Object.keys(this.classNamesToIndex).length - 1), 1, 1, 1, 1];
      const LABEL_MULTIPLIER = [this.camera.video.videoWidth / Object.keys(this.classNamesToIndex).length, 1, 1, 1, 1];
      // Scale the the first column (0-1 shape indicator) of `yTrue` in order
      // to ensure balanced contributions to the final loss value
      // from shape and bounding-box predictions.
      // console.log('===============customLossFunction', yTrue.dataSync(), yPred.dataSync(), yTrue.mul(LABEL_MULTIPLIER).dataSync());
      return tf.metrics.meanSquaredError(yTrue.mul(LABEL_MULTIPLIER), yPred);
    });
  }

  buildNewHead = (inputShape) => {
    const varianceScaling = tf.initializers.varianceScaling({ seed: 3.14 });
    const newHead = tf.sequential();
    newHead.add(tf.layers.flatten({ inputShape }));
    newHead.add(tf.layers.dense({
      // inputShape: [inputShape],
      units: this.trainingParams.denseUnits,
      activation: 'relu',
      kernelRegularizer: tf.regularizers.l2({ l2: 0.01 }),
      // kernelInitializer: varianceScaling,
      // useBias: true,
    }));
    newHead.add(tf.layers.dropout({ rate: 0.1 }));
    newHead.add(tf.layers.dense({
      units: 5,
      // kernelInitializer: varianceScaling,
      // useBias: false,
    }));
    return newHead;
  }

  train = async (callbacks) => {
    if (this.model) {
      this.model.dispose();
    }
    this.model = this.buildNewHead(this.outputShape);
    // this.model.summary();

    this.model.setUserDefinedMetadata({ classNames: this.classNamesToIndex });

    if (!(parseInt(this.trainingParams.batchSize, 10) > 0)) {
      throw new Error('Batch size is 0 or NaN. Please choose a non-zero fraction');
    }

    const images = [];
    this.samples.map((sampleArr) => {
      sampleArr.map((sample) => {
        images.push([sample]);
        // images.push(tf.tensor(sample));
        return null;
      });
      return null;
    });

    const targets = [];
    this.samplesTargets.map((sampleArr) => {
      return sampleArr.map((sample) => {
        targets.push(tf.tensor(sample));
        return null;
      });
      return null;
    });
    const { shuffled, shuffledTarget } = fisherYatesWithTarget(images, targets);
    const imagesShuffled = shuffled;
    const targetsShuffled = shuffledTarget;
    // console.log('===============train', this.classNamesToIndex, images, targets, imagesShuffled, targetsShuffled);

    const optimizer = tf.train.rmsprop(parseFloat(this.trainingParams.learningRate));
    this.model.compile({
      optimizer,
      loss: this.customLossFunction,
      // loss: 'meanSquaredError',
      metrics: ['accuracy'],
    });
    await this.model.fit(tf.stack(imagesShuffled), tf.stack(targetsShuffled), {
      epochs: parseInt(this.trainingParams.epochs, 10),
      batchSize: parseInt(this.trainingParams.batchSize, 10),
      validationSplit: 0.2,
      callbacks,
      shuffle: true,
    });

    optimizer.dispose(); // cleanup

    return this.model;
  }

  save = async (name = 'modelo-educabot-objetos') => {
    // Save to local storage
    const saveResults = await this.model.save(`localstorage://${name}`);
  }

  load = async (name = 'modelo-educabot-objetos') => {
    // Load from local storage
    const loadedModel = await tf.loadLayersModel(`localstorage://${name}`);
  }

  export = async (name = 'modelo-educabot-objetos') => {
    const saveResults = await this.model.save(`downloads://${name}`);
  }

  import = async (json, weights) => {
    this.model = await tf.loadLayersModel(tf.io.browserFiles([json, weights]));
    // this.model = await tf.loadLayersModel(tf.io.browserFiles([json.files[0], weights.files[0]]));
    if (this.model) {
      console.log('===============import', this.model);
      this.classNamesToIndex = this.model.getUserDefinedMetadata().classNames;
      return this.classNamesToIndex;
    } else {
      return false;
    }
  }
}
