Skip to content

Commit

Permalink
+ Adding support for type conversion and modulo expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
artiomchi committed Nov 3, 2018
1 parent dbfbac6 commit c218138
Show file tree
Hide file tree
Showing 11 changed files with 256 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Also supports injecting sql command generators to add support for other provider
+ Adding support for static property/field accessors (e.g. DateTime.Now)
* Explicitly throwing an exception when using identity keys as upsert match columns (since it wouldn't have worked correctly anyway)
+ Added help links to exceptions linking to more details
+ Added support for basic type conversions and modulo operator in expressions
</PackageReleaseNotes>
</PropertyGroup>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,48 @@ private static object GetValueInternal<TSource>(this Expression expression, Lamb
{
switch (expression.NodeType)
{
case ExpressionType.Call:
{
var methodExp = (MethodCallExpression)expression;
var context = methodExp.Object?.GetValueInternal<TSource>(container, useExpressionCompiler, true);
var arguments = methodExp.Arguments.Select(a => a.GetValueInternal<TSource>(container, useExpressionCompiler, true)).ToArray();
return methodExp.Method.Invoke(context, arguments);
}

case ExpressionType.Coalesce:
{
var coalesceExp = (BinaryExpression)expression;
var left = coalesceExp.Left.GetValueInternal<TSource>(container, useExpressionCompiler, nested);
var right = coalesceExp.Right.GetValueInternal<TSource>(container, useExpressionCompiler, nested);

if (left == null)
return right;
if (!(left is IKnownValue))
return left;

if (!(left is IKnownValue leftValue))
leftValue = new ConstantValue(left);
if (!(right is IKnownValue rightValue))
rightValue = new ConstantValue(right);

return new KnownExpression(expression.NodeType, leftValue, rightValue);
}

case ExpressionType.Constant:
{
return ((ConstantExpression)expression).Value;
}

case ExpressionType.Convert:
{
var convertExp = (UnaryExpression)expression;
if (!nested)
return convertExp.Operand.GetValueInternal<TSource>(container, useExpressionCompiler, nested);

var value = convertExp.Operand.GetValueInternal<TSource>(container, useExpressionCompiler, true);
return Convert.ChangeType(value, convertExp.Type);
}

case ExpressionType.MemberAccess:
{
var memberExp = (MemberExpression)expression;
Expand Down Expand Up @@ -64,18 +101,11 @@ private static object GetValueInternal<TSource>(this Expression expression, Lamb
return result;
}

case ExpressionType.Call:
{
var methodExp = (MethodCallExpression)expression;
var context = methodExp.Object?.GetValueInternal<TSource>(container, useExpressionCompiler, true);
var arguments = methodExp.Arguments.Select(a => a.GetValueInternal<TSource>(container, useExpressionCompiler, true)).ToArray();
return methodExp.Method.Invoke(context, arguments);
}

case ExpressionType.Add:
case ExpressionType.Subtract:
case ExpressionType.Multiply:
case ExpressionType.Divide:
case ExpressionType.Modulo:
case ExpressionType.Multiply:
case ExpressionType.Subtract:
{
var exp = (BinaryExpression)expression;
if (!nested && exp.Method == null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,16 +175,28 @@ protected virtual string ExpandExpression(KnownExpression expression)
{
case ExpressionType.Add:
case ExpressionType.Divide:
case ExpressionType.Modulo:
case ExpressionType.Multiply:
case ExpressionType.Subtract:
var left = ExpandValue(expression.Value1);
var right = ExpandValue(expression.Value2);
var op = GetSimpleOperator(expression.ExpressionType);
return $"{left} {op} {right}";
{
var left = ExpandValue(expression.Value1);
var right = ExpandValue(expression.Value2);
var op = GetSimpleOperator(expression.ExpressionType);
return $"{left} {op} {right}";
}

case ExpressionType.Coalesce:
{
var left = ExpandValue(expression.Value1);
var right = ExpandValue(expression.Value2);
return $"COALESCE({left}, {right})";
}

case ExpressionType.MemberAccess:
case ExpressionType.Constant:
return ExpandValue(expression.Value1);
{
return ExpandValue(expression.Value1);
}

default: throw new NotSupportedException("Don't know how to process operation: " + expression.ExpressionType);
}
Expand All @@ -201,6 +213,7 @@ protected virtual string GetSimpleOperator(ExpressionType expressionType)
{
case ExpressionType.Add: return "+";
case ExpressionType.Divide: return "/";
case ExpressionType.Modulo: return "%";
case ExpressionType.Multiply: return "*";
case ExpressionType.Subtract: return "-";
default: throw new InvalidOperationException($"{expressionType} is not a simple arithmetic operation");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public class Country
[Required, StringLength(2)]
public string ISO { get; set; }
public DateTime Created { get; set; }
public DateTime Updated { get; set; }
public DateTime? Updated { get; set; }
}

[Table("Dash-Table")]
Expand Down
74 changes: 73 additions & 1 deletion test/FlexLabs.EntityFrameworkCore.Upsert.Tests/EF/BasicTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ public void Dispose()
Name = "...loading...",
ISO = "AU",
Created = new DateTime(1970, 1, 1),
Updated = new DateTime(1970, 1, 1),
};
PageVisit _dbVisitOld = new PageVisit
{
Expand Down Expand Up @@ -390,6 +389,42 @@ public void Upsert_Country_Update_On_WhenMatched_Values(TestDbContext.DbDriver d
}
}

[Theory]
[MemberData(nameof(GetDatabaseEngines))]
public void Upsert_Country_Update_On_WhenMatched_Constants(TestDbContext.DbDriver driver)
{
ResetDb(driver);
using (var dbContext = new TestDbContext(_dataContexts[driver]))
{
var newCountry = new Country
{
Name = "Australia",
ISO = "AU",
Created = _now,
Updated = _now,
};

dbContext.Countries.Upsert(newCountry)
.On(c => c.ISO)
.WhenMatched(c => new Country
{
Name = "Australia",
Updated = _now,
})
.Run();

Assert.Collection(dbContext.Countries.OrderBy(c => c.ID),
country =>
{
Assert.Equal(newCountry.ISO, country.ISO);
Assert.Equal(newCountry.Name, country.Name);
Assert.NotEqual(newCountry.Created, country.Created);
Assert.Equal(_dbCountry.Created, country.Created);
Assert.Equal(newCountry.Updated, country.Updated);
});
}
}

[Theory]
[MemberData(nameof(GetDatabaseEngines))]
public void Upsert_Country_Insert_On_WhenMatched(TestDbContext.DbDriver driver)
Expand Down Expand Up @@ -814,6 +849,43 @@ public void Upsert_PageVisit_Update_On_WhenMatched_ValueDivide(TestDbContext.DbD
}
}

[Theory]
[MemberData(nameof(GetDatabaseEngines))]
public void Upsert_PageVisit_Update_On_WhenMatched_ValueModulo(TestDbContext.DbDriver driver)
{
ResetDb(driver);
using (var dbContext = new TestDbContext(_dataContexts[driver]))
{
var newVisit = new PageVisit
{
UserID = 1,
Date = DateTime.Today,
Visits = 1,
FirstVisit = _now,
LastVisit = _now,
};

dbContext.PageVisits.Upsert(newVisit)
.On(pv => new { pv.UserID, pv.Date })
.WhenMatched(pv => new PageVisit
{
Visits = pv.Visits % 4,
LastVisit = _now,
})
.Run();

Assert.Collection(dbContext.PageVisits.OrderBy(c => c.ID),
visit => AssertEqual(_dbVisitOld, visit),
visit =>
{
Assert.Equal(_dbVisit.Visits % 4, visit.Visits);
Assert.NotEqual(newVisit.FirstVisit, visit.FirstVisit);
Assert.Equal(_dbVisit.FirstVisit, visit.FirstVisit);
Assert.Equal(newVisit.LastVisit, visit.LastVisit);
});
}
}

[Theory]
[MemberData(nameof(GetDatabaseEngines))]
public void UpsertRange_PageVisit_Update_On_WhenMatched(TestDbContext.DbDriver driver)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,26 @@ public void ExpressionHelpersTests_ValueDivide()
Assert.Equal(4, value2.Value);
}

[Fact]
public void ExpressionHelpersTests_ValueModulo()
{
Expression<Func<TestEntity, TestEntity>> exp = e => new TestEntity
{
Num1 = e.Num1 % 4,
};

var memberAssig = GetMemberExpression(exp);
var expValue = memberAssig.GetValue<TestEntity>(exp);

var knownValue = Assert.IsType<KnownExpression>(expValue);
Assert.Equal(ExpressionType.Modulo, knownValue.ExpressionType);
var value1 = Assert.IsType<ParameterProperty>(knownValue.Value1);
Assert.Equal("Num1", value1.PropertyName);
Assert.True(value1.IsLeftParameter);
var value2 = Assert.IsType<ConstantValue>(knownValue.Value2);
Assert.Equal(4, value2.Value);
}

[Fact]
public void ExpressionHelpersTests_Property()
{
Expand Down Expand Up @@ -276,6 +296,70 @@ public void ExpressionHelpersTests_DateTime_Now()
Assert.True(updated < DateTime.Now.AddMinutes(1));
}

[Fact]
public void ExpressionHelpersTests_Nullable_Assign()
{
int value = 5;

Expression<Func<TestEntity, TestEntity, TestEntity>> exp = (e1, e2) => new TestEntity
{
NumNullable1 = value,
};

var memberAssig = GetMemberExpression(exp);
var expValue = memberAssig.GetValue<TestEntity>(exp);
var num = Assert.IsType<int>(expValue);
Assert.Equal(value, num);
}

[Fact]
public void ExpressionHelpersTests_Nullable_Cast()
{
int? value = 5;

Expression<Func<TestEntity, TestEntity, TestEntity>> exp = (e1, e2) => new TestEntity
{
Num1 = (int)value,
};

var memberAssig = GetMemberExpression(exp);
var expValue = memberAssig.GetValue<TestEntity>(exp);
var num = Assert.IsType<int>(expValue);
Assert.Equal(value.Value, num);
}

[Fact]
public void ExpressionHelpersTests_Nullable_Coalesce()
{
int? value = 5;

Expression<Func<TestEntity, TestEntity, TestEntity>> exp = (e1, e2) => new TestEntity
{
Num1 = value ?? 0,
};

var memberAssig = GetMemberExpression(exp);
var expValue = memberAssig.GetValue<TestEntity>(exp);
var num = Assert.IsType<int>(expValue);
Assert.Equal(value, num);
}

[Fact]
public void ExpressionHelpersTests_Nullable_GetValueOrDefault()
{
int? value = 5;

Expression<Func<TestEntity, TestEntity, TestEntity>> exp = (e1, e2) => new TestEntity
{
Num1 = value.GetValueOrDefault(),
};

var memberAssig = GetMemberExpression(exp);
var expValue = memberAssig.GetValue<TestEntity>(exp);
var num = Assert.IsType<int>(expValue);
Assert.Equal(value.Value, num);
}

[Fact]
public void ExpressionHelperTests_UnsupportedExpression()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,8 @@ public class MySqlUpsertCommandRunnerTests : RelationalCommandRunnerTestsBase

protected override string Update_BinaryAdd_Sql =>
"INSERT INTO myTable (`Name`, `Status`) VALUES (@p0, @p1) ON DUPLICATE KEY UPDATE `Status` = `Status` + @p2";

protected override string Update_Coalesce_Sql =>
"INSERT INTO myTable (`Name`, `Status`) VALUES (@p0, @p1) ON DUPLICATE KEY UPDATE `Status` = COALESCE(`Status`, @p2)";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,10 @@ public class PostgreSqlUpsertCommandRunnerTests : RelationalCommandRunnerTestsBa
"INSERT INTO myTable AS \"T\" (\"Name\", \"Status\") " +
"VALUES (@p0, @p1) ON CONFLICT (\"ID\") " +
"DO UPDATE SET \"Status\" = \"T\".\"Status\" + @p2";

protected override string Update_Coalesce_Sql =>
"INSERT INTO myTable AS \"T\" (\"Name\", \"Status\") " +
"VALUES (@p0, @p1) ON CONFLICT (\"ID\") " +
"DO UPDATE SET \"Status\" = COALESCE(\"T\".\"Status\", @p2)";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,28 @@ public void SqlSyntaxRunner_Update_BinaryAdd()

Assert.Equal(Update_BinaryAdd_Sql, generatedSql);
}

protected abstract string Update_Coalesce_Sql { get; }
[Fact]
public void SqlSyntaxRunner_Update_Coalesce()
{
var runner = GetRunner();
var tableName = "myTable";
ICollection<(string ColumnName, ConstantValue Value)> entity = new[]
{
( "Name", new ConstantValue("value") { ArgumentIndex = 0 } ),
( "Status", new ConstantValue(3) { ArgumentIndex = 1} ),
};
var updates = new[]
{
("Status", new KnownExpression(ExpressionType.Coalesce,
new ParameterProperty("Status", true) { Property = new MockProperty("Status") },
new ConstantValue(1) { ArgumentIndex = 2 }))
};

var generatedSql = runner.GenerateCommand(tableName, new[] { entity }, new[] { "ID" }, updates);

Assert.Equal(Update_Coalesce_Sql, generatedSql);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,12 @@ public class SqlServerUpsertCommandRunnerTests : RelationalCommandRunnerTestsBas
"ON [T].[ID] = [S].[ID] " +
"WHEN NOT MATCHED BY TARGET THEN INSERT ([Name], [Status]) VALUES ([Name], [Status]) " +
"WHEN MATCHED THEN UPDATE SET [Status] = [T].[Status] + @p2;";

protected override string Update_Coalesce_Sql =>
"MERGE INTO myTable WITH (HOLDLOCK) AS [T] " +
"USING ( VALUES (@p0, @p1) ) AS [S] ([Name], [Status]) " +
"ON [T].[ID] = [S].[ID] " +
"WHEN NOT MATCHED BY TARGET THEN INSERT ([Name], [Status]) VALUES ([Name], [Status]) " +
"WHEN MATCHED THEN UPDATE SET [Status] = COALESCE([T].[Status], @p2);";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class TestEntity
{
public int Num1 { get; set; }
public int Num2 { get; set; }
public int? NumNullable1 { get; set; }
public string Text1 { get; set; }
public string Text2 { get; set; }
public DateTime Updated { get; set; }
Expand Down

0 comments on commit c218138

Please sign in to comment.