Hi, I think something may be wrong when computing the mapped result after getting matrix pi.
please see the transform code from line 338 to 342 in fugw/src/fugw/mappings/dense.py
pi.T
@ source_features_tensor.T
/ pi.sum(dim=0).reshape(-1, 1)
You use $pi^{T} \cdot S^{T}$
But the formula should be $pi \cdot Target$, not source data. Please check the transform code from POT
transp = self.coupling_ / nx.sum(self.coupling_, axis=1)[:, None]
# set nans to 0
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)
# compute transported samples
transp_Xs = nx.dot(transp, self.xt_)
I can show you the proofs based on the application and theory.
Proof 1 - application
Here is an example based on your example Transport distributions using dense solvers
After training and getting the pi, you can show the training points to compare with the mapped points,
# modified from transformed_data = mapping.transform(source_features_test)
transformed_data_train = mapping.transform(source_features_train)
fig = plt.figure(figsize=(4, 4))
ax = fig.add_subplot()
ax.set_title("Source and target features")
ax.set_aspect("equal", "datalim")
ax.scatter(source_features_train[0], source_features_train[1], label="Source")
ax.scatter(target_features_train[0], target_features_train[1], label="Target")
ax.scatter(transformed_data_train[0], transformed_data_train[1], label="trans")
ax.legend()
plt.show()
The plot will be like:
![f9857eae-f839-45e6-a382-ff16c82a59a8](https://private-user-images.githubusercontent.com/24941293/333598729-5ed1d282-23a9-4721-90d5-5a970c31c723.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MjE0MDAxNjgsIm5iZiI6MTcyMTM5OTg2OCwicGF0aCI6Ii8yNDk0MTI5My8zMzM1OTg3MjktNWVkMWQyODItMjNhOS00NzIxLTkwZDUtNWE5NzBjMzFjNzIzLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA3MTklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwNzE5VDE0Mzc0OFomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWY3ZjAwNzRjYWFkZGJlMzA3M2E5MzlkOWNlMGM3Y2IxM2FhZjc3ZDY2OGYyZGY1M2E0MjBjNWRlNzQ0YjYyMGEmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.ykzK4U-pRK_LhlKYKPT3AIotCIcDIkM-3NVFlNDKHgY)
You can see the mapped data actually close to source data,
and if you use POT way,
mapped_data_train = np.dot(pi, target_features_train.T) / pi.sum(dim=1).reshape(-1, 1)
fig = plt.figure(figsize=(4, 4))
ax = fig.add_subplot()
ax.set_title("Source and target features")
ax.set_aspect("equal", "datalim")
ax.scatter(source_features_train[0], source_features_train[1], label="Source")
ax.scatter(target_features_train[0], target_features_train[1], label="Target")
ax.scatter(mapped_data_train.T[0], mapped_data_train.T[1], label="trans")
ax.legend()
plt.show()
Then the plot will be:
![download](https://private-user-images.githubusercontent.com/24941293/333599251-599ef211-59f2-45c3-86d0-d026d280cfff.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MjE0MDAxNjgsIm5iZiI6MTcyMTM5OTg2OCwicGF0aCI6Ii8yNDk0MTI5My8zMzM1OTkyNTEtNTk5ZWYyMTEtNTlmMi00NWMzLTg2ZDAtZDAyNmQyODBjZmZmLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA3MTklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwNzE5VDE0Mzc0OFomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTkyYzFmYTQxOWJiNmZiZTZmZGFhMTEzMDkwOTQzMDVhYWM4MzE0NTlhMGEzYjg3NzY5YTUzMTA1ZWVhODJlMGYmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.XoudaqSFCbmo0S6XyJo-Q0toY1gM1kn9uOs1WZ2cMS8)
You can see the mapped data close to the target data.
Proof 2 - theory
Here I can show you the result does not make sense.
We assume:
$S_{s}$ means the Source data in the source space. The shape is [3000, 50], which means 3000 points, 50 features of each point in a 50-dim space.
$T_{t}$ means the Target data in the target space. The shape is [9000, 100], which means 9000 points, 100 features of each point in a 100-dim space.
OT matrix pi, the shape is [3000, 9000].
POT code
From POT code, if we want to get the mapped source data in target space $S_{t}$, we can use:
$$S_{T} = pi \cdot T_{t}$$
The $S_{t}$ shape will be [3000, 100], the details of shapes according to the formula before:
$$[3000, 100] = [3000, 9000] \cdot [9000, 100]$$
The source data shape from [3000, 50] in the 50-dim space map to [3000, 100] in the 100-dim space, the point number does not change. Each point just moves from the 50-dim space to the 100-dim space.
So the explanation of the OT algorithm is:
OT algorithm can map the data from the source space to the target space, without point number change.
FUGW code
According to FUGW code, if we want to get the mapped source data in target space $S_{t}$, we can use:
$$S_{T} = pi^{T} \cdot S_{s}$$
The $S_{t}$ shape will be [3000, 100], and the details of shapes according to the formula will be:
$$[9000, 50] = [9000, 3000] \cdot [3000, 50]$$
So the source data from [3000, 50] in the 50-dim map to [9000, 50] is still in 50-dim space, the data not in the 100-dim target space! It does not make sense!
Please let me know if I was wrong :)
Btw, thanks a lot for the contribution to FUGW, it helps me a lot.