Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified source_.jar
Binary file not shown.
24 changes: 24 additions & 0 deletions src/main/cod/demo/src/main/test/unsafe/UnsafePointerBasics.cod
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
unit test.unsafe (main: this)

x: int = 42

unsafe LowLevel {
ptr: *u8
buffer: u8[1024]

method() :: int {
buffer[0] = 42
ptr = &buffer[0]
value := *ptr
ptr = ptr + 16
~> (value)
}
}

share UnsafePointerMain {
share main() {
low: LowLevel = safe(LowLevel())
out(low.method())
out(x)
}
}
22 changes: 22 additions & 0 deletions src/main/cod/demo/src/main/test/unsafe/UnsafeSafeCommit.cod
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
unit test.unsafe (main: this)

share unsafe UnsafeBox {
value: i8 = 0

share unsafe this(v: int) {
value = v
}

share unsafe unbox(v: int) :: i8 {
wrapped: i8 = v
~> (wrapped)
}
}

share UnsafeMain {
share main() {
box: UnsafeBox = safe(UnsafeBox(260))
n: int = safe(box.unbox(260))
out(n)
}
}
3 changes: 2 additions & 1 deletion src/main/java/cod/ast/node/Method.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ public class Method extends Base {
public List<Stmt> body = new ArrayList<Stmt>();
public boolean isBuiltin = false;
public boolean isPolicyMethod = false;
public boolean isUnsafe = false;

@Override
public final <T> T accept(VisitorImpl<T> visitor) {
return visitor.visit(this);
}

}
}
3 changes: 2 additions & 1 deletion src/main/java/cod/ast/node/Type.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public class Type extends Base {
public List<Stmt> statements = new ArrayList<Stmt>();
public List<Constructor> constructors = new ArrayList<Constructor>();
public List<String> implementedPolicies = new ArrayList<String>();
public boolean isUnsafe = false;

// Make Token fields transient
public transient Token extendToken;
Expand All @@ -32,4 +33,4 @@ public class Type extends Base {
public final <T> T accept(VisitorImpl<T> visitor) {
return visitor.visit(this);
}
}
}
31 changes: 31 additions & 0 deletions src/main/java/cod/interpreter/Interpreter.java
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,8 @@ public Object evalMethod(Method node, ObjectInstance obj, Map<String, Object> lo
ctx.currentClass = associatedClass;
}
}
boolean unsafeContext = node.isUnsafe || (ctx.currentClass != null && ctx.currentClass.isUnsafe);
ctx.setUnsafeExecutionContext(unsafeContext);

visitor.pushContext(ctx);
Object result = null;
Expand Down Expand Up @@ -986,6 +988,14 @@ public Object evalMethodCall(
}
throw new ProgramError("Method not found: " + call.name);
}

ExecutionContext callerCtx = ExecutionContext.getCurrentContext();
if (method.isUnsafe && !isUnsafeExecutionContext(callerCtx) && !ExecutionContext.isUnsafeCommitAllowed()) {
throw new ProgramError(
"Unsafe method '" + method.methodName + "' cannot be called in a safe context. Use safe("
+ call.name
+ "(...)).");
}

boolean hasSingleSlot = method.returnSlots != null && method.returnSlots.size() == 1;
if (call.slotNames.isEmpty() && hasSingleSlot && !call.isSingleSlotCall) {
Expand Down Expand Up @@ -1075,6 +1085,8 @@ public Object evalMethodCall(
}
}

argValue = typeSystem.normalizeForDeclaredType(paramType, argValue);

if (paramType.contains("|")) {
String activeType = typeSystem.getConcreteType(typeSystem.unwrap(argValue));
argValue = new TypeHandler.Value(argValue, activeType, paramType);
Expand Down Expand Up @@ -1120,6 +1132,8 @@ public Object evalMethodCall(
ctx.currentClass = classType;
}
}
boolean unsafeContext = method.isUnsafe || (ctx.currentClass != null && ctx.currentClass.isUnsafe);
ctx.setUnsafeExecutionContext(unsafeContext);

visitor.pushContext(ctx);
boolean calledMethodHasSlots = method.returnSlots != null && !method.returnSlots.isEmpty();
Expand Down Expand Up @@ -1225,6 +1239,23 @@ private Type findTypeByName(String className) {
return null;
}
}

