Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions pkg/tcpip/network/ipv4/ipv4_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4387,6 +4387,36 @@ func newTCPPacket(t *testing.T, srcAddr, dstAddr tcpip.Address, ttl uint8, tcpCh
return pkt
}

func newICMPTimestampPacket(t *testing.T, srcAddr, dstAddr tcpip.Address, ttl uint8) *stack.PacketBuffer {
t.Helper()
const icmpTimestampSize = header.ICMPv4PayloadOffset + 12
totalLength := header.IPv4MinimumSize + icmpTimestampSize
hdr := prependable.New(totalLength)

icmpH := header.ICMPv4(hdr.Prepend(icmpTimestampSize))
icmpH.SetType(header.ICMPv4Timestamp)
icmpH.SetCode(header.ICMPv4UnusedCode)
icmpH.SetChecksum(0)
icmpH.SetChecksum(^checksum.Checksum(icmpH, 0))

ipH := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
ipH.Encode(&header.IPv4Fields{
TotalLength: uint16(totalLength),
Protocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
SrcAddr: srcAddr,
DstAddr: dstAddr,
})
ipH.SetChecksum(0)
ipH.SetChecksum(^ipH.CalculateChecksum())

pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(hdr.View()),
})
pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
return pkt
}

func TestForwardingTCPChecksum(t *testing.T) {
ctx := newTestContext()
defer ctx.cleanup()
Expand Down Expand Up @@ -4452,3 +4482,50 @@ func TestForwardingTCPChecksum(t *testing.T) {
t.Errorf("expected valid TCP checksum, but got invalid")
}
}

func TestForwardingICMPv4TimestampChecksumNoPanic(t *testing.T) {
ctx := newTestContext()
defer ctx.cleanup()
s := ctx.s

endpoints := make(map[tcpip.NICID]*channel.Endpoint)
for nicID, addr := range defaultEndpointConfigs {
ep := channel.New(1, ipv4.MaxTotalSize, "")
defer ep.Close()

if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: addr}
if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
}
endpoints[nicID] = ep
}

s.SetRouteTable([]tcpip.Route{
{
Destination: incomingIPv4Addr.Subnet(),
NIC: incomingNICID,
},
{
Destination: outgoingIPv4Addr.Subnet(),
NIC: outgoingNICID,
},
})

if err := s.SetForwardingDefaultAndAllNICs(header.IPv4ProtocolNumber, true); err != nil {
t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", header.IPv4ProtocolNumber, err)
}

requestPkt := newICMPTimestampPacket(t, remoteIPv4Addr1, remoteIPv4Addr2, 64)
defer requestPkt.DecRef()

endpoints[incomingNICID].InjectInbound(header.IPv4ProtocolNumber, requestPkt)

reply := endpoints[outgoingNICID].Read()
if reply == nil {
t.Fatal("Expected forwarded ICMPv4 Timestamp packet through outgoing NIC")
}
defer reply.DecRef()
}
2 changes: 1 addition & 1 deletion pkg/tcpip/stack/packet_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ func (pk *PacketBuffer) GetHeaders() (netHdr header.Network, transHdr header.Tra
return pk.Network(), icmpHeader, false, true
case header.ICMPv4DstUnreachable, header.ICMPv4TimeExceeded, header.ICMPv4ParamProblem:
default:
panic(fmt.Sprintf("unexpected ICMPv4 type = %d", icmpType))
return nil, nil, false, false
}

h, ok := pk.Data().PullUp(header.IPv4MinimumSize)
Expand Down
24 changes: 24 additions & 0 deletions pkg/tcpip/stack/packet_buffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,30 @@ func TestGetHeadersUnknownProtocol(t *testing.T) {
}
}

// TestGetHeadersUnsupportedICMPv4Type verifies that packet-controlled ICMPv4
// types do not cause GetHeaders to panic.
func TestGetHeadersUnsupportedICMPv4Type(t *testing.T) {
pk := NewPacketBuffer(PacketBufferOptions{
Payload: buffer.MakeWithData([]byte{
byte(header.ICMPv4Timestamp),
byte(header.ICMPv4UnusedCode),
0, 0, // checksum
0, 0, 0, 0, // identifier and sequence/timestamp fields
}),
})
defer pk.DecRef()

pk.TransportProtocolNumber = header.ICMPv4ProtocolNumber
if _, ok := pk.TransportHeader().Consume(header.ICMPv4MinimumSize); !ok {
t.Fatal("failed to consume ICMPv4 header")
}

netHdr, transHdr, isICMPError, ok := pk.GetHeaders()
if ok {
t.Errorf("GetHeaders() = (%v, %v, %v, true); want _, _, _, false", netHdr, transHdr, isICMPError)
}
}

// TestCalculateTransportChecksumUnknownProtocol verifies that
// CalculateTransportChecksum() does not panic when encountering an unknown
// transport protocol, even when falling back to parsing the network header.
Expand Down