예제 #1
0
func (this *SimpleAuthenticator) Open(buffer *alloc.Buffer) bool {
	len := buffer.Len()
	xtra := 4 - len%4
	if xtra != 0 {
		buffer.Slice(0, len+xtra)
	}
	xorbkd(buffer.Value)
	if xtra != 0 {
		buffer.Slice(0, len)
	}

	fnvHash := fnv.New32a()
	fnvHash.Write(buffer.Value[4:])
	if serial.BytesToUint32(buffer.Value[:4]) != fnvHash.Sum32() {
		return false
	}

	length := serial.BytesToUint16(buffer.Value[4:6])
	if buffer.Len()-6 != int(length) {
		return false
	}

	buffer.SliceFrom(6)

	return true
}
예제 #2
0
파일: crypt.go 프로젝트: v2ray/v2ray-core
// Open implements cipher.AEAD.Open().
func (v *SimpleAuthenticator) Open(dst, nonce, cipherText, extra []byte) ([]byte, error) {
	dst = append(dst, cipherText...)
	dstLen := len(dst)
	xtra := 4 - dstLen%4
	if xtra != 4 {
		dst = append(dst, make([]byte, xtra)...)
	}
	xorbkd(dst)
	if xtra != 4 {
		dst = dst[:dstLen]
	}

	fnvHash := fnv.New32a()
	fnvHash.Write(dst[4:])
	if serial.BytesToUint32(dst[:4]) != fnvHash.Sum32() {
		return nil, crypto.ErrAuthenticationFailed
	}

	length := serial.BytesToUint16(dst[4:6])
	if len(dst)-6 != int(length) {
		return nil, crypto.ErrAuthenticationFailed
	}

	return dst[6:], nil
}
예제 #3
0
func UnmarshalCommand(cmdId byte, data []byte) (protocol.ResponseCommand, error) {
	if len(data) <= 4 {
		return nil, transport.ErrCorruptedPacket
	}
	expectedAuth := Authenticate(data[4:])
	actualAuth := serial.BytesToUint32(data[:4])
	if expectedAuth != actualAuth {
		return nil, transport.ErrCorruptedPacket
	}

	var factory CommandFactory
	switch cmdId {
	case 1:
		factory = new(CommandSwitchAccountFactory)
	default:
		return nil, ErrUnknownCommand
	}
	return factory.Unmarshal(data[4:])
}
예제 #4
0
func UnmarshalCommand(cmdID byte, data []byte) (protocol.ResponseCommand, error) {
	if len(data) <= 4 {
		return nil, errors.New("VMess|Command: Insufficient length.")
	}
	expectedAuth := Authenticate(data[4:])
	actualAuth := serial.BytesToUint32(data[:4])
	if expectedAuth != actualAuth {
		return nil, errors.New("VMess|Command: Invalid auth.")
	}

	var factory CommandFactory
	switch cmdID {
	case 1:
		factory = new(CommandSwitchAccountFactory)
	default:
		return nil, ErrUnknownCommand
	}
	return factory.Unmarshal(data[4:])
}
예제 #5
0
func ReadSegment(buf []byte) (Segment, []byte) {
	if len(buf) <= 4 {
		return nil, nil
	}

	conv := serial.BytesToUint16(buf)
	buf = buf[2:]

	cmd := Command(buf[0])
	opt := SegmentOption(buf[1])
	buf = buf[2:]

	if cmd == CommandData {
		seg := NewDataSegment()
		seg.Conv = conv
		seg.Option = opt
		if len(buf) < 16 {
			return nil, nil
		}
		seg.Timestamp = serial.BytesToUint32(buf)
		buf = buf[4:]

		seg.Number = serial.BytesToUint32(buf)
		buf = buf[4:]

		seg.SendingNext = serial.BytesToUint32(buf)
		buf = buf[4:]

		dataLen := int(serial.BytesToUint16(buf))
		buf = buf[2:]

		if len(buf) < dataLen {
			return nil, nil
		}
		seg.Data = AllocateBuffer().Clear().Append(buf[:dataLen])
		buf = buf[dataLen:]

		return seg, buf
	}

	if cmd == CommandACK {
		seg := NewAckSegment()
		seg.Conv = conv
		seg.Option = opt
		if len(buf) < 13 {
			return nil, nil
		}

		seg.ReceivingWindow = serial.BytesToUint32(buf)
		buf = buf[4:]

		seg.ReceivingNext = serial.BytesToUint32(buf)
		buf = buf[4:]

		seg.Timestamp = serial.BytesToUint32(buf)
		buf = buf[4:]

		count := int(buf[0])
		buf = buf[1:]

		if len(buf) < count*4 {
			return nil, nil
		}
		for i := 0; i < count; i++ {
			seg.PutNumber(serial.BytesToUint32(buf))
			buf = buf[4:]
		}

		return seg, buf
	}

	seg := NewCmdOnlySegment()
	seg.Conv = conv
	seg.Command = cmd
	seg.Option = opt

	if len(buf) < 12 {
		return nil, nil
	}

	seg.SendingNext = serial.BytesToUint32(buf)
	buf = buf[4:]

	seg.ReceivinNext = serial.BytesToUint32(buf)
	buf = buf[4:]

	seg.PeerRTO = serial.BytesToUint32(buf)
	buf = buf[4:]

	return seg, buf
}
예제 #6
0
func (this *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.RequestHeader, error) {
	buffer := make([]byte, 512)

	_, err := io.ReadFull(reader, buffer[:protocol.IDBytesLen])
	if err != nil {
		log.Info("Raw: Failed to read request header: ", err)
		return nil, io.EOF
	}

	user, timestamp, valid := this.userValidator.Get(buffer[:protocol.IDBytesLen])
	if !valid {
		return nil, protocol.ErrInvalidUser
	}

	timestampHash := md5.New()
	timestampHash.Write(hashTimestamp(timestamp))
	iv := timestampHash.Sum(nil)
	account := user.Account.(*vmess.Account)
	aesStream := crypto.NewAesDecryptionStream(account.ID.CmdKey(), iv)
	decryptor := crypto.NewCryptionReader(aesStream, reader)

	nBytes, err := io.ReadFull(decryptor, buffer[:41])
	if err != nil {
		log.Debug("Raw: Failed to read request header (", nBytes, " bytes): ", err)
		return nil, err
	}
	bufferLen := nBytes

	request := &protocol.RequestHeader{
		User:    user,
		Version: buffer[0],
	}

	if request.Version != Version {
		log.Info("Raw: Invalid protocol version ", request.Version)
		return nil, protocol.ErrInvalidVersion
	}

	this.requestBodyIV = append([]byte(nil), buffer[1:17]...)   // 16 bytes
	this.requestBodyKey = append([]byte(nil), buffer[17:33]...) // 16 bytes
	this.responseHeader = buffer[33]                            // 1 byte
	request.Option = protocol.RequestOption(buffer[34])         // 1 byte + 2 bytes reserved
	request.Command = protocol.RequestCommand(buffer[37])

	request.Port = v2net.PortFromBytes(buffer[38:40])

	switch buffer[40] {
	case AddrTypeIPv4:
		nBytes, err = io.ReadFull(decryptor, buffer[41:45]) // 4 bytes
		bufferLen += 4
		if err != nil {
			log.Debug("VMess: Failed to read target IPv4 (", nBytes, " bytes): ", err)
			return nil, err
		}
		request.Address = v2net.IPAddress(buffer[41:45])
	case AddrTypeIPv6:
		nBytes, err = io.ReadFull(decryptor, buffer[41:57]) // 16 bytes
		bufferLen += 16
		if err != nil {
			log.Debug("VMess: Failed to read target IPv6 (", nBytes, " bytes): ", nBytes, err)
			return nil, err
		}
		request.Address = v2net.IPAddress(buffer[41:57])
	case AddrTypeDomain:
		nBytes, err = io.ReadFull(decryptor, buffer[41:42])
		if err != nil {
			log.Debug("VMess: Failed to read target domain (", nBytes, " bytes): ", nBytes, err)
			return nil, err
		}
		domainLength := int(buffer[41])
		if domainLength == 0 {
			return nil, transport.ErrCorruptedPacket
		}
		nBytes, err = io.ReadFull(decryptor, buffer[42:42+domainLength])
		if err != nil {
			log.Debug("VMess: Failed to read target domain (", nBytes, " bytes): ", nBytes, err)
			return nil, err
		}
		bufferLen += 1 + domainLength
		request.Address = v2net.DomainAddress(string(buffer[42 : 42+domainLength]))
	}

	nBytes, err = io.ReadFull(decryptor, buffer[bufferLen:bufferLen+4])
	if err != nil {
		log.Debug("VMess: Failed to read checksum (", nBytes, " bytes): ", nBytes, err)
		return nil, err
	}

	fnv1a := fnv.New32a()
	fnv1a.Write(buffer[:bufferLen])
	actualHash := fnv1a.Sum32()
	expectedHash := serial.BytesToUint32(buffer[bufferLen : bufferLen+4])

	if actualHash != expectedHash {
		return nil, transport.ErrCorruptedPacket
	}

	return request, nil
}
예제 #7
0
파일: auth.go 프로젝트: v2ray/v2ray-core
// Open implements AEAD.Open().
func (v *FnvAuthenticator) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
	if serial.BytesToUint32(ciphertext[:4]) != Authenticate(ciphertext[4:]) {
		return dst, crypto.ErrAuthenticationFailed
	}
	return append(dst, ciphertext[4:]...), nil
}
예제 #8
0
파일: server.go 프로젝트: v2ray/v2ray-core
func (v *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.RequestHeader, error) {
	buffer := make([]byte, 512)

	_, err := io.ReadFull(reader, buffer[:protocol.IDBytesLen])
	if err != nil {
		log.Info("VMess|Server: Failed to read request header: ", err)
		return nil, io.EOF
	}

	user, timestamp, valid := v.userValidator.Get(buffer[:protocol.IDBytesLen])
	if !valid {
		return nil, errors.New("VMess|Server: Invalid user.")
	}

	timestampHash := md5.New()
	timestampHash.Write(hashTimestamp(timestamp))
	iv := timestampHash.Sum(nil)
	account, err := user.GetTypedAccount()
	if err != nil {
		return nil, errors.Base(err).Message("VMess|Server: Failed to get user account.")
	}

	aesStream := crypto.NewAesDecryptionStream(account.(*vmess.InternalAccount).ID.CmdKey(), iv)
	decryptor := crypto.NewCryptionReader(aesStream, reader)

	nBytes, err := io.ReadFull(decryptor, buffer[:41])
	if err != nil {
		return nil, errors.Base(err).Message("VMess|Server: Failed to read request header.")
	}
	bufferLen := nBytes

	request := &protocol.RequestHeader{
		User:    user,
		Version: buffer[0],
	}

	if request.Version != Version {
		return nil, errors.New("VMess|Server: Invalid protocol version ", request.Version)
	}

	v.requestBodyIV = append([]byte(nil), buffer[1:17]...)   // 16 bytes
	v.requestBodyKey = append([]byte(nil), buffer[17:33]...) // 16 bytes
	v.responseHeader = buffer[33]                            // 1 byte
	request.Option = protocol.RequestOption(buffer[34])      // 1 byte
	padingLen := int(buffer[35] >> 4)
	request.Security = protocol.NormSecurity(protocol.Security(buffer[35] & 0x0F))
	// 1 bytes reserved
	request.Command = protocol.RequestCommand(buffer[37])

	request.Port = v2net.PortFromBytes(buffer[38:40])

	switch buffer[40] {
	case AddrTypeIPv4:
		_, err = io.ReadFull(decryptor, buffer[41:45]) // 4 bytes
		bufferLen += 4
		if err != nil {
			return nil, errors.Base(err).Message("VMess|Server: Failed to read IPv4.")
		}
		request.Address = v2net.IPAddress(buffer[41:45])
	case AddrTypeIPv6:
		_, err = io.ReadFull(decryptor, buffer[41:57]) // 16 bytes
		bufferLen += 16
		if err != nil {
			return nil, errors.Base(err).Message("VMess|Server: Failed to read IPv6 address.")
		}
		request.Address = v2net.IPAddress(buffer[41:57])
	case AddrTypeDomain:
		_, err = io.ReadFull(decryptor, buffer[41:42])
		if err != nil {
			return nil, errors.Base(err).Message("VMess:Server: Failed to read domain.")
		}
		domainLength := int(buffer[41])
		if domainLength == 0 {
			return nil, errors.New("VMess|Server: Zero length domain.")
		}
		_, err = io.ReadFull(decryptor, buffer[42:42+domainLength])
		if err != nil {
			return nil, errors.Base(err).Message("VMess|Server: Failed to read domain.")
		}
		bufferLen += 1 + domainLength
		request.Address = v2net.DomainAddress(string(buffer[42 : 42+domainLength]))
	}

	if padingLen > 0 {
		_, err = io.ReadFull(decryptor, buffer[bufferLen:bufferLen+padingLen])
		if err != nil {
			return nil, errors.New("VMess|Server: Failed to read padding.")
		}
		bufferLen += padingLen
	}

	_, err = io.ReadFull(decryptor, buffer[bufferLen:bufferLen+4])
	if err != nil {
		return nil, errors.Base(err).Message("VMess|Server: Failed to read checksum.")
	}

	fnv1a := fnv.New32a()
	fnv1a.Write(buffer[:bufferLen])
	actualHash := fnv1a.Sum32()
	expectedHash := serial.BytesToUint32(buffer[bufferLen : bufferLen+4])

	if actualHash != expectedHash {
		return nil, errors.New("VMess|Server: Invalid auth.")
	}

	if request.Address == nil {
		return nil, errors.New("VMess|Server: Invalid remote address.")
	}

	return request, nil
}
예제 #9
0
func (this *AuthChunkReader) Read() (*alloc.Buffer, error) {
	var buffer *alloc.Buffer
	if this.last != nil {
		buffer = this.last
		this.last = nil
	} else {
		buffer = alloc.NewBufferWithSize(4096).Clear()
	}

	if this.chunkLength == -1 {
		for buffer.Len() < 6 {
			_, err := buffer.FillFrom(this.reader)
			if err != nil {
				buffer.Release()
				return nil, io.ErrUnexpectedEOF
			}
		}
		length := serial.BytesToUint16(buffer.Value[:2])
		this.chunkLength = int(length) - 4
		this.validator = NewValidator(serial.BytesToUint32(buffer.Value[2:6]))
		buffer.SliceFrom(6)
		if buffer.Len() < this.chunkLength && this.chunkLength <= 2048 {
			_, err := buffer.FillFrom(this.reader)
			if err != nil {
				buffer.Release()
				return nil, io.ErrUnexpectedEOF
			}
		}
	} else if buffer.Len() < this.chunkLength {
		_, err := buffer.FillFrom(this.reader)
		if err != nil {
			buffer.Release()
			return nil, io.ErrUnexpectedEOF
		}
	}

	if this.chunkLength == 0 {
		buffer.Release()
		return nil, io.EOF
	}

	if buffer.Len() < this.chunkLength {
		this.validator.Consume(buffer.Value)
		this.chunkLength -= buffer.Len()
	} else {
		this.validator.Consume(buffer.Value[:this.chunkLength])
		if !this.validator.Validate() {
			buffer.Release()
			return nil, transport.ErrCorruptedPacket
		}
		leftLength := buffer.Len() - this.chunkLength
		if leftLength > 0 {
			this.last = alloc.NewBufferWithSize(leftLength + 4096).Clear()
			this.last.Append(buffer.Value[this.chunkLength:])
			buffer.Slice(0, this.chunkLength)
		}

		this.chunkLength = -1
		this.validator = nil
	}

	return buffer, nil
}