Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added property_layer with altair #2643

Open
wants to merge 19 commits into
base: main
Choose a base branch
from

Conversation

sanika-n
Copy link
Contributor

Implemented Property Layers in Altair
This is how it looks when it runs on a fairly simple model:
image

Copy link

Performance benchmarks:

Model Size Init time [95% CI] Run time [95% CI]
BoltzmannWealth small 🔵 -0.5% [-1.7%, +0.6%] 🔵 +0.0% [-0.2%, +0.2%]
BoltzmannWealth large 🔵 -0.7% [-1.1%, -0.3%] 🔵 +0.0% [-0.6%, +0.7%]
Schelling small 🔵 -1.1% [-1.5%, -0.8%] 🔵 -0.6% [-0.7%, -0.4%]
Schelling large 🔵 -0.9% [-1.3%, -0.6%] 🔵 -0.6% [-1.3%, +0.2%]
WolfSheep small 🔵 -0.2% [-0.4%, +0.1%] 🔵 -0.6% [-0.7%, -0.4%]
WolfSheep large 🔵 +0.6% [+0.3%, +1.0%] 🔵 -0.8% [-2.0%, +0.5%]
BoidFlockers small 🔵 +0.5% [+0.0%, +1.0%] 🔵 +2.5% [+2.3%, +2.7%]
BoidFlockers large 🔵 +0.2% [-0.1%, +0.6%] 🔵 +2.4% [+2.1%, +2.9%]

@tpike3
Copy link
Member

tpike3 commented Jan 26, 2025

@sanika-n also please look at #2644 as you and @nissu99 are working in the same space

As it looks like you two are looking over the entire altair implementation I would also recommend looking at #2642 discussion

It is always good to collaborate and think together you may address some larger visualization challenges.

Appreciating I gave recommended changes to #2641; becuase of these other PRs I would think more holisitically with @nissu99 first.

@EwoutH
Copy link
Member

EwoutH commented Feb 9, 2025

What's the status of this PR and what's needed to move it forward?

@sanika-n
Copy link
Contributor Author

I think we are waiting for PR #2644 to be merged first and then I guess after addressing any conflicting code, we can merge this as mentioned here

@quaquel
Copy link
Member

quaquel commented Feb 10, 2025

I requested changes in #2644, if those don't come soonish, I suggest we merge all other altair stuff that is ready to go.

@nissu99
Copy link
Contributor

nissu99 commented Feb 11, 2025

@EwoutH @quaquel @sanika-n sorry for the delay , will make the changes as early as possible.

@EwoutH EwoutH requested a review from Sahil-Chhoker February 16, 2025 07:25
Copy link
Collaborator

@Sahil-Chhoker Sahil-Chhoker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @sanika-n for this PR and sorry it took so long for it to be reviewed. Please take a look at the review and adjust the code accordingly.

Some other concerns I have:

  1. "UserWarning: Layer test dimensions ((5, 5)) do not match space dimensions (5, 5)."
    This warning is always there.
  2. I think a fitting name for base_width and base_height should be chart_width and chart_height, because currently it confuses if the width is for chart or for the property layer itself.

else:
raise ValueError(
f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'."
)
return chart
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't return chart here, it doesn't exist in this scope.

base_height=base_height,
)

chart = chart + agent_chart
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same problem here, chart does not exist in this scope

