コード例 #1
0
func (mod *module) handle_connect(ch *e3x.Channel) {
	defer ch.Kill()

	var (
		from        hashname.H
		localIdent  *e3x.Identity
		remoteIdent *e3x.Identity
		handshake   cipherset.Handshake
		innerData   = bufpool.New()
		err         error
	)

	localIdent, err = mod.e.LocalIdentity()
	if err != nil {
		return
	}

	pkt, err := ch.ReadPacket()
	if err != nil {
		return
	}

	pkt.Body(innerData.SetLen(pkt.BodyLen()).RawBytes()[:0])

	inner, err := lob.Decode(innerData)
	if err != nil {
		return
	}

	innerHdr := inner.Header()
	if innerHdr.IsBinary() && len(innerHdr.Bytes) == 1 {
		// handshake
		var (
			csid = innerHdr.Bytes[0]
			key  = localIdent.Keys()[csid]
		)
		if key == nil {
			return
		}

		handshake, err = cipherset.DecryptHandshake(csid, key, inner.Body(nil))
		if err != nil {
			return
		}

		from, err = hashname.FromIntermediates(handshake.Parts())
		if err != nil {
			return
		}

		remoteIdent, err = e3x.NewIdentity(cipherset.Keys{
			handshake.CSID(): handshake.PublicKey(),
		}, handshake.Parts(), nil)
		if err != nil {
			return
		}

	} else {
		// key packet

		var parts = make(cipherset.Parts)
		var csid uint8
		for key, value := range inner.Header().Extra {
			if len(key) != 2 {
				continue
			}

			keyData, err := hex.DecodeString(key)
			if err != nil {
				continue
			}

			partCSID := keyData[0]
			switch v := value.(type) {
			case bool:
				csid = partCSID
			case string:
				parts[partCSID] = v
			}
		}

		hn, err := hashname.FromKeyAndIntermediates(csid, inner.Body(nil), parts)
		if err != nil {
			return
		}

		from = hn

		pubKey, err := cipherset.DecodeKeyBytes(csid, inner.Body(nil), nil)
		if err != nil {
			return
		}

		remoteIdent, err = e3x.NewIdentity(cipherset.Keys{csid: pubKey}, parts, nil)
		if err != nil {
			return
		}
	}

	if from == "" {
		return
	}

	if mod.config.AllowConnect != nil && !mod.config.AllowConnect(from, ch.RemoteHashname()) {
		return
	}

	x, err := mod.e.CreateExchange(remoteIdent)
	if err != nil {
		return
	}

	// when the BODY contains a handshake
	if handshake != nil {
		routerExchange := ch.Exchange()
		routerAddr := &peerAddr{
			router: routerExchange.RemoteHashname(),
		}

		conn := newConnection(x.RemoteHashname(), routerAddr, routerExchange, func() {
			mod.unregisterConnection(routerExchange, x.LocalToken())
		})

		pipe, added := x.AddPipeConnection(conn, nil)
		if added {
			mod.registerConnection(routerExchange, x.LocalToken(), conn)
		}

		resp, ok := x.ApplyHandshake(handshake, pipe)
		if !ok {
			return
		}

		if resp != nil {
			err = mod.peerVia(ch.Exchange(), from, resp)
			if err != nil {
				return
			}
		}
	}

	// when the BODY contains a key packet
	if handshake == nil {
		pkt, err := x.GenerateHandshake()
		if err != nil {
			return
		}

		err = mod.peerVia(ch.Exchange(), from, pkt)
		if err != nil {
			return
		}
	}

	// Notify on-exchange callbacks
	mod.getIntroduction(from).resolve(x, nil)
}
コード例 #2
0
ファイル: resolve_dns.go プロジェクト: utamaro/gogotelehash
func resolveSRV(uri *URI, proto string) (*e3x.Identity, error) {
	// ignore port
	host, _, _ := net.SplitHostPort(uri.Canonical)
	if host == "" {
		host = uri.Canonical
	}

	// normalize
	if !strings.HasSuffix(host, ".") {
		host += "."
	}

	// ignore .public
	if strings.HasSuffix(host, ".public.") {
		return nil, &net.DNSError{Name: host, Err: "cannot resolve .public hostnames using DNS"}
	}

	// lookup SRV records
	_, srvs, err := net.LookupSRV("mesh", proto, host)
	if err != nil {
		return nil, err
	}
	if len(srvs) > 1 {
		return nil, &net.DNSError{Name: host, Err: "too many SRV records"}
	}
	if len(srvs) == 0 {
		return nil, &net.DNSError{Name: host, Err: "no SRV records"}
	}

	var (
		srv     = srvs[0]
		port    = srv.Port
		portStr = strconv.Itoa(int(port))
		hn      hashname.H
		keys    cipherset.Keys
	)

	{ // detect valid target
		parts := strings.SplitN(srv.Target, ".", 2)
		if len(parts) != 2 || len(parts[0]) != 52 || len(parts[1]) == 0 {
			return nil, &net.DNSError{Name: host, Err: "SRV must target a <hashname>.<domain> domain"}
		}

		hn = hashname.H(parts[0])
		if !hn.Valid() {
			return nil, &net.DNSError{Name: host, Err: "SRV must target a <hashname>.<domain> domain"}
		}
	}

	// detect CNAMEs (they are not allowed)
	cname, err := net.LookupCNAME(srv.Target)
	if err != nil {
		return nil, err
	}
	if cname != "" && cname != srv.Target {
		return nil, &net.DNSError{Name: host, Err: "CNAME record are not allowed"}
	}

	// lookup A AAAA records
	ips, err := net.LookupIP(srv.Target)
	if err != nil {
		return nil, err
	}
	if len(ips) == 0 {
		return nil, &net.DNSError{Name: host, Err: "no A or AAAA records"}
	}

	// lookup TXT
	txts, err := net.LookupTXT(srv.Target)
	if err != nil {
		return nil, err
	}
	if len(txts) == 0 {
		return nil, &net.DNSError{Name: host, Err: "no TXT records"}
	}

	// make addrs
	addrs := make([]net.Addr, 0, len(ips))
	for _, ip := range ips {
		var (
			addr net.Addr
		)

		switch proto {
		case "udp":
			addr, _ = transports.ResolveAddr("udp4", net.JoinHostPort(ip.String(), portStr))
			if addr == nil {
				addr, _ = transports.ResolveAddr("udp6", net.JoinHostPort(ip.String(), portStr))
			}
		case "tcp":
			addr, _ = transports.ResolveAddr("tcp4", net.JoinHostPort(ip.String(), portStr))
			if addr == nil {
				addr, _ = transports.ResolveAddr("tcp6", net.JoinHostPort(ip.String(), portStr))
			}
			// case "http":
			// 	addr, _ = http.NewAddr(ip, port)
		}

		if addr != nil {
			addrs = append(addrs, addr)
		}
	}

	{ // parse keys

		// Sort txts so they form ascending sequences of key parts
		sort.Strings(txts)

		keyData := make(map[uint8]string, 10)
		for len(txts) > 0 {
			var (
				txt   = txts[0]
				parts = strings.Split(txt, "=")
			)

			if len(parts) != 2 {
				txts = txts[1:]
				continue
			}

			var (
				label = parts[0]
				value = parts[1]
				csid  uint8
			)

			if len(label) < 2 {
				txts = txts[1:]
				continue
			}

			// parse the CSID portion of the label
			i, err := strconv.ParseUint(label[:2], 16, 8)
			if err != nil {
				txts = txts[1:]
				continue
			}
			csid = uint8(i)

			// verify the key-part portion of the label
			if len(label) > 2 {
				_, err = strconv.ParseUint(label[2:], 10, 8)
				if err != nil {
					txts = txts[1:]
					continue
				}
			}

			keyData[csid] += value
			txts = txts[1:]
		}

		keys = make(cipherset.Keys, len(keyData))
		for csid, str := range keyData {
			key, err := cipherset.DecodeKey(csid, str, "")
			if err != nil {
				continue
			}

			keys[csid] = key
		}
	}

	ident, err := e3x.NewIdentity(keys, nil, addrs)
	if err != nil {
		return nil, err
	}

	if hn != ident.Hashname() {
		return nil, &net.DNSError{Name: host, Err: "invalid keys"}
	}

	return ident, nil
}