예제 #1
0
func (p *pruner) pruneTree(t *pb.TreeNode, e Examples) prunedStage {
	bestNode, bestCost, bestLeaves := &pb.TreeNode{}, math.MaxFloat64, 0
	mapTree(t, e, TreeMapperFunc(func(n *pb.TreeNode, ex Examples) (*pb.TreeNode, bool) {
		nodeSquaredDivergence, nodeLeaves := weakestLinkCostFunction(n, ex)
		nodeCost := nodeSquaredDivergence / float64(nodeLeaves)
		if nodeCost < bestCost {
			bestNode = t
			bestCost = nodeCost
			bestLeaves = nodeLeaves
		}
		return proto.Clone(n).(*pb.TreeNode), true
	}))

	prunedTree := mapTree(t, e, TreeMapperFunc(func(n *pb.TreeNode, ex Examples) (*pb.TreeNode, bool) {
		if n != bestNode {
			return proto.Clone(n).(*pb.TreeNode), true
		}

		// Otherwise, return the leaf constructed by pruning all subtrees
		leafWeight := p.lossFunction.GetLeafWeight(ex)
		prior := p.lossFunction.GetPrior(ex)
		return &pb.TreeNode{
			LeafValue: proto.Float64(leafWeight * prior),
		}, false
	}))

	rootCost, rootLeaves := weakestLinkCostFunction(t, e)
	alpha := (rootCost - bestCost) / float64(rootLeaves-bestLeaves)
	return prunedStage{
		alpha: alpha,
		tree:  prunedTree,
	}
}
예제 #2
0
func TestClone(t *testing.T) {
	m := proto.Clone(cloneTestMessage).(*pb.MyMessage)
	if !proto.Equal(m, cloneTestMessage) {
		t.Errorf("Clone(%v) = %v", cloneTestMessage, m)
	}

	// Verify it was a deep copy.
	*m.Inner.Port++
	if proto.Equal(m, cloneTestMessage) {
		t.Error("Mutating clone changed the original")
	}
	// Byte fields and repeated fields should be copied.
	if &m.Pet[0] == &cloneTestMessage.Pet[0] {
		t.Error("Pet: repeated field not copied")
	}
	if &m.Others[0] == &cloneTestMessage.Others[0] {
		t.Error("Others: repeated field not copied")
	}
	if &m.Others[0].Value[0] == &cloneTestMessage.Others[0].Value[0] {
		t.Error("Others[0].Value: bytes field not copied")
	}
	if &m.RepBytes[0] == &cloneTestMessage.RepBytes[0] {
		t.Error("RepBytes: repeated field not copied")
	}
	if &m.RepBytes[0][0] == &cloneTestMessage.RepBytes[0][0] {
		t.Error("RepBytes[0]: bytes field not copied")
	}
}
예제 #3
0
func TestMerge(t *testing.T) {
	for _, m := range mergeTests {
		got := proto.Clone(m.dst)
		proto.Merge(got, m.src)
		if !proto.Equal(got, m.want) {
			t.Errorf("Merge(%v, %v)\n got %v\nwant %v\n", m.dst, m.src, got, m.want)
		}
	}
}
예제 #4
0
func TestClone(t *testing.T) {
	m := proto.Clone(cloneTestMessage).(*pb.MyMessage)
	if !proto.Equal(m, cloneTestMessage) {
		t.Errorf("Clone(%v) = %v", cloneTestMessage, m)
	}

	// Verify it was a deep copy.
	*m.Inner.Port++
	if proto.Equal(m, cloneTestMessage) {
		t.Error("Mutating clone changed the original")
	}
}
예제 #5
0
func TestCloneNil(t *testing.T) {
	var m *pb.MyMessage
	if c := proto.Clone(m); !proto.Equal(m, c) {
		t.Errorf("Clone(%v) = %v", m, c)
	}
}