package cre.algorithm.cdt;

import cre.Config.OtherConfig;
import cre.algorithm.AbstractAlgorithm;
import cre.algorithm.CanShowOutput;
import cre.algorithm.CanShowStatus;
import cre.algorithm.StratifiedSampleHelper;
import cre.view.ResizablePanel;
import cre.view.tree.TreePanel;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.batik.svggen.SVGSyntax;
import org.apache.commons.io.IOUtils;

/* loaded from: input_file:cre/algorithm/cdt/CDTAlgorithm.class */
public class CDTAlgorithm extends AbstractAlgorithm {
    public CDTConfig config;

    public CDTAlgorithm(File file, CDTConfig cDTConfig) {
        super(file);
        if (cDTConfig != null) {
            this.config = cDTConfig;
            return;
        }
        this.config = new CDTConfig();
        this.config.setHeight(5);
        this.config.setPruned(true);
    }

    public CDTAlgorithm(File file) {
        this(file, null);
    }

    @Override // cre.algorithm.AbstractAlgorithm
    public void init() throws Exception {
        if (!this.filePath.getAbsolutePath().endsWith(".csv")) {
            throw new Exception("Current data file: " + this.filePath.getAbsolutePath() + ".\nFor CR-CS, only CSV format file is permitted.");
        }
    }

    @Override // cre.algorithm.AbstractAlgorithm
    public String getName() {
        return "CDT (Causal Decision Tree)";
    }

    @Override // cre.algorithm.AbstractAlgorithm
    public String getIntroduction() {
        return "A tree model for discovering and representing causal relationships.\n\nReferences\n[1] Jiuyong Li, Saisai Ma, Thuc Duy Le, Lin Liu and Jixue Liu, Causal Decision Trees, IEEE Transactions on Knowledge and Data Engineering (TKDE), 29 (2): 257-271, 2016. \n\nOptions\nCausal discovery  --  Discover and represent causal relationships using CDT.\nClassification  --  Build a CDT and use the causal features included in the CDT for classification.\n";
    }

    @Override // cre.algorithm.AbstractAlgorithm
    public Cloneable getConfiguration() {
        return this.config;
    }

    @Override // cre.algorithm.AbstractAlgorithm
    public AbstractAlgorithm getCloneBecauseChangeOfFile(File file) throws Exception {
        CDTAlgorithm cDTAlgorithm = new CDTAlgorithm(file, this.config);
        cDTAlgorithm.init();
        return cDTAlgorithm;
    }

