Refactored mpq_bytebuf to work using netty Future and EventExecutorGroup

This commit is contained in:
Collin Smith
2021-09-18 13:34:32 -07:00
parent 9adfa09ae4
commit 5e8c3590a1
10 changed files with 601 additions and 489 deletions

View File

@ -0,0 +1,286 @@
package com.riiablo.mpq_bytebuf;
import io.netty.buffer.ByteBuf;
import io.netty.util.concurrent.DefaultEventExecutorGroup;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.EventExecutorGroup;
import io.netty.util.concurrent.FastThreadLocal;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.Promise;
import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit;
import com.riiablo.logger.LogManager;
import com.riiablo.logger.Logger;
import static com.riiablo.mpq_bytebuf.Mpq.Block.FLAG_COMPRESSED;
import static com.riiablo.mpq_bytebuf.Mpq.Block.FLAG_ENCRYPTED;
import static com.riiablo.mpq_bytebuf.Mpq.Block.FLAG_IMPLODE;
public class DecoderExecutorGroup extends DefaultEventExecutorGroup {
private static final Logger log = LogManager.getLogger(DecoderExecutorGroup.class);
private static final boolean DEBUG_MODE = !true;
static final FastThreadLocal<Decoder> decoders = new FastThreadLocal<Decoder>() {
@Override
protected Decoder initialValue() {
return new Decoder();
}
};
final Thread shutdownHook = new Thread(new Runnable() {
@Override
public void run() {
shutdownGracefully();
}
});
public DecoderExecutorGroup(int nThreads) {
super(nThreads);
try {
Runtime.getRuntime().addShutdownHook(shutdownHook);
} catch (Throwable t) {
log.warn("Problem occurred while trying to add runtime shutdown hook.", t);
}
}
@Override
public Future<?> shutdownGracefully(long quietPeriod, long timeout, TimeUnit unit) {
try {
Runtime.getRuntime().removeShutdownHook(shutdownHook);
} catch (IllegalStateException ignored) {
// called during runtime shutdown -- already shutting down
} catch (Throwable t) {
log.warn("Problem occurred while trying to remove runtime shutdown hook.", t);
}
return super.shutdownGracefully(quietPeriod, timeout, unit);
}
public DecodingTask newDecodingTask(
EventExecutor executor,
MpqFileHandle handle,
int offset,
int length
) {
return new DecodingTask(this, executor, handle, offset, length);
}
public static final class DecodingTask {
final EventExecutorGroup group;
final EventExecutor executor;
final MpqFileHandle handle;
final int offset;
final int length;
int numTasks;
final PromiseCombiner combiner;
DecodingTask(
EventExecutorGroup group,
EventExecutor executor,
MpqFileHandle handle,
int offset,
int length
) {
this.group = group;
this.executor = executor;
this.handle = handle;
this.offset = offset;
this.length = length;
this.numTasks = 0;
this.combiner = new PromiseCombiner(executor);
}
int numTasks() {
return numTasks;
}
Future<?> submit(
int sector,
int sectorOffset,
int sectorCSize,
int sectorFSize,
ByteBuf dst,
int dstOffset
) {
final Runnable task = new SectorDecodeTask(
handle,
sector,
sectorOffset,
sectorCSize,
sectorFSize,
dst,
dstOffset);
final Future<?> future = group.submit(task);
combiner.add(future);
return future;
}
Promise<Void> combine(Promise<Void> aggregatePromise) {
combiner.finish(aggregatePromise);
return aggregatePromise;
}
}
static final class SectorDecodeTask implements Runnable {
final Mpq mpq;
final MpqFileHandle handle;
final int sector;
final int sectorOffset;
final int sectorCSize;
final int sectorFSize;
final ByteBuf dst;
final int dstOffset;
SectorDecodeTask(
MpqFileHandle handle,
int sector,
int sectorOffset,
int sectorCSize,
int sectorFSize,
ByteBuf dst,
int dstOffset
) {
this.handle = handle;
this.mpq = handle.mpq;
this.sector = sector;
this.sectorOffset = sectorOffset;
this.sectorCSize = sectorCSize;
this.sectorFSize = sectorFSize;
this.dst = dst;
this.dstOffset = dstOffset;
}
@Override
public void run() {
if (handle.decoded(sector, dst)) return;
if (DEBUG_MODE) {
log.tracef(
"Decoding sector %s[%d] +0x%08x %d bytes -> %d bytes",
handle, sector, sectorOffset, sectorCSize, sectorFSize);
}
final boolean requiresDecompression = sectorCSize < sectorFSize;
final int flags = handle.flags;
if ((flags & FLAG_ENCRYPTED) == 0 && !requiresDecompression) {
ArchiveReadTask.getBytes(handle, sectorOffset, sectorFSize, dst, dstOffset);
return;
}
final ByteBuf bufferSlice = dst.slice(dstOffset, sectorFSize).writerIndex(0);
final ByteBuf sectorSlice = handle.mpq.sectorBuffer(); // thread-safe
final ByteBuf scratch = handle.mpq.sectorBuffer(); // thread-safe
try {
ArchiveReadTask
.getBytes(handle, sectorOffset, sectorCSize, sectorSlice, 0)
.writerIndex(sectorCSize);
if ((flags & FLAG_ENCRYPTED) == FLAG_ENCRYPTED) {
if (DEBUG_MODE) log.trace("Decrypting sector...");
Decrypter.decrypt(handle.encryptionKey() + sector, sectorSlice);
if (DEBUG_MODE) log.trace("Decrypted {} bytes", sectorFSize);
}
final Decoder decoder = decoders.get();
if ((flags & FLAG_COMPRESSED) == FLAG_COMPRESSED && requiresDecompression) {
if (DEBUG_MODE) log.trace("Decompressing sector...");
decoder.decode(sectorSlice, bufferSlice, scratch, sectorCSize, sectorFSize);
if (DEBUG_MODE) log.trace("Decompressed {} bytes", sectorFSize);
}
if ((flags & FLAG_IMPLODE) == FLAG_IMPLODE && requiresDecompression) {
if (DEBUG_MODE) log.trace("Exploding sector...");
decoder.decode(Decoder.FLAG_IMPLODE, sectorSlice, bufferSlice, scratch, sectorCSize, sectorFSize);
if (DEBUG_MODE) log.trace("Exploded {} bytes", sectorFSize);
}
} finally {
scratch.release();
sectorSlice.release();
}
}
/** decodes the specified sector in the callers thread */
static ByteBuf decodeSync(
MpqFileHandle handle,
int sector,
int sectorOffset,
int sectorCSize,
int sectorFSize,
ByteBuf dst,
int dstOffset
) {
new SectorDecodeTask(
handle,
sector,
sectorOffset,
sectorCSize,
sectorFSize,
dst,
dstOffset).run();
return dst.setIndex(dstOffset, dstOffset + sectorFSize);
}
}
public ArchiveReadTask newArchiveReadTask(
EventExecutor executor,
MpqFileHandle handle,
int offset,
int length,
ByteBuf dst,
int dstIndex
) {
return new ArchiveReadTask(this, executor, handle, offset, length, dst, dstIndex);
}
static final class ArchiveReadTask implements Callable<ByteBuf> {
static ByteBuf getBytes(
MpqFileHandle handle,
int offset,
int length,
ByteBuf dst,
int dstIndex
) {
synchronized (handle.mpq.lock()) {
return dst.setBytes(dstIndex, handle.archive, offset, length);
}
}
final EventExecutorGroup group;
final EventExecutor executor;
final MpqFileHandle handle;
final int offset;
final int length;
final ByteBuf dst;
final int dstIndex;
ArchiveReadTask(
EventExecutorGroup group,
EventExecutor executor,
MpqFileHandle handle,
int offset,
int length,
ByteBuf dst,
int dstIndex
) {
this.group = group;
this.executor = executor;
this.handle = handle;
this.offset = offset;
this.length = length;
this.dst = dst;
this.dstIndex = dstIndex;
}
/**
* Only to be used for reading complete uncompressed files
*/
@Override
public ByteBuf call() {
return getBytes(handle, offset, length, dst, dstIndex);
}
Future<ByteBuf> submit() {
return group.submit(this);
}
}
}

