It can be quite complicate to test an Airflow DAG in local. Sometimes, you don't want to run the full DAG, but just some tasks.
To solve this problem, we created a test function to run only some task of a DAG.
We also made a project with an exemple:
The test_dag function¶
This is the main function to test a DAG in local and disabled some tasks.
Note: this is not a generic solution, it will depend on how you declare your tasks in your DAG
import os
import sys
from airflow.models.baseoperator import BaseOperator
def test_dag(dag_name, tasks_id_to_run=[]):
def do_nothing(context):
print('do nothing')
def recursive_disabled_tasks(tasks, tasks_id_to_run):
for i in range(len(tasks)):
if tasks[i].task_id not in tasks_id_to_run:
tasks[i].execute = do_nothing
if len(tasks[i].downstream_list) > 0:
recursive_disabled_tasks(tasks[i].downstream_list, tasks_id_to_run)
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
# Add path where are located DAGs, depend on your project ...
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, os.pardir, 'src')))
# import DAG
module = __import__(dag_name)
# transform module object to dict, get vairables
variables = {a: module.__getattribute__(a) for a in dir(module)}
# Get tasks from simple global variable
tasks = list(filter(lambda x: isinstance(x, BaseOperator), variables.values()))
# Get tasks from list of tasks in global variable
list_var = list(filter(lambda x: isinstance(x, list), variables.values()))
for l in list_var:
for var in l:
if isinstance(var, BaseOperator):
# TODO: complete, for exemple if you declare DAGs in other objects type, like dict
# ...
# Disabled task, add dummy function
recursive_disabled_tasks(tasks, tasks_id_to_run)
dag = variables.get('dag')
# Add suffix to the dag name
dag.dag_id += '_test'
return dag
Let's try to explain the code¶
First of all, we will create a stupid example of DAG.
dag_name = 'tmp_dag'
content_dag = """
from airflow.operators.python_operator import PythonOperator
from airflow import DAG
from airflow.utils.dates import days_ago
default_args = {
"start_date": days_ago(1),
"owner": "Data Banana",
'retries': 0,
dag = DAG(
schedule_interval='1 1 * * *',
description='Stupid DAG'
def do_something(**kwargs):
print("do something")
t1 = PythonOperator(
t2 = PythonOperator(
t3 = PythonOperator(
list_task = [PythonOperator(
t1 >> t2 >> t3
with open('', 'w') as f:
We load the DAG as a module, to access the variables
# We add the current directory in the sys.path, so we can load the DAG as a module
# import the DAG asa module
module = __import__(dag_name)
# transform module object to dict, get vairables
variables = {a: module.__getattribute__(a) for a in dir(module)}
# Now we have all variables available in this module
Now We will start to find the task object in all variables...
# Get tasks from simple variable
tasks = list(filter(lambda x: isinstance(x, BaseOperator), variables.values()))
[<Task(PythonOperator): task1>, <Task(PythonOperator): task2>, <Task(PythonOperator): task3>]
# Get tasks from list of tasks
list_var = list(filter(lambda x: isinstance(x, list), variables.values()))
tasks_from_list = []
for l in list_var:
for var in l:
if isinstance(var, BaseOperator):
[<Task(PythonOperator): task4>]
# So here you may handle more scenario ...
all_tasks = tasks + tasks_from_list
[<Task(PythonOperator): task1>, <Task(PythonOperator): task2>, <Task(PythonOperator): task3>, <Task(PythonOperator): task4>]
for task in all_tasks:
<bound method PythonOperator.execute of <Task(PythonOperator): task1>> <bound method PythonOperator.execute of <Task(PythonOperator): task2>> <bound method PythonOperator.execute of <Task(PythonOperator): task3>> <bound method PythonOperator.execute of <Task(PythonOperator): task4>>
Now we want to change the behaviour of those tasks, and also navigate on the task downstream list to be sure we didn't forget some tasks
def do_nothing(context):
print('do nothing')
def recursive_disabled_tasks(tasks, tasks_id_to_run=[]):
for i in range(len(tasks)):
if tasks[i].task_id not in tasks_id_to_run:
tasks[i].execute = do_nothing
# Navigate in downstream_list to check if we didn't
if len(tasks[i].downstream_list) > 0:
recursive_disabled_tasks(tasks[i].downstream_list, tasks_id_to_run)
recursive_disabled_tasks(all_tasks, ['task4'])
for task in all_tasks:
<function do_nothing at 0x7f94f0199550> <function do_nothing at 0x7f94f0199550> <function do_nothing at 0x7f94f0199550> <bound method PythonOperator.execute of <Task(PythonOperator): task4>>
# Let's just rename the DAG
dag = variables.get('dag')
dag.dag_id += '_test'