Skip to content

Commit

Permalink
Make Basic, Advanced Automod, and Custom Commands compatible with mes…
Browse files Browse the repository at this point in the history
…sage forwards. (#1733)

* Added Support for message-forwards

* fix stupid test case failure

---------

Co-authored-by: Ashish <[email protected]>
  • Loading branch information
ashishjh-bst and ashishjh-bst authored Oct 4, 2024
1 parent 4857016 commit 807cbed
Show file tree
Hide file tree
Showing 9 changed files with 219 additions and 181 deletions.
1 change: 0 additions & 1 deletion antiphishing/antiphishing.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,5 @@ func CheckMessageForPhishingDomains(input string) (string, error) {
if len(matches) < 1 {
return "", nil
}

return queryPhishingLinks(matches)
}
7 changes: 1 addition & 6 deletions automod/automod_bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,13 @@ func (p *Plugin) checkMessage(evt *eventsystem.EventData, msg *discordgo.Message

ms := dstate.MemberStateFromMember(msg.Member)

stripped := ""
return !p.CheckTriggers(nil, evt.GS, ms, msg, cs, func(trig *ParsedPart) (activated bool, err error) {
if stripped == "" {
stripped = PrepareMessageForWordCheck(msg.Content)
}

cast, ok := trig.Part.(MessageTrigger)
if !ok {
return
}

return cast.CheckMessage(&TriggerContext{GS: evt.GS, MS: ms, Data: trig.ParsedSettings}, cs, msg, stripped)
return cast.CheckMessage(&TriggerContext{GS: evt.GS, MS: ms, Data: trig.ParsedSettings}, cs, msg)
})
}

Expand Down
2 changes: 1 addition & 1 deletion automod/rulepart.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ type TriggerContext struct {
type MessageTrigger interface {
RulePart

CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message, mdStripped string) (isAffected bool, err error)
CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message) (isAffected bool, err error)
}

// ViolationListener is a trigger that gets triggered on a violation
Expand Down
124 changes: 67 additions & 57 deletions automod/triggers.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"github.com/botlabs-gg/yagpdb/v2/safebrowsing"
)

var forwardSlashReplacer = strings.NewReplacer("\\", "")
var SanitizeTextName = "Also match visually similar characters such as \"Ĥéĺĺó\""

/////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -97,7 +96,7 @@ func (mc *MentionsTrigger) UserSettings() []*SettingDef {
}
}

