컴퓨터 공부/💿 Airflow

[Airflow] BaseBranchOperator로 분기처리하기

letzgorats 2023. 8. 30. 04:59

이번 포스팅에서는 Task 분기처리하는 방법 중에 마지막 방법인 BaseBranchOperator로 분기처리하는 방법에 대해 살펴보겠습니다.

코드를 바로 살펴보겠습니다.

from airflow.operators.branch import BaseBranchOperator
with DAG(...
) as dag:
    class CustomBranchOperator(BaseBranchOperator):
        def choose_branch(self, context):
            import random
            
            item_lst = ['A','B','C']
            selected_item = random.choice(item_lst)
            if selected_item == 'A':
                return 'task_a'
            elif selected_item in ['B','C]:
                return ['task_b','task_c']
    
    custom_branch_operator = CustomBranchOperator(task_id='python_branch_task')
    custom_branch_operator >> [task_a,task_b,task_c]

먼저, ariflow.operators.branch 및에 BaseBranchOperator 라이브러리를 가져와야 합니다.

dag 선언을 해준 후에는 클래스를 직접 생성해야 하는데요, 이름은 아무렇게나 줘도 됩니다. 위 예시에서는 'CustomBranchOperator'라고 이름을 줬습니다. 

이 때, 중요한 것은 클래스 인자로 BaseBranchOperator를 줘야 합니다. 왜냐하면, 파이썬에서는 클래스 상속을 할 때, 클래스명 뒤의 괄호에 상속할 클래스명을 넣어주기 때문입니다.

그러니까, 여기서는 BaseBranchOperator가 부모 클래스가 되겠고 직접 만든 CustomBranchOperator가 자식 클래스가 되는 것이지요.

(참고로, 파이썬에서는 부모클래스를 2개 이상 상속하는 다중상속이 가능하지만, 권고하지는 않습니다. 가급적이면, 부모 클래스는 1개만 상속하는 것이 좋다고 가이드하고 있습니다.)

 

그 다음에, 'choose_branch'라는 함수를 정의하고 있습니다. 이 함수를 만든 이유를 알아보기 위해서는 먼저 BaseBranchOperator에 대한 이해가 필요합니다.

airflow 공식가이드에서 airflow.operators.branch 에 대해 살펴보면, 아래와 같습니다.

BaseBranchOperator클래스에 대한 설명

설명에서는 사용자는 BaseBranchOperator를 상속해야 하고, 반드시 'choose_branch' 함수를 구현해야 한다고 적혀져 있습니다. 그리고 'choose_branch' 함수는 분기로직이 결정되어야 할 비즈니스 로직이 뭐든 간에 실행되어야 한다고 합니다. 이 때, 반환 값은 분기된 task가 1개라면, task_id를 스트링 형태로 리턴해줘야 하고, 복수개라면, 리스트 형태로 리턴해줘야 합니다.

choose_branch 함수형태
인자는 context

이제 클래스를 선언하고 choose_branch 라고 하는 이 함수를 왜 재정의하는지 이해가 되실겁니다.

반드시 "choose_branch"라는 이름의 함수에 파라미터도 'context'를 꼭 넣어줘서 오버라이딩해야 합니다.

(참고로, 객체지향 프로그래밍에서 이런 함수 재정의 한 것을 "오버라이딩"이라고 합니다.)

 

오버라이딩한 내용은 저번 포스팅에서 코드를 짠 것과 동일합니다. 'task_a'와 ['task_b', 'task_c'] 중에서 random 결과에 따라 분기처리되는 코드입니다.

 

함수까지 작성한 것까지가 클래스에 대한 내용을 설계한 것이라면, 이런 설계도를 통해서 실제 실행이 가능한 객체를 만든 것이 'custom_branch_operator'라는 객체입니다.

task_flow를 정의하는 부분을 보면 직접 만든 CustomBranchOperator로부터 만들어진 task 뒤에 task_a, task_b, task_c 가 하나씩 물려있는 그래프를 형성할 것입니다.

 

이제 직접 실습을 해봅시다.

dags폴더에 dags_base_branch_operator.py파일을 생성해 다음과 같이 코드를 작성합니다.

from airflow import DAG
import pendulum
from airflow.operators.branch import BaseBranchOperator
from airflow.operators.python import PythonOperator

with DAG(
    dag_id = "dags_base_branch_operator",
    start_date=pendulum.datetime(2023,8,1,tz="UTC"),
    schedule=None,
    catchup=False
) as dag:
    class CustomBranchOperator(BaseBranchOperator):
        def choose_branch(self, context):
            import random
            print(context)

            item_lst = ['A','B','C']
            selected_item = random.choice(item_lst)
            if selected_item == 'A':
                return 'task_a'
            elif selected_item in ['B','C']:
                return ['task_b','taks_c']
            
    custom_branch_operator = CustomBranchOperator(task_id='python_branch_task')
            
    def common_func(**kwargs):
        print(kwargs['selected'])
    
    task_a = PythonOperator(
        task_id='task_a',
        python_callable=common_func,
        op_kwargs={'selected':'A'}
    )
    task_b = PythonOperator(
        task_id='task_b',
        python_callable=common_func,
        op_kwargs={'selected':'B'}
    )
    task_c = PythonOperator(
        task_id='task_c',
        python_callable=common_func,
        op_kwargs={'selected':'C'}
    )

    custom_branch_operator >> [task_a,task_b,task_c]

CustomBranchOperator라는 클래스를 만들어, task_id를 'python_branch_task'라고 지정해주고, 하나의 task를 만들었습니다 CustomBranchOperator 내부에 choose_branch 라는 함수를 오버라이딩 해주는데, context 인자에 뭐가 있는지 print()해보는 문장도 추가했습니다.random 결과에 따라 어떤 task로 분기가 될지 airflow로 한 번 확인해보겠습니다.

task_b, task_c 가 선택됨

그래프는 위 사진처럼 task_b와 task_c가 선택됐습니다.

xcom

Xcom을 살펴보면, return_value 키값에 ['task_b','task_c'] 가 들어가있고, skipmixin_key 키값에도 선택된 task가 들어가 있는 것을 확인할 수 있습니다. 로그도 확인해보겠습니다.

log

선택된 branch가 ['task_b','task_c'] 이고, 스킵된 branch가 'task_a' 가 잘 나와있습니다.

이 때, context를 출력해봤는데, conf가 이에 해당합니다.

[2023-08-29, 19:42:11 UTC] {logging_mixin.py:150} INFO - {'conf': <***.configuration.AirflowConfigParser object at 0x7f492922a910>, 'dag': <DAG: dags_base_branch_operator>, 'dag_run': <DagRun dags_base_branch_operator @ 2023-08-29 19:42:09.117489+00:00: manual__2023-08-29T19:42:09.117489+00:00, state:running, queued_at: 2023-08-29 19:42:09.147971+00:00. externally triggered: True>, 'data_interval_end': DateTime(2023, 8, 29, 19, 42, 9, 117489, tzinfo=Timezone('UTC')), 'data_interval_start': DateTime(2023, 8, 29, 19, 42, 9, 117489, tzinfo=Timezone('UTC')), 'ds': '2023-08-29', 'ds_nodash': '20230829', 'execution_date': DateTime(2023, 8, 29, 19, 42, 9, 117489, tzinfo=Timezone('UTC')), 'expanded_ti_count': None, 'inlets': [], 'logical_date': DateTime(2023, 8, 29, 19, 42, 9, 117489, tzinfo=Timezone('UTC')), 'macros': <module '***.macros' from '/home/***/.local/lib/python3.7/site-packages/***/macros/__init__.py'>, 'next_ds': '2023-08-29', 'next_ds_nodash': '20230829', 'next_execution_date': DateTime(2023, 8, 29, 19, 42, 9, 117489, tzinfo=Timezone('UTC')), 'outlets': [], 'params': {}, 'prev_data_interval_start_success': DateTime(2023, 8, 29, 19, 39, 18, 177396, tzinfo=Timezone('UTC')), 'prev_data_interval_end_success': DateTime(2023, 8, 29, 19, 39, 18, 177396, tzinfo=Timezone('UTC')), 'prev_ds': '2023-08-29', 'prev_ds_nodash': '20230829', 'prev_execution_date': DateTime(2023, 8, 29, 19, 42, 9, 117489, tzinfo=Timezone('UTC')), 'prev_execution_date_success': DateTime(2023, 8, 29, 19, 39, 18, 177396, tzinfo=Timezone('UTC')), 'prev_start_date_success': DateTime(2023, 8, 29, 19, 39, 19, 161899, tzinfo=Timezone('UTC')), 'run_id': 'manual__2023-08-29T19:42:09.117489+00:00', 'task': <Task(CustomBranchOperator): python_branch_task>, 'task_instance': <TaskInstance: dags_base_branch_operator.python_branch_task manual__2023-08-29T19:42:09.117489+00:00 [running]>, 'task_instance_key_str': 'dags_base_branch_operator__python_branch_task__20230829', 'test_mode': False, 'ti': <TaskInstance: dags_base_branch_operator.python_branch_task manual__2023-08-29T19:42:09.117489+00:00 [running]>, 'tomorrow_ds': '2023-08-30', 'tomorrow_ds_nodash': '20230830', 'triggering_dataset_events': <Proxy at 0x7f490acc9280 with factory <function TaskInstance.get_template_context.<locals>.get_triggering_events at 0x7f490acb4dd0>>, 'ts': '2023-08-29T19:42:09.117489+00:00', 'ts_nodash': '20230829T194209', 'ts_nodash_with_tz': '20230829T194209.117489+0000', 'var': {'json': None, 'value': None}, 'conn': None, 'yesterday_ds': '2023-08-28', 'yesterday_ds_nodash': '20230828'}

이 내용들은 PythonOperator에서 kwargs 값을 출력했던 것과 거의 비슷한 값을 보여줍니다. state, queued_at, data_interval_end, data_interval_start 값 등 kwargs에서도 봤던 파라미터가 있는 것을 볼 수 있습니다.때문에, context 객체 안에서 원하는 값을 꺼내고 싶다면, 코드에서도 context를 활용하면 되겠습니다.


이 포스팅을 끝으로, Task 분기처리하는 3가지 방법을 모두 살펴봤습니다.요약하자면 다음과 같습니다.

  • Task 분기처리 방법
    • BranchPythonOperator
    • task.branch 데커레이터 이용
    • BaseBranchOperator 상속, choose_branch를 재정의

공통적으로 PythonOperator로 task를 만드는 과정은 들어가고, 함수의 리턴 값으로 후행 Task의 id를 str 또는 list로 리턴해야 합니다. 

3가지 분기처리 방법은 방법만 다를 뿐 결과는 동일한데, 보통 3번째 방법보다는 1번방법이나 2번 방법을 사용하긴 합니다.

반응형