package org.instantsvm;

import java.io.File;
import java.io.IOException;
import java.util.Vector;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_problem;
import org.instantsvm.utils.LibSvmConsole;
import org.instantsvm.utils.LibSvmIO;

/* loaded from: input_file:org/instantsvm/SVM.class */
public class SVM {
    protected svm_parameter param;
    protected svm_problem prob;
    protected svm_model model;
    protected Parameters parameters;

    public SVM() {
    }

    public SVM(String str) throws IOException {
        load(str);
    }

    public SVM(svm_model svm_modelVar) {
        mountModel(svm_modelVar);
    }

    public void train(Vector<svm_node[]> vector, Vector<Double> vector2) {
        train(vector, vector2, getDefaultParameters());
    }

    public void train(Vector<svm_node[]> vector, Vector<Double> vector2, Parameters parameters) {
        this.parameters = parameters;
        this.param = parameters.getParam();
        load(vector, vector2);
        this.model = svm.svm_train(this.prob, this.param);
        parameters.setParam(this.param);
    }

    public XValResult xval(Vector<svm_node[]> vector, Vector<Double> vector2, Parameters parameters, int i, double d, double d2, int i2, double d3, double d4, int i3) {
        this.parameters = parameters;
        this.param = parameters.getParam();
        load(vector, vector2);
        double[][] dArr = new double[i2][i3];
        double d5 = (d2 - d) / i2;
        double d6 = (d4 - d3) / i3;
        double d7 = 0.0d;
        double d8 = 0.0d;
        double d9 = Double.POSITIVE_INFINITY;
        for (int i4 = 0; i4 < i2; i4++) {
            for (int i5 = 0; i5 < i3; i5++) {
                this.param.C = d + (d5 * i4);
                this.param.gamma = d3 + (d6 * i5);
                dArr[i4][i5] = do_cross_validation(i);
                if (dArr[i4][i5] < d9) {
                    d9 = dArr[i4][i5];
                    d7 = this.param.C;
                    d8 = this.param.gamma;
                }
            }
        }
        return new XValResult(dArr, d7, d8);
    }

    public double[] apply(Vector<svm_node[]> vector) {
        return predict(vector, null, 0);
    }

    public Parameters getParameters() {
        return this.parameters;
    }

    public svm_node[][] getSupportVectors() {
        return this.model.SV;
    }

    public double[][] getCoefs() {
        return this.model.sv_coef;
    }

    public void save(String str, String str2) throws IOException {
        File file = new File(str);
        if (!file.exists()) {
            file.mkdirs();
        }
        svm.svm_save_model(str + str2, this.model);
    }

    public void save(String str) throws IOException {
        svm.svm_save_model(str, this.model);
    }

    public void load(String str) throws IOException {
        mountModel(LibSvmIO.loadModel(str));
    }

    public void print() {
        System.out.println("Support vectors (" + this.model.SV.length + "):");
        LibSvmConsole.print(this.model.SV);
        System.out.println("Coefficients (" + this.model.sv_coef.length + "):");
        LibSvmConsole.print(this.model.sv_coef);
    }

    public Parameters getDefaultParameters() {
        return new Parameters();
    }

    protected void mountModel(svm_model svm_modelVar) {
        this.model = svm_modelVar;
        this.parameters = new Parameters(svm_modelVar.param);
    }

    /* JADX WARN: Type inference failed for: r1v6, types: [libsvm.svm_node[], libsvm.svm_node[][]] */
    protected void load(Vector<svm_node[]> vector, Vector<Double> vector2) {
        this.prob = new svm_problem();
        this.prob.l = vector2.size();
        this.prob.x = new svm_node[this.prob.l];
        for (int i = 0; i < this.prob.l; i++) {
            this.prob.x[i] = vector.elementAt(i);
        }
        this.prob.y = new double[this.prob.l];
        for (int i2 = 0; i2 < this.prob.l; i2++) {
            this.prob.y[i2] = vector2.elementAt(i2).doubleValue();
        }
        if (this.param.gamma == 0.0d && 0 > 0) {
            this.param.gamma = 1.0d / 0;
        }
        if (this.param.kernel_type == 4) {
            for (int i3 = 0; i3 < this.prob.l; i3++) {
                if (this.prob.x[i3][0].index != 0) {
                    System.err.print("Wrong kernel matrix: first column must be 0:sample_serial_number\n");
                }
                if (((int) this.prob.x[i3][0].value) <= 0 || ((int) this.prob.x[i3][0].value) > 0) {
                    System.err.print("Wrong input format: sample_serial_number out of range\n");
                }
            }
        }
    }

