import os
import itertools
from typing import List, Optional, Any, Literal, Dict, Iterator, Union
import time
import warnings
from concurrent import futures
from tqdm import tqdm
import requests
from ginkgo_ai_client.queries import QueryBase
[docs]
class RequestError(Exception):
"""An exception raised by a request, due to the query content or a system error.
This exception carries the original query and the result url to enable users to
better handle failure cases.
Parameters
----------
cause: Exception
The original exception that caused the request to fail.
query: QueryBase (optional)
The query that failed. This enables users to retrieve and re-try the failed
queries in a batch query
result_url: str (optional)
The url where the result can be retrieved from. This enables users to get the
result later if the failure cause was a temporary network error or an
accidental timeout.
"""
def __init__(
self,
cause: Exception,
query: Optional[QueryBase] = None,
result_url: Optional[str] = None,
):
self.cause = cause
self.query = query
self.result_url = result_url
message = self._format_error_message()
super().__init__(message)
def _format_error_message(self):
cause_str = f"{self.cause.__class__.__name__}: {self.cause}"
url_str = (
f"\n\nThis happened while polling this url for results:\n{self.result_url}"
if self.result_url is not None
else ""
)
return f"{cause_str}\n on query: {self.query}{url_str}"
[docs]
class GinkgoAIClient:
"""A client for the public Ginkgo AI models API.
Parameters
----------
api_key: str (optional)
The API key to use for the Ginkgo AI API. If none is provided, the
`GINKGOAI_API_KEY` environment variable will be used.
polling_delay: float (default: 1)
The delay between polling requests to the Ginkgo AI API, in seconds.
Examples
--------
.. code-block:: python
client = GinkgoAIClient()
query = MaskedInferenceQuery("MPK<mask><mask>RRL", model="ginkgo-aa0-650m")
response = client.send_request(query)
# response["sequence"] == "MPKYLRRL"
responses = client.send_batch_request([query_params, other_query_params])
"""
INFERENCE_URL = "https://api.ginkgobioworks.ai/v1/transforms/run"
BATCH_INFERENCE_URL = "https://api.ginkgobioworks.ai/v1/batches/transforms/run"
def __init__(
self,
api_key: Optional[str] = None,
polling_delay: float = 1,
):
if api_key is None:
api_key = os.environ.get("GINKGOAI_API_KEY")
if api_key is None:
raise ValueError(
"No API key provided. Please provide an API key or set the "
"`GINKGOAI_API_KEY` environment variable."
)
self.api_key = api_key
self.polling_delay = polling_delay
[docs]
def send_request(
self,
query: QueryBase,
timeout: float = 60,
) -> Any:
"""Send a query to the Ginkgo AI API.
Parameters
----------
query: QueryBase
The query to send to the Ginkgo AI API.
timeout: float (default: 60)
The maximum time to wait for the query to complete, in seconds.
Returns
-------
Response
The response from the Ginkgo AI API, for instance `{"sequence": "ATG..."}`.
It will be different depending on the query, see the different docstrings
of the helper methods ending in `*_params`.
Raises
------
RequestError
If the request failed due to the query content or a system error. The
exception carries the original query and (if it reached that stage) the
url it was polling for the results.
"""
# LAUNCH THE REQUEST
headers = {"x-api-key": self.api_key}
json = query.to_request_params()
launch_response = requests.post(self.INFERENCE_URL, headers=headers, json=json)
if not launch_response.ok:
status_code, text = (launch_response.status_code, launch_response.text)
# TODO: add different error types depending on exception
cause = IOError(
f"Request failed at launch with status {status_code}: {text}."
)
raise RequestError(query=query, cause=cause)
# GET THE JOBID FROM THE RESPONSE
launch_response = launch_response.json()
response_has_result = "result" in launch_response
if not (response_has_result) and ("?jobId" in launch_response["result"]):
cause = IOError(f"Unexpected response format at launch: {launch_response}")
raise RequestError(query=query, cause=cause)
result_url = launch_response["result"]
# POLL UNTIL THE JOB COMPLETES, AN ERROR OCCURS, OR WE TIME OUT.
start_time = time.time()
while True:
time.sleep(self.polling_delay)
poll_response = requests.get(result_url, headers=headers)
assert poll_response.ok, f"Unexpected response: {poll_response}"
poll_response = poll_response.json()
if poll_response["status"] == "COMPLETE":
job_result = poll_response["result"][0]
if job_result["error"] is not None:
cause = IOError(f"Query returned an error: {job_result['error']}")
raise RequestError(query=query, cause=cause, result_url=result_url)
# If we reach this point, we have a result, let's parse it and return it.
response_result = poll_response["result"][0]["result"]
return query.parse_response(response_result)
elif poll_response["status"] in ["PENDING", "IN_PROGRESS"]:
if time.time() - start_time > timeout:
cause = TimeoutError("Request timed out.")
raise RequestError(query=query, cause=cause, result_url=result_url)
else:
cause = IOError(f"Unexpected response status: {poll_response}")
raise RequestError(query=query, cause=cause, result_url=result_url)
[docs]
def send_batch_request(
self,
queries: List[QueryBase],
timeout: float = None,
on_failed_queries: Literal["ignore", "warn", "raise"] = "ignore",
) -> List[Any]:
"""Send multiple queries at once to the Ginkgo AI API in batch mode.
All the queries are sent at once and returned list has results in the same order
as the queries. Additionally, if the queries have a query_name attribute, it will
be preserved in the `query_name` attribute of the results.
Parameters
----------
queries: list of dict
The parameters of the queries (depends on the model used) used to send to the
Ginkgo AI API. These will typically be generated using the helper methods
in `ginkgo_ai_client.queries`.
timeout: float (optional)
The maximum time to wait for the batch to complete, in seconds.
on_failed_queries: Literal["ignore", "warn", "raise"] = "ignore"
What to do if some queries fail. The default is to ignore the failures, they
will be returned as part of the results and will carry the corresponding
`query_name`. The user will have to check and handle the queries themselves.
"warn" will print a warning if there are failed queries, "raise" will raise
an exception if at least one query failed.
Returns
-------
responses : List[Any]
A list of responses from the Ginkgo AI API. the class of the responses
depends on the class of the queries. If some
Raises
------
RequestException
If the request failed due to the query content or a system error.
Examples
--------
.. code-block:: python
client = GinkgoAIClient()
queries = [
MaskedInferenceQuery("MPK<mask><mask>RRL", model="ginkgo-aa0-650m"),
MaskedInferenceQuery("MES<mask><mask>YKL", model="ginkgo-aa0-650m")
]
responses = client.send_batch_request(queries)
"""
# LAUNCH THE BATCH REQUEST
headers = {"x-api-key": self.api_key}
request_params = [q.to_request_params() for q in queries]
launch_response = requests.post(
self.BATCH_INFERENCE_URL, headers=headers, json={"requests": request_params}
)
if not launch_response.status_code == 200:
status_code, text = (launch_response.status_code, launch_response.text)
raise Exception(f"Batch request failed with status {status_code}: {text}")
# GET THE BATCHID FROM THE RESPONSE
launch_response = launch_response.json()
response_has_result = "result" in launch_response
if not response_has_result or "?batchId" not in launch_response["result"]:
raise Exception(f"Unexpected response: {launch_response}")
ordered_job_ids = launch_response["jobIds"]
result_url = launch_response["result"]
# POLL UNTIL THE BATCH COMPLETES, AN ERROR OCCURS, OR WE TIME OUT.
start_time = time.time()
while True:
time.sleep(self.polling_delay)
poll_response = requests.get(result_url, headers=headers)
if not poll_response.ok:
cause = IOError(f"Unexpected response: {poll_response}")
raise RequestError(
cause=cause,
result_url=result_url,
)
poll_response = poll_response.json()
if poll_response["status"] == "COMPLETE":
return self._process_batch_request_results(
queries=queries,
ordered_job_ids=ordered_job_ids,
response=poll_response,
failed_queries_action=on_failed_queries,
)
elif poll_response["status"] in ["PENDING", "IN_PROGRESS"]:
if timeout is not None and (time.time() - start_time > timeout):
cause = TimeoutError(f"Batch request took over {timeout} seconds.")
raise RequestError(
cause=cause,
result_url=result_url,
)
else:
raise Exception(f"Unexpected response status: {poll_response}")
[docs]
def send_requests_by_batches(
self,
queries: Union[List[QueryBase], Iterator[QueryBase]],
batch_size: int = 20,
timeout: float = None,
on_failed_queries: Literal["ignore", "warn", "raise"] = "ignore",
max_concurrent: int = 3,
show_progress: bool = True,
):
"""Send multiple queries at once to the Ginkgo AI API in batch mode.
This method is useful for sending large numbers of queries to the Ginkgo AI API
and process results in small batches as they are ready. It avoids running out of
RAM by holding thousands of requests and their results in memory, and avoids
overwhelming the web API servers.
The method divides the queries in small batches, then submits the batches to the
web API (only 3 batches are submitted at the same time by default), and returns
the list of results in each batch as soon as a full batch is ready.
**Important Warning**: this means that the batch results are not returned strictly
in the same order as the batches sent. The best way to attribute results to
inputs is to give each input query a `query_name` attribute, which will be
preserved in the `query_name` attribute of the results. This is done automatically
by some query methods such as `.iter_from_fasta()` which will attribute the
sequence name each query.
Examples
--------
.. code-block:: python
model="esm2-650m"
queries = MeanEmbeddingQuery.iter_from_fasta("sequences.fasta", model=model)
for batch_result in client.send_requests_by_batches(queries, batch_size=10):
for query_result in batch_result:
query_result.write_to_jsonl("results.jsonl")
Parameters
----------
queries: Union[List[QueryBase], Iterator[QueryBase]]
The queries to send to the Ginkgo AI API. This can be a list or any iterable
or an iterator
batch_size: int (default: 20)
The size of the batches to send to the Ginkgo AI API.
timeout: float (optional)
The maximum time to wait for one batch to complete, in seconds.
on_failed_queries: Literal["ignore", "warn", "raise"] = "ignore"
What to do if some queries fail. The default is to ignore the failures, they
will be returned as part of the results and will carry the corresponding
`query_name`. The user will have to check and handle the queries themselves.
"warn" will print a warning if there are failed queries, "raise" will raise
an exception if at least one query failed.
"""
# Create batch iterator
query_iterator = iter(queries)
batches = iter(lambda: list(itertools.islice(query_iterator, batch_size)), [])
try:
# will only work for lists or ranges() or iterators with a len() such as
# our IteratorWithLength which some of our utility methods return
n_queries = len(queries)
total = n_queries // batch_size + (1 if n_queries % batch_size else 0)
except Exception:
total = None
if show_progress:
progress_bar = tqdm(total=total, desc="Processing batches", mininterval=1)
else:
progress_bar = None
def send_batch(batch):
return self.send_batch_request(
queries=batch, timeout=timeout, on_failed_queries=on_failed_queries
)
for result in process_with_limited_concurrency(
element_iterator=batches,
function=send_batch,
max_concurrent=max_concurrent,
progress_bar=progress_bar,
):
yield result
@classmethod
def _process_batch_request_results(
cls,
queries: List[QueryBase],
ordered_job_ids: List[str],
response: Dict,
failed_queries_action: Literal["ignore", "warn", "raise"] = "ignore",
) -> List[Any]:
"""Apply parsing and error handling to the list of results of a batch request."""
# Technical note: to understand the parsing, one should know that the
# API returns a list of request results where each element is of the form
# {"jobId": "...", "result": [{"error": None, "result": ...}]}
ordered_results = sorted(
response["requests"],
key=lambda x: ordered_job_ids.index(x["jobId"]),
)
parsed_results = [
cls._parse_batch_request_result(query, request_result["result"][0])
for query, request_result in zip(queries, ordered_results)
]
if failed_queries_action != "ignore":
errored_results = [r for r in parsed_results if isinstance(r, RequestError)]
n_errored = len(errored_results)
n_queries = len(queries)
if len(errored_results):
msg = f"{n_errored}/{n_queries} queries in the batch failed"
if failed_queries_action == "warn":
warnings.warn(msg)
else:
raise IOError(msg)
return parsed_results
@staticmethod
def _parse_batch_request_result(query: QueryBase, result: Dict) -> Any:
"""Parse the result of a single query from a batch request (output or error)."""
if result["error"] is not None:
return RequestError(
cause=Exception(result["error"]),
query=query,
)
return query.parse_response(result["result"])
def process_with_limited_concurrency(
element_iterator, function, max_concurrent: int, progress_bar: Optional = None
):
with futures.ThreadPoolExecutor(max_workers=max_concurrent) as executor:
# Initialize active futures
active_futures = []
# Submit initial batch of futures up to max_concurrent
for _ in range(max_concurrent):
try:
element = next(element_iterator)
future = executor.submit(function, element)
active_futures.append(future)
except StopIteration:
break
# Process results and submit new batches as old ones complete
while active_futures:
# Wait for the first future to complete
FC = futures.FIRST_COMPLETED
completed, active_futures = (
futures.wait(active_futures, return_when=FC)[0],
list(futures.wait(active_futures, return_when=FC)[1]),
)
# Submit a new batch if there are any left
try:
element = next(element_iterator)
future = executor.submit(function, element)
active_futures.append(future)
except StopIteration:
pass
# Yield completed results and update progress
for future in completed:
yield future.result()
if progress_bar is not None:
progress_bar.update(1)
if progress_bar is not None:
progress_bar.close()