Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to mock protected methods with and without return value #845

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/NSubstitute/Core/IThreadLocalContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ public interface IThreadLocalContext
void EnqueueArgumentSpecification(IArgumentSpecification spec);
IList<IArgumentSpecification> DequeueAllArgumentSpecifications();

/// <summary>
/// Peeks into the argument specifications
/// </summary>
/// <returns>Enqueued argument specifications</returns>
IList<IArgumentSpecification> PeekAllArgumentSpecifications();

void SetPendingRaisingEventArgumentsFactory(Func<ICall, object?[]> getArguments);
/// <summary>
/// Returns the previously set arguments factory and resets the stored value.
Expand Down
18 changes: 18 additions & 0 deletions src/NSubstitute/Core/ThreadLocalContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,24 @@ public IList<IArgumentSpecification> DequeueAllArgumentSpecifications()
return queue;
}

/// <inheritdoc/>
public IList<IArgumentSpecification> PeekAllArgumentSpecifications()
{
var queue = _argumentSpecifications.Value;
if (queue == null) { throw new SubstituteInternalException("Argument specification queue is null."); }

if (queue.Count > 0)
{
var items = new IArgumentSpecification[queue.Count];

queue.CopyTo(items, 0);

return items;
}

return EmptySpecifications;
}

public void SetPendingRaisingEventArgumentsFactory(Func<ICall, object?[]> getArguments)
{
_getArgumentsForRaisingEvent.Value = getArguments;
Expand Down
59 changes: 59 additions & 0 deletions src/NSubstitute/Extensions/ProtectedExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
using System.Reflection;
using NSubstitute.Core;
using NSubstitute.Core.Arguments;

// Disable nullability for client API, so it does not affect clients.
#nullable disable annotations

namespace NSubstitute.Extensions;

public static class ProtectedExtensions
{
/// <summary>
/// Configure behavior for a protected method with return value
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="obj">The object.</param>
/// <param name="methodName">Name of the method.</param>
/// <param name="args">The method arguments.</param>
/// <returns>Result object from the method invocation.</returns>
/// <exception cref="System.ArgumentNullException">Substitute - Cannot mock null object</exception>
/// <exception cref="System.ArgumentException">Must provide valid protected method name to mock - methodName</exception>
public static object Protected<T>(this T obj, string methodName, params object[] args) where T : class
{
if (obj == null) { throw new ArgumentNullException(nameof(obj), "Cannot mock null object"); }
if (string.IsNullOrWhiteSpace(methodName)) { throw new ArgumentException("Must provide valid protected method name to mock", nameof(methodName)); }

IList<IArgumentSpecification> argTypes = SubstitutionContext.Current.ThreadContext.PeekAllArgumentSpecifications();
MethodInfo mthdInfo = obj.GetType().GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Instance, Type.DefaultBinder, argTypes.Select(x => x.ForType).ToArray(), null);

if (mthdInfo == null) { throw new Exception($"Method {methodName} not found"); }
if (!mthdInfo.IsVirtual) { throw new Exception($"Method {methodName} is not virtual"); }

return mthdInfo.Invoke(obj, args);
}

/// <summary>
/// Configure behavior for a protected method with no return vlaue
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="obj">The object.</param>
/// <param name="methodName">Name of the method.</param>
/// <param name="args">The method arguments.</param>
/// <returns>WhenCalled&lt;T&gt;.</returns>
/// <exception cref="System.ArgumentNullException">Substitute - Cannot mock null object</exception>
/// <exception cref="System.ArgumentException">Must provide valid protected method name to mock - methodName</exception>
public static WhenCalled<T> When<T>(this T obj, string methodName, params object[] args) where T : class
{
if (obj == null) { throw new ArgumentNullException(nameof(obj), "Cannot mock null object"); }
if (string.IsNullOrWhiteSpace(methodName)) { throw new ArgumentException("Must provide valid protected method name to mock", nameof(methodName)); }

IList<IArgumentSpecification> argTypes = SubstitutionContext.Current.ThreadContext.PeekAllArgumentSpecifications();
MethodInfo mthdInfo = obj.GetType().GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Instance, Type.DefaultBinder, argTypes.Select(y => y.ForType).ToArray(), null);

if (mthdInfo == null) { throw new Exception($"Method {methodName} not found"); }
if (!mthdInfo.IsVirtual) { throw new Exception($"Method {methodName} is not virtual"); }

return new WhenCalled<T>(SubstitutionContext.Current, obj, x => mthdInfo.Invoke(x, args), MatchArgs.AsSpecifiedInCall);
}
}
46 changes: 46 additions & 0 deletions tests/NSubstitute.Acceptance.Specs/Infrastructure/AnotherClass.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
namespace NSubstitute.Acceptance.Specs.Infrastructure;