    @Override // cre.algorithm.AbstractAlgorithm
    public List<ResizablePanel> doAlgorithm(CanShowOutput canShowOutput, CanShowStatus canShowStatus, OtherConfig otherConfig) {
        String absolutePath = this.filePath.getAbsolutePath();
        String[] strArr = null;
        int i = 0;
        canShowOutput.showOutputString("Scheme: " + this.config.toString());
        canShowOutput.showOutputString("File Name: " + absolutePath + IOUtils.LINE_SEPARATOR_UNIX);
        BufferedReader bufferedReader = null;
        try {
            try {
                bufferedReader = new BufferedReader(new FileReader(this.filePath));
                strArr = bufferedReader.readLine().split(SVGSyntax.COMMA);
                while (true) {
                    String readLine = bufferedReader.readLine();
                    if (readLine == null) {
                        break;
                    }
                    if (readLine.split(SVGSyntax.COMMA).length == strArr.length) {
                        i++;
                    }
                }
                if (bufferedReader != null) {
                    try {
                        bufferedReader.close();
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                }
            } catch (IOException e2) {
                canShowOutput.showOutputString("ERROR: see log for more details.");
                e2.printStackTrace();
                if (bufferedReader != null) {
                    try {
                        bufferedReader.close();
                    } catch (IOException e3) {
                        e3.printStackTrace();
                    }
                }
            }
            canShowOutput.showOutputString("Attributes:");
            for (String str : strArr) {
                canShowOutput.showOutputString("\t" + str);
            }
            canShowOutput.showOutputString("==== full training set ===");
            canShowStatus.showStatus("Building...");
            CDT cdt = new CDT(this.config, absolutePath, canShowOutput, null, -1, null, null, false);
            try {
                cdt.createDecisionTree();
                ArrayList arrayList = new ArrayList();
                try {
                    if (cdt.rootYizhao != null) {
                        arrayList.add(new TreePanel(cdt.rootYizhao));
                        ((ResizablePanel) arrayList.get(0)).setTag("Diagram");
                    }
                } catch (Exception e4) {
                    canShowOutput.showOutputString("ERROR: see log for more details.");
                    e4.printStackTrace();
                }
                ArrayList arrayList2 = null;
                switch (otherConfig.getValidation()) {
                    case CROSS_VALIDATION:
                        canShowOutput.showOutputString("\n========Cross Validation(folds: " + otherConfig.getCrossValidationFolds() + ", repeat: " + otherConfig.getValidationRepeatTimes() + ")=======\n");
                        try {
                            int crossValidationFolds = otherConfig.getCrossValidationFolds();
                            int validationRepeatTimes = otherConfig.getValidationRepeatTimes();
                            StratifiedSampleHelper stratifiedSampleHelper = new StratifiedSampleHelper(this.filePath.getAbsolutePath(), SVGSyntax.COMMA, strArr.length - 1, true, crossValidationFolds, strArr.length, canShowOutput);
                            arrayList2 = new ArrayList();
                            int i2 = 0;
                            while (true) {
                                if (i2 < validationRepeatTimes) {
                                    ArrayList arrayList3 = new ArrayList();
                                    ArrayList arrayList4 = new ArrayList();
                                    int[] nextLines = stratifiedSampleHelper.nextLines();
                                    for (int i3 = 0; i3 < crossValidationFolds; i3++) {
                                        if (isShouldStop()) {
                                            arrayList2 = null;
                                        } else {
                                            canShowStatus.showStatus("times: " + (i2 + 1) + "; fold: " + (i3 + 1));
                                            try {
                                                new CDT(this.config, absolutePath, canShowOutput, nextLines, i3, arrayList3, arrayList4, true).createDecisionTree();
                                            } catch (Exception e5) {
                                                e5.printStackTrace();
                                                canShowOutput.showOutputString("ERROR: see log for more details.");
                                                return null;
                                            }
                                        }
                                    }
                                    arrayList2.add(new CDTValidationStatistic(arrayList3, arrayList4));
                                    i2++;
                                }
                            }
                            break;
                        } catch (Exception e6) {
                            e6.printStackTrace();
                            canShowOutput.showOutputString("ERROR: " + e6.getMessage());
                            break;
                        }
                    case VALIDATION:
                        canShowOutput.showOutputString("\n========Validation(testing: " + otherConfig.getTest() + "%, repeat: " + otherConfig.getValidationRepeatTimes() + ")=======\n");
                        try {
                            int test = otherConfig.getTest();
                            int validationRepeatTimes2 = otherConfig.getValidationRepeatTimes();
                            StratifiedSampleHelper stratifiedSampleHelper2 = new StratifiedSampleHelper(this.filePath.getAbsolutePath(), SVGSyntax.COMMA, strArr.length - 1, false, test / 100.0d, strArr.length, canShowOutput);
                            arrayList2 = new ArrayList();
                            int i4 = 0;
                            while (true) {
                                if (i4 < validationRepeatTimes2) {
                                    if (isShouldStop()) {
                                        arrayList2 = null;
                                    } else {
                                        canShowStatus.showStatus("times: " + (i4 + 1));
                                        ArrayList arrayList5 = new ArrayList();
                                        ArrayList arrayList6 = new ArrayList();
                                        try {
                                            new CDT(this.config, absolutePath, canShowOutput, stratifiedSampleHelper2.nextLines(), 0, arrayList5, arrayList6, true).createDecisionTree();
                                            arrayList2.add(new CDTValidationStatistic(arrayList5, arrayList6));
                                            i4++;
                                        } catch (Exception e7) {
                                            canShowOutput.showOutputString("ERROR: see log for more details.");
                                            e7.printStackTrace();
                                            return null;
                                        }
                                    }
                                }
                            }
                            break;
                        } catch (Exception e8) {
                            e8.printStackTrace();
                            canShowOutput.showOutputString("ERROR: " + e8.getMessage());
                            break;
                        }
                }
                if (arrayList2 != null) {
                    CDTValidationStatistic average = CDTValidationStatistic.average(arrayList2);
                    canShowOutput.showOutputString(average.toString());
                    canShowOutput.showOutputString("");
                    canShowOutput.showOutputString(average.getDetailedAccuracy());
                }
                return arrayList;
            } catch (Exception e9) {
                canShowOutput.showOutputString("ERROR: see log for more details.");
                e9.printStackTrace();
                return null;
            }
        } catch (Throwable th) {
            if (bufferedReader != null) {
                try {
                    bufferedReader.close();
                } catch (IOException e10) {
                    e10.printStackTrace();
                }
            }
            throw th;
        }
    }

    @Override // cre.algorithm.AbstractAlgorithm
    public Object clone() {
        CDTAlgorithm cDTAlgorithm = new CDTAlgorithm(this.filePath);
        if (this.config != null) {
            cDTAlgorithm.config = (CDTConfig) this.config.clone();
        }
        return cDTAlgorithm;
    }
}
