I have been reading your paper as well as trying to understand this github repo code. There has been several doubts I would like to ask you about. For this issue, I will ask about the corresponding code of loss function.
class StructuredLossWithTurner(nn.Module):
def __init__(self, model, loss_pos_paired=0, loss_neg_paired=0, loss_pos_unpaired=0, loss_neg_unpaired=0,
l1_weight=0., l2_weight=0., sl_weight=1., verbose=False):
super(StructuredLossWithTurner, self).__init__()
self.model = model
self.loss_pos_paired = loss_pos_paired
self.loss_neg_paired = loss_neg_paired
self.loss_pos_unpaired = loss_pos_unpaired
self.loss_neg_unpaired = loss_neg_unpaired
self.l1_weight = l1_weight
self.l2_weight = l2_weight
self.sl_weight = sl_weight
self.verbose = verbose
from .fold.rnafold import RNAFold
from . import param_turner2004
if getattr(self.model, "turner", None):
self.turner = self.model.turner
else:
self.turner = RNAFold(param_turner2004).to(next(self.model.parameters()).device)
def forward(self, seq, pairs, fname=None):
pred, pred_s, _, param = self.model(seq, return_param=True, reference=pairs,
loss_pos_paired=self.loss_pos_paired, loss_neg_paired=self.loss_neg_paired,
loss_pos_unpaired=self.loss_pos_unpaired, loss_neg_unpaired=self.loss_neg_unpaired)
ref, ref_s, _ = self.model(seq, param=param, constraint=pairs, max_internal_length=None)
with torch.no_grad():
ref2, ref2_s, _ = self.turner(seq, constraint=pairs, max_internal_length=None)
l = torch.tensor([len(s) for s in seq], device=pred.device)
loss = (pred - ref) / l
loss += self.sl_weight * (ref-ref2) * (ref-ref2) / l
if self.verbose:
print("Loss = {} = ({} - {})".format(loss.item(), pred.item(), ref.item()))
print(seq)
print(pred_s)
print(ref_s)
if loss.item()> 1e10 or torch.isnan(loss):
print()
print(fname)
print(loss.item(), pred.item(), ref.item())
print(seq)
if self.l1_weight > 0.0:
for p in self.model.parameters():
loss += self.l1_weight * torch.sum(torch.abs(p))
# if self.l2_weight > 0.0:
# l2_reg = 0.0
# for p in self.model.parameters():
# l2_reg += torch.sum((self.l2_weight * p) ** 2)
# loss += torch.sqrt(l2_reg)
return loss
From the python code only, I am not able to match to the function from paper. I can see both L1 and L2 regularization terms, but I am not able to find the code for structured hinge loss function (margin term and max function), I suppose f(x, y) from paper means the score (first return item) from self.model. I assume that the max margin part is embedded in self.model which is in the C++ code part. From the paper, you are using MixedFold, can you please kindly point out which part and lines of C++ code are the max margin function?