Ejemplo n.º 1
0
func ReadUDPRequest(packet []byte) (*Socks5UDPRequest, error) {
	if len(packet) < 5 {
		return nil, transport.CorruptedPacket
	}
	request := new(Socks5UDPRequest)

	// packet[0] and packet[1] are reserved
	request.Fragment = packet[2]

	addrType := packet[3]
	var dataBegin int

	switch addrType {
	case AddrTypeIPv4:
		if len(packet) < 10 {
			return nil, transport.CorruptedPacket
		}
		ip := packet[4:8]
		request.Port = v2net.PortFromBytes(packet[8:10])
		request.Address = v2net.IPAddress(ip)
		dataBegin = 10
	case AddrTypeIPv6:
		if len(packet) < 22 {
			return nil, transport.CorruptedPacket
		}
		ip := packet[4:20]
		request.Port = v2net.PortFromBytes(packet[20:22])
		request.Address = v2net.IPAddress(ip)
		dataBegin = 22
	case AddrTypeDomain:
		domainLength := int(packet[4])
		if len(packet) < 5+domainLength+2 {
			return nil, transport.CorruptedPacket
		}
		domain := string(packet[5 : 5+domainLength])
		request.Port = v2net.PortFromBytes(packet[5+domainLength : 5+domainLength+2])
		maybeIP := net.ParseIP(domain)
		if maybeIP != nil {
			request.Address = v2net.IPAddress(maybeIP)
		} else {
			request.Address = v2net.DomainAddress(domain)
		}
		dataBegin = 5 + domainLength + 2
	default:
		log.Warning("Unknown address type ", addrType)
		return nil, ErrorUnknownAddressType
	}

	if len(packet) > dataBegin {
		request.Data = alloc.NewBuffer().Clear().Append(packet[dataBegin:])
	}

	return request, nil
}
Ejemplo n.º 2
0
func (this *SwitchAccount) Unmarshal(data []byte) error {
	lenHost := int(data[0])
	if len(data) < lenHost+1 {
		return transport.CorruptedPacket
	}
	this.Host = v2net.ParseAddress(string(data[1 : 1+lenHost]))
	portStart := 1 + lenHost
	if len(data) < portStart+2 {
		return transport.CorruptedPacket
	}
	this.Port = v2net.PortFromBytes(data[portStart : portStart+2])
	idStart := portStart + 2
	if len(data) < idStart+16 {
		return transport.CorruptedPacket
	}
	this.ID, _ = uuid.ParseBytes(data[idStart : idStart+16])
	alterIdStart := idStart + 16
	if len(data) < alterIdStart+2 {
		return transport.CorruptedPacket
	}
	this.AlterIds = serial.ParseUint16(data[alterIdStart : alterIdStart+2])
	levelStart := alterIdStart + 2
	if len(data) < levelStart {
		return transport.CorruptedPacket
	}
	this.Level = vmess.UserLevel(data[levelStart])
	timeStart := levelStart + 1
	if len(data) < timeStart {
		return transport.CorruptedPacket
	}
	this.ValidMin = data[timeStart]
	return nil
}
Ejemplo n.º 3
0
func (this *SwitchAccount) Unmarshal(data []byte) error {
	lenHost := int(data[0])
	if len(data) < lenHost+1 {
		return transport.CorruptedPacket
	}
	this.Host = v2net.ParseAddress(string(data[1 : 1+lenHost]))
	portStart := 1 + lenHost
	if len(data) < portStart+2 {
		return transport.CorruptedPacket
	}
	this.Port = v2net.PortFromBytes(data[portStart : portStart+2])
	idStart := portStart + 2
	if len(data) < idStart+16 {
		return transport.CorruptedPacket
	}
	this.ID, _ = uuid.ParseBytes(data[idStart : idStart+16])
	alterIdStart := idStart + 16
	if len(data) < alterIdStart+2 {
		return transport.CorruptedPacket
	}
	this.AlterIds = serial.ParseUint16(data[alterIdStart : alterIdStart+2])
	timeStart := alterIdStart + 2
	if len(data) < timeStart+8 {
		return transport.CorruptedPacket
	}
	this.ValidUntil = time.Unix(serial.BytesLiteral(data[timeStart:timeStart+8]).Int64Value(), 0)
	return nil
}
Ejemplo n.º 4
0
func ReadRequest(reader io.Reader) (*Request, error) {
	buffer := alloc.NewSmallBuffer()
	defer buffer.Release()

	_, err := v2net.ReadAllBytes(reader, buffer.Value[:1])
	if err != nil {
		log.Error("Shadowsocks: Failed to read address type: ", err)
		return nil, transport.CorruptedPacket
	}

	request := new(Request)

	addrType := buffer.Value[0]
	switch addrType {
	case AddrTypeIPv4:
		_, err := v2net.ReadAllBytes(reader, buffer.Value[:4])
		if err != nil {
			log.Error("Shadowsocks: Failed to read IPv4 address: ", err)
			return nil, transport.CorruptedPacket
		}
		request.Address = v2net.IPAddress(buffer.Value[:4])
	case AddrTypeIPv6:
		_, err := v2net.ReadAllBytes(reader, buffer.Value[:16])
		if err != nil {
			log.Error("Shadowsocks: Failed to read IPv6 address: ", err)
			return nil, transport.CorruptedPacket
		}
		request.Address = v2net.IPAddress(buffer.Value[:16])
	case AddrTypeDomain:
		_, err := v2net.ReadAllBytes(reader, buffer.Value[:1])
		if err != nil {
			log.Error("Shadowsocks: Failed to read domain lenth: ", err)
			return nil, transport.CorruptedPacket
		}
		domainLength := int(buffer.Value[0])
		_, err = v2net.ReadAllBytes(reader, buffer.Value[:domainLength])
		if err != nil {
			log.Error("Shadowsocks: Failed to read domain: ", err)
			return nil, transport.CorruptedPacket
		}
		request.Address = v2net.DomainAddress(string(buffer.Value[:domainLength]))
	default:
		log.Error("Shadowsocks: Unknown address type: ", addrType)
		return nil, transport.CorruptedPacket
	}

	_, err = v2net.ReadAllBytes(reader, buffer.Value[:2])
	if err != nil {
		log.Error("Shadowsocks: Failed to read port: ", err)
		return nil, transport.CorruptedPacket
	}

	request.Port = v2net.PortFromBytes(buffer.Value[:2])
	return request, nil
}
Ejemplo n.º 5
0
func ReadRequest(reader io.Reader) (request *Socks5Request, err error) {
	buffer := alloc.NewSmallBuffer()
	defer buffer.Release()

	_, err = io.ReadFull(reader, buffer.Value[:4])
	if err != nil {
		return
	}

	request = &Socks5Request{
		Version: buffer.Value[0],
		Command: buffer.Value[1],
		// buffer[2] is a reserved field
		AddrType: buffer.Value[3],
	}
	switch request.AddrType {
	case AddrTypeIPv4:
		_, err = io.ReadFull(reader, request.IPv4[:])
		if err != nil {
			return
		}
	case AddrTypeDomain:
		_, err = io.ReadFull(reader, buffer.Value[0:1])
		if err != nil {
			return
		}
		domainLength := buffer.Value[0]
		_, err = io.ReadFull(reader, buffer.Value[:domainLength])
		if err != nil {
			return
		}

		request.Domain = string(append([]byte(nil), buffer.Value[:domainLength]...))
	case AddrTypeIPv6:
		_, err = io.ReadFull(reader, request.IPv6[:])
		if err != nil {
			return
		}
	default:
		log.Warning("Socks: Unexpected address type ", request.AddrType)
		err = transport.ErrCorruptedPacket
		return
	}

	_, err = io.ReadFull(reader, buffer.Value[:2])
	if err != nil {
		return
	}

	request.Port = v2net.PortFromBytes(buffer.Value[:2])
	return
}
Ejemplo n.º 6
0
func ReadAuthentication(reader io.Reader) (auth Socks5AuthenticationRequest, auth4 Socks4AuthenticationRequest, err error) {
	buffer := alloc.NewSmallBuffer()
	defer buffer.Release()

	nBytes, err := reader.Read(buffer.Value)
	if err != nil {
		return
	}
	if nBytes < 2 {
		log.Warning("Socks: expected 2 bytes read, but only ", nBytes, " bytes read")
		err = transport.ErrCorruptedPacket
		return
	}

	if buffer.Value[0] == socks4Version {
		auth4.Version = buffer.Value[0]
		auth4.Command = buffer.Value[1]
		auth4.Port = v2net.PortFromBytes(buffer.Value[2:4])
		copy(auth4.IP[:], buffer.Value[4:8])
		err = Socks4Downgrade
		return
	}

	auth.version = buffer.Value[0]
	if auth.version != socksVersion {
		log.Warning("Socks: Unknown protocol version ", auth.version)
		err = proxy.ErrInvalidProtocolVersion
		return
	}

	auth.nMethods = buffer.Value[1]
	if auth.nMethods <= 0 {
		log.Warning("Socks: Zero length of authentication methods")
		err = proxy.ErrInvalidAuthentication
		return
	}

	if nBytes-2 != int(auth.nMethods) {
		log.Warning("Socks: Unmatching number of auth methods, expecting ", auth.nMethods, ", but got ", nBytes)
		err = proxy.ErrInvalidAuthentication
		return
	}
	copy(auth.authMethods[:], buffer.Value[2:nBytes])
	return
}
Ejemplo n.º 7
0
func (this *CommandSwitchAccountFactory) Unmarshal(data []byte) (interface{}, error) {
	cmd := new(protocol.CommandSwitchAccount)
	if len(data) == 0 {
		return nil, transport.ErrCorruptedPacket
	}
	lenHost := int(data[0])
	if len(data) < lenHost+1 {
		return nil, transport.ErrCorruptedPacket
	}
	if lenHost > 0 {
		cmd.Host = v2net.ParseAddress(string(data[1 : 1+lenHost]))
	}
	portStart := 1 + lenHost
	if len(data) < portStart+2 {
		return nil, transport.ErrCorruptedPacket
	}
	cmd.Port = v2net.PortFromBytes(data[portStart : portStart+2])
	idStart := portStart + 2
	if len(data) < idStart+16 {
		return nil, transport.ErrCorruptedPacket
	}
	cmd.ID, _ = uuid.ParseBytes(data[idStart : idStart+16])
	alterIdStart := idStart + 16
	if len(data) < alterIdStart+2 {
		return nil, transport.ErrCorruptedPacket
	}
	cmd.AlterIds = serial.BytesToUint16(data[alterIdStart : alterIdStart+2])
	levelStart := alterIdStart + 2
	if len(data) < levelStart+1 {
		return nil, transport.ErrCorruptedPacket
	}
	cmd.Level = protocol.UserLevel(data[levelStart])
	timeStart := levelStart + 1
	if len(data) < timeStart {
		return nil, transport.ErrCorruptedPacket
	}
	cmd.ValidMin = data[timeStart]
	return cmd, nil
}
Ejemplo n.º 8
0
// Read reads a VMessRequest from a byte stream.
func (this *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) {
	buffer := alloc.NewSmallBuffer()
	defer buffer.Release()

	nBytes, err := io.ReadFull(reader, buffer.Value[:proto.IDBytesLen])
	if err != nil {
		log.Debug("VMess: Failed to read request ID (", nBytes, " bytes): ", err)
		return nil, err
	}

	userObj, timeSec, valid := this.vUserSet.GetUser(buffer.Value[:nBytes])
	if !valid {
		return nil, proxy.ErrorInvalidAuthentication
	}

	timestampHash := TimestampHash()
	timestampHash.Write(timeSec.HashBytes())
	iv := timestampHash.Sum(nil)
	aesStream, err := v2crypto.NewAesDecryptionStream(userObj.ID.CmdKey(), iv)
	if err != nil {
		log.Debug("VMess: Failed to create AES stream: ", err)
		return nil, err
	}

	decryptor := v2crypto.NewCryptionReader(aesStream, reader)

	nBytes, err = io.ReadFull(decryptor, buffer.Value[:41])
	if err != nil {
		log.Debug("VMess: Failed to read request header (", nBytes, " bytes): ", err)
		return nil, err
	}
	bufferLen := nBytes

	request := &VMessRequest{
		User:    userObj,
		Version: buffer.Value[0],
	}

	if request.Version != Version {
		log.Warning("VMess: Invalid protocol version ", request.Version)
		return nil, proxy.ErrorInvalidProtocolVersion
	}

	request.RequestIV = append([]byte(nil), buffer.Value[1:17]...)   // 16 bytes
	request.RequestKey = append([]byte(nil), buffer.Value[17:33]...) // 16 bytes
	request.ResponseHeader = buffer.Value[33]                        // 1 byte
	request.Option = buffer.Value[34]                                // 1 byte + 2 bytes reserved
	request.Command = buffer.Value[37]

	request.Port = v2net.PortFromBytes(buffer.Value[38:40])

	switch buffer.Value[40] {
	case addrTypeIPv4:
		nBytes, err = io.ReadFull(decryptor, buffer.Value[41:45]) // 4 bytes
		bufferLen += 4
		if err != nil {
			log.Debug("VMess: Failed to read target IPv4 (", nBytes, " bytes): ", err)
			return nil, err
		}
		request.Address = v2net.IPAddress(buffer.Value[41:45])
	case addrTypeIPv6:
		nBytes, err = io.ReadFull(decryptor, buffer.Value[41:57]) // 16 bytes
		bufferLen += 16
		if err != nil {
			log.Debug("VMess: Failed to read target IPv6 (", nBytes, " bytes): ", nBytes, err)
			return nil, err
		}
		request.Address = v2net.IPAddress(buffer.Value[41:57])
	case addrTypeDomain:
		nBytes, err = io.ReadFull(decryptor, buffer.Value[41:42])
		if err != nil {
			log.Debug("VMess: Failed to read target domain (", nBytes, " bytes): ", nBytes, err)
			return nil, err
		}
		domainLength := int(buffer.Value[41])
		if domainLength == 0 {
			return nil, transport.ErrorCorruptedPacket
		}
		nBytes, err = io.ReadFull(decryptor, buffer.Value[42:42+domainLength])
		if err != nil {
			log.Debug("VMess: Failed to read target domain (", nBytes, " bytes): ", nBytes, err)
			return nil, err
		}
		bufferLen += 1 + domainLength
		domainBytes := append([]byte(nil), buffer.Value[42:42+domainLength]...)
		request.Address = v2net.DomainAddress(string(domainBytes))
	}

	nBytes, err = io.ReadFull(decryptor, buffer.Value[bufferLen:bufferLen+4])
	if err != nil {
		log.Debug("VMess: Failed to read checksum (", nBytes, " bytes): ", nBytes, err)
		return nil, err
	}

	fnv1a := fnv.New32a()
	fnv1a.Write(buffer.Value[:bufferLen])
	actualHash := fnv1a.Sum32()
	expectedHash := binary.BigEndian.Uint32(buffer.Value[bufferLen : bufferLen+4])

	if actualHash != expectedHash {
		return nil, transport.ErrorCorruptedPacket
	}

	return request, nil
}
Ejemplo n.º 9
0
func ReadRequest(reader io.Reader, auth *Authenticator, udp bool) (*Request, error) {
	buffer := alloc.NewSmallBuffer()
	defer buffer.Release()

	_, err := io.ReadFull(reader, buffer.Value[:1])
	if err != nil {
		log.Error("Shadowsocks: Failed to read address type: ", err)
		return nil, transport.ErrorCorruptedPacket
	}
	lenBuffer := 1

	request := new(Request)

	addrType := (buffer.Value[0] & 0x0F)
	if (buffer.Value[0] & 0x10) == 0x10 {
		request.OTA = true
	}
	switch addrType {
	case AddrTypeIPv4:
		_, err := io.ReadFull(reader, buffer.Value[lenBuffer:lenBuffer+4])
		if err != nil {
			log.Error("Shadowsocks: Failed to read IPv4 address: ", err)
			return nil, transport.ErrorCorruptedPacket
		}
		request.Address = v2net.IPAddress(buffer.Value[lenBuffer : lenBuffer+4])
		lenBuffer += 4
	case AddrTypeIPv6:
		_, err := io.ReadFull(reader, buffer.Value[lenBuffer:lenBuffer+16])
		if err != nil {
			log.Error("Shadowsocks: Failed to read IPv6 address: ", err)
			return nil, transport.ErrorCorruptedPacket
		}
		request.Address = v2net.IPAddress(buffer.Value[lenBuffer : lenBuffer+16])
		lenBuffer += 16
	case AddrTypeDomain:
		_, err := io.ReadFull(reader, buffer.Value[lenBuffer:lenBuffer+1])
		if err != nil {
			log.Error("Shadowsocks: Failed to read domain lenth: ", err)
			return nil, transport.ErrorCorruptedPacket
		}
		domainLength := int(buffer.Value[lenBuffer])
		lenBuffer++
		_, err = io.ReadFull(reader, buffer.Value[lenBuffer:lenBuffer+domainLength])
		if err != nil {
			log.Error("Shadowsocks: Failed to read domain: ", err)
			return nil, transport.ErrorCorruptedPacket
		}
		request.Address = v2net.DomainAddress(string(buffer.Value[lenBuffer : lenBuffer+domainLength]))
		lenBuffer += domainLength
	default:
		log.Error("Shadowsocks: Unknown address type: ", addrType)
		return nil, transport.ErrorCorruptedPacket
	}

	_, err = io.ReadFull(reader, buffer.Value[lenBuffer:lenBuffer+2])
	if err != nil {
		log.Error("Shadowsocks: Failed to read port: ", err)
		return nil, transport.ErrorCorruptedPacket
	}

	request.Port = v2net.PortFromBytes(buffer.Value[lenBuffer : lenBuffer+2])
	lenBuffer += 2

	var authBytes []byte

	if udp {
		nBytes, err := reader.Read(buffer.Value[lenBuffer:])
		if err != nil {
			log.Error("Shadowsocks: Failed to read UDP payload: ", err)
		}
		buffer.Slice(0, lenBuffer+nBytes)
		if request.OTA {
			authBytes = buffer.Value[lenBuffer+nBytes-AuthSize:]
			request.UDPPayload = alloc.NewSmallBuffer().Clear().Append(buffer.Value[lenBuffer : lenBuffer+nBytes-AuthSize])
			lenBuffer = lenBuffer + nBytes - AuthSize
		} else {
			request.UDPPayload = alloc.NewSmallBuffer().Clear().Append(buffer.Value[lenBuffer:])
		}
	} else {
		if request.OTA {
			authBytes = buffer.Value[lenBuffer : lenBuffer+AuthSize]
			_, err = io.ReadFull(reader, authBytes)
			if err != nil {
				log.Error("Shadowsocks: Failed to read OTA: ", err)
				return nil, transport.ErrorCorruptedPacket
			}
		}
	}

	if request.OTA {
		actualAuth := auth.Authenticate(nil, buffer.Value[0:lenBuffer])
		if !serial.BytesLiteral(actualAuth).Equals(serial.BytesLiteral(authBytes)) {
			log.Error("Shadowsocks: Invalid OTA: ", actualAuth)
			return nil, transport.ErrorCorruptedPacket
		}
	}

	return request, nil
}
Ejemplo n.º 10
0
func (this *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.RequestHeader, error) {
	buffer := alloc.NewSmallBuffer()
	defer buffer.Release()

	_, err := io.ReadFull(reader, buffer.Value[:protocol.IDBytesLen])
	if err != nil {
		log.Error("Raw: Failed to read request header: ", err)
		return nil, err
	}

	user, timestamp, valid := this.userValidator.Get(buffer.Value[:protocol.IDBytesLen])
	if !valid {
		return nil, protocol.ErrorInvalidUser
	}

	timestampHash := md5.New()
	timestampHash.Write(hashTimestamp(timestamp))
	iv := timestampHash.Sum(nil)
	aesStream := crypto.NewAesDecryptionStream(user.ID.CmdKey(), iv)
	decryptor := crypto.NewCryptionReader(aesStream, reader)

	nBytes, err := io.ReadFull(decryptor, buffer.Value[:41])
	if err != nil {
		log.Debug("Raw: Failed to read request header (", nBytes, " bytes): ", err)
		return nil, err
	}
	bufferLen := nBytes

	request := &protocol.RequestHeader{
		User:    user,
		Version: buffer.Value[0],
	}

	if request.Version != Version {
		log.Warning("Raw: Invalid protocol version ", request.Version)
		return nil, protocol.ErrorInvalidVersion
	}

	this.requestBodyIV = append([]byte(nil), buffer.Value[1:17]...)   // 16 bytes
	this.requestBodyKey = append([]byte(nil), buffer.Value[17:33]...) // 16 bytes
	this.responseHeader = buffer.Value[33]                            // 1 byte
	request.Option = protocol.RequestOption(buffer.Value[34])         // 1 byte + 2 bytes reserved
	request.Command = protocol.RequestCommand(buffer.Value[37])

	request.Port = v2net.PortFromBytes(buffer.Value[38:40])

	switch buffer.Value[40] {
	case AddrTypeIPv4:
		nBytes, err = io.ReadFull(decryptor, buffer.Value[41:45]) // 4 bytes
		bufferLen += 4
		if err != nil {
			log.Debug("VMess: Failed to read target IPv4 (", nBytes, " bytes): ", err)
			return nil, err
		}
		request.Address = v2net.IPAddress(buffer.Value[41:45])
	case AddrTypeIPv6:
		nBytes, err = io.ReadFull(decryptor, buffer.Value[41:57]) // 16 bytes
		bufferLen += 16
		if err != nil {
			log.Debug("VMess: Failed to read target IPv6 (", nBytes, " bytes): ", nBytes, err)
			return nil, err
		}
		request.Address = v2net.IPAddress(buffer.Value[41:57])
	case AddrTypeDomain:
		nBytes, err = io.ReadFull(decryptor, buffer.Value[41:42])
		if err != nil {
			log.Debug("VMess: Failed to read target domain (", nBytes, " bytes): ", nBytes, err)
			return nil, err
		}
		domainLength := int(buffer.Value[41])
		if domainLength == 0 {
			return nil, transport.ErrorCorruptedPacket
		}
		nBytes, err = io.ReadFull(decryptor, buffer.Value[42:42+domainLength])
		if err != nil {
			log.Debug("VMess: Failed to read target domain (", nBytes, " bytes): ", nBytes, err)
			return nil, err
		}
		bufferLen += 1 + domainLength
		domainBytes := append([]byte(nil), buffer.Value[42:42+domainLength]...)
		request.Address = v2net.DomainAddress(string(domainBytes))
	}

	nBytes, err = io.ReadFull(decryptor, buffer.Value[bufferLen:bufferLen+4])
	if err != nil {
		log.Debug("VMess: Failed to read checksum (", nBytes, " bytes): ", nBytes, err)
		return nil, err
	}

	fnv1a := fnv.New32a()
	fnv1a.Write(buffer.Value[:bufferLen])
	actualHash := fnv1a.Sum32()
	expectedHash := serial.BytesLiteral(buffer.Value[bufferLen : bufferLen+4]).Uint32Value()

	if actualHash != expectedHash {
		return nil, transport.ErrorCorruptedPacket
	}

	return request, nil
}
Ejemplo n.º 11
0
// Read reads a VMessRequest from a byte stream.
func (this *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) {
	buffer := alloc.NewSmallBuffer()

	nBytes, err := v2net.ReadAllBytes(reader, buffer.Value[:vmess.IDBytesLen])
	if err != nil {
		return nil, err
	}

	userObj, timeSec, valid := this.vUserSet.GetUser(buffer.Value[:nBytes])
	if !valid {
		return nil, proxyerrors.InvalidAuthentication
	}

	aesStream, err := v2crypto.NewAesDecryptionStream(userObj.ID().CmdKey(), user.Int64Hash(timeSec))
	if err != nil {
		return nil, err
	}

	decryptor := v2crypto.NewCryptionReader(aesStream, reader)

	nBytes, err = v2net.ReadAllBytes(decryptor, buffer.Value[:41])
	if err != nil {
		return nil, err
	}
	bufferLen := nBytes

	request := &VMessRequest{
		User:    userObj,
		Version: buffer.Value[0],
	}

	if request.Version != Version {
		log.Warning("Invalid protocol version %d", request.Version)
		return nil, proxyerrors.InvalidProtocolVersion
	}

	request.RequestIV = buffer.Value[1:17]       // 16 bytes
	request.RequestKey = buffer.Value[17:33]     // 16 bytes
	request.ResponseHeader = buffer.Value[33:37] // 4 bytes
	request.Command = buffer.Value[37]

	request.Port = v2net.PortFromBytes(buffer.Value[38:40])

	switch buffer.Value[40] {
	case addrTypeIPv4:
		_, err = v2net.ReadAllBytes(decryptor, buffer.Value[41:45]) // 4 bytes
		bufferLen += 4
		if err != nil {
			return nil, err
		}
		request.Address = v2net.IPAddress(buffer.Value[41:45])
	case addrTypeIPv6:
		_, err = v2net.ReadAllBytes(decryptor, buffer.Value[41:57]) // 16 bytes
		bufferLen += 16
		if err != nil {
			return nil, err
		}
		request.Address = v2net.IPAddress(buffer.Value[41:57])
	case addrTypeDomain:
		_, err = v2net.ReadAllBytes(decryptor, buffer.Value[41:42])
		if err != nil {
			return nil, err
		}
		domainLength := int(buffer.Value[41])
		_, err = v2net.ReadAllBytes(decryptor, buffer.Value[42:42+domainLength])
		if err != nil {
			return nil, err
		}
		bufferLen += 1 + domainLength
		request.Address = v2net.DomainAddress(string(buffer.Value[42 : 42+domainLength]))
	}

	_, err = v2net.ReadAllBytes(decryptor, buffer.Value[bufferLen:bufferLen+4])
	if err != nil {
		return nil, err
	}

	fnv1a := fnv.New32a()
	fnv1a.Write(buffer.Value[:bufferLen])
	actualHash := fnv1a.Sum32()
	expectedHash := binary.BigEndian.Uint32(buffer.Value[bufferLen : bufferLen+4])

	if actualHash != expectedHash {
		return nil, transport.CorruptedPacket
	}

	return request, nil
}
Ejemplo n.º 12
0
func ReadRequest(reader io.Reader) (request *Socks5Request, err error) {
	buffer := alloc.NewSmallBuffer()
	defer buffer.Release()

	nBytes, err := reader.Read(buffer.Value[:4])
	if err != nil {
		return
	}
	if nBytes < 4 {
		err = transport.CorruptedPacket
		return
	}
	request = &Socks5Request{
		Version: buffer.Value[0],
		Command: buffer.Value[1],
		// buffer[2] is a reserved field
		AddrType: buffer.Value[3],
	}
	switch request.AddrType {
	case AddrTypeIPv4:
		nBytes, err = reader.Read(request.IPv4[:])
		if err != nil {
			return
		}
		if nBytes != 4 {
			err = transport.CorruptedPacket
			return
		}
	case AddrTypeDomain:
		nBytes, err = reader.Read(buffer.Value[0:1])
		if err != nil {
			return
		}
		domainLength := buffer.Value[0]
		nBytes, err = reader.Read(buffer.Value[:domainLength])
		if err != nil {
			return
		}

		if nBytes != int(domainLength) {
			log.Info("Unable to read domain with %d bytes, expecting %d bytes", nBytes, domainLength)
			err = transport.CorruptedPacket
			return
		}
		request.Domain = string(append([]byte(nil), buffer.Value[:domainLength]...))
	case AddrTypeIPv6:
		nBytes, err = reader.Read(request.IPv6[:])
		if err != nil {
			return
		}
		if nBytes != 16 {
			err = transport.CorruptedPacket
			return
		}
	default:
		log.Info("Unexpected address type %d", request.AddrType)
		err = transport.CorruptedPacket
		return
	}

	nBytes, err = reader.Read(buffer.Value[:2])
	if err != nil {
		return
	}
	if nBytes != 2 {
		err = transport.CorruptedPacket
		return
	}

	request.Port = v2net.PortFromBytes(buffer.Value[:2])
	return
}