Skip to content

Commit

Permalink
Updates to V3 Observation (#1506)
Browse files Browse the repository at this point in the history
This includes:
* Adds wrapper for calculated responses (also cleans up some of the
calculation code from v2 to make it more reusable)
* Adds merging to V3 observation
  • Loading branch information
n-h-diaz authored Jan 24, 2025
1 parent 3933206 commit bbfacfd
Show file tree
Hide file tree
Showing 13 changed files with 256 additions and 95 deletions.
2 changes: 1 addition & 1 deletion cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ func main() {
// Processors
processors := []*dispatcher.Processor{}
if *enableV3 {
var calculationProcessor dispatcher.Processor = &observation.CalculationProcessor{}
var calculationProcessor dispatcher.Processor = observation.NewCalculationProcessor(dataSources, c.SVFormula())
processors = append(processors, &calculationProcessor)
}

Expand Down
28 changes: 11 additions & 17 deletions internal/merger/merger.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,7 @@ func MergeNode(main, aux *pbv2.NodeResponse) (*pbv2.NodeResponse, error) {

// Merges multiple V2 NodeResponses.
// Assumes the responses are in order of priority.
func MergeMultiNode(
allResp []*pbv2.NodeResponse,
) (*pbv2.NodeResponse, error) {
func MergeMultiNode(allResp []*pbv2.NodeResponse) (*pbv2.NodeResponse, error) {
if len(allResp) == 0 {
return &pbv2.NodeResponse{}, nil
}
Expand Down Expand Up @@ -272,22 +270,18 @@ func MergeObservation(main, aux *pbv2.ObservationResponse) *pbv2.ObservationResp
return main
}

// MergeMultipleObservations merges multiple V2 observation responses, ranked
// in order of preference.
func MergeMultipleObservations(obs ...*pbv2.ObservationResponse) *pbv2.ObservationResponse {
if obs == nil {
return nil
}
if len(obs) == 0 {
return nil
}
if len(obs) == 1 {
return obs[0]
// Merges multiple V2 ObservationResponses.
// Assumes the responses are in order of priority.
func MergeMultiObservation(allResp []*pbv2.ObservationResponse) *pbv2.ObservationResponse {
if len(allResp) == 0 {
return &pbv2.ObservationResponse{}
}
if len(obs) == 2 {
return MergeObservation(obs[0], obs[1])
prev := allResp[0]
for i := 1; i < len(allResp); i++ {
cur := MergeObservation(prev, allResp[i])
prev = cur
}
return MergeObservation(obs[0], MergeMultipleObservations(obs[1:]...))
return prev
}

// MergeObservationDates merges two V1 observation-dates responses.
Expand Down
37 changes: 17 additions & 20 deletions internal/merger/merger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -672,19 +672,17 @@ func TestMergeObservation(t *testing.T) {
}
}

func TestMergeMultipleObservations(t *testing.T) {
func TestMergeMultiObservation(t *testing.T) {
cmpOpts := cmp.Options{
protocmp.Transform(),
}

for _, c := range []struct {
o1 *pbv2.ObservationResponse
o2 *pbv2.ObservationResponse
o3 *pbv2.ObservationResponse
want *pbv2.ObservationResponse
allResp []*pbv2.ObservationResponse
want *pbv2.ObservationResponse
}{
{
&pbv2.ObservationResponse{
{[]*pbv2.ObservationResponse{
{
ByVariable: map[string]*pbv2.VariableObservation{
"var1": {
ByEntity: map[string]*pbv2.EntityObservation{
Expand All @@ -705,7 +703,7 @@ func TestMergeMultipleObservations(t *testing.T) {
},
},
},
&pbv2.ObservationResponse{
{
ByVariable: map[string]*pbv2.VariableObservation{
"var1": {
ByEntity: map[string]*pbv2.EntityObservation{
Expand All @@ -726,7 +724,7 @@ func TestMergeMultipleObservations(t *testing.T) {
},
},
},
&pbv2.ObservationResponse{
{
ByVariable: map[string]*pbv2.VariableObservation{
"var1": {
ByEntity: map[string]*pbv2.EntityObservation{
Expand All @@ -746,7 +744,7 @@ func TestMergeMultipleObservations(t *testing.T) {
},
},
},
},
}},
&pbv2.ObservationResponse{
ByVariable: map[string]*pbv2.VariableObservation{
"var1": {
Expand Down Expand Up @@ -788,9 +786,9 @@ func TestMergeMultipleObservations(t *testing.T) {
},
},
} {
got := MergeMultipleObservations(c.o1, c.o2, c.o3)
got := MergeMultiObservation(c.allResp)
if diff := cmp.Diff(got, c.want, cmpOpts); diff != "" {
t.Errorf("MergeMultipleObservations(%v, %v, %v) got diff: %s", c.o1, c.o2, c.o3, diff)
t.Errorf("MergeMultiObservation(%v) got diff: %s", c.allResp, diff)
}
}
}
Expand Down Expand Up @@ -979,7 +977,6 @@ func TestMergeBulkVariableInfoResponse(t *testing.T) {
}
}


func TestMergeSearchStatVarResponse(t *testing.T) {
cmpOpts := cmp.Options{protocmp.Transform()}
for _, tc := range []struct {
Expand All @@ -997,7 +994,7 @@ func TestMergeSearchStatVarResponse(t *testing.T) {
},
{
Name: "sv2",
Dcid: "svid2",
Dcid: "svid2",
},
},
Matches: []string{"match1", "match2"},
Expand Down Expand Up @@ -1025,7 +1022,7 @@ func TestMergeSearchStatVarResponse(t *testing.T) {
},
{
Name: "sv2",
Dcid: "svid2",
Dcid: "svid2",
},
},
Matches: []string{"match1", "match2"},
Expand All @@ -1038,7 +1035,7 @@ func TestMergeSearchStatVarResponse(t *testing.T) {
},
{
Name: "sv2",
Dcid: "svid2",
Dcid: "svid2",
},
},
Matches: []string{"match1", "match2"},
Expand All @@ -1053,7 +1050,7 @@ func TestMergeSearchStatVarResponse(t *testing.T) {
},
{
Name: "sv2",
Dcid: "svid2",
Dcid: "svid2",
},
},
Matches: []string{"match1", "match2"},
Expand All @@ -1066,7 +1063,7 @@ func TestMergeSearchStatVarResponse(t *testing.T) {
},
{
Name: "sv4",
Dcid: "svid4",
Dcid: "svid4",
},
},
Matches: []string{"match1", "match3"},
Expand All @@ -1079,15 +1076,15 @@ func TestMergeSearchStatVarResponse(t *testing.T) {
},
{
Name: "sv2",
Dcid: "svid2",
Dcid: "svid2",
},
{
Name: "sv3",
Dcid: "svid3",
},
{
Name: "sv4",
Dcid: "svid4",
Dcid: "svid4",
},
},
Matches: []string{"match1", "match2", "match3"},
Expand Down
33 changes: 27 additions & 6 deletions internal/server/datasources/datasources.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ func (ds *DataSources) Node(ctx context.Context, in *pbv2.NodeRequest) (*pbv2.No
for _, source := range ds.sources {
respChan := make(chan *pbv2.NodeResponse, 1)
errGroup.Go(func() error {
defer close(respChan)
resp, err := (*source).Node(errCtx, in)
if err != nil {
return err
Expand All @@ -57,20 +58,40 @@ func (ds *DataSources) Node(ctx context.Context, in *pbv2.NodeRequest) (*pbv2.No

allResp := []*pbv2.NodeResponse{}
for _, respChan := range dsRespChan {
close(respChan)
allResp = append(allResp, <-respChan)
}

return merger.MergeMultiNode(allResp)
}

func (ds *DataSources) Observation(ctx context.Context, in *pbv2.ObservationRequest) (*pbv2.ObservationResponse, error) {
if len(ds.sources) == 0 {
return nil, fmt.Errorf("no sources found")
errGroup, errCtx := errgroup.WithContext(ctx)
dsRespChan := []chan *pbv2.ObservationResponse{}

for _, source := range ds.sources {
respChan := make(chan *pbv2.ObservationResponse, 1)
errGroup.Go(func() error {
defer close(respChan)
resp, err := (*source).Observation(errCtx, in)
if err != nil {
return err
}
respChan <- resp
return nil
})
dsRespChan = append(dsRespChan, respChan)
}
// Returning only the first one right now.
// TODO: Execute in parallel and returned merged response.
return (*ds.sources[0]).Observation(ctx, in)

if err := errGroup.Wait(); err != nil {
return nil, err
}

allResp := []*pbv2.ObservationResponse{}
for _, respChan := range dsRespChan {
allResp = append(allResp, <-respChan)
}

return merger.MergeMultiObservation(allResp), nil
}

func (ds *DataSources) NodeSearch(ctx context.Context, in *pbv2.NodeSearchRequest) (*pbv2.NodeSearchResponse, error) {
Expand Down
2 changes: 1 addition & 1 deletion internal/server/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ func (s *Server) V2Observation(
}
// initialResp is preferred over any calculated response.
combinedResp := append([]*pbv2.ObservationResponse{initialResp}, calculatedResps...)
return merger.MergeMultipleObservations(combinedResp...), nil
return merger.MergeMultiObservation(combinedResp), nil
}

// V2Sparql implements API for Mixer.V2Sparql.
Expand Down
2 changes: 1 addition & 1 deletion internal/server/spanner/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ func (sc *SpannerClient) GetObservations(ctx context.Context, variables []string
return observations, nil
}

// GetObservations retrieves observations from Spanner given a list of variables and entities.
// GetObservationsContainedInPlace retrieves observations from Spanner given a list of variables and an entity expression.
func (sc *SpannerClient) GetObservationsContainedInPlace(ctx context.Context, variables []string, containedInPlace *v2.ContainedInPlace) ([]*Observation, error) {
var observations []*Observation
if len(variables) == 0 || containedInPlace == nil {
Expand Down
24 changes: 5 additions & 19 deletions internal/server/v2/observation/calculation.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ import (
)

type Equation struct {
variable string
formula string
Variable string
Formula string
}

// Computes a calculation for a variable and entity, based on a formula and input data.
Expand All @@ -43,7 +43,7 @@ func Calculate(
entity *pbv2.DcidOrExpression,
inputReq *pbv2.ObservationRequest,
) (*pbv2.ObservationResponse, error) {
variableFormula, err := formula.NewVariableFormula(equation.formula)
variableFormula, err := formula.NewVariableFormula(equation.Formula)
if err != nil {
return nil, err
}
Expand All @@ -62,18 +62,7 @@ func Calculate(
if err != nil {
return nil, err
}
intermediateResp, err := evalExpr(variableFormula.Expr, variableFormula.LeafData, inputObs)
if err != nil {
return nil, err
}
if intermediateResp.variableObs == nil {
return nil, fmt.Errorf("nil calculation response")
}
calculatedResp, err := formatCalculatedResponse(intermediateResp.variableObs, inputObs.Facets, equation)
if err != nil {
return nil, err
}
return calculatedResp, nil
return EvalExpr(variableFormula, inputObs, equation)
}

// Detects holes in a V2ObservationResponse and attempts to fill them using calculations.
Expand All @@ -87,10 +76,7 @@ func MaybeCalculateHoles(
inputResp *pbv2.ObservationResponse,
) ([]*pbv2.ObservationResponse, error) {
result := []*pbv2.ObservationResponse{}
holes, err := findObservationResponseHoles(inputReq, inputResp)
if err != nil {
return nil, err
}
holes := FindObservationResponseHoles(inputReq, inputResp)
for variable, entity := range holes {
formulas, ok := cachedata.SVFormula()[variable]
if !ok {
Expand Down
31 changes: 26 additions & 5 deletions internal/server/v2/observation/calculation_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ type intermediateObsResponse struct {
}

// Given an input ObservationResponse, generate a map of variable -> entities with missing data.
func findObservationResponseHoles(
func FindObservationResponseHoles(
inputReq *pbv2.ObservationRequest,
inputResp *pbv2.ObservationResponse,
) (map[string]*pbv2.DcidOrExpression, error) {
) map[string]*pbv2.DcidOrExpression {
result := map[string]*pbv2.DcidOrExpression{}
// Formula variables are handled by DerivedSeries.
if inputReq.Variable.GetFormula() != "" {
return result, nil
return result
}
for variable, variableObs := range inputResp.ByVariable {
if len(inputReq.Entity.GetDcids()) > 0 {
Expand All @@ -60,7 +60,7 @@ func findObservationResponseHoles(
}
}
}
return result, nil
return result
}

func compareFacet(facet1, facet2 *pb.Facet) bool {
Expand Down Expand Up @@ -284,6 +284,27 @@ func evalBinaryExpr(
return nil, fmt.Errorf("invalid binary expr")
}

// Evaluate a calculation given a formula and input observations.
func EvalExpr(
variableFormula *formula.VariableFormula,
inputObs *pbv2.ObservationResponse,
equation *Equation,
) (*pbv2.ObservationResponse, error) {
intermediateResp, err := evalExpr(variableFormula.Expr, variableFormula.LeafData, inputObs)
if err != nil {
return nil, err
}
if intermediateResp.variableObs == nil {
return nil, fmt.Errorf("nil calculation response")
}

calculatedResp, err := formatCalculatedResponse(intermediateResp.variableObs, inputObs.Facets, equation)
if err != nil {
return nil, err
}
return calculatedResp, nil
}

// Recursively iterate through the AST and perform the calculation.
func evalExpr(
node ast.Node,
Expand Down Expand Up @@ -326,7 +347,7 @@ func formatCalculatedResponse(
) (*pbv2.ObservationResponse, error) {
resp := &pbv2.ObservationResponse{
ByVariable: map[string]*pbv2.VariableObservation{
equation.variable: variableObs,
equation.Variable: variableObs,
},
Facets: map[string]*pb.Facet{},
}
Expand Down
8 changes: 2 additions & 6 deletions internal/server/v2/observation/calculation_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,9 @@ func TestFindObservationResponseHoles(t *testing.T) {
},
},
} {
got, err := findObservationResponseHoles(c.inputReq, c.inputResp)
if err != nil {
t.Errorf("error running TestFindObservationResponseHoles: %s", err)
continue
}
got := FindObservationResponseHoles(c.inputReq, c.inputResp)
if ok := reflect.DeepEqual(got, c.want); !ok {
t.Errorf("findObservationResponseHoles(%v, %v) = %v, want %v",
t.Errorf("FindObservationResponseHoles(%v, %v) = %v, want %v",
c.inputReq, c.inputResp, got, c.want)
}
}
Expand Down
Loading

0 comments on commit bbfacfd

Please sign in to comment.