From 3c2ba5101adc0d1b55076134de7a11462b7b9b29 Mon Sep 17 00:00:00 2001 From: Charles Oliver Nutter Date: Wed, 19 Feb 2025 18:07:13 -0600 Subject: [PATCH] Guard access to ptr behind frozen check --- ext/java/org/jruby/ext/stringio/StringIO.java | 133 +++++++++--------- 1 file changed, 69 insertions(+), 64 deletions(-) diff --git a/ext/java/org/jruby/ext/stringio/StringIO.java b/ext/java/org/jruby/ext/stringio/StringIO.java index ae8ff55..2bd9ab6 100644 --- a/ext/java/org/jruby/ext/stringio/StringIO.java +++ b/ext/java/org/jruby/ext/stringio/StringIO.java @@ -97,7 +97,14 @@ static class StringIOData { int flags; volatile Object owner; } - StringIOData ptr; + private StringIOData ptr; + + // MRI: get_strio, StringIO macro + private StringIOData getPtr() { + // equivalent to rb_io_taint_check without tainting + checkFrozen(); + return ptr; + } private static final String STRINGIO_VERSION = "3.1.4"; @@ -141,7 +148,7 @@ public static RubyClass createStringIOClass(final Ruby runtime) { // mri: get_enc public Encoding getEncoding() { - StringIOData ptr = this.ptr; + StringIOData ptr = this.getPtr(); Encoding enc = ptr.enc; if (enc != null) { return enc; @@ -156,7 +163,7 @@ public Encoding getEncoding() { } public void setEncoding(Encoding enc) { - ptr.enc = enc; + getPtr().enc = enc; } @JRubyMethod(name = "new", rest = true, meta = true) @@ -256,7 +263,7 @@ private static IRubyObject yieldOrReturn(ThreadContext context, Block block, Str try { val = block.yield(context, strio); } finally { - strio.ptr.string = null; + strio.getPtr().string = null; strio.flags &= ~STRIO_READWRITE; } } @@ -270,7 +277,7 @@ protected StringIO(Ruby runtime, RubyClass klass) { @JRubyMethod(visibility = PRIVATE, keywords = true) public IRubyObject initialize(ThreadContext context) { - if (ptr == null) { + if (getPtr() == null) { ptr = new StringIOData(); } @@ -282,7 +289,7 @@ public IRubyObject initialize(ThreadContext context) { @JRubyMethod(visibility = PRIVATE, keywords = true) public IRubyObject initialize(ThreadContext context, IRubyObject arg0) { - if (ptr == null) { + if (getPtr() == null) { ptr = new StringIOData(); } @@ -294,7 +301,7 @@ public IRubyObject initialize(ThreadContext context, IRubyObject arg0) { @JRubyMethod(visibility = PRIVATE, keywords = true) public IRubyObject initialize(ThreadContext context, IRubyObject arg0, IRubyObject arg1) { - if (ptr == null) { + if (getPtr() == null) { ptr = new StringIOData(); } @@ -306,7 +313,7 @@ public IRubyObject initialize(ThreadContext context, IRubyObject arg0, IRubyObje @JRubyMethod(visibility = PRIVATE, keywords = true) public IRubyObject initialize(ThreadContext context, IRubyObject arg0, IRubyObject arg1, IRubyObject arg2) { - if (ptr == null) { + if (getPtr() == null) { ptr = new StringIOData(); } @@ -322,7 +329,7 @@ private void strioInit(ThreadContext context, int argc, IRubyObject arg0, IRubyO IRubyObject string = context.nil; IRubyObject vmode = context.nil; - StringIOData ptr = this.ptr; + StringIOData ptr = this.getPtr(); boolean locked = lock(context, ptr); try { @@ -410,7 +417,7 @@ public IRubyObject initialize_copy(ThreadContext context, IRubyObject other) { if (this == otherIO) return this; - ptr = otherIO.ptr; + ptr = otherIO.getPtr(); flags = flags & ~STRIO_READWRITE | otherIO.flags & STRIO_READWRITE; return this; @@ -418,7 +425,7 @@ public IRubyObject initialize_copy(ThreadContext context, IRubyObject other) { @JRubyMethod public IRubyObject binmode(ThreadContext context) { - StringIOData ptr = this.ptr; + StringIOData ptr = this.getPtr(); ptr.enc = EncodingUtils.ascii8bitEncoding(context.runtime); if (writable()) ptr.string.setEncoding(ptr.enc); @@ -477,7 +484,7 @@ public IRubyObject closed_p() { public IRubyObject close_read(ThreadContext context) { // ~ checkReadable() : checkInitialized(); - if ( (ptr.flags & OpenFile.READABLE) == 0 ) { + if ( (getPtr().flags & OpenFile.READABLE) == 0 ) { throw context.runtime.newIOError("not opened for reading"); } int flags = this.flags; @@ -497,7 +504,7 @@ public IRubyObject closed_read_p() { public IRubyObject close_write(ThreadContext context) { // ~ checkWritable() : checkInitialized(); - if ( (ptr.flags & OpenFile.WRITABLE) == 0 ) { + if ( (getPtr().flags & OpenFile.WRITABLE) == 0 ) { throw context.runtime.newIOError("not opened for writing"); } int flags = this.flags; @@ -614,7 +621,7 @@ public IRubyObject each_byte(ThreadContext context, Block block) { if (!block.isGiven()) return enumeratorize(runtime, this, "each_byte"); checkReadable(); - StringIOData ptr = this.ptr; + StringIOData ptr = this.getPtr(); boolean locked = lock(context, ptr); try { @@ -648,11 +655,13 @@ public IRubyObject each_char(final ThreadContext context, final Block block) { @JRubyMethod(name = {"eof", "eof?"}) public IRubyObject eof(ThreadContext context) { checkReadable(); + StringIOData ptr = getPtr(); if (ptr.pos < ptr.string.size()) return context.fals; return context.tru; } private boolean isEndOfString() { + StringIOData ptr = getPtr(); return ptr.string == null || ptr.pos >= ptr.string.size(); } @@ -662,7 +671,7 @@ public IRubyObject getc(ThreadContext context) { if (isEndOfString()) return context.nil; - StringIOData ptr = this.ptr; + StringIOData ptr = this.getPtr(); boolean locked = lock(context, ptr); try { @@ -685,7 +694,7 @@ public IRubyObject getbyte(ThreadContext context) { if (isEndOfString()) return context.nil; int c; - StringIOData ptr = this.ptr; + StringIOData ptr = this.getPtr(); boolean locked = lock(context, ptr); try { c = ptr.string.getByteList().get(ptr.pos++) & 0xFF; @@ -699,7 +708,7 @@ public IRubyObject getbyte(ThreadContext context) { // MRI: strio_substr // must be called under lock private RubyString strioSubstr(Ruby runtime, int pos, int len, Encoding enc) { - StringIOData ptr = this.ptr; + StringIOData ptr = this.getPtr(); final RubyString string = ptr.string; int rlen = string.size() - pos; @@ -750,25 +759,25 @@ private static int bm_search(byte[] little, int lstart, int llen, byte[] big, in @JRubyMethod(name = "gets", writes = FrameField.LASTLINE) public IRubyObject gets(ThreadContext context) { - if (ptr.string == null) return context.nil; + if (getPtr().string == null) return context.nil; return Getline.getlineCall(context, GETLINE, this, getEncoding()); } @JRubyMethod(name = "gets", writes = FrameField.LASTLINE) public IRubyObject gets(ThreadContext context, IRubyObject arg0) { - if (ptr.string == null) return context.nil; + if (getPtr().string == null) return context.nil; return Getline.getlineCall(context, GETLINE, this, getEncoding(), arg0); } @JRubyMethod(name = "gets", writes = FrameField.LASTLINE) public IRubyObject gets(ThreadContext context, IRubyObject arg0, IRubyObject arg1) { - if (ptr.string == null) return context.nil; + if (getPtr().string == null) return context.nil; return Getline.getlineCall(context, GETLINE, this, getEncoding(), arg0, arg1); } @JRubyMethod(name = "gets", writes = FrameField.LASTLINE) public IRubyObject gets(ThreadContext context, IRubyObject arg0, IRubyObject arg1, IRubyObject arg2) { - if (ptr.string == null) return context.nil; + if (getPtr().string == null) return context.nil; return Getline.getlineCall(context, GETLINE, this, getEncoding(), arg0, arg1, arg2); } @@ -792,7 +801,7 @@ public IRubyObject gets(ThreadContext context, IRubyObject[] args) { self.checkReadable(); if (limit == 0) { - if (self.ptr.string == null) return context.nil; + if (self.getPtr().string == null) return context.nil; return RubyString.newEmptyString(context.runtime, self.getEncoding()); } @@ -808,7 +817,7 @@ public IRubyObject gets(ThreadContext context, IRubyObject[] args) { private static final Getline.Callback GETLINE_YIELD = (context, self, rs, limit, chomp, block) -> { IRubyObject line; - StringIOData ptr = self.ptr; + StringIOData ptr = self.getPtr(); if (ptr.string == null || ptr.pos > ptr.string.size()) { return self; } @@ -831,7 +840,7 @@ public IRubyObject gets(ThreadContext context, IRubyObject[] args) { RubyArray ary = (RubyArray) context.runtime.newArray(); IRubyObject line; - StringIOData ptr = self.ptr; + StringIOData ptr = self.getPtr(); if (ptr.string == null || ptr.pos > ptr.string.size()) { return null; } @@ -863,7 +872,7 @@ private IRubyObject getline(ThreadContext context, final IRubyObject rs, int lim return context.nil; } - StringIOData ptr = this.ptr; + StringIOData ptr = this.getPtr(); Encoding enc = getEncoding(); boolean locked = lock(context, ptr); @@ -967,19 +976,19 @@ private static int chompNewlineWidth(byte[] bytes, int s, int e) { @JRubyMethod(name = {"length", "size"}) public IRubyObject length(ThreadContext context) { checkInitialized(); - RubyString myString = ptr.string; + RubyString myString = getPtr().string; if (myString == null) return RubyFixnum.zero(context.runtime); return getRuntime().newFixnum(myString.size()); } @JRubyMethod(name = "lineno") public IRubyObject lineno(ThreadContext context) { - return context.runtime.newFixnum(ptr.lineno); + return context.runtime.newFixnum(getPtr().lineno); } @JRubyMethod(name = "lineno=", required = 1) public IRubyObject set_lineno(ThreadContext context, IRubyObject arg) { - ptr.lineno = RubyNumeric.fix2int(arg); + getPtr().lineno = RubyNumeric.fix2int(arg); return context.nil; } @@ -988,7 +997,7 @@ public IRubyObject set_lineno(ThreadContext context, IRubyObject arg) { public IRubyObject pos(ThreadContext context) { checkInitialized(); - return context.runtime.newFixnum(ptr.pos); + return context.runtime.newFixnum(getPtr().pos); } @JRubyMethod(name = "pos=", required = 1) @@ -1001,13 +1010,13 @@ public IRubyObject set_pos(IRubyObject arg) { if (p > Integer.MAX_VALUE) throw getRuntime().newArgumentError("JRuby does not support StringIO larger than " + Integer.MAX_VALUE + " bytes"); - ptr.pos = (int)p; + getPtr().pos = (int)p; return arg; } private void strioExtend(ThreadContext context, int pos, int len) { - StringIOData ptr = this.ptr; + StringIOData ptr = this.getPtr(); boolean locked = lock(context, ptr); try { @@ -1046,12 +1055,12 @@ public IRubyObject putc(ThreadContext context, IRubyObject ch) { checkModifiable(); if (ch instanceof RubyString) { - if (ptr.string == null) return context.nil; + if (getPtr().string == null) return context.nil; str = substrString((RubyString) ch, str, runtime); } else { byte c = RubyNumeric.num2chr(ch); - if (ptr.string == null) return context.nil; + if (getPtr().string == null) return context.nil; str = RubyString.newString(runtime, new byte[]{c}); } write(context, str); @@ -1083,7 +1092,7 @@ private IRubyObject readCommon(ThreadContext context, int argc, IRubyObject arg0 IRubyObject str = context.nil; boolean binary = false; - StringIOData ptr = this.ptr; + StringIOData ptr = this.getPtr(); int pos = ptr.pos; boolean locked = lock(context, ptr); @@ -1181,7 +1190,7 @@ public IRubyObject pread(ThreadContext context, IRubyObject arg0, IRubyObject ar @SuppressWarnings("fallthrough") private RubyString preadCommon(ThreadContext context, int argc, IRubyObject arg0, IRubyObject arg1, IRubyObject arg2) { IRubyObject str = context.nil; - StringIOData ptr = this.ptr; + StringIOData ptr = this.getPtr(); Ruby runtime = context.runtime; int offset; final RubyString string; @@ -1296,7 +1305,7 @@ public IRubyObject reopen(ThreadContext context) { // MRI: strio_reopen @JRubyMethod(name = "reopen", keywords = true) public IRubyObject reopen(ThreadContext context, IRubyObject arg0) { - checkFrozen(); + checkModifiable(); if (!(arg0 instanceof RubyString)) { return initialize_copy(context, arg0); @@ -1310,7 +1319,7 @@ public IRubyObject reopen(ThreadContext context, IRubyObject arg0) { // MRI: strio_reopen @JRubyMethod(name = "reopen", keywords = true) public IRubyObject reopen(ThreadContext context, IRubyObject arg0, IRubyObject arg1) { - checkFrozen(); + checkModifiable(); // reset the state strioInit(context, 2, arg0, arg1, null); @@ -1320,7 +1329,7 @@ public IRubyObject reopen(ThreadContext context, IRubyObject arg0, IRubyObject a // MRI: strio_reopen @JRubyMethod(name = "reopen", keywords = true) public IRubyObject reopen(ThreadContext context, IRubyObject arg0, IRubyObject arg1, IRubyObject arg2) { - checkFrozen(); + checkModifiable(); // reset the state strioInit(context, 3, arg0, arg1, arg2); @@ -1331,7 +1340,7 @@ public IRubyObject reopen(ThreadContext context, IRubyObject arg0, IRubyObject a public IRubyObject rewind(ThreadContext context) { checkInitialized(); - StringIOData ptr = this.ptr; + StringIOData ptr = this.getPtr(); boolean locked = lock(context, ptr); try { @@ -1355,7 +1364,7 @@ public IRubyObject seek(ThreadContext context, IRubyObject arg0, IRubyObject arg } private RubyFixnum seekCommon(ThreadContext context, int argc, IRubyObject arg0, IRubyObject arg1) { - checkFrozen(); + checkModifiable(); Ruby runtime = context.runtime; @@ -1368,7 +1377,7 @@ private RubyFixnum seekCommon(ThreadContext context, int argc, IRubyObject arg0, checkOpen(); - StringIOData ptr = this.ptr; + StringIOData ptr = this.getPtr(); boolean locked = lock(context, ptr); try { @@ -1398,7 +1407,7 @@ private RubyFixnum seekCommon(ThreadContext context, int argc, IRubyObject arg0, @JRubyMethod(name = "string=", required = 1) public IRubyObject set_string(ThreadContext context, IRubyObject arg) { checkFrozen(); - StringIOData ptr = this.ptr; + StringIOData ptr = this.getPtr(); boolean locked = lock(context, ptr); try { @@ -1415,7 +1424,7 @@ public IRubyObject set_string(ThreadContext context, IRubyObject arg) { @JRubyMethod(name = "string") public IRubyObject string(ThreadContext context) { - RubyString string = ptr.string; + RubyString string = getPtr().string; if (string == null) return context.nil; return string; @@ -1432,7 +1441,7 @@ public IRubyObject truncate(ThreadContext context, IRubyObject len) { checkWritable(); int l = RubyFixnum.fix2int(len); - StringIOData ptr = this.ptr; + StringIOData ptr = this.getPtr(); RubyString string = ptr.string; boolean locked = lock(context, ptr); @@ -1464,7 +1473,7 @@ public IRubyObject ungetc(ThreadContext context, IRubyObject arg) { checkModifiable(); checkReadable(); - if (ptr.string == null) return context.nil; + if (getPtr().string == null) return context.nil; if (arg.isNil()) return arg; if (arg instanceof RubyInteger) { @@ -1491,7 +1500,7 @@ public IRubyObject ungetc(ThreadContext context, IRubyObject arg) { } private void ungetbyteCommon(ThreadContext context, int c) { - StringIOData ptr = this.ptr; + StringIOData ptr = this.getPtr(); boolean locked = lock(context, ptr); try { @@ -1522,7 +1531,7 @@ private void ungetbyteCommon(ThreadContext context, RubyString ungetBytes) { private void ungetbyteCommon(ThreadContext context, byte[] ungetBytes, int cp, int cl) { if (cl == 0) return; - StringIOData ptr = this.ptr; + StringIOData ptr = this.getPtr(); boolean locked = lock(context, ptr); try { @@ -1576,7 +1585,7 @@ public IRubyObject ungetbyte(ThreadContext context, IRubyObject arg) { if (arg.isNil()) return arg; checkModifiable(); - if (ptr.string == null) return context.nil; + if (getPtr().string == null) return context.nil; if (arg instanceof RubyInteger) { ungetbyteCommon(context, ((RubyInteger) ((RubyInteger) arg).op_mod(context, 256)).getIntValue()); @@ -1686,7 +1695,7 @@ private long stringIOWrite(ThreadContext context, Ruby runtime, IRubyObject arg) RubyString str = arg.asString(); int len, olen; - StringIOData ptr = this.ptr; + StringIOData ptr = this.getPtr(); boolean locked = lock(context, ptr); try { @@ -1791,7 +1800,7 @@ public IRubyObject set_encoding(ThreadContext context, IRubyObject ext_enc) { } } - StringIOData ptr = this.ptr; + StringIOData ptr = this.getPtr(); boolean locked = lock(context, ptr); try { @@ -1827,12 +1836,14 @@ public IRubyObject set_encoding(ThreadContext context, IRubyObject enc, IRubyObj @JRubyMethod public IRubyObject set_encoding_by_bom(ThreadContext context) { - if (setEncodingByBOM(context) == null) return context.nil; + StringIOData ptr = getPtr(); + + if (setEncodingByBOM(context, ptr) == null) return context.nil; return context.runtime.getEncodingService().convertEncodingToRubyEncoding(ptr.enc); } - private Encoding setEncodingByBOM(ThreadContext context) { + private Encoding setEncodingByBOM(ThreadContext context, StringIOData ptr) { Encoding enc = detectBOM(context, ptr.string, (ctx, enc2, bomlen) -> { ptr.pos = bomlen; if (writable()) { @@ -1908,7 +1919,7 @@ public IRubyObject each_codepoint(ThreadContext context, Block block) { checkReadable(); - StringIOData ptr = this.ptr; + StringIOData ptr = this.getPtr(); boolean locked = lock(context, ptr); try { @@ -2156,25 +2167,19 @@ public IRubyObject puts(ThreadContext context, IRubyObject[] args) { return GenericWritable.puts(context, this, args); } - /* rb: check_modifiable */ - public void checkFrozen() { - super.checkFrozen(); - checkInitialized(); - } - private boolean readable() { return (flags & STRIO_READABLE) != 0 - && (ptr.flags & OpenFile.READABLE) != 0; + && (getPtr().flags & OpenFile.READABLE) != 0; } private boolean writable() { return (flags & STRIO_WRITABLE) != 0 - && (ptr.flags & OpenFile.WRITABLE) != 0; + && (getPtr().flags & OpenFile.WRITABLE) != 0; } private boolean closed() { return !((flags & STRIO_READWRITE) != 0 - && (ptr.flags & OpenFile.READWRITE) != 0); + && (getPtr().flags & OpenFile.READWRITE) != 0); } /* rb: readable */ @@ -2197,11 +2202,11 @@ private void checkWritable() { private void checkModifiable() { checkFrozen(); - if (ptr.string.isFrozen()) throw getRuntime().newIOError("not modifiable string"); + if (getPtr().string.isFrozen()) throw getRuntime().newIOError("not modifiable string"); } private void checkInitialized() { - if (ptr == null) { + if (getPtr() == null) { throw getRuntime().newIOError("uninitialized stream"); } }