diff --git a/source_.jar b/source_.jar index d733f6ea..37595bb5 100644 Binary files a/source_.jar and b/source_.jar differ diff --git a/src/main/cod/demo/src/main/test/unsafe/UnsafePointerBasics.cod b/src/main/cod/demo/src/main/test/unsafe/UnsafePointerBasics.cod new file mode 100644 index 00000000..62d5558f --- /dev/null +++ b/src/main/cod/demo/src/main/test/unsafe/UnsafePointerBasics.cod @@ -0,0 +1,24 @@ +unit test.unsafe (main: this) + +x: int = 42 + +unsafe LowLevel { + ptr: *u8 + buffer: u8[1024] + + method() :: int { + buffer[0] = 42 + ptr = &buffer[0] + value := *ptr + ptr = ptr + 16 + ~> (value) + } +} + +share UnsafePointerMain { + share main() { + low: LowLevel = safe(LowLevel()) + out(low.method()) + out(x) + } +} diff --git a/src/main/cod/demo/src/main/test/unsafe/UnsafeSafeCommit.cod b/src/main/cod/demo/src/main/test/unsafe/UnsafeSafeCommit.cod new file mode 100644 index 00000000..1d1062f9 --- /dev/null +++ b/src/main/cod/demo/src/main/test/unsafe/UnsafeSafeCommit.cod @@ -0,0 +1,22 @@ +unit test.unsafe (main: this) + +share unsafe UnsafeBox { + value: i8 = 0 + + share unsafe this(v: int) { + value = v + } + + share unsafe unbox(v: int) :: i8 { + wrapped: i8 = v + ~> (wrapped) + } +} + +share UnsafeMain { + share main() { + box: UnsafeBox = safe(UnsafeBox(260)) + n: int = safe(box.unbox(260)) + out(n) + } +} diff --git a/src/main/java/cod/ast/node/Method.java b/src/main/java/cod/ast/node/Method.java index 4db48c9c..842b5dcb 100644 --- a/src/main/java/cod/ast/node/Method.java +++ b/src/main/java/cod/ast/node/Method.java @@ -15,10 +15,11 @@ public class Method extends Base { public List body = new ArrayList(); public boolean isBuiltin = false; public boolean isPolicyMethod = false; + public boolean isUnsafe = false; @Override public final T accept(VisitorImpl visitor) { return visitor.visit(this); } -} \ No newline at end of file +} diff --git a/src/main/java/cod/ast/node/Type.java b/src/main/java/cod/ast/node/Type.java index 5ac02c16..9d3557c3 100644 --- a/src/main/java/cod/ast/node/Type.java +++ b/src/main/java/cod/ast/node/Type.java @@ -19,6 +19,7 @@ public class Type extends Base { public List statements = new ArrayList(); public List constructors = new ArrayList(); public List implementedPolicies = new ArrayList(); + public boolean isUnsafe = false; // Make Token fields transient public transient Token extendToken; @@ -32,4 +33,4 @@ public class Type extends Base { public final T accept(VisitorImpl visitor) { return visitor.visit(this); } -} \ No newline at end of file +} diff --git a/src/main/java/cod/interpreter/Interpreter.java b/src/main/java/cod/interpreter/Interpreter.java index 6e413cbd..7f31faf1 100644 --- a/src/main/java/cod/interpreter/Interpreter.java +++ b/src/main/java/cod/interpreter/Interpreter.java @@ -923,6 +923,8 @@ public Object evalMethod(Method node, ObjectInstance obj, Map lo ctx.currentClass = associatedClass; } } + boolean unsafeContext = node.isUnsafe || (ctx.currentClass != null && ctx.currentClass.isUnsafe); + ctx.setUnsafeExecutionContext(unsafeContext); visitor.pushContext(ctx); Object result = null; @@ -986,6 +988,14 @@ public Object evalMethodCall( } throw new ProgramError("Method not found: " + call.name); } + + ExecutionContext callerCtx = ExecutionContext.getCurrentContext(); + if (method.isUnsafe && !isUnsafeExecutionContext(callerCtx) && !ExecutionContext.isUnsafeCommitAllowed()) { + throw new ProgramError( + "Unsafe method '" + method.methodName + "' cannot be called in a safe context. Use safe(" + + call.name + + "(...))."); + } boolean hasSingleSlot = method.returnSlots != null && method.returnSlots.size() == 1; if (call.slotNames.isEmpty() && hasSingleSlot && !call.isSingleSlotCall) { @@ -1075,6 +1085,8 @@ public Object evalMethodCall( } } + argValue = typeSystem.normalizeForDeclaredType(paramType, argValue); + if (paramType.contains("|")) { String activeType = typeSystem.getConcreteType(typeSystem.unwrap(argValue)); argValue = new TypeHandler.Value(argValue, activeType, paramType); @@ -1120,6 +1132,8 @@ public Object evalMethodCall( ctx.currentClass = classType; } } + boolean unsafeContext = method.isUnsafe || (ctx.currentClass != null && ctx.currentClass.isUnsafe); + ctx.setUnsafeExecutionContext(unsafeContext); visitor.pushContext(ctx); boolean calledMethodHasSlots = method.returnSlots != null && !method.returnSlots.isEmpty(); @@ -1225,6 +1239,23 @@ private Type findTypeByName(String className) { return null; } } + + private boolean isUnsafeExecutionContext(ExecutionContext ctx) { + if (ctx == null) return false; + if (ctx.currentClass != null && ctx.currentClass.isUnsafe) { + return true; + } + if (ctx.currentMethodName == null || ctx.currentMethodName.isEmpty()) { + return false; + } + Type searchType = ctx.currentClass; + if (searchType == null && ctx.objectInstance != null) { + searchType = ctx.objectInstance.type; + } + if (searchType == null) return false; + Method currentMethod = constructorResolver.findMethodInHierarchy(searchType, ctx.currentMethodName, ctx); + return currentMethod != null && currentMethod.isUnsafe; + } public void clearAllCaches() { importResolver.clearCache(); diff --git a/src/main/java/cod/interpreter/InterpreterVisitor.java b/src/main/java/cod/interpreter/InterpreterVisitor.java index 65043697..0ec1f481 100644 --- a/src/main/java/cod/interpreter/InterpreterVisitor.java +++ b/src/main/java/cod/interpreter/InterpreterVisitor.java @@ -266,6 +266,17 @@ public Object visit(ConstructorCall node) { } try { + ExecutionContext ctx = getCurrentContext(); + Type targetType = interpreter.getImportResolver().findType(node.className); + if (targetType != null + && targetType.isUnsafe + && !isUnsafeExecutionContext(ctx) + && !ExecutionContext.isUnsafeCommitAllowed()) { + throw new ProgramError( + "Unsafe class '" + targetType.name + "' cannot be constructed in a safe context. Use safe(" + + targetType.name + + "(...))."); + } return interpreter.getConstructorResolver().resolveAndCreate(node, getCurrentContext()); } catch (ProgramError e) { throw e; @@ -1439,6 +1450,100 @@ private List evaluateMethodCallArguments(MethodCall methodCall) { return evaluatedArgs; } + private boolean isUnsafeExecutionContext(ExecutionContext ctx) { + if (ctx == null) return false; + if (ctx.currentClass != null && ctx.currentClass.isUnsafe) { + return true; + } + Method currentMethod = resolveCurrentContextMethod(ctx); + return currentMethod != null && currentMethod.isUnsafe; + } + + private Method resolveCurrentContextMethod(ExecutionContext ctx) { + if (ctx == null || ctx.currentMethodName == null || ctx.currentMethodName.isEmpty()) { + return null; + } + Type searchType = ctx.currentClass; + if (searchType == null && ctx.objectInstance != null) { + searchType = ctx.objectInstance.type; + } + if (searchType == null) { + return null; + } + return interpreter.getConstructorResolver().findMethodInHierarchy(searchType, ctx.currentMethodName, ctx); + } + + private Method resolveMethodForCall(MethodCall node, ExecutionContext ctx) { + Method method = null; + String callName = node.name; + String callQualifiedName = node.qualifiedName; + + if (ctx.currentClass != null) { + method = interpreter.getConstructorResolver().findMethodInHierarchy(ctx.currentClass, callName, ctx); + } + + if (method == null && ctx.objectInstance != null && ctx.objectInstance.type != null) { + method = interpreter.getConstructorResolver().findMethodInHierarchy(ctx.objectInstance.type, callName, ctx); + } + + if (method == null) { + String qName = callQualifiedName; + if (qName != null && qName.contains(".")) { + String[] parts = qName.split("\\."); + if (parts.length == 2) { + String receiver = parts[0]; + String methodName = parts[1]; + if (ctx.locals().containsKey(receiver)) { + Object receiverObj = ctx.locals().get(receiver); + if (receiverObj instanceof ObjectInstance) { + ObjectInstance objInst = (ObjectInstance) receiverObj; + if (objInst.type != null) { + qName = objInst.type.name + "." + methodName; + } + } + } + } + } + if (qName == null) qName = callName; + method = interpreter.getImportResolver().findMethod(qName); + } + + return method; + } + + private Object executeSafeCommit(MethodCall node, ExecutionContext ctx) { + if (isUnsafeExecutionContext(ctx)) { + throw new ProgramError( + "safe() is not allowed inside unsafe classes or methods; these contexts already have permission to execute unsafe code"); + } + if (node.arguments == null || node.arguments.size() != 1) { + throw new ProgramError("safe() expects exactly one argument"); + } + + Expr argument = node.arguments.get(0); + boolean unsafeTarget = false; + + if (argument instanceof MethodCall) { + Method targetMethod = resolveMethodForCall((MethodCall) argument, ctx); + unsafeTarget = targetMethod != null && targetMethod.isUnsafe; + } else if (argument instanceof ConstructorCall) { + Type targetType = interpreter.getImportResolver().findType(((ConstructorCall) argument).className); + unsafeTarget = targetType != null && targetType.isUnsafe; + } + + if (!unsafeTarget) { + throw new ProgramError( + "safe() requires an unsafe method call or unsafe class constructor as its argument, but the provided expression is not marked unsafe"); + } + + ExecutionContext.enterUnsafeCommitAllowance(); + try { + return dispatch(argument); + } finally { + ExecutionContext.exitUnsafeCommitAllowance(); + } + } + @Override public Object visit(Identifier node) { if (node == null) { @@ -1827,6 +1932,10 @@ public Object visit(MethodCall node) { } } } + + if ("safe".equals(callName) && (callQualifiedName == null || "safe".equals(callQualifiedName))) { + return executeSafeCommit(node, ctx); + } // Evaluate all arguments first List evaluatedArgs = evaluateMethodCallArguments(node); @@ -1885,6 +1994,13 @@ public Object visit(MethodCall node) { throw new ProgramError("Method not found: " + callName); } + if (method.isUnsafe && !isUnsafeExecutionContext(ctx) && !ExecutionContext.isUnsafeCommitAllowed()) { + throw new ProgramError( + "Unsafe method '" + method.methodName + "' cannot be called in a safe context. Use safe(" + + callName + + "(...))."); + } + // Check if this is a single-slot call boolean hasSingleSlot = method.returnSlots != null && method.returnSlots.size() == 1; if (node.slotNames.isEmpty() && hasSingleSlot) { diff --git a/src/main/java/cod/interpreter/context/ExecutionContext.java b/src/main/java/cod/interpreter/context/ExecutionContext.java index 7800484a..0fd38cc9 100644 --- a/src/main/java/cod/interpreter/context/ExecutionContext.java +++ b/src/main/java/cod/interpreter/context/ExecutionContext.java @@ -33,10 +33,17 @@ public class ExecutionContext { // ========== THREAD LOCAL CONTEXT ========== private static final ThreadLocal currentContext = new ThreadLocal(); + private static final ThreadLocal unsafeCommitDepth = new ThreadLocal() { + @Override + protected Integer initialValue() { + return 0; + } + }; // ========== OPTIMIZED LOOP CONTEXT ========== private boolean inOptimizedLoop = false; private List pendingOutputs = new ArrayList(); + private boolean unsafeExecutionContext = false; // ========== TYPE HANDLER ========== private final TypeHandler typeHandler; @@ -73,6 +80,19 @@ public Map getLocalsMap() { public static void clearCurrentContext() { currentContext.remove(); } + + public static void enterUnsafeCommitAllowance() { + unsafeCommitDepth.set(unsafeCommitDepth.get() + 1); + } + + public static void exitUnsafeCommitAllowance() { + int current = unsafeCommitDepth.get(); + unsafeCommitDepth.set(current > 0 ? current - 1 : 0); + } + + public static boolean isUnsafeCommitAllowed() { + return unsafeCommitDepth.get() > 0; + } /** * Mark that we're entering an optimized loop @@ -114,6 +134,14 @@ public List flushPendingOutputs() { pendingOutputs.clear(); return outputs; } + + public boolean isUnsafeExecutionContext() { + return unsafeExecutionContext; + } + + public void setUnsafeExecutionContext(boolean unsafeExecutionContext) { + this.unsafeExecutionContext = unsafeExecutionContext; + } public ExecutionContext(ObjectInstance obj, Map locals, Map slotValues, Map slotTypes, diff --git a/src/main/java/cod/interpreter/handler/AssignmentHandler.java b/src/main/java/cod/interpreter/handler/AssignmentHandler.java index fb47a509..3112e6f7 100644 --- a/src/main/java/cod/interpreter/handler/AssignmentHandler.java +++ b/src/main/java/cod/interpreter/handler/AssignmentHandler.java @@ -144,6 +144,7 @@ private Object assignToSlot(String slotTarget, Object value, ExecutionContext ct if (ctx.hasSlot(slotTarget)) { // O(1) String declaredType = ctx.getSlotType(slotTarget); // O(1) validateAssignmentType(declaredType, value, slotTarget); + value = typeSystem.normalizeForDeclaredType(declaredType, value); value = typeSystem.wrapUnionType(value, declaredType); @@ -411,6 +412,7 @@ private Object updateVariableInScope(String varName, Object newValue, if (declaredType != null) { validateAssignmentType(declaredType, newValue, varName); + newValue = typeSystem.normalizeForDeclaredType(declaredType, newValue); newValue = typeSystem.wrapUnionType(newValue, declaredType); } diff --git a/src/main/java/cod/interpreter/handler/ExpressionHandler.java b/src/main/java/cod/interpreter/handler/ExpressionHandler.java index bc229cb8..3c8d8415 100644 --- a/src/main/java/cod/interpreter/handler/ExpressionHandler.java +++ b/src/main/java/cod/interpreter/handler/ExpressionHandler.java @@ -44,6 +44,10 @@ public Object handleBinaryOp(BinaryOp node, ExecutionContext ctx) { switch (node.op) { case "+": case "+=": + if (typeSystem.unwrap(left) instanceof TypeHandler.PointerValue + || typeSystem.unwrap(right) instanceof TypeHandler.PointerValue) { + return handlePointerArithmetic(left, right, true, ctx); + } if (left instanceof String || right instanceof String || left instanceof TextLiteral || right instanceof TextLiteral) { @@ -78,6 +82,10 @@ public Object handleBinaryOp(BinaryOp node, ExecutionContext ctx) { case "-": case "-=": + if (typeSystem.unwrap(left) instanceof TypeHandler.PointerValue + || typeSystem.unwrap(right) instanceof TypeHandler.PointerValue) { + return handlePointerArithmetic(left, right, false, ctx); + } result = typeSystem.subtractNumbers(left, right); break; @@ -142,6 +150,14 @@ public Object handleUnaryOp(Unary node, ExecutionContext ctx) { } try { + if ("&".equals(node.op)) { + ensureUnsafeContext(ctx, "Address-of operator '&'"); + if (!(node.operand instanceof IndexAccess)) { + throw new ProgramError("Address-of operator '&' requires an index expression like '&buffer[0]'"); + } + return createPointerFromIndexAccess((IndexAccess) node.operand, ctx); + } + Object operand = dispatcher.dispatch(node.operand); switch (node.op) { @@ -151,6 +167,9 @@ public Object handleUnaryOp(Unary node, ExecutionContext ctx) { return operand; case "!": return !typeSystem.isTruthy(operand); + case "*": + ensureUnsafeContext(ctx, "Pointer dereference '*'"); + return dereferencePointer(operand); default: throw new ProgramError("Unknown unary operator: " + node.op); } @@ -160,6 +179,99 @@ public Object handleUnaryOp(Unary node, ExecutionContext ctx) { throw new InternalError("Unary operation failed: " + node.op, e); } } + + private void ensureUnsafeContext(ExecutionContext ctx, String featureName) { + boolean unsafe = ctx.isUnsafeExecutionContext() + || (ctx.currentClass != null && ctx.currentClass.isUnsafe); + if (!unsafe) { + throw new ProgramError(featureName + " is only available inside unsafe class/method contexts"); + } + } + + private TypeHandler.PointerValue createPointerFromIndexAccess(IndexAccess access, ExecutionContext ctx) { + Object container = dispatcher.dispatch(access.array); + container = typeSystem.unwrap(container); + Object indexObj = dispatcher.dispatch(access.index); + indexObj = typeSystem.unwrap(indexObj); + + if (container instanceof NaturalArray) { + long idx = toLongIndex(indexObj); + NaturalArray arr = (NaturalArray) container; + if (idx < 0 || idx >= arr.size()) { + throw new ProgramError("Pointer address index out of bounds: " + idx); + } + return new TypeHandler.PointerValue(arr, idx, arr.getElementType()); + } + + if (container instanceof List) { + long idx = toLongIndex(indexObj); + List list = (List) container; + if (idx < 0 || idx >= list.size()) { + throw new ProgramError("Pointer address index out of bounds: " + idx); + } + String pointedType = "any"; + Object pointedValue = list.get((int) idx); + if (pointedValue != null) { + pointedType = typeSystem.getConcreteType(typeSystem.unwrap(pointedValue)); + } + return new TypeHandler.PointerValue(list, idx, pointedType); + } + + throw new ProgramError("Address-of operator '&' only supports array/list index targets"); + } + + private Object dereferencePointer(Object pointerObj) { + Object unwrapped = typeSystem.unwrap(pointerObj); + if (!(unwrapped instanceof TypeHandler.PointerValue)) { + throw new ProgramError("Dereference '*' expects a pointer value"); + } + TypeHandler.PointerValue pointer = (TypeHandler.PointerValue) unwrapped; + if (pointer.container instanceof NaturalArray) { + return ((NaturalArray) pointer.container).get(pointer.index); + } + if (pointer.container instanceof List) { + List list = (List) pointer.container; + int idx = Math.toIntExact(pointer.index); + if (idx < 0 || idx >= list.size()) { + throw new ProgramError("Pointer dereference out of bounds: " + idx); + } + return list.get(idx); + } + throw new ProgramError("Unsupported pointer target"); + } + + private Object handlePointerArithmetic(Object left, Object right, boolean addition, ExecutionContext ctx) { + ensureUnsafeContext(ctx, "Pointer arithmetic"); + Object leftUnwrapped = typeSystem.unwrap(left); + Object rightUnwrapped = typeSystem.unwrap(right); + + TypeHandler.PointerValue pointer; + long offset; + + if (leftUnwrapped instanceof TypeHandler.PointerValue && !(rightUnwrapped instanceof TypeHandler.PointerValue)) { + pointer = (TypeHandler.PointerValue) leftUnwrapped; + offset = toLongIndex(rightUnwrapped); + } else if (rightUnwrapped instanceof TypeHandler.PointerValue && !(leftUnwrapped instanceof TypeHandler.PointerValue) && addition) { + pointer = (TypeHandler.PointerValue) rightUnwrapped; + offset = toLongIndex(leftUnwrapped); + } else { + throw new ProgramError("Pointer arithmetic expects pointer +/- integer"); + } + + long nextIndex = addition ? pointer.index + offset : pointer.index - offset; + if (pointer.container instanceof NaturalArray) { + NaturalArray arr = (NaturalArray) pointer.container; + if (nextIndex < 0 || nextIndex >= arr.size()) { + throw new ProgramError("Pointer arithmetic out of bounds: " + nextIndex); + } + } else if (pointer.container instanceof List) { + int size = ((List) pointer.container).size(); + if (nextIndex < 0 || nextIndex >= size) { + throw new ProgramError("Pointer arithmetic out of bounds: " + nextIndex); + } + } + return new TypeHandler.PointerValue(pointer.container, nextIndex, pointer.pointedType); + } public Object handleTypeCast(TypeCast node, ExecutionContext ctx) { if (node == null) { diff --git a/src/main/java/cod/interpreter/handler/TypeHandler.java b/src/main/java/cod/interpreter/handler/TypeHandler.java index 402f837c..346165ab 100644 --- a/src/main/java/cod/interpreter/handler/TypeHandler.java +++ b/src/main/java/cod/interpreter/handler/TypeHandler.java @@ -7,6 +7,8 @@ import cod.range.NaturalArray; import static cod.syntax.Keyword.*; +import java.math.BigDecimal; +import java.math.BigInteger; import java.util.ArrayList; import java.util.AbstractList; import java.util.List; @@ -99,11 +101,63 @@ public int hashCode() { return Objects.hash(value, activeType, declaredType); } } + + public static class PointerValue { + public final Object container; + public final long index; + public final String pointedType; + + public PointerValue(Object container, long index, String pointedType) { + this.container = container; + this.index = index; + this.pointedType = pointedType; + } + + @Override + public String toString() { + return "&" + pointedType + "@" + index; + } + } // AutoStackingNumber constants private static final AutoStackingNumber ZERO = AutoStackingNumber.valueOf("0"); private static final AutoStackingNumber ONE = AutoStackingNumber.valueOf("1"); private static final int LAZY_ARRAY_MEMO_MAX_SIZE = 8192; + private static final String[] UNSAFE_NUMERIC_TYPES = { + "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "f32", "f64" + }; + + public boolean isPointerType(String type) { + return type != null && type.startsWith("*") && type.length() > 1; + } + + public boolean isSizedArrayType(String type) { + if (type == null) return false; + int l = type.lastIndexOf('['); + int r = type.lastIndexOf(']'); + if (l <= 0 || r != type.length() - 1) return false; + String sizePart = type.substring(l + 1, r).trim(); + if (sizePart.isEmpty()) return false; + for (int i = 0; i < sizePart.length(); i++) { + if (!Character.isDigit(sizePart.charAt(i))) return false; + } + return true; + } + + public String getSizedArrayElementType(String type) { + if (!isSizedArrayType(type)) return null; + return type.substring(0, type.lastIndexOf('[')); + } + + public int getSizedArrayLength(String type) { + if (!isSizedArrayType(type)) return -1; + String sizePart = type.substring(type.lastIndexOf('[') + 1, type.length() - 1).trim(); + try { + return Integer.parseInt(sizePart); + } catch (NumberFormatException e) { + return -1; + } + } // Helper to check if value is none public boolean isNoneValue(Object obj) { @@ -183,9 +237,31 @@ public boolean isTruthy(Object value) { public boolean isTypeLiteral(String str) { return str.equals("int") || str.equals("float") || str.equals("text") || str.equals("bool") || str.equals("type") || str.equals("none") || + isPointerType(str) || isSizedArrayType(str) || + isUnsafeNumericType(str) || str.equals("[]") || str.startsWith("[") || str.startsWith("(") || str.contains("|"); } + + public boolean isUnsafeNumericType(String type) { + if (type == null) return false; + for (String unsafeType : UNSAFE_NUMERIC_TYPES) { + if (unsafeType.equals(type)) { + return true; + } + } + return false; + } + + public Object normalizeForDeclaredType(String declaredType, Object value) { + if (declaredType == null) return value; + String normalized = declaredType.trim(); + if (!isUnsafeNumericType(normalized)) { + return value; + } + Object converted = convertType(value, normalized); + return new Value(converted, normalized, normalized); + } public Object processTypeLiteral(String typeLiteral) { if (typeLiteral.equals("none")) { @@ -193,6 +269,16 @@ public Object processTypeLiteral(String typeLiteral) { } return Value.createTypeValue(typeLiteral); } + + private String normalizeTypeSignature(String typeSig) { + if (typeSig == null) return null; + String trimmed = typeSig.trim(); + if (isSizedArrayType(trimmed)) { + String inner = normalizeTypeSignature(getSizedArrayElementType(trimmed)); + return "[" + inner + "]"; + } + return trimmed; + } // === TypeHandler Validation with Special Cases === @@ -910,6 +996,23 @@ public Object convertType(Object value, String targetType) { if (targetType.equals("none")) { return new NoneLiteral(); } + + if (isUnsafeNumericType(targetType)) { + return convertUnsafeNumeric(value, targetType); + } + + if (isPointerType(targetType)) { + Object unwrapped = unwrap(value); + if (unwrapped instanceof PointerValue) { + PointerValue pointer = (PointerValue) unwrapped; + String expectedPointedType = normalizeTypeSignature(targetType.substring(1)); + String actualPointedType = normalizeTypeSignature(pointer.pointedType); + if (expectedPointedType.equals(actualPointedType)) { + return pointer; + } + } + throw new ProgramError("Cannot convert '" + value + "' to pointer type " + targetType); + } if (value instanceof FloatLiteral) { AutoStackingNumber num = ((FloatLiteral) value).value; @@ -1026,6 +1129,64 @@ public Object convertType(Object value, String targetType) { "), targetType=" + targetType ); } + + private Object convertUnsafeNumeric(Object value, String targetType) { + if (targetType.equals("f32")) { + double numeric = toDouble(value); + return AutoStackingNumber.fromDouble((double) ((float) numeric)); + } + if (targetType.equals("f64")) { + double numeric = toDouble(value); + return AutoStackingNumber.fromDouble(numeric); + } + BigInteger integral = toIntegralBigInteger(value); + return wrapIntegerUnsafe(integral, targetType); + } + + private BigInteger toIntegralBigInteger(Object value) { + Object unwrapped = unwrap(value); + if (unwrapped instanceof IntLiteral) { + return new BigInteger(((IntLiteral) unwrapped).value.toString()); + } + if (unwrapped instanceof FloatLiteral) { + AutoStackingNumber n = ((FloatLiteral) unwrapped).value; + return new BigDecimal(n.toString()).toBigInteger(); + } + if (unwrapped instanceof AutoStackingNumber) { + return new BigDecimal(((AutoStackingNumber) unwrapped).toString()).toBigInteger(); + } + if (unwrapped instanceof Integer || unwrapped instanceof Long) { + return BigInteger.valueOf(((Number) unwrapped).longValue()); + } + if (unwrapped instanceof Float || unwrapped instanceof Double) { + return BigDecimal.valueOf(((Number) unwrapped).doubleValue()).toBigInteger(); + } + throw new ProgramError( + "Unsafe numeric types require int or float values, got: " + getConcreteType(unwrapped)); + } + + private AutoStackingNumber wrapIntegerUnsafe(BigInteger value, String targetType) { + int bits = 64; + boolean signed = true; + if (targetType.equals("i8")) bits = 8; + else if (targetType.equals("i16")) bits = 16; + else if (targetType.equals("i32")) bits = 32; + else if (targetType.equals("i64")) bits = 64; + else if (targetType.equals("u8")) { bits = 8; signed = false; } + else if (targetType.equals("u16")) { bits = 16; signed = false; } + else if (targetType.equals("u32")) { bits = 32; signed = false; } + else if (targetType.equals("u64")) { bits = 64; signed = false; } + + BigInteger modulus = BigInteger.ONE.shiftLeft(bits); + BigInteger wrapped = value.mod(modulus); + if (signed) { + BigInteger signBoundary = BigInteger.ONE.shiftLeft(bits - 1); + if (wrapped.compareTo(signBoundary) >= 0) { + wrapped = wrapped.subtract(modulus); + } + } + return AutoStackingNumber.valueOf(wrapped.toString()); + } public String getConcreteType(Object value) { if (value instanceof Value) { @@ -1047,6 +1208,10 @@ public String getConcreteType(Object value) { // Return the element type of the array, not "list" return arr.getElementType(); } + + if (value instanceof PointerValue) { + return "*" + ((PointerValue) value).pointedType; + } if (value instanceof IntLiteral) return INT.toString(); if (value instanceof FloatLiteral) return FLOAT.toString(); @@ -1074,7 +1239,10 @@ public String getConcreteType(Object value) { } public boolean validateType(String typeSig, Object value) { - String typeSigTrimmed = typeSig.trim(); + if (typeSig == null) { + return true; + } + String typeSigTrimmed = normalizeTypeSignature(typeSig); if (typeSigTrimmed.contains("|")) { if (!isTypeStructurallyValid(typeSigTrimmed)) { throw new ProgramError("Union type contains illegal keywords: " + typeSig); @@ -1128,7 +1296,7 @@ public boolean areEqual(Object a, Object b) { private boolean validateTypeInternal(String typeSig, Object rawValue, String concreteType) { if (typeSig == null) return true; - String type = typeSig.trim(); + String type = normalizeTypeSignature(typeSig); if (type.equals(ANY.toString())) return true; @@ -1156,6 +1324,16 @@ private boolean validateTypeInternal(String typeSig, Object rawValue, String con return isNoneValue; } + if (isPointerType(type)) { + if (isNoneValue(rawValue)) return false; + Object unwrapped = unwrap(rawValue); + if (!(unwrapped instanceof PointerValue)) return false; + PointerValue pointer = (PointerValue) unwrapped; + String pointedType = normalizeTypeSignature(pointer.pointedType); + String expectedPointedType = normalizeTypeSignature(type.substring(1)); + return expectedPointedType.equals(pointedType); + } + if (type.startsWith("[") && type.endsWith("]")) { if (!(rawValue instanceof List || rawValue instanceof NaturalArray)) return false; @@ -1206,9 +1384,17 @@ private boolean validateTypeInternal(String typeSig, Object rawValue, String con private boolean isValidTypeSignature(String str) { if (str == null || str.isEmpty()) return false; + + if (isPointerType(str)) { + return isValidTypeSignature(str.substring(1)); + } + if (isSizedArrayType(str)) { + return isValidTypeSignature(getSizedArrayElementType(str)); + } if (str.equals("int") || str.equals("float") || str.equals("text") || - str.equals("bool") || str.equals("type") || str.equals("none")) { + str.equals("bool") || str.equals("type") || str.equals("none") || + isUnsafeNumericType(str)) { return true; } @@ -1256,6 +1442,13 @@ private boolean isTypeStructurallyValid(String typeSig) { } String type = typeSig.trim(); if (type.isEmpty()) return false; + + if (isPointerType(type)) { + return isTypeStructurallyValid(type.substring(1)); + } + if (isSizedArrayType(type)) { + return isTypeStructurallyValid(getSizedArrayElementType(type)); + } if (type.startsWith("[") && type.endsWith("]")) { String inner = type.substring(1, type.length() - 1); @@ -1272,7 +1465,8 @@ private boolean isTypeStructurallyValid(String typeSig) { if (type.equals(INT.toString()) || type.equals(FLOAT.toString()) || type.equals(TEXT.toString()) || type.equals(BOOL.toString()) || - type.equals(ANY.toString()) || type.equals("none")) { + type.equals(ANY.toString()) || type.equals("none") || + isUnsafeNumericType(type)) { return true; } if (Character.isUpperCase(type.charAt(0))) return true; @@ -1280,6 +1474,11 @@ private boolean isTypeStructurallyValid(String typeSig) { } private boolean checkPrimitiveMatch(String type, Object rawValue) { + if (isPointerType(type)) { + Object unwrapped = unwrap(rawValue); + return unwrapped instanceof PointerValue; + } + if (type.startsWith("[") && type.endsWith("]")) { if (type.equals("[]")) { return rawValue instanceof List || rawValue instanceof NaturalArray; @@ -1306,6 +1505,15 @@ private boolean checkPrimitiveMatch(String type, Object rawValue) { rawValue instanceof Boolean; } else if (type.equals("none")) { return isNoneValue(rawValue); + } else if (isUnsafeNumericType(type)) { + return rawValue instanceof IntLiteral + || rawValue instanceof FloatLiteral + || rawValue instanceof Integer + || rawValue instanceof Long + || rawValue instanceof Float + || rawValue instanceof Double + || rawValue instanceof AutoStackingNumber + || (rawValue instanceof Value && isUnsafeNumericType(((Value) rawValue).activeType)); } return false; } diff --git a/src/main/java/cod/parser/BaseParser.java b/src/main/java/cod/parser/BaseParser.java index 8ecdf1db..a55f53bf 100644 --- a/src/main/java/cod/parser/BaseParser.java +++ b/src/main/java/cod/parser/BaseParser.java @@ -271,9 +271,31 @@ protected String getTypeName(TokenType type) { } protected boolean isTypeStart(Token token) { - return any(is(token, INT, TEXT, FLOAT, BOOL, TYPE), + return any(is(token, INT, TEXT, FLOAT, BOOL, TYPE, I8, I16, I32, I64, U8, U16, U32, U64, F32, F64), is(token, ID), - is(token, LPAREN, LBRACKET)); + is(token, LPAREN, LBRACKET, MUL)); + } + + protected boolean isUnsafeTypeContext() { + return ctx.isInUnsafeDeclaration(); + } + + protected boolean isUnsafeNumericTypeKeyword(Token token) { + return is(token, I8, I16, I32, I64, U8, U16, U32, U64, F32, F64); + } + + protected boolean isUnsafeNumericTypeName(String typeName) { + if (typeName == null) return false; + return typeName.equals("i8") + || typeName.equals("i16") + || typeName.equals("i32") + || typeName.equals("i64") + || typeName.equals("u8") + || typeName.equals("u16") + || typeName.equals("u32") + || typeName.equals("u64") + || typeName.equals("f32") + || typeName.equals("f64"); } protected String parseQualifiedName() { @@ -303,7 +325,13 @@ protected boolean canBeMethod(Token token) { protected String parseTypeReference() { StringBuilder type = new StringBuilder(); - if (is(LBRACKET)) { + if (is(MUL)) { + Token pointerToken = expect(MUL); + if (!isUnsafeTypeContext()) { + throw error("Pointer types can only be used inside an unsafe class or method", pointerToken); + } + type.append("*").append(parseTypeReference()); + } else if (is(LBRACKET)) { expect(LBRACKET); if (is(RBRACKET)) { expect(RBRACKET); @@ -319,12 +347,30 @@ protected String parseTypeReference() { Token typeToken = now(); if (isTypeStart(typeToken) && !is(typeToken, LBRACKET)) { String typeName = consume().getText(); + if (isUnsafeNumericTypeName(typeName) && !isUnsafeTypeContext()) { + throw error( + "Unsafe type '" + typeName + "' can only be used inside an unsafe class or method", + typeToken); + } type.append(typeName); } else { throw error("Expected type name"); } } + while (is(LBRACKET)) { + Token lbracketToken = expect(LBRACKET); + if (!isUnsafeTypeContext()) { + throw error("Sized array types can only be used inside an unsafe class or method", lbracketToken); + } + if (is(INT_LIT)) { + type.append("[").append(expect(INT_LIT).getText()).append("]"); + } else { + type.append("[]"); + } + expect(RBRACKET); + } + if (consume(QUESTION)) { return type.toString() + "|none"; } @@ -381,12 +427,12 @@ protected boolean isExprStart(Token t) { return false; } return any(is(t, INT_LIT, FLOAT_LIT, TEXT_LIT, BOOL_LIT, ID), - is(t, LPAREN, LBRACKET, BANG, PLUS, MINUS, DOLLAR), + is(t, LPAREN, LBRACKET, BANG, PLUS, MINUS, DOLLAR, AMPERSAND, MUL), is(t, NONE, TRUE, FALSE, SUPER, THIS)); } protected boolean isClassStart() { - if (is(SHARE, LOCAL)) { + if (is(SHARE, LOCAL, UNSAFE)) { return true; } diff --git a/src/main/java/cod/parser/DeclarationParser.java b/src/main/java/cod/parser/DeclarationParser.java index 5a66698d..e3749a81 100644 --- a/src/main/java/cod/parser/DeclarationParser.java +++ b/src/main/java/cod/parser/DeclarationParser.java @@ -23,6 +23,7 @@ public class DeclarationParser extends BaseParser { private final PolicyValidator policyValidator; private Type currentParsingClass = null; + private Method currentParsingMethod = null; public DeclarationParser( ParserContext ctx, StatementParser statementParser, ImportResolver importResolver) { super(ctx); @@ -48,6 +49,21 @@ private Type getCurrentParsingClass() { return currentParsingClass; } + private void setCurrentParsingMethod(Method method) { + currentParsingMethod = method; + } + + @Override + protected boolean isUnsafeTypeContext() { + if (super.isUnsafeTypeContext()) { + return true; + } + if (currentParsingMethod != null && currentParsingMethod.isUnsafe) { + return true; + } + return currentParsingClass != null && currentParsingClass.isUnsafe; + } + public void validateClassViralPolicies(Type type, Program currentProgram) { policyValidator.validateClassViralPolicies(type, currentProgram); @@ -220,6 +236,7 @@ private void parseNamedArgumentList(List args, List argNames) { public Type parseType() { Keyword visibility = null; Token visibilityToken = null; + boolean isUnsafeType = false; if (is(SHARE, LOCAL)) { visibilityToken = now(); @@ -237,6 +254,11 @@ public Type parseType() { } } + if (is(UNSAFE)) { + consume(); + isUnsafeType = true; + } + Token typeNameToken = now(); String typeName = expect(ID).getText(); @@ -281,6 +303,7 @@ public Type parseType() { } Type type = ASTFactory.createType(typeName, visibility, extendName, typeNameToken); + type.isUnsafe = isUnsafeType; type.implementedPolicies = implementedPolicies; type.policyTokens = policyTokens; type.extendToken = extendToken; @@ -288,25 +311,32 @@ public Type parseType() { setCurrentParsingClass(type); - expect(LBRACE); - while (!is(RBRACE)) { - if (isFieldDeclaration()) { - type.fields.add(parseField()); - } else if (isConstructorDeclaration()) { - Constructor constructor = parseConstructor(); - type.constructors.add(constructor); - } else if (isMethodDeclaration()) { - Method method = parseMethod(); - method.associatedClass = type.name; - type.methods.add(method); - } else { - type.statements.add(statementParser.parseStmt()); + if (type.isUnsafe) { + ctx.enterUnsafeDeclaration(); + } + try { + expect(LBRACE); + while (!is(RBRACE)) { + if (isFieldDeclaration()) { + type.fields.add(parseField()); + } else if (isConstructorDeclaration()) { + Constructor constructor = parseConstructor(); + type.constructors.add(constructor); + } else if (isMethodDeclaration()) { + Method method = parseMethod(); + method.associatedClass = type.name; + type.methods.add(method); + } else { + type.statements.add(statementParser.parseStmt()); + } } + expect(RBRACE); + } finally { + if (type.isUnsafe) { + ctx.exitUnsafeDeclaration(); + } + setCurrentParsingClass(null); } - - setCurrentParsingClass(null); - - expect(RBRACE); return type; } @@ -467,38 +497,68 @@ public Method parseMethod() { boolean isBuiltin = false; Keyword visibility = Keyword.SHARE; boolean isPolicyMethod = false; + boolean isUnsafeMethod = false; Token visibilityToken = null; + boolean sawVisibility = false; - if (is(POLICY)) { - expect(POLICY); - isPolicyMethod = true; + if (is(SHARE, LOCAL)) { + sawVisibility = true; + visibilityToken = now(); + Token currentVisibility = consume(); - Type currentClass = getCurrentParsingClass(); - if (!nil(currentClass)) { - visibility = currentClass.visibility; - } else { - visibility = Keyword.SHARE; - } + if (is(currentVisibility, SHARE)) { + visibility = Keyword.SHARE; + } else if (is(currentVisibility, LOCAL)) { + visibility = Keyword.LOCAL; + } else { + throw error( + "Internal parser error: isVisibilityModifier() returned true for non-visibility keyword: '" + + currentVisibility.getText() + + "'", + visibilityToken); + } + } - } else if (is(BUILTIN)) { + boolean consumedModifier = true; + while (consumedModifier) { + consumedModifier = false; + if (is(POLICY)) { + expect(POLICY); + isPolicyMethod = true; + consumedModifier = true; + continue; + } + if (is(BUILTIN)) { expect(BUILTIN); isBuiltin = true; - visibility = Keyword.SHARE; - } else if (is(SHARE, LOCAL)) { - visibilityToken = now(); - Token currentVisibility = consume(); + consumedModifier = true; + continue; + } + if (is(UNSAFE)) { + expect(UNSAFE); + isUnsafeMethod = true; + consumedModifier = true; + } + } - if (is(currentVisibility, SHARE)) { - visibility = Keyword.SHARE; - } else if (is(currentVisibility, LOCAL)) { - visibility = Keyword.LOCAL; - } else { - throw error( - "Internal parser error: isVisibilityModifier() returned true for non-visibility keyword: '" - + currentVisibility.getText() - + "'", - visibilityToken); - } + if (is(SHARE, LOCAL)) { + throw error("Visibility modifier must appear before other modifiers in method declarations", now()); + } + + if (isUnsafeMethod && !sawVisibility) { + throw error( + "Unsafe methods require an explicit visibility modifier before 'unsafe'. " + + "Expected: share unsafe methodName(...) or local unsafe methodName(...)", + startToken); + } + + if (isPolicyMethod && !sawVisibility) { + Type currentClass = getCurrentParsingClass(); + if (!nil(currentClass)) { + visibility = currentClass.visibility; + } else { + visibility = Keyword.SHARE; + } } String methodName; @@ -518,100 +578,113 @@ public Method parseMethod() { Method method = ASTFactory.createMethod(methodName, visibility, null, nameToken); method.isBuiltin = isBuiltin; method.isPolicyMethod = isPolicyMethod; + method.isUnsafe = isUnsafeMethod; + + if (method.isUnsafe) { + ctx.enterUnsafeDeclaration(); + } + setCurrentParsingMethod(method); + + try { + expect(LPAREN); + + if (isBuiltin) { + int parenDepth = 1; + while (!is(EOF) && parenDepth > 0) { + Token t = now(); + if (is(t, LPAREN)) { + parenDepth++; + } else if (is(t, RPAREN)) { + parenDepth--; + if (parenDepth == 0) { + expect(RPAREN); + break; + } + } + consume(); + } + } else { + if (!is(RPAREN)) { + method.parameters.add(parseParameter()); + while (consume(COMMA)) { + method.parameters.add(parseParameter()); + } + } + expect(RPAREN); + } - expect(LPAREN); - - if (isBuiltin) { - int parenDepth = 1; - while (!is(EOF) && parenDepth > 0) { - Token t = now(); - if (is(t, LPAREN)) { - parenDepth++; - } else if (is(t, RPAREN)) { - parenDepth--; - if (parenDepth == 0) { - expect(RPAREN); - break; - } - } - consume(); - } - } else { - if (!is(RPAREN)) { - method.parameters.add(parseParameter()); - while (consume(COMMA)) { - method.parameters.add(parseParameter()); - } - } - expect(RPAREN); - } - - // Parse slot contract if present (:: syntax) - if (isSlotDeclaration()) { - method.returnSlots = slotParser.parseSlotContract(); - } else { - method.returnSlots = new ArrayList(); - } + // Parse slot contract if present (:: syntax) + if (isSlotDeclaration()) { + method.returnSlots = slotParser.parseSlotContract(); + } else { + method.returnSlots = new ArrayList(); + } - if (isBuiltin) { - while (getPosition() < tokens.size()) { - Token current = now(); + if (isBuiltin) { + while (getPosition() < tokens.size()) { + Token current = now(); - if (is(current, RBRACE) - || is(current, SHARE, LOCAL, BUILTIN, POLICY)) { - break; - } + if (is(current, RBRACE) + || is(current, SHARE, LOCAL, BUILTIN, POLICY, UNSAFE)) { + break; + } - consume(); - } + consume(); + } - if (is(TILDE_ARROW, LBRACE)) { - Token current = now(); - throw error( - "Builtin method '" - + methodName - + "' cannot have a body. " - + "Builtin methods are only declarations, not implementations.\n" - + "Remove '~>' or '{...}' after builtin method signature.", - current); - } + if (is(TILDE_ARROW, LBRACE)) { + Token current = now(); + throw error( + "Builtin method '" + + methodName + + "' cannot have a body. " + + "Builtin methods are only declarations, not implementations.\n" + + "Remove '~>' or '{...}' after builtin method signature.", + current); + } - return method; - } + return method; + } - // Parse method body - if (is(TILDE_ARROW)) { - Token tildeArrowToken = now(); - expect(TILDE_ARROW); + // Parse method body + if (is(TILDE_ARROW)) { + Token tildeArrowToken = now(); + expect(TILDE_ARROW); - List slotAssignments = - slotParser.parseParenthesizedSlotAssignments(tildeArrowToken); + List slotAssignments = + slotParser.parseParenthesizedSlotAssignments(tildeArrowToken); - if (slotAssignments.size() == 1) { - method.body.add(slotAssignments.get(0)); - } else { - MultipleSlotAssignment multiAssign = - ASTFactory.createMultipleSlotAsmt(slotAssignments, tildeArrowToken); - method.body.add(multiAssign); - } + if (slotAssignments.size() == 1) { + method.body.add(slotAssignments.get(0)); + } else { + MultipleSlotAssignment multiAssign = + ASTFactory.createMultipleSlotAsmt(slotAssignments, tildeArrowToken); + method.body.add(multiAssign); + } - } else if (is(LBRACE)) { - expect(LBRACE); - while (!is(RBRACE)) { - method.body.add(statementParser.parseStmt()); - } - expect(RBRACE); - - ReturnContractValidator.validateMethodReturnContract(method, currentParsingClass, startToken); - } else { - Token current = now(); - throw error( - "Expected '~>' or '{' after method signature, but found " - + getTypeName(current.type) - + " ('" - + current.getText() - + "')", - current); + } else if (is(LBRACE)) { + expect(LBRACE); + while (!is(RBRACE)) { + method.body.add(statementParser.parseStmt()); + } + expect(RBRACE); + + ReturnContractValidator.validateMethodReturnContract(method, currentParsingClass, startToken); + } else { + Token current = now(); + throw error( + "Expected '~>' or '{' after method signature, but found " + + getTypeName(current.type) + + " ('" + + current.getText() + + "')", + current); + } + } finally { + setCurrentParsingMethod(null); + if (method.isUnsafe) { + ctx.exitUnsafeDeclaration(); + } } return method; @@ -822,9 +895,15 @@ public Boolean parse() throws ParseError { Token first = next(offset); if (nil(first)) return false; - if (is(first, SHARE, LOCAL, BUILTIN, POLICY)) { + if (is(first, SHARE, LOCAL, BUILTIN, POLICY, UNSAFE)) { offset++; while (wsComments(offset)) offset++; + Token maybeMoreModifier = next(offset); + while (is(maybeMoreModifier, BUILTIN, POLICY, UNSAFE)) { + offset++; + while (wsComments(offset)) offset++; + maybeMoreModifier = next(offset); + } } Token nameToken = next(offset); @@ -865,8 +944,9 @@ public Boolean parse() throws ParseError { return false; } - String type = parseTypeReference(); - if (nil(type)) { + try { + parseTypeReference(); + } catch (ParseError e) { return false; } diff --git a/src/main/java/cod/parser/ExpressionParser.java b/src/main/java/cod/parser/ExpressionParser.java index 9b05713e..f26f3f2b 100644 --- a/src/main/java/cod/parser/ExpressionParser.java +++ b/src/main/java/cod/parser/ExpressionParser.java @@ -1321,7 +1321,7 @@ private Expr parsePrefix() { } } - if (is(BANG, PLUS, MINUS)) { + if (is(BANG, PLUS, MINUS, AMPERSAND, MUL)) { Token opToken = consume(); Expr operand = parsePrecedence(PREC_UNARY); return ASTFactory.createUnaryOp(opToken.getText(), operand, opToken); diff --git a/src/main/java/cod/parser/MainParser.java b/src/main/java/cod/parser/MainParser.java index 27e24a6b..49bf90bd 100644 --- a/src/main/java/cod/parser/MainParser.java +++ b/src/main/java/cod/parser/MainParser.java @@ -89,7 +89,9 @@ public Program parseProgram() { // DIRECT CHECK: If we see "local" or "share" text, try to parse as method first if (currentToken != null && - ("local".equals(currentToken.getText()) || "share".equals(currentToken.getText()))) { + ("local".equals(currentToken.getText()) + || "share".equals(currentToken.getText()) + || "unsafe".equals(currentToken.getText()))) { ParserState savedState = getCurrentState(); try { Method method = declarationParser.parseMethod(); @@ -183,7 +185,10 @@ private boolean isTopLevelMethodDeclaration() { public Boolean parse() throws ParseError { ParserState savedState = getCurrentState(); try { - if (is(SHARE, LOCAL)) { + if (is(SHARE, LOCAL, UNSAFE)) { + consume(); + } + while (is(BUILTIN, POLICY, UNSAFE)) { consume(); } @@ -342,12 +347,20 @@ private boolean isMethodDeclarationStart() { Token first = now(); if (first == null) return false; - if (is(first, LOCAL, SHARE, BUILTIN, POLICY)) { - Token second = next(); - if (is(second, ID) || canBeMethod(second)) { - Token third = next(2); - return is(third, LPAREN); - } + int offset = 0; + Token token = next(offset); + + if (is(token, LOCAL, SHARE, UNSAFE, BUILTIN, POLICY)) { + offset++; + token = next(offset); + } + while (is(token, BUILTIN, POLICY, UNSAFE)) { + offset++; + token = next(offset); + } + if (is(token, ID) || canBeMethod(token)) { + Token afterName = next(offset + 1); + return is(afterName, LPAREN); } return false; } finally { diff --git a/src/main/java/cod/parser/context/ParserContext.java b/src/main/java/cod/parser/context/ParserContext.java index daab7a87..b5a86ac4 100644 --- a/src/main/java/cod/parser/context/ParserContext.java +++ b/src/main/java/cod/parser/context/ParserContext.java @@ -13,6 +13,7 @@ public final class ParserContext { private ParserState state; private final Stack backtrackStack = new Stack<>(); + private int unsafeDeclarationDepth = 0; public ParserContext(ParserState initialState) { this.state = initialState; @@ -164,8 +165,22 @@ public List getTokens() { return state.getTokens(); } + public void enterUnsafeDeclaration() { + unsafeDeclarationDepth++; + } + + public void exitUnsafeDeclaration() { + if (unsafeDeclarationDepth > 0) { + unsafeDeclarationDepth--; + } + } + + public boolean isInUnsafeDeclaration() { + return unsafeDeclarationDepth > 0; + } + @Override public String toString() { return state.toString(); } -} \ No newline at end of file +} diff --git a/src/main/java/cod/semantic/ConstructorResolver.java b/src/main/java/cod/semantic/ConstructorResolver.java index 364a582c..ab94a92d 100644 --- a/src/main/java/cod/semantic/ConstructorResolver.java +++ b/src/main/java/cod/semantic/ConstructorResolver.java @@ -162,6 +162,11 @@ public ObjectInstance resolveAndCreate(ConstructorCall call, ExecutionContext ct "Available types: " + getAvailableTypes(ctx) ); } + + if (type.isUnsafe && !isUnsafeExecutionContext(ctx) && !ExecutionContext.isUnsafeCommitAllowed()) { + throw new ProgramError( + "Unsafe class '" + type.name + "' cannot be constructed in a safe context. Use safe(" + type.name + "(...))."); + } validateInheritanceHierarchy(type, ctx); @@ -359,6 +364,7 @@ private Map matchConstructorArguments(Constructor constructor, if (!typeSystem.validateType(param.type, argValue)) { return null; } + argValue = typeSystem.normalizeForDeclaredType(param.type, argValue); argValues.put(param.name, argValue); } else if (!param.hasDefaultValue) { return null; @@ -381,6 +387,7 @@ private Map matchConstructorArguments(Constructor constructor, if (!typeSystem.validateType(param.type, argValue)) { return null; } + argValue = typeSystem.normalizeForDeclaredType(param.type, argValue); argValues.put(param.name, argValue); } else if (!param.hasDefaultValue) { return null; @@ -430,6 +437,7 @@ private ConstructorMatch tryMatchNamedConstructor(Constructor constructor, if (!typeSystem.validateType(param.type, argValue)) { return null; } + argValue = typeSystem.normalizeForDeclaredType(param.type, argValue); argumentValues.put(param.name, argValue); @@ -439,6 +447,9 @@ private ConstructorMatch tryMatchNamedConstructor(Constructor constructor, } } else if (param.hasDefaultValue) { Object defaultValue = evaluateDefaultValue(evaluator, param, ctx); + if (param.type != null) { + defaultValue = typeSystem.normalizeForDeclaredType(param.type, defaultValue); + } argumentValues.put(param.name, defaultValue); conversionScore++; } else { @@ -483,6 +494,9 @@ private ConstructorMatch tryMatchPositionalConstructor(Constructor constructor, return null; } Object defaultValue = evaluateDefaultValue(evaluator, param, ctx); + if (param.type != null) { + defaultValue = typeSystem.normalizeForDeclaredType(param.type, defaultValue); + } argumentValues.put(param.name, defaultValue); conversionScore++; } else { @@ -495,6 +509,7 @@ private ConstructorMatch tryMatchPositionalConstructor(Constructor constructor, if (!typeSystem.validateType(param.type, argValue)) { return null; } + argValue = typeSystem.normalizeForDeclaredType(param.type, argValue); argumentValues.put(param.name, argValue); @@ -506,6 +521,9 @@ private ConstructorMatch tryMatchPositionalConstructor(Constructor constructor, argIndex++; } else if (param.hasDefaultValue) { Object defaultValue = evaluateDefaultValue(evaluator, param, ctx); + if (param.type != null) { + defaultValue = typeSystem.normalizeForDeclaredType(param.type, defaultValue); + } argumentValues.put(param.name, defaultValue); conversionScore++; } else { @@ -550,6 +568,8 @@ private Object evaluateArgument(Evaluator evaluator, Expr argExpr, Param param, } return null; } + + argValue = typeSystem.normalizeForDeclaredType(param.type, argValue); if (param.type.contains("|")) { String activeType = typeSystem.getConcreteType(typeSystem.unwrap(argValue)); @@ -581,6 +601,23 @@ private Object evaluateDefaultValue(Evaluator evaluator, Param param, ExecutionC throw new InternalError("Default value evaluation failed for parameter: " + param.name, e); } } + + private boolean isUnsafeExecutionContext(ExecutionContext ctx) { + if (ctx == null) return false; + if (ctx.currentClass != null && ctx.currentClass.isUnsafe) { + return true; + } + if (ctx.currentMethodName == null || ctx.currentMethodName.isEmpty()) { + return false; + } + Type searchType = ctx.currentClass; + if (searchType == null && ctx.objectInstance != null) { + searchType = ctx.objectInstance.type; + } + if (searchType == null) return false; + Method currentMethod = findMethodInHierarchy(searchType, ctx.currentMethodName, ctx); + return currentMethod != null && currentMethod.isUnsafe; + } private boolean isUnderscore(Expr expr) { return expr instanceof Identifier && "_".equals(((Identifier) expr).name); @@ -703,6 +740,8 @@ private ObjectInstance createInstance(Type type, } ExecutionContext constrCtx = new ExecutionContext(obj, new HashMap(), null, null, ctx.getTypeHandler()); + constrCtx.currentClass = type; + constrCtx.setUnsafeExecutionContext(type.isUnsafe); for (Map.Entry entry : match.argumentValues.entrySet()) { constrCtx.setVariable(entry.getKey(), entry.getValue()); @@ -857,13 +896,36 @@ private void initializeFields(Type type, ObjectInstance obj, ExecutionContext ct obj.fields.put(field.name, defaultValue); } else { String fieldType = field.type; - if (fieldType.contains(INT.toString())) { + if (fieldType != null && typeSystem.isSizedArrayType(fieldType)) { + int length = typeSystem.getSizedArrayLength(fieldType); + int sizedLength = Math.max(length, 0); + String elementType = typeSystem.getSizedArrayElementType(fieldType); + List initialized = new ArrayList(sizedLength); + Object elementDefault = 0; + if (elementType != null && typeSystem.isUnsafeNumericType(elementType)) { + elementDefault = typeSystem.convertType(0, elementType); + } else if (elementType != null && elementType.contains(FLOAT.toString())) { + elementDefault = 0.0; + } else if (elementType != null && elementType.contains(TEXT.toString())) { + elementDefault = ""; + } else if (elementType != null && elementType.contains(BOOL.toString())) { + elementDefault = false; + } + for (int i = 0; i < sizedLength; i++) { + initialized.add(elementDefault); + } + obj.fields.put(field.name, initialized); + } else if (fieldType != null && typeSystem.isPointerType(fieldType)) { + obj.fields.put(field.name, null); + } else if (fieldType != null && typeSystem.isUnsafeNumericType(fieldType)) { + obj.fields.put(field.name, typeSystem.convertType(0, fieldType)); + } else if (fieldType != null && fieldType.contains(INT.toString())) { obj.fields.put(field.name, 0); - } else if (fieldType.contains(FLOAT.toString())) { + } else if (fieldType != null && fieldType.contains(FLOAT.toString())) { obj.fields.put(field.name, 0.0); - } else if (fieldType.contains(TEXT.toString())) { + } else if (fieldType != null && fieldType.contains(TEXT.toString())) { obj.fields.put(field.name, ""); - } else if (fieldType.contains(BOOL.toString())) { + } else if (fieldType != null && fieldType.contains(BOOL.toString())) { obj.fields.put(field.name, false); } else { obj.fields.put(field.name, null);