Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def __init__(self):
self._in_to_static_mode_ = False
self._functional_dygraph_context_manager = None
self._dygraph_tracer_ = _dygraph_tracer_
self.thread_expected_place_ = None
env_pir_enabled = os.environ.get("FLAGS_enable_pir_api")

if env_pir_enabled is not None:
Expand Down Expand Up @@ -332,6 +333,10 @@ def __setattr__(self, name, val):
}


def _is_main_thread():
return threading.current_thread() is threading.main_thread()


def in_dygraph_mode() -> bool:
"""

Expand Down Expand Up @@ -814,53 +819,64 @@ def _dygraph_tracer():

def _current_expected_place_():
global _global_expected_place_
if (
_global_expected_place_ is None
or type(_global_expected_place_) is core.Place
):
expected_place: core.Place = None

if _is_main_thread():
expected_place = _global_expected_place_
else:
expected_place = (
global_var.thread_expected_place_ or _global_expected_place_
)

if expected_place is None or type(expected_place) is core.Place:
if core.is_compiled_with_cuda():
try:
device_count = core.get_cuda_device_count()
except Exception as e:
device_count = 0
if device_count > 0:
_global_expected_place_ = core.CUDAPlace(_cuda_ids()[0])
expected_place = core.CUDAPlace(_cuda_ids()[0])
else:
warnings.warn(
"You are using GPU version Paddle, but your CUDA device is not set properly. CPU device will be used by default."
)
_global_expected_place_ = core.CPUPlace()
expected_place = core.CPUPlace()
elif core.is_compiled_with_xpu():
try:
device_count = core.get_xpu_device_count()
except Exception as e:
device_count = 0
if device_count > 0:
_global_expected_place_ = core.XPUPlace(_xpu_ids()[0])
expected_place = core.XPUPlace(_xpu_ids()[0])
else:
warnings.warn(
"You are using XPU version Paddle, but your XPU device is not set properly. CPU device will be used by default."
)
_global_expected_place_ = core.CPUPlace()
expected_place = core.CPUPlace()
elif len(core.get_all_custom_device_type()) > 0:
dev_type = core.get_all_custom_device_type()[0]
try:
device_count = core.get_custom_device_count(dev_type)
except Exception as e:
device_count = 0
if device_count > 0:
_global_expected_place_ = core.CustomPlace(
expected_place = core.CustomPlace(
dev_type, _custom_device_ids(dev_type)[0]
)
else:
warnings.warn(
"You are using CUSTOM_DEVICE version Paddle, but your custom device is not set properly. CPU device will be used by default."
)
_global_expected_place_ = core.CPUPlace()
expected_place = core.CPUPlace()
else:
_global_expected_place_ = core.CPUPlace()
expected_place = core.CPUPlace()

if _is_main_thread():
_global_expected_place_ = expected_place
elif global_var.thread_expected_place_ is not None:
global_var.thread_expected_place_ = expected_place

return _global_expected_place_
return expected_place


def _current_expected_place():
Expand All @@ -876,7 +892,12 @@ def _set_dygraph_tracer_expected_place(place):

def _set_expected_place(place):
global _global_expected_place_
_global_expected_place_ = place

if _is_main_thread():
_global_expected_place_ = place
else:
global_var.thread_expected_place_ = place

_set_dygraph_tracer_expected_place(place)


Expand Down
Loading
Loading