Class COErrorRate
- java.lang.Object
-
- com.tibco.patterns.learn.training.COIterationCount
-
- com.tibco.patterns.learn.training.COErrorRate
-
- All Implemented Interfaces:
ConvergenceObserver,TrainingObserver
- Direct Known Subclasses:
COErrorRateMin
public class COErrorRate extends COIterationCount
Stops the training when the error rate over validation dataset during the last iteration is below the given threshold. Gathers additional statistics. The error rate and other statistics include predictions made by untrained submodels. This class and its descendants can only be used to train one model, not many models.In practice it may be difficult to estimate the error rate that can be achieved with the specific training and validation datasets. If the target error rate is too high, the training will stop before the model reaches the best state. If it is too low and unachievable, the model will train with the maximum number of iterations, which may result in overtraining. Therefore it is recommended to use the
COErrorRateMinclass to determine the lowest achievable error rate and then retrain the model with the best number of iterations.
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static classCOErrorRate.IterationResultStores statistics for a full or partial training iteration
-
Field Summary
Fields Modifier and Type Field Description protected COErrorRate.IterationResultcurrResprotected COErrorRate.IterationResultprevRes
-
Constructor Summary
Constructors Constructor Description COErrorRate()Creates object with default target error rate and max iterations.COErrorRate(double targetErrorRate, int maxIterations)Creates object with specified parameters.
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description voidbeginIteration()Remembers result of previous iteration, prepares to accumulate results for the current iteration.voidendIteration()If printing is set up, prints results of this iteration.voidevaluatePrediction(VectorExample e)Updates statistics for the current training iteration.static doublegetDftTargetErrorRate()intgetMinIterations()doublegetProgressEstimate()Returns an estimate of training progress, which can be used to indicate how close the desired convergence is.COErrorRate.IterationResultgetResult()doublegetTargetErrorRate()booleanisConverged()Returns true if the target error rate has been reached in the last training iteration, after performing at least the minimum number of iterations.booleanisPerfectResult()Returns true if all examples in the dataset are predicted correctly.booleanneedStopTraining()Returns true if the target error rate (after performing the minimum number of iterations) or the maximum number of iterations have been reached.voidprintHeaderLine()Prints the header of the iteration list.voidsetMaxIterations(int maxIter)Sets the maximum number of model training iterations to perform.voidsetMinIterations(int minIter)Sets the minimum number of iterations to perform.voidsetPrintOptions(int printAfterNIter, Partition partition, boolean printHeader)Enables printing of the iteration results, sets console printing options, prints header.java.lang.StringtoString()-
Methods inherited from class com.tibco.patterns.learn.training.COIterationCount
getDftMaxIterations, getMaxIterations, getNIterations, hasIterations, verifyIterationBegun, verifyIterationEnded
-
-
-
-
Field Detail
-
currRes
protected COErrorRate.IterationResult currRes
-
prevRes
protected COErrorRate.IterationResult prevRes
-
-
Constructor Detail
-
COErrorRate
public COErrorRate(double targetErrorRate, int maxIterations)Creates object with specified parameters.- Parameters:
targetErrorRate- - the error rate that is sufficient to stop model training.maxIterations- - maximum number of model training iterations to perform.- Throws:
java.lang.IllegalArgumentException- if maxIterations is not positive, or target error rate is not between 0 and 1.
-
COErrorRate
public COErrorRate()
Creates object with default target error rate and max iterations. Can be used when the object is treated as TrainingObserver only, i.e. we do not care about the stopping criteria.
-
-
Method Detail
-
getDftTargetErrorRate
public static double getDftTargetErrorRate()
- Returns:
- default target error rate (0.005).
-
toString
public java.lang.String toString()
- Overrides:
toStringin classCOIterationCount
-
setMinIterations
public void setMinIterations(int minIter)
Sets the minimum number of iterations to perform. It is 0 by default (not used). Can be used to extend training until the learning rate becomes reasonably small. This can avoid "lucky" results in early iterations when validation error rate happens to be the lowest, but the training error rate is still very high. Note that if zero errors in the training dataset are reached sooner, training will still stop without performing this minimum number of iterations.- Parameters:
minIter- - the minimum number of iterations to set.- Throws:
java.lang.IllegalArgumentException- if the given value is not below max iterations.
-
getMinIterations
public int getMinIterations()
- Returns:
- the minimum number of iterations to perform.
-
setMaxIterations
public void setMaxIterations(int maxIter)
Sets the maximum number of model training iterations to perform.- Overrides:
setMaxIterationsin classCOIterationCount- Throws:
java.lang.IllegalArgumentException- if the current min iterations is not below maxIter, or if maxIter is not positive.
-
getTargetErrorRate
public double getTargetErrorRate()
- Returns:
- the target error rate that is sufficient to stop model training.
-
getResult
public final COErrorRate.IterationResult getResult()
- Returns:
- a copy of the result of the last iteration, or null if no training was performed.
-
beginIteration
public void beginIteration()
Remembers result of previous iteration, prepares to accumulate results for the current iteration. Overriding methods must call this method.- Specified by:
beginIterationin interfaceTrainingObserver- Overrides:
beginIterationin classCOIterationCount- Throws:
java.lang.IllegalStateException- if an iteration has already begun.
-
endIteration
public void endIteration()
If printing is set up, prints results of this iteration.- Specified by:
endIterationin interfaceTrainingObserver- Overrides:
endIterationin classCOIterationCount- Throws:
java.lang.IllegalStateException- if iteration has not begun.
-
evaluatePrediction
public void evaluatePrediction(VectorExample e)
Updates statistics for the current training iteration. If printing is set up, prints incorrecly classified examples after the previously specified number of iterations. Overriding methods must call this method.- Specified by:
evaluatePredictionin interfaceTrainingObserver- Overrides:
evaluatePredictionin classCOIterationCount- Parameters:
e- - an example that must contain RLink prediction- Throws:
java.lang.IllegalStateException- if iteration has not begun.
-
getProgressEstimate
public double getProgressEstimate()
Description copied from interface:ConvergenceObserverReturns an estimate of training progress, which can be used to indicate how close the desired convergence is. The interpretation and range of the returned value can vary between subclasses. Typically, the return value approaches 0 as the model is converging. Should be called after an iteration, not in the middle of iteration.- Specified by:
getProgressEstimatein interfaceConvergenceObserver- Overrides:
getProgressEstimatein classCOIterationCount- Returns:
- the error rate of the last training iteration
- Throws:
java.lang.IllegalStateException- if iteration has not ended or no iterations were performed.
-
isConverged
public boolean isConverged()
Returns true if the target error rate has been reached in the last training iteration, after performing at least the minimum number of iterations.The error rate includes the errors for examples predicted with 0 confidence (see
COErrorRate.IterationResult.getNUntrainedPred()). Some of these errors may be from submodels that are never trained, so model training will have no impact for these errors. Thus it is essential to eliminate untrained predictions (by adding more training examples to relevant subsets) so that the model is more likely to reach the target error rate.- Specified by:
isConvergedin interfaceConvergenceObserver- Overrides:
isConvergedin classCOIterationCount- Returns:
- true if the error rate of the last training iteration is below the target error rate, and at least the minimum number of iterations have been performed.
- Throws:
java.lang.IllegalStateException- if iteration has not ended or no iterations were performed.
-
needStopTraining
public boolean needStopTraining()
Returns true if the target error rate (after performing the minimum number of iterations) or the maximum number of iterations have been reached. SeeisConverged().- Specified by:
needStopTrainingin interfaceConvergenceObserver- Overrides:
needStopTrainingin classCOIterationCount- Returns:
- true if the error rate of the last training iteration is below the target error rate and at least the minimum number of iterations have been performed, or if maximum number of iterations was reached.
- Throws:
java.lang.IllegalStateException- if iteration has not ended or no iterations were performed.
-
isPerfectResult
public boolean isPerfectResult()
Returns true if all examples in the dataset are predicted correctly. This is used to stop the training process based on the training dataset results: if all training examples were predicted correctly, the model cannot be trained better with these examples. This stopping criterion does not use the minimum number of iterations.- Specified by:
isPerfectResultin interfaceTrainingObserver- Overrides:
isPerfectResultin classCOIterationCount- Returns:
- true if all examples in the dataset were predicted correctly.
- Throws:
java.lang.IllegalStateException- if iteration has not ended or no iterations were performed.
-
setPrintOptions
public void setPrintOptions(int printAfterNIter, Partition partition, boolean printHeader)Enables printing of the iteration results, sets console printing options, prints header. If printing is needed, this should be called before the first iteration.- Parameters:
printAfterNIter- - print errors starting with this iteration.partition- - the data partition that this object will be used with.printHeader- - if true, prints the header of the iteration list. The header for the training observer must be printed before the header for the validation observer.- Throws:
java.lang.IllegalArgumentException- if printAfterNIter is negative.
-
printHeaderLine
public void printHeaderLine()
Prints the header of the iteration list. The header for the training observer must be printed before the header for the validation observer.- Throws:
java.lang.IllegalStateException- if setPrintOptions was not called.
-
-