Skip to content

Commit

Permalink
Added colorscale to axes.plot() (ManimCommunity#3148)
Browse files Browse the repository at this point in the history
* add colorscale to plot

* Update manim/mobject/graphing/coordinate_systems.py

Co-authored-by: Benjamin Hackl <[email protected]>

* updated typing and moved one line

* added test

* fix input_to_graph_point error

* Performance improvement by using cairo color drawing

* Add OpenGL support

* Add OpenGL tests and split test for x and y axis for more behavior coverage

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Updated gradient_line tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Benjamin Hackl <[email protected]>
Co-authored-by: MrDiver <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Francisco Manríquez Novoa <[email protected]>
Co-authored-by: chopan <[email protected]>
  • Loading branch information
6 people authored Jul 25, 2024
1 parent 3a4ab4c commit 20d0194
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 0 deletions.
59 changes: 59 additions & 0 deletions manim/mobject/graphing/coordinate_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
ManimColor,
ParsableManimColor,
color_gradient,
interpolate_color,
invert_color,
)
from manim.utils.config_ops import merge_dicts_recursively, update_dict_recursively
Expand Down Expand Up @@ -628,6 +629,8 @@ def plot(
function: Callable[[float], float],
x_range: Sequence[float] | None = None,
use_vectorized: bool = False,
colorscale: Union[Iterable[Color], Iterable[Color, float]] | None = None,
colorscale_axis: int = 1,
**kwargs: Any,
) -> ParametricFunction:
"""Generates a curve based on a function.
Expand All @@ -641,6 +644,12 @@ def plot(
use_vectorized
Whether to pass in the generated t value array to the function. Only use this if your function supports it.
Output should be a numpy array of shape ``[y_0, y_1, ...]``
colorscale
Colors of the function. Optional parameter used when coloring a function by values. Passing a list of colors
and a colorscale_axis will color the function by y-value. Passing a list of tuples in the form ``(color, pivot)``
allows user-defined pivots where the color transitions.
colorscale_axis
Defines the axis on which the colorscale is applied (0 = x, 1 = y), default is y-axis (1).
kwargs
Additional parameters to be passed to :class:`~.ParametricFunction`.
Expand Down Expand Up @@ -719,7 +728,57 @@ def log_func(x):
use_vectorized=use_vectorized,
**kwargs,
)

graph.underlying_function = function

if colorscale:
if type(colorscale[0]) in (list, tuple):
new_colors, pivots = [
[i for i, j in colorscale],
[j for i, j in colorscale],
]
else:
new_colors = colorscale

ranges = [self.x_range, self.y_range]
pivot_min = ranges[colorscale_axis][0]
pivot_max = ranges[colorscale_axis][1]
pivot_frequency = (pivot_max - pivot_min) / (len(new_colors) - 1)
pivots = np.arange(
start=pivot_min,
stop=pivot_max + pivot_frequency,
step=pivot_frequency,
)

resolution = 0.01 if len(x_range) == 2 else x_range[2]
sample_points = np.arange(x_range[0], x_range[1] + resolution, resolution)
color_list = []
for samp_x in sample_points:
axis_value = (samp_x, function(samp_x))[colorscale_axis]
if axis_value <= pivots[0]:
color_list.append(new_colors[0])
elif axis_value >= pivots[-1]:
color_list.append(new_colors[-1])
else:
for i, pivot in enumerate(pivots):
if pivot > axis_value:
color_index = (axis_value - pivots[i - 1]) / (
pivots[i] - pivots[i - 1]
)
color_index = min(color_index, 1)
mob_color = interpolate_color(
new_colors[i - 1],
new_colors[i],
color_index,
)
color_list.append(mob_color)
break
if config.renderer == RendererType.OPENGL:
graph.set_color(color_list)
else:
graph.set_stroke(color_list)
graph.set_sheen_direction(RIGHT)

return graph

def plot_implicit_curve(
Expand Down
Binary file not shown.
Binary file not shown.
34 changes: 34 additions & 0 deletions tests/opengl/test_coordinate_system_opengl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
tempconfig,
)
from manim import CoordinateSystem as CS
from manim.utils.color import BLUE, GREEN, ORANGE, RED, YELLOW
from manim.utils.testing.frames_comparison import frames_comparison

__module_test__ = "coordinate_system_opengl"


def test_initial_config(using_opengl_renderer):
Expand Down Expand Up @@ -138,3 +142,33 @@ def test_input_to_graph_point(using_opengl_renderer):
# test the line_graph implementation
position = np.around(ax.input_to_graph_point(x=PI, graph=line_graph), decimals=4)
np.testing.assert_array_equal(position, (2.6928, 1.2876, 0))


@frames_comparison
def test_gradient_line_graph_x_axis(scene, using_opengl_renderer):
"""Test that using `colorscale` generates a line whose gradient matches the y-axis"""
axes = Axes(x_range=[-3, 3], y_range=[-3, 3])

curve = axes.plot(
lambda x: 0.1 * x**3,
x_range=(-3, 3, 0.001),
colorscale=[BLUE, GREEN, YELLOW, ORANGE, RED],
colorscale_axis=0,
)

scene.add(axes, curve)


@frames_comparison
def test_gradient_line_graph_y_axis(scene, using_opengl_renderer):
"""Test that using `colorscale` generates a line whose gradient matches the y-axis"""
axes = Axes(x_range=[-3, 3], y_range=[-3, 3])

curve = axes.plot(
lambda x: 0.1 * x**3,
x_range=(-3, 3, 0.001),
colorscale=[BLUE, GREEN, YELLOW, ORANGE, RED],
colorscale_axis=1,
)

scene.add(axes, curve)
Binary file not shown.
Binary file not shown.
30 changes: 30 additions & 0 deletions tests/test_graphical_units/test_coordinate_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,33 @@ def test_number_plane_log(scene):
)

scene.add(VGroup(plane1, plane2).arrange())


@frames_comparison
def test_gradient_line_graph_x_axis(scene):
"""Test that using `colorscale` generates a line whose gradient matches the y-axis"""
axes = Axes(x_range=[-3, 3], y_range=[-3, 3])

curve = axes.plot(
lambda x: 0.1 * x**3,
x_range=(-3, 3, 0.001),
colorscale=[BLUE, GREEN, YELLOW, ORANGE, RED],
colorscale_axis=0,
)

scene.add(axes, curve)


@frames_comparison
def test_gradient_line_graph_y_axis(scene):
"""Test that using `colorscale` generates a line whose gradient matches the y-axis"""
axes = Axes(x_range=[-3, 3], y_range=[-3, 3])

curve = axes.plot(
lambda x: 0.1 * x**3,
x_range=(-3, 3, 0.001),
colorscale=[BLUE, GREEN, YELLOW, ORANGE, RED],
colorscale_axis=1,
)

scene.add(axes, curve)

0 comments on commit 20d0194

Please sign in to comment.