diff --git a/common.go b/common.go index 8b19ccf..28f4af0 100644 --- a/common.go +++ b/common.go @@ -67,7 +67,7 @@ const ( recordHeaderLen = 5 // record header length maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB) maxHandshakeCertificateMsg = 262144 // maximum certificate message size (256 KiB) - maxUselessRecords = 16 // maximum number of consecutive non-advancing records + maxUselessRecords = 32 // maximum number of consecutive non-advancing records ) // TLS record types. diff --git a/conn.go b/conn.go index a408e76..d99b2e6 100644 --- a/conn.go +++ b/conn.go @@ -25,16 +25,17 @@ import ( // A Conn represents a secured connection. // It implements the net.Conn interface. type Conn struct { - AuthKey []byte - ClientVer [3]byte - ClientTime time.Time - ClientShortId [8]byte + AuthKey []byte + ClientVer [3]byte + ClientTime time.Time + ClientShortId [8]byte + MaxUselessRecords int // constant conn net.Conn isClient bool handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake - quic *quicState // nil for non-QUIC connections + quic *quicState // nil for non-QUIC connections // isHandshakeComplete is true if the connection is currently transferring // application data (i.e. is not currently processing a handshake). @@ -827,7 +828,10 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { // a warning alert, empty application_data, or a change_cipher_spec in TLS 1.3. func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error { c.retryCount++ - if c.retryCount > maxUselessRecords { + if c.MaxUselessRecords <= 0 { + c.MaxUselessRecords = maxUselessRecords + } + if c.retryCount > c.MaxUselessRecords { c.sendAlert(alertUnexpectedMessage) return c.in.setErrorLocked(errors.New("tls: too many ignored records")) } @@ -1248,7 +1252,7 @@ func (c *Conn) unmarshalHandshakeMessage(data []byte, transcript transcriptHash) if transcript != nil { transcript.Write(data) } - + return m, nil } @@ -1375,7 +1379,7 @@ func (c *Conn) handlePostHandshakeMessage() error { return err } c.retryCount++ - if c.retryCount > maxUselessRecords { + if c.retryCount > c.MaxUselessRecords { c.sendAlert(alertUnexpectedMessage) return c.in.setErrorLocked(errors.New("tls: too many non-advancing records")) } diff --git a/record_detect.go b/fingerprint_detect.go similarity index 55% rename from record_detect.go rename to fingerprint_detect.go index d2399c3..c397acd 100644 --- a/record_detect.go +++ b/fingerprint_detect.go @@ -14,6 +14,7 @@ import ( ) var GlobalPostHandshakeRecordsLens sync.Map +var GlobalMaxCSSMsgCount sync.Map func DetectPostHandshakeRecordsLens(config *Config) { for sni := range config.ServerNames { @@ -36,7 +37,7 @@ func DetectPostHandshakeRecordsLens(config *Config) { return } } - detectConn := &DetectConn{ + detectConn := &RecordDetectConn{ Conn: target, Key: key, } @@ -60,25 +61,61 @@ func DetectPostHandshakeRecordsLens(config *Config) { } io.Copy(io.Discard, uConn) }() + go func() { + now := time.Now() + target, err := net.Dial("tcp", config.Dest) + rtt := time.Since(now) + if err != nil { + return + } + if config.Xver == 1 || config.Xver == 2 { + if _, err = proxyproto.HeaderProxyFromAddrs(config.Xver, target.LocalAddr(), target.RemoteAddr()).WriteTo(target); err != nil { + return + } + } + fingerprint := utls.HelloChrome_Auto + nextProtos := []string{"h2", "http/1.1"} + if alpn != 2 { + fingerprint = utls.HelloGolang + } + if alpn == 1 { + nextProtos = []string{"http/1.1"} + } + if alpn == 0 { + nextProtos = nil + } + conn := &CCSDetectConn{ + Conn: target, + Key: key, + rtt: rtt, + } + uConn := utls.UClient(conn, &utls.Config{ + ServerName: sni, // needs new loopvar behaviour + NextProtos: nextProtos, + }, fingerprint) + if err = uConn.Handshake(); err != nil { + return + } + }() } } } } -type DetectConn struct { +type RecordDetectConn struct { net.Conn Key string CcsSent bool } -func (c *DetectConn) Write(b []byte) (n int, err error) { +func (c *RecordDetectConn) Write(b []byte) (n int, err error) { if len(b) >= 3 && bytes.Equal(b[:3], []byte{20, 3, 3}) { c.CcsSent = true } return c.Conn.Write(b) } -func (c *DetectConn) Read(b []byte) (n int, err error) { +func (c *RecordDetectConn) Read(b []byte) (n int, err error) { if !c.CcsSent { return c.Conn.Read(b) } @@ -97,3 +134,30 @@ func (c *DetectConn) Read(b []byte) (n int, err error) { GlobalPostHandshakeRecordsLens.Store(c.Key, postHandshakeRecordsLens) return 0, io.EOF } + +var CCSMsg = []byte{0x14, 0x3, 0x3, 0x0, 0x1, 0x1} + +type CCSDetectConn struct { + net.Conn + rtt time.Duration + Key string +} + +func (c *CCSDetectConn) Write(b []byte) (n int, err error) { + if len(b) >= 3 && bytes.Equal(b[:3], []byte{20, 3, 3}) { + var i int + // 32(idx 31) → max allowed (what's we need) + // 33(idx 32) → trigger remote TLS Alert + // 34(idx 33) → trigger remote TCP RST + // 35(idx 34) → write err, pass to system + for i = range 35 { + if _, err = c.Conn.Write(CCSMsg); err != nil { + break + } else { + time.Sleep(c.rtt * 2) + } + } + GlobalMaxCSSMsgCount.Store(c.Key, i-2) + } + return c.Conn.Write(b) +} diff --git a/tls.go b/tls.go index add597a..de9dc75 100644 --- a/tls.go +++ b/tls.go @@ -372,7 +372,7 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { if err != nil { break } - go func() { // TODO: Probe target's maxUselessRecords and some time-outs in advance. + go func() { // TODO: Probe some time-outs in advance. if handshakeLen-len(s2cSaved) > 0 { io.ReadFull(target, buf[:handshakeLen-len(s2cSaved)]) } @@ -422,6 +422,10 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { } } time.Sleep(5 * time.Second) + if maxUseless, ok := GlobalMaxCSSMsgCount.Load(key); ok { + hs.c.MaxUselessRecords = maxUseless.(int) + } + } hs.c.isHandshakeComplete.Store(true) break