Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cirkit/backend/torch/circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _build_address_book(self, fold_idx_info: FoldIndexInfo) -> LayerAddressBook:

def _evaluate_layers(self, x: Tensor | None) -> Tensor:
# Evaluate layers on the given input
y = self.evaluate(x) # (O, B, K)
y = self.evaluate(x)[-1] # (O, B, K)
return y.transpose(0, 1) # (B, O, K)


Expand Down
10 changes: 6 additions & 4 deletions cirkit/backend/torch/graph/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def subgraph(self, *roots: TorchModule) -> "TorchDiAcyclicGraph[TorchModule]":

def evaluate(
self, x: Tensor | None = None, module_fn: ModuleEvalFunctional | None = None
) -> Tensor:
) -> list[Tensor]:
"""Evaluate the Torch graph by following the topological ordering,
and by using the address book information to retrieve the inputs to each module.

Expand All @@ -313,8 +313,8 @@ def evaluate(
the module itself is used.

Returns:
The output tensor of the Torch graph.
If the Torch graph has multiple outputs, then they will be stacked.
A list of Tensors corresponding to the outputs of each layer from input to output.
If the Torch graph has multiple outputs, then output of the last layer will be stacked.

Raises:
RuntimeError: If the address book is somehow not well-formed.
Expand All @@ -326,7 +326,9 @@ def evaluate(
for module, inputs in self._address_book.lookup(module_outputs, in_graph=x):
if module is None:
(output,) = inputs
return output
# return the list of outputs from each layer
module_outputs[-1] = output
return module_outputs
if module_fn is None:
y = module(*inputs)
else:
Expand Down
19 changes: 14 additions & 5 deletions cirkit/backend/torch/layers/inner.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,15 @@ def forward(self, x: Tensor) -> Tensor:
is the number of output units.
"""

def sample(self, x: Tensor) -> tuple[Tensor, Tensor | None]:
def sample(self, x: Tensor, evidence: Tensor = None) -> tuple[Tensor, Tensor | None]:
"""Perform a forward sampling step.

Args:
x: A tensor representing the input variable assignments, having shape
$(F, H, C, K, N, D)$, where $F$ is the number of folds, $H$ is the arity,
$C$ is the number of channels, $K$ is the numbe rof input units, $N$ is the number
of samples, $D$ is the number of variables.
evidence: A tensor representing the evidence for the layer, having same structure as x.

Returns:
Tensor: A new tensor representing the new variable assignements the layers gives
Expand Down Expand Up @@ -123,7 +124,7 @@ def config(self) -> Mapping[str, Any]:
def forward(self, x: Tensor) -> Tensor:
return self.semiring.prod(x, dim=1, keepdim=False) # shape (F, H, B, K) -> (F, B, K).

def sample(self, x: Tensor) -> tuple[Tensor, None]:
def sample(self, x: Tensor, evidence: Tensor = None) -> tuple[Tensor, None]:
# Concatenate samples over disjoint variables through a sum
# x: (F, H, C, K, num_samples, D)
x = torch.sum(x, dim=1) # (F, C, K, num_samples, D)
Expand Down Expand Up @@ -183,7 +184,7 @@ def forward(self, x: Tensor) -> Tensor:
# y0: (F, B, Ko=Ki ** arity)
return y0

def sample(self, x: Tensor) -> tuple[Tensor, Tensor | None]:
def sample(self, x: Tensor, evidence: Tensor = None) -> tuple[Tensor, Tensor | None]:
# x: (F, H, C, K, num_samples, D)
y0 = x[:, 0]
for i in range(1, x.shape[1]):
Expand Down Expand Up @@ -269,7 +270,7 @@ def forward(self, x: Tensor) -> Tensor:
"fbi,foi->fbo", inputs=(x,), operands=(weight,), dim=-1, keepdim=True
) # shape (F, B, K_o).

def sample(self, x: Tensor) -> tuple[Tensor, Tensor]:
def sample(self, x: Tensor, evidence: Tensor = None) -> tuple[Tensor, Tensor]:
weight = self.weight()
negative = torch.any(weight < 0.0)
normalized = torch.allclose(torch.sum(weight, dim=-1), torch.ones(1, device=weight.device))
Expand All @@ -283,7 +284,15 @@ def sample(self, x: Tensor) -> tuple[Tensor, Tensor]:
d = x.shape[3]

# mixing_distribution: (F, Ko, H * Ki)
mixing_distribution = torch.distributions.Categorical(probs=weight)
if evidence is not None:
prior = torch.log(torch.clamp(weight, min=1e-10))
posterior = prior + evidence
normalized_posterior = torch.exp(
posterior - torch.logsumexp(posterior, 2, keepdim=True)
)
mixing_distribution = torch.distributions.Categorical(probs=normalized_posterior)
else:
mixing_distribution = torch.distributions.Categorical(probs=weight)

# mixing_samples: (num_samples, F, Ko) -> (F, Ko, num_samples)
mixing_samples = mixing_distribution.sample((num_samples,))
Expand Down
2 changes: 1 addition & 1 deletion cirkit/backend/torch/layers/optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def forward(self, x: Tensor) -> Tensor:
"fbi,foi->fbo", inputs=(x,), operands=(weight,), dim=-1, keepdim=True
)

def sample(self, x: Tensor) -> tuple[Tensor, Tensor]:
def sample(self, x: Tensor, evidence: Tensor = None) -> tuple[Tensor, Tensor]:
weight = self.weight()
negative = torch.any(weight < 0.0)
if negative:
Expand Down
3 changes: 2 additions & 1 deletion cirkit/backend/torch/parameters/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ def forward(self) -> Tensor:
where F is the number of folds, and (K_1,\ldots,K_n) is the shape
of each parameter tensor slice.
"""
return self.evaluate()
output = self.evaluate()[-1]
return output

