/*
 * Decompiled with CFR 0.152.
 */
package ws.palladian.nodes.classification.text;

import java.io.File;
import java.io.IOException;
import java.lang.invoke.CallSite;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.UUID;
import java.util.function.Function;
import java.util.function.Predicate;
import org.apache.commons.lang3.StringUtils;
import org.knime.core.data.DataColumnSpec;
import org.knime.core.data.DataRow;
import org.knime.core.data.DataTableSpec;
import org.knime.core.data.StringValue;
import org.knime.core.data.filestore.FileStore;
import org.knime.core.data.util.memory.MemoryAlertListener;
import org.knime.core.data.util.memory.MemoryAlertSystem;
import org.knime.core.node.BufferedDataTable;
import org.knime.core.node.CanceledExecutionException;
import org.knime.core.node.ExecutionContext;
import org.knime.core.node.ExecutionMonitor;
import org.knime.core.node.InvalidSettingsException;
import org.knime.core.node.NodeLogger;
import org.knime.core.node.NodeModel;
import org.knime.core.node.NodeSettingsRO;
import org.knime.core.node.NodeSettingsWO;
import org.knime.core.node.port.PortObject;
import org.knime.core.node.port.PortObjectSpec;
import org.knime.core.node.port.PortType;
import org.knime.core.node.streamable.InputPortRole;
import org.knime.core.node.streamable.PartitionInfo;
import org.knime.core.node.streamable.PortInput;
import org.knime.core.node.streamable.PortObjectOutput;
import org.knime.core.node.streamable.PortOutput;
import org.knime.core.node.streamable.RowInput;
import org.knime.core.node.streamable.StreamableOperator;
import ws.palladian.classification.text.DictionaryModel;
import ws.palladian.classification.text.FeatureSetting;
import ws.palladian.classification.text.PalladianTextClassifier;
import ws.palladian.helper.collection.AbstractIterator2;
import ws.palladian.helper.collection.CollectionHelper;
import ws.palladian.helper.functional.Predicates;
import ws.palladian.nodes.PalladianPluginActivator;
import ws.palladian.nodes.classification.text.FileStoreTextClassifierPortObject;
import ws.palladian.nodes.classification.text.ITextClassifierPortObject;
import ws.palladian.nodes.classification.text.RowTrainableConverter;
import ws.palladian.nodes.classification.text.TextClassifierLearnerSettings;
import ws.palladian.nodes.classification.text.TextClassifierPortObjectSpec;
import ws.palladian.nodes.helper.PalladianKnimeHelper;

