Problematic:¶
It can be quite complicate to test an Airflow DAG in local. Sometimes, you don't want to run the full DAG, but just some tasks.
Solution:¶
To solve this problem, we created a test function to run only some task of a DAG.
We also made a project with an exemple: https://github.com/data-banana/example_airflow_in_local
The test_dag function¶
This is the main function to test a DAG in local and disabled some tasks.
Note: this is not a generic solution, it will depend on how you declare your tasks in your DAG
import os
import sys
from airflow.models.baseoperator import BaseOperator
def test_dag(dag_name, tasks_id_to_run=[]):
def do_nothing(context):
print('do nothing')
pass
def recursive_disabled_tasks(tasks, tasks_id_to_run):
for i in range(len(tasks)):
if tasks[i].task_id not in tasks_id_to_run:
tasks[i].execute = do_nothing
if len(tasks[i].downstream_list) > 0:
recursive_disabled_tasks(tasks[i].downstream_list, tasks_id_to_run)
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
# Add path where are located DAGs, depend on your project ...
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, os.pardir, 'src')))
# import DAG
module = __import__(dag_name)
# transform module object to dict, get vairables
variables = {a: module.__getattribute__(a) for a in dir(module)}
# Get tasks from simple global variable
tasks = list(filter(lambda x: isinstance(x, BaseOperator), variables.values()))
# Get tasks from list of tasks in global variable
list_var = list(filter(lambda x: isinstance(x, list), variables.values()))
for l in list_var:
for var in l:
if isinstance(var, BaseOperator):
tasks.append(var)
# TODO: complete, for exemple if you declare DAGs in other objects type, like dict
# ...
# Disabled task, add dummy function
recursive_disabled_tasks(tasks, tasks_id_to_run)
dag = variables.get('dag')
# Add suffix to the dag name
dag.dag_id += '_test'
return dag
Let's try to explain the code¶
First of all, we will create a stupid example of DAG.
dag_name = 'tmp_dag'
content_dag = """
from airflow.operators.python_operator import PythonOperator
from airflow import DAG
from airflow.utils.dates import days_ago
default_args = {
"start_date": days_ago(1),
"owner": "Data Banana",
'retries': 0,
}
dag = DAG(
'tmp_dag',
default_args=default_args,
schedule_interval='1 1 * * *',
description='Stupid DAG'
)
def do_something(**kwargs):
print("do something")
t1 = PythonOperator(
task_id='task1',
python_callable=do_something,
dag=dag,
)
t2 = PythonOperator(
task_id='task2',
python_callable=do_something,
dag=dag,
)
t3 = PythonOperator(
task_id='task3',
python_callable=do_something,
dag=dag,
)
list_task = [PythonOperator(
task_id='task4',
python_callable=do_something,
dag=dag,
)]
t1 >> t2 >> t3
"""
with open('tmp_dag.py', 'w') as f:
f.write(content_dag)
We load the DAG as a module, to access the variables
# We add the current directory in the sys.path, so we can load the DAG as a module
sys.path.append('.')
# import the DAG asa module
module = __import__(dag_name)
# transform module object to dict, get vairables
variables = {a: module.__getattribute__(a) for a in dir(module)}
# Now we have all variables available in this module
variables
{'DAG': airflow.models.dag.DAG,
'PythonOperator': airflow.operators.python.PythonOperator,
'__builtins__': {'__name__': 'builtins',
'__doc__': "Built-in functions, exceptions, and other objects.\n\nNoteworthy: None is the `nil' object; Ellipsis represents `...' in slices.",
'__package__': '',
'__loader__': _frozen_importlib.BuiltinImporter,
'__spec__': ModuleSpec(name='builtins', loader=<class '_frozen_importlib.BuiltinImporter'>, origin='built-in'),
'__build_class__': <function __build_class__>,
'__import__': <function __import__>,
'abs': <function abs(x, /)>,
'all': <function all(iterable, /)>,
'any': <function any(iterable, /)>,
'ascii': <function ascii(obj, /)>,
'bin': <function bin(number, /)>,
'breakpoint': <function breakpoint>,
'callable': <function callable(obj, /)>,
'chr': <function chr(i, /)>,
'compile': <function compile(source, filename, mode, flags=0, dont_inherit=False, optimize=-1, *, _feature_version=-1)>,
'delattr': <function delattr(obj, name, /)>,
'dir': <function dir>,
'divmod': <function divmod(x, y, /)>,
'eval': <function eval(source, globals=None, locals=None, /)>,
'exec': <function exec(source, globals=None, locals=None, /)>,
'format': <function format(value, format_spec='', /)>,
'getattr': <function getattr>,
'globals': <function globals()>,
'hasattr': <function hasattr(obj, name, /)>,
'hash': <function hash(obj, /)>,
'hex': <function hex(number, /)>,
'id': <function id(obj, /)>,
'input': <bound method Kernel.raw_input of <ipykernel.ipkernel.IPythonKernel object at 0x7f950006d430>>,
'isinstance': <function isinstance(obj, class_or_tuple, /)>,
'issubclass': <function issubclass(cls, class_or_tuple, /)>,
'iter': <function iter>,
'len': <function len(obj, /)>,
'locals': <function locals()>,
'max': <function max>,
'min': <function min>,
'next': <function next>,
'oct': <function oct(number, /)>,
'ord': <function ord(c, /)>,
'pow': <function pow(base, exp, mod=None)>,
'print': <function print>,
'repr': <function repr(obj, /)>,
'round': <function round(number, ndigits=None)>,
'setattr': <function setattr(obj, name, value, /)>,
'sorted': <function sorted(iterable, /, *, key=None, reverse=False)>,
'sum': <function sum(iterable, /, start=0)>,
'vars': <function vars>,
'None': None,
'Ellipsis': Ellipsis,
'NotImplemented': NotImplemented,
'False': False,
'True': True,
'bool': bool,
'memoryview': memoryview,
'bytearray': bytearray,
'bytes': bytes,
'classmethod': classmethod,
'complex': complex,
'dict': dict,
'enumerate': enumerate,
'filter': filter,
'float': float,
'frozenset': frozenset,
'property': property,
'int': int,
'list': list,
'map': map,
'object': object,
'range': range,
'reversed': reversed,
'set': set,
'slice': slice,
'staticmethod': staticmethod,
'str': str,
'super': super,
'tuple': tuple,
'type': type,
'zip': zip,
'__debug__': True,
'BaseException': BaseException,
'Exception': Exception,
'TypeError': TypeError,
'StopAsyncIteration': StopAsyncIteration,
'StopIteration': StopIteration,
'GeneratorExit': GeneratorExit,
'SystemExit': SystemExit,
'KeyboardInterrupt': KeyboardInterrupt,
'ImportError': ImportError,
'ModuleNotFoundError': ModuleNotFoundError,
'OSError': OSError,
'EnvironmentError': OSError,
'IOError': OSError,
'EOFError': EOFError,
'RuntimeError': RuntimeError,
'RecursionError': RecursionError,
'NotImplementedError': NotImplementedError,
'NameError': NameError,
'UnboundLocalError': UnboundLocalError,
'AttributeError': AttributeError,
'SyntaxError': SyntaxError,
'IndentationError': IndentationError,
'TabError': TabError,
'LookupError': LookupError,
'IndexError': IndexError,
'KeyError': KeyError,
'ValueError': ValueError,
'UnicodeError': UnicodeError,
'UnicodeEncodeError': UnicodeEncodeError,
'UnicodeDecodeError': UnicodeDecodeError,
'UnicodeTranslateError': UnicodeTranslateError,
'AssertionError': AssertionError,
'ArithmeticError': ArithmeticError,
'FloatingPointError': FloatingPointError,
'OverflowError': OverflowError,
'ZeroDivisionError': ZeroDivisionError,
'SystemError': SystemError,
'ReferenceError': ReferenceError,
'MemoryError': MemoryError,
'BufferError': BufferError,
'Warning': Warning,
'UserWarning': UserWarning,
'DeprecationWarning': DeprecationWarning,
'PendingDeprecationWarning': PendingDeprecationWarning,
'SyntaxWarning': SyntaxWarning,
'RuntimeWarning': RuntimeWarning,
'FutureWarning': FutureWarning,
'ImportWarning': ImportWarning,
'UnicodeWarning': UnicodeWarning,
'BytesWarning': BytesWarning,
'ResourceWarning': ResourceWarning,
'ConnectionError': ConnectionError,
'BlockingIOError': BlockingIOError,
'BrokenPipeError': BrokenPipeError,
'ChildProcessError': ChildProcessError,
'ConnectionAbortedError': ConnectionAbortedError,
'ConnectionRefusedError': ConnectionRefusedError,
'ConnectionResetError': ConnectionResetError,
'FileExistsError': FileExistsError,
'FileNotFoundError': FileNotFoundError,
'IsADirectoryError': IsADirectoryError,
'NotADirectoryError': NotADirectoryError,
'InterruptedError': InterruptedError,
'PermissionError': PermissionError,
'ProcessLookupError': ProcessLookupError,
'TimeoutError': TimeoutError,
'open': <function io.open(file, mode='r', buffering=-1, encoding=None, errors=None, newline=None, closefd=True, opener=None)>,
'copyright': Copyright (c) 2001-2020 Python Software Foundation.
All Rights Reserved.
Copyright (c) 2000 BeOpen.com.
All Rights Reserved.
Copyright (c) 1995-2001 Corporation for National Research Initiatives.
All Rights Reserved.
Copyright (c) 1991-1995 Stichting Mathematisch Centrum, Amsterdam.
All Rights Reserved.,
'credits': Thanks to CWI, CNRI, BeOpen.com, Zope Corporation and a cast of thousands
for supporting Python development. See www.python.org for more information.,
'license': Type license() to see the full license text,
'help': Type help() for interactive help, or help(object) for help about object.,
'__IPYTHON__': True,
'display': <function IPython.core.display.display(*objs, include=None, exclude=None, metadata=None, transient=None, display_id=None, **kwargs)>,
'get_ipython': <bound method InteractiveShell.get_ipython of <ipykernel.zmqshell.ZMQInteractiveShell object at 0x7f950006d5b0>>},
'__doc__': None,
'__name__': 'tmp_dag',
'__package__': '',
'__warningregistry__': {'version': 80},
'dag': <DAG: tmp_dag>,
'days_ago': <function airflow.utils.dates.days_ago(n, hour=0, minute=0, second=0, microsecond=0)>,
'default_args': {'start_date': datetime.datetime(2022, 2, 9, 0, 0, tzinfo=Timezone('UTC')),
'owner': 'Data Banana',
'retries': 0},
'do_something': <function tmp_dag.do_something(**kwargs)>,
'list_task': [<Task(PythonOperator): task4>],
't1': <Task(PythonOperator): task1>,
't2': <Task(PythonOperator): task2>,
't3': <Task(PythonOperator): task3>}
Now We will start to find the task object in all variables...
# Get tasks from simple variable
tasks = list(filter(lambda x: isinstance(x, BaseOperator), variables.values()))
tasks
[<Task(PythonOperator): task1>, <Task(PythonOperator): task2>, <Task(PythonOperator): task3>]
# Get tasks from list of tasks
list_var = list(filter(lambda x: isinstance(x, list), variables.values()))
tasks_from_list = []
for l in list_var:
for var in l:
if isinstance(var, BaseOperator):
tasks_from_list.append(var)
tasks_from_list
[<Task(PythonOperator): task4>]
# So here you may handle more scenario ...
all_tasks = tasks + tasks_from_list
all_tasks
[<Task(PythonOperator): task1>, <Task(PythonOperator): task2>, <Task(PythonOperator): task3>, <Task(PythonOperator): task4>]
for task in all_tasks:
print(task.execute)
<bound method PythonOperator.execute of <Task(PythonOperator): task1>> <bound method PythonOperator.execute of <Task(PythonOperator): task2>> <bound method PythonOperator.execute of <Task(PythonOperator): task3>> <bound method PythonOperator.execute of <Task(PythonOperator): task4>>
Now we want to change the behaviour of those tasks, and also navigate on the task downstream list to be sure we didn't forget some tasks
def do_nothing(context):
print('do nothing')
def recursive_disabled_tasks(tasks, tasks_id_to_run=[]):
for i in range(len(tasks)):
if tasks[i].task_id not in tasks_id_to_run:
tasks[i].execute = do_nothing
# Navigate in downstream_list to check if we didn't
if len(tasks[i].downstream_list) > 0:
recursive_disabled_tasks(tasks[i].downstream_list, tasks_id_to_run)
recursive_disabled_tasks(all_tasks, ['task4'])
for task in all_tasks:
print(task.execute)
<function do_nothing at 0x7f94f0199550> <function do_nothing at 0x7f94f0199550> <function do_nothing at 0x7f94f0199550> <bound method PythonOperator.execute of <Task(PythonOperator): task4>>
# Let's just rename the DAG
dag = variables.get('dag')
dag.dag_id += '_test'