Skip to content

Commit

Permalink
Expand WeakMap to allow floats and fixnums
Browse files Browse the repository at this point in the history
They will use value comparisons rather than identity comparisons
to mimic CRuby's behavior for immediate values.

Fixes jruby#7862
  • Loading branch information
headius committed Nov 10, 2023
1 parent 353808f commit 0a79d3c
Showing 1 changed file with 36 additions and 18 deletions.
54 changes: 36 additions & 18 deletions core/src/main/java/org/jruby/RubyObjectSpace.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

import java.util.Iterator;
import java.util.Map;
import java.util.stream.Stream;

import org.jruby.anno.JRubyMethod;
import org.jruby.anno.JRubyModule;
Expand All @@ -51,6 +52,7 @@
import org.jruby.util.Inspector;
import org.jruby.util.Numeric;
import org.jruby.util.collections.WeakValuedIdentityMap;
import org.jruby.util.collections.WeakValuedMap;

@JRubyModule(name="ObjectSpace")
public class RubyObjectSpace {
Expand Down Expand Up @@ -231,59 +233,74 @@ public WeakMap(Ruby runtime, RubyClass cls) {

@JRubyMethod(name = "[]")
public IRubyObject op_aref(ThreadContext context, IRubyObject key) {
IRubyObject value = map.get(key);
Map<IRubyObject, IRubyObject> weakMap = getWeakMapFor(key);
IRubyObject value = weakMap.get(key);
if (value != null) return value;
return context.nil;
}

private Map<IRubyObject, IRubyObject> getWeakMapFor(IRubyObject key) {
if (key instanceof RubyFixnum || key instanceof RubyFloat) {
return valueMap;
}

return identityMap;
}

@JRubyMethod(name = "[]=")
public IRubyObject op_aref(ThreadContext context, IRubyObject key, IRubyObject value) {
Ruby runtime = context.runtime;

map.put(key, value);
Map<IRubyObject, IRubyObject> weakMap = getWeakMapFor(key);
weakMap.put(key, value);

return runtime.newFixnum(System.identityHashCode(value));
}

@JRubyMethod(name = "key?")
public IRubyObject key_p(ThreadContext context, IRubyObject key) {
return RubyBoolean.newBoolean(context, map.get(key) != null);
Map<IRubyObject, IRubyObject> weakMap = getWeakMapFor(key);
return RubyBoolean.newBoolean(context, weakMap.get(key) != null);
}

@JRubyMethod(name = "keys")
public IRubyObject keys(ThreadContext context) {
return context.runtime.newArrayNoCopy(
map.entrySet()
.stream()
getEntryStream()
.filter(entry -> entry.getValue() != null)
.map(entry -> entry.getKey())
.map(Map.Entry::getKey)
.toArray(IRubyObject[]::new));
}

private Stream<Map.Entry<IRubyObject, IRubyObject>> getEntryStream() {
return Stream.concat(identityMap.entrySet().stream(), valueMap.entrySet().stream());
}

@JRubyMethod(name = "values")
public IRubyObject values(ThreadContext context) {
return context.runtime.newArrayNoCopy(
map.values()
.stream()
getEntryStream()
.map(Map.Entry::getValue)
.filter(ref -> ref != null)
.toArray(IRubyObject[]::new));
}

@JRubyMethod(name = {"length", "size"})
public IRubyObject size(ThreadContext context) {
return context.runtime.newFixnum(map.size());
return context.runtime.newFixnum(identityMap.size() + valueMap.size());
}

@JRubyMethod(name = {"include?", "member?"})
public IRubyObject member_p(ThreadContext context, IRubyObject key) {
return RubyBoolean.newBoolean(context, map.containsKey(key));
return RubyBoolean.newBoolean(context, getWeakMapFor(key).containsKey(key));
}

@JRubyMethod(name = {"each", "each_pair"})
public IRubyObject each(ThreadContext context, Block block) {
map.forEach((key, value) -> {
getEntryStream().forEach((entry) -> {
IRubyObject value = entry.getValue();
if (value != null) {
block.yieldSpecific(context, key, value);
block.yieldSpecific(context, entry.getKey(), value);
}
});

Expand All @@ -292,23 +309,23 @@ public IRubyObject each(ThreadContext context, Block block) {

@JRubyMethod(name = "each_key")
public IRubyObject each_key(ThreadContext context, Block block) {
for (Map.Entry<IRubyObject, IRubyObject> entry : map.entrySet()) {
getEntryStream().forEach((entry) -> {
if (entry.getValue() != null) {
block.yieldSpecific(context, entry.getKey());
}
}
});

return this;
}

@JRubyMethod(name = "each_value")
public IRubyObject each_value(ThreadContext context, Block block) {
for (Map.Entry<IRubyObject, IRubyObject> entry : map.entrySet()) {
getEntryStream().forEach((entry) -> {
IRubyObject value = entry.getValue();
if (value != null) {
block.yieldSpecific(context, value);
}
}
});

return this;
}
Expand All @@ -320,7 +337,7 @@ public IRubyObject inspect(ThreadContext context) {
RubyString part = inspectPrefix(runtime.getCurrentContext(), metaClass.getRealClass(), inspectHashCode());
int base = part.length();

map.entrySet().forEach(entry -> {
getEntryStream().forEach(entry -> {
if (entry.getValue() != null) {
if (part.length() == base) {
part.cat(Inspector.COLON_SPACE);
Expand All @@ -339,6 +356,7 @@ public IRubyObject inspect(ThreadContext context) {
return part;
}

private final WeakValuedIdentityMap<IRubyObject, IRubyObject> map = new WeakValuedIdentityMap<IRubyObject, IRubyObject>();
private final WeakValuedIdentityMap<IRubyObject, IRubyObject> identityMap = new WeakValuedIdentityMap<>();
private final WeakValuedMap<IRubyObject, IRubyObject> valueMap = new WeakValuedMap<>();
}
}

0 comments on commit 0a79d3c

Please sign in to comment.