// RecoverKey returns the AES key used to generate the given white-box construction. func RecoverKey(constr *chow.Construction) []byte { round1, round2 := round{ construction: constr, round: 1, }, round{ construction: constr, round: 2, } // Decomposition Phase constr1 := aspn.DecomposeSPN(round1, cspn.SAS) constr2 := aspn.DecomposeSPN(round2, cspn.SAS) var ( leading, middle, trailing sboxLayer left, right = affineLayer(constr1[1].(encoding.BlockAffine)), affineLayer(constr2[1].(encoding.BlockAffine)) ) for pos := 0; pos < 16; pos++ { leading[pos] = constr1[0].(encoding.ConcatenatedBlock)[pos] middle[pos] = encoding.ComposedBytes{ constr1[2].(encoding.ConcatenatedBlock)[pos], constr2[0].(encoding.ConcatenatedBlock)[common.ShiftRows(pos)], } trailing[pos] = constr2[2].(encoding.ConcatenatedBlock)[pos] } // Disambiguation Phase // Disambiguate the affine layer. lin, lout := left.clean() rin, rout := right.clean() leading.rightCompose(lin, common.NoShift) middle.leftCompose(lout, common.NoShift).rightCompose(rin, common.ShiftRows) trailing.leftCompose(rout, common.NoShift) // The SPN decomposition naturally leaves the affine layers without a constant part. // We would push it into the S-boxes here if that wasn't the case. // Move the constant off of the input and output of the S-boxes. mcin, mcout := middle.cleanConstant() mcin, mcout = left.Decode(mcin), right.Encode(mcout) leading.rightCompose(encoding.DecomposeConcatenatedBlock(encoding.BlockAdditive(mcin)), common.NoShift) trailing.leftCompose(encoding.DecomposeConcatenatedBlock(encoding.BlockAdditive(mcout)), common.NoShift) // Move the multiplication off of the input and output of the middle S-boxes. mlin, mlout := middle.cleanLinear() leading.rightCompose(mlin, common.NoShift) trailing.leftCompose(mlout, common.NoShift) // fmt.Println(encoding.ProbablyEquivalentBlocks( // encoding.ComposedBlocks{aspn.Encoding{round1}, ShiftRows{}, aspn.Encoding{round2}}, // encoding.ComposedBlocks{leading, left, middle, ShiftRows{}, right, trailing}, // )) // Output: true // Extract the key from the leading S-boxes. key := [16]byte{} for pos := 0; pos < 16; pos++ { for guess := 0; guess < 256; guess++ { cand := encoding.ComposedBytes{ leading[pos], encoding.ByteAdditive(guess), encoding.InverseByte{sbox{}}, } if isAS(cand) { key[pos] = byte(guess) break } } } key = left.Encode(key) return backOneRound(backOneRound(key[:], 2), 1) }
// shiftRoundKey adds the fixed SubBytes constant to a round key and returns the result as an encoding.Block. func shiftRoundKey(key []byte) encoding.BlockAdditive { out := [16]byte{} encoding.XOR(out[:], subBytesConst, key) return encoding.BlockAdditive(out) }