Skip to content

Commit f5ba382

Browse files
Regularizer addition and fixes
1 parent 8775b0b commit f5ba382

File tree

9 files changed

+198
-69
lines changed

9 files changed

+198
-69
lines changed
Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
1-
namespace Tensorflow.Keras
1+
using Newtonsoft.Json;
2+
using System.Collections.Generic;
3+
using Tensorflow.Keras.Saving.Common;
4+
5+
namespace Tensorflow.Keras
26
{
3-
public interface IRegularizer
4-
{
5-
Tensor Apply(RegularizerArgs args);
7+
[JsonConverter(typeof(CustomizedRegularizerJsonConverter))]
8+
public interface IRegularizer
9+
{
10+
[JsonProperty("class_name")]
11+
string ClassName { get; }
12+
[JsonProperty("config")]
13+
IDictionary<string, object> Config { get; }
14+
Tensor Apply(RegularizerArgs args);
615
}
716
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
using Newtonsoft.Json.Linq;
2+
using Newtonsoft.Json;
3+
using System;
4+
using System.Collections.Generic;
5+
using System.Text;
6+
using Tensorflow.Operations.Regularizers;
7+
8+
namespace Tensorflow.Keras.Saving.Common
9+
{
10+
class RegularizerInfo
11+
{
12+
public string class_name { get; set; }
13+
public JObject config { get; set; }
14+
}
15+
16+
public class CustomizedRegularizerJsonConverter : JsonConverter
17+
{
18+
public override bool CanConvert(Type objectType)
19+
{
20+
return objectType == typeof(IRegularizer);
21+
}
22+
23+
public override bool CanRead => true;
24+
25+
public override bool CanWrite => true;
26+
27+
public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
28+
{
29+
var regularizer = value as IRegularizer;
30+
if (regularizer is null)
31+
{
32+
JToken.FromObject(null).WriteTo(writer);
33+
return;
34+
}
35+
JToken.FromObject(new RegularizerInfo()
36+
{
37+
class_name = regularizer.ClassName,
38+
config = JObject.FromObject(regularizer.Config)
39+
}, serializer).WriteTo(writer);
40+
}
41+
42+
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
43+
{
44+
var info = serializer.Deserialize<RegularizerInfo>(reader);
45+
if (info is null)
46+
{
47+
return null;
48+
}
49+
return info.class_name switch
50+
{
51+
"L1L2" => new L1L2 (info.config["l1"].ToObject<float>(), info.config["l2"].ToObject<float>()),
52+
"L1" => new L1(info.config["l1"].ToObject<float>()),
53+
"L2" => new L2(info.config["l2"].ToObject<float>()),
54+
};
55+
}
56+
}
57+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using System;
2+
3+
using Tensorflow.Keras;
4+
5+
namespace Tensorflow.Operations.Regularizers
6+
{
7+
public class L1 : IRegularizer
8+
{
9+
float _l1;
10+
private readonly Dictionary<string, object> _config;
11+
12+
public string ClassName => "L2";
13+
public virtual IDictionary<string, object> Config => _config;
14+
15+
public L1(float l1 = 0.01f)
16+
{
17+
// l1 = 0.01 if l1 is None else l1
18+
// validate_float_arg(l1, name = "l1")
19+
// self.l1 = ops.convert_to_tensor(l1)
20+
this._l1 = l1;
21+
22+
_config = new();
23+
_config["l1"] = _l1;
24+
}
25+
26+
27+
public Tensor Apply(RegularizerArgs args)
28+
{
29+
//return self.l1 * ops.sum(ops.absolute(x))
30+
return _l1 * math_ops.reduce_sum(math_ops.abs(args.X));
31+
}
32+
}
33+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
using System;
2+
3+
using Tensorflow.Keras;
4+
5+
namespace Tensorflow.Operations.Regularizers
6+
{
7+
public class L1L2 : IRegularizer
8+
{
9+
float _l1;
10+
float _l2;
11+
private readonly Dictionary<string, object> _config;
12+
13+
public string ClassName => "L1L2";
14+
public virtual IDictionary<string, object> Config => _config;
15+
16+
public L1L2(float l1 = 0.0f, float l2 = 0.0f)
17+
{
18+
//l1 = 0.0 if l1 is None else l1
19+
//l2 = 0.0 if l2 is None else l2
20+
// validate_float_arg(l1, name = "l1")
21+
// validate_float_arg(l2, name = "l2")
22+
23+
// self.l1 = l1
24+
// self.l2 = l2
25+
this._l1 = l1;
26+
this._l2 = l2;
27+
28+
_config = new();
29+
_config["l1"] = l1;
30+
_config["l2"] = l2;
31+
}
32+
33+
public Tensor Apply(RegularizerArgs args)
34+
{
35+
//regularization = ops.convert_to_tensor(0.0, dtype = x.dtype)
36+
//if self.l1:
37+
// regularization += self.l1 * ops.sum(ops.absolute(x))
38+
//if self.l2:
39+
// regularization += self.l2 * ops.sum(ops.square(x))
40+
//return regularization
41+
42+
Tensor regularization = tf.constant(0.0, args.X.dtype);
43+
regularization += _l1 * math_ops.reduce_sum(math_ops.abs(args.X));
44+
regularization += _l2 * math_ops.reduce_sum(math_ops.square(args.X));
45+
return regularization;
46+
}
47+
}
48+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using System;
2+
3+
using Tensorflow.Keras;
4+
5+
namespace Tensorflow.Operations.Regularizers
6+
{
7+
public class L2 : IRegularizer
8+
{
9+
float _l2;
10+
private readonly Dictionary<string, object> _config;
11+
12+
public string ClassName => "L2";
13+
public virtual IDictionary<string, object> Config => _config;
14+
15+
public L2(float l2 = 0.01f)
16+
{
17+
// l2 = 0.01 if l2 is None else l2
18+
// validate_float_arg(l2, name = "l2")
19+
// self.l2 = l2
20+
this._l2 = l2;
21+
22+
_config = new();
23+
_config["l2"] = _l2;
24+
}
25+
26+
27+
public Tensor Apply(RegularizerArgs args)
28+
{
29+
//return self.l2 * ops.sum(ops.square(x))
30+
return _l2 * math_ops.reduce_sum(math_ops.square(args.X));
31+
}
32+
}
33+
}
Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
11
namespace Tensorflow.Keras
22
{
3-
public class Regularizers
4-
{
5-
public IRegularizer l2(float l2 = 0.01f)
6-
=> new L2(l2);
7-
}
3+
public class Regularizers
4+
{
5+
public IRegularizer l1(float l1 = 0.01f)
6+
=> new Tensorflow.Operations.Regularizers.L1(l1);
7+
public IRegularizer l2(float l2 = 0.01f)
8+
=> new Tensorflow.Operations.Regularizers.L2(l2);
9+
10+
//From TF source
11+
//# The default value for l1 and l2 are different from the value in l1_l2
12+
//# for backward compatibility reason. Eg, L1L2(l2=0.1) will only have l2
13+
//# and no l1 penalty.
14+
public IRegularizer l1l2(float l1 = 0.00f, float l2 = 0.00f)
15+
=> new Tensorflow.Operations.Regularizers.L1L2(l1, l2);
16+
}
817
}

src/TensorFlowNET.Keras/Regularizers/L1.cs

Lines changed: 0 additions & 19 deletions
This file was deleted.

src/TensorFlowNET.Keras/Regularizers/L1L2.cs

Lines changed: 0 additions & 24 deletions
This file was deleted.

src/TensorFlowNET.Keras/Regularizers/L2.cs

Lines changed: 0 additions & 17 deletions
This file was deleted.

0 commit comments

Comments
 (0)