diff --git a/.gitignore b/.gitignore index 242fabd..e936de5 100644 --- a/.gitignore +++ b/.gitignore @@ -64,3 +64,10 @@ fastlane/report.xml fastlane/Preview.html fastlane/screenshots/**/*.png fastlane/test_output + + +# Models +Models/ +*.mlpackage +AnySense.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/ +AnySense.xcworkspace/xcshareddata/swiftpm/ diff --git a/AnySense.xcodeproj/project.pbxproj b/AnySense.xcodeproj/project.pbxproj index f030751..c05d8c9 100644 --- a/AnySense.xcodeproj/project.pbxproj +++ b/AnySense.xcodeproj/project.pbxproj @@ -7,12 +7,20 @@ objects = { /* Begin PBXBuildFile section */ + 016073392EF5BE610061563B /* InferenceView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 016073382EF5BE610061563B /* InferenceView.swift */; }; + 01682BC42DF0D88100CDA03C /* ModelInfo.swift in Sources */ = {isa = PBXBuildFile; fileRef = 01F5B7C92DEA6A1000061B16 /* ModelInfo.swift */; }; + 01682BC52DF0D88C00CDA03C /* MLModel+Extensions.swift in Sources */ = {isa = PBXBuildFile; fileRef = 01F5B7CB2DEA6A2000061B16 /* MLModel+Extensions.swift */; }; + 01682BC92DF4D7CB00CDA03C /* ARVisualizationManager.swift in Sources */ = {isa = PBXBuildFile; fileRef = 01682BC82DF4D7CB00CDA03C /* ARVisualizationManager.swift */; }; + 01B77E6D2ED94950007823B6 /* demo.mlpackage in Resources */ = {isa = PBXBuildFile; fileRef = 01B77E6C2ED94950007823B6 /* demo.mlpackage */; }; + 01D2FDD82E64C33600F0BA98 /* ActionTransformUtils.swift in Sources */ = {isa = PBXBuildFile; fileRef = 01D2FDD52E64C33600F0BA98 /* ActionTransformUtils.swift */; }; + 01F5B7C42DEA697200061B16 /* MLInferenceManager.swift in Sources */ = {isa = PBXBuildFile; fileRef = 01F5B7C32DEA697200061B16 /* MLInferenceManager.swift */; }; + 01F5B7C62DEA698D00061B16 /* MLInferenceResultsView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 01F5B7C52DEA698D00061B16 /* MLInferenceResultsView.swift */; }; + 01F5B7C82DEA6A0000061B16 /* ModelManager.swift in Sources */ = {isa = PBXBuildFile; fileRef = 01F5B7C72DEA6A0000061B16 /* ModelManager.swift */; }; B9104F622BFDE63C000D4DDD /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = B9104F612BFDE63C000D4DDD /* Assets.xcassets */; }; B9944A342CFE983300232FBB /* MainPage.swift in Sources */ = {isa = PBXBuildFile; fileRef = B9944A332CFE982E00232FBB /* MainPage.swift */; }; B997CCD32D4E081300F62B49 /* dataStorage.swift in Sources */ = {isa = PBXBuildFile; fileRef = B997CCD22D4E081300F62B49 /* dataStorage.swift */; }; B997CCD52D4E082900F62B49 /* AnySenseApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = B997CCD42D4E082900F62B49 /* AnySenseApp.swift */; }; B9E4B22D2D62A9BA0032E877 /* BluetoothManager.swift in Sources */ = {isa = PBXBuildFile; fileRef = B9E4B2292D62A9BA0032E877 /* BluetoothManager.swift */; }; - B9E4B22E2D62A9BA0032E877 /* WebRTCManager.swift in Sources */ = {isa = PBXBuildFile; fileRef = B9E4B22B2D62A9BA0032E877 /* WebRTCManager.swift */; }; B9E4B22F2D62A9BA0032E877 /* ARViewContainer.swift in Sources */ = {isa = PBXBuildFile; fileRef = B9E4B2282D62A9BA0032E877 /* ARViewContainer.swift */; }; B9E4B2302D62A9BA0032E877 /* USBManager.swift in Sources */ = {isa = PBXBuildFile; fileRef = B9E4B22A2D62A9BA0032E877 /* USBManager.swift */; }; B9E4B2362D62A9C40032E877 /* peripheralView.swift in Sources */ = {isa = PBXBuildFile; fileRef = B9E4B2332D62A9C40032E877 /* peripheralView.swift */; }; @@ -42,6 +50,15 @@ /* End PBXContainerItemProxy section */ /* Begin PBXFileReference section */ + 016073382EF5BE610061563B /* InferenceView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = InferenceView.swift; sourceTree = ""; }; + 01682BC82DF4D7CB00CDA03C /* ARVisualizationManager.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ARVisualizationManager.swift; sourceTree = ""; }; + 01B77E6C2ED94950007823B6 /* demo.mlpackage */ = {isa = PBXFileReference; lastKnownFileType = folder.mlpackage; path = demo.mlpackage; sourceTree = ""; }; + 01D2FDD52E64C33600F0BA98 /* ActionTransformUtils.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ActionTransformUtils.swift; sourceTree = ""; }; + 01F5B7C32DEA697200061B16 /* MLInferenceManager.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MLInferenceManager.swift; sourceTree = ""; }; + 01F5B7C52DEA698D00061B16 /* MLInferenceResultsView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MLInferenceResultsView.swift; sourceTree = ""; }; + 01F5B7C72DEA6A0000061B16 /* ModelManager.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelManager.swift; sourceTree = ""; }; + 01F5B7C92DEA6A1000061B16 /* ModelInfo.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelInfo.swift; sourceTree = ""; }; + 01F5B7CB2DEA6A2000061B16 /* MLModel+Extensions.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = "MLModel+Extensions.swift"; sourceTree = ""; }; B9104F5A2BFDE63A000D4DDD /* AnySense.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = AnySense.app; sourceTree = BUILT_PRODUCTS_DIR; }; B9104F612BFDE63C000D4DDD /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; B9104F6A2BFDE63C000D4DDD /* AnySenseTests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = AnySenseTests.xctest; sourceTree = BUILT_PRODUCTS_DIR; }; @@ -53,7 +70,6 @@ B9E4B2282D62A9BA0032E877 /* ARViewContainer.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ARViewContainer.swift; sourceTree = ""; }; B9E4B2292D62A9BA0032E877 /* BluetoothManager.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = BluetoothManager.swift; sourceTree = ""; }; B9E4B22A2D62A9BA0032E877 /* USBManager.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = USBManager.swift; sourceTree = ""; }; - B9E4B22B2D62A9BA0032E877 /* WebRTCManager.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WebRTCManager.swift; sourceTree = ""; }; B9E4B2312D62A9C40032E877 /* accountView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = accountView.swift; sourceTree = ""; }; B9E4B2322D62A9C40032E877 /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = ""; }; B9E4B2332D62A9C40032E877 /* peripheralView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = peripheralView.swift; sourceTree = ""; }; @@ -110,6 +126,9 @@ B9104F5C2BFDE63A000D4DDD /* AnySense */ = { isa = PBXGroup; children = ( + 01F5B7C92DEA6A1000061B16 /* ModelInfo.swift */, + 01B77E6C2ED94950007823B6 /* demo.mlpackage */, + 01F5B7CB2DEA6A2000061B16 /* MLModel+Extensions.swift */, B9944A332CFE982E00232FBB /* MainPage.swift */, B997CCD22D4E081300F62B49 /* dataStorage.swift */, B9104F612BFDE63C000D4DDD /* Assets.xcassets */, @@ -126,9 +145,12 @@ isa = PBXGroup; children = ( B9E4B2282D62A9BA0032E877 /* ARViewContainer.swift */, + 01F5B7C32DEA697200061B16 /* MLInferenceManager.swift */, + 01D2FDD52E64C33600F0BA98 /* ActionTransformUtils.swift */, + 01682BC82DF4D7CB00CDA03C /* ARVisualizationManager.swift */, + 01F5B7C72DEA6A0000061B16 /* ModelManager.swift */, B9E4B2292D62A9BA0032E877 /* BluetoothManager.swift */, B9E4B22A2D62A9BA0032E877 /* USBManager.swift */, - B9E4B22B2D62A9BA0032E877 /* WebRTCManager.swift */, ); path = Managers; sourceTree = ""; @@ -136,8 +158,10 @@ B9E4B2352D62A9C40032E877 /* Views */ = { isa = PBXGroup; children = ( + 01F5B7C52DEA698D00061B16 /* MLInferenceResultsView.swift */, B9E4B2312D62A9C40032E877 /* accountView.swift */, B9E4B2322D62A9C40032E877 /* ContentView.swift */, + 016073382EF5BE610061563B /* InferenceView.swift */, B9E4B2332D62A9C40032E877 /* peripheralView.swift */, B9E4B2342D62A9C40032E877 /* readView.swift */, ); @@ -266,6 +290,7 @@ isa = PBXResourcesBuildPhase; buildActionMask = 2147483647; files = ( + 01B77E6D2ED94950007823B6 /* demo.mlpackage in Resources */, B9104F622BFDE63C000D4DDD /* Assets.xcassets in Resources */, B9E4D2B42D6965B90044F2D4 /* AnySenseLaunchScreen.storyboard in Resources */, ); @@ -319,16 +344,23 @@ isa = PBXSourcesBuildPhase; buildActionMask = 2147483647; files = ( + 01682BC42DF0D88100CDA03C /* ModelInfo.swift in Sources */, B9E4B2362D62A9C40032E877 /* peripheralView.swift in Sources */, + 01682BC92DF4D7CB00CDA03C /* ARVisualizationManager.swift in Sources */, B9E4B2372D62A9C40032E877 /* accountView.swift in Sources */, B9E4B2382D62A9C40032E877 /* readView.swift in Sources */, B9E4B2392D62A9C40032E877 /* ContentView.swift in Sources */, B997CCD52D4E082900F62B49 /* AnySenseApp.swift in Sources */, B997CCD32D4E081300F62B49 /* dataStorage.swift in Sources */, + 01682BC52DF0D88C00CDA03C /* MLModel+Extensions.swift in Sources */, + 01F5B7C42DEA697200061B16 /* MLInferenceManager.swift in Sources */, + 01D2FDD82E64C33600F0BA98 /* ActionTransformUtils.swift in Sources */, + 01F5B7C82DEA6A0000061B16 /* ModelManager.swift in Sources */, + 016073392EF5BE610061563B /* InferenceView.swift in Sources */, B9E4B22D2D62A9BA0032E877 /* BluetoothManager.swift in Sources */, - B9E4B22E2D62A9BA0032E877 /* WebRTCManager.swift in Sources */, B9E4B22F2D62A9BA0032E877 /* ARViewContainer.swift in Sources */, B9E4B2302D62A9BA0032E877 /* USBManager.swift in Sources */, + 01F5B7C62DEA698D00061B16 /* MLInferenceResultsView.swift in Sources */, B9944A342CFE983300232FBB /* MainPage.swift in Sources */, ); runOnlyForDeploymentPostprocessing = 0; @@ -492,7 +524,7 @@ ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; CODE_SIGN_IDENTITY = "Apple Development"; CODE_SIGN_STYLE = Automatic; - CURRENT_PROJECT_VERSION = 4; + CURRENT_PROJECT_VERSION = 21; DEVELOPMENT_ASSET_PATHS = ""; DEVELOPMENT_TEAM = 88NB9U5CK6; ENABLE_PREVIEWS = YES; @@ -516,7 +548,7 @@ "$(inherited)", "@executable_path/Frameworks", ); - MARKETING_VERSION = 1.0; + MARKETING_VERSION = 1.3; PRODUCT_BUNDLE_IDENTIFIER = GRAIL.AnySense; PRODUCT_NAME = "$(TARGET_NAME)"; PROVISIONING_PROFILE_SPECIFIER = ""; @@ -538,7 +570,7 @@ ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; CODE_SIGN_IDENTITY = "Apple Development"; CODE_SIGN_STYLE = Automatic; - CURRENT_PROJECT_VERSION = 4; + CURRENT_PROJECT_VERSION = 21; DEVELOPMENT_ASSET_PATHS = ""; DEVELOPMENT_TEAM = 88NB9U5CK6; ENABLE_PREVIEWS = YES; @@ -561,7 +593,7 @@ "$(inherited)", "@executable_path/Frameworks", ); - MARKETING_VERSION = 1.0; + MARKETING_VERSION = 1.3; PRODUCT_BUNDLE_IDENTIFIER = GRAIL.AnySense; PRODUCT_NAME = "$(TARGET_NAME)"; PROVISIONING_PROFILE_SPECIFIER = ""; @@ -581,7 +613,7 @@ BUNDLE_LOADER = "$(TEST_HOST)"; CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; - DEVELOPMENT_TEAM = 9U23FMDF45; + DEVELOPMENT_TEAM = 57DF45Y5K2; ENABLE_USER_SCRIPT_SANDBOXING = NO; GENERATE_INFOPLIST_FILE = YES; IPHONEOS_DEPLOYMENT_TARGET = 17.4; @@ -601,7 +633,7 @@ BUNDLE_LOADER = "$(TEST_HOST)"; CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; - DEVELOPMENT_TEAM = 9U23FMDF45; + DEVELOPMENT_TEAM = 57DF45Y5K2; ENABLE_USER_SCRIPT_SANDBOXING = NO; GENERATE_INFOPLIST_FILE = YES; IPHONEOS_DEPLOYMENT_TARGET = 17.4; @@ -620,7 +652,7 @@ buildSettings = { CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; - DEVELOPMENT_TEAM = 9U23FMDF45; + DEVELOPMENT_TEAM = 57DF45Y5K2; ENABLE_USER_SCRIPT_SANDBOXING = NO; GENERATE_INFOPLIST_FILE = YES; MARKETING_VERSION = 1.0; @@ -638,7 +670,7 @@ buildSettings = { CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; - DEVELOPMENT_TEAM = 9U23FMDF45; + DEVELOPMENT_TEAM = 57DF45Y5K2; ENABLE_USER_SCRIPT_SANDBOXING = NO; GENERATE_INFOPLIST_FILE = YES; MARKETING_VERSION = 1.0; diff --git a/AnySense.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist b/AnySense.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist new file mode 100644 index 0000000..18d9810 --- /dev/null +++ b/AnySense.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist @@ -0,0 +1,8 @@ + + + + + IDEDidComputeMac32BitWarning + + + diff --git a/AnySense/.gitignore b/AnySense/.gitignore new file mode 100644 index 0000000..dcf5f86 --- /dev/null +++ b/AnySense/.gitignore @@ -0,0 +1 @@ +Models/ \ No newline at end of file diff --git a/AnySense/AnySenseApp.swift b/AnySense/AnySenseApp.swift index c9b21d8..de35206 100644 --- a/AnySense/AnySenseApp.swift +++ b/AnySense/AnySenseApp.swift @@ -11,12 +11,10 @@ import BackgroundTasks @main struct AnySenseApp: App { @StateObject var appStatus = AppInformation() - @StateObject var bluetoothManager = BluetoothManager() var body: some Scene { WindowGroup { ContentView() .environmentObject(appStatus) - .environmentObject(bluetoothManager) } } } diff --git a/AnySense/Assets.xcassets/gripper_closed.imageset/Contents.json b/AnySense/Assets.xcassets/gripper_closed.imageset/Contents.json new file mode 100644 index 0000000..f64768c --- /dev/null +++ b/AnySense/Assets.xcassets/gripper_closed.imageset/Contents.json @@ -0,0 +1,21 @@ +{ + "images" : [ + { + "filename" : "gripper_closed.png", + "idiom" : "universal", + "scale" : "1x" + }, + { + "idiom" : "universal", + "scale" : "2x" + }, + { + "idiom" : "universal", + "scale" : "3x" + } + ], + "info" : { + "author" : "xcode", + "version" : 1 + } +} \ No newline at end of file diff --git a/AnySense/Assets.xcassets/gripper_closed.imageset/gripper_closed.png b/AnySense/Assets.xcassets/gripper_closed.imageset/gripper_closed.png new file mode 100644 index 0000000..46a1b6e Binary files /dev/null and b/AnySense/Assets.xcassets/gripper_closed.imageset/gripper_closed.png differ diff --git a/AnySense/Assets.xcassets/gripper_overlay.imageset/Contents.json b/AnySense/Assets.xcassets/gripper_overlay.imageset/Contents.json new file mode 100644 index 0000000..48cb9a1 --- /dev/null +++ b/AnySense/Assets.xcassets/gripper_overlay.imageset/Contents.json @@ -0,0 +1,21 @@ +{ + "images" : [ + { + "filename" : "gripper_overlay.png", + "idiom" : "universal", + "scale" : "1x" + }, + { + "idiom" : "universal", + "scale" : "2x" + }, + { + "idiom" : "universal", + "scale" : "3x" + } + ], + "info" : { + "author" : "xcode", + "version" : 1 + } +} \ No newline at end of file diff --git a/AnySense/Assets.xcassets/gripper_overlay.imageset/gripper_overlay.png b/AnySense/Assets.xcassets/gripper_overlay.imageset/gripper_overlay.png new file mode 100644 index 0000000..98f57f3 Binary files /dev/null and b/AnySense/Assets.xcassets/gripper_overlay.imageset/gripper_overlay.png differ diff --git a/AnySense/MLModel+Extensions.swift b/AnySense/MLModel+Extensions.swift new file mode 100644 index 0000000..401227c --- /dev/null +++ b/AnySense/MLModel+Extensions.swift @@ -0,0 +1,324 @@ +import Foundation +import CoreML + +// MARK: - MLModel Extensions +extension MLModel { + + /// Compile a .mlmodel file to .mlmodelc with progress tracking + static func compileModel(at sourceURL: URL, + progressHandler: @escaping (Double) -> Void) async throws -> URL { + + return try await withCheckedThrowingContinuation { continuation in + DispatchQueue.global(qos: .userInitiated).async { + do { + // Start compilation + progressHandler(0.1) + + let compiledURL = try MLModel.compileModel(at: sourceURL) + + // Simulate progress updates during compilation + let progressSteps = [0.3, 0.5, 0.7, 0.9] + for progress in progressSteps { + progressHandler(progress) + Thread.sleep(forTimeInterval: 0.5) // Small delay for visual feedback + } + + progressHandler(1.0) + continuation.resume(returning: compiledURL) + + } catch { + continuation.resume(throwing: error) + } + } + } + } + + /// Validate model compatibility and extract metadata + static func validateModel(at url: URL) throws -> ModelMetadata { + let model = try MLModel(contentsOf: url) + return try ModelMetadata(from: model) + } + + /// Get model file size + static func getModelSize(at url: URL) -> Int64 { + do { + let attributes = try FileManager.default.attributesOfItem(atPath: url.path) + return attributes[.size] as? Int64 ?? 0 + } catch { + return 0 + } + } +} + +// MARK: - Model Metadata +struct ModelMetadata { + let inputDescription: MLFeatureDescription? + let outputDescription: MLFeatureDescription? + let modelDescription: String + let isCompatible: Bool + let requiredInputShape: [Int]? + let expectedOutputCount: Int? + let outputFeatureNames: [String] + let primaryOutputName: String? + private let allInputsByName: [String: MLFeatureDescription] + + init(from model: MLModel) throws { + let modelDescription = model.modelDescription + + // Cache all inputs (local first to avoid using self during init) + let inputsByName = modelDescription.inputDescriptionsByName + + // Get first input description (legacy use) + self.inputDescription = inputsByName.values.first + + // Get output description + self.outputDescription = modelDescription.outputDescriptionsByName.values.first + + self.modelDescription = modelDescription.metadata[.description] as? String ?? "No description" + + // Extract input shape if available (try to infer from first image or 4D array) + func localFirstImageLikeInput(_ inputs: [String: MLFeatureDescription]) -> (String, MLFeatureDescription)? { + if let d = inputs["camera_image"] { return ("camera_image", d) } + for (key, desc) in inputs { + switch desc.type { + case .image: return (key, desc) + case .multiArray: + if let shape = desc.multiArrayConstraint?.shape, shape.count >= 4 { return (key, desc) } + default: continue + } + } + return nil + } + if let (_, desc) = localFirstImageLikeInput(inputsByName) { + switch desc.type { + case .image: + if let c = desc.imageConstraint { + self.requiredInputShape = [Int(c.pixelsHigh), Int(c.pixelsWide), 3] + } else { + self.requiredInputShape = [224, 224, 3] + } + case .multiArray: + if let shape = desc.multiArrayConstraint?.shape, shape.count >= 4 { + let h = shape[shape.count-2].intValue + let w = shape[shape.count-1].intValue + self.requiredInputShape = [h, w, 3] + } else { + self.requiredInputShape = [224, 224, 3] + } + default: + self.requiredInputShape = [224, 224, 3] + } + } else { + self.requiredInputShape = nil + } + + // Expected output count (7 joint actions for our use case) + self.expectedOutputCount = modelDescription.outputDescriptionsByName.count + + // Extract output feature names for dynamic handling + self.outputFeatureNames = Array(modelDescription.outputDescriptionsByName.keys).sorted() + self.primaryOutputName = self.outputFeatureNames.first + + // Check compatibility inline - avoid calling helpers before initialization complete + self.isCompatible = !modelDescription.inputDescriptionsByName.isEmpty && + !modelDescription.outputDescriptionsByName.isEmpty && + modelDescription.inputDescriptionsByName.values.contains { desc in + switch desc.type { + case .image, .multiArray: return true + default: return false + } + } && + modelDescription.outputDescriptionsByName.values.contains { desc in + switch desc.type { + case .multiArray: return true + default: return false + } + } + + // Now that init values are ready, set cached inputs map + self.allInputsByName = inputsByName + } + + // MARK: - Dynamic helpers used by MLInferenceManager + enum ModelType { + case pointConditioned + + var displayName: String { + return "Point-Conditioned" + } + } + + var modelType: ModelType { + return .pointConditioned + } + + var temporalFrames: Int { + // Detect temporal dimension in image input shape + // [1,3,3,224,224] → 3 frames, [1,3,224,224] → 1 frame + guard let (_, desc) = firstImageLikeInput() else { return 1 } + + if desc.type == .multiArray, let shape = desc.multiArrayConstraint?.shape { + let dims = shape.map { $0.intValue } + // Check if this is a temporal model: [B, T, C, H, W] where T > 1 + if dims.count == 5 && dims[1] > 1 && dims[2] == 3 { + return dims[1] // Return temporal dimension + } + } + return 1 // Default to single frame + } + + var isTemporalModel: Bool { + return temporalFrames > 1 + } + + var requiresGoalPoint: Bool { + // Heuristic: presence of a second non-image input named "goal_point" or a small (1x3) array input + if allInputsByName.keys.contains("goal_point") { return true } + for (name, desc) in allInputsByName { + if name == "camera_image" { continue } + switch desc.type { + case .multiArray: + if let shape = desc.multiArrayConstraint?.shape { + // Accept 2D [1,3] or [3] or small shapes as goal vector + let dims = shape.map { $0.intValue } + if dims == [1,3] || dims == [3] || dims.suffix(1).first == 3 && dims.reduce(1,*) <= 16 { + return true + } + } + default: break + } + } + return false + } + + var imageInputSize: CGSize? { + guard let (_, desc) = firstImageLikeInput() else { return nil } + switch desc.type { + case .image: + if let c = desc.imageConstraint { return CGSize(width: Int(c.pixelsWide), height: Int(c.pixelsHigh)) } + case .multiArray: + if let shape = desc.multiArrayConstraint?.shape, shape.count >= 4 { + let h = shape[shape.count-2].intValue + let w = shape[shape.count-1].intValue + return CGSize(width: w, height: h) + } + default: break + } + return nil + } + + func getImageInputName() -> String? { + if allInputsByName.keys.contains("camera_image") { return "camera_image" } + if let (name, _) = firstImageLikeInput() { return name } + return nil + } + + func getGoalInputName() -> String? { + if allInputsByName.keys.contains("goal_point") { return "goal_point" } + for (name, desc) in allInputsByName where name != "camera_image" { + if desc.type == .multiArray, let shape = desc.multiArrayConstraint?.shape { + let dims = shape.map { $0.intValue } + if dims == [1,3] || dims == [3] || dims.reduce(1,*) <= 16 { return name } + } + } + return nil + } + + // Find the first image-like input (image or 4D array) + private func firstImageLikeInput() -> (String, MLFeatureDescription)? { + if let d = allInputsByName["camera_image"] { return ("camera_image", d) } + for (name, desc) in allInputsByName { + switch desc.type { + case .image: return (name, desc) + case .multiArray: + if let shape = desc.multiArrayConstraint?.shape, shape.count >= 4 { return (name, desc) } + default: continue + } + } + return nil + } +} + +// MARK: - Model File Utilities +struct ModelFileUtilities { + + /// Get the Application Support directory URL (better than Documents for internal app files) + static var applicationSupportDirectory: URL { + FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask).first! + } + + /// Get the models directory URL + static var modelsDirectory: URL { + let modelsDir = applicationSupportDirectory.appendingPathComponent("Models") + + // Create directory if it doesn't exist + if !FileManager.default.fileExists(atPath: modelsDir.path) { + try? FileManager.default.createDirectory(at: modelsDir, + withIntermediateDirectories: true) + } + + return modelsDir + } + + /// Get the uploaded models directory + static var uploadedModelsDirectory: URL { + let uploadedDir = modelsDirectory.appendingPathComponent("Uploaded") + + if !FileManager.default.fileExists(atPath: uploadedDir.path) { + try? FileManager.default.createDirectory(at: uploadedDir, + withIntermediateDirectories: true) + } + + return uploadedDir + } + + /// Copy uploaded model to app directory + static func copyUploadedModel(from sourceURL: URL, withName name: String) throws -> URL { + let destinationURL = uploadedModelsDirectory.appendingPathComponent("\(name).mlmodel") + + // Remove existing file if it exists + if FileManager.default.fileExists(atPath: destinationURL.path) { + try FileManager.default.removeItem(at: destinationURL) + } + + try FileManager.default.copyItem(at: sourceURL, to: destinationURL) + return destinationURL + } + + /// Replace compiled model using the recommended approach + static func replaceCompiledModel(compiledURL: URL, withName name: String) throws -> URL { + let permanentCompiledURL = uploadedModelsDirectory + .appendingPathComponent("\(name).mlmodel") + .deletingPathExtension() + .appendingPathExtension("mlmodelc") + + // Use replaceItemAt as recommended in the guide + try? FileManager.default.replaceItem(at: permanentCompiledURL, + withItemAt: compiledURL, + backupItemName: nil, + options: [], + resultingItemURL: nil) + + return permanentCompiledURL + } + + /// Delete model files + static func deleteModel(fileName: String, isUploaded: Bool) throws { + if isUploaded { + // Delete both .mlmodel and .mlmodelc if they exist from uploaded directory + let mlmodelURL = uploadedModelsDirectory.appendingPathComponent(fileName) + let mlmodelcURL = uploadedModelsDirectory + .appendingPathComponent(fileName) + .deletingPathExtension() + .appendingPathExtension("mlmodelc") + + if FileManager.default.fileExists(atPath: mlmodelURL.path) { + try FileManager.default.removeItem(at: mlmodelURL) + } + + if FileManager.default.fileExists(atPath: mlmodelcURL.path) { + try FileManager.default.removeItem(at: mlmodelcURL) + } + } + } +} diff --git a/AnySense/MainPage.swift b/AnySense/MainPage.swift index 6ee8cee..36b0d57 100644 --- a/AnySense/MainPage.swift +++ b/AnySense/MainPage.swift @@ -9,45 +9,150 @@ import SwiftUI struct MainPage: View { @EnvironmentObject private var appStatus : AppInformation - @EnvironmentObject private var bluetoothManager: BluetoothManager @Environment(\.scenePhase) private var phase - let arViewModel: ARViewModel + @ObservedObject var arViewModel: ARViewModel + let modelManager: ModelManager // Start the default page be the read page - @State private var selection = 1 + @State private var selection = 2 + + // Track if AR tabs are active for showing/hiding the shared AR view + private var isARTabActive: Bool { + selection == 1 || selection == 2 + } var body: some View { - TabView(selection: $selection){ - Group{ - PeripheralView() - .tabItem { - Label("ble-device", systemImage: "iphone.gen1.radiowaves.left.and.right") + ZStack { + // MARK: - Background layer (AR for tabs 1,2 or solid color for others) + if isARTabActive { + SharedARViewContainer(arViewModel: arViewModel) + .ignoresSafeArea() + } else { + Color.customizedBackground + .ignoresSafeArea() + } + + // MARK: - Content layer (overlays only, no backgrounds) + VStack(spacing: 0) { + // Main content area + Group { + switch selection { + case 0: + PeripheralView(arViewModel: arViewModel, bluetoothManager: arViewModel.getBLEManagerInstance()) + case 1: + ReadViewOverlay(arViewModel: arViewModel) + case 2: + InferenceViewOverlay(arViewModel: arViewModel) + case 3: + SettingsView(arViewModel: arViewModel, modelManager: modelManager) + default: + ReadViewOverlay(arViewModel: arViewModel) + } } - .tag(0) + .frame(maxWidth: .infinity, maxHeight: .infinity) - ReadView(arViewModel: arViewModel) - .tabItem { - Label("read", systemImage: "dot.scope") + // Custom tab bar at bottom + HStack { + TabBarButton(icon: "iphone.gen1.radiowaves.left.and.right", label: "ble-device", tag: 0, selection: $selection) + TabBarButton(icon: "record.circle", label: "record", tag: 1, selection: $selection) + TabBarButton(icon: "brain.head.profile", label: "inference", tag: 2, selection: $selection) + TabBarButton(icon: "gear", label: "settings", tag: 3, selection: $selection) } - .tag(1) - - - SettingsView() - .tabItem { - Label("settings", systemImage: "gear") + .padding(.vertical, 8) + .background(Color.tabBackground) + } + } + .onAppear { + // Start AR session on initial load + if isARTabActive { + arViewModel.startARSessionIfNeeded() + print("App launched - starting AR session for default tab \(selection)") + } - } - .tag(2) + // Inference is now tab-scoped: enable only on Inference tab (or when USB streaming is active) + syncInferenceForSelectedTab(selection) + } + .onChange(of: selection) { newTab in + let isARTab = (newTab == 1 || newTab == 2) + + if isARTab { + // Switching to AR tab - resume session without resetting tracking + arViewModel.resumeARSession() + print("Switched to AR tab \(newTab) - resuming AR session") + } else { + // Switching away from AR tabs - pause session + arViewModel.pauseARSession() + print("Switched to non-AR tab \(newTab) - pausing AR session") + } + + // Inference is now tab-scoped: enable only on Inference tab (or when USB streaming is active) + syncInferenceForSelectedTab(newTab) + } + .onChange(of: arViewModel.isUSBStreamingActive) { _, _ in + // If USB streaming starts/stops, resync inference enablement without requiring a tab change + syncInferenceForSelectedTab(selection) + } + .onChange(of: phase) { newPhase in + switch newPhase { + case .background: + arViewModel.stopAllActivities() + arViewModel.pauseARSession() + print("App backgrounded - all activities stopped") + case .active: + if isARTabActive { + arViewModel.resumeARSession() + print("App active - resuming AR session for tab \(selection)") + } + case .inactive: + print("App inactive") + @unknown default: + break } - .toolbarBackground(.tabBackground, for: .tabBar) - .toolbarBackground(.visible, for: .tabBar) } - .accentColor(.accentColor) } + @MainActor + private func syncInferenceForSelectedTab(_ tab: Int) { + // Inference is active only in the Inference tab, except when USB streaming explicitly needs it. + if tab == 2 { + arViewModel.mlManager?.enableInference() + return + } + + // Leaving inference tab: stop playback and disable inference unless USB streaming is active. + arViewModel.stopInferencePlayback(reset: true) + if arViewModel.isUSBStreamingActive { + // USB streaming sends joint actions; ensure inference is enabled to avoid all-zero actions. + arViewModel.mlManager?.enableInference() + } else { + arViewModel.mlManager?.disableInference() + } + } +} + +// MARK: - Custom Tab Bar Button +struct TabBarButton: View { + let icon: String + let label: String + let tag: Int + @Binding var selection: Int + + var body: some View { + Button(action: { + selection = tag + }) { + VStack(spacing: 4) { + Image(systemName: icon) + .font(.system(size: 20)) + Text(label) + .font(.caption2) + } + .foregroundColor(selection == tag ? .accentColor : .gray) + .frame(maxWidth: .infinity) + } + } } #Preview { - MainPage(arViewModel: ARViewModel()) + MainPage(arViewModel: ARViewModel(), modelManager: ModelManager()) .environmentObject(AppInformation()) - .environmentObject(BluetoothManager()) } diff --git a/AnySense/Managers/ARViewContainer.swift b/AnySense/Managers/ARViewContainer.swift index 0effdc5..efc9711 100644 --- a/AnySense/Managers/ARViewContainer.swift +++ b/AnySense/Managers/ARViewContainer.swift @@ -15,19 +15,27 @@ import CoreMedia import CoreImage import UIKit import CoreImage.CIFilterBuiltins -//import WebRTC +import Combine +import Accelerate struct RecordingFiles { let rgbFileName: URL let depthFileName: URL let timestamp: String - let rgbImagesDirectory: URL - let depthImagesDirectory: URL + let rgbImagesDirectory: URL? + let depthImagesDirectory: URL? let poseFile: URL let generalDataDirectory: String let tactileFile: URL } +enum RecordingMode { + case none + case standardRecording + case mlInference + case usbStreaming +} + func createFile(fileURL: URL) throws { let success = FileManager.default.createFile(atPath: fileURL.path, contents: nil, attributes: nil) if !success { @@ -35,21 +43,78 @@ func createFile(fileURL: URL) throws { } } -struct ARViewContainer: UIViewRepresentable { - var session: ARSession - typealias UIViewType = ARView +// MARK: - Shared AR View Container (hosts the single ARView from ARViewModel) +struct SharedARViewContainer: UIViewRepresentable { + @ObservedObject var arViewModel: ARViewModel func makeUIView(context: Context) -> ARView { - // Initialize the ARView - let arView = ARView(frame: .zero, cameraMode: .ar, automaticallyConfigureSession: false) - arView.session = session - arView.environment.sceneUnderstanding.options = [] // No extra scene understanding - return arView + print("SharedARViewContainer: returning shared ARView") + return arViewModel.getOrCreateSharedARView() } + func updateUIView(_ uiView: ARView, context: Context) { - if uiView.session !== session { - uiView.session = session + // ARView is managed by ARViewModel, no updates needed here + } +} + +// MARK: - Tap Coordinator for Shared ARView +class TapCoordinator: NSObject { + weak var arViewModel: ARViewModel? + + init(arViewModel: ARViewModel) { + self.arViewModel = arViewModel + super.init() + } + + @objc func handleTap(_ recognizer: UITapGestureRecognizer) { + guard recognizer.state == .ended, let arView = recognizer.view as? ARView else { return } + let location = recognizer.location(in: arView) + + // Try LiDAR-backed mesh raycast first + if let world = meshBackedHit(in: arView, from: location) { + var t = matrix_identity_float4x4 + t.columns.3 = SIMD4(world.x, world.y, world.z, 1) + let goalAnchor = ARAnchor(name: "goal", transform: t) + arView.session.add(anchor: goalAnchor) + print("Using LiDAR mesh raycast for 3D point: \(world)") + NotificationCenter.default.post( + name: NSNotification.Name("ARViewTapForGoal"), + object: nil, + userInfo: ["worldPoint": world, "method": "meshRaycast", "location": location, "bounds": arView.bounds] + ) + return + } + + // Fallback: ARKit plane/estimated-surface raycast + if let hit = arView.raycast(from: location, allowing: .estimatedPlane, alignment: .any).first { + let t = hit.worldTransform + let world = simd_float3(t.columns.3.x, t.columns.3.y, t.columns.3.z) + let goalAnchor = ARAnchor(name: "goal", transform: t) + arView.session.add(anchor: goalAnchor) + print("Using plane/estimated raycast fallback for 3D point: \(world)") + NotificationCenter.default.post( + name: NSNotification.Name("ARViewTapForGoal"), + object: nil, + userInfo: ["worldPoint": world, "method": "raycast", "location": location, "bounds": arView.bounds] + ) + return + } + + // Final fallback: notify with screen info only + NotificationCenter.default.post( + name: NSNotification.Name("ARViewTapForGoal"), + object: nil, + userInfo: ["location": location, "bounds": arView.bounds] + ) + } + + private func meshBackedHit(in arView: ARView, from location: CGPoint) -> SIMD3? { + guard let ray = arView.ray(through: location) else { return nil } + let hits = arView.scene.raycast(origin: ray.origin, direction: ray.direction) + if let hit = hits.first(where: { $0.entity is HasSceneUnderstanding }) { + return hit.position } + return nil } } @@ -61,17 +126,47 @@ class DepthStatus: ObservableObject { isDepthAvailable = false showAlert = true } + + public func dismissAlert() { + showAlert = false + } } class ARViewModel: ObservableObject{ + var bluetoothManager: BluetoothManager? @Published var isOpen : Bool = false @Published var depthStatus = DepthStatus() + var demosCounter : Int = -1 var session = ARSession() var audioSession = AVCaptureSession() var audioCaptureDelegate: AudioCaptureDelegate? + // ML Inference Manager - now optional and initialized later + @Published var mlManager: MLInferenceManager? + + // AR Visualization Manager for 3D pose visualization + @Published var arVisualizationManager: ARVisualizationManager + @Published var goalTapModeEnabled: Bool = false + @Published var isUSBStreamingActive: Bool = false + + // MARK: - Shared ARView (single instance for entire app lifecycle) + private var sharedARView: ARView? + private var hasSetupSharedARView = false + + // MARK: - Centralized Recording State Management + @Published var isRecording: Bool = false + @Published var recordingMode: RecordingMode = .none + private var currentRecordingFiles: RecordingFiles? + + // MARK: - Inference Playback (no file saving) + @Published var isInferencePlaying: Bool = false + @Published var isInferenceEpisodeFinished: Bool = false + + + public var userFPS: Double? public var isColorMapOpened = false + public var ifAudioEnable = false private var usbManager = USBManager() private var orientation: UIInterfaceOrientation = .portrait @@ -89,16 +184,24 @@ class ARViewModel: ObservableObject{ private var combinedRGBTransform: CGAffineTransform? private var combinedDepthTransform: CGAffineTransform? - + private var rgbOutputPixelBufferUSB: CVPixelBuffer? private var depthOutputPixelBufferUSB: CVPixelBuffer? private var depthConfidenceOutputPixelBufferUSB: CVPixelBuffer? + + // MARK: - Accelerate Optimization Properties + private var rgbTransformBuffer: vImage_Buffer? + private var lastTransformImageSize: CGSize = .zero + // MARK: - Exposed helpers for MLInferenceManager + func getARSession() -> ARSession { + return session + } private var poseFileHandle: FileHandle? // Control the destination of rgb images directory and depth images directory - private var rgbDirect: URL = URL(fileURLWithPath: "") - private var depthDirect: URL = URL(fileURLWithPath: "") + private var rgbDirect: URL? = nil + private var depthDirect: URL? = nil // Control the destination of pose data text file private var poseURL: URL = URL(fileURLWithPath: "") private var generalURL: URL = URL(fileURLWithPath: "") @@ -120,7 +223,14 @@ class ARViewModel: ObservableObject{ private var depthConfAttributes: [String: Any] = [:] private var audioOutputSettings: [String: Any] = [:] + // Combine subscriptions for ML integration + private var cancellables = Set() + + @MainActor init() { + self.arVisualizationManager = ARVisualizationManager() + bluetoothManager = BluetoothManager() + self.rgbAttributes = [ kCVPixelBufferPixelFormatTypeKey as String: Int(kCVPixelFormatType_32ARGB), kCVPixelBufferWidthKey as String: Int(viewPortSize.width), @@ -144,14 +254,119 @@ class ARViewModel: ObservableObject{ ] self.ciContext = CIContext() + updateDemoCounter() + + // Listen for goal-tap notifications and start odometry + set goal point + NotificationCenter.default.addObserver(forName: NSNotification.Name("ARViewTapForGoal"), object: nil, queue: .main) { [weak self] notif in + guard let self = self, let ml = self.mlManager else { + print(" Goal tap: No ML manager") + return + } + // Only handle taps when using a point-conditioned policy and the user enabled goal-tap mode + print(" Goal tap received - isPointConditioned: \(ml.isPointConditioned), goalTapMode: \(self.goalTapModeEnabled)") + guard ml.isPointConditioned, self.goalTapModeEnabled else { + print("Goal tap ignored - conditions not met") + return + } + // Prefer direct world point from depth/raycast if provided + if let world = notif.userInfo?["worldPoint"] as? simd_float3 { + let method = notif.userInfo?["method"] as? String ?? "unknown" + print("Using \(method) world point: \(world)") + ml.setGoalPoint(world) + self.goalTapModeEnabled = false + return + } + } + } + + func getBLEManagerInstance() -> BluetoothManager{ + return bluetoothManager!; + } + + // MARK: - Shared ARView Management + @MainActor + func getOrCreateSharedARView() -> ARView { + if let existingView = sharedARView { + return existingView + } + + print("Creating shared ARView (one-time setup)") + + // Create the single ARView instance + let arView = ARView(frame: .zero, cameraMode: .ar, automaticallyConfigureSession: false) + arView.session = session + + // Consistent rendering options + arView.renderOptions = [.disablePersonOcclusion, .disableDepthOfField, .disableMotionBlur] + + // Enable scene understanding for raycasts + arView.environment.sceneUnderstanding.options = [.collision] + + // Setup AR visualization + arVisualizationManager.setupVisualization(with: arView) + + // Add tap recognizer for goal setting + let coordinator = TapCoordinator(arViewModel: self) + let tap = UITapGestureRecognizer(target: coordinator, action: #selector(TapCoordinator.handleTap(_:))) + arView.addGestureRecognizer(tap) + // Store coordinator to prevent deallocation + objc_setAssociatedObject(arView, "tapCoordinator", coordinator, .OBJC_ASSOCIATION_RETAIN) + + sharedARView = arView + hasSetupSharedARView = true + + print("Shared ARView created and configured") + return arView } + // Resume AR session + @MainActor + func resumeARSession() { + guard !isOpen else { + print("AR session already running") + return + } + + let status = AVCaptureDevice.authorizationStatus(for: .video) + guard status == .authorized else { return } + + let configuration = createARConfiguration() + session.run(configuration, options: []) + isOpen = true + + print("AR session resumed (tracking preserved)") + } + // MARK: - Shared AR Configuration + private func createARConfiguration() -> ARWorldTrackingConfiguration { + let configuration = ARWorldTrackingConfiguration() + + for videoFormat in ARWorldTrackingConfiguration.supportedVideoFormats { + if videoFormat.captureDeviceType == .builtInWideAngleCamera { + configuration.videoFormat = videoFormat + break + } + } + + if ARWorldTrackingConfiguration.supportsFrameSemantics(.sceneDepth) { + configuration.frameSemantics.insert(.sceneDepth) + } + configuration.planeDetection = [.horizontal, .vertical] + if ARWorldTrackingConfiguration.supportsSceneReconstruction(.meshWithClassification) { + configuration.sceneReconstruction = .meshWithClassification + } else if ARWorldTrackingConfiguration.supportsSceneReconstruction(.mesh) { + configuration.sceneReconstruction = .mesh + } + configuration.environmentTexturing = .none + configuration.isAutoFocusEnabled = false + + return configuration + } + private func setupAudioSession() { guard let audioDevice = AVCaptureDevice.default(for: .audio), let audioDeviceInput = try? AVCaptureDeviceInput(device: audioDevice) else { - print("Unable to access microphone.") return } audioSession.addInput(audioDeviceInput) @@ -164,33 +379,62 @@ class ARViewModel: ObservableObject{ private func setupTransforms() { DispatchQueue.global(qos: .userInitiated).async { - while self.depthRetryCount < self.maxDepthRetries { + var attempts = 0 + let maxAttempts = 50 // Max 500ms wait + + while attempts < maxAttempts { guard let currentFrame = self.session.currentFrame else { - usleep(10000) + attempts += 1 + usleep(10000) // 10ms continue } - let flipTransform = (self.orientation.isPortrait) - ? CGAffineTransform(scaleX: -1, y: -1).translatedBy(x: -1, y: -1) - : .identity - + + let flipTransform = self.computeFlipTransform() + + // Initialize RGB transform if needed if self.combinedRGBTransform == nil { self.initializeRGBTransform(frame: currentFrame, flipTransform: flipTransform) + print("RGB transform initialized successfully") } - - if !self.depthStatus.isDepthAvailable { break } - + + // Try depth transform if self.combinedDepthTransform == nil { if self.initializeDepthTransform(frame: currentFrame, flipTransform: flipTransform) { - break + print("Depth transform initialized successfully") } } - - self.depthRetryCount += 1 + + // Exit once we have RGB transform (depth is optional) + if self.combinedRGBTransform != nil { + break + } + + attempts += 1 usleep(10000) } - + + if self.combinedRGBTransform == nil { + print("Note: RGB transform not yet initialized, will compute on-demand") + } if self.combinedDepthTransform == nil { - print("Depth initialization failed after \(self.maxDepthRetries) attempts.") + print("Note: Depth transform not yet initialized, will compute on-demand") + } + } + } + + func ensureTransformsReady() { + guard let currentFrame = session.currentFrame else { return } + + let flipTransform = computeFlipTransform() + + if combinedRGBTransform == nil { + initializeRGBTransform(frame: currentFrame, flipTransform: flipTransform) + print("RGB transform computed on-demand") + } + + if combinedDepthTransform == nil { + if initializeDepthTransform(frame: currentFrame, flipTransform: flipTransform) { + print("Depth transform computed on-demand") } } } @@ -210,7 +454,6 @@ class ARViewModel: ObservableObject{ private func initializeDepthTransform(frame: ARFrame, flipTransform: CGAffineTransform) -> Bool { guard let depthPixelBuffer = frame.sceneDepth?.depthMap else { - print("Depth map unavailable. Retrying (\(self.depthRetryCount)/\(self.maxDepthRetries))") return false } let depthSize = CGSize(width: CVPixelBufferGetWidth(depthPixelBuffer), height: CVPixelBufferGetHeight(depthPixelBuffer)) @@ -227,76 +470,251 @@ class ARViewModel: ObservableObject{ return true } + @MainActor func setupARSession() { + // Sync orientation with the current interface orientation before configuring transforms + refreshOrientationFromScene() self.startARSession() - setupAudioSession() + if(ifAudioEnable) { + setupAudioSession() + } setupTransforms() - - print("Finished setting up ARViewModel.") } + @MainActor func startARSession() { let status = AVCaptureDevice.authorizationStatus(for: .video) - guard status == .authorized else { - print("Camera permissions not granted.") - return - } - // Create and configure the AR session configuration - let configuration = ARWorldTrackingConfiguration() - - // Loop through available video formats and select the wide-angle camera format - for videoFormat in ARWorldTrackingConfiguration.supportedVideoFormats { - if videoFormat.captureDeviceType == .builtInWideAngleCamera { - print("Wide-angle camera selected: \(videoFormat)") - configuration.videoFormat = videoFormat - break - } else { - print("Unsupported video format: \(videoFormat.captureDeviceType)") - } - } - - // Set the session configuration properties - if ARWorldTrackingConfiguration.supportsFrameSemantics(.sceneDepth) { - configuration.frameSemantics.insert(.sceneDepth) - } else { - depthStatus.setUnavailable() - } - configuration.planeDetection = [] - configuration.environmentTexturing = .none // No environment texturing - configuration.sceneReconstruction = [] // No scene reconstruction - configuration.isAutoFocusEnabled = false + guard status == .authorized else { return } - // Run the session with the configuration + let configuration = createARConfiguration() session.run(configuration, options: [.resetTracking, .removeExistingAnchors]) - print("Starting session") isOpen = true } + + private func refreshOrientationFromScene() { + // Keep a consistent transform between tabs; force portrait so Record and Inference align + orientation = .portrait + } + @MainActor func pauseARSession(){ session.pause() isOpen = false + clearCachedTransforms() } + @MainActor func killARSession() { session.pause() // Pause before releasing resources session = ARSession() // Replace with a new ARSession isOpen = false - print("ARSession killed and reset.") + clearCachedTransforms() + } + + /// Clear cached transforms so they are recalculated on next session start + private func clearCachedTransforms() { + combinedRGBTransform = nil + combinedDepthTransform = nil + } + + /// Safely extract depth and confidence buffers from an AR frame + private func getDepthBuffers(from frame: ARFrame) -> (depth: CVPixelBuffer, confidence: CVPixelBuffer)? { + guard let depthBuffer = frame.sceneDepth?.depthMap, + let confidenceBuffer = frame.sceneDepth?.confidenceMap else { + return nil + } + return (depthBuffer, confidenceBuffer) } + /// Compute flip transform based on current orientation + private func computeFlipTransform() -> CGAffineTransform { + orientation.isPortrait + ? CGAffineTransform(scaleX: -1, y: -1).translatedBy(x: -1, y: -1) + : .identity + } + + // MARK: - Safe Session Management + @MainActor + func startARSessionIfNeeded() { + guard !isOpen else { + print("AR session already running") + return + } + + print("Starting AR session for ARViewContainer") + setupARSession() + } + + // MARK: - Inference Playback (no file saving) + @MainActor + func startInferencePlayback() { + // MARK: - State Validation Guards + guard !isInferencePlaying else { + print("Inference playback already active - ignoring start request") + return + } + + guard !isUSBStreamingActive else { + print("Cannot start inference playback while USB streaming is active") + return + } + + guard !isRecording else { + print("Cannot start inference playback while recording is active") + return + } + + guard recordingMode == .none else { + print("Another recording mode active: \(recordingMode) - stopping first") + stopAllActivities() + // stopAllActivities resets state; if it couldn't, bail safely + guard recordingMode == .none else { return } + return startInferencePlayback() + } + + // Ensure AR session is running + startARSessionIfNeeded() + + // MARK: - Update Centralized State + recordingMode = .mlInference + isInferencePlaying = true + isInferenceEpisodeFinished = false + + // Reset ML inference state for a new playback session (keep goal) + mlManager?.resetInferenceState() + mlManager?.latestResult = nil + mlManager?.lastResult = nil + + // Reset visualization state (fresh origin/targets for new episode) + arVisualizationManager.stopRecordingVisualization() + arVisualizationManager.enableVisualization() + arVisualizationManager.ensureVisualizationReady() + + let fps = userFPS ?? 30.0 + displayLink = CADisplayLink(target: self, selector: #selector(runInferencePlaybackTick)) + displayLink?.preferredFrameRateRange = CAFrameRateRange( + minimum: Float(fps), + maximum: Float(fps), + preferred: Float(fps) + ) + displayLink?.add(to: .main, forMode: .common) + + print("Inference playback started") + } + + @MainActor + func stopInferencePlayback(reset: Bool = true) { + guard isInferencePlaying || recordingMode == .mlInference else { + return + } + + displayLink?.invalidate() + displayLink = nil + + isInferencePlaying = false + isInferenceEpisodeFinished = false + + if recordingMode == .mlInference { + recordingMode = .none + } + + if reset { + mlManager?.resetInferenceState() + mlManager?.latestResult = nil + mlManager?.lastResult = nil + arVisualizationManager.stopRecordingVisualization() + arVisualizationManager.enableVisualization() + arVisualizationManager.ensureVisualizationReady() + // Ensure episode-finished state clears even if last result was CLOSED + arVisualizationManager.setGripperState(isClosed: false) + } + + print("Inference playback stopped") + } + + @MainActor + @objc private func runInferencePlaybackTick(link: CADisplayLink) { + // Avoid doing any work if playback has ended or mode changed + guard isInferencePlaying, recordingMode == .mlInference else { return } + + // Episode finished -> stop processing frames (but keep "Stop" available for reset) + if arVisualizationManager.isGripperClosed { + if !isInferenceEpisodeFinished { + isInferenceEpisodeFinished = true + print("Episode finished (gripper closed) - waiting for reset") + } + return + } + + guard let currentFrame = session.currentFrame else { return } + let rgbPixelBuffer = currentFrame.capturedImage + + if let mlManager = mlManager { + Task { @MainActor in + mlManager.performInference(on: rgbPixelBuffer, arFrame: currentFrame, timestamp: CACurrentMediaTime()) + } + } + } + + @MainActor func startUSBStreaming() { + // MARK: - State Validation Guards + guard !isUSBStreamingActive else { + print("USB Streaming already active - ignoring start request") + return + } + + guard recordingMode == .none else { + print("Another recording mode active: \(recordingMode) - stopping first") + stopAllActivities() + return + } + + // Ensure transforms are computed before streaming + ensureTransformsReady() + + // MARK: - Update Centralized State + recordingMode = .usbStreaming + + // Reset ML inference state for new streaming session + mlManager?.resetInferenceState() + displayLink = CADisplayLink(target: self, selector: #selector(sendFrameUSB)) displayLink?.preferredFrameRateRange = CAFrameRateRange(minimum: Float(self.userFPS!), maximum: Float(self.userFPS!), preferred: Float(self.userFPS!)) displayLink?.add(to: .main, forMode: .common) + isUSBStreamingActive = true + mlManager?.setUSBStreamingState(isActive: true) + + print("USB streaming started successfully") } + @MainActor func stopUSBStreaming() { + // MARK: - State Validation Guard + guard isUSBStreamingActive else { + print("Stop USB streaming called but not currently streaming") + return + } + displayLink?.invalidate() displayLink = nil + isUSBStreamingActive = false + mlManager?.setUSBStreamingState(isActive: false) + + // Reset ML inference state when stopping + mlManager?.resetInferenceState() + + // MARK: - Update Centralized State + if recordingMode == .usbStreaming { + recordingMode = .none + } + + print("USB streaming stopped successfully") } + @MainActor func setupUSBStreaming() { var rgbBuffer: CVPixelBuffer? @@ -309,30 +727,26 @@ class ARViewModel: ObservableObject{ &rgbBuffer ) guard status == kCVReturnSuccess else { - print("Failed to create CVPixelBuffer") return } self.rgbOutputPixelBufferUSB = rgbBuffer - if self.depthStatus.isDepthAvailable { - var depthBuffer: CVPixelBuffer? - var depthConfidenceBuffer: CVPixelBuffer? + // Try to set up depth buffers - optional, won't block if it fails + var depthBuffer: CVPixelBuffer? + var depthConfidenceBuffer: CVPixelBuffer? - let depthStatus = CVPixelBufferCreate( - kCFAllocatorDefault, - Int(depthViewPortSize.width), - Int(depthViewPortSize.height), - kCVPixelFormatType_DepthFloat32, - depthAttributes as CFDictionary, - &depthBuffer - ) - - guard depthStatus == kCVReturnSuccess else { - print("Failed to create CVPixelBuffer") - return - } + let depthStatus = CVPixelBufferCreate( + kCFAllocatorDefault, + Int(depthViewPortSize.width), + Int(depthViewPortSize.height), + kCVPixelFormatType_DepthFloat32, + depthAttributes as CFDictionary, + &depthBuffer + ) + + if depthStatus == kCVReturnSuccess { self.depthOutputPixelBufferUSB = depthBuffer - + let depthConfidenceStatus = CVPixelBufferCreate( kCFAllocatorDefault, Int(depthViewPortSize.width), @@ -341,42 +755,34 @@ class ARViewModel: ObservableObject{ depthConfAttributes as CFDictionary, &depthConfidenceBuffer ) - guard depthConfidenceStatus == kCVReturnSuccess else { - print("Failed to create CVPixelBuffer") - return + if depthConfidenceStatus == kCVReturnSuccess { + self.depthConfidenceOutputPixelBufferUSB = depthConfidenceBuffer } - self.depthConfidenceOutputPixelBufferUSB = depthConfidenceBuffer } - print("Made all USB Buffers") usbManager.connect() } + @MainActor func killUSBStreaming() { self.usbManager.disconnect() - + self.rgbOutputPixelBufferUSB = nil self.depthOutputPixelBufferUSB = nil self.depthConfidenceOutputPixelBufferUSB = nil - } - -// func startWiFiStreaming(host: String, port: UInt16) { - // Set up the network connection -// // Start WebRTC connection -// webRTCManager.setupConnection() -// } -// func stopWiFiStreaming() { -// displayLink?.invalidate() -// displayLink = nil -// streamConnection?.cancel() -// streamConnection = nil -// } - - @objc private func sendFrame(link: CADisplayLink) { - streamVideoFrameUSB() + // Clean up vImage buffers + if let transformBuffer = rgbTransformBuffer { + free(transformBuffer.data) + rgbTransformBuffer = nil + } + + isUSBStreamingActive = false + mlManager?.setUSBStreamingState(isActive: false) } + + @MainActor @objc private func sendFrameUSB(link: CADisplayLink) { streamVideoFrameUSB() } @@ -384,33 +790,43 @@ class ARViewModel: ObservableObject{ private func processDepthStreamData(depthPixelBuffer: CVPixelBuffer, outputBuffer: CVPixelBuffer, isDepth: Bool) -> Data? { CVPixelBufferLockBaseAddress(depthPixelBuffer, .readOnly) CVPixelBufferLockBaseAddress(outputBuffer, []) - - let depthCiImage = CIImage(cvPixelBuffer: depthPixelBuffer) - let depthTransformedImage = depthCiImage.transformed(by: self.combinedDepthTransform!) - self.ciContext.render(depthTransformedImage, to: outputBuffer) - + + // Try optimized depth processing + if canUseOptimizedDepthTransform(for: depthPixelBuffer) { + processDepthOptimized(depthPixelBuffer, outputBuffer: outputBuffer) + } else { + // Fallback to Core Image + let depthCiImage = CIImage(cvPixelBuffer: depthPixelBuffer) + let depthTransformedImage = depthCiImage.transformed(by: self.combinedDepthTransform ?? CGAffineTransform.identity) + self.ciContext.render(depthTransformedImage, to: outputBuffer) + } + let compressedData = self.usbManager.compressData(from: outputBuffer, isDepth: isDepth) - + CVPixelBufferUnlockBaseAddress(outputBuffer, []) CVPixelBufferUnlockBaseAddress(depthPixelBuffer, .readOnly) - + return compressedData } + @MainActor func streamVideoFrameUSB() { guard let currentFrame = session.currentFrame else {return} let rgbPixelBuffer = currentFrame.capturedImage -// TODO: Check if we need to change this at all - var depthPixelBuffer: CVPixelBuffer? = nil - var depthConfidencePixelBuffer: CVPixelBuffer? = nil - if self.depthStatus.isDepthAvailable { - guard let depthBuffer = currentFrame.sceneDepth?.depthMap else { return } - depthPixelBuffer = depthBuffer - guard let depthConfidenceBuffer = currentFrame.sceneDepth?.confidenceMap else { return } - depthConfidencePixelBuffer = depthConfidenceBuffer + + // Perform ML inference on the RGB frame during streaming (provide ARFrame for odometry/goal updates) + if let mlManager = mlManager { + Task { @MainActor in + mlManager.performInference(on: rgbPixelBuffer, arFrame: currentFrame, timestamp: CACurrentMediaTime()) + } } + // Try to get depth data if available, but continue regardless + let depthBuffers = getDepthBuffers(from: currentFrame) + let depthPixelBuffer = depthBuffers?.depth + let depthConfidencePixelBuffer = depthBuffers?.confidence + let cameraIntrinsics = currentFrame.camera.intrinsics var intrinsicCoeffs = IntrinsicMatrixCoeffs( @@ -446,47 +862,78 @@ class ARViewModel: ObservableObject{ deviceType: 1 ) + let rgbOutputBufferUSB = self.rgbOutputPixelBufferUSB + let depthOutputBufferUSB = self.depthOutputPixelBufferUSB + let depthConfOutputBufferUSB = self.depthConfidenceOutputPixelBufferUSB + let latestJointActions = self.mlManager?.latestResult?.jointPositions + let usbManager = self.usbManager + DispatchQueue.global(qos: .userInitiated).async { + guard let rgbOutputBufferUSB else { return } CVPixelBufferLockBaseAddress(rgbPixelBuffer, .readOnly) - CVPixelBufferLockBaseAddress(self.rgbOutputPixelBufferUSB!, []) - - let rgbCiImage = CIImage(cvPixelBuffer: rgbPixelBuffer) - let rgbTransformedImage = rgbCiImage.transformed(by: self.combinedRGBTransform!) + CVPixelBufferLockBaseAddress(rgbOutputBufferUSB, []) - guard let rgbCgImage = self.ciContext.createCGImage(rgbTransformedImage, from: rgbTransformedImage.extent) else{ - return + let rgbImageData: Data? + if self.canUseOptimizedTransform(for: rgbPixelBuffer) { + rgbImageData = self.processRGBOptimized(rgbPixelBuffer) + } else { + // Fallback to Core Image pipeline + let rgbCiImage = CIImage(cvPixelBuffer: rgbPixelBuffer) + let rgbTransformedImage = rgbCiImage.transformed(by: self.combinedRGBTransform ?? CGAffineTransform.identity) + + guard let rgbCgImage = self.ciContext.createCGImage(rgbTransformedImage, from: rgbTransformedImage.extent) else{ + return + } + rgbImageData = UIImage(cgImage: rgbCgImage).jpegData(compressionQuality: 0.5) } - let rgbImageData = UIImage(cgImage: rgbCgImage).jpegData(compressionQuality: 0.5) record3dHeader.rgbSize = UInt32(rgbImageData!.count) - CVPixelBufferUnlockBaseAddress(self.rgbOutputPixelBufferUSB!, []) + CVPixelBufferUnlockBaseAddress(rgbOutputBufferUSB, []) CVPixelBufferUnlockBaseAddress(rgbPixelBuffer, .readOnly) var compressedDepthData: Data? = nil var compressedDepthConfData: Data? = nil - if self.depthStatus.isDepthAvailable { - compressedDepthData = self.processDepthStreamData(depthPixelBuffer: depthPixelBuffer!, outputBuffer: self.depthOutputPixelBufferUSB!, isDepth: true) - compressedDepthConfData = self.processDepthStreamData(depthPixelBuffer: depthConfidencePixelBuffer!, outputBuffer: self.depthConfidenceOutputPixelBufferUSB!, isDepth: false) + // Process depth data if available + if let depthBuffer = depthPixelBuffer, + let depthConfBuffer = depthConfidencePixelBuffer, + let depthOutputBuffer = depthOutputBufferUSB, + let depthConfOutputBuffer = depthConfOutputBufferUSB { + compressedDepthData = self.processDepthStreamData(depthPixelBuffer: depthBuffer, outputBuffer: depthOutputBuffer, isDepth: true) + compressedDepthConfData = self.processDepthStreamData(depthPixelBuffer: depthConfBuffer, outputBuffer: depthConfOutputBuffer, isDepth: false) record3dHeader.depthSize = UInt32(compressedDepthData?.count ?? 0) record3dHeader.confidenceMapSize = UInt32(compressedDepthConfData?.count ?? 0) } - self.usbManager.sendData( + + // Always send exactly 7 floats (28 bytes) for joint actions + let jointActionsArray: [Float] + if let latestJointActions, !latestJointActions.isEmpty { + // Use actual ML inference results, ensure exactly 7 values + jointActionsArray = Array(latestJointActions.prefix(7)) + Array(repeating: 0.0, count: max(0, 7 - latestJointActions.count)) + } else { + // Fallback to zeros if no ML results available + jointActionsArray = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + } + + // Convert to exactly 28 bytes (7 floats * 4 bytes each) + let jointActionsData = Data(bytes: jointActionsArray, count: 28) + + usbManager.sendData( record3dHeaderData: Data(bytes: &record3dHeader, count: MemoryLayout.size), intrinsicMatData: Data(bytes: &intrinsicCoeffs, count: MemoryLayout.size), poseData: Data(bytes: &camera_pose, count: MemoryLayout.size), rgbImageData: rgbImageData!, + jointActionsData: jointActionsData, compressedDepthData: compressedDepthData, compressedConfData: compressedDepthConfData ) - - } } + @MainActor @objc private func updateFrame(link: CADisplayLink) { guard lastTimestamp > 0 else { // Initialize timestamp on the first call @@ -496,93 +943,210 @@ class ARViewModel: ObservableObject{ captureVideoFrame() } - func startRecording() -> RecordingFiles { - let saveFileNames = setupRecording() + @MainActor + func startRecording() -> RecordingFiles? { + // MARK: - State Validation Guards + guard !isRecording else { + print("Recording already active - ignoring start request") + return currentRecordingFiles + } + + guard recordingMode == .none else { + print("Another recording mode active: \(recordingMode) - stopping first") + stopAllActivities() + return nil + } + // Ensure transforms are computed before recording + ensureTransformsReady() + + guard let saveFileNames = setupRecording() else { + print("Failed to setup recording") + return nil + } + + // MARK: - Update Centralized State + isRecording = true + recordingMode = .standardRecording + currentRecordingFiles = saveFileNames + + // Reset ML inference state for new recording + mlManager?.resetInferenceState() + assetWriter?.startWriting() startTime = CMTimeMake(value: Int64(CACurrentMediaTime() * 1000), timescale: 1000) assetWriter?.startSession(atSourceTime: startTime!) - + + let audioEnabled = ifAudioEnable + let audioSession = self.audioSession DispatchQueue.global(qos: .background).async { - self.audioSession.startRunning() + if audioEnabled { + audioSession.startRunning() + } } - if self.depthStatus.isDepthAvailable { - depthAssetWriter?.startWriting() - depthAssetWriter?.startSession(atSourceTime: startTime!) + // Start depth recording if depth writer is available + if let depthWriter = depthAssetWriter { + depthWriter.startWriting() + depthWriter.startSession(atSourceTime: startTime!) } - + displayLink = CADisplayLink(target: self, selector: #selector(updateFrame)) displayLink?.preferredFrameRateRange = CAFrameRateRange(minimum: Float(self.userFPS!), maximum: Float(self.userFPS!), preferred: Float(self.userFPS!)) displayLink?.add(to: .main, forMode: .common) - - return saveFileNames! + + print("Recording started successfully") + return saveFileNames } + @MainActor func stopRecording(){ + // MARK: - State Validation Guard + guard isRecording else { + print("Stop recording called but not currently recording") + return + } + displayLink?.invalidate() displayLink = nil - audioSession.stopRunning() - audioInput?.markAsFinished() + + // Stop AR pose visualization + arVisualizationManager.stopRecordingVisualization() + + // Reset ML inference state when stopping + mlManager?.resetInferenceState() + + if(ifAudioEnable) { + audioSession.stopRunning() + audioInput?.markAsFinished() + } videoInput?.markAsFinished() - + audioCaptureDelegate = nil - + assetWriter?.finishWriting { self.assetWriter = nil - print("RGB Video recording finished.") } - + depthVideoInput?.markAsFinished() depthAssetWriter?.finishWriting { self.depthAssetWriter = nil - print("Depth Video recording finished.") } do { try poseFileHandle?.close() } catch { - print("Error closing pose file") + // Error closing pose file - continue cleanup + } + + // MARK: - Update Centralized State + isRecording = false + recordingMode = .none + currentRecordingFiles = nil + + updateDemoCounter() + print("Recording stopped successfully") + } + + // MARK: - Comprehensive Cleanup Method + @MainActor + func stopAllActivities() { + // If nothing is active, avoid redundant cleanup work + if !isRecording && !isUSBStreamingActive && !isInferencePlaying && recordingMode == .none && displayLink == nil { + print("No active activities to stop") + return + } + + print("Stopping all activities...") + + // Stop inference playback if active + if isInferencePlaying || recordingMode == .mlInference { + stopInferencePlayback(reset: true) + } + + // Stop recording if active + if isRecording { + stopRecording() } + + // Stop USB streaming if active + if isUSBStreamingActive { + stopUSBStreaming() + } + + // Reset ML inference state + mlManager?.resetInferenceState() + + // Stop AR visualization + arVisualizationManager.stopRecordingVisualization() + + // Invalidate any remaining display links + displayLink?.invalidate() + displayLink = nil + + // Reset state + recordingMode = .none + currentRecordingFiles = nil + isInferencePlaying = false + isInferenceEpisodeFinished = false + + print("All activities stopped") } + @MainActor func setupRecording() -> RecordingFiles? { // Determine all the destinated file saving URL or this recording by its start time let dateFormatter = DateFormatter() dateFormatter.dateFormat = "yyyy-MM-dd-HH_mm_ss" let timestamp = dateFormatter.string(from: Date()) - let fileNames = [ + var fileNames = [ "RGB": "RGB_\(timestamp).mp4", "Depth": "Depth_\(timestamp).mp4", "Pose": "AR_Pose_\(timestamp).txt", - "Tactile": "Tactile_\(timestamp).bin", - "RGBImages": "RGB_Images_\(timestamp)", - "DepthImages": isColorMapOpened ? "Depth_Colored_Images_\(timestamp)" : "Depth_Images_\(timestamp)" + "Tactile": "Tactile_\(timestamp).bin" ] + + // Only include image directories if debug frame saving is enabled + if mlManager?.saveDebugFrames == true { + fileNames["RGBImages"] = "RGB_Images_\(timestamp)" + fileNames["DepthImages"] = isColorMapOpened ? "Depth_Colored_Images_\(timestamp)" : "Depth_Images_\(timestamp)" + } - guard let documentsURL = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask).first else { - print("Failed to get document directory") + let generalDataDirectory = getDocumentsDirect().appendingPathComponent(timestamp) + + guard let rgbFileName = fileNames["RGB"], + let depthFileName = fileNames["Depth"], + let poseFileName = fileNames["Pose"], + let tactileFileName = fileNames["Tactile"] else { return nil } - - let generalDataDirectory = documentsURL.appendingPathComponent(timestamp) - let rgbVideoURL = generalDataDirectory.appendingPathComponent(fileNames["RGB"]!) - let depthVideoURL = generalDataDirectory.appendingPathComponent(fileNames["Depth"]!) - let poseTextURL = generalDataDirectory.appendingPathComponent(fileNames["Pose"]!) - let tactileFileURL = generalDataDirectory.appendingPathComponent(fileNames["Tactile"]!) - let rgbImagesDirectory = generalDataDirectory.appendingPathComponent(fileNames["RGBImages"]!) - let depthImagesDirectory = generalDataDirectory.appendingPathComponent(fileNames["DepthImages"]!) - + + let rgbVideoURL = generalDataDirectory.appendingPathComponent(rgbFileName) + let depthVideoURL = generalDataDirectory.appendingPathComponent(depthFileName) + let poseTextURL = generalDataDirectory.appendingPathComponent(poseFileName) + let tactileFileURL = generalDataDirectory.appendingPathComponent(tactileFileName) + + // Only create image directories if debug frame saving is enabled + var rgbImagesDirectory: URL? + var depthImagesDirectory: URL? + if mlManager?.saveDebugFrames == true, + let rgbDirName = fileNames["RGBImages"], + let depthDirName = fileNames["DepthImages"] { + rgbImagesDirectory = generalDataDirectory.appendingPathComponent(rgbDirName) + depthImagesDirectory = generalDataDirectory.appendingPathComponent(depthDirName) + } + do { try FileManager.default.createDirectory(at: generalDataDirectory, withIntermediateDirectories: true) - if self.depthStatus.isDepthAvailable { - try FileManager.default.createDirectory(at: depthImagesDirectory, withIntermediateDirectories: true) + if mlManager?.saveDebugFrames == true, + let depthDir = depthImagesDirectory { + try FileManager.default.createDirectory(at: depthDir, withIntermediateDirectories: true) } try createFile(fileURL: poseTextURL) } catch { - print("Error creating directories") + // Error creating directories - continue with setup } self.rgbDirect = rgbImagesDirectory @@ -613,29 +1177,30 @@ class ARViewModel: ObservableObject{ self.videoInput?.expectsMediaDataInRealTime = true self.assetWriter?.add(videoInput!) - self.audioInput = AVAssetWriterInput(mediaType: .audio, outputSettings: audioOutputSettings) - self.audioInput?.expectsMediaDataInRealTime = true - self.assetWriter?.add(audioInput!) - - self.pixelBufferAdapter = AVAssetWriterInputPixelBufferAdaptor(assetWriterInput: videoInput!, sourcePixelBufferAttributes: rgbAttributes) - - // Update the audio delegate with the new audioWriterInput - self.audioCaptureDelegate = AudioCaptureDelegate(writerInput: audioInput!) + if(ifAudioEnable) { + self.audioInput = AVAssetWriterInput(mediaType: .audio, outputSettings: audioOutputSettings) + self.audioInput?.expectsMediaDataInRealTime = true + self.assetWriter?.add(audioInput!) + + // Update the audio delegate with the new audioWriterInput + self.audioCaptureDelegate = AudioCaptureDelegate(writerInput: audioInput!) - // Attach the new delegate to the existing AVCaptureAudioDataOutput - if let audioOutput = self.audioSession.outputs.first(where: { $0 is AVCaptureAudioDataOutput }) as? AVCaptureAudioDataOutput { - let audioQueue = DispatchQueue(label: "AudioProcessingQueue") - audioOutput.setSampleBufferDelegate(self.audioCaptureDelegate, queue: audioQueue) + // Attach the new delegate to the existing AVCaptureAudioDataOutput + if let audioOutput = self.audioSession.outputs.first(where: { $0 is AVCaptureAudioDataOutput }) as? AVCaptureAudioDataOutput { + let audioQueue = DispatchQueue(label: "AudioProcessingQueue") + audioOutput.setSampleBufferDelegate(self.audioCaptureDelegate, queue: audioQueue) + } } - if self.depthStatus.isDepthAvailable { - setupDepthRecording(depthVideoURL: depthVideoURL) - } + self.pixelBufferAdapter = AVAssetWriterInputPixelBufferAdaptor(assetWriterInput: videoInput!, sourcePixelBufferAttributes: rgbAttributes) + + // Setup depth recording if supported + setupDepthRecording(depthVideoURL: depthVideoURL) self.poseFileHandle = try FileHandle(forWritingTo: poseTextURL) try poseFileHandle?.seekToEnd() } catch { - print("Failed to setup recording: \(error)") + // Failed to setup recording - continue with available configuration } return RecordingFiles( @@ -674,7 +1239,7 @@ class ARViewModel: ObservableObject{ sourcePixelBufferAttributes: recordingDepthAttributes ) } catch { - print("Failed to setup depth recording: \(error)") + // Failed to setup depth recording - continue without depth } } @@ -690,18 +1255,14 @@ class ARViewModel: ObservableObject{ CVPixelBufferLockBaseAddress(outputBuffer, []) let ciImage = CIImage(cvPixelBuffer: rgbPixelBuffer) - let transformedImage = ciImage.transformed(by: self.combinedRGBTransform!) //.cropped(to: cropRect) + let transformedImage = ciImage.transformed(by: self.combinedRGBTransform ?? CGAffineTransform.identity) //.cropped(to: cropRect) self.ciContext.render(transformedImage, to: outputBuffer, bounds: cropRect, colorSpace: CGColorSpaceCreateDeviceRGB()) guard let pixelBufferAdapter = self.pixelBufferAdapter else { - print("Failed to append RGB pixel buffer: Pixel buffer adapter is nil.") return false } if !pixelBufferAdapter.append(outputBuffer, withPresentationTime: currentTime) { - let isReady = pixelBufferAdapter.assetWriterInput.isReadyForMoreMediaData - let writerError = self.assetWriter?.error?.localizedDescription ?? "Unknown asset writer error." - print("Failed to append RGB pixel buffer. Adapter state: \(isReady), Time: \(currentTime), Error: \(writerError)") return false } CVPixelBufferUnlockBaseAddress(outputBuffer, []) @@ -713,14 +1274,12 @@ class ARViewModel: ObservableObject{ guard let depthVideoInput = self.depthVideoInput, depthVideoInput.isReadyForMoreMediaData else { return false } guard let depthPixelBuffer = depthPixelBuffer else { return false } guard let pixelBufferPool = self.depthPixelBufferAdapter?.pixelBufferPool else { - print("Depth pixel buffer pool is nil.") return false } var outputPixelBuffer: CVPixelBuffer? let status = CVPixelBufferPoolCreatePixelBuffer(nil, pixelBufferPool, &outputPixelBuffer) guard status == kCVReturnSuccess, let depthOutputBuffer = outputPixelBuffer else { - print("Unable to create output pixel buffer for depth.") return false } @@ -740,14 +1299,10 @@ class ARViewModel: ObservableObject{ ) guard let depthPixelBufferAdapter = self.depthPixelBufferAdapter else { - print("Failed to append depth pixel buffer: Pixel buffer adapter is nil.") return false } if !depthPixelBufferAdapter.append(depthOutputBuffer, withPresentationTime: currentTime) { - let isReady = depthPixelBufferAdapter.assetWriterInput.isReadyForMoreMediaData - let writerError = self.depthAssetWriter?.error?.localizedDescription ?? "Unknown asset writer error." - print("❌ Failed to append RGB pixel buffer. Adapter state: \(isReady), Time: \(currentTime), Error: \(writerError)") return false } CVPixelBufferUnlockBaseAddress(depthOutputBuffer, []) @@ -773,7 +1328,7 @@ class ARViewModel: ObservableObject{ try self.poseFileHandle?.write(contentsOf: data) } } catch { - print("❌ Error writing pose data: \(error)") + // Error writing pose data - continue capture } } @@ -788,8 +1343,6 @@ class ARViewModel: ObservableObject{ if let outputImage = depthFilter.outputImage { filteredImage = outputImage - } else { - print("❌ Failed to apply color controls filter to depth image.") } if(self.isColorMapOpened){ @@ -799,11 +1352,107 @@ class ARViewModel: ObservableObject{ falseColorFilter.inputImage = filteredImage if let outputImage = falseColorFilter.outputImage { filteredImage = outputImage - } else { - print("❌ Failed to apply false color filter to depth image.") } } - return filteredImage.transformed(by: self.combinedDepthTransform!) //.cropped(to: cropRect) + return filteredImage.transformed(by: self.combinedDepthTransform ?? CGAffineTransform.identity) //.cropped(to: cropRect) + } + + // MARK: - Accelerate Optimizations + private func canUseOptimizedTransform(for pixelBuffer: CVPixelBuffer) -> Bool { + // Only use optimized path for simple transforms (scale + translate) + // Skip if transform contains rotation or complex operations + guard let transform = combinedRGBTransform else { return false } + + // Check if transform is approximately a simple scale/translate + let hasRotation = abs(transform.b) > 0.001 || abs(transform.c) > 0.001 + return !hasRotation + } + + private func processRGBOptimized(_ pixelBuffer: CVPixelBuffer) -> Data? { + // For now, use a simple direct conversion approach + // This bypasses the expensive CIImage -> CGImage -> UIImage pipeline + + guard let cgImage = createCGImageDirect(from: pixelBuffer) else { + return nil + } + + return UIImage(cgImage: cgImage).jpegData(compressionQuality: 0.5) + } + + private func createCGImageDirect(from pixelBuffer: CVPixelBuffer) -> CGImage? { + let width = CVPixelBufferGetWidth(pixelBuffer) + let height = CVPixelBufferGetHeight(pixelBuffer) + + guard let baseAddress = CVPixelBufferGetBaseAddress(pixelBuffer) else { return nil } + + let bytesPerRow = CVPixelBufferGetBytesPerRow(pixelBuffer) + let colorSpace = CGColorSpaceCreateDeviceRGB() + + let bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.noneSkipFirst.rawValue) + + guard let context = CGContext( + data: baseAddress, + width: width, + height: height, + bitsPerComponent: 8, + bytesPerRow: bytesPerRow, + space: colorSpace, + bitmapInfo: bitmapInfo.rawValue + ) else { return nil } + + return context.makeImage() + } + + private func canUseOptimizedDepthTransform(for pixelBuffer: CVPixelBuffer) -> Bool { + guard let transform = combinedDepthTransform else { return false } + // Check if transform is simple enough for vImage optimization + let hasRotation = abs(transform.b) > 0.001 || abs(transform.c) > 0.001 + return !hasRotation + } + + private func processDepthOptimized(_ inputBuffer: CVPixelBuffer, outputBuffer: CVPixelBuffer) { + // Simple memcpy for identity or simple scaling transforms + // This avoids Core Image overhead for depth data + + let inputWidth = CVPixelBufferGetWidth(inputBuffer) + let inputHeight = CVPixelBufferGetHeight(inputBuffer) + let outputWidth = CVPixelBufferGetWidth(outputBuffer) + let outputHeight = CVPixelBufferGetHeight(outputBuffer) + + guard let inputData = CVPixelBufferGetBaseAddress(inputBuffer), + let outputData = CVPixelBufferGetBaseAddress(outputBuffer) else { + return + } + + let inputBytesPerRow = CVPixelBufferGetBytesPerRow(inputBuffer) + let outputBytesPerRow = CVPixelBufferGetBytesPerRow(outputBuffer) + + if inputWidth == outputWidth && inputHeight == outputHeight { + // Direct copy for same-size buffers + let totalBytes = min(inputHeight * inputBytesPerRow, outputHeight * outputBytesPerRow) + memcpy(outputData, inputData, totalBytes) + } else { + // Use vImage for scaling if available + var sourceBuffer = vImage_Buffer( + data: inputData, + height: vImagePixelCount(inputHeight), + width: vImagePixelCount(inputWidth), + rowBytes: inputBytesPerRow + ) + + var destBuffer = vImage_Buffer( + data: outputData, + height: vImagePixelCount(outputHeight), + width: vImagePixelCount(outputWidth), + rowBytes: outputBytesPerRow + ) + + // Use vImage scaling for better performance than Core Image + let error = vImageScale_Planar16F(&sourceBuffer, &destBuffer, nil, vImage_Flags(kvImageNoFlags)) + if error != kvImageNoError { + print("vImage scaling failed: \(error)") + } + } } private func saveBinaryDepthData(depthPixelBuffer: CVPixelBuffer) { @@ -820,15 +1469,18 @@ class ARViewModel: ObservableObject{ let dataSize = width * height * MemoryLayout.size let data = Data(bytes: floatBuffer, count: dataSize) - // Save binary data to a file - let fileURL = self.depthDirect.appendingPathComponent("\(Int64(Date().timeIntervalSince1970*1000)).bin") - do { - try data.write(to: fileURL) - } catch { - print("Error saving binary file: \(error)") + // Save binary data to a file (only if debug frame saving is enabled) + if let depthDir = self.depthDirect { + let fileURL = depthDir.appendingPathComponent("\(Int64(Date().timeIntervalSince1970*1000)).bin") + do { + try data.write(to: fileURL) + } catch { + // Error saving binary file - continue capture + } } } + @MainActor func captureVideoFrame() { guard let currentFrame = session.currentFrame else {return} @@ -838,13 +1490,17 @@ class ARViewModel: ObservableObject{ let currentTime = CMTimeMake(value: Int64(CACurrentMediaTime() * 1000), timescale: 1000) let rgbPixelBuffer = currentFrame.capturedImage - var depthPixelBuffer: CVPixelBuffer? + let depthPixelBuffer = currentFrame.sceneDepth?.depthMap - if self.depthStatus.isDepthAvailable { - guard let depthBuffer = currentFrame.sceneDepth?.depthMap else { return } - depthPixelBuffer = depthBuffer + // Perform ML inference on the RGB frame (provide ARFrame for odometry/goal updates) + if let mlManager = mlManager { + Task { @MainActor in + mlManager.performInference(on: rgbPixelBuffer, arFrame: currentFrame, timestamp: CACurrentMediaTime()) + } } + + let cropRect = CGRect( x: 0, y: 0, width: self.viewPortSize.width, height: self.viewPortSize.height ) @@ -855,8 +1511,8 @@ class ARViewModel: ObservableObject{ DispatchQueue.global(qos: .userInitiated).async { let rgbSuccess = self.processRGBCaptureData(rgbPixelBuffer: rgbPixelBuffer, cropRect: cropRect, currentTime: currentTime) imgSuccessFlag = imgSuccessFlag && rgbSuccess - if self.depthStatus.isDepthAvailable && imgSuccessFlag { - let depthSuccess = self.processDepthCaptureData(depthPixelBuffer: depthPixelBuffer, cropRect: depthCropRect, currentTime: currentTime) + if let depthBuffer = depthPixelBuffer, imgSuccessFlag { + let depthSuccess = self.processDepthCaptureData(depthPixelBuffer: depthBuffer, cropRect: depthCropRect, currentTime: currentTime) imgSuccessFlag = imgSuccessFlag && depthSuccess } if imgSuccessFlag { @@ -865,13 +1521,62 @@ class ARViewModel: ObservableObject{ } } - func getDocumentsDirect() -> URL{ - let paths = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask) - print(paths[0].path) - return paths[0] + func getDocumentsDirect() -> URL { + FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)[0] } -} + func updateDemoCounter() { + let documentsURL = getDocumentsDirect() + do { + let contents = try FileManager.default.contentsOfDirectory(at: documentsURL, includingPropertiesForKeys: [.isDirectoryKey], options: [.skipsHiddenFiles, .skipsSubdirectoryDescendants]) + let directories = contents.filter { url in + var isDirectory: ObjCBool = false + FileManager.default.fileExists(atPath: url.path, isDirectory: &isDirectory) + return isDirectory.boolValue + } + demosCounter = directories.count + } catch { + demosCounter = 0 + } + } + + // MARK: - Model Manager Integration + @MainActor + func initializeMLManager(with modelManager: ModelManager) { + self.mlManager = MLInferenceManager(modelManager: modelManager) + + // Connect AR visualization to ML inference + self.mlManager?.arVisualizationManager = self.arVisualizationManager + // Provide AR session access to ML manager for goal and odometry + self.mlManager?.setARViewContainer(self) + + // Forward mlManager's property changes to arViewModel so SwiftUI updates + self.mlManager?.objectWillChange + .sink { [weak self] _ in + self?.objectWillChange.send() + } + .store(in: &cancellables) + + + } + + + // MARK: - Bluetooth Recording Helpers (Consolidated) + func startBluetoothRecording(targetURL: URL, fps: Double) { + do { + try createFile(fileURL: targetURL) + } catch { + print("Error creating tactile file.") + } + + bluetoothManager?.startRecording(targetURL: targetURL, fps: fps) + } + + func stopBluetoothRecording() { + bluetoothManager?.stopRecording() + } + + } class AudioCaptureDelegate: NSObject, AVCaptureAudioDataOutputSampleBufferDelegate { private let writerInput: AVAssetWriterInput? @@ -883,7 +1588,6 @@ class AudioCaptureDelegate: NSObject, AVCaptureAudioDataOutputSampleBufferDelega func captureOutput(_ output: AVCaptureOutput, didOutput sampleBuffer: CMSampleBuffer, from connection: AVCaptureConnection) { // Append audio sample buffer to the writer input guard writerInput?.isReadyForMoreMediaData == true else { - print("Not ready") return } writerInput?.append(sampleBuffer) diff --git a/AnySense/Managers/ARVisualizationManager.swift b/AnySense/Managers/ARVisualizationManager.swift new file mode 100644 index 0000000..17584ca --- /dev/null +++ b/AnySense/Managers/ARVisualizationManager.swift @@ -0,0 +1,566 @@ +// +// ARVisualizationManager.swift +// AnySense +// +// Created by Krish on 2025/2/1. +// + +import Foundation +import RealityKit +import ARKit +import simd +import UIKit + +// MARK: - Action State +enum ActionState { + case waiting // User is moving toward target + case reached // Proximity triggered + + var displayName: String { + switch self { + case .waiting: return "waiting" + case .reached: return "reached" + } + } +} + +// MARK: - Target State +enum TargetState { + case active // Red target + case reached // Green target +} + +// MARK: - AR Visualization Manager +@MainActor +class ARVisualizationManager: ObservableObject { + + // MARK: - Published Properties + @Published var isVisualizationEnabled: Bool = false + @Published var actionState: ActionState = .waiting + + // MARK: - Private Properties + private var arView: ARView? + private var worldOriginAnchor: AnchorEntity? + private var targetPose: SIMD3? + private var goalPointEntity: ModelEntity? + private var worldOrigin: SIMD3 = SIMD3(0, 0, 0) + private var hasEstablishedOrigin: Bool = false + + var debugLoggingEnabled: Bool = true + var isGripperClosed: Bool = false + var useVirtualGripper: Bool = false + var applyEndOffset: Bool = true + var endOffsetMeters: Float = 0.05 + + // MARK: - Wireframe & Target Visualization + private var wireframeEntity: Entity? + private var wireframeAnchor: AnchorEntity? + private let wireframeSize: Float = 0.018 + private let wireframeOffsetMeters: Float = 0.05 + private var wireframeVisualPosition: SIMD3? + private var activeTargetEntity: ModelEntity? + private var activeTargetPosition: SIMD3? + private let targetSize: Float = 0.012 + private var lastWireframeUpdateTime: CFTimeInterval = 0 + private let wireframeUpdateInterval: CFTimeInterval = 0.033 + + // MARK: - Initialization + init() { + log("Initialized with wireframe seek-target visualization") + } + + // MARK: - Logging Helper + private func log(_ message: String) { + print("[ARViz] \(message)") + } + + // MARK: - Setup Methods + func setupVisualization(with arView: ARView) { + self.arView = arView + log("Setup completed - using wireframe seek-target visualization") + } + + // MARK: - Recording Control Methods + func startRecordingVisualization() { + print("startRecordingVisualization called") + + guard arView != nil else { + print("ARView not available for visualization") + return + } + + print("Establishing world origin for movement tracking...") + establishWorldOrigin() + enableVisualization() + + // Reset gripper state to allow visualization + isGripperClosed = false + + print("Started movement visualization - enabled=\(isVisualizationEnabled)") + } + + func stopRecordingVisualization() { + disableVisualization() + clearAllVisualization() + resetMovementTracking() + + // Reset action state + actionState = .waiting + + // Reset gripper state + isGripperClosed = false + + print("Stopped movement visualization and reset tracking") + } + + // MARK: - World Origin & Movement Tracking + private func getCurrentCameraTransform() -> float4x4 { + return arView?.session.currentFrame?.camera.transform ?? matrix_identity_float4x4 + } + + private func getCurrentCameraPosition() -> SIMD3 { + let transform = getCurrentCameraTransform() + return SIMD3(transform.columns.3.x, transform.columns.3.y, transform.columns.3.z) + } + + private func establishWorldOrigin() { + guard let currentArView = arView else { return } + guard !hasEstablishedOrigin else { return } + + worldOrigin = getCurrentCameraPosition() + hasEstablishedOrigin = true + + var t = matrix_identity_float4x4 + t.columns.3 = SIMD4(worldOrigin.x, worldOrigin.y, worldOrigin.z, 1) + let anchor = AnchorEntity(world: t) + currentArView.scene.addAnchor(anchor) + worldOriginAnchor = anchor + + print("World origin set at: \(worldOrigin)") + } + + private func resetMovementTracking() { + hasEstablishedOrigin = false + worldOrigin = SIMD3(0, 0, 0) + goalPointEntity?.removeFromParent() + goalPointEntity = nil + worldOriginAnchor?.removeFromParent() + worldOriginAnchor = nil + } + + // MARK: - Control Methods + func enableVisualization() { + isVisualizationEnabled = true + } + + func disableVisualization() { + isVisualizationEnabled = false + clearAllVisualization() + } + + private func clearAllVisualization() { + DispatchQueue.main.async { [weak self] in + guard let self = self else { return } + + self.wireframeEntity?.removeFromParent() + self.wireframeEntity = nil + self.wireframeAnchor?.removeFromParent() + self.wireframeAnchor = nil + self.wireframeVisualPosition = nil + self.activeTargetEntity?.removeFromParent() + self.activeTargetEntity = nil + self.activeTargetPosition = nil + } + } + + // MARK: - Initialization helper + func ensureVisualizationReady() { + guard isVisualizationEnabled else { return } + if !hasEstablishedOrigin { establishWorldOrigin() } + if targetPose != nil && goalPointEntity == nil && worldOriginAnchor != nil { + updateGoalPointVisualization() + } + } + + // MARK: - Wireframe Management (Ego Visualization) + func updateWireframe(cameraRelativePosition: SIMD3) { + guard isVisualizationEnabled, !isGripperClosed else { return } + + let currentTime = CACurrentMediaTime() + guard currentTime - lastWireframeUpdateTime >= wireframeUpdateInterval else { + wireframeVisualPosition = cameraRelativePosition + checkProximityAndUpdateState() + return + } + lastWireframeUpdateTime = currentTime + + wireframeVisualPosition = cameraRelativePosition + + DispatchQueue.main.async { [weak self] in + guard let self = self, let arView = self.arView else { return } + + if self.wireframeAnchor == nil { + self.wireframeAnchor = AnchorEntity(.camera) + arView.scene.addAnchor(self.wireframeAnchor!) + } + + if self.wireframeEntity == nil { + // Ghost Arrow: Semi-transparent Blue + let ghostColor = UIColor.systemBlue.withAlphaComponent(0.4) + self.wireframeEntity = self.createArrowEntity(color: ghostColor) + self.wireframeAnchor!.addChild(self.wireframeEntity!) + + // Orient arrow to face forward (-Z) which is the default for our geometry + self.wireframeEntity!.orientation = simd_quatf(angle: 0, axis: SIMD3(1, 0, 0)) + } + + self.wireframeEntity?.position = SIMD3(0, 0, -self.wireframeOffsetMeters) + } + + checkProximityAndUpdateState() + } + + // MARK: - Proximity Check + private var lastLoggedDistance: Float = -1.0 + + private func checkProximityAndUpdateState() { + guard !isGripperClosed, + let wireframePos = wireframeVisualPosition, + let activeTarget = activeTargetEntity else { return } + + // Fix: Use local position (relative to WorldOriginAnchor) to match wireframeVisualPosition coordinate space + let targetPos = activeTarget.position + + // Use distance-based proximity + let distance = length(targetPos - wireframePos) + + // Update target color based on proximity + updateTargetColor(for: distance) + + // Debug print to verify distance + if debugLoggingEnabled && distance < 0.2 { + // throttling print to avoid spam could be good, but simple print is fine for now + // print("Dist: \(distance)") + } + + // Proximity threshold: Relaxed to 2.5cm for robust interaction + let proximityThreshold: Float = 0.025 + + let isNearby = distance <= proximityThreshold + + if isNearby { + if actionState != .reached { + print("Target Reached! (Dist: \(String(format: "%.3f", distance))m)") + actionState = .reached + // Remove the target to indicate success / clear the view + activeTargetEntity?.removeFromParent() + activeTargetEntity = nil + activeTargetPosition = nil + NotificationCenter.default.post(name: NSNotification.Name("ProximityReached"), object: nil) + } + } else { + if actionState != .waiting { + actionState = .waiting + } + } + } + + // MARK: - Color Update Based on Proximity + private func updateTargetColor(for distance: Float) { + guard let activeTarget = activeTargetEntity else { return } + + // Color transition based on distance + // Far: Red (distance > 0.12m) + // Medium: Orange/Yellow (0.03m - 0.12m) + // Close: Green (< 0.03m) + + let color: UIColor + if distance > 0.12 { + // Far - Red + color = UIColor.systemRed + } else if distance > 0.03 { + // Medium distance - interpolate from red to green + let progress = (0.12 - distance) / (0.12 - 0.03) // 0.0 to 1.0 + color = interpolateColor(from: UIColor.systemRed, to: UIColor.systemGreen, progress: progress) + } else { + // Close - Green + color = UIColor.systemGreen + } + + // Update the material on main thread - need to update all children of the arrow entity + DispatchQueue.main.async { [weak self] in + guard let self = self else { return } + let newMaterial = SimpleMaterial(color: color, isMetallic: false) + + // Update all children of the arrow entity (shaft and head) + for child in activeTarget.children { + if let modelChild = child as? ModelEntity { + modelChild.model?.materials = [newMaterial] + } + } + } + } + + // MARK: - Color Interpolation Helper + private func interpolateColor(from: UIColor, to: UIColor, progress: Float) -> UIColor { + let clampedProgress = max(0.0, min(1.0, progress)) + + var fromRed: CGFloat = 0, fromGreen: CGFloat = 0, fromBlue: CGFloat = 0, fromAlpha: CGFloat = 0 + var toRed: CGFloat = 0, toGreen: CGFloat = 0, toBlue: CGFloat = 0, toAlpha: CGFloat = 0 + + from.getRed(&fromRed, green: &fromGreen, blue: &fromBlue, alpha: &fromAlpha) + to.getRed(&toRed, green: &toGreen, blue: &toBlue, alpha: &toAlpha) + + let resultRed = fromRed + (toRed - fromRed) * CGFloat(clampedProgress) + let resultGreen = fromGreen + (toGreen - fromGreen) * CGFloat(clampedProgress) + let resultBlue = fromBlue + (toBlue - fromBlue) * CGFloat(clampedProgress) + let resultAlpha = fromAlpha + (toAlpha - fromAlpha) * CGFloat(clampedProgress) + + return UIColor(red: resultRed, green: resultGreen, blue: resultBlue, alpha: resultAlpha) + } + + // MARK: - Manual Trigger Support + func forceTargetTransition() { + activeTargetEntity?.removeFromParent() + activeTargetEntity = nil + activeTargetPosition = nil + actionState = .waiting + } + + // MARK: - Proximity Configuration + func setProximityThreshold(_ threshold: Float) { + // Allow adjusting the merge distance threshold if needed + log("Proximity threshold: \(threshold)m") + } + + // MARK: - Gripper State Control + func setGripperState(isClosed: Bool) { + let previousState = isGripperClosed + isGripperClosed = isClosed + + if isClosed && !previousState { + // Gripper just closed - hide all visualization + print("[Viz] Gripper CLOSED - Hiding visualization") + + DispatchQueue.main.async { [weak self] in + guard let self = self else { return } + + // Remove wireframe + self.wireframeEntity?.removeFromParent() + self.wireframeEntity = nil + + // Remove active target + self.activeTargetEntity?.removeFromParent() + self.activeTargetEntity = nil + self.activeTargetPosition = nil + + // Set action state to waiting + self.actionState = .waiting + } + } else if !isClosed && previousState { + print("[Viz] Gripper OPENED - Visualization enabled") + } + } + + // MARK: - Virtual Gripper Control + func toggleVirtualGripper() { + useVirtualGripper.toggle() + print("Virtual gripper: \(useVirtualGripper ? "ON" : "OFF")") + } + + func setVirtualGripper(enabled: Bool) { + useVirtualGripper = enabled + print("Virtual gripper: \(enabled ? "ON" : "OFF")") + } + + // MARK: - USB Streaming Integration + private var isUSBStreamingActive: Bool = false + + func setUSBStreamingState(isActive: Bool) { + // This integrates with the existing USB streaming system + // When USB streaming is active, virtual gripper is automatically disabled + isUSBStreamingActive = isActive + if isActive { + print("USB streaming ON - Virtual gripper automatically disabled") + } else { + print("USB streaming OFF - Virtual gripper setting: \(useVirtualGripper ? "ON" : "OFF")") + } + } + + func shouldUseVirtualGripper() -> Bool { + // Virtual gripper is only used when: + // 1. useVirtualGripper is enabled AND + // 2. USB streaming is not active + return useVirtualGripper && !isUSBStreamingActive + } + + + // MARK: - Device Pose Integration + func updateActualDevicePose(from arFrame: ARFrame) { + if isVisualizationEnabled && hasEstablishedOrigin { + let t = arFrame.camera.transform + let currentCameraPosition = SIMD3(t.columns.3.x, t.columns.3.y, t.columns.3.z) - worldOrigin + + // Calculate wireframe position in World Frame + // The wireframe is fixed at (0, 0, -wireframeOffsetMeters) in Camera Frame (Forward) + // We need to rotate this offset by the camera's orientation to get it in World Frame + + // Camera Forward vector is -Z axis (column 2 is +Z/Backward) + let cameraBackward = SIMD3(t.columns.2.x, t.columns.2.y, t.columns.2.z) + let offsetInWorld = -wireframeOffsetMeters * cameraBackward + + let cameraWorldPosition = currentCameraPosition + offsetInWorld + updateWireframe(cameraRelativePosition: cameraWorldPosition) + } + } + + func setTargetPose(_ worldPoint: SIMD3) { + targetPose = worldPoint + ensureVisualizationReady() + } + + func clearTargetPose() { + targetPose = nil + goalPointEntity?.removeFromParent() + goalPointEntity = nil + } + + // MARK: - ML Integration Method + func updatePoseFromMLOutput(_ jointActions: [Float], timestamp: CFTimeInterval = CACurrentMediaTime()) { + guard isVisualizationEnabled && hasEstablishedOrigin && !isGripperClosed else { return } + guard jointActions.count >= 6 else { return } + + let (cameraDeltaTranslation, cameraRotation) = interpretMLDirections(jointActions, timestamp: timestamp) + let cameraTransform = getCurrentCameraTransform() + let rotationWorldFromCamera = simd_float3x3( + columns: ( + SIMD3(cameraTransform.columns.0.x, cameraTransform.columns.0.y, cameraTransform.columns.0.z), + SIMD3(cameraTransform.columns.1.x, cameraTransform.columns.1.y, cameraTransform.columns.1.z), + SIMD3(cameraTransform.columns.2.x, cameraTransform.columns.2.y, cameraTransform.columns.2.z) + ) + ) + let deltaTranslation = rotationWorldFromCamera * cameraDeltaTranslation + + // Convert local rotation to world rotation: R_target = R_camera * R_delta + let currentCameraRotation = simd_quatf(cameraTransform) + let targetRotation = currentCameraRotation * cameraRotation + + let currentCameraPosition = SIMD3(cameraTransform.columns.3.x, cameraTransform.columns.3.y, cameraTransform.columns.3.z) - worldOrigin + let targetPosition = currentCameraPosition + deltaTranslation + updateTarget(position: targetPosition, rotation: targetRotation) + } + + private func interpretMLDirections(_ jointActions: [Float], timestamp: CFTimeInterval = CACurrentMediaTime()) -> (translation: SIMD3, rotation: simd_quatf) { + let action7 = Array(jointActions.prefix(7)) + let mapped = ActionTransformUtils.policyToCameraEulerAction(action7, rotationUnit: .eulerXYZ) + var translationCamera = SIMD3(mapped[0], mapped[1], mapped[2]) + + if applyEndOffset { + translationCamera += SIMD3(0, 0, -endOffsetMeters) + } + + let rotationCamera = eulerToQuaternion(roll: mapped[3], pitch: mapped[4], yaw: mapped[5]) + return (translationCamera, rotationCamera) + } + + func updateTargetCube(position: SIMD3) { + // Legacy support - default identity rotation + updateTarget(position: position, rotation: simd_quatf(angle: 0, axis: SIMD3(0, 1, 0))) + } + + // MARK: - Target Management + func updateTarget(position: SIMD3, rotation: simd_quatf) { + guard isVisualizationEnabled, !isGripperClosed else { return } + + DispatchQueue.main.async { [weak self] in + guard let self = self, let worldOriginAnchor = self.worldOriginAnchor else { return } + + if self.activeTargetEntity == nil { + // Target Arrow: Solid Red + let redColor = UIColor.systemRed.withAlphaComponent(1.0) + let newTarget = self.createArrowEntity(color: redColor) + worldOriginAnchor.addChild(newTarget) + self.activeTargetEntity = newTarget + self.actionState = .waiting + } + + self.activeTargetEntity?.position = position + self.activeTargetEntity?.orientation = rotation // Aligned with Camera Frame + self.activeTargetPosition = position + print("[Viz] Target Arrow Pos: (\(String(format: "%.3f", position.x)), \(String(format: "%.3f", position.y)), \(String(format: "%.3f", position.z)))") + } + } + + // Shared Arrow Creation (Used for both Ghost and Target) + private func createArrowEntity(color: UIColor) -> ModelEntity { + let arrowGroup = Entity() + let material = SimpleMaterial(color: color, isMetallic: false) + + // Dimensions - Adjusted for visual clarity + let length: Float = 0.025 // Shortened shaft (was 0.04) + let shaftRadius: Float = 0.004 // Thicker shaft (was 0.003) + let headRadius: Float = 0.012 // Wider head (was 0.01) + let headLength: Float = 0.015 // Head length kept similar + + // Shaft + let shaft = MeshResource.generateBox(width: shaftRadius*2, height: shaftRadius*2, depth: length) + let shaftEntity = ModelEntity(mesh: shaft, materials: [material]) + shaftEntity.position = SIMD3(0, 0, -length/2) + + // Head (using Box for simplicity, but scaled to look broadly pointer-like) + let head = MeshResource.generateBox(size: headRadius*2) + let headEntity = ModelEntity(mesh: head, materials: [material]) + headEntity.position = SIMD3(0, 0, -length - headRadius/2) + // Note: Head position shifted to attach to shaft end + + // Combine + let parent = ModelEntity() + parent.addChild(shaftEntity) + parent.addChild(headEntity) + + return parent + } + + private func eulerToQuaternion(roll: Float, pitch: Float, yaw: Float) -> simd_quatf { + let phi_2 = roll / 2.0 + let theta_2 = pitch / 2.0 + let psi_2 = yaw / 2.0 + + let cos_phi_2 = cos(phi_2) + let sin_phi_2 = sin(phi_2) + let cos_theta_2 = cos(theta_2) + let sin_theta_2 = sin(theta_2) + let cos_psi_2 = cos(psi_2) + let sin_psi_2 = sin(psi_2) + + let w = cos_phi_2 * cos_theta_2 * cos_psi_2 + sin_phi_2 * sin_theta_2 * sin_psi_2 + let x = sin_phi_2 * cos_theta_2 * cos_psi_2 - cos_phi_2 * sin_theta_2 * sin_psi_2 + let y = cos_phi_2 * sin_theta_2 * cos_psi_2 + sin_phi_2 * cos_theta_2 * sin_psi_2 + let z = cos_phi_2 * cos_theta_2 * sin_psi_2 - sin_phi_2 * sin_theta_2 * cos_psi_2 + + return simd_quatf(ix: x, iy: y, iz: z, r: w) + } + + // MARK: - Goal Point Visualization + private func updateGoalPointVisualization() { + guard let targetPose = targetPose, + let worldOriginAnchor = worldOriginAnchor, + hasEstablishedOrigin else { + goalPointEntity?.removeFromParent() + goalPointEntity = nil + return + } + + goalPointEntity?.removeFromParent() + goalPointEntity = nil + + let sphereMesh = MeshResource.generateSphere(radius: 0.02) + let goalMaterial = SimpleMaterial(color: .systemRed, isMetallic: false) + goalPointEntity = ModelEntity(mesh: sphereMesh, materials: [goalMaterial]) + + let relativePosition = targetPose - worldOrigin + goalPointEntity?.position = relativePosition + worldOriginAnchor.addChild(goalPointEntity!) + print("[Viz] Sphere (Goal) Position: (\(String(format: "%.3f", relativePosition.x)), \(String(format: "%.3f", relativePosition.y)), \(String(format: "%.3f", relativePosition.z))) | Dist: \(length(relativePosition))m") + } +} diff --git a/AnySense/Managers/ActionTransformUtils.swift b/AnySense/Managers/ActionTransformUtils.swift new file mode 100644 index 0000000..12f883d --- /dev/null +++ b/AnySense/Managers/ActionTransformUtils.swift @@ -0,0 +1,214 @@ +import Foundation +import simd + +struct ActionTransformUtils { + enum RotationUnit { + case eulerXYZ // rx, ry, rz (radians) + case axisAngle // rotation vector (axis * angle) + } + // arkit camera frame to labels.json frame + private static let P: simd_float4x4 = { + let c0 = SIMD4(-1, 0, 0, 0) + let c1 = SIMD4( 0, 0, -1, 0) + let c2 = SIMD4( 0, -1, 0, 0) + let c3 = SIMD4( 0, 0, 0, 1) + return simd_float4x4(columns: (c0, c1, c2, c3)) + }() + + // labels.json frame to robot frame + private static let Z90: simd_float4x4 = { + let c0 = SIMD4( 0, 1, 0, 0) + let c1 = SIMD4(-1, 0, 0, 0) + let c2 = SIMD4( 0, 0, 1, 0) + let c3 = SIMD4( 0, 0, 0, 1) + return simd_float4x4(columns: (c0, c1, c2, c3)) + }() + + // MARK: - Public Entry + // The policy produces an action tensor in its own convention (labels.json). + // We first map it into the iPhone CAMERA frame (policy→camera), then convert + // CAMERA→ROBOT using the Z90 rotation to match the robot's execution frame. + // Input: policy action [tx, ty, tz, r1, r2, r3, gripper] + // Output: robot-frame [tx, ty, tz, rx, ry, rz, gripper] (Euler xyz) + static func toRobotActions(_ policyAction7: [Float], rotationUnit: RotationUnit = .eulerXYZ) -> [Float] { + guard policyAction7.count >= 7 else { return policyAction7 } + + // 1) Policy (Labels) → Camera Frame + let camEulerAction = policyToCameraEulerAction(policyAction7, rotationUnit: rotationUnit) + let gr = camEulerAction[6] + + // Build 4x4 Transform in Camera Frame + let T_c = buildTransform(translation: SIMD3(camEulerAction[0], camEulerAction[1], camEulerAction[2]), eulerXYZ: SIMD3(camEulerAction[3], camEulerAction[4], camEulerAction[5])) + + // 2) Camera → Robot: T_r = Z90 @ T_c @ Z90.T + // We skip the intermediate P transform (Pt * T_c * P) because T_c is already + // in the correct Camera frame, and Z90 maps Camera(Right,Up,Back) → Robot(Down,Right,Back). + let Zt = simd_transpose(Z90) + let T_r = Z90 * T_c * Zt + + // 3) 4x4 → robot action (Euler xyz) + let rxyz = eulerXYZ(from: T_r) + let t = translation(from: T_r) + + return [t.x, t.y, t.z, rxyz.x, rxyz.y, rxyz.z, gr] + } + + // Policy → CAMERA mapping in a single place so viz and robot stay consistent. + // Mapping based on frame definitions: + // - Policy (Labels): [x: Left, y: Forward, z: Down] + // - Camera (ARKit): [x: Right, y: Up, z: Back] + // + // Transformation: + // - Policy x (Left) → Camera -x (Right) + // - Policy y (Forward) → Camera -z (Back) + // - Policy z (Down) → Camera -y (Up) + static func policyToCameraEulerAction(_ policyAction7: [Float], rotationUnit: RotationUnit = .eulerXYZ) -> [Float] { + guard policyAction7.count >= 7 else { return policyAction7 } + let ml_x = policyAction7[0] // Left + let ml_y = policyAction7[1] // Forward + let ml_z = policyAction7[2] // Down + let r1 = policyAction7[3] + let r2 = policyAction7[4] + let r3 = policyAction7[5] + let gr = policyAction7[6] + + // Translation mapping (Labels → CAMERA) + // cam_x = -label_x (Left -> Right) + // cam_y = -label_z (Down -> Up) + // cam_z = -label_y (Forward -> Back) + let cam_t = SIMD3(-ml_x, -ml_z, -ml_y) + + // Rotation mapping → always return Euler xyz in CAMERA frame + var cam_euler: SIMD3 + switch rotationUnit { + case .eulerXYZ: + cam_euler = SIMD3(-r1, -r2, -r3) + case .axisAngle: + let R_cam = rotationMatrixFromAxisAngle(axisAngle: SIMD3(r1, r2, r3)) + cam_euler = eulerXYZ(from: matrixFromRotationAndTranslation(R_cam, t: SIMD3(0,0,0))) + } + + return [cam_t.x, cam_t.y, cam_t.z, cam_euler.x, cam_euler.y, cam_euler.z, gr] + } + + // MARK: - Builders + private static func buildTransform(translation t: SIMD3, eulerXYZ r: SIMD3) -> simd_float4x4 { + let R = rotationMatrixXYZ(rx: r.x, ry: r.y, rz: r.z) + var T = matrix_identity_float4x4 + T.columns.0 = SIMD4(R.columns.0.x, R.columns.0.y, R.columns.0.z, 0) + T.columns.1 = SIMD4(R.columns.1.x, R.columns.1.y, R.columns.1.z, 0) + T.columns.2 = SIMD4(R.columns.2.x, R.columns.2.y, R.columns.2.z, 0) + T.columns.3 = SIMD4(t.x, t.y, t.z, 1) + return T + } + + private static func matrixFromRotationAndTranslation(_ R: simd_float3x3, t: SIMD3) -> simd_float4x4 { + var T = matrix_identity_float4x4 + T.columns.0 = SIMD4(R.columns.0.x, R.columns.0.y, R.columns.0.z, 0) + T.columns.1 = SIMD4(R.columns.1.x, R.columns.1.y, R.columns.1.z, 0) + T.columns.2 = SIMD4(R.columns.2.x, R.columns.2.y, R.columns.2.z, 0) + T.columns.3 = SIMD4(t.x, t.y, t.z, 1) + return T + } + + // Rotation matrix for Euler 'xyz' (radians) + // Using standard formula: + // R = [[ cy*cz, -cy*sz, sy ], + // [ cx*sz + cz*sx*sy, cx*cz - sx*sy*sz, -cy*sx ], + // [ sx*sz - cx*cz*sy, cz*sx + cx*sy*sz, cx*cy ]] + private static func rotationMatrixXYZ(rx: Float, ry: Float, rz: Float) -> simd_float3x3 { + let cx = cos(rx), sx = sin(rx) + let cy = cos(ry), sy = sin(ry) + let cz = cos(rz), sz = sin(rz) + + let r00 = cy * cz + let r01 = -cy * sz + let r02 = sy + + let r10 = cx * sz + cz * sx * sy + let r11 = cx * cz - sx * sy * sz + let r12 = -cy * sx + + let r20 = sx * sz - cx * cz * sy + let r21 = cz * sx + cx * sy * sz + let r22 = cx * cy + + let c0 = SIMD3(r00, r10, r20) + let c1 = SIMD3(r01, r11, r21) + let c2 = SIMD3(r02, r12, r22) + return simd_float3x3(columns: (c0, c1, c2)) + } + + // Axis-angle (rotation vector) → rotation matrix via Rodrigues' formula + private static func rotationMatrixFromAxisAngle(axisAngle v: SIMD3) -> simd_float3x3 { + let theta = sqrt(v.x*v.x + v.y*v.y + v.z*v.z) + let eps: Float = 1e-8 + if theta < eps { + return simd_float3x3(diagonal: SIMD3(1,1,1)) + } + let k = SIMD3(v.x/theta, v.y/theta, v.z/theta) + let K = simd_float3x3(rows: [ + SIMD3( 0, -k.z, k.y), + SIMD3( k.z, 0, -k.x), + SIMD3(-k.y, k.x, 0) + ]) + let I = simd_float3x3(diagonal: SIMD3(1,1,1)) + // R = I + sinθ K + (1 - cosθ) K^2 + let K2 = simd_mul(K, K) + let R = I + sin(theta) * K + (1 - cos(theta)) * K2 + return R + } + + // Rotation matrix for Euler 'xyz' (radians) + private static func eulerXYZ(from T: simd_float4x4) -> SIMD3 { + // Reconstruct row-major elements from column-major storage + let R00 = T.columns.0.x, R01 = T.columns.1.x, R02 = T.columns.2.x + let R11 = T.columns.1.y, R12 = T.columns.2.y + let R21 = T.columns.1.z, R22 = T.columns.2.z + + let y = asin(clamp(R02, -1.0, 1.0)) + let cy = cos(y) + let eps: Float = 1e-6 + let x: Float + let z: Float + if abs(cy) > eps { + x = atan2(-R12, R22) + z = atan2(-R01, R00) + } else { + // Gimbal lock + x = atan2(R21, R11) + z = 0.0 + } + return SIMD3(x, y, z) + } + + private static func translation(from T: simd_float4x4) -> SIMD3 { + return SIMD3(T.columns.3.x, T.columns.3.y, T.columns.3.z) + } + + private static func clamp(_ v: Float, _ lo: Float, _ hi: Float) -> Float { + return max(lo, min(hi, v)) + } + + // MARK: - Debug helpers + static func debugTransformReport(_ policyAction7: [Float], rotationUnit: RotationUnit = .eulerXYZ) -> String { + guard policyAction7.count >= 7 else { return "" } + let cam = policyToCameraEulerAction(policyAction7, rotationUnit: rotationUnit) + let T_c = buildTransform(translation: SIMD3(cam[0], cam[1], cam[2]), eulerXYZ: SIMD3(cam[3], cam[4], cam[5])) + + let Zt = simd_transpose(Z90) + let T_r = Z90 * T_c * Zt + + let robotEuler = eulerXYZ(from: T_r) + let robotT = translation(from: T_r) + func fmt(_ f: Float) -> String { String(format: "%.4f", f) } + func fmt3(_ v: SIMD3) -> String { "(\(fmt(v.x)),\(fmt(v.y)),\(fmt(v.z)))" } + return [ + "policy(Lab) t,r: \(fmt3(SIMD3(policyAction7[0], policyAction7[1], policyAction7[2]))) \(fmt3(SIMD3(policyAction7[3], policyAction7[4], policyAction7[5])))", + "camera(Cam) t,r: \(fmt3(SIMD3(cam[0], cam[1], cam[2]))) \(fmt3(SIMD3(cam[3], cam[4], cam[5])))", + "robot(Exec) t,r: \(fmt3(robotT)) \(fmt3(robotEuler))", + ].joined(separator: "\n") + } +} + + diff --git a/AnySense/Managers/BluetoothManager.swift b/AnySense/Managers/BluetoothManager.swift index fd4f469..72a6928 100644 --- a/AnySense/Managers/BluetoothManager.swift +++ b/AnySense/Managers/BluetoothManager.swift @@ -21,27 +21,45 @@ class BluetoothManager : NSObject, ObservableObject{ super.init() self.centralManager = CBCentralManager(delegate: self, queue: .main) } - + + // Cleanup + deinit { + // Clean up CADisplayLink to prevent retain cycles + displayLink?.invalidate() + displayLink = nil + + // Clean up file handles + try? BTFileHandle?.close() + BTFileHandle = nil + + // Disconnect from any connected peripherals + disconnectFromDevice() + + // Stop scanning + centralManager?.stopScan() + centralManager = nil + + // BluetoothManager deinitialized + } } extension BluetoothManager: CBCentralManagerDelegate{ func centralManagerDidUpdateState(_ central: CBCentralManager) { switch central.state { case .poweredOff: - print("Is Powered Off.") + break case .poweredOn: - print("Is Powered On.") self.scan() case .unsupported: - print("Is Unsupported.") + break case .unauthorized: - print("Is Unauthorized.") + break case .unknown: - print("Unknown") + break case .resetting: - print("Resetting") + break @unknown default: - print("Error") + break } } func scan() -> Void{ @@ -87,7 +105,6 @@ extension BluetoothManager: CBCentralManagerDelegate{ // Disconnect if already connected to another peripheral if let currentPeripheral = matchedPeripheral, currentPeripheral.identifier != peripheral.identifier { - print("🔄 Disconnecting previous peripheral: \(currentPeripheral.name ?? "Unknown")") central.cancelPeripheralConnection(currentPeripheral) } @@ -124,10 +141,7 @@ extension BluetoothManager: CBCentralManagerDelegate{ extension BluetoothManager: CBPeripheralDelegate{ func peripheral(_ peripheral: CBPeripheral, didDiscoverServices error: Error?) { - print("*******************************************************") - if ((error) != nil) { - print("Error discovering services: \(error!.localizedDescription)") return } guard let services = peripheral.services else { @@ -137,7 +151,6 @@ extension BluetoothManager: CBPeripheralDelegate{ for service in services { peripheral.discoverCharacteristics(nil, for: service) } - print("Discovered Services: \(services)") } func peripheral(_ peripheral: CBPeripheral, didDiscoverCharacteristicsFor service: CBService, error: Error?) { @@ -145,27 +158,18 @@ extension BluetoothManager: CBPeripheralDelegate{ return } - print("Found \(characteristics.count) characteristics.") - // NOTE: We will simply take the first Rx characteristic and use it for reading for characteristic in characteristics { if characteristic.properties.contains(.notify) || characteristic.properties.contains(.indicate) { - print("This characteristic is Rx (Receiving data)") rxCharacteristic = characteristic peripheral.setNotifyValue(true, for: rxCharacteristic!) peripheral.readValue(for: characteristic) - print("RX Characteristic: \(rxCharacteristic.uuid)") break } - if characteristic.properties.contains(.write) || characteristic.properties.contains(.writeWithoutResponse) { - print("This characteristic is Tx (Transmitting data)") - // TODO: Code for handling Tx characteristics goes here - } } } func peripheral(_ peripheral: CBPeripheral, didUpdateValueFor characteristic: CBCharacteristic, error: Error?) { if let error = error { - print("Error updating characteristic: \(error.localizedDescription)") return } } @@ -177,19 +181,19 @@ extension BluetoothManager: CBPeripheralManagerDelegate { func peripheralManagerDidUpdateState(_ peripheral: CBPeripheralManager) { switch peripheral.state { case .poweredOn: - print("Peripheral Is Powered On.") + break case .unsupported: - print("Peripheral Is Unsupported.") + break case .unauthorized: - print("Peripheral Is Unauthorized.") + break case .unknown: - print("Peripheral Unknown") + break case .resetting: - print("Peripheral Resetting") + break case .poweredOff: - print("Peripheral Is Powered Off.") + break @unknown default: - print("Error") + break } } @@ -200,7 +204,7 @@ extension BluetoothManager: CBPeripheralManagerDelegate { try self.BTFileHandle?.seekToEnd() } catch { - print("Error opening BTFileHandle") + // Error opening BTFileHandle } displayLink = CADisplayLink(target: self, selector: #selector(recordSingleData)) @@ -214,7 +218,7 @@ extension BluetoothManager: CBPeripheralManagerDelegate { do { try BTFileHandle?.close() } catch { - print("Error closing pose file") + // Error closing pose file } } diff --git a/AnySense/Managers/MLInferenceManager.swift b/AnySense/Managers/MLInferenceManager.swift new file mode 100644 index 0000000..82ae20e --- /dev/null +++ b/AnySense/Managers/MLInferenceManager.swift @@ -0,0 +1,1354 @@ +// +// MLInferenceManager.swift +// AnySense +// +// Created by Krish on 2025/2/1. +// + +import Foundation +import ImageIO +import CoreML +import Vision +import CoreVideo +import QuartzCore +import Combine +import RealityKit +import CoreImage +import ARKit +import simd +import UIKit +import Accelerate + +// MARK: - ML Inference Results +struct InferenceResult { + let jointPositions: [Float] // 7 joint action values + let inferenceTime: TimeInterval +} + +// MARK: - ML Inference Manager +@MainActor +class MLInferenceManager: ObservableObject { + + // MARK: - Published Properties + @Published var latestResult: InferenceResult? + @Published var lastResult: InferenceResult? + @Published var isInferencePendingUI: Bool = false + + @MainActor + func clearPendingState() { + isInferencePending = false + isInferencePendingUI = false + } + @Published var isInferenceEnabled: Bool = false + @Published var inferenceFrequency: InferenceFrequency = .medium + @Published var currentGoalPoint: simd_float3? + @Published var modelMetadata: ModelMetadata? + @Published var isModelLoading: Bool = false // Tracks loading and warm-up + + // MARK: - Private Properties + private var model: MLModel? + private var lastInferenceTime: CFTimeInterval = 0 + private var inferenceQueue = DispatchQueue(label: "MLInferenceQueue", qos: .userInitiated) + + // MARK: - Goal Point Management + + // MARK: - Goal Point Management + + func setGoalPoint(_ point: simd_float3) { + self.currentGoalPoint = point + arVisualizationManager?.setTargetPose(point) + // Reset goal frame count when new goal is set + goalFrameCount = 0 + } + + // Goal conditioning mode (point-conditioned models use 3D goals) + private var goalDimension: Int = 3 + + // Track how many times goal has been used (for first-frame offset) + private var goalFrameCount: Int = 0 + + // MARK: - Frame Buffering for Temporal Models + private struct FrameBufferEntry { + let mlArray: MLMultiArray // Pre-processed [1,3,H,W] frame + let goalPoint: [Float]? // Goal at time of capture + } + private var frameBuffer: [FrameBufferEntry] = [] + private var maxBufferSize: Int = 3 // Always maintain 3 action trigger frames + + // Current frame processing (still processes every frame for potential storage) + private var currentFrameEntry: FrameBufferEntry? + + // MARK: - Proximity-based Inference Control + private var proximityReached: Bool = false + private var isInferencePending: Bool = false + private var hasRunFirstInference: Bool = false // Track if we've run initial inference + + // MARK: - Model Management + private var modelManager: ModelManager + private var cancellables = Set() + + // MARK: - AR Visualization Integration + weak var arVisualizationManager: ARVisualizationManager? + + // MARK: - Frame Processing (Taken from ARViewContainer) + private let ciContext: CIContext + private var modelInputSize = CGSize(width: 224, height: 224) + + // MARK: - Shared Buffers (Reused to avoid allocations) + private var sharedOutputPixelBuffer: CVPixelBuffer? // Reused CVPixelBuffer for frame processing + private var sharedMLMultiArrayBuffer: MLMultiArray? // Pre-allocated MLMultiArray for frame conversion + private var cachedGripperOverlays: [String: CIImage] = [:] // Cached transformed gripper overlays + + // MARK: - Gripper Overlay Properties + private var gripperOpenCIImage: CIImage? + private var gripperClosedCIImage: CIImage? + private var gripperOpenUIImage: UIImage? + private var gripperClosedUIImage: UIImage? + private var gripperOverlayBuffer: vImage_Buffer? + private var isUSBStreamingActive: Bool = false + private var currentGripperValue: Float = 1.0 // Track latest gripper value + @Published var enableGripperOverlay: Bool = true // Default enabled (for model input) + @Published var showGripperOverlayOnScreen: Bool = true { // Show overlay on AR view + didSet { + Task { @MainActor in + updateGripperOverlayDisplay() + } + } + } + @Published var currentGripperOverlayImage: UIImage? // Current overlay image for display + @Published var saveDebugFrames: Bool = false // For testing + + // MARK: - Inference Settings + enum InferenceFrequency: CaseIterable { + case high, medium, low, minute + + var interval: TimeInterval { + switch self { + case .high: return 0.0 + case .medium: return 1.0 + case .low: return 10.0 + case .minute: return 60.0 + } + } + + var displayName: String { + switch self { + case .high: return "High (30 FPS)" + case .medium: return "Medium (1 Hz)" + case .low: return "Low (0.1 FPS)" + case .minute: return "Minute (1/min)" + } + } + } + + // MARK: - Transform/Debug Settings + var rotationUnit: ActionTransformUtils.RotationUnit = .eulerXYZ + var enableTransformDebug: Bool = true + var debugLoggingEnabled: Bool = true // Enable detailed logging + // Apply server-style image orientation (Record3D publisher does rotations/mirrors) + var applyServerImageOrientation: Bool = false + + // Initialization + init(modelManager: ModelManager) { + self.modelManager = modelManager + self.ciContext = CIContext() + initializeSharedBuffers() // Initialize shared buffers early + loadActiveModel() + loadGripperOverlay() + + // Listen for active model changes + modelManager.$activeModel + .receive(on: DispatchQueue.main) + .sink { [weak self] (activeModel: ModelInfo?) in + self?.loadActiveModel() + } + .store(in: &cancellables) + + // Listen for proximity reached notifications from ARVisualizationManager + NotificationCenter.default.addObserver( + forName: NSNotification.Name("ProximityReached"), + object: nil, + queue: .main + ) { [weak self] _ in + self?.handleProximityReached() + } + } + + // MARK: - Gripper Overlay Methods + private func loadGripperOverlay() { + // Load open gripper (default/original) + if let openImage = UIImage(named: "gripper_overlay") { + gripperOpenCIImage = CIImage(image: openImage) + gripperOpenUIImage = openImage + print("Open gripper overlay loaded:") + print(" - Size: \(openImage.size)") + print(" - Scale: \(openImage.scale)") + print(" - CIImage extent: \(gripperOpenCIImage?.extent ?? .zero)") + } else { + print("Warning: Could not load gripper_overlay (open) image from assets") + } + + // Load closed gripper + if let closedImage = UIImage(named: "gripper_closed") { + gripperClosedCIImage = CIImage(image: closedImage) + gripperClosedUIImage = closedImage + print("Closed gripper overlay loaded:") + print(" - Size: \(closedImage.size)") + print(" - Scale: \(closedImage.scale)") + print(" - CIImage extent: \(gripperClosedCIImage?.extent ?? .zero)") + } else { + print("Warning: Could not load gripper_closed image from assets") + } + + // Setup vImage buffer with open gripper as default + if let openImage = UIImage(named: "gripper_overlay") { + setupGripperOverlayBuffer(from: openImage) + } + + // Set initial overlay image for display + Task { @MainActor in + updateGripperOverlayDisplay() + } + } + + @MainActor + private func updateGripperOverlayDisplay() { + guard shouldShowGripperOverlayOnScreen() else { + currentGripperOverlayImage = nil + return + } + + let isGripperClosed = currentGripperValue < 0.7 + let baseImage = isGripperClosed ? gripperClosedUIImage : gripperOpenUIImage + + print("DEBUG: Updating gripper overlay - value: \(String(format: "%.3f", currentGripperValue)), closed: \(isGripperClosed)") + + // Update published property (automatically triggers objectWillChange) + currentGripperOverlayImage = baseImage + print("DEBUG: Gripper overlay image updated: \(isGripperClosed ? "CLOSED" : "OPEN")") + } + + private func setupGripperOverlayBuffer(from uiImage: UIImage) { + guard let cgImage = uiImage.cgImage else { return } + + let width = cgImage.width + let height = cgImage.height + let bytesPerPixel = 4 + let bytesPerRow = width * bytesPerPixel + let bufferLength = height * bytesPerRow + + guard let data = malloc(bufferLength) else { + print("Warning: Could not allocate memory for gripper overlay buffer") + return + } + + var buffer = vImage_Buffer( + data: data, + height: vImagePixelCount(height), + width: vImagePixelCount(width), + rowBytes: bytesPerRow + ) + + // Convert CGImage to vImage buffer + var format = vImage_CGImageFormat( + bitsPerComponent: 8, + bitsPerPixel: 32, + colorSpace: nil, + bitmapInfo: CGBitmapInfo(rawValue: CGImageAlphaInfo.premultipliedLast.rawValue), + version: 0, + decode: nil, + renderingIntent: .defaultIntent + ) + + let error = vImageBuffer_InitWithCGImage(&buffer, &format, nil, cgImage, vImage_Flags(kvImageNoFlags)) + if error == kvImageNoError { + gripperOverlayBuffer = buffer + } else { + free(data) + print("Warning: Failed to create vImage buffer for gripper overlay: \(error)") + } + } + + func setUSBStreamingState(isActive: Bool) { + isUSBStreamingActive = isActive + let inputOverlay = shouldApplyGripperOverlayToModelInput() ? "ENABLED" : "DISABLED" + let screenOverlay = shouldShowGripperOverlayOnScreen() ? "ON" : "OFF" + print("USB streaming state: \(isActive ? "ON" : "OFF") - Overlay input: \(inputOverlay), Overlay display: \(screenOverlay)") + Task { @MainActor in + updateGripperOverlayDisplay() + } + } + + private func shouldApplyGripperOverlayToModelInput() -> Bool { + // Use gripper overlay in model input when USB streaming is OFF (virtual gripper proxy) + return enableGripperOverlay && !isUSBStreamingActive + } + + private func shouldShowGripperOverlayOnScreen() -> Bool { + // On-screen overlay is user-controlled; keep it off during USB streaming + return showGripperOverlayOnScreen && !isUSBStreamingActive + } + + private func getCurrentGripperOverlay() -> CIImage? { + // Use closed gripper when gripper value < 0.6, otherwise open gripper + let isGripperClosed = currentGripperValue < 0.7 + if saveDebugFrames { + print("Gripper state: \(String(format: "%.3f", currentGripperValue)) → \(isGripperClosed ? "CLOSED" : "OPEN")") + } + + if isGripperClosed { + return gripperClosedCIImage + } else { + return gripperOpenCIImage + } + } + + private func applyGripperOverlay(to image: CIImage) -> CIImage { + guard shouldApplyGripperOverlayToModelInput() else { + if saveDebugFrames { + print("DEBUG: Gripper overlay skipped - enableGripperOverlay: \(enableGripperOverlay), isUSBStreaming: \(isUSBStreamingActive)") + } + return image + } + + guard let gripperOverlay = getCurrentGripperOverlay() else { + if saveDebugFrames { + print("DEBUG: No gripper overlay image available - openCIImage: \(gripperOpenCIImage != nil), closedCIImage: \(gripperClosedCIImage != nil)") + } + return image + } + + // Check cache first to avoid expensive transform operations + let cacheKey = "\(currentGripperValue < 0.7 ? "closed" : "open")_\(Int(image.extent.width))x\(Int(image.extent.height))" + if let cachedOverlay = cachedGripperOverlays[cacheKey] { + return applyCachedGripperOverlay(to: image, overlay: cachedOverlay) + } + + if saveDebugFrames { + print("DEBUG: Applying gripper overlay - value: \(String(format: "%.3f", currentGripperValue))") + } + let result = applyGripperOverlayCoreImage(to: image, overlay: gripperOverlay) + + // Cache the transformed overlay for reuse + if cachedGripperOverlays.count < 10 { // Limit cache size + let transformedOverlay = createTransformedGripperOverlay(gripperOverlay, imageSize: image.extent.size) + cachedGripperOverlays[cacheKey] = transformedOverlay + } + + return result + } + + // MARK: - Cached Gripper Overlay Methods + private func createTransformedGripperOverlay(_ gripperOverlay: CIImage, imageSize: CGSize) -> CIImage { + // Apply same transformations as camera frames: scale to fit, then rotate if needed + let scale = min(imageSize.width / gripperOverlay.extent.width, imageSize.height / gripperOverlay.extent.height) + + // Build combined transform: scale -> optional orientation -> rotation -> translation + var transform = CGAffineTransform(scaleX: scale, y: scale) + + // Apply same orientation as camera frames + if applyServerImageOrientation { + transform = transform.concatenating(CGAffineTransform(rotationAngle: CGFloat.pi)) + } + + // Additional +90 degree rotation to align gripper direction with viewpoint + transform = transform.concatenating(CGAffineTransform(rotationAngle: CGFloat.pi / 2)) + + // Apply combined transform + var transformedOverlay = gripperOverlay.transformed(by: transform) + + // After rotation, translate back to origin for proper overlay positioning + let rotatedExtent = transformedOverlay.extent + transformedOverlay = transformedOverlay.transformed(by: CGAffineTransform(translationX: -rotatedExtent.origin.x, y: -rotatedExtent.origin.y)) + + return transformedOverlay + } + + private func applyCachedGripperOverlay(to image: CIImage, overlay cachedOverlay: CIImage) -> CIImage { + guard let compositeFilter = CIFilter(name: "CISourceOverCompositing") else { + return image + } + + compositeFilter.setValue(cachedOverlay, forKey: kCIInputImageKey) + compositeFilter.setValue(image, forKey: kCIInputBackgroundImageKey) + + return compositeFilter.outputImage ?? image + } + + private func applyGripperOverlayCoreImage(to image: CIImage, overlay gripperOverlay: CIImage) -> CIImage { + let imageSize = image.extent.size + + if saveDebugFrames { + print("DEBUG: Starting composite - image: \(imageSize)") + print("DEBUG: Original overlay extent: \(gripperOverlay.extent)") + } + + let transformedOverlay = createTransformedGripperOverlay(gripperOverlay, imageSize: imageSize) + + if saveDebugFrames { + print("DEBUG: Server orientation: \(applyServerImageOrientation)") + print("DEBUG: Final overlay extent: \(transformedOverlay.extent)") + } + + guard let compositeFilter = CIFilter(name: "CISourceOverCompositing") else { + if saveDebugFrames { + print("DEBUG: Failed to create CISourceOverCompositing filter") + } + return image + } + + compositeFilter.setValue(transformedOverlay, forKey: kCIInputImageKey) + compositeFilter.setValue(image, forKey: kCIInputBackgroundImageKey) + + if let result = compositeFilter.outputImage { + if saveDebugFrames { + print("DEBUG: Composite successful, result extent: \(result.extent)") + } + return result + } else { + if saveDebugFrames { + print("DEBUG: Composite filter returned nil") + } + return image + } + } + + // MARK: - Debug Frame Saving + private func saveDebugFrame(_ image: CIImage, prefix: String) { + guard saveDebugFrames else { return } + + let timestamp = Int(Date().timeIntervalSince1970 * 1000) + let filename = "\(prefix)_\(timestamp).png" + + // Get Documents directory + guard let documentsDirectory = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask).first else { + print("Warning: Could not access Documents directory") + return + } + + let fileURL = documentsDirectory.appendingPathComponent(filename) + + // Convert CIImage to Data + guard let colorSpace = CGColorSpace(name: CGColorSpace.sRGB) else { return } + guard let data = ciContext.pngRepresentation(of: image, format: .RGBA8, colorSpace: colorSpace) else { + print("Warning: Could not create PNG data for debug frame") + return + } + + do { + try data.write(to: fileURL) + print("Debug frame saved: \(fileURL.lastPathComponent)") + } catch { + print("Warning: Could not save debug frame: \(error)") + } + } + + // MARK: - Model Loading + private func loadActiveModel() { + guard let activeModel = modelManager.activeModel, + activeModel.compilationStatus.isCompiled else { + // No active compiled model available + model = nil + modelMetadata = nil + frameBuffer.removeAll() + hasRunFirstInference = false + return + } + + // Indicate loading started + DispatchQueue.main.async { + self.isModelLoading = true + } + + // Perform loading on background thread to keep UI responsive + Task.detached(priority: .userInitiated) { [weak self] in + guard let self = self else { return } + + do { + let loadedModel = try await self.modelManager.loadModelAsync(for: activeModel) + + // Extract model metadata for type detection + let metadata = try ModelMetadata(from: loadedModel) + + await MainActor.run { + self.model = loadedModel + self.modelMetadata = metadata + + // Always maintain 3-frame rolling buffer + self.frameBuffer.removeAll() + self.hasRunFirstInference = false // Reset for new model + + print("Model loaded: \(activeModel.name)") + print(" Temporal frames: \(metadata.temporalFrames)") + print(" Goal conditioning: \(metadata.requiresGoalPoint)") + print(" Buffer size: \(self.maxBufferSize)") + + self.modelInputSize = CGSize(width: 224, height: 224) + self.initializeSharedBuffers() + self.goalDimension = 3 + + // Mark loading as complete + self.isModelLoading = false + } + } catch { + print("Error loading model: \(error)") + await MainActor.run { + self.model = nil + self.modelMetadata = nil + self.frameBuffer.removeAll() + self.hasRunFirstInference = false + self.isModelLoading = false + } + } + } + } + + // MARK: - Shared Buffer Initialization + private func initializeSharedBuffers() { + // Initialize shared output pixel buffer (224x224 ARGB) - reused for all frame processing + if sharedOutputPixelBuffer == nil { + let width = 224 + let height = 224 + let attributes: [String: Any] = [ + kCVPixelBufferPixelFormatTypeKey as String: kCVPixelFormatType_32ARGB, + kCVPixelBufferWidthKey as String: width, + kCVPixelBufferHeightKey as String: height + ] + var pixelBuffer: CVPixelBuffer? + let status = CVPixelBufferCreate( + kCFAllocatorDefault, + width, height, + kCVPixelFormatType_32ARGB, + attributes as CFDictionary, + &pixelBuffer + ) + if status == kCVReturnSuccess { + sharedOutputPixelBuffer = pixelBuffer + } + } + + // Initialize shared MLMultiArray buffer (224x224x3) - reused for frame conversion + if sharedMLMultiArrayBuffer == nil { + do { + sharedMLMultiArrayBuffer = try MLMultiArray(shape: [1, 3, 224, 224], dataType: .float32) + } catch { + print("Warning: Could not create shared MLMultiArray buffer: \(error)") + } + } + } + + // MARK: - Model Management Integration + var hasAvailableModel: Bool { + return modelManager.hasAvailableModel && model != nil + } + + var requiresGoalPoint: Bool { + return modelMetadata?.requiresGoalPoint ?? false + } + + var isPointConditioned: Bool { + return true // All models are point-conditioned now + } + + var activeModelName: String? { + return modelManager.activeModel?.name + } + + var isUsingUploadedModel: Bool { + return modelManager.hasCompiledModel + } + + // MARK: - Goal Point Management + + + func clearGoalPoint() { + currentGoalPoint = nil + goalFrameCount = 0 + } + + // MARK: - Odometry Integration Methods (removed) + + /// Set goal dimension (2D or 3D goal conditioning) + func setGoalDimension(_ dimension: Int) { + goalDimension = dimension + } + + /// Get goal point for model input (2D or 3D based on goal_dim) + func getGoalPointForModel() -> [Float]? { + // Standardize 3D goals to labels.json frame from the CURRENT camera frame: [-x, z, y] + if goalDimension == 3 { + guard let session = getARSession(), let frame = session.currentFrame else { + return nil + } + + + // Use the world-locked goal point (ARKit keeps it fixed in world space) + guard let p_w = currentGoalPoint else { return nil } + // World (ARKit) → Camera + let T_wc = frame.camera.transform + let T_cw = simd_inverse(T_wc) + let p_c4 = simd_mul(T_cw, simd_float4(p_w.x, p_w.y, p_w.z, 1.0)) + // camera: x right, y up, z back + // labels: x left, y forward, z down + // Mapping: x = -x_cam, y = -z_cam, z = -y_cam + // Add 0.02 offset only on first frame + let yOffset: Float = (goalFrameCount == 0) ? 0.02 : 0.0 + let goalArr = [-p_c4.x, -p_c4.z + yOffset, -p_c4.y] + goalFrameCount += 1 + return goalArr + } + // If model expects 2D goals, return nil since we only support 3D goals now + return nil + } + + // MARK: - AR Session Access + weak var arViewContainer: ARViewModel? + + private func getARSession() -> ARSession? { + // Get ARSession from the connected ARViewContainer + return arViewContainer?.getARSession() + } + + func setARViewContainer(_ container: ARViewModel) { + self.arViewContainer = container + } + + // MARK: - Proximity Handler + private func handleProximityReached() { + print("[ML] Goal Reached (Proximity Trigger Received)") + guard !isInferencePending else { + print("[MLInference] Proximity reached but inference already pending - skipping") + return + } + proximityReached = true + print("[MLInference] Proximity reached - inference will trigger with next frame (firstInference: \(hasRunFirstInference))") + } + + // MARK: - Frequency-based Inference Helper + private func shouldRunBasedOnFrequency(_ timestamp: CFTimeInterval) -> Bool { + let interval = inferenceFrequency.interval + if interval == 0.0 { + return true // High frequency - every frame + } + return (timestamp - lastInferenceTime) >= interval + } + + // MARK: - Inference Methods (Using existing frame processing patterns) + func performInference(on pixelBuffer: CVPixelBuffer, arFrame: ARFrame?, timestamp: CFTimeInterval = CACurrentMediaTime()) { + // Update device pose for visualization (optional) + if let frame = arFrame { + arVisualizationManager?.updateActualDevicePose(from: frame) + } + + performInference(on: pixelBuffer, timestamp: timestamp) + } + + func performInference(on pixelBuffer: CVPixelBuffer, timestamp: CFTimeInterval = CACurrentMediaTime()) { + guard isInferenceEnabled, + let metadata = modelMetadata else { + return + } + + // Check if goal point is required but not set + if metadata.requiresGoalPoint && currentGoalPoint == nil { + return // Skip until goal point is set + } + + // Process current frame for potential action storage + do { + let processedFrame = try processFrame(pixelBuffer, targetSize: CGSize(width: 224, height: 224), debugPrefix: "current") + let goalPointArray = metadata.requiresGoalPoint ? getGoalPointForModel() : nil + currentFrameEntry = FrameBufferEntry(mlArray: processedFrame, goalPoint: goalPointArray) + } catch { + print("ERROR: Failed to process current frame: \(error)") + return + } + + // Store current frame entry for buffering + guard let currentEntry = currentFrameEntry else { + print("ERROR: No current frame entry available") + return + } + + let isFirstInference = !hasRunFirstInference + let shouldRunInference: Bool + + if isUSBStreamingActive { + // USB ON: Continuously add frames to buffer (rolling 3-frame window) + frameBuffer.append(currentEntry) + if frameBuffer.count > maxBufferSize { + frameBuffer.removeFirst() + } + print("[MLInference] USB Mode: Frame added to rolling buffer (\(frameBuffer.count)/\(maxBufferSize))") + + // Run inference based on frequency setting or first inference + shouldRunInference = isFirstInference || shouldRunBasedOnFrequency(timestamp) + } else { + // USB OFF: Proximity-triggered buffering for recording mode + let isProximityTriggered = proximityReached && !isInferencePending + shouldRunInference = isFirstInference || isProximityTriggered + + guard shouldRunInference else { + return // Don't add to buffer unless inference is triggered + } + + frameBuffer.append(currentEntry) + if frameBuffer.count > maxBufferSize { + frameBuffer.removeFirst(frameBuffer.count - maxBufferSize) + } + print("[MLInference] Recording Mode: Action frame stored (\(frameBuffer.count) action trigger frames)") + } + + guard shouldRunInference else { + return + } + + // Temporal models pad with repeated frames if needed + + // Ensure we have a model loaded + guard let model = model else { + print("ERROR: No model loaded for inference") + return + } + + // Mark inference as pending + isInferencePending = true + isInferencePendingUI = true + proximityReached = false // Reset proximity flag + + if debugLoggingEnabled { + print("[MLInference] Running inference - firstTime: \(!hasRunFirstInference), buffer: \(frameBuffer.count)") + } + + // Prepare input using buffered frames + let modelInput: MLFeatureProvider + do { + modelInput = try prepareModelInputFromBuffer(metadata: metadata) + } catch { + print("ERROR: Failed to prepare model input: \(error)") + isInferencePending = false + isInferencePendingUI = false + return + } + + inferenceQueue.async { [weak self, modelInput, model] in + guard let self = self else { return } + + let startTime = CACurrentMediaTime() + + autoreleasepool { + do { + print("DEBUG: Running model prediction with buffered frames...") + let output = try model.prediction(from: modelInput) + print("DEBUG: Model prediction succeeded") + + let inferenceTime = CACurrentMediaTime() - startTime + self.processInferenceResults(output, inferenceTime: inferenceTime) + + // Reset pending flag and mark first inference complete + DispatchQueue.main.async { + self.isInferencePending = false + self.isInferencePendingUI = false + self.lastInferenceTime = CACurrentMediaTime() // Update for frequency tracking + if !self.hasRunFirstInference { + self.hasRunFirstInference = true + print("[MLInference] First inference complete - target cube should now be visible") + } + } + } catch { + print("ERROR: Model inference failed: \(error)") + DispatchQueue.main.async { + self.isInferencePending = false + self.isInferencePendingUI = false + } + } + } + } + } + + // MARK: - Action Frame Buffer Input Preparation + private func prepareModelInputFromBuffer(metadata: ModelMetadata) throws -> MLFeatureProvider { + print("DEBUG: Preparing input from action frame buffer for point-conditioned model") + + guard !frameBuffer.isEmpty else { + throw NSError(domain: "MLInferenceManager", code: -1, userInfo: [NSLocalizedDescriptionKey: "No action frames stored yet"]) + } + + return try prepareVQBeTInputFromBuffer(metadata: metadata) + } + + // MARK: - MLMultiArray Copying Helper + /// Copy a single channel from source MLMultiArray to target MLMultiArray + /// - Parameters: + /// - source: Source MLMultiArray with shape [1, 3, H, W] + /// - target: Target MLMultiArray to copy into + /// - targetTimestep: Target timestep index + /// - targetChannel: Target channel index + /// - sourceChannel: Source channel index + private func copyMLMultiArrayChannel(source: MLMultiArray, target: MLMultiArray, targetTimestep: Int, targetChannel: Int, sourceChannel: Int) { + for h in 0..<224 { + for w in 0..<224 { + let value = source[[0, NSNumber(value: sourceChannel), NSNumber(value: h), NSNumber(value: w)]] + target[[0, NSNumber(value: targetTimestep), NSNumber(value: targetChannel), NSNumber(value: h), NSNumber(value: w)]] = value + } + } + } + + // MARK: - Action Frame Buffer Input Preparation Methods + private func prepareVQBeTInputFromBuffer(metadata: ModelMetadata) throws -> MLFeatureProvider { + let goalPointArray = frameBuffer.last?.goalPoint ?? getGoalPointForModel() + guard let goalPointArray = goalPointArray else { + throw NSError(domain: "MLInferenceManager", code: -1, userInfo: [NSLocalizedDescriptionKey: "Goal point required"]) + } + + let imageInputName = metadata.getImageInputName() ?? "camera_image" + let goalInputName = metadata.getGoalInputName() ?? "goal_point" + let temporalFrames = metadata.temporalFrames + + // Determine expected rank from model + let expectedRank: Int = { + if let d = model?.modelDescription.inputDescriptionsByName[imageInputName], + d.type == .multiArray, + let shape = d.multiArrayConstraint?.shape { + return shape.count + } + return 4 + }() + + print("DEBUG: Using buffered action frames - temporalFrames: \(temporalFrames), actionFramesAvailable: \(frameBuffer.count)") + + // Build image array from buffer + let imageArray: MLMultiArray + if temporalFrames > 1 { + // Temporal model: use available action frames, pad with repetition if needed + let framesToUse = min(temporalFrames, frameBuffer.count) + let paddingNeeded = max(0, temporalFrames - framesToUse) + + // Create temporal frame array (can't reuse since it's used in model input) + imageArray = try MLMultiArray(shape: [1, NSNumber(value: temporalFrames), 3, 224, 224], dataType: .float32) + + guard frameBuffer.count > 0 else { + throw NSError(domain: "MLInferenceManager", code: -1, userInfo: [NSLocalizedDescriptionKey: "No action frames available"]) + } + + // Pad with repeated first action frame if needed + if paddingNeeded > 0 { + let firstActionFrame = frameBuffer[0].mlArray + for t in 0.. 1 && dims[2] == 3 { + // Temporal goal: [1, T, 3] + for t in 0.. MLMultiArray { + let width = Int(targetSize.width) + let height = Int(targetSize.height) + + // Use shared output pixel buffer or create if needed + if sharedOutputPixelBuffer == nil { + initializeSharedBuffers() + } + + guard let outputBuffer = sharedOutputPixelBuffer else { + throw NSError(domain: "MLInferenceManager", code: -1, userInfo: [NSLocalizedDescriptionKey: "Failed to get shared output pixel buffer"]) + } + + // Process image - use Accelerate for scaling when possible, fallback to Core Image for complex transforms + CVPixelBufferLockBaseAddress(pixelBuffer, .readOnly) + CVPixelBufferLockBaseAddress(outputBuffer, []) + defer { + CVPixelBufferUnlockBaseAddress(outputBuffer, []) + CVPixelBufferUnlockBaseAddress(pixelBuffer, .readOnly) + } + + let inputImageSize = CGSize(width: CVPixelBufferGetWidth(pixelBuffer), height: CVPixelBufferGetHeight(pixelBuffer)) + let scaleX = targetSize.width / inputImageSize.width + let scaleY = targetSize.height / inputImageSize.height + + // Use vImage for simple scaling (no rotation/orientation), Core Image for complex transforms + if !applyServerImageOrientation && abs(scaleX - scaleY) < 0.01 { + // Simple uniform scaling - use vImage for better performance + if processFrameWithVImage(pixelBuffer: pixelBuffer, outputBuffer: outputBuffer, scale: Float(scaleX)) { + // vImage scaling succeeded, now apply gripper overlay if needed + let scaledCIImage = CIImage(cvPixelBuffer: outputBuffer) + var finalImage = scaledCIImage + + if shouldApplyGripperOverlayToModelInput(), let gripperOverlay = getCurrentGripperOverlay() { + finalImage = applyGripperOverlayCoreImage(to: scaledCIImage, overlay: gripperOverlay) + } + + // Render final image to output buffer + let cropRect = CGRect(origin: .zero, size: targetSize) + ciContext.render(finalImage, to: outputBuffer, bounds: cropRect, colorSpace: CGColorSpaceCreateDeviceRGB()) + } else { + // Fallback to Core Image pipeline + let inputImage = CIImage(cvPixelBuffer: pixelBuffer) + var scaledImage = inputImage.transformed(by: CGAffineTransform(scaleX: scaleX, y: scaleY)) + + // Save original scaled image for debugging + saveDebugFrame(scaledImage, prefix: "\(debugPrefix)_original") + + // Apply gripper overlay when USB streaming is off (virtual gripper proxy) + scaledImage = applyGripperOverlay(to: scaledImage) + saveDebugFrame(scaledImage, prefix: "\(debugPrefix)_with_overlay") + + let cropRect = CGRect(origin: .zero, size: targetSize) + ciContext.render(scaledImage, to: outputBuffer, bounds: cropRect, colorSpace: CGColorSpaceCreateDeviceRGB()) + } + } else { + // Complex transform (rotation/orientation) - use Core Image + let inputImage = CIImage(cvPixelBuffer: pixelBuffer) + var scaledImage = inputImage.transformed(by: CGAffineTransform(scaleX: scaleX, y: scaleY)) + if applyServerImageOrientation { + scaledImage = scaledImage.oriented(.down) + } + + // Save original scaled image for debugging + saveDebugFrame(scaledImage, prefix: "\(debugPrefix)_original") + + // Apply gripper overlay when USB streaming is off (virtual gripper proxy) + scaledImage = applyGripperOverlay(to: scaledImage) + saveDebugFrame(scaledImage, prefix: "\(debugPrefix)_with_overlay") + + let cropRect = CGRect(origin: .zero, size: targetSize) + ciContext.render(scaledImage, to: outputBuffer, bounds: cropRect, colorSpace: CGColorSpaceCreateDeviceRGB()) + } + + // Convert to MLMultiArray as single frame [1,3,H,W] for buffering + return try convertPixelBufferToMLMultiArray(outputBuffer, width: width, height: height) + } + + // MARK: - Unified Pixel Buffer to MLMultiArray Conversion (Accelerate Optimized) + private func convertPixelBufferToMLMultiArray(_ pixelBuffer: CVPixelBuffer, width: Int, height: Int) throws -> MLMultiArray { + // Use shared buffer if available, otherwise create new one + let inputArray: MLMultiArray + if let sharedBuffer = sharedMLMultiArrayBuffer { + inputArray = sharedBuffer + // Clear the buffer by zeroing it out efficiently + memset(inputArray.dataPointer, 0, inputArray.count * MemoryLayout.size) + } else { + // Fallback: create new MLMultiArray + inputArray = try MLMultiArray(shape: [1, 3, NSNumber(value: height), NSNumber(value: width)], dataType: .float32) + } + + guard let baseAddress = CVPixelBufferGetBaseAddress(pixelBuffer) else { + throw NSError(domain: "MLInferenceManager", code: -1, userInfo: [NSLocalizedDescriptionKey: "Failed to get pixel buffer base address"]) + } + + let bytesPerRow = CVPixelBufferGetBytesPerRow(pixelBuffer) + let totalPixels = width * height + + // Use Accelerate for faster conversion + let rPtr = inputArray.dataPointer.assumingMemoryBound(to: Float.self) + let gPtr = rPtr.advanced(by: totalPixels) + let bPtr = gPtr.advanced(by: totalPixels) + + // Use vDSP for efficient channel extraction and conversion + // ARGB format: [A, R, G, B, A, R, G, B, ...] + // We need to extract R (offset 1), G (offset 2), B (offset 3) from each 4-byte pixel + + let buffer = baseAddress.assumingMemoryBound(to: UInt8.self) + let tempBufferSize = totalPixels * MemoryLayout.size + guard let tempR = malloc(tempBufferSize), + let tempG = malloc(tempBufferSize), + let tempB = malloc(tempBufferSize) else { + // Fallback to manual conversion if memory allocation fails + return try convertPixelBufferToMLMultiArrayManual(pixelBuffer: pixelBuffer, inputArray: inputArray, width: width, height: height) + } + defer { + free(tempR) + free(tempG) + free(tempB) + } + + // Extract channels using optimized stride-based approach + var pixelIndex = 0 + for y in 0...size) + return copyArray + } + + return inputArray + } + + // Fallback manual conversion method + private func convertPixelBufferToMLMultiArrayManual(pixelBuffer: CVPixelBuffer, inputArray: MLMultiArray, width: Int, height: Int) throws -> MLMultiArray { + guard let baseAddress = CVPixelBufferGetBaseAddress(pixelBuffer) else { + throw NSError(domain: "MLInferenceManager", code: -1, userInfo: [NSLocalizedDescriptionKey: "Failed to get pixel buffer base address"]) + } + + let buffer = baseAddress.assumingMemoryBound(to: UInt8.self) + let bytesPerRow = CVPixelBufferGetBytesPerRow(pixelBuffer) + let totalPixels = width * height + + let rPtr = inputArray.dataPointer.assumingMemoryBound(to: Float.self) + let gPtr = rPtr.advanced(by: totalPixels) + let bPtr = gPtr.advanced(by: totalPixels) + + var pixelIndex = 0 + for y in 0.. Bool { + guard let inputBase = CVPixelBufferGetBaseAddress(pixelBuffer), + let outputBase = CVPixelBufferGetBaseAddress(outputBuffer) else { + return false + } + + let inputWidth = CVPixelBufferGetWidth(pixelBuffer) + let inputHeight = CVPixelBufferGetHeight(pixelBuffer) + let outputWidth = CVPixelBufferGetWidth(outputBuffer) + let outputHeight = CVPixelBufferGetHeight(outputBuffer) + let inputBytesPerRow = CVPixelBufferGetBytesPerRow(pixelBuffer) + let outputBytesPerRow = CVPixelBufferGetBytesPerRow(outputBuffer) + + var sourceBuffer = vImage_Buffer( + data: inputBase, + height: vImagePixelCount(inputHeight), + width: vImagePixelCount(inputWidth), + rowBytes: inputBytesPerRow + ) + + var destBuffer = vImage_Buffer( + data: outputBase, + height: vImagePixelCount(outputHeight), + width: vImagePixelCount(outputWidth), + rowBytes: outputBytesPerRow + ) + + // Use vImageScale_ARGB8888 for fast scaling + let error = vImageScale_ARGB8888(&sourceBuffer, &destBuffer, nil, vImage_Flags(kvImageHighQualityResampling)) + return error == kvImageNoError + } + + // MARK: - Result Processing + + /// Extract joint positions from model output, handling both single-step and multi-step outputs + private func extractJointPositions(from resultArray: MLMultiArray) -> [Float] { + let shape = resultArray.shape.map { $0.intValue } + + // Check if this is a multi-step temporal output: [T,1,7] or similar + if shape.count == 3 && shape[0] > 1 && shape[2] >= 7 { + // Multi-step output: extract last timestep [T-1, 0, 0...6] + let lastTimestep = shape[0] - 1 + let jointPositions = (0..<7).map { i in + resultArray[[NSNumber(value: lastTimestep), 0, NSNumber(value: i)]].floatValue + } + print("Multi-step output detected: shape \(shape), using last timestep [\(lastTimestep)]") + return jointPositions + } else { + // Single-step output: extract directly + let outputCount = min(resultArray.count, 10) + return (0..= 7 { + currentGripperValue = jointPositions[6] + let isGripperClosed = currentGripperValue < 0.7 + print("[ML] PointCond - Gripper Value: \(String(format: "%.3f", currentGripperValue)) | State: \(isGripperClosed ? "CLOSED" : "OPEN")") + + Task { @MainActor [weak self] in + self?.updateGripperOverlayDisplay() + // Update AR visualization manager gripper state to stop visualization when closed + self?.arVisualizationManager?.setGripperState(isClosed: isGripperClosed) + } + } + + let result = InferenceResult( + jointPositions: jointPositions, + inferenceTime: inferenceTime + ) + + // Update UI on main thread + DispatchQueue.main.async { [weak self] in + self?.latestResult = result + self?.lastResult = result + + // Only enable visualization when NOT in USB streaming mode (recording mode only) + if let arManager = self?.arVisualizationManager, + jointPositions.count >= 6, + self?.isUSBStreamingActive != true, + arManager.isVisualizationEnabled { + arManager.ensureVisualizationReady() + arManager.updatePoseFromMLOutput(jointPositions, timestamp: self?.lastInferenceTime ?? CACurrentMediaTime()) + } + + // Joint actions are automatically sent via USB stream (transform to robot frame) + if jointPositions.count >= 7 { + let src = Array(jointPositions.prefix(7)) + if self?.enableTransformDebug == true { + let report = ActionTransformUtils.debugTransformReport(src, rotationUnit: self?.rotationUnit ?? .eulerXYZ) + print("Coordinate Transform:\n\(report)") + } + } + } + } + + // MARK: - Control Methods + func enableInference() { + isInferenceEnabled = true + let modelName = modelManager.activeModel?.name ?? "No model" + print("Inference enabled: \(modelName)") + if enableTransformDebug { + print("Transform debug enabled (\(rotationUnit))") + } + + // Update gripper overlay to show inference status + Task { @MainActor in + updateGripperOverlayDisplay() + } + } + + func disableInference() { + isInferenceEnabled = false + latestResult = nil + // Preserve lastResult so UI can continue showing the previous inference output while idle + isInferencePending = false + isInferencePendingUI = false + print("Inference disabled") + + // Update gripper overlay to hide inference status + Task { @MainActor in + updateGripperOverlayDisplay() + } + } + + func resetInferenceState() { + hasRunFirstInference = false + proximityReached = false + isInferencePending = false + isInferencePendingUI = false + frameBuffer.removeAll() + currentFrameEntry = nil + goalFrameCount = 0 // Reset goal frame count + print("Inference state reset - ready for new recording") + } + + // MARK: - Manual Inference Trigger + func triggerInferenceManually() { + guard isInferenceEnabled, + let metadata = modelMetadata else { + print("[MLInference] Cannot trigger manually - inference disabled or no model") + return + } + + // Check if goal point is required but not set + if metadata.requiresGoalPoint && currentGoalPoint == nil { + print("[MLInference] Cannot trigger manually - goal point required but not set") + return + } + + // Ensure we have a model loaded + guard let model = model else { + print("[MLInference] Cannot trigger manually - no model loaded") + return + } + + // Skip if inference already pending + guard !isInferencePending else { + print("[MLInference] Cannot trigger manually - inference already pending") + return + } + + // Store current frame for manual trigger + guard let currentEntry = currentFrameEntry else { + print("[MLInference] Cannot trigger manually - no current frame available") + return + } + + // Add frame to buffer following same logic as automatic inference + if isUSBStreamingActive { + frameBuffer.append(currentEntry) + if frameBuffer.count > maxBufferSize { + frameBuffer.removeFirst() + } + print("[MLInference] Manual trigger - USB mode: Frame added to rolling buffer (\(frameBuffer.count))") + } else { + frameBuffer.append(currentEntry) + if frameBuffer.count > maxBufferSize { + frameBuffer.removeFirst(frameBuffer.count - maxBufferSize) + } + print("[MLInference] Manual trigger - Recording mode: Action frame stored (\(frameBuffer.count))") + } + + // Mark inference as pending + isInferencePending = true + + // Prepare input using buffered frames + let modelInput: MLFeatureProvider + do { + modelInput = try prepareModelInputFromBuffer(metadata: metadata) + } catch { + print("ERROR: Failed to prepare model input for manual trigger: \(error)") + isInferencePending = false + return + } + + inferenceQueue.async { [weak self, modelInput, model] in + guard let self = self else { return } + + let startTime = CACurrentMediaTime() + + autoreleasepool { + do { + print("DEBUG: Running manual model prediction with buffered frames...") + let output = try model.prediction(from: modelInput) + print("DEBUG: Manual model prediction succeeded") + + let inferenceTime = CACurrentMediaTime() - startTime + self.processInferenceResults(output, inferenceTime: inferenceTime) + + // Reset pending flag and mark first inference complete + DispatchQueue.main.async { + self.isInferencePending = false + self.lastInferenceTime = CACurrentMediaTime() // Update for frequency tracking + if !self.hasRunFirstInference { + self.hasRunFirstInference = true + print("[MLInference] First inference complete (manual) - target cube should now be visible") + } else { + // For manual triggers after first inference, transition current target to fading + // This allows the new target to appear immediately + self.arVisualizationManager?.forceTargetTransition() + } + } + } catch { + print("ERROR: Manual model inference failed: \(error)") + DispatchQueue.main.async { + self.isInferencePending = false + } + } + } + } + } + + func setInferenceFrequency(_ frequency: InferenceFrequency) { + inferenceFrequency = frequency + print("Inference frequency: \(frequency.displayName)") + } + +} diff --git a/AnySense/Managers/ModelManager.swift b/AnySense/Managers/ModelManager.swift new file mode 100644 index 0000000..7bc151c --- /dev/null +++ b/AnySense/Managers/ModelManager.swift @@ -0,0 +1,485 @@ +import Foundation +import CoreML +import Combine + +// MARK: - Model Manager +class ModelManager: ObservableObject { + + // MARK: - Published Properties + @Published var availableModels: [ModelInfo] = [] + @Published var activeModel: ModelInfo? + @Published var isCompiling: Bool = false + @Published var compilationProgress: Double = 0.0 + @Published var compilationError: String? + + // MARK: - Private Properties + private var modelRegistry: ModelRegistry + private let registryURL: URL + private var cancellables = Set() + + // MARK: - Initialization + init() { + self.registryURL = ModelFileUtilities.modelsDirectory.appendingPathComponent("model_registry.json") + self.modelRegistry = ModelRegistry() + + loadModelRegistry() + setupBundledModel() + } + + // MARK: - Public Properties + var hasAvailableModel: Bool { + return !compiledModels.isEmpty + } + + var hasCompiledModel: Bool { + return activeModel?.source == .uploaded && activeModel?.compilationStatus.isCompiled == true + } + + var compiledModels: [ModelInfo] { + return availableModels.filter { $0.compilationStatus.isCompiled } + } + + var activeModelID: UUID? { + get { activeModel?.id } + set { + if let newID = newValue { + setActiveModel(id: newID) + } + } + } + + // MARK: - Model Registry Management + private func loadModelRegistry() { + do { + if FileManager.default.fileExists(atPath: registryURL.path) { + let data = try Data(contentsOf: registryURL) + modelRegistry = try JSONDecoder().decode(ModelRegistry.self, from: data) + availableModels = modelRegistry.models + activeModel = modelRegistry.activeModel + print("Loaded model registry with \(modelRegistry.models.count) models") + // Validate entries and fix stale active model pointing to removed files + validateAndFixRegistry() + } + } catch { + print("Failed to load model registry: \(error)") + modelRegistry = ModelRegistry() + } + } + + // Remove models whose files no longer exist and fix an invalid active model + private func validateAndFixRegistry() { + // Prune missing files + let original = availableModels + availableModels = availableModels.filter { getModelURL(for: $0) != nil } + if availableModels.count != original.count { + print("Registry cleanup: removed \(original.count - availableModels.count) missing model entries") + } + // Fix active model if missing + if let active = activeModel, getModelURL(for: active) == nil { + print("Active model missing on disk: \(active.name). Selecting a valid model...") + activeModel = nil + } + if activeModel == nil { + // Prioritize any bundled model, then fallback to any available model + let preferredModel = availableModels.first { model in + model.source == .bundled && getModelURL(for: model) != nil + } ?? availableModels.first { getModelURL(for: $0) != nil } + + if let next = preferredModel { + // Directly set without dispatching to avoid race during init + for i in availableModels.indices { availableModels[i].isActive = availableModels[i].id == next.id } + activeModel = next + modelRegistry.setActiveModel(id: next.id) + print("Switched active model to: \(next.name)") + } + } + saveModelRegistry() + } + + private func saveModelRegistry() { + do { + modelRegistry.models = availableModels + modelRegistry.activeModelID = activeModel?.id + + let data = try JSONEncoder().encode(modelRegistry) + try data.write(to: registryURL) + print("Saved model registry") + } catch { + print("Failed to save model registry: \(error)") + } + } + + // MARK: - Bundled Model Setup + private func setupBundledModel() { + // Find and register any bundled model dynamically + let bundledModelNames = findBundledModelNames() + guard let defaultModelName = bundledModelNames.first else { + print("No bundled CoreML models found - user will need to upload a model") + return + } + var added: [ModelInfo] = [] + + let alreadyExists = availableModels.contains { $0.source == .bundled && $0.name == defaultModelName } + guard !alreadyExists else { + // Model already registered, just ensure it's active if no active model + if activeModel == nil, let bundledModel = availableModels.first(where: { $0.source == .bundled && $0.name == defaultModelName }) { + setActiveModel(id: bundledModel.id) + print("Set default model: \(bundledModel.name)") + } + return + } + + // Find the model URL using existing logic + let tempModelInfo = ModelInfo(name: defaultModelName, fileName: "", source: .bundled) + guard let url = getModelURL(for: tempModelInfo) else { + print("Bundled model '\(defaultModelName)' not found in app bundle") + return + } + + // Determine file extension for fileName + let ext = url.pathExtension.isEmpty ? "mlpackage" : url.pathExtension + var info = ModelInfo( + name: defaultModelName, + fileName: "\(defaultModelName).\(ext)", + source: .bundled + ) + info.compilationStatus = .compiled + availableModels.append(info) + added.append(info) + print("Registered bundled model: \(defaultModelName)") + + // Set as default active model if none is active yet + if activeModel == nil { + setActiveModel(id: info.id) + print("Set default model: \(defaultModelName)") + } else { + print("Active model already exists: \(activeModel?.name ?? "Unknown")") + } + + if !added.isEmpty { saveModelRegistry() } + } + + private func findBundledModelNames() -> [String] { + let extensions = ["mlpackage", "mlmodelc", "mlmodel"] + var modelNames: [String] = [] + + for ext in extensions { + if let urls = Bundle.main.urls(forResourcesWithExtension: ext, subdirectory: nil) { + modelNames.append(contentsOf: urls.map { $0.deletingPathExtension().lastPathComponent }) + } + } + + return Array(Set(modelNames)) // Remove duplicates + } + + // MARK: - Model Upload and Compilation + func uploadAndCompileModel(from sourceURL: URL, withName customName: String? = nil) async throws { + + await MainActor.run { + isCompiling = true + compilationProgress = 0.0 + compilationError = nil + } + + do { + // Follow the best practices guide exactly + let shouldStopAccessing = sourceURL.startAccessingSecurityScopedResource() + defer { + if shouldStopAccessing { + sourceURL.stopAccessingSecurityScopedResource() + } + } + + print("DEBUG: Processing file from: \(sourceURL.path)") + + // Generate model name (strip known extensions) + let fileName = sourceURL.lastPathComponent + let ext = sourceURL.pathExtension.lowercased() + let baseName = sourceURL.deletingPathExtension().lastPathComponent + let modelName = customName ?? baseName + + // Check for duplicate names + if availableModels.contains(where: { $0.name == modelName }) { + throw ModelError.duplicateName("A model with this name already exists") + } + + // Decide import strategy by extension + let fm = FileManager.default + let uploadsDir = ModelFileUtilities.uploadedModelsDirectory + var fileSize: Int64 = 0 + var finalCompiledURL: URL? = nil + var localModelURL: URL? = nil + + switch ext { + case "mlmodel": + // Copy raw .mlmodel and compile + let local = try ModelFileUtilities.copyUploadedModel(from: sourceURL, withName: modelName) + localModelURL = local + print("DEBUG: Copied .mlmodel to local storage: \(local.path)") + guard fm.fileExists(atPath: local.path) else { throw ModelError.modelNotFound("Copied file does not exist") } + fileSize = MLModel.getModelSize(at: local) + case "mlpackage": + // Compile the .mlpackage directly from sourceURL (no need to copy the whole package first) + localModelURL = sourceURL + fileSize = MLModel.getModelSize(at: sourceURL) + case "mlmodelc": + // Already compiled; copy as-is into uploads directory with normalized name + let dest = uploadsDir.appendingPathComponent("\(modelName).mlmodelc") + if fm.fileExists(atPath: dest.path) { try fm.removeItem(at: dest) } + try fm.copyItem(at: sourceURL, to: dest) + finalCompiledURL = dest + fileSize = MLModel.getModelSize(at: dest) + print("DEBUG: Copied compiled model (.mlmodelc) to: \(dest.path)") + default: + throw ModelError.invalidFile("Unsupported file type. Use .mlmodel, .mlpackage, or .mlmodelc") + } + + let modelUploadDate = Date() // Use current date as upload date + + // Create model info (we'll validate compatibility after compilation) + let modelInfo = ModelInfo( + name: modelName, + fileName: fileName, + source: .uploaded, + fileSize: fileSize, + uploadDate: modelUploadDate + ) + + // Add to registry + await MainActor.run { + availableModels.append(modelInfo) + saveModelRegistry() + } + + // Update progress: file copied/prepared + await MainActor.run { + compilationProgress = 0.2 + } + + // If we imported a raw .mlmodel, compile it now + if (ext == "mlmodel" || ext == "mlpackage"), let local = localModelURL { + await MainActor.run { + compilationProgress = 0.3 + } + + let tempCompiledURL = try await MLModel.compileModel(at: local) { [weak self] progress in + Task { @MainActor in + // Map compile progress (0.0-1.0) to overall progress (0.3-0.9) + self?.compilationProgress = 0.3 + (progress * 0.6) + } + } + print("DEBUG: Compiled to temp location: \(tempCompiledURL.path)") + + await MainActor.run { + compilationProgress = 0.9 + } + + // Validate compiled + do { + let metadata = try MLModel.validateModel(at: tempCompiledURL) + guard metadata.isCompatible else { throw ModelError.incompatibleModel("Model format not compatible with app requirements") } + print("DEBUG: Model validation passed") + } catch { print("DEBUG: Model validation warning: \(error.localizedDescription)") } + + // Move compiled to uploads dir + finalCompiledURL = try ModelFileUtilities.replaceCompiledModel( + compiledURL: tempCompiledURL, + withName: modelName + ) + print("DEBUG: Final compiled location: \(finalCompiledURL!.path)") + } else if let compiled = finalCompiledURL { + await MainActor.run { + compilationProgress = 0.5 + } + + // Validate mlpackage/mlmodelc directly + do { + let metadata = try MLModel.validateModel(at: compiled) + print("DEBUG: Direct model validation: compatible=\(metadata.isCompatible)") + } catch { print("DEBUG: Model validation warning: \(error.localizedDescription)") } + } + + // Update model status + let modelId = modelInfo.id + + await MainActor.run { + if let index = availableModels.firstIndex(where: { $0.id == modelId }) { + availableModels[index].compilationStatus = .compiled + } + + isCompiling = false + compilationProgress = 1.0 + saveModelRegistry() + + // Automatically activate the newly uploaded model + setActiveModel(id: modelId) + + print("Successfully compiled model: \(modelName)") + } + + } catch { + await MainActor.run { + isCompiling = false + compilationError = error.localizedDescription + print("Failed to upload/compile model: \(error)") + } + throw error + } + } + + // MARK: - Model Management + func setActiveModel(id: UUID) { + print("DEBUG: setActiveModel called with id: \(id)") + print("DEBUG: Available models: \(availableModels.map { "\($0.name) (\($0.id))" })") + + // Ensure all UI updates happen on main thread + DispatchQueue.main.async { [weak self] in + guard let self = self else { return } + + // Trigger UI update before making changes + self.objectWillChange.send() + + // Deactivate all models + for i in self.availableModels.indices { + self.availableModels[i].isActive = false + } + + // Activate selected model + if let index = self.availableModels.firstIndex(where: { $0.id == id }) { + self.availableModels[index].isActive = true + self.activeModel = self.availableModels[index] + self.modelRegistry.setActiveModel(id: id) + self.saveModelRegistry() + + print("DEBUG: Switched to model: \(self.activeModel?.name ?? "Unknown")") + print("DEBUG: hasAvailableModel: \(self.hasAvailableModel)") + print("DEBUG: hasCompiledModel: \(self.hasCompiledModel)") + } else { + print("DEBUG: Model with id \(id) not found!") + } + } + } + + func deleteModel(id: UUID) throws { + guard let modelInfo = availableModels.first(where: { $0.id == id }) else { + throw ModelError.modelNotFound("Model not found") + } + + // Can't delete bundled models + guard modelInfo.source == .uploaded else { + throw ModelError.cannotDeleteBundled("Cannot delete built-in models") + } + + // Delete files + try ModelFileUtilities.deleteModel(fileName: modelInfo.fileName, isUploaded: modelInfo.source == .uploaded) + + // Remove from registry + availableModels.removeAll { $0.id == id } + + // If this was the active model, switch to bundled model + if activeModel?.id == id { + if let bundledModel = availableModels.first(where: { $0.source == .bundled }) { + setActiveModel(id: bundledModel.id) + } else { + activeModel = nil + } + } + + saveModelRegistry() + print("Deleted model: \(modelInfo.name)") + } + + // MARK: - Model Loading + func getModelURL(for modelInfo: ModelInfo) -> URL? { + switch modelInfo.source { + case .bundled: + // Use dynamic discovery for bundled models + let extensions = ["mlmodelc", "mlpackage", "mlmodel"] + for ext in extensions { + if let url = Bundle.main.url(forResource: modelInfo.name, withExtension: ext) { + return url + } + } + return nil + + case .uploaded: + // Prefer compiled .mlmodelc if present + let compiledURL = ModelFileUtilities.uploadedModelsDirectory + .appendingPathComponent("\(modelInfo.name).mlmodelc") + if FileManager.default.fileExists(atPath: compiledURL.path) { return compiledURL } + + // Support .mlpackage in uploaded directory + let packageURL = ModelFileUtilities.uploadedModelsDirectory + .appendingPathComponent("\(modelInfo.name).mlpackage") + if FileManager.default.fileExists(atPath: packageURL.path) { return packageURL } + + // Fallback to .mlmodel + let uploadedURL = ModelFileUtilities.uploadedModelsDirectory + .appendingPathComponent("\(modelInfo.name).mlmodel") + if FileManager.default.fileExists(atPath: uploadedURL.path) { return uploadedURL } + + return nil + } + } + + func loadModel(for modelInfo: ModelInfo) throws -> MLModel { + guard let modelURL = getModelURL(for: modelInfo) else { + throw ModelError.modelNotFound("Model file not found: \(modelInfo.name)") + } + + // Use async loading with configuration for better performance + let config = MLModelConfiguration() + config.computeUnits = .all // Use all available compute units + + return try MLModel(contentsOf: modelURL, configuration: config) + } + + func loadModelAsync(for modelInfo: ModelInfo) async throws -> MLModel { + guard let modelURL = getModelURL(for: modelInfo) else { + throw ModelError.modelNotFound("Model file not found: \(modelInfo.name)") + } + + return try await Task.detached(priority: .userInitiated) { + let config = MLModelConfiguration() + config.computeUnits = .all + return try MLModel(contentsOf: modelURL, configuration: config) + }.value + } + + func getActiveModelMetadata() -> ModelMetadata? { + guard let activeModel = activeModel, + let modelURL = getModelURL(for: activeModel) else { + return nil + } + + do { + let model = try MLModel(contentsOf: modelURL) + return try ModelMetadata(from: model) + } catch { + print("Failed to get model metadata: \(error)") + return nil + } + } +} + +// MARK: - Model Errors +enum ModelError: LocalizedError { + case modelNotFound(String) + case compilationFailed(String) + case incompatibleModel(String) + case duplicateName(String) + case cannotDeleteBundled(String) + case invalidFile(String) + + var errorDescription: String? { + switch self { + case .modelNotFound(let message), + .compilationFailed(let message), + .incompatibleModel(let message), + .duplicateName(let message), + .cannotDeleteBundled(let message), + .invalidFile(let message): + return message + } + } +} diff --git a/AnySense/Managers/TransformDebug.swift b/AnySense/Managers/TransformDebug.swift new file mode 100644 index 0000000..583a2ee --- /dev/null +++ b/AnySense/Managers/TransformDebug.swift @@ -0,0 +1,24 @@ +import Foundation +import simd + +struct TransformDebug { + static func runSamples(rotationUnit: ActionTransformUtils.RotationUnit = .eulerXYZ) { + print("===== Transform Debug Samples (rotationUnit=\(rotationUnit)) =====") + let samples: [[Float]] = [ + // [down, right, backward, r1, r2, r3, gripper] + [0.0, 0.0, 0.10, 0.0, 0.0, 0.0, 0.5], // backward (+z in policy) + [0.0, 0.10, 0.00, 0.0, 0.0, 0.0, 0.5], // right + [0.10, 0.0, 0.00, 0.0, 0.0, 0.0, 0.5], // down + [0.0, 0.0, 0.00, 0.10, 0.0, 0.0, 0.5], // roll (or rx axisAngle) + [0.0, 0.0, 0.00, 0.0, 0.10, 0.0, 0.5], // pitch (or ry) + [0.0, 0.0, 0.00, 0.0, 0.0, 0.10, 0.5], // yaw (or rz) + ] + for s in samples { + let report = ActionTransformUtils.debugTransformReport(s, rotationUnit: rotationUnit) + print(report) + print("--------------------------------------------------") + } + } +} + + diff --git a/AnySense/Managers/USBManager.swift b/AnySense/Managers/USBManager.swift index 75fbf39..9c7c233 100644 --- a/AnySense/Managers/USBManager.swift +++ b/AnySense/Managers/USBManager.swift @@ -21,7 +21,7 @@ struct Record3DHeader { var rgbWidth: UInt32 var rgbHeight: UInt32 var depthWidth: UInt32 - var depthHeight: UInt32 + var depthHeight: UInt32 var confidenceWidth: UInt32 var confidenceHeight: UInt32 var rgbSize: UInt32 @@ -29,6 +29,7 @@ struct Record3DHeader { var confidenceMapSize: UInt32 var miscSize: UInt32 var deviceType: UInt32 + // jointActions are always 28 bytes (7 floats), embedded in message body after RGB data } struct IntrinsicMatrixCoeffs { @@ -61,24 +62,21 @@ class USBManager { listener?.stateUpdateHandler = { state in switch state { case .ready: - print("Server ready and listening on port 1337") + print("USB listener ready") case .failed(let error): - print("Listener failed with error: \(error)") + print("USB listener failed: \(error)") default: break } } listener?.newConnectionHandler = { [weak self] connection in - print("Connection received") self?.handleConnection(connection: connection) - -// self?.sendData(connection: connection, message: "Hello from iPhone!") } listener?.start(queue: .main) } catch { - print("Failed to start listener: \(error)") + // Failed to start listener } } @@ -86,14 +84,12 @@ class USBManager { // Cancel the listener if it exists if let listener = listener { listener.cancel() - print("Listener cancelled") } listener = nil // Cancel the active connection if it exists if let connection = activeConnection { connection.cancel() - print("Connection cancelled") } activeConnection = nil } @@ -108,14 +104,14 @@ class USBManager { intrinsicMatData: Data, poseData: Data, rgbImageData: Data, + jointActionsData: Data, // Always exactly 28 bytes (7 floats) compressedDepthData: Data? = nil, compressedConfData: Data? = nil ) { guard let activeConnection = activeConnection else { - print("No active connection. Cannot send data.") return } - var messageBody = record3dHeaderData + intrinsicMatData + poseData + rgbImageData + var messageBody = record3dHeaderData + intrinsicMatData + poseData + rgbImageData + jointActionsData if let depthData = compressedDepthData { messageBody += depthData } @@ -127,12 +123,12 @@ class USBManager { let ptHeaderData = Data(bytes: &self.ptHeader, count:MemoryLayout.size) let completeMessage = ptHeaderData + messageBody - print("Sending data of size: \(completeMessage.count)") + + print("USB data: \(completeMessage.count) bytes total") + activeConnection.send(content:completeMessage, completion: .contentProcessed {error in if let error = error { - print("Failed to send data: \(error)") - } else { - print("Image data sent successfully") + // USB send failed } }) } @@ -141,9 +137,7 @@ class USBManager { let data = message.data(using: .utf8)! connection.send(content: data, completion: .contentProcessed { error in if let error = error { - print("Failed to send data: \(error)") - } else { - print("Data sent successfully") + // Data send failed } }) } @@ -151,7 +145,6 @@ class USBManager { func compressData(from pixelBuffer: CVPixelBuffer, isDepth: Bool) -> Data? { // Extract depth data guard let baseAddress = CVPixelBufferGetBaseAddress(pixelBuffer) else { - print("Failed to access depth buffer base address") return nil } @@ -179,11 +172,11 @@ class USBManager { ) guard compressedSize > 0 else { - print("Failed to compress depth map") return nil } // Return compressed depth data return Data(bytes: compressedBuffer, count: compressedSize) } + } diff --git a/AnySense/ModelInfo.swift b/AnySense/ModelInfo.swift new file mode 100644 index 0000000..d9e38ff --- /dev/null +++ b/AnySense/ModelInfo.swift @@ -0,0 +1,134 @@ +import Foundation + +// MARK: - Model Information +struct ModelInfo: Identifiable, Codable, Equatable { + let id: UUID + let name: String + let fileName: String + let source: ModelSource + var compilationStatus: CompilationStatus + let fileSize: Int64 + let uploadDate: Date + var isActive: Bool + + init(name: String, fileName: String, source: ModelSource, fileSize: Int64 = 0, uploadDate: Date? = nil) { + self.id = UUID() + self.name = name + self.fileName = fileName + self.source = source + self.compilationStatus = source == .bundled ? .compiled : .notCompiled + self.fileSize = fileSize + self.uploadDate = uploadDate ?? Date() + self.isActive = false + } + + var displayName: String { + switch source { + case .bundled: + return "\(name) (Built-in)" + case .uploaded: + return name + } + } + + var statusDescription: String { + switch compilationStatus { + case .notCompiled: + return "Not compiled" + case .compiling(let progress): + return "Compiling (\(Int(progress * 100))%)" + case .compiled: + return "Ready" + case .failed(let error): + return "Failed: \(error)" + } + } +} + +// MARK: - Model Source +enum ModelSource: String, Codable, CaseIterable { + case bundled + case uploaded + + var displayName: String { + switch self { + case .bundled: + return "Built-in" + case .uploaded: + return "Uploaded" + } + } +} + +// MARK: - Compilation Status +enum CompilationStatus: Codable, Equatable { + case notCompiled + case compiling(progress: Double) + case compiled + case failed(error: String) + + var isCompiled: Bool { + if case .compiled = self { + return true + } + return false + } + + var isCompiling: Bool { + if case .compiling = self { + return true + } + return false + } + + var progress: Double { + if case .compiling(let progress) = self { + return progress + } + return 0.0 + } +} + +// MARK: - Model Registry +struct ModelRegistry: Codable { + var models: [ModelInfo] + var activeModelID: UUID? + let version: String = "1.0" + + init() { + self.models = [] + self.activeModelID = nil + } + + mutating func addModel(_ model: ModelInfo) { + models.append(model) + } + + mutating func updateModel(_ model: ModelInfo) { + if let index = models.firstIndex(where: { $0.id == model.id }) { + models[index] = model + } + } + + mutating func removeModel(id: UUID) { + models.removeAll { $0.id == id } + if activeModelID == id { + activeModelID = models.first?.id + } + } + + mutating func setActiveModel(id: UUID) { + for i in models.indices { + models[i].isActive = models[i].id == id + } + activeModelID = id + } + + var activeModel: ModelInfo? { + return models.first { $0.isActive } + } + + var compiledModels: [ModelInfo] { + return models.filter { $0.compilationStatus.isCompiled } + } +} diff --git a/AnySense/Views/ContentView.swift b/AnySense/Views/ContentView.swift index 0b0aecc..4b80216 100644 --- a/AnySense/Views/ContentView.swift +++ b/AnySense/Views/ContentView.swift @@ -6,12 +6,11 @@ // import SwiftUI -import CoreBluetooth import AVFoundation struct ContentView: View { @EnvironmentObject var appStatus : AppInformation - @EnvironmentObject var bluetoothManager: BluetoothManager + @StateObject private var modelManager = ModelManager() @StateObject private var arViewModel = ARViewModel() @State private var hasPermissions = false @@ -43,9 +42,10 @@ struct ContentView: View { } .onAppear { checkPermissions() + setupModelManager() } .fullScreenCover(isPresented: $showMainPage) { - MainPage(arViewModel: arViewModel) + MainPage(arViewModel: arViewModel, modelManager: modelManager) } .alert(isPresented: $showPermissionAlert) { Alert( @@ -70,6 +70,11 @@ struct ContentView: View { } } + private func setupModelManager() { + // Initialize the ML inference manager with model manager + arViewModel.initializeMLManager(with: modelManager) + } + private func openAppSettings() { if let url = URL(string: UIApplication.openSettingsURLString) { UIApplication.shared.open(url) @@ -103,6 +108,11 @@ class AppInformation : ObservableObject{ @Published var gridProjectionTrigger: GridMode = .off @Published var colorMapTrigger: Bool = false @Published var ifBluetoothConnected: Bool = false + @Published var ifAudioRecordingEnabled: Bool = false + + // MARK: - Inference Settings + @Published var showGripperOverlay: Bool = true + @Published var enableGripperOverlayInModel: Bool = true } diff --git a/AnySense/Views/InferenceView.swift b/AnySense/Views/InferenceView.swift new file mode 100644 index 0000000..a6c9db6 --- /dev/null +++ b/AnySense/Views/InferenceView.swift @@ -0,0 +1,419 @@ +// +// InferenceView.swift +// Anysense +// +// Created by Krish Mehta +// ML Inference and AR Visualization View +// + +import SwiftUI +import UIKit +import CoreBluetooth +import BackgroundTasks +import UserNotifications +import Foundation +import AVFoundation +import ARKit + +// MARK: - InferenceView Overlay +struct InferenceViewOverlay: View { + @EnvironmentObject var appStatus: AppInformation + @ObservedObject var arViewModel: ARViewModel + @State var openFlash = true + @State private var deviceOrientation: UIDeviceOrientation = .unknown + @State private var isLandscape = false + + private var instructionRotation: Angle { + switch deviceOrientation { + case .landscapeLeft: + return .degrees(90) + case .landscapeRight: + return .degrees(-90) + default: + return .degrees(0) + } + } + + var body: some View { + GeometryReader { geometry in + let screenWidth = geometry.size.width + let screenHeight = geometry.size.height + let arViewHeight = min(screenWidth * 1.33, 0.75 * screenHeight) + let arViewWidth = min(arViewHeight / 1.33, screenWidth) + let arViewPadding = 0.2 * arViewHeight + let buttonSize: CGFloat = min(screenWidth * 0.25, 80) + let btBarHeight: CGFloat = 25.0 + let gridSize = appStatus.gridProjectionTrigger.rawValue + + ZStack { + // Transparent background - AR view shows through from MainPage + Color.clear + + ZStack { + // Gripper Overlay on AR View + if let mlManager = arViewModel.mlManager, + let overlayImage = mlManager.currentGripperOverlayImage { + Image(uiImage: overlayImage) + .resizable() + .scaledToFit() + .allowsHitTesting(false) + } + + // Guided Flow Instructions + VStack { + Spacer() + + let instructionText: String = { + guard isLandscape else { + return "Hold landscape" + } + + guard let mlManager = arViewModel.mlManager else { + return "Loading model…" + } + + if arViewModel.isInferenceEpisodeFinished { + return "Episode finished!" + } + + if arViewModel.isInferencePlaying { + return "Follow the arrows" + } + + if mlManager.requiresGoalPoint { + if mlManager.currentGoalPoint == nil { + return arViewModel.goalTapModeEnabled + ? "Click on an object to track" + : "Click on set goal" + } + return "Press play to start demo" + } + + return "Press play to start demo" + }() + + HStack { + Spacer() + Text(instructionText) + .font(.system(size: 14, weight: .medium)) + .multilineTextAlignment(.center) + .foregroundColor(.white.opacity(0.9)) + .padding(.horizontal, 16) + .padding(.vertical, 8) + .background(Color.gray.opacity(0.6)) + .cornerRadius(8) + .shadow(color: .black.opacity(0.2), radius: 2, x: 0, y: 1) + .rotationEffect(isLandscape ? instructionRotation : .degrees(0)) + .animation(.easeInOut, value: instructionText) + .animation(.easeInOut, value: isLandscape) + Spacer() + } + + Spacer() + .frame(height: 60) + } + .allowsHitTesting(false) + + // Manual Next Action Button + if let mlManager = arViewModel.mlManager, + mlManager.isInferenceEnabled && arViewModel.isInferencePlaying && !arViewModel.isInferenceEpisodeFinished { + VStack { + HStack { + Spacer() + VStack(spacing: 2) { + Button(action: { + mlManager.triggerInferenceManually() + UIImpactFeedbackGenerator(style: appStatus.hapticFeedbackLevel).impactOccurred() + }) { + Image(systemName: "arrow.forward.circle.fill") + .font(.title2) + .foregroundColor(.white) + .background( + Circle() + .fill(Color.blue.opacity(0.8)) + .frame(width: 44, height: 44) + ) + .shadow(color: .black.opacity(0.3), radius: 4, x: 0, y: 2) + } + + Text("Get next action") + .font(.system(size: 9, weight: .medium)) + .foregroundColor(.white) + .padding(.horizontal, 4) + .padding(.vertical, 1) + .background(Color.black.opacity(0.7)) + .cornerRadius(3) + .shadow(color: .black.opacity(0.4), radius: 1, x: 0, y: 0.5) + } + .padding(.trailing, 12) + .padding(.top, 12) + } + Spacer() + } + } + + // Model Loading Indicator + if let mlManager = arViewModel.mlManager, mlManager.isModelLoading { + ZStack { + Color.black.opacity(0.4) + .ignoresSafeArea() + VStack(spacing: 16) { + ProgressView() + .controlSize(.large) + .tint(.white) + Text("Preparing Model...") + .font(.headline) + .foregroundColor(.white) + } + } + .transition(.opacity) + .zIndex(100) + } + } + .frame(width: arViewWidth, height: arViewHeight) + .padding(.bottom, arViewPadding) + + // ML Status Overlay + VStack { + HStack { + VStack(alignment: .leading, spacing: 8) { + if arViewModel.mlManager?.isInferenceEnabled == true { + if let mlManager = arViewModel.mlManager { + MLInferenceResultsView(mlManager: mlManager) + } + } + } + Spacer() + } + Spacer() + .padding(.bottom, 10) + } + .frame(width: arViewWidth, height: arViewHeight) + .padding(.bottom, arViewPadding) + .allowsHitTesting(false) + + // Top bar with notch area + Bluetooth status + VStack(spacing: 0) { + // White bar for notch/safe area + Color.white + .frame(height: geometry.safeAreaInsets.top) + + // Bluetooth status bar + Text(appStatus.ifBluetoothConnected ? "bluetooth device connected" : "bluetooth device disconnected") + .font(.footnote) + .foregroundColor(Color.white) + .frame(maxWidth: .infinity) + .frame(height: btBarHeight) + .background(appStatus.ifBluetoothConnected ? .green : .red) + + Spacer() + } + .ignoresSafeArea(edges: .top) + + // Grid overlay + if appStatus.gridProjectionTrigger.rawValue > 0 { + VStack { + Path { path in + for col in 1.. some View { + configuration.label.scaleEffect(isPlaying ? 0.55 : 1) + } + } + + func toggleInferencePlayback() { + if arViewModel.isOpen { + if !arViewModel.isInferencePlaying { + arViewModel.startInferencePlayback() + } else { + arViewModel.stopInferencePlayback(reset: true) + arViewModel.goalTapModeEnabled = false + arViewModel.mlManager?.clearGoalPoint() + arViewModel.arVisualizationManager.clearTargetPose() + } + } + UIImpactFeedbackGenerator(style: appStatus.hapticFeedbackLevel).impactOccurred() + } + + func toggleFlash() { + guard let device = AVCaptureDevice.default(for: AVMediaType.video) else { return } + if device.hasTorch { + do { + try device.lockForConfiguration() + device.torchMode = openFlash ? .on : .off + device.unlockForConfiguration() + } catch { + print("Flash could not be used") + } + } + openFlash = !openFlash + UIImpactFeedbackGenerator(style: appStatus.hapticFeedbackLevel).impactOccurred() + } + + private func updateOrientation() { + let orientation = UIDevice.current.orientation + guard orientation.isValidInterfaceOrientation else { return } + deviceOrientation = orientation + isLandscape = orientation.isLandscape + } +} + +#Preview { + InferenceViewOverlay(arViewModel: ARViewModel()) + .environmentObject(AppInformation()) +} \ No newline at end of file diff --git a/AnySense/Views/MLInferenceResultsView.swift b/AnySense/Views/MLInferenceResultsView.swift new file mode 100644 index 0000000..51cc693 --- /dev/null +++ b/AnySense/Views/MLInferenceResultsView.swift @@ -0,0 +1,96 @@ +// +// MLInferenceResultsView.swift +// AnySense +// +// Created by Krish on 2025/2/1. +// + +import SwiftUI + +struct MLInferenceResultsView: View { + @ObservedObject var mlManager: MLInferenceManager + + var body: some View { + VStack(alignment: .leading, spacing: 6) { + Text("Inference State") + .foregroundColor(.white) + .font(.subheadline) + .fontWeight(.semibold) + + if mlManager.isInferencePendingUI { + HStack(spacing: 6) { + ProgressView() + .progressViewStyle(.circular) + .tint(.white) + .scaleEffect(0.8) + Text(mlManager.latestResult == nil && mlManager.lastResult == nil ? "Analyzing..." : "Updating...") + .foregroundColor(.white.opacity(0.8)) + .font(.caption) + } + } + + if let result = mlManager.latestResult ?? mlManager.lastResult { + GripperBlock(result: result) + } else if !mlManager.isInferencePendingUI { + Text("Analyzing...") + .foregroundColor(.white.opacity(0.7)) + .font(.caption) + .italic() + } + } + .padding(8) + .background( + RoundedRectangle(cornerRadius: 8) + .fill(Color.black.opacity(0.75)) + .overlay( + RoundedRectangle(cornerRadius: 8) + .stroke(Color.white.opacity(0.2), lineWidth: 1) + ) + ) + .frame(maxWidth: 160) + } +} + +// MARK: - Gripper Subview +private struct GripperBlock: View { + let result: InferenceResult + + private var gripperValue: Float { + return result.jointPositions.count >= 7 ? result.jointPositions[6] : 0.0 + } + + private var gripperState: String { + return gripperValue < 0.7 ? "CLOSED" : "OPEN" + } + + private var stateColor: Color { + return gripperValue < 0.7 ? .red : .green + } + + var body: some View { + VStack(alignment: .leading, spacing: 4) { + HStack { + Text("Value:") + .foregroundColor(.white.opacity(0.7)) + .font(.caption) + Text(String(format: "%.3f", gripperValue)) + .foregroundColor(.orange) + .font(.caption) + .fontWeight(.medium) + .fontDesign(.monospaced) + Spacer() + } + + HStack { + Text("State:") + .foregroundColor(.white.opacity(0.7)) + .font(.caption) + Text(gripperState) + .foregroundColor(stateColor) + .font(.caption) + .fontWeight(.bold) + Spacer() + } + } + } +} diff --git a/AnySense/Views/accountView.swift b/AnySense/Views/accountView.swift index 840fbcd..fd00cf5 100644 --- a/AnySense/Views/accountView.swift +++ b/AnySense/Views/accountView.swift @@ -6,17 +6,41 @@ // import SwiftUI +import CoreBluetooth +import AVFoundation +import UniformTypeIdentifiers struct SettingsView : View{ @EnvironmentObject var appStatus: AppInformation + @ObservedObject var arViewModel: ARViewModel + let modelManager: ModelManager - let frequencyOptions = ["0.1", "0.05", "0.033", "0.02", "0.017", "0.01"] // Frequency options + // File picker state + @State private var showingFilePicker = false + @State private var showingAlert = false + @State private var alertMessage = "" + + // Track current frequency for UI updates + @State private var currentFrequencyIndex: Int = 1 + + // Map available inference frequencies to picker choices + private let inferenceOptions: [MLInferenceManager.InferenceFrequency] = MLInferenceManager.InferenceFrequency.allCases + + // Helper function for short display names + private func shortDisplayName(for frequency: MLInferenceManager.InferenceFrequency) -> String { + switch frequency { + case .high: return "30 Hz" + case .medium: return "1 Hz" + case .low: return "0.1 Hz" + case .minute: return "0.017 Hz" + } + } var body : some View{ ZStack{ Color.customizedBackground .ignoresSafeArea() - Form{ + Form { Section(header: Text("GENERAL")) { HStack { VStack(alignment: .leading, spacing: 8) { @@ -32,13 +56,10 @@ struct SettingsView : View{ let binding = Binding( get: { appStatus.rgbdVideoStreaming }, set: { newValue in - if newValue != StreamingMode.wifi { // Disable Option wifi - appStatus.rgbdVideoStreaming = newValue - } + appStatus.rgbdVideoStreaming = newValue } ) - Picker("Streaming Options", selection: binding) { // Temporary fix to keep wifi option but disable it - Text("Wi-Fi").tag(StreamingMode.wifi).opacity(0.5) + Picker("Streaming Options", selection: binding) { // Text("USB").tag(StreamingMode.usb) Text("Off").tag(StreamingMode.off) } @@ -47,6 +68,9 @@ struct SettingsView : View{ } .padding(.vertical, 5) .padding(.vertical, 5) + HStack{ + Toggle("Audio recording enabled", isOn: $appStatus.ifAudioRecordingEnabled) + } HStack{ Text("Buttons haptic feedback") .font(.body) @@ -76,39 +100,184 @@ struct SettingsView : View{ .foregroundStyle(.gray) } } - } -// Section(header: Text("INFO")) { -// NavigationLink { -// InstructionView() -// } label: { -// HStack { -// Text("How to use?") -// .font(.body) -// .foregroundColor(.black) -// Spacer() -// } -// } -// NavigationLink { -// fileMarkdownView() -// } label: { -// HStack { -// Text("About") -// .font(.body) -// .foregroundColor(.black) -// Spacer() -// } -// } -// } + + // MARK: - Model Management Section + Section(header: Text("MODEL MANAGEMENT")) { + // Upload Model Button + Button("Upload Model") { + showingFilePicker = true + } + .foregroundColor(.blue) + .sheet(isPresented: $showingFilePicker) { + ModelImporter(onPickDocument: handleModelUpload) + } + + // Compilation Progress + if modelManager.isCompiling { + HStack { + Text("Compiling model...") + .font(.body) + .foregroundColor(.primary) + Spacer() + VStack(alignment: .trailing, spacing: 4) { + ProgressView(value: modelManager.compilationProgress) + .frame(width: 100) + Text("\(Int(modelManager.compilationProgress * 100))%") + .font(.caption) + .foregroundColor(.gray) + } + } + .padding(.vertical, 5) + } + + // Model Selection (when compiled models available) + if !modelManager.compiledModels.isEmpty { + Picker("Select Model", selection: Binding( + get: { + let activeID = modelManager.activeModelID + // print("DEBUG: Picker get - activeModelID: \(String(describing: activeID))") + return activeID + }, + set: { newValue in + // print("DEBUG: Picker set - newValue: \(String(describing: newValue))") + if let newValue = newValue { + // Force immediate UI update + DispatchQueue.main.async { + modelManager.setActiveModel(id: newValue) + } + } + } + )) { + ForEach(modelManager.compiledModels) { model in + Text(model.displayName).tag(model.id as UUID?) + } + } + .pickerStyle(MenuPickerStyle()) + .padding(.vertical, 5) + .id(modelManager.activeModelID?.uuidString ?? "none") // Force refresh when activeModel changes + } + } + + // MARK: - Inference Settings Section + Section(header: Text("INFERENCE SETTINGS")) { + // Inference is tab-scoped (enabled automatically in the Inference tab) + VStack(alignment: .leading, spacing: 4) { + Text("Inference runs automatically in the Inference tab") + .font(.caption) + .foregroundColor(.gray) + } + .padding(.vertical, 5) + + // Inference Frequency Slider + if let mlManager = arViewModel.mlManager { + VStack(alignment: .leading, spacing: 12) { + Text("Inference Frequency") + .font(.body) + .foregroundColor(.primary) + + let sliderBinding = Binding( + get: { + Double(currentFrequencyIndex) + }, + set: { newValue in + let index = Int(newValue.rounded()) + if index >= 0 && index < inferenceOptions.count { + currentFrequencyIndex = index + mlManager.setInferenceFrequency(inferenceOptions[index]) + } + } + ) + + Slider(value: sliderBinding, + in: 0...Double(inferenceOptions.count - 1), + step: 1) + + HStack { + ForEach(0..) { + switch result { + case .success(let url): + Task { + do { + try await modelManager.uploadAndCompileModel(from: url) + + DispatchQueue.main.async { + alertMessage = "Model uploaded and compiled successfully!" + showingAlert = true + } + } catch { + DispatchQueue.main.async { + alertMessage = "Failed to upload model: \(error.localizedDescription)" + showingAlert = true + } + } + } + + case .failure(let error): + alertMessage = "Failed to select file: \(error.localizedDescription)" + showingAlert = true + } } + } enum StreamingMode: String { case off = "Off" - case wifi = "Wi-Fi" case usb = "USB" } @@ -118,7 +287,54 @@ enum GridMode: Int { case _5x5 = 5 } +// MARK: - Model Importer +struct ModelImporter: UIViewControllerRepresentable { + let onPickDocument: (Result) -> Void + + func makeUIViewController(context: Context) -> UIDocumentPickerViewController { + // Prefer system-declared UTIs; fall back to filename extensions for safety + let mlmodel = UTType(importedAs: "com.apple.coreml.model") + let mlpackage = UTType(importedAs: "com.apple.coreml.modelpackage") + // Compiled model UTI name varies by SDK; use a broad fallback as well + let mlmodelc = UTType(importedAs: "com.apple.coreml.compiled-model") + + var allowedTypes: [UTType] = [mlmodel, mlpackage, mlmodelc, .package, .data, .item] + // Add filename-extension fallbacks to catch older devices + if let byExt1 = UTType(filenameExtension: "mlmodel") { allowedTypes.append(byExt1) } + if let byExt2 = UTType(filenameExtension: "mlmodelc") { allowedTypes.append(byExt2) } + if let byExt3 = UTType(filenameExtension: "mlpackage") { allowedTypes.append(byExt3) } + + let picker = UIDocumentPickerViewController(forOpeningContentTypes: allowedTypes) + picker.delegate = context.coordinator + picker.allowsMultipleSelection = false + return picker + } + + func updateUIViewController(_ uiViewController: UIDocumentPickerViewController, context: Context) {} + + func makeCoordinator() -> Coordinator { + Coordinator(self) + } + + class Coordinator: NSObject, UIDocumentPickerDelegate { + let parent: ModelImporter + + init(_ parent: ModelImporter) { + self.parent = parent + } + + func documentPicker(_ controller: UIDocumentPickerViewController, didPickDocumentsAt urls: [URL]) { + guard let url = urls.first else { return } + parent.onPickDocument(.success(url)) + } + + func documentPickerWasCancelled(_ controller: UIDocumentPickerViewController) { + // User cancelled - no action needed + } + } +} + #Preview { - SettingsView() + SettingsView(arViewModel: ARViewModel(), modelManager: ModelManager()) .environmentObject(AppInformation()) } diff --git a/AnySense/Views/peripheralView.swift b/AnySense/Views/peripheralView.swift index 8ea7db5..892b454 100644 --- a/AnySense/Views/peripheralView.swift +++ b/AnySense/Views/peripheralView.swift @@ -9,8 +9,8 @@ import SwiftUI import CoreBluetooth struct singleBLEPeripheral: View { - @EnvironmentObject var bluetoothManager: BluetoothManager @EnvironmentObject var appStatus: AppInformation + @ObservedObject var arViewModel: ARViewModel @State private var isConnected = false let name: String let uuid: UUID @@ -35,7 +35,7 @@ struct singleBLEPeripheral: View { UIImpactFeedbackGenerator(style: appStatus.hapticFeedbackLevel).impactOccurred() if !isConnected{ - bluetoothManager.connectToPeripheral(withUUID: uuid) { result in + arViewModel.getBLEManagerInstance().connectToPeripheral(withUUID: uuid) { result in switch result { case .success(let connectedPeripheral): print("Successfully connected to: \(connectedPeripheral.name ?? "Unknown Device")") @@ -46,7 +46,7 @@ struct singleBLEPeripheral: View { appStatus.ifBluetoothConnected = true } else { appStatus.ifBluetoothConnected = false - bluetoothManager.disconnectFromDevice() + arViewModel.getBLEManagerInstance().disconnectFromDevice() } isConnected = !isConnected @@ -55,19 +55,21 @@ struct singleBLEPeripheral: View { struct PeripheralView: View { @EnvironmentObject var appStatus : AppInformation - @EnvironmentObject var bluetoothManager: BluetoothManager + @ObservedObject var arViewModel: ARViewModel + @ObservedObject var bluetoothManager: BluetoothManager var body: some View { VStack{ Text("Devices Detected") .font(.body) - .frame(width: 500.0, height: 50) - .ignoresSafeArea() + .frame(maxWidth: .infinity) + .frame(height: 50) .foregroundStyle(.deviceWord) .background(.deviceTop) .padding(.top, 5) List(Array(bluetoothManager.discoveredPeripherals.keys), id: \.self) { uuid in if let peripheral = bluetoothManager.discoveredPeripherals[uuid] { singleBLEPeripheral( + arViewModel: arViewModel, name: peripheral.name ?? "Unknown Device", uuid: peripheral.identifier ) @@ -76,10 +78,7 @@ struct PeripheralView: View { .scrollContentBackground(.hidden) .background(Color.customizedBackground) } + .frame(maxWidth: .infinity, maxHeight: .infinity) .background(Color.customizedBackground) } } - -#Preview { - PeripheralView().environmentObject(AppInformation()).environmentObject(BluetoothManager()) -} diff --git a/AnySense/Views/readView.swift b/AnySense/Views/readView.swift index 9b0592b..62c5e3a 100644 --- a/AnySense/Views/readView.swift +++ b/AnySense/Views/readView.swift @@ -12,24 +12,23 @@ import BackgroundTasks import UserNotifications import Foundation import AVFoundation +import ARKit -enum ActiveAlert { +enum ReadActiveAlert { case first, second } -struct ReadView : View{ - @EnvironmentObject var appStatus : AppInformation - @EnvironmentObject var bluetoothManager: BluetoothManager +// MARK: - ReadView Overlay +struct ReadViewOverlay: View { + @EnvironmentObject var appStatus: AppInformation @ObservedObject var arViewModel: ARViewModel - @State private var isReading = false - @State var showingAlert : Bool = false - @Environment(\.scenePhase) private var phase + @State var showingAlert: Bool = false @State private var fileSetNames: RecordingFiles? - @State var showingExporter = false @State var openFlash = true - @State private var activeAlert: ActiveAlert = .first + @State private var activeAlert: ReadActiveAlert = .first @State private var isRecordedOnce: Bool = false - var body : some View{ + + var body: some View { let paths = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask) GeometryReader { geometry in let screenWidth = geometry.size.width @@ -37,243 +36,169 @@ struct ReadView : View{ let arViewHeight = min(screenWidth * 1.33, 0.75 * screenHeight) let arViewWidth = min(arViewHeight / 1.33, screenWidth) let arViewPadding = 0.2 * arViewHeight - let buttonSize: CGFloat = min(screenWidth * 0.3, 100) -// let buttonPadding: CGFloat = + let buttonSize: CGFloat = min(screenWidth * 0.25, 80) let btBarHeight: CGFloat = 25.0 let gridSize = appStatus.gridProjectionTrigger.rawValue - + ZStack { - // Apply the custom background color - Color.customizedBackground - .ignoresSafeArea() - ZStack{ - ARViewContainer(session: arViewModel.session) - .edgesIgnoringSafeArea(.all) - .frame(width: arViewWidth, height: arViewHeight) - // .clipShape(RoundedRectangle(cornerRadius: 30, style: .continuous)) - .padding(.bottom, arViewPadding) - .opacity(appStatus.rgbdVideoStreaming == .off ? 1 : 0) - .allowsHitTesting(appStatus.rgbdVideoStreaming == .off) // Disable interaction in streaming mode - if appStatus.rgbdVideoStreaming == .off { + Color.clear + // Top bar with notch area + Bluetooth status + VStack(spacing: 0) { + // White bar for notch/safe area + Color.white + .frame(height: geometry.safeAreaInsets.top) + + // Bluetooth status bar Text(appStatus.ifBluetoothConnected ? "bluetooth device connected" : "bluetooth device disconnected") .font(.footnote) .foregroundColor(Color.white) - .frame(width: screenWidth, height: btBarHeight) + .frame(maxWidth: .infinity) + .frame(height: btBarHeight) .background(appStatus.ifBluetoothConnected ? .green : .red) - .padding(.bottom, arViewPadding + arViewHeight + btBarHeight) - .ignoresSafeArea() - if appStatus.gridProjectionTrigger.rawValue > 0 { - VStack { - Path { path in - for col in 1.. 0 { + VStack { + Path { path in + for col in 1.. some View { + configuration.label.scaleEffect(isRecording ? 0.35 : 1) + } } private func initCode() { @@ -289,44 +220,35 @@ struct ReadView : View{ } private func handleStreamingModeChange(from oldMode: StreamingMode, to newMode: StreamingMode) { - if isReading { + if arViewModel.isRecording { toggleRecording(mode: oldMode) } switch (oldMode, newMode) { case (_, .off): arViewModel.killUSBStreaming() - print("Switched to \(newMode): ARView is active.") - - case (_, .wifi): - print("NOT IMPLEMENTED.") case (_, .usb): - print("Switched to \(newMode): ARView is hidden.") arViewModel.setupUSBStreaming() } } func toggleRecording(mode: StreamingMode) { - isReading = !isReading if arViewModel.isOpen { if mode == .off { - if isReading { - fileSetNames = arViewModel.startRecording() - if(bluetoothManager.ifConnected){ - startRecordingBT(targetURL: fileSetNames!.tactileFile) + if !arViewModel.isRecording { + if let files = arViewModel.startRecording() { + fileSetNames = files + if arViewModel.getBLEManagerInstance().ifConnected { + arViewModel.startBluetoothRecording(targetURL: files.tactileFile, fps: appStatus.animationFPS) + } } - -// print(fileSetNames) } else { - if(bluetoothManager.ifConnected){ - stopRecordingBT() - print("This stop recording is when shared bluetooth manager is connected") + if arViewModel.getBLEManagerInstance().ifConnected { + arViewModel.stopBluetoothRecording() } arViewModel.stopRecording() - } - } - else if mode == .usb { - if isReading { + } else if mode == .usb { + if !arViewModel.isUSBStreamingActive { arViewModel.startUSBStreaming() } else { arViewModel.stopUSBStreaming() @@ -334,66 +256,24 @@ struct ReadView : View{ } } UIImpactFeedbackGenerator(style: appStatus.hapticFeedbackLevel).impactOccurred() - } - func toggleFlash() { - guard let device = AVCaptureDevice.default(for: AVMediaType.video) - else {return} + guard let device = AVCaptureDevice.default(for: AVMediaType.video) else { return } if device.hasTorch { do { try device.lockForConfiguration() - if openFlash == true { device.torchMode = .on // set on - } else { - device.torchMode = .off // set off - } + device.torchMode = openFlash ? .on : .off device.unlockForConfiguration() } catch { print("Flash could not be used") } - } else { - print("Flash is not available") } openFlash = !openFlash UIImpactFeedbackGenerator(style: appStatus.hapticFeedbackLevel).impactOccurred() } - - func startRecordingBT(targetURL:URL) { - do { - try createFile(fileURL: targetURL) - } - catch { - print("Error creating tactile file.") - } - - bluetoothManager.startRecording( - targetURL: targetURL, - fps: appStatus.animationFPS - ) - } - - func stopRecordingBT() { - bluetoothManager.stopRecording() - } - - func createDocumentaryFolderFiles(paths: [URL], fileSetNames: RecordingFiles?) -> [FileElement] { - guard let fileSetNames = fileSetNames else { - print("❌ Error: Insufficient paths or fileSetNames elements") - return [] - } - - let rgbFile = FileElement.videoFile(VideoFile(url:fileSetNames.rgbFileName)) - let depthFile = FileElement.videoFile(VideoFile(url:fileSetNames.depthFileName)) - let poseFile = FileElement.textFile(TextFile(url:fileSetNames.poseFile.path)) -// let rgbImageFolder = FileElement.directory(SubLevelDirectory(url:fileSetNames.rgbImagesDirectory)) - let depthImageFolder = FileElement.directory(SubLevelDirectory(url: fileSetNames.depthImagesDirectory)) - - return [rgbFile, depthFile, poseFile, depthImageFolder] - } - - func deleteRecordedData(url: [URL], targetDirect: String){ + func deleteRecordedData(url: [URL], targetDirect: String) { do { let urlToDelete = url[0].appendingPathComponent(targetDirect) try FileManager.default.removeItem(at: urlToDelete) @@ -402,16 +282,10 @@ struct ReadView : View{ print("Error deleting file: \(error)") } } - - - } - - - #Preview { - ReadView(arViewModel: ARViewModel()) + ReadViewOverlay(arViewModel: ARViewModel()) .environmentObject(AppInformation()) } - + \ No newline at end of file diff --git a/AnySense/dataStorage.swift b/AnySense/dataStorage.swift index 739cd68..0386572 100644 --- a/AnySense/dataStorage.swift +++ b/AnySense/dataStorage.swift @@ -75,14 +75,50 @@ struct ImageFile: FileDocument { } + +// MARK: - ML Model File Document +struct MLModelFile: FileDocument { + var url: URL + + static var readableContentTypes: [UTType] { + // Support .mlmodel, .mlmodelc, and .mlpackage files + [ + UTType(filenameExtension: "mlmodel")!, + UTType(filenameExtension: "mlmodelc")!, + UTType(filenameExtension: "mlpackage")! + ] + } + static var writableContentTypes: [UTType] { + [ + UTType(filenameExtension: "mlmodel")!, + UTType(filenameExtension: "mlmodelc")!, + UTType(filenameExtension: "mlpackage")! + ] + } + + init(url: URL) { + self.url = url + } + + init(configuration: ReadConfiguration) throws { + self.url = URL(fileURLWithPath: "") + } + + func fileWrapper(configuration: WriteConfiguration) throws -> FileWrapper { + return try FileWrapper(url: url, options: .immediate) + } +} + enum FileElement { case videoFile(VideoFile) case textFile(TextFile) case directory(SubLevelDirectory) + case mlModelFile(MLModelFile) // Add ML model support } enum FileElementSub { case imageFile(ImageFile) + case mlModelFile(MLModelFile) // Add ML model support to sub-level } struct SubLevelDirectory: FileDocument{ @@ -98,8 +134,10 @@ struct SubLevelDirectory: FileDocument{ do{ let contents = try FileManager.default.contentsOfDirectory(at: url, includingPropertiesForKeys: nil) for content in contents{ - if content.pathExtension.lowercased() == "ipeg" || content.pathExtension.lowercased() == "jpg" { + if content.pathExtension.lowercased() == "jpeg" || content.pathExtension.lowercased() == "jpg" { self.containedFiles.append(.imageFile(ImageFile(url: content))) + } else if ["mlmodel","mlmodelc","mlpackage"].contains(content.pathExtension.lowercased()) { + self.containedFiles.append(.mlModelFile(MLModelFile(url: content))) } } }catch{ @@ -120,6 +158,9 @@ struct SubLevelDirectory: FileDocument{ case .imageFile(let imageFile): let fileWrapper = try imageFile.fileWrapper(configuration: configuration) dirWrapper.addFileWrapper(fileWrapper) + case .mlModelFile(let mlModelFile): + let fileWrapper = try mlModelFile.fileWrapper(configuration: configuration) + dirWrapper.addFileWrapper(fileWrapper) } } return dirWrapper @@ -164,6 +205,10 @@ struct SubLevelDirectory: FileDocument{ let directoryWrapper = try directory.fileWrapper(configuration: configuration) directoryWrapper.preferredFilename = directory.url.lastPathComponent folderWrapper.addFileWrapper(directoryWrapper) + case .mlModelFile(let mlModelFile): + let fileWrapper = try mlModelFile.fileWrapper(configuration: configuration) + fileWrapper.preferredFilename = mlModelFile.url.lastPathComponent + folderWrapper.addFileWrapper(fileWrapper) } } diff --git a/AnySense/gripper_closed.png b/AnySense/gripper_closed.png new file mode 100644 index 0000000..46a1b6e Binary files /dev/null and b/AnySense/gripper_closed.png differ diff --git a/AnySense/gripper_overlay.png b/AnySense/gripper_overlay.png new file mode 100644 index 0000000..98f57f3 Binary files /dev/null and b/AnySense/gripper_overlay.png differ diff --git a/README.md b/README.md index d5aa5cb..0fe373b 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # AnySense -[AnySense](https://anysense.app) is an iPhone application that integrates the iPhone's sensory suite with external multisensory inputs via Bluetooth and wired interfaces, enabling both offline data collection and online streaming to robots. Currently, we record RGB and depth videos, metric depth frames, streamed Bluetooth data appended into a binary file and timestamped pose data as a `.txt` file. Example streaming code for streaming Bluetooth data can be found on [AnySkin](https://any-skin.github.io). We also allow for USB streaming by simply connecting the iPhone to your computer and using this [accompanying library](https://github.com/NYU-robot-learning/anysense-streaming) forked from the excellent [record3d](https://github.com/marek-simonik/record3d) library. +[AnySense](https://anysense.app) is an open-source iPhone app that enables multi-sensory data collection by integrating the iPhone's sensory suite with external sensors via Bluetooth and wired interfaces, enabling both offline data collection and online streaming to robots. The app also supports on-device policy evaluation: you can load a trained CoreML policy, set a 3D goal in AR, and run the policy with AR visualization of predicted actions without a robot. We record RGB and depth videos, metric depth frames, streamed Bluetooth data (e.g. from tactile sensors) into a binary file, and timestamped pose data as a `.txt` file. Example streaming code for Bluetooth data can be found on [AnySkin](https://any-skin.github.io). USB streaming is supported by connecting the iPhone to your computer and using this [accompanying library](https://github.com/NYU-robot-learning/anysense-streaming) (forked from [record3d](https://github.com/marek-simonik/record3d)). ## App Screenshots sdfasdf diff --git a/README_Inference.md b/README_Inference.md new file mode 100644 index 0000000..0c5c93d --- /dev/null +++ b/README_Inference.md @@ -0,0 +1,49 @@ +# Inference View + +The inference view lets you run a point-conditioned CoreML policy on-device, visualize actions in AR, and inspect the gripper state. The purpose of this is to allow for us to be able to test the model without a robot, and get a sense of it's capability(vibe) with you acting as a robot arm, and evaluating in the wild based on the policy's conditioning task. The demo model loaded by default is a object pick-up policy. It is optimized for low-latency inference and small memory churn using shared pixel buffers and Accelerate for preprocessing. + +## Technical Specifications + +**Model Requirements:** +The system accepts CoreML models converted from our pipelines such as RUM/min-stretch. Models must process RGB input at 224×224 pixel resolution, though the original specification supported 256×256 pixels. Researchers requiring alternative resolutions can adjust the `modelInputSize` parameter and corresponding buffer allocations in `MLInferenceManager`. + +The system supports temporal models with up to 3-frame rolling buffers, automatically managing sequence padding for consistency. Input tensors follow standard formats: `[1, 3, H, W]` for single-frame models or `[1, T, 3, H, W]` for temporal architectures, accompanied by goal tensors representing 3D spatial coordinates with shape inferred from model specifications. + +Model outputs must provide 7-element vectors representing 6-DOF manipulation actions plus gripper state. For temporal models, the system extracts predictions from the final timestep, ensuring consistency with training assumptions while accommodating variable sequence lengths. + +## App Pipeline + +1) **Model load**: `ModelManager` provides the active compiled model and metadata (temporal frames, input/output names, goal requirement). A loading overlay is shown while this warms up. +2) **Goal conditioning**: A user-tapped 3D goal in world space is transformed to the camera/labels frame just before inference (mapping: `[-x_cam, -z_cam + first-frame 0.02 offset, -y_cam]`). Required models skip inference until a goal exists. +3) **Frame prep**: AR camera frames → vImage/Core Image resize to 224×224 ARGB → normalized MLMultiArray `[1,3,H,W]`. Gripper overlay is composited into the model input when USB streaming is OFF. +4) **Temporal buffer**: Maintain up to 3 action frames. + - USB streaming ON: always roll the buffer; run by frequency (High/Med/Low/Minute) or first inference. + - USB streaming OFF (recording mode): buffer only when proximity trigger fires or on first/manual inference. +5) **Input packing**: Build `[1,T,3,H,W]` with padding if fewer than T frames; build goal tensor matching the model-declared shape; assemble `MLDictionaryFeatureProvider`. +6) **Inference**: Run off-main on a dedicated queue; track inference time; guard with a pending flag to avoid overlap. +7) **Postprocess**: Extract joint actions (last timestep if temporal). Gripper value updates UI + overlay, and AR visualization updates the target pose (skipped in USB mode). Latest result is published for the UI card. + +## UI/Controls + +- **Set goal**: Tap “Set goal”, then tap in AR to place the 3D target. +- **Start/stop**: Inference is enabled automatically in the Inference tab; press **Play/Stop** to start/stop inference (no video is saved). A loading overlay appears when switching models. +- **Manual step**: "Get next action" calls manual inference using the existing buffered frames. This is specifically useful if there's significant deviation from target, and user wants to realign and restart inferencing. This also helps check for robustness in the performance. +- **Visualization**: AR overlay shows the inferred pose when **not** in USB streaming mode. Gripper overlay image on-screen mirrors predicted gripper open/closed and shows an “iPhone Inference” badge when active. +- **Status card**: `MLInferenceResultsView` shows gripper value and OPEN/CLOSED state. + +## How to Use Quickly + +1) Ensure a compiled model is available (see min-stretch conversion flow) and enter the Inference tab. A demo model is already loaded up, trained on object pick-up tasks. +2) Hold the phone sideways (**landscape**) while using the demo. Wait for "Preparing Model…" to finish if shown. +3) Tap "Set goal" and tap a point to place the target in AR. +4) Press **Play**, then start aligning the blue arrow to the red arrow (that changes from red to green as you get closer). As you align, the next action is inferred. If the next action is far off target, you can retry using the "Get next action" button. +5) When the gripper closes, the episode is finished; press **Stop** to reset. +6) Watch the AR pose updates and the gripper card/overlay for state and timing feedback. + +## USB Streaming for Robot Control + +The app can stream inference results directly to a robot via USB using the Record3D protocol. When USB streaming is enabled, the iPhone performs on-device inference and streams RGB frames, depth maps, camera poses, and predicted joint actions (7-DOF: 6-DOF manipulation + gripper state) to a connected computer. The robot server receives these action predictions in real-time and can execute them directly. To use this feature, enable "USB Streaming mode" in the Settings tab, connect your iOS device to your computer, and use the [anysense-streaming library](https://github.com/NYU-robot-learning/anysense-streaming/tree/dev/Krish) to receive the stream. The library provides Python and C++ bindings for receiving RGBD data, camera poses, and joint actions via `session.get_joint_actions()`. + +## Implementation Notes: PyTorch to CoreML Conversion + +Converting PyTorch models (RUM/min-stretch) to CoreML requires refactoring the loss function's forward method from a training-style interface to an inference-only path (`images, goals` → actions), replacing negative dimension indices with explicit positive ones, using slice notation to preserve tensor dimensions, and adding explicit type casts for `F.one_hot()` operations. A `CoreMLBranchStyleWrapper` encapsulates the model and loss function to provide a clean inference path. The conversion notebook is available in the [min-stretch repository's coreml branch](https://github.com/NYU-robot-learning/min-stretch/tree/coreml).