/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.ann;

import breeze.linalg.DenseMatrix;
import org.apache.spark.ml.ann.DataStacker;
import org.apache.spark.ml.ann.Topology;
import org.apache.spark.ml.ann.TopologyModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors$;
import org.apache.spark.mllib.optimization.Gradient;
import scala.MatchError;
import scala.Tuple2;
import scala.Tuple3;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

@ScalaSignature(bytes="\u0006\u0001\u00193Q!\u0001\u0002\u0001\u00051\u00111\"\u0011(O\u000fJ\fG-[3oi*\u00111\u0001B\u0001\u0004C:t'BA\u0003\u0007\u0003\tiGN\u0003\u0002\b\u0011\u0005)1\u000f]1sW*\u0011\u0011BC\u0001\u0007CB\f7\r[3\u000b\u0003-\t1a\u001c:h'\t\u0001Q\u0002\u0005\u0002\u000f'5\tqB\u0003\u0002\u0011#\u0005aq\u000e\u001d;j[&T\u0018\r^5p]*\u0011!CB\u0001\u0006[2d\u0017NY\u0005\u0003)=\u0011\u0001b\u0012:bI&,g\u000e\u001e\u0005\t-\u0001\u0011\t\u0011)A\u00051\u0005AAo\u001c9pY><\u0017p\u0001\u0001\u0011\u0005eQR\"\u0001\u0002\n\u0005m\u0011!\u0001\u0003+pa>dwnZ=\t\u0011u\u0001!\u0011!Q\u0001\ny\t1\u0002Z1uCN#\u0018mY6feB\u0011\u0011dH\u0005\u0003A\t\u00111\u0002R1uCN#\u0018mY6fe\")!\u0005\u0001C\u0001G\u00051A(\u001b8jiz\"2\u0001J\u0013'!\tI\u0002\u0001C\u0003\u0017C\u0001\u0007\u0001\u0004C\u0003\u001eC\u0001\u0007a\u0004C\u0003)\u0001\u0011\u0005\u0013&A\u0004d_6\u0004X\u000f^3\u0015\t)J4(\u0010\t\u0005W9\u0002d'D\u0001-\u0015\u0005i\u0013!B:dC2\f\u0017BA\u0018-\u0005\u0019!V\u000f\u001d7feA\u0011\u0011\u0007N\u0007\u0002e)\u00111'E\u0001\u0007Y&t\u0017\r\\4\n\u0005U\u0012$A\u0002,fGR|'\u000f\u0005\u0002,o%\u0011\u0001\b\f\u0002\u0007\t>,(\r\\3\t\u000bi:\u0003\u0019\u0001\u0019\u0002\t\u0011\fG/\u0019\u0005\u0006y\u001d\u0002\rAN\u0001\u0006Y\u0006\u0014W\r\u001c\u0005\u0006}\u001d\u0002\r\u0001M\u0001\bo\u0016Lw\r\u001b;t\u0011\u0015A\u0003\u0001\"\u0011A)\u00151\u0014IQ\"E\u0011\u0015Qt\b1\u00011\u0011\u0015at\b1\u00017\u0011\u0015qt\b1\u00011\u0011\u0015)u\b1\u00011\u0003-\u0019W/\\$sC\u0012LWM\u001c;")
public class ANNGradient
extends Gradient {
    private final Topology topology;
    private final DataStacker dataStacker;

    @Override
    public Tuple2<Vector, Object> compute(Vector data, double label, Vector weights2) {
        Vector gradient2 = Vectors$.MODULE$.zeros(weights2.size());
        double loss2 = this.compute(data, label, weights2, gradient2);
        return new Tuple2((Object)gradient2, (Object)BoxesRunTime.boxToDouble((double)loss2));
    }

    @Override
    public double compute(Vector data, double label, Vector weights2, Vector cumGradient) {
        Tuple3<DenseMatrix<Object>, DenseMatrix<Object>, Object> tuple3 = this.dataStacker.unstack(data);
        if (tuple3 != null) {
            Tuple3 tuple32;
            DenseMatrix input = (DenseMatrix)tuple3._1();
            DenseMatrix target = (DenseMatrix)tuple3._2();
            int realBatchSize = BoxesRunTime.unboxToInt((Object)tuple3._3());
            Tuple3 tuple33 = tuple32 = new Tuple3((Object)input, (Object)target, (Object)BoxesRunTime.boxToInteger((int)realBatchSize));
            DenseMatrix input2 = (DenseMatrix)tuple33._1();
            DenseMatrix target2 = (DenseMatrix)tuple33._2();
            int realBatchSize2 = BoxesRunTime.unboxToInt((Object)tuple33._3());
            TopologyModel model = this.topology.getInstance(weights2);
            return model.computeGradient((DenseMatrix<Object>)input2, (DenseMatrix<Object>)target2, cumGradient, realBatchSize2);
        }
        throw new MatchError(tuple3);
    }

    public ANNGradient(Topology topology, DataStacker dataStacker) {
        this.topology = topology;
        this.dataStacker = dataStacker;
    }
}

