diff --git a/get.go b/get.go index c8527a2a..3d1bd3d2 100644 --- a/get.go +++ b/get.go @@ -411,7 +411,7 @@ func (s Style) GetFrameSize() (x, y int) { // GetTransform returns the transform set on the style. If no transform is set // nil is returned. -func (s Style) GetTransform() func(string) string { +func (s Style) GetTransform() Transform { return s.getAsTransform(transformKey) } @@ -520,7 +520,7 @@ func (s Style) getBorderStyle() Border { return s.borderStyle } -func (s Style) getAsTransform(propKey) func(string) string { +func (s Style) getAsTransform(propKey) Transform { if !s.isSet(transformKey) { return nil } diff --git a/list/list_test.go b/list/list_test.go index 21f8f438..f10eca67 100644 --- a/list/list_test.go +++ b/list/list_test.go @@ -306,7 +306,7 @@ func TestEnumeratorsTransform(t *testing.T) { }{ "alphabet lower": { enumeration: list.Alphabet, - style: lipgloss.NewStyle().PaddingRight(1).Transform(strings.ToLower), + style: lipgloss.NewStyle().PaddingRight(1).Transform(lipgloss.TransformFunc(strings.ToLower)), expected: ` a. Foo b. Bar @@ -315,9 +315,9 @@ c. Baz }, "arabic)": { enumeration: list.Arabic, - style: lipgloss.NewStyle().PaddingRight(1).Transform(func(s string) string { + style: lipgloss.NewStyle().PaddingRight(1).Transform(lipgloss.TransformFunc(func(s string) string { return strings.Replace(s, ".", ")", 1) - }), + })), expected: ` 1) Foo 2) Bar @@ -326,9 +326,9 @@ c. Baz }, "roman within ()": { enumeration: list.Roman, - style: lipgloss.NewStyle().Transform(func(s string) string { + style: lipgloss.NewStyle().Transform(lipgloss.TransformFunc(func(s string) string { return "(" + strings.Replace(strings.ToLower(s), ".", "", 1) + ") " - }), + })), expected: ` (i) Foo (ii) Bar @@ -337,9 +337,9 @@ c. Baz }, "bullet is dash": { enumeration: list.Bullet, - style: lipgloss.NewStyle().Transform(func(s string) string { + style: lipgloss.NewStyle().Transform(lipgloss.TransformFunc(func(s string) string { return "- " // this is better done by replacing the enumerator. - }), + })), expected: ` - Foo - Bar diff --git a/set.go b/set.go index 1e294354..6afa5d16 100644 --- a/set.go +++ b/set.go @@ -1,6 +1,8 @@ package lipgloss -import "image/color" +import ( + "image/color" +) // Set a value on the underlying rules map. func (s *Style) set(key propKey, value interface{}) { @@ -66,7 +68,7 @@ func (s *Style) set(key propKey, value interface{}) { // that negative value can be no less than -1). s.tabWidth = value.(int) case transformKey: - s.transform = value.(func(string) string) + s.transform = value.(Transform) default: if v, ok := value.(bool); ok { //nolint:nestif if v { @@ -675,15 +677,31 @@ func (s Style) StrikethroughSpaces(v bool) Style { return s } +// Transform is a method for setting a function that will be applied to the +// string at render time. +type Transform interface { + Transform(string) string +} + +// TransformFunc is a function that can be used to transform a string at render +// time. +type TransformFunc func(string) string + +// Transform applies a given function to a string at render time, allowing for +// the string being rendered to be manipuated. +func (t TransformFunc) Transform(s string) string { + return t(s) +} + // Transform applies a given function to a string at render time, allowing for // the string being rendered to be manipuated. // // Example: // -// s := NewStyle().Transform(strings.ToUpper) +// s := NewStyle().Transform(lipgloss.TransformFunc(strings.ToUpper)) // fmt.Println(s.Render("raow!") // "RAOW!" -func (s Style) Transform(fn func(string) string) Style { - s.set(transformKey, fn) +func (s Style) Transform(t Transform) Style { + s.set(transformKey, t) return s } diff --git a/style.go b/style.go index b4d4627e..74a9be92 100644 --- a/style.go +++ b/style.go @@ -145,7 +145,7 @@ type Style struct { maxHeight int tabWidth int - transform func(string) string + transform Transform } // joinString joins a list of strings into a single string separated with a @@ -271,7 +271,7 @@ func (s Style) Render(strs ...string) string { ) if transform != nil { - str = transform(str) + str = transform.Transform(str) } if s.props == 0 { diff --git a/style_test.go b/style_test.go index c660b907..7502f090 100644 --- a/style_test.go +++ b/style_test.go @@ -456,7 +456,7 @@ func TestStringTransform(t *testing.T) { "犬 yzal eht revo depmuj 狐 nworb kciuq ehT", }, } { - res := NewStyle().Bold(true).Transform(tc.fn).Render(tc.input) + res := NewStyle().Bold(true).Transform(TransformFunc(tc.fn)).Render(tc.input) expected := "\x1b[1m" + tc.expected + "\x1b[m" if res != expected { t.Errorf("Test #%d:\nExpected: %q\nGot: %q", i+1, expected, res)