Python3/Apache Airflow

[Airflow] Apache Airflow Test Case 작성하는 법

Razelo 2023. 3. 18. 15:46

최근 Apache Airflow를 만질 일이 생겼다. 

 

여러 Task 들을 작성하기는 했는데 테스트 코드를 어떻게 작성할 지에 대한 감도 잡히지 않았고 이걸 직접 다 돌려보는 풀테스트를 작성해야하나 긴가민가 했다.

 

그리고 짧은 조언을 받았는데 기본적인 테스트 코드들이 있었으면 좋겠다는 리뷰를 받았다. 그러니 앞으로는 테스트는 기본으로 가져가자는 마인드를 가져야겠다. 

 

그래서 airflow 에 대해 간단한 Test Case를 작성하는 법을 알아보고자 한다.


우선 pytest를 사용할 것이기 때문에 pip install을 해주자. mock 테스트 안할 거면 후자는 빼줘도 되요~

pip3 install pytest pytest-mock

 

그리고 아래와 같은 dag와 task 들을 정의해줬다고 가정하자. 

from datetime import datetime, timedelta
from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator

default_args = {
    'owner': 'airflow',
    'start_date': datetime(2022, 1, 1),
    'retries': 1,
    'retry_delay': timedelta(minutes=5)
}

dag = DAG(
    'my_dag',
    default_args=default_args,
    description='My example DAG',
    schedule_interval=timedelta(days=1),
    catchup=False
)

start = DummyOperator(
    task_id='start',
    dag=dag
)

end = DummyOperator(
    task_id='end',
    dag=dag
)

start >> end

 

자 그렇다면 아래와 같은 test 파일을 정의해주자. (test_my_dag.py)

import pytest
from airflow.models import DagBag

@pytest.fixture(scope="module")
def test_dag():
    dag_bag = DagBag(dag_folder='.', include_examples=False)
    dag = dag_bag.get_dag(dag_id='my_dag')
    return dag

def test_dag_loaded(test_dag):
    assert test_dag is not None, "DAG not found"
    assert len(test_dag.tasks) == 2, "Wrong number of tasks"

def test_task_dependencies(test_dag):
    start_task = test_dag.get_task('start')
    end_task = test_dag.get_task('end')

    assert start_task.downstream_list == [end_task], "start task not connected to end task"
    assert end_task.upstream_list == [start_task], "end task not connected to start task"

 

한 번 읽어보면 대충 무슨 의미인지 알 것이다. 읽어보면 DagBag, get_dat, get_task, downstream_list, upstream_list 등등 처음 접하는 개념들이 등장하는데 감이 잡히지 않으면 찾아보시면 된다. 

 

자 이제 위 테스트 코드를 테스트하려면 어떻게 해야할까? 아래처럼 타이핑해보자. 

pytest test_my_dag.py

 

airflow 테스트는 이런 식으로 작성하면 된다. 

 

자 그런데 여기까지는 그저 task 의 loading 혹은 import error 들을 테스트하는 코드에 불과하다. 그렇다면 실제 task의 action을 테스트하고 싶다면 어떻게 할 수 있을까? 

 

이때 활용할 수 있는 것이 TaskInstance이다. 

 

from airflow.models import TaskInstance
from airflow.utils import timezone

def test_my_custom_operator_execution(test_dag):
    task = test_dag.get_task(task_id='my_task')
    ti = TaskInstance(task=task, execution_date=timezone.utcnow())

    # Mock the execute method if needed
    with unittest.mock.patch('my_custom_operator.MyCustomOperator.execute') as mock_execute:
        # The mock can return a specific value or raise an exception, depending on your test case
        mock_execute.return_value = "your expected return value"

        # Run the task
        ti.run(ignore_ti_state=True)

        # Check that the execute method was called
        mock_execute.assert_called_once()

        # Add any additional assertions to verify the task's outcome

 

그런데 위의 경우는 custom Operator 라고 볼 수 있다. 만약 CustomOperator가 아니라 PythonOperator를 사용한다고 하면 어떤 코드가 나올까? 

 

아래 dag와 task 정의가 있다. 

from datetime import datetime, timedelta
from airflow import DAG
from airflow.operators.python_operator import PythonOperator
from my_module import my_function

default_args = {
    'owner': 'airflow',
    'start_date': datetime(2022, 1, 1),
}

dag = DAG(
    'my_dag',
    default_args=default_args,
    description='My example DAG',
    schedule_interval=timedelta(days=1),
    catchup=False
)

my_task = PythonOperator(
    task_id='my_task',
    python_callable=my_function,
    dag=dag
)

 

아래처럼 테스트 코드를 작성하면 된다. 기본적으로 mock 테스트를 하는 방식인데 이 부분은 좀 더 알아보고 싶으면 찾아보면 된다. 

import pytest
import unittest.mock
from airflow.models import TaskInstance
from airflow.utils import timezone
from my_dag import dag as test_dag  # Import your actual DAG here

def test_python_operator_execution():
    task = test_dag.get_task(task_id='my_task')
    ti = TaskInstance(task=task, execution_date=timezone.utcnow())

    # Mock the function called by the PythonOperator
    with unittest.mock.patch('my_module.my_function') as mock_function:
        # The mock can return a specific value or raise an exception, depending on your test case
        mock_function.return_value = "your expected return value"

        # Run the task
        ti.run(ignore_ti_state=True)

        # Check that the function was called
        mock_function.assert_called_once()

        # Add any additional assertions to verify the task's outcome

 

그리고 개인적으로도 mock test를 권장하는데 특히 airflow의 경우 외부와의 커넥션이 많은 작업일 수 있다. s3에 접근한다던지 dynamodb와 관련된 작업일 수도 있고 다른 api 서버에 접근할 수 있다. 그렇기 때문에 그걸 테스트로 돌린다는 건 말이 안되고 mock 테스트를 활용해서 테스트하는 게 좋다. 또는 postman에서 제공하는 mock server 를 활용하는 것도 좋은 방법 중 하나이다. 

 


아래 블로그를 참고하시면 더 좋은 내용을 얻으실 수 있습니다. 

 

감사합니다. 

https://docs.astronomer.io/learn/testing-airflow

https://taegyuhan.github.io/python/Python_test_Code/

 

 

 

반응형