// Html2SAMLResponse extracts the SAMLResponse from a html document
func Html2SAMLResponse(tp *Testparams) (samlresponse *gosaml.Xp) {
	response := gosaml.NewHtmlXp(tp.Responsebody)
	samlbase64 := response.Query1(nil, `//input[@name="SAMLResponse"]/@value`)
	samlxml, _ := base64.StdEncoding.DecodeString(samlbase64)
	samlresponse = gosaml.NewXp(samlxml)
	if _, err := samlresponse.SchemaValidate(samlSchema); err != nil {
		fmt.Println("SchemaError")
	}

	certs := tp.Firstidpmd.Query(nil, `//md:KeyDescriptor[@use="signing" or not(@use)]/ds:KeyInfo/ds:X509Data/ds:X509Certificate`)
	if len(certs) == 0 {
		fmt.Printf("Could not find signing cert for: %s", tp.Firstidpmd.Query1(nil, "/@entityID"))
		log.Printf("Could not find signing cert for: %s", tp.Firstidpmd.Query1(nil, "/@entityID"))
	}

	_, pub, _ := gosaml.PublicKeyInfo(tp.Firstidpmd.NodeGetContent(certs[0]))
	assertion := samlresponse.Query(nil, "saml:Assertion[1]")
	if assertion == nil {
		fmt.Println("no assertion found")
	}
	if err := samlresponse.VerifySignature(assertion[0], pub); err != nil {
		fmt.Printf("SignatureVerificationError %s", err)
	}
	return
}
func ValidateSignature(md, xp *gosaml.Xp) (err error) {

	//no ds:Object in signatures
	certificates := md.Query(nil, gosaml.IdpCertQuery)
	if len(certificates) == 0 {
		err = errors.New("no certificates found in metadata")
		return
	}
	signatures := xp.Query(nil, "(/samlp:Response[ds:Signature] | /samlp:Response/saml:Assertion[ds:Signature])")
	destination := xp.Query1(nil, "/samlp:Response/@Destination")

	if len(signatures) == 0 {
		err = fmt.Errorf("%s neither the assertion nor the response was signed", destination)
		return
	}
	verified := 0
	signerrors := []error{}
	for _, certificate := range certificates {
		var key *rsa.PublicKey
		_, key, err = gosaml.PublicKeyInfo(md.NodeGetContent(certificate))

		if err != nil {
			return
		}

		for _, signature := range signatures {
			signerror := xp.VerifySignature(signature, key)
			if signerror != nil {
				signerrors = append(signerrors, signerror)
			} else {
				verified++
			}
		}
	}
	if verified == 0 || verified != len(signatures) {
		errorstring := ""
		delim := ""
		for _, e := range signerrors {
			errorstring += e.Error() + delim
			delim = ", "
		}
		err = fmt.Errorf("%s unable to validate signature: %s", destination, errorstring)
		return
	}
	return
}
Ejemplo n.º 3
0
Archivo: lMDQ.go Proyecto: wayf-dk/lMDQ
func (mdq *MDQ) Update() (err error) {
	start := time.Now()
	log.Println("lMDQ updating", mdq.Url, mdq.Path)

	_, err = mdq.db.Exec(lMDQSchema)
	if err != nil {
		return
	}

	recs, err := mdq.getEntityList()
	if err != nil {
		return err
	}
	var md []byte
	if md, err = get(mdq.Url); err != nil {
		return
	}

	dom := gosaml.NewXp(md)

	if _, err := dom.SchemaValidate(mdq.MetadataSchemaPath); err != nil {
		log.Println("feed", "SchemaError")
	}

	certificate := dom.Query(nil, "/md:EntitiesDescriptor/ds:Signature/ds:KeyInfo/ds:X509Data/ds:X509Certificate")
	if len(certificate) != 1 {
		err = errors.New("Metadata not signed")
		return
	}
	keyname, key, err := gosaml.PublicKeyInfo(dom.NodeGetContent(certificate[0]))

	if err != nil {
		return
	}

	ok := dom.VerifySignature(nil, key)
	if ok != nil || keyname != mdq.Hash {
		return fmt.Errorf("Signature check failed. Signature %s, %s = %s", ok, keyname, mdq.Hash)
	}

	tx, err := mdq.db.Begin()
	if err != nil {
		return
	}
	defer func() {
		if err != nil {
			tx.Rollback()
			return
		}
		err = tx.Commit()
	}()

	entityInsertStmt, err := tx.Prepare("insert into entity (entityid, md, hash) values ($1, $2, $3)")
	if err != nil {
		return
	}
	defer entityInsertStmt.Close()

	lookupInsertStmt, err := tx.Prepare("insert or ignore into lookup (hash, entity_id_fk) values (?, ?)")
	if err != nil {
		return err
	}
	defer lookupInsertStmt.Close()

	entityDeleteStmt, err := tx.Prepare("delete from entity where id = $1")
	if err != nil {
		return err
	}
	defer entityDeleteStmt.Close()

	vu, err := time.Parse(time.RFC3339Nano, dom.Query1(nil, "@validUntil"))
	if err != nil {
		return err
	}
	validUntil := vu.Unix()

	var new, updated, nochange, deleted int
	seen := map[string]bool{}

	entities := dom.Query(nil, "./md:EntityDescriptor")
	for _, entity := range entities {
		entityID := dom.Query1(entity, "@entityID")
		if seen[entityID] {
			log.Printf("lMDQ duplicate entityID: %s", entityID)
			continue
		}
		seen[entityID] = true
		md := gosaml.NewXpFromNode(entity).X2s()
		rec := recs[entityID]
		id := rec.id
		hash := hex.EncodeToString(gosaml.Hash(crypto.SHA1, md))
		oldhash := rec.hash
		if rec.hash == hash { // no changes
			delete(recs, entityID) // remove so it won't be deleted
			nochange++
			continue
		} else if oldhash != "" { // update is delete + insert - then the cascading delete will also delete the potential stale lookup entries
			_, err = entityDeleteStmt.Exec(rec.id)
			if err != nil {
				return
			}
			updated++
			log.Printf("lMDQ updated entityID: %s", entityID)
			delete(recs, entityID) // updated - remove so it won't be deleted
		} else {
			new++
			if !mdq.Silent {
				log.Printf("lMDQ new entityID: %s", entityID)
			}
		}
		var res sql.Result
		res, err = entityInsertStmt.Exec(entityID, md, hash)
		if err != nil {
			return err
		}

		id, _ = res.LastInsertId()

		_, err = lookupInsertStmt.Exec(hex.EncodeToString(gosaml.Hash(crypto.SHA1, entityID)), id)
		if err != nil {
			return
		}

		for _, target := range indextargets {
			locations := dom.Query(entity, target)
			for i, location := range locations {
				if !mdq.Silent {
					log.Println(i, dom.NodeGetContent(location))
				}
				_, err = lookupInsertStmt.Exec(hex.EncodeToString(gosaml.Hash(crypto.SHA1, dom.NodeGetContent(location))), id)
				if err != nil {
					return
				}
			}
		}
	}
	for entid, ent := range recs { // delete entities no longer in feed
		_, err = entityDeleteStmt.Exec(ent.id)
		if err != nil {
			return
		}
		deleted++
		log.Printf("lMDQ deleted entityID: %s", entid)
	}

	_, err = tx.Exec("update validuntil set validuntil = $1 where id = 1", validUntil)
	if err != nil {
		return
	}

	log.Printf("lMDQ finished %d new, %d updated, %d unchanged, %d deleted validUntil: %s duration: %.1f",
		new, updated, nochange, deleted, time.Unix(validUntil, 0).Format(time.RFC3339), time.Since(start).Seconds())
	return
}
// SSOSendRequest2 does the 2nd part of sending the request to the final IdP.
// Creates the response and signs and optionally encrypts it
func (tp *Testparams) SSOSendRequest2() {
	u, _ := tp.Resp.Location()

	// if going via birk we now got a scoped request to the hub
	if tp.Usedoubleproxy {

		if tp.Logxml {
			query := u.Query()
			req, _ := base64.StdEncoding.DecodeString(query["SAMLRequest"][0])
			authnrequest := gosaml.NewXp(gosaml.Inflate(req))
			log.Println("birkrequest", authnrequest.Pp())
		}

		tp.Resp, tp.Responsebody, _ = tp.sendRequest(u, tp.Resolv[u.Host], "GET", "", tp.Cookiejar)
		u, _ = tp.Resp.Location()
	}

	// We still expect to be redirected
	// if we are not at our final IdP something is rotten

	eid := tp.Idpmd.Query1(nil, "@entityID")
	idp, _ := url.Parse(eid)
	if u.Host != idp.Host {
		//log.Println("u.host != idp.Host", u, idp)
		// Errors from HUB is 302 to https://wayf.wayf.dk/displayerror.php ... which is a 500 with html content
		u, _ = tp.Resp.Location()
		tp.Resp, tp.Responsebody, tp.Err = tp.sendRequest(u, tp.Resolv[u.Host], "GET", "", tp.Cookiejar)
		return
	}

	// get the SAMLRequest
	query := u.Query()
	req, _ := base64.StdEncoding.DecodeString(query["SAMLRequest"][0])
	authnrequest := gosaml.NewXp(gosaml.Inflate(req))

	if tp.Logxml {
		log.Println("idprequest", authnrequest.Pp())
	}

	// create a response
	tp.Newresponse = gosaml.NewResponse(gosaml.IdAndTiming{time.Now(), 4 * time.Minute, 4 * time.Hour, "", ""}, tp.Idpmd, tp.Hubspmd, authnrequest, tp.Attributestmt)

	if tp.Logxml {
		log.Println("response", tp.Newresponse.Pp())
	}

	// and sign it
	assertion := tp.Newresponse.Query(nil, "saml:Assertion[1]")[0]

	// use cert to calculate key name
	err := tp.Newresponse.Sign(assertion, tp.Privatekey, tp.Privatekeypw, tp.Certificate, tp.Hashalgorithm)
	if err != nil {
		log.Fatal(err)
	}

	if tp.Encryptresponse {

		certs := tp.Hubspmd.Query(nil, `//md:KeyDescriptor[@use="encryption" or not(@use)]/ds:KeyInfo/ds:X509Data/ds:X509Certificate`)
		if len(certs) == 0 {
			fmt.Errorf("Could not find encryption cert for: %s", tp.Hubspmd.Query1(nil, "/@entityID"))
		}

		_, publickey, _ := gosaml.PublicKeyInfo(tp.Hubspmd.NodeGetContent(certs[0]))

		if tp.Env == "xdev" {
			cert, err := ioutil.ReadFile(*testcertpath)
			pk, err := x509.ParseCertificate(cert)
			if err != nil {
				return
			}
			publickey = pk.PublicKey.(*rsa.PublicKey)
		}

		tp.Newresponse.Encrypt(assertion, publickey)
		tp.Encryptresponse = false // for now only possible for idp -> hub
	}

	return
}
func Newtp(overwrite *Testparams) (tp *Testparams) {
	tp = new(Testparams)
	tp.Privatekeypw = os.Getenv("PW")
	if tp.Privatekeypw == "" {
		log.Fatal("no PW environment var")
	}
	tp.Env = *env
	tp.Krib = *dokrib
	tp.Birk = *dobirk
	tp.Hub = *dohub
	tp.Spmd, _ = hub_ops.MDQ("https://wayfsp.wayf.dk")
	tp.Hubspmd, _ = wayf_hub_public.MDQ("https://wayf.wayf.dk")
	tp.Hubspmd.Query(nil, "./md:SPSSODescriptor")[0].AddChild(wayfAttCSDoc.CopyNode(wayfAttCSElement, 1))
	tp.Hubidpmd, _ = wayf_hub_public.MDQ("https://wayf.wayf.dk")

	wayfserver := "wayf.wayf.dk"
	/*
		if tp.Env == "beta" {
			wayfserver = "betawayf.wayf.dk"
			tp.Hubspmd = newMD("https://betawayf.wayf.dk/module.php/saml/sp/metadata.php/betawayf.wayf.dk")
			tp.Hubidpmd = newMD("https://betawayf.wayf.dk/saml2/idp/metadata.php")
		}
	*/
	tp.Resolv = map[string]string{wayfserver: *hub, "birk.wayf.dk": *birk}
	tp.Idpmd, _ = hub_ops.MDQ("https://this.is.not.a.valid.idp")
	tp.Firstidpmd = tp.Hubidpmd
	if tp.Birk {
		tp.Birkmd, _ = birk_ops.MDQ("https://birk.wayf.dk/birk.php/this.is.not.a.valid.idp")
	}

	tp.DSIdpentityID = "https://this.is.not.a.valid.idp"
	if tp.Krib {
		tp.DSIdpentityID = "https://birk.wayf.dk/birk.php/this.is.not.a.valid.idp"
	}
	tp.Trace = *trace
	tp.Logxml = *logxml

	tp.Cookiejar = make(map[string]map[string]*http.Cookie)
	tp.Cookiejar["wayf.wayf.dk"] = make(map[string]*http.Cookie)
	tp.Cookiejar["wayf.wayf.dk"]["wayfid"] = &http.Cookie{Name: "wayfid", Value: *hubbe}
	tp.Cookiejar["birk.wayf.dk"] = make(map[string]*http.Cookie)
	tp.Cookiejar["birk.wayf.dk"]["birkid"] = &http.Cookie{Name: "birkid", Value: *birkbe}

	tp.Attributestmt = b(avals)
	tp.Hashalgorithm = "sha1"

	certs := tp.Idpmd.Query(nil, `//md:KeyDescriptor[@use="signing" or not(@use)]/ds:KeyInfo/ds:X509Data/ds:X509Certificate`)
	if len(certs) == 0 {
		fmt.Errorf("Could not find signing cert for: %s", tp.Idpmd.Query1(nil, "/@entityID"))
	}

	keyname, _, err := gosaml.PublicKeyInfo(tp.Idpmd.NodeGetContent(certs[0]))
	if err != nil {
		log.Fatal(err)
	}

	tp.Certificate = tp.Idpmd.NodeGetContent(certs[0])
	pk, err := ioutil.ReadFile("/etc/ssl/wayf/signing/" + keyname + ".key")
	if err != nil {
		log.Fatal(err)
	}
	tp.Privatekey = string(pk)
	if overwrite != nil {
		if overwrite.Hubspmd != nil {
			tp.Hubspmd = overwrite.Hubspmd
		}
		if overwrite.Hubidpmd != nil {
			tp.Hubidpmd = overwrite.Hubidpmd
			tp.Firstidpmd = tp.Hubidpmd
		}
		if overwrite.Encryptresponse {
			tp.Encryptresponse = true
		}
		if overwrite.Spmd != nil {
			tp.Spmd = overwrite.Spmd
		}
		if overwrite.Privatekey != "" {
			tp.Privatekey = overwrite.Privatekey
		}
		if overwrite.Privatekeypw != "" {
			tp.Privatekeypw = overwrite.Privatekeypw
		}
	}

	//	m := mapFields(tp)
	//    log.Println("Mapped fields: ", m)
	return
}