public abstract class AnotherClass
{
protected abstract string ProtectedMethod();

protected abstract string ProtectedMethod(int i);

protected abstract string ProtectedMethod(string msg, int i, char j);

protected abstract void ProtectedMethodWithNoReturn();

protected abstract void ProtectedMethodWithNoReturn(int i);

protected abstract void ProtectedMethodWithNoReturn(string msg, int i, char j);

public string DoWork()
{
return ProtectedMethod();
}

public string DoWork(int i)
{
return ProtectedMethod(i);
}

public string DoWork(string msg, int i, char j)
{
return ProtectedMethod(msg, i, j);
}

public void DoVoidWork()
{
ProtectedMethodWithNoReturn();
}

public void DoVoidWork(int i)
{
ProtectedMethodWithNoReturn(i);
}

public void DoVoidWork(string msg, int i, char j)
{
ProtectedMethodWithNoReturn(msg, i, j);
}
}
123 changes: 123 additions & 0 deletions tests/NSubstitute.Acceptance.Specs/ProtectedExtensionsTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
using NSubstitute.Acceptance.Specs.Infrastructure;
using NSubstitute.Extensions;
using NUnit.Framework;

namespace NSubstitute.Acceptance.Specs;

public class ProtectedExtensionsTests
{
[Test]
public void Should_mock_and_verify_protected_method_with_no_args()
{
var expectedMsg = "unit test message";
var sub = Substitute.For<AnotherClass>();
var worker = new Worker();

sub.Protected("ProtectedMethod").Returns(expectedMsg);

Assert.That(worker.DoWork(sub), Is.EqualTo(expectedMsg));
sub.Received(1).Protected("ProtectedMethod");
}

[Test]
public void Should_mock_and_verify_protected_method_with_arg()
{
var expectedMsg = "unit test message";
var sub = Substitute.For<AnotherClass>();
var worker = new Worker();

sub.Protected("ProtectedMethod", Arg.Any<int>()).Returns(expectedMsg);

Assert.That(worker.DoMoreWork(sub, 5), Is.EqualTo(expectedMsg));
var a = sub.Received(1);
a.Protected("ProtectedMethod", Arg.Any<int>());
}

[Test]
public void Should_mock_and_verify_protected_method_with_multiple_args()
{
var expectedMsg = "unit test message";
var sub = Substitute.For<AnotherClass>();
var worker = new Worker();

sub.Protected("ProtectedMethod", Arg.Any<string>(), Arg.Any<int>(), Arg.Any<char>()).Returns(expectedMsg);

Assert.That(worker.DoEvenMoreWork(sub, 3, 'x'), Is.EqualTo(expectedMsg));
sub.Received(1).Protected("ProtectedMethod", Arg.Any<string>(), Arg.Any<int>(), Arg.Any<char>());
}

[Test]
public void Should_mock_and_verify_method_with_no_return_and_no_args()
{
var count = 0;
var sub = Substitute.For<AnotherClass>();
var worker = new Worker();

sub.When("ProtectedMethodWithNoReturn").Do(x => count++);

worker.DoVoidWork(sub);
Assert.That(count, Is.EqualTo(1));
sub.Received(1).Protected("ProtectedMethodWithNoReturn");
}

[Test]
public void Should_mock_and_verify_method_with_no_return_with_arg()
{
var count = 0;
var sub = Substitute.For<AnotherClass>();
var worker = new Worker();

sub.When("ProtectedMethodWithNoReturn", Arg.Any<int>()).Do(x => count++);

worker.DoVoidWork(sub, 5);
Assert.That(count, Is.EqualTo(1));
sub.Received(1).Protected("ProtectedMethodWithNoReturn", Arg.Any<int>());
}

[Test]
public void Should_mock_and_verify_method_with_no_return_with_multiple_args()
{
var count = 0;
var sub = Substitute.For<AnotherClass>();
var worker = new Worker();

sub.When("ProtectedMethodWithNoReturn", Arg.Any<string>(), Arg.Any<int>(), Arg.Any<char>()).Do(x => count++);

worker.DoVoidWork(sub, 5, 'x');
Assert.That(count, Is.EqualTo(1));
sub.Received(1).Protected("ProtectedMethodWithNoReturn", Arg.Any<string>(), Arg.Any<int>(), Arg.Any<char>());
}

private class Worker
{
internal string DoWork(AnotherClass worker)
{
return worker.DoWork();
}

internal string DoMoreWork(AnotherClass worker, int i)
{
return worker.DoWork(i);
}

internal string DoEvenMoreWork(AnotherClass worker, int i, char j)
{
return worker.DoWork("worker", i, j);
}

internal void DoVoidWork(AnotherClass worker)
{
worker.DoVoidWork();
}

internal void DoVoidWork(AnotherClass worker, int i)
{
worker.DoVoidWork(i);
}

internal void DoVoidWork(AnotherClass worker, int i, char j)
{
worker.DoVoidWork("void worker", i, j);
}
}
}