diff --git a/capturemock/__init__.py b/capturemock/__init__.py index 2a5c615e..52d8f04e 100644 --- a/capturemock/__init__.py +++ b/capturemock/__init__.py @@ -12,7 +12,7 @@ import bisect from urllib.request import urlopen -version = "2.8.3" +version = "2.8.4" class CaptureMockManager: fileContents = "import capturemock; capturemock.interceptCommand()\n" @@ -248,10 +248,13 @@ def replay_for_server(rcFile=None, replayFile=None, recordFile=None, serverAddre FileEditTraffic.configure(foptions) from .server import ReplayOnlyDispatcher dispatcher = ReplayOnlyDispatcher(replayFile, recordFile, rcFile) - if serverAddress: - from .clientservertraffic import ClientSocketTraffic - ClientSocketTraffic.setServerLocation(serverAddress, True) - dispatcher.replay_all(**kw) + try: + if serverAddress: + from .clientservertraffic import ClientSocketTraffic + ClientSocketTraffic.setServerLocation(serverAddress, True) + dispatcher.replay_all(**kw) + finally: + dispatcher.shutdown() def add_timestamp_data(data_by_timestamp, given_ts, fn, currText, fn_timestamps): if currText.startswith("->"): # it's a reply, must match with best client traffic @@ -617,6 +620,10 @@ def wrapped_func(*funcargs, **funckw): setUpPython(self.mode, recordFile, replayFile, self.rcFiles, self.pythonAttrs) interceptor = interceptPython(self.mode, recordFile, replayFile, self.rcFiles, self.pythonAttrs) result = func(*funcargs, **funckw) + # Close the file handlers before attempting file operations (needed for Windows) + if interceptor: + interceptor.resetIntercepts() + interceptor = None if self.mode == config.REPLAY: self.checkMatching(recordFile, replayFile) elif os.path.isfile(recordFile): diff --git a/capturemock/capturepython.py b/capturemock/capturepython.py index ba140e9e..a0c5fc02 100644 --- a/capturemock/capturepython.py +++ b/capturemock/capturepython.py @@ -42,8 +42,8 @@ def callerExcluded(self, stackDistance=1, callback=False): # Don't intercept if we've been called from within the standard library self.excludeLevel += 1 - framerecord = inspect.stack()[stackDistance] - fileName = framerecord[1] + frame = sys._getframe(stackDistance) + fileName = frame.f_code.co_filename dirName = self.getDirectory(fileName) moduleName = self.getModuleName(fileName) moduleNames = set([ moduleName, os.path.basename(dirName) ]) @@ -241,6 +241,7 @@ def __init__(self, mode, recordFile, replayFile, rcFiles, pythonAttrs): self.replayInfo = replayinfo.ReplayInfo(mode, replayFile, self.rcHandler) self.recordFile = recordFile self.allAttrNames = self.findAttributeNames(mode, pythonAttrs) + self.trafficHandler = None def findAttributeNames(self, mode, pythonAttrs): rcAttrs = self.rcHandler.getIntercepts("python") @@ -272,14 +273,14 @@ def makeIntercepts(self): return callStackChecker = CallStackChecker(self.rcHandler) from .pythontraffic import PythonTrafficHandler - trafficHandler = PythonTrafficHandler(self.replayInfo, self.recordFile, self.rcHandler, - callStackChecker, self.allAttrNames) + self.trafficHandler = PythonTrafficHandler(self.replayInfo, self.recordFile, self.rcHandler, + callStackChecker, self.allAttrNames) if len(fullIntercepts): - import_handler = ImportHandler(fullIntercepts, callStackChecker, trafficHandler) + import_handler = ImportHandler(fullIntercepts, callStackChecker, self.trafficHandler) if import_handler not in sys.meta_path: sys.meta_path.insert(0, import_handler) for moduleName, attributes in partialIntercepts.items(): - self.interceptAttributes(moduleName, attributes, trafficHandler) + self.interceptAttributes(moduleName, attributes, self.trafficHandler) def splitByModule(self, attrName): if self.canImport(attrName): @@ -338,3 +339,5 @@ def resetIntercepts(self): item.reset() for realObj, attrName, origValue in self.attributesIntercepted: setattr(realObj, attrName, origValue) + if self.trafficHandler is not None: + self.trafficHandler.close() diff --git a/capturemock/pythontraffic.py b/capturemock/pythontraffic.py index 0f609558..2ca827c7 100644 --- a/capturemock/pythontraffic.py +++ b/capturemock/pythontraffic.py @@ -497,6 +497,11 @@ def __init__(self, replayInfo, recordFile, rcHandler, callStackChecker, intercep PythonCallbackWrapper.resetCaches() PythonAttributeTraffic.resetCaches() + def close(self): + if self.recordFileHandler is not None: + self.recordFileHandler.close() + self.recordFileHandler = None + def importModule(self, name, proxy, loadModule): with self.lock: traffic = PythonImportTraffic(name, self.rcHandler) diff --git a/capturemock/recordfilehandler.py b/capturemock/recordfilehandler.py index 4bc77a0c..44e05c7d 100644 --- a/capturemock/recordfilehandler.py +++ b/capturemock/recordfilehandler.py @@ -7,26 +7,40 @@ def __init__(self, file): self.file = file self.lastTruncationPoint = None self.recordedSinceTruncationPoint = [] + self._writeFile = open(file, "a", buffering=8192) if file else None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + return False + + def __del__(self): + self.close() + + def close(self): + if self._writeFile is not None: + self._writeFile.flush() + self._writeFile.close() + self._writeFile = None def record(self, text, truncationPoint=False): - if self.file: + if self._writeFile: if truncationPoint: + self._writeFile.flush() self.lastTruncationPoint = os.path.getsize(self.file) self.recordedSinceTruncationPoint = [] if self.lastTruncationPoint is not None: self.recordedSinceTruncationPoint.append(text) - writeFile = open(self.file, "a") - writeFile.write(text) - writeFile.flush() - writeFile.close() + self._writeFile.write(text) + self._writeFile.flush() def rerecord(self, oldText, newText): - if self.file: - writeFile = open(self.file, "a") - writeFile.truncate(self.lastTruncationPoint) + if self._writeFile: + self._writeFile.truncate(self.lastTruncationPoint) for text in self.recordedSinceTruncationPoint: - writeFile.write(text.replace(oldText, newText)) - writeFile.flush() - writeFile.close() + self._writeFile.write(text.replace(oldText, newText)) + self._writeFile.flush() self.lastTruncationPoint = None self.recordedSinceTruncationPoint = [] diff --git a/capturemock/server.py b/capturemock/server.py index 58ac6118..e1d69dab 100644 --- a/capturemock/server.py +++ b/capturemock/server.py @@ -593,7 +593,9 @@ def getServerClass(self): return AMQPTrafficServer def shutdown(self): - pass + if self.recordFileHandler is not None: + self.recordFileHandler.close() + self.recordFileHandler = None def findFilesAndLinks(self, path): if not os.path.exists(path): @@ -863,6 +865,7 @@ def run(self): def shutdown(self): self.diag.debug("Told to shut down!") self.server.shutdown() + super(ServerDispatcher, self).shutdown() class ReplayOnlyDispatcher(ServerDispatcherBase): @@ -981,7 +984,10 @@ def main(): fileedittraffic.FileEditTraffic.configure(options) server = ServerDispatcher(options) - server.run() + try: + server.run() + finally: + server.shutdown() if __name__ == "__main__":