diff --git a/component/fakeip/pool.go b/component/fakeip/pool.go index e2c1072254..12c063324b 100644 --- a/component/fakeip/pool.go +++ b/component/fakeip/pool.go @@ -36,6 +36,7 @@ type Pool struct { cycle bool mux sync.Mutex host []C.DomainMatcher + mode C.FilterMode ipnet netip.Prefix store store } @@ -66,6 +67,14 @@ func (p *Pool) LookBack(ip netip.Addr) (string, bool) { // ShouldSkipped return if domain should be skipped func (p *Pool) ShouldSkipped(domain string) bool { + should := p.shouldSkipped(domain) + if p.mode == C.FilterWhiteList { + return !should + } + return should +} + +func (p *Pool) shouldSkipped(domain string) bool { for _, matcher := range p.host { if matcher.MatchDomain(domain) { return true @@ -157,6 +166,7 @@ func (p *Pool) restoreState() { type Options struct { IPNet netip.Prefix Host []C.DomainMatcher + Mode C.FilterMode // Size sets the maximum number of entries in memory // and does not work if Persistence is true @@ -187,6 +197,7 @@ func New(options Options) (*Pool, error) { offset: first.Prev(), cycle: false, host: options.Host, + mode: options.Mode, ipnet: options.IPNet, } if options.Persistence { diff --git a/component/fakeip/pool_test.go b/component/fakeip/pool_test.go index 1d4fa05f0a..923cca574d 100644 --- a/component/fakeip/pool_test.go +++ b/component/fakeip/pool_test.go @@ -164,6 +164,28 @@ func TestPool_Skip(t *testing.T) { for _, pool := range pools { assert.True(t, pool.ShouldSkipped("example.com")) assert.False(t, pool.ShouldSkipped("foo.com")) + assert.False(t, pool.shouldSkipped("baz.com")) + } +} + +func TestPool_SkipWhiteList(t *testing.T) { + ipnet := netip.MustParsePrefix("192.168.0.1/29") + tree := trie.New[struct{}]() + assert.NoError(t, tree.Insert("example.com", struct{}{})) + assert.False(t, tree.IsEmpty()) + pools, tempfile, err := createPools(Options{ + IPNet: ipnet, + Size: 10, + Host: []C.DomainMatcher{tree.NewDomainSet()}, + Mode: C.FilterWhiteList, + }) + assert.Nil(t, err) + defer os.Remove(tempfile) + + for _, pool := range pools { + assert.False(t, pool.ShouldSkipped("example.com")) + assert.True(t, pool.ShouldSkipped("foo.com")) + assert.True(t, pool.ShouldSkipped("baz.com")) } } diff --git a/config/config.go b/config/config.go index c250d3ec21..ed30bfe452 100644 --- a/config/config.go +++ b/config/config.go @@ -205,6 +205,7 @@ type RawDNS struct { EnhancedMode C.DNSMode `yaml:"enhanced-mode" json:"enhanced-mode"` FakeIPRange string `yaml:"fake-ip-range" json:"fake-ip-range"` FakeIPFilter []string `yaml:"fake-ip-filter" json:"fake-ip-filter"` + FakeIPFilterMode C.FilterMode `yaml:"fake-ip-filter-mode" json:"fake-ip-filter-mode"` DefaultNameserver []string `yaml:"default-nameserver" json:"default-nameserver"` CacheAlgorithm string `yaml:"cache-algorithm" json:"cache-algorithm"` NameServerPolicy *orderedmap.OrderedMap[string, any] `yaml:"nameserver-policy" json:"nameserver-policy"` @@ -474,6 +475,7 @@ func DefaultRawConfig() *RawConfig { "www.msftnsci.com", "www.msftconnecttest.com", }, + FakeIPFilterMode: C.FilterBlackList, }, NTP: RawNTP{ Enable: false, @@ -1458,6 +1460,7 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[resolver.HostValue], rul IPNet: fakeIPRange, Size: 1000, Host: host, + Mode: cfg.FakeIPFilterMode, Persistence: rawCfg.Profile.StoreFakeIP, }) if err != nil { diff --git a/constant/dns.go b/constant/dns.go index 3d97d97b71..8d038a6bbb 100644 --- a/constant/dns.go +++ b/constant/dns.go @@ -43,7 +43,9 @@ func (e DNSMode) MarshalYAML() (any, error) { // UnmarshalJSON unserialize EnhancedMode with json func (e *DNSMode) UnmarshalJSON(data []byte) error { var tp string - json.Unmarshal(data, &tp) + if err := json.Unmarshal(data, &tp); err != nil { + return err + } mode, exist := DNSModeMapping[tp] if !exist { return errors.New("invalid mode") @@ -115,6 +117,64 @@ func NewDNSPrefer(prefer string) DNSPrefer { } } +// FilterModeMapping is a mapping for FilterMode enum +var FilterModeMapping = map[string]FilterMode{ + FilterBlackList.String(): FilterBlackList, + FilterWhiteList.String(): FilterWhiteList, +} + +type FilterMode int + +const ( + FilterBlackList FilterMode = iota + FilterWhiteList +) + +func (e FilterMode) String() string { + switch e { + case FilterBlackList: + return "blacklist" + case FilterWhiteList: + return "whitelist" + default: + return "unknown" + } +} + +func (e FilterMode) MarshalYAML() (interface{}, error) { + return e.String(), nil +} + +func (e *FilterMode) UnmarshalYAML(unmarshal func(interface{}) error) error { + var tp string + if err := unmarshal(&tp); err != nil { + return err + } + mode, exist := FilterModeMapping[tp] + if !exist { + return errors.New("invalid mode") + } + *e = mode + return nil +} + +func (e FilterMode) MarshalJSON() ([]byte, error) { + return json.Marshal(e.String()) +} + +func (e *FilterMode) UnmarshalJSON(data []byte) error { + var tp string + if err := json.Unmarshal(data, &tp); err != nil { + return err + } + mode, exist := FilterModeMapping[tp] + if !exist { + return errors.New("invalid mode") + } + *e = mode + return nil +} + type HTTPVersion string const ( diff --git a/docs/config.yaml b/docs/config.yaml index bb60b28649..1da37841cb 100644 --- a/docs/config.yaml +++ b/docs/config.yaml @@ -249,6 +249,9 @@ dns: - rule-set:fakeip-filter # fakeip-filter 为 geosite 中名为 fakeip-filter 的分类(需要自行保证该分类存在) - geosite:fakeip-filter + # 配置fake-ip-filter的匹配模式,默认为blacklist,即如果匹配成功不返回fake-ip + # 可设置为whitelist,即只有匹配成功才返回fake-ip + fake-ip-filter-mode: blacklist # use-hosts: true # 查询 hosts