diff --git a/fedlearner/common/common.py b/fedlearner/common/common.py index 0d3faca4f..19dfaa64a 100644 --- a/fedlearner/common/common.py +++ b/fedlearner/common/common.py @@ -253,6 +253,36 @@ def convert_time_string_to_datetime(value): return date_time +def get_process_dates(start_date, end_date=None, fmt='%Y%m%d'): + today = datetime.date.today() + today_date = datetime.datetime(today.year, today.month, today.day) + if end_date is None or end_date > today_date: + end_date = today_date + if start_date > end_date: + raise ValueError("start_date should be less than or equal to end_date") + process_dates = [] + current_date = start_date + while current_date <= end_date: + process_dates.append(current_date.strftime(fmt)) + current_date += datetime.timedelta(days=1) + return process_dates + + +def end_with_valid_date(path: str) -> bool: + last_field = path.rstrip('/').split('/')[-1] + + def is_valid_date(date_str: str) -> bool: + for fmt in ('%Y-%m-%d', '%Y%m%d', '%Y/%m/%d', '%Y.%m.%d'): + try: + datetime.strptime(date_str, fmt) + return True + except ValueError: + continue + return False + + return is_valid_date(last_field) + + def set_logger(): verbosity = int(os.environ.get('VERBOSITY', 1)) if verbosity == 0: diff --git a/fedlearner/trainer/data_visitor.py b/fedlearner/trainer/data_visitor.py index 5f7dc6ab0..557da212a 100644 --- a/fedlearner/trainer/data_visitor.py +++ b/fedlearner/trainer/data_visitor.py @@ -30,6 +30,8 @@ from fedlearner.common import fl_logging from fedlearner.common import trainer_master_service_pb2 as tm_pb from fedlearner.common.common import convert_time_string_to_datetime +from fedlearner.common.common import end_with_valid_date +from fedlearner.common.common import get_process_dates from fedlearner.data_join.data_block_visitor import DataBlockVisitor from fedlearner.trainer.utils import match_date @@ -351,22 +353,41 @@ def __init__(self, if end_date: end_date = convert_time_string_to_datetime(str(end_date)) datablocks = [] - for dirname, _, filenames in tf.io.gfile.walk(data_path): - for filename in filenames: - if not fnmatch(os.path.join(dirname, filename), wildcard): + if start_date and not end_with_valid_date(data_path): + process_dates = get_process_dates(start_date, end_date) + miss_dates = [] + for process_date in process_dates: + dir_path = os.path.join(data_path, process_date) + if not tf.io.gfile.exists(dir_path): + miss_dates.append(process_date) continue - subdirname = os.path.relpath(dirname, data_path) - try: - cur_date = datetime.strptime(subdirname, '%Y%m%d') - if not match_date(cur_date, start_date, end_date): + for _, _, filenames in tf.io.gfile.walk(dir_path): + for filename in filenames: + if not fnmatch(os.path.join(dir_path, filename), wildcard): + continue + block_id = os.path.join(process_date, filename) + datablock = _RawDataBlock( + id=block_id, data_path=os.path.join(dir_path, filename), + start_time=None, end_time=None, type=tm_pb.JOINED) + datablocks.append(datablock) + fl_logging.info('miss_dates: [%s]', ",".join(miss_dates)) + else: + for dirname, _, filenames in tf.io.gfile.walk(data_path): + for filename in filenames: + if not fnmatch(os.path.join(dirname, filename), wildcard): continue - except Exception: - fl_logging.info('subdirname is not the format of time') - block_id = os.path.join(subdirname, filename) - datablock = _RawDataBlock( - id=block_id, data_path=os.path.join(dirname, filename), - start_time=None, end_time=None, type=tm_pb.JOINED) - datablocks.append(datablock) + subdirname = os.path.relpath(dirname, data_path) + try: + cur_date = datetime.strptime(subdirname, '%Y%m%d') + if not match_date(cur_date, start_date, end_date): + continue + except Exception: + fl_logging.info('subdirname is not the format of time') + block_id = os.path.join(subdirname, filename) + datablock = _RawDataBlock( + id=block_id, data_path=os.path.join(dirname, filename), + start_time=None, end_time=None, type=tm_pb.JOINED) + datablocks.append(datablock) datablocks.sort(key=lambda x: x.id) fl_logging.info("create DataVisitor by local_data_path: %s",