Ejemplo n.º 1
0
func (this *AuthChunkReader) Read() (*alloc.Buffer, error) {
	buffer := alloc.NewBuffer()
	if _, err := io.ReadFull(this.reader, buffer.Value[:2]); err != nil {
		buffer.Release()
		return nil, err
	}

	length := serial.BytesLiteral(buffer.Value[:2]).Uint16Value()
	if _, err := io.ReadFull(this.reader, buffer.Value[:length]); err != nil {
		buffer.Release()
		return nil, err
	}
	buffer.Slice(0, int(length))

	fnvHash := fnv.New32a()
	fnvHash.Write(buffer.Value[4:])
	expAuth := serial.BytesLiteral(fnvHash.Sum(nil))
	actualAuth := serial.BytesLiteral(buffer.Value[:4])
	if !actualAuth.Equals(expAuth) {
		buffer.Release()
		return nil, transport.ErrorCorruptedPacket
	}
	buffer.SliceFrom(4)
	return buffer, nil
}
Ejemplo n.º 2
0
func (this *ChunkReader) Read() (*alloc.Buffer, error) {
	buffer := alloc.NewLargeBuffer()
	if _, err := io.ReadFull(this.reader, buffer.Value[:2]); err != nil {
		buffer.Release()
		return nil, err
	}
	// There is a potential buffer overflow here. Large buffer is 64K bytes,
	// while uin16 + 10 will be more than that
	length := serial.BytesLiteral(buffer.Value[:2]).Uint16Value() + AuthSize
	if _, err := io.ReadFull(this.reader, buffer.Value[:length]); err != nil {
		buffer.Release()
		return nil, err
	}
	buffer.Slice(0, int(length))

	authBytes := buffer.Value[:AuthSize]
	payload := buffer.Value[AuthSize:]

	actualAuthBytes := this.auth.Authenticate(nil, payload)
	if !serial.BytesLiteral(authBytes).Equals(serial.BytesLiteral(actualAuthBytes)) {
		buffer.Release()
		log.Debug("AuthenticationReader: Unexpected auth: ", authBytes)
		return nil, transport.ErrorCorruptedPacket
	}
	buffer.Value = payload

	return buffer, nil
}
Ejemplo n.º 3
0
func (this *VMessOutboundHandler) handleCommand(dest v2net.Destination, cmdId byte, data []byte) {
	if len(data) < 4 {
		return
	}
	fnv1hash := fnv.New32a()
	fnv1hash.Write(data[4:])
	actualHashValue := fnv1hash.Sum32()
	expectedHashValue := serial.BytesLiteral(data[:4]).Uint32Value()
	if actualHashValue != expectedHashValue {
		return
	}
	data = data[4:]
	cmd, err := command.CreateResponseCommand(cmdId)
	if err != nil {
		log.Warning("VMessOut: Unknown response command (", cmdId, "): ", err)
		return
	}
	if err := cmd.Unmarshal(data); err != nil {
		log.Warning("VMessOut: Failed to parse response command: ", err)
		return
	}
	switch typedCommand := cmd.(type) {
	case *command.SwitchAccount:
		if typedCommand.Host == nil {
			typedCommand.Host = dest.Address()
		}
		this.handleSwitchAccount(typedCommand)
	default:
	}
}
Ejemplo n.º 4
0
func (this *SwitchAccount) Unmarshal(data []byte) error {
	lenHost := int(data[0])
	if len(data) < lenHost+1 {
		return transport.CorruptedPacket
	}
	this.Host = v2net.ParseAddress(string(data[1 : 1+lenHost]))
	portStart := 1 + lenHost
	if len(data) < portStart+2 {
		return transport.CorruptedPacket
	}
	this.Port = v2net.PortFromBytes(data[portStart : portStart+2])
	idStart := portStart + 2
	if len(data) < idStart+16 {
		return transport.CorruptedPacket
	}
	this.ID, _ = uuid.ParseBytes(data[idStart : idStart+16])
	alterIdStart := idStart + 16
	if len(data) < alterIdStart+2 {
		return transport.CorruptedPacket
	}
	this.AlterIds = serial.ParseUint16(data[alterIdStart : alterIdStart+2])
	timeStart := alterIdStart + 2
	if len(data) < timeStart+8 {
		return transport.CorruptedPacket
	}
	this.ValidUntil = time.Unix(serial.BytesLiteral(data[timeStart:timeStart+8]).Int64Value(), 0)
	return nil
}
Ejemplo n.º 5
0
func (this *SwitchAccount) Unmarshal(data []byte) error {
	if len(data) != 24 {
		return transport.CorruptedPacket
	}
	this.ID, _ = uuid.ParseBytes(data[0:16])
	this.ValidUntil = time.Unix(serial.BytesLiteral(data[16:24]).Int64Value(), 0)
	return nil
}
Ejemplo n.º 6
0
// IPAddress creates an Address with given IP and port.
func IPAddress(ip []byte) Address {
	switch len(ip) {
	case net.IPv4len:
		var addr IPv4Address = [4]byte{ip[0], ip[1], ip[2], ip[3]}
		return &addr
	case net.IPv6len:
		if serial.BytesLiteral(ip[0:10]).All(0) && serial.BytesLiteral(ip[10:12]).All(0xff) {
			return IPAddress(ip[12:16])
		}
		var addr IPv6Address = [16]byte{
			ip[0], ip[1], ip[2], ip[3],
			ip[4], ip[5], ip[6], ip[7],
			ip[8], ip[9], ip[10], ip[11],
			ip[12], ip[13], ip[14], ip[15],
		}
		return &addr
	default:
		log.Error("Invalid IP format: ", ip)
		return nil
	}
}
Ejemplo n.º 7
0
func (this *PortRange) UnmarshalJSON(data []byte) error {
	var maybeint int
	err := json.Unmarshal(data, &maybeint)
	if err == nil {
		if maybeint <= 0 || maybeint >= 65535 {
			log.Error("Invalid port [", serial.BytesLiteral(data), "]")
			return ErrorInvalidPortRange
		}
		this.From = Port(maybeint)
		this.To = Port(maybeint)
		return nil
	}

	var maybestring string
	err = json.Unmarshal(data, &maybestring)
	if err == nil {
		pair := strings.SplitN(maybestring, "-", 2)
		if len(pair) == 1 {
			value, err := strconv.Atoi(pair[0])
			if err != nil || value <= 0 || value >= 65535 {
				log.Error("Invalid from port ", pair[0])
				return ErrorInvalidPortRange
			}
			this.From = Port(value)
			this.To = Port(value)
			return nil
		} else if len(pair) == 2 {
			from, err := strconv.Atoi(pair[0])
			if err != nil || from <= 0 || from >= 65535 {
				log.Error("Invalid from port ", pair[0])
				return ErrorInvalidPortRange
			}
			this.From = Port(from)

			to, err := strconv.Atoi(pair[1])
			if err != nil || to <= 0 || to >= 65535 {
				log.Error("Invalid to port ", pair[1])
				return ErrorInvalidPortRange
			}
			this.To = Port(to)

			if this.From > this.To {
				log.Error("Invalid port range ", this.From, " -> ", this.To)
				return ErrorInvalidPortRange
			}
			return nil
		}
	}

	return ErrorInvalidPortRange
}
Ejemplo n.º 8
0
func (this *HttpProxyServer) handlePlainHTTP(request *http.Request, dest v2net.Destination, reader *bufio.Reader, writer io.Writer) {
	if len(request.URL.Host) <= 0 {
		hdr := http.Header(make(map[string][]string))
		hdr.Set("Connection", "close")
		response := &http.Response{
			Status:        "400 Bad Request",
			StatusCode:    400,
			Proto:         "HTTP/1.1",
			ProtoMajor:    1,
			ProtoMinor:    1,
			Header:        hdr,
			Body:          nil,
			ContentLength: 0,
			Close:         false,
		}

		buffer := alloc.NewSmallBuffer().Clear()
		response.Write(buffer)
		writer.Write(buffer.Value)
		buffer.Release()
		return
	}

	request.Host = request.URL.Host
	stripHopByHopHeaders(request)

	requestBuffer := alloc.NewBuffer().Clear() // Don't release this buffer as it is passed into a Packet.
	request.Write(requestBuffer)
	log.Debug("Request to remote:\n", serial.BytesLiteral(requestBuffer.Value))

	packet := v2net.NewPacket(dest, requestBuffer, true)
	ray := this.packetDispatcher.DispatchToOutbound(packet)
	defer close(ray.InboundInput())

	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
		defer wg.Done()
		responseReader := bufio.NewReader(NewChanReader(ray.InboundOutput()))
		response, err := http.ReadResponse(responseReader, request)
		if err != nil {
			return
		}
		responseBuffer := alloc.NewBuffer().Clear()
		defer responseBuffer.Release()
		response.Write(responseBuffer)
		writer.Write(responseBuffer.Value)
		response.Body.Close()
	}()
	wg.Wait()
}
Ejemplo n.º 9
0
func (this *SwitchAccount) Unmarshal(data []byte) error {
	if len(data) == 0 {
		return transport.ErrorCorruptedPacket
	}
	lenHost := int(data[0])
	if len(data) < lenHost+1 {
		return transport.ErrorCorruptedPacket
	}
	if lenHost > 0 {
		this.Host = v2net.ParseAddress(string(data[1 : 1+lenHost]))
	}
	portStart := 1 + lenHost
	if len(data) < portStart+2 {
		return transport.ErrorCorruptedPacket
	}
	this.Port = v2net.PortFromBytes(data[portStart : portStart+2])
	idStart := portStart + 2
	if len(data) < idStart+16 {
		return transport.ErrorCorruptedPacket
	}
	this.ID, _ = uuid.ParseBytes(data[idStart : idStart+16])
	alterIdStart := idStart + 16
	if len(data) < alterIdStart+2 {
		return transport.ErrorCorruptedPacket
	}
	this.AlterIds = serial.BytesLiteral(data[alterIdStart : alterIdStart+2]).Uint16()
	levelStart := alterIdStart + 2
	if len(data) < levelStart+1 {
		return transport.ErrorCorruptedPacket
	}
	this.Level = vmess.UserLevel(data[levelStart])
	timeStart := levelStart + 1
	if len(data) < timeStart {
		return transport.ErrorCorruptedPacket
	}
	this.ValidMin = data[timeStart]
	return nil
}
Ejemplo n.º 10
0
func ReadRequest(reader io.Reader, auth *Authenticator, udp bool) (*Request, error) {
	buffer := alloc.NewSmallBuffer()
	defer buffer.Release()

	_, err := io.ReadFull(reader, buffer.Value[:1])
	if err != nil {
		log.Error("Shadowsocks: Failed to read address type: ", err)
		return nil, transport.ErrorCorruptedPacket
	}
	lenBuffer := 1

	request := new(Request)

	addrType := (buffer.Value[0] & 0x0F)
	if (buffer.Value[0] & 0x10) == 0x10 {
		request.OTA = true
	}
	switch addrType {
	case AddrTypeIPv4:
		_, err := io.ReadFull(reader, buffer.Value[lenBuffer:lenBuffer+4])
		if err != nil {
			log.Error("Shadowsocks: Failed to read IPv4 address: ", err)
			return nil, transport.ErrorCorruptedPacket
		}
		request.Address = v2net.IPAddress(buffer.Value[lenBuffer : lenBuffer+4])
		lenBuffer += 4
	case AddrTypeIPv6:
		_, err := io.ReadFull(reader, buffer.Value[lenBuffer:lenBuffer+16])
		if err != nil {
			log.Error("Shadowsocks: Failed to read IPv6 address: ", err)
			return nil, transport.ErrorCorruptedPacket
		}
		request.Address = v2net.IPAddress(buffer.Value[lenBuffer : lenBuffer+16])
		lenBuffer += 16
	case AddrTypeDomain:
		_, err := io.ReadFull(reader, buffer.Value[lenBuffer:lenBuffer+1])
		if err != nil {
			log.Error("Shadowsocks: Failed to read domain lenth: ", err)
			return nil, transport.ErrorCorruptedPacket
		}
		domainLength := int(buffer.Value[lenBuffer])
		lenBuffer++
		_, err = io.ReadFull(reader, buffer.Value[lenBuffer:lenBuffer+domainLength])
		if err != nil {
			log.Error("Shadowsocks: Failed to read domain: ", err)
			return nil, transport.ErrorCorruptedPacket
		}
		request.Address = v2net.DomainAddress(string(buffer.Value[lenBuffer : lenBuffer+domainLength]))
		lenBuffer += domainLength
	default:
		log.Error("Shadowsocks: Unknown address type: ", addrType)
		return nil, transport.ErrorCorruptedPacket
	}

	_, err = io.ReadFull(reader, buffer.Value[lenBuffer:lenBuffer+2])
	if err != nil {
		log.Error("Shadowsocks: Failed to read port: ", err)
		return nil, transport.ErrorCorruptedPacket
	}

	request.Port = v2net.PortFromBytes(buffer.Value[lenBuffer : lenBuffer+2])
	lenBuffer += 2

	var authBytes []byte

	if udp {
		nBytes, err := reader.Read(buffer.Value[lenBuffer:])
		if err != nil {
			log.Error("Shadowsocks: Failed to read UDP payload: ", err)
		}
		buffer.Slice(0, lenBuffer+nBytes)
		if request.OTA {
			authBytes = buffer.Value[lenBuffer+nBytes-AuthSize:]
			request.UDPPayload = alloc.NewSmallBuffer().Clear().Append(buffer.Value[lenBuffer : lenBuffer+nBytes-AuthSize])
			lenBuffer = lenBuffer + nBytes - AuthSize
		} else {
			request.UDPPayload = alloc.NewSmallBuffer().Clear().Append(buffer.Value[lenBuffer:])
		}
	} else {
		if request.OTA {
			authBytes = buffer.Value[lenBuffer : lenBuffer+AuthSize]
			_, err = io.ReadFull(reader, authBytes)
			if err != nil {
				log.Error("Shadowsocks: Failed to read OTA: ", err)
				return nil, transport.ErrorCorruptedPacket
			}
		}
	}

	if request.OTA {
		actualAuth := auth.Authenticate(nil, buffer.Value[0:lenBuffer])
		if !serial.BytesLiteral(actualAuth).Equals(serial.BytesLiteral(authBytes)) {
			log.Error("Shadowsocks: Invalid OTA: ", actualAuth)
			return nil, transport.ErrorCorruptedPacket
		}
	}

	return request, nil
}
Ejemplo n.º 11
0
// PortFromBytes converts a byte array to a Port, assuming bytes are in big endian order.
// @unsafe Caller must ensure that the byte array has at least 2 elements.
func PortFromBytes(port []byte) Port {
	return Port(serial.BytesLiteral(port).Uint16Value())
}
Ejemplo n.º 12
0
func TestCmdKey(t *testing.T) {
	v2testing.Current(t)

	id := NewID(uuid.New())
	assert.Bool(serial.BytesLiteral(id.CmdKey()).AllZero()).IsFalse()
}
Ejemplo n.º 13
0
func (this *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.RequestHeader, error) {
	buffer := alloc.NewSmallBuffer()
	defer buffer.Release()

	_, err := io.ReadFull(reader, buffer.Value[:protocol.IDBytesLen])
	if err != nil {
		log.Error("Raw: Failed to read request header: ", err)
		return nil, err
	}

	user, timestamp, valid := this.userValidator.Get(buffer.Value[:protocol.IDBytesLen])
	if !valid {
		return nil, protocol.ErrorInvalidUser
	}

	timestampHash := md5.New()
	timestampHash.Write(hashTimestamp(timestamp))
	iv := timestampHash.Sum(nil)
	aesStream := crypto.NewAesDecryptionStream(user.ID.CmdKey(), iv)
	decryptor := crypto.NewCryptionReader(aesStream, reader)

	nBytes, err := io.ReadFull(decryptor, buffer.Value[:41])
	if err != nil {
		log.Debug("Raw: Failed to read request header (", nBytes, " bytes): ", err)
		return nil, err
	}
	bufferLen := nBytes

	request := &protocol.RequestHeader{
		User:    user,
		Version: buffer.Value[0],
	}

	if request.Version != Version {
		log.Warning("Raw: Invalid protocol version ", request.Version)
		return nil, protocol.ErrorInvalidVersion
	}

	this.requestBodyIV = append([]byte(nil), buffer.Value[1:17]...)   // 16 bytes
	this.requestBodyKey = append([]byte(nil), buffer.Value[17:33]...) // 16 bytes
	this.responseHeader = buffer.Value[33]                            // 1 byte
	request.Option = protocol.RequestOption(buffer.Value[34])         // 1 byte + 2 bytes reserved
	request.Command = protocol.RequestCommand(buffer.Value[37])

	request.Port = v2net.PortFromBytes(buffer.Value[38:40])

	switch buffer.Value[40] {
	case AddrTypeIPv4:
		nBytes, err = io.ReadFull(decryptor, buffer.Value[41:45]) // 4 bytes
		bufferLen += 4
		if err != nil {
			log.Debug("VMess: Failed to read target IPv4 (", nBytes, " bytes): ", err)
			return nil, err
		}
		request.Address = v2net.IPAddress(buffer.Value[41:45])
	case AddrTypeIPv6:
		nBytes, err = io.ReadFull(decryptor, buffer.Value[41:57]) // 16 bytes
		bufferLen += 16
		if err != nil {
			log.Debug("VMess: Failed to read target IPv6 (", nBytes, " bytes): ", nBytes, err)
			return nil, err
		}
		request.Address = v2net.IPAddress(buffer.Value[41:57])
	case AddrTypeDomain:
		nBytes, err = io.ReadFull(decryptor, buffer.Value[41:42])
		if err != nil {
			log.Debug("VMess: Failed to read target domain (", nBytes, " bytes): ", nBytes, err)
			return nil, err
		}
		domainLength := int(buffer.Value[41])
		if domainLength == 0 {
			return nil, transport.ErrorCorruptedPacket
		}
		nBytes, err = io.ReadFull(decryptor, buffer.Value[42:42+domainLength])
		if err != nil {
			log.Debug("VMess: Failed to read target domain (", nBytes, " bytes): ", nBytes, err)
			return nil, err
		}
		bufferLen += 1 + domainLength
		domainBytes := append([]byte(nil), buffer.Value[42:42+domainLength]...)
		request.Address = v2net.DomainAddress(string(domainBytes))
	}

	nBytes, err = io.ReadFull(decryptor, buffer.Value[bufferLen:bufferLen+4])
	if err != nil {
		log.Debug("VMess: Failed to read checksum (", nBytes, " bytes): ", nBytes, err)
		return nil, err
	}

	fnv1a := fnv.New32a()
	fnv1a.Write(buffer.Value[:bufferLen])
	actualHash := fnv1a.Sum32()
	expectedHash := serial.BytesLiteral(buffer.Value[bufferLen : bufferLen+4]).Uint32Value()

	if actualHash != expectedHash {
		return nil, transport.ErrorCorruptedPacket
	}

	return request, nil
}