diff --git a/README.md b/README.md index f754a52..c555671 100644 --- a/README.md +++ b/README.md @@ -374,7 +374,7 @@ For this purpose exists `IValueGenerator` interface which your class must implem ```c# public interface IValueGenerator { - object GetValue(); + object GetValue(); } ``` @@ -383,7 +383,7 @@ Here is an example of custom value generator: ```c# private class AutoIncrementValueGenerator : IValueGenerator { - public object GetValue() => 123; + public object GetValue() => 123; } ``` @@ -777,6 +777,24 @@ _database.ExecuteWithTempTable(ids, (database, tableName) => database.Query() .From($"PERSON AS P INNER JOIN {tableName} AS T ON (P.Id = T.Value)") .ToList()); + +public class IdDto +{ + public IdDto(int id) + { + Id = id; + } + + public int Id { get; set; } +} + +var ids = new List(){ new IdDto(0), new IdDto(1), new IdDto(2), new IdDto(3) }; + +_database.ExecuteWithTempTable(ids, (database, tableName) + => database.Query() + .Select("P.*") + .From($"PERSON AS P INNER JOIN {tableName} AS T ON (P.Id = T.Id)") + .ToList()); ``` ### SQL commands executing diff --git a/src/IDatabaseExtensions.TempTable.cs b/src/IDatabaseExtensions.TempTable.cs index 9041833..ac4edab 100644 --- a/src/IDatabaseExtensions.TempTable.cs +++ b/src/IDatabaseExtensions.TempTable.cs @@ -1,8 +1,10 @@ using Kros.Data.BulkActions; using Kros.KORM.Data; using Kros.KORM.Extensions; +using Kros.KORM.Metadata; using System; using System.Collections.Generic; +using System.Linq; using System.Threading.Tasks; namespace Kros.KORM @@ -234,25 +236,38 @@ private static void InsertValuesIntoTempTable( IEnumerable values, string tempTableName) { - database.ExecuteNonQuery($"CREATE TABLE {tempTableName} ( Value {typeof(TValue).ToSqlDataType()} )"); + TableInfo tableInfo = Database.DatabaseMapper.GetTableInfo(); + string columns = GetColumnsWithSqlTypes(tableInfo, typeof(TValue)); + database.ExecuteNonQuery($"CREATE TABLE {tempTableName} ( {columns} )"); using IBulkInsert bulkInsert = database.CreateBulkInsert(); bulkInsert.DestinationTableName = tempTableName; - using var reader = new EnumerableDataReader(values, new string[] { "Value" }); + using var reader = new EnumerableDataReader(values, GetColumns(tableInfo, typeof(TValue))); bulkInsert.Insert(reader); } + private static string GetColumnsWithSqlTypes(TableInfo tableInfo, Type type) + => (type.IsPrimitive || type == typeof(string)) + ? $"Value {type.ToSqlDataType()}" + : string.Join( + ",", + tableInfo.Columns.Select(c => $"[{c.PropertyInfo.Name}] {c.PropertyInfo.PropertyType.ToSqlDataType()}")); + + private static IEnumerable GetColumns(TableInfo tableInfo, Type type) + => (type.IsPrimitive || type == typeof(string)) + ? new string[] { "Value" } + : tableInfo.Columns.Select(c => c.PropertyInfo.Name); + private static void InsertValuesIntoTempTable( IDatabase database, IDictionary values, string tempTableName) { - database.ExecuteNonQuery($"CREATE TABLE {tempTableName} " + - $"( [Key] {typeof(TKey).ToSqlDataType()}, " + - $" [Value] {typeof(TValue).ToSqlDataType()} )"); + database.ExecuteNonQuery( + $"CREATE TABLE {tempTableName}([Key] {typeof(TKey).ToSqlDataType()}, [Value] {typeof(TValue).ToSqlDataType()})"); using IBulkInsert bulkInsert = database.CreateBulkInsert(); bulkInsert.DestinationTableName = tempTableName; diff --git a/src/Kros.KORM.csproj b/src/Kros.KORM.csproj index 12b1c72..793bc6a 100644 --- a/src/Kros.KORM.csproj +++ b/src/Kros.KORM.csproj @@ -2,7 +2,7 @@ netcoreapp2.1;net46 - 4.3.1 + 4.3.2 KROS a. s. KROS a. s. KORM is fast, easy to use, micro ORM tool (Kros Object Relation Mapper). diff --git a/tests/Kros.KORM.UnitTests/Integration/IDatabaseExtensionsTempTableShould.cs b/tests/Kros.KORM.UnitTests/Integration/IDatabaseExtensionsTempTableShould.cs index 51b6056..779564c 100644 --- a/tests/Kros.KORM.UnitTests/Integration/IDatabaseExtensionsTempTableShould.cs +++ b/tests/Kros.KORM.UnitTests/Integration/IDatabaseExtensionsTempTableShould.cs @@ -1,6 +1,9 @@ using FluentAssertions; using Kros.Extensions; +using Kros.KORM.Metadata.Attribute; using Kros.KORM.UnitTests.Base; +using Microsoft.Data.SqlClient; +using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; @@ -40,6 +43,72 @@ public void ExecuteWithTempTableList() } } + [Fact] + public void ExecuteWithTempTableObjectList() + { + using (IDatabase database = CreateDatabase(CreateTable_TestTable, InsertDataScript2)) + { + List ids = CreateIds(); + var affectedCount = database.ExecuteWithTempTable( + ids, + (database, tableName) => database.ExecuteNonQuery( + $@"UPDATE P + SET P.Age = 18 + FROM People AS P INNER JOIN {tableName} AS T ON (P.Id = T.Id)")); + + affectedCount.Should().Be(4); + } + } + + [Fact] + public void ExecuteWithTempTableObjectListNoMapColumnThrowException() + { + using (IDatabase database = CreateDatabase(CreateTable_TestTable, InsertDataScript2)) + { + List data = CreateTestData(); + Action act = () => database.ExecuteWithTempTable( + data, + (database, tableName) => database.ExecuteNonQuery( + $@"UPDATE P + SET P.Age = 18 + FROM People AS P INNER JOIN {tableName} AS T ON (P.Id = T.Number)")); + act.Should().Throw().WithMessage("Invalid column name 'Number'."); + } + } + + [Fact] + public void ExecuteWithTempTableObjectListAliasColumnThrowException() + { + using (IDatabase database = CreateDatabase(CreateTable_TestTable, InsertDataScript2)) + { + List data = CreateTestData(); + Action act = () => database.ExecuteWithTempTable( + data, + (database, tableName) => database.ExecuteNonQuery( + $@"UPDATE P + SET P.Age = 18 + FROM People AS P INNER JOIN {tableName} AS T ON (P.Id = T.Value)")); + act.Should().Throw().WithMessage("Invalid column name 'Value'."); + } + } + + [Fact] + public void ExecuteWithTempTableObjectListAliasColumn() + { + using (IDatabase database = CreateDatabase(CreateTable_TestTable, InsertDataScript2)) + { + List data = CreateTestData(); + var affectedCount = database.ExecuteWithTempTable( + data, + (database, tableName) => database.ExecuteNonQuery( + $@"UPDATE P + SET P.Age = 18 + FROM People AS P INNER JOIN {tableName} AS T ON (P.Id = T.Id)")); + + affectedCount.Should().Be(1); + } + } + [Fact] public async Task ExecuteWithTempTableListAsync() { @@ -57,6 +126,23 @@ public async Task ExecuteWithTempTableListAsync() } } + [Fact] + public async Task ExecuteWithTempTableObjectListAsync() + { + using (IDatabase database = CreateDatabase(CreateTable_TestTable, InsertDataScript2)) + { + List ids = CreateIds(); + var affectedCount = await database.ExecuteWithTempTableAsync( + ids, + (database, tableName) => database.ExecuteNonQueryAsync( + $@"UPDATE P + SET P.Age = 18 + FROM People AS P INNER JOIN {tableName} AS T ON (P.Id = T.Id)")); + + affectedCount.Should().Be(4); + } + } + [Fact] public void ExecuteWithTempTableTList() { @@ -72,6 +158,23 @@ public void ExecuteWithTempTableTList() } } + [Fact] + public void ExecuteWithTempTableTObjectList() + { + using (IDatabase database = CreateDatabase(CreateTable_TestTable, InsertDataScript2)) + { + List ids = CreateIds(); + IEnumerable result = database.ExecuteWithTempTable( + ids, + (database, tableName) => database.Query() + .Select("P.*") + .From($"People AS P INNER JOIN {tableName} AS T ON (P.Id = T.Id)") + .ToList()); + + result.Should().HaveCount(4); + } + } + [Fact] public async Task ExecuteWithTempTableTListAsync() { @@ -89,6 +192,24 @@ public async Task ExecuteWithTempTableTListAsync() } } + [Fact] + public async Task ExecuteWithTempTableTObjectListAsync() + { + using (IDatabase database = CreateDatabase(CreateTable_TestTable, InsertDataScript2)) + { + List ids = CreateIds(); + IEnumerable result = await database.ExecuteWithTempTableAsync( + ids, + (database, tableName) => database.Query() + .Select("P.*") + .From($"People AS P INNER JOIN {tableName} AS T ON (P.Id = T.Id)") + .ToList() + .AsTask()); + + result.Should().HaveCount(4); + } + } + [Fact] public void ExecuteWithTempTableDictionary() { @@ -175,5 +296,40 @@ await database.ExecuteNonQueryAsync( result.Should().HaveCount(2); } } + + private static List CreateIds() + => new List() + { + new IdDto(1), + new IdDto(2), + new IdDto(3), + new IdDto(4), + new IdDto(456), + new IdDto(789) + }; + + private static List CreateTestData() + => new List() + { + new TestDto(1, 1), + new TestDto(789, 789) + }; + + private record IdDto(int Id); + + private class TestDto + { + public TestDto(int id, int number) + { + Id = id; + Number = number; + } + + [Alias("Value")] + public int Id { get; set; } + + [NoMap] + public int Number { get; set; } + } } }