From 7506de289b3fdffe422e7006b0fa0bd673f3168f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?B=E7=AB=99=E8=B4=B4=E5=90=A7=E8=9C=A1=E6=B2=B9?= Date: Sat, 19 Aug 2023 11:10:28 +0800 Subject: [PATCH] Improve CFB8 implementation (#256) * Add CFB8 tests & benchmark * Improve CFB8 implementation * Cleanup code * Speed up with copy function * Even faster * Fix & more tests * Fix tests * Fix typo --- net/CFB8/cfb8.go | 103 +++++++++++++++++------ net/CFB8/cfb8_test.go | 186 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 266 insertions(+), 23 deletions(-) create mode 100644 net/CFB8/cfb8_test.go diff --git a/net/CFB8/cfb8.go b/net/CFB8/cfb8.go index 0ba56f9..5714a2b 100644 --- a/net/CFB8/cfb8.go +++ b/net/CFB8/cfb8.go @@ -1,51 +1,108 @@ -// Package CFB8 is copied from https://play.golang.org/p/LTbId4b6M2 +// Package CFB8 implements CFB8 block cipher mode of operation used by Minecraft protocol. package CFB8 -import "crypto/cipher" +import ( + "crypto/cipher" + "unsafe" +) type CFB8 struct { c cipher.Block blockSize int - iv, tmp []byte + ivPos int + iv []byte de bool } func NewCFB8Decrypt(c cipher.Block, iv []byte) *CFB8 { - cp := make([]byte, len(iv)) - copy(cp, iv) - return &CFB8{ - c: c, - blockSize: c.BlockSize(), - iv: cp, - tmp: make([]byte, c.BlockSize()), - de: true, - } + return newCFB8(c, iv, true) } func NewCFB8Encrypt(c cipher.Block, iv []byte) *CFB8 { - cp := make([]byte, len(iv)) + return newCFB8(c, iv, false) +} + +func newCFB8(c cipher.Block, iv []byte, de bool) *CFB8 { + cp := make([]byte, len(iv)*3) copy(cp, iv) return &CFB8{ c: c, blockSize: c.BlockSize(), iv: cp, - tmp: make([]byte, c.BlockSize()), - de: false, + de: de, } } func (cf *CFB8) XORKeyStream(dst, src []byte) { - for i := 0; i < len(src); i++ { - val := src[i] - copy(cf.tmp, cf.iv) - cf.c.Encrypt(cf.iv, cf.iv) - val = val ^ cf.iv[0] + if len(dst) < len(src) { + panic("cfb8: output smaller than input") + } - copy(cf.iv, cf.tmp[1:]) + // If dst and src does not overlap in first block size, + // and the length of src is greater than 2*blockSize, + // we can use an optimized implementation. + if len(src) > cf.blockSize<<1 && + (uintptr(unsafe.Pointer(&dst[0]))+uintptr(cf.blockSize) <= uintptr(unsafe.Pointer(&src[0])) || + uintptr(unsafe.Pointer(&src[0]))+uintptr(len(src)) <= uintptr(unsafe.Pointer(&dst[0]))) { + // encrypt/decrypt first blockSize bytes + // After this, the IV will come to the same as + // the last blockSize of ciphertext, so + // we can reuse them without copy. + cf.XORKeyStream(dst, src[:cf.blockSize]) + var ciphertext []byte if cf.de { - cf.iv[15] = src[i] + ciphertext = src } else { - cf.iv[15] = val + ciphertext = dst + } + dst = dst[cf.blockSize:] + src = src[cf.blockSize:] + iv := cf.iv + _ = iv[0] // bounds check hint to compiler; see golang.org/issue/14808 + var ( + i int + val byte + ) + for i, val = range src { + cf.c.Encrypt(iv, ciphertext[i:]) + dst[i] = val ^ iv[0] + } + // copy the current IV for next operation + copy(iv, ciphertext[i+1:i+1+cf.blockSize]) + cf.ivPos = 0 + return + } + + for i, val := range src { + posPlusBlockSize := cf.ivPos + cf.blockSize + // fast mod; 2*blockSize must be a non-negative integer power of 2 + tempPos := posPlusBlockSize & (cf.blockSize<<1 - 1) + // reuse space to store encrypted block + cf.c.Encrypt(cf.iv[tempPos:], cf.iv[cf.ivPos:]) + // Only the first byte of the encrypted block is used + // for encryption/decryption, other bytes are ignored. + val ^= cf.iv[tempPos] + + if cf.ivPos == cf.blockSize<<1 { + // bound reached; move to next round for next operation + // copy next block to the start of the ring buffer + copy(cf.iv, cf.iv[cf.ivPos+1:]) + // insert the encrypted byte to the end of IV + if cf.de { + cf.iv[cf.blockSize-1] = src[i] + } else { + cf.iv[cf.blockSize-1] = val + } + cf.ivPos = 0 + } else { + // insert the encrypted byte to the end of IV + if cf.de { + cf.iv[posPlusBlockSize] = src[i] + } else { + cf.iv[posPlusBlockSize] = val + } + // move to next block + cf.ivPos += 1 } dst[i] = val diff --git a/net/CFB8/cfb8_test.go b/net/CFB8/cfb8_test.go new file mode 100644 index 0000000..c0b59e4 --- /dev/null +++ b/net/CFB8/cfb8_test.go @@ -0,0 +1,186 @@ +package CFB8 + +import ( + "bytes" + "crypto/aes" + "crypto/rand" + "encoding/hex" + "testing" +) + +// cfb8Tests contains the test vectors from +// https://csrc.nist.gov/publications/nistpubs/800-38a/sp800-38a.pdf, section +// F.3.7. Modified for Minecraft CFB8 tests. +var cfb8Tests = []struct { + key, iv, plaintext, ciphertext string +}{ + { + "2b7e151628aed2a6abf7158809cf4f3c", + "000102030405060708090a0b0c0d0e0f", + "6bc1bee22e409f96e93d7e117393172a", + "3b79424c9c0dd436bace9e0ed4586a4f", + }, + { + "2b7e151628aed2a6abf7158809cf4f3c", + "3B3FD92EB72DAD20333449F8E83CFB4A", + "ae2d8a571e03ac9c9eb76fac45af8e51", + "c8b0723943d71f61a2e5b0e8cedf87c8", + }, + { + "2b7e151628aed2a6abf7158809cf4f3c", + "C8A64537A0B3A93FCDE3CDAD9F1CE58B", + "30c81c46a35ce411e5fbc1191a0a52ef", + "260d20e9395d3501067286d3a2a7002f", + }, + { + "2b7e151628aed2a6abf7158809cf4f3c", + "26751F67A3CBB140B1808CF187A4F4DF", + "f69f2445df4f9b17ad2b417be66c3710", + "c0af633cd9c599309f924802af599ee6", + }, + { + "2b7e151628aed2a6abf7158809cf4f3c", + "000102030405060708090a0b0c0d0e0f", + "0ecbd6d36cd12962ce671b4d96fb95aaa902096aeac366e13a6ae57c05d48673cf320c626689d05548f65fd6a108630c1d4e3aab543b006823c7a9422e97c0431587537c384f99a11488ffd9b2e9b46f49005a7e5cef64e27e2de3cf3fb87c1524766601", + "5efb6f6b93cf5f0e135a0c932f59f9aaa2276e4b06cd4f5edca4baba735ac7708dd7c0f9e92c6b89d2245b0d9a6356b0e98529cd45e56df22e914ef9e0792facaab707af90c13162bfad06a240eb6adcbf3365fd84a003f8083f4662a7a27232c72c6c0c", + }, +} + +func TestCFB8VectorsNonOverlapping(t *testing.T) { + for i, test := range cfb8Tests { + key, err := hex.DecodeString(test.key) + if err != nil { + t.Fatal(err) + } + iv, err := hex.DecodeString(test.iv) + if err != nil { + t.Fatal(err) + } + plaintext, err := hex.DecodeString(test.plaintext) + if err != nil { + t.Fatal(err) + } + expected, err := hex.DecodeString(test.ciphertext) + if err != nil { + t.Fatal(err) + } + + block, err := aes.NewCipher(key) + if err != nil { + t.Fatal(err) + } + + ciphertext := make([]byte, len(plaintext)) + cfb := NewCFB8Encrypt(block, iv) + if len(plaintext) > 50 { + cfb.XORKeyStream(ciphertext, plaintext[:len(plaintext)/2]) + cfb.XORKeyStream(ciphertext[len(plaintext)/2:], plaintext[len(plaintext)/2:]) + } else { + cfb.XORKeyStream(ciphertext, plaintext) + } + + if !bytes.Equal(ciphertext, expected) { + t.Errorf("#%d: wrong output: got %x, expected %x", i, ciphertext, expected) + } + + cfbdec := NewCFB8Decrypt(block, iv) + plaintextCopy := make([]byte, len(ciphertext)) + if len(ciphertext) > 50 { + cfbdec.XORKeyStream(plaintextCopy, ciphertext[:len(ciphertext)/2]) + cfbdec.XORKeyStream(plaintextCopy[len(ciphertext)/2:], ciphertext[len(ciphertext)/2:]) + } else { + cfbdec.XORKeyStream(plaintextCopy, ciphertext) + } + + if !bytes.Equal(plaintextCopy, plaintext) { + t.Errorf("#%d: wrong plaintext: got %x, expected %x", i, plaintextCopy, plaintext) + } + } +} + +func TestCFB8VectorsOverlapped(t *testing.T) { + for i, test := range cfb8Tests { + key, err := hex.DecodeString(test.key) + if err != nil { + t.Fatal(err) + } + iv, err := hex.DecodeString(test.iv) + if err != nil { + t.Fatal(err) + } + plaintext, err := hex.DecodeString(test.plaintext) + if err != nil { + t.Fatal(err) + } + expected, err := hex.DecodeString(test.ciphertext) + if err != nil { + t.Fatal(err) + } + + block, err := aes.NewCipher(key) + if err != nil { + t.Fatal(err) + } + + buf := make([]byte, len(plaintext)) + copy(buf, plaintext) + cfb := NewCFB8Encrypt(block, iv) + if len(buf) > 50 { + cfb.XORKeyStream(buf, buf[:len(buf)/2]) + cfb.XORKeyStream(buf[len(buf)/2:], buf[len(buf)/2:]) + } else { + cfb.XORKeyStream(buf, buf) + } + + if !bytes.Equal(buf, expected) { + t.Errorf("#%d: wrong output: got %x, expected %x", i, buf, expected) + } + + cfbdec := NewCFB8Decrypt(block, iv) + if len(buf) > 50 { + cfbdec.XORKeyStream(buf, buf[:len(buf)/2]) + cfbdec.XORKeyStream(buf[len(buf)/2:], buf[len(buf)/2:]) + } else { + cfbdec.XORKeyStream(buf, buf) + } + + if !bytes.Equal(buf, plaintext) { + t.Errorf("#%d: wrong plaintext: got %x, expected %x", i, buf, plaintext) + } + } +} + +func BenchmarkCFB8AES1KOverlapped(b *testing.B) { + var key [16]byte + var iv [16]byte + rand.Read(key[:]) + rand.Read(iv[:]) + buf := make([]byte, 1024) + aes, _ := aes.NewCipher(key[:]) + stream := NewCFB8Encrypt(aes, iv[:]) + + b.SetBytes(int64(len(buf))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + stream.XORKeyStream(buf, buf) + } +} + +func BenchmarkCFB8AES1KNonOverlapping(b *testing.B) { + var key [16]byte + var iv [16]byte + rand.Read(key[:]) + rand.Read(iv[:]) + buf := make([]byte, 1024) + buf2 := make([]byte, 1024) + aes, _ := aes.NewCipher(key[:]) + stream := NewCFB8Encrypt(aes, iv[:]) + + b.SetBytes(int64(len(buf))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + stream.XORKeyStream(buf2, buf) + } +}