func addSelection(p Plan, child LogicalPlan, conditions []expression.Expression, allocator *idAllocator) error { conditions = expression.PropagateConstant(p.context(), conditions) selection := &Selection{ Conditions: conditions, baseLogicalPlan: newBaseLogicalPlan(Sel, allocator)} selection.self = selection selection.initIDAndContext(p.context()) selection.SetSchema(child.GetSchema().Clone()) selection.correlated = child.IsCorrelated() for _, cond := range conditions { selection.correlated = selection.correlated || cond.IsCorrelated() } return InsertPlan(p, child, selection) }
// PredicatePushDown implements LogicalPlan PredicatePushDown interface. func (p *Selection) PredicatePushDown(predicates []expression.Expression) ([]expression.Expression, LogicalPlan, error) { retConditions, child, err := p.GetChildByIndex(0).(LogicalPlan).PredicatePushDown(append(p.Conditions, predicates...)) if err != nil { return nil, nil, errors.Trace(err) } if len(retConditions) > 0 { p.Conditions = expression.PropagateConstant(p.ctx, retConditions) return nil, p, nil } err = RemovePlan(p) if err != nil { return nil, nil, errors.Trace(err) } return nil, child, nil }
func (s *testPlanSuite) TestConstantPropagation(c *C) { defer testleak.AfterTest(c)() cases := []struct { sql string after string }{ { sql: "a = b and b = c and c = d and d = 1", after: "eq(test.t.a, 1), eq(test.t.b, 1), eq(test.t.c, 1), eq(test.t.d, 1)", }, { sql: "a = b and b = 1 and a = null and c = d and c > 2 and c != 4 and d != 5", after: "0", }, { sql: "a = b and b = 1 and c = d and c > 2 and c != 4 and d != 5", after: "eq(test.t.a, 1), eq(test.t.b, 1), eq(test.t.c, test.t.d), gt(test.t.c, 2), gt(test.t.d, 2), ne(test.t.c, 4), ne(test.t.c, 5), ne(test.t.d, 4), ne(test.t.d, 5)", }, { sql: "a = b and b > 0 and a = c", after: "eq(test.t.a, test.t.b), eq(test.t.a, test.t.c), gt(test.t.a, 0), gt(test.t.b, 0), gt(test.t.c, 0)", }, { sql: "a = b and b = c and c LIKE 'abc%'", after: "eq(test.t.a, test.t.b), eq(test.t.b, test.t.c), like(cast(test.t.c), abc%, 92)", }, { sql: "a = b and a > 2 and b > 3 and a < 1 and b < 2", after: "eq(test.t.a, test.t.b), gt(test.t.a, 2), gt(test.t.a, 3), gt(test.t.b, 2), gt(test.t.b, 3), lt(test.t.a, 1), lt(test.t.a, 2), lt(test.t.b, 1), lt(test.t.b, 2)", }, { sql: "a = 1 and cast(null as SIGNED) is null", after: "1, eq(test.t.a, 1)", }, } for _, ca := range cases { sql := "select * from t where " + ca.sql comment := Commentf("for %s", sql) stmt, err := s.ParseOneStmt(sql, "", "") c.Assert(err, IsNil, comment) err = mockResolve(stmt) c.Assert(err, IsNil) builder := &planBuilder{ allocator: new(idAllocator), ctx: mock.NewContext(), colMapper: make(map[*ast.ColumnNameExpr]int), } p := builder.build(stmt) c.Assert(builder.err, IsNil) lp := p.(LogicalPlan) var ( sel *Selection ok bool result []string ) v := lp for { if sel, ok = v.(*Selection); ok { break } v = v.GetChildByIndex(0).(LogicalPlan) } newConds := expression.PropagateConstant(builder.ctx, sel.Conditions) for _, v := range newConds { result = append(result, v.String()) } sort.Strings(result) c.Assert(strings.Join(result, ", "), Equals, ca.after, Commentf("for %s", ca.sql)) } }
// PredicatePushDown implements LogicalPlan PredicatePushDown interface. func (p *Join) PredicatePushDown(predicates []expression.Expression) (ret []expression.Expression, retPlan LogicalPlan, err error) { err = outerJoinSimplify(p, predicates) if err != nil { return nil, nil, errors.Trace(err) } groups, valid := tryToGetJoinGroup(p) if valid { e := joinReOrderSolver{allocator: p.allocator} e.reorderJoin(groups, predicates) newJoin := e.resultJoin parent := p.parents[0] newJoin.SetParents(parent) parent.ReplaceChild(p, newJoin) return newJoin.PredicatePushDown(predicates) } var leftCond, rightCond []expression.Expression retPlan = p leftPlan := p.GetChildByIndex(0).(LogicalPlan) rightPlan := p.GetChildByIndex(1).(LogicalPlan) var ( equalCond []*expression.ScalarFunction leftPushCond, rightPushCond, otherCond []expression.Expression ) if p.JoinType != InnerJoin { equalCond, leftPushCond, rightPushCond, otherCond = extractOnCondition(predicates, leftPlan, rightPlan) } else { tempCond := make([]expression.Expression, 0, len(p.LeftConditions)+len(p.RightConditions)+len(p.EqualConditions)+len(p.OtherConditions)+len(predicates)) tempCond = append(tempCond, p.LeftConditions...) tempCond = append(tempCond, p.RightConditions...) tempCond = append(tempCond, expression.ScalarFuncs2Exprs(p.EqualConditions)...) tempCond = append(tempCond, p.OtherConditions...) tempCond = append(tempCond, predicates...) equalCond, leftPushCond, rightPushCond, otherCond = extractOnCondition(expression.PropagateConstant(p.ctx, tempCond), leftPlan, rightPlan) } switch p.JoinType { case LeftOuterJoin, SemiJoinWithAux: rightCond = p.RightConditions p.RightConditions = nil leftCond = leftPushCond ret = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...) ret = append(ret, rightPushCond...) case RightOuterJoin: leftCond = p.LeftConditions p.LeftConditions = nil rightCond = rightPushCond ret = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...) ret = append(ret, leftPushCond...) case SemiJoin: equalCond, leftPushCond, rightPushCond, otherCond = extractOnCondition(predicates, leftPlan, rightPlan) leftCond = append(p.LeftConditions, leftPushCond...) rightCond = append(p.RightConditions, rightPushCond...) p.LeftConditions = nil p.RightConditions = nil case InnerJoin: p.LeftConditions = nil p.RightConditions = nil p.EqualConditions = equalCond p.OtherConditions = otherCond leftCond = leftPushCond rightCond = rightPushCond } leftRet, _, err1 := leftPlan.PredicatePushDown(leftCond) if err1 != nil { return nil, nil, errors.Trace(err1) } rightRet, _, err2 := rightPlan.PredicatePushDown(rightCond) if err2 != nil { return nil, nil, errors.Trace(err2) } if len(leftRet) > 0 { err2 = addSelection(p, leftPlan, leftRet, p.allocator) if err2 != nil { return nil, nil, errors.Trace(err2) } } if len(rightRet) > 0 { err2 = addSelection(p, rightPlan, rightRet, p.allocator) if err2 != nil { return nil, nil, errors.Trace(err2) } } return }