Skip to content

LLM Client

Implements get_llm_client method to configure litellm.completion to communicate with user provided LLM client.

get_llm_client(api_key, model, temperature, **kwargs)

Generate a lambda function around litellm.completion to be called from PromptRefiner.refine.

Parameters:

Name Type Description Default
api_key str

API key to access model.

required
model str

model name to use for refining prompt.

required
temperature float

Temperature for model.

required
**kwargs

Extra arguments to feed into model.

{}

Returns:

Type Description
Callable

A lambda function, which takes system_prompt and user_prompt as an argument and retunrs refined prompt on call.

Source code in promptrefiner/client_factory.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def get_llm_client(api_key: str, model: str, temperature: float, **kwargs) -> Callable:
    """
    Generate a lambda function around ` litellm.completion` to be called
    from `PromptRefiner.refine`.

    Args:
        api_key (str): API key to access model.
        model (str): model name to use for refining prompt.
        temperature (float): Temperature for model.
        **kwargs: Extra arguments to feed into model.

    Returns:
        A lambda function, which takes `system_prompt` and `user_prompt`
            as an argument and retunrs refined prompt on call.
    """

    # Identify which environment variable `litellm` expects for the chosen model
    provider = model.split("/")[0]
    expected_env_var = MODEL_API_KEY_MAP.get(provider)

    # Set the API key dynamically for the expected environment variable
    if expected_env_var and api_key:
        os.environ[expected_env_var] = api_key

    # Return a function that calls litellm
    return lambda system_prompt, user_prompt: litellm.completion(
        model=model,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ],
        temperature=temperature,
        **kwargs,
    )["choices"][0]["message"]["content"]