Problematic:¶
It can be quite complicate to test an Airflow DAG in local. Sometimes, you don't want to run the full DAG, but just some tasks.
Solution:¶
To solve this problem, we created a test function to run only some task of a DAG.
We also made a project with an exemple: https://github.com/data-banana/example_airflow_in_local
The test_dag function¶
This is the main function to test a DAG in local and disabled some tasks.
Note: this is not a generic solution, it will depend on how you declare your tasks in your DAG
import os
import sys
from airflow.models.baseoperator import BaseOperator
def test_dag(dag_name, tasks_id_to_run=[]):
def do_nothing(context):
print('do nothing')
pass
def recursive_disabled_tasks(tasks, tasks_id_to_run):
for i in range(len(tasks)):
if tasks[i].task_id not in tasks_id_to_run:
tasks[i].execute = do_nothing
if len(tasks[i].downstream_list) > 0:
recursive_disabled_tasks(tasks[i].downstream_list, tasks_id_to_run)
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
# Add path where are located DAGs, depend on your project ...
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, os.pardir, 'src')))
# import DAG
module = __import__(dag_name)
# transform module object to dict, get vairables
variables = {a: module.__getattribute__(a) for a in dir(module)}
# Get tasks from simple global variable
tasks = list(filter(lambda x: isinstance(x, BaseOperator), variables.values()))
# Get tasks from list of tasks in global variable
list_var = list(filter(lambda x: isinstance(x, list), variables.values()))
for l in list_var:
for var in l:
if isinstance(var, BaseOperator):
tasks.append(var)
# TODO: complete, for exemple if you declare DAGs in other objects type, like dict
# ...
# Disabled task, add dummy function
recursive_disabled_tasks(tasks, tasks_id_to_run)
dag = variables.get('dag')
# Add suffix to the dag name
dag.dag_id += '_test'
return dag
Let's try to explain the code¶
First of all, we will create a stupid example of DAG.
dag_name = 'tmp_dag'
content_dag = """
from airflow.operators.python_operator import PythonOperator
from airflow import DAG
from airflow.utils.dates import days_ago
default_args = {
"start_date": days_ago(1),
"owner": "Data Banana",
'retries': 0,
}
dag = DAG(
'tmp_dag',
default_args=default_args,
schedule_interval='1 1 * * *',
description='Stupid DAG'
)
def do_something(**kwargs):
print("do something")
t1 = PythonOperator(
task_id='task1',
python_callable=do_something,
dag=dag,
)
t2 = PythonOperator(
task_id='task2',
python_callable=do_something,
dag=dag,
)
t3 = PythonOperator(
task_id='task3',
python_callable=do_something,
dag=dag,
)
list_task = [PythonOperator(
task_id='task4',
python_callable=do_something,
dag=dag,
)]
t1 >> t2 >> t3
"""
with open('tmp_dag.py', 'w') as f:
f.write(content_dag)
We load the DAG as a module, to access the variables
# We add the current directory in the sys.path, so we can load the DAG as a module
sys.path.append('.')
# import the DAG asa module
module = __import__(dag_name)
# transform module object to dict, get vairables
variables = {a: module.__getattribute__(a) for a in dir(module)}
# Now we have all variables available in this module
variables
{'DAG': airflow.models.dag.DAG, 'PythonOperator': airflow.operators.python.PythonOperator, '__builtins__': {'__name__': 'builtins', '__doc__': "Built-in functions, exceptions, and other objects.\n\nNoteworthy: None is the `nil' object; Ellipsis represents `...' in slices.", '__package__': '', '__loader__': _frozen_importlib.BuiltinImporter, '__spec__': ModuleSpec(name='builtins', loader=<class '_frozen_importlib.BuiltinImporter'>, origin='built-in'), '__build_class__': <function __build_class__>, '__import__': <function __import__>, 'abs': <function abs(x, /)>, 'all': <function all(iterable, /)>, 'any': <function any(iterable, /)>, 'ascii': <function ascii(obj, /)>, 'bin': <function bin(number, /)>, 'breakpoint': <function breakpoint>, 'callable': <function callable(obj, /)>, 'chr': <function chr(i, /)>, 'compile': <function compile(source, filename, mode, flags=0, dont_inherit=False, optimize=-1, *, _feature_version=-1)>, 'delattr': <function delattr(obj, name, /)>, 'dir': <function dir>, 'divmod': <function divmod(x, y, /)>, 'eval': <function eval(source, globals=None, locals=None, /)>, 'exec': <function exec(source, globals=None, locals=None, /)>, 'format': <function format(value, format_spec='', /)>, 'getattr': <function getattr>, 'globals': <function globals()>, 'hasattr': <function hasattr(obj, name, /)>, 'hash': <function hash(obj, /)>, 'hex': <function hex(number, /)>, 'id': <function id(obj, /)>, 'input': <bound method Kernel.raw_input of <ipykernel.ipkernel.IPythonKernel object at 0x7f950006d430>>, 'isinstance': <function isinstance(obj, class_or_tuple, /)>, 'issubclass': <function issubclass(cls, class_or_tuple, /)>, 'iter': <function iter>, 'len': <function len(obj, /)>, 'locals': <function locals()>, 'max': <function max>, 'min': <function min>, 'next': <function next>, 'oct': <function oct(number, /)>, 'ord': <function ord(c, /)>, 'pow': <function pow(base, exp, mod=None)>, 'print': <function print>, 'repr': <function repr(obj, /)>, 'round': <function round(number, ndigits=None)>, 'setattr': <function setattr(obj, name, value, /)>, 'sorted': <function sorted(iterable, /, *, key=None, reverse=False)>, 'sum': <function sum(iterable, /, start=0)>, 'vars': <function vars>, 'None': None, 'Ellipsis': Ellipsis, 'NotImplemented': NotImplemented, 'False': False, 'True': True, 'bool': bool, 'memoryview': memoryview, 'bytearray': bytearray, 'bytes': bytes, 'classmethod': classmethod, 'complex': complex, 'dict': dict, 'enumerate': enumerate, 'filter': filter, 'float': float, 'frozenset': frozenset, 'property': property, 'int': int, 'list': list, 'map': map, 'object': object, 'range': range, 'reversed': reversed, 'set': set, 'slice': slice, 'staticmethod': staticmethod, 'str': str, 'super': super, 'tuple': tuple, 'type': type, 'zip': zip, '__debug__': True, 'BaseException': BaseException, 'Exception': Exception, 'TypeError': TypeError, 'StopAsyncIteration': StopAsyncIteration, 'StopIteration': StopIteration, 'GeneratorExit': GeneratorExit, 'SystemExit': SystemExit, 'KeyboardInterrupt': KeyboardInterrupt, 'ImportError': ImportError, 'ModuleNotFoundError': ModuleNotFoundError, 'OSError': OSError, 'EnvironmentError': OSError, 'IOError': OSError, 'EOFError': EOFError, 'RuntimeError': RuntimeError, 'RecursionError': RecursionError, 'NotImplementedError': NotImplementedError, 'NameError': NameError, 'UnboundLocalError': UnboundLocalError, 'AttributeError': AttributeError, 'SyntaxError': SyntaxError, 'IndentationError': IndentationError, 'TabError': TabError, 'LookupError': LookupError, 'IndexError': IndexError, 'KeyError': KeyError, 'ValueError': ValueError, 'UnicodeError': UnicodeError, 'UnicodeEncodeError': UnicodeEncodeError, 'UnicodeDecodeError': UnicodeDecodeError, 'UnicodeTranslateError': UnicodeTranslateError, 'AssertionError': AssertionError, 'ArithmeticError': ArithmeticError, 'FloatingPointError': FloatingPointError, 'OverflowError': OverflowError, 'ZeroDivisionError': ZeroDivisionError, 'SystemError': SystemError, 'ReferenceError': ReferenceError, 'MemoryError': MemoryError, 'BufferError': BufferError, 'Warning': Warning, 'UserWarning': UserWarning, 'DeprecationWarning': DeprecationWarning, 'PendingDeprecationWarning': PendingDeprecationWarning, 'SyntaxWarning': SyntaxWarning, 'RuntimeWarning': RuntimeWarning, 'FutureWarning': FutureWarning, 'ImportWarning': ImportWarning, 'UnicodeWarning': UnicodeWarning, 'BytesWarning': BytesWarning, 'ResourceWarning': ResourceWarning, 'ConnectionError': ConnectionError, 'BlockingIOError': BlockingIOError, 'BrokenPipeError': BrokenPipeError, 'ChildProcessError': ChildProcessError, 'ConnectionAbortedError': ConnectionAbortedError, 'ConnectionRefusedError': ConnectionRefusedError, 'ConnectionResetError': ConnectionResetError, 'FileExistsError': FileExistsError, 'FileNotFoundError': FileNotFoundError, 'IsADirectoryError': IsADirectoryError, 'NotADirectoryError': NotADirectoryError, 'InterruptedError': InterruptedError, 'PermissionError': PermissionError, 'ProcessLookupError': ProcessLookupError, 'TimeoutError': TimeoutError, 'open': <function io.open(file, mode='r', buffering=-1, encoding=None, errors=None, newline=None, closefd=True, opener=None)>, 'copyright': Copyright (c) 2001-2020 Python Software Foundation. All Rights Reserved. Copyright (c) 2000 BeOpen.com. All Rights Reserved. Copyright (c) 1995-2001 Corporation for National Research Initiatives. All Rights Reserved. Copyright (c) 1991-1995 Stichting Mathematisch Centrum, Amsterdam. All Rights Reserved., 'credits': Thanks to CWI, CNRI, BeOpen.com, Zope Corporation and a cast of thousands for supporting Python development. See www.python.org for more information., 'license': Type license() to see the full license text, 'help': Type help() for interactive help, or help(object) for help about object., '__IPYTHON__': True, 'display': <function IPython.core.display.display(*objs, include=None, exclude=None, metadata=None, transient=None, display_id=None, **kwargs)>, 'get_ipython': <bound method InteractiveShell.get_ipython of <ipykernel.zmqshell.ZMQInteractiveShell object at 0x7f950006d5b0>>}, '__doc__': None, '__name__': 'tmp_dag', '__package__': '', '__warningregistry__': {'version': 80}, 'dag': <DAG: tmp_dag>, 'days_ago': <function airflow.utils.dates.days_ago(n, hour=0, minute=0, second=0, microsecond=0)>, 'default_args': {'start_date': datetime.datetime(2022, 2, 9, 0, 0, tzinfo=Timezone('UTC')), 'owner': 'Data Banana', 'retries': 0}, 'do_something': <function tmp_dag.do_something(**kwargs)>, 'list_task': [<Task(PythonOperator): task4>], 't1': <Task(PythonOperator): task1>, 't2': <Task(PythonOperator): task2>, 't3': <Task(PythonOperator): task3>}
Now We will start to find the task object in all variables...
# Get tasks from simple variable
tasks = list(filter(lambda x: isinstance(x, BaseOperator), variables.values()))
tasks
[<Task(PythonOperator): task1>, <Task(PythonOperator): task2>, <Task(PythonOperator): task3>]
# Get tasks from list of tasks
list_var = list(filter(lambda x: isinstance(x, list), variables.values()))
tasks_from_list = []
for l in list_var:
for var in l:
if isinstance(var, BaseOperator):
tasks_from_list.append(var)
tasks_from_list
[<Task(PythonOperator): task4>]
# So here you may handle more scenario ...
all_tasks = tasks + tasks_from_list
all_tasks
[<Task(PythonOperator): task1>, <Task(PythonOperator): task2>, <Task(PythonOperator): task3>, <Task(PythonOperator): task4>]
for task in all_tasks:
print(task.execute)
<bound method PythonOperator.execute of <Task(PythonOperator): task1>> <bound method PythonOperator.execute of <Task(PythonOperator): task2>> <bound method PythonOperator.execute of <Task(PythonOperator): task3>> <bound method PythonOperator.execute of <Task(PythonOperator): task4>>
Now we want to change the behaviour of those tasks, and also navigate on the task downstream list to be sure we didn't forget some tasks
def do_nothing(context):
print('do nothing')
def recursive_disabled_tasks(tasks, tasks_id_to_run=[]):
for i in range(len(tasks)):
if tasks[i].task_id not in tasks_id_to_run:
tasks[i].execute = do_nothing
# Navigate in downstream_list to check if we didn't
if len(tasks[i].downstream_list) > 0:
recursive_disabled_tasks(tasks[i].downstream_list, tasks_id_to_run)
recursive_disabled_tasks(all_tasks, ['task4'])
for task in all_tasks:
print(task.execute)
<function do_nothing at 0x7f94f0199550> <function do_nothing at 0x7f94f0199550> <function do_nothing at 0x7f94f0199550> <bound method PythonOperator.execute of <Task(PythonOperator): task4>>
# Let's just rename the DAG
dag = variables.get('dag')
dag.dag_id += '_test'