/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.r;

import java.io.Serializable;
import org.apache.spark.SparkException;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.attribute.Attribute;
import org.apache.spark.ml.attribute.AttributeGroup$;
import org.apache.spark.ml.feature.RFormula;
import org.apache.spark.ml.feature.RFormulaModel;
import org.apache.spark.ml.r.AFTSurvivalRegressionWrapper;
import org.apache.spark.ml.r.RWrapperUtils$;
import org.apache.spark.ml.regression.AFTSurvivalRegression;
import org.apache.spark.ml.util.MLReadable;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.StructType;
import scala.Function1;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.ArrayOps$;
import scala.collection.LinearSeqOps;
import scala.collection.StringOps$;
import scala.collection.immutable.List;
import scala.reflect.ClassTag$;
import scala.util.matching.Regex;

public final class AFTSurvivalRegressionWrapper$
implements MLReadable<AFTSurvivalRegressionWrapper> {
    public static final AFTSurvivalRegressionWrapper$ MODULE$ = new AFTSurvivalRegressionWrapper$();
    private static final Regex FORMULA_REGEXP;

    static {
        MLReadable.$init$(MODULE$);
        FORMULA_REGEXP = StringOps$.MODULE$.r$extension(Predef$.MODULE$.augmentString("Surv\\(([^,]+), ([^,]+)\\) ~ (.+)"));
    }

    private Regex FORMULA_REGEXP() {
        return FORMULA_REGEXP;
    }

    private Tuple2<String, String> formulaRewrite(String formula) {
        String rewrittenFormula = null;
        String censorCol = null;
        try {
            Option option;
            String string = formula;
            if (string == null || (option = this.FORMULA_REGEXP().unapplySeq((CharSequence)string)).isEmpty() || option.get() == null || ((List)option.get()).lengthCompare(3) != 0) {
                throw new MatchError((Object)string);
            }
            String label = (String)((LinearSeqOps)option.get()).apply(0);
            String censor = (String)((LinearSeqOps)option.get()).apply(1);
            String features = (String)((LinearSeqOps)option.get()).apply(2);
            Tuple3 tuple3 = new Tuple3((Object)label, (Object)censor, (Object)features);
            String label2 = (String)tuple3._1();
            String censor2 = (String)tuple3._2();
            String features2 = (String)tuple3._3();
            if (features2.contains(".")) {
                throw new UnsupportedOperationException("Terms of survreg formula can not support dot operator.");
            }
            rewrittenFormula = label2.trim() + "~" + features2.trim();
            censorCol = censor2.trim();
        }
        catch (MatchError e) {
            throw new SparkException("Could not parse formula: " + formula);
        }
        return new Tuple2((Object)rewrittenFormula, (Object)censorCol);
    }

    public AFTSurvivalRegressionWrapper fit(String formula, Dataset<Row> data, int aggregationDepth, String stringIndexerOrderType) {
        Tuple2<String, String> tuple2 = this.formulaRewrite(formula);
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        String rewrittenFormula = (String)tuple2._1();
        String censorCol = (String)tuple2._2();
        Tuple2 tuple22 = new Tuple2((Object)rewrittenFormula, (Object)censorCol);
        String rewrittenFormula2 = (String)tuple22._1();
        String censorCol2 = (String)tuple22._2();
        RFormula rFormula = new RFormula().setFormula(rewrittenFormula2).setStringIndexerOrderType(stringIndexerOrderType);
        RWrapperUtils$.MODULE$.checkDataColumns(rFormula, data);
        Model rFormulaModel = rFormula.fit((Dataset)data);
        StructType schema = ((RFormulaModel)rFormulaModel).transform(data).schema();
        Attribute[] featureAttrs = (Attribute[])AttributeGroup$.MODULE$.fromStructField(schema.apply(rFormula.getFeaturesCol())).attributes().get();
        String[] features = (String[])ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps((Object[])featureAttrs), (Function1 & Serializable)x$3 -> (String)x$3.name().get(), ClassTag$.MODULE$.apply(String.class));
        AFTSurvivalRegression aft = ((AFTSurvivalRegression)new AFTSurvivalRegression().setCensorCol(censorCol2).setFitIntercept(rFormula.hasIntercept()).setFeaturesCol(rFormula.getFeaturesCol())).setAggregationDepth(aggregationDepth);
        Model pipeline = new Pipeline().setStages((PipelineStage[])((Object[])new PipelineStage[]{rFormulaModel, aft})).fit((Dataset)data);
        return new AFTSurvivalRegressionWrapper((PipelineModel)pipeline, features);
    }

    @Override
    public MLReader<AFTSurvivalRegressionWrapper> read() {
        return new AFTSurvivalRegressionWrapper.AFTSurvivalRegressionWrapperReader();
    }

    @Override
    public AFTSurvivalRegressionWrapper load(String path) {
        return (AFTSurvivalRegressionWrapper)MLReadable.load$(this, path);
    }

    private AFTSurvivalRegressionWrapper$() {
    }
}

