JSON Output Validation#
Introduction#
When setting up ARTKIT pipelines, you’ll often want to parse multiple values from a model’s response. In these cases, prompting the model to return JSON-formatted output is incredibly useful.
Many of the best commercial models are able to reliably produce JSONs in a particular format. Some LLM providers can even guarantee a specific output format, by setting invalid token probabilities to 0 when sampling the next token.
However, if you are using a weaker model, have a high temperature, or require a complex custom format for your model response (perhaps not even JSON), it can be helpful to validate the output before passing it to the next step in the pipeline.
To support JSON formatting, ARTKIT provides the parse_json_autofix
utility function. Below, we’ll show you how it works.
Example#
We’ll start with a simple example that uses a JSON output format to pass model results between pipeline steps. In this case, we’ll ask an LLM for the five most populous states, and then multiply their population by two:
[1]:
import json
import logging
from dotenv import load_dotenv
import artkit.api as ak
from artkit.model.llm.util import parse_json_autofix
load_dotenv()
logging.basicConfig(level=logging.WARNING)
gpt_chat = ak.CachedChatModel(
ak.OpenAIChat(
model_id="gpt-3.5-turbo",
seed=0,
),
database="./cache/json_output_validation.db"
)
[2]:
state_prompt = """\
Given an input integer 'N', list the 'N' most populous states in the United States.
# Output Format
[
{{
"state": "<state1>",
"population": <state1 population>,
}},
{{
"state": "<state2>",
"populate": <state2 population>,
}}
]
"""
async def get_states(n: int, llm: ak.ChatModel):
for response in await llm.get_response(message=str(n)):
json_response = json.loads(response)
for item in json_response:
# Note that the model response will always be a string, so we need to explicitly
# cast "population" as an into to use in later steps
yield {"state": item["state"], "population": int(item["population"])}
def multiply_population(population: int, factor: int):
return {"factor": factor, "multiplied_population": population * factor}
# Run the steps and print results
steps = ak.chain(
ak.step("get_states", get_states, llm=gpt_chat.with_system_prompt(state_prompt)),
ak.step("multiply_population", multiply_population, factor=2)
)
input = [{"n": 5}]
result = ak.run(input=input, steps=steps)
result.to_frame()
[2]:
input | get_states | multiply_population | |||
---|---|---|---|---|---|
n | state | population | factor | multiplied_population | |
item | |||||
0 | 5 | California | 39538223 | 2 | 79076446 |
1 | 5 | Texas | 29145505 | 2 | 58291010 |
2 | 5 | Florida | 21538187 | 2 | 43076374 |
3 | 5 | New York | 20201249 | 2 | 40402498 |
4 | 5 | Pennsylvania | 13002700 | 2 | 26005400 |
If the model made small errors in the output format, the above pipeline would fail with a JSONDecodeError
. We’ll handle this by calling parse_json_autofix
before loading the raw model output as a JSON, which will use an LLM to attempt to fix any errors:
[3]:
error_state_prompt = state_prompt + """\
\nYou must introduce a few small errors in the output format, such as adding or removing commas, colons, quotation marks.
"""
async def get_and_validate_states(n: int, llm: ak.ChatModel):
for response in await llm.with_system_prompt(error_state_prompt).get_response(message=str(n)):
# Instead of calling json.loads directly, we'll pass the result to parse_json_autofix
parsed_response = await parse_json_autofix(json=response, model=llm)
for item in parsed_response:
yield {"state": item["state"], "population": item["population"]}
# Run the steps and print results
error_steps = ak.chain(
ak.step("get_states", get_and_validate_states, llm=gpt_chat),
ak.step("multiply_population", multiply_population, factor=2)
)
result = ak.run(input=input, steps=error_steps)
result.to_frame()
WARNING:artkit.model.llm.util._json:Attempting to fix malformed JSON:
[
{
state: "California",
population: 39538223
},
{
state: "Texas"
population: 29145505
},
{
"state": "Florida",
population: "21538187"
},
{
state: "New York",
"population": 20201249
},
{
state: "Pennsylvania",
population: 12820878
}
]
[3]:
input | get_states | multiply_population | |||
---|---|---|---|---|---|
n | state | population | factor | multiplied_population | |
item | |||||
0 | 5 | California | 39538223 | 2 | 79076446 |
1 | 5 | Texas | 29145505 | 2 | 58291010 |
2 | 5 | Florida | 21538187 | 2 | 43076374 |
3 | 5 | New York | 20201249 | 2 | 40402498 |
4 | 5 | Pennsylvania | 12820878 | 2 | 25641756 |
ARTKIT logs a warning that the original JSON output was malformed — however, we can see that after running parse_json_autofix
, our final results are the same as in the initial run. Under the hood, parse_json_autofix
sends a query to the LLM with the malformed json together with the error message, and asks the LLM to fix the error message.
Concluding remarks#
In order to enforce that an LLM’s output follows a particular format, the most reliable method is to normalize token probabilities after deterministically setting all invalid tokens’ probabilities to 0. However, this approach is infeasible for most projects, since they often do not have access to the model. For the special case of JSON and the latest OpenAI models, you can use the response_format
parameter.
This notebook uses an approach that works more generally:
apply a parser to the output
show any eventual parsing errors to the LLM
ask the LLM to resolve the error
We have shown how to efficiently implement this approach with ARTKIT’s parse_json_autofix
, which can be used to validate and fix JSON-formatted output from any LLM.
If you have other ideas on this topic, please consider contributing!