public class TextClassifierLearnerNodeModel
extends NodeModel {
    private static final NodeLogger logger = NodeLogger.getLogger(TextClassifierLearnerNodeModel.class);
    private static final int NUM_CATEGORIES_WARNING = 25;
    private final TextClassifierLearnerSettings settings = new TextClassifierLearnerSettings();

    protected TextClassifierLearnerNodeModel() {
        super(new PortType[]{BufferedDataTable.TYPE}, new PortType[]{ITextClassifierPortObject.TYPE});
    }

    protected PortObject[] execute(PortObject[] inData, ExecutionContext exec) throws Exception {
        PalladianPluginActivator.checkLicense();
        BufferedDataTable trainTable = (BufferedDataTable)inData[0];
        long rowCount = trainTable.size();
        FileStoreTextClassifierPortObject portObject = this.runTraining(trainTable.getSpec(), rowCount, (Iterable<DataRow>)trainTable, exec);
        return new PortObject[]{portObject};
    }

    private FileStoreTextClassifierPortObject runTraining(DataTableSpec inTableSpec, long rowCount, Iterable<DataRow> rows, ExecutionContext exec) throws CanceledExecutionException, Exception {
        DictionaryModel dictionaryModel;
        int textIndex = inTableSpec.findColumnIndex(this.settings.textColumn.getStringValue());
        int categoryIndex = inTableSpec.findColumnIndex(this.settings.categoryColumn.getStringValue());
        int weightIndex = inTableSpec.findColumnIndex(this.settings.weightColumn.getStringValue());
        RowTrainableConverter converter = new RowTrainableConverter(exec, textIndex, categoryIndex, weightIndex, rowCount);
        Iterable trainables = CollectionHelper.convert(rows, (Function)converter);
        trainables = CollectionHelper.filter((Iterable)trainables, (Predicate)Predicates.NOT_NULL);
        if (!this.settings.disableMemoryWarnings.getBooleanValue()) {
            MemoryAlertSystem.getInstance().addListener((MemoryAlertListener)converter);
        }
        exec.setProgress("Training the model");
        FeatureSetting featureSetting = this.settings.featureSettings.getFeatureSetting();
        logger.info((Object)("Feature setting: " + String.valueOf(featureSetting)));
        PalladianTextClassifier classifier = new PalladianTextClassifier(featureSetting);
        try {
            try {
                dictionaryModel = (DictionaryModel)classifier.train(trainables);
            }
            catch (Exception e) {
                if (e.getCause() instanceof CanceledExecutionException) {
                    throw (CanceledExecutionException)e.getCause();
                }
                throw e;
            }
        }
        finally {
            if (!this.settings.disableMemoryWarnings.getBooleanValue()) {
                MemoryAlertSystem.getInstance().removeListener((MemoryAlertListener)converter);
            }
        }
        if (converter.getNumCategories() > 25) {
            this.setWarningMessage("Trained lots of categories (" + converter.getNumCategories() + "), make sure the correct category input has been selected and keep in mind, that classification performance degrades, when a great number of categories need to be predicted.");
        }
        TextClassifierPortObjectSpec portObjectSpec = new TextClassifierPortObjectSpec(this.settings.textColumn.getStringValue());
        FileStore fileStore = exec.createFileStore(UUID.randomUUID().toString());
        return FileStoreTextClassifierPortObject.createPortObject(portObjectSpec, dictionaryModel, fileStore);
    }

    public StreamableOperator createStreamableOperator(PartitionInfo partitionInfo, PortObjectSpec[] inSpecs) throws InvalidSettingsException {
        final DataTableSpec inSpec = (DataTableSpec)inSpecs[0];
        return new StreamableOperator(){

            public void runFinal(PortInput[] inputs, PortOutput[] outputs, ExecutionContext exec) throws Exception {
                final RowInput trainTable = (RowInput)inputs[0];
                Iterable<DataRow> iterableRows = new Iterable<DataRow>(){
                    boolean consumed = false;

                    @Override
                    public Iterator<DataRow> iterator() {
                        if (this.consumed) {
                            throw new IllegalStateException("iterator was already used; this should not happen");
                        }
                        this.consumed = true;
                        return new AbstractIterator2<DataRow>(){

                            protected DataRow getNext() {
                                DataRow row;
                                try {
                                    row = trainTable.poll();
                                }
                                catch (InterruptedException e) {
                                    trainTable.close();
                                    throw new IllegalStateException(e);
                                }
                                if (row == null) {
                                    trainTable.close();
                                    return (DataRow)this.finished();
                                }
                                return row;
                            }
                        };
                    }
                };
                FileStoreTextClassifierPortObject portObject = TextClassifierLearnerNodeModel.this.runTraining(inSpec, -1L, iterableRows, exec);
                logger.debug((Object)("created dictionary: " + portObject.getSummary()));
                PortObjectOutput resultObject = (PortObjectOutput)outputs[0];
                resultObject.setPortObject((PortObject)portObject);
            }
        };
    }

    public InputPortRole[] getInputPortRoles() {
        return new InputPortRole[]{InputPortRole.NONDISTRIBUTED_STREAMABLE};
    }

    protected void reset() {
    }

    protected PortObjectSpec[] configure(PortObjectSpec[] inSpecs) throws InvalidSettingsException {
        String columnName;
        DataTableSpec inSpec = (DataTableSpec)inSpecs[0];
        DataColumnSpec textColSpec = PalladianKnimeHelper.getColumn(inSpec, this.settings.textColumn.getStringValue(), StringValue.class);
        DataColumnSpec categoryColSpec = PalladianKnimeHelper.getColumn(inSpec, this.settings.categoryColumn.getStringValue(), StringValue.class);
        ArrayList<CallSite> guessedColumns = new ArrayList<CallSite>();
        int startIndex = 0;
        if (textColSpec == null) {
            textColSpec = PalladianKnimeHelper.guessColumn(inSpec, StringValue.class);
            columnName = textColSpec.getName();
            guessedColumns.add((CallSite)((Object)("text input: " + columnName)));
            this.settings.textColumn.setStringValue(columnName);
            startIndex = inSpec.findColumnIndex(columnName) + 1;
        }
        if (categoryColSpec == null) {
            categoryColSpec = PalladianKnimeHelper.guessColumn(inSpec, startIndex, StringValue.class);
            columnName = categoryColSpec.getName();
            guessedColumns.add((CallSite)((Object)("category input: " + columnName)));
            this.settings.categoryColumn.setStringValue(columnName);
        }
        if (guessedColumns.size() > 0) {
            String warningMessage = String.format("Guessed column name(s): %s", StringUtils.join(guessedColumns, (String)", "));
            this.setWarningMessage(warningMessage);
        }
        return new PortObjectSpec[]{new TextClassifierPortObjectSpec(this.settings.textColumn.getStringValue())};
    }

    protected void saveSettingsTo(NodeSettingsWO settings) {
        this.settings.saveSettingsTo(settings);
    }

    protected void loadValidatedSettingsFrom(NodeSettingsRO settings) throws InvalidSettingsException {
        this.settings.loadValidatedSettingsFrom(settings);
    }

    protected void validateSettings(NodeSettingsRO settings) throws InvalidSettingsException {
        this.settings.validateSettings(settings);
    }

    protected void loadInternals(File internDir, ExecutionMonitor exec) throws IOException, CanceledExecutionException {
    }

    protected void saveInternals(File internDir, ExecutionMonitor exec) throws IOException, CanceledExecutionException {
    }
}

