Class Trainer
- java.lang.Object
-
- com.tibco.patterns.learn.training.Trainer
-
public final class Trainer extends java.lang.ObjectImplements automated training for a single RLink model. Trains the model with a training dataset of examples until convergence criteria (measured over the same or different dataset) are satisfied. The RLink model needs to be created using RLink class first. The lowest level training that Trainer can perform is a single iteration based on given dataset of feature vectors. (Individual feature vectors are trained by RLink class).
-
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description voidevaluate(RLinkDataSet<? extends VectorExample> examples, TrainingObserver co)Calculates RLink predictions and evaluates them for the given examples.doubleevaluateErrorRate(RLinkDataSet<? extends VectorExample> examples)Convenience method that creates and uses COErrorRate to evaluate given examples.intgetModelId()PredictOptionsgetPredictOptions()static booleanisDftRandomOrder()voidlearn(RLinkDataSet<? extends VectorExample> examples)Performs a single training iteration for all given training examples.static booleanneedContinueTraining(TrainingObserver trainTO, ConvergenceObserver vldCO)RLinkOutpredict(FeatureVector featureVector)Calculates the RLink prediction for the given feature vector.voidpredict(VectorExample e)Calculates and saves RLink prediction in the given example.voidsetPredictOptions(PredictOptions predictOpts)Sets options to be used for all model predictions.voidsetRandomSeed(long seed)Sets the random seed to be used for randomizing the order of examples.voidtrainIteration(RLinkExperiment<? extends VectorExample> exper, TrainingObserver trainTO, ConvergenceObserver vldCO)Runs a single training iteration.booleantrainToConverge(RLinkDataSet<? extends VectorExample> examples, ConvergenceObserver co)Uses the same set of examples for both training and validation.booleantrainToConverge(RLinkExperiment<? extends VectorExample> exper, ConvergenceObserver vldCO)Trains the model until training termination criteria over the validation dataset (as specified in vldCO) are satisfied.booleantrainToConverge(RLinkExperiment<? extends VectorExample> exper, TrainingObserver trainTO, ConvergenceObserver vldCO)Trains the model until training termination criteria over the validation dataset (as specified in vldCO) are satisfied, or until the model correctly predicts all examples in the training dataset.
-
-
-
Constructor Detail
-
Trainer
public Trainer(int modelId, boolean randomOrder)Creates a Trainer to train one existing model.- Parameters:
modelId- - the ID of a model that was created in RLink.randomOrder- - if true, the order of examples is randomized before each training iteration. A fixed random seed is used to enable repeatable training.- Throws:
java.lang.ArrayIndexOutOfBoundsException- if model with the given modelId does not exist.
-
Trainer
public Trainer(int modelId)
Creates a Trainer to train one existing model. Uses random order with the default seed.- Parameters:
modelId- - the ID of a model that was created in RLink.- Throws:
java.lang.ArrayIndexOutOfBoundsException- if model with the given modelId does not exist.
-
-
Method Detail
-
isDftRandomOrder
public static boolean isDftRandomOrder()
- Returns:
- whether the training order is randomized by default (true).
-
setRandomSeed
public void setRandomSeed(long seed)
Sets the random seed to be used for randomizing the order of examples. Using the same seed enables the same order of training (given the same datasets). Should be used before starting training iterations. If random order is not used, the method does nothing.
-
getModelId
public int getModelId()
- Returns:
- the ID of the model that this object works with.
-
setPredictOptions
public void setPredictOptions(PredictOptions predictOpts)
Sets options to be used for all model predictions. If specific options object is not set, then default options are used.- Parameters:
predictOpts- the prediction options. If null, default options will be used.
-
getPredictOptions
public PredictOptions getPredictOptions()
- Returns:
- the options that are currently used for all model predictions. Returns null if default options are used.
-
learn
public void learn(RLinkDataSet<? extends VectorExample> examples)
Performs a single training iteration for all given training examples. Randomizes the order of training examples if needed, but the order in the given examples object is never changed. Useevaluate(RLinkDataSet, TrainingObserver)to obtain results from the model after performing one or more training iterations.- Parameters:
examples- - training dataset containing the training examples.
-
evaluate
public void evaluate(RLinkDataSet<? extends VectorExample> examples, TrainingObserver co)
Calculates RLink predictions and evaluates them for the given examples. Saves model prediction in each example of the dataset.- Parameters:
examples- - the dataset with examples to evaluate. Not null.co- - observer that accumulates statistics of evaluation. Using COScoreChange only makes sense immediately after learn() is called, not for a separate evaluation. The observer must not have begun an iteration.
-
evaluateErrorRate
public double evaluateErrorRate(RLinkDataSet<? extends VectorExample> examples)
Convenience method that creates and uses COErrorRate to evaluate given examples.- Parameters:
examples- - the dataset with examples to evaluate.- Returns:
- error rate over the given examples.
- See Also:
evaluate(RLinkDataSet, TrainingObserver)
-
predict
public void predict(VectorExample e)
Calculates and saves RLink prediction in the given example.- Parameters:
e- - example to be predicted.
-
predict
public RLinkOut predict(FeatureVector featureVector)
Calculates the RLink prediction for the given feature vector.- Parameters:
featureVector- - the vector to be predicted. Not null.- Returns:
- the prediction for the given feature vector.
-
trainIteration
public void trainIteration(RLinkExperiment<? extends VectorExample> exper, TrainingObserver trainTO, ConvergenceObserver vldCO)
Runs a single training iteration. Trains the model with all examples in training dataset. Then the model is evaluated with training dataset (stores results in trainTO) and with validation dataset (stores results in vldCO).If dataset(s) in this experiment have been used for training/evaluation earlier AND trainTO or vldCO uses predictions from previous iteration (e.g. COScoreChange), then
RLinkExperiment.clearPredictions()must be called before the first call to this method.- Parameters:
exper- - experiment that must contain training and validation datasets. Not null. If validation dataset is empty, then validation is performed using the training dataset. If training dataset is empty, no actual training is performed.trainTO- - used to accumulate statistics for the training dataset. May be null if these statistics are not important (saves time).vldCO- - used to accumulate statistics for the dataset used to validate model and to determine when to stop the training. Not null.- Throws:
java.lang.IllegalArgumentException- if the same object is used as trainTO and vldCO.
-
needContinueTraining
public static boolean needContinueTraining(TrainingObserver trainTO, ConvergenceObserver vldCO)
- Parameters:
trainTO- - the same object that was used to accumulate predictions for the training dataset in the trainIteration() call. May be null if it was not used. Observer must have ended an iteration.vldCO- - the same object that was used to accumulate predictions for the dataset used to validate model and to determine when to stop the training in the trainIteration() call. Observer must have ended an iteration.- Returns:
- true if another iteration should be performed to train the model: vldCO.needStopTraining() is false, and trainTO.isPerfectResult() is false.
- Throws:
java.lang.NullPointerException- if vldCO is null.java.lang.IllegalArgumentException- if the same object is used as trainTO and vldCO.- See Also:
ConvergenceObserver.needStopTraining(),TrainingObserver.isPerfectResult()
-
trainToConverge
public boolean trainToConverge(RLinkExperiment<? extends VectorExample> exper, TrainingObserver trainTO, ConvergenceObserver vldCO)
Trains the model until training termination criteria over the validation dataset (as specified in vldCO) are satisfied, or until the model correctly predicts all examples in the training dataset. Gathers statistics for training dataset.- Parameters:
exper- - contains two datasets used in trainingtrainTO- - gathers statistics for the training dataset. May be null if these statistics are not important (saves time), but then only the termination criteria over validation set are monitored.vldCO- - gathers statistics for the validation dataset. Not null.- Returns:
- true if convergence criteria in vldCO were satisfied; false if it failed to converge and another termination condition was reached
- Throws:
java.lang.IllegalArgumentException- if the same object is used as trainTO and vldCO.
-
trainToConverge
public boolean trainToConverge(RLinkExperiment<? extends VectorExample> exper, ConvergenceObserver vldCO)
Trains the model until training termination criteria over the validation dataset (as specified in vldCO) are satisfied. Note that the training does not stop when zero error rate is reached for the training dataset (since only one ConvergenceObserver is used).- Parameters:
exper- - contains two datasets used in training. Not null.vldCO- - gathers training statistics for validation dataset. Not null.- Returns:
- true if convergence criteria in vldCO were satisfied; false if it failed to converge and another termination condition was reached.
-
trainToConverge
public boolean trainToConverge(RLinkDataSet<? extends VectorExample> examples, ConvergenceObserver co)
Uses the same set of examples for both training and validation. Note that the training does not stop when zero error rate is reached for the training dataset (since only one ConvergenceObserver is used). Not recommended; validation dataset should be used.- Parameters:
examples- - the single training (and validation) dataset. Not null.co- - gathers training statistics for the given dataset. Not null.- Returns:
- true if convergence criteria in co were satisfied; false if it failed to converge and another termination condition was reached.
-
-