diff --git a/iterator.go b/iterator.go index 2e1e564..435adeb 100644 --- a/iterator.go +++ b/iterator.go @@ -63,28 +63,38 @@ func NewIter(ipStr string) (it *Iter, startIp net.IP, err error) { } else if strings.Contains(ipStr, ":") { it.isIpv6 = true // 填充缩写 - if strings.Count(ipStr, "::") == 1 { - // :: 扩展 - buf := ":0000" - for i := strings.Count(ipStr, ":"); i < 7; i++ { - buf += ":0000" - } - ipStr = strings.Replace(ipStr, "::", buf+":", 1) - // 补零 - ipv6C := strings.Split(ipStr, ":") - for i, v := range ipv6C { - ipv6D := strings.Split(v, "-") - for i2, v2 := range ipv6D { - for len(v2) < 4 { - v2 = "0" + v2 - } - ipv6D[i2] = v2 + // :: 扩展 + var fill = func(ipStr string) string { + buf := strings.Builder{} + for i := strings.Count(ipStr, ":"); i < 8; i++ { + buf.WriteString(":0000") + } + ipStr = strings.Replace(ipStr, "::", buf.String()+":", 1) + buf.Reset() + return ipStr + } + if strings.Count(ipStr, "::") == 1 { + ipStr = fill(ipStr) + } else if strings.Count(ipStr, "::") == 2 && strings.Count(ipStr, "-") == 1 { + iL := strings.Split(ipStr, "-") + iL[0] = fill(iL[0]) + iL[1] = fill(iL[1]) + ipStr = strings.Join(iL, "-") + } + // 补零 + ipv6C := strings.Split(ipStr, ":") + for i, v := range ipv6C { + ipv6D := strings.Split(v, "-") + for i2, v2 := range ipv6D { + for len(v2) < 4 { + v2 = "0" + v2 } - ipv6C[i] = strings.Join(ipv6D, "-") + ipv6D[i2] = v2 } - ipStr = strings.Join(ipv6C, ":") + ipv6C[i] = strings.Join(ipv6D, "-") } + ipStr = strings.Join(ipv6C, ":") } it.ipStr = ipStr if !it.isIpv4 && !it.isIpv6 { diff --git a/iterator_test.go b/iterator_test.go index 57eb080..a579c38 100644 --- a/iterator_test.go +++ b/iterator_test.go @@ -1,14 +1,31 @@ package iprange import ( + "fmt" "net" "testing" ) func TestName(t *testing.T) { - for _, v := range []string{"1.1.1.1", "1.1.1.2/30", "1.1.1.0-255", "1.1-2.0-1.4", "1.1.1.1-1.1.2.1", "2001::59:63", "2001::59:63/126", "2001::59:63-f2", "2001::59-60:63-f2", "2001::59:63-2001::59:f2"} { - t.Logf("Test %s", v) - it, startIp, err := NewIter(v) + testCases := []struct { + input string + want string + }{ + {input: "1.1.1.1", want: "1.1.1.1"}, + {input: "1.1.1.2/30", want: "1.1.1.3"}, + {input: "1.1.1.0-255", want: "1.1.1.200"}, + {input: "1.1-2.0-1.4", want: "1.2.0.4"}, + {input: "1.1.1.1-1.1.2.1", want: "1.1.2.0"}, + {input: "2001::59:63", want: "2001::59:63"}, + {input: "2001::59:63/126", want: "2001::59:62"}, + {input: "2001::59:63-f2", want: "2001::59:f0"}, + {input: "2001::59-60:63-f2", want: "2001::60:f0"}, + {input: "2001::59:63-2001::59:f2", want: "2001::59:f0"}, + } + + for _, v := range testCases { + t.Logf("Test %s", v.input) + it, startIp, err := NewIter(v.input) if err != nil { t.Fatal(err) } @@ -30,17 +47,18 @@ func TestName(t *testing.T) { // 迭代 it.GetIpByIndex(0) // rest index - for itn := startIp; it.HasNext(); itn = it.Next() { + i := 0 + for itn := startIp; it.HasNext() && i <= 3; itn = it.Next() { t.Log(itn) + i++ } // 包含判断 - t.Log("Contains 1.1.1.0?", it.Contains(net.ParseIP("1.1.1.0"))) - t.Log("Contains 1.1.1.1?", it.Contains(net.ParseIP("1.1.1.1"))) - t.Log("Contains 1.1.1.3?", it.Contains(net.ParseIP("1.1.1.3"))) - t.Log("Contains 2001::59:63?", it.Contains(net.ParseIP("2001::59:63"))) - t.Log("Contains 2001::59:f2?", it.Contains(net.ParseIP("2001::59:f2"))) - t.Log("Contains 2001::59:f3?", it.Contains(net.ParseIP("2001::59:f3"))) + if !it.Contains(net.ParseIP(v.want)) { + t.Error(fmt.Sprintf("[ERR] %s Contains %s?", v.input, v.want), it.Contains(net.ParseIP(v.want))) + } else { + t.Log(fmt.Sprintf("%s Contains %s?", v.input, v.want), it.Contains(net.ParseIP(v.want))) + } } // 简单的获取IP序列