    protected double[] predict(Vector<svm_node[]> vector, double[] dArr, int i) {
        double[] dArr2 = new double[vector.size()];
        int i2 = 0;
        int i3 = 0;
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        double d5 = 0.0d;
        double d6 = 0.0d;
        int svm_get_svm_type = svm.svm_get_svm_type(this.model);
        int svm_get_nr_class = svm.svm_get_nr_class(this.model);
        double[] dArr3 = null;
        if (i == 1) {
            if (svm_get_svm_type == 3 || svm_get_svm_type == 4) {
                System.out.print("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=" + svm.svm_get_svr_probability(this.model) + "\n");
            } else {
                svm.svm_get_labels(this.model, new int[svm_get_nr_class]);
                dArr3 = new double[svm_get_nr_class];
            }
        }
        for (int i4 = 0; i4 < vector.size(); i4++) {
            double svm_predict_probability = (i == 1 && (svm_get_svm_type == 0 || svm_get_svm_type == 1)) ? svm.svm_predict_probability(this.model, vector.get(i4), dArr3) : svm.svm_predict(this.model, vector.get(i4));
            dArr2[i4] = svm_predict_probability;
            if (dArr != null) {
                if (svm_predict_probability == dArr[i4]) {
                    i2++;
                }
                d += (svm_predict_probability - dArr[i4]) * (svm_predict_probability - dArr[i4]);
                d2 += svm_predict_probability;
                d3 += dArr[i4];
                d4 += svm_predict_probability * svm_predict_probability;
                d5 += dArr[i4] * dArr[i4];
                d6 += svm_predict_probability * dArr[i4];
                i3++;
            }
        }
        if (svm_get_svm_type == 3 || svm_get_svm_type == 4) {
            System.out.print("Mean squared error = " + (d / i3) + " (regression)\n");
            System.out.print("Squared correlation coefficient = " + ((((i3 * d6) - (d2 * d3)) * ((i3 * d6) - (d2 * d3))) / (((i3 * d4) - (d2 * d2)) * ((i3 * d5) - (d3 * d3)))) + " (regression)\n");
        } else {
            System.out.print("Accuracy = " + ((i2 / i3) * 100.0d) + "% (" + i2 + "/" + i3 + ") (classification)\n");
        }
        return dArr2;
    }

    protected double do_cross_validation(int i) {
        int i2 = 0;
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        double d5 = 0.0d;
        double d6 = 0.0d;
        double[] dArr = new double[this.prob.l];
        svm.svm_cross_validation(this.prob, this.param, i, dArr);
        if (this.param.svm_type != 3 && this.param.svm_type != 4) {
            for (int i3 = 0; i3 < this.prob.l; i3++) {
                if (dArr[i3] == this.prob.y[i3]) {
                    i2++;
                }
            }
            System.out.print("Cross Validation Accuracy = " + ((100.0d * i2) / this.prob.l) + "%\n");
            return (100.0d * i2) / this.prob.l;
        }
        for (int i4 = 0; i4 < this.prob.l; i4++) {
            double d7 = this.prob.y[i4];
            double d8 = dArr[i4];
            d += (d8 - d7) * (d8 - d7);
            d2 += d8;
            d3 += d7;
            d4 += d8 * d8;
            d5 += d7 * d7;
            d6 += d8 * d7;
        }
        System.out.print("Cross Validation Mean squared error = " + (d / this.prob.l) + "\n");
        System.out.print("Cross Validation Squared correlation coefficient = " + ((((this.prob.l * d6) - (d2 * d3)) * ((this.prob.l * d6) - (d2 * d3))) / (((this.prob.l * d4) - (d2 * d2)) * ((this.prob.l * d5) - (d3 * d3)))) + "\n");
        return d;
    }
}