func (mc *MentionsTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message, mdStripped string) (bool, error) {
func (mc *MentionsTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message) (bool, error) {
dataCast := triggerCtx.Data.(*MentionsTriggerData)
if len(m.Mentions) >= dataCast.Treshold {
return true, nil
Expand Down Expand Up @@ -136,12 +135,14 @@ func (alc *AnyLinkTrigger) UserSettings() []*SettingDef {
return []*SettingDef{}
}

func (alc *AnyLinkTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message, mdStripped string) (bool, error) {
if common.LinkRegex.MatchString(forwardSlashReplacer.Replace(m.Content)) {
return true, nil
func (alc *AnyLinkTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message) (bool, error) {
for _, content := range m.GetMessageContents() {
if common.LinkRegex.MatchString(common.ForwardSlashReplacer.Replace(content)) {
return true, nil
}
}

return false, nil

}

func (alc *AnyLinkTrigger) MergeDuplicates(data []interface{}) interface{} {
Expand Down Expand Up @@ -200,19 +201,20 @@ func (wl *WordListTrigger) UserSettings() []*SettingDef {
}
}

func (wl *WordListTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message, mdStripped string) (bool, error) {
func (wl *WordListTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message) (bool, error) {
dataCast := triggerCtx.Data.(*WorldListTriggerData)

list, err := FindFetchGuildList(triggerCtx.GS.ID, dataCast.ListID)
if err != nil {
return false, nil
}

messageFields := strings.Fields(mdStripped)

if dataCast.SanitizeText {
messageFieldsFixText := strings.Fields(confusables.SanitizeText(mdStripped))
messageFields = append(messageFields, messageFieldsFixText...) // Could be turned into a 1-liner, lmk if I should or not
var messageFields []string
for _, content := range m.GetMessageContents() {
content := PrepareMessageForWordCheck(content)
messageFields = append(messageFields, strings.Fields(content)...)
if dataCast.SanitizeText {
messageFields = append(messageFields, strings.Fields(confusables.SanitizeText(content))...) // Could be turned into a 1-liner, lmk if I should or not
}
}

for _, mf := range messageFields {
Expand Down Expand Up @@ -284,15 +286,19 @@ func (dt *DomainTrigger) UserSettings() []*SettingDef {
}
}

func (dt *DomainTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message, mdStripped string) (bool, error) {
func (dt *DomainTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message) (bool, error) {
dataCast := triggerCtx.Data.(*DomainTriggerData)

list, err := FindFetchGuildList(triggerCtx.GS.ID, dataCast.ListID)
if err != nil {
return false, nil
}

matches := common.LinkRegex.FindAllString(forwardSlashReplacer.Replace(m.Content), -1)
var matches []string
for _, content := range m.GetMessageContents() {
snapshotMatches := common.LinkRegex.FindAllString(common.ForwardSlashReplacer.Replace(content), -1)
matches = append(matches, snapshotMatches...)
}

for _, v := range matches {
if contains, _ := dt.containsDomain(v, list.Content); contains {
Expand Down Expand Up @@ -480,7 +486,7 @@ func (caps *AllCapsTrigger) UserSettings() []*SettingDef {
}
}

func (caps *AllCapsTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message, mdStripped string) (bool, error) {
func (caps *AllCapsTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message) (bool, error) {
dataCast := triggerCtx.Data.(*AllCapsTriggerData)

if len(m.Content) < dataCast.MinLength {
Expand Down Expand Up @@ -550,8 +556,8 @@ func (inv *ServerInviteTrigger) UserSettings() []*SettingDef {
return []*SettingDef{}
}

func (inv *ServerInviteTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message, mdStripped string) (bool, error) {
containsBadInvited := automod_legacy.CheckMessageForBadInvites(m.Content, m.GuildID)
func (inv *ServerInviteTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message) (bool, error) {
containsBadInvited := automod_legacy.CheckMessageForBadInvites(m)
return containsBadInvited, nil
}

Expand Down Expand Up @@ -585,17 +591,17 @@ func (a *AntiPhishingLinkTrigger) UserSettings() []*SettingDef {
return []*SettingDef{}
}

func (a *AntiPhishingLinkTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message, mdStripped string) (bool, error) {
badDomain, err := antiphishing.CheckMessageForPhishingDomains(forwardSlashReplacer.Replace(m.Content))
if err != nil {
logger.WithError(err).Error("Failed to check url ")
return false, nil
}

if badDomain != "" {
return true, nil
func (a *AntiPhishingLinkTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message) (bool, error) {
for _, content := range m.GetMessageContents() {
badDomain, err := antiphishing.CheckMessageForPhishingDomains(common.ForwardSlashReplacer.Replace(content))
if err != nil {
logger.WithError(err).Error("Failed to check url ")
continue
}
if badDomain != "" {
return true, nil
}
}

return false, nil
}

Expand Down Expand Up @@ -625,17 +631,17 @@ func (g *GoogleSafeBrowsingTrigger) UserSettings() []*SettingDef {
return []*SettingDef{}
}

func (g *GoogleSafeBrowsingTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message, mdStripped string) (bool, error) {
threat, err := safebrowsing.CheckString(forwardSlashReplacer.Replace(m.Content))
if err != nil {
logger.WithError(err).Error("Failed checking urls against google safebrowser")
return false, nil
}

if threat != nil {
return true, nil
func (g *GoogleSafeBrowsingTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message) (bool, error) {
for _, input := range m.GetMessageContents() {
threat, err := safebrowsing.CheckString(common.ForwardSlashReplacer.Replace(input))
if err != nil {
logger.WithError(err).Error("Failed checking urls against google safebrowser")
continue
}
if threat != nil {
return true, nil
}
}

return false, nil
}

Expand Down Expand Up @@ -756,12 +762,12 @@ func (s *SlowmodeTrigger) UserSettings() []*SettingDef {
return settings
}

func (s *SlowmodeTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message, mdStripped string) (bool, error) {
func (s *SlowmodeTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message) (bool, error) {
if s.Attachments && len(m.Attachments) < 1 {
return false, nil
}

if s.Links && !common.LinkRegex.MatchString(forwardSlashReplacer.Replace(m.Content)) {
if s.Links && !common.LinkRegex.MatchString(common.ForwardSlashReplacer.Replace(strings.Join(m.GetMessageContents(), ""))) {
return false, nil
}

Expand Down Expand Up @@ -798,7 +804,9 @@ func (s *SlowmodeTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.Ch
amount++
}
} else if s.Links {
linksLen := len(common.LinkRegex.FindAllString(forwardSlashReplacer.Replace(v.Content), -1))
contents := m.GetMessageContents()
contentString := strings.Join(contents, "")
linksLen := len(common.LinkRegex.FindAllString(common.ForwardSlashReplacer.Replace(contentString), -1))
if linksLen < 1 {
continue // we're only checking messages with links
}
Expand Down Expand Up @@ -884,7 +892,7 @@ func (mt *MultiMsgMentionTrigger) UserSettings() []*SettingDef {
}
}

func (mt *MultiMsgMentionTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message, mdStripped string) (bool, error) {
func (mt *MultiMsgMentionTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message) (bool, error) {
if len(m.Mentions) < 1 {
return false, nil
}
Expand Down Expand Up @@ -956,7 +964,7 @@ func (r *MessageRegexTrigger) Description() string {
return "Triggers when a message matches the provided regex"
}

func (r *MessageRegexTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message, mdStripped string) (bool, error) {
func (r *MessageRegexTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message) (bool, error) {
dataCast := triggerCtx.Data.(*BaseRegexTriggerData)

item, err := RegexCache.Fetch(dataCast.Regex, time.Minute*10, func() (interface{}, error) {
Expand All @@ -974,20 +982,22 @@ func (r *MessageRegexTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstat

re := item.Value().(*regexp.Regexp)

var sanitizedContent string
if dataCast.SanitizeText {
sanitizedContent = confusables.SanitizeText(m.Content)
}
for _, content := range m.GetMessageContents() {
var sanitizedContent string
if dataCast.SanitizeText {
sanitizedContent = confusables.SanitizeText(content)
}

if re.MatchString(m.Content) || (dataCast.SanitizeText && re.MatchString(sanitizedContent)) {
if r.BaseRegexTrigger.Inverse {
return false, nil
if re.MatchString(m.Content) || (dataCast.SanitizeText && re.MatchString(sanitizedContent)) {
if r.BaseRegexTrigger.Inverse {
continue
}
return true, nil
}
return true, nil
}

if r.BaseRegexTrigger.Inverse {
return true, nil
if r.BaseRegexTrigger.Inverse {
return true, nil
}
}

return false, nil
Expand Down Expand Up @@ -1055,7 +1065,7 @@ func (spam *SpamTrigger) UserSettings() []*SettingDef {
}
}

func (spam *SpamTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message, mdStripped string) (bool, error) {
func (spam *SpamTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message) (bool, error) {

settingsCast := triggerCtx.Data.(*SpamTriggerData)

Expand Down Expand Up @@ -1521,7 +1531,7 @@ func (mat *MessageAttachmentTrigger) UserSettings() []*SettingDef {
return []*SettingDef{}
}

func (mat *MessageAttachmentTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message, mdStripped string) (bool, error) {
func (mat *MessageAttachmentTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message) (bool, error) {
contains := len(m.Attachments) > 0
if contains && mat.RequiresAttachment {
return true, nil
Expand Down Expand Up @@ -1581,7 +1591,7 @@ func (ml *MessageLengthTrigger) UserSettings() []*SettingDef {
}
}

func (ml *MessageLengthTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message, mdStripped string) (bool, error) {
func (ml *MessageLengthTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.ChannelState, m *discordgo.Message) (bool, error) {
dataCast := triggerCtx.Data.(*MessageLengthTriggerData)

if ml.Inverted {
Expand Down
Loading

0 comments on commit 807cbed

Please sign in to comment.