/*
 * Decompiled with CFR 0.152.
 */
package ws.palladian.nodes.classification.text;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import org.knime.core.data.DataCell;
import org.knime.core.data.DataColumnSpec;
import org.knime.core.data.DataColumnSpecCreator;
import org.knime.core.data.DataRow;
import org.knime.core.data.DataTableSpec;
import org.knime.core.data.DataType;
import org.knime.core.data.StringValue;
import org.knime.core.data.container.AbstractCellFactory;
import org.knime.core.data.container.CellFactory;
import org.knime.core.data.container.ColumnRearranger;
import org.knime.core.data.def.DoubleCell;
import org.knime.core.data.def.StringCell;
import org.knime.core.node.BufferedDataTable;
import org.knime.core.node.ExecutionContext;
import org.knime.core.node.ExecutionMonitor;
import org.knime.core.node.InvalidSettingsException;
import org.knime.core.node.NodeLogger;
import org.knime.core.node.NodeSettingsRO;
import org.knime.core.node.NodeSettingsWO;
import org.knime.core.node.port.PortObject;
import org.knime.core.node.port.PortObjectSpec;
import org.knime.core.node.port.PortType;
import org.knime.core.node.streamable.simple.SimpleStreamableFunctionNodeModel;
import ws.palladian.classification.text.DictionaryModel;
import ws.palladian.classification.text.PalladianTextClassifier;
import ws.palladian.core.CategoryEntries;
import ws.palladian.nodes.PalladianPluginActivator;
import ws.palladian.nodes.classification.text.ITextClassifierPortObject;
import ws.palladian.nodes.classification.text.TextClassifierPortObjectSpec;
import ws.palladian.nodes.classification.text.TextClassifierPredictorNodeSettings;
import ws.palladian.nodes.helper.PalladianKnimeHelper;

