TLDR: The function must be wrapped with a @task
decorator for it to run in parallel.
Serial finishes in 14.6s on my Macbook.
import xarray as xr
import matplotlib.pyplot as plt
def plot(ds, time):
plt.figure()
ds['air'].plot()
plt.savefig(str(time)[:16])
plt.close()
def process():
ds = xr.tutorial.open_dataset('air_temperature').isel(
time=slice(0, 250))
[plot(ds.sel(time=time), time) for time in ds['time'].values]
process()
Dask with four workers (processes) finishes in 7.1 seconds
import dask
import xarray as xr
import matplotlib.pyplot as plt
@dask.delayed()
def plot(ds, time):
plt.figure()
ds['air'].plot()
plt.savefig(str(time)[:16])
plt.close()
def dask_process():
ds = xr.tutorial.open_dataset('air_temperature').isel(
time=slice(0, 250))
dask.compute([
plot(ds.sel(time=time), time) for time in ds['time'].values],
scheduler="processes", num_workers=4
)
dask_process()
Prefect with four workers in DaskTaskRunner finishes in 18.4 seconds.
import xarray as xr
import matplotlib.pyplot as plt
from prefect import flow, task
from prefect.task_runners import DaskTaskRunner
def plot(ds, time):
plt.figure()
ds['air'].plot()
plt.savefig(str(time)[:16])
plt.close()
@flow(task_runner=DaskTaskRunner(cluster_kwargs={"n_workers": 4}))
def prefect_process():
ds = xr.tutorial.open_dataset('air_temperature').isel(
time=slice(0, 250))
[plot(ds.sel(time=time), time) for time in ds['time'].values]
prefect_process()
I suspect I am just using threads rather than processes for the DaskTaskRunner, but I don’t know how to specify the scheduler here.
With the default ConcurrentTaskRunner, it finishes in 15.4s.
import xarray as xr
import matplotlib.pyplot as plt
from prefect import flow, task
def plot(ds, time):
plt.figure()
ds['air'].plot()
plt.savefig(str(time)[:16])
plt.close()
@flow()
def prefect_process():
ds = xr.tutorial.open_dataset('air_temperature').isel(
time=slice(0, 250))
[plot(ds.sel(time=time), time) for time in ds['time'].values]
prefect_process()
Interestingly, if I wrap @task on plot(), the run time doubles to 29.6 seconds; I suspect it’s because I am spamming the logs?
import xarray as xr
import matplotlib.pyplot as plt
from prefect import flow, task
@task
def plot(ds, time):
plt.figure()
ds['air'].plot()
plt.savefig(str(time)[:16])
plt.close()
@flow()
def prefect_process():
ds = xr.tutorial.open_dataset('air_temperature').isel(
time=slice(0, 250))
[plot(ds.sel(time=time), time) for time in ds['time'].values]
prefect_process()