private boolean isUnsafeExecutionContext(ExecutionContext ctx) {
if (ctx == null) return false;
if (ctx.currentClass != null && ctx.currentClass.isUnsafe) {
return true;
}
if (ctx.currentMethodName == null || ctx.currentMethodName.isEmpty()) {
return false;
}
Type searchType = ctx.currentClass;
if (searchType == null && ctx.objectInstance != null) {
searchType = ctx.objectInstance.type;
}
if (searchType == null) return false;
Method currentMethod = constructorResolver.findMethodInHierarchy(searchType, ctx.currentMethodName, ctx);
return currentMethod != null && currentMethod.isUnsafe;
}

public void clearAllCaches() {
importResolver.clearCache();
Expand Down
116 changes: 116 additions & 0 deletions src/main/java/cod/interpreter/InterpreterVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,17 @@ public Object visit(ConstructorCall node) {
}

try {
ExecutionContext ctx = getCurrentContext();
Type targetType = interpreter.getImportResolver().findType(node.className);
if (targetType != null
&& targetType.isUnsafe
&& !isUnsafeExecutionContext(ctx)
&& !ExecutionContext.isUnsafeCommitAllowed()) {
throw new ProgramError(
"Unsafe class '" + targetType.name + "' cannot be constructed in a safe context. Use safe("
+ targetType.name
+ "(...)).");
}
return interpreter.getConstructorResolver().resolveAndCreate(node, getCurrentContext());
} catch (ProgramError e) {
throw e;
Expand Down Expand Up @@ -1439,6 +1450,100 @@ private List<Object> evaluateMethodCallArguments(MethodCall methodCall) {
return evaluatedArgs;
}

private boolean isUnsafeExecutionContext(ExecutionContext ctx) {
if (ctx == null) return false;
if (ctx.currentClass != null && ctx.currentClass.isUnsafe) {
return true;
}
Method currentMethod = resolveCurrentContextMethod(ctx);
return currentMethod != null && currentMethod.isUnsafe;
}

private Method resolveCurrentContextMethod(ExecutionContext ctx) {
if (ctx == null || ctx.currentMethodName == null || ctx.currentMethodName.isEmpty()) {
return null;
}
Type searchType = ctx.currentClass;
if (searchType == null && ctx.objectInstance != null) {
searchType = ctx.objectInstance.type;
}
if (searchType == null) {
return null;
}
return interpreter.getConstructorResolver().findMethodInHierarchy(searchType, ctx.currentMethodName, ctx);
}

private Method resolveMethodForCall(MethodCall node, ExecutionContext ctx) {
Method method = null;
String callName = node.name;
String callQualifiedName = node.qualifiedName;

if (ctx.currentClass != null) {
method = interpreter.getConstructorResolver().findMethodInHierarchy(ctx.currentClass, callName, ctx);
}

if (method == null && ctx.objectInstance != null && ctx.objectInstance.type != null) {
method = interpreter.getConstructorResolver().findMethodInHierarchy(ctx.objectInstance.type, callName, ctx);
}

if (method == null) {
String qName = callQualifiedName;
if (qName != null && qName.contains(".")) {
String[] parts = qName.split("\\.");
if (parts.length == 2) {
String receiver = parts[0];
String methodName = parts[1];
if (ctx.locals().containsKey(receiver)) {
Object receiverObj = ctx.locals().get(receiver);
if (receiverObj instanceof ObjectInstance) {
ObjectInstance objInst = (ObjectInstance) receiverObj;
if (objInst.type != null) {
qName = objInst.type.name + "." + methodName;
}
}
}
}
}
if (qName == null) qName = callName;
method = interpreter.getImportResolver().findMethod(qName);
}

return method;
}

private Object executeSafeCommit(MethodCall node, ExecutionContext ctx) {
if (isUnsafeExecutionContext(ctx)) {
throw new ProgramError(
"safe() is not allowed inside unsafe classes or methods; these contexts already have permission to execute unsafe code");
}
if (node.arguments == null || node.arguments.size() != 1) {
throw new ProgramError("safe() expects exactly one argument");
}

Expr argument = node.arguments.get(0);
boolean unsafeTarget = false;

if (argument instanceof MethodCall) {
Method targetMethod = resolveMethodForCall((MethodCall) argument, ctx);
unsafeTarget = targetMethod != null && targetMethod.isUnsafe;
} else if (argument instanceof ConstructorCall) {
Type targetType = interpreter.getImportResolver().findType(((ConstructorCall) argument).className);
unsafeTarget = targetType != null && targetType.isUnsafe;
}

if (!unsafeTarget) {
throw new ProgramError(
"safe() requires an unsafe method call or unsafe class constructor as its argument, but the provided expression is not marked unsafe");
}

ExecutionContext.enterUnsafeCommitAllowance();
try {
return dispatch(argument);
} finally {
ExecutionContext.exitUnsafeCommitAllowance();
}
}

@Override
public Object visit(Identifier node) {
if (node == null) {
Expand Down Expand Up @@ -1827,6 +1932,10 @@ public Object visit(MethodCall node) {
}
}
}

if ("safe".equals(callName) && (callQualifiedName == null || "safe".equals(callQualifiedName))) {
return executeSafeCommit(node, ctx);
}

// Evaluate all arguments first
List<Object> evaluatedArgs = evaluateMethodCallArguments(node);
Expand Down Expand Up @@ -1885,6 +1994,13 @@ public Object visit(MethodCall node) {
throw new ProgramError("Method not found: " + callName);
}

if (method.isUnsafe && !isUnsafeExecutionContext(ctx) && !ExecutionContext.isUnsafeCommitAllowed()) {
throw new ProgramError(
"Unsafe method '" + method.methodName + "' cannot be called in a safe context. Use safe("
+ callName
+ "(...)).");
}

// Check if this is a single-slot call
boolean hasSingleSlot = method.returnSlots != null && method.returnSlots.size() == 1;
if (node.slotNames.isEmpty() && hasSingleSlot) {
Expand Down
28 changes: 28 additions & 0 deletions src/main/java/cod/interpreter/context/ExecutionContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,17 @@ public class ExecutionContext {

// ========== THREAD LOCAL CONTEXT ==========
private static final ThreadLocal<ExecutionContext> currentContext = new ThreadLocal<ExecutionContext>();
private static final ThreadLocal<Integer> unsafeCommitDepth = new ThreadLocal<Integer>() {
@Override
protected Integer initialValue() {
return 0;
}
};

// ========== OPTIMIZED LOOP CONTEXT ==========
private boolean inOptimizedLoop = false;
private List<Object> pendingOutputs = new ArrayList<Object>();
private boolean unsafeExecutionContext = false;

// ========== TYPE HANDLER ==========
private final TypeHandler typeHandler;
Expand Down Expand Up @@ -73,6 +80,19 @@ public Map<String, Object> getLocalsMap() {
public static void clearCurrentContext() {
currentContext.remove();
}

public static void enterUnsafeCommitAllowance() {
unsafeCommitDepth.set(unsafeCommitDepth.get() + 1);
}

public static void exitUnsafeCommitAllowance() {
int current = unsafeCommitDepth.get();
unsafeCommitDepth.set(current > 0 ? current - 1 : 0);
}

public static boolean isUnsafeCommitAllowed() {
return unsafeCommitDepth.get() > 0;
}

/**
* Mark that we're entering an optimized loop
Expand Down Expand Up @@ -114,6 +134,14 @@ public List<Object> flushPendingOutputs() {
pendingOutputs.clear();
return outputs;
}

public boolean isUnsafeExecutionContext() {
return unsafeExecutionContext;
}

public void setUnsafeExecutionContext(boolean unsafeExecutionContext) {
this.unsafeExecutionContext = unsafeExecutionContext;
}

public ExecutionContext(ObjectInstance obj, Map<String, Object> locals,
Map<String, Object> slotValues, Map<String, String> slotTypes,
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/cod/interpreter/handler/AssignmentHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ private Object assignToSlot(String slotTarget, Object value, ExecutionContext ct
if (ctx.hasSlot(slotTarget)) { // O(1)
String declaredType = ctx.getSlotType(slotTarget); // O(1)
validateAssignmentType(declaredType, value, slotTarget);
value = typeSystem.normalizeForDeclaredType(declaredType, value);

value = typeSystem.wrapUnionType(value, declaredType);

Expand Down Expand Up @@ -411,6 +412,7 @@ private Object updateVariableInScope(String varName, Object newValue,

if (declaredType != null) {
validateAssignmentType(declaredType, newValue, varName);
newValue = typeSystem.normalizeForDeclaredType(declaredType, newValue);
newValue = typeSystem.wrapUnionType(newValue, declaredType);
}

Expand Down
Loading