public class TextClassifierPredictorNodeModel
extends SimpleStreamableFunctionNodeModel {
    private static final NodeLogger logger = NodeLogger.getLogger(TextClassifierPredictorNodeModel.class);
    private final TextClassifierPredictorNodeSettings settings = new TextClassifierPredictorNodeSettings();
    private DictionaryModel dictionaryModel;

    protected TextClassifierPredictorNodeModel() {
        super(new PortType[]{ITextClassifierPortObject.TYPE, BufferedDataTable.TYPE}, new PortType[]{BufferedDataTable.TYPE}, 1, 0);
    }

    protected PortObject[] execute(PortObject[] inData, ExecutionContext exec) throws Exception {
        PalladianPluginActivator.checkLicense();
        ITextClassifierPortObject portObject = (ITextClassifierPortObject)inData[0];
        this.dictionaryModel = portObject.getModel();
        BufferedDataTable textTable = (BufferedDataTable)inData[1];
        ColumnRearranger rearranger = this.createColumnRearranger(textTable.getSpec());
        BufferedDataTable out = exec.createColumnRearrangeTable(textTable, rearranger, (ExecutionMonitor)exec);
        return new BufferedDataTable[]{out};
    }

    protected ColumnRearranger createColumnRearranger(DataTableSpec textTableSpec) throws InvalidSettingsException {
        PalladianTextClassifier textClassifier;
        ArrayList categories;
        String textColumn = this.settings.getInputColumnName();
        if (textColumn == null) {
            throw new IllegalStateException("The column for the text input is not configured correctly.");
        }
        ArrayList<DataColumnSpec> colSpecs = new ArrayList<DataColumnSpec>();
        if (this.settings.isAppendClassDist()) {
            categories = new ArrayList(this.dictionaryModel.getCategories());
            Collections.sort(categories);
            int i = 0;
            while (i < categories.size()) {
                String category = (String)categories.get(i);
                colSpecs.add(new DataColumnSpecCreator(category, DoubleCell.TYPE).createSpec());
                ++i;
            }
        } else {
            categories = null;
        }
        colSpecs.add(new DataColumnSpecCreator("predictedCategory", StringCell.TYPE).createSpec());
        final int inputColumnIndex = textTableSpec.findColumnIndex(textColumn);
        assert (inputColumnIndex > -1);
        if (this.dictionaryModel != null) {
            PalladianTextClassifier.Scorer scorer = this.settings.createScorer();
            textClassifier = new PalladianTextClassifier(this.dictionaryModel.getFeatureSetting(), scorer);
        } else {
            textClassifier = null;
        }
        DataColumnSpec[] colSpecArray = colSpecs.toArray(new DataColumnSpec[0]);
        AbstractCellFactory factory = new AbstractCellFactory(colSpecArray){

            public DataCell[] getCells(DataRow row) {
                boolean addCategoryRel = TextClassifierPredictorNodeModel.this.settings.isAppendClassDist();
                int numCategories = TextClassifierPredictorNodeModel.this.dictionaryModel.getNumCategories();
                int numCells = addCategoryRel ? numCategories + 1 : 1;
                Object[] cells = new DataCell[numCells];
                if (row.getCell(inputColumnIndex).isMissing()) {
                    Arrays.fill(cells, DataType.getMissingCell());
                    return cells;
                }
                String text = ((StringValue)row.getCell(inputColumnIndex)).getStringValue();
                CategoryEntries result = textClassifier.classify(text, TextClassifierPredictorNodeModel.this.dictionaryModel);
                String categoryName = result.getMostLikelyCategory();
                cells[numCells - 1] = new StringCell(categoryName);
                if (addCategoryRel) {
                    int i = 0;
                    while (i < numCategories) {
                        String category = (String)categories.get(i);
                        double categoryRel = result.getProbability(category);
                        cells[i] = new DoubleCell(categoryRel);
                        ++i;
                    }
                }
                return cells;
            }

            public void afterProcessing() {
                TextClassifierPredictorNodeModel.this.dictionaryModel = null;
            }
        };
        factory.setParallelProcessing(true);
        ColumnRearranger rearranger = new ColumnRearranger(textTableSpec);
        rearranger.append((CellFactory)factory);
        return rearranger;
    }

    protected PortObjectSpec[] configure(PortObjectSpec[] inSpecs) throws InvalidSettingsException {
        TextClassifierPortObjectSpec portObjectSpec = (TextClassifierPortObjectSpec)inSpecs[0];
        DataTableSpec textTableSpec = (DataTableSpec)inSpecs[1];
        DataColumnSpec textColSpec = null;
        if (this.settings.getInputColumnName() == null) {
            String textColumnName = portObjectSpec.getTextColumn();
            if (textColumnName != null) {
                logger.debug((Object)("Name of text column from PortObjectSpec: " + textColumnName));
            }
            textColSpec = textTableSpec.getColumnSpec(textColumnName);
        }
        if (textColSpec == null) {
            textColSpec = PalladianKnimeHelper.getColumn(textTableSpec, this.settings.getInputColumnName(), StringValue.class);
        }
        if (textColSpec == null) {
            textColSpec = PalladianKnimeHelper.guessColumn(textTableSpec, StringValue.class);
            this.setWarningMessage("Guessing input column: " + textColSpec.getName());
        }
        this.settings.setInputColumnName(textColSpec.getName());
        if (this.settings.isAppendClassDist()) {
            return null;
        }
        ColumnRearranger rearranger = this.createColumnRearranger(textTableSpec);
        return new DataTableSpec[]{rearranger.createSpec()};
    }

    protected void saveSettingsTo(NodeSettingsWO settings) {
        this.settings.saveSettingsTo(settings);
    }

    protected void loadValidatedSettingsFrom(NodeSettingsRO settings) throws InvalidSettingsException {
        this.settings.loadValidatedSettingsFrom(settings);
    }

    protected void validateSettings(NodeSettingsRO settings) throws InvalidSettingsException {
        this.settings.validateSettings(settings);
    }
}

