// TestMerkleBlockWire tests the MsgMerkleBlock wire encode and decode for
// various numbers of transaction hashes and protocol versions.
func TestMerkleBlockWire(t *testing.T) {
	tests := []struct {
		in   *wire.MsgMerkleBlock // Message to encode
		out  *wire.MsgMerkleBlock // Expected decoded message
		buf  []byte               // Wire encoding
		pver uint32               // Protocol version for wire encoding
	}{
		// Latest protocol version.
		{
			&merkleBlockOne, &merkleBlockOne, merkleBlockOneBytes,
			wire.ProtocolVersion,
		},

		// Protocol version BIP0037Version.
		{
			&merkleBlockOne, &merkleBlockOne, merkleBlockOneBytes,
			wire.BIP0037Version,
		},
	}

	t.Logf("Running %d tests", len(tests))
	for i, test := range tests {
		// Encode the message to wire format.
		var buf bytes.Buffer
		err := test.in.BtcEncode(&buf, test.pver)
		if err != nil {
			t.Errorf("BtcEncode #%d error %v", i, err)
			continue
		}
		if !bytes.Equal(buf.Bytes(), test.buf) {
			t.Errorf("BtcEncode #%d\n got: %s want: %s", i,
				spew.Sdump(buf.Bytes()), spew.Sdump(test.buf))
			continue
		}

		// Decode the message from wire format.
		var msg wire.MsgMerkleBlock
		rbuf := bytes.NewReader(test.buf)
		err = msg.BtcDecode(rbuf, test.pver)
		if err != nil {
			t.Errorf("BtcDecode #%d error %v", i, err)
			continue
		}
		if !reflect.DeepEqual(&msg, test.out) {
			t.Errorf("BtcDecode #%d\n got: %s want: %s", i,
				spew.Sdump(&msg), spew.Sdump(test.out))
			continue
		}
	}
}
// TestMerkleBlockOverflowErrors performs tests to ensure encoding and decoding
// merkle blocks that are intentionally crafted to use large values for the
// number of hashes and flags are handled properly.  This could otherwise
// potentially be used as an attack vector.
func TestMerkleBlockOverflowErrors(t *testing.T) {
	// Use protocol version 70001 specifically here instead of the latest
	// protocol version because the test data is using bytes encoded with
	// that version.
	pver := uint32(70001)

	// Create bytes for a merkle block that claims to have more than the max
	// allowed tx hashes.
	var buf bytes.Buffer
	wire.TstWriteVarInt(&buf, pver, wire.MaxTxPerBlock+1)
	numHashesOffset := 84
	exceedMaxHashes := make([]byte, numHashesOffset)
	copy(exceedMaxHashes, merkleBlockOneBytes[:numHashesOffset])
	exceedMaxHashes = append(exceedMaxHashes, buf.Bytes()...)

	// Create bytes for a merkle block that claims to have more than the max
	// allowed flag bytes.
	buf.Reset()
	wire.TstWriteVarInt(&buf, pver, wire.MaxFlagsPerMerkleBlock+1)
	numFlagBytesOffset := 117
	exceedMaxFlagBytes := make([]byte, numFlagBytesOffset)
	copy(exceedMaxFlagBytes, merkleBlockOneBytes[:numFlagBytesOffset])
	exceedMaxFlagBytes = append(exceedMaxFlagBytes, buf.Bytes()...)

	tests := []struct {
		buf  []byte // Wire encoding
		pver uint32 // Protocol version for wire encoding
		err  error  // Expected error
	}{
		// Block that claims to have more than max allowed hashes.
		{exceedMaxHashes, pver, &wire.MessageError{}},
		// Block that claims to have more than max allowed flag bytes.
		{exceedMaxFlagBytes, pver, &wire.MessageError{}},
	}

	t.Logf("Running %d tests", len(tests))
	for i, test := range tests {
		// Decode from wire format.
		var msg wire.MsgMerkleBlock
		r := bytes.NewReader(test.buf)
		err := msg.BtcDecode(r, test.pver)
		if reflect.TypeOf(err) != reflect.TypeOf(test.err) {
			t.Errorf("BtcDecode #%d wrong error got: %v, want: %v",
				i, err, reflect.TypeOf(test.err))
			continue
		}
	}
}
// NewMerkleBlock returns a new *wire.MsgMerkleBlock and an array of the matched
// transaction index numbers based on the passed block and filter.
func NewMerkleBlock(block *btcutil.Block, filter *Filter) (*wire.MsgMerkleBlock, []uint32) {
	numTx := uint32(len(block.Transactions()))
	mBlock := merkleBlock{
		numTx:       numTx,
		allHashes:   make([]*wire.ShaHash, 0, numTx),
		matchedBits: make([]byte, 0, numTx),
	}

	// Find and keep track of any transactions that match the filter.
	var matchedIndices []uint32
	for txIndex, tx := range block.Transactions() {
		if filter.MatchTxAndUpdate(tx) {
			mBlock.matchedBits = append(mBlock.matchedBits, 0x01)
			matchedIndices = append(matchedIndices, uint32(txIndex))
		} else {
			mBlock.matchedBits = append(mBlock.matchedBits, 0x00)
		}
		mBlock.allHashes = append(mBlock.allHashes, tx.Sha())
	}

	// Calculate the number of merkle branches (height) in the tree.
	height := uint32(0)
	for mBlock.calcTreeWidth(height) > 1 {
		height++
	}

	// Build the depth-first partial merkle tree.
	mBlock.traverseAndBuild(height, 0)

	// Create and return the merkle block.
	msgMerkleBlock := wire.MsgMerkleBlock{
		Header:       block.MsgBlock().Header,
		Transactions: uint32(mBlock.numTx),
		Hashes:       make([]*wire.ShaHash, 0, len(mBlock.finalHashes)),
		Flags:        make([]byte, (len(mBlock.bits)+7)/8),
	}
	for _, sha := range mBlock.finalHashes {
		msgMerkleBlock.AddTxHash(sha)
	}
	for i := uint32(0); i < uint32(len(mBlock.bits)); i++ {
		msgMerkleBlock.Flags[i/8] |= mBlock.bits[i] << (i % 8)
	}
	return &msgMerkleBlock, matchedIndices
}
// TestMerkleBlockWireErrors performs negative tests against wire encode and
// decode of MsgBlock to confirm error paths work correctly.
func TestMerkleBlockWireErrors(t *testing.T) {
	// Use protocol version 70001 specifically here instead of the latest
	// because the test data is using bytes encoded with that protocol
	// version.
	pver := uint32(70001)
	pverNoMerkleBlock := wire.BIP0037Version - 1
	wireErr := &wire.MessageError{}

	tests := []struct {
		in       *wire.MsgMerkleBlock // Value to encode
		buf      []byte               // Wire encoding
		pver     uint32               // Protocol version for wire encoding
		max      int                  // Max size of fixed buffer to induce errors
		writeErr error                // Expected write error
		readErr  error                // Expected read error
	}{
		// Force error in version.
		{
			&merkleBlockOne, merkleBlockOneBytes, pver, 0,
			io.ErrShortWrite, io.EOF,
		},
		// Force error in prev block hash.
		{
			&merkleBlockOne, merkleBlockOneBytes, pver, 4,
			io.ErrShortWrite, io.EOF,
		},
		// Force error in merkle root.
		{
			&merkleBlockOne, merkleBlockOneBytes, pver, 36,
			io.ErrShortWrite, io.EOF,
		},
		// Force error in timestamp.
		{
			&merkleBlockOne, merkleBlockOneBytes, pver, 68,
			io.ErrShortWrite, io.EOF,
		},
		// Force error in difficulty bits.
		{
			&merkleBlockOne, merkleBlockOneBytes, pver, 72,
			io.ErrShortWrite, io.EOF,
		},
		// Force error in header nonce.
		{
			&merkleBlockOne, merkleBlockOneBytes, pver, 76,
			io.ErrShortWrite, io.EOF,
		},
		// Force error in transaction count.
		{
			&merkleBlockOne, merkleBlockOneBytes, pver, 80,
			io.ErrShortWrite, io.EOF,
		},
		// Force error in num hashes.
		{
			&merkleBlockOne, merkleBlockOneBytes, pver, 84,
			io.ErrShortWrite, io.EOF,
		},
		// Force error in hashes.
		{
			&merkleBlockOne, merkleBlockOneBytes, pver, 85,
			io.ErrShortWrite, io.EOF,
		},
		// Force error in num flag bytes.
		{
			&merkleBlockOne, merkleBlockOneBytes, pver, 117,
			io.ErrShortWrite, io.EOF,
		},
		// Force error in flag bytes.
		{
			&merkleBlockOne, merkleBlockOneBytes, pver, 118,
			io.ErrShortWrite, io.EOF,
		},
		// Force error due to unsupported protocol version.
		{
			&merkleBlockOne, merkleBlockOneBytes, pverNoMerkleBlock,
			119, wireErr, wireErr,
		},
	}

	t.Logf("Running %d tests", len(tests))
	for i, test := range tests {
		// Encode to wire format.
		w := newFixedWriter(test.max)
		err := test.in.BtcEncode(w, test.pver)
		if reflect.TypeOf(err) != reflect.TypeOf(test.writeErr) {
			t.Errorf("BtcEncode #%d wrong error got: %v, want: %v",
				i, err, test.writeErr)
			continue
		}

		// For errors which are not of type wire.MessageError, check
		// them for equality.
		if _, ok := err.(*wire.MessageError); !ok {
			if err != test.writeErr {
				t.Errorf("BtcEncode #%d wrong error got: %v, "+
					"want: %v", i, err, test.writeErr)
				continue
			}
		}

		// Decode from wire format.
		var msg wire.MsgMerkleBlock
		r := newFixedReader(test.max, test.buf)
		err = msg.BtcDecode(r, test.pver)
		if reflect.TypeOf(err) != reflect.TypeOf(test.readErr) {
			t.Errorf("BtcDecode #%d wrong error got: %v, want: %v",
				i, err, test.readErr)
			continue
		}

		// For errors which are not of type wire.MessageError, check
		// them for equality.
		if _, ok := err.(*wire.MessageError); !ok {
			if err != test.readErr {
				t.Errorf("BtcDecode #%d wrong error got: %v, "+
					"want: %v", i, err, test.readErr)
				continue
			}
		}
	}
}
// TestMerkleBlock tests the MsgMerkleBlock API.
func TestMerkleBlock(t *testing.T) {
	pver := wire.ProtocolVersion

	// Block 1 header.
	prevHash := &blockOne.Header.PrevBlock
	merkleHash := &blockOne.Header.MerkleRoot
	bits := blockOne.Header.Bits
	nonce := blockOne.Header.Nonce
	bh := wire.NewBlockHeader(prevHash, merkleHash, bits, nonce)

	// Ensure the command is expected value.
	wantCmd := "merkleblock"
	msg := wire.NewMsgMerkleBlock(bh)
	if cmd := msg.Command(); cmd != wantCmd {
		t.Errorf("NewMsgBlock: wrong command - got %v want %v",
			cmd, wantCmd)
	}

	// Ensure max payload is expected value for latest protocol version.
	// Num addresses (varInt) + max allowed addresses.
	wantPayload := uint32(1000000)
	maxPayload := msg.MaxPayloadLength(pver)
	if maxPayload != wantPayload {
		t.Errorf("MaxPayloadLength: wrong max payload length for "+
			"protocol version %d - got %v, want %v", pver,
			maxPayload, wantPayload)
	}

	// Load maxTxPerBlock hashes
	data := make([]byte, 32)
	for i := 0; i < wire.MaxTxPerBlock; i++ {
		rand.Read(data)
		hash, err := wire.NewShaHash(data)
		if err != nil {
			t.Errorf("NewShaHash failed: %v\n", err)
			return
		}

		if err = msg.AddTxHash(hash); err != nil {
			t.Errorf("AddTxHash failed: %v\n", err)
			return
		}
	}

	// Add one more Tx to test failure.
	rand.Read(data)
	hash, err := wire.NewShaHash(data)
	if err != nil {
		t.Errorf("NewShaHash failed: %v\n", err)
		return
	}

	if err = msg.AddTxHash(hash); err == nil {
		t.Errorf("AddTxHash succeeded when it should have failed")
		return
	}

	// Test encode with latest protocol version.
	var buf bytes.Buffer
	err = msg.BtcEncode(&buf, pver)
	if err != nil {
		t.Errorf("encode of MsgMerkleBlock failed %v err <%v>", msg, err)
	}

	// Test decode with latest protocol version.
	readmsg := wire.MsgMerkleBlock{}
	err = readmsg.BtcDecode(&buf, pver)
	if err != nil {
		t.Errorf("decode of MsgMerkleBlock failed [%v] err <%v>", buf, err)
	}

	// Force extra hash to test maxTxPerBlock.
	msg.Hashes = append(msg.Hashes, hash)
	err = msg.BtcEncode(&buf, pver)
	if err == nil {
		t.Errorf("encode of MsgMerkleBlock succeeded with too many " +
			"tx hashes when it should have failed")
		return
	}

	// Force too many flag bytes to test maxFlagsPerMerkleBlock.
	// Reset the number of hashes back to a valid value.
	msg.Hashes = msg.Hashes[len(msg.Hashes)-1:]
	msg.Flags = make([]byte, wire.MaxFlagsPerMerkleBlock+1)
	err = msg.BtcEncode(&buf, pver)
	if err == nil {
		t.Errorf("encode of MsgMerkleBlock succeeded with too many " +
			"flag bytes when it should have failed")
		return
	}
}