def _build_unfold_index_info(self) -> FoldIndexInfo:
return build_unfold_index_info(
Expand Down
169 changes: 145 additions & 24 deletions cirkit/backend/torch/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,32 @@ def __call__(self, x: Tensor, *, integrate_vars: Tensor | Scope | Iterable[Scope
f"was defined over {integrate_vars.shape[1]} != {num_vars} variables"
)
else:
# Convert list of scopes to a boolean mask of dimension (B, N) where
# N is the number of variables in the circuit's scope.
integrate_vars_mask = IntegrateQuery.scopes_to_mask(self._circuit, integrate_vars)
integrate_vars_mask = integrate_vars_mask.to(x.device)
integrate_vars_mask = self.convert_scopes_to_boolean_mask(integrate_vars, x)

# Check batch sizes of input x and mask are compatible
self.check_batch_size_compatibility(integrate_vars_mask, x)

output = self.retrieve_layerwise_outputs(integrate_vars_mask, x)[-1]
# output has shape (O, B, K)
return output.transpose(0, 1) # (B, O, K)

def convert_scopes_to_boolean_mask(self, integrate_vars, x):
# Convert list of scopes to a boolean mask of dimension (B, N) where
# N is the number of variables in the circuit's scope.
integrate_vars_mask = IntegrateQuery.scopes_to_mask(self._circuit, integrate_vars)
integrate_vars_mask = integrate_vars_mask.to(x.device)
return integrate_vars_mask

def retrieve_layerwise_outputs(self, integrate_vars_mask, x) -> list[Tensor]:
return self._circuit.evaluate(
x,
module_fn=functools.partial(
IntegrateQuery._layer_fn, integrate_vars_mask=integrate_vars_mask
),
) # (O, B, K)

@staticmethod
def check_batch_size_compatibility(integrate_vars_mask, x):
if integrate_vars_mask.shape[0] not in (1, x.shape[0]):
raise ValueError(
"The number of scopes to integrate over must "
Expand All @@ -101,14 +121,6 @@ def __call__(self, x: Tensor, *, integrate_vars: Tensor | Scope | Iterable[Scope
f"{x.shape[0]} != {integrate_vars_mask.shape[0]} = len(integrate_vars)"
)

output = self._circuit.evaluate(
x,
module_fn=functools.partial(
IntegrateQuery._layer_fn, integrate_vars_mask=integrate_vars_mask
),
) # (O, B, K)
return output.transpose(0, 1) # (B, O, K)

@staticmethod
def _layer_fn(layer: TorchLayer, x: Tensor, *, integrate_vars_mask: Tensor) -> Tensor:
# Evaluate a layer: if it is not an input layer, then evaluate it in the usual
Expand Down Expand Up @@ -193,8 +205,7 @@ class SamplingQuery(Query):
def __init__(self, circuit: TorchCircuit) -> None:
"""Initialize a sampling query object. Currently, only sampling from the joint distribution
is supported, i.e., sampling won't work in the case of circuits obtained by
marginalization, or by observing evidence. Conditional sampling is currently not
implemented.
marginalization, or by observing evidence.

Args:
circuit: The circuit to sample from.
Expand Down Expand Up @@ -233,27 +244,25 @@ def __call__(self, num_samples: int = 1) -> tuple[Tensor, list[Tensor]]:
# samples: (O, K, num_samples, D)
samples = self._circuit.evaluate(
module_fn=functools.partial(
self._layer_fn,
num_samples=num_samples,
mixture_samples=mixture_samples,
self._layer_fn, num_samples=num_samples, mixture_samples=mixture_samples
),
)
)[-1]
# samples: (num_samples, O, K, D)
samples = samples.permute(2, 0, 1, 3)
# TODO: fix for the case of multi-output circuits, i.e., O != 1 or K != 1
samples = samples[:, 0, 0] # (num_samples, D)
return samples, mixture_samples

def _layer_fn(
self, layer: TorchLayer, *inputs: Tensor, num_samples: int, mixture_samples: list[Tensor]
self,
layer: TorchLayer,
*inputs: Tensor,
num_samples: int,
mixture_samples: list[Tensor],
) -> Tensor:
# Sample from an input layer
if not inputs:
assert isinstance(layer, TorchInputLayer)
samples = layer.sample(num_samples)
samples = self._pad_samples(samples, layer.scope_idx)
mixture_samples.append(samples)
return samples
return self.sample_from_input_layer(layer, mixture_samples, num_samples)

# Sample through an inner layer
assert isinstance(layer, TorchInnerLayer)
Expand All @@ -262,6 +271,13 @@ def _layer_fn(
mixture_samples.append(mix_samples)
return samples

def sample_from_input_layer(self, layer, mixture_samples, num_samples):
assert isinstance(layer, TorchInputLayer)
samples = layer.sample(num_samples)
samples = self._pad_samples(samples, layer.scope_idx)
mixture_samples.append(samples)
return samples

def _pad_samples(self, samples: Tensor, scope_idx: Tensor) -> Tensor:
"""Pads univariate samples to the size of the scope of the circuit (output dimension)
according to scope for compatibility in downstream inner nodes.
Expand All @@ -276,3 +292,108 @@ def _pad_samples(self, samples: Tensor, scope_idx: Tensor) -> Tensor:
fold_idx = torch.arange(samples.shape[0], device=samples.device)
padded_samples[fold_idx, :, :, scope_idx.squeeze(dim=1)] = samples
return padded_samples


class ConditionalSamplingQuery(SamplingQuery):
"""The conditional sampling query object."""

def __init__(self, circuit: TorchCircuit) -> None:
"""Initialize a conditional sampling query object. Currently, only sampling from the joint distribution
is supported, i.e., sampling won't work in the case of circuits obtained by
marginalization, or by observing evidence.

Args:
circuit: The circuit to sample from.

Raises:
ValueError: If the circuit to sample from is not smooth and decomposable.
"""
if not circuit.properties.smooth or not circuit.properties.decomposable:
raise ValueError(
f"The circuit to sample from must be smooth and decomposable, "
f"but found {circuit.properties}"
)
# TODO: add a check to verify the circuit is monotonic and normalized?
super().__init__(circuit=circuit)
self._layerwise_evidence = None

def __call__(
self, num_samples: int = 1, x: Tensor = None, integrate_vars: Scope = None
) -> tuple[Tensor, list[Tensor]]:
"""Sample a number of data points based on the provided evidence.

Args:
num_samples: The number of samples to return.
x: An input batch of shape $(B, D)$, where $B$ is the batch size and $D$ is the number of variables.
integrate_vars: The variables to integrate. It must be a subset of the variables on
which the circuit given in the constructor is defined on. At the moment, only Scope type is supported
and the same integration mask is applied for all entries of the batch.

Return:
A pair (samples, mixture_samples), consisting of (i) an assignment to the observed
variables the circuit is defined on, and (ii) the samples of the finitely-discrete
latent variables associated to the sum units. The samples (i) are returned as a
tensor of shape (num_samples, num_variables).

Raises:
ValueError: if the number of samples is not a positive number or only integrate_vars is specified without x.
"""
if num_samples <= 0:
raise ValueError("The number of samples must be a positive number")
if bool(integrate_vars is None) ^ bool(x is None):
raise ValueError(
"For conditional samples, both input to condition and scope to integrate out must be specified"
)

intgrateQuery = IntegrateQuery(self._circuit)
integrate_vars_mask = intgrateQuery.convert_scopes_to_boolean_mask(integrate_vars, x)

# Check batch sizes of input x and mask are compatible
intgrateQuery.check_batch_size_compatibility(integrate_vars_mask, x)

self._layerwise_evidence = intgrateQuery.retrieve_layerwise_outputs(integrate_vars_mask, x)

mixture_samples: list[Tensor] = []
# samples: (O, K, num_samples, D)
samples = self._circuit.evaluate(
module_fn=functools.partial(
self._layer_fn, num_samples=num_samples, mixture_samples=mixture_samples
),
)[-1]

# samples: (num_samples, O, K, D)
samples = samples.permute(2, 0, 1, 3)
# TODO: fix for the case of multi-output circuits, i.e., O != 1 or K != 1
samples = samples[:, 0, 0] # (num_samples, D)
# combine the conditioned scopes and the observed scopes
marginalized_scope_ids = [i for i in range(x.shape[1]) if i in integrate_vars]
non_marginalized_scope_ids = [i for i in range(x.shape[1]) if i not in integrate_vars]
x[..., marginalized_scope_ids] = 0.0
samples[..., non_marginalized_scope_ids] = 0.0
samples = samples + x
return samples, mixture_samples

def _layerwise_evidence_generator(self):
for evidence in self._layerwise_evidence:
yield evidence

def _layer_fn(
self,
layer: TorchLayer,
*inputs: Tensor,
num_samples: int,
mixture_samples: list[Tensor],
) -> Tensor:
# Sample from an input layer
if not inputs:
return self.sample_from_input_layer(layer, mixture_samples, num_samples)

inner_layer_evidence_generator = self._layerwise_evidence_generator()
inner_layer_evidence = next(inner_layer_evidence_generator)

# Sample through an inner layer
assert isinstance(layer, TorchInnerLayer)
samples, mix_samples = layer.sample(*inputs, evidence=inner_layer_evidence)
if mix_samples is not None:
mixture_samples.append(mix_samples)
return samples
Loading
Loading