support report custom args
Former-commit-id: d41254c40a1c5cacf9377096adb27efa9bdb79ea
This commit is contained in:
@@ -171,7 +171,10 @@ class HuggingfaceEngine(BaseEngine):
|
||||
elif not isinstance(value, torch.Tensor):
|
||||
value = torch.tensor(value)
|
||||
|
||||
gen_kwargs[key] = value.to(dtype=model.dtype, device=model.device)
|
||||
if torch.is_floating_point(value):
|
||||
value = value.to(model.dtype)
|
||||
|
||||
gen_kwargs[key] = value.to(model.device)
|
||||
|
||||
return gen_kwargs, prompt_length
|
||||
|
||||
|
||||
Reference in New Issue
Block a user