Skip to content

Commit

Permalink
more clear naming
Browse files Browse the repository at this point in the history
  • Loading branch information
russcam committed Nov 19, 2024
1 parent f0ef937 commit 470870e
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions src/RankLib/Learning/Tree/LambdaMART.cs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ public class LambdaMART : Ranker<LambdaMARTParameters>
private double[][] _modelScoresOnValidation = [];
private int _bestModelOnValidation = int.MaxValue - 2;
private int[][] _sortedIdx = [];
private FeatureHistogram _hist;
private FeatureHistogram _histogram;
private double[] _weights = [];

/// <summary>
Expand Down Expand Up @@ -319,8 +319,8 @@ await Task.Run(() =>
}
}

_hist = new FeatureHistogram(Parameters.SamplingRate, Parameters.MaxDegreeOfParallelism);
await _hist.ConstructAsync(MARTSamples, PseudoResponses, _sortedIdx, Features, _thresholds, Impacts).ConfigureAwait(false);
_histogram = new FeatureHistogram(Parameters.SamplingRate, Parameters.MaxDegreeOfParallelism);
await _histogram.ConstructAsync(MARTSamples, PseudoResponses, _sortedIdx, Features, _thresholds, Impacts).ConfigureAwait(false);

//we no longer need the sorted indexes of samples
_sortedIdx = [];
Expand Down Expand Up @@ -349,20 +349,20 @@ public override async Task LearnAsync()
await ComputePseudoResponsesAsync().ConfigureAwait(false);

//update the histogram with these training labels (the feature histogram will be used to find the best tree split)
await _hist.UpdateAsync(PseudoResponses).ConfigureAwait(false);
await _histogram.UpdateAsync(PseudoResponses).ConfigureAwait(false);

//Fit a regression tree
var rt = new RegressionTree(Parameters.TreeLeavesCount, MARTSamples, PseudoResponses, _hist, Parameters.MinimumLeafSupport);
await rt.FitAsync().ConfigureAwait(false);
var tree = new RegressionTree(Parameters.TreeLeavesCount, MARTSamples, PseudoResponses, _histogram, Parameters.MinimumLeafSupport);
await tree.FitAsync().ConfigureAwait(false);

//Add this tree to the ensemble (our model)
_ensemble.Add(rt, Parameters.LearningRate);
_ensemble.Add(tree, Parameters.LearningRate);

//update the outputs of the tree (with gamma computed using the Newton-Raphson method)
UpdateTreeOutput(rt);
UpdateTreeOutput(tree);

//Update the model's outputs on all training samples
var leaves = rt.Leaves;
var leaves = tree.Leaves;
for (var i = 0; i < leaves.Count; i++)
{
var s = leaves[i];
Expand All @@ -372,7 +372,7 @@ public override async Task LearnAsync()
}

//clear references to data that is no longer used
rt.ClearSamples();
tree.ClearSamples();

//Evaluate the current model
TrainingDataScore = ComputeModelScoreOnTraining();
Expand All @@ -384,7 +384,7 @@ public override async Task LearnAsync()
for (var i = 0; i < _modelScoresOnValidation.Length; i++)
{
for (var j = 0; j < _modelScoresOnValidation[i].Length; j++)
_modelScoresOnValidation[i][j] += Parameters.LearningRate * rt.Eval(ValidationSamples[i][j]);
_modelScoresOnValidation[i][j] += Parameters.LearningRate * tree.Eval(ValidationSamples[i][j]);
}

double score = ComputeModelScoreOnValidation();
Expand All @@ -406,20 +406,21 @@ public override async Task LearnAsync()
_ensemble.RemoveAt(_ensemble.TreeCount - 1);

TrainingDataScore = Scorer.Score(Rank(Samples));
_logger.LogInformation($"Finished successfully. {Scorer.Name} on training data: {SimpleMath.Round(TrainingDataScore, 4)}");
_logger.LogInformation("Finished successfully. {Scorer} on training data: {TrainingScore}",
Scorer.Name, SimpleMath.Round(TrainingDataScore, 4));

if (ValidationSamples != null)
{
ValidationDataScore = Scorer.Score(Rank(ValidationSamples));
_logger.LogInformation($"{Scorer.Name} on validation data: {SimpleMath.Round(ValidationDataScore, 4)}");
_logger.LogInformation("{Scorer} on validation data: {ValidationScore}", Scorer.Name, SimpleMath.Round(ValidationDataScore, 4));
}

_logger.LogInformation("-- FEATURE IMPACTS");
var ftrsSorted = MergeSorter.Sort(Impacts, false);
for (var index = 0; index < ftrsSorted.Length; index++)
var sortedImpacts = MergeSorter.Sort(Impacts, false);
for (var index = 0; index < sortedImpacts.Length; index++)
{
var ftr = ftrsSorted[index];
_logger.LogInformation($"Feature {Features[ftr]} reduced error {Impacts[ftr]}");
var ftr = sortedImpacts[index];
_logger.LogInformation("Feature {FeatureId} reduced error {Impact}", Features[ftr], Impacts[ftr]);
}
}

Expand Down Expand Up @@ -616,7 +617,8 @@ private RankList Rank(int rankListIndex, int current)
return new RankList(orig, idx);
}

private float ComputeModelScoreOnTraining() => ComputeModelScoreOnTraining(0, Samples.Count - 1, 0) / Samples.Count;
private float ComputeModelScoreOnTraining() =>
ComputeModelScoreOnTraining(0, Samples.Count - 1, 0) / Samples.Count;

private float ComputeModelScoreOnTraining(int start, int end, int current)
{
Expand Down

0 comments on commit 470870e

Please sign in to comment.