Comment on lines 182 to 187
unique_colors = list({agent["color"] for agent in all_agent_data})
encoding_dict["color"] = alt.Color(
"color:N",
scale=alt.Scale(domain=unique_colors, range=unique_colors),
legend=None,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this logic changed when the previous one was fine and it is not even working. I think the bug is Altair is now taking the colors in matplotlib format, stopping agent visualization completely.

A simple fix would be to just use the past code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not exactly sure why it is not working, could you share the code so that I can replicate the error? (I actually fixed some small errors elsewhere, so if it is still not working, I can work on it after you share the code)

If we used the existing syntax, then altair randomly assigns colors to labels> But with the edited version of the code, the name of the color is correctly mapped with the actual color.

Past code with agents of color red and green:

def agent_portrayal(agent):
    color = "green" if agent.unique_id % 2 == 0 else "red"
    return {"Shape": "o", "color": color, "size": 20}

image

Current Code:
image

Just for reference: https://chatgpt.com/share/67bf710b-9e3c-8010-bf61-681fc1aefbec

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are absolutely right, I had some kind of bug on my end, really sorry for that.

Comment on lines +289 to +293
elif "colormap" in portrayal:
cmap = portrayal.get("colormap", "viridis")
cmap_scale = alt.Scale(scheme=cmap, domain=[vmin, vmax])

chart = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you apply the alpha to colormaps as well. I think .mark_rect(opacity=alpha) should do the job.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing this out, I totally forgot to implement it...

Comment on lines 269 to 270
if "color" in portrayal:
df["color"] = df["value"].apply(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both the color and colormaps are not scaling according to the values of vmin and vmax. This feature should be present.

Comment on lines 271 to 276
lambda val,
portrayal=portrayal,
alpha=alpha: f"rgba({int(to_rgba(portrayal['color'], alpha=alpha)[0] * 255)}, {int(to_rgba(portrayal['color'], alpha=alpha)[1] * 255)}, {int(to_rgba(portrayal['color'], alpha=alpha)[2] * 255)}, {to_rgba(portrayal['color'], alpha=alpha)[3]:.2f})"
if val > 0
else "rgba(0, 0, 0, 0)"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this function also works as expected. The lambda function is calculating the same rgba values for every row because it's using the same portrayal['color'] and alpha values. If you try printing the values yourself, you would also see something like:

0     rgba(0, 0, 255, 1.00)
1     rgba(0, 0, 255, 1.00)
2     rgba(0, 0, 255, 1.00)
3     rgba(0, 0, 255, 1.00)
4     rgba(0, 0, 255, 1.00)
5     rgba(0, 0, 255, 1.00)
6     rgba(0, 0, 255, 1.00)
7     rgba(0, 0, 255, 1.00)
...

After some iterations, I came up with:

color = portrayal.get("color", None)
if color:
  rgba = to_rgba(color, alpha=alpha) # first convert into rgba using matplotlib's func.
  chart = (
      alt.Chart(df)
      .mark_rect(
          color=f"rgba({int(rgba[0] * 255)}, {int(rgba[1] * 255)}, {int(rgba[2] * 255)}, {rgba[3]})"
      )
      .encode(
          x=alt.X("x:O", axis=None),
          y=alt.Y("y:O", axis=None),
          opacity=alt.Opacity(
              "value:Q", # use quantitative here because we are dealing with numerical ranges
              scale=alt.Scale(domain=[vmin, vmax], range=[0, alpha])
          ) # scale the values properly with alpha
      )
      .properties(width=base_width, height=base_height, title=layer_name)
  )
  base = (base + chart) if base is not None else chart

But I don't think the scaling with vmin and vmax are right here, please check that.

@sanika-n
Copy link
Contributor Author

Thank you so much, I look into it as soon as possible, hopefully by the end of this week

@sanika-n
Copy link
Contributor Author

@Sahil-Chhoker, thank you so much for the review, just made all the corrections.... pls let me know if anything more is needed


chart = chart + agent_chart
return chart
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, chart does not exist here, you will have to declare a chart variable above to use it here. The chart declared inside the if statement get destroyed with it.

Copy link
Contributor Author

@sanika-n sanika-n Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could be wrong here but from what I know, in languages like C++ a variable defined inside a loop only exists within that loop’s scope but in python I am fairly sure that variables defined in loops remain accessible outside the loop and since I am defining chart both in the if and else part of the loop, it is definitely going to be defined by the time we reach the return line.
image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While you're right that Python variables from if/else blocks remain accessible afterward, it's safer to initialize chart at the function level first. This ensures it's always defined regardless of execution path. Could you update your code to follow this pattern? It prevents potential undefined variable issues if your conditions change later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okk, that makes sense, will change it 👍

@Sahil-Chhoker
Copy link
Collaborator

You have done a good job @sanika-n, I will do a thorough review this weekend once again.

@Sahil-Chhoker
Copy link
Collaborator

Sahil-Chhoker commented Feb 28, 2025

@quaquel I have some concerns regarding this PR, it would be helpful if they get resolved before taking this PR forward:

  1. tab:<color> does not work in Altair, so should the agent_portrayal be redefined if using Altair?.
  2. Same could be the case for property layer (if not using to_rgba method).
  3. The elegant implementation of the Altair property layer can be done by using matplotlib's method to_rgba but is it okay to use it since we are using Altair?

Comment on lines 269 to 295
if "color" in portrayal:
# any value less than vmin will be mapped to the color corresponding to vmin
# any value more than vmax will be mapped to the color corresponding to vmax
def apply_rgba(val, vmin=vmin, vmax=vmax, alpha=alpha, portrayal=portrayal):
a = (val - vmin) / (vmax - vmin)
a = max(0, min(a, 1)) # to ensure that a is between 0 and 1
a *= alpha # vmax will have an opacity corresponding to alpha
rgb_color = to_rgb(portrayal["color"])
r = int(rgb_color[0] * 255)
g = int(rgb_color[1] * 255)
b = int(rgb_color[2] * 255)

return f"rgba({r}, {g}, {b}, {a:.2f})"

df["color"] = df["value"].apply(apply_rgba)

chart = (
alt.Chart(df)
.mark_rect()
.encode(
x=alt.X("x:O", axis=None),
y=alt.Y("y:O", axis=None),
fill=alt.Fill("color:N", scale=None),
)
.properties(width=chart_width, height=chart_height, title=layer_name)
)
base = (base + chart) if base is not None else chart
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are overcomplicating it a bit, this should work:

Suggested change
if "color" in portrayal:
# any value less than vmin will be mapped to the color corresponding to vmin
# any value more than vmax will be mapped to the color corresponding to vmax
def apply_rgba(val, vmin=vmin, vmax=vmax, alpha=alpha, portrayal=portrayal):
a = (val - vmin) / (vmax - vmin)
a = max(0, min(a, 1)) # to ensure that a is between 0 and 1
a *= alpha # vmax will have an opacity corresponding to alpha
rgb_color = to_rgb(portrayal["color"])
r = int(rgb_color[0] * 255)
g = int(rgb_color[1] * 255)
b = int(rgb_color[2] * 255)
return f"rgba({r}, {g}, {b}, {a:.2f})"
df["color"] = df["value"].apply(apply_rgba)
chart = (
alt.Chart(df)
.mark_rect()
.encode(
x=alt.X("x:O", axis=None),
y=alt.Y("y:O", axis=None),
fill=alt.Fill("color:N", scale=None),
)
.properties(width=chart_width, height=chart_height, title=layer_name)
)
base = (base + chart) if base is not None else chart
import matplotlib.colors as mcolors
if "color" in portrayal:
color = portrayal["color"]
# Convert the user color + alpha to RGBA
rgba = mcolors.to_rgba(color, alpha=alpha)
layer_chart = (
alt.Chart(df)
.mark_rect(
color=f"rgba({int(rgba[0] * 255)}, "
f"{int(rgba[1] * 255)}, "
f"{int(rgba[2] * 255)}, "
f"{rgba[3]})"
)
.encode(
x=alt.X("x:O", axis=None),
y=alt.Y("y:O", axis=None),
opacity=alt.Opacity(
"value:Q",
scale=alt.Scale(domain=[vmin, vmax], range=[0, 1])
)
)
.properties(width=chart_width, height=chart_height, title=layer_name)
)

Colorbar should be implemented in altair as well, though I am also not very sure how will that work, I have been trying for the past hour to get it right, but either its overcomplicating the code or its not working.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually tried this out, but I am not able to create a color bar when I am using the code you have suggested as I don't know the mapping of value to color when the Scale function is used and the inbuilt altair legend is discrete and not continous

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but the matplotlib implementation of just the colorbar is also not right.

@Sahil-Chhoker
Copy link
Collaborator

Ok so I have spent more time in this than I would like to admit, but finally got the proper version working, here it is:

Some things to note:

  1. It changes the chart_property_layers function definition to include agent_chart in it.
  2. This creates a separate chart for the colorbar manually defining the gradient along with ticks and text.

While @sanika-n did most of the work, the colorbar implementation still needs work (I have provided the full code to avoid confusion).
@quaquel I would love to get your opinion here because it changes the API design from that of matplotlib's side, but I see no other way.

My implementation
def _draw_grid(space, agent_portrayal, propertylayer_portrayal):
    ...
    agent_chart = (
        alt.Chart(
            alt.Data(values=all_agent_data), encoding=alt.Encoding(**encoding_dict)
        )
        .mark_point(filled=True)
        .properties(width=300, height=300)
    )
    base_chart = None

    # This is the default value for the marker size, which auto-scales according to the grid area.
    if not has_size:
        length = min(space.width, space.height)
        agent_chart = agent_chart.mark_point(size=30000 / length**2, filled=True)

    if propertylayer_portrayal is not None:
        chart_width = agent_chart.properties().width
        chart_height = agent_chart.properties().height
        base_chart = chart_property_layers(
            space=space,
            propertylayer_portrayal=propertylayer_portrayal,
            chart_width=chart_width,
            chart_height=chart_height,
            agent_chart = agent_chart <---
        )
    else:
        base_chart = agent_chart
    return base_chart
    
def chart_property_layers(space, propertylayer_portrayal, chart_width, chart_height, agent_chart):
    """Creates Property Layers in the Altair Components.

    Args:
        space: the ContinuousSpace instance
        propertylayer_portrayal:Dictionary of PropertyLayer portrayal specifications
        chart_width: width of the agent chart to maintain consistency with the property charts
        chart_height: height of the agent chart to maintain consistency with the property charts
        agent_chart: the agent chart to layer with the property layers on the grid
    Returns:
        Altair Chart
    """
    try:
        # old style spaces
        property_layers = space.properties
    except AttributeError:
        # new style spaces
        property_layers = space._mesa_property_layers
    base = agent_chart
    for layer_name, portrayal in propertylayer_portrayal.items():
        layer = property_layers.get(layer_name, None)
        if not isinstance(
            layer,
            PropertyLayer | mesa.discrete_space.property_layer.PropertyLayer,
        ):
            continue

        data = layer.data.astype(float) if layer.data.dtype == bool else layer.data

        if (space.width, space.height) != data.shape:
            warnings.warn(
                f"Layer {layer_name} dimensions ({data.shape}) do not match space dimensions ({space.width}, {space.height}).",
                UserWarning,
                stacklevel=2,
            )
        alpha = portrayal.get("alpha", 1)
        vmin = portrayal.get("vmin", np.min(data))
        vmax = portrayal.get("vmax", np.max(data))
        colorbar = portrayal.get("colorbar", True)

        # Prepare data for Altair (convert 2D array to a long-form DataFrame)
        df = pd.DataFrame(
            {
                "x": np.repeat(np.arange(data.shape[0]), data.shape[1]),
                "y": np.tile(np.arange(data.shape[1]), data.shape[0]),
                "value": data.flatten(),
            }
        )

        if "color" in portrayal:
            # Create a function to map values to RGBA colors with proper opacity scaling
            def apply_rgba(val):
                """
                Maps data values to RGBA colors with opacity based on value magnitude.
                
                Args:
                    val: The data value to convert
                    
                Returns:
                    String representation of RGBA color
                """
                # Normalize value to range [0,1] and clamp
                normalized = max(0, min((val - vmin) / (vmax - vmin), 1))
                
                # Scale opacity by alpha parameter
                opacity = normalized * alpha
                
                # Convert color to RGB components
                rgb_color = to_rgb(portrayal["color"])
                r = int(rgb_color[0] * 255)
                g = int(rgb_color[1] * 255)
                b = int(rgb_color[2] * 255)

                return f"rgba({r}, {g}, {b}, {opacity:.2f})"

            # Apply color mapping to each value in the dataset
            df["color"] = df["value"].apply(apply_rgba)

            # Create chart for the property layer
            chart = (
                alt.Chart(df)
                .mark_rect()
                .encode(
                    x=alt.X("x:O", axis=None),
                    y=alt.Y("y:O", axis=None),
                    fill=alt.Fill("color:N", scale=None),
                )
                .properties(width=chart_width, height=chart_height, title=layer_name)
            )
            base = alt.layer(chart, base) if base is not None else chart

            # Add colorbar if specified in portrayal
            if colorbar:
                # Extract RGB components from base color
                rgb_color = to_rgb(portrayal["color"])
                r_int = int(rgb_color[0] * 255)
                g_int = int(rgb_color[1] * 255)
                b_int = int(rgb_color[2] * 255)
                
                # Define gradient endpoints
                min_color = f"rgba({r_int},{g_int},{b_int},0)"
                max_color = f"rgba({r_int},{g_int},{b_int},{alpha:.2f})"
                
                # Define colorbar dimensions
                colorbar_height = 20
                colorbar_width = chart_width
                
                # Create dataframe for gradient visualization
                df_gradient = pd.DataFrame({'x': [0, 1], 'y': [0, 1]})
                
                # Create evenly distributed tick values
                axis_values = np.linspace(vmin, vmax, 11)
                tick_positions = np.linspace(0, colorbar_width, 11)
                
                # Prepare data for axis and labels
                axis_data = pd.DataFrame({
                    'value': axis_values,
                    'x': tick_positions
                })
                
                # Create colorbar with linear gradient
                colorbar_chart = alt.Chart(df_gradient).mark_rect(
                    x=0, y=0,
                    width=colorbar_width, height=colorbar_height,
                    color=alt.Gradient(
                        gradient='linear',
                        stops=[
                            alt.GradientStop(color=min_color, offset=0),
                            alt.GradientStop(color=max_color, offset=1)
                        ],
                        x1=0, x2=1,  # Horizontal gradient
                        y1=0, y2=0   # Keep y constant
                    )
                ).encode(
                    x=alt.value(chart_width / 2), # Center colorbar
                    y=alt.value(0)
                ).properties(
                    width=colorbar_width,
                    height=colorbar_height
                )
                
                # Add tick marks to colorbar
                axis_chart = alt.Chart(axis_data).mark_tick(
                    thickness=2,
                    size=8
                ).encode(
                    x=alt.X('x:Q', axis=None),
                    y=alt.value(colorbar_height - 2)
                )
                
                # Add value labels below tick marks
                text_labels = alt.Chart(axis_data).mark_text(
                    baseline='top',
                    fontSize=10,
                    dy=0
                ).encode(
                    x=alt.X('x:Q'),
                    text=alt.Text('value:Q', format='.1f'),
                    y=alt.value(colorbar_height + 10)
                )
                
                # Add title to colorbar
                title = alt.Chart(pd.DataFrame([{'text': layer_name}])).mark_text(
                    fontSize=12,
                    fontWeight='bold',
                    baseline='bottom',
                    align='center'
                ).encode(
                    text='text:N',
                    x=alt.value(colorbar_width / 2),
                    y=alt.value(colorbar_height + 40)
                )
                
                # Combine all colorbar components
                combined_colorbar = alt.layer(
                    colorbar_chart,
                    axis_chart,
                    text_labels,
                    title
                ).properties(
                    width=colorbar_width,
                    height=colorbar_height + 50
                )
                
                # Stack main visualization and colorbar vertically
                base = alt.vconcat(
                    base,
                    combined_colorbar,
                    spacing=20
                ).resolve_scale(
                    color='independent'
                ).configure_view(
                    stroke=None # Remove border around colorbar
                )
                
        elif "colormap" in portrayal:
            cmap = portrayal.get("colormap", "viridis")
            cmap_scale = alt.Scale(scheme=cmap, domain=[vmin, vmax])

            chart = (
                alt.Chart(df)
                .mark_rect(opacity=alpha)
                .encode(
                    x=alt.X("x:O", axis=None),
                    y=alt.Y("y:O", axis=None),
                    color=alt.Color(
                        "value:Q",
                        scale=cmap_scale,
                        title=layer_name,
                        legend=alt.Legend(title=layer_name) if colorbar else None,
                    ),
                )
                .properties(width=chart_width, height=chart_height)
            )
            base = alt.layer(chart, base) if base is not None else chart

        else:
            raise ValueError(
                f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'."
            )
    return base

How it looks:
image

@Sahil-Chhoker
Copy link
Collaborator

Hey @sanika-n, Jan seems very busy lately so I think I will have to judge the PR, I would love his opinion on this as well when it is completed but that can be done later. In the mean time can you incorporate the code given in above comment and check if there is a way of achieving the same thing without passing agent_chart in chart_property_layers.

@sanika-n
Copy link
Contributor Author

sanika-n commented Mar 2, 2025

Sure, have a bit going on today, will finish it by tomorrow

@sanika-n
Copy link
Contributor Author

sanika-n commented Mar 3, 2025

@Sahil-Chhoker sorry for the delay, just updated the file to include your suggestion, also changed the function definition of apply_rgba so that it would pass the ruff format...

@Sahil-Chhoker
Copy link
Collaborator

Great work @sanika-n!, are tests running fine?

@sanika-n
Copy link
Contributor Author

sanika-n commented Mar 3, 2025

Yes :)
Also, the only way I could think of, of not passing agent_chart is by altering this part of the code

# Stack main visualization and colorbar vertically
                 base = (
                     alt.vconcat(base, combined_colorbar, spacing=20)
                     .resolve_scale(color="independent")
                    .configure_view(
                          stroke=None  # Remove border around colorbar
                    )
                 )

Is there any way to not attach the color bar directly to the main chart and like to keep it as a separate entity?(cuz the fact that this code used concat and configure on base, were causing problems) If we can do that then I think we won't have to pass agent_chart as in the code below:

full code
"""Altair based solara components for visualization mesa spaces."""

import warnings

import altair as alt
import numpy as np
import pandas as pd
import solara
from matplotlib.colors import to_rgb

import mesa
from mesa.discrete_space import DiscreteSpace, Grid
from mesa.space import ContinuousSpace, PropertyLayer, _Grid
from mesa.visualization.utils import update_counter


def make_space_altair(*args, **kwargs):  # noqa: D103
    warnings.warn(
        "make_space_altair has been renamed to make_altair_space",
        DeprecationWarning,
        stacklevel=2,
    )
    return make_altair_space(*args, **kwargs)


def make_altair_space(
    agent_portrayal, propertylayer_portrayal, post_process, **space_drawing_kwargs
):
    """Create an Altair-based space visualization component.

    Args:
        agent_portrayal: Function to portray agents.
        propertylayer_portrayal: Dictionary of PropertyLayer portrayal specifications
        post_process :A user specified callable that will be called with the Chart instance from Altair. Allows for fine tuning plots (e.g., control ticks)
        space_drawing_kwargs : not yet implemented

    ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
    "size", "marker", and "zorder". Other field are ignored and will result in a user warning.


    Returns:
        function: A function that creates a SpaceMatplotlib component
    """
    if agent_portrayal is None:

        def agent_portrayal(a):
            return {"id": a.unique_id}

    def MakeSpaceAltair(model):
        return SpaceAltair(
            model, agent_portrayal, propertylayer_portrayal, post_process=post_process
        )

    return MakeSpaceAltair


@solara.component
def SpaceAltair(
    model,
    agent_portrayal,
    propertylayer_portrayal,
    dependencies: list[any] | None = None,
    post_process=None,
):
    """Create an Altair-based space visualization component.

    Returns:
        a solara FigureAltair instance
    """
    update_counter.get()
    space = getattr(model, "grid", None)
    if space is None:
        # Sometimes the space is defined as model.space instead of model.grid
        space = model.space

    chart = _draw_grid(space, agent_portrayal, propertylayer_portrayal)
    # Apply post-processing if provided
    if post_process is not None:
        chart = post_process(chart)

    solara.FigureAltair(chart)


def _get_agent_data_old__discrete_space(space, agent_portrayal):
    """Format agent portrayal data for old-style discrete spaces.

    Args:
        space: the mesa.space._Grid instance
        agent_portrayal: the agent portrayal callable

    Returns:
        list of dicts

    """
    all_agent_data = []
    for content, (x, y) in space.coord_iter():
        if not content:
            continue
        if not hasattr(content, "__iter__"):
            # Is a single grid
            content = [content]  # noqa: PLW2901
        for agent in content:
            # use all data from agent portrayal, and add x,y coordinates
            agent_data = agent_portrayal(agent)
            agent_data["x"] = x
            agent_data["y"] = y
            all_agent_data.append(agent_data)
    return all_agent_data


def _get_agent_data_new_discrete_space(space: DiscreteSpace, agent_portrayal):
    """Format agent portrayal data for new-style discrete spaces.

    Args:
        space: the mesa.experiment.cell_space.Grid instance
        agent_portrayal: the agent portrayal callable

    Returns:
        list of dicts

    """
    all_agent_data = []

    for cell in space.all_cells:
        for agent in cell.agents:
            agent_data = agent_portrayal(agent)
            agent_data["x"] = cell.coordinate[0]
            agent_data["y"] = cell.coordinate[1]
            all_agent_data.append(agent_data)
    return all_agent_data


def _get_agent_data_continuous_space(space: ContinuousSpace, agent_portrayal):
    """Format agent portrayal data for continuous space.

    Args:
        space: the ContinuousSpace instance
        agent_portrayal: the agent portrayal callable

    Returns:
        list of dicts
    """
    all_agent_data = []
    for agent in space._agent_to_index:
        agent_data = agent_portrayal(agent)
        agent_data["x"] = agent.pos[0]
        agent_data["y"] = agent.pos[1]
        all_agent_data.append(agent_data)
    return all_agent_data


def _draw_grid(space, agent_portrayal, propertylayer_portrayal):
    match space:
        case Grid():
            all_agent_data = _get_agent_data_new_discrete_space(space, agent_portrayal)
        case _Grid():
            all_agent_data = _get_agent_data_old__discrete_space(space, agent_portrayal)
        case ContinuousSpace():
            all_agent_data = _get_agent_data_continuous_space(space, agent_portrayal)
        case _:
            raise NotImplementedError(
                f"visualizing {type(space)} is currently not supported through altair"
            )

    invalid_tooltips = ["color", "size", "x", "y"]

    x_y_type = "ordinal" if not isinstance(space, ContinuousSpace) else "nominal"

    encoding_dict = {
        # no x-axis label
        "x": alt.X("x", axis=None, type=x_y_type),
        # no y-axis label
        "y": alt.Y("y", axis=None, type=x_y_type),
        "tooltip": [
            alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value]))
            for key, value in all_agent_data[0].items()
            if key not in invalid_tooltips
        ],
    }
    has_color = "color" in all_agent_data[0]
    if has_color:
        unique_colors = list({agent["color"] for agent in all_agent_data})
        encoding_dict["color"] = alt.Color(
            "color:N",
            scale=alt.Scale(domain=unique_colors, range=unique_colors),
        )
    has_size = "size" in all_agent_data[0]
    if has_size:
        encoding_dict["size"] = alt.Size("size", type="quantitative")

    agent_chart = (
        alt.Chart(
            alt.Data(values=all_agent_data), encoding=alt.Encoding(**encoding_dict)
        )
        .mark_point(filled=True)
        .properties(width=300, height=300)
    )
    base_chart = None

    # This is the default value for the marker size, which auto-scales according to the grid area.
    if not has_size:
        length = min(space.width, space.height)
        agent_chart = agent_chart.mark_point(size=30000 / length**2, filled=True)

    if propertylayer_portrayal is not None:
        chart_width = agent_chart.properties().width
        chart_height = agent_chart.properties().height
        base_chart = chart_property_layers(
            space=space,
            propertylayer_portrayal=propertylayer_portrayal,
            chart_width=chart_width,
            chart_height=chart_height
        )
        
        base_chart=alt.layer(base_chart, agent_chart)
    else:
        base_chart = agent_chart
    return base_chart


