Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add the implementation of class_weight in model.fit #1189

Merged
merged 1 commit into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
using Tensorflow.Keras.ArgsDefinition;
using static Tensorflow.Binding;
using Tensorflow.Keras.Utils;
using Tensorflow.Util;
using Tensorflow.Framework;

namespace Tensorflow.Keras.Engine.DataAdapters
{
Expand All @@ -24,6 +26,7 @@ public class DataHandler
long _steps_per_execution_value;
int _initial_epoch => args.InitialEpoch;
int _epochs => args.Epochs;
NDArray _sample_weight => args.SampleWeight;
IVariableV1 _steps_per_execution;

public DataHandler(DataHandlerArgs args)
Expand Down Expand Up @@ -75,10 +78,75 @@ public DataHandler(DataHandlerArgs args)
}

_dataset = _adapter.GetDataset();
_inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset);
_current_step = 0;
_step_increment = _steps_per_execution_value - 1;
_insufficient_data = false;
_configure_dataset_and_inferred_steps(args.X, args.ClassWeight);
}

void _configure_dataset_and_inferred_steps(Tensors x, Dictionary<int, float> class_weight)
{
if (_dataset == null)
{
_dataset = _adapter.GetDataset();
_inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset);
}

if (class_weight != null)
{
_dataset = _dataset.map(_make_class_weight_map_fn(class_weight));
}
_inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset);
}


Func<Tensors, Tensors> _make_class_weight_map_fn(Dictionary<int, float> class_weight)
{
var class_ids = class_weight.Keys.OrderBy(key => key).ToList();
var expected_class_ids = range(class_ids[0], class_ids[class_ids.Count - 1] + 1);
if (!class_ids.SequenceEqual(expected_class_ids))
{
throw new ValueError("Expected `class_weight` to be a dict with keys from 0 to one less "+
$"than the number of classes, found {class_weight}");
}

var class_weight_list = new List<float>();
foreach (var class_id in class_ids)
{
class_weight_list.Add(class_weight[class_id]);
}
var class_weight_tensor = tf.convert_to_tensor(class_weight_list.ToArray());

Func<Tensors, Tensors> _class_weight_map_fn = (Tensors data) =>
{
var x = data[0];
var y = data[1];
var sw = _sample_weight == null ? null : ops.convert_to_tensor(_sample_weight);

if (y.shape.rank > 2)
{
throw new ValueError("`class_weight` not supported for 3+ dimensional targets.");
}

var y_classes = smart_module.smart_cond(
y.shape.rank == 2 && y.shape[1] > 1,
() => math_ops.argmax(y, dimension: 1),
() => math_ops.cast(tf.reshape(y, (-1)), TF_DataType.TF_INT64));

var cw = array_ops.gather(class_weight_tensor, y_classes);
if (sw != null)
{
cw = tf.cast(cw, sw.dtype);
cw *= sw;
}
else
{
sw = cw;
}
return new Tensors { x, y, sw };
};

return _class_weight_map_fn;
}

long _infer_steps(int steps_per_epoch, IDatasetV2 dataset)
Expand Down
13 changes: 11 additions & 2 deletions src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,20 @@ Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handl
}


Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight = null)
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y)
{
(x,y) = data_handler.DataAdapter.Expand1d(x, y);
var y_pred = Apply(x, training: false);
var loss = compiled_loss.Call(y, y_pred);
compiled_metrics.update_state(y, y_pred);
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2);
}

Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight)
{
(x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight);
var y_pred = Apply(x, training: false);
var loss = compiled_loss.Call(y, y_pred, sample_weight:sample_weight);
var loss = compiled_loss.Call(y, y_pred, sample_weight: sample_weight);
compiled_metrics.update_state(y, y_pred);
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2);
}
Expand Down
11 changes: 4 additions & 7 deletions src/TensorFlowNET.Keras/Engine/Model.Fit.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,6 @@ public ICallback fit(NDArray x, NDArray y,
((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split);
}

// TODO(Wanglongzhi2001)
if (class_weight != null)
{
throw new NotImplementedException("class_weight is not implemented");
}

var data_handler = new DataHandler(new DataHandlerArgs
{
X = x,
Expand All @@ -78,6 +72,7 @@ public ICallback fit(NDArray x, NDArray y,
InitialEpoch = initial_epoch,
Epochs = epochs,
Shuffle = shuffle,
ClassWeight = class_weight,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
Expand Down Expand Up @@ -126,11 +121,12 @@ public ICallback fit(IEnumerable<NDArray> x, NDArray y,
{
X = new Tensors(x.ToArray()),
Y = y,
SampleWeight = sample_weight,
BatchSize = batch_size,
InitialEpoch = initial_epoch,
Epochs = epochs,
Shuffle = shuffle,
SampleWeight = sample_weight,
ClassWeight = class_weight,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
Expand Down Expand Up @@ -174,6 +170,7 @@ public History fit(IDatasetV2 dataset,
InitialEpoch = initial_epoch,
Epochs = epochs,
Shuffle = shuffle,
SampleWeight = sample_weight,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
Expand Down
Loading