Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pkg-r/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ Depends:
R (>= 4.1.0)
Imports:
bslib,
callr,
DBI,
duckdb,
ellmer,
ggplot2,
glue,
htmltools,
jsonlite,
Expand Down
90 changes: 70 additions & 20 deletions pkg-r/R/querychat.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,23 @@
#'
#' @export
querychat_init <- function(
df,
...,
table_name = deparse(substitute(df)),
greeting = NULL,
data_description = NULL,
extra_instructions = NULL,
prompt_template = NULL,
system_prompt = querychat_system_prompt(
df,
table_name,
# By default, pass through any params supplied to querychat_init()
...,
data_description = data_description,
extra_instructions = extra_instructions,
prompt_template = prompt_template
),
create_chat_func = purrr::partial(ellmer::chat_openai, model = "gpt-4o")
) {
table_name = deparse(substitute(df)),
greeting = NULL,
data_description = NULL,
extra_instructions = NULL,
prompt_template = NULL,
system_prompt = querychat_system_prompt(
df,
table_name,
# By default, pass through any params supplied to querychat_init()
...,
data_description = data_description,
extra_instructions = extra_instructions,
prompt_template = prompt_template
),
create_chat_func = purrr::partial(ellmer::chat_openai, model = "gpt-4o")) {
is_table_name_ok <- is.character(table_name) &&
length(table_name) == 1 &&
grepl("^[a-zA-Z][a-zA-Z0-9_]*$", table_name, perl = TRUE)
Expand Down Expand Up @@ -139,7 +138,8 @@ querychat_ui <- function(id) {
htmltools::tagList(
# TODO: Make this into a proper HTML dependency
shiny::includeCSS(system.file("www", "styles.css", package = "querychat")),
shinychat::chat_ui(ns("chat"), height = "100%", fill = TRUE)
shinychat::chat_ui(ns("chat"), height = "100%", fill = TRUE),
shiny::plotOutput(ns("llm_plot"), height = 300)
)
}

Expand Down Expand Up @@ -191,6 +191,47 @@ querychat_server <- function(id, querychat_config) {
session = session
)
}
plot_code <- shiny::reactiveVal(NULL)
# Preload the conversation with the system prompt. These are instructions for
# the chat model, and must not be shown to the end user.
chat <- create_chat_func(system_prompt = system_prompt)
output$llm_plot <- shiny::renderPlot({
code <- plot_code()
if (is.null(code) || !nzchar(code)) {
return(NULL)
}
df <- filtered_df()
forbidden <- c("system", "file", "unlink", "assign", "library", "require")
if (any(sapply(forbidden, grepl, code))) {
stop("Forbidden function detected in plot code.")
}
res <- tryCatch(
{
callr::r(
function(code, df) {
p <- eval(parse(text = code))
if (!inherits(p, "ggplot")) stop("Code did not return a ggplot object.")
p # return the ggplot object
},
args = list(code = code, df = df),
show = TRUE,
stdout = TRUE,
stderr = TRUE,
)
},
error = function(e) {
message(
"Plot error: ", e$message, "\n",
"Code: ", code, "\n",
)
plot.new()
text(0.5, 0.5, "Plot error. See R console for details.")
return(NULL)
}
)
if (inherits(res, "ggplot")) print(res)
invisible(res)
})

# Modifies the data presented in the data dashboard, based on the given SQL
# query, and also updates the title.
Expand Down Expand Up @@ -219,6 +260,18 @@ querychat_server <- function(id, querychat_config) {
}
}

update_plot <- function(ggplot_code) {
plot_code(ggplot_code)
append_output("\n```r\n", ggplot_code, "\n```\n\n")
}
chat$register_tool(ellmer::tool(
update_plot,
"Updates the plot displayed in the data dashboard, based on the given ggplot code.",
ggplot_code = ellmer::type_string(
"A string containing R code that generates a ggplot object."
)
))

# Perform a SQL query on the data, and return the results as JSON.
# @param query A DuckDB SQL query; must be a SELECT statement.
# @return The results of the query as a JSON string.
Expand All @@ -242,9 +295,6 @@ querychat_server <- function(id, querychat_config) {
df |> jsonlite::toJSON(auto_unbox = TRUE)
}

# Preload the conversation with the system prompt. These are instructions for
# the chat model, and must not be shown to the end user.
chat <- create_chat_func(system_prompt = system_prompt)
chat$register_tool(ellmer::tool(
update_dashboard,
"Modifies the data presented in the data dashboard, based on the given SQL query, and also updates the title.",
Expand Down
28 changes: 28 additions & 0 deletions pkg-r/inst/prompt/prompt.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,32 @@ If you find yourself offering example questions to the user as part of your resp

* `percentile_cont` and `percentile_disc` are "ordered set" aggregate functions. These functions are specified using the WITHIN GROUP (ORDER BY sort_expression) syntax, and they are converted to an equivalent aggregate function that takes the ordering expression as the first argument. For example, `percentile_cont(fraction) WITHIN GROUP (ORDER BY column [(ASC|DESC)])` is equivalent to `quantile_cont(column, fraction ORDER BY column [(ASC|DESC)])`.

## Task: Plotting with ggplot2

You can create and update plots in the dashboard using the `update_plot` tool. This tool takes a string of R code that generates a ggplot2 plot using the data frame `df` (which contains the currently filtered data). The code you provide will be evaluated and the resulting plot will be displayed in the dashboard.

* Always use valid R code that creates a ggplot2 plot and assigns it as the last expression (no assignment needed, just return the plot object).
* The data frame available for plotting is named `df`.
* Do not attempt to retrieve or manipulate data outside of `df`.
* Only use plotting code that is safe and reproducible.

## Plotting guardrails

When generating R code for plotting, you must never use or reference any of the following functions or statements: `system`, `file`, `unlink`, `assign`, `library`, `require`, or any function that accesses the file system, environment, or external resources. Only use functions from `ggplot2` and the provided data frame `df`. Any attempt to use forbidden functions will result in an error and your code will not be executed.

Example of plotting:

> [User]
> Show me a scatterplot of x vs y.
> [/User]
> [ToolCall]
> update_plot({ggplot_code: "ggplot2::ggplot(df, ggplot2::aes(x = x, y = y)) + ggplot2::geom_point()"})
> [/ToolCall]
> [ToolResponse]
> null
> [/ToolResponse]
> [Assistant]
> Here is a scatterplot of x vs y.
> [/Assistant]

{{extra_instructions}}
Loading