def chart_property_layers(
    space, propertylayer_portrayal, chart_width, chart_height
):
    """Creates Property Layers in the Altair Components.

    Args:
        space: the ContinuousSpace instance
        propertylayer_portrayal:Dictionary of PropertyLayer portrayal specifications
        chart_width: width of the agent chart to maintain consistency with the property charts
        chart_height: height of the agent chart to maintain consistency with the property charts
        agent_chart: the agent chart to layer with the property layers on the grid
    Returns:
        Altair Chart
    """
    try:
        # old style spaces
        property_layers = space.properties
    except AttributeError:
        # new style spaces
        property_layers = space._mesa_property_layers
    base = None
    for layer_name, portrayal in propertylayer_portrayal.items():
        layer = property_layers.get(layer_name, None)
        if not isinstance(
            layer,
            PropertyLayer | mesa.discrete_space.property_layer.PropertyLayer,
        ):
            continue

        data = layer.data.astype(float) if layer.data.dtype == bool else layer.data

        if (space.width, space.height) != data.shape:
            warnings.warn(
                f"Layer {layer_name} dimensions ({data.shape}) do not match space dimensions ({space.width}, {space.height}).",
                UserWarning,
                stacklevel=2,
            )
        alpha = portrayal.get("alpha", 1)
        vmin = portrayal.get("vmin", np.min(data))
        vmax = portrayal.get("vmax", np.max(data))
        colorbar = portrayal.get("colorbar", True)

        # Prepare data for Altair (convert 2D array to a long-form DataFrame)
        df = pd.DataFrame(
            {
                "x": np.repeat(np.arange(data.shape[0]), data.shape[1]),
                "y": np.tile(np.arange(data.shape[1]), data.shape[0]),
                "value": data.flatten(),
            }
        )

        if "color" in portrayal:
            # Create a function to map values to RGBA colors with proper opacity scaling
            def apply_rgba(val, vmin=vmin, vmax=vmax, alpha=alpha, portrayal=portrayal):
                """Maps data values to RGBA colors with opacity based on value magnitude.

                Args:
                    val: The data value to convert
                    vmin: The smallest value for which the color is displayed in the colorbar
                    vmax: The largest value for which the color is displayed in the colorbar
                    alpha: The opacity of the color
                    portrayal: The specifics of the current property layer in the iterative loop

                Returns:
                    String representation of RGBA color
                """
                # Normalize value to range [0,1] and clamp
                normalized = max(0, min((val - vmin) / (vmax - vmin), 1))

                # Scale opacity by alpha parameter
                opacity = normalized * alpha

                # Convert color to RGB components
                rgb_color = to_rgb(portrayal["color"])
                r = int(rgb_color[0] * 255)
                g = int(rgb_color[1] * 255)
                b = int(rgb_color[2] * 255)

                return f"rgba({r}, {g}, {b}, {opacity:.2f})"

            # Apply color mapping to each value in the dataset
            df["color"] = df["value"].apply(apply_rgba)

            # Create chart for the property layer
            chart = (
                alt.Chart(df)
                .mark_rect()
                .encode(
                    x=alt.X("x:O", axis=None),
                    y=alt.Y("y:O", axis=None),
                    fill=alt.Fill("color:N", scale=None),
                )
                .properties(width=chart_width, height=chart_height, title=layer_name)
            )
            base = alt.layer(chart, base) if base is not None else chart

            # Add colorbar if specified in portrayal
            if colorbar:
                # Extract RGB components from base color
                rgb_color = to_rgb(portrayal["color"])
                r_int = int(rgb_color[0] * 255)
                g_int = int(rgb_color[1] * 255)
                b_int = int(rgb_color[2] * 255)

                # Define gradient endpoints
                min_color = f"rgba({r_int},{g_int},{b_int},0)"
                max_color = f"rgba({r_int},{g_int},{b_int},{alpha:.2f})"

                # Define colorbar dimensions
                colorbar_height = 20
                colorbar_width = chart_width

                # Create dataframe for gradient visualization
                df_gradient = pd.DataFrame({"x": [0, 1], "y": [0, 1]})

                # Create evenly distributed tick values
                axis_values = np.linspace(vmin, vmax, 11)
                tick_positions = np.linspace(0, colorbar_width, 11)

                # Prepare data for axis and labels
                axis_data = pd.DataFrame({"value": axis_values, "x": tick_positions})

                # Create colorbar with linear gradient
                colorbar_chart = (
                    alt.Chart(df_gradient)
                    .mark_rect(
                        x=0,
                        y=0,
                        width=colorbar_width,
                        height=colorbar_height,
                        color=alt.Gradient(
                            gradient="linear",
                            stops=[
                                alt.GradientStop(color=min_color, offset=0),
                                alt.GradientStop(color=max_color, offset=1),
                            ],
                            x1=0,
                            x2=1,  # Horizontal gradient
                            y1=0,
                            y2=0,  # Keep y constant
                        ),
                    )
                    .encode(
                        x=alt.value(chart_width / 2),  # Center colorbar
                        y=alt.value(0),
                    )
                    .properties(width=colorbar_width, height=colorbar_height)
                )

                # Add tick marks to colorbar
                axis_chart = (
                    alt.Chart(axis_data)
                    .mark_tick(thickness=2, size=8)
                    .encode(x=alt.X("x:Q", axis=None), y=alt.value(colorbar_height - 2))
                )

                # Add value labels below tick marks
                text_labels = (
                    alt.Chart(axis_data)
                    .mark_text(baseline="top", fontSize=10, dy=0)
                    .encode(
                        x=alt.X("x:Q"),
                        text=alt.Text("value:Q", format=".1f"),
                        y=alt.value(colorbar_height + 10),
                    )
                )

                # Add title to colorbar
                title = (
                    alt.Chart(pd.DataFrame([{"text": layer_name}]))
                    .mark_text(
                        fontSize=12,
                        fontWeight="bold",
                        baseline="bottom",
                        align="center",
                    )
                    .encode(
                        text="text:N",
                        x=alt.value(colorbar_width / 2),
                        y=alt.value(colorbar_height + 40),
                    )
                )

                # Combine all colorbar components
                combined_colorbar = alt.layer(
                    colorbar_chart, axis_chart, text_labels, title
                ).properties(width=colorbar_width, height=colorbar_height + 50)

                # Stack main visualization and colorbar vertically
                # base = (
                #     alt.vconcat(base, combined_colorbar, spacing=20)
                #     .resolve_scale(color="independent")
                #     .configure_view(
                #          stroke=None  # Remove border around colorbar
                #     )
                # )

        elif "colormap" in portrayal:
            cmap = portrayal.get("colormap", "viridis")
            cmap_scale = alt.Scale(scheme=cmap, domain=[vmin, vmax])

            chart = (
                alt.Chart(df)
                .mark_rect(opacity=alpha)
                .encode(
                    x=alt.X("x:O", axis=None),
                    y=alt.Y("y:O", axis=None),
                    color=alt.Color(
                        "value:Q",
                        scale=cmap_scale,
                        title=layer_name,
                        legend=alt.Legend(title=layer_name) if colorbar else None,
                    ),
                )
                .properties(width=chart_width, height=chart_height)
            )
            base = alt.layer(chart, base) if base is not None else chart

        else:
            raise ValueError(
                f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'."
            )
    return base

@Sahil-Chhoker
Copy link
Collaborator

Yeah exactly, maybe we can keep the chart of colorbar separate and layer it with the base graph at the end of the chart_property_layer and then agent_chart can also be layered? I am not exactly sure it will work, can you test this?

@sanika-n
Copy link
Contributor Author

sanika-n commented Mar 3, 2025

I think it's done... This is how it looks(pretty much the same)
image

@Sahil-Chhoker
Copy link
Collaborator

Great job @sanika-n! Thanks for this, but it looks like you have also pushed your example model in this PR, can you remove that please.

@sanika-n
Copy link
Contributor Author

sanika-n commented Mar 3, 2025

Sorry about that and thank you!

@Sahil-Chhoker
Copy link
Collaborator

Great work! This PR looks good to me, as everything works as expected. However, I'm not sure about the implementation and best practices. @quaquel, it would be very helpful if you could give the code a quick review!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants