Files
mihomo/transport/sudoku/crypto/aead.go

131 lines
2.5 KiB
Go

package crypto
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"golang.org/x/crypto/chacha20poly1305"
)
type AEADConn struct {
net.Conn
aead cipher.AEAD
readBuf bytes.Buffer
nonceSize int
}
func NewAEADConn(c net.Conn, key string, method string) (*AEADConn, error) {
if method == "none" {
return &AEADConn{Conn: c, aead: nil}, nil
}
h := sha256.New()
h.Write([]byte(key))
keyBytes := h.Sum(nil)
var aead cipher.AEAD
var err error
switch method {
case "aes-128-gcm":
block, _ := aes.NewCipher(keyBytes[:16])
aead, err = cipher.NewGCM(block)
case "chacha20-poly1305":
aead, err = chacha20poly1305.New(keyBytes)
default:
return nil, fmt.Errorf("unsupported cipher: %s", method)
}
if err != nil {
return nil, err
}
return &AEADConn{
Conn: c,
aead: aead,
nonceSize: aead.NonceSize(),
}, nil
}
func (cc *AEADConn) Write(p []byte) (int, error) {
if cc.aead == nil {
return cc.Conn.Write(p)
}
maxPayload := 65535 - cc.nonceSize - cc.aead.Overhead()
totalWritten := 0
var frameBuf bytes.Buffer
header := make([]byte, 2)
nonce := make([]byte, cc.nonceSize)
for len(p) > 0 {
chunkSize := len(p)
if chunkSize > maxPayload {
chunkSize = maxPayload
}
chunk := p[:chunkSize]
p = p[chunkSize:]
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return totalWritten, err
}
ciphertext := cc.aead.Seal(nil, nonce, chunk, nil)
frameLen := len(nonce) + len(ciphertext)
binary.BigEndian.PutUint16(header, uint16(frameLen))
frameBuf.Reset()
frameBuf.Write(header)
frameBuf.Write(nonce)
frameBuf.Write(ciphertext)
if _, err := cc.Conn.Write(frameBuf.Bytes()); err != nil {
return totalWritten, err
}
totalWritten += chunkSize
}
return totalWritten, nil
}
func (cc *AEADConn) Read(p []byte) (int, error) {
if cc.aead == nil {
return cc.Conn.Read(p)
}
if cc.readBuf.Len() > 0 {
return cc.readBuf.Read(p)
}
header := make([]byte, 2)
if _, err := io.ReadFull(cc.Conn, header); err != nil {
return 0, err
}
frameLen := int(binary.BigEndian.Uint16(header))
body := make([]byte, frameLen)
if _, err := io.ReadFull(cc.Conn, body); err != nil {
return 0, err
}
if len(body) < cc.nonceSize {
return 0, errors.New("frame too short")
}
nonce := body[:cc.nonceSize]
ciphertext := body[cc.nonceSize:]
plaintext, err := cc.aead.Open(nil, nonce, ciphertext, nil)
if err != nil {
return 0, errors.New("decryption failed")
}
cc.readBuf.Write(plaintext)
return cc.readBuf.Read(p)
}