I originally asked this question, updating here with the solution thanks to Marvin’s guidance
Here’s an example:
dask-gateway helmchart values.yaml:
gateway:
extraConfig:
options: |
from dask_gateway_server.options import Options, Integer, Float, String
def options_handler(options):
return {
"worker_cores": options.worker_cores,
"worker_memory": int(options.worker_memory * 2 ** 30),
"worker_extra_pod_config": {
"tolerations": [
{
"key": "nodepool",
"operator": "Equal",
"value": options.nodepool,
"effect": "NoSchedule",
}
],
"affinity": {
"nodeAffinity": {
"requiredDuringSchedulingIgnoredDuringExecution": {
"nodeSelectorTerms": [{
"matchExpressions": [{
"key": "cloud.google.com/gke-nodepool",
"operator": "In",
"values": [options.nodepool]
}]
}]
}
}
}
}
}
c.Backend.cluster_options = Options(
String("nodepool", default='default-pool', label="Worker NodePool"),
Integer("worker_cores", default=1, min=1, max=16, label="Worker Cores"),
Float("worker_memory", default=0.5, min=0.1, max=32, label="Worker Memory (GiB)"),
handler=options_handler,
)
python code from client side:
import os
from prefect import task, flow, get_run_logger
from prefect_dask import DaskTaskRunner
from dask_gateway import BasicAuth
from platform import node, platform
gateway_address = os.environ["DASK_GATEWAY_ADDRESS"]
gateway_password = os.environ["DASK_PASSWORD"]
auth = BasicAuth(password=gateway_password)
runner = DaskTaskRunner(
cluster_class='dask_gateway.GatewayCluster',
adapt_kwargs={'minimum': 1, 'maximum': 10, 'active': True},
cluster_kwargs={
'auth': auth,
'address': gateway_address,
'worker_cores': 1,
'worker_memory': 1,
'nodepool': 't4-cpu-pool'
}
)
@task
def check():
logger = get_run_logger()
logger.info(f"Network: {node()}. ✅")
logger.info(f"Instance: {platform()}. ✅")
@flow(task_runner=runner)
def poc_flow():
check.submit()
if __name__ == "__main__":
poc_flow()
Works really well! One gateway, and I create a GatewayCluster for each nodepool that I need to use.
Thanks Marvin!