Lightweight mmm - Lightweight (Bayesian) Media Mix Model

Overview

Lightweight (Bayesian) Media Mix Model

This is not an official Google product.

LightweightMMM 🦇 is a lightweight Bayesian media mix modeling library that allows users to easily train MMMs and obtain channel attribution information. The library also includes capabilities for optimizing media allocation as well as plotting common graphs in the field.

It is built in python3 and makes use of Numpyro and JAX.

What you can do with LightweightMMM

  • Scale you data for training.
  • Easily train your media mix model.
  • Evaluate your model.
  • Learn about your media attribution and ROI per media channel.
  • Optimize your budget allocation.

Installation

The recommended way of installing lightweight_mmm is through PyPi:

pip install lightweight_mmm

If you want to use the most recent and slightly less stable version you can install it from github:

pip install --upgrade git+https://github.com/google/lightweight_mmm.git

The models

For larger countries we recommend a geo-based model, this is not implemented yet.

We estimate a national weekly model where we use sales revenue (y) as the KPI.

$$\mu_t = a + trend_t + seasonality_t + \beta_m sat(lag(X_{mt}, \phi_m), \theta_m) + \beta_o X_{ot}$$

$$y_t \sim N(\mu_t, \sigma^2)$$

$$\sigma \sim \Gamma(1, 1)$$

$$\beta_m \sim N^+(0, \sigma_m^2)$$

$$X_m$$ is a media matrix and $$X_o$$ is a matrix of other exogenous variables.

Seasonality is a latent sinusoidal parameter with a repeating pattern.

Media parameter $$\beta_m$$ is informed by costs. It uses a HalfNormal distribution and the scale of the distribution is the total cost of each media channel.

$$sat()$$ is a saturation function and $$lag()$$ is a lagging function, eg Adstock. They have parameters $$\theta$$ and $$\phi$$ respectively.

We have three different versions of the MMM with different lagging and saturation and we recommend you compare all three models. The Adstock and carryover models have an exponent for diminishing returns. The Hill functions covers that functionality for the Hill-Adstock model.

  • Adstock: Applies an infinite lag that decreases its weight as time passes.
  • Hill-Adstock: Applies a sigmoid like function for diminishing returns to the output of the adstock function.
  • Carryover: Applies a causal convolution giving more weight to the near values than distant ones.

Scaling

Scaling is a bit of an art, Bayesian techniques work well if the input data is small scale. We should not center variables at 0. Sales and media should have a lower bound of 0.

  1. y can be scaled as $$y / mean_y$$.
  2. media can be scaled as $$X_m / mean_X$$, which means the new column mean will be 1.

Optimization

For optimization we will maximize the sales changing the media inputs such that the summed cost of the media is constant. We can also allow reasonable bounds on each media input (eg +- x%). We only optimise across channels and not over time.

Getting started

Preparing the data

Here we use simulated data but it is assumed you have you data cleaned at this point. The necessary data will be:

  • Media data: Containing the metric per channel and time span (eg. impressions per week). Media values must not contain negative values.
  • Extra features: Any other features that one might want to add to the analysis. These features need to be known ahead of time for optimization or you would need another model to estimate them.
  • Target: Target KPI for the model to predict. This will also be the metric optimized during the optimization phase.
  • Costs: The average cost per media unit per channel.
# Let's assume we have the following datasets with the following shapes:
media_data, extra_features, target, unscaled_costs, _ = data_simulation.simulate_all_data(
    data_size=120,
    n_media_channels=3,
    n_extra_features=2)

Scaling is a bit of an art, Bayesian techniques work well if the input data is small scale. We should not center variables at 0. Sales and media should have a lower bound of 0.

We provide a CustomScaler which can apply multiplications and division scaling in case the wider used scalers don't fit your use case. Scale your data accordingly before fitting the model. Below is an example of usage of this CustomScaler:

# Simple split of the data based on time.
split_point = data_size - data_size // 10
media_data_train = media_data[:split_point, :]
target_train = target[:split_point]
extra_features_train = extra_features[:split_point, :]
extra_features_test = extra_features[split_point:, :]

# Scale data
media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
extra_features_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
target_scaler = preprocessing.CustomScaler(
    divide_operation=jnp.mean)
# scale cost up by N since fit() will divide it by number of weeks
cost_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)

media_data_train = media_scaler.fit_transform(media_data_train)
extra_features_train = extra_features_scaler.fit_transform(
    extra_features_train)
target_train = target_scaler.fit_transform(target_train)
costs = cost_scaler.fit_transform(unscaled_costs)

Training the model

The model requires the media data, the extra features, the costs of each media unit per channel and the target. You can also pass how many samples you would like to use as well as the number of chains.

For running multiple chains in parallel the user would need to set numpyro.set_host_device_count to either the number of chains or the number of CPUs available.

See an example below:

# Fit model.
mmm = lightweight_mmm.LightweightMMM()
mmm.fit(media=media_data,
        extra_features=extra_features,
        total_costs=costs,
        target=target,
        number_warmup=1000,
        number_samples=1000,
        number_chains=2)

Obtaining media effect and ROI

There are two ways of obtaining the media effect and ROI with lightweightMMM depending on if you scaled the data or not prior to training. If you did not scale your data you can simply call:

mmm.get_posterior_metrics()

However if you scaled your media data, target or both it is important that you provide get_posterior_metrics with the necessary information to unscale the data and calculate media effect and ROI.

  • If only costs were scaled, the following two function calls are equivalent:
# Option 1
mmm.get_posterior_metrics(cost_scaler=cost_scaler)
# Option 2
mmm.get_posterior_metrics(unscaled_costs=unscaled_costs)
  • If only the target was scaled:
mmm.get_posterior_metrics(target_scaler=target_scaler)
  • If both were scaled:
mmm.get_posterior_metrics(cost_scaler=cost_scaler,
                          target_scaler=target_scaler)

Running the optimization

For running the optimization one needs the following main parameters:

  • n_time_periods: The number of time periods you want to simulate (eg. Optimize for the next 10 weeks if you trained a model on weekly data).
  • The model that was trained.
  • The budget you want to allocate for the next n_time_periods.
  • The extra features used for training for the following n_time_periods.
  • Price per media unit per channel.
  • media_gap refers to the media data gap between the end of training data and the start of the out of sample media given. Eg. if 100 weeks of data were used for training and prediction starts 2 months after training data finished we need to provide the 8 weeks missing between the training data and the prediction data so data transformations (adstock, carryover, ...) can take place correctly.

See below and example of optimization:

# Run media optimization.
budget = 40
prices = np.array([0.1, 0.11, 0.12])
extra_features_test = extra_features_scaler.transform(extra_features_test)
solution = optimize_media.find_optimal_budgets(
    n_time_periods=extra_features_test.shape[0],
    media_mix_model=mmm,
    budget=budget,
    extra_features=extra_features_test,
    prices=prices)

Run times

A model with 5 media variables and 1 other variable and 150 weeks, 1500 draws and 2 chains should take 7 mins per chain to estimate (on CPU machine). This excludes compile time.

References

Comments
  • .predict vs .trace

    .predict vs .trace

    Hi Team,

    What is the difference between '.predict' and '.trace["mu"]'? Below are the plots after running both -

    1. prediction = mmm.predict( media=media_data, extra_features=extra_features, target_scaler=target_scaler ) prediction_mean = prediction.mean(axis=0)

      plt.figure(figsize=(8,7)) plt.plot(x, targ, label='Actual') --> targ is the actual target value in the data plt.plot(x, prediction_mean, label='Predicted') plt.legend()

    mmm_predict

    np.sum(prediction_mean) = 24816.469

    1. pred = mmm.trace["mu"] predictions = target_scaler.inverse_transform(pred) pred_mean = predictions.mean(axis=0)

      plt.figure(figsize=(8,7)) plt.plot(x, targ, label='Actual') plt.plot(x, pred_mean, label='Predicted') plt.legend()

    mmm_trace

    np.sum(pred_mean) = 43999.086

    The 'predicted' line in both plots have similar trend with the 'actual' line, however, with mmm.predict(), the predicted values have an offset, whereas, with mmm.trace["mu"] that offset is not present - both predicted and actual lines are aligned.

    Also, the sum of predicted values returned by mmm.predict and mmm.trace["mu"] are different. In my case, the sum of predicted values returned by mmm.trace["mu"] is close to the sum of actual target values in the data. Why is mmm.predict() not giving values close to the actual target values?

    It will be helpful to get a clarity on this.

    Thank you!

    question 
    opened by sv09 14
  • find_optimal_budgets current function value returning nan

    find_optimal_budgets current function value returning nan

    Hi Team,

    1. find_optimal_budgets current function value returning nan
    2. previous and. optimal budget allocation values are always equal how much i change the values and range
    3. one of the channel previous budget is returning 0 even where there is budget present
    4. Media contribution is more for channel 1 but where as ROI is more for channel 0

    Thanks in advance

    opened by virithavanama 13
  • Examples failed

    Examples failed

    I installed all requirements in a env. Running the https://github.com/google/lightweight_mmm/blob/main/examples/simple_end_to_end_demo.ipynb example fails executing mmm.fit(.....)

    177 mcmc.run( 178 rng_key=jax.random.PRNGKey(seed), 179 media_data=jnp.array(media), 180 extra_features=extra_features, 181 target_data=jnp.array(target), 182 cost_prior=jnp.array(total_costs), 183 degrees_seasonality=degrees_seasonality, 184 frequency=seasonality_frequency, 185 transform_function=self._model_transform_function, 186 weekday_seasonality=weekday_seasonality) 188 if media_names is not None: 189 self.media_names = media_names

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/mcmc.py:597, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs) 595 else: 596 if self.chain_method == "sequential": --> 597 states, last_state = _laxmap(partial_map_fn, map_args) 598 elif self.chain_method == "parallel": 599 states, last_state = pmap(partial_map_fn)(map_args)

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/mcmc.py:160, in _laxmap(f, xs) 158 for i in range(n): 159 x = jit(_get_value_from_index)(xs, i) --> 160 ys.append(f(x)) 162 return tree_map(lambda *args: jnp.stack(args), *ys)

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/mcmc.py:381, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields) 379 rng_key, init_state, init_params = init 380 if init_state is None: --> 381 init_state = self.sampler.init( 382 rng_key, 383 self.num_warmup, 384 init_params, 385 model_args=args, 386 model_kwargs=kwargs, 387 ) 388 sample_fn, postprocess_fn = self._get_cached_fns() 389 diagnostics = ( 390 lambda x: self.sampler.get_diagnostics_str(x[0]) 391 if rng_key.ndim == 1 392 else "" 393 ) # noqa: E731

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/hmc.py:706, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs) 701 # vectorized 702 else: 703 rng_key, rng_key_init_model = jnp.swapaxes( 704 vmap(random.split)(rng_key), 0, 1 705 ) --> 706 init_params = self._init_state( 707 rng_key_init_model, model_args, model_kwargs, init_params 708 ) 709 if self._potential_fn and init_params is None: 710 raise ValueError( 711 "Valid value of init_params must be provided with" " potential_fn." 712 )

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/hmc.py:652, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params) 650 def _init_state(self, rng_key, model_args, model_kwargs, init_params): 651 if self._model is not None: --> 652 init_params, potential_fn, postprocess_fn, model_trace = initialize_model( 653 rng_key, 654 self._model, 655 dynamic_args=True, 656 init_strategy=self._init_strategy, 657 model_args=model_args, 658 model_kwargs=model_kwargs, 659 forward_mode_differentiation=self._forward_mode_differentiation, 660 ) 661 if self._init_fn is None: 662 self._init_fn, self._sample_fn = hmc( 663 potential_fn_gen=potential_fn, 664 kinetic_fn=self._kinetic_fn, 665 algo=self._algo, 666 )

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:654, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad) 652 init_strategy = _init_to_unconstrained_value(values=unconstrained_values) 653 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True) --> 654 (init_params, pe, grad), is_valid = find_valid_initial_params( 655 rng_key, 656 substitute( 657 model, 658 data={ 659 k: site["value"] 660 for k, site in model_trace.items() 661 if site["type"] in ["plate"] 662 }, 663 ), 664 init_strategy=init_strategy, 665 enum=has_enumerate_support, 666 model_args=model_args, 667 model_kwargs=model_kwargs, 668 prototype_params=prototype_params, 669 forward_mode_differentiation=forward_mode_differentiation, 670 validate_grad=validate_grad, 671 ) 673 if not_jax_tracer(is_valid): 674 if device_get(~jnp.all(is_valid)):

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:395, in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad) 393 # Handle possible vectorization 394 if rng_key.ndim == 1: --> 395 (init_params, pe, z_grad), is_valid = _find_valid_params( 396 rng_key, exit_early=True 397 ) 398 else: 399 (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key)

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:381, in find_valid_initial_params.._find_valid_params(rng_key, exit_early) 377 init_state = (0, rng_key, (prototype_params, 0.0, prototype_params), False) 378 if exit_early and not_jax_tracer(rng_key): 379 # Early return if valid params found. This is only helpful for single chain, 380 # where we can avoid compiling body_fn in while_loop. --> 381 _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state) 382 if not_jax_tracer(is_valid): 383 if device_get(is_valid):

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:366, in find_valid_initial_params..body_fn(state) 364 z_grad = jacfwd(potential_fn)(params) 365 else: --> 366 pe, z_grad = value_and_grad(potential_fn)(params) 367 z_grad_flat = ravel_pytree(z_grad)[0] 368 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))

    [... skipping hidden 8 frame]
    

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:248, in potential_energy(model, model_args, model_kwargs, params, enum) 244 substituted_model = substitute( 245 model, substitute_fn=partial(unconstrain_reparam, params) 246 ) 247 # no param is needed for log_density computation because we already substitute --> 248 log_joint, model_trace = log_density( 249 substituted_model, model_args, model_kwargs, {} 250 ) 251 return -log_joint

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:62, in log_density(model, model_args, model_kwargs, params) 50 """ 51 (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given 52 latent values params. (...) 59 :return: log of joint density and a corresponding model trace 60 """ 61 model = substitute(model, data=params) ---> 62 model_trace = trace(model).get_trace(*model_args, **model_kwargs) 63 log_joint = jnp.zeros(()) 64 for site in model_trace.values():

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs) 163 def get_trace(self, *args, **kwargs): 164 """ 165 Run the wrapped callable and return the recorded trace. 166 (...) 169 :return: OrderedDict containing the execution trace. 170 """ --> 171 self(*args, **kwargs) 172 return self.trace

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, **kwargs) 103 return self 104 with self: --> 105 return self.fn(*args, **kwargs)

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, **kwargs) 103 return self 104 with self: --> 105 return self.fn(*args, **kwargs)

    [... skipping similar frames: Messenger.__call__ at line 105 (2 times)]
    

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, **kwargs) 103 return self 104 with self: --> 105 return self.fn(*args, **kwargs)

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/lightweight_mmm/models.py:187, in media_mix_model(media_data, target_data, cost_prior, degrees_seasonality, frequency, transform_function, transform_kwargs, weekday_seasonality, extra_features) 182 with numpyro.plate(name="beta_trend_plate", size=n_geos): 183 beta_trend = numpyro.sample( 184 name="beta_trend", 185 fn=dist.Normal(loc=0., scale=1.)) --> 187 expo_trend = numpyro.sample( 188 name="expo_trend", 189 fn=dist.Beta(concentration1=1., concentration0=1.)) 191 with numpyro.plate( 192 name="channel_media_plate", 193 size=n_channels, 194 dim=-2 if media_data.ndim == 3 else -1): 195 beta_media = numpyro.sample( 196 name="channel_beta_media" if media_data.ndim == 3 else "beta_media", 197 fn=dist.HalfNormal(scale=cost_prior))

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:219, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask) 204 initial_msg = { 205 "type": "sample", 206 "name": name, (...) 215 "infer": {} if infer is None else infer, 216 } 218 # ...and use apply_stack to send it to the Messengers --> 219 msg = apply_stack(initial_msg) 220 return msg["value"]

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:47, in apply_stack(msg) 45 pointer = 0 46 for pointer, handler in enumerate(reversed(_PYRO_STACK)): ---> 47 handler.process_message(msg) 48 # When a Messenger sets the "stop" field of a message, 49 # it prevents any Messengers above it on the stack from being applied. 50 if msg.get("stop"):

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/handlers.py:789, in substitute.process_message(self, msg) 787 value = self.data.get(msg["name"]) 788 else: --> 789 value = self.substitute_fn(msg) 791 if value is not None: 792 msg["value"] = value

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:216, in _unconstrain_reparam(params, site) 213 return p 214 value = t(p) --> 216 log_det = t.log_abs_det_jacobian(p, value) 217 log_det = sum_rightmost( 218 log_det, jnp.ndim(log_det) - jnp.ndim(value) + len(site["fn"].event_shape) 219 ) 220 if site["scale"] is not None:

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/distributions/transforms.py:816, in SigmoidTransform.log_abs_det_jacobian(self, x, y, intermediates) 815 def log_abs_det_jacobian(self, x, y, intermediates=None): --> 816 return -softplus(x) - softplus(-x)

    [... skipping hidden 20 frame]
    

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/nn/functions.py:66, in softplus(x) 54 @jax.jit 55 def softplus(x: Array) -> Array: 56 r"""Softplus activation function. 57 58 Computes the element-wise function (...) 64 x : input array 65 """ ---> 66 return jnp.logaddexp(x, 0)

    [... skipping hidden 5 frame]
    

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/numpy/ufuncs.py:361, in _logaddexp_jvp(primals, tangents) 359 x1, x2 = primals 360 t1, t2 = tangents --> 361 x1, x2, t1, t2 = _promote_args_inexact("logaddexp_jvp", x1, x2, t1, t2) 362 primal_out = logaddexp(x1, x2) 363 tangent_out = lax.add(lax.mul(t1, exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), 364 lax.mul(t2, exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/numpy/util.py:327, in _promote_args_inexact(fun_name, *args) 325 _check_arraylike(fun_name, *args) 326 _check_no_float0s(fun_name, *args) --> 327 return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args))

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/numpy/util.py:262, in _promote_dtypes_inexact(*args) 258 def _promote_dtypes_inexact(*args): 259 """Convenience function to apply Numpy argument dtype promotion. 260 261 Promotes arguments to an inexact type.""" --> 262 to_dtype, weak_type = dtypes._lattice_result_type(*args) 263 to_dtype = dtypes.canonicalize_dtype(to_dtype) 264 to_dtype_inexact = _to_inexact_dtype(to_dtype)

    [... skipping hidden 2 frame]
    

    File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/dtypes.py:311, in (.0) 309 N = set(nodes) 310 UB = _lattice_upper_bounds --> 311 CUB = set.intersection(*(UB[n] for n in N)) 312 LUB = (CUB & N) or {c for c in CUB if CUB.issubset(UB[c])} 313 if len(LUB) == 1:

    KeyError: dtype([('float0', 'V')])

    dtype([('float0', 'V')])

    opened by blattnem 13
  • plot_media_channel_posteriors running into an error

    plot_media_channel_posteriors running into an error

    image

    I'm running into an error when trying the plot_media_channel_posteriors on the standard simulated data from the instructions. "IndexError: too many indices for array: array is 1-dimensional, but 3 were indexed"

    Here's my colab notebook: https://colab.research.google.com/drive/1S3V8T8CfIFaaGweySDyuQ4Jrqy8QnU1i?usp=sharing

    opened by hammer-mt 13
  • Budget allocation initial values are wrong when the modelling variables are not all costs

    Budget allocation initial values are wrong when the modelling variables are not all costs

    https://github.com/google/lightweight_mmm/blob/main/lightweight_mmm/optimize_media.py#L145

    Here in the generating starting values function, the prices for each media channel is not passed in therefore the starting values are always not the actual values in monetary terms

    bug p0 
    opened by yanhong-zhao-ef 12
  • AttributeError: 'CustomScaler' object has no attribute 'divide_by'

    AttributeError: 'CustomScaler' object has no attribute 'divide_by'

    I have loaded the model using load_model, when trying to use 'find_optimal_budgets' with media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean) target_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean) I'm getting 'AttributeError: 'CustomScaler' object has no attribute 'divide_by'' error

    opened by virithavanama 8
  • Example failed - unexpected keyword argument in fit

    Example failed - unexpected keyword argument in fit

    Hello,

    I'm running this example on a Databricks cluster (Runtime 10.4LTS, Python 3.8.10) and I'm getting the following issue when running fit:

    mmm.fit(
        media=media_data_train,
        media_prior=costs,
        target=target_train,
        extra_features=extra_features_train,
        number_warmup=number_warmup,
        number_samples=number_samples,
        seed=SEED)
    
    TypeError: fit() got an unexpected keyword argument 'media_prior'
    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    <command-420949> in <module>
          1 # For replicability in terms of random number generation in sampling
          2 # reuse the same seed for different trainings.
    ----> 3 mmm.fit(
          4     media=media_data_train,
          5     media_prior=costs,
    
    TypeError: fit() got an unexpected keyword argument 'media_prior'
    
    opened by nie-moviestarplanet 8
  • Delay effect in hill adstock transformation

    Delay effect in hill adstock transformation

    Hi,

    I'd like to know if it's possible to include the delay effect of the adstock transfomation on media data if the variable transform_function in the definition of the model is set to transform_hill_adstock.

    Looking at the function transform_hill_adstock of the script models.py of the repository, it doesn't seem to be possible, as the delay effect is only included in the carryover transformation.

    Thanks, Alessandro

    opened by adavoli91 7
  • Weird MAPE values

    Weird MAPE values

    Hi Team,

    1. I'm using the channel's cost as the media data and after training all the r_hat values are less than 1.1 and r2 score is 0.95, the prediction seems to be good but the MAPE value is weird as shown in the figure. Am I missing something? Screen Shot 2022-08-23 at 5 11 25 PM

    2. Can I use impressions as the extra features?

    Thanks!

    opened by virithavanama 7
  • Difference between RBA and light MMM

    Difference between RBA and light MMM

    Hi, I am trying to compare the RBA and light MMM on my data and I don't manage to understand the difference between the 2 models. Both models are regressions right ? Where does the MCMC is used in the MMM model ? @pabloduque0 @cahlheim

    question 
    opened by elisakrammerfiverr 7
  • running mmm_lightweight with M1 mackbooks

    running mmm_lightweight with M1 mackbooks

    mmm_Lightweight won't get installed on the new M1 mackbook pro even if tensorflow is properly installed and running. There is the following error message: ERROR: Could not find a version that satisfies the requirement tensorflow>=2.7.2 (from lightweight-mmm) (from versions: none)

    dependencies 
    opened by Samuelchazy 6
  • ValueError: Normal distribution got invalid loc parameter.

    ValueError: Normal distribution got invalid loc parameter.

    Hi! I'm attempting to recreate the sample presented in PyData 2022 seen here with some of my own MMM data: https://github.com/takechanman1228/mmm_pydata_global_2022/blob/main/simple_end_to_end_demo_pydataglobal.ipynb

    data = data.tail(150)
    data_size = len(data)
    
    n_media_channels = len(mdsp_cols)
    n_extra_features = len(control_vars)
    media_data = data[mdsp_cols].to_numpy()
    extra_features = data[control_vars].to_numpy()
    target = data['y'].to_numpy()
    costs = data[mdsp_cols].sum().to_numpy()
    
    media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
    extra_features_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
    target_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
    cost_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean, multiply_by=0.15)
    
    media_data_train = media_scaler.fit_transform(media_data_train)
    extra_features_train = extra_features_scaler.fit_transform(extra_features_train)
    target_train = target_scaler.fit_transform(target_train)
    costs = cost_scaler.fit_transform(costs)
    
    mmm = lightweight_mmm.LightweightMMM(model_name="hill_adstock")
    
    number_warmup=1000
    number_samples=1000
    
    mmm.fit(
        media=media_data_train,
        media_prior=costs,
        target=target_train,
        extra_features=extra_features_train,
        number_warmup=number_warmup,
        number_samples=number_samples,
        media_names = mdsp_cols,
        seed=105)
    

    The below error is displayed:

    ValueError                                Traceback (most recent call last)
    /tmp/ipykernel_3869/3074029020.py in <module>
         12     number_samples=number_samples,
         13     media_names = mdsp_cols,
    ---> 14     seed=105)
    
    /opt/conda/lib/python3.7/site-packages/lightweight_mmm/lightweight_mmm.py in fit(self, media, media_prior, target, extra_features, degrees_seasonality, seasonality_frequency, weekday_seasonality, media_names, number_warmup, number_samples, number_chains, target_accept_prob, init_strategy, custom_priors, seed)
        370         transform_function=self._model_transform_function,
        371         weekday_seasonality=weekday_seasonality,
    --> 372         custom_priors=custom_priors)
        373 
        374     self.custom_priors = custom_priors
    
    /opt/conda/lib/python3.7/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
        595         else:
        596             if self.chain_method == "sequential":
    --> 597                 states, last_state = _laxmap(partial_map_fn, map_args)
        598             elif self.chain_method == "parallel":
        599                 states, last_state = pmap(partial_map_fn)(map_args)
    
    /opt/conda/lib/python3.7/site-packages/numpyro/infer/mcmc.py in _laxmap(f, xs)
        158     for i in range(n):
        159         x = jit(_get_value_from_index)(xs, i)
    --> 160         ys.append(f(x))
        161 
        162     return tree_map(lambda *args: jnp.stack(args), *ys)
    
    /opt/conda/lib/python3.7/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
        384                 init_params,
        385                 model_args=args,
    --> 386                 model_kwargs=kwargs,
        387             )
        388         sample_fn, postprocess_fn = self._get_cached_fns()
    
    /opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
        705             )
        706         init_params = self._init_state(
    --> 707             rng_key_init_model, model_args, model_kwargs, init_params
        708         )
        709         if self._potential_fn and init_params is None:
    
    /opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc.py in _init_state(self, rng_key, model_args, model_kwargs, init_params)
        657                 model_args=model_args,
        658                 model_kwargs=model_kwargs,
    --> 659                 forward_mode_differentiation=self._forward_mode_differentiation,
        660             )
        661             if self._init_fn is None:
    
    /opt/conda/lib/python3.7/site-packages/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
        674             with numpyro.validation_enabled(), trace() as tr:
        675                 # validate parameters
    --> 676                 substituted_model(*model_args, **model_kwargs)
        677                 # validate values
        678                 for site in tr.values():
    
    /opt/conda/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
        103             return self
        104         with self:
    --> 105             return self.fn(*args, **kwargs)
        106 
        107 
    
    /opt/conda/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
        103             return self
        104         with self:
    --> 105             return self.fn(*args, **kwargs)
        106 
        107 
    
    /opt/conda/lib/python3.7/site-packages/lightweight_mmm/models.py in media_mix_model(media_data, target_data, media_prior, degrees_seasonality, frequency, transform_function, custom_priors, transform_kwargs, weekday_seasonality, extra_features)
        433 
        434   numpyro.sample(
    --> 435       name="target", fn=dist.Normal(loc=mu, scale=sigma), obs=target_data)
    
    /opt/conda/lib/python3.7/site-packages/numpyro/distributions/distribution.py in __call__(cls, *args, **kwargs)
         97             if result is not None:
         98                 return result
    ---> 99         return super().__call__(*args, **kwargs)
        100 
        101 
    
    /opt/conda/lib/python3.7/site-packages/numpyro/distributions/continuous.py in __init__(self, loc, scale, validate_args)
       1700         batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
       1701         super(Normal, self).__init__(
    -> 1702             batch_shape=batch_shape, validate_args=validate_args
       1703         )
       1704 
    
    /opt/conda/lib/python3.7/site-packages/numpyro/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
        177                         raise ValueError(
        178                             "{} distribution got invalid {} parameter.".format(
    --> 179                                 self.__class__.__name__, param
        180                             )
        181                         )
    
    ValueError: Normal distribution got invalid loc parameter.
    

    I've checked for null values in my observation data, and none are present. In addition, I removed zero-cost channels from the data model after checking that a few had zero-cost after scaling and referred to the answer in https://github.com/google/lightweight_mmm/issues/115 as such. I also tried scaling down the number of rows and the number of columns that are fed into the model, but none of those have helped get past this error. Please let me know what I can do to diagnose this model. Thanks in advance.

    opened by steven-struglia 21
  • inconsistent Response curves

    inconsistent Response curves

    Getting this weird behaviour on the Demo data - was running the demo end to end example 2 times (as is in 1st run and in the 2nd run with different splitting point - both in the same jupyter kernel) split_point = data_size - 10 (instead of split_point = data_size - 13) getting this for 1st run image

    and this for 2nd run image

    what could be the problem? could be something to do with the seed? ot something else im missing? btw, when i restart the kernel and run the #2 option first, i get similar results to original #1 run

    opened by ohad-monday 1
  • ValueError: Normal distribution got invalid scale parameter During .fit() call

    ValueError: Normal distribution got invalid scale parameter During .fit() call

    This is a question related to my attempt to implement the sample presented during PyData 2022 with our own data. I am able to successfully train the sample data following the process outlined in the talk here: https://github.com/takechanman1228/mmm_pydata_global_2022/blob/main/simple_end_to_end_demo_pydataglobal.ipynb

    When I attempt to reproduce the results with some of our own data, I hit a road-block when attempting to fit the model and searching for this error online doesn't get me any results. Literally none. For one of the first times in my career Google actually returned no results for an error, so I was hoping someone on the board here could point me in the right direction, as my knowledge of how these Scalars are supposed to work is limited. "ValueError: Normal distribution got invalid scale parameter During .fit() call"

    I have a Pandas Dataframe with a head() that prints like below. The only difference I can think of between my model and the sample is the target in the sample was "sales" (ie a currency like the spend) and I'm targeting Conversions, which is a non-currency count of people who converted.

      Week_Beginning  Conversions    Social  Audio  Display  Video  Search  Content
    0     2019-10-27          0.0   112.000    0.0      0.0    0.0     0.0      0.0
    1     2019-11-03          0.0   252.000    0.0      0.0    0.0     0.0      0.0
    2     2019-11-10          0.0  1326.000    0.0      0.0    0.0     0.0      0.0
    3     2019-11-17          0.0  1700.000    0.0      0.0    0.0     0.0      0.0
    4     2019-11-24          0.0  1182.912    0.0      0.0    0.0     0.0      0.0
    

    With CustomScalars defined thus:

    media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
    target_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
    cost_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean, multiply_by=0.15)
    

    With a LightweightMMM model instantiated like so:

    mmm = lightweight_mmm.LightweightMMM(model_name="hill_adstock")
    mmm.fit(
        media=media_data_train,
        media_prior=costs,
        target=target_train,
        number_warmup=number_warmup,
        number_samples=number_samples,
        media_names = mdsp_cols,
        seed=SEED)
    

    and get the following traceback:

    Traceback (most recent call last):
      File "data.py", line 97, in <module>
        mmm.fit(
      File "/usr/local/lib/python3.8/dist-packages/lightweight_mmm/lightweight_mmm.py", line 362, in fit
        mcmc.run(
      File "/usr/local/lib/python3.8/dist-packages/numpyro/infer/mcmc.py", line 597, in run
        states, last_state = _laxmap(partial_map_fn, map_args)
      File "/usr/local/lib/python3.8/dist-packages/numpyro/infer/mcmc.py", line 160, in _laxmap
        ys.append(f(x))
      File "/usr/local/lib/python3.8/dist-packages/numpyro/infer/mcmc.py", line 381, in _single_chain_mcmc
        init_state = self.sampler.init(
      File "/usr/local/lib/python3.8/dist-packages/numpyro/infer/hmc.py", line 706, in init
        init_params = self._init_state(
      File "/usr/local/lib/python3.8/dist-packages/numpyro/infer/hmc.py", line 652, in _init_state
        init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
      File "/usr/local/lib/python3.8/dist-packages/numpyro/infer/util.py", line 676, in initialize_model
        substituted_model(*model_args, **model_kwargs)
      File "/usr/local/lib/python3.8/dist-packages/numpyro/primitives.py", line 105, in __call__
        return self.fn(*args, **kwargs)
      File "/usr/local/lib/python3.8/dist-packages/numpyro/primitives.py", line 105, in __call__
        return self.fn(*args, **kwargs)
      File "/usr/local/lib/python3.8/dist-packages/lightweight_mmm/models.py", line 350, in media_mix_model
        fn=dist.HalfNormal(scale=media_prior))
      File "/usr/local/lib/python3.8/dist-packages/numpyro/distributions/distribution.py", line 99, in __call__
        return super().__call__(*args, **kwargs)
      File "/usr/local/lib/python3.8/dist-packages/numpyro/distributions/continuous.py", line 465, in __init__
        self._normal = Normal(0.0, scale)
      File "/usr/local/lib/python3.8/dist-packages/numpyro/distributions/distribution.py", line 99, in __call__
        return super().__call__(*args, **kwargs)
      File "/usr/local/lib/python3.8/dist-packages/numpyro/distributions/continuous.py", line 1701, in __init__
        super(Normal, self).__init__(
      File "/usr/local/lib/python3.8/dist-packages/numpyro/distributions/distribution.py", line 177, in __init__
        raise ValueError(
    ValueError: Normal distribution got invalid scale parameter.
    

    Any help debugging would be most appreciated.

    opened by RogerWebb 4
  • Hill and carryover transformations

    Hill and carryover transformations

    Hi,

    I recently discovered your library and I find it very interesting.

    However, as far as I can see, the combination of hill transformation + carryover is not implemented (differently from hill + adstock). Do you plan to include it?

    I think it could be very useful to combine lag and saturation effects.

    Thanks, Kimberly

    opened by kimbtowns 1
  • OOM errors with GPU

    OOM errors with GPU

    I have a use case which requires probably fair amount of data ( 27 channels, 644 Geos). The cpu approach takes few days, and the GPU option takes 4-5 hours but always crashes at 100% due to OOM errors. I have tried with multiple GPUS and memory config 16GB, 32GB, 40GB but it seems to fail everytime.

    I have also used the following XLA options but with no luck.

    os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] ='false' os.environ['XLA_PYTHON_CLIENT_ALLOCATOR']='platform' os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

    The only thing that seems to work is to reduce the data points from 365 to 50 or something.

    ` import jax.numpy as jnp import numpyro

    from lightweight_mmm import lightweight_mmm from lightweight_mmm import optimize_media from lightweight_mmm import plot from lightweight_mmm import preprocessing from lightweight_mmm import utils

    media_data, extra_features, target, costs = utils.simulate_dummy_data( data_size=365, n_media_channels=27, n_extra_features=1, geos=644) #

    mmm = lightweight_mmm.LightweightMMM('adstock') mmm.fit(media=media_data, extra_features=extra_features, media_prior=costs, target=target, number_warmup=1000, number_samples=1000, weekday_seasonality=False,number_chains=1) ` Please see the stack trace below

    ` sample: 100%|██████████| 2000/2000 [5:02:36<00:00, 9.08s/it, 1023 steps of size 5.54e-04. acc. prob=0.93

    XlaRuntimeError Traceback (most recent call last) in 12 number_warmup=1000, 13 number_samples=1000, ---> 14 weekday_seasonality=False,number_chains=1)

    ~/.local/lib/python3.7/site-packages/lightweight_mmm/lightweight_mmm.py in fit(self, media, media_prior, target, extra_features, degrees_seasonality, seasonality_frequency, weekday_seasonality, media_names, number_warmup, number_samples, number_chains, target_accept_prob, init_strategy, custom_priors, seed) 370 transform_function=self._model_transform_function, 371 weekday_seasonality=weekday_seasonality, --> 372 custom_priors=custom_priors) 373 374 self.custom_priors = custom_priors

    ~/.local/lib/python3.7/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs) 592 if self.num_chains == 1: 593 states_flat, last_state = partial_map_fn(map_args) --> 594 states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat) 595 else: 596 if self.chain_method == "sequential":

    ~/.local/lib/python3.7/site-packages/jax/_src/tree_util.py in tree_map(f, tree, is_leaf, *rest) 205 leaves, treedef = tree_flatten(tree, is_leaf) 206 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest] --> 207 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) 208 209 def build_tree(treedef: PyTreeDef, xs: Any) -> Any:

    ~/.local/lib/python3.7/site-packages/jax/_src/tree_util.py in (.0) 205 leaves, treedef = tree_flatten(tree, is_leaf) 206 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest] --> 207 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) 208 209 def build_tree(treedef: PyTreeDef, xs: Any) -> Any:

    ~/.local/lib/python3.7/site-packages/numpyro/infer/mcmc.py in (x) 592 if self.num_chains == 1: 593 states_flat, last_state = partial_map_fn(map_args) --> 594 states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat) 595 else: 596 if self.chain_method == "sequential":

    ~/.local/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value) 3815 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape) 3816 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, -> 3817 unique_indices, mode, fill_value) 3818 3819 # TODO(phawkins): re-enable jit after fixing excessive recompilation for

    ~/.local/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value) 3852 3853 # This adds np.newaxis/None dimensions. -> 3854 return expand_dims(y, indexer.newaxis_dims) 3855 3856 _Indexer = collections.namedtuple("_Indexer", [

    ~/.local/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py in expand_dims(a, axis) 920 if hasattr(a, "expand_dims"): 921 return a.expand_dims(axis) # type: ignore --> 922 return lax.expand_dims(a, axis) 923 924

    ~/.local/lib/python3.7/site-packages/jax/_src/lax/lax.py in expand_dims(array, dimensions) 1320 result_shape.insert(i, 1) 1321 broadcast_dims = [i for i in range(ndim_out) if i not in dims_set] -> 1322 return broadcast_in_dim(array, result_shape, broadcast_dims) 1323 1324

    ~/.local/lib/python3.7/site-packages/jax/_src/lax/lax.py in broadcast_in_dim(operand, shape, broadcast_dimensions) 822 return broadcast_in_dim_p.bind( 823 operand, *dyn_shape, shape=tuple(static_shape), --> 824 broadcast_dimensions=tuple(broadcast_dimensions)) 825 826 def broadcast_to_rank(x: Array, rank: int) -> Array:

    ~/.local/lib/python3.7/site-packages/jax/core.py in bind(self, *args, **params) 327 assert (not config.jax_enable_checks or 328 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args --> 329 return self.bind_with_trace(find_top_trace(args), args, params) 330 331 def bind_with_trace(self, trace, args, params):

    ~/.local/lib/python3.7/site-packages/jax/core.py in bind_with_trace(self, trace, args, params) 330 331 def bind_with_trace(self, trace, args, params): --> 332 out = trace.process_primitive(self, map(trace.full_raise, args), params) 333 return map(full_lower, out) if self.multiple_results else full_lower(out) 334

    ~/.local/lib/python3.7/site-packages/jax/core.py in process_primitive(self, primitive, tracers, params) 710 711 def process_primitive(self, primitive, tracers, params): --> 712 return primitive.impl(*tracers, **params) 713 714 def process_call(self, primitive, f, tracers, params):

    ~/.local/lib/python3.7/site-packages/jax/_src/dispatch.py in apply_primitive(prim, *args, **params) 113 compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), 114 **params) --> 115 return compiled_fun(*args) 116 117 # TODO(phawkins,frostig,mattjj): update code referring to

    ~/.local/lib/python3.7/site-packages/jax/_src/dispatch.py in (*args, **kw) 198 prim.name, donated_invars, False, *arg_specs) 199 if not prim.multiple_results: --> 200 return lambda *args, **kw: compiled(*args, **kw)[0] 201 else: 202 return compiled

    ~/.local/lib/python3.7/site-packages/jax/_src/dispatch.py in _execute_compiled(name, compiled, input_handler, output_buffer_counts, result_handler, has_unordered_effects, ordered_effects, kept_var_idx, has_host_callbacks, *args) 893 runtime_token = None 894 else: --> 895 out_flat = compiled.execute(in_flat) 896 check_special(name, out_flat) 897 out_bufs = unflatten(out_flat, output_buffer_counts)

    XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to allocate request for 23.64GiB (25386480000B) on device ordinal 0 BufferAssignment OOM Debugging. BufferAssignment stats: parameter allocation: 23.64GiB constant allocation: 0B maybe_live_out allocation: 23.64GiB preallocated temp allocation: 0B total allocation: 47.29GiB total fragmentation: 0B (0.00%) Peak buffers: Buffer 1: Size: 23.64GiB Entry Parameter Subshape: f32[1000,365,27,644] ==========================

    Buffer 2:
    	Size: 23.64GiB
    	XLA Label: copy
    	Shape: f32[1,1000,365,27,644]
    	==========================
    

    `

    opened by giritatavarty-8451 2
Releases(v0.1.6)
  • v0.1.6(Oct 20, 2022)

    What's Changed

    • Fixed optimization previous media/budget allocation to account for unequal prices.
    • Exposed more scipy's parameters on optimization and made jacobian more accurate.
    • New plot for prior-posterior comparison for all parameters in the model.
    • Minor fixes and improvements.
    Source code(tar.gz)
    Source code(zip)
  • v0.1.5(Aug 9, 2022)

    What's Changed

    • Add new API for custom priors and respective documentation.
    • Further expansion on supporting python version from 3.7 to 3.10 inclusive.
    • Documentation improvements.
    • Minor fixes.
    Source code(tar.gz)
    Source code(zip)
  • v0.1.4(Jun 30, 2022)

    What's Changed

    • Plotting media contribution overtime has been added to plots.
    • Plotting results from optimisation has been added.
    • New utility for going from pd.DataFrame to our expected geo 3D array for geo models
    • Expanded documentation of models.
    • Improved workflows, CI and documentation
    • Fixed edge case for certain tests
    • Minor changes and fixes

    Full Changelog: https://github.com/google/lightweight_mmm/compare/0.1.3...v0.1.4

    Source code(tar.gz)
    Source code(zip)
  • 0.1.3(Jun 7, 2022)

  • v0.1.2(May 3, 2022)

    What's Changed

    • Add geo level models for bigger countries
    • Adapt data scaling to accept geo data
    • Adapt media optimization to geo models
    • Adapt plots to geo level model and data
    • Minor fixes and changes
    Source code(tar.gz)
    Source code(zip)
  • v0.1.1(Feb 18, 2022)

    What's Changed

    • Fix install and import issues
    • Add simple demo notebook
    • Add control over seeds in fit and predict methods
    • Minor changes and fixes
    Source code(tar.gz)
    Source code(zip)
  • release(Feb 14, 2022)

Owner
Google
Google ❤️ Open Source
Google
The code of paper "Block Modeling-Guided Graph Convolutional Neural Networks".

Block Modeling-Guided Graph Convolutional Neural Networks This repository contains the demo code of the paper: Block Modeling-Guided Graph Convolution

22 Dec 08, 2022
This repository is the official implementation of Open Rule Induction. This paper has been accepted to NeurIPS 2021.

Open Rule Induction This repository is the official implementation of Open Rule Induction. This paper has been accepted to NeurIPS 2021. Abstract Rule

Xingran Chen 16 Nov 14, 2022
Research code for Arxiv paper "Camera Motion Agnostic 3D Human Pose Estimation"

GMR(Camera Motion Agnostic 3D Human Pose Estimation) This repo provides the source code of our arXiv paper: Seong Hyun Kim, Sunwon Jeong, Sungbum Park

Seong Hyun Kim 1 Feb 07, 2022
Custom Implementation of Non-Deep Networks

ParNet Custom Implementation of Non-deep Networks arXiv:2110.07641 Ankit Goyal, Alexey Bochkovskiy, Jia Deng, Vladlen Koltun Official Repository https

Pritama Kumar Nayak 20 May 27, 2022
[ACMMM 2021, Oral] Code release for "Elastic Tactile Simulation Towards Tactile-Visual Perception"

EIP: Elastic Interaction of Particles Code release for "Elastic Tactile Simulation Towards Tactile-Visual Perception", in ACMMM (Oral) 2021. By Yikai

Yikai Wang 37 Dec 20, 2022
DIT is a DTLS MitM proxy implemented in Python 3. It can intercept, manipulate and suppress datagrams between two DTLS endpoints and supports psk-based and certificate-based authentication schemes (RSA + ECC).

DIT - DTLS Interception Tool DIT is a MitM proxy tool to intercept DTLS traffic. It can intercept, manipulate and/or suppress DTLS datagrams between t

52 Nov 30, 2022
Customer-Transaction-Analysis - This analysis is based on a synthesised transaction dataset containing 3 months worth of transactions for 100 hypothetical customers.

Customer-Transaction-Analysis - This analysis is based on a synthesised transaction dataset containing 3 months worth of transactions for 100 hypothetical customers. It contains purchases, recurring

Ayodeji Yekeen 1 Jan 01, 2022
Plugin adapted from Ultralytics to bring YOLOv5 into Napari

napari-yolov5 Plugin adapted from Ultralytics to bring YOLOv5 into Napari. Training and detection can be done using the GUI. Training dataset must be

2 May 05, 2022
Transfer Learning Shootout for PyTorch's model zoo (torchvision)

pytorch-retraining Transfer Learning shootout for PyTorch's model zoo (torchvision). Load any pretrained model with custom final layer (num_classes) f

Alexander Hirner 169 Jun 29, 2022
Object-aware Contrastive Learning for Debiased Scene Representation

Object-aware Contrastive Learning Official PyTorch implementation of "Object-aware Contrastive Learning for Debiased Scene Representation" by Sangwoo

43 Dec 14, 2022
Medical Insurance Cost Prediction using Machine earning

Medical-Insurance-Cost-Prediction-using-Machine-learning - Here in this project, I will use regression analysis to predict medical insurance cost for people in different regions, and based on several

1 Dec 27, 2021
CyTran: Cycle-Consistent Transformers for Non-Contrast to Contrast CT Translation

CyTran: Cycle-Consistent Transformers for Non-Contrast to Contrast CT Translation We propose a novel approach to translate unpaired contrast computed

Nicolae Catalin Ristea 13 Jan 02, 2023
Source code for our paper "Molecular Mechanics-Driven Graph Neural Network with Multiplex Graph for Molecular Structures"

Molecular Mechanics-Driven Graph Neural Network with Multiplex Graph for Molecular Structures Code for the Multiplex Molecular Graph Neural Network (M

shzhang 59 Dec 10, 2022
PyTorch implementation for our AAAI 2022 Paper "Graph-wise Common Latent Factor Extraction for Unsupervised Graph Representation Learning"

deepGCFX PyTorch implementation for our AAAI 2022 Paper "Graph-wise Common Latent Factor Extraction for Unsupervised Graph Representation Learning" Pr

Thilini Cooray 4 Aug 11, 2022
A plug-and-play library for neural networks written in Python

A plug-and-play library for neural networks written in Python!

Dimos Michailidis 2 Jul 16, 2022
🌈 PyTorch Implementation for EMNLP'21 Findings "Reasoning Visual Dialog with Sparse Graph Learning and Knowledge Transfer"

SGLKT-VisDial Pytorch Implementation for the paper: Reasoning Visual Dialog with Sparse Graph Learning and Knowledge Transfer Gi-Cheon Kang, Junseok P

Gi-Cheon Kang 9 Jul 05, 2022
GeneDisco is a benchmark suite for evaluating active learning algorithms for experimental design in drug discovery.

GeneDisco is a benchmark suite for evaluating active learning algorithms for experimental design in drug discovery.

22 Dec 12, 2022
Cerberus Transformer: Joint Semantic, Affordance and Attribute Parsing

Cerberus Transformer: Joint Semantic, Affordance and Attribute Parsing Paper Introduction Multi-task indoor scene understanding is widely considered a

62 Dec 05, 2022
MPI-IS Mesh Processing Library

Perceiving Systems Mesh Package This package contains core functions for manipulating meshes and visualizing them. It requires Python 3.5+ and is supp

Max Planck Institute for Intelligent Systems 494 Jan 06, 2023
一个免费开源一键搭建的通用验证码识别平台,大部分常见的中英数验证码识别都没啥问题。

captcha_server 一个免费开源一键搭建的通用验证码识别平台,大部分常见的中英数验证码识别都没啥问题。 使用方法 python = 3.8 以上环境 pip install -r requirements.txt -i https://pypi.douban.com/simple gun

Sml2h3 189 Dec 02, 2022