Skip to content

Commit

Permalink
Merge pull request #5955 from peppy/model-backed-fuckedable
Browse files Browse the repository at this point in the history
Fix `ModelBackedDrawable` potentially crashing on badly timed `null`
  • Loading branch information
peppy authored Jul 31, 2023
2 parents e72f4fe + 597f5ac commit e301001
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 55 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// Copyright (c) ppy Pty Ltd <[email protected]>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.

#nullable disable

using System;
using System.Threading;
using NUnit.Framework;
Expand All @@ -18,7 +16,7 @@ namespace osu.Framework.Tests.Visual.Drawables
{
public partial class TestSceneModelBackedDrawable : FrameworkTestScene
{
private TestModelBackedDrawable backedDrawable;
private TestModelBackedDrawable backedDrawable = null!;

private void createModelBackedDrawable(bool hasIntermediate, bool showNullModel = false) =>
Child = backedDrawable = new TestModelBackedDrawable
Expand All @@ -34,13 +32,14 @@ private void createModelBackedDrawable(bool hasIntermediate, bool showNullModel
public void TestEmptyDefaultState()
{
AddStep("setup", () => createModelBackedDrawable(false));
AddAssert("nothing shown", () => backedDrawable.DisplayedDrawable == null);
AddUntilStep("wait for load", () => backedDrawable.DelayedLoadFinished);
AddAssert("nothing shown", () => backedDrawable.DisplayedDrawable, () => Is.Null.Or.InstanceOf(Empty().GetType()));
}

[Test]
public void TestModelDefaultState()
{
TestDrawableModel drawableModel = null;
TestDrawableModel drawableModel = null!;

AddStep("setup", () =>
{
Expand All @@ -55,8 +54,8 @@ public void TestModelDefaultState()
[TestCase(true)]
public void TestChangeModel(bool hasIntermediate)
{
TestDrawableModel firstModel = null;
TestDrawableModel secondModel = null;
TestDrawableModel firstModel = null!;
TestDrawableModel secondModel = null!;

AddStep("setup", () =>
{
Expand All @@ -77,9 +76,9 @@ public void TestChangeModel(bool hasIntermediate)
[TestCase(true)]
public void TestChangeModelDuringLoad(bool hasIntermediate)
{
TestDrawableModel firstModel = null;
TestDrawableModel secondModel = null;
TestDrawableModel thirdModel = null;
TestDrawableModel firstModel = null!;
TestDrawableModel secondModel = null!;
TestDrawableModel thirdModel = null!;

AddStep("setup", () =>
{
Expand All @@ -106,8 +105,8 @@ public void TestChangeModelDuringLoad(bool hasIntermediate)
[TestCase(true)]
public void TestOutOfOrderLoad(bool hasIntermediate)
{
TestDrawableModel firstModel = null;
TestDrawableModel secondModel = null;
TestDrawableModel firstModel = null!;
TestDrawableModel secondModel = null!;

AddStep("setup", () =>
{
Expand All @@ -130,7 +129,7 @@ public void TestOutOfOrderLoad(bool hasIntermediate)
[Test]
public void TestSetNullModel()
{
TestDrawableModel drawableModel = null;
TestDrawableModel drawableModel = null!;

AddStep("setup", () =>
{
Expand All @@ -147,7 +146,7 @@ public void TestSetNullModel()
[Test]
public void TestInsideBufferedContainer()
{
TestDrawableModel drawableModel = null;
TestDrawableModel drawableModel = null!;

AddStep("setup", () =>
{
Expand Down Expand Up @@ -267,25 +266,37 @@ public TestNullDrawableModel()
private partial class TestModelBackedDrawable : ModelBackedDrawable<TestModel>
{
public bool ShowNullModel;

public bool HasIntermediate;
public bool DelayedLoadFinished;

protected override Drawable CreateDrawable(TestModel model)
protected override Drawable? CreateDrawable(TestModel? model)
{
if (model == null && ShowNullModel)
return new TestNullDrawableModel();

return model?.DrawableModel;
}

public new Drawable DisplayedDrawable => base.DisplayedDrawable;
public new Drawable? DisplayedDrawable => base.DisplayedDrawable;

public new TestModel Model
public new TestModel? Model
{
set => base.Model = value;
}

protected override bool TransformImmediately => HasIntermediate;

protected override void OnLoadStarted()
{
base.OnLoadStarted();
DelayedLoadFinished = false;
}

protected override void OnLoadFinished()
{
base.OnLoadFinished();
DelayedLoadFinished = true;
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// Copyright (c) ppy Pty Ltd <[email protected]>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.

#nullable disable

using System;
using NUnit.Framework;
using osu.Framework.Graphics;
Expand All @@ -18,8 +16,8 @@ namespace osu.Framework.Tests.Visual.Drawables
{
public partial class TestSceneModelBackedDrawableWithUnloading : FrameworkTestScene
{
private TestUnloadingModelBackedDrawable backedDrawable;
private Drawable initialDrawable;
private TestUnloadingModelBackedDrawable backedDrawable = null!;
private Drawable? initialDrawable;

[SetUpSteps]
public void SetUpSteps()
Expand All @@ -43,7 +41,19 @@ public void SetUpSteps()
public void TestUnloading()
{
AddStep("mask away", () => backedDrawable.Position = new Vector2(-2));
AddUntilStep("drawable unloaded", () => initialDrawable.IsDisposed && backedDrawable.DisplayedDrawable == null);
AddUntilStep("drawable unloaded", () => initialDrawable?.IsDisposed == true && backedDrawable.DisplayedDrawable == null);

AddStep("return back", () => backedDrawable.Position = Vector2.Zero);
AddUntilStep("new drawable displayed", () => backedDrawable.DisplayedDrawable != null && backedDrawable.DisplayedDrawable != initialDrawable);
}

[Test]
public void TestUnloadingWithNullAfterUnload()
{
AddStep("mask away", () => backedDrawable.Position = new Vector2(-2));
AddUntilStep("drawable unloaded", () => initialDrawable?.IsDisposed == true && backedDrawable.DisplayedDrawable == null);

AddStep("set providing drawable to null", () => backedDrawable.ReturnNullDrawable = true);

AddStep("return back", () => backedDrawable.Position = Vector2.Zero);
AddUntilStep("new drawable displayed", () => backedDrawable.DisplayedDrawable != null && backedDrawable.DisplayedDrawable != initialDrawable);
Expand Down Expand Up @@ -76,12 +86,14 @@ public void TestTransformsAppliedOnReloading()
// on loading, ModelBackedDrawable applies immediate hide transform on new drawable then applies show transform.
AddAssert("initial hide transform applied", () => backedDrawable.HideTransforms == 1);
AddAssert("show transform applied", () => backedDrawable.ShowTransforms == 1);
AddUntilStep("new drawable alpha = 1", () => backedDrawable.DisplayedDrawable.Alpha == 1);
AddUntilStep("new drawable alpha = 1", () => backedDrawable.DisplayedDrawable?.Alpha == 1);
}

private partial class TestUnloadingModelBackedDrawable : ModelBackedDrawable<int>
{
public new Drawable DisplayedDrawable => base.DisplayedDrawable;
public bool ReturnNullDrawable;

public new Drawable? DisplayedDrawable => base.DisplayedDrawable;

public new int Model
{
Expand Down Expand Up @@ -120,8 +132,11 @@ protected override TransformSequence<Drawable> ApplyHideTransforms(Drawable draw
return base.ApplyHideTransforms(drawable);
}

protected override Drawable CreateDrawable(int model)
protected override Drawable? CreateDrawable(int model)
{
if (ReturnNullDrawable)
return null;

return new Container
{
RelativeSizeAxes = Axes.Both,
Expand Down
59 changes: 30 additions & 29 deletions osu.Framework/Graphics/Containers/ModelBackedDrawable.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
// Copyright (c) ppy Pty Ltd <[email protected]>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.

#nullable disable

using System;
using System.Collections.Generic;
using JetBrains.Annotations;
using osu.Framework.Graphics.Transforms;
using osu.Framework.Lists;

Expand All @@ -20,20 +17,20 @@ public abstract partial class ModelBackedDrawable<T> : CompositeDrawable
/// <summary>
/// The currently displayed <see cref="Drawable"/>. Null if no drawable is displayed.
/// </summary>
protected Drawable DisplayedDrawable => displayedWrapper?.Content;
protected Drawable? DisplayedDrawable => displayedWrapper?.Content;

/// <summary>
/// The <see cref="IEqualityComparer{T}"/> used to compare models to ensure that <see cref="Drawable"/>s are not updated unnecessarily.
/// </summary>
protected readonly IEqualityComparer<T> Comparer;

private T model;
private T? model;

/// <summary>
/// Gets or sets the model, potentially triggering the current <see cref="Drawable"/> to update.
/// Subclasses should expose this via a nicer property name to better represent the data being set.
/// </summary>
protected T Model
protected T? Model
{
get => model;
set
Expand All @@ -53,12 +50,12 @@ protected T Model
/// <summary>
/// The wrapper which has the current displayed content.
/// </summary>
private DelayedLoadWrapper displayedWrapper;
private DelayedLoadWrapper? displayedWrapper;

/// <summary>
/// The wrapper which is currently loading, or has finished loading (i.e <see cref="displayedWrapper"/>).
/// </summary>
private DelayedLoadWrapper currentWrapper;
private DelayedLoadWrapper? currentWrapper;

/// <summary>
/// Constructs a new <see cref="ModelBackedDrawable{T}"/> with the default <typeparamref name="T"/> equality comparer.
Expand Down Expand Up @@ -100,10 +97,10 @@ private void updateDrawable()
loadDrawable(null);
}

loadDrawable(() => CreateDrawable(model));
loadDrawable(() => CreateDrawable(model) ?? Empty());
}

private void loadDrawable(Func<Drawable> createDrawableFunc)
private void loadDrawable(Func<Drawable>? createDrawableFunc)
{
// Remove the previous wrapper if the inner drawable hasn't finished loading.
if (currentWrapper?.DelayedLoadCompleted == false)
Expand All @@ -112,7 +109,9 @@ private void loadDrawable(Func<Drawable> createDrawableFunc)
DisposeChildAsync(currentWrapper);
}

currentWrapper = createWrapper(createDrawableFunc, LoadDelay);
currentWrapper = createDrawableFunc == null
? null
: createWrapper(createDrawableFunc, LoadDelay);

if (currentWrapper == null)
{
Expand All @@ -136,14 +135,19 @@ private void loadDrawable(Func<Drawable> createDrawableFunc)
/// Invoked when a <see cref="DelayedLoadWrapper"/> has finished loading its contents.
/// May be invoked multiple times for each <see cref="DelayedLoadWrapper"/>.
/// </summary>
/// <param name="wrapper">The <see cref="DelayedLoadWrapper"/>.</param>
private void finishLoad(DelayedLoadWrapper wrapper)
/// <param name="wrapper">The current <see cref="DelayedLoadWrapper"/>.</param>
private void finishLoad(DelayedLoadWrapper? wrapper)
{
// Make the wrapper initially hidden.
ApplyHideTransforms(wrapper);
wrapper?.FinishTransforms();
TransformSequence<Drawable>? showTransforms = null;

var showTransforms = ApplyShowTransforms(wrapper);
if (wrapper != null)
{
// Make the wrapper initially hidden.
ApplyHideTransforms(wrapper);
wrapper.FinishTransforms();

showTransforms = ApplyShowTransforms(wrapper);
}

// If the wrapper hasn't changed then this invocation must be a result of a reload (e.g. DelayedLoadUnloadWrapper)
// In that case, we do not want to apply hide transforms and expire the last wrapper.
Expand Down Expand Up @@ -172,10 +176,9 @@ private void finishLoad(DelayedLoadWrapper wrapper)
/// <returns>A <see cref="DelayedLoadWrapper"/> or null if <paramref name="createContentFunc"/> returns null.</returns>
private DelayedLoadWrapper createWrapper(Func<Drawable> createContentFunc, double timeBeforeLoad)
{
var content = createContentFunc?.Invoke();

if (content == null)
return null;
// Note that this only becomes null after the first consumption.
// ie. the `createContentFunc` cannot provide a null.
Drawable? content = createContentFunc();

return CreateDelayedLoadWrapper(() =>
{
Expand Down Expand Up @@ -224,32 +227,30 @@ protected virtual void OnLoadFinished()
/// <summary>
/// Allows subclasses to customise the <see cref="DelayedLoadWrapper"/>.
/// </summary>
[NotNull]
protected virtual DelayedLoadWrapper CreateDelayedLoadWrapper([NotNull] Func<Drawable> createContentFunc, double timeBeforeLoad) =>
protected virtual DelayedLoadWrapper CreateDelayedLoadWrapper(Func<Drawable> createContentFunc, double timeBeforeLoad) =>
new DelayedLoadWrapper(createContentFunc(), timeBeforeLoad);

/// <summary>
/// Creates a custom <see cref="Drawable"/> to display a model.
/// </summary>
/// <param name="model">The model that the <see cref="Drawable"/> should represent.</param>
/// <returns>A <see cref="Drawable"/> that represents <paramref name="model"/>, or null if no <see cref="Drawable"/> should be displayed.</returns>
[CanBeNull]
protected abstract Drawable CreateDrawable([CanBeNull] T model);
protected abstract Drawable? CreateDrawable(T? model);

/// <summary>
/// Hides a drawable.
/// </summary>
/// <param name="drawable">The drawable that is to be hidden.</param>
/// <returns>The transform sequence.</returns>
protected virtual TransformSequence<Drawable> ApplyHideTransforms([CanBeNull] Drawable drawable)
=> drawable?.FadeOut(TransformDuration, Easing.OutQuint);
protected virtual TransformSequence<Drawable> ApplyHideTransforms(Drawable drawable)
=> drawable.FadeOut(TransformDuration, Easing.OutQuint);

/// <summary>
/// Shows a drawable.
/// </summary>
/// <param name="drawable">The drawable that is to be shown.</param>
/// <returns>The transform sequence.</returns>
protected virtual TransformSequence<Drawable> ApplyShowTransforms([CanBeNull] Drawable drawable)
=> drawable?.FadeIn(TransformDuration, Easing.OutQuint);
protected virtual TransformSequence<Drawable> ApplyShowTransforms(Drawable drawable)
=> drawable.FadeIn(TransformDuration, Easing.OutQuint);
}
}

0 comments on commit e301001

Please sign in to comment.