diff --git a/gnmi_server/client_subscribe.go b/gnmi_server/client_subscribe.go index 8c36642d..bf791e79 100644 --- a/gnmi_server/client_subscribe.go +++ b/gnmi_server/client_subscribe.go @@ -161,7 +161,7 @@ func (c *Client) Run(stream gnmipb.GNMI_SubscribeServer) (err error) { if origin == "openconfig" { dc, err = sdc.NewTranslClient(prefix, paths, ctx, extensions, sdc.TranslWildcardOption{}) } else if IsNativeOrigin(origin) { - dc, err = sdc.NewMixedDbClient(paths, prefix, origin, gnmipb.Encoding_JSON_IETF, "") + dc, err = sdc.NewMixedDbClient(paths, prefix, origin, gnmipb.Encoding_JSON_IETF, "", "") } else if len(origin) != 0 { return grpc.Errorf(codes.Unimplemented, "Unsupported origin: %s", origin) } else if target == "" { diff --git a/gnmi_server/server.go b/gnmi_server/server.go index be490f54..f3ec24ce 100644 --- a/gnmi_server/server.go +++ b/gnmi_server/server.go @@ -85,6 +85,7 @@ type Config struct { ZmqPort string IdleConnDuration int ConfigTableName string + Vrf string } var AuthLock sync.Mutex @@ -409,7 +410,7 @@ func (s *Server) Get(ctx context.Context, req *gnmipb.GetRequest) (*gnmipb.GetRe } } if check := IsNativeOrigin(origin); check { - dc, err = sdc.NewMixedDbClient(paths, prefix, origin, encoding, s.config.ZmqPort) + dc, err = sdc.NewMixedDbClient(paths, prefix, origin, encoding, s.config.ZmqPort, s.config.Vrf) } else { dc, err = sdc.NewTranslClient(prefix, paths, ctx, extensions) } @@ -507,7 +508,7 @@ func (s *Server) Set(ctx context.Context, req *gnmipb.SetRequest) (*gnmipb.SetRe common_utils.IncCounter(common_utils.GNMI_SET_FAIL) return nil, grpc.Errorf(codes.Unimplemented, "GNMI native write is disabled") } - dc, err = sdc.NewMixedDbClient(paths, prefix, origin, encoding, s.config.ZmqPort) + dc, err = sdc.NewMixedDbClient(paths, prefix, origin, encoding, s.config.ZmqPort, s.config.Vrf) } else { if s.config.EnableTranslibWrite == false { common_utils.IncCounter(common_utils.GNMI_SET_FAIL) @@ -584,7 +585,7 @@ func (s *Server) Capabilities(ctx context.Context, req *gnmipb.CapabilityRequest var supportedModels []gnmipb.ModelData dc, _ := sdc.NewTranslClient(nil, nil, ctx, extensions) supportedModels = append(supportedModels, dc.Capabilities()...) - dc, _ = sdc.NewMixedDbClient(nil, nil, "", gnmipb.Encoding_JSON_IETF, s.config.ZmqPort) + dc, _ = sdc.NewMixedDbClient(nil, nil, "", gnmipb.Encoding_JSON_IETF, s.config.ZmqPort, s.config.Vrf) supportedModels = append(supportedModels, dc.Capabilities()...) suppModels := make([]*gnmipb.ModelData, len(supportedModels)) diff --git a/sonic_data_client/client_test.go b/sonic_data_client/client_test.go index dcb4d74f..55cadaf4 100644 --- a/sonic_data_client/client_test.go +++ b/sonic_data_client/client_test.go @@ -793,17 +793,17 @@ func TestGetZmqClient(t *testing.T) { dpusTable.Hset("dpu0", "midplane_interface", "dpu0") dhcpPortTable.Hset("bridge-midplane|dpu0", "ips@", "127.0.0.2,127.0.0.1") - client, err := getZmqClient("dpu0", "") + client, err := getZmqClient("dpu0", "", "") if client != nil || err != nil { t.Errorf("empty ZMQ port should not get ZMQ client") } - client, err = getZmqClient("dpu0", "1234") + client, err = getZmqClient("dpu0", "1234", "") if client == nil { t.Errorf("get ZMQ client failed") } - client, err = getZmqClient("", "1234") + client, err = getZmqClient("", "1234", "") if client == nil { t.Errorf("get ZMQ client failed") } diff --git a/sonic_data_client/mixed_db_client.go b/sonic_data_client/mixed_db_client.go index ee9c8b07..4fe1bb22 100644 --- a/sonic_data_client/mixed_db_client.go +++ b/sonic_data_client/mixed_db_client.go @@ -159,10 +159,10 @@ func getZmqAddress(container string, zmqPort string) (string, error) { var zmqClientMap = map[string]swsscommon.ZmqClient{} -func getZmqClientByAddress(zmqAddress string) (swsscommon.ZmqClient, error) { +func getZmqClientByAddress(zmqAddress string, vrf string) (swsscommon.ZmqClient, error) { client, ok := zmqClientMap[zmqAddress] if !ok { - client = swsscommon.NewZmqClient(zmqAddress) + client = swsscommon.NewZmqClient(zmqAddress, vrf) zmqClientMap[zmqAddress] = client } @@ -181,7 +181,7 @@ func removeZmqClient(zmqClient swsscommon.ZmqClient) (error) { return fmt.Errorf("Can't find ZMQ client in zmqClientMap: %v", zmqClient) } -func getZmqClient(dpuId string, zmqPort string) (swsscommon.ZmqClient, error) { +func getZmqClient(dpuId string, zmqPort string, vrf string) (swsscommon.ZmqClient, error) { if zmqPort == "" { // ZMQ feature disabled when zmqPort flag not set return nil, nil @@ -189,7 +189,7 @@ func getZmqClient(dpuId string, zmqPort string) (swsscommon.ZmqClient, error) { if dpuId == sdcfg.SONIC_DEFAULT_CONTAINER { // When DPU ID is default, create ZMQ with local address - return getZmqClientByAddress("tcp://" + LOCAL_ADDRESS + ":" + zmqPort) + return getZmqClientByAddress("tcp://" + LOCAL_ADDRESS + ":" + zmqPort, vrf) } zmqAddress, err := getZmqAddress(dpuId, zmqPort) @@ -197,7 +197,7 @@ func getZmqClient(dpuId string, zmqPort string) (swsscommon.ZmqClient, error) { return nil, fmt.Errorf("Get ZMQ address failed: %v", err) } - return getZmqClientByAddress(zmqAddress) + return getZmqClientByAddress(zmqAddress, vrf) } // This function get target present in GNMI Request and @@ -493,7 +493,7 @@ func init() { initRedisDbMap() } -func NewMixedDbClient(paths []*gnmipb.Path, prefix *gnmipb.Path, origin string, encoding gnmipb.Encoding, zmqPort string) (Client, error) { +func NewMixedDbClient(paths []*gnmipb.Path, prefix *gnmipb.Path, origin string, encoding gnmipb.Encoding, zmqPort string, vrf string) (Client, error) { var err error // Initialize RedisDbMap for test @@ -556,7 +556,7 @@ func NewMixedDbClient(paths []*gnmipb.Path, prefix *gnmipb.Path, origin string, client.workPath = common_utils.GNMI_WORK_PATH // continer is DPU ID - client.zmqClient, err = getZmqClient(container, zmqPort) + client.zmqClient, err = getZmqClient(container, zmqPort, vrf) if err != nil { return nil, fmt.Errorf("Get ZMQ client failed: %v", err) } diff --git a/telemetry/telemetry.go b/telemetry/telemetry.go index 466cc33d..cb56e10c 100644 --- a/telemetry/telemetry.go +++ b/telemetry/telemetry.go @@ -57,6 +57,7 @@ type TelemetryConfig struct { WithMasterArbitration *bool WithSaveOnSet *bool IdleConnDuration *int + Vrf *string } func main() { @@ -165,6 +166,7 @@ func setupFlags(fs *flag.FlagSet) (*TelemetryConfig, *gnmi.Config, error) { WithMasterArbitration: fs.Bool("with-master-arbitration", false, "Enables master arbitration policy."), WithSaveOnSet: fs.Bool("with-save-on-set", false, "Enables save-on-set."), IdleConnDuration: fs.Int("idle_conn_duration", 5, "Seconds before server closes idle connections"), + Vrf: fs.String("vrf", "", "VRF name, when zmq_address belong on a VRF, need VRF name to bind ZMQ."), } fs.Var(&telemetryCfg.UserAuth, "client_auth", "Client auth mode(s) - none,cert,password") @@ -227,6 +229,7 @@ func setupFlags(fs *flag.FlagSet) (*TelemetryConfig, *gnmi.Config, error) { cfg.Threshold = int(*telemetryCfg.Threshold) cfg.IdleConnDuration = int(*telemetryCfg.IdleConnDuration) cfg.ConfigTableName = *telemetryCfg.ConfigTableName + cfg.Vrf = *telemetryCfg.Vrf // TODO: After other dependent projects are migrated to ZmqPort, remove ZmqAddress zmqAddress := *telemetryCfg.ZmqAddress