Class COIterationCount

  • All Implemented Interfaces:
    ConvergenceObserver, TrainingObserver
    Direct Known Subclasses:
    COErrorRate, COScoreChange

    public class COIterationCount
    extends java.lang.Object
    implements ConvergenceObserver
    A simple class that counts iterations and stops training after the given maximum number of iterations has been reached. This class can be used to retrain the model after determining the best number of iterations using COErrorRateMin. It can also be used in derived classes in order not to train forever if model accuracy does not reach the desired level. This class and its descendants must be used to train only one model, not many models.
    • Constructor Detail

      • COIterationCount

        public COIterationCount​(int maxIter)
        Creates object with the given max iterations.
        Parameters:
        maxIter - - maximum number of model training iterations to perform.
        Throws:
        java.lang.IllegalArgumentException - if maxIter is not positive.
      • COIterationCount

        public COIterationCount()
        Uses the default maximum number of iterations.
        See Also:
        getDftMaxIterations()
    • Method Detail

      • getDftMaxIterations

        public static int getDftMaxIterations()
        Returns:
        default maximum number of iterations (100).
      • toString

        public java.lang.String toString()
        Overrides:
        toString in class java.lang.Object
      • getMaxIterations

        public int getMaxIterations()
        Returns:
        the maximum number of model training iterations allowed by this observer.
      • setMaxIterations

        public void setMaxIterations​(int maxIter)
        Sets the maximum number of model training iterations to perform.
        Throws:
        java.lang.IllegalArgumentException - if maxIter is not positive.
      • verifyIterationBegun

        protected void verifyIterationBegun()
        Use from methods that process incoming predictions to ensure that iteration has begun.
        Throws:
        java.lang.IllegalStateException - if iteration has not begun.
      • verifyIterationEnded

        protected void verifyIterationEnded()
        Use from methods that return iteration results to ensure that iteration has ended.
        Throws:
        java.lang.IllegalStateException - if iteration has not ended or no iterations were performed.
      • beginIteration

        public void beginIteration()
        Counts an additional iteration, switches object state. Overriding methods must call this method.
        Specified by:
        beginIteration in interface TrainingObserver
        Throws:
        java.lang.IllegalStateException - if an iteration has already begun.
      • endIteration

        public void endIteration()
        Only switches object state. No need to override it in descendants if it is not needed. Overriding methods must call this method.
        Specified by:
        endIteration in interface TrainingObserver
        Throws:
        java.lang.IllegalStateException - if iteration has not begun.
      • evaluatePrediction

        public void evaluatePrediction​(VectorExample e)
        Does nothing. Descendants typically override it to accumulate statistics.
        Specified by:
        evaluatePrediction in interface TrainingObserver
        Parameters:
        e - - an example that must have the RLink prediction assigned.
        Throws:
        java.lang.IllegalStateException - if iteration has not begun.
      • getProgressEstimate

        public double getProgressEstimate()
        Description copied from interface: ConvergenceObserver
        Returns an estimate of training progress, which can be used to indicate how close the desired convergence is. The interpretation and range of the returned value can vary between subclasses. Typically, the return value approaches 0 as the model is converging. Should be called after an iteration, not in the middle of iteration.
        Specified by:
        getProgressEstimate in interface ConvergenceObserver
        Returns:
        a progress estimate that approaches 0 as the number of iterations approaches the maximum number of iterations.
        Throws:
        java.lang.IllegalStateException - if iteration has not ended or no iterations were performed.
      • isConverged

        public boolean isConverged()
        Description copied from interface: ConvergenceObserver
        Returns whether the trained model has converged. Depending on implementation, this can be based on one or several iterations of evaluating the verification set.
        Specified by:
        isConverged in interface ConvergenceObserver
        Returns:
        true if the number of iterations has reached the maximum number of iterations
        Throws:
        java.lang.IllegalStateException - if iteration has not ended or no iterations were performed.
      • needStopTraining

        public boolean needStopTraining()
        Description copied from interface: ConvergenceObserver
        Returns whether the training has to be stopped after current iteration. This may mean that the model has converged, based on all score evaluations that were performed. Or the model may have failed to converge and another termination condition was reached.
        Specified by:
        needStopTraining in interface ConvergenceObserver
        Returns:
        true if the number of iterations has reached the maximum number of iterations
        Throws:
        java.lang.IllegalStateException - if iteration has not ended or no iterations were performed.
      • isPerfectResult

        public boolean isPerfectResult()
        Specified by:
        isPerfectResult in interface TrainingObserver
        Returns:
        false. This class has no info on model performance, so can never stop training based on any results on the training dataset.
        Throws:
        java.lang.IllegalStateException - if iteration has not ended or no iterations were performed.
      • hasIterations

        public boolean hasIterations()
        Specified by:
        hasIterations in interface TrainingObserver
        Returns:
        true if any training iterations have been started using this observer.
      • getNIterations

        public int getNIterations()
        Can be used with subclasses to distinguish between convergence and reaching maxIterations without converging.
        Returns:
        the number of training iterations that have been started.