// populateEndpoints processes p.streams and adds the corresponding
// StreamEndpointSpects to the processors' input and output specs.
func (dsp *distSQLPlanner) populateEndpoints(planCtx *planningCtx, p *physicalPlan) {
	// Note: we could fill in the input/output specs directly instead of adding
	// streams to p.streams, but this makes the rest of the code a bit simpler.
	for sIdx, s := range p.streams {
		p1 := &p.processors[s.sourceProcessor]
		p2 := &p.processors[s.destProcessor]
		endpoint := distsqlrun.StreamEndpointSpec{StreamID: distsqlrun.StreamID(sIdx)}
		if p1.node == p2.node {
			endpoint.Type = distsqlrun.StreamEndpointSpec_LOCAL
		} else {
			endpoint.Type = distsqlrun.StreamEndpointSpec_REMOTE
		}
		p2.spec.Input[s.destInput].Streams = append(p2.spec.Input[s.destInput].Streams, endpoint)
		if endpoint.Type == distsqlrun.StreamEndpointSpec_REMOTE {
			endpoint.TargetAddr = planCtx.nodeAddresses[p2.node]
		}

		router := &p1.spec.Output[0]
		// We are about to put this stream on the len(router.Streams) position in
		// the router; verify this matches the sourceRouterSlot. We expect it to
		// because the streams should be in order; if that assumption changes we can
		// reorder them here according to sourceRouterSlot.
		if len(router.Streams) != s.sourceRouterSlot {
			panic(fmt.Sprintf(
				"sourceRouterSlot mismatch: %d, expected %d", len(router.Streams), s.sourceRouterSlot,
			))
		}
		router.Streams = append(router.Streams, endpoint)
	}
}
Esempio n. 2
0
// checkDistAggregationInfo tests that a flow with multiple local stages and a
// final stage (in accordance with per DistAggregationInfo) gets the same result
// with a naive aggregation flow that has a single non-distributed stage.
//
// Both types of flows are set up and ran against the first numRows of the given
// table. We assume the table's first column is the primary key, with values
// from 1 to numRows. A non-PK column that works with the function is chosen.
func checkDistAggregationInfo(
	t *testing.T,
	srv serverutils.TestServerInterface,
	tableDesc *sqlbase.TableDescriptor,
	colIdx int,
	numRows int,
	fn distsqlrun.AggregatorSpec_Func,
	info DistAggregationInfo,
) {
	colType := tableDesc.Columns[colIdx].Type

	makeTableReader := func(startPK, endPK int, streamID int) distsqlrun.ProcessorSpec {
		tr := distsqlrun.TableReaderSpec{
			Table:         *tableDesc,
			OutputColumns: []uint32{uint32(colIdx)},
			Spans:         make([]distsqlrun.TableReaderSpan, 1),
		}

		var err error
		tr.Spans[0].Span.Key, err = sqlbase.MakePrimaryIndexKey(tableDesc, startPK)
		if err != nil {
			t.Fatal(err)
		}
		tr.Spans[0].Span.EndKey, err = sqlbase.MakePrimaryIndexKey(tableDesc, endPK)
		if err != nil {
			t.Fatal(err)
		}

		return distsqlrun.ProcessorSpec{
			Core: distsqlrun.ProcessorCoreUnion{TableReader: &tr},
			Output: []distsqlrun.OutputRouterSpec{{
				Type: distsqlrun.OutputRouterSpec_PASS_THROUGH,
				Streams: []distsqlrun.StreamEndpointSpec{
					{Type: distsqlrun.StreamEndpointSpec_LOCAL, StreamID: distsqlrun.StreamID(streamID)},
				},
			}},
		}
	}

	// First run a flow that aggregates all the rows without any local stages.

	rowsNonDist := runTestFlow(
		t, srv,
		makeTableReader(1, numRows+1, 0),
		distsqlrun.ProcessorSpec{
			Input: []distsqlrun.InputSyncSpec{{
				Type:        distsqlrun.InputSyncSpec_UNORDERED,
				ColumnTypes: []sqlbase.ColumnType{colType},
				Streams: []distsqlrun.StreamEndpointSpec{
					{Type: distsqlrun.StreamEndpointSpec_LOCAL, StreamID: 0},
				},
			}},
			Core: distsqlrun.ProcessorCoreUnion{Aggregator: &distsqlrun.AggregatorSpec{
				Aggregations: []distsqlrun.AggregatorSpec_Aggregation{{Func: fn, ColIdx: 0}},
			}},
			Output: []distsqlrun.OutputRouterSpec{{
				Type: distsqlrun.OutputRouterSpec_PASS_THROUGH,
				Streams: []distsqlrun.StreamEndpointSpec{
					{Type: distsqlrun.StreamEndpointSpec_SYNC_RESPONSE},
				},
			}},
		},
	)

	// Now run a flow with 4 separate table readers, each with its own local
	// stage, all feeding into a single final stage.

	numParallel := 4
	// The type outputted by the local stage can be different than the input type
	// (e.g. DECIMAL instead of INT).
	_, intermediaryType, err := distsqlrun.GetAggregateInfo(fn, colType)
	if err != nil {
		t.Fatal(err)
	}

	if numParallel < numRows {
		numParallel = numRows
	}
	finalProc := distsqlrun.ProcessorSpec{
		Input: []distsqlrun.InputSyncSpec{{
			Type:        distsqlrun.InputSyncSpec_UNORDERED,
			ColumnTypes: []sqlbase.ColumnType{intermediaryType},
		}},
		Core: distsqlrun.ProcessorCoreUnion{Aggregator: &distsqlrun.AggregatorSpec{
			Aggregations: []distsqlrun.AggregatorSpec_Aggregation{{Func: info.FinalStage, ColIdx: 0}},
		}},
		Output: []distsqlrun.OutputRouterSpec{{
			Type: distsqlrun.OutputRouterSpec_PASS_THROUGH,
			Streams: []distsqlrun.StreamEndpointSpec{
				{Type: distsqlrun.StreamEndpointSpec_SYNC_RESPONSE},
			},
		}},
	}
	var procs []distsqlrun.ProcessorSpec
	for i := 0; i < numParallel; i++ {
		tr := makeTableReader(1+i*numRows/numParallel, 1+(i+1)*numRows/numParallel, 2*i)
		agg := distsqlrun.ProcessorSpec{
			Input: []distsqlrun.InputSyncSpec{{
				Type:        distsqlrun.InputSyncSpec_UNORDERED,
				ColumnTypes: []sqlbase.ColumnType{colType},
				Streams: []distsqlrun.StreamEndpointSpec{
					{Type: distsqlrun.StreamEndpointSpec_LOCAL, StreamID: distsqlrun.StreamID(2 * i)},
				},
			}},
			Core: distsqlrun.ProcessorCoreUnion{Aggregator: &distsqlrun.AggregatorSpec{
				Aggregations: []distsqlrun.AggregatorSpec_Aggregation{{Func: info.LocalStage, ColIdx: 0}},
			}},
			Output: []distsqlrun.OutputRouterSpec{{
				Type: distsqlrun.OutputRouterSpec_PASS_THROUGH,
				Streams: []distsqlrun.StreamEndpointSpec{
					{Type: distsqlrun.StreamEndpointSpec_LOCAL, StreamID: distsqlrun.StreamID(2*i + 1)},
				},
			}},
		}
		procs = append(procs, tr, agg)
		finalProc.Input[0].Streams = append(finalProc.Input[0].Streams, distsqlrun.StreamEndpointSpec{
			Type:     distsqlrun.StreamEndpointSpec_LOCAL,
			StreamID: distsqlrun.StreamID(2*i + 1),
		})
	}
	procs = append(procs, finalProc)
	rowsDist := runTestFlow(t, srv, procs...)

	if len(rowsDist[0]) != len(rowsNonDist[0]) {
		t.Errorf("different row lengths (dist: %d non-dist: %d)", len(rowsDist[0]), len(rowsNonDist[0]))
	} else {
		for i := range rowsDist[0] {
			tDist := rowsDist[0][i].Type.String()
			tNonDist := rowsNonDist[0][i].Type.String()
			if tDist != tNonDist {
				t.Errorf("different type for column %d (dist: %s non-dist: %s)", i, tDist, tNonDist)
			}
		}
	}
	if rowsDist.String() != rowsNonDist.String() {
		t.Errorf("different results\nw/o local stage:   %s\nwith local stage:  %s", rowsNonDist, rowsDist)
	}
}