-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathserverless_dag_decorators.py
More file actions
93 lines (75 loc) · 3.06 KB
/
serverless_dag_decorators.py
File metadata and controls
93 lines (75 loc) · 3.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from datetime import datetime
from airflow import DAG
from airflow.decorators import task
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.operators.s3 import S3ListOperator
from airflow.providers.amazon.aws.operators.lambda_function import LambdaInvokeFunctionOperator
from airflow.utils.task_group import TaskGroup
from airflow.utils.dates import days_ago
import json
from airflow.operators.python import get_current_context
import logging
logger = logging.getLogger()
MY_BUCKET = "credence-core-raw"
default_args = {
'owner': 'airflow',
'start_date': days_ago(1),
}
with DAG(
dag_id='serverless_dag_decorators',
default_args=default_args,
schedule_interval=None,
catchup=False) as dag:
list_filenames = S3ListOperator(
task_id="list_filenames",
bucket=MY_BUCKET,
prefix=r"serverless_dag/data",
)
@task(task_id='get_file_size',map_index_template = """{{ filename }}""")
def get_file_size(aws_conn_id, bucket, filename):
context = get_current_context()
context["filename"] = filename
hook = S3Hook(aws_conn_id=aws_conn_id)
logging.info(f"Getting file size for {filename}")
return {filename: hook.get_key(filename, bucket).content_length}
@task(task_id='add_lines')
def total(lines):
logging.info(f"Adding lines {lines}")
return sum(int(line) for line in lines) # for line in json.loads(lines['body'])['total'])
@task(map_index_template = """{{ filename }}""")
def choose_lambda_function(file):
filename, filesize = list(file.items())[0]
context = get_current_context()
context["filename"] = filename
d= {
"payload" :
json.dumps({
'bucket': list_filenames.bucket,
'key': filename,
'filesize': filesize})
}
if filesize < 10000:
d['function_name'] = 'process_small_csv'
elif filesize < 20000:
d['function_name'] = 'process_medium_csv'
else:
d['function_name'] = 'process_large_csv'
return d
file_sizes = get_file_size.partial(
aws_conn_id="aws_default",
bucket=list_filenames.bucket,
# map_index_template = """{{task.parameters['filename']}}"""
).expand(
filename=list_filenames.output,
# parameters = [{'filename': filename} for filename in list_filenames.output]
)
with TaskGroup(group_id='process_files') as process_files:
branch_task = choose_lambda_function.expand(file = file_sizes)
invoke_lambda_operators = LambdaInvokeFunctionOperator.partial(
task_id = 'call_lambdas',
aws_conn_id = 'aws_default',
map_index_template = """{{task['function_name']}}"""
).expand_kwargs(branch_task)
branch_task >> invoke_lambda_operators
aggregate_task = total(lines=invoke_lambda_operators.output)
list_filenames >> file_sizes >> process_files >> aggregate_task