View File

@ -1,358 +0,0 @@
package com.riiablo.mpq_bytebuf;
import io.netty.buffer.ByteBuf;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;
import java.util.concurrent.RecursiveAction;
import java.util.concurrent.RecursiveTask;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.exception.ExceptionUtils;
import com.artemis.utils.BitVector;
import com.riiablo.logger.LogManager;
import com.riiablo.logger.Logger;
import static com.riiablo.mpq_bytebuf.Mpq.Block.FLAG_COMPRESSED;
import static com.riiablo.mpq_bytebuf.Mpq.Block.FLAG_ENCRYPTED;
import static com.riiablo.mpq_bytebuf.Mpq.Block.FLAG_IMPLODE;
public final class DecodingService extends ForkJoinPool {
private static final Logger log = LogManager.getLogger(DecodingService.class);
private static final boolean DEBUG_MODE = !true;
public static final Callback IGNORE = new Callback() {
@Override
public void onDecoded(MpqFileHandle handle, int offset, int length, ByteBuf buffer) {
}
@Override
public void onError(MpqFileHandle handle, Throwable throwable) {
}
};
static final ThreadLocal<Decoder> decoders = new ThreadLocal<Decoder>() {
@Override
protected Decoder initialValue() {
return new Decoder();
}
};
final Thread shutdownHook = new Thread(new Runnable() {
@Override
public void run() {
gracefulShutdown();
}
});
DecodingService(int nThreads) {
super(nThreads);
try {
Runtime.getRuntime().addShutdownHook(shutdownHook);
} catch (Throwable t) {
log.warn("Problem occurred while trying to add runtime shutdown hook.", t);
}
}
public Future<ByteBuf> submit(DecodingTask task) {
return super.submit(task);
}
public boolean gracefulShutdown() {
try {
Runtime.getRuntime().removeShutdownHook(shutdownHook);
} catch (IllegalStateException ignored) {
// called during runtime shutdown -- who cares
} catch (Throwable t) {
log.warn("Problem occurred while trying to remove runtime shutdown hook.", t);
}
shutdown();
boolean shutdown;
try {
shutdown = awaitTermination(Long.MAX_VALUE, TimeUnit.MILLISECONDS);
if (!shutdown) log.error("Executor did not terminate in the specified time.");
} catch (InterruptedException t) {
log.error(ExceptionUtils.getRootCauseMessage(t), t);
shutdown = false;
}
if (!shutdown) {
List<Runnable> droppedTasks = shutdownNow();
log.error("Executor was abruptly shut down. {} tasks will not be executed.", droppedTasks.size());
}
return shutdown;
}
public interface Callback {
void onDecoded(MpqFileHandle handle, int offset, int length, ByteBuf buffer);
void onError(MpqFileHandle handle, Throwable throwable);
}
static final class ArchiveRead implements Callable<ByteBuf> {
static ByteBuf getBytes(
MpqFileHandle handle,
int offset,
int length,
ByteBuf dst,
int dstIndex
) {
synchronized (handle.mpq.lock()) {
return dst.setBytes(dstIndex, handle.archive, offset, length);
}
}
final MpqFileHandle handle;
final int offset;
final int length;
final ByteBuf dst;
final int dstIndex;
final Callback callback;
ArchiveRead(
MpqFileHandle handle,
int offset,
int length,
ByteBuf dst,
int dstIndex,
Callback callback
) {
this.handle = handle;
this.offset = offset;
this.length = length;
this.dst = dst;
this.dstIndex = dstIndex;
this.callback = callback;
}
/**
* Only to be used for reading complete uncompressed files
*/
@Override
public ByteBuf call() {
getBytes(handle, offset, length, dst, dstIndex);
synchronized (handle.decoded) { handle.decoded.unsafeSet(0); }
callback.onDecoded(handle, offset, length, dst);
return dst;
}
}
static final class DecodingTask extends RecursiveTask<ByteBuf> {
final MpqFileHandle handle;
final int offset;
final int length;
final Collection<SectorDecodeTask> sectors;
final BitVector decoding;
final Callback callback;
DecodingTask(
MpqFileHandle handle,
int offset,
int length,
Collection<SectorDecodeTask> sectors,
BitVector decoding,
Callback callback
) {
this.handle = handle;
this.offset = offset;
this.length = length;
this.sectors = sectors;
this.decoding = decoding;
this.callback = callback;
}
@Override
protected ByteBuf compute() {
try {
if (DEBUG_MODE) log.trace("Decoding {} sectors...", sectors.size());
invokeAll(sectors);
if (DEBUG_MODE) log.trace("Decoded {} sectors", sectors.size());
synchronized (handle.decoded) { handle.decoded.or(decoding); }
ByteBuf buffer = handle.buffer.slice(offset, length);
callback.onDecoded(handle, offset, length, buffer);
return buffer;
} catch (Throwable t) {
callback.onError(handle, t);
throw t;
}
}
static Builder builder(
MpqFileHandle handle,
int offset,
int length,
int numSectors,
Callback callback
) {
return new Builder(handle, offset, length, numSectors, callback);
}
static ByteBuf decodeSync(
MpqFileHandle handle,
int sector,
int sectorOffset,
int sectorCSize,
int sectorFSize,
ByteBuf buffer,
int bufferOffset
) {
new DecodingService.SectorDecodeTask(
handle,
sector,
sectorOffset,
sectorCSize,
sectorFSize,
buffer,
bufferOffset).compute();
return buffer.setIndex(bufferOffset, bufferOffset + sectorFSize);
}
static final class Builder {
final MpqFileHandle handle;
final int offset;
final int length;
final Collection<SectorDecodeTask> sectors;
final BitVector decoding;
final Callback callback;
Builder(
MpqFileHandle handle,
int offset,
int length,
int numSectors,
Callback callback
) {
this.handle = handle;
this.offset = offset;
this.length = length;
this.sectors = new ArrayList<>(numSectors);
this.decoding = new BitVector(handle.numSectors); // bits must match handle#decoded
this.callback = callback;
}
int size() {
return sectors.size();
}
Builder add(
int sector,
int sectorOffset,
int sectorCSize,
int sectorFSize,
ByteBuf buffer,
int bufferOffset
) {
decoding.unsafeSet(sector);
sectors.add(new SectorDecodeTask(
handle,
sector,
sectorOffset,
sectorCSize,
sectorFSize,
buffer,
bufferOffset));
return this;
}
DecodingTask build() {
return new DecodingTask(handle, offset, length, sectors, decoding, callback);
}
}
}
static final class SectorDecodeTask extends RecursiveAction {
final Mpq mpq;
final MpqFileHandle handle;
final int sector;
final int sectorOffset;
final int sectorCSize;
final int sectorFSize;
final ByteBuf buffer;
final int bufferOffset;
SectorDecodeTask(
MpqFileHandle handle,
int sector,
int sectorOffset,
int sectorCSize,
int sectorFSize,
ByteBuf buffer,
int bufferOffset
) {
this.mpq = handle.mpq;
this.handle = handle;
this.sector = sector;
this.sectorOffset = sectorOffset;
this.sectorCSize = sectorCSize;
this.sectorFSize = sectorFSize;
this.buffer = buffer;
this.bufferOffset = bufferOffset;
}
@Override
protected void compute() {
// try {
decode();
// } catch (Throwable t) {
// log.errorf(
// "Error decoding sector %s[%d] +0x%08x %d bytes -> %d bytes into %d",
// handle, sector, sectorOffset, sectorCSize, sectorFSize, bufferOffset);
// log.error("... cont arch " + handle.archive);
// log.error("... cont buff " + buffer);
// throw t;
// }
}
void decode() {
synchronized (handle.decoded) { if (handle.decoded.unsafeGet(sector)) return; }
if (DEBUG_MODE) {
log.tracef(
"Decoding sector %s[%d] +0x%08x %d bytes -> %d bytes",
handle, sector, sectorOffset, sectorCSize, sectorFSize);
}
final boolean requiresDecompression = sectorCSize < sectorFSize;
final int flags = handle.flags;
if ((flags & FLAG_ENCRYPTED) == 0 && !requiresDecompression) {
ArchiveRead.getBytes(handle, sectorOffset, sectorFSize, buffer, bufferOffset);
return;
}
final ByteBuf bufferSlice = buffer.slice(bufferOffset, sectorFSize).writerIndex(0);
final ByteBuf sectorSlice = handle.mpq.sectorBuffer(); // thread-safe
final ByteBuf scratch = handle.mpq.sectorBuffer(); // thread-safe
try {
ArchiveRead
.getBytes(handle, sectorOffset, sectorCSize, sectorSlice, 0)
.writerIndex(sectorCSize);
if ((flags & FLAG_ENCRYPTED) == FLAG_ENCRYPTED) {
if (DEBUG_MODE) log.trace("Decrypting sector...");
Decrypter.decrypt(handle.encryptionKey() + sector, sectorSlice);
if (DEBUG_MODE) log.trace("Decrypted {} bytes", sectorFSize);
}
final Decoder decoder = decoders.get();
if ((flags & FLAG_COMPRESSED) == FLAG_COMPRESSED && requiresDecompression) {
if (DEBUG_MODE) log.trace("Decompressing sector...");
decoder.decode(sectorSlice, bufferSlice, scratch, sectorCSize, sectorFSize);
if (DEBUG_MODE) log.trace("Decompressed {} bytes", sectorFSize);
}
if ((flags & FLAG_IMPLODE) == FLAG_IMPLODE && requiresDecompression) {
if (DEBUG_MODE) log.trace("Exploding sector...");
decoder.decode(FLAG_IMPLODE, sectorSlice, bufferSlice, scratch, sectorCSize, sectorFSize);
if (DEBUG_MODE) log.trace("Exploded {} bytes", sectorFSize);
}
} finally {
scratch.release();
sectorSlice.release();
}
}
}
}

