Added support for reliable UDP congestion control and packet resending -- no transcieve yet #78

This commit is contained in:
Collin Smith
2020-06-24 21:53:45 -07:00
parent 398f034755
commit e36b832667
10 changed files with 360 additions and 18 deletions

View File

@ -19,7 +19,7 @@ public abstract class MessageChannel implements ReliablePacketController.PacketL
} }
public abstract void reset(); public abstract void reset();
public abstract void update(float delta, DatagramChannel ch); public abstract void update(float delta, int channelId, DatagramChannel ch);
public abstract void sendMessage(int channelId, DatagramChannel ch, ByteBuf bb); public abstract void sendMessage(int channelId, DatagramChannel ch, ByteBuf bb);
public abstract void onMessageReceived(ChannelHandlerContext ctx, DatagramPacket packet); public abstract void onMessageReceived(ChannelHandlerContext ctx, DatagramPacket packet);

View File

@ -11,8 +11,8 @@ public class Packet {
public static final int USHORT_MAX_VALUE = 0xFFFF; public static final int USHORT_MAX_VALUE = 0xFFFF;
static final int MAX_PACKET_HEADER_SIZE = 10; public static final int MAX_PACKET_HEADER_SIZE = 10;
static final int FRAGMENT_HEADER_SIZE = 6; public static final int FRAGMENT_HEADER_SIZE = 6;
static final int SINGLE = 0; static final int SINGLE = 0;
static final int FRAGMENTED = 1 << 0; static final int FRAGMENTED = 1 << 0;

View File

@ -54,12 +54,18 @@ public class ReliableEndpoint implements Endpoint<DatagramPacket>, MessageChanne
@Override @Override
public void reset() { public void reset() {
for (MessageChannel mc : channels) if (mc != null) mc.reset(); final MessageChannel[] channels = this.channels;
for (int i = 0, s = channels.length; i < s; i++) {
channels[i].reset();
}
} }
@Override @Override
public void update(float delta) { public void update(float delta) {
for (MessageChannel mc : channels) if (mc != null) mc.update(delta, channel); final MessageChannel[] channels = this.channels;
for (int i = 0, s = channels.length; i < s; i++) {
channels[i].update(delta, i, channel);
}
} }
@Override @Override

View File

@ -5,6 +5,8 @@ import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.socket.DatagramChannel; import io.netty.channel.socket.DatagramChannel;
import io.netty.channel.socket.DatagramPacket; import io.netty.channel.socket.DatagramPacket;
import com.badlogic.gdx.math.MathUtils;
import com.riiablo.net.reliable.data.FragmentReassemblyData; import com.riiablo.net.reliable.data.FragmentReassemblyData;
import com.riiablo.net.reliable.data.ReceivedPacketData; import com.riiablo.net.reliable.data.ReceivedPacketData;
import com.riiablo.net.reliable.data.SentPacketData; import com.riiablo.net.reliable.data.SentPacketData;
@ -16,6 +18,8 @@ public class ReliablePacketController {
private static final boolean DEBUG_SEND = DEBUG && true; private static final boolean DEBUG_SEND = DEBUG && true;
private static final boolean DEBUG_RECEIVE = DEBUG && true; private static final boolean DEBUG_RECEIVE = DEBUG && true;
private static final float TOLERANCE = 0.00001f;
private final ReliableConfiguration config; private final ReliableConfiguration config;
private final MessageChannel channel; private final MessageChannel channel;
@ -23,7 +27,12 @@ public class ReliablePacketController {
private final SequenceBuffer<ReceivedPacketData> receivedPackets; private final SequenceBuffer<ReceivedPacketData> receivedPackets;
private final SequenceBuffer<FragmentReassemblyData> fragmentReassembly; private final SequenceBuffer<FragmentReassemblyData> fragmentReassembly;
private long time; private float time;
private float rtt;
private float packetLoss;
private float sentBandwidth;
private float receivedBandwidth;
private float ackedBandwidth;
public ReliablePacketController(ReliableConfiguration config, MessageChannel channel) { public ReliablePacketController(ReliableConfiguration config, MessageChannel channel) {
this.config = config; this.config = config;
@ -42,15 +51,132 @@ public class ReliablePacketController {
return channel.sequence = (channel.sequence + 1) & Packet.USHORT_MAX_VALUE; return channel.sequence = (channel.sequence + 1) & Packet.USHORT_MAX_VALUE;
} }
public void reset() { public float rtt() {
return rtt;
}
public void reset() {
channel.sequence = 0;
for (int i = 0, s = config.fragmentReassemblyBufferSize; i < s; i++) {
FragmentReassemblyData reassemblyData = fragmentReassembly.atIndex(i);
if (reassemblyData != null) reassemblyData.dataBuffer.clear();
}
sentPackets.reset();
receivedPackets.reset();
fragmentReassembly.reset();
} }
public void update(float delta) { public void update(float delta) {
this.time = time; time += delta;
updatePacketLoss();
updateSentBandwidth();
updateReceivedBandwidth();
updateAckedBandwidth();
}
private void updatePacketLoss() {
int baseSequence = (sentPackets.getSequence() - config.sentPacketBufferSize + 1 + Packet.USHORT_MAX_VALUE) & Packet.USHORT_MAX_VALUE;
int numDropped = 0;
int numSamples = config.sentPacketBufferSize / 2;
for (int i = 0; i < numSamples; i++) {
int sequence = (baseSequence + i) & Packet.USHORT_MAX_VALUE;
SentPacketData sentPacketData = sentPackets.find(sequence);
if (sentPacketData != null && !sentPacketData.acked) numDropped++;
}
float packetLoss = numDropped / (float) numSamples;
if (MathUtils.isEqual(this.packetLoss, packetLoss, TOLERANCE)) {
this.packetLoss += (packetLoss - this.packetLoss) * config.packetLossSmoothingFactor;
} else {
this.packetLoss = packetLoss;
}
}
private void updateSentBandwidth() {
int baseSequence = (sentPackets.getSequence() - config.sentPacketBufferSize + 1 + Packet.USHORT_MAX_VALUE) & Packet.USHORT_MAX_VALUE;
int bytesSent = 0;
float startTime = Float.MAX_VALUE;
float finishTime = 0f;
int numSamples = config.sentPacketBufferSize / 2;
for (int i = 0; i < numSamples; i++) {
int sequence = (baseSequence + i) & Packet.USHORT_MAX_VALUE;
SentPacketData sentPacketData = sentPackets.find(sequence);
if (sentPacketData == null) continue;
bytesSent += sentPacketData.packetSize;
startTime = Math.min(startTime, sentPacketData.time);
finishTime = Math.max(finishTime, sentPacketData.time);
}
if (startTime != Float.MAX_VALUE && finishTime != 0f) {
float sentBandwidth = bytesSent / (finishTime - startTime) * 8f / 1000f;
if (MathUtils.isEqual(this.sentBandwidth, sentBandwidth, TOLERANCE)) {
this.sentBandwidth += (sentBandwidth - this.sentBandwidth) * config.bandwidthSmoothingFactor;
} else {
this.sentBandwidth = sentBandwidth;
}
}
}
private void updateReceivedBandwidth() {
synchronized (receivedPackets) {
int baseSequence = (receivedPackets.getSequence() - config.receivedPacketBufferSize + 1 + Packet.USHORT_MAX_VALUE) & Packet.USHORT_MAX_VALUE;
int bytesReceived = 0;
float startTime = Float.MAX_VALUE;
float finishTime = 0f;
int numSamples = config.receivedPacketBufferSize / 2;
for (int i = 0; i < numSamples; i++) {
int sequence = (baseSequence + i) & Packet.USHORT_MAX_VALUE;
ReceivedPacketData receivedPacketData = receivedPackets.find(sequence);
if (receivedPacketData == null) continue;
bytesReceived += receivedPacketData.packetSize;
startTime = Math.min(startTime, receivedPacketData.time);
finishTime = Math.max(finishTime, receivedPacketData.time);
}
if (startTime != Float.MAX_VALUE && finishTime != 0f) {
float receivedBandwidth = bytesReceived / (finishTime - startTime) * 8f / 1000f;
if (MathUtils.isEqual(this.receivedBandwidth, receivedBandwidth, TOLERANCE)) {
this.receivedBandwidth += (receivedBandwidth - this.receivedBandwidth) * config.bandwidthSmoothingFactor;
} else {
this.receivedBandwidth = receivedBandwidth;
}
}
}
}
private void updateAckedBandwidth() {
int baseSequence = (sentPackets.getSequence() - config.sentPacketBufferSize + 1 + Packet.USHORT_MAX_VALUE) & Packet.USHORT_MAX_VALUE;
int bytesSent = 0;
float startTime = Float.MAX_VALUE;
float finishTime = 0f;
int numSamples = config.sentPacketBufferSize / 2;
for (int i = 0; i < numSamples; i++) {
int sequence = (baseSequence + i) & Packet.USHORT_MAX_VALUE;
SentPacketData sentPacketData = sentPackets.find(sequence);
if (sentPacketData == null || !sentPacketData.acked) continue;
bytesSent += sentPacketData.packetSize;
startTime = Math.min(startTime, sentPacketData.time);
finishTime = Math.max(finishTime, sentPacketData.time);
}
if (startTime != Float.MAX_VALUE && finishTime != 0f) {
float ackedBandwidth = bytesSent / (finishTime - startTime) * 8f / 1000f;
if (MathUtils.isEqual(this.ackedBandwidth, ackedBandwidth, TOLERANCE)) {
this.ackedBandwidth += (ackedBandwidth - this.ackedBandwidth) * config.bandwidthSmoothingFactor;
} else {
this.ackedBandwidth = ackedBandwidth;
}
}
} }
public void sendAck(int channelId, DatagramChannel ch) { public void sendAck(int channelId, DatagramChannel ch) {
if (DEBUG_SEND) Log.debug(TAG, "sendAck");
int ack, ackBits; int ack, ackBits;
synchronized (receivedPackets) { synchronized (receivedPackets) {
ack = receivedPackets.generateAck(); ack = receivedPackets.generateAck();
@ -87,7 +213,7 @@ public class ReliablePacketController {
SentPacketData sentPacketData = sentPackets.insert(sequence); SentPacketData sentPacketData = sentPackets.insert(sequence);
sentPacketData.time = this.time; sentPacketData.time = this.time;
// sentPacketData.packetSize = sentPacketData.packetSize = packetSize;
sentPacketData.acked = false; sentPacketData.acked = false;
if (packetSize <= config.fragmentThreshold) { if (packetSize <= config.fragmentThreshold) {
@ -165,8 +291,15 @@ public class ReliablePacketController {
if (DEBUG_RECEIVE) Log.debug(TAG, "acked packet %d", ackSequence); if (DEBUG_RECEIVE) Log.debug(TAG, "acked packet %d", ackSequence);
ReliableEndpoint.stats.NUM_PACKETS_ACKED++; ReliableEndpoint.stats.NUM_PACKETS_ACKED++;
sentPacketData.acked = true; sentPacketData.acked = true;
// ack packet callback // ack packet callback
// TODO: rtt
float rtt = (time - sentPacketData.time) * 1000f;
if ((this.rtt == 0.0f && rtt > 0.0f) || MathUtils.isEqual(this.rtt, rtt, TOLERANCE)) {
this.rtt = rtt;
} else {
this.rtt += (rtt - this.rtt) * config.rttSmoothingFactor;
}
} }
} }
} }

View File

@ -8,7 +8,7 @@ public class SequenceBuffer<T> {
private int sequence; private int sequence;
private final int numEntries; public final int numEntries;
private final int entrySequence[]; private final int entrySequence[];
private final Object entryData[]; private final Object entryData[];
@ -26,6 +26,10 @@ public class SequenceBuffer<T> {
sequence = 0; sequence = 0;
} }
public int getSequence() {
return sequence;
}
public void reset() { public void reset() {
sequence = 0; sequence = 0;
Arrays.fill(entrySequence, INVALID_SEQUENCE); Arrays.fill(entrySequence, INVALID_SEQUENCE);

View File

@ -1,39 +1,228 @@
package com.riiablo.net.reliable.channel; package com.riiablo.net.reliable.channel;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.socket.DatagramChannel; import io.netty.channel.socket.DatagramChannel;
import io.netty.channel.socket.DatagramPacket; import io.netty.channel.socket.DatagramPacket;
import com.badlogic.gdx.math.MathUtils;
import com.badlogic.gdx.utils.IntArray;
import com.badlogic.gdx.utils.Queue;
import com.riiablo.net.reliable.Log;
import com.riiablo.net.reliable.MessageChannel; import com.riiablo.net.reliable.MessageChannel;
import com.riiablo.net.reliable.Packet;
import com.riiablo.net.reliable.ReliableConfiguration; import com.riiablo.net.reliable.ReliableConfiguration;
import com.riiablo.net.reliable.ReliableUtils;
import com.riiablo.net.reliable.SequenceBuffer;
public class ReliableMessageChannel extends MessageChannel { public class ReliableMessageChannel extends MessageChannel {
private static final String TAG = "ReliableMessageChannel"; private static final String TAG = "ReliableMessageChannel";
private static final boolean DEBUG = true;
private static final boolean DEBUG_SEND = DEBUG && true;
private static final boolean DEBUG_RECEIVE = DEBUG && true;
private final ByteBuf packetBuffer = Unpooled.buffer();
private final SequenceBuffer<BufferedPacket> sendBuffer;
private final SequenceBuffer<BufferedPacket> receiveBuffer;
private final SequenceBuffer<OutgoingPacketSet> ackBuffer;
private final Queue<ByteBuf> messageQueue = new Queue<>(64, ByteBuf.class);
private final IntArray outgoingMessageIds = new IntArray(256);
private float time;
private float lastBufferFlush;
private float lastMessageSend;
private int oldestUnacked;
// private int sequence; // hides MessageChannel#sequence
private int nextReceive;
private boolean congestionControl = false;
private float congestionDisableTimer;
private float congestionDisableInterval;
private float lastCongestionSwitchTime;
public ReliableMessageChannel(PacketTransceiver packetTransceiver) { public ReliableMessageChannel(PacketTransceiver packetTransceiver) {
super(new ReliableConfiguration(), packetTransceiver); super(new ReliableConfiguration(), packetTransceiver);
this.sendBuffer = new SequenceBuffer<>(BufferedPacket.class, 256);
this.receiveBuffer = new SequenceBuffer<>(BufferedPacket.class, 256);
this.ackBuffer = new SequenceBuffer<>(OutgoingPacketSet.class, 256);
time = 0.0f;
lastBufferFlush = -1.0f;
lastMessageSend = 0.0f;
congestionDisableInterval = 5.0f;
sequence = 0;
nextReceive = 0;
oldestUnacked = 0;
} }
@Override @Override
public void reset() { public void reset() {
packetController.reset();
sendBuffer.reset();
// receiveBuffer.reset(); // this isn't in the original code? why?
ackBuffer.reset();
lastBufferFlush = -1.0f;
lastMessageSend = 0.0f;
congestionControl = false;
lastCongestionSwitchTime = 0.0f;
congestionDisableTimer = 0.0f;
congestionDisableInterval = 5.0f;
sequence = 0;
nextReceive = 0;
oldestUnacked = 0;
} }
@Override @Override
public void update(float delta, DatagramChannel ch) { public void update(float delta, int channelId, DatagramChannel ch) {
packetController.update(delta);
time += delta;
// see if we can pop messages off of the message queue and put them into the send queue
updateQueue(channelId, ch);
updateCongestion(delta, channelId, ch);
}
private void updateQueue(int channelId, DatagramChannel ch) {
if (messageQueue.size > 0) {
int sendBufferSize = 0;
for (int seq = oldestUnacked; ReliableUtils.sequenceLessThan(seq, sequence); seq = (seq + 1) & Packet.USHORT_MAX_VALUE) {
if (sendBuffer.exists(seq)) sendBufferSize++;
}
if (sendBufferSize < sendBuffer.numEntries) {
ByteBuf packetData = messageQueue.removeFirst();
sendMessage(channelId, ch, packetData);
}
}
}
private void updateCongestion(float delta, int channelId, DatagramChannel ch) {
boolean conditionsBad = packetController.rtt() >= 250.0f; // 250ms
// if conditions are bad, immediately enable congestion control and reset the congestion timer
if (conditionsBad) {
if (!congestionControl) {
// if we're within 10 seconds of the last time we switched, double the threshold interval
if (time - lastCongestionSwitchTime < 10.0) {
congestionDisableInterval = Math.min(congestionDisableInterval * 2, 60.0f);
}
lastCongestionSwitchTime = time;
}
congestionControl = true;
congestionDisableTimer = 0.0f;
}
// if we're in bad mode, and conditions are good, update the timer and see if we can disable
// congestion control
if (congestionControl && !conditionsBad) {
congestionDisableTimer += delta;
if (congestionDisableTimer >= congestionDisableInterval) {
congestionControl = false;
lastCongestionSwitchTime = time;
congestionDisableTimer = 0.0f;
}
}
// as long as conditions are good, halve the threshold interval every 10 seconds
if (!congestionControl) {
congestionDisableTimer += delta;
if (congestionDisableTimer > 10.0f) {
congestionDisableInterval = Math.max(congestionDisableInterval * 0.5f, 5.0f);
}
}
// if we're in congestion control mode, only send packets 10 times per second. otherwise, send
// 30 times per second
float flushInterval = congestionControl ? (1.0f / 10) : (1.0f / 30);
if (time - lastBufferFlush >= flushInterval) {
lastBufferFlush = time;
processSendBuffer(channelId, ch);
}
}
private void processSendBuffer(int channelId, DatagramChannel ch) {
// int numUnacked = 0;
// for (int seq = oldestUnacked; ReliableUtils.sequenceLessThan(seq, sequence); seq = (seq + 1) & Packet.USHORT_MAX_VALUE) {
// numUnacked++;
// }
for (int seq = oldestUnacked; ReliableUtils.sequenceLessThan(seq, sequence); seq = (seq + 1) & Packet.USHORT_MAX_VALUE) {
// never send message ID >= (oldestUnacked + bufferSize)
if (seq >= (oldestUnacked + 256)) break;
// for any message that hasn't been sent in the last 0.1 seconds and fits in the available
// space of our message packer, add it
BufferedPacket packet = sendBuffer.find(seq);
if (packet != null && !packet.writeLock) {
if (MathUtils.isEqual(time, packet.time, 0.1f)) continue;
boolean packetFits = false;
int packetSize = packetBuffer.readableBytes() + packet.bb.readableBytes();
if (packet.bb.readableBytes() < config.fragmentThreshold) {
packetFits = packetSize <= (config.fragmentThreshold - Packet.MAX_PACKET_HEADER_SIZE);
} else {
packetFits = packetSize <= (config.maxPacketSize - Packet.FRAGMENT_HEADER_SIZE - Packet.MAX_PACKET_HEADER_SIZE);
}
// if the packet won't fit, flush the message packet
if (!packetFits) {
flushPacketBuffer(channelId, ch);
}
packet.time = time;
packetBuffer.writeBytes(packet.bb);
outgoingMessageIds.add(seq);
lastMessageSend = time;
}
}
// if it has been 0.1 seconds since the last time we sent a message, send an empty message
if (time - lastMessageSend >= 0.1f) {
packetController.sendAck(channelId, ch);
lastMessageSend = time;
}
// flush and remaining messages in the packet buffer
flushPacketBuffer(channelId, ch);
}
private void flushPacketBuffer(int channelId, DatagramChannel ch) {
if (packetBuffer.readableBytes() > 0) {
int outgoingSeq = packetController.sendPacket(channelId, ch, packetBuffer);
OutgoingPacketSet outgoingPacket = ackBuffer.insert(outgoingSeq);
// store message IDs so we can map packet-level acks to message ID acks
outgoingPacket.messageIds.clear();
outgoingPacket.messageIds.addAll(outgoingMessageIds);
packetBuffer.clear();
outgoingMessageIds.clear();
}
} }
@Override @Override
public void sendMessage(int channelId, DatagramChannel ch, ByteBuf bb) { public void sendMessage(int channelId, DatagramChannel ch, ByteBuf bb) {
if (DEBUG_SEND) Log.debug(TAG, "sendMessage " + bb);
} }
@Override @Override
public void onMessageReceived(ChannelHandlerContext ctx, DatagramPacket packet) { public void onMessageReceived(ChannelHandlerContext ctx, DatagramPacket packet) {
if (DEBUG_SEND) Log.debug(TAG, "onMessageReceived " + packet);
} }
@Override @Override
@ -45,4 +234,14 @@ public class ReliableMessageChannel extends MessageChannel {
public void onPacketProcessed(int sequence, ByteBuf bb) { public void onPacketProcessed(int sequence, ByteBuf bb) {
} }
public static class BufferedPacket {
boolean writeLock = true;
float time;
ByteBuf bb;
}
public static class OutgoingPacketSet {
final IntArray messageIds = new IntArray();
}
} }

View File

@ -36,7 +36,7 @@ public class UnreliableMessageChannel extends MessageChannel {
} }
@Override @Override
public void update(float delta, DatagramChannel ch) { public void update(float delta, int channelId, DatagramChannel ch) {
packetController.update(delta); packetController.update(delta);
} }

View File

@ -31,7 +31,7 @@ public class UnreliableOrderedMessageChannel extends MessageChannel {
} }
@Override @Override
public void update(float delta, DatagramChannel ch) { public void update(float delta, int channelId, DatagramChannel ch) {
packetController.update(delta); packetController.update(delta);
} }

View File

@ -1,6 +1,6 @@
package com.riiablo.net.reliable.data; package com.riiablo.net.reliable.data;
public class ReceivedPacketData { public class ReceivedPacketData {
public long time; public float time;
public int packetSize; public int packetSize;
} }

View File

@ -1,7 +1,7 @@
package com.riiablo.net.reliable.data; package com.riiablo.net.reliable.data;
public class SentPacketData { public class SentPacketData {
public long time; public float time;
public boolean acked; public boolean acked;
public int packetSize; public int packetSize;
} }