diff --git a/manim/mobject/graphing/coordinate_systems.py b/manim/mobject/graphing/coordinate_systems.py index 435d7aced8..5730862118 100644 --- a/manim/mobject/graphing/coordinate_systems.py +++ b/manim/mobject/graphing/coordinate_systems.py @@ -48,6 +48,7 @@ ManimColor, ParsableManimColor, color_gradient, + interpolate_color, invert_color, ) from manim.utils.config_ops import merge_dicts_recursively, update_dict_recursively @@ -628,6 +629,8 @@ def plot( function: Callable[[float], float], x_range: Sequence[float] | None = None, use_vectorized: bool = False, + colorscale: Union[Iterable[Color], Iterable[Color, float]] | None = None, + colorscale_axis: int = 1, **kwargs: Any, ) -> ParametricFunction: """Generates a curve based on a function. @@ -641,6 +644,12 @@ def plot( use_vectorized Whether to pass in the generated t value array to the function. Only use this if your function supports it. Output should be a numpy array of shape ``[y_0, y_1, ...]`` + colorscale + Colors of the function. Optional parameter used when coloring a function by values. Passing a list of colors + and a colorscale_axis will color the function by y-value. Passing a list of tuples in the form ``(color, pivot)`` + allows user-defined pivots where the color transitions. + colorscale_axis + Defines the axis on which the colorscale is applied (0 = x, 1 = y), default is y-axis (1). kwargs Additional parameters to be passed to :class:`~.ParametricFunction`. @@ -719,7 +728,57 @@ def log_func(x): use_vectorized=use_vectorized, **kwargs, ) + graph.underlying_function = function + + if colorscale: + if type(colorscale[0]) in (list, tuple): + new_colors, pivots = [ + [i for i, j in colorscale], + [j for i, j in colorscale], + ] + else: + new_colors = colorscale + + ranges = [self.x_range, self.y_range] + pivot_min = ranges[colorscale_axis][0] + pivot_max = ranges[colorscale_axis][1] + pivot_frequency = (pivot_max - pivot_min) / (len(new_colors) - 1) + pivots = np.arange( + start=pivot_min, + stop=pivot_max + pivot_frequency, + step=pivot_frequency, + ) + + resolution = 0.01 if len(x_range) == 2 else x_range[2] + sample_points = np.arange(x_range[0], x_range[1] + resolution, resolution) + color_list = [] + for samp_x in sample_points: + axis_value = (samp_x, function(samp_x))[colorscale_axis] + if axis_value <= pivots[0]: + color_list.append(new_colors[0]) + elif axis_value >= pivots[-1]: + color_list.append(new_colors[-1]) + else: + for i, pivot in enumerate(pivots): + if pivot > axis_value: + color_index = (axis_value - pivots[i - 1]) / ( + pivots[i] - pivots[i - 1] + ) + color_index = min(color_index, 1) + mob_color = interpolate_color( + new_colors[i - 1], + new_colors[i], + color_index, + ) + color_list.append(mob_color) + break + if config.renderer == RendererType.OPENGL: + graph.set_color(color_list) + else: + graph.set_stroke(color_list) + graph.set_sheen_direction(RIGHT) + return graph def plot_implicit_curve( diff --git a/tests/opengl/control_data/coordinate_system_opengl/gradient_line_graph_x_axis_using_opengl_renderer[None].npz b/tests/opengl/control_data/coordinate_system_opengl/gradient_line_graph_x_axis_using_opengl_renderer[None].npz new file mode 100644 index 0000000000..6efda8b2e0 Binary files /dev/null and b/tests/opengl/control_data/coordinate_system_opengl/gradient_line_graph_x_axis_using_opengl_renderer[None].npz differ diff --git a/tests/opengl/control_data/coordinate_system_opengl/gradient_line_graph_y_axis_using_opengl_renderer[None].npz b/tests/opengl/control_data/coordinate_system_opengl/gradient_line_graph_y_axis_using_opengl_renderer[None].npz new file mode 100644 index 0000000000..3c7afcda7d Binary files /dev/null and b/tests/opengl/control_data/coordinate_system_opengl/gradient_line_graph_y_axis_using_opengl_renderer[None].npz differ diff --git a/tests/opengl/test_coordinate_system_opengl.py b/tests/opengl/test_coordinate_system_opengl.py index b63a431234..c9f5cc81b6 100644 --- a/tests/opengl/test_coordinate_system_opengl.py +++ b/tests/opengl/test_coordinate_system_opengl.py @@ -20,6 +20,10 @@ tempconfig, ) from manim import CoordinateSystem as CS +from manim.utils.color import BLUE, GREEN, ORANGE, RED, YELLOW +from manim.utils.testing.frames_comparison import frames_comparison + +__module_test__ = "coordinate_system_opengl" def test_initial_config(using_opengl_renderer): @@ -138,3 +142,33 @@ def test_input_to_graph_point(using_opengl_renderer): # test the line_graph implementation position = np.around(ax.input_to_graph_point(x=PI, graph=line_graph), decimals=4) np.testing.assert_array_equal(position, (2.6928, 1.2876, 0)) + + +@frames_comparison +def test_gradient_line_graph_x_axis(scene, using_opengl_renderer): + """Test that using `colorscale` generates a line whose gradient matches the y-axis""" + axes = Axes(x_range=[-3, 3], y_range=[-3, 3]) + + curve = axes.plot( + lambda x: 0.1 * x**3, + x_range=(-3, 3, 0.001), + colorscale=[BLUE, GREEN, YELLOW, ORANGE, RED], + colorscale_axis=0, + ) + + scene.add(axes, curve) + + +@frames_comparison +def test_gradient_line_graph_y_axis(scene, using_opengl_renderer): + """Test that using `colorscale` generates a line whose gradient matches the y-axis""" + axes = Axes(x_range=[-3, 3], y_range=[-3, 3]) + + curve = axes.plot( + lambda x: 0.1 * x**3, + x_range=(-3, 3, 0.001), + colorscale=[BLUE, GREEN, YELLOW, ORANGE, RED], + colorscale_axis=1, + ) + + scene.add(axes, curve) diff --git a/tests/test_graphical_units/control_data/coordinate_system/gradient_line_graph_x_axis.npz b/tests/test_graphical_units/control_data/coordinate_system/gradient_line_graph_x_axis.npz new file mode 100644 index 0000000000..6efda8b2e0 Binary files /dev/null and b/tests/test_graphical_units/control_data/coordinate_system/gradient_line_graph_x_axis.npz differ diff --git a/tests/test_graphical_units/control_data/coordinate_system/gradient_line_graph_y_axis.npz b/tests/test_graphical_units/control_data/coordinate_system/gradient_line_graph_y_axis.npz new file mode 100644 index 0000000000..3c7afcda7d Binary files /dev/null and b/tests/test_graphical_units/control_data/coordinate_system/gradient_line_graph_y_axis.npz differ diff --git a/tests/test_graphical_units/test_coordinate_systems.py b/tests/test_graphical_units/test_coordinate_systems.py index d7603f4d00..7d6dad67af 100644 --- a/tests/test_graphical_units/test_coordinate_systems.py +++ b/tests/test_graphical_units/test_coordinate_systems.py @@ -141,3 +141,33 @@ def test_number_plane_log(scene): ) scene.add(VGroup(plane1, plane2).arrange()) + + +@frames_comparison +def test_gradient_line_graph_x_axis(scene): + """Test that using `colorscale` generates a line whose gradient matches the y-axis""" + axes = Axes(x_range=[-3, 3], y_range=[-3, 3]) + + curve = axes.plot( + lambda x: 0.1 * x**3, + x_range=(-3, 3, 0.001), + colorscale=[BLUE, GREEN, YELLOW, ORANGE, RED], + colorscale_axis=0, + ) + + scene.add(axes, curve) + + +@frames_comparison +def test_gradient_line_graph_y_axis(scene): + """Test that using `colorscale` generates a line whose gradient matches the y-axis""" + axes = Axes(x_range=[-3, 3], y_range=[-3, 3]) + + curve = axes.plot( + lambda x: 0.1 * x**3, + x_range=(-3, 3, 0.001), + colorscale=[BLUE, GREEN, YELLOW, ORANGE, RED], + colorscale_axis=1, + ) + + scene.add(axes, curve)