View File

@ -276,6 +276,10 @@ public enum Decrypter {
(byte) 0xf8, (byte) 0xf9, (byte) 0xfa, (byte) 0xfb, (byte) 0xfc, (byte) 0xfd, (byte) 0xfe, (byte) 0xff,
};
public static byte[] charMap() {
return charMap;
}
public static String fix(String str) {
final byte[] charMap = Decrypter.charMap;
final byte[] bytes = getBytes(str);

View File

@ -370,12 +370,12 @@ public final class Mpq implements Disposable {
/** @deprecated for use in tests, use MpqFileHandleResolver instead */
@Deprecated
MpqFileHandle open(final DecodingService decoder, final CharSequence filename, final short locale) {
MpqFileHandle open(final DecoderExecutorGroup decoder, final CharSequence filename, final short locale) {
final int index = get(filename, locale);
return open(decoder, index, filename);
}
MpqFileHandle open(final DecodingService decoder, final int index, final CharSequence filename) {
MpqFileHandle open(final DecoderExecutorGroup decoder, final int index, final CharSequence filename) {
return hashTable.open(decoder, this, index, filename);
}
@ -502,7 +502,7 @@ public final class Mpq implements Disposable {
return bestId;
}
MpqFileHandle open(final DecodingService decoder, final Mpq mpq, final int index, final CharSequence filename) {
MpqFileHandle open(final DecoderExecutorGroup decoder, final Mpq mpq, final int index, final CharSequence filename) {
MpqFileHandle handle = this.handle[index];
if (handle == null) {
final int blockId = this.blockId[index];

View File

@ -4,6 +4,11 @@ import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.ImmediateEventExecutor;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.ReferenceCountUpdater;
import java.io.BufferedInputStream;
import java.io.BufferedReader;
@ -16,11 +21,10 @@ import java.io.Reader;
import java.io.Writer;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.lang3.concurrent.ConcurrentUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import com.artemis.utils.BitVector;
@ -30,11 +34,8 @@ import com.badlogic.gdx.files.FileHandle;
import com.riiablo.logger.LogManager;
import com.riiablo.logger.Logger;
import com.riiablo.mpq_bytebuf.DecodingService.ArchiveRead;
import com.riiablo.mpq_bytebuf.DecodingService.Callback;
import com.riiablo.mpq_bytebuf.DecodingService.DecodingTask;
import com.riiablo.mpq_bytebuf.DecoderExecutorGroup.DecodingTask;
import static com.riiablo.mpq_bytebuf.DecodingService.IGNORE;
import static com.riiablo.mpq_bytebuf.Mpq.Block.FLAG_ENCRYPTED;
import static com.riiablo.mpq_bytebuf.Mpq.Block.FLAG_EXISTS;
import static com.riiablo.util.ImplUtils.unsupported;
@ -85,7 +86,7 @@ public final class MpqFileHandle extends FileHandle implements ReferenceCounted
@SuppressWarnings("unused")
private volatile int refCnt = updater.initialValue();
final DecodingService decoder;
final DecoderExecutorGroup decoder;
public final Mpq mpq;
final int index;
public final String filename;
@ -106,7 +107,7 @@ public final class MpqFileHandle extends FileHandle implements ReferenceCounted
int encryptionKey;
MpqFileHandle(
DecodingService decoder,
DecoderExecutorGroup decoder,
Mpq mpq,
int index,
String filename,
@ -178,8 +179,8 @@ public final class MpqFileHandle extends FileHandle implements ReferenceCounted
* released, instead the handle must be released when it is no longer needed.
*
* @see #buffer(int, int)
* @see #bufferAsync(Callback)
* @see #bufferAsync(int, int, Callback)
* @see #bufferAsync(EventExecutor)
* @see #bufferAsync(EventExecutor, int, int)
* @see #release()
*/
public ByteBuf buffer() {
@ -195,30 +196,18 @@ public final class MpqFileHandle extends FileHandle implements ReferenceCounted
* @param length length of decompressed contents after offset
*
* @see #buffer()
* @see #bufferAsync(Callback)
* @see #bufferAsync(int, int, Callback)
* @see #bufferAsync(EventExecutor)
* @see #bufferAsync(EventExecutor, int, int)
* @see #release()
*/
public ByteBuf buffer(int offset, int length) {
if (offset + length > FSize) {
throw new IndexOutOfBoundsException(
String.format(
"offset(+0x%x) + length(0x%x) exceeds declared FSize(0x%x)",
offset, length, FSize));
}
Future<ByteBuf> future = ensureReadable(offset, length, IGNORE);
try {
return future.get();
} catch (InterruptedException | ExecutionException t) {
return bufferAsync(ImmediateEventExecutor.INSTANCE, offset, length).get();
} catch (InterruptedException | CancellationException | ExecutionException t) {
return ExceptionUtils.rethrow(t);
}
}
public Future<ByteBuf> bufferAsync() {
return bufferAsync(IGNORE);
}
/**
* Schedules the contents of this mpq file for decoding and returns a future
* used to track the progress. Decoding the buffer may take a lot of time
@ -230,15 +219,15 @@ public final class MpqFileHandle extends FileHandle implements ReferenceCounted
*
* @see #buffer()
* @see #buffer(int, int)
* @see #bufferAsync()
* @see #bufferAsync(int, int, Callback)
* @see #bufferAsync(EventExecutor)
* @see #bufferAsync(EventExecutor, int, int)
* @see #release()
*/
public Future<ByteBuf> bufferAsync(Callback callback) {
return bufferAsync(0, FSize, callback);
public Future<ByteBuf> bufferAsync(EventExecutor executor) {
return bufferAsync(executor, 0, FSize);
}
public Future<ByteBuf> bufferAsync(int offset, int length, Callback callback) {
public Future<ByteBuf> bufferAsync(EventExecutor executor, int offset, int length) {
if (offset + length > FSize) {
throw new IndexOutOfBoundsException(
String.format(
@ -246,7 +235,7 @@ public final class MpqFileHandle extends FileHandle implements ReferenceCounted
offset, length, FSize));
}
return ensureReadable(offset, length, callback);
return ensureReadable(executor, offset, length);
}
int encryptionKey() {
@ -255,15 +244,15 @@ public final class MpqFileHandle extends FileHandle implements ReferenceCounted
: encryptionKey;
}
Future<ByteBuf> ensureReadable(final int offset, final int length, final Callback callback) {
Future<ByteBuf> ensureReadable(EventExecutor executor, int offset, int length) {
if (numSectors < 0) {
readSectorOffsets();
allocateBuffer();
}
return numSectors == 0
? readRawArchive(callback)
: decodeSectors(offset, length, callback);
? readRawArchive(executor)
: decodeSectors(executor, offset, length);
}
ByteBuf readSectorOffsets() {
@ -297,53 +286,97 @@ public final class MpqFileHandle extends FileHandle implements ReferenceCounted
return buffer;
}
Future<ByteBuf> readRawArchive(Callback callback) {
Future<ByteBuf> readRawArchive(EventExecutor executor) {
assert numSectors == 0 : "copyBuffer requires numSectors=" + numSectors;
final boolean decoded;
synchronized (this.decoded) {
decoded = this.decoded.unsafeGet(0); // using bit 0 as decoded tag for buffer
}
if (decoded) {
if (decoded(0)) { // using bit 0 as decoded tag for buffer
final ByteBuf buffer = this.buffer;
callback.onDecoded(this, 0, FSize, buffer);
return ConcurrentUtils.constantFuture(buffer);
return executor.newSucceededFuture(buffer);
}
return decoder.submit(new ArchiveRead(this, 0, FSize, buffer, 0, callback));
return decoder
.newArchiveReadTask(executor, this, 0, FSize, buffer, 0)
.submit()
.addListener(new FutureListener<ByteBuf>() {
@Override
public void operationComplete(Future<ByteBuf> future) {
setDecoded(0, buffer); // using bit 0 as decoded tag for buffer
}
});
}
Future<ByteBuf> decodeSectors(int offset, int length, Callback callback) {
Future<ByteBuf> decodeSectors(EventExecutor executor, final int offset, final int length) {
final int sectorSize = mpq.sectorSize;
int startSector = offset / sectorSize;
int endSector = (offset + length + sectorSize - 1) / sectorSize;
DecodingTask.Builder task = null;
DecodingTask task = null;
for (int i = startSector; i < endSector; i++) {
synchronized (decoded) { if (decoded.unsafeGet(i)) continue; }
final int bufferOffset = i * sectorSize;
final int sectorOffset = sectorOffsets.getIntLE(i << 2);
final int nextSectorOffset = sectorOffsets.getIntLE((i + 1) << 2);
final int sector = i;
if (decoded(sector)) continue;
final int bufferOffset = sector * sectorSize;
final int sectorOffset = sectorOffsets.getIntLE(sector << 2);
final int nextSectorOffset = sectorOffsets.getIntLE((sector + 1) << 2);
final int sectorCSize = nextSectorOffset - sectorOffset;
final int sectorFSize = Math.min(FSize - bufferOffset, sectorSize);
if (task == null) {
task = DecodingTask.builder(
this,
offset,
length,
endSector - i,
callback);
}
task.add(i, sectorOffset, sectorCSize, sectorFSize, buffer, bufferOffset);
if (task == null) task = decoder.newDecodingTask(executor, this, offset, length);
// if (buffer == null) throw new AssertionError("buffer was null?");
final ByteBuf buffer = this.buffer;
task.submit(sector, sectorOffset, sectorCSize, sectorFSize, buffer, bufferOffset)
.addListener(new FutureListener<Object>() {
@Override
public void operationComplete(Future<Object> future) {
setDecoded(sector, buffer);
}
});
}
if (task != null) {
if (DEBUG_MODE) log.trace("Submitting {} sectors for decoding", task.size());
return decoder.submit(task.build());
if (DEBUG_MODE) log.trace("Submitting {} sectors for decoding", task.numTasks());
final Promise<ByteBuf> aggregatePromise = executor.newPromise();
task.combine(executor.<Void>newPromise())
.addListener(new FutureListener<Void>() {
@Override
public void operationComplete(Future<Void> future) {
aggregatePromise.setSuccess(buffer.slice(offset, length));
}
});
return aggregatePromise;
}
ByteBuf buffer = this.buffer.slice(offset, length);
callback.onDecoded(this, offset, length, buffer);
return ConcurrentUtils.constantFuture(buffer);
return executor.newSucceededFuture(buffer);
}
/**
* Queries whether the buffer backed by this file handle already contains the
* decoded contents of the specified sector.
*
* @see #decoded(int, ByteBuf)
*/
boolean decoded(final int sector) {
return decoded(sector, this.buffer);
}
/**
* Queries whether the buffer backed by this file handle already contains the
* decoded contents of the specified sector, and that the given buffer is the
* backing buffer (and thus contains those contents).
*
* @see #decoded(int)
*/
boolean decoded(final int sector, final ByteBuf buffer) {
assert buffer != null : "buffer cannot be null";
if (this.buffer != buffer) return false;
synchronized (decoded) { return decoded.get(sector); }
}
/**
* Marks the specified sector as decoded iff {@code buffer} is the backing
* buffer.
*/
void setDecoded(final int sector, final ByteBuf buffer) {
assert buffer != null : "buffer cannot be null";
if (this.buffer != buffer) return;
synchronized (decoded) { decoded.unsafeSet(sector); }
}
public InputStream stream() {
@ -630,4 +663,22 @@ public final class MpqFileHandle extends FileHandle implements ReferenceCounted
public long lastModified() {
return unsupported("not supported for mpq files");
}
public FutureListener<ByteBuf> releasingFuture() {
return new ReleasingFuture(this);
}
public static final class ReleasingFuture implements FutureListener<ByteBuf> {
final MpqFileHandle handle;
public ReleasingFuture(MpqFileHandle handle) {
this.handle = handle;
}
@Override
public void operationComplete(Future<ByteBuf> future) {
log.info("Releasing {}", handle);
handle.release();
}
}
}

View File

@ -33,7 +33,7 @@ public class MpqFileResolver implements FileHandleResolver, Disposable {
public final Mpq d2video;
final Array<Mpq> mpqs = Array.of(true, 16, Mpq.class);
final DecodingService decoder = new DecodingService(2);
final DecoderExecutorGroup decoder = new DecoderExecutorGroup(2);
final Array<MountPoint> mounts = Array.of(true, 4, MountPoint.class);
public MpqFileResolver() {
@ -64,7 +64,7 @@ public class MpqFileResolver implements FileHandleResolver, Disposable {
@Override
public void dispose() {
decoder.gracefulShutdown();
decoder.shutdownGracefully();
for (Mpq mpq : mpqs) mpq.dispose();
mpqs.clear();
}

View File

@ -8,7 +8,6 @@ import java.io.InputStream;
import com.riiablo.logger.LogManager;
import com.riiablo.logger.Logger;
import com.riiablo.mpq_bytebuf.DecodingService.DecodingTask;
import static com.riiablo.mpq_bytebuf.Decrypter.ENCRYPTION;
import static com.riiablo.mpq_bytebuf.Decrypter.SEED2;
@ -19,7 +18,6 @@ public final class MpqStream extends InputStream {
private static final boolean DEBUG_MODE = !true;
final MpqFileHandle handle;
final DecodingService decoder;
final boolean releaseOnClose;
final boolean encrypted;
ByteBuf buffer;
@ -65,7 +63,6 @@ public final class MpqStream extends InputStream {
}
this.handle = handle;
this.decoder = handle.decoder;
this.encrypted = handle.encrypted();
this.releaseOnClose = releaseOnClose;
this.bytesRead = offset;
@ -188,7 +185,7 @@ public final class MpqStream extends InputStream {
final int bufferOffset = currentSector * sectorSize;
final int sectorCSize = nextSectorOffset - sectorOffset;
final int sectorFSize = Math.min(handle.FSize - bufferOffset, sectorSize);
return DecodingTask.decodeSync(
return DecoderExecutorGroup.SectorDecodeTask.decodeSync(
handle,
currentSector++,
sectorOffset,

View File

@ -0,0 +1,166 @@
package com.riiablo.mpq_bytebuf;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.EventExecutorGroup;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.ObjectUtil;
import java.util.concurrent.atomic.AtomicInteger;
/**
* Copy of {@link io.netty.util.concurrent.PromiseCombiner} with the
* modification that {@code doneCount} is an {@link AtomicInteger} in order to
* allow for atomic increments and checks. Maybe I was doing something wrong,
* but when submitting tasks to an {@link EventExecutorGroup}, aggregating the
* results and blocking using the standard netty impl, I was getting concurrency
* errors causing incorrect increments of {@code doneCount} when
* {@link PromiseCombiner#executor} was {@link io.netty.util.concurrent.ImmediateEventExecutor#INSTANCE}
*
* @see <a href="https://github.com/netty/netty/blob/d50cfc69e03a35e984d12d381de7a89ac8f0b2d7/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java">Netty Github Repository d50cfc6</a>
*/
final class PromiseCombiner {
private int expectedCount;
private final AtomicInteger doneCount = new AtomicInteger();
private Promise<Void> aggregatePromise;
private Throwable cause;
private final GenericFutureListener<Future<?>> listener = new GenericFutureListener<Future<?>>() {
@Override
public void operationComplete(final Future<?> future) {
if (executor.inEventLoop()) {
operationComplete0(future);
} else {
executor.execute(new Runnable() {
@Override
public void run() {
operationComplete0(future);
}
});
}
}
private void operationComplete0(Future<?> future) {
assert executor.inEventLoop();
final int doneCount = PromiseCombiner.this.doneCount.incrementAndGet();
if (!future.isSuccess() && cause == null) {
cause = future.cause();
}
if (doneCount == expectedCount && aggregatePromise != null) {
tryPromise();
}
}
};
private final EventExecutor executor;
/**
* The {@link EventExecutor} to use for notifications. You must call {@link #add(Future)}, {@link
* #addAll(Future[])} and {@link #finish(Promise)} from within the {@link EventExecutor} thread.
*
* @param executor
* the {@link EventExecutor} to use for notifications.
*/
public PromiseCombiner(EventExecutor executor) {
this.executor = ObjectUtil.checkNotNull(executor, "executor");
}
/**
* Adds a new promise to be combined. New promises may be added until an aggregate promise is
* added via the {@link io.netty.util.concurrent.PromiseCombiner#finish(Promise)} method.
*
* @param promise
* the promise to add to this promise combiner
*
* @deprecated Replaced by {@link io.netty.util.concurrent.PromiseCombiner#add(Future)}.
*/
@Deprecated
public void add(Promise promise) {
add((Future) promise);
}
/**
* Adds a new future to be combined. New futures may be added until an aggregate promise is added
* via the {@link io.netty.util.concurrent.PromiseCombiner#finish(Promise)} method.
*
* @param future
* the future to add to this promise combiner
*/
@SuppressWarnings({"unchecked", "rawtypes"})
public void add(Future future) {
checkAddAllowed();
checkInEventLoop();
++expectedCount;
future.addListener(listener);
}
/**
* Adds new promises to be combined. New promises may be added until an aggregate promise is added
* via the {@link io.netty.util.concurrent.PromiseCombiner#finish(Promise)} method.
*
* @param promises
* the promises to add to this promise combiner
*
* @deprecated Replaced by {@link io.netty.util.concurrent.PromiseCombiner#addAll(Future[])}
*/
@Deprecated
public void addAll(Promise... promises) {
addAll((Future[]) promises);
}
/**
* Adds new futures to be combined. New futures may be added until an aggregate promise is added
* via the {@link io.netty.util.concurrent.PromiseCombiner#finish(Promise)} method.
*
* @param futures
* the futures to add to this promise combiner
*/
@SuppressWarnings({"unchecked", "rawtypes"})
public void addAll(Future... futures) {
for (Future future : futures) {
this.add(future);
}
}
/**
* <p>Sets the promise to be notified when all combined futures have finished. If all combined
* futures succeed,
* then the aggregate promise will succeed. If one or more combined futures fails, then the
* aggregate promise will fail with the cause of one of the failed futures. If more than one
* combined future fails, then exactly which failure will be assigned to the aggregate promise is
* undefined.</p>
*
* <p>After this method is called, no more futures may be added via the {@link
* io.netty.util.concurrent.PromiseCombiner#add(Future)} or
* {@link io.netty.util.concurrent.PromiseCombiner#addAll(Future[])} methods.</p>
*
* @param aggregatePromise
* the promise to notify when all combined futures have finished
*/
public void finish(Promise<Void> aggregatePromise) {
ObjectUtil.checkNotNull(aggregatePromise, "aggregatePromise");
checkInEventLoop();
if (this.aggregatePromise != null) {
throw new IllegalStateException("Already finished");
}
this.aggregatePromise = aggregatePromise;
if (doneCount.get() == expectedCount) {
tryPromise();
}
}
private void checkInEventLoop() {
if (!executor.inEventLoop()) {
throw new IllegalStateException("Must be called from EventExecutor thread");
}
}
private boolean tryPromise() {
return (cause == null) ? aggregatePromise.trySuccess(null) : aggregatePromise.tryFailure(cause);
}
private void checkAddAllowed() {
if (aggregatePromise != null) {
throw new IllegalStateException("Adding promises is not allowed after finished adding");
}
}
}

View File

@ -9,25 +9,27 @@ import static org.junit.jupiter.api.TestInstance.Lifecycle.*;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.ImmediateEventExecutor;
import io.netty.util.concurrent.Promise;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicReference;
import com.badlogic.gdx.files.FileHandle;
import com.riiablo.RiiabloTest;
import com.riiablo.logger.Level;
import com.riiablo.logger.LogManager;
import com.riiablo.mpq_bytebuf.DecodingService.Callback;
import static com.riiablo.mpq_bytebuf.DecodingService.IGNORE;
import static com.riiablo.mpq_bytebuf.Mpq.DEFAULT_LOCALE;
class DecodingTest extends RiiabloTest {
@BeforeAll
public static void before() {
LogManager.setLevel("com.riiablo.mpq_bytebuf.MpqFileHandle", Level.TRACE);
LogManager.setLevel("com.riiablo.mpq_bytebuf.DecodingService", Level.TRACE);
LogManager.setLevel("com.riiablo.mpq_bytebuf.DecoderExecutorGroup", Level.TRACE);
}
@Nested
@ -42,8 +44,8 @@ class DecodingTest extends RiiabloTest {
"data\\global\\CHARS\\BA\\LG\\BALGLITTNHTH.DCC",
"DATA\\GLOBAL\\CHARS\\PA\\LA\\PALALITTN1HS.DCC",
})
void decode(String in) {
DecodingService decoder = new DecodingService(4);
void decode_await(String in) {
DecoderExecutorGroup decoder = new DecoderExecutorGroup(4);
try {
MpqFileHandle handle = mpq.open(decoder, in, DEFAULT_LOCALE);
try {
@ -56,7 +58,7 @@ class DecodingTest extends RiiabloTest {
handle.release();
}
} finally {
decoder.gracefulShutdown();
decoder.shutdownGracefully();
}
}
@ -65,68 +67,32 @@ class DecodingTest extends RiiabloTest {
"data\\global\\CHARS\\BA\\LG\\BALGLITTNHTH.DCC",
"DATA\\GLOBAL\\CHARS\\PA\\LA\\PALALITTN1HS.DCC",
})
void decode_async(String in) {
DecodingService decoder = new DecodingService(4);
try {
MpqFileHandle handle = mpq.open(decoder, in, DEFAULT_LOCALE);
try {
final ByteBuf actual;
Future<ByteBuf> future = handle.bufferAsync(IGNORE);
try {
actual = future.get();
} catch (InterruptedException | ExecutionException t) {
fail(t);
return;
}
FileHandle handle_out = testAsset(in);
ByteBuf expected = Unpooled.wrappedBuffer(handle_out.readBytes());
assertTrue(ByteBufUtil.equals(expected, actual));
} finally {
handle.release();
}
} finally {
decoder.gracefulShutdown();
}
}
@ParameterizedTest
@ValueSource(strings = {
"data\\global\\CHARS\\BA\\LG\\BALGLITTNHTH.DCC",
"DATA\\GLOBAL\\CHARS\\PA\\LA\\PALALITTN1HS.DCC",
})
void decode_callback(String in) {
DecodingService decoder = new DecodingService(4);
void decode_future(String in) {
DecoderExecutorGroup decoder = new DecoderExecutorGroup(4);
try {
MpqFileHandle handle = mpq.open(decoder, in, DEFAULT_LOCALE);
try {
final AtomicReference<ByteBuf> actual = new AtomicReference<>();
Future<ByteBuf> future = handle.bufferAsync(new Callback() {
@Override
public void onDecoded(MpqFileHandle handle, int offset, int length, ByteBuf buffer) {
actual.set(buffer);
}
@Override
public void onError(MpqFileHandle handle, Throwable throwable) {
}
});
final EventExecutor executor = ImmediateEventExecutor.INSTANCE;
Promise<ByteBuf> actual = executor.newPromise();
Future<ByteBuf> future = handle.bufferAsync(executor);
future.addListener((FutureListener<ByteBuf>) f -> actual.setSuccess(f.getNow()));
try {
ByteBuf futureResult = future.get();
assertSame(actual.get(), futureResult);
} catch (InterruptedException | ExecutionException t) {
actual.await();
assertSame(actual.getNow(), futureResult);
} catch (InterruptedException | CancellationException | ExecutionException t) {
fail(t);
return;
}
FileHandle handle_out = testAsset(in);
ByteBuf expected = Unpooled.wrappedBuffer(handle_out.readBytes());
assertTrue(ByteBufUtil.equals(expected, actual.get()));
assertTrue(ByteBufUtil.equals(expected, actual.getNow()));
} finally {
handle.release();
}
} finally {
decoder.gracefulShutdown();
decoder.shutdownGracefully();
}
}
}

View File

@ -26,7 +26,7 @@ class MpqStreamTest extends RiiabloTest {
public static void before() {
LogManager.setLevel("com.riiablo.mpq_bytebuf", Level.WARN);
LogManager.setLevel("com.riiablo.mpq_bytebuf.MpqStream", Level.TRACE);
LogManager.setLevel("com.riiablo.mpq_bytebuf.DecodingService", Level.WARN);
LogManager.setLevel("com.riiablo.mpq_bytebuf.DecoderExecutorGroup", Level.WARN);
}
@Test
@ -45,7 +45,7 @@ class MpqStreamTest extends RiiabloTest {
}
void read(String in) throws IOException {
DecodingService decoder = new DecodingService(4);
DecoderExecutorGroup decoder = new DecoderExecutorGroup(4);
try {
MpqFileHandle handle = mpq.open(decoder, in, DEFAULT_LOCALE);
InputStream stream = MpqStream.open(handle, 0, handle.FSize, true);
@ -61,7 +61,7 @@ class MpqStreamTest extends RiiabloTest {
ByteBuf expected = Unpooled.wrappedBuffer(handle_out.readBytes());
assertTrue(ByteBufUtil.equals(expected, actual));
} finally {
decoder.gracefulShutdown();
decoder.shutdownGracefully();
}
}
}